Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
desc="unsuper", **self.tqdm_params_):
# select one input vector & calculate best matching unit (BMU)
dp = np.random.randint(low=0, high=len(self.X_))
bmu_pos = self.get_bmu(self.X_[dp], self.unsuper_som_)
# calculate learning rate and neighborhood function
learning_rate = self.calc_learning_rate(
curr_it=it, mode=self.learn_mode_unsupervised)
nbh_func = self.calc_neighborhood_func(
curr_it=it, mode=self.neighborhood_mode_unsupervised)
# calculate distance weight matrix and update weights
dist_weight_matrix = self.get_nbh_distance_weight_matrix(
nbh_func, bmu_pos)
self.unsuper_som_ = modify_weight_matrix_online(
self.unsuper_som_, dist_weight_matrix,
true_vector=self.X_[dp],
learningrate=learning_rate*self.sample_weights_[dp])
elif self.train_mode_unsupervised == "batch":
for it in tqdm(range(self.n_iter_unsupervised),
desc="unsuper", **self.tqdm_params_):
# calculate BMUs
bmus = self.get_bmus(self.X_)
# calculate neighborhood function
nbh_func = self.calc_neighborhood_func(
curr_it=it, mode=self.neighborhood_mode_unsupervised)
# calculate distance weight matrix for all datapoints
dist_weight_matrix : np.array of float
Current distance weight of the SOM for the specific node
data : np.array, optional
True vector(s)
learningrate : float, optional
Current learning rate of the SOM
Returns
-------
modify_weight_matrix : np.array
Weight vector of the SOM after the modification
"""
modify_weight_matrix = None
if self.train_mode_supervised == "online":
modify_weight_matrix = modify_weight_matrix_online(
self.super_som_, dist_weight_matrix, true_vector=true_vector,
learningrate=learningrate)
elif self.train_mode_supervised == "batch":
modify_weight_matrix = self.modify_weight_matrix_batch(
som_array=self.super_som_,
dist_weight_matrix=dist_weight_matrix[self.labeled_indices_],
data=self.y_[self.labeled_indices_])
else:
raise ValueError("Invalid train_mode_supervised: "+str(
self.train_mode_supervised))
return modify_weight_matrix