How to use the tf2rl.algos.policy_base.IRLPolicy function in tf2rl

To help you get started, we’ve selected a few tf2rl examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github keiohta / tf2rl / tf2rl / algos / gail.py View on Github external
def get_argument(parser=None):
        parser = IRLPolicy.get_argument(parser)
        parser.add_argument('--enable-sn', action='store_true')
        return parser
github keiohta / tf2rl / tf2rl / algos / gail.py View on Github external
dummy_action = tf.constant(
            np.zeros(shape=[1, action_dim], dtype=np.float32))
        with tf.device("/cpu:0"):
            self([dummy_state, dummy_action])

    def call(self, inputs):
        features = tf.concat(inputs, axis=1)
        features = self.l1(features)
        features = self.l2(features)
        return self.l3(features)

    def compute_reward(self, inputs):
        return tf.math.log(self(inputs) + 1e-8)


class GAIL(IRLPolicy):
    def __init__(
            self,
            state_shape,
            action_dim,
            units=[32, 32],
            lr=0.001,
            enable_sn=False,
            name="GAIL",
            **kwargs):
        super().__init__(name=name, n_training=1, **kwargs)
        self.disc = Discriminator(
            state_shape=state_shape, action_dim=action_dim,
            units=units, enable_sn=enable_sn)
        self.optimizer = tf.keras.optimizers.Adam(
            learning_rate=lr, beta_1=0.5)
github keiohta / tf2rl / tf2rl / algos / airl.py View on Github external
self.l1 = DenseClass(units[0], name="L1", activation="relu")
        self.l2 = DenseClass(units[1], name="L2", activation="relu")
        self.l3 = DenseClass(1, name="L3", activation=output_activation)

        dummy_state = tf.constant(
            np.zeros(shape=(1,)+state_shape, dtype=np.float32))
        with tf.device("/cpu:0"):
            self(dummy_state)

    def call(self, inputs):
        features = self.l1(inputs)
        features = self.l2(features)
        return self.l3(features)


class AIRL(IRLPolicy):
    def __init__(
            self,
            state_shape,
            action_dim,
            state_only=True,
            units=[32, 32],
            lr=0.001,
            enable_sn=False,
            name="AIRL",
            **kwargs):
        super().__init__(name=name, n_training=1, **kwargs)
        self._state_only = state_only
        if state_only:
            self.rew_net = StateModel(
                state_shape=state_shape, units=units,
                name="reward_net", enable_sn=enable_sn, output_activation="linear")