How to use the geoopt.utils.broadcast_shapes function in geoopt

To help you get started, we’ve selected a few geoopt examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github geoopt / geoopt / geoopt / manifolds / product.py View on Github external
def component_inner(self, x: torch.Tensor, u: torch.Tensor, v=None) -> torch.Tensor:
        products = []
        for i, manifold in enumerate(self.manifolds):
            point = self.take_submanifold_value(x, i)
            u_vec = self.take_submanifold_value(u, i)
            target_shape = geoopt.utils.broadcast_shapes(point.shape, u_vec.shape)
            if v is not None:
                v_vec = self.take_submanifold_value(v, i)
            else:
                v_vec = None
            inner = manifold.component_inner(point, u_vec, v_vec)
            inner = inner.expand(target_shape)
            products.append(inner)
        result = self.pack_point(*products)
        return result
github geoopt / geoopt / geoopt / manifolds / euclidean.py View on Github external
def inner(
        self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor = None, *, keepdim=False
    ) -> torch.Tensor:
        if v is None:
            inner = u.pow(2)
        else:
            inner = u * v
        if self.ndim > 0:
            inner = inner.sum(dim=tuple(range(-self.ndim, 0)), keepdim=keepdim)
            x_shape = x.shape[: -self.ndim] + (1,) * self.ndim * keepdim
        else:
            x_shape = x.shape
        i_shape = inner.shape
        target_shape = broadcast_shapes(x_shape, i_shape)
        return inner.expand(target_shape)
github geoopt / geoopt / geoopt / manifolds / euclidean.py View on Github external
def transp(self, x: torch.Tensor, y: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        target_shape = broadcast_shapes(x.shape, y.shape, v.shape)
        return v.expand(target_shape)
github geoopt / geoopt / geoopt / manifolds / poincare / __init__.py View on Github external
def proju(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
        target_shape = broadcast_shapes(x.shape, u.shape)
        return u.expand(target_shape)
github geoopt / geoopt / geoopt / manifolds / euclidean.py View on Github external
def egrad2rgrad(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
        target_shape = broadcast_shapes(x.shape, u.shape)
        return u.expand(target_shape)
github geoopt / geoopt / geoopt / manifolds / sphere.py View on Github external
def inner(
        self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor = None, *, keepdim=False
    ) -> torch.Tensor:
        if v is None:
            v = u
        inner = (u * v).sum(-1, keepdim=keepdim)
        target_shape = broadcast_shapes(x.shape[:-1] + (1,) * keepdim, inner.shape)
        return inner.expand(target_shape)
github geoopt / geoopt / geoopt / manifolds / euclidean.py View on Github external
def proju(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
        target_shape = broadcast_shapes(x.shape, u.shape)
        return u.expand(target_shape)
github geoopt / geoopt / geoopt / manifolds / euclidean.py View on Github external
def component_inner(
        self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor = None
    ) -> torch.Tensor:
        # it is possible to factorize the manifold
        if v is None:
            inner = u.pow(2)
        else:
            inner = u * v
        target_shape = broadcast_shapes(x.shape, inner.shape)
        return inner.expand(target_shape)