Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def setUp(self):
self._original_provider = crypten.mpc.get_default_provider()
crypten.mpc.set_default_provider(crypten.mpc.provider.TrustedThirdParty)
super(TestTTP, self).setUp()
def tearDown(self):
crypten.mpc.set_default_provider(self._original_provider)
super(TestTTP, self).tearDown()
cls = self.__class__
queue = self.mp_context.Queue()
cls.benchmark_helper = BenchmarkHelper(
self.benchmarks_enabled, self.benchmark_iters, queue
)
if hasattr(self, "benchmark_queue"):
cls.benchmark_helper.queue = self.benchmark_queue
# This gets called in the children process as well to give subclasses a
# chance to initialize themselves in the new process
if self.rank == self.MAIN_PROCESS_RANK:
self.file = tempfile.NamedTemporaryFile(delete=True).name
self.processes = [
self._spawn_process(rank) for rank in range(int(self.world_size))
]
if crypten.mpc.ttp_required():
self.processes += [self._spawn_ttp()]
# the different private type attributes of an mpc encrypted tensor
arithmetic = ptype.arithmetic
binary = ptype.binary
def print_communication_stats():
comm.get().print_communication_stats()
def reset_communication_stats():
comm.get().reset_communication_stats()
# Set backend
__SUPPORTED_BACKENDS = [crypten.mpc]
__default_backend = __SUPPORTED_BACKENDS[0]
def set_default_backend(new_default_backend):
"""Sets the default cryptensor backend (mpc, he)"""
global __default_backend
assert new_default_backend in __SUPPORTED_BACKENDS, (
"Backend %s is not supported" % new_default_backend
)
__default_backend = new_default_backend
def get_default_backend():
"""Returns the default cryptensor backend (mpc, he)"""
return __default_backend
def bernoulli(self):
"""Draws a random tensor from {0, 1} with probability 0.5"""
return self > crypten.mpc.rand(self.size())
# Use random file so multiple jobs can be run simultaneously
INIT_METHOD = "file:///tmp/crypten-rendezvous-{}".format(uuid.uuid1())
env["RENDEZVOUS"] = INIT_METHOD
self.processes = []
for rank in range(world_size):
process_name = "process " + str(rank)
process = multiprocessing.Process(
target=self.__class__._run_process,
name=process_name,
args=(rank, world_size, env, run_process_fn, fn_args),
)
self.processes.append(process)
if crypten.mpc.ttp_required():
ttp_process = multiprocessing.Process(
target=self.__class__._run_process,
name="TTP",
args=(
world_size,
world_size,
env,
crypten.mpc.provider.TTPServer,
None,
),
)
self.processes.append(ttp_process)