Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def setUp(self):
super().setUp()
# We don't want the main process (rank -1) to initialize the communcator
if self.rank >= 0:
crypten.init()
for float in [False, True]:
if float:
fpe = FixedPointEncoder(precision_bits=16)
else:
fpe = FixedPointEncoder(precision_bits=0)
tensor = get_test_tensor(float=float)
decoded = fpe.decode(fpe.encode(tensor))
self._check(
decoded,
tensor,
"Encoding/decoding a %s failed." % "float" if float else "long",
)
# Make sure encoding a subclass of CrypTensor is a no-op
crypten.mpc.set_default_provider(crypten.mpc.provider.TrustedFirstParty)
crypten.init()
tensor = get_test_tensor(float=True)
encrypted_tensor = crypten.cryptensor(tensor)
encrypted_tensor = fpe.encode(encrypted_tensor)
self._check(
encrypted_tensor.get_plain_text(),
tensor,
"Encoding an EncryptedTensor failed.",
)
# Try a few other types.
fpe = FixedPointEncoder(precision_bits=0)
for dtype in [torch.uint8, torch.int8, torch.int16]:
tensor = torch.zeros(5, dtype=dtype).random_()
decoded = fpe.decode(fpe.encode(tensor)).type(dtype)
self._check(decoded, tensor, "Encoding/decoding a %s failed." % dtype)
def setUp(self):
super().setUp()
if self.rank >= 0:
crypten.init()
crypten.set_default_backend(crypten.mpc)
def setUp(self):
super().setUp()
if self.rank >= 0:
crypten.init()
def setUp(self):
super().setUp()
# We don't want the main process (rank -1) to initialize the communcator
if self.rank >= 0:
crypten.init()
def setUp(self):
super().setUp()
# We don't want the main process (rank -1) to initialize the communcator
if self.rank == self.MAIN_PROCESS_RANK:
return
crypten.init()
torch.manual_seed(0)
self.sizes = [(1, 8), (1, 16), (1, 32)]
self.int_tensors = [
get_random_test_tensor(size=size, is_float=False) for size in self.sizes
]
self.int_operands = [
(
get_random_test_tensor(size=size, is_float=False),
get_random_test_tensor(size=size, is_float=False),
)
for size in self.sizes
]
self.float_tensors = [
def run_mpc_linear_svm(
epochs=50, examples=50, features=100, lr=0.5, skip_plaintext=False
):
crypten.init()
# Set random seed for reproducibility
torch.manual_seed(1)
# Initialize x, y, w, b
x = torch.randn(features, examples)
w_true = torch.randn(1, features)
b_true = torch.randn(1)
y = w_true.matmul(x) + b_true
y = y.sign()
if not skip_plaintext:
logging.info("==================")
logging.info("PyTorch Training")
logging.info("==================")
w_torch, b_torch = train_linear_svm(x, y, lr=lr, print_time=True)
def _run_experiment(args):
if args.plaintext:
import plain_contextual_bandits as bandits
else:
import private_contextual_bandits as bandits
learner_func = build_learner(args, bandits, download_mnist)
import crypten
crypten.init()
learner_func()
start_epoch=0,
batch_size=256,
lr=0.01,
momentum=0.9,
weight_decay=1e-6,
print_freq=10,
resume="",
evaluate=True,
seed=None,
skip_plaintext=False,
save_checkpoint_dir="/tmp/tfe_benchmarks",
save_modelbest_dir="/tmp/tfe_benchmarks_best",
context_manager=None,
mnist_dir=None,
):
crypten.init()
if seed is not None:
random.seed(seed)
torch.manual_seed(seed)
# create model
model = create_benchmark_model(network)
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
model.parameters(), lr, momentum=momentum, weight_decay=weight_decay
)
# optionally resume from a checkpoint
def _run_process(cls, rank, world_size, env, run_process_fn, fn_args):
for env_key, env_value in env.items():
os.environ[env_key] = env_value
os.environ["RANK"] = str(rank)
orig_logging_level = logging.getLogger().level
logging.getLogger().setLevel(logging.INFO)
crypten.init()
logging.getLogger().setLevel(orig_logging_level)
if fn_args is None:
run_process_fn()
else:
run_process_fn(fn_args)