How to use the geoopt.tensor.ManifoldTensor function in geoopt

To help you get started, we’ve selected a few geoopt examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github geoopt / geoopt / geoopt / tensor.py View on Github external
self.__class__,
            self.storage(),
            self.storage_offset(),
            self.size(),
            self.stride(),
            self.requires_grad,
            dict(),
        )
        return _rebuild_manifold_parameter, proto + (self.manifold,)

    @insert_docs(Manifold.unpack_tensor.__doc__, r"\s+tensor : .+\n.+", "")
    def unpack_tensor(self) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
        return self.manifold.unpack_tensor(self)


class ManifoldParameter(ManifoldTensor, torch.nn.Parameter):
    """Same as :class:`torch.nn.Parameter` that has information about its manifold.

    It should be used within :class:`torch.nn.Module` to be recognized
    in parameter collection.

    Other Parameters
    ----------------
    manifold : :class:`geoopt.Manifold` (optional)
        A manifold for the tensor if ``data`` is not a :class:`geoopt.ManifoldTensor`
    """

    def __new__(cls, data=None, manifold=None, requires_grad=True):
        if data is None:
            data = ManifoldTensor(manifold=manifold or Euclidean())
        elif not isinstance(data, ManifoldTensor):
            data = ManifoldTensor(data, manifold=manifold or Euclidean())
github geoopt / geoopt / geoopt / samplers / samplers.py View on Github external
self._step(p, self.state[p]["r"], group["epsilon"])
                        p.grad.zero_()

        logp = closure()
        logp.backward()

        new_logp = logp.item()
        new_H = -new_logp
        with torch.no_grad():
            for group in self.param_groups:
                for p in group["params"]:
                    if p.grad is None:
                        continue

                    if isinstance(p, (ManifoldParameter, ManifoldTensor)):
                        manifold = p.manifold
                    else:
                        manifold = Euclidean()

                    proju = manifold.proju

                    r = self.state[p]["r"]
                    r.add_(0.5 * epsilon * proju(p, p.grad))
                    p.grad.zero_()

                    new_H += 0.5 * (r * r).sum().item()

            rho = min(1.0, math.exp(old_H - new_H))

            if not self.burnin:
                self.steps += 1
github geoopt / geoopt / geoopt / tensor.py View on Github external
def __new__(cls, data=None, manifold=None, requires_grad=True):
        if data is None:
            data = ManifoldTensor(manifold=manifold or Euclidean())
        elif not isinstance(data, ManifoldTensor):
            data = ManifoldTensor(data, manifold=manifold or Euclidean())
        else:
            if manifold is not None and data.manifold != manifold:
                raise ValueError(
                    "Manifolds do not match: {}, {}".format(data.manifold, manifold)
                )
        instance = ManifoldTensor._make_subclass(cls, data, requires_grad)
        instance.manifold = data.manifold
        return instance
github geoopt / geoopt / geoopt / tensor.py View on Github external
def __new__(cls, data=None, manifold=None, requires_grad=True):
        if data is None:
            data = ManifoldTensor(manifold=manifold or Euclidean())
        elif not isinstance(data, ManifoldTensor):
            data = ManifoldTensor(data, manifold=manifold or Euclidean())
        else:
            if manifold is not None and data.manifold != manifold:
                raise ValueError(
                    "Manifolds do not match: {}, {}".format(data.manifold, manifold)
                )
        instance = ManifoldTensor._make_subclass(cls, data, requires_grad)
        instance.manifold = data.manifold
        return instance
github geoopt / geoopt / geoopt / tensor.py View on Github external
def __new__(cls, data=None, manifold=None, requires_grad=True):
        if data is None:
            data = ManifoldTensor(manifold=manifold or Euclidean())
        elif not isinstance(data, ManifoldTensor):
            data = ManifoldTensor(data, manifold=manifold or Euclidean())
        else:
            if manifold is not None and data.manifold != manifold:
                raise ValueError(
                    "Manifolds do not match: {}, {}".format(data.manifold, manifold)
                )
        instance = ManifoldTensor._make_subclass(cls, data, requires_grad)
        instance.manifold = data.manifold
        return instance
github geoopt / geoopt / geoopt / optim / rsgd.py View on Github external
def stabilize_group(self, group):
        for p in group["params"]:
            if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
                continue
            manifold = p.manifold
            momentum = group["momentum"]
            copy_or_set_(p, manifold.projx(p))
            if momentum > 0:
                param_state = self.state[p]
                if not param_state:  # due to None grads
                    continue
                if "momentum_buffer" in param_state:
                    buf = param_state["momentum_buffer"]
                    buf.set_(manifold.proju(p, buf))
github geoopt / geoopt / geoopt / tensor.py View on Github external
def __new__(cls, data=None, manifold=None, requires_grad=True):
        if data is None:
            data = ManifoldTensor(manifold=manifold or Euclidean())
        elif not isinstance(data, ManifoldTensor):
            data = ManifoldTensor(data, manifold=manifold or Euclidean())
        else:
            if manifold is not None and data.manifold != manifold:
                raise ValueError(
                    "Manifolds do not match: {}, {}".format(data.manifold, manifold)
                )
        instance = ManifoldTensor._make_subclass(cls, data, requires_grad)
        instance.manifold = data.manifold
        return instance
github geoopt / geoopt / geoopt / manifolds / stiefel.py View on Github external
the desired shape
        device : torch.device
            the desired device
        dtype : torch.dtype
            the desired dtype
        seed : int
            ignored

        Returns
        -------
        ManifoldTensor
        """
        self._assert_check_shape(size2shape(*size), "x")
        eye = torch.zeros(*size, dtype=dtype, device=device)
        eye[..., torch.arange(eye.shape[-1]), torch.arange(eye.shape[-1])] += 1
        return ManifoldTensor(eye, manifold=self)
github geoopt / geoopt / geoopt / samplers / rhmc.py View on Github external
def stabilize_group(self, group):
        for p in group["params"]:
            if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
                continue
            copy_or_set_(p, p.manifold.projx(p))
            state = self.state[p]
            if not state:  # due to None grads
                continue
            copy_or_set_(state["old_p"], p.manifold.projx(state["old_p"]))
github geoopt / geoopt / geoopt / samplers / base.py View on Github external
def __init__(self, params, defaults):
        super().__init__(params, defaults)
        self.n_rejected = 0
        self.steps = 0
        self.burnin = True

        self.log_probs = []
        self.acceptance_probs = []
        for group in self.param_groups:
            for p in group["params"]:
                if isinstance(p, (ManifoldParameter, ManifoldTensor)):
                    if not p.manifold.reversible:
                        raise ValueError(
                            "Sampling methods can't me applied to manifolds that "
                            "do not implement reversible retraction"