Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_OR(self):
"""Test bitwise-OR function on BinarySharedTensor"""
for tensor_type in [lambda x: x, BinarySharedTensor]:
tensor = get_random_test_tensor(is_float=False)
tensor2 = get_random_test_tensor(is_float=False)
reference = tensor | tensor2
encrypted_tensor = BinarySharedTensor(tensor)
encrypted_tensor2 = tensor_type(tensor2)
with self.benchmark(tensor_type=tensor_type.__name__) as bench:
for _ in bench.iters:
encrypted_out = encrypted_tensor | encrypted_tensor2
self._check(encrypted_out, reference, "%s OR failed" % tensor_type)
def test_add(self):
"""Tests add using binary shares"""
for tensor_type in [lambda x: x, BinarySharedTensor]:
tensor = get_random_test_tensor(is_float=False)
tensor2 = get_random_test_tensor(is_float=False)
reference = tensor + tensor2
encrypted_tensor = BinarySharedTensor(tensor)
encrypted_tensor2 = tensor_type(tensor2)
with self.benchmark(tensor_type=tensor_type.__name__) as bench:
for _ in bench.iters:
encrypted_out = encrypted_tensor + encrypted_tensor2
self._check(encrypted_out, reference, "%s AND failed" % tensor_type)
def test_get_set(self):
for tensor_type in [lambda x: x, BinarySharedTensor]:
for size in range(1, 5):
# Test __getitem__
tensor = get_random_test_tensor(size=(size, size), is_float=False)
reference = tensor[:, 0]
encrypted_tensor = BinarySharedTensor(tensor)
encrypted_out = encrypted_tensor[:, 0]
self._check(encrypted_out, reference, "getitem failed")
reference = tensor[0, :]
encrypted_out = encrypted_tensor[0, :]
self._check(encrypted_out, reference, "getitem failed")
# Test __setitem__
tensor2 = get_random_test_tensor(size=(size,), is_float=False)
reference = tensor.clone()
reference[:, 0] = tensor2
encrypted_out = BinarySharedTensor(tensor)
encrypted2 = tensor_type(tensor2)
encrypted_out[:, 0] = encrypted2
def test_ptype(self):
"""Test that ptype attribute creates the correct type of encrypted tensor"""
ptype_values = [crypten.arithmetic, crypten.binary]
tensor_types = [ArithmeticSharedTensor, BinarySharedTensor]
for i, curr_ptype in enumerate(ptype_values):
tensor = get_random_test_tensor(is_float=False)
encr_tensor = crypten.cryptensor(tensor, ptype=curr_ptype)
assert isinstance(encr_tensor._tensor, tensor_types[i]), "ptype test failed"
def test_XOR(self):
"""Test bitwise-XOR function on BinarySharedTensor"""
for tensor_type in [lambda x: x, BinarySharedTensor]:
tensor = get_random_test_tensor(is_float=False)
tensor2 = get_random_test_tensor(is_float=False)
reference = tensor ^ tensor2
encrypted_tensor = BinarySharedTensor(tensor)
encrypted_tensor2 = tensor_type(tensor2)
with self.benchmark(tensor_type=tensor_type.__name__) as bench:
for _ in bench.iters:
encrypted_out = encrypted_tensor ^ encrypted_tensor2
self._check(encrypted_out, reference, "%s XOR failed" % tensor_type)
def test_src_match_input_data(self):
"""Tests incorrect src in BinarySharedTensor fails as expected"""
tensor = get_random_test_tensor(is_float=True)
tensor.src = 0
for testing_src in [None, "abc", -2, self.world_size]:
with self.assertRaises(AssertionError):
BinarySharedTensor(tensor, src=testing_src)
def test_control_flow_failure(self):
"""Tests that control flow fails as expected"""
tensor = get_random_test_tensor(is_float=False)
encrypted_tensor = BinarySharedTensor(tensor)
with self.assertRaises(RuntimeError):
if encrypted_tensor:
pass
with self.assertRaises(RuntimeError):
tensor = 5 if encrypted_tensor else 0
with self.assertRaises(RuntimeError):
if False:
pass
elif encrypted_tensor:
pass
def test_AND(self):
"""Test bitwise-AND function on BinarySharedTensor"""
for tensor_type in [lambda x: x, BinarySharedTensor]:
tensor = get_random_test_tensor(is_float=False)
tensor2 = get_random_test_tensor(is_float=False)
reference = tensor & tensor2
encrypted_tensor = BinarySharedTensor(tensor)
encrypted_tensor2 = tensor_type(tensor2)
with self.benchmark(tensor_type=tensor_type.__name__) as bench:
for _ in bench.iters:
encrypted_out = encrypted_tensor & encrypted_tensor2
self._check(encrypted_out, reference, "%s AND failed" % tensor_type)
def test_sum(self):
"""Tests sum using binary shares"""
tensor = get_random_test_tensor(size=(5, 5, 5), is_float=False)
encrypted = BinarySharedTensor(tensor)
self._check(encrypted.sum(), tensor.sum(), "sum failed")
for dim in [0, 1, 2]:
reference = tensor.sum(dim)
with self.benchmark(type="sum", dim=dim) as bench:
for _ in bench.iters:
encrypted_out = encrypted.sum(dim)
self._check(encrypted_out, reference, "sum failed")
def generate_binary_triple(size):
"""Generate binary triples of given size"""
generator = TTPClient.get().generator
a = generate_kbit_random_tensor(size, generator=generator)
b = generate_kbit_random_tensor(size, generator=generator)
if comm.get().get_rank() == 0:
# Request c from TTP
c = TTPClient.get().ttp_request("binary", size)
else:
c = generate_kbit_random_tensor(size, generator=generator)
# Stack to vectorize scatter function
a = BinarySharedTensor.from_shares(a)
b = BinarySharedTensor.from_shares(b)
c = BinarySharedTensor.from_shares(c)
return a, b, c