Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
else:
fpe = FixedPointEncoder(precision_bits=0)
tensor = get_test_tensor(float=float)
decoded = fpe.decode(fpe.encode(tensor))
self._check(
decoded,
tensor,
"Encoding/decoding a %s failed." % "float" if float else "long",
)
# Make sure encoding a subclass of CrypTensor is a no-op
crypten.mpc.set_default_provider(crypten.mpc.provider.TrustedFirstParty)
crypten.init()
tensor = get_test_tensor(float=True)
encrypted_tensor = crypten.cryptensor(tensor)
encrypted_tensor = fpe.encode(encrypted_tensor)
self._check(
encrypted_tensor.get_plain_text(),
tensor,
"Encoding an EncryptedTensor failed.",
)
# Try a few other types.
fpe = FixedPointEncoder(precision_bits=0)
for dtype in [torch.uint8, torch.int8, torch.int16]:
tensor = torch.zeros(5, dtype=dtype).random_()
decoded = fpe.decode(fpe.encode(tensor)).type(dtype)
self._check(decoded, tensor, "Encoding/decoding a %s failed." % dtype)
since PyTorch does not implement square().
"""
for size in SIZES:
tensor = get_random_test_tensor(size=size, is_float=True)
tensor.requires_grad = True
tensor_encr = AutogradCrypTensor(
crypten.cryptensor(tensor), requires_grad=True
)
out = tensor.pow(2)
out_encr = tensor_encr.square()
self._check(out_encr, out, f"square forward failed with size {size}")
grad_output = get_random_test_tensor(size=out.shape, is_float=True)
out.backward(grad_output)
out_encr.backward(crypten.cryptensor(grad_output))
self._check(
tensor_encr.grad,
tensor.grad,
f"square backward failed with size {size}",
)
crypten.nn.Linear(num_feat, num_feat - 1) for num_feat in layer_idx
]
sequential = crypten.nn.Sequential(module_list)
sequential.encrypt()
# check container:
self.assertTrue(sequential.encrypted, "nn.Sequential not encrypted")
for module in sequential.modules():
self.assertTrue(module.encrypted, "module not encrypted")
assert sum(1 for _ in sequential.modules()) == len(
module_list
), "nn.Sequential contains incorrect number of modules"
# construct test input and run through sequential container:
input = get_random_test_tensor(size=input_size, is_float=True)
encr_input = crypten.cryptensor(input)
if wrap:
encr_input = AutogradCrypTensor(encr_input)
encr_output = sequential(encr_input)
# compute reference output:
encr_reference = encr_input
for module in sequential.modules():
encr_reference = module(encr_reference)
reference = encr_reference.get_plain_text()
# compare output to reference:
self._check(encr_output, reference, "nn.Sequential forward failed")
encr_param = getattr(encr_node, key)
break
# or get it from the crypten Module directly:
else:
encr_param = getattr(encr_module, key)
# compare with reference:
# NOTE: Because some parameters are initialized randomly
# with different values on each process, we only want to
# check that they are consistent with source parameter value
reference = getattr(module, key)
src_reference = comm.get().broadcast(reference, src=0)
msg = "parameter %s in %s incorrect" % (key, module_name)
if not encrypted:
encr_param = crypten.cryptensor(encr_param)
self._check(encr_param, src_reference, msg)
# compare model outputs:
self.assertTrue(encr_module.training, "training value incorrect")
reference = module(input)
encr_output = encr_module(encr_input)
self._check(encr_output, reference, "%s forward failed" % module_name)
# test backward pass:
reference.backward(torch.ones(reference.size()))
encr_output.backward()
if wrap: # you cannot get input gradients on MPCTensor inputs
self._check(
encr_input.grad,
input.grad,
"%s backward on input failed" % module_name,
encr_module.encrypt()
self.assertTrue(encr_module.encrypted, "module not encrypted")
# generate inputs:
inputs, encr_inputs = None, None
ex_zero_values = module_name in ex_zero_modules
if module_name in binary_modules:
inputs = [
get_random_test_tensor(
size=input_sizes[module_name],
is_float=True,
ex_zero=ex_zero_values,
)
for _ in range(2)
]
encr_inputs = [crypten.cryptensor(input) for input in inputs]
elif module_name not in no_input_modules:
inputs = get_random_test_tensor(
size=input_sizes[module_name], is_float=True, ex_zero=ex_zero_values
)
encr_inputs = crypten.cryptensor(inputs)
# some modules take additonal indices as input:
if module_name in additional_inputs:
if not isinstance(inputs, (list, tuple)):
inputs, encr_inputs = [inputs], [encr_inputs]
inputs.append(additional_inputs[module_name])
encr_inputs.append(crypten.cryptensor(inputs[-1]))
# compare model outputs:
reference = module_lambdas[module_name](inputs)
encr_output = encr_module(encr_inputs)
input = AutogradCrypTensor(inputs[0])
reference = [True, True, False]
for func_name in ["min", "max"]:
outputs = [None] * 3
outputs[0] = getattr(input, func_name)()
outputs[1], outputs[2] = getattr(input, func_name)(dim=0)
for idx, output in enumerate(outputs):
self.assertEqual(
output.requires_grad,
reference[idx],
"value of requires_grad is incorrect",
)
# behavior of max_pool2d in which indices are returned:
input = get_random_test_tensor(size=(1, 3, 8, 8), is_float=True)
input = AutogradCrypTensor(crypten.cryptensor(input))
reference = [True, True, False]
outputs = [None] * 3
outputs[0] = input.max_pool2d(2, return_indices=False)
outputs[1], outputs[2] = input.max_pool2d(2, return_indices=True)
for idx, output in enumerate(outputs):
self.assertEqual(
output.requires_grad,
reference[idx],
"value of requires_grad is incorrect",
)
def encrypt(self, mode=True, src=0):
"""Encrypts the model."""
if mode != self.encrypted:
# encrypt / decrypt parameters:
self.encrypted = mode
for name, param in self.named_parameters(recurse=False):
requires_grad = param.requires_grad
if mode: # encrypt parameter
self.set_parameter(
name,
AutogradCrypTensor(
crypten.cryptensor(param, **{"src": src}),
requires_grad=requires_grad,
),
)
else: # decrypt parameter
self.set_parameter(name, param.get_plain_text())
self._parameters[name].requires_grad = requires_grad
# encrypt / decrypt buffers:
for name, buffer in self.named_buffers(recurse=False):
if mode: # encrypt buffer
self.set_buffer(
name,
AutogradCrypTensor(
crypten.cryptensor(buffer, **{"src": src}),
requires_grad=False,
),
rank = comm.get().get_rank()
# get world size
world_size = comm.get().get_world_size()
if world_size > 1:
# party 1 gets the actual tensor; remaining parties get dummy tensor
src_id = 1
else:
# party 0 gets the actual tensor since world size is 1
src_id = 0
if rank == src_id:
input_upd = input
else:
input_upd = torch.empty(input.size())
private_input = crypten.cryptensor(input_upd, src=src_id)
return private_input
AutogradCrypTensor(
crypten.cryptensor(param, **{"src": src}),
requires_grad=requires_grad,
),
)
else: # decrypt parameter
self.set_parameter(name, param.get_plain_text())
self._parameters[name].requires_grad = requires_grad
# encrypt / decrypt buffers:
for name, buffer in self.named_buffers(recurse=False):
if mode: # encrypt buffer
self.set_buffer(
name,
AutogradCrypTensor(
crypten.cryptensor(buffer, **{"src": src}),
requires_grad=False,
),
)
else: # decrypt buffer
self.set_buffer(name, buffer.get_plain_text())
# apply encryption recursively:
return self._apply(lambda m: m.encrypt(mode=mode, src=src))
return self