Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
"""Test convolution of encrypted tensor with public/private tensors."""
for kernel_type in [lambda x: x, ArithmeticSharedTensor]:
for matrix_width in range(2, 5):
for kernel_width in range(1, matrix_width):
for padding in range(kernel_width // 2 + 1):
matrix_size = (5, matrix_width)
matrix = get_random_test_tensor(size=matrix_size, is_float=True)
kernel_size = (kernel_width, kernel_width)
kernel = get_random_test_tensor(size=kernel_size, is_float=True)
matrix = matrix.unsqueeze(0).unsqueeze(0)
kernel = kernel.unsqueeze(0).unsqueeze(0)
reference = F.conv2d(matrix, kernel, padding=padding)
encrypted_matrix = ArithmeticSharedTensor(matrix)
encrypted_kernel = kernel_type(kernel)
with self.benchmark(
kernel_type=kernel_type.__name__, matrix_width=matrix_width
) as bench:
for _ in bench.iters:
encrypted_conv = encrypted_matrix.conv2d(
encrypted_kernel, padding=padding
)
self._check(encrypted_conv, reference, "conv2d failed")
def test_inplace(self):
"""Test inplace vs. out-of-place functions"""
for op in ["add", "sub", "mul", "div"]:
for tensor_type in [lambda x: x, ArithmeticSharedTensor]:
tensor1 = get_random_test_tensor(is_float=True)
tensor2 = get_random_test_tensor(is_float=True)
# ArithmeticSharedTensors can't divide by negative
# private values - MPCTensor overrides this to allow negatives
if op == "div" and tensor_type == ArithmeticSharedTensor:
continue
reference = getattr(torch, op)(tensor1, tensor2)
encrypted1 = ArithmeticSharedTensor(tensor1)
encrypted2 = tensor_type(tensor2)
input_plain_id = id(encrypted1.share)
input_encrypted_id = id(encrypted1)
# Test that out-of-place functions do not modify the input
private = isinstance(encrypted2, ArithmeticSharedTensor)
encrypted_out = getattr(encrypted1, op)(encrypted2)
self._check(
encrypted1,
tensor1,
"%s out-of-place %s modifies input"
% ("private" if private else "public", op),
)
self._check(
encrypted_out,
def test_arithmetic(self):
arithmetic_functions = ["add", "add_", "sub", "sub_", "mul", "mul_"]
for func in arithmetic_functions:
for tensor_type in [lambda x: x, ArithmeticSharedTensor]:
for size, (a, b) in zip(self.sizes, self.float_operands):
encrypted_as = [
ArithmeticSharedTensor(a) for _ in range(self.benchmark_iters)
]
encrypted_bs = [tensor_type(b) for _ in range(self.benchmark_iters)]
data = list(zip(encrypted_as, encrypted_bs))
with self.benchmark(
data=data, func=func, float=True, size=size
) as bench:
for encrypted_a, encrypted_b in bench.data:
encrypted_out = getattr(encrypted_a, func)(encrypted_b)
self.assertTrue(encrypted_out is not None)
tensor2 = get_random_test_tensor(is_float=True, size=(1,))
encrypted1 = ArithmeticSharedTensor(tensor1)
encrypted2 = ArithmeticSharedTensor(tensor2)
reference = getattr(tensor1, func)(tensor2)
encrypted_out = getattr(encrypted1, func)(encrypted2)
self._check(encrypted_out, reference, "private %s failed" % func)
tensor = get_random_test_tensor(is_float=True)
reference = tensor * tensor
encrypted = ArithmeticSharedTensor(tensor)
encrypted_out = encrypted.square()
self._check(encrypted_out, reference, "square failed")
# Test radd, rsub, and rmul
reference = 2 + tensor1
encrypted = ArithmeticSharedTensor(tensor1)
encrypted_out = 2 + encrypted
self._check(encrypted_out, reference, "right add failed")
reference = 2 - tensor1
encrypted_out = 2 - encrypted
self._check(encrypted_out, reference, "right sub failed")
reference = 2 * tensor1
encrypted_out = 2 * encrypted
self._check(encrypted_out, reference, "right mul failed")
def test_sum(self):
tensor = get_random_test_tensor(size=(5, 100, 100), is_float=True)
encrypted = ArithmeticSharedTensor(tensor)
self._check(encrypted.sum(), tensor.sum(), "sum failed")
for dim in [0, 1, 2]:
reference = tensor.sum(dim)
with self.benchmark(type="sum", dim=dim) as bench:
for _ in bench.iters:
encrypted_out = encrypted.sum(dim)
self._check(encrypted_out, reference, "sum failed")
def test_scatter(self):
"""Test scatter/scatter_add function of encrypted tensor"""
funcs = ["scatter", "scatter_", "scatter_add", "scatter_add_"]
sizes = [(5, 5), (5, 5, 5), (5, 5, 5, 5)]
for func in funcs:
for size in sizes:
for tensor_type in [lambda x: x, ArithmeticSharedTensor]:
for dim in range(len(size)):
tensor1 = get_random_test_tensor(size=size, is_float=True)
tensor2 = get_random_test_tensor(size=size, is_float=True)
index = get_random_test_tensor(size=size, is_float=False)
index = index.abs().clamp(0, 4)
encrypted = ArithmeticSharedTensor(tensor1)
encrypted2 = tensor_type(tensor2)
reference = getattr(tensor1, func)(dim, index, tensor2)
encrypted_out = getattr(encrypted, func)(dim, index, encrypted2)
private = tensor_type == ArithmeticSharedTensor
self._check(
encrypted_out,
reference,
"%s %s failed" % ("private" if private else "public", func),
)
if func.endswith("_"):
# Check in-place scatter/scatter-add worked
self._check(
encrypted,
reference,
"%s %s failed"
% ("private" if private else "public", func),
def test_take(self):
"""Tests take function of encrypted tensor"""
tensor_size = [5, 5, 5, 5]
index = torch.tensor([[[1, 2], [3, 4]], [[4, 2], [1, 3]]], dtype=torch.long)
tensor = get_random_test_tensor(size=tensor_size, is_float=True)
# Test when dimension!=None
for dimension in range(0, 4):
reference = torch.from_numpy(tensor.numpy().take(index, dimension))
encrypted_tensor = ArithmeticSharedTensor(tensor)
encrypted_out = encrypted_tensor.take(index, dimension)
self._check(encrypted_out, reference, "take function failed: dimension set")
# Test when dimension is default (i.e. None)
sizes = [(15,), (5, 10), (15, 10, 5)]
for size in sizes:
tensor = get_random_test_tensor(size=size, is_float=True)
encrypted_tensor = ArithmeticSharedTensor(tensor)
take_indices = [[0], [10], [0, 5, 10]]
for indices in take_indices:
indices = torch.tensor(indices)
self._check(
encrypted_tensor.take(indices),
tensor.take(indices),
f"take failed with indices {indices}",
)
encrypted_out,
reference,
"%s %s broadcast failed"
% ("private" if private else "public", func),
)
for size in matmul_sizes:
for batch1, batch2 in itertools.combinations(batch_dims, 2):
size1 = (*batch1, *size)
size2 = (*batch2, *size)
tensor1 = get_random_test_tensor(size=size1, is_float=True)
tensor2 = get_random_test_tensor(size=size2, is_float=True)
tensor2 = tensor1.transpose(-2, -1)
encrypted1 = ArithmeticSharedTensor(tensor1)
encrypted2 = tensor_type(tensor2)
reference = tensor1.matmul(tensor2)
encrypted_out = encrypted1.matmul(encrypted2)
private = isinstance(encrypted2, ArithmeticSharedTensor)
self._check(
encrypted_out,
reference,
"%s matmul broadcast failed"
% ("private" if private else "public"),
)
reference = tensor[:, 0]
encrypted_tensor = ArithmeticSharedTensor(tensor)
encrypted_out = encrypted_tensor[:, 0]
self._check(encrypted_out, reference, "getitem failed")
reference = tensor[0, :]
encrypted_out = encrypted_tensor[0, :]
self._check(encrypted_out, reference, "getitem failed")
# Test __setitem__
tensor2 = get_random_test_tensor(size=(size,), is_float=True)
reference = tensor.clone()
reference[:, 0] = tensor2
encrypted_out = ArithmeticSharedTensor(tensor)
encrypted2 = tensor_type(tensor2)
encrypted_out[:, 0] = encrypted2
self._check(
encrypted_out, reference, "%s setitem failed" % type(encrypted2)
)
reference = tensor.clone()
reference[0, :] = tensor2
encrypted_out = ArithmeticSharedTensor(tensor)
encrypted2 = tensor_type(tensor2)
encrypted_out[0, :] = encrypted2
self._check(
encrypted_out, reference, "%s setitem failed" % type(encrypted2)
def to_tensor(self):
if self.value == 0:
return ArithmeticSharedTensor
elif self.value == 1:
return BinarySharedTensor
else:
raise ValueError("Cannot convert %s to encrypted tensor" % (self.name))