Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
total_num = math_ops.reduce_sum(sequence_length)
policy = self.policy(states, training=True)
behavioral_policy = self.behavioral_policy(states)
baseline_values = array_ops.squeeze(
self.value(states, training=True),
axis=-1) * weights
bootstrap_values = indexing_ops.batched_index(
baseline_values, math_ops.cast(sequence_length - 1, dtypes.int32))
baseline_values = parray_ops.swap_time_major(baseline_values)
pcontinues = parray_ops.swap_time_major(decay * weights)
log_prob = policy.log_prob(actions)
log_rhos = parray_ops.swap_time_major(log_prob) - parray_ops.swap_time_major(
behavioral_policy.log_prob(actions))
vtrace_returns = vtrace_ops.vtrace_from_importance_weights(
log_rhos,
pcontinues,
parray_ops.swap_time_major(rewards),
baseline_values,
bootstrap_values)
advantages = parray_ops.swap_time_major(vtrace_returns.pg_advantages)
if normalize_advantages:
advantages = normalization_ops.normalize_by_moments(advantages, weights)
advantages = gen_array_ops.stop_gradient(advantages)
policy_gradient_loss = advantages * -log_prob
self.policy_gradient_loss = losses_impl.compute_weighted_loss(
policy_gradient_loss,
weights=weights)
def compute_loss(self, rollouts, decay=.999, lambda_=1., entropy_scale=.2, baseline_scale=1., ratio_epsilon=.2):
policy = self.policy(rollouts.states, training=True)
behavioral_policy = self.behavioral_policy(rollouts.states)
baseline_values = parray_ops.swap_time_major(
array_ops.squeeze(self.value(rollouts.states, training=True), axis=-1))
pcontinues = parray_ops.swap_time_major(decay * rollouts.weights)
bootstrap_values = baseline_values[-1, :]
log_prob = policy.log_prob(rollouts.actions)
behavioral_log_prob = behavioral_policy.log_prob(rollouts.actions)
log_rhos = parray_ops.swap_time_major(log_prob - gen_array_ops.stop_gradient(behavioral_log_prob))
vtrace_returns = vtrace_ops.vtrace_from_importance_weights(
log_rhos,
pcontinues,
parray_ops.swap_time_major(rollouts.rewards),
baseline_values,
bootstrap_values)
advantages = parray_ops.swap_time_major(vtrace_returns.pg_advantages)
advantages = normalization_ops.weighted_moments_normalize(advantages, rollouts.weights)
advantages = gen_array_ops.stop_gradient(advantages)
ratio = parray_ops.swap_time_major(gen_math_ops.exp(log_rhos))
clipped_ratio = clip_ops.clip_by_value(ratio, 1. - ratio_epsilon, 1. + ratio_epsilon)
self.policy_gradient_loss = -losses_impl.compute_weighted_loss(
gen_math_ops.minimum(advantages * ratio, advantages * clipped_ratio),
weights=rollouts.weights)