Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def model(data=None):
beta = numpyro.sample("beta", dist.Beta(np.ones(2), np.ones(2)))
with numpyro.plate("plate", N, dim=-2):
numpyro.sample("obs", dist.Bernoulli(beta), obs=data)
def guide(data):
alpha_q = numpyro.param("alpha_q", 1.0,
constraint=constraints.positive)
beta_q = numpyro.param("beta_q", 1.0,
constraint=constraints.positive)
numpyro.sample("beta", dist.Beta(alpha_q, beta_q))
class _ImproperWrapper(dist.ImproperUniform):
def sample(self, key, sample_shape=()):
transform = biject_to(self.support)
prototype_value = jnp.zeros(self.event_shape)
unconstrained_event_shape = jnp.shape(transform.inv(prototype_value))
shape = sample_shape + self.batch_shape + unconstrained_event_shape
unconstrained_samples = random.uniform(key, shape,
minval=-2,
maxval=2)
return transform(unconstrained_samples)
_DIST_MAP = {
dist.BernoulliProbs: lambda probs: osp.bernoulli(p=probs),
dist.BernoulliLogits: lambda logits: osp.bernoulli(p=_to_probs_bernoulli(logits)),
dist.Beta: lambda con1, con0: osp.beta(con1, con0),
dist.BinomialProbs: lambda probs, total_count: osp.binom(n=total_count, p=probs),
dist.BinomialLogits: lambda logits, total_count: osp.binom(n=total_count, p=_to_probs_bernoulli(logits)),
dist.Cauchy: lambda loc, scale: osp.cauchy(loc=loc, scale=scale),
dist.Chi2: lambda df: osp.chi2(df),
dist.Dirichlet: lambda conc: osp.dirichlet(conc),
dist.Exponential: lambda rate: osp.expon(scale=jnp.reciprocal(rate)),
dist.Gamma: lambda conc, rate: osp.gamma(conc, scale=1. / rate),
dist.Gumbel: lambda loc, scale: osp.gumbel_r(loc=loc, scale=scale),
dist.HalfCauchy: lambda scale: osp.halfcauchy(scale=scale),
dist.HalfNormal: lambda scale: osp.halfnorm(scale=scale),
dist.InverseGamma: lambda conc, rate: osp.invgamma(conc, scale=rate),
dist.Laplace: lambda loc, scale: osp.laplace(loc=loc, scale=scale),
dist.LogNormal: lambda loc, scale: osp.lognorm(s=scale, scale=jnp.exp(loc)),
dist.MultinomialProbs: lambda probs, total_count: osp.multinomial(n=total_count, p=probs),
dist.MultinomialLogits: lambda logits, total_count: osp.multinomial(n=total_count,
p=_to_probs_multinom(logits)),
def model(data):
y_prob = numpyro.sample("y_prob", dist.Beta(1., 1.))
with numpyro.plate("data", data.shape[0]):
y = numpyro.sample("y", dist.Bernoulli(y_prob))
z = numpyro.sample("z", dist.Bernoulli(0.65 * y + 0.1))
numpyro.sample("obs", dist.Normal(2. * z, 1.), obs=data)
def model_0(sequences, lengths, args, include_prior=True):
num_sequences, max_length, data_dim = sequences.shape
with numpyro_mask(mask_array=include_prior):
probs_x = pyro_sample("probs_x",
dist.Dirichlet(0.9 * np.eye(args.hidden_dim) + 0.1)
.to_event(1))
probs_y = pyro_sample("probs_y",
# the parameter expansion here is unfortunate, and
# necessitated by the fact that NumPyro allows some
# batch dimensions that are not plate or enum dims
dist.Beta(0.1 * np.ones((args.hidden_dim, data_dim)),
0.9 * np.ones((args.hidden_dim, data_dim))
).to_event(2))
tones_plate = pyro_plate("tones", data_dim, dim=-1)
for i in pyro_plate("sequences", len(sequences)):
length = lengths[i]
sequence = sequences[i, :length]
x = 0
for t in pyro_markov(range(length)):
x = pyro_sample("x_{}_{}".format(i, t), dist.Categorical(probs_x[x]),
infer={"enumerate": "parallel"})
logging.info(f"x[{i}, {t}]: {x.shape}")
with tones_plate:
pyro_sample("y_{}_{}".format(i, t), dist.Bernoulli(probs_y[x.squeeze(-1)]),
obs=sequence[t])
def model_4(sequences, lengths, args, include_prior=True):
num_sequences, max_length, data_dim = sequences.shape
hidden_dim = int(args.hidden_dim ** 0.5) # split between w and x
with numpyro_mask(mask_array=include_prior):
probs_w = pyro_sample("probs_w",
dist.Dirichlet(0.9 * np.eye(hidden_dim) + 0.1)
.to_event(1))
probs_x = pyro_sample("probs_x",
dist.Dirichlet(
np.broadcast_to(0.9 * np.eye(hidden_dim) + 0.1,
(hidden_dim, hidden_dim, hidden_dim)))
.to_event(2))
probs_y_shape = (hidden_dim, hidden_dim, data_dim)
probs_y = pyro_sample("probs_y",
dist.Beta(np.full(probs_y_shape, 0.1),
np.full(probs_y_shape, 0.9))
.to_event(len(probs_y_shape)))
tones_plate = pyro_plate("tones", data_dim, dim=-1)
with pyro_plate("sequences", num_sequences, dim=-2) as batch:
lengths = lengths[batch]
# Note the broadcasting tricks here: we declare a hidden arange and
# ensure that w and x are always tensors so we can unsqueeze them below,
# thus ensuring that the x sample sites have correct distribution shape.
w = x = np.array(0)
for t in pyro_markov(range(max_length)):
with numpyro_mask(mask_array=(t < lengths).reshape(lengths.shape + (1,))):
probs_ww = probs_w[w]
probs_ww = np.broadcast_to(probs_ww, probs_ww.shape[:-3] + (num_sequences, 1) + probs_ww.shape[-1:])
w = pyro_sample("w_{}".format(t), dist.Categorical(probs_ww),
infer={"enumerate": "parallel"})
def model_1(sequences, lengths, args, include_prior=True):
num_sequences, max_length, data_dim = sequences.shape
with numpyro_mask(mask_array=include_prior):
probs_x = pyro_sample("probs_x",
dist.Dirichlet(0.9 * np.eye(args.hidden_dim) + 0.1)
.to_event(1))
probs_y = pyro_sample("probs_y",
# the parameter expansion here is unfortunate, and
# necessitated by the fact that NumPyro allows some
# batch dimensions that are not plate or enum dims
dist.Beta(0.1 * np.ones((args.hidden_dim, data_dim)),
0.9 * np.ones((args.hidden_dim, data_dim))
).to_event(2))
tones_plate = pyro_plate("tones", data_dim, dim=-1)
with pyro_plate("sequences", num_sequences, dim=-2) as batch:
lengths = lengths[batch]
x = 0
for t in pyro_markov(range(max_length)):
with numpyro_mask(mask_array=(t < lengths).reshape(lengths.shape + (1,))):
probs_xx = probs_x[x]
probs_xx = np.broadcast_to(probs_xx, probs_xx.shape[:-3] + (num_sequences, 1) + probs_xx.shape[-1:])
x = pyro_sample("x_{}".format(t), dist.Categorical(probs_xx),
infer={"enumerate": "parallel"})
logging.info(f"x[{t}]: {x.shape}")
with tones_plate:
probs_yx = probs_y[x.squeeze(-1)]
def model_3(sequences, lengths, args, include_prior=True):
num_sequences, max_length, data_dim = sequences.shape
hidden_dim = int(args.hidden_dim ** 0.5) # split between w and x
with numpyro_mask(mask_array=include_prior):
probs_w = pyro_sample("probs_w",
dist.Dirichlet(0.9 * np.eye(hidden_dim) + 0.1)
.to_event(1))
probs_x = pyro_sample("probs_x",
dist.Dirichlet(0.9 * np.eye(hidden_dim) + 0.1)
.to_event(1))
probs_y_shape = (hidden_dim, hidden_dim, data_dim)
probs_y = pyro_sample("probs_y",
dist.Beta(np.full(probs_y_shape, 0.1),
np.full(probs_y_shape, 0.9))
.to_event(len(probs_y_shape)))
tones_plate = pyro_plate("tones", data_dim, dim=-1)
with pyro_plate("sequences", num_sequences, dim=-2) as batch:
lengths = lengths[batch]
w, x = 0, 0
for t in pyro_markov(range(max_length)):
with numpyro_mask(mask_array=(t < lengths).reshape(lengths.shape + (1,))):
probs_ww = probs_w[w]
probs_ww = np.broadcast_to(probs_ww, probs_ww.shape[:-3] + (num_sequences, 1) + probs_ww.shape[-1:])
w = pyro_sample("w_{}".format(t), dist.Categorical(probs_ww),
infer={"enumerate": "parallel"})
logging.info(f"w[{t}]: {w.shape}")
probs_xx = probs_x[x]