Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def component_inner(self, x: torch.Tensor, u: torch.Tensor, v=None) -> torch.Tensor:
products = []
for i, manifold in enumerate(self.manifolds):
point = self.take_submanifold_value(x, i)
u_vec = self.take_submanifold_value(u, i)
target_shape = geoopt.utils.broadcast_shapes(point.shape, u_vec.shape)
if v is not None:
v_vec = self.take_submanifold_value(v, i)
else:
v_vec = None
inner = manifold.component_inner(point, u_vec, v_vec)
inner = inner.expand(target_shape)
products.append(inner)
result = self.pack_point(*products)
return result
def inner(
self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor = None, *, keepdim=False
) -> torch.Tensor:
if v is None:
inner = u.pow(2)
else:
inner = u * v
if self.ndim > 0:
inner = inner.sum(dim=tuple(range(-self.ndim, 0)), keepdim=keepdim)
x_shape = x.shape[: -self.ndim] + (1,) * self.ndim * keepdim
else:
x_shape = x.shape
i_shape = inner.shape
target_shape = broadcast_shapes(x_shape, i_shape)
return inner.expand(target_shape)
def transp(self, x: torch.Tensor, y: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
target_shape = broadcast_shapes(x.shape, y.shape, v.shape)
return v.expand(target_shape)
def proju(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
target_shape = broadcast_shapes(x.shape, u.shape)
return u.expand(target_shape)
def egrad2rgrad(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
target_shape = broadcast_shapes(x.shape, u.shape)
return u.expand(target_shape)
def inner(
self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor = None, *, keepdim=False
) -> torch.Tensor:
if v is None:
v = u
inner = (u * v).sum(-1, keepdim=keepdim)
target_shape = broadcast_shapes(x.shape[:-1] + (1,) * keepdim, inner.shape)
return inner.expand(target_shape)
def proju(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
target_shape = broadcast_shapes(x.shape, u.shape)
return u.expand(target_shape)
def component_inner(
self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor = None
) -> torch.Tensor:
# it is possible to factorize the manifold
if v is None:
inner = u.pow(2)
else:
inner = u * v
target_shape = broadcast_shapes(x.shape, inner.shape)
return inner.expand(target_shape)