Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_inverse_and_logdet():
# Test `Dense`.
a = np.random.randn(3, 3)
a = Dense(a.dot(a.T))
allclose(B.matmul(a, B.inverse(a)), np.eye(3))
allclose(B.matmul(B.inverse(a), a), np.eye(3))
allclose(B.logdet(a), np.log(np.linalg.det(to_np(a))))
# Test `Diagonal`.
d = Diagonal(np.array([1, 2, 3]))
allclose(B.matmul(d, B.inverse(d)), np.eye(3))
allclose(B.matmul(B.inverse(d), d), np.eye(3))
allclose(B.logdet(d), np.log(np.linalg.det(to_np(d))))
assert B.shape(B.inverse(Diagonal(np.array([1, 2]),
rows=2, cols=4))) == (4, 2)
# Test `Woodbury`.
a = np.random.randn(3, 2)
b = np.random.randn(2, 2) + 1e-2 * np.eye(2)
wb = d + LowRank(left=a, middle=b.dot(b.T))
for _ in range(4):
allclose(B.matmul(wb, B.inverse(wb)), np.eye(3))
allclose(B.matmul(B.inverse(wb), wb), np.eye(3))
allclose(B.logdet(wb), np.log(np.linalg.det(to_np(wb))))
wb = B.inverse(wb)
# Test `LowRank`.
def test_sample():
a = np.random.randn(3, 3)
a = Dense(a.dot(a.T))
b = np.random.randn(2, 2)
wb = Diagonal(B.diag(a)) + LowRank(left=np.random.randn(3, 2),
middle=b.dot(b.T))
# Test `Dense` and `Woodbury`.
num_samps = 500000
for cov in [a, wb]:
samps = B.sample(cov, num_samps)
cov_emp = B.matmul(samps, samps, tr_b=True) / num_samps
assert np.mean(np.abs(to_np(cov_emp) - to_np(cov))) <= 5e-2
def compare(a, b):
return np.allclose(to_np(B.matmul(a, b)),
B.matmul(to_np(a), to_np(b)))
def lr_diff(a, b):
"""Subtract two low-rank matrices, forcing the resulting middle part to
be positive definite if the result is so.
Args:
a (:class:`.matrix.LowRank`): `a`.
b (:class:`.matrix.LowRank`): `b`.
Returns:
:class:`.matrix.LowRank`: Difference between `a` and `b`.
"""
diff = a - b
u_left, s_left, v_left = B.svd(diff.left)
u_right, s_right, v_right = B.svd(diff.right)
middle = B.matmul(Diagonal(s_left, *B.shape(diff.left)),
B.matmul(v_left, diff.middle, v_right, tr_a=True),
Diagonal(s_right, *B.shape(diff.right)).T)
return LowRank(left=u_left,
right=u_right,
middle=middle)
def inverse(a):
# Use the Woodbury matrix identity.
if a.inverse is None:
inv_diag = B.inverse(a.diag)
lr = LowRank(left=B.matmul(inv_diag, a.lr.left),
right=B.matmul(inv_diag, a.lr.right),
middle=B.inverse(B.schur(a)))
if a.lr_pd:
a.inverse = Woodbury(diag=inv_diag, lr=-lr, lr_pd=False)
else:
a.inverse = Woodbury(diag=inv_diag, lr=lr, lr_pd=True)
return a.inverse
def matmul(*xs, **trs):
def tr(name):
return trs['tr_' + name] if 'tr_' + name in trs else False
# Compute the first product.
res = B.matmul(xs[0], xs[1], tr_a=tr('a'), tr_b=tr('b'))
# Compute the remaining products.
for name, x in zip(ascii_lowercase[2:], xs[2:]):
res = B.matmul(res, x, tr_b=tr(name))
return res
def matmul(a, b, tr_a=False, tr_b=False):
# Prioritise expanding out the Woodbury matrix. Give this one even higher
# precedence to resolve ambiguity in the case of two Woodbury matrices.
return B.add(B.matmul(a, b.lr, tr_a=tr_a, tr_b=tr_b),
B.matmul(a, b.diag, tr_a=tr_a, tr_b=tr_b))
def diag(a):
# The matrix might be non-square, so handle that.
diag_len = B.diag_len(a)
return B.sum(B.matmul(a.left, a.middle)[:diag_len, :] *
a.right[:diag_len, :], axis=1)
def ratio(a, b): return B.sum(B.qf_diag(b, a.right, B.matmul(a.left, a.middle)))
def sum(a, axis=None):
# Efficiently handle a number of common cases.
if axis is None:
left = B.expand_dims(B.sum(a.left, axis=0), axis=0)
right = B.expand_dims(B.sum(a.right, axis=0), axis=1)
return B.matmul(left, a.middle, right)[0, 0]
elif axis is 0:
left = B.expand_dims(B.sum(a.left, axis=0), axis=0)
return B.matmul(left, a.middle, a.right, tr_c=True)[0, :]
elif axis is 1:
right = B.expand_dims(B.sum(a.right, axis=0), axis=1)
return B.matmul(a.left, a.middle, right)[:, 0]
else:
# Fall back to generic implementation.
return B.sum.invoke(Dense)(a, axis=axis)