Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def logpdf(self, x):
"""Compute the log-pdf.
Args:
x (input): Values to compute the log-pdf of.
Returns:
list[tensor]: Log-pdf for every input in `x`. If it can be
determined that the list contains only a single log-pdf,
then the list is flattened to a scalar.
"""
logpdfs = -(B.logdet(self.var) +
B.cast(self.dtype, self.dim) *
B.cast(self.dtype, B.log_2_pi) +
B.iqf_diag(self.var, uprank(x) - self.mean)) / 2
return logpdfs[0] if B.shape(logpdfs) == (1,) else logpdfs
def entropy(self):
"""Compute the entropy.
Returns:
scalar: The entropy.
"""
return (B.logdet(self.var) +
B.cast(self.dtype, self.dim) *
B.cast(self.dtype, B.log_2_pi + 1)) / 2
def _compute(self, dists2):
dtype = B.dtype(dists2)
return B.cast(dtype, B.lt(dists2, B.cast(dtype, self.epsilon)))
def from_(cls, constant, ref):
return cls(B.cast(B.dtype(ref), constant), *B.shape(ref))
def __init__(self, dtype, rows, cols=None):
Constant.__init__(self, B.cast(dtype, 1), rows=rows, cols=cols)
def kl(self, other):
"""Compute the KL divergence with respect to another normal
distribution.
Args:
other (:class:`.random.Normal`): Other normal.
Returns:
scalar: KL divergence.
"""
return (B.ratio(self.var, other.var) +
B.iqf_diag(other.var, other.mean - self.mean)[0] -
B.cast(self.dtype, self.dim) +
B.logdet(other.var) - B.logdet(self.var)) / 2
def __neg__(self):
return mul(B.cast(B.dtype(self), -1), self)
def __init__(self, dtype, rows, cols=None):
Constant.__init__(self, B.cast(dtype, 0), rows=rows, cols=cols)
def logpdf(self, x):
"""Compute the log-pdf.
Args:
x (input): Values to compute the log-pdf of.
Returns:
list[tensor]: Log-pdf for every input in `x`. If it can be
determined that the list contains only a single log-pdf,
then the list is flattened to a scalar.
"""
logpdfs = -(B.logdet(self.var) +
B.cast(self.dtype, self.dim) *
B.cast(self.dtype, B.log_2_pi) +
B.iqf_diag(self.var, uprank(x) - self.mean)) / 2
return logpdfs[0] if B.shape(logpdfs) == (1,) else logpdfs