Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _batch_mahalanobis(bL, bx):
if bL.shape[:-1] == bx.shape:
# no need to use the below optimization procedure
solve_bL_bx = solve_triangular(bL, bx[..., None], lower=True).squeeze(-1)
return jnp.sum(jnp.square(solve_bL_bx), -1)
# NB: The following procedure handles the case: bL.shape = (i, 1, n, n), bx.shape = (i, j, n)
# because we don't want to broadcast bL to the shape (i, j, n, n).
# Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
# we are going to make bx have shape (..., 1, j, i, 1, n) to apply batched tril_solve
sample_ndim = bx.ndim - bL.ndim + 1 # size of sample_shape
out_shape = jnp.shape(bx)[:-1] # shape of output
# Reshape bx with the shape (..., 1, i, j, 1, n)
bx_new_shape = out_shape[:sample_ndim]
for (sL, sx) in zip(bL.shape[:-2], out_shape[sample_ndim:]):
bx_new_shape += (sx // sL, sL)
bx_new_shape += (-1,)
bx = jnp.reshape(bx, bx_new_shape)
# Permute bx to make it have shape (..., 1, j, i, 1, n)
permute_dims = (tuple(range(sample_ndim))
def kXkXsq_row(i, kX):
return np.square(1.0 + np.matmul(kX, kX[i]))
def kXkXsq_mvm(b, kX, dilation=2):
def kernel(X, Z, eta1, eta2, c, jitter=1.0e-6):
eta1sq = np.square(eta1)
eta2sq = np.square(eta2)
k1 = 0.5 * eta2sq * np.square(1.0 + kdot(X, Z))
k2 = -0.5 * eta2sq * kdot(np.square(X), np.square(Z))
k3 = (eta1sq - eta2sq) * kdot(X, Z)
k4 = np.square(c) - 0.5 * eta2sq
if X.shape == Z.shape:
k4 += jitter * np.eye(X.shape[0])
return k1 + k2 + k3 + k4
def kernel(X, Z, eta1, eta2, c, jitter=1.0e-6):
eta1sq = np.square(eta1)
eta2sq = np.square(eta2)
k1 = 0.5 * eta2sq * np.square(1.0 + kdot(X, Z))
k2 = -0.5 * eta2sq * kdot(np.square(X), np.square(Z))
k3 = (eta1sq - eta2sq) * kdot(X, Z)
k4 = np.square(c) - 0.5 * eta2sq
if X.shape == Z.shape:
k4 += jitter * np.eye(X.shape[0])
return k1 + k2 + k3 + k4
def kernel(X, Z, eta1, eta2, c):
eta1sq = np.square(eta1)
eta2sq = np.square(eta2)
k1 = 0.5 * eta2sq * np.square(1.0 + kdot(X, Z))
k2 = -0.5 * eta2sq * kdot(np.square(X), np.square(Z))
k3 = (eta1sq - eta2sq) * kdot(X, Z)
k4 = np.square(c) - 0.5 * eta2sq
return k1 + k2 + k3 + k4
def kernel(X, Z, eta1, eta2, c, jitter=1.0e-6):
eta1sq = np.square(eta1)
eta2sq = np.square(eta2)
k1 = 0.5 * eta2sq * np.square(1.0 + dot(X, Z))
k2 = -0.5 * eta2sq * dot(np.square(X), np.square(Z))
k3 = (eta1sq - eta2sq) * dot(X, Z)
k4 = np.square(c) - 0.5 * eta2sq
if X.shape == Z.shape:
k4 += jitter * np.eye(X.shape[0])
return k1 + k2 + k3 + k4
def kernel(X, Z, eta1, eta2, c, jitter=1.0e-6):
eta1sq = np.square(eta1)
eta2sq = np.square(eta2)
k1 = 0.5 * eta2sq * np.square(1.0 + kdot(X, Z))
k2 = -0.5 * eta2sq * kdot(np.square(X), np.square(Z))
k3 = (eta1sq - eta2sq) * kdot(X, Z)
k4 = np.square(c) - 0.5 * eta2sq
if X.shape == Z.shape:
k4 += jitter * np.eye(X.shape[0])
return k1 + k2 + k3 + k4
def kernel(X, Z, eta1, eta2, c, jitter=1.0e-6):
eta1sq = np.square(eta1)
eta2sq = np.square(eta2)
k1 = 0.5 * eta2sq * np.square(1.0 + dot(X, Z))
k2 = -0.5 * eta2sq * dot(np.square(X), np.square(Z))
k3 = (eta1sq - eta2sq) * dot(X, Z)
k4 = np.square(c) - 0.5 * eta2sq
if X.shape == Z.shape:
k4 += jitter * np.eye(X.shape[0])
return k1 + k2 + k3 + k4
def sample_hypers(sigma, S, N, P, hypers):
phi = sigma * (S / np.sqrt(N)) / (P - S)
eta1 = numpyro.sample("eta1", dist.HalfCauchy(phi))
msq = numpyro.sample("msq", dist.InverseGamma(hypers['alpha1'], hypers['beta1']))
xisq = numpyro.sample("xisq", dist.InverseGamma(hypers['alpha2'], hypers['beta2']))
eta2 = numpyro.deterministic('eta2', np.square(eta1) * np.sqrt(xisq) / msq)
lam = numpyro.sample("lambda", dist.HalfCauchy(np.ones(P)))
kappa = numpyro.deterministic('kappa', np.sqrt(msq) * lam / np.sqrt(msq + np.square(eta1 * lam)))
return eta1, eta2, kappa