Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_n_additions_via_scalar_multiplication(n, a, c):
y = torch.zeros_like(a)
for _ in range(n):
y = poincare.math.mobius_add(a, y, c=c)
ny = poincare.math.mobius_scalar_mul(n, a, c=c)
tolerance = {
torch.float32: dict(atol=1e-7, rtol=1e-6),
torch.float64: dict(atol=1e-10),
}
np.testing.assert_allclose(y, ny, **tolerance[c.dtype])
def test_mobius_addition_left_cancelation(a, b, c):
res = poincare.math.mobius_add(-a, poincare.math.mobius_add(a, b, c=c), c=c)
tolerance = {torch.float32: dict(atol=1e-6, rtol=1e-6), torch.float64: dict()}
np.testing.assert_allclose(res, b, **tolerance[c.dtype])
def test_mobius_cosub(a, b, c):
# (a \oplus_c b) \boxminus b = a
ah = poincare.math.mobius_cosub(poincare.math.mobius_add(a, b, c=c), b, c=c)
np.testing.assert_allclose(ah, a, atol=1e-5)
def test_mobius_addition_zero_a(b, c):
a = torch.zeros(100, 10, dtype=c.dtype)
res = poincare.math.mobius_add(a, b, c=c)
np.testing.assert_allclose(res, b)
def test_mobius_addition_negative_cancellation(a, c):
res = poincare.math.mobius_add(a, -a, c=c)
tolerance = {
torch.float32: dict(atol=1e-7, rtol=1e-6),
torch.float64: dict(atol=1e-10),
}
np.testing.assert_allclose(res, torch.zeros_like(res), **tolerance[c.dtype])
def test_mobius_negative_addition(a, b, c):
res = poincare.math.mobius_add(-b, -a, c=c)
res1 = -poincare.math.mobius_add(b, a, c=c)
tolerance = {
torch.float32: dict(atol=1e-7, rtol=1e-6),
torch.float64: dict(atol=1e-10),
}
np.testing.assert_allclose(res, res1, **tolerance[c.dtype])
def test_mobius_negative_addition(a, b, c):
res = poincare.math.mobius_add(-b, -a, c=c)
res1 = -poincare.math.mobius_add(b, a, c=c)
tolerance = {
torch.float32: dict(atol=1e-7, rtol=1e-6),
torch.float64: dict(atol=1e-10),
}
np.testing.assert_allclose(res, res1, **tolerance[c.dtype])
def test_scalar_multiplication_distributive(a, c, r1, r2):
res = poincare.math.mobius_scalar_mul(r1 + r2, a, c=c)
res1 = poincare.math.mobius_add(
poincare.math.mobius_scalar_mul(r1, a, c=c),
poincare.math.mobius_scalar_mul(r2, a, c=c),
c=c,
)
res2 = poincare.math.mobius_add(
poincare.math.mobius_scalar_mul(r1, a, c=c),
poincare.math.mobius_scalar_mul(r2, a, c=c),
c=c,
)
tolerance = {
torch.float32: dict(atol=1e-6, rtol=1e-7),
torch.float64: dict(atol=1e-7, rtol=1e-10),
}
np.testing.assert_allclose(res1, res, **tolerance[c.dtype])
np.testing.assert_allclose(res2, res, **tolerance[c.dtype])