How to use the numpyro.distributions.constraints function in numpyro

To help you get started, we’ve selected a few numpyro examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / test / test_autoguide.py View on Github external
flows.append(transforms.PermuteTransform(jnp.arange(dim + 1)[::-1]))
        arn_init, arn_apply = AutoregressiveNN(dim + 1, [dim + 1, dim + 1],
                                               permutation=jnp.arange(dim + 1),
                                               skip_connections=guide._skip_connections,
                                               nonlinearity=guide._nonlinearity)
        arn = partial(arn_apply, params['auto_arn__{}$params'.format(i)])
        flows.append(InverseAutoregressiveTransform(arn))
    flows.append(guide._unpack_latent)

    transform = transforms.ComposeTransform(flows)
    _, rng_key_sample = random.split(rng_key)
    expected_sample = transform(dist.Normal(jnp.zeros(dim + 1), 1).sample(rng_key_sample))
    expected_output = transform(x)
    assert_allclose(actual_sample['coefs'], expected_sample['coefs'])
    assert_allclose(actual_sample['offset'],
                    transforms.biject_to(constraints.interval(-1, 1))(expected_sample['offset']))
    check_eq(actual_output, expected_output)
github pyro-ppl / numpyro / test / test_distributions.py View on Github external
def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)):
    if isinstance(constraint, constraints._Boolean):
        return random.bernoulli(key, shape=size) - 2
    elif isinstance(constraint, constraints._GreaterThan):
        return constraint.lower_bound - jnp.exp(random.normal(key, size))
    elif isinstance(constraint, constraints._IntegerInterval):
        lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
        return random.randint(key, size, lower_bound - 1, lower_bound)
    elif isinstance(constraint, constraints._IntegerGreaterThan):
        return constraint.lower_bound - random.poisson(key, np.array(5), shape=size)
    elif isinstance(constraint, constraints._Interval):
        upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
        return random.uniform(key, size, minval=upper_bound, maxval=upper_bound + 1.)
    elif isinstance(constraint, (constraints._Real, constraints._RealVector)):
        return lax.full(size, jnp.nan)
    elif isinstance(constraint, constraints._Simplex):
        return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1]) + 1e-2
    elif isinstance(constraint, constraints._Multinomial):
        n = size[-1]
        return multinomial(key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1]) + 1
github pyro-ppl / numpyro / numpyro / contrib / distributions / continuous.py View on Github external
#
# Copyright (c) 2003-2019 SciPy Developers.
# All rights reserved.

import jax.numpy as jnp
from jax.numpy.lax_numpy import _promote_dtypes
import jax.random as random
from jax.scipy.special import digamma, gammaln, log_ndtr, ndtr, ndtri

from numpyro.contrib.distributions.distribution import jax_continuous
from numpyro.distributions import constraints


class beta_gen(jax_continuous):
    arg_constraints = {'a': constraints.positive, 'b': constraints.positive}
    _support_mask = constraints.unit_interval

    def _rvs(self, a, b):
        # TODO: use upstream implementation when available
        # XXX the implementation is different from PyTorch's one
        # in PyTorch, a sample is generated from dirichlet distribution
        key_a, key_b = random.split(self._random_state)
        gamma_a = random.gamma(key_a, a, shape=self._size)
        gamma_b = random.gamma(key_b, b, shape=self._size)
        return gamma_a / (gamma_a + gamma_b)

    def _cdf(self, x, a, b):
        raise NotImplementedError('Missing jax.scipy.special.btdtr')

    def _ppf(self, q, a, b):
        raise NotImplementedError('Missing jax.scipy.special.btdtri')
github pyro-ppl / numpyro / numpyro / distributions / discrete.py View on Github external
def support(self):
        return constraints.integer_interval(0, self.total_count)
