Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def model(data):
concentration = jnp.array([1.0, 1.0, 1.0])
p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
return p_latent
def model(data):
concentration = np.array([1.0, 1.0, 1.0])
p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
return p_latent
def test_chain_smoke(chain_method, compile_args):
def model(data):
concentration = jnp.array([1.0, 1.0, 1.0])
p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
return p_latent
data = dist.Categorical(jnp.array([0.1, 0.6, 0.3])).sample(random.PRNGKey(1), (2000,))
kernel = NUTS(model)
mcmc = MCMC(kernel, 2, 5, num_chains=2, chain_method=chain_method, jit_model_args=compile_args)
mcmc.warmup(random.PRNGKey(0), data)
mcmc.run(random.PRNGKey(1), data)
concentration = np.array([1.0, 1.0, 1.0])
p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
return p_latent
@jit
def get_samples(rng_key, data, step_size, trajectory_length, target_accept_prob):
kernel = kernel_cls(model, step_size=step_size, trajectory_length=trajectory_length,
target_accept_prob=target_accept_prob)
mcmc = MCMC(kernel, warmup_steps, num_samples, num_chains=2, chain_method=chain_method,
progress_bar=False)
mcmc.run(rng_key, data)
return mcmc.get_samples()
true_probs = np.array([0.1, 0.6, 0.3])
data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,))
samples = get_samples(rng_key, data, step_size, trajectory_length, target_accept_prob)
assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.02)
def model(data):
concentration = jnp.array([1.0, 1.0, 1.0])
p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
return p_latent
def model(data):
concentration = np.array([1.0, 1.0, 1.0])
p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
return p_latent
rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)
transition_prior = np.ones(num_categories)
emission_prior = np.repeat(0.1, num_words)
transition_prob = dist.Dirichlet(transition_prior).sample(key=rng_key_transition,
sample_shape=(num_categories,))
emission_prob = dist.Dirichlet(emission_prior).sample(key=rng_key_emission,
sample_shape=(num_categories,))
start_prob = np.repeat(1. / num_categories, num_categories)
categories, words = [], []
for t in range(num_supervised_data + num_unsupervised_data):
rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)
if t == 0 or t == num_supervised_data:
category = dist.Categorical(start_prob).sample(key=rng_key_transition)
else:
category = dist.Categorical(transition_prob[category]).sample(key=rng_key_transition)
word = dist.Categorical(emission_prob[category]).sample(key=rng_key_emission)
categories.append(category)
words.append(word)
# split into supervised data and unsupervised data
categories, words = np.stack(categories), np.stack(words)
supervised_categories = categories[:num_supervised_data]
supervised_words = words[:num_supervised_data]
unsupervised_words = words[num_supervised_data:]
return (transition_prior, emission_prior, transition_prob, emission_prob,
supervised_categories, supervised_words, unsupervised_words)
def semi_supervised_hmm(transition_prior, emission_prior,
supervised_categories, supervised_words,
unsupervised_words):
num_categories, num_words = transition_prior.shape[0], emission_prior.shape[0]
transition_prob = numpyro.sample('transition_prob', dist.Dirichlet(
np.broadcast_to(transition_prior, (num_categories, num_categories))))
emission_prob = numpyro.sample('emission_prob', dist.Dirichlet(
np.broadcast_to(emission_prior, (num_categories, num_words))))
# models supervised data;
# here we don't make any assumption about the first supervised category, in other words,
# we place a flat/uniform prior on it.
numpyro.sample('supervised_categories', dist.Categorical(transition_prob[supervised_categories[:-1]]),
obs=supervised_categories[1:])
numpyro.sample('supervised_words', dist.Categorical(emission_prob[supervised_categories]),
obs=supervised_words)
# computes log prob of unsupervised data
transition_log_prob = np.log(transition_prob)
emission_log_prob = np.log(emission_prob)
init_log_prob = emission_log_prob[:, unsupervised_words[0]]
log_prob = forward_log_prob(init_log_prob, unsupervised_words[1:],
transition_log_prob, emission_log_prob)
log_prob = logsumexp(log_prob, axis=0, keepdims=True)
# inject log_prob to potential function
numpyro.factor('forward_log_prob', log_prob)
tones_plate = pyro_plate("tones", data_dim, dim=-1)
with pyro_plate("sequences", num_sequences, dim=-2) as batch:
lengths = lengths[batch]
w, x = 0, 0
for t in pyro_markov(range(max_length)):
with numpyro_mask(mask_array=(t < lengths).reshape(lengths.shape + (1,))):
probs_ww = probs_w[w]
probs_ww = np.broadcast_to(probs_ww, probs_ww.shape[:-3] + (num_sequences, 1) + probs_ww.shape[-1:])
w = pyro_sample("w_{}".format(t), dist.Categorical(probs_ww),
infer={"enumerate": "parallel"})
logging.info(f"w[{t}]: {w.shape}")
probs_xx = probs_x[x]
probs_xx = np.broadcast_to(probs_xx, probs_xx.shape[:-3] + (num_sequences, 1) + probs_xx.shape[-1:])
x = pyro_sample("x_{}".format(t), dist.Categorical(probs_xx),
infer={"enumerate": "parallel"})
logging.info(f"x[{t}]: {x.shape}")
with tones_plate as tones:
probs_ywx = probs_y[w, x, tones]
probs_ywx = np.broadcast_to(
probs_ywx, probs_ywx.shape[:-2] + (num_sequences,) + probs_ywx.shape[-1:])
pyro_sample("y_{}".format(t), dist.Bernoulli(probs_ywx),
obs=sequences[batch, t])
.to_event(1))
probs_y_shape = (hidden_dim, hidden_dim, data_dim)
probs_y = pyro_sample("probs_y",
dist.Beta(np.full(probs_y_shape, 0.1),
np.full(probs_y_shape, 0.9))
.to_event(len(probs_y_shape)))
tones_plate = pyro_plate("tones", data_dim, dim=-1)
with pyro_plate("sequences", num_sequences, dim=-2) as batch:
lengths = lengths[batch]
w, x = 0, 0
for t in pyro_markov(range(max_length)):
with numpyro_mask(mask_array=(t < lengths).reshape(lengths.shape + (1,))):
probs_ww = probs_w[w]
probs_ww = np.broadcast_to(probs_ww, probs_ww.shape[:-3] + (num_sequences, 1) + probs_ww.shape[-1:])
w = pyro_sample("w_{}".format(t), dist.Categorical(probs_ww),
infer={"enumerate": "parallel"})
logging.info(f"w[{t}]: {w.shape}")
probs_xx = probs_x[x]
probs_xx = np.broadcast_to(probs_xx, probs_xx.shape[:-3] + (num_sequences, 1) + probs_xx.shape[-1:])
x = pyro_sample("x_{}".format(t), dist.Categorical(probs_xx),
infer={"enumerate": "parallel"})
logging.info(f"x[{t}]: {x.shape}")
with tones_plate as tones:
probs_ywx = probs_y[w, x, tones]
probs_ywx = np.broadcast_to(
probs_ywx, probs_ywx.shape[:-2] + (num_sequences,) + probs_ywx.shape[-1:])
pyro_sample("y_{}".format(t), dist.Bernoulli(probs_ywx),
obs=sequences[batch, t])