Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_execute_command_self(hook):
sy.VirtualWorker.mocked_function = MethodType(
mock.Mock(return_value="bob_mocked_function"), sy.VirtualWorker
)
bob = sy.VirtualWorker(hook, "bob")
x = th.tensor([1, 2, 3]).send(bob)
message = bob.create_message_execute_command(
command_name="mocked_function", command_owner="self"
)
serialized_message = sy.serde.serialize(message)
response = bob._recv_msg(serialized_message)
response = sy.serde.deserialize(response)
assert response == "bob_mocked_function"
bob.mocked_function.assert_called()
def _recv_msg(self, message: bin) -> bin:
# Sends the message to the server
self.sio.emit("message", message)
self.wait_for_client_event = True
# Wait until the server gets back with a result or an ACK
while self.wait_for_client_event:
time.sleep(0.1)
# Return the result
if self.response_from_client == "ACK":
# Empty result for the serialiser to continue
return sy.serde.serialize(b"")
return self.response_from_client
def run_remote_inference(self, model_id, data):
""" Run a dataset inference using a remote model.
Args:
model_id (str) : Model ID.
data (Tensor) : dataset to be inferred.
Returns:
inference (Tensor) : Inference result
Raises:
RuntimeError : If an unexpected behavior happen, It will forward the error message.
"""
serialized_data = sy.serde.serialize(data).decode(self._encoding)
message = {
REQUEST_MSG.TYPE_FIELD: REQUEST_MSG.RUN_INFERENCE,
"model_id": model_id,
"data": serialized_data,
"encoding": self._encoding,
}
response = self._forward_json_to_websocket_server_worker(message)
return self._return_bool_result(response, RESPONSE_MSG.INFERENCE_RESULT)
def _recv_msg(self, message: bin) -> bin:
try:
return self.recv_msg(message)
except (ResponseSignatureError, GetNotPermittedError) as e:
return sy.serde.serialize(e)
def serialize(self): # check serde.py to see how to provide compression schemes
"""Serializes the tensor on which it's called.
This is the high level convenience function for serializing torch
tensors. It includes three steps, Simplify, Serialize, and Compress as
described in serde.py.
By default serde is compressing using LZ4
Returns:
The serialized form of the tensor.
For example:
x = torch.Tensor([1,2,3,4,5])
x.serialize() # returns a serialized object
"""
return sy.serde.serialize(self)
"""Returns a serialized_model with given model id.
Args:
model_id (str): The unique identifier associated with the model.
Returns:
A dict with structure: {"success": Bool, "model": serialized model object}.
On error returns dict: {"success": Bool, "error": error message }.
"""
if _is_model_in_cache(model_id):
# Model already exists
cache_model = _get_model_from_cache(model_id)
if cache_model.allow_download:
return {
"success": True,
"serialized_model": sy.serde.serialize(cache_model.model_obj),
}
else:
return {
"success": False,
"not_allowed": True,
"error": NOT_ALLOWED_TO_DOWNLOAD_MSG,
}
try:
result = _get_model_from_db(model_id)
if result:
model = sy.serde.deserialize(result.model)
# If the model is a Plan we also need to retrieve
# the state tensors
if isinstance(model, sy.Plan):
_retrieve_state_from_db(model)
def _send_msg_and_deserialize(self, command_name: str, *args, **kwargs):
message = self.create_message_execute_command(
command_name=command_name, command_owner="self", *args, **kwargs
)
# Send the message and return the deserialized response.
serialized_message = sy.serde.serialize(message)
response = self._recv_msg(serialized_message)
return sy.serde.deserialize(response)
Args:
msg_type: A integer representing the message type.
message: A Message object
location: A BaseWorker instance that lets you provide the
destination to send the message.
Returns:
The deserialized form of message from the worker at specified
location.
"""
if self.verbose:
print(f"worker {self} sending {message} to {location}")
# Step 1: serialize the message to a binary
bin_message = sy.serde.serialize(message, worker=self)
# Step 2: send the message and wait for a response
bin_response = self._send_msg(bin_message, location)
# Step 3: deserialize the response
response = sy.serde.deserialize(bin_response, worker=self)
return response