Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
logger=None):
if memo is None:
memo = set()
prefix = model.__class__.__name__
if model not in memo:
memo.add(model)
for name, module in model.named_children():
if module is None:
continue
child_name = get_node_name(module, name, prefix)
replaced_module = replace_fn(module)
if replaced_module is not None and module is not replaced_module:
if in_scope_list(child_name, ignored_scopes):
if logger is not None:
logger.info("Ignored wrapping modules in scope: {}".format(child_name))
continue
if target_scopes is None or in_scope_list(child_name, target_scopes):
if logger is not None:
logger.info("Wrapping module {} by {}".format(
child_name, get_node_name(replaced_module, name, prefix)))
if isinstance(model, nn.Sequential):
# pylint: disable=protected-access
model._modules[name] = replaced_module
else:
setattr(model, name, replaced_module)
replace_modules(module, replace_fn, ignored_scopes, target_scopes, memo, child_name, logger)
return model
def _register_weight_sparsifying_operations(self, device, ignored_scopes, target_scopes, logger):
sparsified_modules = get_all_modules_by_type(self._model, NNCF_MODULES)
self.sparsified_module_info = []
for module_name, module in sparsified_modules.items():
if in_scope_list(module_name, ignored_scopes):
logger.info("Ignored adding Weight Sparsifier in scope: {}".format(module_name))
continue
if target_scopes is None or in_scope_list(module_name, target_scopes):
logger.info("Adding Weight Sparsifier in scope: {}".format(module_name))
operation = self.create_weight_sparsifying_operation(module)
opid = module.register_pre_forward_operation(UpdateWeight(operation).to(device))
self.sparsified_module_info.append(
SparseModuleInfo(module_name, module, module.get_pre_op(opid).operand))
def _register_weight_sparsifying_operations(self, device, ignored_scopes, target_scopes, logger):
sparsified_modules = get_all_modules_by_type(self._model, NNCF_MODULES)
self.sparsified_module_info = []
for module_name, module in sparsified_modules.items():
if in_scope_list(module_name, ignored_scopes):
logger.info("Ignored adding Weight Sparsifier in scope: {}".format(module_name))
continue
if target_scopes is None or in_scope_list(module_name, target_scopes):
logger.info("Adding Weight Sparsifier in scope: {}".format(module_name))
operation = self.create_weight_sparsifying_operation(module)
opid = module.register_pre_forward_operation(UpdateWeight(operation).to(device))
self.sparsified_module_info.append(
SparseModuleInfo(module_name, module, module.get_pre_op(opid).operand))
if model not in memo:
memo.add(model)
for name, module in model.named_children():
if module is None:
continue
child_name = get_node_name(module, name, prefix)
replaced_module = replace_fn(module)
if replaced_module is not None and module is not replaced_module:
if in_scope_list(child_name, ignored_scopes):
if logger is not None:
logger.info("Ignored wrapping modules in scope: {}".format(child_name))
continue
if target_scopes is None or in_scope_list(child_name, target_scopes):
if logger is not None:
logger.info("Wrapping module {} by {}".format(
child_name, get_node_name(replaced_module, name, prefix)))
if isinstance(model, nn.Sequential):
# pylint: disable=protected-access
model._modules[name] = replaced_module
else:
setattr(model, name, replaced_module)
replace_modules(module, replace_fn, ignored_scopes, target_scopes, memo, child_name, logger)
return model