Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def compute_loss(self, rollouts, delay=.999):
"""Implements the Q-learning loss.
The loss is `0.5` times the squared difference between `q_tm1[a_tm1]` and
the target `r_t + pcont_t * max q_t`.
See "Reinforcement Learning: An Introduction" by Sutton and Barto.
(http://incompleteideas.net/book/ebook/node65.html).
"""
pcontinues = delay * rollouts.weights
action_values = self.value(rollouts.states, training=True)
next_action_values = self.value(rollouts.next_states, training=True)
self.value_loss = math_ops.reduce_mean(
action_value_ops.qlearning(
action_values,
rollouts.actions,
rollouts.rewards,
pcontinues,
next_action_values).loss,
axis=0)
self.total_loss = self.value_loss
return self.total_loss
tf.set_random_seed(seed)
self._rng = np.random.RandomState(seed)
# Make the TensorFlow graph.
o = tf.placeholder(shape=obs_spec.shape, dtype=obs_spec.dtype)
q = online_network(tf.expand_dims(o, 0))
o_tm1 = tf.placeholder(shape=(None,) + obs_spec.shape, dtype=obs_spec.dtype)
a_tm1 = tf.placeholder(shape=(None,), dtype=action_spec.dtype)
r_t = tf.placeholder(shape=(None,), dtype=tf.float32)
d_t = tf.placeholder(shape=(None,), dtype=tf.float32)
o_t = tf.placeholder(shape=(None,) + obs_spec.shape, dtype=obs_spec.dtype)
q_tm1 = online_network(o_tm1)
q_t = target_network(o_t)
loss = qlearning(q_tm1, a_tm1, r_t, discount * d_t, q_t).loss
train_op = self._optimizer.minimize(loss)
with tf.control_dependencies([train_op]):
train_op = periodic_target_update(
target_variables=target_network.variables,
source_variables=online_network.variables,
update_period=target_update_period)
# Make session and callables.
session = tf.Session()
self._sgd_fn = session.make_callable(train_op,
[o_tm1, a_tm1, r_t, d_t, o_t])
self._value_fn = session.make_callable(q, [o])
session.run(tf.global_variables_initializer())