Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
baseline_scale: scalar or Tensor of shape `[B, T]` containing the baseline loss scale.
**kwargs: positional arguments (unused)
Returns:
the total loss Tensor of shape [].
"""
del kwargs
sequence_length = math_ops.reduce_sum(weights, axis=1)
total_num = math_ops.reduce_sum(sequence_length)
policy = self.policy(states, training=True)
behavioral_policy = self.behavioral_policy(states)
baseline_values = array_ops.squeeze(
self.value(states, training=True),
axis=-1) * weights
bootstrap_values = indexing_ops.batched_index(
baseline_values, math_ops.cast(sequence_length - 1, dtypes.int32))
baseline_values = parray_ops.swap_time_major(baseline_values)
pcontinues = parray_ops.swap_time_major(decay * weights)
log_prob = policy.log_prob(actions)
log_rhos = parray_ops.swap_time_major(log_prob) - parray_ops.swap_time_major(
behavioral_policy.log_prob(actions))
vtrace_returns = vtrace_ops.vtrace_from_importance_weights(
log_rhos,
pcontinues,
parray_ops.swap_time_major(rewards),
baseline_values,
bootstrap_values)
advantages = parray_ops.swap_time_major(vtrace_returns.pg_advantages)
if normalize_advantages:
# Define:
# T_tm1 = T(x_{t-1}, a_{t-1})
# T_t = T(x_t, a_t)
# exp_q_t = 𝔼_π Q(x_{t+1},.)
# qa_t = Q(x_t, a_t)
# Hence:
# T_tm1 = (r_t + γ * exp_q_t - c_t * qa_t) + γ * c_t * T_t
# Define:
# current = r_t + γ * (exp_q_t - c_t * qa_t)
# Thus:
# T_tm1 = scan_discounted_sum(current, γ * c_t, reverse=True)
args = [r_t, pcont_t, target_policy_t, c_t, q_t, a_t]
with tf.name_scope(
name, 'general_returns_based_off_policy_target', values=args):
exp_q_t = tf.reduce_sum(target_policy_t * q_t, axis=2)
qa_t = indexing_ops.batched_index(q_t, a_t)
current = r_t + pcont_t * (exp_q_t - c_t * qa_t)
initial_value = qa_t[-1]
return sequence_ops.scan_discounted_sum(
current,
pcont_t * c_t,
initial_value,
reverse=True,
back_prop=back_prop)
the total loss Tensor of shape [].
"""
del kwargs
sequence_length = math_ops.reduce_sum(weights, axis=1)
total_num = math_ops.reduce_sum(sequence_length)
policy = self.policy(states, training=True)
behavioral_policy = self.behavioral_policy(states)
baseline_values = array_ops.squeeze(
self.value(states, training=True),
axis=-1) * weights
pcontinues = decay * weights
lambda_ = lambda_ * weights
bootstrap_values = indexing_ops.batched_index(
baseline_values, math_ops.cast(sequence_length - 1, dtypes.int32))
baseline_loss, td_lambda = value_ops.td_lambda(
parray_ops.swap_time_major(baseline_values),
parray_ops.swap_time_major(rewards),
parray_ops.swap_time_major(pcontinues),
bootstrap_values,
parray_ops.swap_time_major(lambda_))
advantages = parray_ops.swap_time_major(td_lambda.temporal_differences)
advantages = normalization_ops.normalize_by_moments(advantages, weights)
advantages = gen_array_ops.stop_gradient(advantages)
ratio = gen_math_ops.exp(
policy.log_prob(actions) - gen_array_ops.stop_gradient(
behavioral_policy.log_prob(actions)))
clipped_ratio = clip_ops.clip_by_value(ratio, 1. - ratio_epsilon, 1. + ratio_epsilon)
A namedtuple with fields:
* `loss`: a tensor containing the batch of losses, shape `[B]`.
* `extra`: a namedtuple with fields:
* `target`: batch of target values for `q_tm1[a_tm1]`, shape `[B]`.
* `td_error`: batch of temporal difference errors, shape `[B]`.
"""
# Rank and compatibility checks.
base_ops.wrap_rank_shape_assert(
[[q_tm1, q_t], [a_t, r_t, pcont_t]], [2, 1], name)
# SARSA op.
with tf.name_scope(name, values=[q_tm1, a_tm1, r_t, pcont_t, q_t, a_t]):
# Select head to update and build target.
qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)
qa_t = indexing_ops.batched_index(q_t, a_t)
target = tf.stop_gradient(r_t + pcont_t * qa_t)
# Temporal difference error and loss.
# Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
td_error = target - qa_tm1
loss = 0.5 * tf.square(td_error)
return base_ops.LossOutput(loss, QExtra(target, td_error))
* `extra`: a namedtuple with fields:
* `target`: batch of target values for `q_tm1[a_tm1]`, shape `[B]`.
* `td_error`: batch of temporal difference errors, shape `[B]`.
"""
# Rank and compatibility checks.
base_ops.wrap_rank_shape_assert(
[[q_tm1, q_t], [a_tm1, r_t, pcont_t]], [2, 1], name)
# Q-learning op.
with tf.name_scope(name, values=[q_tm1, a_tm1, r_t, pcont_t, q_t]):
# Build target and select head to update.
with tf.name_scope("target"):
target = tf.stop_gradient(
r_t + pcont_t * tf.reduce_max(q_t, axis=1))
qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)
# Temporal difference error and loss.
# Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
td_error = target - qa_tm1
loss = 0.5 * tf.square(td_error)
return base_ops.LossOutput(loss, QExtra(target, td_error))
shape `[T, B]`.
* `target`: Tensor containing target action values, shape `[T, B]`.
"""
all_args = [
lambda_, q_tm1, a_tm1, r_t, pcont_t, target_policy_t, behaviour_policy_t,
targnet_q_t, a_t
]
with tf.name_scope(name, 'RetraceCore', all_args):
(lambda_, q_tm1, a_tm1, r_t, pcont_t, target_policy_t, behaviour_policy_t,
targnet_q_t, a_t) = (
tf.convert_to_tensor(arg) for arg in all_args)
# Evaluate importance weights.
c_t = _retrace_weights(
indexing_ops.batched_index(target_policy_t, a_t),
behaviour_policy_t) * lambda_
# Targets are evaluated by using only Q values from the target network.
# This provides fixed regression targets until the next target network
# update.
target = _general_off_policy_corrected_multistep_target(
r_t, pcont_t, target_policy_t, c_t, targnet_q_t, a_t,
not stop_targnet_gradients)
if stop_targnet_gradients:
target = tf.stop_gradient(target)
# Regress Q values of the learning network towards the targets evaluated
# by using the target network.
qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)
delta = target - qa_tm1
loss = 0.5 * tf.square(delta)
def _discounted_returns(rewards, decay, weights):
"""Compute the discounted returns given the decay factor."""
sequence_lengths = math_ops.reduce_sum(weights, axis=1)
bootstrap_values = indexing_ops.batched_index(
rewards, math_ops.cast(sequence_lengths - 1, dtypes.int32))
multi_step_returns = sequence_ops.scan_discounted_sum(
parray_ops.swap_time_major(rewards * weights),
parray_ops.swap_time_major(decay * weights),
bootstrap_values,
reverse=True,
back_prop=False)
return parray_ops.swap_time_major(multi_step_returns)
ValueError: If tensors are empty or fail the rank and mutual
compatibility asserts.
"""
del kwargs
base_ops.assert_rank_and_shape_compatibility([weights], 2)
sequence_lengths = math_ops.reduce_sum(weights, axis=1)
total_num = math_ops.reduce_sum(sequence_lengths)
baseline_values = array_ops.squeeze(
self.value(states, training=True),
axis=-1) * weights
base_ops.assert_rank_and_shape_compatibility([rewards, baseline_values], 2)
pcontinues = decay * weights
lambda_ = lambda_ * weights
bootstrap_values = indexing_ops.batched_index(
baseline_values, math_ops.cast(sequence_lengths - 1, dtypes.int32))
baseline_loss, td_lambda = value_ops.td_lambda(
parray_ops.swap_time_major(baseline_values),
parray_ops.swap_time_major(rewards),
parray_ops.swap_time_major(pcontinues),
bootstrap_values,
parray_ops.swap_time_major(lambda_))
advantages = parray_ops.swap_time_major(td_lambda.temporal_differences)
if normalize_advantages:
advantages = normalization_ops.normalize_by_moments(advantages, weights)
advantages = gen_array_ops.check_numerics(advantages, 'advantages')
policy = self.policy(states, training=True)
log_prob = policy.log_prob(actions)