Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
compression._apply_compress_scheme = compression.apply_lz4_compression
elif compress_scheme == compression.ZSTD:
compression._apply_compress_scheme = compression.apply_zstd_compression
else:
compression._apply_compress_scheme = compression.apply_no_compression
else:
compression._apply_compress_scheme = compression.apply_no_compression
t = Tensor(numpy.ones((100, 100)))
t_serialized = syft.serde.serialize(t)
assert (
t_serialized[0] == compress_scheme
if compress
else t_serialized[0] == compression.NO_COMPRESSION
)
t_serialized_deserialized = syft.serde.deserialize(t_serialized)
assert (t == t_serialized_deserialized).all()
def test_bytearray(compress):
if compress:
compression._apply_compress_scheme = compression.apply_lz4_compression
else:
compression._apply_compress_scheme = compression.apply_no_compression
bytearr = bytearray("This is a teststring", "utf-8")
bytearr_serialized = syft.serde.serialize(bytearr)
bytearr_serialized_desirialized = syft.serde.deserialize(bytearr_serialized)
assert bytearr == bytearr_serialized_desirialized
bytearr = bytearray(numpy.random.random((100, 100)))
bytearr_serialized = syft.serde.serialize(bytearr)
bytearr_serialized_desirialized = syft.serde.deserialize(bytearr_serialized)
assert bytearr == bytearr_serialized_desirialized
def test_ndarray_serde(compress):
if compress:
compression._apply_compress_scheme = compression.apply_lz4_compression
else:
compression._apply_compress_scheme = compression.apply_no_compression
arr = numpy.random.random((100, 100))
arr_serialized = syft.serde.serialize(arr)
arr_serialized_deserialized = syft.serde.deserialize(arr_serialized)
assert numpy.array_equal(arr, arr_serialized_deserialized)
def test_torch_Tensor(compress):
if compress:
compression._apply_compress_scheme = compression.apply_lz4_compression
else:
compression._apply_compress_scheme = compression.apply_no_compression
t = Tensor(numpy.random.random((100, 100)))
t_serialized = syft.serde.serialize(t)
t_serialized_deserialized = syft.serde.deserialize(t_serialized)
assert (t == t_serialized_deserialized).all()
"""
if _is_model_in_cache(model_id):
# Model already exists
cache_model = _get_model_from_cache(model_id)
if cache_model.allow_remote_inference:
return {"success": True, "model": cache_model.model_obj}
else:
return {
"success": False,
"not_allowed": True,
"error": NOT_ALLOWED_TO_RUN_INFERENCE_MSG,
}
try:
result = _get_model_from_db(model_id)
if result:
model = sy.serde.deserialize(result.model)
# If the model is a Plan we also need to retrieve
# the state tensors
if isinstance(model, sy.Plan):
_retrieve_state_from_db(model)
# Save model in cache
_save_model_to_cache(
model,
model_id,
result.allow_download,
result.allow_remote_inference,
serialized=False,
)
if result.allow_remote_inference:
return {"success": True, "model": model}
def object(self):
return sy.serde.deserialize(self.data)
def object(self):
return sy.serde.deserialize(self.data)
message (bin) : PySyft binary message.
Returns:
response (bin) : PySyft binary response.
"""
try:
## If worker is empty, load previous database tensors.
if not current_user.worker._objects:
recover_objects(current_user.worker)
# Process message
decoded_response = current_user.worker._recv_msg(message)
# Save worker state at database
snapshot(current_user.worker)
except GetNotPermittedError as e:
message = sy.serde.deserialize(message, worker=current_user.worker)
# Register this request into tensor owner account.
if hasattr(current_user, "save_tensor_request"):
current_user.save_request(message._contents)
decoded_response = sy.serde.serialize(e)
return decoded_response
Returns:
The deserialized form of message from the worker at specified
location.
"""
if self.verbose:
print(f"worker {self} sending {message} to {location}")
# Step 1: serialize the message to a binary
bin_message = sy.serde.serialize(message, worker=self)
# Step 2: send the message and wait for a response
bin_response = self._send_msg(bin_message, location)
# Step 3: deserialize the response
response = sy.serde.deserialize(bin_response, worker=self)
return response