How to use the numpyro.distributions.distribution.Distribution function in numpyro

To help you get started, we’ve selected a few numpyro examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / numpyro / distributions / View on Github external
    def variance(self):
        # var is inf for alpha <= 2
        a = lax.div((self.scale ** 2) * self.alpha, (self.alpha - 1) ** 2 * (self.alpha - 2))
        return jnp.where(self.alpha <= 2, jnp.inf, a)

    # override the default behaviour to save computations
    def support(self):
        return constraints.greater_than(self.scale)

    def tree_flatten(self):
        return super(TransformedDistribution, self).tree_flatten()

class StudentT(Distribution):
    arg_constraints = {'df': constraints.positive, 'loc': constraints.real, 'scale': constraints.positive}
    support = constraints.real
    reparametrized_params = ['loc', 'scale']

    def __init__(self, df, loc=0., scale=1., validate_args=None):
        batch_shape = lax.broadcast_shapes(jnp.shape(df), jnp.shape(loc), jnp.shape(scale))
        self.df = jnp.broadcast_to(df, batch_shape)
        self.loc, self.scale = promote_shapes(loc, scale, shape=batch_shape)
        self._chi2 = Chi2(self.df)
        super(StudentT, self).__init__(batch_shape, validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        key_normal, key_chi2 = random.split(key)
        std_normal = random.normal(key_normal, shape=sample_shape + self.batch_shape)
        z = self._chi2.sample(key_chi2, sample_shape)
        y = std_normal * jnp.sqrt(self.df / z)
github pyro-ppl / numpyro / numpyro / distributions / View on Github external
    def support(self):
        return constraints.multinomial(self.total_count)

def Multinomial(total_count=1, probs=None, logits=None, validate_args=None):
    if probs is not None:
        return MultinomialProbs(probs, total_count, validate_args=validate_args)
    elif logits is not None:
        return MultinomialLogits(logits, total_count, validate_args=validate_args)
        raise ValueError('One of `probs` or `logits` must be specified.')

class Poisson(Distribution):
    arg_constraints = {'rate': constraints.positive}
    support = constraints.nonnegative_integer
    is_discrete = True

    def __init__(self, rate, validate_args=None):
        self.rate = rate
        super(Poisson, self).__init__(jnp.shape(rate), validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        return random.poisson(key, self.rate, shape=sample_shape + self.batch_shape)

    def log_prob(self, value):
        if self._validate_args:
        return (jnp.log(self.rate) * value) - gammaln(value + 1) - self.rate
github pyro-ppl / numpyro / numpyro / distributions / View on Github external
    def mean(self):
        return self.probs

    def variance(self):
        return self.probs * (1 - self.probs)

    def enumerate_support(self, expand=True):
        values = jnp.arange(2).reshape((-1,) + (1,) * len(self.batch_shape))
        if expand:
            values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
        return values

class BernoulliLogits(Distribution):
    arg_constraints = {'logits': constraints.real}
    support = constraints.boolean
    has_enumerate_support = True
    is_discrete = True

    def __init__(self, logits=None, validate_args=None):
        self.logits = logits
        super(BernoulliLogits, self).__init__(batch_shape=jnp.shape(self.logits), validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        return random.bernoulli(key, self.probs, shape=sample_shape + self.batch_shape)

    def log_prob(self, value):
        return -binary_cross_entropy_with_logits(self.logits, value)
github pyro-ppl / numpyro / numpyro / distributions / View on Github external
    def log_prob(self, value):
        if self._validate_args:
        return (jnp.log(self.rate) * value) - gammaln(value + 1) - self.rate

    def mean(self):
        return self.rate

    def variance(self):
        return self.rate

class ZeroInflatedPoisson(Distribution):
    A Zero Inflated Poisson distribution.

    :param numpy.ndarray gate: probability of extra zeros.
    :param numpy.ndarray rate: rate of Poisson distribution.
    arg_constraints = {'gate': constraints.unit_interval, 'rate': constraints.positive}
    support = constraints.nonnegative_integer
    is_discrete = True

    def __init__(self, gate, rate=1., validate_args=None):
        batch_shape = lax.broadcast_shapes(jnp.shape(gate), jnp.shape(rate))
        self.gate, self.rate = promote_shapes(gate, rate)
        super(ZeroInflatedPoisson, self).__init__(batch_shape, validate_args=validate_args)

    def sample(self, key, sample_shape=()):
github pyro-ppl / numpyro / numpyro / distributions / View on Github external
        log_det = _batch_lowrank_logdet(self.cov_factor,
        return -0.5 * (self.loc.shape[-1] * jnp.log(2 * jnp.pi) + log_det + M)

    def entropy(self):
        log_det = _batch_lowrank_logdet(self.cov_factor,
        H = 0.5 * (self.loc.shape[-1] * (1.0 + jnp.log(2 * jnp.pi)) + log_det)
        return jnp.broadcast_to(H, self.batch_shape)

class Normal(Distribution):
    arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
    support = constraints.real
    reparametrized_params = ['loc', 'scale']

    def __init__(self, loc=0., scale=1., validate_args=None):
        self.loc, self.scale = promote_shapes(loc, scale)
        batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
        super(Normal, self).__init__(batch_shape=batch_shape, validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        eps = random.normal(key, shape=sample_shape + self.batch_shape + self.event_shape)
        return self.loc + eps * self.scale

    def log_prob(self, value):
        normalize_term = jnp.log(jnp.sqrt(2 * jnp.pi) * self.scale)
github pyro-ppl / numpyro / numpyro / distributions / View on Github external
def set_default_validate_args(value):
        if value not in [True, False]:
            raise ValueError
        Distribution._validate_args = value
github pyro-ppl / numpyro / numpyro / distributions / View on Github external
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from jax import lax, random
import jax.numpy as np
from jax.scipy.special import betaln, gammaln

from numpyro.distributions import constraints
from numpyro.distributions.continuous import Beta, Gamma
from numpyro.distributions.discrete import Binomial, Poisson
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import promote_shapes, validate_sample

class BetaBinomial(Distribution):
    Compound distribution comprising of a beta-binomial pair. The probability of
    success (``probs`` for the :class:`~numpyro.distributions.Binomial` distribution)
    is unknown and randomly drawn from a :class:`~numpyro.distributions.Beta` distribution
    prior to a certain number of Bernoulli trials given by ``total_count``.

    :param numpy.ndarray concentration1: 1st concentration parameter (alpha) for the
        Beta distribution.
    :param numpy.ndarray concentration0: 2nd concentration parameter (beta) for the
        Beta distribution.
    :param numpy.ndarray total_count: number of Bernoulli trials.
    arg_constraints = {'concentration1': constraints.positive, 'concentration0': constraints.positive,
                       'total_count': constraints.nonnegative_integer}

    def __init__(self, concentration1, concentration0, total_count=1, validate_args=None):
github pyro-ppl / numpyro / numpyro / distributions / View on Github external
        Masks a distribution by a boolean or boolean-valued array that is
        broadcastable to the distributions
        :attr:`Distribution.batch_shape` .

        :param mask: A boolean or boolean valued array.
        :type mask: bool or jnp.ndarray
        :return: A masked copy of this distribution.
        :rtype: :class:`MaskedDistribution`
        if mask is True:
            return self
        return MaskedDistribution(self, mask)

class ExpandedDistribution(Distribution):
    arg_constraints = {}

    def __init__(self, base_dist, batch_shape=()):
        if isinstance(base_dist, ExpandedDistribution):
            batch_shape = self._broadcast_shape(base_dist.batch_shape, batch_shape)
            base_dist = base_dist.base_dist
        self.base_dist = base_dist
        super().__init__(base_dist.batch_shape, base_dist.event_shape)
        # adjust batch shape

    def expand(self, batch_shape):
        # Do basic validation. e.g. we should not "unexpand" distributions even if that is possible.
        new_shape, _, _ = self._broadcast_shape(self.batch_shape, batch_shape)
        # Record interstitial and expanded dims/sizes w.r.t. the base distribution
        new_shape, expanded_sizes, interstitial_sizes = self._broadcast_shape(self.base_dist.batch_shape,
github pyro-ppl / numpyro / numpyro / distributions / View on Github external
return constraints.integer_interval(0, self.total_count)

    def enumerate_support(self, expand=True):
        total_count = jnp.amax(self.total_count)
        if not_jax_tracer(total_count):
            # NB: the error can't be raised if inhomogeneous issue happens when tracing
            if jnp.amin(self.total_count) != total_count:
                raise NotImplementedError("Inhomogeneous total count not supported"
                                          " by `enumerate_support`.")
        values = jnp.arange(total_count + 1).reshape((-1,) + (1,) * len(self.batch_shape))
        if expand:
            values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
        return values

class BinomialLogits(Distribution):
    arg_constraints = {'logits': constraints.real,
                       'total_count': constraints.nonnegative_integer}
    has_enumerate_support = True
    is_discrete = True

    def __init__(self, logits, total_count=1, validate_args=None):
        self.logits, self.total_count = promote_shapes(logits, total_count)
        batch_shape = lax.broadcast_shapes(jnp.shape(logits), jnp.shape(total_count))
        super(BinomialLogits, self).__init__(batch_shape=batch_shape, validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        return binomial(key, self.probs, n=self.total_count, shape=sample_shape + self.batch_shape)

    def log_prob(self, value):
        log_factorial_n = gammaln(self.total_count + 1)
github pyro-ppl / numpyro / numpyro / distributions / View on Github external
    def log_prob(self, value):
        return self._dirichlet.log_prob(jnp.stack([value, 1. - value], -1))

    def mean(self):
        return self.concentration1 / (self.concentration1 + self.concentration0)

    def variance(self):
        total = self.concentration1 + self.concentration0
        return self.concentration1 * self.concentration0 / (total ** 2 * (total + 1))

class Cauchy(Distribution):
    arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
    support = constraints.real
    reparametrized_params = ['loc', 'scale']

    def __init__(self, loc=0., scale=1., validate_args=None):
        self.loc, self.scale = promote_shapes(loc, scale)
        batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
        super(Cauchy, self).__init__(batch_shape=batch_shape, validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        eps = random.cauchy(key, shape=sample_shape + self.batch_shape)
        return self.loc + eps * self.scale

    def log_prob(self, value):
        return - jnp.log(jnp.pi) - jnp.log(self.scale) - jnp.log1p(((value - self.loc) / self.scale) ** 2)