How to use the tf2rl.experiments.on_policy_trainer.OnPolicyTrainer.get_argument function in tf2rl

To help you get started, we’ve selected a few tf2rl examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github keiohta / tf2rl / examples / run_ppo.py View on Github external
import gym

from tf2rl.algos.ppo import PPO
from tf2rl.policies.categorical_actor import CategoricalActorCritic
from tf2rl.experiments.on_policy_trainer import OnPolicyTrainer
from tf2rl.envs.utils import is_discrete, get_act_dim


if __name__ == '__main__':
    parser = OnPolicyTrainer.get_argument()
    parser = PPO.get_argument(parser)
    parser.add_argument('--env-name', type=str,
                        default="Pendulum-v0")
    parser.set_defaults(test_interval=20480)
    parser.set_defaults(max_steps=int(1e7))
    parser.set_defaults(horizon=2048)
    parser.set_defaults(batch_size=64)
    parser.set_defaults(gpu=-1)
    args = parser.parse_args()

    env = gym.make(args.env_name)
    test_env = gym.make(args.env_name)

    policy = PPO(
        state_shape=env.observation_space.shape,
        action_dim=get_act_dim(env.action_space),
github keiohta / tf2rl / examples / run_vpg.py View on Github external
import gym

from tf2rl.algos.vpg import VPG
from tf2rl.experiments.on_policy_trainer import OnPolicyTrainer
from tf2rl.envs.utils import is_discrete, get_act_dim


if __name__ == '__main__':
    parser = OnPolicyTrainer.get_argument()
    parser = VPG.get_argument(parser)
    parser.add_argument('--env-name', type=str,
                        default="Pendulum-v0")
    parser.add_argument('--normalize-adv', action='store_true')
    parser.add_argument('--enable-gae', action='store_true')
    parser.set_defaults(test_interval=5000)
    parser.set_defaults(max_steps=int(1e7))
    parser.set_defaults(horizon=1000)
    parser.set_defaults(gpu=-1)
    args = parser.parse_args()

    env = gym.make(args.env_name)
    test_env = gym.make(args.env_name)
    policy = VPG(
        state_shape=env.observation_space.shape,
        action_dim=get_act_dim(env.action_space),
github keiohta / tf2rl / examples / run_ppo_atari.py View on Github external
import gym

import numpy as np
import tensorflow as tf

from tf2rl.algos.ppo import PPO
from tf2rl.envs.atari_wrapper import wrap_dqn
from tf2rl.experiments.on_policy_trainer import OnPolicyTrainer
from tf2rl.networks.atari_model import AtariCategoricalActorCritic


if __name__ == '__main__':
    parser = OnPolicyTrainer.get_argument()
    parser = PPO.get_argument(parser)
    parser.add_argument('--env-name', type=str,
                        default="PongNoFrameskip-v4")
    parser.set_defaults(episode_max_steps=108000)
    parser.set_defaults(horizon=1024)
    parser.set_defaults(test_interval=204800)
    parser.set_defaults(max_steps=int(1e9))
    parser.set_defaults(save_model_interval=2048000)
    parser.set_defaults(gpu=0)
    parser.set_defaults(show_test_images=True)
    args = parser.parse_args()

    env = wrap_dqn(gym.make(args.env_name))
    test_env = wrap_dqn(gym.make(args.env_name), reward_clipping=False)

    state_shape = env.observation_space.shape
github keiohta / tf2rl / examples / run_ppo_pendulum.py View on Github external
import gym

from tf2rl.algos.ppo import PPO
from tf2rl.experiments.on_policy_trainer import OnPolicyTrainer
from tf2rl.envs.utils import is_discrete, get_act_dim


if __name__ == '__main__':
    parser = OnPolicyTrainer.get_argument()
    parser = PPO.get_argument(parser)
    parser.add_argument('--env-name', type=str,
                        default="Pendulum-v0")
    parser.set_defaults(test_interval=10240)
    parser.set_defaults(max_steps=int(1e7))
    parser.set_defaults(horizon=512)
    parser.set_defaults(batch_size=32)
    parser.set_defaults(gpu=-1)
    args = parser.parse_args()

    env = gym.make(args.env_name)
    test_env = gym.make(args.env_name)
    policy = PPO(
        state_shape=env.observation_space.shape,
        action_dim=get_act_dim(env.action_space),
        is_discrete=is_discrete(env.action_space),