Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
adam.update(g, optim_stepsize * cur_lrmult)
losses.append(newlosses)
logger.log(fmt_row(13, np.mean(losses, axis=0)))
logger.log("Evaluating losses...")
losses = []
for batch in d.iterate_once(optim_batchsize):
newlosses = compute_losses(batch["obs"], batch["actions"], batch["atarg"], batch["vtarg"], cur_lrmult)
losses.append(newlosses)
meanlosses, _, _ = mpi_moments(losses, axis=0)
logger.log(fmt_row(13, meanlosses))
for (lossval, name) in zipsame(meanlosses, loss_names):
logger.record_tabular("loss_" + name, lossval)
logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
lens, rews = map(flatten_lists, zip(*listoflrpairs))
lenbuffer.extend(lens)
rewbuffer.extend(rews)
logger.record_tabular("EpThisIter", len(lens))
episodes_so_far += len(lens)
timesteps_so_far += sum(lens)
iters_so_far += 1
logger.record_tabular("EpisodesSoFar", episodes_so_far)
logger.record_tabular("TimestepsSoFar", timesteps_so_far)
logger.record_tabular("TimeElapsed", time.time() - tstart)
logger.record_tabular("EpLenMean", np.mean(lenbuffer))
logger.record_tabular("EpRewMean", np.mean(rewbuffer))
if MPI.COMM_WORLD.Get_rank() == 0:
logger.dump_tabular()
self.reward_giver.obs_rms.update(np.concatenate((ob_batch, ob_expert), 0))
# Reshape actions if needed when using discrete actions
if isinstance(self.action_space, gym.spaces.Discrete):
if len(ac_batch.shape) == 2:
ac_batch = ac_batch[:, 0]
if len(ac_expert.shape) == 2:
ac_expert = ac_expert[:, 0]
*newlosses, grad = self.reward_giver.lossandgrad(ob_batch, ac_batch, ob_expert, ac_expert)
self.d_adam.update(self.allmean(grad), self.d_stepsize)
d_losses.append(newlosses)
logger.log(fmt_row(13, np.mean(d_losses, axis=0)))
# lr: lengths and rewards
lr_local = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"]) # local values
list_lr_pairs = MPI.COMM_WORLD.allgather(lr_local) # list of tuples
lens, rews, true_rets = map(flatten_lists, zip(*list_lr_pairs))
true_reward_buffer.extend(true_rets)
else:
# lr: lengths and rewards
lr_local = (seg["ep_lens"], seg["ep_rets"]) # local values
list_lr_pairs = MPI.COMM_WORLD.allgather(lr_local) # list of tuples
lens, rews = map(flatten_lists, zip(*list_lr_pairs))
len_buffer.extend(lens)
reward_buffer.extend(rews)
if len(len_buffer) > 0:
logger.record_tabular("EpLenMean", np.mean(len_buffer))
logger.record_tabular("EpRewMean", np.mean(reward_buffer))
if self.using_gail:
logger.record_tabular("EpTrueRewMean", np.mean(true_reward_buffer))
logger.record_tabular("EpThisIter", len(lens))
for (lossname, lossval) in zip(loss_names, meanlosses):
logger.record_tabular(lossname, lossval)
with timed("vf"):
for _ in range(vf_iters):
for (mbob, mbret) in dataset.iterbatches((seg["ob"], seg["tdlamret"]),
include_final_partial_batch=False, batch_size=64):
g = allmean(compute_vflossandgrad(mbob, mbret))
vfadam.update(g, vf_stepsize)
logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
lens, rews = map(flatten_lists, zip(*listoflrpairs))
lenbuffer.extend(lens)
rewbuffer.extend(rews)
logger.record_tabular("EpLenMean", np.mean(lenbuffer))
logger.record_tabular("EpRewMean", np.mean(rewbuffer))
logger.record_tabular("EpThisIter", len(lens))
episodes_so_far += len(lens)
timesteps_so_far += sum(lens)
iters_so_far += 1
logger.record_tabular("EpisodesSoFar", episodes_so_far)
logger.record_tabular("TimestepsSoFar", timesteps_so_far)
logger.record_tabular("TimeElapsed", time.time() - tstart)
if rank==0:
losses.append(newlosses)
logger.log(fmt_row(13, np.mean(losses, axis=0)))
logger.log("Evaluating losses...")
losses = []
for batch in d.iterate_once(optim_batchsize):
newlosses = compute_losses(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"],
cur_lrmult)
losses.append(newlosses)
meanlosses, _, _ = mpi_moments(losses, axis=0)
logger.log(fmt_row(13, meanlosses))
for (lossval, name) in zipsame(meanlosses, loss_names):
logger.record_tabular("loss_" + name, lossval)
logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
lens, rews = map(flatten_lists, zip(*listoflrpairs))
lenbuffer.extend(lens)
rewbuffer.extend(rews)
logger.record_tabular("EpLenMean", np.mean(lenbuffer))
logger.record_tabular("EpRewMean", np.mean(rewbuffer))
logger.record_tabular("EpThisIter", len(lens))
episodes_so_far += len(lens)
timesteps_so_far += sum(lens)
iters_so_far += 1
logger.record_tabular("EpisodesSoFar", episodes_so_far)
logger.record_tabular("TimestepsSoFar", timesteps_so_far)
logger.record_tabular("TimeElapsed", time.time() - tstart)
if MPI.COMM_WORLD.Get_rank() == 0:
logger.dump_tabular()
#print(iters_so_far, save_per_acts)
losses.append(newlosses)
logger.log(fmt_row(13, np.mean(losses, axis=0)))
logger.log("Evaluating losses...")
losses = []
for batch in d.iterate_once(optim_batchsize):
newlosses = compute_losses(batch["ob"], batch["ac"], batch["atarg"], batch["vtarg"],
cur_lrmult)
losses.append(newlosses)
meanlosses, _, _ = mpi_moments(losses, axis=0)
logger.log(fmt_row(13, meanlosses))
for (lossval, name) in zipsame(meanlosses, loss_names):
logger.record_tabular("loss_" + name, lossval)
logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
lrlocal = (seg["ep_lens"], seg["ep_rets"]) # local values
listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples
lens, rews = map(flatten_lists, zip(*listoflrpairs))
lenbuffer.extend(lens)
rewbuffer.extend(rews)
logger.record_tabular("EpLenMean", np.mean(lenbuffer))
logger.record_tabular("EpRewMean", np.mean(rewbuffer))
logger.record_tabular("EpThisIter", len(lens))
episodes_so_far += len(lens)
timesteps_so_far += sum(lens)
iters_so_far += 1
logger.record_tabular("EpisodesSoFar", episodes_so_far)
logger.record_tabular("TimestepsSoFar", timesteps_so_far)
logger.record_tabular("TimeElapsed", time.time() - tstart)
if MPI.COMM_WORLD.Get_rank() == 0:
logger.dump_tabular()
#print(iters_so_far, save_per_acts)
def update(self):
#Some logic gathering best ret, rooms etc using MPI.
temp = sum(MPI.COMM_WORLD.allgather(self.local_rooms), [])
temp = sorted(list(set(temp)))
self.rooms = temp
temp = sum(MPI.COMM_WORLD.allgather(self.scores), [])
temp = sorted(list(set(temp)))
self.scores = temp
temp = sum(MPI.COMM_WORLD.allgather([self.local_best_ret]), [])
self.best_ret = max(temp)
eprews = MPI.COMM_WORLD.allgather(np.mean(list(self.I.statlists["eprew"])))
local_best_rets = MPI.COMM_WORLD.allgather(self.local_best_ret)
n_rooms = sum(MPI.COMM_WORLD.allgather([len(self.local_rooms)]), [])
if MPI.COMM_WORLD.Get_rank() == 0:
logger.info(f"Rooms visited {self.rooms}")
logger.info(f"Best return {self.best_ret}")
logger.info(f"Best local return {sorted(local_best_rets)}")
logger.info(f"eprews {sorted(eprews)}")
logger.info(f"n_rooms {sorted(n_rooms)}")
logger.info(f"Extrinsic coefficient {self.ext_coeff}")
logger.info(f"Gamma {self.gamma}")
logger.info(f"Gamma ext {self.gamma_ext}")
logger.info(f"All scores {sorted(self.scores)}")
#Normalize intrinsic rewards.
def update(self):
#Some logic gathering best ret, rooms etc using MPI.
temp = sum(MPI.COMM_WORLD.allgather(self.local_rooms), [])
temp = sorted(list(set(temp)))
self.rooms = temp
temp = sum(MPI.COMM_WORLD.allgather(self.scores), [])
temp = sorted(list(set(temp)))
self.scores = temp
temp = sum(MPI.COMM_WORLD.allgather([self.local_best_ret]), [])
self.best_ret = max(temp)
eprews = MPI.COMM_WORLD.allgather(np.mean(list(self.I.statlists["eprew"])))
local_best_rets = MPI.COMM_WORLD.allgather(self.local_best_ret)
n_rooms = sum(MPI.COMM_WORLD.allgather([len(self.local_rooms)]), [])
if MPI.COMM_WORLD.Get_rank() == 0:
logger.info(f"Rooms visited {self.rooms}")
def decrement_starting_point(self, n_points_to_shift):
self.env.decrement_starting_point(n_points_to_shift)
starting_points = self.env.recursive_getattr('starting_point')
all_starting_points = flatten_lists(MPI.COMM_WORLD.allgather(starting_points))
self.max_starting_point = max(all_starting_points)
if not np.isfinite(mean_losses).all():
logger.log("Got non-finite value of losses -- bad!")
elif kl_loss > self.max_kl * 1.5:
logger.log("violated KL constraint. shrinking step.")
elif improve < 0:
logger.log("surrogate didn't improve. shrinking step.")
else:
logger.log("Stepsize OK!")
break
stepsize *= .5
else:
logger.log("couldn't compute a good step")
self.set_from_flat(thbefore)
if self.nworkers > 1 and iters_so_far % 20 == 0:
# list of tuples
paramsums = MPI.COMM_WORLD.allgather((thnew.sum(), self.vfadam.getflat().sum()))
assert all(np.allclose(ps, paramsums[0]) for ps in paramsums[1:])
with self.timed("vf"):
for _ in range(self.vf_iters):
for (mbob, mbret) in dataset.iterbatches((seg["ob"], seg["tdlamret"]),
include_final_partial_batch=False,
batch_size=128):
grad = self.allmean(self.compute_vflossandgrad(mbob, mbob, mbret, sess=self.sess))
self.vfadam.update(grad, self.vf_stepsize)
for (loss_name, loss_val) in zip(self.loss_names, mean_losses):
logger.record_tabular(loss_name, loss_val)
logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret))
if self.using_gail:
if not np.isfinite(mean_losses).all():
logger.log("Got non-finite value of losses -- bad!")
elif kl_loss > self.max_kl * 1.5:
logger.log("violated KL constraint. shrinking step.")
elif improve < 0:
logger.log("surrogate didn't improve. shrinking step.")
else:
logger.log("Stepsize OK!")
break
stepsize *= .5
else:
logger.log("couldn't compute a good step")
self.set_from_flat(thbefore)
if self.nworkers > 1 and iters_so_far % 20 == 0:
# list of tuples
paramsums = MPI.COMM_WORLD.allgather((thnew.sum(), self.vfadam.getflat().sum()))
assert all(np.allclose(ps, paramsums[0]) for ps in paramsums[1:])
for (loss_name, loss_val) in zip(self.loss_names, mean_losses):
logger.record_tabular(loss_name, loss_val)
with self.timed("vf"):
for _ in range(self.vf_iters):
# NOTE: for recurrent policies, use shuffle=False?
for (mbob, mbret) in dataset.iterbatches((seg["observations"], seg["tdlamret"]),
include_final_partial_batch=False,
batch_size=128,
shuffle=True):
grad = self.allmean(self.compute_vflossandgrad(mbob, mbob, mbret, sess=self.sess))
self.vfadam.update(grad, self.vf_stepsize)
logger.record_tabular("explained_variance_tdlam_before",