Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_mcmc_parallel_chain(deterministic):
GLOBAL["count"] = 0
mcmc = MCMC(NUTS(model), 100, 100, num_chains=2)
mcmc.run(random.PRNGKey(0), deterministic=deterministic)
mcmc.get_samples()
if deterministic:
assert GLOBAL["count"] == 4
else:
assert GLOBAL["count"] == 3
def test_chain_smoke(chain_method, compile_args):
def model(data):
concentration = jnp.array([1.0, 1.0, 1.0])
p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
return p_latent
data = dist.Categorical(jnp.array([0.1, 0.6, 0.3])).sample(random.PRNGKey(1), (2000,))
kernel = NUTS(model)
mcmc = MCMC(kernel, 2, 5, num_chains=2, chain_method=chain_method, jit_model_args=compile_args)
mcmc.warmup(random.PRNGKey(0), data)
mcmc.run(random.PRNGKey(1), data)
def test_empty_model(num_chains, chain_method, progress_bar):
def model():
pass
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10, num_chains=num_chains,
chain_method=chain_method, progress_bar=progress_bar)
mcmc.run(random.PRNGKey(0))
assert mcmc.get_samples() == {}
mu1 = beta * mu0 + x1
y1 = numpyro.sample('y', dist.Normal(mu1, r))
numpyro.deterministic('y2', y1 * 2)
return (x1, mu1), (x1, y1)
mu0 = x0 = numpyro.sample('x_0', dist.Normal(0, q))
y0 = numpyro.sample('y_0', dist.Normal(mu0, r))
_, xy = scan(transition, (x0, mu0), jnp.arange(T))
x, y = xy
return jnp.append(x0, x), jnp.append(y0, y)
T = 10
num_samples = 100
kernel = NUTS(model)
mcmc = MCMC(kernel, 100, num_samples)
mcmc.run(jax.random.PRNGKey(0), T=T)
assert set(mcmc.get_samples()) == {'x', 'y', 'y2', 'x_0', 'y_0'}
mcmc.print_summary()
samples = mcmc.get_samples()
x = samples.pop('x')[0] # take 1 sample of x
# this tests for the composition of condition and substitute
# this also tests if we can use `vmap` for predictive.
future = 5
predictive = Predictive(numpyro.handlers.condition(model, {'x': x}),
samples, return_sites=['x', 'y', 'y2'], parallel=True)
result = predictive(jax.random.PRNGKey(1), T=T + future)
expected_shape = (num_samples, T + future)
assert result['x'].shape == expected_shape
assert result['y'].shape == expected_shape
trans_prob = transition[x]
def _generate_data():
transition_probs = np.random.rand(dim, dim)
transition_probs = transition_probs / transition_probs.sum(-1, keepdims=True)
emissions_loc = np.arange(dim)
emissions_scale = 1.
state = np.random.choice(3)
obs = [np.random.normal(emissions_loc[state], emissions_scale)]
for _ in range(num_steps - 1):
state = np.random.choice(dim, p=transition_probs[state])
obs.append(np.random.normal(emissions_loc[state], emissions_scale))
return np.stack(obs)
data = _generate_data()
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500)
mcmc.run(random.PRNGKey(0), data)
def benchmark_hmc(args, features, labels):
step_size = np.sqrt(0.5 / features.shape[0])
trajectory_length = step_size * args.num_steps
rng_key = random.PRNGKey(1)
start = time.time()
kernel = NUTS(model, trajectory_length=trajectory_length)
mcmc = MCMC(kernel, 0, args.num_samples)
mcmc.run(rng_key, features, labels)
mcmc.print_summary()
print('\nMCMC elapsed time:', time.time() - start)
def numpyro_inference(data, args):
rng_key = jax.random.PRNGKey(args.seed)
kernel = numpyro.infer.NUTS(numpyro_model)
mcmc = numpyro.infer.MCMC(kernel, args.num_warmup, args.num_samples,
num_chains=args.num_chains, progress_bar=not args.disable_progbar)
tic = time.time()
mcmc._compile(rng_key, *data, extra_fields=('num_steps',))
print('MCMC (numpyro) compiling time:', time.time() - tic, '\n')
tic = time.time()
mcmc.warmup(rng_key, *data, extra_fields=('num_steps',))
mcmc.num_samples = args.num_samples
rng_key = mcmc._warmup_state.rng_key.copy()
tic_run = time.time()
mcmc.run(rng_key, *data, extra_fields=('num_steps',))
mcmc._last_state.rng_key.copy()
toc = time.time()
mcmc.print_summary()
print('\nMCMC (numpyro) elapsed time:', toc - tic)
sampling_time = toc - tic_run
def run_inference(model, at_bats, hits, rng_key, args):
if args.algo == "NUTS":
kernel = NUTS(model)
elif args.algo == "HMC":
kernel = HMC(model)
elif args.algo == "SA":
kernel = SA(model)
mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains,
progress_bar=False if (
"NUMPYRO_SPHINXBUILD" in os.environ or args.disable_progbar) else True)
mcmc.run(rng_key, at_bats, hits)
return mcmc.get_samples()
def run_hmc(model, args, rng_key, X, Y, hypers):
start = time.time()
kernel = NUTS(model, max_tree_depth=args['mtd'])
mcmc = MCMC(kernel, args['num_warmup'], args['num_samples'], num_chains=1,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
mcmc.run(rng_key, X, Y, hypers)
#mcmc.print_summary()
elapsed_time = time.time() - start
samples = mcmc.get_samples()
return samples, elapsed_time