Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
test_flag = 0
total_steps = 0
while total_steps < args.train_total_steps:
trajectories = collect_trajectories(
env, agent, scaler, episodes=args.episodes_per_batch)
total_steps += sum([t['obs'].shape[0] for t in trajectories])
total_train_rewards = sum([np.sum(t['rewards']) for t in trajectories])
train_obs, train_actions, train_advantages, train_discount_sum_rewards = build_train_data(
trajectories, agent)
policy_loss, kl = agent.policy_learn(train_obs, train_actions,
train_advantages)
value_loss = agent.value_learn(train_obs, train_discount_sum_rewards)
logger.info(
'Steps {}, Train reward: {}, Policy loss: {}, KL: {}, Value loss: {}'
.format(total_steps, total_train_rewards / args.episodes_per_batch,
policy_loss, kl, value_loss))
if total_steps // args.test_every_steps >= test_flag:
while total_steps // args.test_every_steps >= test_flag:
test_flag += 1
eval_reward = run_evaluate_episode(env, agent, scaler)
logger.info('Steps {}, Evaluate reward: {}'.format(
total_steps, eval_reward))
def _new_ready_actor(self):
"""
The actor is ready to start new episode,
but blocking until training thread call actor_ready_event.set()
"""
actor_ready_event = threading.Event()
self.ready_actor_queue.put(actor_ready_event)
logger.info(
"[new_avaliabe_actor] approximate size of ready actors:{}".format(
self.ready_actor_queue.qsize()))
actor_ready_event.wait()
for file, code in pyfiles['python_files'].items():
file = os.path.join(envdir, file)
with open(file, 'wb') as code_file:
code_file.write(code)
# save other files to current directory
for file, content in pyfiles['other_files'].items():
# create directory (i.e. ./rom_files/)
if '/' in file:
try:
os.makedirs(os.path.join(*file.rsplit('/')[:-1]))
except OSError as e:
pass
with open(file, 'wb') as f:
f.write(content)
logger.info('[job] reply')
reply_socket.send_multipart([remote_constants.NORMAL_TAG])
return envdir
else:
logger.error("NotImplementedError:{}, received tag:{}".format(
job_address, ))
raise NotImplementedError
Returns:
gpu_count: int
"""
gpu_count = 0
env_cuda_devices = os.environ.get('CUDA_VISIBLE_DEVICES', None)
if env_cuda_devices is not None:
assert isinstance(env_cuda_devices, str)
try:
if not env_cuda_devices:
return 0
gpu_count = len(
[x for x in env_cuda_devices.split(',') if int(x) >= 0])
logger.info(
'CUDA_VISIBLE_DEVICES found gpu count: {}'.format(gpu_count))
except:
logger.info('Cannot find available GPU devices, using CPU now.')
gpu_count = 0
else:
try:
gpu_count = str(subprocess.check_output(["nvidia-smi",
"-L"])).count('UUID')
logger.info('nvidia-smi -L found gpu count: {}'.format(gpu_count))
except:
logger.info('Cannot find available GPU devices, using CPU now.')
gpu_count = 0
return gpu_count
elif stage == 3:
assert change_num >= 3
interval = 3.0 / self.discrete_bin
discrete_id = np.random.randint(self.discrete_bin)
min_vel = -0.25 + discrete_id * interval
max_vel = -0.25 + (discrete_id + 1) * interval
while True:
target_vels = [1.25]
for i in range(change_num):
target_vels.append(target_vels[-1] +
random.uniform(-0.5, 0.5))
if target_vels[3] >= min_vel and target_vels[3] <= max_vel:
break
else:
raise NotImplemented
logger.info('[CustomR2Env] stage: {}, target_vels: {}'.format(
stage, target_vels))
return target_vels
[remote_constants.NORMAL_TAG, status])
# `xparl status` command line API
elif tag == remote_constants.STATUS_TAG:
status_info = self.cluster_monitor.get_status_info()
self.client_socket.send_multipart(
[remote_constants.NORMAL_TAG,
to_byte(status_info)])
elif tag == remote_constants.WORKER_INITIALIZED_TAG:
initialized_worker = cloudpickle.loads(message[1])
worker_address = initialized_worker.worker_address
self.job_center.add_worker(initialized_worker)
hostname = self.job_center.get_hostname(worker_address)
self.cluster_monitor.add_worker_status(worker_address, hostname)
logger.info("A new worker {} is added, ".format(worker_address) +
"the cluster has {} CPUs.\n".format(self.cpu_num))
# a thread for sending heartbeat signals to `worker.address`
thread = threading.Thread(
target=self._create_worker_monitor,
args=(initialized_worker.worker_address, ))
thread.start()
self.client_socket.send_multipart([remote_constants.NORMAL_TAG])
# a client connects to the master
elif tag == remote_constants.CLIENT_CONNECT_TAG:
client_heartbeat_address = to_str(message[1])
client_hostname = to_str(message[2])
self.client_hostname[client_heartbeat_address] = client_hostname
logger.info(
# start training
if rpm.size() > MEMORY_WARMUP_SIZE:
if steps % UPDATE_FREQ == 0:
batch_all_state, batch_action, batch_reward, batch_isOver = rpm.sample_batch(
args.batch_size)
batch_state = batch_all_state[:, :CONTEXT_LEN, :, :]
batch_next_state = batch_all_state[:, 1:, :, :]
cost = agent.learn(batch_state, batch_action, batch_reward,
batch_next_state, batch_isOver)
all_cost.append(float(cost))
total_reward += reward
state = next_state
if isOver:
break
if all_cost:
logger.info('[Train]total_reward: {}, mean_cost: {}'.format(
total_reward, np.mean(all_cost)))
return total_reward, steps, np.mean(all_cost)
def disconnect():
"""Disconnect the global client from the master node."""
global GLOBAL_CLIENT
if GLOBAL_CLIENT is not None:
GLOBAL_CLIENT.client_is_alive = False
GLOBAL_CLIENT = None
else:
logger.info(
"No client to be released. Please make sure that you have call `parl.connect`"
)