Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_compare_manifolds():
m1 = geoopt.Euclidean()
m2 = geoopt.Euclidean(ndim=1)
tensor = geoopt.ManifoldTensor(10, manifold=m1)
with pytest.raises(ValueError) as e:
_ = geoopt.ManifoldParameter(tensor, manifold=m2)
assert e.match("Manifolds do not match")
def test_init_manifold():
torch.manual_seed(42)
stiefel = geoopt.manifolds.Stiefel()
rn = geoopt.manifolds.Euclidean()
x0 = torch.randn(10, 10)
x1 = torch.randn(10, 10)
with torch.no_grad():
p0 = geoopt.ManifoldParameter(x0, manifold=stiefel).proj_()
p1 = geoopt.ManifoldParameter(x1, manifold=rn)
p0.grad = torch.zeros_like(p0)
p1.grad = torch.zeros_like(p1)
p0old = p0.clone()
p1old = p1.clone()
opt = geoopt.optim.RiemannianSGD([p0, p1], lr=1, stabilize=1)
opt.zero_grad()
opt.step()
assert not np.allclose(p0.data, p0old.data)
assert p0.is_contiguous()
np.testing.assert_allclose(p1.data, p1old.data)
np.testing.assert_allclose(p0.data, stiefel.projx(p0old.data), atol=1e-4)
def __init__(self, mu, sigma):
super().__init__()
self.d = torch.distributions.Normal(mu, sigma)
self.x = geoopt.ManifoldParameter(
torch.randn_like(mu), manifold=geoopt.Stiefel()
)
def test_rsgd_stiefel(params):
stiefel = geoopt.manifolds.Stiefel()
torch.manual_seed(42)
with torch.no_grad():
X = geoopt.ManifoldParameter(torch.randn(20, 10), manifold=stiefel).proj_()
Xstar = torch.randn(20, 10)
Xstar.set_(stiefel.projx(Xstar))
def closure():
optim.zero_grad()
loss = (X - Xstar).pow(2).sum()
# manifold constraint that makes optimization hard if violated
loss += (X.t() @ X - torch.eye(X.shape[1])).pow(2).sum() * 100
loss.backward()
return loss.item()
optim = geoopt.optim.RiemannianSGD([X], **params)
assert (X - Xstar).norm() > 1e-5
for _ in range(10000):
if (X - Xstar).norm() < 1e-5:
break
def test_pickle3():
t = torch.ones(10)
span = torch.randn(10, 2)
sub_sphere = geoopt.manifolds.Sphere(intersection=span)
p = geoopt.ManifoldParameter(t, manifold=sub_sphere)
with tempfile.TemporaryDirectory() as path:
torch.save(p, os.path.join(path, "tens.t7"))
p1 = torch.load(os.path.join(path, "tens.t7"))
assert isinstance(p1, geoopt.ManifoldParameter)
assert p.stride() == p1.stride()
assert p.storage_offset() == p1.storage_offset()
assert p.requires_grad == p1.requires_grad
np.testing.assert_allclose(p.detach(), p1.detach())
assert isinstance(p.manifold, type(p1.manifold))
np.testing.assert_allclose(p.manifold.projector, p1.manifold.projector)
def test_adam_stiefel(params):
stiefel = geoopt.manifolds.Stiefel()
torch.manual_seed(42)
with torch.no_grad():
X = geoopt.ManifoldParameter(torch.randn(20, 10), manifold=stiefel).proj_()
Xstar = torch.randn(20, 10)
Xstar.set_(stiefel.projx(Xstar))
def closure():
optim.zero_grad()
loss = (X - Xstar).pow(2).sum()
# manifold constraint that makes optimization hard if violated
loss += (X.t() @ X - torch.eye(X.shape[1])).pow(2).sum() * 100
loss.backward()
return loss.item()
optim = geoopt.optim.RiemannianAdam([X], stabilize=4500, **params)
assert (X - Xstar).norm() > 1e-5
for _ in range(10000):
if (X - Xstar).norm() < 1e-5:
break