How to use the numpyro.distributions.Dirichlet function in numpyro

To help you get started, we’ve selected a few numpyro examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pyro-ppl / numpyro / test / contrib / test_funsor.py View on Github external
def model(data):
        with numpyro.plate("states", dim):
            transition = numpyro.sample("transition", dist.Dirichlet(jnp.ones(dim)))
            emission_loc = numpyro.sample("emission_loc", dist.Normal(0, 1))
            emission_scale = numpyro.sample("emission_scale", dist.LogNormal(0, 1))

        trans_prob = numpyro.sample("initialize", dist.Dirichlet(jnp.ones(dim)))
        for t, y in markov(enumerate(data)):
            x = numpyro.sample("x_{}".format(t), dist.Categorical(trans_prob))
            numpyro.sample("y_{}".format(t), dist.Normal(emission_loc[x], emission_scale[x]), obs=y)
            trans_prob = transition[x]
github pyro-ppl / numpyro / test / test_mcmc.py View on Github external
def model(data):
        concentration = jnp.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
        numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
        return p_latent
github pyro-ppl / numpyro / test / test_mcmc.py View on Github external
def model(data):
        concentration = jnp.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
        numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
        return p_latent
github pyro-ppl / numpyro / test / test_mcmc_interface.py View on Github external
def model(data):
        concentration = np.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
        numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
        return p_latent
github pyro-ppl / numpyro / examples / hmm.py View on Github external
def simulate_data(rng_key, num_categories, num_words, num_supervised_data, num_unsupervised_data):
    rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)

    transition_prior = np.ones(num_categories)
    emission_prior = np.repeat(0.1, num_words)

    transition_prob = dist.Dirichlet(transition_prior).sample(key=rng_key_transition,
                                                              sample_shape=(num_categories,))
    emission_prob = dist.Dirichlet(emission_prior).sample(key=rng_key_emission,
                                                          sample_shape=(num_categories,))

    start_prob = np.repeat(1. / num_categories, num_categories)
    categories, words = [], []
    for t in range(num_supervised_data + num_unsupervised_data):
        rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)
        if t == 0 or t == num_supervised_data:
            category = dist.Categorical(start_prob).sample(key=rng_key_transition)
        else:
            category = dist.Categorical(transition_prob[category]).sample(key=rng_key_transition)
        word = dist.Categorical(emission_prob[category]).sample(key=rng_key_emission)
        categories.append(category)
        words.append(word)
github pyro-ppl / numpyro / examples / hmm.py View on Github external
def simulate_data(rng_key, num_categories, num_words, num_supervised_data, num_unsupervised_data):
    rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)

    transition_prior = np.ones(num_categories)
    emission_prior = np.repeat(0.1, num_words)

    transition_prob = dist.Dirichlet(transition_prior).sample(key=rng_key_transition,
                                                              sample_shape=(num_categories,))
    emission_prob = dist.Dirichlet(emission_prior).sample(key=rng_key_emission,
                                                          sample_shape=(num_categories,))

    start_prob = np.repeat(1. / num_categories, num_categories)
    categories, words = [], []
    for t in range(num_supervised_data + num_unsupervised_data):
        rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)
        if t == 0 or t == num_supervised_data:
            category = dist.Categorical(start_prob).sample(key=rng_key_transition)
        else:
            category = dist.Categorical(transition_prob[category]).sample(key=rng_key_transition)
        word = dist.Categorical(emission_prob[category]).sample(key=rng_key_emission)
        categories.append(category)
        words.append(word)

    # split into supervised data and unsupervised data
    categories, words = np.stack(categories), np.stack(words)
