Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
self.trivializing = True
self.namer.target = node.targets[0]
if isinstance(node.targets[0], (gast.Subscript, gast.Attribute)):
node.value = self.trivialize(node.value)
node.targets[0] = self.visit(node.targets[0])
elif isinstance(node.targets[0], gast.Tuple):
node.value = self.visit(node.value)
name = self.namer.name(node.targets[0])
target = gast.Name(id=name, ctx=gast.Store(), annotation=None)
for i, elt in enumerate(node.targets[0].elts):
stmt = gast.Assign(
targets=[elt],
value=gast.Subscript(
value=gast.Name(id=name, ctx=gast.Load(),
annotation=None),
slice=gast.Index(value=gast.Num(n=i)),
ctx=gast.Load()))
self.mark(stmt)
self.append(stmt)
node.targets[0] = target
elif not isinstance(node.targets[0], gast.Name):
raise ValueError('Cannot Assign to %s' % type(node.target))
node = self.generic_visit(node)
self.namer.target = None
self.trivializing = False
return node
stmts.append(
ast.Return(
ast.Tuple(
[ast.Name(fp, ast.Load(), None) for fp in out_parameters],
ast.Load()
)
)
)
if has_return:
pr = PatchReturn(stmts[-1], has_break or has_cont)
pr.visit(fdef)
if has_break or has_cont:
if not has_return:
stmts[-1].value = ast.Tuple([ast.Num(LOOP_NONE),
stmts[-1].value],
ast.Load())
pbc = PatchBreakContinue(stmts[-1])
pbc.visit(fdef)
return fdef
def visit_If(self, node):
node = self.generic_visit(node)
node = self.add_mask(node, node.test)
nodes = [node]
if len(node.orelse) > 0:
test_inverse = gast.Call(
gast.Attribute(
node.test, gast.Name('eq', gast.Load(), None), gast.Load()),
[gast.Num(0)], [])
else_node = gast.If(any_active(test_inverse), node.orelse, [])
node.orelse = []
self.add_mask(else_node, test_inverse)
nodes.append(else_node)
node.test = any_active(node.test)
return nodes
def visit_Subscript(self, node):
if isinstance(node.value, (gast.Name, gast.Num)) and node.value.id == 'd':
if (not isinstance(node.slice, gast.Index) or
not isinstance(node.slice.value,
(gast.Subscript, gast.Name, gast.Str))):
# This happens when the gradient of a constant is taken
if self.replace_grad == Replace.TANGENT:
new_node = gast.Num(0)
else:
new_node = gast.Name(id='_', ctx=None, annotation=None)
self.remove(new_node)
elif (self.replace_grad in (Replace.FULL, Replace.TANGENT) or
isinstance(node.ctx, gast.Load)):
new_node = create.create_grad(node.slice.value, self.namer,
self.tangent)
elif isinstance(node.ctx, gast.Store):
new_node = create.create_temp_grad(node.slice.value, self.namer,
self.tangent)
else:
raise ValueError
new_node.ctx = node.ctx
if isinstance(new_node, gast.Tuple):
for elt in new_node.elts:
elt.ctx = node.ctx
if ty_obj.is_fixed_len:
# if the object is fixed-length list, coerce it.
unify(ty_obj, TyList(TyVar()))
ty_ret = list_attr_ty[node.attr](ty_obj)
self.nodetype[node] = ty_ret
elif isinstance(node, gast.Subscript):
# Subscript(expr value, slice slice, expr_context ctx)
ty_obj = self.infer_expr(node.value)
if isinstance(ty_obj, TySequence):
self.infer_slice(node.slice, TyInt())
if ty_obj.is_fixed_len:
if isinstance(node.slice, gast.Index) and \
isinstance(node.slice.value, gast.Num):
# TODO(momohatt): handle cases where index is
# more complex but still a constant
self.nodetype[node] = ty_obj.get_tys()[node.slice.value.n]
else:
ty_obj.coerce_to_variable_len()
if isinstance(node.slice, gast.Index):
self.nodetype[node] = ty_obj.get_ty()
elif isinstance(node.slice, gast.Slice):
self.nodetype[node] = ty_obj
else:
assert False
else:
if isinstance(node.slice, gast.Index):
self.nodetype[node] = ty_obj.get_ty()
elif isinstance(node.slice, gast.Slice):
ty_ret = ty_ret.deref()
self.nodetype[node.func] = TyArrow(ty_args, ty_ret)
self.nodetype[node] = ty_ret
else:
ty_fun = self.infer_expr(node.func)
unify(ty_fun, TyArrow(ty_args, ty_ret))
self.nodetype[node.func] = ty_fun.deref()
self.nodetype[node] = ty_ret.deref()
else:
ty_fun = self.infer_expr(node.func)
unify(ty_fun, TyArrow(ty_args, ty_ret))
self.nodetype[node] = ty_ret.deref()
elif isinstance(node, gast.Num):
# Num(object n)
if isinstance(node.n, int):
self.nodetype[node] = TyInt()
elif isinstance(node.n, float):
self.nodetype[node] = TyFloat()
elif isinstance(node, gast.Str):
# Str(string s)
self.nodetype[node] = TyString()
elif isinstance(node, gast.NameConstant):
# NameConstant(singleton value)
# value is either True, False or None
if isinstance(node.value, bool):
self.update = True
# name for various variables resulting from the static_if
n = len(self.new_functions)
status_n = "$status{}".format(n)
return_n = "$return{}".format(n)
cont_n = "$cont{}".format(n)
if has_return:
cont_ass = self.make_control_flow_handlers(cont_n, status_n,
expected_return,
has_cont, has_break)
cmpr = ast.Compare(ast.Name(status_n, ast.Load(), None),
[ast.Eq()], [ast.Num(EARLY_RET)])
fast_return = [ast.Name(status_n, ast.Store(), None),
ast.Name(return_n, ast.Store(), None),
ast.Name(cont_n, ast.Store(), None)]
return [ast.Assign([ast.Tuple(fast_return, ast.Store())],
actual_call),
ast.If(cmpr,
[ast.Return(ast.Name(return_n, ast.Load(), None))],
cont_ass)]
elif has_break or has_cont:
cont_ass = self.make_control_flow_handlers(cont_n, status_n,
expected_return,
has_cont, has_break)
fast_return = [ast.Name(status_n, ast.Store(), None),
def create_grad_list(self, node):
assert isinstance(node, (gast.List, gast.Tuple)), 'Must be list or tuple'
list_of_nodes = node.elts
elts = []
for _node in list_of_nodes:
if isinstance(_node, (gast.Name, gast.Subscript)):
grad_node = create.create_grad(_node, self.namer, tangent=True)
grad_node.ctx = node.ctx
elts.append(grad_node)
elif isinstance(_node, gast.Num):
elts.append(gast.Num(0))
elif isinstance(_node, (gast.List, gast.Tuple)):
elts.append(self.create_grad_list(_node.elts))
else:
raise ValueError('Cannot handle node type %s' % type(_node))
return node.__class__(elts=elts, ctx=node.ctx)
def visit_Subscript(self, node):
if isinstance(node.value, (gast.Name, gast.Num)) and node.value.id == 'd':
if (not isinstance(node.slice, gast.Index) or
not isinstance(node.slice.value,
(gast.Subscript, gast.Name, gast.Str))):
# This happens when the gradient of a constant is taken
if self.replace_grad == Replace.TANGENT:
new_node = gast.Num(0)
else:
new_node = gast.Name(id='_', ctx=None, annotation=None)
self.remove(new_node)
elif (self.replace_grad in (Replace.FULL, Replace.TANGENT) or
isinstance(node.ctx, gast.Load)):
new_node = create.create_grad(node.slice.value, self.namer,
self.tangent)
elif isinstance(node.ctx, gast.Store):
new_node = create.create_temp_grad(node.slice.value, self.namer,
self.tangent)
# to just assert that they're set.
for a in node.args:
self._check_inner_children_have_context(a)
for k in node.keywords:
self._check_inner_children_have_context(k.value)
elif isinstance(node, gast.Dict):
# We may be able to override these to Load(), but for now it's simpler
# to just assert that they're set.
for e in node.keys:
self._check_inner_children_have_context(e)
for e in node.values:
self._check_inner_children_have_context(e)
elif isinstance(node, gast.Subscript):
self._set_inner_child_context(node.value, ctx)
self._check_inner_children_have_context(node.slice)
elif isinstance(node, (gast.Str, gast.Num)):
pass
else:
raise ValueError('unexpected node type "%s"' % node)