Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
grad_kernel.size(1) // batch_size,
grad_kernel.size(2),
grad_kernel.size(3),
)
grad_kernel = (
grad_kernel.sum(dim=0)
.view(in_channels, out_channels, grad_kernel.size(2), grad_kernel.size(3))
.transpose(0, 1)
)
grad_kernel = grad_kernel.narrow(2, 0, kernel_size_y)
grad_kernel = grad_kernel.narrow(3, 0, kernel_size_x)
return (grad_input, grad_kernel)
@register_function("batchnorm")
class AutogradBatchNorm(AutogradFunction):
@staticmethod
def forward(
ctx,
input,
running_mean=None,
running_var=None,
training=False,
eps=1e-05,
momentum=0.1,
):
"""
Computes forward step of batch norm by normalizing x
and returning weight * x_norm + bias.
Running mean and var are computed over the `C` dimension for an input
of size `(N, C, +)`.
@register_function("cumsum")
class AutogradCumsum(AutogradFunction):
@staticmethod
def forward(ctx, input):
input, dim = input
ctx.save_for_backward(dim)
return input.cumsum(dim)
@staticmethod
def backward(ctx, grad_output):
dim, = ctx.saved_tensors
return grad_output.flip(dim).cumsum(dim).flip(dim)
@register_function("trace")
class AutogradTrace(AutogradFunction):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input.size()[0])
return input.trace()
@staticmethod
def backward(ctx, grad_output):
size, = ctx.saved_tensors
return grad_output.new(torch.eye(size)).mul_(grad_output)
@register_function("mean")
class AutogradMean(AutogradFunction):
@staticmethod
def forward(ctx, input, dim=None, keepdim=False):
ctx.save_multiple_for_backward((input.size(), dim, keepdim))
@register_function("sin")
class AutogradSin(AutogradFunction):
@staticmethod
def forward(ctx, input):
cossin = input.cossin()
ctx.save_for_backward(cossin[0])
return cossin[1]
@staticmethod
def backward(ctx, grad_output):
cos, = ctx.saved_tensors
return grad_output.mul(cos)
@register_function("cos")
class AutogradCos(AutogradFunction):
@staticmethod
def forward(ctx, input):
cossin = input.cossin()
ctx.save_for_backward(cossin[1])
return cossin[0]
@staticmethod
def backward(ctx, grad_output):
sin, = ctx.saved_tensors
return grad_output.mul(sin.neg_())
@register_function("abs")
class AutogradAbs(AutogradFunction):
@staticmethod
def forward(ctx, input):
@staticmethod
def backward(ctx, grad_output):
reciprocal, other = ctx.saved_tensors
grad_input = reciprocal.square().mul(other).mul(grad_output).neg()
grad_input = _inverse_broadcast(grad_input, reciprocal.size())
if torch.is_tensor(other) or crypten.is_encrypted_tensor(other):
grad_other = reciprocal.mul(grad_output)
grad_other = _inverse_broadcast(grad_other, other.size())
return (grad_input, grad_other)
else:
return grad_input
@register_function("pow")
class AutogradPow(AutogradFunction):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
input, power = input
return input.pow(power)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
input, power = input
return input.pow(power - 1.0).mul_(power).mul_(grad_output)
@register_function("pos_pow")
class AutogradPosPow(AutogradFunction):
@staticmethod
@register_function("dot")
class AutogradDot(AutogradFunction):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input[0].dot(input[1])
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
self_, other = input
return (grad_output.mul(other), grad_output.mul(self_))
@register_function("ger")
class AutogradGer(AutogradFunction):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input[0].ger(input[1])
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
input, other = input
return (grad_output.matmul(other), input.matmul(grad_output))
@register_function("sin")
class AutogradSin(AutogradFunction):
@staticmethod
def forward(ctx, input):
grad_flat = grad.flatten()
flat_index = index.flatten()
grad_output_flat = grad_output.flatten()
grad_flat[flat_index] = grad_output_flat
grad = grad_flat.reshape(size)
else:
flat_index = index.flatten()
grad_output_flat = grad_output.flatten(
start_dim=dimension, end_dim=(dimension + index.dim() - 1)
)
grad.index_add_(dimension, flat_index, grad_output_flat)
return grad
@register_function("gather")
class AutogradGather(AutogradFunction):
@staticmethod
def forward(ctx, input):
input, dim, index = input
ctx.save_multiple_for_backward([input.size(), dim, index])
return input.gather(dim, index)
@staticmethod
def backward(ctx, grad_output):
size, dim, index = ctx.saved_tensors
return grad_output.new(torch.zeros(size)).scatter_add_(dim, index, grad_output)
@register_function("scatter")
class AutogradScatter(AutogradFunction):
@staticmethod
def forward(ctx, input):
for i, shift in enumerate(shifts):
shifts[i] = -shift
shifts.reverse()
else:
shifts = -shifts
# Reverse dims
if isinstance(dims, (tuple, list)):
dims = list(dims)
dims.reverse()
return grad_output.roll(shifts, dims)
@register_function("squeeze")
class AutogradSqueeze(AutogradFunction):
@staticmethod
def forward(ctx, input):
# preprocess inputs:
dim = None
if isinstance(input, (tuple, list)) and len(input) == 1:
input, = input # no dimension to squeeze specified
elif isinstance(input, (tuple, list)):
input, dim = input # dimension to squeeze specified
# perform the actual squeeze:
output = input.squeeze() if dim is None else input.squeeze(dim)
# keep correct dimensions for backward pass:
if dim is None:
dims = [idx for idx, sz in enumerate(output.size()) if sz == 1]
@register_function("gather")
class AutogradGather(AutogradFunction):
@staticmethod
def forward(ctx, input):
input, dim, index = input
ctx.save_multiple_for_backward([input.size(), dim, index])
return input.gather(dim, index)
@staticmethod
def backward(ctx, grad_output):
size, dim, index = ctx.saved_tensors
return grad_output.new(torch.zeros(size)).scatter_add_(dim, index, grad_output)
@register_function("scatter")
class AutogradScatter(AutogradFunction):
@staticmethod
def forward(ctx, input):
input, dim, index, src = input
output = input.scatter(dim, index, src)
ctx.save_multiple_for_backward([dim, index])
return output
@staticmethod
def backward(ctx, grad_output):
dim, index = ctx.saved_tensors
size = grad_output.size()
mask = torch.ones(size).scatter(dim, index, torch.zeros(size)).long()
input_grad = grad_output.mul(mask)
src_grad = grad_output.gather(dim, index)
return (input_grad, src_grad)
class AutogradSum(AutogradFunction):
@staticmethod
def forward(ctx, input, dim=None, keepdim=False):
ctx.save_multiple_for_backward((input.size(), dim, keepdim))
return input.sum(dim=dim, keepdim=keepdim) if dim is not None else input.sum()
@staticmethod
def backward(ctx, grad_output):
input_size, dim, keepdim = ctx.saved_tensors
if not keepdim and dim is not None:
grad_output = grad_output.unsqueeze(dim)
return grad_output.mul(torch.ones(input_size))
@register_function("cumsum")
class AutogradCumsum(AutogradFunction):
@staticmethod
def forward(ctx, input):
input, dim = input
ctx.save_for_backward(dim)
return input.cumsum(dim)
@staticmethod
def backward(ctx, grad_output):
dim, = ctx.saved_tensors
return grad_output.flip(dim).cumsum(dim).flip(dim)
@register_function("trace")
class AutogradTrace(AutogradFunction):
@staticmethod
def forward(ctx, input):
@register_function("stack")
class AutogradStack(AutogradFunction):
@staticmethod
def forward(ctx, input, dim=0):
ctx.save_for_backward(dim)
return crypten.stack(input, dim=dim)
@staticmethod
def backward(ctx, grad_output):
dim, = ctx.saved_tensors
return grad_output.unbind(dim=dim)
@register_function("view")
class AutogradView(AutogradFunction):
@staticmethod
def forward(ctx, input):
input, *size = input
ctx.save_for_backward(input)
return input.view(*size)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return grad_output.view(input.size())
@register_function("reshape")
class AutogradReshape(AutogradFunction):
@staticmethod
def forward(ctx, input):