Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
Returns:
A namedtuple with fields:
* `loss`: a tensor containing the batch of losses, shape `[B]`.
* `extra`: a namedtuple with fields:
* `target`: a tensor containing the values that `q_tm1` at actions
`a_tm1` are regressed towards, shape `[B, num_atoms]`.
Raises:
ValueError: If the tensors do not have the correct rank or compatibility.
"""
# Rank and compatibility checks.
assertion_lists = [[logits_q_tm1, logits_q_t], [a_tm1, r_t, pcont_t],
[atoms_tm1, atoms_t]]
base_ops.wrap_rank_shape_assert(assertion_lists, [3, 1, 1], name)
# Categorical distributional Q-learning op.
with tf.name_scope(
name,
values=[
atoms_tm1, logits_q_tm1, a_tm1, r_t, pcont_t, atoms_t, logits_q_t
]):
with tf.name_scope("target"):
# Scale and shift time-t distribution atoms by discount and reward.
target_z = r_t[:, None] + pcont_t[:, None] * atoms_t[None, :]
# Convert logits to distribution, then find greedy action in state s_t.
q_t_probs = tf.nn.softmax(logits_q_t)
q_t_mean = tf.reduce_sum(q_t_probs * atoms_t, 2)
pi_t = tf.argmax(q_t_mean, 1, output_type=tf.int32)
r_t: Tensor holding rewards, shape `[B]`.
pcont_t: Tensor holding pcontinue values, shape `[B]`.
q_t: Tensor holding Q-values for second timestep in a batch of
transitions, shape `[B x num_actions]`.
name: name to prefix ops created within this op.
Returns:
A namedtuple with fields:
* `loss`: a tensor containing the batch of losses, shape `[B]`.
* `extra`: a namedtuple with fields:
* `target`: batch of target values for `q_tm1[a_tm1]`, shape `[B]`.
* `td_error`: batch of temporal difference errors, shape `[B]`.
"""
# Rank and compatibility checks.
base_ops.wrap_rank_shape_assert(
[[q_tm1, q_t], [a_tm1, r_t, pcont_t]], [2, 1], name)
# Q-learning op.
with tf.name_scope(name, values=[q_tm1, a_tm1, r_t, pcont_t, q_t]):
# Build target and select head to update.
with tf.name_scope("target"):
target = tf.stop_gradient(
r_t + pcont_t * tf.reduce_max(q_t, axis=1))
qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)
# Temporal difference error and loss.
# Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
td_error = target - qa_tm1
loss = 0.5 * tf.square(td_error)
return base_ops.LossOutput(loss, QExtra(target, td_error))
v_tm1: Tensor holding values at previous timestep, shape `[B]`.
r_t: Tensor holding rewards, shape `[B]`.
pcont_t: Tensor holding pcontinue values, shape `[B]`.
v_t: Tensor holding values at current timestep, shape `[B]`.
name: name to prefix ops created by this function.
Returns:
A namedtuple with fields:
* `loss`: a tensor containing the batch of losses, shape `[B]`.
* `extra`: a namedtuple with fields:
* `target`: batch of target values for `v_tm1`, shape `[B]`.
* `td_error`: batch of temporal difference errors, shape `[B]`.
"""
# Rank and compatibility checks.
base_ops.wrap_rank_shape_assert([[v_tm1, v_t, r_t, pcont_t]], [1], name)
# TD(0)-learning op.
with tf.name_scope(name, values=[v_tm1, r_t, pcont_t, v_t]):
# Build target.
target = tf.stop_gradient(r_t + pcont_t * v_t)
# Temporal difference error and loss.
# Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
td_error = target - v_tm1
loss = 0.5 * tf.square(td_error)
return base_ops.LossOutput(loss, TDExtra(target, td_error))
transitions, shape `[B x num_actions]`.
These values are used for estimating the value of the best action. In
DQN they come from the target network.
action_gap_scale: coefficient in [0, 1] for scaling the action gap term.
name: name to prefix ops created within this op.
Returns:
A namedtuple with fields:
* `loss`: a tensor containing the batch of losses, shape `[B]`.
* `extra`: a namedtuple with fields:
* `target`: batch of target values for `q_tm1[a_tm1]`, shape `[B]`.
* `td_error`: batch of temporal difference errors, shape `[B]`.
"""
# Rank and compatibility checks.
base_ops.wrap_rank_shape_assert(
[[q_tm1, q_t], [a_tm1, r_t, pcont_t]], [2, 1], name)
base_ops.assert_arg_bounded(action_gap_scale, 0, 1, name, "action_gap_scale")
# persistent Q-learning op.
with tf.name_scope(name, values=[q_tm1, a_tm1, r_t, pcont_t, q_t]):
# Build target and select head to update.
with tf.name_scope("target"):
max_q_t = tf.reduce_max(q_t, axis=1)
qa_t = indexing_ops.batched_index(q_t, a_tm1)
corrected_q_t = (1 - action_gap_scale) * max_q_t + action_gap_scale * qa_t
target = tf.stop_gradient(r_t + pcont_t * corrected_q_t)
qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)
# Temporal difference error and loss.
# Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
v_tm1: Tensor holding values at previous timestep, shape `[B]`.
r_t: Tensor holding rewards, shape `[B]`.
pcont_t: Tensor holding pcontinue values, shape `[B]`.
q_t: Tensor of action values at current timestep, shape `[B, num_actions]`.
name: name to prefix ops created by this function.
Returns:
A namedtuple with fields:
* `loss`: a tensor containing the batch of losses, shape `[B]`.
* `extra`: a namedtuple with fields:
* `target`: batch of target values for `v_tm1`, shape `[B]`.
* `td_error`: batch of temporal difference errors, shape `[B]`.
"""
# Rank and compatibility checks.
base_ops.wrap_rank_shape_assert([[v_tm1, r_t, pcont_t], [q_t]], [1, 2], name)
# The QVMAX op.
with tf.name_scope(name, values=[v_tm1, r_t, pcont_t, q_t]):
# Build target.
target = tf.stop_gradient(r_t + pcont_t * tf.reduce_max(q_t, axis=1))
# Temporal difference error and loss.
# Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
td_error = target - v_tm1
loss = 0.5 * tf.square(td_error)
return base_ops.LossOutput(loss, TDExtra(target, td_error))
used to estimate the value of the best action, shape `[B x num_actions]`.
q_t_selector: Tensor of Q-values for second timestep in a batch of
transitions used to estimate the best action, shape `[B x num_actions]`.
name: name to prefix ops created within this op.
Returns:
A namedtuple with fields:
* `loss`: a tensor containing the batch of losses, shape `[B]`.
* `extra`: a namedtuple with fields:
* `target`: batch of target values for `q_tm1[a_tm1]`, shape `[B]`
* `td_error`: batch of temporal difference errors, shape `[B]`
* `best_action`: batch of greedy actions wrt `q_t_selector`, shape `[B]`
"""
# Rank and compatibility checks.
base_ops.wrap_rank_shape_assert(
[[q_tm1, q_t_value, q_t_selector], [a_tm1, r_t, pcont_t]], [2, 1], name)
# double Q-learning op.
with tf.name_scope(
name, values=[q_tm1, a_tm1, r_t, pcont_t, q_t_value, q_t_selector]):
# Build target and select head to update.
best_action = tf.argmax(q_t_selector, 1, output_type=tf.int32)
double_q_bootstrapped = indexing_ops.batched_index(q_t_value, best_action)
target = tf.stop_gradient(r_t + pcont_t * double_q_bootstrapped)
qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)
# Temporal difference error and loss.
# Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
td_error = target - qa_tm1
loss = 0.5 * tf.square(td_error)
pcont_t: Tensor holding pcontinue values, shape `[B]`.
q_t: Tensor holding Q-values for second timestep in a batch of
transitions, shape `[B x num_actions]`.
a_t: Tensor holding action indices for second timestep, shape `[B]`.
name: name to prefix ops created within this op.
Returns:
A namedtuple with fields:
* `loss`: a tensor containing the batch of losses, shape `[B]`.
* `extra`: a namedtuple with fields:
* `target`: batch of target values for `q_tm1[a_tm1]`, shape `[B]`.
* `td_error`: batch of temporal difference errors, shape `[B]`.
"""
# Rank and compatibility checks.
base_ops.wrap_rank_shape_assert(
[[q_tm1, q_t], [a_t, r_t, pcont_t]], [2, 1], name)
# SARSA op.
with tf.name_scope(name, values=[q_tm1, a_tm1, r_t, pcont_t, q_t, a_t]):
# Select head to update and build target.
qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)
qa_t = indexing_ops.batched_index(q_t, a_t)
target = tf.stop_gradient(r_t + pcont_t * qa_t)
# Temporal difference error and loss.
# Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
td_error = target - qa_tm1
loss = 0.5 * tf.square(td_error)
return base_ops.LossOutput(loss, QExtra(target, td_error))
Returns:
A namedtuple with fields:
* `loss`: Tensor containing the batch of losses, shape `[B]`.
* `extra`: A namedtuple with fields:
* `target`: Tensor containing the values that `v_tm1` are
regressed towards, shape `[B, num_atoms]`.
Raises:
ValueError: If the tensors do not have the correct rank or compatibility.
"""
# Rank and compatibility checks.
assertion_lists = [[logits_v_tm1, logits_v_t], [r_t, pcont_t],
[atoms_tm1, atoms_t]]
base_ops.wrap_rank_shape_assert(assertion_lists, [2, 1, 1], name)
# Categorical distributional TD-learning op.
with tf.name_scope(
name, values=[atoms_tm1, logits_v_tm1, r_t, pcont_t, atoms_t,
logits_v_t]):
with tf.name_scope("target"):
# Scale and shift time-t distribution atoms by discount and reward.
target_z = r_t[:, None] + pcont_t[:, None] * atoms_t[None, :]
v_t_probs = tf.nn.softmax(logits_v_t)
# Project using the Cramer distance
target = tf.stop_gradient(_l2_project(target_z, v_t_probs, atoms_tm1))
loss = tf.nn.softmax_cross_entropy_with_logits(
logits=logits_v_tm1, labels=target)
r_t: Tensor holding rewards, shape `[B]`.
pcont_t: Tensor holding pcontinue values, shape `[B]`.
v_t: Tensor holding state-values for second timestep in a batch of
transitions, shape `[B]`.
name: name to prefix ops created within this op.
Returns:
A namedtuple with fields:
* `loss`: a tensor containing the batch of losses, shape `[B]`.
* `extra`: a namedtuple with fields:
* `target`: batch of target values for `q_tm1[a_tm1]`, shape `[B]`.
* `td_error`: batch of temporal difference errors, shape `[B]`.
"""
# Rank and compatibility checks.
base_ops.wrap_rank_shape_assert(
[[q_tm1], [a_tm1, r_t, pcont_t, v_t]], [2, 1], name)
# QV op.
with tf.name_scope(name, values=[q_tm1, a_tm1, r_t, pcont_t, v_t]):
# Build target and select head to update.
with tf.name_scope("target"):
target = tf.stop_gradient(r_t + pcont_t * v_t)
qa_tm1 = indexing_ops.batched_index(q_tm1, a_tm1)
# Temporal difference error and loss.
# Loss is MSE scaled by 0.5, so the gradient is equal to the TD error.
td_error = target - qa_tm1
loss = 0.5 * tf.square(td_error)
return base_ops.LossOutput(loss, QExtra(target, td_error))