Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_add_infinity_and_beyond(a, b, c):
infty = b * 10000000
for i in range(100):
z = poincare.math.expmap(a, infty, c=c)
z = poincare.math.project(z, c=c)
z = poincare.math.mobius_scalar_mul(1000.0, z, c=c)
z = poincare.math.project(z, c=c)
infty = poincare.math.parallel_transport(a, z, infty, c=c)
assert np.isfinite(z).all(), (i, z)
assert np.isfinite(infty).all(), (i, infty)
a = z
z = poincare.math.expmap(a, -infty, c=c)
# they just need to be very far, exact answer is not supposed
tolerance = {
torch.float32: dict(rtol=3e-1, atol=2e-1),
torch.float64: dict(rtol=1e-1, atol=1e-3),
}
np.testing.assert_allclose(z, -a, **tolerance[c.dtype])
def test_parallel_transport_a_b(a, b, c):
# pointing to the center
v_0 = torch.rand_like(a)
u_0 = torch.rand_like(a)
v_1 = poincare.math.parallel_transport(a, b, v_0, c=c)
u_1 = poincare.math.parallel_transport(a, b, u_0, c=c)
# compute norms
vu_1 = poincare.math.inner(b, v_1, u_1, c=c, keepdim=True)
vu_0 = poincare.math.inner(a, v_0, u_0, c=c, keepdim=True)
np.testing.assert_allclose(vu_0, vu_1, atol=1e-6, rtol=1e-6)
def test_parallel_transport0_is_same_as_usual(a, c):
# pointing to the center
v_0 = torch.rand_like(a) + 1e-5
zero = torch.zeros_like(a)
v_a = poincare.math.parallel_transport0(a, v_0, c=c)
v_a1 = poincare.math.parallel_transport(zero, a, v_0, c=c)
# compute norms
np.testing.assert_allclose(v_a, v_a1, atol=1e-6, rtol=1e-6)
def test_parallel_transport_a_b(a, b, c):
# pointing to the center
v_0 = torch.rand_like(a)
u_0 = torch.rand_like(a)
v_1 = poincare.math.parallel_transport(a, b, v_0, c=c)
u_1 = poincare.math.parallel_transport(a, b, u_0, c=c)
# compute norms
vu_1 = poincare.math.inner(b, v_1, u_1, c=c, keepdim=True)
vu_0 = poincare.math.inner(a, v_0, u_0, c=c, keepdim=True)
np.testing.assert_allclose(vu_0, vu_1, atol=1e-6, rtol=1e-6)
def transp(self, x: torch.Tensor, y: torch.Tensor, v: torch.Tensor, dim=-1):
return math.parallel_transport(x, y, v, c=self.c, dim=dim)