Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def sphere_compliment_case():
torch.manual_seed(42)
shape = manifold_shapes[geoopt.manifolds.Sphere]
complement = torch.rand(shape[-1], 1, dtype=torch.float64)
Q, _ = geoopt.linalg.batch_linalg.qr(complement)
P = -Q @ Q.transpose(-1, -2)
P[..., torch.arange(P.shape[-2]), torch.arange(P.shape[-2])] += 1
ex = torch.randn(*shape, dtype=torch.float64)
ev = torch.randn(*shape, dtype=torch.float64)
x = (ex @ P.t()) / torch.norm(ex @ P.t())
v = (ev - (x @ ev) * x) @ P.t()
manifold = geoopt.Sphere(complement=complement)
x = geoopt.ManifoldTensor(x, manifold=manifold)
case = UnaryCase(shape, x, ex, v, ev, manifold)
yield case
manifold = geoopt.SphereExact(complement=complement)
x = geoopt.ManifoldTensor(x, manifold=manifold)
case = UnaryCase(shape, x, ex, v, ev, manifold)
yield case
def test_qr(A):
q, r = geoopt.linalg.qr(A)
with torch.no_grad():
for i, a in enumerate(A):
qt, rt = torch.qr(a)
np.testing.assert_allclose(q.detach()[i], qt.detach())
np.testing.assert_allclose(r.detach()[i], rt.detach())
def test_svd(A):
u, d, v = geoopt.linalg.svd(A)
with torch.no_grad():
for i, a in enumerate(A):
ut, dt, vt = torch.svd(a)
np.testing.assert_allclose(u.detach()[i], ut.detach())
np.testing.assert_allclose(d.detach()[i], dt.detach())
np.testing.assert_allclose(v.detach()[i], vt.detach())
u.sum().backward() # this should work
def _configure_manifold_intersection(self, intersection: torch.Tensor):
Q, _ = geoopt.linalg.batch_linalg.qr(intersection)
self.register_buffer("projector", Q @ Q.transpose(-1, -2))
def _configure_manifold_complement(self, complement: torch.Tensor):
Q, _ = geoopt.linalg.batch_linalg.qr(complement)
P = -Q @ Q.transpose(-1, -2)
P[..., torch.arange(P.shape[-2]), torch.arange(P.shape[-2])] += 1
self.register_buffer("projector", P)
def proju(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
return u - x @ linalg.batch_linalg.sym(x.transpose(-1, -2) @ u)
def expmap(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
xtu = x.transpose(-1, -2) @ u
utu = u.transpose(-1, -2) @ u
eye = torch.zeros_like(utu)
eye[..., torch.arange(utu.shape[-2]), torch.arange(utu.shape[-2])] += 1
logw = linalg.block_matrix(((xtu, -utu), (eye, xtu)))
w = linalg.expm(logw)
z = torch.cat((linalg.expm(-xtu), torch.zeros_like(utu)), dim=-2)
y = torch.cat((x, u), dim=-1) @ w @ z
return y
def expmap(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
xtu = x.transpose(-1, -2) @ u
utu = u.transpose(-1, -2) @ u
eye = torch.zeros_like(utu)
eye[..., torch.arange(utu.shape[-2]), torch.arange(utu.shape[-2])] += 1
logw = linalg.block_matrix(((xtu, -utu), (eye, xtu)))
w = linalg.expm(logw)
z = torch.cat((linalg.expm(-xtu), torch.zeros_like(utu)), dim=-2)
y = torch.cat((x, u), dim=-1) @ w @ z
return y
----------
size : shape
the desired output shape
dtype : torch.dtype
desired dtype
device : torch.device
desired device
Returns
-------
ManifoldTensor
random point on Stiefel manifold
"""
self._assert_check_shape(size2shape(*size), "x")
tens = torch.randn(*size, device=device, dtype=dtype)
return ManifoldTensor(linalg.qr(tens)[0], manifold=self)
def proju(self, x, u):
# takes batch data
# batch_size, n, _ = x.shape
x_shape = x.shape
x = x.reshape(-1, x_shape[-2], x_shape[-1])
batch_size, n = x.shape[0:2]
e = torch.ones(batch_size, n, 1)
I = torch.unsqueeze(torch.eye(x.shape[-1]), 0).repeat(batch_size, 1, 1)
mu = x * u
A = linalg.block_matrix([[I, x], [torch.transpose(x, 1, 2), I]])
B = A[:, :, 1:]
b = torch.cat(
[
torch.sum(mu, dim=2, keepdim=True),
torch.transpose(torch.sum(mu, dim=1, keepdim=True), 1, 2),
],
dim=1,
)
zeta, _ = torch.solve(
B.transpose(1, 2) @ (b - A[:, :, 0:1]), B.transpose(1, 2) @ B
)
alpha = torch.cat([torch.ones(batch_size, 1, 1), zeta[:, 0 : n - 1]], dim=1)
beta = zeta[:, n - 1 : 2 * n - 1]
rgrad = mu - (alpha @ e.transpose(1, 2) + e @ beta.transpose(1, 2)) * x