Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def set_weights_fn(policy, weights):
actor_weights, critic_weights, critic_target_weights = weights
update_target_variables(
policy.actor.weights, actor_weights, tau=1.)
update_target_variables(
policy.critic.weights, critic_weights, tau=1.)
update_target_variables(
policy.critic_target.weights, critic_target_weights, tau=1.)
def set_weights_fn(policy, weights):
actor_weights, critic_weights, critic_target_weights = weights
update_target_variables(
policy.actor.weights, actor_weights, tau=1.)
update_target_variables(
policy.critic.weights, critic_weights, tau=1.)
update_target_variables(
policy.critic_target.weights, critic_target_weights, tau=1.)
self.critic_optimizer.apply_gradients(
zip(critic_grad, self.critic.trainable_variables))
with tf.GradientTape() as tape:
next_action = self.actor(states)
actor_loss = -tf.reduce_mean(self.critic([states, next_action]))
actor_grad = tape.gradient(
actor_loss, self.actor.trainable_variables)
self.actor_optimizer.apply_gradients(
zip(actor_grad, self.actor.trainable_variables))
# Update target networks
update_target_variables(
self.critic_target.weights, self.critic.weights, self.tau)
update_target_variables(
self.actor_target.weights, self.actor.weights, self.tau)
return actor_loss, critic_loss, td_errors
super().__init__(name=name, memory_capacity=memory_capacity, n_warmup=n_warmup, **kwargs)
# Define and initialize Actor network
self.actor = Actor(state_shape, action_dim, max_action, actor_units)
self.actor_target = Actor(
state_shape, action_dim, max_action, actor_units)
self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=lr_actor)
update_target_variables(self.actor_target.weights,
self.actor.weights, tau=1.)
# Define and initialize Critic network
self.critic = Critic(state_shape, action_dim, critic_units)
self.critic_target = Critic(state_shape, action_dim, critic_units)
self.critic_optimizer = tf.keras.optimizers.Adam(
learning_rate=lr_critic)
update_target_variables(
self.critic_target.weights, self.critic.weights, tau=1.)
# Set hyperparameters
self.sigma = sigma
self.tau = tau
def _setup_critic_v(self, state_shape, critic_units, lr):
self.vf = CriticV(state_shape, critic_units)
self.vf_target = CriticV(state_shape, critic_units)
update_target_variables(self.vf_target.weights,
self.vf.weights, tau=1.)
self.vf_optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
next_actions = self.actor(states)
actor_loss = - \
tf.reduce_mean(self.critic([states, next_actions]))
remainder = tf.math.mod(self._it, self._actor_update_freq)
def optimize_actor():
actor_grad = tape.gradient(
actor_loss, self.actor.trainable_variables)
return self.actor_optimizer.apply_gradients(
zip(actor_grad, self.actor.trainable_variables))
tf.cond(pred=tf.equal(remainder, 0), true_fn=optimize_actor, false_fn=tf.no_op)
# Update target networks
update_target_variables(
self.critic_target.weights, self.critic.weights, self.tau)
update_target_variables(
self.actor_target.weights, self.actor.weights, self.tau)
return actor_loss, critic_loss, tf.abs(td_error1) + tf.abs(td_error2)
def set_weights_fn(policy, weights):
actor_weights, critic_weights, critic_target_weights = weights
update_target_variables(
policy.actor.weights, actor_weights, tau=1.)
update_target_variables(
policy.critic.weights, critic_weights, tau=1.)
update_target_variables(
policy.critic_target.weights, critic_target_weights, tau=1.)
def set_weights_fn(policy, weights):
actor_weights, critic_weights, critic_target_weights = weights
update_target_variables(
policy.actor.weights, actor_weights, tau=1.)
update_target_variables(
policy.critic.weights, critic_weights, tau=1.)
update_target_variables(
policy.critic_target.weights, critic_target_weights, tau=1.)
def set_weights_fn(policy, weights):
actor_weights, critic_weights, critic_target_weights = weights
update_target_variables(
policy.actor.weights, actor_weights, tau=1.)
update_target_variables(
policy.critic.weights, critic_weights, tau=1.)
update_target_variables(
policy.critic_target.weights, critic_target_weights, tau=1.)