Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_scale(use_context_manager):
def model(data):
x = numpyro.sample('x', dist.Normal(0, 1))
with optional(use_context_manager, handlers.scale(scale=10)):
numpyro.sample('obs', dist.Normal(x, 1), obs=data)
model = model if use_context_manager else handlers.scale(model, 10.)
data = random.normal(random.PRNGKey(0), (3,))
x = random.normal(random.PRNGKey(1))
log_joint = log_density(model, (data,), {}, {'x': x})[0]
log_prob1, log_prob2 = dist.Normal(0, 1).log_prob(x), dist.Normal(x, 1).log_prob(data).sum()
expected = log_prob1 + 10 * log_prob2 if use_context_manager else 10 * (log_prob1 + log_prob2)
assert_allclose(log_joint, expected)
def get_actual_probe(loc, scale):
with numpyro.handlers.trace() as trace:
with numpyro.handlers.seed(rng_seed=0):
with numpyro.handlers.reparam(config={"x": reparam}):
model(loc, scale)
return get_moments(trace["x"]["value"])
def get_actual_probe(loc, scale):
with numpyro.handlers.trace() as trace:
with numpyro.handlers.seed(rng_seed=0):
with numpyro.handlers.reparam(config={"x": reparam}):
model(loc, scale)
return get_moments(trace["x"]["value"])
def model(data, mask):
with numpyro.plate('N', N):
x = numpyro.sample('x', dist.Normal(0, 1))
with handlers.mask(mask_array=mask):
numpyro.sample('y', dist.Delta(x, log_density=1.))
with handlers.scale(scale=2):
numpyro.sample('obs', dist.Normal(x, 1), obs=data)
def body_fn(wrapped_carry, x):
i, rng_key, carry = wrapped_carry
rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None)
with handlers.block():
seeded_fn = handlers.seed(f, subkey) if subkey is not None else f
for subs_type, subs_map in substitute_stack:
subs_fn = partial(_subs_wrapper, subs_map, i, length)
if subs_type == 'condition':
seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
elif subs_type == 'substitute':
seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)
with handlers.trace() as trace:
carry, y = seeded_fn(carry, x)
return (i + 1, rng_key, carry), (PytreeTrace(trace), y)
def body_fn(wrapped_carry, x):
i, rng_key, carry = wrapped_carry
rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None)
with handlers.block():
seeded_fn = handlers.seed(f, subkey) if subkey is not None else f
for subs_type, subs_map in substitute_stack:
subs_fn = partial(_subs_wrapper, subs_map, i, length)
if subs_type == 'condition':
seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
elif subs_type == 'substitute':
seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)
with handlers.trace() as trace:
carry, y = seeded_fn(carry, x)
return (i + 1, rng_key, carry), (PytreeTrace(trace), y)
def predict(model, rng_key, samples, X, D_H):
model = handlers.substitute(handlers.seed(model, rng_key), samples)
# note that Y will be sampled in the model because we pass Y=None here
model_trace = handlers.trace(model).get_trace(X=X, Y=None, D_H=D_H)
return model_trace['Y']['value']
def predict(model, rng_key, samples, X, D_H):
model = handlers.substitute(handlers.seed(model, rng_key), samples)
# note that Y will be sampled in the model because we pass Y=None here
model_trace = handlers.trace(model).get_trace(X=X, Y=None, D_H=D_H)
return model_trace['Y']['value']
def _setup_prototype(self, *args, **kwargs):
rng_key = numpyro.sample("_{}_rng_key_setup".format(self.prefix), dist.PRNGIdentity())
with handlers.block():
init_params, _, self._postprocess_fn, self.prototype_trace = initialize_model(
rng_key, self.model,
init_strategy=self.init_strategy,
dynamic_args=False,
model_args=args,
model_kwargs=kwargs)
self._init_latent, unpack_latent = ravel_pytree(init_params[0])
# this is to match the behavior of Pyro, where we can apply
# unpack_latent for a batch of samples
self._unpack_latent = UnpackTransform(unpack_latent)
self.latent_dim = jnp.size(self._init_latent)
if self.latent_dim == 0:
raise RuntimeError('{} found no latent variables; Use an empty guide instead'
.format(type(self).__name__))