Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
@property
def variance(self):
# var is inf for alpha <= 2
a = lax.div((self.scale ** 2) * self.alpha, (self.alpha - 1) ** 2 * (self.alpha - 2))
return jnp.where(self.alpha <= 2, jnp.inf, a)
# override the default behaviour to save computations
@property
def support(self):
return constraints.greater_than(self.scale)
def tree_flatten(self):
return super(TransformedDistribution, self).tree_flatten()
class StudentT(Distribution):
arg_constraints = {'df': constraints.positive, 'loc': constraints.real, 'scale': constraints.positive}
support = constraints.real
reparametrized_params = ['loc', 'scale']
def __init__(self, df, loc=0., scale=1., validate_args=None):
batch_shape = lax.broadcast_shapes(jnp.shape(df), jnp.shape(loc), jnp.shape(scale))
self.df = jnp.broadcast_to(df, batch_shape)
self.loc, self.scale = promote_shapes(loc, scale, shape=batch_shape)
self._chi2 = Chi2(self.df)
super(StudentT, self).__init__(batch_shape, validate_args=validate_args)
def sample(self, key, sample_shape=()):
key_normal, key_chi2 = random.split(key)
std_normal = random.normal(key_normal, shape=sample_shape + self.batch_shape)
z = self._chi2.sample(key_chi2, sample_shape)
y = std_normal * jnp.sqrt(self.df / z)
@property
def support(self):
return constraints.multinomial(self.total_count)
def Multinomial(total_count=1, probs=None, logits=None, validate_args=None):
if probs is not None:
return MultinomialProbs(probs, total_count, validate_args=validate_args)
elif logits is not None:
return MultinomialLogits(logits, total_count, validate_args=validate_args)
else:
raise ValueError('One of `probs` or `logits` must be specified.')
class Poisson(Distribution):
arg_constraints = {'rate': constraints.positive}
support = constraints.nonnegative_integer
is_discrete = True
def __init__(self, rate, validate_args=None):
self.rate = rate
super(Poisson, self).__init__(jnp.shape(rate), validate_args=validate_args)
def sample(self, key, sample_shape=()):
return random.poisson(key, self.rate, shape=sample_shape + self.batch_shape)
@validate_sample
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
return (jnp.log(self.rate) * value) - gammaln(value + 1) - self.rate
@property
def mean(self):
return self.probs
@property
def variance(self):
return self.probs * (1 - self.probs)
def enumerate_support(self, expand=True):
values = jnp.arange(2).reshape((-1,) + (1,) * len(self.batch_shape))
if expand:
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values
class BernoulliLogits(Distribution):
arg_constraints = {'logits': constraints.real}
support = constraints.boolean
has_enumerate_support = True
is_discrete = True
def __init__(self, logits=None, validate_args=None):
self.logits = logits
super(BernoulliLogits, self).__init__(batch_shape=jnp.shape(self.logits), validate_args=validate_args)
def sample(self, key, sample_shape=()):
return random.bernoulli(key, self.probs, shape=sample_shape + self.batch_shape)
@validate_sample
def log_prob(self, value):
return -binary_cross_entropy_with_logits(self.logits, value)
@validate_sample
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
return (jnp.log(self.rate) * value) - gammaln(value + 1) - self.rate
@property
def mean(self):
return self.rate
@property
def variance(self):
return self.rate
class ZeroInflatedPoisson(Distribution):
"""
A Zero Inflated Poisson distribution.
:param numpy.ndarray gate: probability of extra zeros.
:param numpy.ndarray rate: rate of Poisson distribution.
"""
arg_constraints = {'gate': constraints.unit_interval, 'rate': constraints.positive}
support = constraints.nonnegative_integer
is_discrete = True
def __init__(self, gate, rate=1., validate_args=None):
batch_shape = lax.broadcast_shapes(jnp.shape(gate), jnp.shape(rate))
self.gate, self.rate = promote_shapes(gate, rate)
super(ZeroInflatedPoisson, self).__init__(batch_shape, validate_args=validate_args)
def sample(self, key, sample_shape=()):
diff,
self._capacitance_tril)
log_det = _batch_lowrank_logdet(self.cov_factor,
self.cov_diag,
self._capacitance_tril)
return -0.5 * (self.loc.shape[-1] * jnp.log(2 * jnp.pi) + log_det + M)
def entropy(self):
log_det = _batch_lowrank_logdet(self.cov_factor,
self.cov_diag,
self._capacitance_tril)
H = 0.5 * (self.loc.shape[-1] * (1.0 + jnp.log(2 * jnp.pi)) + log_det)
return jnp.broadcast_to(H, self.batch_shape)
class Normal(Distribution):
arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
support = constraints.real
reparametrized_params = ['loc', 'scale']
def __init__(self, loc=0., scale=1., validate_args=None):
self.loc, self.scale = promote_shapes(loc, scale)
batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
super(Normal, self).__init__(batch_shape=batch_shape, validate_args=validate_args)
def sample(self, key, sample_shape=()):
eps = random.normal(key, shape=sample_shape + self.batch_shape + self.event_shape)
return self.loc + eps * self.scale
@validate_sample
def log_prob(self, value):
normalize_term = jnp.log(jnp.sqrt(2 * jnp.pi) * self.scale)
def set_default_validate_args(value):
if value not in [True, False]:
raise ValueError
Distribution._validate_args = value
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
from jax import lax, random
import jax.numpy as np
from jax.scipy.special import betaln, gammaln
from numpyro.distributions import constraints
from numpyro.distributions.continuous import Beta, Gamma
from numpyro.distributions.discrete import Binomial, Poisson
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import promote_shapes, validate_sample
class BetaBinomial(Distribution):
r"""
Compound distribution comprising of a beta-binomial pair. The probability of
success (``probs`` for the :class:`~numpyro.distributions.Binomial` distribution)
is unknown and randomly drawn from a :class:`~numpyro.distributions.Beta` distribution
prior to a certain number of Bernoulli trials given by ``total_count``.
:param numpy.ndarray concentration1: 1st concentration parameter (alpha) for the
Beta distribution.
:param numpy.ndarray concentration0: 2nd concentration parameter (beta) for the
Beta distribution.
:param numpy.ndarray total_count: number of Bernoulli trials.
"""
arg_constraints = {'concentration1': constraints.positive, 'concentration0': constraints.positive,
'total_count': constraints.nonnegative_integer}
def __init__(self, concentration1, concentration0, total_count=1, validate_args=None):
"""
Masks a distribution by a boolean or boolean-valued array that is
broadcastable to the distributions
:attr:`Distribution.batch_shape` .
:param mask: A boolean or boolean valued array.
:type mask: bool or jnp.ndarray
:return: A masked copy of this distribution.
:rtype: :class:`MaskedDistribution`
"""
if mask is True:
return self
return MaskedDistribution(self, mask)
class ExpandedDistribution(Distribution):
arg_constraints = {}
def __init__(self, base_dist, batch_shape=()):
if isinstance(base_dist, ExpandedDistribution):
batch_shape = self._broadcast_shape(base_dist.batch_shape, batch_shape)
base_dist = base_dist.base_dist
self.base_dist = base_dist
super().__init__(base_dist.batch_shape, base_dist.event_shape)
# adjust batch shape
self.expand(batch_shape)
def expand(self, batch_shape):
# Do basic validation. e.g. we should not "unexpand" distributions even if that is possible.
new_shape, _, _ = self._broadcast_shape(self.batch_shape, batch_shape)
# Record interstitial and expanded dims/sizes w.r.t. the base distribution
new_shape, expanded_sizes, interstitial_sizes = self._broadcast_shape(self.base_dist.batch_shape,
return constraints.integer_interval(0, self.total_count)
def enumerate_support(self, expand=True):
total_count = jnp.amax(self.total_count)
if not_jax_tracer(total_count):
# NB: the error can't be raised if inhomogeneous issue happens when tracing
if jnp.amin(self.total_count) != total_count:
raise NotImplementedError("Inhomogeneous total count not supported"
" by `enumerate_support`.")
values = jnp.arange(total_count + 1).reshape((-1,) + (1,) * len(self.batch_shape))
if expand:
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values
class BinomialLogits(Distribution):
arg_constraints = {'logits': constraints.real,
'total_count': constraints.nonnegative_integer}
has_enumerate_support = True
is_discrete = True
def __init__(self, logits, total_count=1, validate_args=None):
self.logits, self.total_count = promote_shapes(logits, total_count)
batch_shape = lax.broadcast_shapes(jnp.shape(logits), jnp.shape(total_count))
super(BinomialLogits, self).__init__(batch_shape=batch_shape, validate_args=validate_args)
def sample(self, key, sample_shape=()):
return binomial(key, self.probs, n=self.total_count, shape=sample_shape + self.batch_shape)
@validate_sample
def log_prob(self, value):
log_factorial_n = gammaln(self.total_count + 1)
@validate_sample
def log_prob(self, value):
return self._dirichlet.log_prob(jnp.stack([value, 1. - value], -1))
@property
def mean(self):
return self.concentration1 / (self.concentration1 + self.concentration0)
@property
def variance(self):
total = self.concentration1 + self.concentration0
return self.concentration1 * self.concentration0 / (total ** 2 * (total + 1))
class Cauchy(Distribution):
arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
support = constraints.real
reparametrized_params = ['loc', 'scale']
def __init__(self, loc=0., scale=1., validate_args=None):
self.loc, self.scale = promote_shapes(loc, scale)
batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
super(Cauchy, self).__init__(batch_shape=batch_shape, validate_args=validate_args)
def sample(self, key, sample_shape=()):
eps = random.cauchy(key, shape=sample_shape + self.batch_shape)
return self.loc + eps * self.scale
@validate_sample
def log_prob(self, value):
return - jnp.log(jnp.pi) - jnp.log(self.scale) - jnp.log1p(((value - self.loc) / self.scale) ** 2)