Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _make_code_from_fdef_node(self, fdef):
transformed = TypeHintRemover().visit(fdef)
# convert the AST back to source code
code = extast.unparse(transformed)
return format_str(code)
return a + fooo()
def bar():
return foo()
"""
module = extast.parse(code)
function = module.body[3]
capturex = CaptureX(module, function)
capturex.visit(function)
for node in capturex.external:
# print(astunparse.dump(node))
print(extast.unparse(node).strip())
def strip_typehints(source):
"""Strip the type hints from a function"""
source = format_str(source)
# parse the source code into an AST
parsed_source = ast.parse(source)
# remove all type annotations, function return type definitions
# and import statements from 'typing'
transformed = TypeHintRemover().visit(parsed_source)
# convert the AST back to source code
striped_code = extast.unparse(transformed)
return striped_code
def make_code_from_fdef_node(fdef):
transformed = TypeHintRemover().visit(fdef)
# convert the AST back to source code
code = extast.unparse(transformed)
return format_str(code)
def analysis_jit(code, pathfile, backend_name):
"""Gather the informations for ``@jit`` with an ast analysis"""
debug = logger.debug
debug("extast.parse")
module = extast.parse(code)
debug("compute ancestors and chains")
ancestors, duc, udc = compute_ancestors_chains(module)
jitted_dicts = get_decorated_dicts(
module, ancestors, duc, pathfile, backend_name, decorator="jit"
)
jitted_dicts = dict(
functions=jitted_dicts["functions"][backend_name],
functions_ext=jitted_dicts["functions_ext"][backend_name],
methods=jitted_dicts["methods"][backend_name],
classes=jitted_dicts["classes"][backend_name],
)
debug("compute code dependance")
def make_backend_source(self, info_analysis, func, path_backend):
func_name = func.__name__
jitted_dicts = info_analysis["jitted_dicts"]
src = info_analysis["codes_dependance"][func_name]
if func_name in info_analysis["special"]:
if func_name in jitted_dicts["functions"]:
src += extast.unparse(jitted_dicts["functions"][func_name])
elif func_name in jitted_dicts["methods"]:
src += extast.unparse(jitted_dicts["methods"][func_name])
else:
# TODO find a prettier solution to remove decorator for cython
# than doing two times a regex
src += re.sub(
r"@.*?\sdef\s", "def ", get_source_without_decorator(func)
)
has_to_write = True
if path_backend.exists() and mpi.rank == 0:
with open(path_backend) as file:
src_old = file.read()
if src_old == src:
has_to_write = False
return src, has_to_write
def find_decorated_function(module, function_name: str, pathfile: str = None):
ext_module = False
for node in module.body:
if isinstance(node, ast.FunctionDef):
if node.name == function_name:
return node, ext_module
# look for function_name in the imports of module
if isinstance(node, ast.ImportFrom):
for func in node.names:
if func.name == function_name:
# find and read the imported module file
name, path = find_path(node, pathfile)
with open(path) as file:
ext_module = extast.parse(file.read())
# find the definition of function_name in the imported module
node, _ = find_decorated_function(ext_module, function_name)
return node, ext_module
raise RuntimeError
def _fill_ast_annotations_class(class_def, ast_annotations):
dict_node = ast_annotations.value
for node in class_def.body:
if not isinstance(node, ast.AnnAssign):
continue
if node.annotation is not None:
name = node.target.id
dict_node.keys.append(extast.Constant(value=name))
dict_node.values.append(node.annotation)
def make_code_external(self):
code = []
for node in self.external:
code.append(extast.unparse(node).strip())
return "\n".join(code)
nodes_def_vars = [
find_last_def_node(variable, module) for variable in variables
]
nodes_def_vars.sort(key=lambda x: x.lineno)
capturex = CaptureX(
list(nodes_def_vars),
module,
ancestors,
defuse_chains=duc,
usedef_chains=udc,
)
lines_ext = []
for node in capturex.external:
lines_ext.append(extast.unparse(node).strip())
for node in nodes_def_vars:
line = extast.unparse(node).strip()
if line not in lines_ext:
lines_ext.append(line)
return "\n".join(lines_ext)