How to use the delve.pca_layers.Conv2DPCALayer function in delve

To help you get started, we’ve selected a few delve examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github delve-team / delve / delve / pca_layers.py View on Github external
def __init__(self, in_filters, threshold: float = 0.99, verbose: bool = True, gradient_epoch_start: int = 20, centering: bool = False, downsampling: int = None):
        super(Conv2DPCALayer, self).__init__(centering=centering, in_features=in_filters, threshold=threshold, keepdim=True, verbose=verbose, gradient_epoch_start=gradient_epoch_start)
        if verbose:
            print('Added Conv2D PCA Layer')
        self.convolution = torch.nn.Conv2d(in_channels=in_filters,
                                           out_channels=in_filters,
                                           kernel_size=1, stride=1, bias=True)
        self.mean_subtracting_convolution = torch.nn.Conv2d(in_channels=in_filters,
                                                              out_channels=in_filters,
                                                              kernel_size=1, stride=1, bias=True)
        self.mean_subtracting_convolution.weight = torch.nn.Parameter(
            torch.zeros((in_filters, in_filters)).unsqueeze(2).unsqueeze(3)
        )
        self.downsampling = downsampling
github delve-team / delve / delve / pca_layers.py View on Github external
names = []
    lc = {'lin': 0, 'conv': 0}
    for module in network.modules():
        if isinstance(module, LinearPCALayer):
            module.threshold = threshold
            fake_base = rvs(module.fs_dim)[:, :module.in_dim]
            in_dims.append(module.in_dim)
            fs_dims.append(module.fs_dim)
            sat.append(module.sat)
            fake_projection = fake_base @ fake_base.T
            module.transformation_matrix.data = torch.from_numpy(fake_projection.astype('float32')).to(device)
            names.append(f'Linear-{lc["lin"]}')
            lc["lin"] += 1
            if verbose:
                print(f'Changed threshold for layer {module} to {threshold}')
        elif isinstance(module, Conv2DPCALayer):
            module.threshold = threshold
            in_dims.append(module.in_dim)
            fs_dims.append(module.fs_dim)
            sat.append(module.sat)
            fake_base = rvs(module.fs_dim)[:, :module.in_dim]
            fake_projection = fake_base @ fake_base.T
            module.transformation_matrix.data = torch.from_numpy(fake_projection.astype('float32')).to(device)
            weight = torch.nn.Parameter(module.transformation_matrix.unsqueeze(2).unsqueeze(3))
            module.convolution.weight = weight
            names.append(f'Conv-{lc["conv"]}')
            lc['conv'] += 1
            if verbose:
                print(f'Changed threshold for layer {module} to {threshold}')
    if include_names:
        return sat, in_dims, fs_dims, names
    return sat, in_dims, fs_dims