Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
if module is None:
continue
child_name = get_node_name(module, name, prefix)
replaced_module = replace_fn(module)
if replaced_module is not None and module is not replaced_module:
if in_scope_list(child_name, ignored_scopes):
if logger is not None:
logger.info("Ignored wrapping modules in scope: {}".format(child_name))
continue
if target_scopes is None or in_scope_list(child_name, target_scopes):
if logger is not None:
logger.info("Wrapping module {} by {}".format(
child_name, get_node_name(replaced_module, name, prefix)))
if isinstance(model, nn.Sequential):
# pylint: disable=protected-access
model._modules[name] = replaced_module
else:
setattr(model, name, replaced_module)
replace_modules(module, replace_fn, ignored_scopes, target_scopes, memo, child_name, logger)
return model
def replace_modules(model: nn.Module, replace_fn, ignored_scopes=None, target_scopes=None, memo=None, prefix=None,
logger=None):
if memo is None:
memo = set()
prefix = model.__class__.__name__
if model not in memo:
memo.add(model)
for name, module in model.named_children():
if module is None:
continue
child_name = get_node_name(module, name, prefix)
replaced_module = replace_fn(module)
if replaced_module is not None and module is not replaced_module:
if in_scope_list(child_name, ignored_scopes):
if logger is not None:
logger.info("Ignored wrapping modules in scope: {}".format(child_name))
continue
if target_scopes is None or in_scope_list(child_name, target_scopes):
if logger is not None:
logger.info("Wrapping module {} by {}".format(
child_name, get_node_name(replaced_module, name, prefix)))
if isinstance(model, nn.Sequential):
# pylint: disable=protected-access
model._modules[name] = replaced_module
else: