Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_warmup_adapter(jitted):
def find_reasonable_step_size(step_size, m_inv, z, rng_key):
return jnp.where(step_size < 1, step_size * 4, step_size / 4)
num_steps = 150
adaptation_schedule = build_adaptation_schedule(num_steps)
init_step_size = 1.
mass_matrix_size = 3
wa_init, wa_update = warmup_adapter(num_steps, find_reasonable_step_size)
wa_update = jit(wa_update) if jitted else wa_update
rng_key = random.PRNGKey(0)
z = jnp.ones(3)
wa_state = wa_init((z, None, None, None), rng_key, init_step_size, mass_matrix_size=mass_matrix_size)
step_size, inverse_mass_matrix, _, _, _, window_idx, _ = wa_state
assert step_size == find_reasonable_step_size(init_step_size, inverse_mass_matrix, z, rng_key)
assert_allclose(inverse_mass_matrix, jnp.ones(mass_matrix_size))
assert window_idx == 0
window = adaptation_schedule[0]
for t in range(window.start, window.end + 1):
wa_state = wa_update(t, 0.7 + 0.1 * t / (window.end - window.start), z, wa_state)
last_step_size = step_size
step_size, inverse_mass_matrix, _, _, _, window_idx, _ = wa_state
assert window_idx == 1
# step_size is decreased because accept_prob < target_accept_prob
assert step_size < last_step_size
# inverse_mass_matrix does not change at the end of the first window
def test_param():
# this test the validity of model/guide sites having
# param constraints contain composed transformed
rng_keys = random.split(random.PRNGKey(0), 5)
a_minval = 1
c_minval = -2
c_maxval = -1
a_init = jnp.exp(random.normal(rng_keys[0])) + a_minval
b_init = jnp.exp(random.normal(rng_keys[1]))
c_init = random.uniform(rng_keys[2], minval=c_minval, maxval=c_maxval)
d_init = random.uniform(rng_keys[3])
obs = random.normal(rng_keys[4])
def model():
a = numpyro.param('a', a_init, constraint=constraints.greater_than(a_minval))
b = numpyro.param('b', b_init, constraint=constraints.positive)
numpyro.sample('x', dist.Normal(a, b), obs=obs)
def guide():
c = numpyro.param('c', c_init, constraint=constraints.interval(c_minval, c_maxval))
def test_Reparametrized_unparametrized_transform():
def doubled(params):
return 2 * params
@parametrized
def net():
return parameter((), lambda key, shape: 2 * np.ones(shape))
scared_params = Reparametrized(net, reparametrization_factory=lambda: doubled)
params = scared_params.init_parameters(key=PRNGKey(0))
reg_loss_out = scared_params.apply(params)
assert 4 == reg_loss_out
def random_inputs(input_shape, key=PRNGKey(0)):
if type(input_shape) is tuple:
return random.uniform(key, input_shape, np.float32)
elif type(input_shape) is list:
return [random_inputs(key, shape) for shape in input_shape]
else:
raise TypeError(type(input_shape))
def model(data):
y_prob = numpyro.sample("y_prob", dist.Beta(1., 1.))
with numpyro.plate("data", data.shape[0]):
y = numpyro.sample("y", dist.Bernoulli(y_prob))
z = numpyro.sample("z", dist.Bernoulli(0.65 * y + 0.1))
numpyro.sample("obs", dist.Normal(2. * z, 1.), obs=data)
N = 2000
y_prob = 0.3
y = dist.Bernoulli(y_prob).sample(random.PRNGKey(0), (N,))
z = dist.Bernoulli(0.65 * y + 0.1).sample(random.PRNGKey(1))
data = dist.Normal(2. * z, 1.0).sample(random.PRNGKey(2))
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500)
mcmc.run(random.PRNGKey(3), data)
samples = mcmc.get_samples()
assert_allclose(samples["y_prob"].mean(0), y_prob, atol=0.05)
def test_unnormalized_normal(kernel_cls, dense_mass):
true_mean, true_std = 1., 2.
warmup_steps, num_samples = 1000, 8000
def potential_fn(z):
return 0.5 * np.sum(((z - true_mean) / true_std) ** 2)
init_params = np.array(0.)
kernel = kernel_cls(potential_fn=potential_fn, trajectory_length=9, dense_mass=dense_mass)
mcmc = MCMC(kernel, warmup_steps, num_samples)
mcmc.run(random.PRNGKey(0), init_params=init_params)
hmc_states = mcmc.get_samples()
assert_allclose(np.mean(hmc_states), true_mean, rtol=0.05)
assert_allclose(np.std(hmc_states), true_std, rtol=0.05)
if 'JAX_ENABLE_x64' in os.environ:
assert hmc_states.dtype == np.float64
def renyi_loss_fn(x):
return RenyiELBO(alpha=alpha, num_particles=10).loss(random.PRNGKey(0), {}, model, guide, x)
def test_logistic_regression(kernel_cls):
N, dim = 3000, 3
warmup_steps, num_samples = 1000, 8000
data = random.normal(random.PRNGKey(0), (N, dim))
true_coefs = np.arange(1., dim + 1.)
logits = np.sum(true_coefs * data, axis=-1)
labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))
def model(labels):
coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim)))
logits = np.sum(coefs * data, axis=-1)
return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)
kernel = kernel_cls(model=model, trajectory_length=10)
mcmc = MCMC(kernel, warmup_steps, num_samples)
mcmc.run(random.PRNGKey(2), labels)
samples = mcmc.get_samples()
assert_allclose(np.mean(samples['coefs'], 0), true_coefs, atol=0.22)
if 'JAX_ENABLE_x64' in os.environ:
assert samples['coefs'].dtype == np.float64
def test_categorical_stats(p):
rng_key = random.PRNGKey(0)
n = 10000
z = categorical(rng_key, p, (n,))
_, counts = onp.unique(z, return_counts=True)
assert_allclose(counts / float(n), p, atol=0.01)
def test_internal_param_sharing():
@parametrized
def shared_net(inputs, layer=Dense(2, zeros, zeros)):
return layer(layer(inputs))
inputs = np.zeros((1, 2))
params = shared_net.init_parameters(inputs, key=PRNGKey(0))
assert_parameters_equal(((np.zeros((2, 2)), np.zeros(2),),), params)
out = shared_net.apply(params, inputs)
assert np.array_equal(np.zeros((1, 2)), out)
out_ = shared_net.apply(params, inputs, jit=True)
assert np.array_equal(out, out_)