Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
"""
grid = Grid(shape=(4, 4, 4))
ti0 = Function(name='ti0', grid=grid)
ti1 = Function(name='ti1', grid=grid)
tu = TimeFunction(name='tu', grid=grid)
tv = TimeFunction(name='tv', grid=grid)
eq1 = Eq(tu, tv*ti0 + ti0)
eq2 = Eq(ti0, tu + 3.)
eq3 = Eq(tv, ti0*ti1)
op1 = Operator([eq1, eq2, eq3], dse='noop', dle='noop')
op2 = Operator([eq2, eq1, eq3], dse='noop', dle='noop')
op3 = Operator([eq3, eq2, eq1], dse='noop', dle='noop')
trees = [retrieve_iteration_tree(i) for i in [op1, op2, op3]]
assert all(len(i) == 1 for i in trees)
trees = [i[0] for i in trees]
for tree in trees:
assert IsPerfectIteration().visit(tree[1])
exprs = FindNodes(Expression).visit(tree[-1])
assert len(exprs) == 3
"""
grid = Grid(shape=(4, 4, 4))
ti0 = Function(name='ti0', grid=grid)
ti1 = Function(name='ti1', grid=grid)
tu = TimeFunction(name='tu', grid=grid)
tv = TimeFunction(name='tv', grid=grid)
eq1 = Eq(tu, tv*ti0 + ti0)
eq2 = Eq(ti0, tu + 3.)
eq3 = Eq(tv, ti0*ti1)
op1 = Operator([eq1, eq2, eq3], opt='noop')
op2 = Operator([eq2, eq1, eq3], opt='noop')
op3 = Operator([eq3, eq2, eq1], opt='noop')
trees = [retrieve_iteration_tree(i) for i in [op1, op2, op3]]
assert all(len(i) == 1 for i in trees)
trees = [i[0] for i in trees]
for tree in trees:
assert IsPerfectIteration().visit(tree[1])
exprs = FindNodes(Expression).visit(tree[-1])
assert len(exprs) == 3
call = tree.root.nodes[1]
assert call.name == 'pokempi0'
assert call.arguments[0].name == 'msg0'
if configuration['openmp']:
# W/ OpenMP, we prod until all comms have completed
assert call.then_body[0].body[0].is_While
# W/ OpenMP, we expect dynamic thread scheduling
assert 'dynamic,1' in tree.root.pragmas[0].value
else:
# W/o OpenMP, it's a different story
assert call._single_thread
# Now we do as before, but enforcing loop blocking (by default off,
# as heuristically it is not enabled when the Iteration nest has depth < 3)
op = Operator(eqn, dle=('advanced', {'blockinner': True}))
trees = retrieve_iteration_tree(op._func_table['bf0'].root)
assert len(trees) == 2
tree = trees[1]
# Make sure `pokempi0` is the last node within the inner Iteration over blocks
assert len(tree) == 2
assert len(tree.root.nodes[0].nodes) == 2
call = tree.root.nodes[0].nodes[1]
assert call.name == 'pokempi0'
assert call.arguments[0].name == 'msg0'
if configuration['openmp']:
# W/ OpenMP, we prod until all comms have completed
assert call.then_body[0].body[0].is_While
# W/ OpenMP, we expect dynamic thread scheduling
assert 'dynamic,1' in tree.root.pragmas[0].value
else:
# W/o OpenMP, it's a different story
assert call._single_thread
"""
grid = Grid(shape=(3, 3, 3))
x, y, z = grid.dimensions
t0 = Constant(name='t0')
t1 = Scalar(name='t1')
e = Function(name='e', shape=(3,), dimensions=(x,), space_order=0)
f = Function(name='f', shape=(3, 3), dimensions=(x, y), space_order=0)
g = Function(name='g', grid=grid, space_order=0)
h = Function(name='h', grid=grid, space_order=0)
eq0 = Eq(t1, e*1.)
eq1 = Eq(f, t0*3. + t1)
eq2 = Eq(h, g + 4. + f*5.)
op = Operator([eq0, eq1, eq2], dse='noop', dle='noop')
trees = retrieve_iteration_tree(op)
assert len(trees) == 3
outer, middle, inner = trees
assert len(outer) == 1 and len(middle) == 2 and len(inner) == 3
assert outer[0] == middle[0] == inner[0]
assert middle[1] == inner[1]
assert outer[-1].nodes[0].exprs[0].expr.rhs == indexify(eq0.rhs)
assert middle[-1].nodes[0].exprs[0].expr.rhs == indexify(eq1.rhs)
assert inner[-1].nodes[0].exprs[0].expr.rhs == indexify(eq2.rhs)
def test_equations_mixed_timedim_stepdim(self):
""""
Test that two equations one using a TimeDimension the other a derived
SteppingDimension end up in the same loop nest.
"""
grid = Grid(shape=(3, 3, 3))
x, y, z = grid.dimensions
time = grid.time_dim
t = grid.stepping_dim
u1 = TimeFunction(name='u1', grid=grid)
u2 = TimeFunction(name='u2', grid=grid, save=2)
eqn_1 = Eq(u1[t+1, x, y, z], u1[t, x, y, z] + 1.)
eqn_2 = Eq(u2[time+1, x, y, z], u2[time, x, y, z] + 1.)
op = Operator([eqn_1, eqn_2], opt='topofuse')
trees = retrieve_iteration_tree(op)
assert len(trees) == 1
assert len(trees[0][-1].nodes[0].exprs) == 2
assert trees[0][-1].nodes[0].exprs[0].write == u1
assert trees[0][-1].nodes[0].exprs[1].write == u2
@switchconfig(platform='nvidiaX')
def test_basic(self):
grid = Grid(shape=(3, 3, 3))
u = TimeFunction(name='u', grid=grid)
op = Operator(Eq(u.forward, u + 1), dle=('advanced', {'openmp': True}))
trees = retrieve_iteration_tree(op)
assert len(trees) == 1
assert trees[0][1].pragmas[0].value ==\
'omp target teams distribute parallel for collapse(3)'
assert op.body[1].header[1].value ==\
('omp target enter data map(to: u[0:u_vec->size[0]]'
'[0:u_vec->size[1]][0:u_vec->size[2]][0:u_vec->size[3]])')
assert op.body[1].footer[0].value ==\
('omp target exit data map(from: u[0:u_vec->size[0]]'
'[0:u_vec->size[1]][0:u_vec->size[2]][0:u_vec->size[3]])')
def test_cache_blocking_structure(blockinner, expected):
_, op = _new_operator1((10, 31, 45), dle=('blocking', {'blockalways': True,
'blockinner': blockinner}))
# Check presence of remainder loops
iterations = retrieve_iteration_tree(op)
assert len(iterations) == expected
assert not iterations[0][0].is_Remainder
assert all(i[0].is_Remainder for i in iterations[1:])
# Check presence of openmp pragmas at the right place
_, op = _new_operator1((10, 31, 45), dle=('blocking',
{'openmp': True,
'blockalways': True,
'blockinner': blockinner}))
iterations = retrieve_iteration_tree(op)
assert len(iterations) == expected
# All iterations except the last one an outermost parallel loop over blocks
assert not iterations[-1][0].is_Parallel
for i in iterations[:-1]:
outermost = i[0]
assert len(outermost.pragmas) == 1
u1 = TimeFunction(name="u1", grid=grid, save=10, time_order=2)
u2 = TimeFunction(name="u2", grid=grid, time_order=2)
sf1 = SparseTimeFunction(name='sf1', grid=grid, npoint=1, nt=10)
sf2 = SparseTimeFunction(name='sf2', grid=grid, npoint=1, nt=10)
# Deliberately inject into u1, rather than u1.forward, to create a WAR w/ eqn3
eqn1 = Eq(u1.forward, u1 + 2.0 - u1.backward)
eqn2 = sf1.inject(u1, expr=sf1)
eqn3 = Eq(u2.forward, u2 + 2*u2.backward - u1.dt2)
eqn4 = sf2.interpolate(u2)
# Note: opts disabled only because with OpenMP otherwise there might be more
# `trees` than 4
op = Operator([eqn1] + eqn2 + [eqn3] + eqn4, opt=('noop', {'openmp': False}))
trees = retrieve_iteration_tree(op)
assert len(trees) == 4
# Time loop not shared due to the WAR
assert trees[0][0].dim is time and trees[0][0] is trees[1][0] # this IS shared
assert trees[1][0] is not trees[2][0]
assert trees[2][0].dim is time and trees[2][0] is trees[3][0] # this IS shared
# Now single, shared time loop expected
eqn2 = sf1.inject(u1.forward, expr=sf1)
op = Operator([eqn1] + eqn2 + [eqn3] + eqn4, opt=('noop', {'openmp': False}))
trees = retrieve_iteration_tree(op)
assert len(trees) == 4
assert all(trees[0][0] is i[0] for i in trees)
def test_expressions_imperfect_loops(self, ti0, ti1, ti2, t0):
"""
Test that equations depending only on a subset of all indices
appearing across all equations are placed within earlier loops
in the loop nest tree.
"""
eq1 = Eq(ti2, t0*3.)
eq2 = Eq(ti0, ti1 + 4. + ti2*5.)
op = Operator([eq1, eq2], dse='noop', dle='noop')
trees = retrieve_iteration_tree(op)
assert len(trees) == 2
outer, inner = trees
assert len(outer) == 2 and len(inner) == 3
assert all(i == j for i, j in zip(outer, inner[:-1]))
assert outer[-1].nodes[0].expr.rhs == eq1.rhs
assert inner[-1].nodes[0].expr.rhs == eq2.rhs
...
one may provide ``blockshape = {i: 4, j: 7}``, in which case the
two outer loops will blocked, and the resulting 2-dimensional block will
have size 4x7. The latter may be set to True to also block innermost parallel
:class:`Iteration` objects.
"""
exclude_innermost = not self.params.get('blockinner', False)
ignore_heuristic = self.params.get('blockalways', False)
# Make sure loop blocking will span as many Iterations as possible
fold = fold_blockable_tree(nodes, exclude_innermost)
mapper = {}
blocked = OrderedDict()
for tree in retrieve_iteration_tree(fold):
# Is the Iteration tree blockable ?
iterations = [i for i in tree if i.is_Parallel]
if exclude_innermost:
iterations = [i for i in iterations if not i.is_Vectorizable]
if len(iterations) <= 1:
continue
root = iterations[0]
if not IsPerfectIteration().visit(root):
# Illegal/unsupported
continue
if not tree[0].is_Sequential and not ignore_heuristic:
# Heuristic: avoid polluting the generated code with blocked
# nests (thus increasing JIT compilation time and affecting
# readability) if the blockable tree isn't embedded in a
# sequential loop (e.g., a timestepping loop)
continue