Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
flows.append(transforms.PermuteTransform(jnp.arange(dim + 1)[::-1]))
arn_init, arn_apply = AutoregressiveNN(dim + 1, [dim + 1, dim + 1],
permutation=jnp.arange(dim + 1),
skip_connections=guide._skip_connections,
nonlinearity=guide._nonlinearity)
arn = partial(arn_apply, params['auto_arn__{}$params'.format(i)])
flows.append(InverseAutoregressiveTransform(arn))
flows.append(guide._unpack_latent)
transform = transforms.ComposeTransform(flows)
_, rng_key_sample = random.split(rng_key)
expected_sample = transform(dist.Normal(jnp.zeros(dim + 1), 1).sample(rng_key_sample))
expected_output = transform(x)
assert_allclose(actual_sample['coefs'], expected_sample['coefs'])
assert_allclose(actual_sample['offset'],
transforms.biject_to(constraints.interval(-1, 1))(expected_sample['offset']))
check_eq(actual_output, expected_output)
def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)):
if isinstance(constraint, constraints._Boolean):
return random.bernoulli(key, shape=size) - 2
elif isinstance(constraint, constraints._GreaterThan):
return constraint.lower_bound - jnp.exp(random.normal(key, size))
elif isinstance(constraint, constraints._IntegerInterval):
lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
return random.randint(key, size, lower_bound - 1, lower_bound)
elif isinstance(constraint, constraints._IntegerGreaterThan):
return constraint.lower_bound - random.poisson(key, np.array(5), shape=size)
elif isinstance(constraint, constraints._Interval):
upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
return random.uniform(key, size, minval=upper_bound, maxval=upper_bound + 1.)
elif isinstance(constraint, (constraints._Real, constraints._RealVector)):
return lax.full(size, jnp.nan)
elif isinstance(constraint, constraints._Simplex):
return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1]) + 1e-2
elif isinstance(constraint, constraints._Multinomial):
n = size[-1]
return multinomial(key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1]) + 1
#
# Copyright (c) 2003-2019 SciPy Developers.
# All rights reserved.
import jax.numpy as jnp
from jax.numpy.lax_numpy import _promote_dtypes
import jax.random as random
from jax.scipy.special import digamma, gammaln, log_ndtr, ndtr, ndtri
from numpyro.contrib.distributions.distribution import jax_continuous
from numpyro.distributions import constraints
class beta_gen(jax_continuous):
arg_constraints = {'a': constraints.positive, 'b': constraints.positive}
_support_mask = constraints.unit_interval
def _rvs(self, a, b):
# TODO: use upstream implementation when available
# XXX the implementation is different from PyTorch's one
# in PyTorch, a sample is generated from dirichlet distribution
key_a, key_b = random.split(self._random_state)
gamma_a = random.gamma(key_a, a, shape=self._size)
gamma_b = random.gamma(key_b, b, shape=self._size)
return gamma_a / (gamma_a + gamma_b)
def _cdf(self, x, a, b):
raise NotImplementedError('Missing jax.scipy.special.btdtr')
def _ppf(self, q, a, b):
raise NotImplementedError('Missing jax.scipy.special.btdtri')
def support(self):
return constraints.integer_interval(0, self.total_count)
useful when we know a priori that some underlying variables are correlated.
:param int dimension: dimension of the matrices
:param ndarray concentration: concentration/shape parameter of the
distribution (often referred to as eta)
:param str sample_method: Either "cvine" or "onion". Both methods are proposed in [1] and
offer the same distribution over correlation matrices. But they are different in how
to generate samples. Defaults to "onion".
**References**
[1] `Generating random correlation matrices based on vines and extended onion method`,
Daniel Lewandowski, Dorota Kurowicka, Harry Joe
"""
arg_constraints = {'concentration': constraints.positive}
support = constraints.corr_matrix
def __init__(self, dimension, concentration=1., sample_method='onion', validate_args=None):
base_dist = LKJCholesky(dimension, concentration, sample_method)
self.dimension, self.concentration = base_dist.dimension, base_dist.concentration
self.sample_method = sample_method
super(LKJ, self).__init__(base_dist, InvCholeskyTransform(domain=constraints.corr_cholesky),
validate_args=validate_args)
@property
def mean(self):
return jnp.broadcast_to(jnp.identity(self.dimension), self.batch_shape + (self.dimension, self.dimension))
def tree_flatten(self):
return (self.concentration,), (self.dimension, self.sample_method)
@classmethod
def register(self, constraint, factory=None):
if factory is None:
return lambda factory: self.register(constraint, factory)
if isinstance(constraint, constraints.Constraint):
constraint = type(constraint)
self._registry[constraint] = factory
def arg_constraints(self):
if self.is_logits:
return {'n': constraints.nonnegative_integer,
'p': constraints.real}
else:
return {'n': constraints.nonnegative_integer,
'p': constraints.simplex}
@property
def variance(self):
return self.concentration / jnp.power(self.rate, 2)
class Chi2(Gamma):
arg_constraints = {'df': constraints.positive}
def __init__(self, df, validate_args=None):
self.df = df
super(Chi2, self).__init__(0.5 * df, 0.5, validate_args=validate_args)
class GaussianRandomWalk(Distribution):
arg_constraints = {'scale': constraints.positive, 'num_steps': constraints.positive_integer}
support = constraints.real_vector
reparametrized_params = ['scale']
def __init__(self, scale=1., num_steps=1, validate_args=None):
assert jnp.shape(num_steps) == ()
self.scale = scale
self.num_steps = num_steps
batch_shape, event_shape = jnp.shape(scale), (num_steps,)
super(GaussianRandomWalk, self).__init__(batch_shape, event_shape, validate_args=validate_args)
def sample(self, key, sample_shape=()):
shape = sample_shape + self.batch_shape + self.event_shape
walks = random.normal(key, shape=shape)
return jnp.cumsum(walks, axis=-1) * jnp.expand_dims(self.scale, axis=-1)
@validate_sample
def arg_constraints(self):
if self.is_logits:
return {'n': constraints.nonnegative_integer,
'p': constraints.real}
else:
return {'n': constraints.nonnegative_integer,
'p': constraints.unit_interval}
def variance(self):
return jnp.full(self.batch_shape, jnp.nan, dtype=get_dtype(self.probs))
@property
def support(self):
return constraints.integer_interval(0, jnp.shape(self.probs)[-1] - 1)
def enumerate_support(self, expand=True):
values = jnp.arange(self.probs.shape[-1]).reshape((-1,) + (1,) * len(self.batch_shape))
if expand:
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values
class CategoricalLogits(Distribution):
arg_constraints = {'logits': constraints.real_vector}
has_enumerate_support = True
is_discrete = True
def __init__(self, logits, validate_args=None):
if jnp.ndim(logits) < 1:
raise ValueError("`logits` parameter must be at least one-dimensional.")
self.logits = logits
super(CategoricalLogits, self).__init__(batch_shape=jnp.shape(logits)[:-1],
validate_args=validate_args)
def sample(self, key, sample_shape=()):
return random.categorical(key, self.logits, shape=sample_shape + self.batch_shape)
@validate_sample
def log_prob(self, value):
batch_shape = lax.broadcast_shapes(jnp.shape(value), self.batch_shape)