Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
expected_inputs = OrderedDict(sample_inputs + batch_inputs + event_inputs)
sample_inputs = OrderedDict(sample_inputs)
batch_inputs = OrderedDict(batch_inputs)
event_inputs = OrderedDict(event_inputs)
x = random_gaussian(be_inputs)
rng_key = subkey = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32)
xfail = False
for num_sampled in range(len(event_inputs) + 1):
for sampled_vars in itertools.combinations(list(event_inputs), num_sampled):
sampled_vars = frozenset(sampled_vars)
print('sampled_vars: {}'.format(', '.join(sampled_vars)))
try:
if rng_key is not None:
import jax
rng_key, subkey = jax.random.split(rng_key)
y = x.sample(sampled_vars, sample_inputs, rng_key=subkey)
except NotImplementedError:
xfail = True
continue
if num_sampled == len(event_inputs):
assert isinstance(y, (Delta, Contraction))
if sampled_vars:
assert dict(y.inputs) == dict(expected_inputs), sampled_vars
else:
assert y is x
if xfail:
pytest.xfail(reason='Not implemented')
def init_fn(z_info, rng_key, step_size=1.0, inverse_mass_matrix=None, mass_matrix_size=None):
"""
:param IntegratorState z_info: The initial integrator state.
:param jax.random.PRNGKey rng_key: Random key to be used as the source of randomness.
:param float step_size: Initial step size.
:param inverse_mass_matrix: Inverse of the initial mass matrix. If ``None``,
inverse of mass matrix will be an identity matrix with size is decided
by the argument `mass_matrix_size`.
:param int mass_matrix_size: Size of the mass matrix.
:return: initial state of the adapt scheme.
"""
rng_key, rng_key_ss = random.split(rng_key)
if inverse_mass_matrix is None:
assert mass_matrix_size is not None
if dense_mass:
inverse_mass_matrix = jnp.identity(mass_matrix_size)
else:
inverse_mass_matrix = jnp.ones(mass_matrix_size)
mass_matrix_sqrt = inverse_mass_matrix
else:
if dense_mass:
mass_matrix_sqrt = cholesky_of_inverse(inverse_mass_matrix)
else:
mass_matrix_sqrt = jnp.sqrt(jnp.reciprocal(inverse_mass_matrix))
if adapt_step_size:
step_size = find_reasonable_step_size(step_size, inverse_mass_matrix, z_info, rng_key_ss)
ss_state = ss_init(jnp.log(10 * step_size))
constrained_values, inv_transforms = {}, {}
for k, v in model_trace.items():
if v['type'] == 'sample' and not v['is_observed'] and not v['fn'].is_discrete:
constrained_values[k] = v['value']
inv_transforms[k] = biject_to(v['fn'].support)
params = transform_fn(inv_transforms,
{k: v for k, v in constrained_values.items()},
invert=True)
else: # this branch doesn't require tracing the model
params = {}
for k, v in prototype_params.items():
if k in init_values:
params[k] = init_values[k]
else:
params[k] = random.uniform(subkey, jnp.shape(v), minval=-radius, maxval=radius)
key, subkey = random.split(key)
potential_fn = partial(potential_energy, model, model_args, model_kwargs, enum=enum)
pe, z_grad = value_and_grad(potential_fn)(params)
z_grad_flat = ravel_pytree(z_grad)[0]
is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
return i + 1, key, (params, pe, z_grad), is_valid
rng_key, samples = val
model_trace = trace(seed(substitute(model, samples), rng_key)).get_trace(
*model_args, **model_kwargs)
if return_sites is not None:
if return_sites == '':
sites = {k for k, site in model_trace.items() if site['type'] != 'plate'}
else:
sites = return_sites
else:
sites = {k for k, site in model_trace.items()
if (site['type'] == 'sample' and k not in samples) or (site['type'] == 'deterministic')}
return {name: site['value'] for name, site in model_trace.items() if name in sites}
num_samples = int(np.prod(batch_shape))
if num_samples > 1:
rng_key = random.split(rng_key, num_samples)
rng_key = rng_key.reshape(batch_shape + (2,))
chunk_size = num_samples if parallel else 1
return soft_vmap(single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size)
def body_fun(i, opt_state):
elbo_rng, data_rng = random.split(random.fold_in(rng, i))
batch = binarize_batch(data_rng, i, train_images)
loss = lambda params: -elbo(elbo_rng, params, batch) / batch_size
g = grad(loss)(optimizers.get_params(opt_state))
return opt_update(i, g, opt_state)
return lax.fori_loop(0, num_batches, body_fun, opt_state)
def update(self, svi_state, *args, **kwargs):
"""
Take a single step of SVI (possibly on a batch / minibatch of data),
using the optimizer.
:param svi_state: current state of SVI.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param kwargs: keyword arguments to the model / guide (these can possibly vary
during the course of fitting).
:return: tuple of `(svi_state, loss)`.
"""
rng_key, rng_key_step = random.split(svi_state.rng_key)
params = self.optim.get_params(svi_state.optim_state)
loss_val, grads = value_and_grad(
lambda x: self.loss.loss(rng_key_step, self.constrain_fn(x), self.model, self.guide,
*args, **kwargs, **self.static_kwargs))(params)
optim_state = self.optim.update(grads, svi_state.optim_state)
return SVIState(optim_state, rng_key), loss_val
def init(self, rng_key, *args, **kwargs):
"""
:param jax.random.PRNGKey rng_key: random number generator seed.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param kwargs: keyword arguments to the model / guide (these can possibly vary
during the course of fitting).
:return: tuple containing initial :data:`SVIState`, and `get_params`, a callable
that transforms unconstrained parameter values from the optimizer to the
specified constrained domain
"""
rng_key, model_seed, guide_seed = random.split(rng_key, 3)
model_init = seed(self.model, model_seed)
guide_init = seed(self.guide, guide_seed)
guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
model_trace = trace(model_init).get_trace(*args, **kwargs, **self.static_kwargs)
params = {}
inv_transforms = {}
# NB: params in model_trace will be overwritten by params in guide_trace
for site in list(model_trace.values()) + list(guide_trace.values()):
if site['type'] == 'param':
constraint = site['kwargs'].pop('constraint', constraints.real)
transform = biject_to(constraint)
inv_transforms[site['name']] = transform
params[site['name']] = transform.inv(site['value'])
self.constrain_fn = partial(transform_fn, inv_transforms)
return SVIState(self.optim.init(params), rng_key)
import jax
rng_keys = jax.random.split(rng_key, len(self.terms))
else:
rng_keys = [None] * len(self.terms)
# Design choice: we sample over logaddexp reductions, but leave logaddexp
# binary choices symbolic.
terms = [
term.unscaled_sample(sampled_vars.intersection(term.inputs), sample_inputs)
for term, rng_key in zip(self.terms, rng_keys)]
return Contraction(self.red_op, self.bin_op, self.reduced_vars, *terms)
if self.bin_op is ops.add:
if rng_key is not None:
import jax
rng_keys = jax.random.split(rng_key)
else:
rng_keys = [None] * 2
# Sample variables greedily in order of the terms in which they appear.
for term in self.terms:
greedy_vars = sampled_vars.intersection(term.inputs)
if greedy_vars:
break
greedy_terms, terms = [], []
for term in self.terms:
(terms if greedy_vars.isdisjoint(term.inputs) else greedy_terms).append(term)
if len(greedy_terms) == 1:
term = greedy_terms[0]
terms.append(term.unscaled_sample(greedy_vars, sample_inputs, rng_keys[0]))
result = Contraction(self.red_op, self.bin_op, self.reduced_vars, *terms)
elif (len(greedy_terms) == 2 and
def body_fn(carry):
"""Inner loop of Knuth algorithm."""
i, k, rng, log_prod = carry
rng, subkey = random.split(rng)
k = np.where(log_prod > -lam, k + 1, k)
return i + 1, k, rng, log_prod + np.log(random.uniform(subkey, shape))
def _btrs_body_fn(val):
_, key, _, _ = val
key, key_u, key_v = random.split(key, 3)
u = random.uniform(key_u)
v = random.uniform(key_v)
u = u - 0.5
k = jnp.floor((2 * tr_params.a / (0.5 - jnp.abs(u)) + tr_params.b) * u + tr_params.c).astype(n.dtype)
return k, key, u, v