Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
Matthew D. Hoffman, Andrew Gelman
:param potential_fn: A callable to compute potential energy.
:param kinetic_fn: A callable to compute kinetic energy.
:param momentum_generator: A generator to get a random momentum variable.
:param float init_step_size: Initial step size to be tuned.
:param inverse_mass_matrix: Inverse of mass matrix.
:param IntegratorState z_info: The current integrator state.
:param jax.random.PRNGKey rng_key: Random key to be used as the source of randomness.
:return: a reasonable value for step size.
:rtype: float
"""
# We are going to find a step_size which make accept_prob (Metropolis correction)
# near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small,
# then we have to decrease step_size; otherwise, increase step_size.
target_accept_prob = jnp.log(0.8)
_, vv_update = velocity_verlet(potential_fn, kinetic_fn)
z, _, potential_energy, z_grad = z_info
if potential_energy is None or z_grad is None:
potential_energy, z_grad = value_and_grad(potential_fn)(z)
finfo = jnp.finfo(get_dtype(init_step_size))
def _body_fn(state):
step_size, _, direction, rng_key = state
rng_key, rng_key_momentum = random.split(rng_key)
# scale step_size: increase 2x or decrease 2x depends on direction;
# direction=1 means keep increasing step_size, otherwise decreasing step_size.
# Note that the direction is -1 if delta_energy is `NaN`, which may be the
# case for a diverging trajectory (e.g. in the case of evaluating log prob
# of a value simulated using a large step size for a constrained sample site).
step_size = (2.0 ** direction) * step_size
def _batch_lowrank_logdet(W, D, capacitance_tril):
r"""
Uses "matrix determinant lemma"::
log|W @ W.T + D| = log|C| + log|D|,
where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute
the log determinant.
"""
return 2 * jnp.sum(jnp.log(jnp.diagonal(capacitance_tril, axis1=-2, axis2=-1)), axis=-1) + jnp.log(D).sum(-1)
"""
# initial population
z_init = numpyro.sample("z_init", dist.LogNormal(jnp.log(10), 1), sample_shape=(2,))
# measurement times
ts = jnp.arange(float(N))
# parameters alpha, beta, gamma, delta of dz_dt
theta = numpyro.sample(
"theta",
dist.TruncatedNormal(low=0., loc=jnp.array([0.5, 0.05, 1.5, 0.05]),
scale=jnp.array([0.5, 0.05, 0.5, 0.05])))
# integrate dz/dt, the result will have shape N x 2
z = odeint(dz_dt, z_init, ts, theta, rtol=1e-5, atol=1e-3, mxstep=500)
# measurement errors, we expect that measured hare has larger error than measured lynx
sigma = numpyro.sample("sigma", dist.Exponential(jnp.array([1, 2])))
# measured populations (in log scale)
numpyro.sample("y", dist.Normal(jnp.log(z), sigma), obs=y)
# Source code modified from scipy.stats._continuous_distns.py
#
# Copyright (c) 2001, 2002 Enthought, Inc.
# All rights reserved.
#
# Copyright (c) 2003-2019 SciPy Developers.
# All rights reserved.
import jax.numpy as np
import jax.random as random
from numpyro.distributions.distribution import jax_continuous
_norm_pdf_C = np.sqrt(2 * np.pi)
_norm_pdf_logC = np.log(_norm_pdf_C)
def _norm_pdf(x):
return np.exp(-x ** 2 / 2.0) / _norm_pdf_C
def _norm_logpdf(x):
return -x ** 2 / 2.0 - _norm_pdf_logC
class norm_gen(jax_continuous):
def _rvs(self):
return random.normal(self._random_state, self._size)
def _stats(self):
return 0.0, 1.0, 0.0, 0.0
def _logpdf(self, x):
return 0.5 * jnp.log(2.0 / jnp.pi) - x * x / 2.0
def predict(model, at_bats, hits, z, rng_key, player_names, train=True):
header = model.__name__ + (' - TRAIN' if train else ' - TEST')
predictions = Predictive(model, posterior_samples=z)(rng_key, at_bats)['obs']
print_results('=' * 30 + header + '=' * 30,
predictions,
player_names,
at_bats,
hits)
if not train:
post_loglik = log_likelihood(model, z, at_bats, hits)['obs']
# computes expected log predictive density at each data point
exp_log_density = logsumexp(post_loglik, axis=0) - jnp.log(jnp.shape(post_loglik)[0])
# reports log predictive density of all test points
print('\nLog pointwise predictive density: {:.2f}\n'.format(exp_log_density.sum()))
# log probability for edge case of 0 (before scaling):
log_cdf_plus = plus_in - softplus(plus_in)
# log probability for edge case of 255 (before scaling):
log_one_minus_cdf_min = - softplus(min_in)
cdf_delta = cdf_plus - cdf_min # probability for all other cases
mid_in = inv_stdv * centered_y
log_pdf_mid = mid_in - log_scales - 2. * softplus(mid_in)
log_probs = np.where(
y < -0.999, log_cdf_plus,
np.where(y > 0.999, log_one_minus_cdf_min,
np.where(cdf_delta > 1e-5,
np.log(np.maximum(cdf_delta, 1e-12)),
log_pdf_mid - np.log((num_class - 1) / 2))))
log_probs = log_probs + log_softmax(logit_probs)
return -np.sum(logsumexp(log_probs, axis=-1), axis=-1)
def logpmf(self, x, p):
batch_shape = lax.broadcast_shapes(x.shape, p.shape[:-1])
# append a dimension to x
# TODO: consider to convert x.dtype to int
x = jnp.expand_dims(x, axis=-1)
x = jnp.broadcast_to(x, batch_shape + (1,))
p = jnp.broadcast_to(p, batch_shape + p.shape[-1:])
if self.is_logits:
# normalize log prob
p = p - logsumexp(p, axis=-1, keepdims=True)
# gather and remove the trailing dimension
return jnp.take_along_axis(p, x, axis=-1)[..., 0]
else:
return jnp.take_along_axis(jnp.log(p), x, axis=-1)[..., 0]
def logmatmulexp(x, y):
"""
Numerically stable version of ``(x.log() @ y.log()).exp()``.
"""
x_shift = lax.stop_gradient(jnp.amax(x, -1, keepdims=True))
y_shift = lax.stop_gradient(jnp.amax(y, -2, keepdims=True))
xy = jnp.log(jnp.matmul(jnp.exp(x - x_shift), jnp.exp(y - y_shift)))
return xy + x_shift + y_shift
def log_prob(self, value):
M = _batch_mahalanobis(self.scale_tril, value - self.loc)
half_log_det = jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum(-1)
normalize_term = half_log_det + 0.5 * self.scale_tril.shape[-1] * jnp.log(2 * jnp.pi)
return - 0.5 * M - normalize_term