Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_with_inputs(func, input):
encrypted_tensor = MPCTensor(input)
reference = getattr(tensor, func)()
with self.benchmark(niters=10, func=func) as bench:
for _ in bench.iters:
encrypted_out = getattr(encrypted_tensor, func)()
self._check(encrypted_out, reference, "%s failed" % func)
def test_set(self):
"""Tests set correctly re-assigns encrypted shares"""
sizes = [(1, 5), (5, 10), (15, 10, 5)]
for size in sizes:
tensor1 = get_random_test_tensor(size=size, is_float=True)
encrypted1 = MPCTensor(tensor1)
tensor2 = get_random_test_tensor(size=size, is_float=True)
encrypted2 = MPCTensor(tensor2)
# check encrypted set
encrypted1.set(encrypted2)
self._check(
encrypted1, tensor2, f"set with encrypted other failed with size {size}"
)
# check plain text set
encrypted1 = MPCTensor(tensor1)
encrypted1.set(tensor2)
self._check(
encrypted1,
tensor2,
f"set with unencrypted other failed with size {size}",
)
def test_narrow(self):
"""Tests narrow function."""
sizes = [(5, 6), (5, 6, 7), (6, 7, 8, 9)]
for size in sizes:
tensor = get_random_test_tensor(size=size, is_float=True)
encr_tensor = MPCTensor(tensor)
for dim in range(len(size)):
for start in range(size[dim] - 2):
for length in range(1, size[dim] - start):
tensor_narrow = tensor.narrow(dim, start, length)
encr_tensor_narrow = encr_tensor.narrow(dim, start, length)
self._check(
encr_tensor_narrow,
tensor_narrow,
"narrow failed along dimension %d" % dim,
)
def test_get_set(self):
"""Tests element setting and getting by index"""
for tensor_type in [lambda x: x, MPCTensor]:
for size in range(1, 5):
# Test __getitem__
tensor = get_random_test_tensor(size=(size, size), is_float=True)
reference = tensor[:, 0]
encrypted_tensor = MPCTensor(tensor)
encrypted_out = encrypted_tensor[:, 0]
self._check(encrypted_out, reference, "getitem failed")
reference = tensor[0, :]
encrypted_out = encrypted_tensor[0, :]
self._check(encrypted_out, reference, "getitem failed")
# Test __setitem__
tensor2 = get_random_test_tensor(size=(size,), is_float=True)
reference = tensor.clone()
reference[:, 0] = tensor2
encrypted_out = MPCTensor(tensor)
encrypted2 = tensor_type(tensor2)
encrypted_out[:, 0] = encrypted2
# sample input:
matrix_size = (5, matrix_width)
matrix = get_random_test_tensor(
size=matrix_size, is_float=True
)
matrix = matrix.unsqueeze(0).unsqueeze(0)
# sample filtering kernel:
kernel_size = (kernel_width, kernel_width)
kernel = get_random_test_tensor(
size=kernel_size, is_float=True
)
kernel = kernel.unsqueeze(0).unsqueeze(0)
# perform filtering:
encr_matrix = MPCTensor(matrix)
encr_kernel = kernel_type(kernel)
with self.benchmark(
kernel_type=kernel_type.__name__,
matrix_width=matrix_width,
) as bench:
for _ in bench.iters:
encr_conv = getattr(encr_matrix, func_name)(
encr_kernel, padding=padding
)
# check that result is correct:
reference = getattr(F, func_name)(
matrix, kernel, padding=padding
)
self._check(encr_conv, reference, "%s failed" % func_name)
"""Tests padding"""
sizes = [(1,), (5,), (1, 1), (5, 5), (5, 5, 5), (5, 3, 32, 32)]
pads = [
(0, 0, 0, 0),
(1, 0, 0, 0),
(0, 1, 0, 0),
(0, 0, 1, 0),
(0, 0, 0, 1),
(1, 1, 1, 1),
(2, 2, 1, 1),
(2, 2, 2, 2),
]
for size in sizes:
tensor = get_random_test_tensor(size=size, is_float=True)
encrypted_tensor = MPCTensor(tensor)
for pad in pads:
for value in [0, 1, 10]:
for tensor_type in [lambda x: x, MPCTensor]:
if tensor.dim() < 2:
pad = pad[:2]
reference = torch.nn.functional.pad(tensor, pad, value=value)
encrypted_value = tensor_type(value)
with self.benchmark(tensor_type=tensor_type.__name__) as bench:
for _ in bench.iters:
encrypted_out = encrypted_tensor.pad(
pad, value=encrypted_value
)
self._check(encrypted_out, reference, "pad failed")
def test_pow(self):
"""Tests pow function"""
for pow_fn in ["pow", "pow_"]:
for power in [-3, -2, -1, 0, 1, 2, 3]:
tensor = get_random_test_tensor(is_float=True)
encrypted_tensor = MPCTensor(tensor)
reference = getattr(tensor, pow_fn)(power)
with self.benchmark(niters=10, func=pow_fn, power=power) as bench:
for _ in bench.iters:
encrypted_out = getattr(encrypted_tensor, pow_fn)(power)
self._check(encrypted_out, reference, "pow failed with power %s" % power)
if pow_fn.endswith("_"):
self._check(
encrypted_tensor, reference, "in-place pow_ does not modify input"
)
else:
self._check(encrypted_tensor, tensor, "out-of-place pow modifies input")
def test_mean(self):
"""Tests computing means of encrypted tensors."""
tensor = get_random_test_tensor(size=(5, 10, 15), is_float=True)
encrypted = MPCTensor(tensor)
self._check(encrypted.mean(), tensor.mean(), "mean failed")
for dim in [0, 1, 2]:
reference = tensor.mean(dim)
encrypted_out = encrypted.mean(dim)
self._check(encrypted_out, reference, "mean failed")
def test_bernoulli(self):
"""Tests bernoulli sampling"""
for size in [(10,), (10, 10), (10, 10, 10)]:
probs = MPCTensor(torch.rand(size))
with self.benchmark(size=size) as bench:
for _ in bench.iters:
randvec = probs.bernoulli()
self.assertTrue(randvec.size() == size, "Incorrect size")
tensor = randvec.get_plain_text()
self.assertTrue(((tensor == 0) + (tensor == 1)).all(), "Invalid values")
probs = MPCTensor(torch.Tensor(int(1e6)).fill_(0.2))
randvec = probs.bernoulli().get_plain_text()
frac_zero = float((randvec == 0).sum()) / randvec.nelement()
self.assertTrue(math.isclose(frac_zero, 0.8, rel_tol=1e-3, abs_tol=1e-3))