Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# Note that cloning behaviour of `eval_tsr` is different
# when `forward_hook_with_return` is set to True. This is because
# otherwise `backward()` on the last output layer won't execute.
if forward_hook_with_return:
saved_layer[eval_tsr.device] = eval_tsr
eval_tsr_to_return = eval_tsr.clone()
return (eval_tsr_to_return,) if is_tuple else eval_tsr_to_return
else:
saved_layer[eval_tsr.device] = eval_tsr.clone()
if attribute_to_layer_input:
hook = layer.register_forward_pre_hook(forward_hook)
else:
hook = layer.register_forward_hook(forward_hook)
output = _run_forward(
forward_fn,
inputs,
target=target_ind,
additional_forward_args=additional_forward_args,
)
hook.remove()
if len(saved_layer) == 0:
raise AssertionError("Forward hook did not obtain any outputs for given layer")
if forward_hook_with_return:
return saved_layer, output
return saved_layer
transformed_inputs[feature_i] = self._transform(
feature.input_transforms, transformed_inputs[feature_i], True
)
if feature.baseline_transforms is not None:
assert baseline_transforms_len == len(
feature.baseline_transforms
), "Must have same number of baselines across all features"
for baseline_i, baseline_transform in enumerate(
feature.baseline_transforms
):
baselines[baseline_i][feature_i] = self._transform(
baseline_transform, transformed_inputs[feature_i], True
)
outputs = _run_forward(
net,
tuple(transformed_inputs),
additional_forward_args=additional_forward_args,
)
if self.score_func is not None:
outputs = self.score_func(outputs)
if outputs.nelement() == 1:
scores = outputs
predicted = scores.round().to(torch.int)
else:
scores, predicted = outputs.topk(min(4, outputs.shape[-1]))
scores = scores.cpu().squeeze(0)
predicted = predicted.cpu().squeeze(0)
"Attributions tensor and the end_point must match on the first"
" dimension but found attribution: {} and end_point: {}".format(
attribution.shape[0], end_point_tnsr.shape[0]
)
)
num_samples = end_point[0].shape[0]
_validate_input(end_point, start_point)
_validate_target(num_samples, target)
def _sum_rows(input):
return input.view(input.shape[0], -1).sum(1)
with torch.no_grad():
start_point = _sum_rows(
_run_forward(
self.forward_func, start_point, target, additional_forward_args
)
)
end_point = _sum_rows(
_run_forward(
self.forward_func, end_point, target, additional_forward_args
)
)
row_sums = [_sum_rows(attribution) for attribution in attributions]
attr_sum = torch.stack([sum(row_sum) for row_sum in zip(*row_sums)])
return attr_sum - (end_point - start_point)
num_samples = end_point[0].shape[0]
_validate_input(end_point, start_point)
_validate_target(num_samples, target)
def _sum_rows(input):
return input.view(input.shape[0], -1).sum(1)
with torch.no_grad():
start_point = _sum_rows(
_run_forward(
self.forward_func, start_point, target, additional_forward_args
)
)
end_point = _sum_rows(
_run_forward(
self.forward_func, end_point, target, additional_forward_args
)
)
row_sums = [_sum_rows(attribution) for attribution in attributions]
attr_sum = torch.stack([sum(row_sum) for row_sum in zip(*row_sums)])
return attr_sum - (end_point - start_point)
current_add_args,
current_target,
current_mask,
) in self._ablation_generator(
i,
inputs,
additional_forward_args,
target,
baselines,
feature_mask,
ablations_per_eval,
**kwargs
):
# modified_eval dimensions: 1D tensor with length
# equal to #num_examples * #features in batch
modified_eval = _run_forward(
self.forward_func,
current_inputs,
current_target,
current_add_args,
)
# eval_diff dimensions: (#features in batch, #num_examples, 1,.. 1)
# (contains 1 more dimension than inputs). This adds extra
# dimensions of 1 to make the tensor broadcastable with the inputs
# tensor.
if single_output_mode:
eval_diff = initial_eval - modified_eval
else:
eval_diff = (
initial_eval - modified_eval.reshape(-1, num_examples)
).reshape(
(-1, num_examples) + (len(inputs[i].shape) - 1) * (1,)
gradient_mask = apply_gradient_requirements(inputs)
_validate_input(inputs, baselines)
# set hooks for baselines
warnings.warn(
"""Setting forward, backward hooks and attributes on non-linear
activations. The hooks and attributes will be removed
after the attribution is finished"""
)
self.model.apply(self._register_hooks_ref)
baselines = _tensorize_baseline(inputs, baselines)
_run_forward(
self.model,
baselines,
target=target,
additional_forward_args=additional_forward_args,
)
# remove forward hook set for baselines
for forward_handles_ref in self.forward_handles_refs:
forward_handles_ref.remove()
self.model.apply(self._register_hooks)
gradients = self.gradient_func(
self.model,
inputs,
target_ind=target,
additional_forward_args=additional_forward_args,
)
Args:
forward_fn: forward function. This can be for example model's
forward function.
input: Input at which gradients are evaluated,
will be passed to forward_fn.
target_ind: Index of the target class for which gradients
must be computed (classification only).
args: Additional input arguments that forward function requires.
It takes an empty tuple (no additional arguments) if no
additional arguments are required
"""
with torch.autograd.set_grad_enabled(True):
# runs forward pass
outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args)
assert outputs[0].numel() == 1, (
"Target not provided when necessary, cannot"
" take gradient with respect to multiple outputs."
)
# torch.unbind(forward_out) is a list of scalar tensor tuples and
# contains batch_size * #steps elements
grads = torch.autograd.grad(torch.unbind(outputs), inputs)
return grads