Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def get_argument(parser=None):
parser = IRLPolicy.get_argument(parser)
parser.add_argument('--enable-sn', action='store_true')
return parser
dummy_action = tf.constant(
np.zeros(shape=[1, action_dim], dtype=np.float32))
with tf.device("/cpu:0"):
self([dummy_state, dummy_action])
def call(self, inputs):
features = tf.concat(inputs, axis=1)
features = self.l1(features)
features = self.l2(features)
return self.l3(features)
def compute_reward(self, inputs):
return tf.math.log(self(inputs) + 1e-8)
class GAIL(IRLPolicy):
def __init__(
self,
state_shape,
action_dim,
units=[32, 32],
lr=0.001,
enable_sn=False,
name="GAIL",
**kwargs):
super().__init__(name=name, n_training=1, **kwargs)
self.disc = Discriminator(
state_shape=state_shape, action_dim=action_dim,
units=units, enable_sn=enable_sn)
self.optimizer = tf.keras.optimizers.Adam(
learning_rate=lr, beta_1=0.5)
self.l1 = DenseClass(units[0], name="L1", activation="relu")
self.l2 = DenseClass(units[1], name="L2", activation="relu")
self.l3 = DenseClass(1, name="L3", activation=output_activation)
dummy_state = tf.constant(
np.zeros(shape=(1,)+state_shape, dtype=np.float32))
with tf.device("/cpu:0"):
self(dummy_state)
def call(self, inputs):
features = self.l1(inputs)
features = self.l2(features)
return self.l3(features)
class AIRL(IRLPolicy):
def __init__(
self,
state_shape,
action_dim,
state_only=True,
units=[32, 32],
lr=0.001,
enable_sn=False,
name="AIRL",
**kwargs):
super().__init__(name=name, n_training=1, **kwargs)
self._state_only = state_only
if state_only:
self.rew_net = StateModel(
state_shape=state_shape, units=units,
name="reward_net", enable_sn=enable_sn, output_activation="linear")