Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_discount_cumsum(self):
x = np.array([1., 1., 1.])
discount = 0.99
expected = [
1. + 1.*discount**1 + 1.*discount**2,
1. + 1.*discount**1,
1.]
results = discount_cumsum(x, discount)
np.testing.assert_array_equal(results, expected)
the targets for the value function.
The "last_val" argument should be 0 if the trajectory ended
because the agent reached a terminal state (died), and otherwise
should be V(s_T), the value function estimated for the last state.
This allows us to bootstrap the reward-to-go calculation to account
for timesteps beyond the arbitrary episode horizon (or epoch cutoff).
"""
samples = self.local_buffer._encode_sample(
np.arange(self.local_buffer.get_stored_size()))
rews = np.append(samples["rew"], last_val)
vals = np.append(samples["val"], last_val)
# GAE-Lambda advantage calculation
deltas = rews[:-1] + self._policy.discount * vals[1:] - vals[:-1]
if self._policy.enable_gae:
advs = discount_cumsum(
deltas, self._policy.discount * self._policy.lam)
else:
advs = deltas
# Rewards-to-go, to be targets for the value function
rets = discount_cumsum(rews, self._policy.discount)[:-1]
self.replay_buffer.add(
obs=samples["obs"], act=samples["act"], done=samples["done"],
ret=rets, adv=advs, logp=np.squeeze(samples["logp"]))
self.local_buffer.clear()
"""
samples = self.local_buffer._encode_sample(
np.arange(self.local_buffer.get_stored_size()))
rews = np.append(samples["rew"], last_val)
vals = np.append(samples["val"], last_val)
# GAE-Lambda advantage calculation
deltas = rews[:-1] + self._policy.discount * vals[1:] - vals[:-1]
if self._policy.enable_gae:
advs = discount_cumsum(
deltas, self._policy.discount * self._policy.lam)
else:
advs = deltas
# Rewards-to-go, to be targets for the value function
rets = discount_cumsum(rews, self._policy.discount)[:-1]
self.replay_buffer.add(
obs=samples["obs"], act=samples["act"], done=samples["done"],
ret=rets, adv=advs, logp=np.squeeze(samples["logp"]))
self.local_buffer.clear()