Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
return random.uniform(key, size, minval=lower_bound, maxval=upper_bound)
elif isinstance(constraint, (constraints._Real, constraints._RealVector)):
return random.normal(key, size)
elif isinstance(constraint, constraints._Simplex):
return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1])
elif isinstance(constraint, constraints._Multinomial):
n = size[-1]
return multinomial(key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1])
elif isinstance(constraint, constraints._CorrCholesky):
return signed_stick_breaking_tril(
random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
elif isinstance(constraint, constraints._CorrMatrix):
cholesky = signed_stick_breaking_tril(
random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1))
return jnp.matmul(cholesky, jnp.swapaxes(cholesky, -2, -1))
elif isinstance(constraint, constraints._LowerCholesky):
return jnp.tril(random.uniform(key, size))
elif isinstance(constraint, constraints._PositiveDefinite):
x = random.normal(key, size)
return jnp.matmul(x, jnp.swapaxes(x, -2, -1))
elif isinstance(constraint, constraints._OrderedVector):
x = jnp.cumsum(random.exponential(key, size), -1)
return x - random.normal(key, size[:-1])
else:
raise NotImplementedError('{} not implemented.'.format(constraint))
def scale_tril(self):
# The following identity is used to increase the numerically computation stability
# for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3):
# W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2
# The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1,
# hence it is well-conditioned and safe to take Cholesky decomposition.
cov_diag_sqrt_unsqueeze = jnp.expand_dims(jnp.sqrt(self.cov_diag), axis=-1)
Dinvsqrt_W = self.cov_factor / cov_diag_sqrt_unsqueeze
K = jnp.matmul(Dinvsqrt_W, jnp.swapaxes(Dinvsqrt_W, -1, -2))
K = jnp.add(K, jnp.identity(K.shape[-1]))
scale_tril = cov_diag_sqrt_unsqueeze * jnp.linalg.cholesky(K)
return scale_tril
dilation=4, max_iters=200, cg_tol=1.0e-3):
k_probeX = kernel(kprobe, kX, eta1, eta2, c)
k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)
diag = 1.0 / omega
mvm = lambda b: kernel_mvm_diag(b, kX, eta1, eta2, c, diag, dilation=dilation)
presolve = lowrank_presolve(kX, diag, eta1, eta2, c, kappa, rank1, rank2)
Y_omega = 0.5 * Y / omega
Y_kprb = np.concatenate([Y_omega[None, :], k_probeX])
Kinv_Y_kprb = pcg_batch_b(Y_kprb, mvm, presolve=presolve, cg_tol=cg_tol, max_iters=max_iters)[0]
mu = np.dot(vec, np.dot(k_probeX, Kinv_Y_kprb[0]))
var = k_prbprb - np.matmul(Kinv_Y_kprb[1:], np.transpose(k_probeX))
var = np.dot(vec, np.matmul(var, vec))
return mu, var
def euclidean_kinetic_energy(inverse_mass_matrix, r):
r, _ = ravel_pytree(r)
if inverse_mass_matrix.ndim == 2:
v = jnp.matmul(inverse_mass_matrix, r)
elif inverse_mass_matrix.ndim == 1:
v = jnp.multiply(inverse_mass_matrix, r)
return 0.5 * jnp.dot(v, r)
kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))
kX = kappa * X
kprobe = kappa * probe
k_xx = kernel(kX, kX, eta1, eta2, c) + sigma ** 2 * jnp.eye(N)
k_xx_inv = jnp.linalg.inv(k_xx)
k_probeX = kernel(kprobe, kX, eta1, eta2, c)
k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)
vec = jnp.array([0.25, -0.25, -0.25, 0.25])
mu = jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, Y))
mu = jnp.dot(mu, vec)
var = k_prbprb - jnp.matmul(k_probeX, jnp.matmul(k_xx_inv, jnp.transpose(k_probeX)))
var = jnp.matmul(var, vec)
var = jnp.dot(var, vec)
return mu, var
kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))
kX = kappa * X
kprobe = kappa * probe
k_xx = kernel(kX, kX, eta1, eta2, c) + sigma ** 2 * jnp.eye(N)
L = cho_factor(k_xx, lower=True)[0]
k_probeX = kernel(kprobe, kX, eta1, eta2, c)
k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)
mu = jnp.matmul(k_probeX, cho_solve((L, True), Y))
mu = jnp.sum(mu * vec, axis=-1)
Linv_k_probeX = solve_triangular(L, jnp.transpose(k_probeX), lower=True)
covar = k_prbprb - jnp.matmul(jnp.transpose(Linv_k_probeX), Linv_k_probeX)
covar = jnp.matmul(vec, jnp.matmul(covar, jnp.transpose(vec)))
# sample from N(mu, covar)
L = jnp.linalg.cholesky(covar)
sample = mu + jnp.matmul(L, np.random.randn(num_coefficients))
return sample
jnp.array([0.25, -0.25, -0.25, 0.25]))
start1 += 4
start2 += 1
eta2 = jnp.square(eta1) * jnp.sqrt(xisq) / msq
kappa = jnp.sqrt(msq) * lam / jnp.sqrt(msq + jnp.square(eta1 * lam))
kX = kappa * X
kprobe = kappa * probe
k_xx = kernel(kX, kX, eta1, eta2, c) + sigma ** 2 * jnp.eye(N)
L = cho_factor(k_xx, lower=True)[0]
k_probeX = kernel(kprobe, kX, eta1, eta2, c)
k_prbprb = kernel(kprobe, kprobe, eta1, eta2, c)
mu = jnp.matmul(k_probeX, cho_solve((L, True), Y))
mu = jnp.sum(mu * vec, axis=-1)
Linv_k_probeX = solve_triangular(L, jnp.transpose(k_probeX), lower=True)
covar = k_prbprb - jnp.matmul(jnp.transpose(Linv_k_probeX), Linv_k_probeX)
covar = jnp.matmul(vec, jnp.matmul(covar, jnp.transpose(vec)))
# sample from N(mu, covar)
L = jnp.linalg.cholesky(covar)
sample = mu + jnp.matmul(L, np.random.randn(num_coefficients))
return sample
def predict(rng_key, X, Y, X_test, var, length, noise):
# compute kernels between train and test data, etc.
k_pp = kernel(X_test, X_test, var, length, noise, include_noise=True)
k_pX = kernel(X_test, X, var, length, noise, include_noise=False)
k_XX = kernel(X, X, var, length, noise, include_noise=True)
K_xx_inv = np.linalg.inv(k_XX)
K = k_pp - np.matmul(k_pX, np.matmul(K_xx_inv, np.transpose(k_pX)))
sigma_noise = np.sqrt(np.clip(np.diag(K), a_min=0.)) * jax.random.normal(rng_key, X_test.shape[:1])
mean = np.matmul(k_pX, np.matmul(K_xx_inv, Y))
# we return both the mean function and a sample from the posterior predictive for the
# given set of hyperparameters
return mean, mean + sigma_noise
'MatMul': lambda x, y: [np.matmul(x, y)],
'MaxPool': onnx_maxpool,