Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_logistic_regression(auto_class):
N, dim = 3000, 3
data = random.normal(random.PRNGKey(0), (N, dim))
true_coefs = jnp.arange(1., dim + 1.)
logits = jnp.sum(true_coefs * data, axis=-1)
labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))
def model(data, labels):
coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
logits = jnp.sum(coefs * data, axis=-1)
return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)
adam = optim.Adam(0.01)
rng_key_init = random.PRNGKey(1)
guide = auto_class(model, init_strategy=init_strategy)
svi = SVI(model, guide, adam, ELBO())
svi_state = svi.init(rng_key_init, data, labels)
def body_fn(i, val):
svi_state, loss = svi.update(val, data, labels)
return svi_state
def model(data):
f = numpyro.sample('beta', dist.Beta(jnp.ones(2), jnp.ones(2)))
numpyro.sample('obs', dist.Bernoulli(f), obs=data)
def model(labels):
coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
logits = numpyro.deterministic('logits', jnp.sum(coefs * data, axis=-1))
return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)
def test_log_likelihood():
model, data, _ = beta_bernoulli()
samples = Predictive(model, return_sites=["beta"], num_samples=100).get_samples(random.PRNGKey(1))
loglik = log_likelihood(model, samples, data)
assert loglik.keys() == {"obs"}
# check shapes
assert loglik["obs"].shape == (100,) + data.shape
assert_allclose(loglik["obs"], dist.Bernoulli(samples["beta"].reshape((100, 1, -1))).log_prob(data))
def test_beta_bernoulli_x64(kernel_cls):
warmup_steps, num_samples = (100000, 100000) if kernel_cls is SA else (500, 20000)
def model(data):
alpha = jnp.array([1.1, 1.1])
beta = jnp.array([1.1, 1.1])
p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta))
numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data)
return p_latent
true_probs = jnp.array([0.9, 0.1])
data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2))
if kernel_cls is SA:
kernel = SA(model=model)
else:
kernel = kernel_cls(model=model, trajectory_length=0.1)
mcmc = MCMC(kernel, num_warmup=warmup_steps, num_samples=num_samples, progress_bar=False)
mcmc.run(random.PRNGKey(2), data)
mcmc.print_summary()
samples = mcmc.get_samples()
assert_allclose(jnp.mean(samples['p_latent'], 0), true_probs, atol=0.05)
if 'JAX_ENABLE_X64' in os.environ:
assert samples['p_latent'].dtype == jnp.float64
def test_logistic_regression(kernel_cls):
N, dim = 3000, 3
warmup_steps, num_samples = 1000, 8000
data = random.normal(random.PRNGKey(0), (N, dim))
true_coefs = np.arange(1., dim + 1.)
logits = np.sum(true_coefs * data, axis=-1)
labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))
def model(labels):
coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim)))
logits = np.sum(coefs * data, axis=-1)
return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)
kernel = kernel_cls(model=model, trajectory_length=10)
mcmc = MCMC(kernel, warmup_steps, num_samples)
mcmc.run(random.PRNGKey(2), labels)
samples = mcmc.get_samples()
assert_allclose(np.mean(samples['coefs'], 0), true_coefs, atol=0.22)
if 'JAX_ENABLE_x64' in os.environ:
assert samples['coefs'].dtype == np.float64
def test_functional_beta_bernoulli_x64(algo):
warmup_steps, num_samples = 500, 20000
def model(data):
alpha = jnp.array([1.1, 1.1])
beta = jnp.array([1.1, 1.1])
p_latent = numpyro.sample('p_latent', dist.Beta(alpha, beta))
numpyro.sample('obs', dist.Bernoulli(p_latent), obs=data)
return p_latent
true_probs = jnp.array([0.9, 0.1])
data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2))
init_params, potential_fn, constrain_fn, _ = initialize_model(random.PRNGKey(2), model, model_args=(data,))
init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
hmc_state = init_kernel(init_params,
trajectory_length=1.,
num_warmup=warmup_steps)
samples = fori_collect(0, num_samples, sample_kernel, hmc_state,
transform=lambda x: constrain_fn(x.z))
assert_allclose(jnp.mean(samples['p_latent'], 0), true_probs, atol=0.05)
if 'JAX_ENABLE_X64' in os.environ:
assert samples['p_latent'].dtype == jnp.float64
def beta_bernoulli():
N = 800
true_probs = np.array([0.2, 0.7])
data = dist.Bernoulli(true_probs).sample(random.PRNGKey(0), (N,))
def model(data=None):
beta = numpyro.sample("beta", dist.Beta(np.ones(2), np.ones(2)))
with numpyro.plate("plate", N, dim=-2):
numpyro.sample("obs", dist.Bernoulli(beta), obs=data)
return model, data, true_probs
def model(data, labels):
dim = data.shape[1]
coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim), np.ones(dim)))
logits = np.dot(data, coefs)
return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)
x, y = 0, 0
for t in pyro_markov(range(max_length)):
with numpyro_mask(mask_array=(t < lengths).reshape(lengths.shape + (1,))):
probs_xx = probs_x[x]
probs_xx = np.broadcast_to(probs_xx, probs_xx.shape[:-3] + (num_sequences, 1) + probs_xx.shape[-1:])
x = pyro_sample("x_{}".format(t), dist.Categorical(probs_xx),
infer={"enumerate": "parallel"})
logging.info(f"x[{t}]: {x.shape}")
# Note the broadcasting tricks here: to index probs_y on tensors x and y,
# we also need a final tensor for the tones dimension. This is conveniently
# provided by the plate associated with that dimension.
with tones_plate as tones:
probs_yx = probs_y[x, y, tones]
probs_yx = np.broadcast_to(probs_yx, probs_yx.shape[:-2] + (num_sequences,) + probs_yx.shape[-1:])
y = pyro_sample("y_{}".format(t),
dist.Bernoulli(probs_yx),
obs=sequences[batch, t]).astype(np.int32)
# dist.Bernoulli(probs_y[x, y, tones]),