Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_tuple():
with pytest.raises(AssertionError):
Tuple(Discrete(10))
space = Tuple([Discrete(5),
Box(-1.0, 1.0, np.float32, shape=(2, 3)),
Dict({'success': Discrete(2), 'velocity': Box(-1, 1, np.float32, shape=(1, 3))})])
assert len(space.spaces) == 3
assert space.spaces[0] == Discrete(5)
assert space.spaces[1] == Box(-1.0, 1.0, np.float32, shape=(2, 3))
assert space.spaces[2] == Dict({'success': Discrete(2), 'velocity': Box(-1, 1, np.float32, shape=(1, 3))})
sample = space.sample()
assert isinstance(sample, tuple) and len(sample) == 3
assert sample[0] in Discrete(5)
assert sample[1] in Box(-1.0, 1.0, np.float32, shape=(2, 3))
assert sample[2] in Dict({'success': Discrete(2), 'velocity': Box(-1, 1, np.float32, shape=(1, 3))})
assert space.flat_dim == 5+2*3+2+3
assert space.flatten(sample).shape == (16,)
sample2 = space.unflatten(space.flatten(sample))
assert sample[0] == sample2[0]
assert np.allclose(sample[1], sample2[1])
assert sample[2]['success'] == sample2[2]['success']
def test_box():
with pytest.raises(AssertionError):
Box(-1.0, 1.0, dtype=None)
with pytest.raises(AssertionError):
Box(-1.0, [1.0, 2.0], np.float32, shape=(2,))
with pytest.raises(AttributeError):
Box(np.array([-1.0, -2.0]), [3.0, 4.0], np.float32)
def check(box):
assert all([dtype == np.float32 for dtype in [box.dtype, box.low.dtype, box.high.dtype]])
assert all([s == (2, 3) for s in [box.shape, box.low.shape, box.high.shape]])
assert np.allclose(box.low, np.full([2, 3], -1.0))
assert np.allclose(box.high, np.full([2, 3], 1.0))
sample = box.sample()
assert sample.shape == (2, 3) and sample.dtype == np.float32
assert box.flat_dim == 6 and isinstance(box.flat_dim, int)
assert box.flatten(sample).shape == (6,)
assert np.allclose(sample, box.unflatten(box.flatten(sample)))
assert sample in box
assert str(box) == 'Box(2, 3)'
assert box == Box(-1.0, 1.0, np.float32, shape=[2, 3])
del box, sample
def test_tuple():
with pytest.raises(AssertionError):
Tuple(Discrete(10))
space = Tuple([Discrete(5),
Box(-1.0, 1.0, np.float32, shape=(2, 3)),
Dict({'success': Discrete(2), 'velocity': Box(-1, 1, np.float32, shape=(1, 3))})])
assert len(space.spaces) == 3
assert space.spaces[0] == Discrete(5)
assert space.spaces[1] == Box(-1.0, 1.0, np.float32, shape=(2, 3))
assert space.spaces[2] == Dict({'success': Discrete(2), 'velocity': Box(-1, 1, np.float32, shape=(1, 3))})
sample = space.sample()
assert isinstance(sample, tuple) and len(sample) == 3
assert sample[0] in Discrete(5)
assert sample[1] in Box(-1.0, 1.0, np.float32, shape=(2, 3))
assert sample[2] in Dict({'success': Discrete(2), 'velocity': Box(-1, 1, np.float32, shape=(1, 3))})
assert space.flat_dim == 5+2*3+2+3
assert space.flatten(sample).shape == (16,)
sample2 = space.unflatten(space.flatten(sample))
assert sample[0] == sample2[0]
assert np.allclose(sample[1], sample2[1])
assert sample[2]['success'] == sample2[2]['success']
assert np.allclose(sample[2]['velocity'], sample2[2]['velocity'])
def test_convert_gym_space():
# Discrete
gym_space = gym.spaces.Discrete(n=5)
lagom_space = convert_gym_space(gym_space)
assert isinstance(lagom_space, Discrete)
assert not isinstance(lagom_space, gym.spaces.Discrete)
assert lagom_space.n == 5
assert lagom_space.sample() in lagom_space
del gym_space, lagom_space
# Box
gym_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(2, 3), dtype=np.float32)
lagom_space = convert_gym_space(gym_space)
assert isinstance(lagom_space, Box)
assert not isinstance(lagom_space, gym.spaces.Box)
assert lagom_space.shape == (2, 3)
assert lagom_space.sample() in lagom_space
del gym_space, lagom_space
# Dict
gym_space = gym.spaces.Dict({
'sensors': gym.spaces.Dict({
'position': gym.spaces.Box(low=-100, high=100, shape=(3,), dtype=np.float32),
'velocity': gym.spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32)}),
'charge': gym.spaces.Discrete(100)})
lagom_space = convert_gym_space(gym_space)
assert isinstance(lagom_space, Dict)
assert not isinstance(lagom_space, gym.spaces.Dict)
assert len(lagom_space.spaces) == 2
def test_vec_env(vec_env_class):
# unpack class
v_id, vec_env_class = vec_env_class
venv = make_vec_env(vec_env_class, make_gym_env, 'CartPole-v1', 5, 1, True)
assert isinstance(venv, VecEnv)
assert v_id in [0, 1]
if v_id == 0:
isinstance(venv, SerialVecEnv)
elif v_id == 1:
assert isinstance(venv, ParallelVecEnv)
assert venv.num_env == 5
assert not venv.closed and venv.viewer is None
assert venv.unwrapped is venv
assert isinstance(venv.observation_space, Box)
assert isinstance(venv.action_space, Discrete)
assert venv.T == 500
assert venv.max_episode_reward == 475.0
assert venv.reward_range == (-float('inf'), float('inf'))
obs = venv.reset()
assert len(obs) == 5
assert np.asarray(obs).shape == (5, 4)
assert all([not np.allclose(obs[0], obs[i]) for i in [1, 2, 3, 4]])
a = [1]*5
obs, rewards, dones, infos = venv.step(a)
assert all([len(item) == 5 for item in [obs, rewards, dones, infos]])
assert all([not np.allclose(obs[0], obs[i]) for i in [1, 2, 3, 4]])
# EnvSpec
env_spec = EnvSpec(venv)
assert isinstance(env_spec.action_space, Discrete)
def test_box():
with pytest.raises(AssertionError):
Box(-1.0, 1.0, dtype=None)
with pytest.raises(AssertionError):
Box(-1.0, [1.0, 2.0], np.float32, shape=(2,))
with pytest.raises(AttributeError):
Box(np.array([-1.0, -2.0]), [3.0, 4.0], np.float32)
def check(box):
assert all([dtype == np.float32 for dtype in [box.dtype, box.low.dtype, box.high.dtype]])
assert all([s == (2, 3) for s in [box.shape, box.low.shape, box.high.shape]])
assert np.allclose(box.low, np.full([2, 3], -1.0))
assert np.allclose(box.high, np.full([2, 3], 1.0))
sample = box.sample()
assert sample.shape == (2, 3) and sample.dtype == np.float32
assert box.flat_dim == 6 and isinstance(box.flat_dim, int)
assert box.flatten(sample).shape == (6,)
assert np.allclose(sample, box.unflatten(box.flatten(sample)))
assert sample in box
def check(box):
assert all([dtype == np.float32 for dtype in [box.dtype, box.low.dtype, box.high.dtype]])
assert all([s == (2, 3) for s in [box.shape, box.low.shape, box.high.shape]])
assert np.allclose(box.low, np.full([2, 3], -1.0))
assert np.allclose(box.high, np.full([2, 3], 1.0))
sample = box.sample()
assert sample.shape == (2, 3) and sample.dtype == np.float32
assert box.flat_dim == 6 and isinstance(box.flat_dim, int)
assert box.flatten(sample).shape == (6,)
assert np.allclose(sample, box.unflatten(box.flatten(sample)))
assert sample in box
assert str(box) == 'Box(2, 3)'
assert box == Box(-1.0, 1.0, np.float32, shape=[2, 3])
del box, sample
def __init__(self, list_make_env):
self.envs = [make_env() for make_env in list_make_env]
observation_space = self.envs[0].observation_space
action_space = self.envs[0].action_space
super().__init__(len(list_make_env), observation_space, action_space)
assert isinstance(self.observation_space, (Box, Dict)) # enforce observation space either Box or Dict
def __init__(self, env, keys):
super().__init__(env)
self.keys = keys
spaces = self.env.observation_space.spaces
assert all([isinstance(space, Box) for space in spaces.values()]) # enforce all Box spaces
# Calculate dimensionality
shape = (int(np.sum([spaces[key].flat_dim for key in self.keys])), )
self._observation_space = Box(low=-np.inf, high=np.inf, shape=shape, dtype=np.float32)
def observation_space(self):
# Update observation space
return Box(0, self.env.get_source_env().maze_size[0]-1, shape=(4,), dtype=np.float32)