Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def make_ops_kernels(iet):
warning("The OPS backend is still work-in-progress")
affine_trees = find_affine_trees(iet).items()
# If there is no affine trees, then there is no loop to be optimized using OPS.
if not affine_trees:
return iet, {}
ops_init = Call(namespace['ops_init'], [0, 0, 2])
ops_partition = Call(namespace['ops_partition'], Literal('""'))
ops_exit = Call(namespace['ops_exit'])
# Extract all symbols that need to be converted to ops_dat
dims = []
to_dat = set()
for _, tree in affine_trees:
dims.append(len(tree[0].dimensions))
symbols = set(FindSymbols('symbolics').visit(tree[0].root))
symbols -= set(FindSymbols('defines').visit(tree[0].root))
to_dat |= symbols
# Create the OPS block for this problem
ops_block = OpsBlock('block')
ops_block_init = Expression(ClusterizedEq(Eq(
ops_block,
@target_pass
def make_ops_kernels(iet):
warning("The OPS backend is still work-in-progress")
affine_trees = find_affine_trees(iet).items()
# If there is no affine trees, then there is no loop to be optimized using OPS.
if not affine_trees:
return iet, {}
ops_init = Call(namespace['ops_init'], [0, 0, 2])
ops_partition = Call(namespace['ops_partition'], Literal('""'))
ops_exit = Call(namespace['ops_exit'])
# Extract all symbols that need to be converted to ops_dat
dims = []
to_dat = set()
for _, tree in affine_trees:
dims.append(len(tree[0].dimensions))
symbols = set(FindSymbols('symbolics').visit(tree[0].root))
symbols -= set(FindSymbols('defines').visit(tree[0].root))
to_dat |= symbols
# Create the OPS block for this problem
ops_block = OpsBlock('block')
ops_block_init = Expression(ClusterizedEq(Eq(
ops_block,
namespace['ops_decl_block'](
dims[0],
fromrank = FieldFromComposite(msg._C_field_from, msgi)
sizes = [FieldFromComposite('%s[%d]' % (msg._C_field_sizes, i), msgi)
for i in range(len(f._dist_dimensions))]
ofss = [FieldFromComposite('%s[%d]' % (msg._C_field_ofss, i), msgi)
for i in range(len(f._dist_dimensions))]
ofss = [fixed.get(d) or ofss.pop(0) for d in f.dimensions]
# The `scatter` must be guarded as we must not alter the halo values along
# the domain boundary, where the sender is actually MPI.PROC_NULL
scatter = Call('scatter_%s' % key, [bufs] + sizes + [f] + ofss)
scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter)
rrecv = Byref(FieldFromComposite(msg._C_field_rrecv, msgi))
waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')])
rsend = Byref(FieldFromComposite(msg._C_field_rsend, msgi))
waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')])
# The -1 below is because an Iteration, by default, generates <=
ncomms = Symbol(name='ncomms')
iet = Iteration([waitsend, waitrecv, scatter], dim, ncomms - 1)
parameters = ([f] + list(fixed.values()) + [msg, ncomms])
return Callable('halowait%d' % key, iet, 'void', parameters, ('static',))
msgi = IndexedPointer(msg, dim)
bufg = FieldFromComposite(msg._C_field_bufg, msgi)
bufs = FieldFromComposite(msg._C_field_bufs, msgi)
fromrank = FieldFromComposite(msg._C_field_from, msgi)
torank = FieldFromComposite(msg._C_field_to, msgi)
sizes = [FieldFromComposite('%s[%d]' % (msg._C_field_sizes, i), msgi)
for i in range(len(f._dist_dimensions))]
ofsg = [FieldFromComposite('%s[%d]' % (msg._C_field_ofsg, i), msgi)
for i in range(len(f._dist_dimensions))]
ofsg = [fixed.get(d) or ofsg.pop(0) for d in f.dimensions]
# The `gather` is unnecessary if sending to MPI.PROC_NULL
gather = Call('gather_%s' % key, [bufg] + sizes + [f] + ofsg)
gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather)
# Make Irecv/Isend
count = reduce(mul, sizes, 1)
rrecv = Byref(FieldFromComposite(msg._C_field_rrecv, msgi))
rsend = Byref(FieldFromComposite(msg._C_field_rsend, msgi))
recv = Call('MPI_Irecv', [bufs, count, Macro(dtype_to_mpitype(f.dtype)),
fromrank, Integer(13), comm, rrecv])
send = Call('MPI_Isend', [bufg, count, Macro(dtype_to_mpitype(f.dtype)),
torank, Integer(13), comm, rsend])
# The -1 below is because an Iteration, by default, generates <=
ncomms = Symbol(name='ncomms')
iet = Iteration([recv, gather, send], dim, ncomms - 1)
parameters = ([f, comm, msg, ncomms]) + list(fixed.values())
return Callable('haloupdate%d' % key, iet, 'void', parameters, ('static',))
# so that we have nranks == ngpus (as long as the user has launched
# the right number of MPI processes per node given the available
# number of GPUs per node)
comm = None
for i in iet.parameters:
if isinstance(i, MPICommObject):
comm = i
break
device_nvidia = Macro('acc_device_nvidia')
body = Call('acc_init', [device_nvidia])
if comm is not None:
rank = Symbol(name='rank')
rank_decl = LocalExpression(DummyEq(rank, 0))
rank_init = Call('MPI_Comm_rank', [comm, Byref(rank)])
ngpus = Symbol(name='ngpus')
call = DefFunction('acc_get_num_devices', device_nvidia)
ngpus_init = LocalExpression(DummyEq(ngpus, call))
devicenum = Symbol(name='devicenum')
devicenum_init = LocalExpression(DummyEq(devicenum, rank % ngpus))
set_device_num = Call('acc_set_device_num', [devicenum, device_nvidia])
body = [rank_decl, rank_init, ngpus_init, devicenum_init,
set_device_num, body]
init = List(header=c.Comment('Begin of OpenACC+MPI setup'),
body=body,
footer=(c.Comment('End of OpenACC+MPI setup'), c.Line()))
# The `gather` is unnecessary if sending to MPI.PROC_NULL
gather = Conditional(CondNe(torank, Macro('MPI_PROC_NULL')), gather)
# The `scatter` must be guarded as we must not alter the halo values along
# the domain boundary, where the sender is actually MPI.PROC_NULL
scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter)
count = reduce(mul, bufs.shape, 1)
rrecv = MPIRequestObject(name='rrecv')
rsend = MPIRequestObject(name='rsend')
recv = Call('MPI_Irecv', [bufs, count, Macro(dtype_to_mpitype(f.dtype)),
fromrank, Integer(13), comm, rrecv])
send = Call('MPI_Isend', [bufg, count, Macro(dtype_to_mpitype(f.dtype)),
torank, Integer(13), comm, rsend])
waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')])
waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')])
iet = List(body=[recv, gather, send, waitsend, waitrecv, scatter])
parameters = ([f] + list(bufs.shape) + ofsg + ofss + [fromrank, torank, comm])
return Callable('sendrecv_%s' % key, iet, 'void', parameters, ('static',))
ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions]
fromrank = Symbol(name='fromrank')
sizes = [FieldFromPointer('%s[%d]' % (msg._C_field_sizes, i), msg)
for i in range(len(f._dist_dimensions))]
scatter = Call('scatter_%s' % key, [bufs] + sizes + [f] + ofss)
# The `scatter` must be guarded as we must not alter the halo values along
# the domain boundary, where the sender is actually MPI.PROC_NULL
scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter)
rrecv = Byref(FieldFromPointer(msg._C_field_rrecv, msg))
waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')])
rsend = Byref(FieldFromPointer(msg._C_field_rsend, msg))
waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')])
iet = List(body=[waitsend, waitrecv, scatter])
parameters = ([f] + ofss + [fromrank, msg])
return Callable('wait_%s' % key, iet, 'void', parameters, ('static',))
bufs = FieldFromPointer(msg._C_field_bufs, msg)
ofss = [Symbol(name='os%s' % d.root) for d in f.dimensions]
fromrank = Symbol(name='fromrank')
sizes = [FieldFromPointer('%s[%d]' % (msg._C_field_sizes, i), msg)
for i in range(len(f._dist_dimensions))]
scatter = Call('scatter_%s' % key, [bufs] + sizes + [f] + ofss)
# The `scatter` must be guarded as we must not alter the halo values along
# the domain boundary, where the sender is actually MPI.PROC_NULL
scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter)
rrecv = Byref(FieldFromPointer(msg._C_field_rrecv, msg))
waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')])
rsend = Byref(FieldFromPointer(msg._C_field_rsend, msg))
waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')])
iet = List(body=[waitsend, waitrecv, scatter])
parameters = ([f] + ofss + [fromrank, msg])
return Callable('wait_%s' % key, iet, 'void', parameters, ('static',))
sizes = [FieldFromComposite('%s[%d]' % (msg._C_field_sizes, i), msgi)
for i in range(len(f._dist_dimensions))]
ofss = [FieldFromComposite('%s[%d]' % (msg._C_field_ofss, i), msgi)
for i in range(len(f._dist_dimensions))]
ofss = [fixed.get(d) or ofss.pop(0) for d in f.dimensions]
# The `scatter` must be guarded as we must not alter the halo values along
# the domain boundary, where the sender is actually MPI.PROC_NULL
scatter = Call('scatter_%s' % key, [bufs] + sizes + [f] + ofss)
scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter)
rrecv = Byref(FieldFromComposite(msg._C_field_rrecv, msgi))
waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')])
rsend = Byref(FieldFromComposite(msg._C_field_rsend, msgi))
waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')])
# The -1 below is because an Iteration, by default, generates <=
ncomms = Symbol(name='ncomms')
iet = Iteration([waitsend, waitrecv, scatter], dim, ncomms - 1)
parameters = ([f] + list(fixed.values()) + [msg, ncomms])
return Callable('halowait%d' % key, iet, 'void', parameters, ('static',))
fromrank = FieldFromComposite(msg._C_field_from, msgi)
sizes = [FieldFromComposite('%s[%d]' % (msg._C_field_sizes, i), msgi)
for i in range(len(f._dist_dimensions))]
ofss = [FieldFromComposite('%s[%d]' % (msg._C_field_ofss, i), msgi)
for i in range(len(f._dist_dimensions))]
ofss = [fixed.get(d) or ofss.pop(0) for d in f.dimensions]
# The `scatter` must be guarded as we must not alter the halo values along
# the domain boundary, where the sender is actually MPI.PROC_NULL
scatter = Call('scatter_%s' % key, [bufs] + sizes + [f] + ofss)
scatter = Conditional(CondNe(fromrank, Macro('MPI_PROC_NULL')), scatter)
rrecv = Byref(FieldFromComposite(msg._C_field_rrecv, msgi))
waitrecv = Call('MPI_Wait', [rrecv, Macro('MPI_STATUS_IGNORE')])
rsend = Byref(FieldFromComposite(msg._C_field_rsend, msgi))
waitsend = Call('MPI_Wait', [rsend, Macro('MPI_STATUS_IGNORE')])
# The -1 below is because an Iteration, by default, generates <=
ncomms = Symbol(name='ncomms')
iet = Iteration([waitsend, waitrecv, scatter], dim, ncomms - 1)
parameters = ([f] + list(fixed.values()) + [msg, ncomms])
return Callable('halowait%d' % key, iet, 'void', parameters, ('static',))