Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def learn(self, batch: Batch, batch_size: int, repeat: int,
**kwargs) -> Dict[str, List[float]]:
self._batch = batch_size
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
v = []
old_log_prob = []
with torch.no_grad():
for b in batch.split(batch_size, shuffle=False):
v.append(self.critic(b.obs))
old_log_prob.append(self(b).dist.log_prob(
to_torch_as(b.act, v[0])))
batch.v = torch.cat(v, dim=0).squeeze(-1) # old value
batch.act = to_torch_as(batch.act, v[0])
batch.logp_old = torch.cat(old_log_prob, dim=0).reshape(batch.v.shape)
batch.returns = to_torch_as(batch.returns, v[0])
if self._rew_norm:
mean, std = batch.returns.mean(), batch.returns.std()
if not np.isclose(std.item(), 0):
batch.returns = (batch.returns - mean) / std
batch.adv = batch.returns - batch.v
if self._rew_norm:
mean, std = batch.adv.mean(), batch.adv.std()
if not np.isclose(std.item(), 0):
batch.adv = (batch.adv - mean) / std
for _ in range(repeat):
for b in batch.split(batch_size):
dist = self(b).dist
value = self.critic(b.obs).squeeze(-1)
ratio = (dist.log_prob(b.act).reshape(value.shape) - b.logp_old
def learn(self, batch: Batch, batch_size: int, repeat: int,
**kwargs) -> Dict[str, List[float]]:
losses = []
r = batch.returns
if self._rew_norm and not np.isclose(r.std(), 0):
batch.returns = (r - r.mean()) / r.std()
for _ in range(repeat):
for b in batch.split(batch_size):
self.optim.zero_grad()
dist = self(b).dist
a = to_torch_as(b.act, dist.logits)
r = to_torch_as(b.returns, dist.logits)
loss = -(dist.log_prob(a) * r).sum()
loss.backward()
self.optim.step()
losses.append(loss.item())
return {'loss': losses}
def learn(self, batch: Batch, batch_size: int, repeat: int,
**kwargs) -> Dict[str, List[float]]:
self._batch = batch_size
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
v = []
old_log_prob = []
with torch.no_grad():
for b in batch.split(batch_size, shuffle=False):
v.append(self.critic(b.obs))
old_log_prob.append(self(b).dist.log_prob(
to_torch_as(b.act, v[0])))
batch.v = torch.cat(v, dim=0).squeeze(-1) # old value
batch.act = to_torch_as(batch.act, v[0])
batch.logp_old = torch.cat(old_log_prob, dim=0).reshape(batch.v.shape)
batch.returns = to_torch_as(batch.returns, v[0])
if self._rew_norm:
mean, std = batch.returns.mean(), batch.returns.std()
if not np.isclose(std.item(), 0):
batch.returns = (batch.returns - mean) / std
batch.adv = batch.returns - batch.v
if self._rew_norm:
mean, std = batch.adv.mean(), batch.adv.std()
if not np.isclose(std.item(), 0):
batch.adv = (batch.adv - mean) / std
for _ in range(repeat):
for b in batch.split(batch_size):
dist = self(b).dist
value = self.critic(b.obs).squeeze(-1)
ratio = (dist.log_prob(b.act).reshape(value.shape) - b.logp_old
).exp().float()
surr1 = ratio * b.adv
def learn(self, batch: Batch, batch_size: int, repeat: int,
**kwargs) -> Dict[str, List[float]]:
self._batch = batch_size
r = batch.returns
if self._rew_norm and not np.isclose(r.std(), 0):
batch.returns = (r - r.mean()) / r.std()
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
for _ in range(repeat):
for b in batch.split(batch_size):
self.optim.zero_grad()
dist = self(b).dist
v = self.critic(b.obs).squeeze(-1)
a = to_torch_as(b.act, v)
r = to_torch_as(b.returns, v)
a_loss = -(dist.log_prob(a).reshape(v.shape) * (r - v).detach()
).mean()
vf_loss = F.mse_loss(r, v)
ent_loss = dist.entropy().mean()
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
loss.backward()
if self._grad_norm is not None:
nn.utils.clip_grad_norm_(
list(self.actor.parameters()) +
list(self.critic.parameters()),
max_norm=self._grad_norm)
self.optim.step()
actor_losses.append(a_loss.item())
vf_losses.append(vf_loss.item())
ent_losses.append(ent_loss.item())
mean, std = 0, 1
else:
mean, std = 0, 1
returns = np.zeros_like(indice)
gammas = np.zeros_like(indice) + n_step
done, buf_len = buffer.done, len(buffer)
for n in range(n_step - 1, -1, -1):
now = (indice + n) % buf_len
gammas[done[now] > 0] = n
returns[done[now] > 0] = 0
returns = (rew[now] - mean) / std + gamma * returns
terminal = (indice + n_step - 1) % buf_len
target_q = target_q_fn(buffer, terminal).flatten() # shape: [bsz, ]
target_q[gammas != n_step] = 0
returns = to_torch_as(returns, target_q)
gammas = to_torch_as(gamma ** gammas, target_q)
batch.returns = target_q * gammas + returns
return batch
:return: A :class:`~tianshou.data.Batch` which has 2 keys:
* ``act`` the action.
* ``state`` the hidden state.
.. seealso::
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
model = getattr(self, model)
obs = getattr(batch, input)
actions, h = model(obs, state=state, info=batch.info)
actions += self._action_bias
if self.training and explorating:
actions += to_torch_as(self._noise(actions.shape), actions)
actions = actions.clamp(self._range[0], self._range[1])
return Batch(act=actions, state=h)
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
if self._target and self._cnt % self._freq == 0:
self.sync_weight()
self.optim.zero_grad()
q = self(batch).logits
q = q[np.arange(len(q)), batch.act]
r = to_torch_as(batch.returns, q)
if hasattr(batch, 'update_weight'):
td = r - q
batch.update_weight(batch.indice, to_numpy(td))
impt_weight = to_torch_as(batch.impt_weight, q)
loss = (td.pow(2) * impt_weight).mean()
else:
loss = F.mse_loss(q, r)
loss.backward()
self.optim.step()
self._cnt += 1
return {'loss': loss.item()}
if np.isclose(std, 0):
mean, std = 0, 1
else:
mean, std = 0, 1
returns = np.zeros_like(indice)
gammas = np.zeros_like(indice) + n_step
done, buf_len = buffer.done, len(buffer)
for n in range(n_step - 1, -1, -1):
now = (indice + n) % buf_len
gammas[done[now] > 0] = n
returns[done[now] > 0] = 0
returns = (rew[now] - mean) / std + gamma * returns
terminal = (indice + n_step - 1) % buf_len
target_q = target_q_fn(buffer, terminal).flatten() # shape: [bsz, ]
target_q[gammas != n_step] = 0
returns = to_torch_as(returns, target_q)
gammas = to_torch_as(gamma ** gammas, target_q)
batch.returns = target_q * gammas + returns
return batch
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
input: str = 'obs',
explorating: bool = True,
**kwargs) -> Batch:
obs = getattr(batch, input)
logits, h = self.actor(obs, state=state, info=batch.info)
assert isinstance(logits, tuple)
dist = DiagGaussian(*logits)
x = dist.rsample()
y = torch.tanh(x)
act = y * self._action_scale + self._action_bias
y = self._action_scale * (1 - y.pow(2)) + self.__eps
log_prob = dist.log_prob(x) - torch.log(y).sum(-1, keepdim=True)
if self._noise is not None and self.training and explorating:
act += to_torch_as(self._noise(act.shape), act)
act = act.clamp(self._range[0], self._range[1])
return Batch(
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
def _target_q(self, buffer: ReplayBuffer,
indice: np.ndarray) -> torch.Tensor:
batch = buffer[indice] # batch.obs: s_{t+n}
with torch.no_grad():
obs_next_result = self(batch, input='obs_next', explorating=False)
a_ = obs_next_result.act
batch.act = to_torch_as(batch.act, a_)
target_q = torch.min(
self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_),
) - self._alpha * obs_next_result.log_prob
return target_q