Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
self.__class__,
self.storage(),
self.storage_offset(),
self.size(),
self.stride(),
self.requires_grad,
dict(),
)
return _rebuild_manifold_parameter, proto + (self.manifold,)
@insert_docs(Manifold.unpack_tensor.__doc__, r"\s+tensor : .+\n.+", "")
def unpack_tensor(self) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
return self.manifold.unpack_tensor(self)
class ManifoldParameter(ManifoldTensor, torch.nn.Parameter):
"""Same as :class:`torch.nn.Parameter` that has information about its manifold.
It should be used within :class:`torch.nn.Module` to be recognized
in parameter collection.
Other Parameters
----------------
manifold : :class:`geoopt.Manifold` (optional)
A manifold for the tensor if ``data`` is not a :class:`geoopt.ManifoldTensor`
"""
def __new__(cls, data=None, manifold=None, requires_grad=True):
if data is None:
data = ManifoldTensor(manifold=manifold or Euclidean())
elif not isinstance(data, ManifoldTensor):
data = ManifoldTensor(data, manifold=manifold or Euclidean())
self._step(p, self.state[p]["r"], group["epsilon"])
p.grad.zero_()
logp = closure()
logp.backward()
new_logp = logp.item()
new_H = -new_logp
with torch.no_grad():
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
if isinstance(p, (ManifoldParameter, ManifoldTensor)):
manifold = p.manifold
else:
manifold = Euclidean()
proju = manifold.proju
r = self.state[p]["r"]
r.add_(0.5 * epsilon * proju(p, p.grad))
p.grad.zero_()
new_H += 0.5 * (r * r).sum().item()
rho = min(1.0, math.exp(old_H - new_H))
if not self.burnin:
self.steps += 1
def __new__(cls, data=None, manifold=None, requires_grad=True):
if data is None:
data = ManifoldTensor(manifold=manifold or Euclidean())
elif not isinstance(data, ManifoldTensor):
data = ManifoldTensor(data, manifold=manifold or Euclidean())
else:
if manifold is not None and data.manifold != manifold:
raise ValueError(
"Manifolds do not match: {}, {}".format(data.manifold, manifold)
)
instance = ManifoldTensor._make_subclass(cls, data, requires_grad)
instance.manifold = data.manifold
return instance
def __new__(cls, data=None, manifold=None, requires_grad=True):
if data is None:
data = ManifoldTensor(manifold=manifold or Euclidean())
elif not isinstance(data, ManifoldTensor):
data = ManifoldTensor(data, manifold=manifold or Euclidean())
else:
if manifold is not None and data.manifold != manifold:
raise ValueError(
"Manifolds do not match: {}, {}".format(data.manifold, manifold)
)
instance = ManifoldTensor._make_subclass(cls, data, requires_grad)
instance.manifold = data.manifold
return instance
def __new__(cls, data=None, manifold=None, requires_grad=True):
if data is None:
data = ManifoldTensor(manifold=manifold or Euclidean())
elif not isinstance(data, ManifoldTensor):
data = ManifoldTensor(data, manifold=manifold or Euclidean())
else:
if manifold is not None and data.manifold != manifold:
raise ValueError(
"Manifolds do not match: {}, {}".format(data.manifold, manifold)
)
instance = ManifoldTensor._make_subclass(cls, data, requires_grad)
instance.manifold = data.manifold
return instance
def stabilize_group(self, group):
for p in group["params"]:
if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
continue
manifold = p.manifold
momentum = group["momentum"]
copy_or_set_(p, manifold.projx(p))
if momentum > 0:
param_state = self.state[p]
if not param_state: # due to None grads
continue
if "momentum_buffer" in param_state:
buf = param_state["momentum_buffer"]
buf.set_(manifold.proju(p, buf))
def __new__(cls, data=None, manifold=None, requires_grad=True):
if data is None:
data = ManifoldTensor(manifold=manifold or Euclidean())
elif not isinstance(data, ManifoldTensor):
data = ManifoldTensor(data, manifold=manifold or Euclidean())
else:
if manifold is not None and data.manifold != manifold:
raise ValueError(
"Manifolds do not match: {}, {}".format(data.manifold, manifold)
)
instance = ManifoldTensor._make_subclass(cls, data, requires_grad)
instance.manifold = data.manifold
return instance
the desired shape
device : torch.device
the desired device
dtype : torch.dtype
the desired dtype
seed : int
ignored
Returns
-------
ManifoldTensor
"""
self._assert_check_shape(size2shape(*size), "x")
eye = torch.zeros(*size, dtype=dtype, device=device)
eye[..., torch.arange(eye.shape[-1]), torch.arange(eye.shape[-1])] += 1
return ManifoldTensor(eye, manifold=self)
def stabilize_group(self, group):
for p in group["params"]:
if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
continue
copy_or_set_(p, p.manifold.projx(p))
state = self.state[p]
if not state: # due to None grads
continue
copy_or_set_(state["old_p"], p.manifold.projx(state["old_p"]))
def __init__(self, params, defaults):
super().__init__(params, defaults)
self.n_rejected = 0
self.steps = 0
self.burnin = True
self.log_probs = []
self.acceptance_probs = []
for group in self.param_groups:
for p in group["params"]:
if isinstance(p, (ManifoldParameter, ManifoldTensor)):
if not p.manifold.reversible:
raise ValueError(
"Sampling methods can't me applied to manifolds that "
"do not implement reversible retraction"