Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# Use pretrained ResNet-18 provided by PyTorch
model = models.resnet18(pretrained=True)
model = model.to(device)
# Initialize saliency methods
saliency_methods = {
# FullGrad-based methods
'fullgrad': FullGrad(model),
'simple_fullgrad': SimpleFullGrad(model),
'smooth_fullgrad': SmoothFullGrad(model),
# Other saliency methods from literature
'gradcam': GradCAM(model),
'inputgrad': InputGradient(model),
'smoothgrad': SmoothGrad(model)
}
def compute_saliency_and_save():
for batch_idx, (data, _) in enumerate(sample_loader):
data = data.to(device).requires_grad_()
# Compute saliency maps for the input data
for s in saliency_methods:
saliency_map = saliency_methods[s].saliency(data)
# Save saliency maps
for i in range(data.size(0)):
filename = save_path + str( (batch_idx+1) * (i+1))
image = unnormalize(data[i].cpu())
save_saliency_map(image, saliency_map[i], filename + '_' + s + '.jpg')