How to use the geoopt.utils.copy_or_set_ function in geoopt

To help you get started, we’ve selected a few geoopt examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github geoopt / geoopt / geoopt / samplers / rhmc.py View on Github external
def _step(self, p, r, epsilon):
        if isinstance(p, (ManifoldParameter, ManifoldTensor)):
            manifold = p.manifold
        else:
            manifold = self._default_manifold

        egrad2rgrad = manifold.egrad2rgrad
        retr_transp = manifold.retr_transp

        r.add_(epsilon * egrad2rgrad(p, p.grad))
        p_, r_ = retr_transp(p, r * epsilon, r)
        copy_or_set_(p, p_)
        r.set_(r_)
github geoopt / geoopt / geoopt / samplers / sgrhmc.py View on Github external
for group in self.param_groups:
                    for p in group["params"]:
                        if isinstance(p, (ManifoldParameter, ManifoldTensor)):
                            manifold = p.manifold
                        else:
                            manifold = self._default_manifold

                        egrad2rgrad = manifold.egrad2rgrad
                        retr_transp = manifold.retr_transp

                        epsilon, alpha = group["epsilon"], group["alpha"]

                        v = self.state[p]["v"]

                        p_, v_ = retr_transp(p, v, v)
                        copy_or_set_(p, p_)
                        v.set_(v_)

                        n = egrad2rgrad(p, torch.randn_like(v))
                        v.mul_(1 - alpha).add_(epsilon * p.grad).add_(
                            math.sqrt(2 * alpha * epsilon) * n
                        )
                        p.grad.zero_()

                        r = v / epsilon
                        H_new += 0.5 * (r * r).sum().item()

        if not self.burnin:
            self.steps += 1
            self.log_probs.append(logp)
github geoopt / geoopt / geoopt / tensor.py View on Github external
def proj_(self) -> torch.Tensor:
        """
        Inplace projection to the manifold.

        Returns
        -------
        tensor
            same instance
        """
        return copy_or_set_(self, self.manifold.projx(self))
github geoopt / geoopt / geoopt / samplers / rhmc.py View on Github external
def stabilize_group(self, group):
        for p in group["params"]:
            if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
                continue
            copy_or_set_(p, p.manifold.projx(p))
            state = self.state[p]
            if not state:  # due to None grads
                continue
            copy_or_set_(state["old_p"], p.manifold.projx(state["old_p"]))
github geoopt / geoopt / geoopt / samplers / rsgld.py View on Github external
with torch.no_grad():
            for group in self.param_groups:
                for p in group["params"]:
                    if isinstance(p, (ManifoldParameter, ManifoldTensor)):
                        manifold = p.manifold
                    else:
                        manifold = self._default_manifold

                    egrad2rgrad, retr = manifold.egrad2rgrad, manifold.retr
                    epsilon = group["epsilon"]

                    n = torch.randn_like(p).mul_(math.sqrt(epsilon))
                    r = egrad2rgrad(p, 0.5 * epsilon * p.grad + n)
                    # use copy only for user facing point
                    copy_or_set_(p, retr(p, r))
                    p.grad.zero_()

        if not self.burnin:
            self.steps += 1
            self.log_probs.append(logp.item())
github geoopt / geoopt / geoopt / samplers / rsgld.py View on Github external
def stabilize_group(self, group):
        for p in group["params"]:
            if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
                continue
            copy_or_set_(p, p.manifold.projx(p))
github geoopt / geoopt / geoopt / optim / rsgd.py View on Github external
momentum_buffer = state["momentum_buffer"]
                        momentum_buffer.mul_(momentum).add_(1 - dampening, grad)
                        if nesterov:
                            grad = grad.add_(momentum, momentum_buffer)
                        else:
                            grad = momentum_buffer
                        # we have all the things projected
                        new_point, new_momentum_buffer = manifold.retr_transp(
                            point, -learning_rate * grad, momentum_buffer
                        )
                        momentum_buffer.set_(new_momentum_buffer)
                        # use copy only for user facing point
                        copy_or_set_(point, new_point)
                    else:
                        new_point = manifold.retr(point, -learning_rate * grad)
                        copy_or_set_(point, new_point)

                    group["step"] += 1
                if self._stabilize is not None and group["step"] % self._stabilize == 0:
                    self.stabilize_group(group)
        return loss
github geoopt / geoopt / geoopt / optim / radam.py View on Github external
def stabilize_group(self, group):
        for p in group["params"]:
            if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
                continue
            state = self.state[p]
            if not state:  # due to None grads
                continue
            manifold = p.manifold
            exp_avg = state["exp_avg"]
            copy_or_set_(p, manifold.projx(p))
            exp_avg.set_(manifold.proju(p, exp_avg))