Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
output_name_without_port
)
tf.reset_default_graph()
tf.import_graph_def(graph_def, name='')
# optimize graph
graph_def = tf_optimize(input_names_with_port, output_names_with_port,
sess.graph_def, True)
with tf.Session() as sess:
if self.config.is_debug_mode:
if not os.path.exists(self.test_data_directory):
os.makedirs(self.test_data_directory)
model_path = os.path.join(self.test_data_directory, self._testMethodName + "_after_tf_optimize.pb")
utils.save_protobuf(model_path, graph_def)
self.logger.debug("created file %s", model_path)
tf.reset_default_graph()
tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
inferred_graph = infer_shape_for_graph(sess.graph)
# compare each operation
for op in origin_graph.get_operations():
inferred_op = None
try:
inferred_op = inferred_graph.get_operation_by_name(op.name)
except KeyError:
continue
self._compare_shape_for_op(op, inferred_op)
def create_onnx_file(name, model_proto, inputs, outdir):
os.makedirs(outdir, exist_ok=True)
model_path = os.path.join(outdir, name + ".onnx")
utils.save_protobuf(model_path, model_proto)
logger.info("Created %s", model_path)
def test_dropout(self):
with tf.Session() as sess:
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
x2 = tf.placeholder(tf.float32, [1, 3], name="input2")
prop = tf.placeholder(tf.float32, (), name="prob")
x_ = tf.add(x1, x2)
x_ = tf.nn.dropout(x_, prop)
x_ = tf.identity(x_, name="output1")
x_ = tf.identity(x_, name="output2")
_ = tf.identity(x_, name="output")
# feed output_names in order to remove unused nodes.
g = process_tf_graph(sess.graph, opset=self.config.opset, output_names=["output:0"])
utils.save_protobuf("./test.onnx", g.make_model("test"))
actual = onnx_to_graphviz(g)
expected = 'digraph { prob [op_type=Placeholder shape="[]"] input2 [op_type=Placeholder shape="[1, 3]"] ' \
'input1 [op_type=Placeholder shape="[2, 3]"] Add [op_type=Add] output1 [op_type=Identity] ' \
'output2 [op_type=Identity] output [op_type=Identity] output_graph_outputs_Identity__3 ' \
'[op_type=Identity] input1:0 -> Add input2:0 -> Add Add:0 -> output1 output1:0 -> output2 ' \
'output2:0 -> output output_raw_output___2:0 -> output_graph_outputs_Identity__3 }'
self.assertEqual(expected, actual)
logger.info("Load model from %s", model_path)
input_names = list(self.input_names.keys())
outputs = self.output_names
if self.model_type in ["checkpoint"]:
graph_def, input_names, outputs = tf_loader.from_checkpoint(model_path, input_names, outputs)
elif self.model_type in ["saved_model"]:
graph_def, input_names, outputs = tf_loader.from_saved_model(model_path, input_names, outputs)
else:
graph_def, input_names, outputs = tf_loader.from_graphdef(model_path, input_names, outputs)
# remove unused input names
input_names = list(set(input_names).intersection(self.input_names.keys()))
graph_def = tf2onnx.tf_loader.tf_optimize(input_names, self.output_names, graph_def, fold_const)
if utils.is_debug_mode():
utils.save_protobuf(os.path.join(TEMP_DIR, name + "_after_tf_optimize.pb"), graph_def)
inputs = {}
shape_override = {}
g = tf.import_graph_def(graph_def, name='')
# with tf_session(config=tf.ConfigProto(allow_soft_placement=True), graph=g) as sess:
with tf_session(graph=g) as sess:
# create the input data
for k in input_names:
v = self.input_names[k]
t = sess.graph.get_tensor_by_name(k)
expected_dtype = tf.as_dtype(t.dtype).name
if isinstance(v, six.text_type) and v.startswith("np."):
np_value = eval(v) # pylint: disable=eval-used
if expected_dtype != np_value.dtype:
logger.warning("dtype mismatch for input %s: expected=%s, actual=%s", k, expected_dtype,
np_value.dtype)
tf.reset_default_graph()
tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
variables_lib.global_variables_initializer().run()
output_dict = []
for out_name in output_names_with_port:
output_dict.append(sess.graph.get_tensor_by_name(out_name))
expected = sess.run(output_dict, feed_dict=feed_dict)
if self.config.is_debug_mode:
if not os.path.exists(self.test_data_directory):
os.makedirs(self.test_data_directory)
model_path = os.path.join(self.test_data_directory, self._testMethodName + "_original.pb")
utils.save_protobuf(model_path, sess.graph_def)
self.logger.debug("created file %s", model_path)
graph_def = tf_optimize(input_names_with_port, output_names_with_port,
sess.graph_def, constant_fold)
if self.config.is_debug_mode:
model_path = os.path.join(self.test_data_directory, self._testMethodName + "_after_tf_optimize.pb")
utils.save_protobuf(model_path, graph_def)
self.logger.debug("created file %s", model_path)
tf.reset_default_graph()
tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
g = process_tf_graph(sess.graph, opset=self.config.opset, output_names=output_names_with_port,
target=self.config.target, **process_args)
opset=args.opset,
custom_op_handlers=custom_ops,
extra_opset=extra_opset,
shape_override=args.shape_override,
input_names=inputs,
output_names=outputs,
inputs_as_nchw=args.inputs_as_nchw)
onnx_graph = optimizer.optimize_graph(g)
model_proto = onnx_graph.make_model("converted from {}".format(model_path))
# write onnx graph
logger.info("")
logger.info("Successfully converted TensorFlow model %s to ONNX", model_path)
if args.output:
utils.save_protobuf(args.output, model_proto)
logger.info("ONNX model is saved at %s", args.output)
else:
logger.info("To export ONNX model to file, please run with `--output` option")