Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_log_normal(shape):
loc = np.random.rand(*shape) * 2 - 1
scale = np.random.rand(*shape) + 0.5
def model():
with numpyro.plate_stack("plates", shape):
with numpyro.plate("particles", 100000):
return numpyro.sample("x",
dist.TransformedDistribution(
dist.Normal(jnp.zeros_like(loc),
jnp.ones_like(scale)),
[AffineTransform(loc, scale),
ExpTransform()]).expand_by([100000]))
with handlers.trace() as tr:
value = handlers.seed(model, 0)()
expected_moments = get_moments(jnp.log(value))
with numpyro.handlers.reparam(config={"x": TransformReparam()}):
with handlers.trace() as tr:
value = handlers.seed(model, 0)()
assert tr["x"]["type"] == "deterministic"
actual_moments = get_moments(jnp.log(value))
assert_allclose(actual_moments, expected_moments, atol=0.05)
def get_expected_probe(loc, scale):
with numpyro.handlers.trace() as trace:
with numpyro.handlers.seed(rng_seed=0):
model(loc, scale)
return get_moments(trace["x"]["value"])
def init(self, rng_key, *args, **kwargs):
"""
:param jax.random.PRNGKey rng_key: random number generator seed.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param kwargs: keyword arguments to the model / guide (these can possibly vary
during the course of fitting).
:return: tuple containing initial :data:`SVIState`, and `get_params`, a callable
that transforms unconstrained parameter values from the optimizer to the
specified constrained domain
"""
rng_key, model_seed, guide_seed = random.split(rng_key, 3)
model_init = seed(self.model, model_seed)
guide_init = seed(self.guide, guide_seed)
guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
model_trace = trace(model_init).get_trace(*args, **kwargs, **self.static_kwargs)
params = {}
inv_transforms = {}
# NB: params in model_trace will be overwritten by params in guide_trace
for site in list(model_trace.values()) + list(guide_trace.values()):
if site['type'] == 'param':
constraint = site['kwargs'].pop('constraint', constraints.real)
transform = biject_to(constraint)
inv_transforms[site['name']] = transform
params[site['name']] = transform.inv(site['value'])
self.constrain_fn = partial(transform_fn, inv_transforms)
return SVIState(self.optim.init(params), rng_key)
def single_prediction(val):
rng_key, samples = val
model_trace = trace(seed(substitute(model, samples), rng_key)).get_trace(
*model_args, **model_kwargs)
if return_sites is not None:
if return_sites == '':
sites = {k for k, site in model_trace.items() if site['type'] != 'plate'}
else:
sites = return_sites
else:
sites = {k for k, site in model_trace.items()
if (site['type'] == 'sample' and k not in samples) or (site['type'] == 'deterministic')}
return {name: site['value'] for name, site in model_trace.items() if name in sites}
def _get_model_transforms(model, model_args=(), model_kwargs=None):
model_kwargs = {} if model_kwargs is None else model_kwargs
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
inv_transforms = {}
# model code may need to be replayed in the presence of deterministic sites
replay_model = False
has_enumerate_support = False
for k, v in model_trace.items():
if v['type'] == 'sample' and not v['is_observed']:
if v['fn'].is_discrete:
has_enumerate_support = True
if not v['fn'].has_enumerate_support:
raise RuntimeError("MCMC only supports continuous sites or discrete sites "
f"with enumerate support, but got {type(v['fn']).__name__}.")
else:
support = v['fn'].support
inv_transforms[k] = biject_to(support)
# XXX: the following code filters out most situations with dynamic supports
args = ()
def log_density(model, model_args, model_kwargs, params):
"""
(EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
latent values ``params``.
:param model: Python callable containing NumPyro primitives.
:param tuple model_args: args provided to the model.
:param dict model_kwargs: kwargs provided to the model.
:param dict params: dictionary of current parameter values keyed by site
name.
:return: log of joint density and a corresponding model trace
"""
model = substitute(model, data=params)
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
log_joint = jnp.array(0.)
for site in model_trace.values():
if site['type'] == 'sample' and not isinstance(site['fn'], dist.PRNGIdentity):
value = site['value']
intermediates = site['intermediates']
scale = site['scale']
if intermediates:
log_prob = site['fn'].log_prob(value, intermediates)
else:
log_prob = site['fn'].log_prob(value)
if (scale is not None) and (not is_identically_one(scale)):
log_prob = scale * log_prob
log_prob = jnp.sum(log_prob)
log_joint = log_joint + log_prob
def _setup_prototype(self, *args, **kwargs):
# run the model so we can inspect its structure
rng_key = numpyro.sample("_{}_rng_key_setup".format(self.prefix), dist.PRNGIdentity())
model = handlers.seed(self.model, rng_key)
self.prototype_trace = handlers.block(handlers.trace(model).get_trace)(*args, **kwargs)