Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _saliency_base_assert(
self, model, inputs, expected, additional_forward_args=None, nt_type="vanilla"
):
saliency = Saliency(model)
if nt_type == "vanilla":
attributions = saliency.attribute(
inputs, additional_forward_args=additional_forward_args
)
else:
nt = NoiseTunnel(saliency)
attributions = nt.attribute(
inputs,
nt_type=nt_type,
n_samples=10,
stdevs=0.0000002,
additional_forward_args=additional_forward_args,
)
if isinstance(attributions, tuple):
for input, attribution, expected_attr in zip(
inputs, attributions, expected
):
if nt_type == "vanilla":
self._assert_attribution(attribution, expected_attr)
self.assertEqual(input.shape, attribution.shape)
else:
if nt_type == "vanilla":
def _input_x_gradient_base_assert(
self,
model,
inputs,
expected_grads,
additional_forward_args=None,
nt_type="vanilla",
):
input_x_grad = InputXGradient(model)
if nt_type == "vanilla":
attributions = input_x_grad.attribute(
inputs, additional_forward_args=additional_forward_args
)
else:
nt = NoiseTunnel(input_x_grad)
attributions = nt.attribute(
inputs,
nt_type=nt_type,
n_samples=10,
stdevs=0.0002,
additional_forward_args=additional_forward_args,
)
if isinstance(attributions, tuple):
for input, attribution, expected_grad in zip(
inputs, attributions, expected_grads
):
if nt_type == "vanilla":
assertArraysAlmostEqual(
attribution.reshape(-1), (expected_grad * input).reshape(-1)
)
target = torch.tensor(5)
# 10-class classification model
model = SoftmaxModel(num_in, 20, 10)
input_x_grad = InputXGradient(model.forward)
if nt_type == "vanilla":
attributions = input_x_grad.attribute(input, target)
output = model(input)[:, target]
output.backward()
expercted = input.grad * input
self.assertEqual(
expercted.detach().numpy().tolist(),
attributions.detach().numpy().tolist(),
)
else:
nt = NoiseTunnel(input_x_grad)
attributions = nt.attribute(
input, nt_type=nt_type, n_samples=10, stdevs=1.0, target=target
)
self.assertAlmostEqual(attributions.shape, input.shape)
# 10-class classification model
model = SoftmaxModel(num_in, 20, 10)
saliency = Saliency(model)
if nt_type == "vanilla":
attributions = saliency.attribute(input, target)
output = model(input)[:, target]
output.backward()
expected = torch.abs(input.grad)
self.assertEqual(
expected.detach().numpy().tolist(),
attributions.detach().numpy().tolist(),
)
else:
nt = NoiseTunnel(saliency)
attributions = nt.attribute(
input, nt_type=nt_type, n_samples=10, stdevs=0.0002, target=target
)
self.assertEqual(input.shape, attributions.shape)