Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
a = Batch(a=Batch(a=np.random.randn(3, 4)))
assert np.allclose(
np.concatenate([a.a.a, a.a.a]),
Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a)
# test cat with lens infer
a = Batch(a=Batch(a=np.random.randn(3, 4)), b=np.random.randn(3, 4))
b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4))
ans = Batch.cat([a, b, a])
assert np.allclose(ans.a.a,
np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a]))
assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b]))
assert ans.a.t.is_empty()
b12_stack = Batch.stack((b1, b2))
assert isinstance(b12_stack.a.d.e, np.ndarray)
assert b12_stack.a.d.e.ndim == 2
# test cat with incompatible keys
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2])
ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
b=torch.cat([torch.zeros(3, 3), b2.b]),
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
assert np.allclose(test.a, ans.a)
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
# test cat with reserved keys (values are Batch())
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
assert np.allclose(d.d, [0, 6, 9])
# test stack with empty Batch()
assert Batch.stack([Batch(), Batch(), Batch()]).is_empty()
a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch())
b = Batch(a=4, b=5, d=6, e=Batch())
c = Batch(c=7, b=6, d=9, e=Batch())
d = Batch.stack([a, b, c])
assert np.allclose(d.a, [1, 4, 0])
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
assert d.e.is_empty()
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2], axis=-1)
assert test.a.is_empty()
assert test.b.is_empty()
assert np.allclose(test.common.c,
np.stack([b1.common.c, b2.common.c], axis=-1))
b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2])
ans = Batch(a=np.stack([b1.a, np.zeros((4, 4))]),
b=torch.stack([torch.zeros(4, 6), b2.b]),
common=Batch(c=np.stack([b1.common.c, b2.common.c])))
assert np.allclose(test.a, ans.a)
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
b34_stack = Batch.stack((b3, b4), axis=1)
assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1))
assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d))))
b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}},
{'a': True, 'b': {'c': 3.0}}])
b5 = Batch(b5_dict)
assert b5.a[0] == np.array(False) and b5.a[1] == np.array(True)
assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0))
assert b5.b.d[0] == b5_dict[0]['b']['d']
assert b5.b.d[1] == 0.0
# test stack with incompatible keys
a = Batch(a=1, b=2, c=3)
b = Batch(a=4, b=5, d=6)
c = Batch(c=7, b=6, d=9)
d = Batch.stack([a, b, c])
assert np.allclose(d.a, [1, 4, 0])
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
# test stack with empty Batch()
assert Batch.stack([Batch(), Batch(), Batch()]).is_empty()
a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch())
b = Batch(a=4, b=5, d=6, e=Batch())
c = Batch(c=7, b=6, d=9, e=Batch())
d = Batch.stack([a, b, c])
assert np.allclose(d.a, [1, 4, 0])
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
assert d.e.is_empty()
assert np.allclose(d.a, [1, 4, 0])
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
assert d.e.is_empty()
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2], axis=-1)
assert test.a.is_empty()
assert test.b.is_empty()
assert np.allclose(test.common.c,
np.stack([b1.common.c, b2.common.c], axis=-1))
b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2])
ans = Batch(a=np.stack([b1.a, np.zeros((4, 4))]),
b=torch.stack([torch.zeros(4, 6), b2.b]),
common=Batch(c=np.stack([b1.common.c, b2.common.c])))
assert np.allclose(test.a, ans.a)
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
batch3.a.d[0] = Batch(f=5.0, g=0.0)
# auto convert
batch4 = Batch(a=np.array(['a', 'b']))
assert batch4.a.dtype == np.object # auto convert to np.object
batch4.update(a=np.array(['c', 'd']))
assert list(batch4.a) == ['c', 'd']
assert batch4.a.dtype == np.object # auto convert to np.object
batch5 = Batch(a=np.array([{'index': 0}]))
assert isinstance(batch5.a, Batch)
assert np.allclose(batch5.a.index, [0])
batch5.b = np.array([{'index': 1}])
assert isinstance(batch5.b, Batch)
assert np.allclose(batch5.b.index, [1])
# None is a valid object and can be stored in Batch
a = Batch.stack([Batch(a=None), Batch(b=None)])
assert a.a[0] is None and a.a[1] is None
assert a.b[0] is None and a.b[1] is None
assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0))
assert b5.b.d[0] == b5_dict[0]['b']['d']
assert b5.b.d[1] == 0.0
# test stack with incompatible keys
a = Batch(a=1, b=2, c=3)
b = Batch(a=4, b=5, d=6)
c = Batch(c=7, b=6, d=9)
d = Batch.stack([a, b, c])
assert np.allclose(d.a, [1, 4, 0])
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
# test stack with empty Batch()
assert Batch.stack([Batch(), Batch(), Batch()]).is_empty()
a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch())
b = Batch(a=4, b=5, d=6, e=Batch())
c = Batch(c=7, b=6, d=9, e=Batch())
d = Batch.stack([a, b, c])
assert np.allclose(d.a, [1, 4, 0])
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
assert d.e.is_empty()
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2], axis=-1)
assert test.a.is_empty()
assert test.b.is_empty()
assert np.allclose(test.common.c,
np.stack([b1.common.c, b2.common.c], axis=-1))