Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
advantages = parray_ops.swap_time_major(td_lambda.temporal_differences)
multi_advantages.append(advantages)
advantages = math_ops.add_n(multi_advantages) # A = A[0] + A[1] + ...
if normalize_advantages:
advantages = normalization_ops.normalize_by_moments(advantages, weights)
advantages = gen_array_ops.stop_gradient(advantages)
policy = self.policy(states, training=True)
log_prob = policy.log_prob(actions)
policy_gradient_loss = advantages * -log_prob
self.policy_gradient_loss = losses_impl.compute_weighted_loss(
policy_gradient_loss,
weights=weights)
entropy_loss = policy_gradient_ops.policy_entropy_loss(
policy,
self.policy.trainable_variables,
lambda policies: entropy_scale).loss
self.policy_gradient_entropy_loss = losses_impl.compute_weighted_loss(
entropy_loss,
weights=weights)
self.total_loss = math_ops.add_n([
math_ops.add_n(self.value_loss),
self.policy_gradient_loss,
self.policy_gradient_entropy_loss])
return self.total_loss