Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
low=2, high=10, generator=lambda: self.generate_Name(gast.Param()))
args = gast.arguments(arg_vars, None, [], [], None, [])
def visit_GeneratorExp(self, node):
self.update = True
node.elt = self.visit(node.elt)
name = "generator_expression{0}".format(self.count)
self.count += 1
args = self.gather(ImportedIds, node)
self.count_iter = 0
body = reduce(self.nest_reducer,
reversed(node.generators),
ast.Expr(ast.Yield(node.elt))
)
sargs = [ast.Name(arg, ast.Param(), None, None) for arg in args]
fd = ast.FunctionDef(name,
ast.arguments(sargs, [], None, [], [], None, []),
[body], [], None, None)
metadata.add(fd, metadata.Local())
self.ctx.module.body.append(fd)
return ast.Call(
ast.Name(name, ast.Load(), None, None),
[ast.Name(arg.id, ast.Load(), None, None) for arg in sargs],
[],
) # no sharing !
def visit_Name(self, node):
self.generic_visit(node)
qn = anno.getanno(node, anno.Basic.QN)
if isinstance(node.ctx, gast.Param):
self._process_function_arg(qn)
elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn):
# E.g. if we had
# a = b
# then for future references to `a` we should have definition = `b`
definition = self.scope.getval(qn)
if anno.hasanno(definition, 'type'):
anno.setanno(node, 'type', anno.getanno(definition, 'type'))
anno.setanno(node, 'type_fqn', anno.getanno(definition, 'type_fqn'))
if anno.hasanno(definition, 'element_type'):
anno.setanno(node, 'element_type',
anno.getanno(definition, 'element_type'))
return node
if len(iters) == 1:
iterAST = iters[0]
varAST = ast.arguments([variables[0]], [], None, [], [], None, [])
else:
self.use_itertools = True
prodName = ast.Attribute(
value=ast.Name(id=mangle('itertools'),
ctx=ast.Load(),
annotation=None, type_comment=None),
attr='product', ctx=ast.Load())
varid = variables[0].id # retarget this id, it's free
renamings = {v.id: (i,) for i, v in enumerate(variables)}
node.elt = ConvertToTuple(varid, renamings).visit(node.elt)
iterAST = ast.Call(prodName, iters, [])
varAST = ast.arguments([ast.Name(varid, ast.Param(), None, None)],
[], None, [], [], None, [])
ldBodymap = node.elt
ldmap = ast.Lambda(varAST, ldBodymap)
return make_attr(ldmap, iterAST)
else:
return self.generic_visit(node)
if MODULES['functools'] not in self.global_declarations.values():
import_ = ast.Import([ast.alias('functools', mangle('functools'))])
self.imports.append(import_)
functools_module = MODULES['functools']
self.global_declarations[mangle('functools')] = functools_module
self.generic_visit(node)
forged_name = "{0}_lambda{1}".format(
self.prefix,
len(self.lambda_functions))
ii = self.gather(ImportedIds, node)
ii.difference_update(self.lambda_functions) # remove current lambdas
binded_args = [ast.Name(iin, ast.Load(), None, None) for iin in sorted(ii)]
node.args.args = ([ast.Name(iin, ast.Param(), None, None)
for iin in sorted(ii)] +
node.args.args)
forged_fdef = ast.FunctionDef(
forged_name,
copy(node.args),
[ast.Return(node.body)],
[], None, None)
metadata.add(forged_fdef, metadata.Local())
self.lambda_functions.append(forged_fdef)
self.global_declarations[forged_name] = forged_fdef
proxy_call = ast.Name(forged_name, ast.Load(), None, None)
if binded_args:
return ast.Call(
ast.Attribute(
ast.Name(mangle('functools'), ast.Load(), None, None),
"partial",
def visit_Name(self, node):
if isinstance(node.ctx, (ast.Param, ast.Store)):
dnode = self.chains.setdefault(node, Def(node))
if node.id in self._promoted_locals[-1]:
self.extend_definition(node.id, dnode)
if dnode not in self.locals[self.module]:
self.locals[self.module].append(dnode)
else:
self.set_definition(node.id, dnode)
if dnode not in self.locals[self._currenthead[-1]]:
self.locals[self._currenthead[-1]].append(dnode)
if node.annotation is not None:
self.visit(node.annotation)
elif isinstance(node.ctx, (ast.Load, ast.Del)):
node_in_chains = node in self.chains
if node_in_chains:
)
)
# add extra metadata to this node
metadata.add(body, metadata.Comprehension(starget))
init = ast.Assign(
[ast.Name(starget, ast.Store(), None, None)],
ast.Call(
ast.Attribute(
ast.Name('__builtin__', ast.Load(), None, None),
comp_type,
ast.Load()
),
[], [],)
)
result = ast.Return(ast.Name(starget, ast.Load(), None, None))
sargs = [ast.Name(arg, ast.Param(), None, None) for arg in args]
fd = ast.FunctionDef(name,
ast.arguments(sargs, [], None, [], [], None, []),
[init, body, result],
[], None, None)
metadata.add(fd, metadata.Local())
self.ctx.module.body.append(fd)
return ast.Call(
ast.Name(name, ast.Load(), None, None),
[ast.Name(arg.id, ast.Load(), None, None) for arg in sargs],
[],
) # no sharing !
# We add the stack as first argument of the primal
node.args.args = [self.stack] + node.args.args
# Rename the function to its primal name
func = anno.getanno(node, 'func')
node.name = naming.primal_name(func, self.wrt)
# The new body is the primal body plus the return statement
node.body = body + node.body[-1:]
# Find the cost; the first variable of potentially multiple return values
# The adjoint will receive a value for the initial gradient of the cost
y = node.body[-1].value
if isinstance(y, gast.Tuple):
y = y.elts[0]
dy = gast.Name(id=self.namer.grad(y.id), ctx=gast.Param(),
annotation=None)
if self.check_dims:
def shape_match_template(primal, adjoint):
assert tangent.shapes_match(
primal, adjoint
), 'Shape mismatch between return value (%s) and seed derivative (%s)' % (
numpy.shape(primal), numpy.shape(adjoint))
shape_check = template.replace(shape_match_template, primal=y, adjoint=dy)
adjoint_body = shape_check + adjoint_body
# Construct the adjoint
adjoint_template = grads.adjoints[gast.FunctionDef]
adjoint, = template.replace(adjoint_template, namer=self.namer,
def make_Iterator(self, gen):
if gen.ifs:
ldFilter = ast.Lambda(
ast.arguments([ast.Name(gen.target.id, ast.Param(), None)],
None, [], [], None, []),
ast.BoolOp(ast.And(), gen.ifs)
if len(gen.ifs) > 1 else gen.ifs[0])
self.use_itertools |= MODULE == 'itertools'
ifilterName = ast.Attribute(
value=ast.Name(id=ASMODULE,
ctx=ast.Load(),
annotation=None),
attr=IFILTER, ctx=ast.Load())
return ast.Call(ifilterName, [ldFilter, gen.iter], [])
else:
return gen.iter
def make_Iterator(self, gen):
if gen.ifs:
ldFilter = ast.Lambda(
ast.arguments([ast.Name(gen.target.id, ast.Param(), None, None)],
[], None, [], [], None, []),
ast.BoolOp(ast.And(), gen.ifs)
if len(gen.ifs) > 1 else gen.ifs[0])
self.use_itertools |= MODULE == 'itertools'
ifilterName = ast.Attribute(
value=ast.Name(id=ASMODULE,
ctx=ast.Load(),
annotation=None, type_comment=None),
attr=IFILTER, ctx=ast.Load())
return ast.Call(ifilterName, [ldFilter, gen.iter], [])
else:
return gen.iter