Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def setUpClass(cls):
super().setUpClass()
cls.agent = DDPG(
state_shape=cls.continuous_env.observation_space.shape,
action_dim=cls.continuous_env.action_space.low.size,
batch_size=cls.batch_size,
sigma=0.5, # Make noise bigger to easier to test
gpu=-1)
def policy_fn(env, name, memory_capacity=int(1e6), gpu=-1, *args, **kwargs):
return DDPG(
state_shape=env.observation_space.shape,
action_dim=env.action_space.high.size,
n_warmup=500,
gpu=-1)
if __name__ == '__main__':
parser = IRLTrainer.get_argument()
parser = VAIL.get_argument(parser)
parser.add_argument('--env-name', type=str, default="RoboschoolReacher-v1")
args = parser.parse_args()
if args.expert_path_dir is None:
print("Plaese generate demonstrations first")
print("python examples/run_sac.py --env-name=RoboschoolReacher-v1 --save-test-path --test-interval=50000")
exit()
units = [400, 300]
env = gym.make(args.env_name)
test_env = gym.make(args.env_name)
policy = DDPG(
state_shape=env.observation_space.shape,
action_dim=env.action_space.high.size,
max_action=env.action_space.high[0],
gpu=args.gpu,
actor_units=units,
critic_units=units,
n_warmup=10000,
batch_size=100)
irl = VAIL(
state_shape=env.observation_space.shape,
action_dim=env.action_space.high.size,
units=units,
enable_sn=args.enable_sn,
batch_size=32,
gpu=args.gpu)
expert_trajs = restore_latest_n_traj(
def make_policy(env, name, tf2rl=False):
dim_state, dim_action = env.observation_space.shape[0], env.action_space.shape[0]
max_action = env.action_space.high
with tf.device("/gpu:0"):
if name == "DDPG":
if tf2rl:
policy = DDPG_tf2rl(state_shape=(dim_state,), action_dim=dim_action,
max_action=max_action[0], max_grad=1.)
saved_policy = DDPG_tf2rl(state_shape=(dim_state,), action_dim=dim_action,
max_action=max_action[0], max_grad=1.)
else:
policy = DDPG.DDPG(dim_state, dim_action, max_action, training=True)
saved_policy = DDPG.DDPG(dim_state, dim_action, max_action, training=False)
elif name == "SAC":
if tf2rl:
policy = SAC_tf2rl(state_shape=(dim_state,), action_dim=dim_action, max_action=max_action[0])
saved_policy = SAC_tf2rl(state_shape=(dim_state,), action_dim=dim_action, max_action=max_action[0])
else:
policy = SAC.SAC(dim_state, dim_action, max_action, training=True)
saved_policy = SAC.SAC(dim_state, dim_action, max_action, training=False)
else:
raise ValueError("invalid policy")
from tf2rl.algos.ddpg import DDPG
from tf2rl.experiments.trainer import Trainer
if __name__ == '__main__':
parser = Trainer.get_argument()
parser = DDPG.get_argument(parser)
parser.add_argument('--env-name', type=str, default="RoboschoolAnt-v1")
parser.set_defaults(batch_size=100)
parser.set_defaults(n_warmup=10000)
args = parser.parse_args()
env = gym.make(args.env_name)
test_env = gym.make(args.env_name)
policy = DDPG(
state_shape=env.observation_space.shape,
action_dim=env.action_space.high.size,
gpu=args.gpu,
memory_capacity=args.memory_capacity,
max_action=env.action_space.high[0],
batch_size=args.batch_size,
n_warmup=args.n_warmup)
trainer = Trainer(policy, env, args, test_env=test_env)
trainer()
def make_policy(env, name, tf2rl=False):
dim_state, dim_action = env.observation_space.shape[0], env.action_space.shape[0]
max_action = env.action_space.high
with tf.device("/gpu:0"):
if name == "DDPG":
if tf2rl:
policy = DDPG_tf2rl(state_shape=(dim_state,), action_dim=dim_action,
max_action=max_action[0], max_grad=1.)
saved_policy = DDPG_tf2rl(state_shape=(dim_state,), action_dim=dim_action,
max_action=max_action[0], max_grad=1.)
else:
policy = DDPG.DDPG(dim_state, dim_action, max_action, training=True)
saved_policy = DDPG.DDPG(dim_state, dim_action, max_action, training=False)
elif name == "SAC":
if tf2rl:
policy = SAC_tf2rl(state_shape=(dim_state,), action_dim=dim_action, max_action=max_action[0])
saved_policy = SAC_tf2rl(state_shape=(dim_state,), action_dim=dim_action, max_action=max_action[0])
else:
policy = SAC.SAC(dim_state, dim_action, max_action, training=True)
saved_policy = SAC.SAC(dim_state, dim_action, max_action, training=False)
else:
raise ValueError("invalid policy")
return policy, saved_policy