Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_plate(model):
trace = handlers.trace(handlers.seed(model, random.PRNGKey(1))).get_trace()
jit_trace = handlers.trace(jit(handlers.seed(model, random.PRNGKey(1)))).get_trace()
assert 'z' in trace
for name, site in trace.items():
if site['type'] == 'sample':
assert_allclose(jit_trace[name]['value'], site['value'])
ys = vmap(lambda rng_key: handlers.seed(lambda: _sample(), rng_key)())(jnp.arange(100))
assert_allclose(xs, ys, atol=1e-6)
def test_model_with_transformed_distribution():
x_prior = dist.HalfNormal(2)
y_prior = dist.LogNormal(scale=3.) # transformed distribution
def model():
numpyro.sample('x', x_prior)
numpyro.sample('y', y_prior)
params = {'x': np.array(-5.), 'y': np.array(7.)}
model = handlers.seed(model, random.PRNGKey(0))
inv_transforms = {'x': biject_to(x_prior.support), 'y': biject_to(y_prior.support)}
expected_samples = partial(transform_fn, inv_transforms)(params)
expected_potential_energy = (
- x_prior.log_prob(expected_samples['x']) -
y_prior.log_prob(expected_samples['y']) -
inv_transforms['x'].log_abs_det_jacobian(params['x'], expected_samples['x']) -
inv_transforms['y'].log_abs_det_jacobian(params['y'], expected_samples['y'])
)
base_inv_transforms = {'x': biject_to(x_prior.support), 'y': biject_to(y_prior.base_dist.support)}
actual_samples = constrain_fn(
handlers.seed(model, random.PRNGKey(0)), base_inv_transforms, (), {}, params)
actual_potential_energy = potential_energy(model, base_inv_transforms, (), {}, params)
assert_allclose(expected_samples['x'], actual_samples['x'])
assert_allclose(expected_samples['y'], actual_samples['y'])
def fn(rng_key_1, rng_key_2, rng_key_3):
xs = []
with handlers.seed(rng_seed=rng_key_1):
with handlers.seed(rng_seed=rng_key_2):
xs.append(numpyro.sample('x', dist.Normal(0., 1.)))
with handlers.seed(rng_seed=rng_key_3):
xs.append(numpyro.sample('y', dist.Normal(0., 1.)))
return jnp.stack(xs)
def test_condition():
def model():
x = numpyro.sample('x', dist.Delta(0.))
y = numpyro.sample('y', dist.Normal(0., 1.))
return x + y
model = handlers.condition(handlers.seed(model, random.PRNGKey(1)), {'y': 2.})
model_trace = handlers.trace(model).get_trace()
assert model_trace['y']['value'] == 2.
assert model_trace['y']['is_observed']
assert handlers.condition(model, {'y': 3.})() == 3.
def single_particle_elbo(rng_key):
model_seed, guide_seed = random.split(rng_key)
seeded_model = seed(model, model_seed)
seeded_guide = seed(guide, guide_seed)
guide_log_density, guide_trace = log_density(seeded_guide, args, kwargs, param_map)
seeded_model = replay(seeded_model, guide_trace)
model_log_density, _ = log_density(seeded_model, args, kwargs, param_map)
# log p(z) - log q(z)
elbo = model_log_density - guide_log_density
return elbo
samples = {
"Param:{}".format(i): jax.device_get(v)
for i, v in enumerate(tree_flatten_samples)
}
self._samples = samples
self.nchains, self.ndraws = posterior.num_chains, posterior.num_samples
self.model = self.posterior.sampler.model
# model arguments and keyword arguments
self._args = self.posterior._args # pylint: disable=protected-access
self._kwargs = self.posterior._kwargs # pylint: disable=protected-access
else:
self.nchains = self.ndraws = 0
observations = {}
if self.model is not None:
seeded_model = numpyro.handlers.seed(self.model, jax.random.PRNGKey(0))
trace = numpyro.handlers.trace(seeded_model).get_trace(*self._args, **self._kwargs)
observations = {
name: site["value"]
for name, site in trace.items()
if site["type"] == "sample" and site["is_observed"]
}
self.observations = observations if observations else None
def sample_posterior(self, rng_key, params, sample_shape=()):
"""
Get samples from the learned posterior.
:param jax.random.PRNGKey rng_key: random key to be used draw samples.
:param dict params: Current parameters of model and autoguide.
The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params`
method from :class:`~numpyro.infer.svi.SVI`.
:param tuple sample_shape: batch shape of each latent sample, defaults to ().
:return: a dict containing samples drawn the this guide.
:rtype: dict
"""
latent_sample = handlers.substitute(
handlers.seed(self._sample_latent, rng_key), params)(sample_shape=sample_shape)
return self._unpack_and_constrain(latent_sample, params)
present_notes = ((sequences == 1).sum(0).sum(0) > 0)
# remove notes that are never played (we remove 37/88 notes)
sequences = sequences[..., present_notes]
if args.truncate:
lengths = lengths.clip(0, args.truncate)
sequences = sequences[:, :args.truncate]
# All of our models have two plates: "data" and "tones".
max_plate_nesting = 1 if model is model_0 else 2
# To help debug our tensor shapes, let's print the shape of each site's
# distribution, value, and log_prob tensor. Note this information is
# automatically printed on most errors inside SVI.
if args.print_shapes:
model_trace = packed_trace(enum(seed(model, 42), -max_plate_nesting - 1)).get_trace(
sequences, lengths, args=args)
for name in model_trace:
if model_trace[name]['is_observed'] or model_trace[name]['infer'].get('enumerate', None) == 'parallel':
dim_to_name = model_trace[name]['infer']['dim_to_name']
logging.info(to_funsor(model_trace[name]['fn'].log_prob(model_trace[name]['value']),
output=funsor.reals(), dim_to_name=dim_to_name).inputs)
logging.info('Starting inference...')
rng_key = random.PRNGKey(2)
start = time.time()
kernel = {'nuts': NUTS, 'hmc': HMC}[args.kernel](enum(model, -max_plate_nesting - 1))
mcmc = MCMC(kernel, args.num_warmup, args.num_samples, progress_bar=True)
mcmc.run(rng_key, sequences, lengths, args=args)
mcmc.print_summary()
# samples = mcmc.get_samples() # TODO do something with this
logging.info('\nMCMC elapsed time:', time.time() - start)