Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_parallel_transport0_preserves_inner_products(a, c):
# pointing to the center
v_0 = torch.rand_like(a) + 1e-5
u_0 = torch.rand_like(a) + 1e-5
zero = torch.zeros_like(a)
v_a = poincare.math.parallel_transport0(a, v_0, c=c)
u_a = poincare.math.parallel_transport0(a, u_0, c=c)
# compute norms
vu_0 = poincare.math.inner(zero, v_0, u_0, c=c, keepdim=True)
vu_a = poincare.math.inner(a, v_a, u_a, c=c, keepdim=True)
np.testing.assert_allclose(vu_a, vu_0, atol=1e-6, rtol=1e-6)
def a(seed, c):
if seed in {30, 35}:
a = torch.randn(100, 10, dtype=c.dtype)
elif seed > 35:
# do not check numerically unstable regions
# I've manually observed small differences there
a = torch.empty(100, 10, dtype=c.dtype).normal_(-1, 1)
a /= a.norm(dim=-1, keepdim=True) * 1.3
a *= (torch.rand_like(c) * c) ** 0.5
else:
a = torch.empty(100, 10, dtype=c.dtype).normal_(-1, 1)
a /= a.norm(dim=-1, keepdim=True) * 1.3
a *= random.uniform(0, c) ** 0.5
return poincare.math.project(a, c=c)
@poincare.math.mobiusify
def matvec(x):
return x @ mat.transpose(-1, -2)
def closure():
optim.zero_grad()
loss = geoopt.manifolds.poincare.math.dist(start, ideal) ** 2
loss.backward()
return loss.item()
def expmap(
self, x: torch.Tensor, u: torch.Tensor, *, project=True, dim=-1
) -> torch.Tensor:
res = math.expmap(x, u, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res
def expmap(
self, x: torch.Tensor, u: torch.Tensor, *, project=True, dim=-1
) -> torch.Tensor:
res = math.expmap(x, u, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res
def logmap0(self, x: torch.Tensor, *, dim=-1) -> torch.Tensor:
return math.logmap0(x, c=self.c, dim=dim)
def expmap0(self, u: torch.Tensor, *, dim=-1, project=True) -> torch.Tensor:
res = math.expmap0(u, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res
def _check_point_on_manifold(
self, x: torch.Tensor, *, atol=1e-5, rtol=1e-5
) -> Tuple[bool, Optional[str]]:
px = math.project(x, c=self.c)
ok = torch.allclose(x, px, atol=atol, rtol=rtol)
if not ok:
reason = "'x' norm lies out of the bounds [-1/sqrt(c)+eps, 1/sqrt(c)-eps]"
else:
reason = None
return ok, reason