Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_multiple_imports(self):
"""Multiple imports from a single module should each be on their own line"""
imports = ImportMap()
imports['a.module'] = {'AClass', 'AnotherClass', 'AThirdClass'}
stub = ImportBlockStub(imports)
expected = "\n".join([
'from a.module import (',
' AClass,',
' AThirdClass,',
' AnotherClass,',
')',
])
assert stub.render() == expected
def test_single_import(self):
"""Single imports should be on one line"""
imports = ImportMap()
imports['a.module'] = {'AClass'}
imports['another.module'] = {'AnotherClass'}
stub = ImportBlockStub(imports)
expected = "\n".join([
'from a.module import AClass',
'from another.module import AnotherClass',
])
assert stub.render() == expected
def test_merge(self):
a = ImportMap()
a['module.a'] = {'ClassA', 'ClassB'}
a['module.b'] = {'ClassE', 'ClassF'}
b = ImportMap()
b['module.a'] = {'ClassB', 'ClassC'}
b['module.c'] = {'ClassX', 'ClassY'}
expected = ImportMap()
for mod in ('module.a', 'module.b', 'module.c'):
expected[mod] = a[mod] | b[mod]
a.merge(b)
assert a == expected
def get_imports_for_signature(sig: inspect.Signature) -> ImportMap:
"""Return the imports (module, name) needed for all types in annotations"""
imports = ImportMap()
for param in sig.parameters.values():
param_imports = get_imports_for_annotation(param.annotation)
imports.merge(param_imports)
return_imports = get_imports_for_annotation(sig.return_annotation)
imports.merge(return_imports)
return imports
def get_imports_for_annotation(anno: Any) -> ImportMap:
"""Return the imports (module, name) needed for the type in the annotation"""
imports = ImportMap()
if (
anno is inspect.Parameter.empty or
anno is inspect.Signature.empty or
not (isinstance(anno, type) or is_any(anno) or is_union(anno) or is_generic(anno)) or
anno.__module__ == 'builtins'
):
return imports
if is_any(anno):
imports['typing'].add('Any')
elif _is_optional(anno):
imports['typing'].add('Optional')
elem_type = _get_optional_elem(anno)
elem_imports = get_imports_for_annotation(elem_type)
imports.merge(elem_imports)
elif is_generic(anno):
if is_union(anno):
def __init__(self, imports: ImportMap = None) -> None:
self.imports = imports if imports else ImportMap()