Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def optimize_graph(graph, signature_def, output_graph,
tf_version, quantization_dtype=None, skip_op_check=False,
strip_debug_ops=False):
"""Takes a Python Graph object and optimizes the graph.
Args:
graph: The frozen graph to optimize.
signature_def: the SignatureDef of the inference graph.
output_graph: The location of the output graph.
tf_version: Tensorflow version of the input graph.
quantization_dtype: An optional numpy dtype to quantize weights to for
compression. Only np.uint8 and np.uint16 are supported.
skip_op_check: Bool whether to skip the op check.
strip_debug_ops: Bool whether to strip debug ops.
"""
fuse_prelu.register_prelu_func(graph)
# Add a collection 'train_op' so that Grappler knows the outputs.
for _, output in signature_def.outputs.items():
name = output.name.split(':')[0]
graph.add_to_collection('train_op', graph.get_operation_by_name(name))
graph_def = graph.as_graph_def()
unsupported = validate(graph_def.node, skip_op_check,
strip_debug_ops)
if unsupported:
raise ValueError('Unsupported Ops in the model before optimization\n' +
', '.join(unsupported))
# first pass of grappler optimization, this is needed for batch norm folding.
config = config_pb2.ConfigProto()
signature_def: the SignatureDef of the inference graph.
quantization_dtype: An optional numpy dtype to quantize weights to for
compression. Only np.uint8 and np.uint16 are supported.
"""
constants = [node for node in graph_def.node if node.op == 'Const']
const_inputs = {}
# removed the conditional inputs for constants
for const in constants:
const_inputs[const.name] = const.input[:]
del const.input[:]
print('Writing weight file ' + output_graph + '...')
const_manifest = []
graph = tf.Graph()
fuse_prelu.register_prelu_func(graph)
fuse_depthwise_conv2d.register_fused_depthwise_conv2d_func(graph)
extracted_graph = fuse_depthwise_conv2d.extract_op_attributes(graph_def)
with tf.compat.v1.Session(graph=graph) as sess:
tf.import_graph_def(extracted_graph, name='')
for const in constants:
tensor = graph.get_tensor_by_name(const.name + ':0')
value = tensor.eval(session=sess)
if not isinstance(value, np.ndarray):
value = np.array(value)
const_manifest.append({'name': const.name, 'data': value})
# Restore the conditional inputs
const.input[:] = const_inputs[const.name]