Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_elbo_dynamic_support():
x_prior = dist.TransformedDistribution(
dist.Normal(), [AffineTransform(0, 2), SigmoidTransform(), AffineTransform(0, 3)])
x_guide = dist.Uniform(0, 3)
def model():
numpyro.sample('x', x_prior)
def guide():
numpyro.sample('x', x_guide)
adam = optim.Adam(0.01)
x = 2.
guide = substitute(guide, param_map={'x': x})
svi = SVI(model, guide, adam, ELBO())
svi_state = svi.init(random.PRNGKey(0))
actual_loss = svi.evaluate(svi_state)
assert jnp.isfinite(actual_loss)
expected_loss = x_guide.log_prob(x) - x_prior.log_prob(x)
assert_allclose(actual_loss, expected_loss)
def body_fn(state):
i, key, _, _ = state
key, subkey = random.split(key)
if radius is None or prototype_params is None:
# Wrap model in a `substitute` handler to initialize from `init_loc_fn`.
seeded_model = substitute(seed(model, subkey), substitute_fn=init_strategy)
model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
constrained_values, inv_transforms = {}, {}
for k, v in model_trace.items():
if v['type'] == 'sample' and not v['is_observed'] and not v['fn'].is_discrete:
constrained_values[k] = v['value']
inv_transforms[k] = biject_to(v['fn'].support)
params = transform_fn(inv_transforms,
{k: v for k, v in constrained_values.items()},
invert=True)
else: # this branch doesn't require tracing the model
params = {}
for k, v in prototype_params.items():
if k in init_values:
params[k] = init_values[k]
else:
params[k] = random.uniform(subkey, jnp.shape(v), minval=-radius, maxval=radius)
def get_transform(self, params):
"""
Returns the transformation learned by the guide to generate samples from the unconstrained
(approximate) posterior.
:param dict params: Current parameters of model and autoguide.
The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params`
method from :class:`~numpyro.infer.svi.SVI`.
:return: the transform of posterior distribution
:rtype: :class:`~numpyro.distributions.transforms.Transform`
"""
posterior = handlers.substitute(self._get_posterior, params)()
assert isinstance(posterior, dist.TransformedDistribution), \
"posterior is not a transformed distribution"
if len(posterior.transforms) > 0:
return ComposeTransform(posterior.transforms)
else:
return posterior.transforms[0]
:param model: Python callable containing NumPyro primitives. Typically,
the model has been enumerated by using
:class:`~numpyro.contrib.funsor.enum_messenger.enum` handler::
def model(*args, **kwargs):
...
log_joint = log_density(enum(config_enumerate(model)), args, kwargs, params)
:param tuple model_args: args provided to the model.
:param dict model_kwargs: kwargs provided to the model.
:param dict params: dictionary of current parameter values keyed by site
name.
:return: log of joint density and a corresponding model trace
"""
model = substitute(model, data=params)
with plate_to_enum_plate():
model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
log_factors = []
sum_vars, prod_vars = frozenset(), frozenset()
for site in model_trace.values():
if site['type'] == 'sample':
value = site['value']
intermediates = site['intermediates']
scale = site['scale']
if intermediates:
log_prob = site['fn'].log_prob(value, intermediates)
else:
log_prob = site['fn'].log_prob(value)
if (scale is not None) and (not is_identically_one(scale)):
log_prob = scale * log_prob
def log_density(model, model_args, model_kwargs, params):
"""
(EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
latent values ``params``.
:param model: Python callable containing NumPyro primitives.
:param tuple model_args: args provided to the model.
:param dict model_kwargs: kwargs provided to the model.
:param dict params: dictionary of current parameter values keyed by site
name.
:return: log of joint density and a corresponding model trace
"""
model = substitute(model, data=params)
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
log_joint = jnp.array(0.)
for site in model_trace.values():
if site['type'] == 'sample' and not isinstance(site['fn'], dist.PRNGIdentity):
value = site['value']
intermediates = site['intermediates']
scale = site['scale']
if intermediates:
log_prob = site['fn'].log_prob(value, intermediates)
else:
log_prob = site['fn'].log_prob(value)
if (scale is not None) and (not is_identically_one(scale)):
log_prob = scale * log_prob
log_prob = jnp.sum(log_prob)
the corresponding base value lies in the support of base distribution. Otherwise,
the base value lies in the support of the distribution.
:param model: a callable containing NumPyro primitives.
:param tuple model_args: args provided to the model.
:param dict model_kwargs: kwargs provided to the model.
:param dict params: unconstrained parameters of `model`.
:param bool enum: whether to enumerate over discrete latent sites.
:return: potential energy given unconstrained parameters.
"""
if enum:
from numpyro.contrib.funsor import log_density as log_density_
else:
log_density_ = log_density
substituted_model = substitute(model, substitute_fn=partial(_unconstrain_reparam, params))
# no param is needed for log_density computation because we already substitute
log_joint, model_trace = log_density_(substituted_model, model_args, model_kwargs, {})
return - log_joint