Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_scale_poincare():
ball = geoopt.PoincareBallExact()
sball = geoopt.Scaled(ball, 2)
v = torch.arange(10).float() / 10
np.testing.assert_allclose(
ball.dist0(ball.expmap0(v)).item(),
sball.dist0(sball.expmap0(v)).item(),
atol=1e-5,
)
np.testing.assert_allclose(
sball.dist0(sball.expmap0(v)).item(),
sball.norm(torch.zeros_like(v), v),
atol=1e-5,
)
def test_rescaling_methods_accessible():
ball = geoopt.PoincareBallExact()
sball = geoopt.Scaled(ball, 2)
rsball = geoopt.Scaled(sball, 0.5)
v0 = torch.arange(10).float() / 10
v1 = -torch.arange(10).float() / 10
rsball.geodesic(0.5, v0, v1)
def test_scaling_compensates():
ball = geoopt.PoincareBallExact()
sball = geoopt.Scaled(ball, 2)
rsball = geoopt.Scaled(sball, 0.5)
v = torch.arange(10).float() / 10
np.testing.assert_allclose(ball.expmap0(v), rsball.expmap0(v))
def test_tensor_is_attached():
m1 = geoopt.Euclidean()
m1 = geoopt.Scaled(m1)
m1 = geoopt.Scaled(m1)
p = m1.random(())
assert m1.is_attached(p)
def unary_case(unary_case_base, scaled):
if scaled:
return unary_case_base._replace(
manifold=geoopt.Scaled(unary_case_base.manifold, 2)
)
else:
return unary_case_base
def test_scaling_compensates():
ball = geoopt.PoincareBallExact()
sball = geoopt.Scaled(ball, 2)
rsball = geoopt.Scaled(sball, 0.5)
v = torch.arange(10).float() / 10
np.testing.assert_allclose(ball.expmap0(v), rsball.expmap0(v))
def test_ismanifold():
m1 = geoopt.Euclidean()
assert geoopt.ismanifold(m1, geoopt.Euclidean)
m1 = geoopt.Scaled(m1)
m1 = geoopt.Scaled(m1)
assert geoopt.ismanifold(m1, geoopt.Euclidean)
with pytest.raises(TypeError):
geoopt.ismanifold(m1, int)
with pytest.raises(TypeError):
geoopt.ismanifold(m1, 1)
assert not geoopt.ismanifold(1, geoopt.Euclidean)
def test_rescaling_methods_accessible():
ball = geoopt.PoincareBallExact()
sball = geoopt.Scaled(ball, 2)
rsball = geoopt.Scaled(sball, 0.5)
v0 = torch.arange(10).float() / 10
v1 = -torch.arange(10).float() / 10
rsball.geodesic(0.5, v0, v1)
def test_ismanifold():
m1 = geoopt.Euclidean()
assert geoopt.ismanifold(m1, geoopt.Euclidean)
m1 = geoopt.Scaled(m1)
m1 = geoopt.Scaled(m1)
assert geoopt.ismanifold(m1, geoopt.Euclidean)
with pytest.raises(TypeError):
geoopt.ismanifold(m1, int)
with pytest.raises(TypeError):
geoopt.ismanifold(m1, 1)
assert not geoopt.ismanifold(1, geoopt.Euclidean)
check if a given manifold is compatible with cls API
cls : type
manifold type
Returns
-------
bool
comparison result
"""
if not issubclass(cls, geoopt.manifolds.Manifold):
raise TypeError("`cls` should be a subclass of geoopt.manifolds.Manifold")
if not isinstance(instance, geoopt.manifolds.Manifold):
return False
else:
# this is the case to care about, Scaled class is a proxy, but fails instance checks
while isinstance(instance, geoopt.Scaled):
instance = instance.base
return isinstance(instance, cls)