Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def PRZS(*size):
"""
Generate a Pseudo-random Sharing of Zero (using arithmetic shares)
This function does so by generating `n` numbers across `n` parties with
each number being held by exactly 2 parties. One of these parties adds
this number while the other subtracts this number.
"""
tensor = ArithmeticSharedTensor(src=SENTINEL)
current_share = generate_random_ring_element(*size, generator=comm.get().g0)
next_share = generate_random_ring_element(*size, generator=comm.get().g1)
tensor.share = current_share - next_share
return tensor
def rand(*sizes, encoder=None):
"""Generate random ArithmeticSharedTensor uniform on [0, 1]"""
generator = TTPClient.get().generator
if isinstance(sizes, torch.Size):
sizes = tuple(sizes)
if isinstance(sizes[0], torch.Size):
sizes = tuple(sizes[0])
if comm.get().get_rank() == 0:
# Request samples from TTP
samples = TTPClient.get().ttp_request("rand", *sizes, encoder=encoder)
else:
samples = generate_random_ring_element(sizes, generator=generator)
return ArithmeticSharedTensor.from_shares(samples)
def share(secret, num_parties=2):
"""Create an arithmetic (additive) sharing from a secret"""
# For single process, do not encrypt (for debugging purposes)
if num_parties < 2:
return secret
shares0 = generate_random_ring_element(secret.size())
shares1 = secret - shares0
if num_parties == 2:
return shares0, shares1
return (shares0, *share(shares1, num_parties=(num_parties - 1)))
def generate_additive_triple(size0, size1, op, *args, **kwargs):
"""Generate multiplicative triples of given sizes"""
generator = TTPClient.get().generator
a = generate_random_ring_element(size0, generator=generator)
b = generate_random_ring_element(size1, generator=generator)
if comm.get().get_rank() == 0:
# Request c from TTP
c = TTPClient.get().ttp_request(
"additive", size0, size1, op, *args, **kwargs
)
else:
# TODO: Compute size without executing computation
c_size = getattr(torch, op)(a, b, *args, **kwargs).size()
c = generate_random_ring_element(c_size, generator=generator)
a = ArithmeticSharedTensor.from_shares(a, precision=0)
b = ArithmeticSharedTensor.from_shares(b, precision=0)
c = ArithmeticSharedTensor.from_shares(c, precision=0)
return a, b, c
def PRZS(*size):
"""
Generate a Pseudo-random Sharing of Zero (using arithmetic shares)
This function does so by generating `n` numbers across `n` parties with
each number being held by exactly 2 parties. One of these parties adds
this number while the other subtracts this number.
"""
tensor = ArithmeticSharedTensor(src=SENTINEL)
current_share = generate_random_ring_element(*size, generator=comm.get().g0)
next_share = generate_random_ring_element(*size, generator=comm.get().g1)
tensor.share = current_share - next_share
return tensor
def generate_additive_triple(size0, size1, op, *args, **kwargs):
"""Generate multiplicative triples of given sizes"""
generator = TTPClient.get().generator
a = generate_random_ring_element(size0, generator=generator)
b = generate_random_ring_element(size1, generator=generator)
if comm.get().get_rank() == 0:
# Request c from TTP
c = TTPClient.get().ttp_request(
"additive", size0, size1, op, *args, **kwargs
)
else:
# TODO: Compute size without executing computation
c_size = getattr(torch, op)(a, b, *args, **kwargs).size()
c = generate_random_ring_element(c_size, generator=generator)
a = ArithmeticSharedTensor.from_shares(a, precision=0)
b = ArithmeticSharedTensor.from_shares(b, precision=0)
c = ArithmeticSharedTensor.from_shares(c, precision=0)
return a, b, c
def _get_additive_PRSS(self, size, remove_rank=False):
"""
Generates a plaintext value from a set of random additive secret shares
generated by each party
"""
gens = self.generators[1:] if remove_rank else self.generators
result = torch.stack(
[generate_random_ring_element(size, generator=g) for g in gens]
)
return result.sum(0)
def wrap_rng(size):
"""Generate random shared tensor of given size and sharing of its wraps"""
generator = TTPClient.get().generator
r = generate_random_ring_element(size, generator=generator)
if comm.get().get_rank() == 0:
# Request theta_r from TTP
theta_r = TTPClient.get().ttp_request("wraps", size)
else:
theta_r = generate_random_ring_element(size, generator=generator)
r = ArithmeticSharedTensor.from_shares(r, precision=0)
theta_r = ArithmeticSharedTensor.from_shares(theta_r, precision=0)
return r, theta_r
def PRZS(*size):
"""
Generate a Pseudo-random Sharing of Zero (using arithmetic shares)
This function does so by generating `n` numbers across `n` parties with
each number being held by exactly 2 parties. One of these parties adds
this number while the other subtracts this number.
"""
tensor = ArithmeticSharedTensor(src=SENTINEL)
current_share = generate_random_ring_element(*size, generator=comm.get().g0)
next_share = generate_random_ring_element(*size, generator=comm.get().g1)
tensor.share = current_share - next_share
return tensor