Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
baseline_scale: scalar or Tensor of shape `[B, T]` containing the baseline loss scale.
**kwargs: positional arguments (unused)
Returns:
the total loss Tensor of shape [].
"""
del kwargs
base_ops.assert_rank_and_shape_compatibility([weights], 2)
sequence_lengths = math_ops.reduce_sum(weights, axis=1)
total_num = math_ops.reduce_sum(sequence_lengths)
multi_advantages = []
self.value_loss = []
multi_baseline_values = self.value(states, training=True) * array_ops.expand_dims(weights, axis=-1)
base_ops.assert_rank_and_shape_compatibility(
[rewards, multi_baseline_values], 3)
multi_baseline_values = array_ops.unstack(multi_baseline_values, axis=-1)
num_values = len(multi_baseline_values)
base_shape = rewards.shape
decay = self._least_fit(decay, base_shape)
lambda_ = self._least_fit(lambda_, base_shape)
baseline_scale = self._least_fit(baseline_scale, base_shape)
for i in range(num_values):
pcontinues = decay[..., i] * weights
lambdas = lambda_[..., i] * weights
bootstrap_values = indexing_ops.batched_index(
multi_baseline_values[i], math_ops.cast(sequence_lengths - 1, dtypes.int32))
baseline_loss, td_lambda = value_ops.td_lambda(
parray_ops.swap_time_major(multi_baseline_values[i]),
actions: Tensor of `[B, T, ...]` containing actions.
rewards: Tensor of `[B, T, V]` containing rewards.
weights: Tensor of shape `[B, T]` containing weights (1. or 0.).
decay: scalar, 1-D Tensor of shape [V], or Tensor of shape
`[B, T]` or `[B, T, V]` containing decays/discounts.
lambda_: scalar, 1-D Tensor of shape [V], or Tensor of shape
`[B, T]` or `[B, T, V]` containing generalized lambda parameter.
entropy_scale: scalar or Tensor of shape `[B, T]` containing the entropy loss scale.
baseline_scale: scalar or Tensor of shape `[B, T]` containing the baseline loss scale.
**kwargs: positional arguments (unused)
Returns:
the total loss Tensor of shape [].
"""
del kwargs
base_ops.assert_rank_and_shape_compatibility([weights], 2)
sequence_lengths = math_ops.reduce_sum(weights, axis=1)
total_num = math_ops.reduce_sum(sequence_lengths)
multi_advantages = []
self.value_loss = []
multi_baseline_values = self.value(states, training=True) * array_ops.expand_dims(weights, axis=-1)
base_ops.assert_rank_and_shape_compatibility(
[rewards, multi_baseline_values], 3)
multi_baseline_values = array_ops.unstack(multi_baseline_values, axis=-1)
num_values = len(multi_baseline_values)
base_shape = rewards.shape
decay = self._least_fit(decay, base_shape)
lambda_ = self._least_fit(lambda_, base_shape)
baseline_scale = self._least_fit(baseline_scale, base_shape)
Returns:
the total loss Tensor of shape [].
Raises:
ValueError: If tensors are empty or fail the rank and mutual
compatibility asserts.
"""
del kwargs
base_ops.assert_rank_and_shape_compatibility([weights], 2)
sequence_lengths = math_ops.reduce_sum(weights, axis=1)
total_num = math_ops.reduce_sum(sequence_lengths)
baseline_values = array_ops.squeeze(
self.value(states, training=True),
axis=-1) * weights
base_ops.assert_rank_and_shape_compatibility([rewards, baseline_values], 2)
pcontinues = decay * weights
lambda_ = lambda_ * weights
bootstrap_values = indexing_ops.batched_index(
baseline_values, math_ops.cast(sequence_lengths - 1, dtypes.int32))
baseline_loss, td_lambda = value_ops.td_lambda(
parray_ops.swap_time_major(baseline_values),
parray_ops.swap_time_major(rewards),
parray_ops.swap_time_major(pcontinues),
bootstrap_values,
parray_ops.swap_time_major(lambda_))
advantages = parray_ops.swap_time_major(td_lambda.temporal_differences)
if normalize_advantages:
advantages = normalization_ops.normalize_by_moments(advantages, weights)
def check_rank(tensors, ranks):
for i, (tensor, rank) in enumerate(zip(tensors, ranks)):
if tensor.get_shape():
base_ops.assert_rank_and_shape_compatibility([tensor], rank)
else:
tf.logging.error(
'Tensor "%s", which was offered as Retrace parameter %d, has '
'no rank at construction time, so Retrace can\'t verify that '
'it has the necessary rank of %d', tensor.name, i + 1, rank)