Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def model():
a = numpyro.param('a', a_init, constraint=constraints.greater_than(a_minval))
b = numpyro.param('b', b_init, constraint=constraints.positive)
numpyro.sample('x', dist.Normal(a, b))
def guide(data):
alpha_q = numpyro.param("alpha_q", 1.0,
constraint=constraints.positive)
beta_q = numpyro.param("beta_q", 1.0,
constraint=constraints.positive)
numpyro.sample("beta", dist.Beta(alpha_q, beta_q))
def model(data):
mean = numpyro.param('mean', 0.)
std = numpyro.param('std', 1., constraint=constraints.positive)
return numpyro.sample('obs', dist.Normal(mean, std), obs=data)
def guide(data):
alpha_q = numpyro.param("alpha_q", 1.0,
constraint=constraints.positive)
beta_q = numpyro.param("beta_q", 1.0,
constraint=constraints.positive)
numpyro.sample("beta", dist.Beta(alpha_q, beta_q))
y = handlers.substitute(lambda: numpyro.param('y', None) * numpyro.param('x', None), {'y': x})()
return x + y
def model():
x = numpyro.param('x', None)
y = handlers.substitute(lambda: numpyro.param('y', None) * numpyro.param('x', None), {'y': x})()
return x + y
def _get_posterior(self):
loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent)
scale_tril = numpyro.param('{}_scale_tril'.format(self.prefix),
jnp.identity(self.latent_dim) * self._init_scale,
constraint=constraints.lower_cholesky)
return dist.MultivariateNormal(loc, scale_tril=scale_tril)
def bernoulli_guide(X, Y, hypers, method="direct", num_probes=4, cg_tol=0.001):
S, sigma, P, N = hypers['expected_sparsity'], hypers['sigma'], X.shape[1], X.shape[0]
phi = sigma * (S / np.sqrt(N)) / (P - S)
eta1_loc = numpyro.param("eta1_loc", 0.25, constraint=constraints.positive)
numpyro.sample("eta1", dist.Delta(eta1_loc))
msq_loc = numpyro.param("msq_loc", 1.0, constraint=constraints.positive)
numpyro.sample("msq", dist.Delta(msq_loc))
xisq_loc = numpyro.param("xisq_loc", 1.0, constraint=constraints.positive)
numpyro.sample("xisq", dist.Delta(xisq_loc))
lam_loc = numpyro.param("lam_loc", 0.5 * np.ones(P), constraint=constraints.positive)
numpyro.sample("lambda", dist.Delta(lam_loc))
omega_loc = numpyro.param('omega_loc', -2.0 * np.ones(N))
omega_scale = numpyro.param('omega_scale', 0.8 * np.ones(N), constraint=constraints.positive)
base_dist = dist.Normal(omega_loc, omega_scale)
omega_dist = dist.TransformedDistribution(base_dist, [SigmoidTransform(), AffineTransform(0, 2.5)])
omega = numpyro.sample("omega", omega_dist)
def __call__(self, name, fn, obs):
assert obs is None, "LocScaleReparam does not support observe statements"
centered = self.centered
if is_identically_one(centered):
return name, fn, obs
event_shape = fn.event_shape
fn, event_dim = self._unwrap(fn)
fn, batch_shape = self._unexpand(fn)
# Apply a partial decentering transform.
params = {key: getattr(fn, key) for key in self.shape_params}
if self.centered is None:
centered = numpyro.param("{}_centered".format(name),
jnp.full(event_shape, 0.5),
constraint=constraints.unit_interval)
params["loc"] = fn.loc * centered
params["scale"] = fn.scale ** centered
decentered_fn = type(fn)(**params).expand(batch_shape)
# Draw decentered noise.
decentered_value = numpyro.sample("{}_decentered".format(name),
self._wrap(decentered_fn, event_dim))
# Differentiably transform.
delta = decentered_value - centered * fn.loc
value = fn.loc + jnp.power(fn.scale, 1 - centered) * delta
# Simulate a pyro.deterministic() site.
return None, value
def record_stats(stat_value, num_stats=2):
stat = numpyro.param('stats', np.zeros(num_stats)) * stop_gradient(stat_value)
numpyro.factor('stats_dummy_factor', -stat + stop_gradient(stat))