Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
tf.get_variable_scope().name)
# actual loss computation
imitation_loss = -tf.reduce_mean(actions_logp)
tm_values = self.model.values
baseline_values = tm_values[:-1]
if config.get("soft_horizon"):
discounts = config["gamma"]
else:
discounts = tf.to_float(~dones[:-1]) * config["gamma"]
td_lambda = trfl.td_lambda(
state_values=baseline_values,
rewards=rewards[:-1],
pcontinues=discounts,
bootstrap_value=tm_values[-1],
lambda_=config.get("lambda", 1.))
# td_lambda.loss has shape [B] after a reduce_sum
vf_loss = tf.reduce_mean(td_lambda.loss) / T
self.total_loss = imitation_loss + self.config["vf_loss_coeff"] * vf_loss
# Initialize TFPolicyGraph
loss_in = [
(SampleBatch.ACTIONS, actions),
(SampleBatch.DONES, dones),
# (BEHAVIOUR_LOGITS, behaviour_logits),