Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# If you would like to license the source code under different terms,
# please contact James Kermode, james.kermode@gmail.com
import logging
import os
import warnings
import numpy as np
from f90wrap import codegen as cg
from f90wrap import fortran as ft
from f90wrap.six import string_types # Python 2/3 compatibility library
from f90wrap.transform import ArrayDimensionConverter
class F90WrapperGenerator(ft.FortranVisitor, cg.CodeGenerator):
"""
Creates the Fortran90 code necessary to wrap a given Fortran parse tree
suitable for input to `f2py`.
Each node of the tree (Module, Subroutine etc.) is wrapped according to the
rules in this class when visited (using `F90WrapperGenerator.visit()`).
Each module's wrapper is written to a separate file, with top-level
procedures written to another separate file. Derived-types and arrays (both
of normal types and derive-types) are specially treated. For each, a number
of subroutines allowing the getting/setting of items, and retrieval of array
length are written. Furthermore, derived-types are treated as opaque
references to enable wrapping with `f2py`.
Parameters
----------
def generic_visit(self, node):
print(' ' * self.depth + str(node))
self.depth += 1
FortranVisitor.generic_visit(self, node)
self.depth -= 1
new_attribs = []
for attrib in node.attributes:
if attrib.startswith('dimension('):
new_attribs.append(attrib.replace(old_name, new_name))
else:
new_attribs.append(attrib)
node.attributes = new_attribs
return self.generic_visit(node)
visit_Procedure = visit_Argument
visit_Element = visit_Argument
visit_Module = visit_Argument
visit_Type = visit_Argument
class RenameArgumentsPython(ft.FortranVisitor):
def __init__(self, types):
self.types = types
def visit_Procedure(self, node):
if hasattr(node, 'method_name'):
if 'constructor' in node.attributes:
node.ret_val[0].py_name = 'self'
elif len(node.arguments) >= 1 and node.arguments[0].type in self.types:
node.arguments[0].py_name = 'self'
elif hasattr(node, 'attributes') and 'callback' in node.attributes:
self.visit_Argument(node)
return self.generic_visit(node)
def visit_Argument(self, node):
if not hasattr(node, 'py_name'):
node.py_name = node.name
continue
elif not isinstance(value, Fortran):
new_values.extend(value)
continue
new_values.append(value)
old_value[:] = new_values
elif isinstance(old_value, Fortran):
new_node = self.visit(old_value)
if new_node is None:
delattr(node, field)
else:
setattr(node, field, new_node)
return node
class FortranTreeDumper(FortranVisitor):
"""
Subclass of `FortranVisitor` which prints a textual representation
of the Fortran parse tree.
"""
def __init__(self):
self.depth = 0
def generic_visit(self, node):
print(' ' * self.depth + str(node))
self.depth += 1
FortranVisitor.generic_visit(self, node)
self.depth -= 1
def dump(node):
"""Print contents of Fortran parse tree starting at `node`."""
FortranTreeDumper().visit(node)
class NormaliseTypes(ft.FortranVisitor):
"""
Convert all type names to standard form and resolve kind names
"""
def __init__(self, kind_map):
self.kind_map = kind_map
def visit_Declaration(self, node):
node.type = ft.normalise_type(node.type, self.kind_map)
return self.generic_visit(node)
visit_Argument = visit_Declaration
class SetInterfaceProcedureCallNames(ft.FortranVisitor):
"""
Set call names of procedures within overloaded interfaces to the name of the interface
"""
def visit_Interface(self, node):
for proc in node.procedures:
logging.info('setting call_name of %s to %s' % (proc.name, node.name))
proc.call_name = node.name
return node
def transform_to_generic_wrapper(tree, types, callbacks, constructors,
destructors, short_names, init_lines,
only_subs, only_mods, argument_name_map,
move_methods, shorten_routine_names,
modules_for_type, remove_optional_arguments):
logging.debug('marking private symbol ' + node.name)
self.mod.private_symbols.append(node.name)
else:
# symbol should be marked as public if it's not already
if node.name not in self.mod.public_symbols:
logging.debug('marking public symbol ' + node.name)
self.mod.public_symbols.append(node.name)
else:
raise ValueError('bad default access %s for module %s' %
(self.mod.default_access, self.mod.name))
return node # no need to recurse further
class PrivateSymbolsRemover(FortranTransformer):
"""
Transform a tree by removing private symbols.
"""
def __init__(self):
self.mod = None
def visit_Module(self, mod):
# keep track of the current module
self.mod = mod
self.generic_visit(mod)
self.mod = None
def visit(self, node):
if self.mod is None:
return self.generic_visit(node)
def fix_argument_attributes(node):
"""
Walk over all procedures in the tree starting at `node` and
fix the argument attributes.
"""
for mod, sub, arguments in walk_procedures(node):
for arg in arguments:
if not hasattr(arg, 'type'):
arg.type = 'callback'
arg.value = ''
arg.attributes.append('callback')
return node
class LowerCaseConverter(FortranTransformer):
"""
Subclass of FortranTransformer which converts program, module,
procedure, interface, type and declaration names and attributes to
lower case. Original names are preserved in the *orig_name*
attribute.
"""
def visit_Program(self, node):
node.orig_name = node.name
node.name = node.name.lower()
return self.generic_visit(node)
def visit_Module(self, node):
node.orig_name = node.name
node.name = node.name.lower()
node.default_access = node.default_access.lower()
node.default_access = 'public'
if 'private' in node.attributes:
node.default_access = 'private'
node = self.generic_visit(node)
self.type = None
return node
def visit_Element(self, node):
if self.type is not None:
self.update_access(node, self.mod, self.type.default_access, in_type=True)
else:
self.update_access(node, self.mod, self.mod.default_access)
return node
class PrivateSymbolsRemover(ft.FortranTransformer):
"""
Transform a tree by removing private symbols
"""
def __init__(self):
self.mod = None
def visit_Module(self, mod):
# keep track of the current module
self.mod = mod
mod = self.generic_visit(mod)
self.mod = None
return mod
def visit_Procedure(self, node):
if self.mod is None:
node.orig_name = node.name
node.name = node.name.lower()
node.type = node.type.lower()
node.attributes = [a.lower() for a in node.attributes]
return self.generic_visit(node)
def strip_type(t):
"""Return type name from type declaration"""
t = t.replace(' ', '') # remove blanks
if t.startswith('type('):
t = t[t.index('(') + 1:t.index(')')]
if t.startswith('class('):
t = t[t.index('(') + 1:t.index(')')]
return t.lower()
class AccessUpdater(FortranTransformer):
"""Visit module contents and update public_symbols and
private_symbols lists to be consistent with (i) default module
access; (ii) public and private statements at module level;
(iii) public and private attibutes."""
def __init__(self):
self.mod = None
def visit_Module(self, mod):
# keep track of the current module
self.mod = mod
self.generic_visit(mod)
self.mod = None
def visit(self, node):
if self.mod is None:
def remove_private_symbols(node):
"""
Walk the tree starting at *node*, removing all private symbols.
This function first applies the AccessUpdater transformer to
ensure module *public_symbols* and *private_symbols* are up to
date with *default_access* and individual `public` and `private`
attributes.
"""
node = AccessUpdater().visit(node)
node = PrivateSymbolsRemover().visit(node)
return node
class UnwrappablesRemover(ft.FortranTransformer):
def __init__(self, callbacks, types, constructors, destructors, remove_optional_arguments):
self.callbacks = callbacks
self.types = types
self.constructors = constructors
self.destructors = destructors
self.remove_optional_arguments = remove_optional_arguments
def visit_Interface(self, node):
# don't wrap operator overloading routines
if node.name.startswith('operator('):
return None
return self.generic_visit(node)
def visit_Procedure(self, node):
# special case: keep all constructors and destructors, although