Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_random_policy():
env = make_gym_env('Pendulum-v0', 0)
env_spec = EnvSpec(env)
policy = RandomPolicy(None, env_spec)
out = policy(env.reset())
assert isinstance(out, dict)
assert 'action' in out and out['action'].shape == (1,)
venv = make_vec_env(SerialVecEnv, make_gym_env, 'CartPole-v0', 3, 0)
env_spec = EnvSpec(venv)
policy = RandomPolicy(None, env_spec)
out = policy(env.reset())
assert isinstance(out, dict)
assert 'action' in out and len(out['action']) == 3 and isinstance(out['action'][0], int)
def test_vec_env(vec_env_class):
# unpack class
v_id, vec_env_class = vec_env_class
venv = make_vec_env(vec_env_class, make_gym_env, 'CartPole-v1', 5, 1, True)
assert isinstance(venv, VecEnv)
assert v_id in [0, 1]
if v_id == 0:
isinstance(venv, SerialVecEnv)
elif v_id == 1:
assert isinstance(venv, ParallelVecEnv)
assert venv.num_env == 5
assert not venv.closed and venv.viewer is None
assert venv.unwrapped is venv
assert isinstance(venv.observation_space, Box)
assert isinstance(venv.action_space, Discrete)
assert venv.T == 500
assert venv.max_episode_reward == 475.0
assert venv.reward_range == (-float('inf'), float('inf'))
obs = venv.reset()
def test_categorical_head():
with pytest.raises(AssertionError):
env = make_gym_env('Pendulum-v0', 0)
env_spec = EnvSpec(env)
CategoricalHead(None, None, 30, env_spec)
env = make_gym_env('CartPole-v1', 0)
env_spec = EnvSpec(env)
head = CategoricalHead(None, None, 30, env_spec)
assert head.feature_dim == 30
assert isinstance(head.action_head, nn.Linear)
assert head.action_head.in_features == 30 and head.action_head.out_features == 2
dist = head(torch.randn(3, 30))
assert isinstance(dist, Categorical)
assert list(dist.batch_shape) == [3]
assert list(dist.probs.shape) == [3, 2]
action = dist.sample()
assert action.shape == (3,)
def test_reward_scale():
env = make_gym_env(env_id='CartPole-v1', seed=0)
env = RewardScale(env, scale=0.02)
env.reset()
observation, reward, done, info = env.step(env.action_space.sample())
assert reward == 0.02
def test_random_policy():
env = make_gym_env('Pendulum-v0', 0)
env_spec = EnvSpec(env)
policy = RandomPolicy(None, env_spec)
out = policy(env.reset())
assert isinstance(out, dict)
assert 'action' in out and out['action'].shape == (1,)
venv = make_vec_env(SerialVecEnv, make_gym_env, 'CartPole-v0', 3, 0)
env_spec = EnvSpec(venv)
policy = RandomPolicy(None, env_spec)
out = policy(env.reset())
assert isinstance(out, dict)
assert 'action' in out and len(out['action']) == 3 and isinstance(out['action'][0], int)
def test_diag_gaussian_head():
with pytest.raises(AssertionError):
env = make_gym_env('CartPole-v1', 0)
env_spec = EnvSpec(env)
DiagGaussianHead(None, None, 30, env_spec)
env = make_gym_env('Pendulum-v0', 0)
env_spec = EnvSpec(env)
head = DiagGaussianHead(None, None, 30, env_spec)
assert head.feature_dim == 30
assert isinstance(head.mean_head, nn.Linear)
assert isinstance(head.logstd_head, nn.Parameter)
assert head.mean_head.in_features == 30 and head.mean_head.out_features == 1
assert list(head.logstd_head.shape) == [1]
assert torch.eq(head.logstd_head, torch.tensor(-0.510825624))
dist = head(torch.randn(3, 30))
assert isinstance(dist, Independent) and isinstance(dist.base_dist, Normal)
assert list(dist.batch_shape) == [3]
action = dist.sample()
assert list(action.shape) == [3, 1]
head = DiagGaussianHead(None, None, 30 , env_spec, std_style='softplus')
dist = head(torch.randn(3, 30))
def __call__(self, config, seed, device):
set_global_seeds(seed)
logdir = Path(config['log.dir']) / str(config['ID']) / str(seed)
env = make_vec_env(SerialVecEnv, make_gym_env, config['env.id'], config['train.N'], seed)
eval_env = make_vec_env(SerialVecEnv, make_gym_env, config['env.id'], config['eval.N'], seed)
if config['env.standardize']: # running averages of observation and reward
env = VecStandardize(venv=env,
use_obs=True,
use_reward=True,
clip_obs=10.,
clip_reward=10.,
gamma=0.99,
eps=1e-8)
eval_env = VecStandardize(venv=eval_env,
use_obs=True,
use_reward=False, # do not process rewards, no training
clip_obs=env.clip_obs,
clip_reward=env.clip_reward,
gamma=env.gamma,
eps=env.eps,
def _prepare(self, config):
self.env = make_vec_env(SerialVecEnv, make_gym_env, config['env.id'], config['train.N'], 0)
self.env = VecClipAction(self.env)
if config['env.standardize']:
self.env = VecStandardize(self.env,
use_obs=True,
use_reward=False,
clip_obs=10.0,
clip_reward=10.0,
gamma=0.99,
eps=1e-08)
self.env_spec = EnvSpec(self.env)
self.device = torch.device('cpu')
self.agent = Agent(config, self.env_spec, self.device)
def __call__(self, config, seed, device):
set_global_seeds(seed)
logdir = Path(config['log.dir']) / str(config['ID']) / str(seed)
if config['env.time_aware_obs']:
kwargs = {'extra_wrapper': [TimeAwareObservation]}
else:
kwargs = {}
env = make_vec_env(SerialVecEnv, make_gym_env, config['env.id'], config['train.N'], seed, monitor=True, **kwargs)
if config['eval.independent']:
eval_env = make_vec_env(SerialVecEnv, make_gym_env, config['env.id'], config['eval.N'], seed)
if config['env.clip_action']:
env = VecClipAction(env)
if config['eval.independent']:
eval_env = VecClipAction(eval_env)
if config['env.standardize']: # running averages of observation and reward
env = VecStandardize(venv=env,
use_obs=True,
use_reward=True,
clip_obs=10.,
clip_reward=10.,
gamma=0.99,
eps=1e-8)
env_spec = EnvSpec(env)
agent = Agent(config, env_spec, device)
def __call__(self, config, seed, device):
set_global_seeds(seed)
logdir = Path(config['log.dir']) / str(config['ID']) / str(seed)
env = make_vec_env(vec_env_class=SerialVecEnv,
make_env=make_gym_env,
env_id=config['env.id'],
num_env=config['train.N'], # batched environment
init_seed=seed)
eval_env = make_vec_env(vec_env_class=SerialVecEnv,
make_env=make_gym_env,
env_id=config['env.id'],
num_env=config['eval.N'],
init_seed=seed)
if config['env.standardize']: # running averages of observation and reward
env = VecStandardize(venv=env,
use_obs=True,
use_reward=True,
clip_obs=10.,
clip_reward=10.,
gamma=0.99,
eps=1e-8)
eval_env = VecStandardize(venv=eval_env,
use_obs=True,
use_reward=False, # do not process rewards, no training
clip_obs=env.clip_obs,
clip_reward=env.clip_reward,