Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def model(data):
mean = numpyro.sample('mean', dist.Normal(0, 1).mask(False))
std = numpyro.sample('std', dist.ImproperUniform(dist.constraints.positive, (), ()))
return numpyro.sample('obs', dist.Normal(mean, std), obs=data)
def reconstruct_img(epoch, rng_key):
img = test_fetch(0, test_idx)[0][0]
plt.imsave(os.path.join(RESULTS_DIR, 'original_epoch={}.png'.format(epoch)), img, cmap='gray')
rng_key_binarize, rng_key_sample = random.split(rng_key)
test_sample = binarize(rng_key_binarize, img)
params = svi.get_params(svi_state)
z_mean, z_var = encoder_nn[1](params['encoder$params'], test_sample.reshape([1, -1]))
z = dist.Normal(z_mean, z_var).sample(rng_key_sample)
img_loc = decoder_nn[1](params['decoder$params'], z).reshape([28, 28])
plt.imsave(os.path.join(RESULTS_DIR, 'recons_epoch={}.png'.format(epoch)), img_loc, cmap='gray')
def factor(name, log_factor):
"""
Factor statement to add arbitrary log probability factor to a
probabilistic model.
:param str name: Name of the trivial sample.
:param numpy.ndarray log_factor: A possibly batched log probability factor.
"""
unit_dist = numpyro.distributions.distribution.Unit(log_factor)
unit_value = unit_dist.sample(None)
sample(name, unit_dist, obs=unit_value)
def step(self, *args, rng_key=None, **kwargs):
if self.svi_state is None:
if rng_key is None:
rng_key = numpyro.sample('svi.init', dist.PRNGIdentity())
self.svi_state = self.init(rng_key, *args, **kwargs)
try:
self.svi_state, loss = jit(self.update)(self.svi_state, *args, **kwargs)
except TypeError as e:
if 'not a valid JAX type' in str(e):
raise TypeError('NumPyro backend requires args, kwargs to be arrays or tuples, '
'dicts of arrays.')
else:
raise e
params = jit(super(SVI, self).get_params)(self.svi_state)
get_param_store().update(params)
return loss
def build_hooks(npyro=False):
if npyro:
d = np_dist
const = np_constraints
provider = jnp
else:
d = dist
const = constraints
provider = torch
def categorical_logits(logits):
return d.Categorical(logits=logits)
def bernoulli_logit(logits):
return d.Bernoulli(logits=logits)
def binomial_logit(n, logits):
def get_base_dist(self):
return dist.Normal(jnp.zeros(self.latent_dim), 1).to_event(1)
"""
(EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
latent values ``params``.
:param model: Python callable containing NumPyro primitives.
:param tuple model_args: args provided to the model.
:param dict model_kwargs: kwargs provided to the model.
:param dict params: dictionary of current parameter values keyed by site
name.
:return: log of joint density and a corresponding model trace
"""
model = substitute(model, data=params)
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
log_joint = jnp.array(0.)
for site in model_trace.values():
if site['type'] == 'sample' and not isinstance(site['fn'], dist.PRNGIdentity):
value = site['value']
intermediates = site['intermediates']
scale = site['scale']
if intermediates:
log_prob = site['fn'].log_prob(value, intermediates)
else:
log_prob = site['fn'].log_prob(value)
if (scale is not None) and (not is_identically_one(scale)):
log_prob = scale * log_prob
log_prob = jnp.sum(log_prob)
log_joint = log_joint + log_prob
return log_joint, model_trace
def __call__(self, name, fn, obs):
assert obs is None, "TransformReparam does not support observe statements"
fn, batch_shape = self._unexpand(fn)
assert isinstance(fn, dist.TransformedDistribution)
# Draw noise from the base distribution.
# We need to make sure that we have the same batch_shape
reinterpreted_batch_ndims = fn.event_dim - fn.base_dist.event_dim
x = numpyro.sample("{}_base".format(name),
fn.base_dist.to_event(reinterpreted_batch_ndims).expand(batch_shape))
# Differentiably transform.
for t in fn.transforms:
x = t(x)
# Simulate a pyro.deterministic() site.
return None, x
import numpyro
from numpyro import optim
from numpyro.contrib.autoguide import AutoContinuousELBO, AutoIAFNormal
from numpyro.diagnostics import print_summary
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import MCMC, NUTS, SVI
from numpyro.infer.util import initialize_model, transformed_potential_energy
# TODO: remove when the issue https://github.com/google/jax/issues/939 is fixed upstream
# The behaviour when training guide under fast math mode is unstable.
os.environ["XLA_FLAGS"] = "--xla_cpu_enable_fast_math=false"
class DualMoonDistribution(dist.Distribution):
support = constraints.real_vector
def __init__(self):
super(DualMoonDistribution, self).__init__(event_shape=(2,))
def sample(self, key, sample_shape=()):
# it is enough to return an arbitrary sample with correct shape
return np.zeros(sample_shape + self.event_shape)
def log_prob(self, x):
term1 = 0.5 * ((np.linalg.norm(x, axis=-1) - 2) / 0.4) ** 2
term2 = -0.5 * ((x[..., :1] + np.array([-2., 2.])) / 0.6) ** 2
pe = term1 - logsumexp(term2, axis=-1)
return -pe
def model(X, Y, D_H):
D_X, D_Y = X.shape[1], 1
# sample first layer (we put unit normal priors on all weights)
w1 = numpyro.sample("w1", dist.Normal(np.zeros((D_X, D_H)), np.ones((D_X, D_H)))) # D_X D_H
z1 = nonlin(np.matmul(X, w1)) # N D_H <= first layer of activations
# sample second layer
w2 = numpyro.sample("w2", dist.Normal(np.zeros((D_H, D_H)), np.ones((D_H, D_H)))) # D_H D_H
z2 = nonlin(np.matmul(z1, w2)) # N D_H <= second layer of activations
# sample final layer of weights and neural network output
w3 = numpyro.sample("w3", dist.Normal(np.zeros((D_H, D_Y)), np.ones((D_H, D_Y)))) # D_H D_Y
z3 = np.matmul(z2, w3) # N D_Y <= output of the neural network
# we put a prior on the observation noise
prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0))
sigma_obs = 1.0 / np.sqrt(prec_obs)
# observe data
numpyro.sample("Y", dist.Normal(z3, sigma_obs), obs=Y)