Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
assert b5.b.d[0] == b5_dict[0]['b']['d']
assert b5.b.d[1] == 0.0
# test stack with incompatible keys
a = Batch(a=1, b=2, c=3)
b = Batch(a=4, b=5, d=6)
c = Batch(c=7, b=6, d=9)
d = Batch.stack([a, b, c])
assert np.allclose(d.a, [1, 4, 0])
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
# test stack with empty Batch()
assert Batch.stack([Batch(), Batch(), Batch()]).is_empty()
a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch())
b = Batch(a=4, b=5, d=6, e=Batch())
c = Batch(c=7, b=6, d=9, e=Batch())
d = Batch.stack([a, b, c])
assert np.allclose(d.a, [1, 4, 0])
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
assert d.e.is_empty()
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2], axis=-1)
assert test.a.is_empty()
assert test.b.is_empty()
assert np.allclose(test.common.c,
np.stack([b1.common.c, b2.common.c], axis=-1))
def test_batch_from_to_numpy_without_copy():
batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
a_mem_addr_orig = batch.a.__array_interface__['data'][0]
c_mem_addr_orig = batch.b.c.__array_interface__['data'][0]
batch.to_torch()
batch.to_numpy()
a_mem_addr_new = batch.a.__array_interface__['data'][0]
c_mem_addr_new = batch.b.c.__array_interface__['data'][0]
assert a_mem_addr_new == a_mem_addr_orig
assert c_mem_addr_new == c_mem_addr_orig
def test_batch_empty():
b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}},
{'a': True, 'b': {'c': 3.0}}])
b5 = Batch(b5_dict)
b5[1] = Batch.empty(b5[0])
assert np.allclose(b5.a, [False, False])
assert np.allclose(b5.b.c, [2, 0])
assert np.allclose(b5.b.d, [1, 0])
data = Batch(a=[False, True],
b={'c': np.array([2., 'st'], dtype=np.object),
'd': [1, None],
'e': [2., float('nan')]},
c=np.array([1, 3, 4], dtype=np.int),
t=torch.tensor([4, 5, 6, 7.]))
data[-1] = Batch.empty(data[1])
assert np.allclose(data.c, [1, 3, 0])
assert np.allclose(data.a, [False, False])
assert list(data.b.c) == [2.0, None]
assert list(data.b.d) == [1, None]
assert np.allclose(data.b.e, [2, 0])
assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.]))
data[0].empty_() # which will fail in a, b.c, b.d, b.e, c
assert torch.allclose(data.t, torch.tensor([0., 5, 6, 0]))
data.empty_(index=0)
assert np.allclose(data.c, [0, 3, 0])
assert list(data.b.c) == [None, None]
assert list(data.b.d) == [None, None]
assert list(data.b.e) == [0, 0]
b0 = Batch()
b0.empty_()
assert b0.shape == []
assert b5.b.d[1] == 0.0
# test stack with incompatible keys
a = Batch(a=1, b=2, c=3)
b = Batch(a=4, b=5, d=6)
c = Batch(c=7, b=6, d=9)
d = Batch.stack([a, b, c])
assert np.allclose(d.a, [1, 4, 0])
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
# test stack with empty Batch()
assert Batch.stack([Batch(), Batch(), Batch()]).is_empty()
a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch())
b = Batch(a=4, b=5, d=6, e=Batch())
c = Batch(c=7, b=6, d=9, e=Batch())
d = Batch.stack([a, b, c])
assert np.allclose(d.a, [1, 4, 0])
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
assert d.e.is_empty()
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2], axis=-1)
assert test.a.is_empty()
assert test.b.is_empty()
assert np.allclose(test.common.c,
np.stack([b1.common.c, b2.common.c], axis=-1))
b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5)))
assert batch2_from_comp.a.d.e == batch2.a.d.e
for batch_slice in [
batch2[slice(0, 1)], batch2[:1], batch2[0:]]:
assert batch_slice.a.b == batch2.a.b
assert batch_slice.a.c == batch2.a.c
assert batch_slice.a.d.e == batch2.a.d.e
batch2_sum = (batch2 + 1.0) * 2
assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2
assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2
assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2
batch3 = Batch(a={
'c': np.zeros(1),
'd': Batch(e=np.array([0.0]), f=np.array([3.0]))})
batch3.a.d[0] = {'e': 4.0}
assert batch3.a.d.e[0] == 4.0
batch3.a.d[0] = Batch(f=5.0)
assert batch3.a.d.f[0] == 5.0
with pytest.raises(KeyError):
batch3.a.d[0] = Batch(f=5.0, g=0.0)
# auto convert
batch4 = Batch(a=np.array(['a', 'b']))
assert batch4.a.dtype == np.object # auto convert to np.object
batch4.update(a=np.array(['c', 'd']))
assert list(batch4.a) == ['c', 'd']
assert batch4.a.dtype == np.object # auto convert to np.object
batch5 = Batch(a=np.array([{'index': 0}]))
assert isinstance(batch5.a, Batch)
assert np.allclose(batch5.a.index, [0])
batch5.b = np.array([{'index': 1}])
assert isinstance(batch5.b, Batch)
assert np.allclose(batch5.b.index, [1])
def test_batch_pickle():
batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])),
np=np.zeros([3, 4]))
batch_pk = pickle.loads(pickle.dumps(batch))
assert batch.obs.a == batch_pk.obs.a
assert torch.all(batch.obs.c == batch_pk.obs.c)
assert np.all(batch.np == batch_pk.np)
assert isinstance(b12_stack.a.d.e, np.ndarray)
assert b12_stack.a.d.e.ndim == 2
# test cat with incompatible keys
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2])
ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
b=torch.cat([torch.zeros(3, 3), b2.b]),
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
assert np.allclose(test.a, ans.a)
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
# test cat with reserved keys (values are Batch())
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
b2 = Batch(a=Batch(),
b=torch.rand(4, 3),
common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2])
ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
b=torch.cat([torch.zeros(3, 3), b2.b]),
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
assert np.allclose(test.a, ans.a)
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
# test cat with all reserved keys (values are Batch())
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(3, 5)))
b2 = Batch(a=Batch(),
b=torch.rand(4, 3),
common=Batch(c=np.random.rand(4, 5)))
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
# test stack with empty Batch()
assert Batch.stack([Batch(), Batch(), Batch()]).is_empty()
a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch())
b = Batch(a=4, b=5, d=6, e=Batch())
c = Batch(c=7, b=6, d=9, e=Batch())
d = Batch.stack([a, b, c])
assert np.allclose(d.a, [1, 4, 0])
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
assert d.e.is_empty()
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2], axis=-1)
assert test.a.is_empty()
assert test.b.is_empty()
assert np.allclose(test.common.c,
np.stack([b1.common.c, b2.common.c], axis=-1))
b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2])
ans = Batch(a=np.stack([b1.a, np.zeros((4, 4))]),
b=torch.stack([torch.zeros(4, 6), b2.b]),
common=Batch(c=np.stack([b1.common.c, b2.common.c])))
assert np.allclose(test.a, ans.a)
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
# calculate the next action
if random:
action_space = self.env.action_space
if isinstance(action_space, list):
result = Batch(act=[a.sample() for a in action_space])
else:
result = Batch(act=self._make_batch(action_space.sample()))
else:
with torch.no_grad():
result = self.policy(self.data, last_state)
# convert None to Batch(), since None is reserved for 0-init
state = result.get('state', Batch())
if state is None:
state = Batch()
self.data.state = state
if hasattr(result, 'policy'):
self.data.policy = to_numpy(result.policy)
# save hidden state to policy._state, in order to save into buffer
self.data.policy._state = self.data.state
self.data.act = to_numpy(result.act)
if self._action_noise is not None:
self.data.act += self._action_noise(self.data.act.shape)
# step in env
obs_next, rew, done, info = self.env.step(
self.data.act if self._multi_env else self.data.act[0])
# move data to self.data
if not self._multi_env: