Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_optim_multi_params(optim_class, args):
params = {'x': np.array([1., 1., 1.]), 'y': np.array([-1, -1., -1.])}
opt = optim_class(*args)
opt_state = opt.init(params)
for i in range(2000):
opt_state = step(opt_state, opt)
for _, param in opt.get_params(opt_state).items():
assert np.allclose(param, np.zeros(3))
def test_dirichlet_categorical(kernel_cls, dense_mass):
warmup_steps, num_samples = 100, 20000
def model(data):
concentration = np.array([1.0, 1.0, 1.0])
p_latent = numpyro.sample('p_latent', dist.Dirichlet(concentration))
numpyro.sample('obs', dist.Categorical(p_latent), obs=data)
return p_latent
true_probs = np.array([0.1, 0.6, 0.3])
data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,))
kernel = kernel_cls(model, trajectory_length=1., dense_mass=dense_mass)
mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)
mcmc.run(random.PRNGKey(2), data)
samples = mcmc.get_samples()
assert_allclose(np.mean(samples['p_latent'], 0), true_probs, atol=0.02)
if 'JAX_ENABLE_x64' in os.environ:
assert samples['p_latent'].dtype == np.float64
return (quad_form, np.mean(res_norm), np.mean(iters)), (tangent_out, 0.0, 0.0)
if __name__ == '__main__':
N = 5
P = 4
b = np.array(onp.random.randn(N))
X = np.array(onp.random.randn(N * P).reshape((N, P)))
dkX = np.array(onp.random.randn(N * P).reshape((N, P)))
#Ainv_b_probes = np.array(onp.random.randn(N * 2).reshape((2, N)))
#probes = np.array(onp.random.randn(N * 1).reshape((1, N)))
kappa = 0.3 + 2.0 * np.array(onp.random.rand(P))
eta1 = 0.8
eta2 = 0.5
diag = np.array(onp.random.rand(N))
c = 1.0
#num_probes = 1
#probes = np.array(onp.random.randn(N * num_probes).reshape((num_probes, N)))
probes = math.sqrt(N) * np.eye(N)
def direct(_kappa, _b, _eta1, _eta2, _diag, include_log_det):
kX = _kappa * X
k = kernel(kX, kX, _eta1, _eta2, c)
k_diag = k + np.diag(_diag)
return direct_quad_form_log_det(k_diag, np.matmul(k, _b), _b * _diag, include_log_det=include_log_det)
def pcpcg(_kappa, _b, _eta1, _eta2, _diag):
return pcpcg_quad_form_log_det2(_kappa, _b, _eta1, _eta2, _diag, c, X,
probes, 2, 2, 1.0e-5, 400, 1)[0]
which = 3
def do_chunk(svi_state):
return _fori_loop(np.array(0), np.array(report_frequency), body_fn, (svi_state, np.array(0.0), np.zeros(2)))
y0: initial value for the state.
f0: initial value for the derivative, computed from `func(t0, y0)`.
t0: initial time.
dt: time step.
alpha, beta, c: Butcher tableau describing how to take the Runge-Kutta
step.
Returns:
y1: estimated function at t1 = t0 + dt
f1: derivative of the state at t1
y1_error: estimated error at t1
k: list of Runge-Kutta coefficients `k` used for calculating these terms.
"""
# Dopri5 Butcher tableaux
alpha = np.array([1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1., 0])
beta = np.array(
[[1 / 5, 0, 0, 0, 0, 0, 0],
[3 / 40, 9 / 40, 0, 0, 0, 0, 0],
[44 / 45, -56 / 15, 32 / 9, 0, 0, 0, 0],
[19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729, 0, 0, 0],
[9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656, 0, 0],
[35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0]])
c_sol = np.array([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84,
0])
c_error = np.array([35 / 384 - 1951 / 21600, 0, 500 / 1113 - 22642 / 50085,
125 / 192 - 451 / 720, -2187 / 6784 - -12231 / 42400,
11 / 84 - 649 / 6300, -1. / 60.])
def _fori_body_fun(i, val):
ti = t0 + dt * alpha[i-1]
yi = y0 + dt * np.dot(beta[i-1, :], val)
ft = func(yi, ti)
def ham_ising():
"""Dimension 2 "Ising" Hamiltonian.
This version from Evenbly & White, Phys. Rev. Lett. 116, 140403 (2016).
"""
E = np.array([[1, 0], [0, 1]])
X = np.array([[0, 1], [1, 0]])
Z = np.array([[1, 0], [0, -1]])
hmat = np.kron(X, np.kron(Z, X))
hmat -= 0.5 * (np.kron(np.kron(X, X), E) + np.kron(E, np.kron(X, X)))
return np.reshape(hmat, [2]*6)
def compute_pairwise_mean_variance(X, Y, dim1, dim2, msq, lam, eta1, xisq, c, sigma):
P, N = X.shape[1], X.shape[0]
probe = jnp.zeros((4, P))
probe = jax.ops.index_update(probe, jax.ops.index[:, dim1], jnp.array([1.0, 1.0, -1.0, -1.0]))
probe = jax.ops.index_update(probe, jax.ops.index[:, dim2], jnp.array([1.0, -1.0, 1.0, -1.0]))
eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq
kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))
kX = kappa * X
kprobe = kappa * probe
k_xx = kernel(kX, kX, eta1, eta2, c) + sigma ** 2 * jnp.eye(N)
k_xx_inv = jnp.linalg.inv(k_xx)
k_probeX = kernel(kprobe, kX, eta1, eta2, c)
k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)
vec = jnp.array([0.25, -0.25, -0.25, 0.25])
mu = jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, Y))
mu = jnp.dot(mu, vec)
def process_singleton_pcg(dim, P, kappa, kX, omega, Y, eta1, eta2, c, rank1, rank2,
cg_tol=1.0e-3, max_iters=200):
probe = np.zeros((2, P))
probe = jax.ops.index_update(probe, jax.ops.index[:, dim], np.array([1.0, -1.0]))
vec = np.array([0.50, -0.50])
mu, var = process_probe_pcg(kappa * probe, kX, kappa, omega, Y, vec, eta1, eta2, c, rank1, rank2,
cg_tol=cg_tol, max_iters=max_iters)
return mu, var
- 2.0 * meansum(probes_kX * Ainv_probes_kX) \
- meansum(probes_ksqXsq * Ainv_probes_ksqXsq) \
- np.mean(np.sum(probes, axis=-1) * np.sum(Ainv_probes, axis=-1)))
log_det_ddiag = meansum(probes * diag_dot * Ainv_probes)
tangent_out = -0.125 * (quad_form_dk + quad_form_deta1 + quad_form_deta2 + quad_form_ddiag - quad_form_db) + \
-0.5 * (log_det_dk + log_det_deta1 + log_det_deta2 + log_det_ddiag)
quad_form = 0.125 * np.dot(Kb, Ainv_b)
return (quad_form, np.mean(res_norm), np.mean(iters)), (tangent_out, 0.0, 0.0)
if __name__ == '__main__':
N = 5
P = 4
b = np.array(onp.random.randn(N))
X = np.array(onp.random.randn(N * P).reshape((N, P)))
dkX = np.array(onp.random.randn(N * P).reshape((N, P)))
#Ainv_b_probes = np.array(onp.random.randn(N * 2).reshape((2, N)))
#probes = np.array(onp.random.randn(N * 1).reshape((1, N)))
kappa = 0.3 + 2.0 * np.array(onp.random.rand(P))
eta1 = 0.8
eta2 = 0.5
diag = np.array(onp.random.rand(N))
c = 1.0
#num_probes = 1
#probes = np.array(onp.random.randn(N * num_probes).reshape((num_probes, N)))
probes = math.sqrt(N) * np.eye(N)
def direct(_kappa, _b, _eta1, _eta2, _diag, include_log_det):
kX = _kappa * X
k = kernel(kX, kX, _eta1, _eta2, c)