Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
)
yield UnaryCase(
manifold_shapes[geoopt.ProductManifold],
product_manifold.pack_point(*x),
product_manifold.pack_point(*ex),
product_manifold.pack_point(*v),
product_manifold.pack_point(*ev),
product_manifold,
)
# + 1 case without stiefel
torch.manual_seed(42)
ex = [torch.randn(10), torch.randn(3) / 10, torch.randn(())]
ev = [torch.randn(10), torch.randn(3) / 10, torch.randn(())]
manifolds = [
geoopt.Sphere(),
geoopt.PoincareBall(),
# geoopt.Stiefel(),
geoopt.Euclidean(),
]
x = [manifolds[i].projx(ex[i]) for i in range(len(manifolds))]
v = [manifolds[i].proju(x[i], ev[i]) for i in range(len(manifolds))]
product_manifold = geoopt.ProductManifold(
*((manifolds[i], ex[i].shape) for i in range(len(ex)))
)
yield UnaryCase(
manifold_shapes[geoopt.ProductManifold],
product_manifold.pack_point(*x),
product_manifold.pack_point(*ex),
product_manifold.pack_point(*v),
def sphere_subspace_case():
torch.manual_seed(42)
shape = manifold_shapes[geoopt.manifolds.Sphere]
subspace = torch.rand(shape[-1], 2, dtype=torch.float64)
Q, _ = geoopt.linalg.batch_linalg.qr(subspace)
P = Q @ Q.t()
ex = torch.randn(*shape, dtype=torch.float64)
ev = torch.randn(*shape, dtype=torch.float64)
x = (ex @ P.t()) / torch.norm(ex @ P.t())
v = (ev - (x @ ev) * x) @ P.t()
manifold = geoopt.Sphere(intersection=subspace)
x = geoopt.ManifoldTensor(x, manifold=manifold)
case = UnaryCase(shape, x, ex, v, ev, manifold)
yield case
manifold = geoopt.SphereExact(intersection=subspace)
x = geoopt.ManifoldTensor(x, manifold=manifold)
case = UnaryCase(shape, x, ex, v, ev, manifold)
yield case
def test_pickle2():
t = torch.ones(10)
p = geoopt.ManifoldParameter(t, manifold=geoopt.Sphere())
with tempfile.TemporaryDirectory() as path:
torch.save(p, os.path.join(path, "tens.t7"))
p1 = torch.load(os.path.join(path, "tens.t7"))
assert isinstance(p1, geoopt.ManifoldParameter)
assert p.stride() == p1.stride()
assert p.storage_offset() == p1.storage_offset()
assert p.requires_grad == p1.requires_grad
np.testing.assert_allclose(p.detach(), p1.detach())
assert isinstance(p.manifold, type(p1.manifold))
def test_random_Sphere():
manifold = geoopt.Sphere()
point = manifold.random_uniform(3, 10, 10)
manifold.assert_check_point_on_manifold(point)
assert point.manifold is manifold
def test_fails_Sphere():
with pytest.raises(ValueError):
manifold = geoopt.Sphere()
manifold.random_uniform(())
with pytest.raises(ValueError):
manifold = geoopt.Sphere()
manifold.random_uniform(1)
def test_component_inner_product():
pman = ProductManifold((Sphere(), 10), (Sphere(), (3, 2)), (Euclidean(), ()))
point = [
Sphere().random_uniform(5, 10),
Sphere().random_uniform(5, 3, 2),
Euclidean().random_normal(5),
]
tensor = pman.pack_point(*point)
tangent = torch.randn_like(tensor)
tangent = pman.proju(tensor, tangent)
inner = pman.component_inner(tensor, tangent)
assert inner.shape == (5, pman.n_elements)