Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
"""
Substitutes old for new in an expression after sympifying args.
new_elements : list of tuples like [(x,2)(y,3)]
"""
if len(list(new_elements))==0:
return expr
if isinstance(expr, (list, tuple, Tuple)):
return [subs(expr, new_elements) for i in expr]
elif isinstance(expr, (Expr, Assign)):
return expr.subs(new_elements)
elif isinstance(expr, While):
test = subs(expr.test, a_old, a_new)
body = subs(expr.body, a_old, a_new)
return While(test, body)
elif isinstance(expr, For):
# TODO treat iter correctly
target = subs(expr.target, a_old, a_new)
it = subs(expr.iterable, a_old, a_new)
target = expr.target
it = expr.iterable
body = subs(expr.body, a_old, a_new)
return For(target, it, body)
elif isinstance(expr, If):
args = []
for block in expr.args:
test = block[0]
stmts = block[1]
t = subs(test, a_old, a_new)
s = subs(stmts, a_old, a_new)
args.append((t, s))
return If(*args)
body += list(stmts)
# ...
# ...
if isinstance(generator, ZipGenerator):
for i,n in zip(index, length):
for x,xs in zip(iterator, iterable):
if not isinstance(xs, (list, tuple, Tuple)):
body = [Assign(x, IndexedBase(xs.name)[i])] + body
else:
for v in xs:
body = [Assign(x, IndexedBase(v.name)[i])] + body
body = [For(i, Range(0, n), Tuple(*body), strict=False)]
else:
for i,n,x,xs in zip(index, length, iterator, iterable):
if not isinstance(xs, (list, tuple, Tuple)):
body = [Assign(x, IndexedBase(xs.name)[i])] + body
else:
for v in xs:
body = [Assign(x, IndexedBase(v.name)[i])] + body
body = [For(i, Range(0, n), Tuple(*body), strict=False)]
# ...
return decs, body
return [mpify(i, **options) for i in stmt]
if isinstance(stmt, MPI):
return stmt
if isinstance(stmt, Tensor):
options['label'] = stmt.name
return stmt
if isinstance(stmt, ForIterator):
iterable = mpify(stmt.iterable, **options)
target = stmt.target
body = mpify(stmt.body, **options)
return ForIterator(target, iterable, body, strict=False)
if isinstance(stmt, For):
iterable = mpify(stmt.iterable, **options)
target = stmt.target
body = mpify(stmt.body, **options)
return For(target, iterable, body, strict=False)
if isinstance(stmt, list):
return [mpify(a, **options) for a in stmt]
if isinstance(stmt, While):
test = mpify(stmt.test, **options)
body = mpify(stmt.body, **options)
return While(test, body)
if isinstance(stmt, If):
args = []
for block in stmt.args:
def set_fst(expr, fst):
if isinstance(expr, (tuple,list,Tuple)):
for i in expr:set_fst(i, fst)
elif isinstance(expr, For):
set_fst(expr.body, fst)
elif isinstance(expr, (Assign, AugAssign)):
expr.set_fst(fst)
elif isinstance(expr, GC):
expr.set_fst(fst)
set_fst(expr.loops, fst)
def __new__(
cls,
target,
iter,
body,
strict=True,
):
if isinstance(iter, Symbol):
iter = Range(Len(iter))
return For.__new__(cls, target, iter, body, strict)
if isinstance(expr, Expr):
atoms = expr.atoms(Symbol)
return READ*len(atoms)
elif isinstance(expr, Assign):
return count_access(expr.rhs, visual) + WRITE
elif isinstance(expr, Tuple):
return sum(count_access(i, visual) for i in expr)
elif isinstance(expr, CodeBlock):
return sum(count_access(i, visual) for i in expr.body)
elif isinstance(expr, For):
s = expr.iterable.size
ops = sum(count_access(i, visual) for i in expr.body.body)
return ops*s
elif isinstance(expr, (Zeros, Ones)):
import numpy as np
return WRITE*np.prod(expr.shape)
elif isinstance(expr, NewLine):
return 0
else:
raise NotImplementedError('TODO count_access for {}'.format(type(expr)))
return expr
if isinstance(expr, (list, tuple, Tuple)):
return [subs(i, new_elements) for i in expr]
elif isinstance(expr, While):
test = subs(expr.test, new_elements)
body = subs(expr.body, new_elements)
return While(test, body)
elif isinstance(expr, For):
target = subs(expr.target, new_elements)
it = subs(expr.iterable, new_elements)
target = expr.target
it = expr.iterable
body = subs(expr.body, new_elements)
return For(target, it, body)
elif isinstance(expr, If):
args = []
for block in expr.args:
test = block[0]
stmts = block[1]
t = subs(test, new_elements)
s = subs(stmts, new_elements)
args.append((t, s))
return If(*args)
elif isinstance(expr, Return):
for i in new_elements:
expr = expr.subs(i[0],i[1])
return expr
def get_assigned_symbols(expr):
"""Returns all assigned symbols (as sympy Symbol) in the AST.
Parameters
----------
expr: Expression
any AST valid expression
"""
if isinstance(expr, (CodeBlock, FunctionDef, For, While)):
return get_assigned_symbols(expr.body)
elif isinstance(expr, FunctionalFor):
return get_assigned_symbols(expr.loops)
elif isinstance(expr, If):
return get_assigned_symbols(expr.bodies)
elif iterable(expr):
symbols = []
for a in expr:
symbols += get_assigned_symbols(a)
symbols = set(symbols)
symbols = list(symbols)
return symbols
elif isinstance(expr, (Assign, AugAssign)):
ind = vars_new[i].indices
tp = list(stmts[i + 1].atoms(Tuple))
size = None
size = [None] * len(ind)
for (j, k) in enumerate(ind):
for t in tp:
if k == t[0]:
size[j] = t[2] - t[1] + 1
break
if not all(size):
raise ValueError('Unable to find range of index')
name = str(vars_new[i].base)
var = Symbol(name)
stmt = Assign(var, Function('empty')(size[0]))
allocate.append(stmt)
stmts[i] = For(ind[0], Function('range')(size[0]), [stmts[i]], strict=False)
lhs = create_variable(expr)
stmts[-1] = Assign(lhs, stmts[-1])
imports = [Import('empty', 'numpy')]
return imports + allocate + stmts