Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def b(seed, c):
if seed in {30, 35}:
b = torch.randn(100, 10, dtype=c.dtype)
elif seed > 35:
b = torch.empty(100, 10, dtype=c.dtype).normal_(-1, 1)
b /= b.norm(dim=-1, keepdim=True) * 1.3
b *= (torch.rand_like(c) * c) ** 0.5
else:
b = torch.empty(100, 10, dtype=c.dtype).normal_(-1, 1)
b /= b.norm(dim=-1, keepdim=True) * 1.3
b *= random.uniform(0, c) ** 0.5
return poincare.math.project(b, c=c)
def geodesic_unit(
self, t: torch.Tensor, x: torch.Tensor, u: torch.Tensor, *, dim=-1, project=True
) -> torch.Tensor:
res = math.geodesic_unit(t, x, u, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res
def mobius_scalar_mul(
self, r: torch.Tensor, x: torch.Tensor, *, dim=-1, project=True
) -> torch.Tensor:
res = math.mobius_scalar_mul(r, x, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res
def retr(self, x: torch.Tensor, u: torch.Tensor, *, dim=-1) -> torch.Tensor:
# always assume u is scaled properly
approx = x + u
return math.project(approx, c=self.c, dim=dim)
def mobius_matvec(
self, m: torch.Tensor, x: torch.Tensor, *, dim=-1, project=True
) -> torch.Tensor:
res = math.mobius_matvec(m, x, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res
def expmap0(self, u: torch.Tensor, *, dim=-1, project=True) -> torch.Tensor:
res = math.expmap0(u, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res
def projx(self, x: torch.Tensor, dim=-1) -> torch.Tensor:
return math.project(x, c=self.c, dim=dim)
def mobius_pointwise_mul(
self, w: torch.Tensor, x: torch.Tensor, *, dim=-1, project=True
) -> torch.Tensor:
res = math.mobius_pointwise_mul(w, x, c=self.c, dim=dim)
if project:
return math.project(res, c=self.c, dim=dim)
else:
return res