Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
if isinstance(encr_module, crypten.nn.Graph):
for encr_node in encr_module.modules():
if hasattr(encr_node, key):
encr_param = getattr(encr_node, key)
break
# or get it from the crypten Module directly:
else:
encr_param = getattr(encr_module, key)
# compare with reference:
# NOTE: Because some parameters are initialized randomly
# with different values on each process, we only want to
# check that they are consistent with source parameter value
reference = getattr(module, key)
src_reference = comm.get().broadcast(reference, src=0)
msg = "parameter %s in %s incorrect" % (key, module_name)
if not encrypted:
encr_param = crypten.cryptensor(encr_param)
self._check(encr_param, src_reference, msg)
# compare model outputs:
self.assertTrue(encr_module.training, "training value incorrect")
reference = module(input)
encr_output = encr_module(encr_input)
self._check(encr_output, reference, "%s forward failed" % module_name)
# test backward pass:
reference.backward(torch.ones(reference.size()))
encr_output.backward()
if wrap: # you cannot get input gradients on MPCTensor inputs
self._check(
# integer here so other parties cannot guess its value.
# We sometimes get here from a forked process, which causes all parties
# to have the same RNG state. Reset the seed to make sure RNG streams
# are different in all the parties. We use numpy's random here since
# setting its seed to None will produce different seeds even from
# forked processes.
import numpy
numpy.random.seed(seed=None)
next_seed = torch.tensor(numpy.random.randint(-2 ** 63, 2 ** 63 - 1, (1,)))
prev_seed = torch.LongTensor([0]) # placeholder
# Send random seed to next party, receive random seed from prev party
world_size = comm.get().get_world_size()
rank = comm.get().get_rank()
if world_size >= 2: # Otherwise sending seeds will segfault.
next_rank = (rank + 1) % world_size
prev_rank = (next_rank - 2) % world_size
req0 = comm.get().isend(tensor=next_seed, dst=next_rank)
req1 = comm.get().irecv(tensor=prev_seed, src=prev_rank)
req0.wait()
req1.wait()
else:
prev_seed = next_seed
# Seed Generators
comm.get().g0.manual_seed(next_seed.item())
comm.get().g1.manual_seed(prev_seed.item())
# NOTE: Chosen seed can be any number, but we choose as a random 64-bit
# integer here so other parties cannot guess its value.
# We sometimes get here from a forked process, which causes all parties
# to have the same RNG state. Reset the seed to make sure RNG streams
# are different in all the parties. We use numpy's random here since
# setting its seed to None will produce different seeds even from
# forked processes.
import numpy
numpy.random.seed(seed=None)
next_seed = torch.tensor(numpy.random.randint(-2 ** 63, 2 ** 63 - 1, (1,)))
prev_seed = torch.LongTensor([0]) # placeholder
# Send random seed to next party, receive random seed from prev party
world_size = comm.get().get_world_size()
rank = comm.get().get_rank()
if world_size >= 2: # Otherwise sending seeds will segfault.
next_rank = (rank + 1) % world_size
prev_rank = (next_rank - 2) % world_size
req0 = comm.get().isend(tensor=next_seed, dst=next_rank)
req1 = comm.get().irecv(tensor=prev_seed, src=prev_rank)
req0.wait()
req1.wait()
else:
prev_seed = next_seed
# Seed Generators
comm.get().g0.manual_seed(next_seed.item())
comm.get().g1.manual_seed(prev_seed.item())
def save_checkpoint(
state, is_best, filename="checkpoint.pth.tar", model_best="model_best.pth.tar"
):
# TODO: use crypten.save() in future.
rank = comm.get().get_rank()
# only save for process rank = 0
if rank == 0:
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, model_best)
def reveal(self):
"""Get plaintext without any downscaling"""
shares = comm.get().all_gather(self.share)
result = shares[0]
for x in shares[1:]:
result = result ^ x
return result
def generate_additive_triple(size0, size1, op, *args, **kwargs):
"""Generate multiplicative triples of given sizes"""
generator = TTPClient.get().generator
a = generate_random_ring_element(size0, generator=generator)
b = generate_random_ring_element(size1, generator=generator)
if comm.get().get_rank() == 0:
# Request c from TTP
c = TTPClient.get().ttp_request(
"additive", size0, size1, op, *args, **kwargs
)
else:
# TODO: Compute size without executing computation
c_size = getattr(torch, op)(a, b, *args, **kwargs).size()
c = generate_random_ring_element(c_size, generator=generator)
a = ArithmeticSharedTensor.from_shares(a, precision=0)
b = ArithmeticSharedTensor.from_shares(b, precision=0)
c = ArithmeticSharedTensor.from_shares(c, precision=0)
return a, b, c
[theta_x] = theta_z + [beta_xr] - [theta_r] - [eta_xr]
Where [theta_i] is the wraps for a variable i
[beta_ij] is the differential wraps for variables i and j
[eta_ij] is the plaintext wraps for variables i and j
Note: Since [eta_xr] = 0 with probability 1 - |x| / Q for modulus Q, we
can make the assumption that [eta_xr] = 0 with high probability.
"""
provider = crypten.mpc.get_default_provider()
r, theta_r = provider.wrap_rng(x.size())
beta_xr = theta_r.clone()
beta_xr._tensor = count_wraps([x._tensor, r._tensor])
z = x + r
theta_z = comm.get().gather(z._tensor, 0)
theta_x = beta_xr - theta_r
# TODO: Incorporate eta_xr
if x.rank == 0:
theta_z = count_wraps(theta_z)
theta_x._tensor += theta_z
return theta_x