Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def PRZS(*size):
"""
Generate a Pseudo-random Sharing of Zero (using arithmetic shares)
This function does so by generating `n` numbers across `n` parties with
each number being held by exactly 2 parties. Therefore, each party holds
two numbes. A zero sharing is found by having each party xor their two
numbers together.
"""
tensor = BinarySharedTensor(src=SENTINEL)
current_share = generate_kbit_random_tensor(*size, generator=comm.get().g0)
next_share = generate_kbit_random_tensor(*size, generator=comm.get().g1)
tensor.share = current_share ^ next_share
return tensor
def __init__(self, tensor=None, size=None, src=0):
if src == SENTINEL:
return
assert (
isinstance(src, int) and src >= 0 and src < comm.get().get_world_size()
), "invalid tensor source"
# Assume 0 bits of precision unless encoder is set outside of init
self.encoder = FixedPointEncoder(precision_bits=0)
if tensor is not None:
tensor = self.encoder.encode(tensor)
size = tensor.size()
# Generate Psuedo-random Sharing of Zero and add source's tensor
self.share = BinarySharedTensor.PRZS(size).share
if self.rank == src:
assert tensor is not None, "Source must provide a data tensor"
if hasattr(tensor, "src"):
assert (
tensor.src == src
), "Source of data tensor must match source of encryption"
self.share ^= tensor
def PRZS(*size):
"""
Generate a Pseudo-random Sharing of Zero (using arithmetic shares)
This function does so by generating `n` numbers across `n` parties with
each number being held by exactly 2 parties. Therefore, each party holds
two numbes. A zero sharing is found by having each party xor their two
numbers together.
"""
tensor = BinarySharedTensor(src=SENTINEL)
current_share = generate_kbit_random_tensor(*size, generator=comm.get().g0)
next_share = generate_kbit_random_tensor(*size, generator=comm.get().g1)
tensor.share = current_share ^ next_share
return tensor
def shallow_copy(self):
"""Create a shallow copy"""
result = BinarySharedTensor(src=SENTINEL)
result.encoder = self.encoder
result.share = self.share
return result
def _A2B(arithmetic_tensor):
binary_tensor = BinarySharedTensor.stack(
[
BinarySharedTensor(arithmetic_tensor.share, src=i)
for i in range(comm.get().get_world_size())
]
)
binary_tensor = binary_tensor.sum(dim=0)
binary_tensor.encoder = arithmetic_tensor.encoder
return binary_tensor
def _add_regular_function(function_name):
def regular_func(self, *args, **kwargs):
result = self.shallow_copy()
result.share = getattr(result.share, function_name)(*args, **kwargs)
return result
setattr(BinarySharedTensor, function_name, regular_func)
def AND(x, y):
"""
Performs Beaver protocol for binary secret-shared tensors x and y
1. Obtain uniformly random sharings [a],[b] and [c] = [a & b]
2. XOR hide [x] and [y] with appropriately sized [a] and [b]
3. Open ([epsilon] = [x] ^ [a]) and ([delta] = [y] ^ [b])
4. Return [c] ^ (epsilon & [b]) ^ ([a] & delta) ^ (epsilon & delta)
"""
from .binary import BinarySharedTensor
provider = crypten.mpc.get_default_provider()
a, b, c = provider.generate_binary_triple(x.size())
# Stack to vectorize reveal
eps_del = BinarySharedTensor.stack([x ^ a, y ^ b]).reveal()
epsilon = eps_del[0]
delta = eps_del[1]
return (b & epsilon) ^ (a & delta) ^ (epsilon & delta) ^ c
def __setitem__(self, index, value):
"""Set tensor values by index"""
if torch.is_tensor(value) or isinstance(value, list):
value = BinarySharedTensor(value)
assert isinstance(
value, BinarySharedTensor
), "Unsupported input type %s for __setitem__" % type(value)
self.share.__setitem__(index, value.share)
def scatter_(self, dim, index, src):
"""Writes all values from the tensor `src` into `self` at the indices
specified in the `index` tensor. For each value in `src`, its output index
is specified by its index in `src` for `dimension != dim` and by the
corresponding value in `index` for `dimension = dim`.
"""
if torch.is_tensor(src):
src = BinarySharedTensor(src)
assert isinstance(
src, BinarySharedTensor
), "Unrecognized scatter src type: %s" % type(src)
self.share.scatter_(dim, index, src.share)
return self