How to use the tf2rl.algos.dqn.DQN function in tf2rl

To help you get started, we’ve selected a few tf2rl examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github keiohta / tf2rl / tests / experiments / test_utils.py View on Github external
def setUpClass(cls):
        cls.env = gym.make("CartPole-v0")
        policy = DQN(
            state_shape=cls.env.observation_space.shape,
            action_dim=cls.env.action_space.n,
            memory_capacity=2**4)
        cls.replay_buffer = get_replay_buffer(
            policy, cls.env)
        cls.output_dir = os.path.join(
            os.path.dirname(__file__),
            "tests")
        if not os.path.isdir(cls.output_dir):
            os.makedirs(cls.output_dir)
github keiohta / tf2rl / tests / algos / test_dqn.py View on Github external
def setUpClass(cls):
        super().setUpClass()
        cls.agent = DQN(
            state_shape=cls.discrete_env.observation_space.shape,
            action_dim=cls.discrete_env.action_space.n,
            batch_size=cls.batch_size,
            enable_categorical_dqn=True,
            epsilon=0.,
            gpu=-1)
github keiohta / tf2rl / tests / algos / test_dqn.py View on Github external
def setUpClass(cls):
        super().setUpClass()
        cls.agent = DQN(
            state_shape=cls.discrete_env.observation_space.shape,
            action_dim=cls.discrete_env.action_space.n,
            batch_size=cls.batch_size,
            enable_double_dqn=True,
            enable_dueling_dqn=True,
            epsilon=0.,
            gpu=-1)
github keiohta / tf2rl / tests / algos / test_dqn.py View on Github external
def setUpClass(cls):
        super().setUpClass()
        cls.agent = DQN(
            state_shape=cls.discrete_env.observation_space.shape,
            action_dim=cls.discrete_env.action_space.n,
            batch_size=cls.batch_size,
            enable_noisy_dqn=True,
            epsilon=0.,
            gpu=-1)
github keiohta / tf2rl / examples / run_dqn.py View on Github external
if __name__ == '__main__':
    parser = Trainer.get_argument()
    parser = DQN.get_argument(parser)
    parser.set_defaults(test_interval=2000)
    parser.set_defaults(max_steps=100000)
    parser.set_defaults(gpu=-1)
    parser.set_defaults(n_warmup=500)
    parser.set_defaults(batch_size=32)
    parser.set_defaults(memory_capacity=int(1e4))
    parser.add_argument('--env-name', type=str, default="CartPole-v0")
    args = parser.parse_args()

    env = gym.make(args.env_name)
    test_env = gym.make(args.env_name)
    policy = DQN(
        enable_double_dqn=args.enable_double_dqn,
        enable_dueling_dqn=args.enable_dueling_dqn,
        enable_noisy_dqn=args.enable_noisy_dqn,
        enable_categorical_dqn=args.enable_categorical_dqn,
        state_shape=env.observation_space.shape,
        action_dim=env.action_space.n,
        target_replace_interval=300,
        discount=0.99,
        gpu=args.gpu,
        memory_capacity=args.memory_capacity,
        batch_size=args.batch_size,
        n_warmup=args.n_warmup)
    trainer = Trainer(policy, env, args, test_env=test_env)
    trainer()
github keiohta / tf2rl / tf2rl / algos / categorical_dqn.py View on Github external
def call(self, inputs):
        """Compute probability
        """
        features = tf.concat(inputs, axis=1)
        features = self.l1(features)
        features = self.l2(features)
        if self._enable_dueling_dqn:
            raise NotImplementedError
        else:
            features = self.l3(features)  # [batch_size, action_dim * n_atoms]
            features = tf.reshape(features, (-1, self._action_dim, self._n_atoms))  # [batch_size, action_dim, n_atoms]
            features = tf.keras.activations.softmax(features, axis=2)  # [batch_size, action_dim, n_atoms]
            return tf.clip_by_value(features, 1e-8, 1.0-1e-8)


class CategoricalDQN(DQN):
    def __init__(self, *args, **kwargs):
        kwargs["q_func"] = CategoricalQFunc
        super().__init__(*args, **kwargs)
        self._v_max, self._v_min = 10., -10.
        self._delta_z = (self._v_max - self._v_min) / (self.q_func._n_atoms - 1)
        self._z_list = tf.constant(
            [self._v_min + i * self._delta_z for i in range(self.q_func._n_atoms)],
            dtype=tf.float64)
        self._z_list_broadcasted = tf.tile(
            tf.reshape(self._z_list, [1, self.q_func._n_atoms]),
            tf.constant([self._action_dim, 1]))

    def get_action(self, state, test=False):
        if isinstance(state, LazyFrames):
            state = np.array(state)
        assert isinstance(state, np.ndarray)
github keiohta / tf2rl / examples / run_apex_dqn.py View on Github external
def policy_fn(env, name, memory_capacity=int(1e6),
                  gpu=-1, noise_level=0.3):
        return DQN(
            name=name,
            enable_double_dqn=args.enable_double_dqn,
            enable_dueling_dqn=args.enable_dueling_dqn,
            enable_noisy_dqn=args.enable_noisy_dqn,
            enable_categorical_dqn=args.enable_categorical_dqn,
            state_shape=env.observation_space.shape,
            action_dim=env.action_space.n,
            n_warmup=n_warmup,
            target_replace_interval=target_replace_interval,
            batch_size=batch_size,
            memory_capacity=memory_capacity,
            discount=0.99,
            epsilon=1.,
            epsilon_min=0.1,
            epsilon_decay_step=epsilon_decay_rate,
            optimizer=optimizer,