Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
mask = array_ops.expand_dims(
array_ops.sequence_mask(
gen_math_ops.maximum(sequence_length - 1, 0),
maxlen=states.shape[1],
dtype=dtypes.float32),
axis=-1)
pcontinues = decay * weights
action_values = self.value(states, training=True) * mask
next_action_values = gen_array_ops.stop_gradient(
self.target_value(next_states, training=True))
lambda_ = gen_array_ops.broadcast_to(lambda_, array_ops.shape(rewards))
baseline_loss = action_value_ops.qlambda(
parray_ops.swap_time_major(action_values),
parray_ops.swap_time_major(actions),
parray_ops.swap_time_major(rewards),
parray_ops.swap_time_major(pcontinues),
parray_ops.swap_time_major(next_action_values),
parray_ops.swap_time_major(lambda_)).loss
self.value_loss = baseline_scale * losses_impl.compute_weighted_loss(
baseline_loss,
parray_ops.swap_time_major(weights))
self.total_loss = self.value_loss
return self.total_loss
sequence_length = math_ops.reduce_sum(weights, axis=1)
mask = array_ops.expand_dims(
array_ops.sequence_mask(
gen_math_ops.maximum(sequence_length - 1, 0),
maxlen=states.shape[1],
dtype=dtypes.float32),
axis=-1)
pcontinues = decay * weights
action_values = self.value(states, training=True) * mask
next_action_values = gen_array_ops.stop_gradient(
self.value(next_states, training=True))
lambda_ = gen_array_ops.broadcast_to(lambda_, array_ops.shape(rewards))
baseline_loss = action_value_ops.qlambda(
parray_ops.swap_time_major(action_values),
parray_ops.swap_time_major(actions),
parray_ops.swap_time_major(rewards),
parray_ops.swap_time_major(pcontinues),
parray_ops.swap_time_major(next_action_values),
parray_ops.swap_time_major(lambda_)).loss
self.value_loss = baseline_scale * losses_impl.compute_weighted_loss(
baseline_loss,
parray_ops.swap_time_major(weights))
self.total_loss = self.value_loss
return self.total_loss
discount_factor = tf.convert_to_tensor(discount_factor)
if discount_factor.shape.ndims == 0:
pcont_t = tf.reshape(discount_factor, [1, 1]) # [1,1].
pcont_t = tf.tile(pcont_t, tf.shape(a_tm1)) # [T,BHW].
elif discount_factor.shape.ndims == 2:
tiled_pcont = tf.tile(
tf.expand_dims(tf.expand_dims(discount_factor, -1), -1),
[1, 1] + height_width)
pcont_t = tf.reshape(tiled_pcont, [sequence_length, -1])
else:
raise ValueError(
"The discount_factor must be a scalar or a tensor of rank 2."
"instead is a tensor of shape {}".format(
discount_factor.shape.as_list()))
# Compute a QLambda loss of shape [T,BHW]
loss, _ = action_value_ops.qlambda(q_tm1, a_tm1, r_t, pcont_t, q_t, lambda_=1)
# Take sum over sequence, sum over cells.
expanded_shape = [sequence_length, batch_size] + height_width
spatial_loss = tf.reshape(loss, expanded_shape) # [T,B,H,W].
# Return.
extra = PixelControlExtra(
spatial_loss=spatial_loss, pseudo_rewards=pseudo_rewards)
return base_ops.LossOutput(
tf.reduce_sum(spatial_loss, axis=[0, 2, 3]), extra) # [B]
def compute_loss(self, rollouts, delay=.999):
"""Implements the double Q-learning loss.
The loss is `0.5` times the squared difference between `q_tm1[a_tm1]` and
the target `r_t + pcont_t * q_t_value[argmax q_t_selector]`.
See "Double Q-learning" by van Hasselt.
(https://papers.nips.cc/paper/3964-double-q-learning.pdf).
"""
pcontinues = delay * rollouts.weights
action_values = self.value(rollouts.states, training=True)
next_action_values = self.value(rollouts.next_states, training=True)
self.value_loss = math_ops.reduce_mean(
action_value_ops.double_qlearning(
action_values,
rollouts.actions,
rollouts.rewards,
pcontinues,
next_action_values,
next_action_values).loss,
axis=0)
self.total_loss = self.value_loss
return self.total_loss