How to use the numpyro.sample function in numpyro

To help you get started, we’ve selected a few numpyro examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / test / test_compile.py View on Github external
def model(deterministic=True):
    GLOBAL["count"] += 1
    x = numpyro.sample("x", dist.Normal())
    if deterministic:
        numpyro.deterministic("x_copy", x)
github pyro-ppl / numpyro / test / test_handlers.py View on Github external
def model_nested_plates_2():
    outer = numpyro.plate('outer', 10)
    inner = numpyro.plate('inner', 5, dim=-3)
    with outer:
        x = numpyro.sample('x', dist.Normal(0., 1.))
        assert x.shape == (10,)
    with inner:
        y = numpyro.sample('y', dist.Normal(0., 1.))
        assert y.shape == (5, 1, 1)
        z = numpyro.deterministic('z', x ** 2)
        assert z.shape == (10,)

    with outer, inner:
        xy = numpyro.sample('xy', dist.Normal(0., 1.), sample_shape=(10,))
        assert xy.shape == (5, 1, 10)
github pyro-ppl / numpyro / test / test_mcmc_interface.py View on Github external
def model(data):
        alpha = 1 / np.mean(data)
        lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha))
        lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha))
        tau = numpyro.sample('tau', dist.Uniform(0, 1))
        lambda12 = np.where(np.arange(len(data)) < tau * len(data), lambda1, lambda2)
        numpyro.sample('obs', dist.Poisson(lambda12), obs=data)
github pyro-ppl / numpyro / test / test_autoguide.py View on Github external
def model():
        a = numpyro.param('a', a_init, constraint=constraints.greater_than(a_minval))
        b = numpyro.param('b', b_init, constraint=constraints.positive)
        numpyro.sample('x', dist.Normal(a, b))
github pyro-ppl / numpyro / test / contrib / test_funsor.py View on Github external
def model(data):
        with numpyro.plate("states", dim):
            transition = numpyro.sample("transition", dist.Dirichlet(jnp.ones(dim)))
            emission_loc = numpyro.sample("emission_loc", dist.Normal(0, 1))
            emission_scale = numpyro.sample("emission_scale", dist.LogNormal(0, 1))

        trans_prob = numpyro.sample("initialize", dist.Dirichlet(jnp.ones(dim)))
        for t, y in markov(enumerate(data)):
            x = numpyro.sample("x_{}".format(t), dist.Categorical(trans_prob))
            numpyro.sample("y_{}".format(t), dist.Normal(emission_loc[x], emission_scale[x]), obs=y)
            trans_prob = transition[x]
github pyro-ppl / numpyro / test / contrib / test_funsor.py View on Github external
def model(data):
        with numpyro.plate("states", dim):
            transition = numpyro.sample("transition", dist.Dirichlet(jnp.ones(dim)))
            emission_loc = numpyro.sample("emission_loc", dist.Normal(0, 1))
            emission_scale = numpyro.sample("emission_scale", dist.LogNormal(0, 1))

        trans_prob = numpyro.sample("initialize", dist.Dirichlet(jnp.ones(dim)))
        for t, y in markov(enumerate(data)):
            x = numpyro.sample("x_{}".format(t), dist.Categorical(trans_prob))
            numpyro.sample("y_{}".format(t), dist.Normal(emission_loc[x], emission_scale[x]), obs=y)
            trans_prob = transition[x]
github pyro-ppl / numpyro / examples / utils.py View on Github external
def sample_aux_noise(shape):
    key = numpyro.sample('rng_key_aux', numpyro.distributions.PRNGIdentity())
    with numpyro.handlers.block():
        return jax.random.normal(key, shape=shape)
github pyro-ppl / numpyro / benchmarks / sparse_regression.py View on Github external
alpha_1, beta_1 = hypers['alpha_1'], hypers['beta_1']
    alpha_2, beta_2 = hypers['alpha_2'], hypers['beta_2']
    sigma = numpyro.sample("sigma", dist.HalfNormal(hypers['sigma_scale']))
    phi = (m0 / (M - m0)) * (sigma / np.sqrt(N))
    eta1_base = numpyro.sample("eta1_base", dist.HalfCauchy(1.))
    eta1 = phi * eta1_base
    msq = numpyro.sample("m_sq", dist.InverseGamma(alpha_1, beta_1))
    psi_sq = numpyro.sample("psi_sq", dist.InverseGamma(alpha_2, beta_2))

    eta2 = np.square(eta1) * np.sqrt(psi_sq) / msq

    lam = numpyro.sample("lambda", dist.HalfCauchy(np.ones(M)))
    kappa = np.sqrt(msq) * lam / np.sqrt(msq + np.square(eta1 * lam))

    # sample observation noise
    var_obs = numpyro.sample("var_obs", dist.InverseGamma(hypers['alpha_obs'], hypers['beta_obs']))

    # compute kernel
    kX = kappa * X
    kX2 = kappa * np.square(X)
    k = kernel_matrix(kX, kX2, eta1, eta2, hypers['c']) + var_obs * np.eye(N)
    assert k.shape == (N, N)

    # sample Y according to the standard gaussian process formula
    numpyro.sample("Y", dist.MultivariateNormal(loc=np.zeros(X.shape[0]), covariance_matrix=k),
                   obs=Y)