Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __init__(self, in_features: int,
threshold: float = .99,
keepdim: bool = True,
verbose: bool = False,
gradient_epoch_start: int = 20,
centering: bool = True):
super(LinearPCALayer, self).__init__()
self.register_buffer('eigenvalues', torch.zeros(in_features, dtype=torch.float64))
self.register_buffer('eigenvectors', torch.zeros((in_features, in_features), dtype=torch.float64))
self.register_buffer('_threshold', torch.Tensor([threshold]).type(torch.float64))
self.register_buffer('sum_squares', torch.zeros((in_features, in_features), dtype=torch.float64))
self.register_buffer('seen_samples', torch.zeros(1, dtype=torch.float64))
self.register_buffer('running_sum', torch.zeros(in_features, dtype=torch.float64))
self.register_buffer('mean', torch.zeros(in_features, dtype=torch.float32))
self.keepdim: bool = keepdim
self.verbose: bool = verbose
self.pca_computed: bool = True
self.gradient_epoch = gradient_epoch_start
self.epoch = 0
self.name = f'pca{LinearPCALayer.num}'
LinearPCALayer.num += 1
self._centering = centering
self.data_dtype = None
def change_all_pca_layer_thresholds(threshold: float, network: Module, verbose: bool = False):
in_dims = []
fs_dims = []
sat = []
names = []
lc = {'lin': 0, 'conv': 0}
for module in network.modules():
if isinstance(module, Conv2DPCALayer) or isinstance(module, LinearPCALayer):
module.threshold = threshold
in_dims.append(module.in_dim)
fs_dims.append(module.fs_dim)
sat.append(module.sat)
if isinstance(module, Conv2DPCALayer):
names.append(f'Conv-{lc["conv"]}')
lc['conv'] += 1
else:
names.append(f"Lin-{lc['lin']}")
lc["lin"] += 1
if verbose:
print(f'Changed threshold for layer {module} to {threshold}')
return sat, in_dims, fs_dims, names
self.epoch += 1
if self.keepdim:
if not self.centering:
return x @ self.transformation_matrix.t()
else:
self.mean = self.mean.to(x.device)
self.transformation_matrix = self.transformation_matrix.to(x.device)
return ((x-self.mean) @ self.transformation_matrix.t()) + self.mean
else:
if not self.centering:
return x @ self.reduced_transformation_matrix
else:
return ((x-self.mean) @ self.reduced_transformation_matrix) + self.mean
class Conv2DPCALayer(LinearPCALayer):
def __init__(self, in_filters, threshold: float = 0.99, verbose: bool = True, gradient_epoch_start: int = 20, centering: bool = False, downsampling: int = None):
super(Conv2DPCALayer, self).__init__(centering=centering, in_features=in_filters, threshold=threshold, keepdim=True, verbose=verbose, gradient_epoch_start=gradient_epoch_start)
if verbose:
print('Added Conv2D PCA Layer')
self.convolution = torch.nn.Conv2d(in_channels=in_filters,
out_channels=in_filters,
kernel_size=1, stride=1, bias=True)
self.mean_subtracting_convolution = torch.nn.Conv2d(in_channels=in_filters,
out_channels=in_filters,
kernel_size=1, stride=1, bias=True)
self.mean_subtracting_convolution.weight = torch.nn.Parameter(
torch.zeros((in_filters, in_filters)).unsqueeze(2).unsqueeze(3)
)
self.downsampling = downsampling