Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_chain(use_init_params, chain_method):
N, dim = 3000, 3
num_chains = 2
num_warmup, num_samples = 5000, 5000
data = random.normal(random.PRNGKey(0), (N, dim))
true_coefs = jnp.arange(1., dim + 1.)
logits = jnp.sum(true_coefs * data, axis=-1)
labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))
def model(labels):
coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
logits = jnp.sum(coefs * data, axis=-1)
return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)
kernel = NUTS(model=model)
mcmc = MCMC(kernel, num_warmup, num_samples, num_chains=num_chains)
mcmc.chain_method = chain_method
init_params = None if not use_init_params else \
{'coefs': jnp.tile(jnp.ones(dim), num_chains).reshape(num_chains, dim)}
mcmc.run(random.PRNGKey(2), labels, init_params=init_params)
samples_flat = mcmc.get_samples()
def image_sample(rng, params, nrow, ncol):
"""Sample images from the generative model."""
_, dec_params = params
code_rng, img_rng = random.split(rng)
logits = decode(dec_params, random.normal(code_rng, (nrow * ncol, 10)))
sampled_images = random.bernoulli(img_rng, np.logaddexp(0., logits))
return image_grid(nrow, ncol, sampled_images, (28, 28))
seeded_model = seed(model, model_seed)
seeded_guide = seed(guide, guide_seed)
guide_log_density, guide_trace = log_density(seeded_guide, args, kwargs, param_map)
seeded_model = replay(seeded_model, guide_trace)
model_log_density, _ = log_density(seeded_model, args, kwargs, param_map)
# log p(z) - log q(z)
elbo = model_log_density - guide_log_density
return elbo
# Return (-elbo) since by convention we do gradient descent on a loss and
# the ELBO is a lower bound that needs to be maximized.
if self.num_particles == 1:
return - single_particle_elbo(rng_key)
else:
rng_keys = random.split(rng_key, self.num_particles)
return - np.mean(vmap(single_particle_elbo)(rng_keys))
"name": "jax",
"np": jnp,
"logsumexp": jax_special.logsumexp,
"expit": jax_special.expit,
"erf": jax_special.erf,
"conv": jax_conv,
"avg_pool": jax_avg_pool,
"max_pool": jax_max_pool,
"sum_pool": jax_sum_pool,
"jit": jax.jit,
"grad": jax.grad,
"pmap": jax.pmap,
"eval_on_shapes": jax_eval_on_shapes,
"random_uniform": jax_random.uniform,
"random_randint": jax_randint,
"random_normal": jax_random.normal,
"random_bernoulli": jax_random.bernoulli,
"random_get_prng": jax.jit(jax_random.PRNGKey),
"random_split": jax_random.split,
"dataset_as_numpy": tfds.as_numpy,
}
_NUMPY_BACKEND = {
"name": "numpy",
"np": onp,
"jit": (lambda f: f),
"random_get_prng": lambda seed: None,
"random_split": lambda prng, num=2: (None,) * num,
"expit": (lambda x: 1. / (1. + onp.exp(-x))),
}
def evaluate(opt_state, images):
params = get_params(opt_state)
elbo_rng, data_rng, image_rng = random.split(test_rng, 3)
binarized_test = random.bernoulli(data_rng, images)
test_elbo = elbo(elbo_rng, params, binarized_test) / images.shape[0]
sampled_images = image_sample(image_rng, params, nrow, ncol)
return test_elbo, sampled_images
def _double_tree(current_tree, vv_update, kinetic_fn, inverse_mass_matrix, step_size,
going_right, rng_key, energy_current, max_delta_energy, r_ckpts, r_sum_ckpts):
key, transition_key = random.split(rng_key)
new_tree = _iterative_build_subtree(current_tree, vv_update, kinetic_fn,
inverse_mass_matrix, step_size,
going_right, key, energy_current, max_delta_energy,
r_ckpts, r_sum_ckpts)
return _combine_tree(current_tree, new_tree, inverse_mass_matrix, going_right, transition_key,
True)
assert len(vecs.shape) == 2
n_vecs = vecs.shape[0]
rng1, rng2 = backend.random.split(rng, num=2)
# We need to sample 2 * n_hashes * r_div_2 vectors from `vecs` at random.
num_needed = 2 * n_hashes * r_div_2
if n_vecs < num_needed:
# shape = (n_hashes, r_div_2)
random_idxs_1 = jax.random.randint(
rng1, (n_hashes, r_div_2), 0, n_vecs)
random_idxs_2 = jax.random.randint(
rng2, (n_hashes, r_div_2), 0, n_vecs)
else:
# Sample without replacement.
shuffled_indices = jax.random.shuffle(rng1, np.arange(n_vecs))
random_idxs = np.reshape(shuffled_indices[:num_needed],
(2, n_hashes, r_div_2))
random_idxs_1 = random_idxs[0]
random_idxs_2 = random_idxs[1]
if self._data_rotation_farthest:
# shape = (n_hashes * r_div_2, )
random_idxs_1 = np.reshape(random_idxs_1, (-1,))
random_vecs_1 = vecs[random_idxs_1]
# Sample candidates for vec2s.
rng, subrng = backend.random.split(rng)
# shape = (self._data_rotation_farthest_num, n_hashes * r_div_2)
candidate_idxs_2 = jax.random.randint(
subrng, (self._data_rotation_farthest_num, n_hashes * r_div_2), 0,
n_vecs)
def __init__(self, seed):
key = random.PRNGKey(0)
self.key = key
self.subkey = key
def _rvs(self, alpha):
K = alpha.shape[-1]
gamma_samples = random.gamma(self._random_state, alpha, shape=self._size + (K,))
return gamma_samples / jnp.sum(gamma_samples, axis=-1, keepdims=True)
def dropout(self, x, p, seed=None):
seed = next(self.rng)
p = 1 - p
keep = random.bernoulli(seed, p, x.shape)
return np.where(keep, x / p, 0)