Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_reduction_symbolic():
for backend in sym_op_backends:
print('Reduction tests for ', backend.framework_name)
for reduction in _reductions:
input = numpy.arange(2 * 3 * 4 * 5 * 6, dtype='int64').reshape(2, 3, 4, 5, 6)
input = input / input.astype('float64').mean()
test_cases = [
['a b c d e -> ', {},
getattr(input, reduction)()],
['a ... -> ', {},
getattr(input, reduction)()],
['(a a2) ... (e e2) -> ', dict(a2=1, e2=1),
getattr(input, reduction)()],
['a b c d e -> (e c) a', {},
getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape(-1, 2)],
['a ... c d e -> (e c) a', {},
getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape(-1, 2)],
['a b c d e ... -> (e c) a', {},
getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape(-1, 2)],
['a b c d e -> (e c a)', {},
def test_reduce_imperative():
for backend in collect_test_backends(symbolic=False, layers=True):
print('Test layer for ', backend.framework_name)
for reduction in _reductions:
for pattern, axes_lengths, input_shape, wrong_shapes in reduction_patterns:
print(backend, reduction, pattern, axes_lengths, input_shape, wrong_shapes)
x = numpy.arange(1, 1 + numpy.prod(input_shape), dtype='float32').reshape(input_shape)
x /= x.mean()
result_numpy = reduce(x, pattern, reduction, **axes_lengths)
layer = backend.layers().Reduce(pattern, reduction, **axes_lengths)
for shape in wrong_shapes:
try:
layer(backend.from_numpy(numpy.zeros(shape, dtype='float32')))
except:
pass
else:
raise AssertionError('Failure expected')
# simple pickling / unpickling
layer2 = pickle.loads(pickle.dumps(layer))
def test_reduction_imperatives():
for backend in imp_op_backends:
print('Reduction tests for ', backend.framework_name)
for reduction in _reductions:
input = numpy.arange(2 * 3 * 4 * 5 * 6, dtype='int64').reshape(2, 3, 4, 5, 6)
if reduction in ['mean', 'prod']:
input = input / input.astype('float64').mean()
test_cases = [
['a b c d e -> ', {}, getattr(input, reduction)()],
['... -> ', {}, getattr(input, reduction)()],
['(a1 a2) ... (e1 e2) -> ', dict(a1=1, e2=2), getattr(input, reduction)()],
['a b c d e -> (e c) a', {}, getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape(-1, 2)],
['a ... c d e -> (e c) a', {},
getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape(-1, 2)],
['a b c d e ... -> (e c) a', {},
getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape(-1, 2)],
['a b c d e -> (e c a)', {}, getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape(-1)],
['(a1 a2) ... -> (a2 a1) ...', dict(a2=1), input],
]
for pattern, axes_lengths, expected_result in test_cases: