Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_batch_cat_and_stack():
# test cat with compatible keys
b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
b12_cat_out = Batch.cat([b1, b2])
b12_cat_in = copy.deepcopy(b1)
b12_cat_in.cat_(b2)
assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
assert isinstance(b12_cat_in.a.d.e, np.ndarray)
assert b12_cat_in.a.d.e.ndim == 1
a = Batch(a=Batch(a=np.random.randn(3, 4)))
assert np.allclose(
np.concatenate([a.a.a, a.a.a]),
Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a)
# test cat with lens infer
a = Batch(a=Batch(a=np.random.randn(3, 4)), b=np.random.randn(3, 4))
b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4))
ans = Batch.cat([a, b, a])
def test_batch_cat_and_stack():
# test cat with compatible keys
b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
b12_cat_out = Batch.cat([b1, b2])
b12_cat_in = copy.deepcopy(b1)
b12_cat_in.cat_(b2)
assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
assert isinstance(b12_cat_in.a.d.e, np.ndarray)
assert b12_cat_in.a.d.e.ndim == 1
a = Batch(a=Batch(a=np.random.randn(3, 4)))
assert np.allclose(
np.concatenate([a.a.a, a.a.a]),
Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a)
# test cat with lens infer
a = Batch(a=Batch(a=np.random.randn(3, 4)), b=np.random.randn(3, 4))
b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4))
ans = Batch.cat([a, b, a])
assert np.allclose(ans.a.a,
np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a]))
assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b]))
assert ans.a.t.is_empty()
b12_stack = Batch.stack((b1, b2))
assert isinstance(b12_stack.a.d.e, np.ndarray)
assert b12_stack.a.d.e.ndim == 2
# test cat with incompatible keys
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2])
ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
b=torch.cat([torch.zeros(3, 3), b2.b]),
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
assert np.allclose(test.a, ans.a)
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
# test cat with reserved keys (values are Batch())
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
b2 = Batch(a=Batch(),
b=torch.rand(4, 3),
common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2])
ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
b=torch.cat([torch.zeros(3, 3), b2.b]),
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
assert np.allclose(test.a, ans.a)
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
# test cat with all reserved keys (values are Batch())
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(3, 5)))
b2 = Batch(a=Batch(),
b=torch.rand(4, 3),
common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2])
ans = Batch(a=Batch(),
b=torch.cat([torch.zeros(3, 3), b2.b]),
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
b12_cat_in = copy.deepcopy(b1)
b12_cat_in.cat_(b2)
assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
assert isinstance(b12_cat_in.a.d.e, np.ndarray)
assert b12_cat_in.a.d.e.ndim == 1
a = Batch(a=Batch(a=np.random.randn(3, 4)))
assert np.allclose(
np.concatenate([a.a.a, a.a.a]),
Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a)
# test cat with lens infer
a = Batch(a=Batch(a=np.random.randn(3, 4)), b=np.random.randn(3, 4))
b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4))
ans = Batch.cat([a, b, a])
assert np.allclose(ans.a.a,
np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a]))
assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b]))
assert ans.a.t.is_empty()
b12_stack = Batch.stack((b1, b2))
assert isinstance(b12_stack.a.d.e, np.ndarray)
assert b12_stack.a.d.e.ndim == 2
# test cat with incompatible keys
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2])
ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
b=torch.cat([torch.zeros(3, 3), b2.b]),
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
b=torch.rand(4, 3),
common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2])
ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
b=torch.cat([torch.zeros(3, 3), b2.b]),
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
assert np.allclose(test.a, ans.a)
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
# test cat with all reserved keys (values are Batch())
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(3, 5)))
b2 = Batch(a=Batch(),
b=torch.rand(4, 3),
common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2])
ans = Batch(a=Batch(),
b=torch.cat([torch.zeros(3, 3), b2.b]),
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
assert ans.a.is_empty()
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
# test stack with compatible keys
b3 = Batch(a=np.zeros((3, 4)),
b=torch.ones((2, 5)),
c=Batch(d=[[1], [2]]))
b4 = Batch(a=np.ones((3, 4)),
b=torch.ones((2, 5)),
c=Batch(d=[[0], [3]]))
b34_stack = Batch.stack((b3, b4), axis=1)
assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1))