Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
"""Tests tensor encoding and decoding."""
for float in [False, True]:
if float:
fpe = FixedPointEncoder(precision_bits=16)
else:
fpe = FixedPointEncoder(precision_bits=0)
tensor = get_test_tensor(float=float)
decoded = fpe.decode(fpe.encode(tensor))
self._check(
decoded,
tensor,
"Encoding/decoding a %s failed." % "float" if float else "long",
)
# Make sure encoding a subclass of CrypTensor is a no-op
crypten.mpc.set_default_provider(crypten.mpc.provider.TrustedFirstParty)
crypten.init()
tensor = get_test_tensor(float=True)
encrypted_tensor = crypten.cryptensor(tensor)
encrypted_tensor = fpe.encode(encrypted_tensor)
self._check(
encrypted_tensor.get_plain_text(),
tensor,
"Encoding an EncryptedTensor failed.",
)
# Try a few other types.
fpe = FixedPointEncoder(precision_bits=0)
for dtype in [torch.uint8, torch.int8, torch.int16]:
tensor = torch.zeros(5, dtype=dtype).random_()
decoded = fpe.decode(fpe.encode(tensor)).type(dtype)
def test_in_first(self):
# TODO: Make this work with TTP provider
crypten.mpc.set_default_provider(crypten.mpc.provider.TrustedFirstParty)
# This will cause the parent process to init with world-size 1
crypten.init()
self.assertEqual(comm.get().get_world_size(), 1)
# This will fork 2 children which will have to init with world-size 2
self.assertEqual(test_rank_func(), [0, 1])
# Make sure everything is the same in the parent
self.assertEqual(comm.get().get_world_size(), 1)
def setUp(self):
self._original_provider = crypten.mpc.get_default_provider()
crypten.mpc.set_default_provider(crypten.mpc.provider.TrustedFirstParty)
super(TestTFP, self).setUp()
def setUp(self):
self._original_provider = crypten.mpc.get_default_provider()
crypten.mpc.set_default_provider(crypten.mpc.provider.TrustedThirdParty)
super(TestTTP, self).setUp()
def tearDown(self):
crypten.mpc.set_default_provider(self._original_provider)
super(TestTFP, self).tearDown()
def setUp(self):
self._original_provider = crypten.mpc.get_default_provider()
crypten.mpc.set_default_provider(crypten.mpc.provider.TrustedFirstParty)
super(TestTFP, self).setUp()
def setUp(self):
self._original_provider = crypten.mpc.get_default_provider()
crypten.mpc.set_default_provider(crypten.mpc.provider.TrustedThirdParty)
super(TestTTP, self).setUp()
def tearDown(self):
crypten.mpc.set_default_provider(self._original_provider)
super(TestTTP, self).tearDown()
def setUp(self):
self._original_provider = crypten.mpc.get_default_provider()
crypten.mpc.set_default_provider(crypten.mpc.provider.TrustedThirdParty)
super(TestTTP, self).setUp()