Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_uprank():
allclose(uprank(0), [[0]])
allclose(uprank(np.array([0])), [[0]])
allclose(uprank(np.array([[0]])), [[0]])
assert type(uprank(Component('test')(0))) == Component('test')
k = OneKernel()
assert B.shape(k(0, 0)) == (1, 1)
assert B.shape(k(0, np.ones(5))) == (1, 5)
assert B.shape(k(0, np.ones((5, 2)))) == (1, 5)
assert B.shape(k(np.ones(5), 0)) == (5, 1)
assert B.shape(k(np.ones(5), np.ones(5))) == (5, 5)
assert B.shape(k(np.ones(5), np.ones((5, 2)))) == (5, 5)
assert B.shape(k(np.ones((5, 2)), 0)) == (5, 1)
assert B.shape(k(np.ones((5, 2)), np.ones(5))) == (5, 5)
assert B.shape(k(np.ones((5, 2)), np.ones((5, 2)))) == (5, 5)
with pytest.raises(ValueError):
k(0, np.ones((5, 2, 1)))
with pytest.raises(ValueError):
k(np.ones((5, 2, 1)))
m = OneMean()
def __init__(self, constant, rows, cols=None):
self.constant = constant
self.rows = rows
self.cols = rows if cols is None else cols
# Construct and initialise the low-rank representation.
left = B.ones(B.dtype(self.constant), self.rows, 1)
if self.rows is self.cols:
right = left
else:
right = B.ones(B.dtype(self.constant), self.cols, 1)
middle = B.expand_dims(B.expand_dims(self.constant, axis=0), axis=0)
LowRank.__init__(self, left=left, right=right, middle=middle)
@_dispatch(B.Numeric, B.Numeric)
@uprank
def __call__(self, x, y):
return Dense(B.exp(-B.pw_dists(x, y)))
def __call__(self, x):
return B.ones(B.dtype(x), B.shape(x)[0], 1)
@B.schur.extend(Woodbury)
def schur(a):
if a.schur is None:
prod = B.matmul(a.lr.right, B.inverse(a.diag), a.lr.left, tr_a=True)
if a.lr_pd:
a.schur = B.inverse(a.lr.middle) + prod
else:
a.schur = B.inverse(-a.lr.middle) - prod
return a.schur
@B.subtract.extend(Dense, Dense)
def subtract(a, b): return B.add(a, -b)
def uprank(x):
"""Ensure that the rank of `x` is 2.
Args:
x (tensor): Tensor to uprank.
Returns:
tensor: `x` with rank at least 2.
"""
# Simply return non-numerical inputs.
if not isinstance(x, B.Numeric):
return x
# Now check the rank of `x` and act accordingly.
rank = B.rank(x)
if rank > 2:
raise ValueError('Input must be at most rank 2.')
elif rank == 2:
return x
elif rank == 1:
return B.expand_dims(x, axis=1)
else:
# Rank must be 0.
return B.expand_dims(B.expand_dims(x, axis=0), axis=1)
@_dispatch(B.DType, [object])
def __init__(self, dtype, rows, cols=None):
Constant.__init__(self, B.cast(dtype, 0), rows=rows, cols=cols)