Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
from collections import namedtuple
from texttable import Texttable
from nncf.compression_method_api import CompressionAlgorithm
from nncf.dynamic_graph.transform_graph import replace_modules_by_nncf_modules
from nncf.layers import NNCF_MODULES
from nncf.layer_utils import COMPRESSION_MODULES
from nncf.operations import UpdateWeight
from nncf.utils import get_all_modules_by_type, in_scope_list
SparseModuleInfo = namedtuple('SparseModuleInfo', ['module_name', 'module', 'operand'])
class BaseSparsityAlgo(CompressionAlgorithm):
def freeze(self):
raise NotImplementedError
def set_sparsity_level(self, sparsity_level):
raise NotImplementedError
def _replace_sparsifying_modules_by_nncf_modules(self, device, ignored_scopes, target_scopes, logger):
self._model = replace_modules_by_nncf_modules(self._model,
ignored_scopes=ignored_scopes,
target_scopes=target_scopes,
logger=logger)
self._model.to(device)
def _register_weight_sparsifying_operations(self, device, ignored_scopes, target_scopes, logger):
sparsified_modules = get_all_modules_by_type(self._model, NNCF_MODULES)
self.sparsified_module_info = []