Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
return p_latent
@jit
def get_samples(rng_key, data, step_size, trajectory_length, target_accept_prob):
kernel = kernel_cls(model, step_size=step_size, trajectory_length=trajectory_length,
target_accept_prob=target_accept_prob)
mcmc = MCMC(kernel, warmup_steps, num_samples, num_chains=2, chain_method=chain_method,
progress_bar=False)
mcmc.run(rng_key, data)
return mcmc.get_samples()
true_probs = np.array([0.1, 0.6, 0.3])
data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,))
samples = get_samples(rng_key, data, step_size, trajectory_length, target_accept_prob)
assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.02)
def fastvar(x, axis, keepdims):
"""A fast but less numerically-stable variance calculation than np.var."""
return np.mean(x ** 2, axis, keepdims=keepdims) - np.mean(x, axis, keepdims=keepdims) ** 2
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -np.mean(logsoftmax(preds) * targets)
D = b.shape[-1]
presolve = lowrank_presolve(kX, diag, eta1, eta2, c, kappa, rank1, rank2)
mvm = lambda b: np.matmul(A, b)
b_probes = np.concatenate([b[None, :], probes])
Ainv_b_probes, res_norm, iters = pcg_batch_b(b_probes, mvm, presolve=presolve, cg_tol=cg_tol, max_iters=max_iters)
Ainv_b, Ainv_probes = Ainv_b_probes[0], Ainv_b_probes[1:]
quad_form_dA = -np.dot(Ainv_b, np.matmul(A_dot, Ainv_b))
quad_form_db = 2.0 * np.dot(Ainv_b, b_dot)
log_det_dA = np.mean(np.einsum('...i,...i->...', np.matmul(probes, A_dot), Ainv_probes))
tangent_out = log_det_dA + quad_form_dA + quad_form_db
quad_form = np.dot(b, Ainv_b)
return (quad_form, np.mean(res_norm), np.mean(iters)), (tangent_out, 0.0, 0.0)
assert (RT + 1, AT) == rewards_to_actions.shape
# (B, RT)
td_deltas = deltas(
value_predictions_old, # (B, RT+1)
padded_rewards,
reward_mask,
gamma=gamma)
# (B, RT)
advantages = gae_advantages(
td_deltas, reward_mask, lambda_=lambda_, gamma=gamma)
# Normalize the advantages.
advantage_mean = np.mean(advantages)
advantage_std = np.std(advantages)
advantages = (advantages - advantage_mean) / (advantage_std + 1e-8)
# Scatter advantages over padded_actions.
# rewards_to_actions is RT + 1 -> AT, so we pad the advantages and the reward
# mask by 1.
advantages = np.dot(np.pad(advantages, ((0, 0), (0, 1))), rewards_to_actions)
action_mask = np.dot(
np.pad(reward_mask, ((0, 0), (0, 1))), rewards_to_actions
)
# (B, AT)
ratios = compute_probab_ratios(log_probab_actions_new, log_probab_actions_old,
padded_actions, action_mask)
assert (B, AT) == ratios.shape
def cross_entropy(logits, labels):
assert logits.ndim == 2
assert labels.ndim == 1
assert len(logits) == len(labels)
logprobs = logits - logsumexp(logits, axis=1, keepdims=True)
nll = jnp.take_along_axis(logprobs, jnp.expand_dims(labels, axis=1), axis=1)
ce = -jnp.mean(nll)
return ce
def loss(batch):
theta = wavenet(batch)[:, :-1, :]
# now slice the padding off the batch
sliced_batch = batch[:, receptive_field:, :]
return (np.mean(discretized_mix_logistic_loss(
theta, sliced_batch, num_class=1 << 16), axis=0)
* np.log2(np.e) / (output_width - 1))
kX = kappa * X
omega_b = b * diag
mvm = lambda _b: kernel_mvm_diag(_b, kX, eta1, eta2, c, diag, dilation=dilation,dilation2=dilation2)
presolve = lowrank_presolve(kX, diag, eta1, eta2, c, kappa, rank1, rank2)
om_b_probes = np.concatenate([omega_b[None, :], probes])
Ainv_om_b_probes, res_norm, iters = pcg_batch_b(om_b_probes, mvm, presolve=presolve,
cg_tol=cg_tol, max_iters=max_iters)
Ainv_om_b, Ainv_probes = Ainv_om_b_probes[0], Ainv_om_b_probes[1:]
K_Ainv_om_b = kernel_mvm(Ainv_om_b, kX, eta1, eta2, c, dilation=dilation, dilation2=dilation2)
quad_form = 0.125 * np.dot(b, K_Ainv_om_b)
residuals = (kX, kappa, eta1, eta2, K_Ainv_om_b, Ainv_om_b, diag, Ainv_probes, probes)
return (quad_form, np.mean(res_norm), np.mean(iters)), residuals