Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __init__(
self,
state_shape,
units=[32, 32],
lr=0.001,
enable_sn=False,
name="GAIfO",
**kwargs):
IRLPolicy.__init__(self, name=name, n_training=1, **kwargs)
self.disc = Discriminator(
state_shape=state_shape,
units=units, enable_sn=enable_sn)
self.optimizer = tf.keras.optimizers.Adam(
learning_rate=lr, beta_1=0.5)
units=[32, 32],
lr=0.001,
enable_sn=False,
enable_gp=True,
enable_gc=False,
name="WGAIL",
**kwargs):
"""
:param enable_sn (bool): If true, add spectral normalization in Dense layer
:param enable_gp (bool): If true, add gradient penalty to loss function
:param enable_gc (bool): If true, apply gradient clipping while training
"""
assert enable_gp and enable_gc, \
"You must choose either Gradient Penalty or Gradient Clipping." \
"Both at the same time is not supported."
IRLPolicy.__init__(
self, name=name, **kwargs)
self.disc = Discriminator(
state_shape=state_shape, action_dim=action_dim,
units=units, enable_sn=enable_sn, output_activation="linear")
self.optimizer = tf.keras.optimizers.Adam(
learning_rate=lr, beta_1=0.5)
self._enable_gp = enable_gp
self._enable_gc = enable_gc
state_shape,
action_dim,
units=[32, 32],
n_latent_unit=32,
lr=5e-5,
kl_target=0.5,
reg_param=0.,
enable_sn=False,
enable_gp=False,
name="VAIL",
**kwargs):
"""
:param enable_sn (bool): If true, add spectral normalization in Dense layer
:param enable_gp (bool): If true, add gradient penalty to loss function
"""
IRLPolicy.__init__(
self, name=name, n_training=10, **kwargs)
self.disc = Discriminator(
state_shape=state_shape, action_dim=action_dim,
units=units, n_latent_unit=n_latent_unit,
enable_sn=enable_sn)
self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
self._kl_target = kl_target
self._reg_param = tf.Variable(reg_param, dtype=tf.float32)
self._step_reg_param = tf.constant(1e-5, dtype=tf.float32)
self._enable_gp = enable_gp