Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def div_(self, y):
"""Divide two tensors element-wise"""
# TODO: Add test coverage for this code path (next 4 lines)
if isinstance(y, float) and int(y) == y:
y = int(y)
if is_float_tensor(y) and y.frac().eq(0).all():
y = y.long()
if isinstance(y, int) or is_int_tensor(y):
# Truncate protocol for dividing by public integers:
if comm.get().get_world_size() > 2:
wraps = self.wraps()
self.share /= y
# NOTE: The multiplication here must be split into two parts
# to avoid long out-of-bounds when y <= 2 since (2 ** 63) is
# larger than the largest long integer.
self -= wraps * 4 * (int(2 ** 62) // y)
else:
self.share /= y
return self
# Otherwise multiply by reciprocal
if isinstance(y, float):
y = torch.FloatTensor([y])
assert is_float_tensor(y), "Unsupported type for div_: %s" % type(y)
return self.mul_(y.reciprocal())
def test_save_load(self):
"""Test that crypten.save and crypten.load properly save and load tensors"""
import tempfile
comm = crypten.communicator
filename = tempfile.NamedTemporaryFile(delete=True).name
for dimensions in range(1, 5):
# Create tensors with different sizes on each rank
size = [self.rank + 1] * dimensions
size = tuple(size)
tensor = torch.randn(size=size)
for src in range(comm.get().get_world_size()):
crypten.save(tensor, filename, src=src)
encrypted_load = crypten.load(filename, src=src)
reference_size = tuple([src + 1] * dimensions)
self.assertEqual(encrypted_load.size(), reference_size)
size_out = [src + 1] * dimensions
reference = tensor if self.rank == src else torch.empty(size=size_out)
def test_generator_func():
t0 = torch.randint(-2 ** 63, 2 ** 63 - 1, (1,), generator=comm.get().g0).item()
t1 = torch.randint(-2 ** 63, 2 ** 63 - 1, (1,), generator=comm.get().g1).item()
return (t0, t1)
rank = comm.get().get_rank()
if world_size >= 2: # Otherwise sending seeds will segfault.
next_rank = (rank + 1) % world_size
prev_rank = (next_rank - 2) % world_size
req0 = comm.get().isend(tensor=next_seed, dst=next_rank)
req1 = comm.get().irecv(tensor=prev_seed, src=prev_rank)
req0.wait()
req1.wait()
else:
prev_seed = next_seed
# Seed Generators
comm.get().g0.manual_seed(next_seed.item())
comm.get().g1.manual_seed(prev_seed.item())
def reveal(self, dst=None):
"""Get plaintext without any downscaling"""
tensor = self.share.clone()
if dst is None:
return comm.get().all_reduce(tensor)
else:
return comm.get().reduce(tensor, dst=dst)
def encrypt_data_tensor_with_src(input):
"""Encrypt data tensor for multi-party setting"""
# get rank of current process
rank = comm.get().get_rank()
# get world size
world_size = comm.get().get_world_size()
if world_size > 1:
# party 1 gets the actual tensor; remaining parties get dummy tensor
src_id = 1
else:
# party 0 gets the actual tensor since world size is 1
src_id = 0
if rank == src_id:
input_upd = input
else:
input_upd = torch.empty(input.size())
private_input = crypten.cryptensor(input_upd, src=src_id)
return private_input
def _A2B(arithmetic_tensor):
binary_tensor = BinarySharedTensor.stack(
[
BinarySharedTensor(arithmetic_tensor.share, src=i)
for i in range(comm.get().get_world_size())
]
)
binary_tensor = binary_tensor.sum(dim=0)
binary_tensor.encoder = arithmetic_tensor.encoder
return binary_tensor
def reveal(self, dst=None):
"""Get plaintext without any downscaling"""
if dst is None:
shares = comm.get().all_gather(self.share)
else:
shares = comm.get().gather(self.share, dst=dst)
return reduce(lambda x, y: x ^ y, shares)
if print_time:
pt_time = AverageMeter()
end = time.time()
for epoch in range(epochs):
# Forward
label_predictions = w.matmul(features).add(b).sign()
# Compute accuracy
correct = label_predictions.mul(labels)
accuracy = correct.add(1).div(2).mean()
if crypten.is_encrypted_tensor(accuracy):
accuracy = accuracy.get_plain_text()
# Print Accuracy once
if crypten.communicator.get().get_rank() == 0:
logging.info(
f"Epoch {epoch} --- Training Accuracy %.2f%%" % (accuracy.item() * 100)
)
# Backward
loss_grad = -labels * (1 - correct) * 0.5 # Hinge loss
b_grad = loss_grad.mean()
w_grad = loss_grad.matmul(features.t()).div(loss_grad.size(1))
# Update
w -= w_grad * lr
b -= b_grad * lr
if print_time:
iter_time = time.time() - end
pt_time.add(iter_time)
def evaluate_linear_svm(features, labels, w, b):
"""Compute accuracy on a test set"""
predictions = w.matmul(features).add(b).sign()
correct = predictions.mul(labels)
accuracy = correct.add(1).div(2).mean().get_plain_text()
if crypten.communicator.get().get_rank() == 0:
print("Test accuracy %.2f%%" % (accuracy.item() * 100))