Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
del kwargs
base_ops.assert_rank_and_shape_compatibility([weights], 2)
sequence_lengths = math_ops.reduce_sum(weights, axis=1)
total_num = math_ops.reduce_sum(sequence_lengths)
baseline_values = array_ops.squeeze(
self.value(states, training=True),
axis=-1) * weights
base_ops.assert_rank_and_shape_compatibility([rewards, baseline_values], 2)
pcontinues = decay * weights
lambda_ = lambda_ * weights
bootstrap_values = indexing_ops.batched_index(
baseline_values, math_ops.cast(sequence_lengths - 1, dtypes.int32))
baseline_loss, td_lambda = value_ops.td_lambda(
parray_ops.swap_time_major(baseline_values),
parray_ops.swap_time_major(rewards),
parray_ops.swap_time_major(pcontinues),
bootstrap_values,
parray_ops.swap_time_major(lambda_))
advantages = parray_ops.swap_time_major(td_lambda.temporal_differences)
if normalize_advantages:
advantages = normalization_ops.normalize_by_moments(advantages, weights)
advantages = gen_array_ops.check_numerics(advantages, 'advantages')
policy = self.policy(states, training=True)
log_prob = policy.log_prob(actions)
policy_gradient_loss = gen_array_ops.stop_gradient(advantages) * -log_prob
self.policy_gradient_loss = losses_impl.compute_weighted_loss(
policy_gradient_loss,
dtypes = [obs_spec.dtype, np.int32, np.float32, np.float32]
placeholders = [
tf.placeholder(shape=(self._sequence_length, 1) + shape, dtype=dtype)
for shape, dtype in zip(shapes, dtypes)]
observations, actions, rewards, discounts = placeholders
self.arrays = [
np.zeros(shape=(self._sequence_length, 1) + shape, dtype=dtype)
for shape, dtype in zip(shapes, dtypes)]
# Build actor and critic losses.
logits, values = snt.BatchApply(network)(observations)
_, bootstrap_value = network(tf.expand_dims(obs, 0))
critic_loss, (advantages, _) = td_lambda_loss(
state_values=values,
rewards=rewards,
pcontinues=agent_discount * discounts,
bootstrap_value=bootstrap_value,
lambda_=td_lambda)
actor_loss = discrete_policy_gradient_loss(logits, actions, advantages)
train_op = optimizer.minimize(actor_loss + critic_loss)
# Create TF session and callables.
session = tf.Session()
self._policy_fn = session.make_callable(action, [obs])
self._update_fn = session.make_callable(train_op, placeholders + [obs])
session.run(tf.global_variables_initializer())