Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_memoryview_float_notrans(A, B, a_rows, a_cols, out_cols):
A, B, C = _reshape_for_gemm(A, B, a_rows, a_cols, out_cols, dtype='float32')
assume(A is not None)
assume(B is not None)
assume(C is not None)
assume(A.size >= 1)
assume(B.size >= 1)
assume(C.size >= 1)
gemm(A, B, out=C)
numpy_result = A.dot(B)
assert_allclose(numpy_result, C, atol=1e-3, rtol=1e-3)
def test_memoryview_double_notrans(A, B, a_rows, a_cols, out_cols):
A, B, C = _reshape_for_gemm(A, B, a_rows, a_cols, out_cols, 'float64')
assume(A is not None)
assume(B is not None)
assume(C is not None)
assume(A.size >= 1)
assume(B.size >= 1)
assume(C.size >= 1)
gemm(A, B, out=C)
numpy_result = A.dot(B)
assert_allclose(numpy_result, C, atol=1e-3, rtol=1e-3)
def test_memoryview_double_noconj(A, B):
if len(A) < len(B):
B = B[:len(A)]
else:
A = A[:len(B)]
assume(A is not None)
assume(B is not None)
numpy_result = A.dot(B)
result = dotv(A, B)
assert_allclose([numpy_result], result, atol=1e-3, rtol=1e-3)
def test_memoryview_float_noconj(A, B):
if len(A) < len(B):
B = B[:len(A)]
else:
A = A[:len(B)]
assume(A is not None)
assume(B is not None)
numpy_result = A.dot(B)
result = dotv(A, B)
assert_allclose([numpy_result], result, atol=1e-3, rtol=1e-3)
def blis_gemm(X, W, n=1000):
nO, nI = W.shape
batch_size = X.shape[0]
total = 0.0
y = numpy.zeros((batch_size, nO), dtype="f")
for i in range(n):
gemm(X, W, out=y)
total += y.sum()
y.fill(0.0)
print("Total:", total)
def blis_einsum(X, W, n=1000):
nO, nI = W.shape
batch_size = X.shape[0]
total = 0.0
y = numpy.zeros((nO, batch_size), dtype="f")
for i in range(n):
einsum("ab,cb->ca", X, W, out=y)
total += y.sum()
y.fill(0.0)
print("Total:", total)