Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
if not in_node.doc_string:
return None
from model_editor_internal import parse_custom_attributes
custom_qcfg = parse_custom_attributes(in_node)
if custom_qcfg:
assert custom_qcfg['IntermediateBit'] == 32
assert custom_qcfg['PerRowQuantization']
assert custom_qcfg['QuantizeBitOfVector'] == custom_qcfg['QuantizeBitOfMatrix']
qbits = custom_qcfg['QuantizeBitOfVector']
assert ("Asymmetric" in custom_qcfg['VectorQuantizationType']) == ("Asymmetric" in custom_qcfg['MatrixQuantizationType'])
symmetric = 0 if "Asymmetric" in custom_qcfg['VectorQuantizationType'] else 1
x_signed = 0 if "Unsigned" in custom_qcfg['VectorQuantizationType'] else 1
w_signed = 0 if "Unsigned" in custom_qcfg['MatrixQuantizationType'] else 1
x_reserved_bits = custom_qcfg['ReservedBitOfVector']
w_reserved_bits = custom_qcfg['ReservedBitOfMatrix']
return {'W' : dict(QuantizeConfig(signed=w_signed, reserved_bits=w_reserved_bits, type_bits=qbits)),
'X' : dict(QuantizeConfig(signed=x_signed, reserved_bits=x_reserved_bits, type_bits=qbits)),
'Symmetric' : symmetric}
return None
def convert_matmul_model(input_model, output_model, only_for_scan=False, share_input_quantization=False, preset_str='asymm8_param0_input1', qcfg_json=None, export_qcfg_json=None):
preset_qcfgs = {'asymm8_param0_input1' : {'W' : dict(QuantizeConfig(signed=1, reserved_bits=0, type_bits=8)),
'X' : dict(QuantizeConfig(signed=0, reserved_bits=1, type_bits=8)),
'Symmetric' : 0},
'symm16_param3_input3' : {'W' : dict(QuantizeConfig(signed=1, reserved_bits=3, type_bits=16)),
'X' : dict(QuantizeConfig(signed=1, reserved_bits=3, type_bits=16)),
'Symmetric' : 1}}
default_qcfg = preset_qcfgs[preset_str]
in_mp = onnx.load(input_model)
qcfg_dict = {}
if qcfg_json and not export_qcfg_json:
with open(qcfg_json, 'r') as f:
qcfg_dict = json.load(f)
out_mp = onnx.ModelProto()
out_mp.CopyFrom(in_mp)
out_mp.ir_version = 5 # update ir version to avoid requirement of initializer in graph input
def from_dict(qcfg_dict):
return QuantizeConfig(1 if qcfg_dict['QuantizationType'] == 'Signed' else 0,
qcfg_dict['ReservedBit'],
qcfg_dict['QuantizeBit'])