How to use the numpyro.plate function in numpyro

To help you get started, we’ve selected a few numpyro examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / test / test_handlers.py View on Github external
def model_nested_plates_1():
    with numpyro.plate('outer', 10, dim=-2):
        x = numpyro.sample('y', dist.Normal(0., 1.))
        assert x.shape == (10, 1)
        with numpyro.plate('inner', 5):
            y = numpyro.sample('x', dist.Normal(0., 1.))
            assert y.shape == (10, 5)
            z = numpyro.deterministic('z', x ** 2)
            assert z.shape == (10, 1)
github pyro-ppl / numpyro / test / test_reparam.py View on Github external
def dirichlet_categorical(data):
    concentration = jnp.array([1.0, 1.0, 1.0])
    p_latent = numpyro.sample('p', dist.Dirichlet(concentration))
    with numpyro.plate('N', data.shape[0]):
        numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
    return p_latent
github pyro-ppl / numpyro / test / test_handlers.py View on Github external
def model_subsample_1():
    outer = numpyro.plate('outer', 20, subsample_size=10)
    inner = numpyro.plate('inner', 10, subsample_size=5, dim=-3)
    with outer:
        x = numpyro.sample('x', dist.Normal(0., 1.))
        assert x.shape == (10,)
    with inner:
        y = numpyro.sample('y', dist.Normal(0., 1.))
        assert y.shape == (5, 1, 1)
        z = numpyro.deterministic('z', x ** 2)
        assert z.shape == (10,)

    with outer, inner:
        xy = numpyro.sample('xy', dist.Normal(0., 1.))
        assert xy.shape == (5, 1, 10)
github pyro-ppl / numpyro / test / test_infer_util.py View on Github external
def model(data=None):
        beta = numpyro.sample("beta", dist.Beta(np.ones(2), np.ones(2)))
        with numpyro.plate("plate", N, dim=-2):
            numpyro.sample("obs", dist.Bernoulli(beta), obs=data)
github pyro-ppl / numpyro / test / test_reparam.py View on Github external
def neals_funnel(dim):
    y = numpyro.sample('y', dist.Normal(0, 3))
    with numpyro.plate('D', dim):
        numpyro.sample('x', dist.Normal(0, jnp.exp(y / 2)))
github pyro-ppl / numpyro / test / test_infer_util.py View on Github external
def model(data):
        f = numpyro.sample("beta", dist.Beta(1., 1.))
        with numpyro.plate("plate", 10):
            numpyro.sample("obs", dist.Bernoulli(f), obs=data)
github pyro-ppl / numpyro / numpyro / contrib / funsor / infer_util.py View on Github external
This is useful when doing inference for the usual NumPyro programs with
    `numpyro.plate` statements. For example, to get trace of a `model` whose discrete
    latent sites are enumerated, we can use

        enum_model = numpyro.contrib.funsor.enum(model)
        with plate_to_enum_plate():
            model_trace = numpyro.contrib.funsor.trace(enum_model).get_trace(
                *model_args, **model_kwargs)

    """
    try:
        numpyro.plate.__new__ = lambda cls, *args, **kwargs: enum_plate(*args, **kwargs)
        yield
    finally:
        numpyro.plate.__new__ = lambda *args, **kwargs: object.__new__(numpyro.plate)
github pyro-ppl / numpyro / examples / baseball.py View on Github external
def partially_pooled(at_bats, hits=None):
    r"""
    Number of hits has a Binomial distribution with independent
    probability of success, $\phi_i$. Each $\phi_i$ follows a Beta
    distribution with concentration parameters $c_1$ and $c_2$, where
    $c_1 = m * kappa$, $c_2 = (1 - m) * kappa$, $m ~ Uniform(0, 1)$,
    and $kappa ~ Pareto(1, 1.5)$.

    :param (jnp.DeviceArray) at_bats: Number of at bats for each player.
    :param (jnp.DeviceArray) hits: Number of hits for the given at bats.
    :return: Number of hits predicted by the model.
    """
    m = numpyro.sample("m", dist.Uniform(0, 1))
    kappa = numpyro.sample("kappa", dist.Pareto(1, 1.5))
    num_players = at_bats.shape[0]
    with numpyro.plate("num_players", num_players):
        phi_prior = dist.Beta(m * kappa, (1 - m) * kappa)
        phi = numpyro.sample("phi", phi_prior)
        return numpyro.sample("obs", dist.Binomial(at_bats, probs=phi), obs=hits)
github pyro-ppl / numpyro / examples / baseball.py View on Github external
def not_pooled(at_bats, hits=None):
    r"""
    Number of hits in $K$ at bats for each player has a Binomial
    distribution with independent probability of success, $\phi_i$.

    :param (jnp.DeviceArray) at_bats: Number of at bats for each player.
    :param (jnp.DeviceArray) hits: Number of hits for the given at bats.
    :return: Number of hits predicted by the model.
    """
    num_players = at_bats.shape[0]
    with numpyro.plate("num_players", num_players):
        phi_prior = dist.Uniform(0, 1)
        phi = numpyro.sample("phi", phi_prior)
        return numpyro.sample("obs", dist.Binomial(at_bats, probs=phi), obs=hits)