Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_model_with_multiple_exec_paths(jit_args):
def model(a=None, b=None, z=None):
int_term = numpyro.sample('a', dist.Normal(0., 0.2))
x_term, y_term = 0., 0.
if a is not None:
x = numpyro.sample('x', dist.HalfNormal(0.5))
x_term = a * x
if b is not None:
y = numpyro.sample('y', dist.HalfNormal(0.5))
y_term = b * y
sigma = numpyro.sample('sigma', dist.Exponential(1.))
mu = int_term + x_term + y_term
numpyro.sample('obs', dist.Normal(mu, sigma), obs=z)
a = jnp.exp(np.random.randn(10))
b = jnp.exp(np.random.randn(10))
z = np.random.randn(10)
# Run MCMC on zero observations.
kernel = NUTS(model)
mcmc = MCMC(kernel, 20, 10, jit_model_args=jit_args)
mcmc.run(random.PRNGKey(1), a, b=None, z=z)
assert set(mcmc.get_samples()) == {'a', 'x', 'sigma'}
mcmc.run(random.PRNGKey(2), a=None, b=b, z=z)
assert set(mcmc.get_samples()) == {'a', 'y', 'sigma'}
mcmc.run(random.PRNGKey(3), a=a, b=b, z=z)
assert set(mcmc.get_samples()) == {'a', 'x', 'y', 'sigma'}
def sigmoid(self, x):
return 1 / (1 + np.exp(-x))
scales_ = jnp.concatenate([scales, scale[None, ...]])
if scale.ndim == 2: # dense_mass
log_weights_ = dist.MultivariateNormal(locs_, scale_tril=scales_).log_prob(zs_) + pes_
else:
log_weights_ = dist.Normal(locs_, scales_).log_prob(zs_).sum(-1) + pes_
log_weights_ = jnp.where(jnp.isnan(log_weights_), -jnp.inf, log_weights_)
# get rejecting index
j = random.categorical(rng_key_reject, log_weights_)
zs = _numpy_delete(zs_, j)
pes = _numpy_delete(pes_, j)
loc = locs_[j]
scale = scales_[j]
adapt_state = SAAdaptState(zs, pes, loc, scale)
# NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
itr = sa_state.i + 1
n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
mean_accept_prob = sa_state.mean_accept_prob + (accept_prob - sa_state.mean_accept_prob) / n
# XXX: we make a modification of SA sampler in [1]
# in [1], each MCMC state contains N points `zs`
# here we do resampling to pick randomly a point from those N points
k = random.categorical(rng_key_accept, jnp.zeros(zs.shape[0]))
z = unravel_fn(zs[k])
pe = pes[k]
return SAState(itr, z, pe, accept_prob, mean_accept_prob, diverging, adapt_state, rng_key)
def model(returns):
step_size = numpyro.sample('sigma', dist.Exponential(50.))
s = numpyro.sample('s', dist.GaussianRandomWalk(scale=step_size, num_steps=jnp.shape(returns)[0]))
nu = numpyro.sample('nu', dist.Exponential(.1))
return numpyro.sample('r', dist.StudentT(df=nu, loc=0., scale=jnp.exp(s)),
obs=returns)
@primitive
def logsumexp(x):
max_x = np.max(x)
return max_x + np.log(np.sum(np.exp(x - max_x)))
def __call__(self, x):
# XXX consider to clamp from below for stability if necessary
return jnp.exp(x)
def softmax(x, axis=-1):
r"""Softmax function.
Computes the function which rescales elements to the range :math:`[0, 1]`
such that the elements along :code:`axis` sum to :math:`1`.
.. math ::
\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
Args:
axis: the axis or axes along which the softmax should be computed. The
softmax output summed across these dimensions should sum to :math:`1`.
Either an integer or a tuple of integers.
"""
unnormalized = np.exp(x - x.max(axis, keepdims=True))
return unnormalized / unnormalized.sum(axis, keepdims=True)
def _logpmf(self, x, n, p):
x, n, p = _promote_dtypes(x, n, p)
combiln = gammaln(n + 1) - (gammaln(x + 1) + gammaln(n - x + 1))
if self.is_logits:
# TODO: move this implementation to PyTorch if it does not get non-continuous problem
# In PyTorch, k * logit - n * log1p(e^logit) get overflow when logit is a large
# positive number. In that case, we can reformulate into
# k * logit - n * log1p(e^logit) = k * logit - n * (log1p(e^-logit) + logit)
# = k * logit - n * logit - n * log1p(e^-logit)
# More context: https://github.com/pytorch/pytorch/pull/15962/
return combiln + x * p - (n * jnp.clip(p, 0) + xlog1py(n, jnp.exp(-jnp.abs(p))))
else:
return combiln + xlogy(x, p) + xlog1py(n - x, -p)
def log_prob(self, value):
z = (value - self.loc) / self.scale
return -(z + jnp.exp(-z)) - jnp.log(self.scale)