Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def train(self, batch, actor_update=True, critic_update=True):
states_t, actions_t, rewards_t, states_tp1, done_t = \
batch["state"], batch["action"], batch["reward"], \
batch["next_state"], batch["done"]
states_t = utils.any2device(states_t, device=self._device)
actions_t = utils.any2device(actions_t, device=self._device)
rewards_t = utils.any2device(
rewards_t, device=self._device
).unsqueeze(1)
states_tp1 = utils.any2device(states_tp1, device=self._device)
done_t = utils.any2device(done_t, device=self._device).unsqueeze(1)
"""
states_t: [bs; history_len; observation_len]
actions_t: [bs; action_len]
rewards_t: [bs; 1]
states_tp1: [bs; history_len; observation_len]
done_t: [bs; 1]
"""
policy_loss, value_loss = self._loss_fn(
states_t, actions_t, rewards_t, states_tp1, done_t
def get_rollout(self, states, actions, rewards, dones):
assert len(states) == len(actions) == len(rewards) == len(dones)
trajectory_len = \
rewards.shape[0] if dones[-1] else rewards.shape[0] - 1
states_len = states.shape[0]
states = utils.any2device(states, device=self._device)
actions = utils.any2device(actions, device=self._device)
rewards = np.array(rewards)[:trajectory_len]
values = torch.zeros(
(states_len + 1, self._num_heads, self._num_atoms)).\
to(self._device)
values[:states_len, ...] = self.critic(states).squeeze_(dim=2)
# Each column corresponds to a different gamma
values = values.cpu().numpy()[:trajectory_len + 1, ...]
_, logprobs = self.actor(states, logprob=actions)
logprobs = logprobs.cpu().numpy().reshape(-1)[:trajectory_len]
# len x num_heads
deltas = rewards[:, None, None] \
+ self._gammas[:, None] * values[1:] - values[:-1]
# For each gamma in the list of gammas compute the
# advantage and returns
# len x num_heads x num_atoms
values_range = self.critic.values_range
self.v_min, self.v_max = values_range
self.delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1)
z = torch.linspace(
start=self.v_min, end=self.v_max, steps=self.num_atoms
)
self.z = utils.any2device(z, device=self._device)
self._loss_fn = self._categorical_loss
elif critic_distribution == "quantile":
self.num_atoms = self.critic.num_atoms
tau_min = 1 / (2 * self.num_atoms)
tau_max = 1 - tau_min
tau = torch.linspace(
start=tau_min, end=tau_max, steps=self.num_atoms
)
self.tau = utils.any2device(tau, device=self._device)
self._loss_fn = self._quantile_loss
else:
assert self.critic_criterion is not None
self._gamma,
self._hyperbolic_constant,
self._num_heads
)
self._gammas = utils.any2device(self._gammas, device=self._device)
assert critic_distribution in [None, "categorical", "quantile"]
if critic_distribution == "categorical":
self.num_atoms = self.critic.num_atoms
values_range = self.critic.values_range
self.v_min, self.v_max = values_range
self.delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1)
z = torch.linspace(
start=self.v_min, end=self.v_max, steps=self.num_atoms
)
self.z = utils.any2device(z, device=self._device)
self._loss_fn = self._categorical_loss
elif critic_distribution == "quantile":
self.num_atoms = self.critic.num_atoms
tau_min = 1 / (2 * self.num_atoms)
tau_max = 1 - tau_min
tau = torch.linspace(
start=tau_min, end=tau_max, steps=self.num_atoms
)
self.tau = utils.any2device(tau, device=self._device)
self._loss_fn = self._quantile_loss
else:
assert self.critic_criterion is not None
(
states_t, actions_t, returns_t,
advantages_t, action_logprobs_t
) = (
batch["state"], batch["action"], batch["return"],
batch["advantage"], batch["action_logprob"]
)
states_t = utils.any2device(states_t, device=self._device)
actions_t = utils.any2device(actions_t, device=self._device)
returns_t = utils.any2device(
returns_t, device=self._device
).unsqueeze_(-1)
advantages_t = utils.any2device(advantages_t, device=self._device)
action_logprobs_t = utils.any2device(
action_logprobs_t, device=self._device
)
action_logprobs_t = utils.any2device(
action_logprobs_t, device=self._device
)
# critic loss
values_tp0 = self.critic(states_t).squeeze_(dim=2)
advantages_tp0 = (returns_t - values_tp0)
value_loss = 0.5 * advantages_tp0.pow(2).mean()
# actor loss
_, action_logprobs_tp0 = self.actor(states_t, logprob=actions_t)
policy_loss = -(advantages_t.detach() * action_logprobs_tp0).mean()
self.clip_eps = clip_eps
self.entropy_regularization = entropy_regularization
critic_distribution = self.critic.distribution
self._value_loss_fn = self._base_value_loss
self._num_atoms = self.critic.num_atoms
self._num_heads = self.critic.num_heads
self._hyperbolic_constant = self.critic.hyperbolic_constant
self._gammas = \
utils.hyperbolic_gammas(
self._gamma,
self._hyperbolic_constant,
self._num_heads
)
# 1 x num_heads x 1
self._gammas_torch = utils.any2device(
self._gammas, device=self._device
)[None, :, None]
if critic_distribution == "categorical":
self.num_atoms = self.critic.num_atoms
values_range = self.critic.values_range
self.v_min, self.v_max = values_range
self.delta_z = (self.v_max - self.v_min) / (self._num_atoms - 1)
z = torch.linspace(
start=self.v_min, end=self.v_max, steps=self._num_atoms
)
self.z = utils.any2device(z, device=self._device)
self._value_loss_fn = self._categorical_value_loss
elif critic_distribution == "quantile":
assert self.critic_criterion is not None
def train(self, batch, **kwargs):
(
states_t, actions_t, returns_t, states_tp1, done_t, values_t,
advantages_t, action_logprobs_t
) = (
batch["state"], batch["action"], batch["return"],
batch["state_tp1"], batch["done"], batch["value"],
batch["advantage"], batch["action_logprob"]
)
states_t = utils.any2device(states_t, device=self._device)
actions_t = utils.any2device(actions_t, device=self._device)
returns_t = utils.any2device(
returns_t, device=self._device
).unsqueeze_(-1)
states_tp1 = utils.any2device(states_tp1, device=self._device)
done_t = utils.any2device(done_t, device=self._device)[:, None, None]
# done_t = done_t[:, None, :] # [bs; 1; 1]
values_t = utils.any2device(values_t, device=self._device)
advantages_t = utils.any2device(advantages_t, device=self._device)
action_logprobs_t = utils.any2device(
action_logprobs_t, device=self._device
)
# critic loss
# states_t - [bs; {state_shape}]
# values_t - [bs; num_heads; num_atoms]
values_range = self.critic.values_range
self.v_min, self.v_max = values_range
self.delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1)
z = torch.linspace(
start=self.v_min, end=self.v_max, steps=self.num_atoms
)
self.z = utils.any2device(z, device=self._device)
self._loss_fn = self._categorical_loss
elif critic_distribution == "quantile":
self.num_atoms = self.critic.num_atoms
tau_min = 1 / (2 * self.num_atoms)
tau_max = 1 - tau_min
tau = torch.linspace(
start=tau_min, end=tau_max, steps=self.num_atoms
)
self.tau = utils.any2device(tau, device=self._device)
self._loss_fn = self._quantile_loss
else:
assert self.critic_criterion is not None
def train(self, batch, **kwargs):
states, actions, returns, action_logprobs = \
batch["state"], batch["action"], batch["return"],\
batch["action_logprob"]
states = utils.any2device(states, device=self._device)
actions = utils.any2device(actions, device=self._device)
returns = utils.any2device(returns, device=self._device)
old_logprobs = utils.any2device(action_logprobs, device=self._device)
# actor loss
_, logprobs = self.actor(states, logprob=actions)
# REINFORCE objective function
policy_loss = -torch.mean(logprobs * returns)
if self.entropy_regularization is not None:
entropy = -(torch.exp(logprobs) * logprobs).mean()
entropy_loss = self.entropy_regularization * entropy
policy_loss = policy_loss + entropy_loss
# actor update