Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def sphere_subspace_case():
torch.manual_seed(42)
shape = manifold_shapes[geoopt.manifolds.Sphere]
subspace = torch.rand(shape[-1], 2, dtype=torch.float64)
Q, _ = geoopt.linalg.batch_linalg.qr(subspace)
P = Q @ Q.t()
ex = torch.randn(*shape, dtype=torch.float64)
ev = torch.randn(*shape, dtype=torch.float64)
x = (ex @ P.t()) / torch.norm(ex @ P.t())
v = (ev - (x @ ev) * x) @ P.t()
manifold = geoopt.Sphere(intersection=subspace)
x = geoopt.ManifoldTensor(x, manifold=manifold)
case = UnaryCase(shape, x, ex, v, ev, manifold)
yield case
manifold = geoopt.SphereExact(intersection=subspace)
x = geoopt.ManifoldTensor(x, manifold=manifold)
case = UnaryCase(shape, x, ex, v, ev, manifold)
yield case
def retr(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
q, r = linalg.batch_linalg.qr(x + u)
unflip = linalg.batch_linalg.extract_diag(r).sign().add(0.5).sign()
q *= unflip[..., None, :]
return q