Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def compare(a):
allclose(B.transpose(a), to_np(a).T)
def test_shorthands():
a = Dense(np.random.randn(4, 4))
allclose(a.T, B.transpose(a))
allclose(a.__matmul__(a), B.matmul(a, a))
def matmul(a, b, tr_a=False, tr_b=False):
a = B.transpose(a) if tr_a else a
b = B.transpose(b) if tr_b else b
# Get shape of `b`.
b_rows, b_cols = B.shape(b)
# If `b` is square, don't do complicated things.
if b_rows == b_cols and b_rows is not None:
return dense(a) * B.diag(b)[None, :]
# Compute the core part.
cols = B.minimum(B.shape(a)[1], b_cols)
core = dense(a)[:, :cols] * B.diag(b)[None, :cols]
# Compute extra zeros to be appended.
extra_cols = b_cols - cols
extra_zeros = B.zeros(B.dtype(b), B.shape(a)[0], extra_cols)
def transpose(a): return Woodbury(diag=B.transpose(a.diag),
lr=B.transpose(a.lr),
lr_pd=a.lr_pd)
def transpose(a): return matrix(B.transpose(dense(a)))