Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
warmup_steps, num_samples = (100000, 100000) if kernel_cls is SA else (500, 20000)
def model(data):
alpha = jnp.array([1.1, 1.1])
beta = jnp.array([1.1, 1.1])
p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta))
numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data)
return p_latent
true_probs = jnp.array([0.9, 0.1])
data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2))
if kernel_cls is SA:
kernel = SA(model=model)
else:
kernel = kernel_cls(model=model, trajectory_length=0.1)
mcmc = MCMC(kernel, num_warmup=warmup_steps, num_samples=num_samples, progress_bar=False)
mcmc.run(random.PRNGKey(2), data)
mcmc.print_summary()
samples = mcmc.get_samples()
assert_allclose(jnp.mean(samples['p_latent'], 0), true_probs, atol=0.05)
if 'JAX_ENABLE_X64' in os.environ:
assert samples['p_latent'].dtype == jnp.float64
def test_reuse_mcmc_run(jit_args, shape):
y1 = np.random.normal(3, 0.1, (100,))
y2 = np.random.normal(-3, 0.1, (shape,))
def model(y_obs):
mu = numpyro.sample('mu', dist.Normal(0., 1.))
sigma = numpyro.sample("sigma", dist.HalfCauchy(3.))
numpyro.sample("y", dist.Normal(mu, sigma), obs=y_obs)
# Run MCMC on zero observations.
kernel = NUTS(model)
mcmc = MCMC(kernel, 300, 500, jit_model_args=jit_args)
mcmc.run(random.PRNGKey(32), y1)
# Re-run on new data - should be much faster.
mcmc.run(random.PRNGKey(32), y2)
assert_allclose(mcmc.get_samples()['mu'].mean(), -3., atol=0.1)
def test_uniform_normal():
true_coef = 0.9
num_warmup, num_samples = 1000, 1000
def model(data):
alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
loc = numpyro.sample('loc', dist.Uniform(0, alpha))
numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)
data = true_coef + random.normal(random.PRNGKey(0), (1000,))
kernel = NUTS(model=model)
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
mcmc.run(random.PRNGKey(2), data, collect_warmup=True)
samples = mcmc.get_samples()
assert len(samples['loc']) == num_warmup + num_samples
assert_allclose(np.mean(samples['loc'], 0), true_coef, atol=0.05)
# This requires dense mass matrix estimation.
D = 5
warmup_steps, num_samples = 5000, 8000
true_mean = 0.
a = jnp.tril(0.5 * jnp.fliplr(jnp.eye(D)) + 0.1 * jnp.exp(random.normal(random.PRNGKey(0), shape=(D, D))))
true_cov = jnp.dot(a, a.T)
true_prec = jnp.linalg.inv(true_cov)
def potential_fn(z):
return 0.5 * jnp.dot(z.T, jnp.dot(true_prec, z))
init_params = jnp.zeros(D)
kernel = NUTS(potential_fn=potential_fn, dense_mass=True)
mcmc = MCMC(kernel, warmup_steps, num_samples)
mcmc.run(random.PRNGKey(0), init_params=init_params)
samples = mcmc.get_samples()
assert_allclose(jnp.mean(samples), true_mean, atol=0.02)
assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02
def test_uniform_normal():
true_coef = 0.9
num_warmup, num_samples = 1000, 1000
def model(data):
alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
with numpyro.handlers.reparam(config={'loc': TransformReparam()}):
loc = numpyro.sample('loc', dist.Uniform(0, alpha))
numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)
data = true_coef + random.normal(random.PRNGKey(0), (1000,))
kernel = NUTS(model=model)
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
mcmc.warmup(random.PRNGKey(2), data, collect_warmup=True)
warmup_samples = mcmc.get_samples()
mcmc.run(random.PRNGKey(3), data)
samples = mcmc.get_samples()
assert len(warmup_samples['loc']) == num_warmup
assert len(samples['loc']) == num_samples
assert_allclose(jnp.mean(samples['loc'], 0), true_coef, atol=0.05)
def run_inference(model, at_bats, hits, rng_key, args):
if args.algo == "NUTS":
kernel = NUTS(model)
elif args.algo == "HMC":
kernel = HMC(model)
elif args.algo == "SA":
kernel = SA(model)
mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains,
progress_bar=False if (
"NUMPYRO_SPHINXBUILD" in os.environ or args.disable_progbar) else True)
mcmc.run(rng_key, at_bats, hits)
return mcmc.get_samples()
print("Start training guide...")
last_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, np.zeros(args.num_iters))
params = svi.get_params(last_state)
print("Finish training guide. Extract samples...")
guide_samples = guide.sample_posterior(random.PRNGKey(0), params,
sample_shape=(args.num_samples,))['x'].copy()
transform = guide.get_transform(params)
_, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), dual_moon_model)
transformed_potential_fn = partial(transformed_potential_energy, potential_fn, transform)
transformed_constrain_fn = lambda x: constrain_fn(transform(x)) # noqa: E731
print("\nStart NeuTra HMC...")
nuts_kernel = NUTS(potential_fn=transformed_potential_fn)
mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples)
init_params = np.zeros(guide.latent_size)
mcmc.run(random.PRNGKey(3), init_params=init_params)
mcmc.print_summary()
zs = mcmc.get_samples()
print("Transform samples into unwarped space...")
samples = vmap(transformed_constrain_fn)(zs)
print_summary(tree_map(lambda x: x[None, ...], samples))
samples = samples['x'].copy()
# make plots
# guide samples (for plotting)
guide_base_samples = dist.Normal(np.zeros(2), 1.).sample(random.PRNGKey(4), (1000,))
guide_trans_samples = vmap(transformed_constrain_fn)(guide_base_samples)['x']
x1 = np.linspace(-3, 3, 100)
def run_hmc(model, args, rng_key, X, Y, hypers):
start = time.time()
kernel = NUTS(model, max_tree_depth=args['mtd'])
mcmc = MCMC(kernel, args['num_warmup'], args['num_samples'], num_chains=1,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
mcmc.run(rng_key, X, Y, hypers)
#mcmc.print_summary()
elapsed_time = time.time() - start
samples = mcmc.get_samples()
return samples, elapsed_time
def run_inference(model, args, rng_key, X, Y, hypers):
start = time.time()
kernel = NUTS(model)
mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
mcmc.run(rng_key, X, Y, hypers)
mcmc.print_summary()
print('\nMCMC elapsed time:', time.time() - start)
return mcmc.get_samples()
def mcmc(self, num_samples=10000, warmup_steps=1000, num_chains=1, thin=1, kernel=None):
if kernel is None:
kernel = numpyro.infer.NUTS(self._model, adapt_step_size=True)
mcmc = numpyro.infer.MCMC(
kernel, warmup_steps, num_samples - warmup_steps, num_chains=num_chains)
return MCMCProxy(mcmc, True, self._generated_quantities, self._transformed_data, thin)