Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _batch_mahalanobis(bL, bx):
if bL.shape[:-1] == bx.shape:
# no need to use the below optimization procedure
solve_bL_bx = solve_triangular(bL, bx[..., None], lower=True).squeeze(-1)
return jnp.sum(jnp.square(solve_bL_bx), -1)
# NB: The following procedure handles the case: bL.shape = (i, 1, n, n), bx.shape = (i, j, n)
# because we don't want to broadcast bL to the shape (i, j, n, n).
# Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
# we are going to make bx have shape (..., 1, j, i, 1, n) to apply batched tril_solve
sample_ndim = bx.ndim - bL.ndim + 1 # size of sample_shape
out_shape = jnp.shape(bx)[:-1] # shape of output
# Reshape bx with the shape (..., 1, i, j, 1, n)
bx_new_shape = out_shape[:sample_ndim]
for (sL, sx) in zip(bL.shape[:-2], out_shape[sample_ndim:]):
bx_new_shape += (sx // sL, sL)
bx_new_shape += (-1,)
bx = jnp.reshape(bx, bx_new_shape)
# Permute bx to make it have shape (..., 1, j, i, 1, n)
permute_dims = (tuple(range(sample_ndim))
+ tuple(range(sample_ndim, bx.ndim - 1, 2))
+ tuple(range(sample_ndim + 1, bx.ndim - 1, 2))
+ (bx.ndim - 1,))
bx = jnp.transpose(bx, permute_dims)
# reshape to (-1, i, 1, n)
xt = jnp.reshape(bx, (-1,) + bL.shape[:-1])
# permute to (i, 1, n, -1)
x = biject_to(transform.domain)(random.normal(rng_key, shape))
y = transform(x)
# test codomain
assert_array_equal(transform.codomain(y), jnp.ones(batch_shape))
# test inv
z = transform.inv(y)
assert_allclose(x, z, atol=1e-6, rtol=1e-6)
# test domain
assert_array_equal(transform.domain(z), jnp.ones(batch_shape))
# test log_abs_det_jacobian
actual = transform.log_abs_det_jacobian(x, y)
assert jnp.shape(actual) == batch_shape
if len(shape) == transform.event_dim:
if len(event_shape) == 1:
expected = np.linalg.slogdet(jax.jacobian(transform)(x))[1]
inv_expected = np.linalg.slogdet(jax.jacobian(transform.inv)(y))[1]
else:
expected = jnp.log(jnp.abs(grad(transform)(x)))
inv_expected = jnp.log(jnp.abs(grad(transform.inv)(y)))
assert_allclose(actual, expected, atol=1e-6)
assert_allclose(actual, -inv_expected, atol=1e-6)
def _binomial(key, p, n, shape):
shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n))
# reshape to map over axis 0
p = jnp.reshape(jnp.broadcast_to(p, shape), -1)
n = jnp.reshape(jnp.broadcast_to(n, shape), -1)
key = random.split(key, jnp.size(p))
if xla_bridge.get_backend().platform == 'cpu':
ret = lax.map(lambda x: _binomial_dispatch(*x),
(key, p, n))
else:
ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n)
return jnp.reshape(ret, shape)
def _multinomial(key, p, n, n_max, shape=()):
if jnp.shape(n) != jnp.shape(p)[:-1]:
broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
n = jnp.broadcast_to(n, broadcast_shape)
p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
shape = shape or p.shape[:-1]
# get indices from categorical distribution then gather the result
indices = categorical(key, p, (n_max,) + shape)
# mask out values when counts is heterogeneous
if jnp.ndim(n) > 0:
mask = promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1), jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))], -1)
else:
mask = 1
excess = 0
# NB: we transpose to move batch shape to the front
indices_2D = (jnp.reshape(indices * mask, (n_max, -1,))).T
def von_mises_centered(key, concentration, shape=(), dtype=jnp.float64):
""" Compute centered von Mises samples using rejection sampling from [1] with wrapped Cauchy proposal.
*** References ***
[1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986;
Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf
:param key: random number generator key
:param concentration: concentration of distribution
:param shape: shape of samples
:param dtype: float precesions for choosing correct s cutfoff
:return: centered samples from von Mises
"""
shape = shape or jnp.shape(concentration)
dtype = canonicalize_dtype(dtype)
concentration = lax.convert_element_type(concentration, dtype)
concentration = jnp.broadcast_to(concentration, shape)
return _von_mises_centered(key, concentration, shape, dtype)
def __init__(self, logits=None, validate_args=None):
self.logits = logits
super(BernoulliLogits, self).__init__(batch_shape=jnp.shape(self.logits), validate_args=validate_args)
def __init__(self, logits, total_count=1, validate_args=None):
if jnp.ndim(logits) < 1:
raise ValueError("`logits` parameter must be at least one-dimensional.")
batch_shape = lax.broadcast_shapes(jnp.shape(logits)[:-1], jnp.shape(total_count))
self.logits = promote_shapes(logits, shape=batch_shape + jnp.shape(logits)[-1:])[0]
self.total_count = promote_shapes(total_count, shape=batch_shape)[0]
super(MultinomialLogits, self).__init__(batch_shape=batch_shape,
event_shape=jnp.shape(self.logits)[-1:],
validate_args=validate_args)
return tree_map(lambda x: jnp.reshape(x, sample_shape + jnp.shape(x)[1:]),
unpacked_samples)
def __init__(self, loc, cov_factor, cov_diag, validate_args=None):
if jnp.ndim(loc) < 1:
raise ValueError("`loc` must be at least one-dimensional.")
event_shape = jnp.shape(loc)[-1:]
if jnp.ndim(cov_factor) < 2:
raise ValueError("`cov_factor` must be at least two-dimensional, "
"with optional leading batch dimensions")
if jnp.shape(cov_factor)[-2:-1] != event_shape:
raise ValueError("`cov_factor` must be a batch of matrices with shape {} x m"
.format(event_shape[0]))
if jnp.shape(cov_diag)[-1:] != event_shape:
raise ValueError("`cov_diag` must be a batch of vectors with shape {}".format(self.event_shape))
loc, cov_factor, cov_diag = promote_shapes(loc[..., jnp.newaxis], cov_factor, cov_diag[..., jnp.newaxis])
batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(cov_factor), jnp.shape(cov_diag))[:-2]
self.loc = jnp.broadcast_to(loc[..., 0], batch_shape + event_shape)
self.cov_factor = cov_factor
cov_diag = cov_diag[..., 0]
self.cov_diag = cov_diag
self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag)
super(LowRankMultivariateNormal, self).__init__(
batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args
)
def __init__(self, log_factor, validate_args=None):
batch_shape = jnp.shape(log_factor)
event_shape = (0,) # This satisfies .size == 0.
self.log_factor = log_factor
super(Unit, self).__init__(batch_shape, event_shape, validate_args=validate_args)