Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
@ray.remote
def multiple_dependency(i, arg1, arg2, arg3):
arg1 = np.copy(arg1)
arg1[0] = i
return arg1
def run_func(func, *args, **kwargs):
"""Helper function for running examples"""
ray.init()
func = ray.remote(func)
# NOTE: kwargs not allowed for now
result = ray.get(func.remote(*args))
# Inspect the stack to get calling example
caller = inspect.stack()[1][3]
print("%s: %s" % (caller, str(result)))
return result
import pickle
from collections import defaultdict
import projekt
from forge import trinity
from forge.trinity.timed import runtime
from forge.ethyr.io import Stimulus, Action, utils
from forge.ethyr.torch import optim
from forge.ethyr.experience import RolloutManager
import torch
@ray.remote(num_gpus=1)
class God(trinity.God):
'''Server level God API demo
This server level optimizer node aggregates experience
across all core level rollout worker nodes. It
uses the aggregated experience compute gradients.
This is effectively a lightweight variant of the
Rapid computation model, with the potential notable
difference that we also recompute the forward pass
from small observation buffers rather than
communicating large activation tensors.
This demo builds up the ExperienceBuffer utility,
which handles rollout batching.'''
import argparse
import time
import ray
import model
parser = argparse.ArgumentParser(description="Run the asynchronous parameter "
"server example.")
parser.add_argument("--num-workers", default=4, type=int,
help="The number of workers to use.")
parser.add_argument("--redis-address", default=None, type=str,
help="The Redis address of the cluster.")
@ray.remote
class ParameterServer(object):
def __init__(self, keys, values):
# These values will be mutated, so we must create a copy that is not
# backed by the object store.
values = [value.copy() for value in values]
self.weights = dict(zip(keys, values))
def push(self, keys, values):
for key, value in zip(keys, values):
self.weights[key] += value
def pull(self, keys):
return [self.weights[key] for key in keys]
@ray.remote
@DeveloperAPI
@classmethod
def as_remote(cls, num_cpus=None, num_gpus=None, resources=None):
return ray.remote(
num_cpus=num_cpus, num_gpus=num_gpus, resources=resources)(cls)
from ray.experimental.serve.policy import RoutePolicy
from ray.experimental.serve.server import HTTPActor
def start_initial_state(kv_store_connector):
nursery_handle = ActorNursery.remote()
ray.experimental.register_actor(SERVE_NURSERY_NAME, nursery_handle)
ray.get(
nursery_handle.store_bootstrap_state.remote(
BOOTSTRAP_KV_STORE_CONN_KEY, kv_store_connector))
return nursery_handle
@ray.remote
class ActorNursery:
"""Initialize and store all actor handles.
Note:
This actor is necessary because ray will destory actors when the
original actor handle goes out of scope (when driver exit). Therefore
we need to initialize and store actor handles in a seperate actor.
"""
def __init__(self):
# Dict: Actor handles -> tag
self.actor_handles = dict()
self.bootstrap_state = dict()
def start_actor(self, actor_cls, tag, init_args=(), init_kwargs={}):
if self.is_setup:
raise ValueError("setup can be only invoke once")
self.is_setup = True
import ray
if not self.is_in_mlsql:
if func_for_rows is not None:
func = ray.remote(func_for_rows)
return ray.get(func.remote(self.mock_data))
else:
func = ray.remote(func_for_row)
def iter_all(rows):
return [ray.get(func.remote(row)) for row in rows]
iter_all_func = ray.remote(iter_all)
return ray.get(iter_all_func.remote(self.mock_data))
buffer = []
for server_info in self.build_servers_in_ray():
server = ray.experimental.get_actor(server_info.server_id)
buffer.append(ray.get(server.connect_info.remote()))
server.serve.remote(func_for_row, func_for_rows)
items = [vars(server) for server in buffer]
self.python_context.build_result(items, 1024)
return buffer
@ray.remote
def slogdet(a):
raise NotImplementedError
def _create_single_worker(self, config):
logger.warning("AVAIL CLUSTER RES {}".format(ray.cluster_resources()))
RemotePyTorchRunner = ray.remote(num_gpus=1)(PyTorchRunner)
worker = RemotePyTorchRunner.remote(
config["batch_per_device"],
momentum=config["momentum"],
weight_decay=config["weight_decay"],
model_creator=config["model_creator"],
data_creator=config["data_creator"],
loss_creator=config["loss_creator"],
verbose=config["verbose"],
use_nccl=config["use_nccl"],
worker_config=config["worker_config"],
lr_config=config["lr_config"],
)
worker.set_device.remote()
return worker
write_with_length(ex.encode("utf-8"), out)
out.flush()
read_int(infile)
except IOError:
# JVM close the socket
pass
except Exception:
# Write the error to stderr if it happened while serializing
print("Py worker failed with exception:")
print(traceback.format_exc())
pass
conn.close()
@ray.remote
class RayDataServer(object):
def __init__(self, server_id, java_server, port=0, timezone="Asia/Harbin"):
self.server = OnceServer(
self.get_address(), port, java_server.timezone)
try:
(rel_host, rel_port) = self.server.bind()
except Exception:
print(traceback.format_exc())
self.host = rel_host
self.port = rel_port
self.timezone = timezone
self.server_id = server_id
self.java_server = java_server