Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def model(data):
with numpyro.plate("states", dim):
transition = numpyro.sample("transition", dist.Dirichlet(jnp.ones(dim)))
emission_loc = numpyro.sample("emission_loc", dist.Normal(0, 1))
emission_scale = numpyro.sample("emission_scale", dist.LogNormal(0, 1))
trans_prob = numpyro.sample("initialize", dist.Dirichlet(jnp.ones(dim)))
for t, y in markov(enumerate(data)):
x = numpyro.sample("x_{}".format(t), dist.Categorical(trans_prob))
numpyro.sample("y_{}".format(t), dist.Normal(emission_loc[x], emission_scale[x]), obs=y)
trans_prob = transition[x]
def model(data):
concentration = jnp.array([1.0, 1.0, 1.0])
p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
return p_latent
def model(data):
concentration = jnp.array([1.0, 1.0, 1.0])
p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
return p_latent
def model(data):
concentration = np.array([1.0, 1.0, 1.0])
p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
return p_latent
def simulate_data(rng_key, num_categories, num_words, num_supervised_data, num_unsupervised_data):
rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)
transition_prior = np.ones(num_categories)
emission_prior = np.repeat(0.1, num_words)
transition_prob = dist.Dirichlet(transition_prior).sample(key=rng_key_transition,
sample_shape=(num_categories,))
emission_prob = dist.Dirichlet(emission_prior).sample(key=rng_key_emission,
sample_shape=(num_categories,))
start_prob = np.repeat(1. / num_categories, num_categories)
categories, words = [], []
for t in range(num_supervised_data + num_unsupervised_data):
rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)
if t == 0 or t == num_supervised_data:
category = dist.Categorical(start_prob).sample(key=rng_key_transition)
else:
category = dist.Categorical(transition_prob[category]).sample(key=rng_key_transition)
word = dist.Categorical(emission_prob[category]).sample(key=rng_key_emission)
categories.append(category)
words.append(word)
def simulate_data(rng_key, num_categories, num_words, num_supervised_data, num_unsupervised_data):
rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)
transition_prior = np.ones(num_categories)
emission_prior = np.repeat(0.1, num_words)
transition_prob = dist.Dirichlet(transition_prior).sample(key=rng_key_transition,
sample_shape=(num_categories,))
emission_prob = dist.Dirichlet(emission_prior).sample(key=rng_key_emission,
sample_shape=(num_categories,))
start_prob = np.repeat(1. / num_categories, num_categories)
categories, words = [], []
for t in range(num_supervised_data + num_unsupervised_data):
rng_key, rng_key_transition, rng_key_emission = random.split(rng_key, 3)
if t == 0 or t == num_supervised_data:
category = dist.Categorical(start_prob).sample(key=rng_key_transition)
else:
category = dist.Categorical(transition_prob[category]).sample(key=rng_key_transition)
word = dist.Categorical(emission_prob[category]).sample(key=rng_key_emission)
categories.append(category)
words.append(word)
# split into supervised data and unsupervised data
categories, words = np.stack(categories), np.stack(words)
def model_3(sequences, lengths, args, include_prior=True):
num_sequences, max_length, data_dim = sequences.shape
hidden_dim = int(args.hidden_dim ** 0.5) # split between w and x
with numpyro_mask(mask_array=include_prior):
probs_w = pyro_sample("probs_w",
dist.Dirichlet(0.9 * np.eye(hidden_dim) + 0.1)
.to_event(1))
probs_x = pyro_sample("probs_x",
dist.Dirichlet(0.9 * np.eye(hidden_dim) + 0.1)
.to_event(1))
probs_y_shape = (hidden_dim, hidden_dim, data_dim)
probs_y = pyro_sample("probs_y",
dist.Beta(np.full(probs_y_shape, 0.1),
np.full(probs_y_shape, 0.9))
.to_event(len(probs_y_shape)))
tones_plate = pyro_plate("tones", data_dim, dim=-1)
with pyro_plate("sequences", num_sequences, dim=-2) as batch:
lengths = lengths[batch]
w, x = 0, 0
for t in pyro_markov(range(max_length)):
with numpyro_mask(mask_array=(t < lengths).reshape(lengths.shape + (1,))):
probs_ww = probs_w[w]
probs_ww = np.broadcast_to(probs_ww, probs_ww.shape[:-3] + (num_sequences, 1) + probs_ww.shape[-1:])
w = pyro_sample("w_{}".format(t), dist.Categorical(probs_ww),
def semi_supervised_hmm(transition_prior, emission_prior,
supervised_categories, supervised_words,
unsupervised_words):
num_categories, num_words = transition_prior.shape[0], emission_prior.shape[0]
transition_prob = numpyro.sample('transition_prob', dist.Dirichlet(
np.broadcast_to(transition_prior, (num_categories, num_categories))))
emission_prob = numpyro.sample('emission_prob', dist.Dirichlet(
np.broadcast_to(emission_prior, (num_categories, num_words))))
# models supervised data;
# here we don't make any assumption about the first supervised category, in other words,
# we place a flat/uniform prior on it.
numpyro.sample('supervised_categories', dist.Categorical(transition_prob[supervised_categories[:-1]]),
obs=supervised_categories[1:])
numpyro.sample('supervised_words', dist.Categorical(emission_prob[supervised_categories]),
obs=supervised_words)
# computes log prob of unsupervised data
transition_log_prob = np.log(transition_prob)
emission_log_prob = np.log(emission_prob)
init_log_prob = emission_log_prob[:, unsupervised_words[0]]
def model_4(sequences, lengths, args, include_prior=True):
num_sequences, max_length, data_dim = sequences.shape
hidden_dim = int(args.hidden_dim ** 0.5) # split between w and x
with numpyro_mask(mask_array=include_prior):
probs_w = pyro_sample("probs_w",
dist.Dirichlet(0.9 * np.eye(hidden_dim) + 0.1)
.to_event(1))
probs_x = pyro_sample("probs_x",
dist.Dirichlet(
np.broadcast_to(0.9 * np.eye(hidden_dim) + 0.1,
(hidden_dim, hidden_dim, hidden_dim)))
.to_event(2))
probs_y_shape = (hidden_dim, hidden_dim, data_dim)
probs_y = pyro_sample("probs_y",
dist.Beta(np.full(probs_y_shape, 0.1),
np.full(probs_y_shape, 0.9))
.to_event(len(probs_y_shape)))
tones_plate = pyro_plate("tones", data_dim, dim=-1)
with pyro_plate("sequences", num_sequences, dim=-2) as batch:
lengths = lengths[batch]
# Note the broadcasting tricks here: we declare a hidden arange and
# ensure that w and x are always tensors so we can unsqueeze them below,
# thus ensuring that the x sample sites have correct distribution shape.
def numpyro_model(transition_prior, emission_prior, supervised_categories, supervised_words, unsupervised_words):
num_categories = transition_prior.shape[0]
with numpyro.plate('K', num_categories):
transition_prob = numpyro.sample('transition_prob', ndist.Dirichlet(transition_prior))
emission_prob = numpyro.sample('emission_prob', ndist.Dirichlet(emission_prior))
numpyro.sample('supervised_categories', ndist.Categorical(transition_prob[supervised_categories[:-1]]),
obs=supervised_categories[1:])
numpyro.sample('supervised_words', ndist.Categorical(emission_prob[supervised_categories]),
obs=supervised_words)
log_prob = _forward_log_prob(unsupervised_words, np.log(transition_prob), np.log(emission_prob))
numpyro.factor('forward_log_prob', log_prob)