Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def sphere_subspace_case():
torch.manual_seed(42)
shape = manifold_shapes[geoopt.manifolds.Sphere]
subspace = torch.rand(shape[-1], 2, dtype=torch.float64)
Q, _ = geoopt.linalg.batch_linalg.qr(subspace)
P = Q @ Q.t()
ex = torch.randn(*shape, dtype=torch.float64)
ev = torch.randn(*shape, dtype=torch.float64)
x = (ex @ P.t()) / torch.norm(ex @ P.t())
v = (ev - (x @ ev) * x) @ P.t()
manifold = geoopt.Sphere(intersection=subspace)
x = geoopt.ManifoldTensor(x, manifold=manifold)
case = UnaryCase(shape, x, ex, v, ev, manifold)
yield case
manifold = geoopt.SphereExact(intersection=subspace)
x = geoopt.ManifoldTensor(x, manifold=manifold)
case = UnaryCase(shape, x, ex, v, ev, manifold)
yield case
def canonical_stiefel_case():
torch.manual_seed(42)
shape = manifold_shapes[geoopt.manifolds.CanonicalStiefel]
ex = torch.randn(*shape)
ev = torch.randn(*shape)
u, _, v = torch.svd(ex)
x = u @ v.t()
v = ev - x @ ev.t() @ x
manifold = geoopt.manifolds.CanonicalStiefel()
x = geoopt.ManifoldTensor(x, manifold=manifold)
case = UnaryCase(shape, x, ex, v, ev, manifold)
yield case
def test_compare_manifolds():
m1 = geoopt.Euclidean()
m2 = geoopt.Euclidean(ndim=1)
tensor = geoopt.ManifoldTensor(10, manifold=m1)
with pytest.raises(ValueError) as e:
_ = geoopt.ManifoldParameter(tensor, manifold=m2)
assert e.match("Manifolds do not match")
def euclidean_case():
torch.manual_seed(42)
shape = manifold_shapes[geoopt.manifolds.Euclidean]
ex = torch.randn(*shape, dtype=torch.float64)
ev = torch.randn(*shape, dtype=torch.float64)
x = ex.clone()
v = ev.clone()
manifold = geoopt.Euclidean(ndim=1)
x = geoopt.ManifoldTensor(x, manifold=manifold)
case = UnaryCase(shape, x, ex, v, ev, manifold)
yield case
def poincare_case():
torch.manual_seed(42)
shape = manifold_shapes[geoopt.manifolds.PoincareBall]
ex = torch.randn(*shape, dtype=torch.float64) / 3
ev = torch.randn(*shape, dtype=torch.float64) / 3
x = torch.tanh(torch.norm(ex)) * ex / torch.norm(ex)
ex = x.clone()
v = ev.clone()
manifold = geoopt.PoincareBall().to(dtype=torch.float64)
x = geoopt.ManifoldTensor(x, manifold=manifold)
case = UnaryCase(shape, x, ex, v, ev, manifold)
yield case
manifold = geoopt.PoincareBallExact().to(dtype=torch.float64)
x = geoopt.ManifoldTensor(x, manifold=manifold)
case = UnaryCase(shape, x, ex, v, ev, manifold)
yield case
def sphere_case():
torch.manual_seed(42)
shape = manifold_shapes[geoopt.manifolds.Sphere]
ex = torch.randn(*shape, dtype=torch.float64)
ev = torch.randn(*shape, dtype=torch.float64)
x = ex / torch.norm(ex)
v = ev - (x @ ev) * x
manifold = geoopt.Sphere()
x = geoopt.ManifoldTensor(x, manifold=manifold)
case = UnaryCase(shape, x, ex, v, ev, manifold)
yield case
manifold = geoopt.SphereExact()
x = geoopt.ManifoldTensor(x, manifold=manifold)
case = UnaryCase(shape, x, ex, v, ev, manifold)
yield case
def origin(
self, *size, dtype=None, device=None, seed=42
) -> "geoopt.ManifoldTensor":
shape = geoopt.utils.size2shape(*size)
self._assert_check_shape(shape, "x")
batch_shape = shape[:-1]
points = []
for manifold, shape in zip(self.manifolds, self.shapes):
points.append(
manifold.origin(
batch_shape + shape, dtype=dtype, device=device, seed=seed
)
)
tensor = self.pack_point(*points)
return geoopt.ManifoldTensor(tensor, manifold=self)
----------
size : shape
the desired shape
device : torch.device
the desired device
dtype : torch.dtype
the desired dtype
seed : int
ignored
Returns
-------
ManifoldTensor
random point on the manifold
"""
return geoopt.ManifoldTensor(
torch.zeros(*size, dtype=dtype, device=device), manifold=self
)
std value for the Normal distribution
device : torch.device
the desired device
dtype : torch.dtype
the desired dtype
Returns
-------
ManifoldTensor
random point on the manifold
"""
self._assert_check_shape(size2shape(*size), "x")
mean = torch.as_tensor(mean, device=device, dtype=dtype)
std = torch.as_tensor(std, device=device, dtype=dtype)
tens = std.new_empty(*size).normal_() * std + mean
return geoopt.ManifoldTensor(tens, manifold=self)