github pyro-ppl / numpyro / examples / hmm_enum.py View on Github external
def model_3(sequences, lengths, args, include_prior=True):
    num_sequences, max_length, data_dim = sequences.shape
    hidden_dim = int(args.hidden_dim ** 0.5)  # split between w and x
    with numpyro_mask(mask_array=include_prior):
        probs_w = pyro_sample("probs_w",
                              dist.Dirichlet(0.9 * np.eye(hidden_dim) + 0.1)
                                  .to_event(1))
        probs_x = pyro_sample("probs_x",
                              dist.Dirichlet(0.9 * np.eye(hidden_dim) + 0.1)
                                  .to_event(1))
        probs_y_shape = (hidden_dim, hidden_dim, data_dim)
        probs_y = pyro_sample("probs_y",
                              dist.Beta(np.full(probs_y_shape, 0.1),
                                        np.full(probs_y_shape, 0.9))
                                  .to_event(len(probs_y_shape)))

    tones_plate = pyro_plate("tones", data_dim, dim=-1)
    with pyro_plate("sequences", num_sequences, dim=-2) as batch:
        lengths = lengths[batch]
        w, x = 0, 0
        for t in pyro_markov(range(max_length)):
            with numpyro_mask(mask_array=(t < lengths).reshape(lengths.shape + (1,))):
                probs_ww = probs_w[w]
                probs_ww = np.broadcast_to(probs_ww, probs_ww.shape[:-3] + (num_sequences, 1) + probs_ww.shape[-1:])
                w = pyro_sample("w_{}".format(t), dist.Categorical(probs_ww),
github pyro-ppl / numpyro / examples / hmm.py View on Github external
def semi_supervised_hmm(transition_prior, emission_prior,
                        supervised_categories, supervised_words,
                        unsupervised_words):
    num_categories, num_words = transition_prior.shape[0], emission_prior.shape[0]
    transition_prob = numpyro.sample('transition_prob', dist.Dirichlet(
        np.broadcast_to(transition_prior, (num_categories, num_categories))))
    emission_prob = numpyro.sample('emission_prob', dist.Dirichlet(
        np.broadcast_to(emission_prior, (num_categories, num_words))))

    # models supervised data;
    # here we don't make any assumption about the first supervised category, in other words,
    # we place a flat/uniform prior on it.
    numpyro.sample('supervised_categories', dist.Categorical(transition_prob[supervised_categories[:-1]]),
                   obs=supervised_categories[1:])
    numpyro.sample('supervised_words', dist.Categorical(emission_prob[supervised_categories]),
                   obs=supervised_words)

    # computes log prob of unsupervised data
    transition_log_prob = np.log(transition_prob)
    emission_log_prob = np.log(emission_prob)
    init_log_prob = emission_log_prob[:, unsupervised_words[0]]
github pyro-ppl / numpyro / examples / hmm_enum.py View on Github external
def model_4(sequences, lengths, args, include_prior=True):
    num_sequences, max_length, data_dim = sequences.shape
    hidden_dim = int(args.hidden_dim ** 0.5)  # split between w and x
    with numpyro_mask(mask_array=include_prior):
        probs_w = pyro_sample("probs_w",
                              dist.Dirichlet(0.9 * np.eye(hidden_dim) + 0.1)
                                  .to_event(1))
        probs_x = pyro_sample("probs_x",
                              dist.Dirichlet(
                                  np.broadcast_to(0.9 * np.eye(hidden_dim) + 0.1,
                                                  (hidden_dim, hidden_dim, hidden_dim)))
                                  .to_event(2))

        probs_y_shape = (hidden_dim, hidden_dim, data_dim)
        probs_y = pyro_sample("probs_y",
                              dist.Beta(np.full(probs_y_shape, 0.1),
                                        np.full(probs_y_shape, 0.9))
                                  .to_event(len(probs_y_shape)))

    tones_plate = pyro_plate("tones", data_dim, dim=-1)
    with pyro_plate("sequences", num_sequences, dim=-2) as batch:
        lengths = lengths[batch]
        # Note the broadcasting tricks here: we declare a hidden arange and
        # ensure that w and x are always tensors so we can unsqueeze them below,
        # thus ensuring that the x sample sites have correct distribution shape.
github pyro-ppl / numpyro / benchmarks / hmm.py View on Github external
def numpyro_model(transition_prior, emission_prior, supervised_categories, supervised_words, unsupervised_words):
    num_categories = transition_prior.shape[0]
    with numpyro.plate('K', num_categories):
        transition_prob = numpyro.sample('transition_prob', ndist.Dirichlet(transition_prior))
        emission_prob = numpyro.sample('emission_prob', ndist.Dirichlet(emission_prior))

    numpyro.sample('supervised_categories', ndist.Categorical(transition_prob[supervised_categories[:-1]]),
                   obs=supervised_categories[1:])
    numpyro.sample('supervised_words', ndist.Categorical(emission_prob[supervised_categories]),
                   obs=supervised_words)

    log_prob = _forward_log_prob(unsupervised_words, np.log(transition_prob), np.log(emission_prob))
    numpyro.factor('forward_log_prob', log_prob)