Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
shape = activations_batch.shape
reshaped_batch = activations_batch.reshape(shape[0], shape[1], shape[2] * shape[3])
activations_batch, _ = torch.max(reshaped_batch, dim=2) # channel median
elif self.conv_method == 'mean':
activations_batch = torch.mean(activations_batch, dim=(2, 3))
elif self.conv_method == 'flatten':
activations_batch = activations_batch.view(activations_batch.size(0), -1)
elif self.conv_method == 'channelwise':
reshaped_batch: torch.Tensor = activations_batch.permute([1, 0, 2, 3])
shape = reshaped_batch.shape
reshaped_batch: torch.Tensor = reshaped_batch.flatten(1)
reshaped_batch: torch.Tensor = reshaped_batch.permute([1, 0])
activations_batch = reshaped_batch
if layer.name not in self.logs[f'{training_state}-{stat}']:
self.logs[f'{training_state}-{stat}'][layer.name] = TorchCovarianceMatrix(device=self.device)
self.logs[f'{training_state}-{stat}'][layer.name].update(activations_batch, lstm_ae)