github pyro-ppl / numpyro / numpyro / distributions / continuous.py View on Github external
useful when we know a priori that some underlying variables are correlated.

    :param int dimension: dimension of the matrices
    :param ndarray concentration: concentration/shape parameter of the
        distribution (often referred to as eta)
    :param str sample_method: Either "cvine" or "onion". Both methods are proposed in [1] and
        offer the same distribution over correlation matrices. But they are different in how
        to generate samples. Defaults to "onion".

    **References**

    [1] `Generating random correlation matrices based on vines and extended onion method`,
    Daniel Lewandowski, Dorota Kurowicka, Harry Joe
    """
    arg_constraints = {'concentration': constraints.positive}
    support = constraints.corr_matrix

    def __init__(self, dimension, concentration=1., sample_method='onion', validate_args=None):
        base_dist = LKJCholesky(dimension, concentration, sample_method)
        self.dimension, self.concentration = base_dist.dimension, base_dist.concentration
        self.sample_method = sample_method
        super(LKJ, self).__init__(base_dist, InvCholeskyTransform(domain=constraints.corr_cholesky),
                                  validate_args=validate_args)

    @property
    def mean(self):
        return jnp.broadcast_to(jnp.identity(self.dimension), self.batch_shape + (self.dimension, self.dimension))

    def tree_flatten(self):
        return (self.concentration,), (self.dimension, self.sample_method)

    @classmethod
github pyro-ppl / numpyro / numpyro / distributions / constraint_registry.py View on Github external
def register(self, constraint, factory=None):
        if factory is None:
            return lambda factory: self.register(constraint, factory)

        if isinstance(constraint, constraints.Constraint):
            constraint = type(constraint)

        self._registry[constraint] = factory
github pyro-ppl / numpyro / numpyro / contrib / distributions / multivariate.py View on Github external
def arg_constraints(self):
        if self.is_logits:
            return {'n': constraints.nonnegative_integer,
                    'p': constraints.real}
        else:
            return {'n': constraints.nonnegative_integer,
                    'p': constraints.simplex}
github pyro-ppl / numpyro / numpyro / distributions / continuous.py View on Github external
    @property
    def variance(self):
        return self.concentration / jnp.power(self.rate, 2)


class Chi2(Gamma):
    arg_constraints = {'df': constraints.positive}

    def __init__(self, df, validate_args=None):
        self.df = df
        super(Chi2, self).__init__(0.5 * df, 0.5, validate_args=validate_args)


class GaussianRandomWalk(Distribution):
    arg_constraints = {'scale': constraints.positive, 'num_steps': constraints.positive_integer}
    support = constraints.real_vector
    reparametrized_params = ['scale']

    def __init__(self, scale=1., num_steps=1, validate_args=None):
        assert jnp.shape(num_steps) == ()
        self.scale = scale
        self.num_steps = num_steps
        batch_shape, event_shape = jnp.shape(scale), (num_steps,)
        super(GaussianRandomWalk, self).__init__(batch_shape, event_shape, validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        shape = sample_shape + self.batch_shape + self.event_shape
        walks = random.normal(key, shape=shape)
        return jnp.cumsum(walks, axis=-1) * jnp.expand_dims(self.scale, axis=-1)

    @validate_sample
github pyro-ppl / numpyro / numpyro / contrib / distributions / discrete.py View on Github external
def arg_constraints(self):
        if self.is_logits:
            return {'n': constraints.nonnegative_integer,
                    'p': constraints.real}
        else:
            return {'n': constraints.nonnegative_integer,
                    'p': constraints.unit_interval}
github pyro-ppl / numpyro / numpyro / distributions / discrete.py View on Github external
def variance(self):
        return jnp.full(self.batch_shape, jnp.nan, dtype=get_dtype(self.probs))

    @property
    def support(self):
        return constraints.integer_interval(0, jnp.shape(self.probs)[-1] - 1)

    def enumerate_support(self, expand=True):
        values = jnp.arange(self.probs.shape[-1]).reshape((-1,) + (1,) * len(self.batch_shape))
        if expand:
            values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
        return values


class CategoricalLogits(Distribution):
    arg_constraints = {'logits': constraints.real_vector}
    has_enumerate_support = True
    is_discrete = True

    def __init__(self, logits, validate_args=None):
        if jnp.ndim(logits) < 1:
            raise ValueError("`logits` parameter must be at least one-dimensional.")
        self.logits = logits
        super(CategoricalLogits, self).__init__(batch_shape=jnp.shape(logits)[:-1],
                                                validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        return random.categorical(key, self.logits, shape=sample_shape + self.batch_shape)

    @validate_sample
    def log_prob(self, value):
        batch_shape = lax.broadcast_shapes(jnp.shape(value), self.batch_shape)