Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _step(self, p, r, epsilon):
if isinstance(p, (ManifoldParameter, ManifoldTensor)):
manifold = p.manifold
else:
manifold = self._default_manifold
egrad2rgrad = manifold.egrad2rgrad
retr_transp = manifold.retr_transp
r.add_(epsilon * egrad2rgrad(p, p.grad))
p_, r_ = retr_transp(p, r * epsilon, r)
copy_or_set_(p, p_)
r.set_(r_)
for group in self.param_groups:
for p in group["params"]:
if isinstance(p, (ManifoldParameter, ManifoldTensor)):
manifold = p.manifold
else:
manifold = self._default_manifold
egrad2rgrad = manifold.egrad2rgrad
retr_transp = manifold.retr_transp
epsilon, alpha = group["epsilon"], group["alpha"]
v = self.state[p]["v"]
p_, v_ = retr_transp(p, v, v)
copy_or_set_(p, p_)
v.set_(v_)
n = egrad2rgrad(p, torch.randn_like(v))
v.mul_(1 - alpha).add_(epsilon * p.grad).add_(
math.sqrt(2 * alpha * epsilon) * n
)
p.grad.zero_()
r = v / epsilon
H_new += 0.5 * (r * r).sum().item()
if not self.burnin:
self.steps += 1
self.log_probs.append(logp)
def proj_(self) -> torch.Tensor:
"""
Inplace projection to the manifold.
Returns
-------
tensor
same instance
"""
return copy_or_set_(self, self.manifold.projx(self))
def stabilize_group(self, group):
for p in group["params"]:
if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
continue
copy_or_set_(p, p.manifold.projx(p))
state = self.state[p]
if not state: # due to None grads
continue
copy_or_set_(state["old_p"], p.manifold.projx(state["old_p"]))
with torch.no_grad():
for group in self.param_groups:
for p in group["params"]:
if isinstance(p, (ManifoldParameter, ManifoldTensor)):
manifold = p.manifold
else:
manifold = self._default_manifold
egrad2rgrad, retr = manifold.egrad2rgrad, manifold.retr
epsilon = group["epsilon"]
n = torch.randn_like(p).mul_(math.sqrt(epsilon))
r = egrad2rgrad(p, 0.5 * epsilon * p.grad + n)
# use copy only for user facing point
copy_or_set_(p, retr(p, r))
p.grad.zero_()
if not self.burnin:
self.steps += 1
self.log_probs.append(logp.item())
def stabilize_group(self, group):
for p in group["params"]:
if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
continue
copy_or_set_(p, p.manifold.projx(p))
momentum_buffer = state["momentum_buffer"]
momentum_buffer.mul_(momentum).add_(1 - dampening, grad)
if nesterov:
grad = grad.add_(momentum, momentum_buffer)
else:
grad = momentum_buffer
# we have all the things projected
new_point, new_momentum_buffer = manifold.retr_transp(
point, -learning_rate * grad, momentum_buffer
)
momentum_buffer.set_(new_momentum_buffer)
# use copy only for user facing point
copy_or_set_(point, new_point)
else:
new_point = manifold.retr(point, -learning_rate * grad)
copy_or_set_(point, new_point)
group["step"] += 1
if self._stabilize is not None and group["step"] % self._stabilize == 0:
self.stabilize_group(group)
return loss
def stabilize_group(self, group):
for p in group["params"]:
if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
continue
state = self.state[p]
if not state: # due to None grads
continue
manifold = p.manifold
exp_avg = state["exp_avg"]
copy_or_set_(p, manifold.projx(p))
exp_avg.set_(manifold.proju(p, exp_avg))