Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def model(deterministic=True):
GLOBAL["count"] += 1
x = numpyro.sample("x", dist.Normal())
if deterministic:
numpyro.deterministic("x_copy", x)
def model_nested_plates_2():
outer = numpyro.plate('outer', 10)
inner = numpyro.plate('inner', 5, dim=-3)
with outer:
x = numpyro.sample('x', dist.Normal(0., 1.))
assert x.shape == (10,)
with inner:
y = numpyro.sample('y', dist.Normal(0., 1.))
assert y.shape == (5, 1, 1)
z = numpyro.deterministic('z', x ** 2)
assert z.shape == (10,)
with outer, inner:
xy = numpyro.sample('xy', dist.Normal(0., 1.), sample_shape=(10,))
assert xy.shape == (5, 1, 10)
def model(data):
alpha = 1 / np.mean(data)
lambda1 = numpyro.sample('lambda1', dist.Exponential(alpha))
lambda2 = numpyro.sample('lambda2', dist.Exponential(alpha))
tau = numpyro.sample('tau', dist.Uniform(0, 1))
lambda12 = np.where(np.arange(len(data)) < tau * len(data), lambda1, lambda2)
numpyro.sample('obs', dist.Poisson(lambda12), obs=data)
def model():
a = numpyro.param('a', a_init, constraint=constraints.greater_than(a_minval))
b = numpyro.param('b', b_init, constraint=constraints.positive)
numpyro.sample('x', dist.Normal(a, b))
def model(data):
with numpyro.plate("states", dim):
transition = numpyro.sample("transition", dist.Dirichlet(jnp.ones(dim)))
emission_loc = numpyro.sample("emission_loc", dist.Normal(0, 1))
emission_scale = numpyro.sample("emission_scale", dist.LogNormal(0, 1))
trans_prob = numpyro.sample("initialize", dist.Dirichlet(jnp.ones(dim)))
for t, y in markov(enumerate(data)):
x = numpyro.sample("x_{}".format(t), dist.Categorical(trans_prob))
numpyro.sample("y_{}".format(t), dist.Normal(emission_loc[x], emission_scale[x]), obs=y)
trans_prob = transition[x]
def model(data):
with numpyro.plate("states", dim):
transition = numpyro.sample("transition", dist.Dirichlet(jnp.ones(dim)))
emission_loc = numpyro.sample("emission_loc", dist.Normal(0, 1))
emission_scale = numpyro.sample("emission_scale", dist.LogNormal(0, 1))
trans_prob = numpyro.sample("initialize", dist.Dirichlet(jnp.ones(dim)))
for t, y in markov(enumerate(data)):
x = numpyro.sample("x_{}".format(t), dist.Categorical(trans_prob))
numpyro.sample("y_{}".format(t), dist.Normal(emission_loc[x], emission_scale[x]), obs=y)
trans_prob = transition[x]
def sample_aux_noise(shape):
key = numpyro.sample('rng_key_aux', numpyro.distributions.PRNGIdentity())
with numpyro.handlers.block():
return jax.random.normal(key, shape=shape)
alpha_1, beta_1 = hypers['alpha_1'], hypers['beta_1']
alpha_2, beta_2 = hypers['alpha_2'], hypers['beta_2']
sigma = numpyro.sample("sigma", dist.HalfNormal(hypers['sigma_scale']))
phi = (m0 / (M - m0)) * (sigma / np.sqrt(N))
eta1_base = numpyro.sample("eta1_base", dist.HalfCauchy(1.))
eta1 = phi * eta1_base
msq = numpyro.sample("m_sq", dist.InverseGamma(alpha_1, beta_1))
psi_sq = numpyro.sample("psi_sq", dist.InverseGamma(alpha_2, beta_2))
eta2 = np.square(eta1) * np.sqrt(psi_sq) / msq
lam = numpyro.sample("lambda", dist.HalfCauchy(np.ones(M)))
kappa = np.sqrt(msq) * lam / np.sqrt(msq + np.square(eta1 * lam))
# sample observation noise
var_obs = numpyro.sample("var_obs", dist.InverseGamma(hypers['alpha_obs'], hypers['beta_obs']))
# compute kernel
kX = kappa * X
kX2 = kappa * np.square(X)
k = kernel_matrix(kX, kX2, eta1, eta2, hypers['c']) + var_obs * np.eye(N)
assert k.shape == (N, N)
# sample Y according to the standard gaussian process formula
numpyro.sample("Y", dist.MultivariateNormal(loc=np.zeros(X.shape[0]), covariance_matrix=k),
obs=Y)