Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_adam_poincare():
torch.manual_seed(44)
ideal = torch.tensor([0.5, 0.5])
start = torch.randn(2) / 2
start = geoopt.manifolds.poincare.math.expmap0(start, c=1.0)
start = geoopt.ManifoldParameter(start, manifold=geoopt.PoincareBall())
def closure():
optim.zero_grad()
loss = geoopt.manifolds.poincare.math.dist(start, ideal) ** 2
loss.backward()
return loss.item()
optim = geoopt.optim.RiemannianAdam([start], lr=1e-2)
for _ in range(2000):
optim.step(closure)
np.testing.assert_allclose(start.data, ideal, atol=1e-5, rtol=1e-5)
def test_dtype_checked_properly():
p1 = PoincareBall()
p2 = PoincareBall().double()
with pytest.raises(ValueError) as e:
_ = ProductManifold((p1, (10,)), (p2, (12,)))
assert e.match("Not all manifold share the same dtype")
def test_product():
manifold = geoopt.ProductManifold(
(geoopt.Sphere(), 10),
(geoopt.PoincareBall(), 3),
(geoopt.Stiefel(), (20, 2)),
(geoopt.Euclidean(), 43),
)
sample = manifold.random(20, manifold.n_elements)
manifold.assert_check_point_on_manifold(sample)
def poincare_case():
torch.manual_seed(42)
shape = manifold_shapes[geoopt.manifolds.PoincareBall]
ex = torch.randn(*shape, dtype=torch.float64) / 3
ev = torch.randn(*shape, dtype=torch.float64) / 3
x = torch.tanh(torch.norm(ex)) * ex / torch.norm(ex)
ex = x.clone()
v = ev.clone()
manifold = geoopt.PoincareBall().to(dtype=torch.float64)
x = geoopt.ManifoldTensor(x, manifold=manifold)
case = UnaryCase(shape, x, ex, v, ev, manifold)
yield case
manifold = geoopt.PoincareBallExact().to(dtype=torch.float64)
x = geoopt.ManifoldTensor(x, manifold=manifold)
case = UnaryCase(shape, x, ex, v, ev, manifold)
yield case
def test_dtype_checked_properly():
p1 = PoincareBall()
p2 = PoincareBall().double()
with pytest.raises(ValueError) as e:
_ = ProductManifold((p1, (10,)), (p2, (12,)))
assert e.match("Not all manifold share the same dtype")
def test_random_Poincare():
manifold = geoopt.PoincareBall()
point = manifold.random_normal(3, 10, 10)
manifold.assert_check_point_on_manifold(point)
assert point.manifold is manifold
yield UnaryCase(
manifold_shapes[geoopt.ProductManifold],
product_manifold.pack_point(*x),
product_manifold.pack_point(*ex),
product_manifold.pack_point(*v),
product_manifold.pack_point(*ev),
product_manifold,
)
# + 1 case without stiefel
torch.manual_seed(42)
ex = [torch.randn(10), torch.randn(3) / 10, torch.randn(())]
ev = [torch.randn(10), torch.randn(3) / 10, torch.randn(())]
manifolds = [
geoopt.Sphere(),
geoopt.PoincareBall(),
# geoopt.Stiefel(),
geoopt.Euclidean(),
]
x = [manifolds[i].projx(ex[i]) for i in range(len(manifolds))]
v = [manifolds[i].proju(x[i], ev[i]) for i in range(len(manifolds))]
product_manifold = geoopt.ProductManifold(
*((manifolds[i], ex[i].shape) for i in range(len(ex)))
)
yield UnaryCase(
manifold_shapes[geoopt.ProductManifold],
product_manifold.pack_point(*x),
product_manifold.pack_point(*ex),
product_manifold.pack_point(*v),
product_manifold.pack_point(*ev),
def test_fails_Poincare():
with pytest.raises(ValueError):
manifold = geoopt.PoincareBall()
manifold.random_normal(())