Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_batchnorm_module(self):
"""Test module correctly sets and updates running stats"""
batchnorm_fn_and_size = (
("BatchNorm1d", (500, 10, 3)),
("BatchNorm2d", (600, 7, 4, 20)),
("BatchNorm3d", (800, 5, 4, 8, 15)),
)
for batchnorm_fn, size in batchnorm_fn_and_size:
for is_trainning in (True, False):
tensor = get_random_test_tensor(size=size, is_float=True)
tensor.requires_grad = True
encrypted_input = AutogradCrypTensor(crypten.cryptensor(tensor))
C = size[1]
weight = get_random_test_tensor(size=[C], max_value=1, is_float=True)
bias = get_random_test_tensor(size=[C], max_value=1, is_float=True)
weight.requires_grad = True
bias.requires_grad = True
# dimensions for mean and variance
stats_dimensions = list(range(tensor.dim()))
# perform on C dimension for tensor of shape (N, C, +)
stats_dimensions.pop(1)
# check running stats initial
enc_model = getattr(crypten.nn.module, batchnorm_fn)(C).encrypt()
plain_model = getattr(torch.nn.modules, batchnorm_fn)(C)
stats = ["running_var", "running_mean"]
sequential = crypten.nn.Sequential(module_list)
sequential.encrypt()
# check container:
self.assertTrue(sequential.encrypted, "nn.Sequential not encrypted")
for module in sequential.modules():
self.assertTrue(module.encrypted, "module not encrypted")
assert sum(1 for _ in sequential.modules()) == len(
module_list
), "nn.Sequential contains incorrect number of modules"
# construct test input and run through sequential container:
input = get_random_test_tensor(size=input_size, is_float=True)
encr_input = crypten.cryptensor(input)
if wrap:
encr_input = AutogradCrypTensor(encr_input)
encr_output = sequential(encr_input)
# compute reference output:
encr_reference = encr_input
for module in sequential.modules():
encr_reference = module(encr_reference)
reference = encr_reference.get_plain_text()
# compare output to reference:
self._check(encr_output, reference, "nn.Sequential forward failed")
def test_square(self):
"""Tests square function gradient.
Note: torch pow(2) is used to verify gradient,
since PyTorch does not implement square().
"""
for size in SIZES:
tensor = get_random_test_tensor(size=size, is_float=True)
tensor.requires_grad = True
tensor_encr = AutogradCrypTensor(
crypten.cryptensor(tensor), requires_grad=True
)
out = tensor.pow(2)
out_encr = tensor_encr.square()
self._check(out_encr, out, f"square forward failed with size {size}")
grad_output = get_random_test_tensor(size=out.shape, is_float=True)
out.backward(grad_output)
out_encr.backward(crypten.cryptensor(grad_output))
self._check(
tensor_encr.grad,
tensor.grad,
f"square backward failed with size {size}",
)
# PyTorch test case:
for test in tests:
# get test case:
number_of_inputs, ops = test
inputs = [
get_random_test_tensor(size=(12, 5), is_float=True)
for _ in range(number_of_inputs)
]
encr_inputs = [crypten.cryptensor(input) for input in inputs]
# get autograd variables:
for input in inputs:
input.requires_grad = True
encr_inputs = [AutogradCrypTensor(encr_input) for encr_input in encr_inputs]
# perform forward pass, logging all intermediate outputs:
outputs, encr_outputs = [inputs], [encr_inputs]
for op in ops:
# get inputs for current operation:
input, output = outputs[-1], []
encr_input, encr_output = encr_outputs[-1], []
# apply current operation:
if op in binary_functions: # combine outputs via operation
output.append(getattr(input[0], op)(input[1]))
encr_output.append(getattr(encr_input[0], op)(encr_input[1]))
else:
for idx in range(len(input)):
output.append(getattr(input[idx], op)())
"Softmax": (5, 5, 5),
"LogSoftmax": (5, 5, 5),
}
# loop over all modules:
for module_name in module_args.keys():
for wrap in [True, False]:
# generate inputs:
input = get_random_test_tensor(
size=input_sizes[module_name], is_float=True
)
input.requires_grad = True
encr_input = crypten.cryptensor(input)
if wrap:
encr_input = AutogradCrypTensor(encr_input)
# create PyTorch module:
module = getattr(torch.nn, module_name)(*module_args[module_name])
module.train()
# create encrypted CrypTen module:
encr_module = crypten.nn.from_pytorch(module, input)
# check that module properly encrypts / decrypts and
# check that encrypting with current mode properly performs no-op
for encrypted in [False, True, True, False, True]:
encr_module.encrypt(mode=encrypted)
if encrypted:
self.assertTrue(encr_module.encrypted, "module not encrypted")
else:
self.assertFalse(encr_module.encrypted, "module encrypted")
"""Checks forward and backward against PyTorch
Args:
func_name (str): PyTorch/CrypTen function name
input_tensor (torch.tensor): primary input
args (list): contains arguments for function
msg (str): additional message for mismatch
kwargs (list): keyword arguments for function
"""
if msg is None:
msg = f"{func_name} grad_fn incorrect"
input = input_tensor.clone()
input.requires_grad = True
input_encr = AutogradCrypTensor(crypten.cryptensor(input), requires_grad=True)
for private in [False, True]:
input.grad = None
input_encr.grad = None
args = self._set_grad_to_zero(args)
args_encr = self._set_grad_to_zero(list(args), make_private=private)
# obtain torch function
if torch_func_name is not None:
torch_func = self._get_torch_func(torch_func_name)
else:
torch_func = self._get_torch_func(func_name)
reference = torch_func(input, *args, **kwargs)
encrypted_out = getattr(input_encr, func_name)(*args_encr, **kwargs)
def detach(self):
"""Detaches tensor from the autograd graph, making it a leaf."""
return AutogradCrypTensor(self._tensor.clone(), requires_grad=False)
"""
# convert tuples to lists to allow changes:
convert_to_tuple = False
if isinstance(args, tuple):
args = list(args)
convert_to_tuple = True
# wrap all input tensors in AutogradCrypTensor:
for idx in range(len(args)):
if isinstance(args[idx], (list, tuple)): # input may be list of tensors
args[idx] = _to_autograd(args[idx])
elif isinstance(args[idx], AutogradCrypTensor) or args[idx] is None:
pass
elif isinstance(args[idx], crypten.CrypTensor):
args[idx] = AutogradCrypTensor(args[idx])
else:
raise ValueError(
"Cannot convert type {} to AutogradCrypTensor.".format(type(args[idx]))
)
# return:
if convert_to_tuple:
args = tuple(args)
return args
name,
AutogradCrypTensor(
crypten.cryptensor(param, **{"src": src}),
requires_grad=requires_grad,
),
)
else: # decrypt parameter
self.set_parameter(name, param.get_plain_text())
self._parameters[name].requires_grad = requires_grad
# encrypt / decrypt buffers:
for name, buffer in self.named_buffers(recurse=False):
if mode: # encrypt buffer
self.set_buffer(
name,
AutogradCrypTensor(
crypten.cryptensor(buffer, **{"src": src}),
requires_grad=False,
),
)
else: # decrypt buffer
self.set_buffer(name, buffer.get_plain_text())
# apply encryption recursively:
return self._apply(lambda m: m.encrypt(mode=mode, src=src))
return self
def stack(tensors, dim=0):
assert isinstance(tensors, list), "input to stack must be a list"
if len(tensors) == 1:
return tensors[0].unsqueeze(dim)
from .autograd_cryptensor import AutogradCrypTensor
if any(isinstance(t, AutogradCrypTensor) for t in tensors):
if not isinstance(tensors[0], AutogradCrypTensor):
tensors[0] = AutogradCrypTensor(tensors[0], requires_grad=False)
return tensors[0].stack(*tensors[1:], dim=dim)
else:
return get_default_backend().stack(tensors, dim=dim)