How to use the numpyro.infer.MCMC function in numpyro

To help you get started, we’ve selected a few numpyro examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / test / test_mcmc.py View on Github external
warmup_steps, num_samples = (100000, 100000) if kernel_cls is SA else (500, 20000)

    def model(data):
        alpha = jnp.array([1.1, 1.1])
        beta = jnp.array([1.1, 1.1])
        p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta))
        numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data)
        return p_latent

    true_probs = jnp.array([0.9, 0.1])
    data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2))
    if kernel_cls is SA:
        kernel = SA(model=model)
    else:
        kernel = kernel_cls(model=model, trajectory_length=0.1)
    mcmc = MCMC(kernel, num_warmup=warmup_steps, num_samples=num_samples, progress_bar=False)
    mcmc.run(random.PRNGKey(2), data)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples['p_latent'], 0), true_probs, atol=0.05)

    if 'JAX_ENABLE_X64' in os.environ:
        assert samples['p_latent'].dtype == jnp.float64
github pyro-ppl / numpyro / test / test_mcmc.py View on Github external
def test_reuse_mcmc_run(jit_args, shape):
    y1 = np.random.normal(3, 0.1, (100,))
    y2 = np.random.normal(-3, 0.1, (shape,))

    def model(y_obs):
        mu = numpyro.sample('mu', dist.Normal(0., 1.))
        sigma = numpyro.sample("sigma", dist.HalfCauchy(3.))
        numpyro.sample("y", dist.Normal(mu, sigma), obs=y_obs)

    # Run MCMC on zero observations.
    kernel = NUTS(model)
    mcmc = MCMC(kernel, 300, 500, jit_model_args=jit_args)
    mcmc.run(random.PRNGKey(32), y1)

    # Re-run on new data - should be much faster.
    mcmc.run(random.PRNGKey(32), y2)
    assert_allclose(mcmc.get_samples()['mu'].mean(), -3., atol=0.1)
github pyro-ppl / numpyro / test / test_mcmc_interface.py View on Github external
def test_uniform_normal():
    true_coef = 0.9
    num_warmup, num_samples = 1000, 1000

    def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        loc = numpyro.sample('loc', dist.Uniform(0, alpha))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000,))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.run(random.PRNGKey(2), data, collect_warmup=True)
    samples = mcmc.get_samples()
    assert len(samples['loc']) == num_warmup + num_samples
    assert_allclose(np.mean(samples['loc'], 0), true_coef, atol=0.05)
github pyro-ppl / numpyro / test / test_mcmc.py View on Github external
# This requires dense mass matrix estimation.
    D = 5

    warmup_steps, num_samples = 5000, 8000

    true_mean = 0.
    a = jnp.tril(0.5 * jnp.fliplr(jnp.eye(D)) + 0.1 * jnp.exp(random.normal(random.PRNGKey(0), shape=(D, D))))
    true_cov = jnp.dot(a, a.T)
    true_prec = jnp.linalg.inv(true_cov)

    def potential_fn(z):
        return 0.5 * jnp.dot(z.T, jnp.dot(true_prec, z))

    init_params = jnp.zeros(D)
    kernel = NUTS(potential_fn=potential_fn, dense_mass=True)
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(0), init_params=init_params)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples), true_mean, atol=0.02)
    assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02
github pyro-ppl / numpyro / test / test_mcmc.py View on Github external
def test_uniform_normal():
    true_coef = 0.9
    num_warmup, num_samples = 1000, 1000

    def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={'loc': TransformReparam()}):
            loc = numpyro.sample('loc', dist.Uniform(0, alpha))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000,))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.warmup(random.PRNGKey(2), data, collect_warmup=True)
    warmup_samples = mcmc.get_samples()
    mcmc.run(random.PRNGKey(3), data)
    samples = mcmc.get_samples()
    assert len(warmup_samples['loc']) == num_warmup
    assert len(samples['loc']) == num_samples
    assert_allclose(jnp.mean(samples['loc'], 0), true_coef, atol=0.05)
github pyro-ppl / numpyro / examples / baseball.py View on Github external
def run_inference(model, at_bats, hits, rng_key, args):
    if args.algo == "NUTS":
        kernel = NUTS(model)
    elif args.algo == "HMC":
        kernel = HMC(model)
    elif args.algo == "SA":
        kernel = SA(model)
    mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains,
                progress_bar=False if (
                    "NUMPYRO_SPHINXBUILD" in os.environ or args.disable_progbar) else True)
    mcmc.run(rng_key, at_bats, hits)
    return mcmc.get_samples()
github pyro-ppl / numpyro / examples / neutra.py View on Github external
print("Start training guide...")
    last_state, losses = lax.scan(lambda state, i: svi.update(state), svi_state, np.zeros(args.num_iters))
    params = svi.get_params(last_state)
    print("Finish training guide. Extract samples...")
    guide_samples = guide.sample_posterior(random.PRNGKey(0), params,
                                           sample_shape=(args.num_samples,))['x'].copy()

    transform = guide.get_transform(params)
    _, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), dual_moon_model)
    transformed_potential_fn = partial(transformed_potential_energy, potential_fn, transform)
    transformed_constrain_fn = lambda x: constrain_fn(transform(x))  # noqa: E731

    print("\nStart NeuTra HMC...")
    nuts_kernel = NUTS(potential_fn=transformed_potential_fn)
    mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples)
    init_params = np.zeros(guide.latent_size)
    mcmc.run(random.PRNGKey(3), init_params=init_params)
    mcmc.print_summary()
    zs = mcmc.get_samples()
    print("Transform samples into unwarped space...")
    samples = vmap(transformed_constrain_fn)(zs)
    print_summary(tree_map(lambda x: x[None, ...], samples))
    samples = samples['x'].copy()

    # make plots

    # guide samples (for plotting)
    guide_base_samples = dist.Normal(np.zeros(2), 1.).sample(random.PRNGKey(4), (1000,))
    guide_trans_samples = vmap(transformed_constrain_fn)(guide_base_samples)['x']

    x1 = np.linspace(-3, 3, 100)
github pyro-ppl / numpyro / examples / cb.py View on Github external
def run_hmc(model, args, rng_key, X, Y, hypers):
    start = time.time()
    kernel = NUTS(model, max_tree_depth=args['mtd'])
    mcmc = MCMC(kernel, args['num_warmup'], args['num_samples'], num_chains=1,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key, X, Y, hypers)
    #mcmc.print_summary()
    elapsed_time = time.time() - start

    samples = mcmc.get_samples()

    return samples, elapsed_time
github pyro-ppl / numpyro / examples / sparse_regression.py View on Github external
def run_inference(model, args, rng_key, X, Y, hypers):
    start = time.time()
    kernel = NUTS(model)
    mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    mcmc.run(rng_key, X, Y, hypers)
    mcmc.print_summary()
    print('\nMCMC elapsed time:', time.time() - start)
    return mcmc.get_samples()
github deepppl / deepppl / deepppl / dppl.py View on Github external
def mcmc(self, num_samples=10000, warmup_steps=1000, num_chains=1, thin=1, kernel=None):
        if kernel is None:
            kernel = numpyro.infer.NUTS(self._model, adapt_step_size=True)
        mcmc = numpyro.infer.MCMC(
            kernel, warmup_steps, num_samples - warmup_steps, num_chains=num_chains)
        return MCMCProxy(mcmc, True, self._generated_quantities, self._transformed_data, thin)