Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __init__(self, in_filters, threshold: float = 0.99, verbose: bool = True, gradient_epoch_start: int = 20, centering: bool = False, downsampling: int = None):
super(Conv2DPCALayer, self).__init__(centering=centering, in_features=in_filters, threshold=threshold, keepdim=True, verbose=verbose, gradient_epoch_start=gradient_epoch_start)
if verbose:
print('Added Conv2D PCA Layer')
self.convolution = torch.nn.Conv2d(in_channels=in_filters,
out_channels=in_filters,
kernel_size=1, stride=1, bias=True)
self.mean_subtracting_convolution = torch.nn.Conv2d(in_channels=in_filters,
out_channels=in_filters,
kernel_size=1, stride=1, bias=True)
self.mean_subtracting_convolution.weight = torch.nn.Parameter(
torch.zeros((in_filters, in_filters)).unsqueeze(2).unsqueeze(3)
)
self.downsampling = downsampling
names = []
lc = {'lin': 0, 'conv': 0}
for module in network.modules():
if isinstance(module, LinearPCALayer):
module.threshold = threshold
fake_base = rvs(module.fs_dim)[:, :module.in_dim]
in_dims.append(module.in_dim)
fs_dims.append(module.fs_dim)
sat.append(module.sat)
fake_projection = fake_base @ fake_base.T
module.transformation_matrix.data = torch.from_numpy(fake_projection.astype('float32')).to(device)
names.append(f'Linear-{lc["lin"]}')
lc["lin"] += 1
if verbose:
print(f'Changed threshold for layer {module} to {threshold}')
elif isinstance(module, Conv2DPCALayer):
module.threshold = threshold
in_dims.append(module.in_dim)
fs_dims.append(module.fs_dim)
sat.append(module.sat)
fake_base = rvs(module.fs_dim)[:, :module.in_dim]
fake_projection = fake_base @ fake_base.T
module.transformation_matrix.data = torch.from_numpy(fake_projection.astype('float32')).to(device)
weight = torch.nn.Parameter(module.transformation_matrix.unsqueeze(2).unsqueeze(3))
module.convolution.weight = weight
names.append(f'Conv-{lc["conv"]}')
lc['conv'] += 1
if verbose:
print(f'Changed threshold for layer {module} to {threshold}')
if include_names:
return sat, in_dims, fs_dims, names
return sat, in_dims, fs_dims