Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def get_saliency(model, raw_input, input, label, method='gradcam', layer_path=None):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
input = input.to(device)
if label is not None:
label = label.to(device)
if input.grad is not None:
input.grad.zero_()
if label is not None and label.grad is not None:
label.grad.zero_()
model.eval()
model.zero_grad()
exp = get_explainer(method, model, layer_path)
saliency = exp.explain(input, label, raw_input)
if saliency is not None:
saliency = saliency.abs().sum(dim=1)[0].squeeze()
saliency -= saliency.min()
saliency /= (saliency.max() + 1e-20)
return saliency.detach().cpu().numpy()
else:
return None