Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def rewrite_test(g, ops):
pattern = \
OpTypePattern('Add', name='op', inputs=["*", "*"])
ops = g.get_nodes()
matcher = GraphMatcher(pattern)
match_results = list(matcher.match_ops(ops))
for match in match_results:
op = match.get_op('op')
op.type = "Mul"
return ops
with tf.Session() as sess:
x = tf.placeholder(tf.float32, [2, 3], name="input1")
x_ = tf.add(x, x)
_ = tf.identity(x_, name="output")
g = process_tf_graph(sess.graph, opset=self.config.opset, custom_rewriter=[rewrite_test])
self.assertEqual(
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Add [op_type=Mul] '
'output [op_type=Identity] input1:0 -> Add input1:0 -> Add Add:0 -> output }',
onnx_to_graphviz(g))
"""Custom op test."""
@tf_op("Print", onnx_op="Identity")
class Print:
@classmethod
def version_1(cls, ctx, node, **kwargs):
self.assertEqual(node.type, "Identity")
node.domain = constants.TENSORFLOW_OPSET.domain
del node.input[1:]
return node
with tf.Session() as sess:
x = tf.placeholder(tf.float32, [2, 3], name="input1")
x_ = tf.Print(x, [x], "hello")
_ = tf.identity(x_, name="output")
g = process_tf_graph(sess.graph,
opset=self.config.opset,
extra_opset=[constants.TENSORFLOW_OPSET])
self.assertEqual(
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Print [domain="ai.onnx.converters.tensorflow" '
'op_type=Identity] output [op_type=Identity] input1:0 -> Print Print:0 -> output }',
onnx_to_graphviz(g))
self.assertEqual(g.opset, self.config.opset)
self.assertEqual(g.extra_opset, [constants.TENSORFLOW_OPSET])
def test_squeeze(self):
with tf.Session() as sess:
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
x_ = tf.squeeze(x1)
_ = tf.identity(x_, name="output")
g = process_tf_graph(sess.graph, opset=self.config.opset)
self.assertEqual(
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Squeeze [op_type=Squeeze] '
'output [op_type=Identity] input1:0 -> Squeeze Squeeze:0 -> output }',
onnx_to_graphviz(g))
def test_randomnormal(self):
with tf.Session() as sess:
x_ = tf.random_normal([2, 3], name="rand")
_ = tf.identity(x_, name="output")
g = process_tf_graph(sess.graph, opset=self.config.opset)
actual = onnx_to_graphviz(g)
expected = 'digraph { RandomNormal__2 [op_type=RandomNormal shape="[2, 3]"] output [op_type=Identity] ' \
'RandomNormal__2:0 -> output }'
self.assertEqual(expected, actual)
def test_randomuniform(self):
with tf.Session() as sess:
shape = tf.constant([2, 3], name="shape")
x_ = tf.random_uniform(shape, name="rand")
x_ = tf.identity(x_, name="output1")
x_ = tf.identity(x_, name="output2")
_ = tf.identity(x_, name="output")
g = process_tf_graph(sess.graph, opset=self.config.opset)
self.assertEqual(
'digraph { RandomUniform__2 [op_type=RandomUniform shape="[2, 3]"] output1 [op_type=Identity] '
'output2 [op_type=Identity] output [op_type=Identity] RandomUniform__2:0 -> output1 '
'output1:0 -> output2 output2:0 -> output }',
onnx_to_graphviz(g))
"""
A simple example how to call tensorflow-onnx via python.
"""
import tensorflow as tf
import tf2onnx
with tf.Session() as sess:
x = tf.placeholder(tf.float32, [2, 3], name="input")
x_ = tf.add(x, x)
_ = tf.identity(x_, name="output")
onnx_graph = tf2onnx.tfonnx.process_tf_graph(sess.graph, output_names=["output:0"])
model_proto = onnx_graph.make_model("test")
with open("/tmp/model.onnx", "wb") as f:
f.write(model_proto.SerializeToString())
if args.saved_model:
graph_def, inputs, outputs = loader.from_saved_model(
args.saved_model, args.inputs, args.outputs, args.signature_def)
model_path = args.saved_model
if args.verbose:
logger.info("inputs: %s", inputs)
logger.info("outputs: %s", outputs)
# todo: consider to enable const folding by default?
graph_def = tf_optimize(inputs, outputs, graph_def, args.fold_const)
with tf.Graph().as_default() as tf_graph:
tf.import_graph_def(graph_def, name='')
with tf.Session(graph=tf_graph):
g = process_tf_graph(tf_graph,
continue_on_error=args.continue_on_error,
target=args.target,
opset=args.opset,
custom_op_handlers=custom_ops,
extra_opset=extra_opset,
shape_override=args.shape_override,
input_names=inputs,
output_names=outputs,
inputs_as_nchw=args.inputs_as_nchw)
onnx_graph = optimizer.optimize_graph(g)
model_proto = onnx_graph.make_model("converted from {}".format(model_path))
# write onnx graph
logger.info("")
logger.info("Successfully converted TensorFlow model %s to ONNX", model_path)
settings: SerializationSettings, frozen_graph_def: tf.GraphDef
) -> Any:
# This is basically https://github.com/onnx/tensorflow-onnx/blob/master/tf2onnx/convert.py
inputs = _get_input_node_names(frozen_graph_def)
outputs = _get_output_node_names(frozen_graph_def)
logger.info(f"onnx export - inputs:{inputs} outputs:{outputs}")
frozen_graph_def = tf_optimize(
inputs, outputs, frozen_graph_def, fold_constant=True
)
with tf.Graph().as_default() as tf_graph:
tf.import_graph_def(frozen_graph_def, name="")
with tf.Session(graph=tf_graph):
g = process_tf_graph(
tf_graph,
input_names=inputs,
output_names=outputs,
opset=settings.onnx_opset,
)
onnx_graph = optimizer.optimize_graph(g)
model_proto = onnx_graph.make_model(settings.brain_name)
return model_proto
def convert_to_onnx(self, graph_def, inputs, outputs):
# FIXME: folding const = False
graph_def = tf2onnx.tfonnx.tf_optimize(
inputs, outputs, graph_def, False)
with tf.Graph().as_default() as tf_graph:
tf.import_graph_def(graph_def, name='')
with tf.Session(graph=tf_graph):
onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph,
continue_on_error=False,
verbose=False,
target=",".join(
constants.DEFAULT_TARGET),
opset=9,
input_names=inputs,
output_names=outputs,
inputs_as_nchw=None)
model_proto = onnx_graph.make_model(
"converted from {}".format(self._tf_file))
new_model_proto = GraphUtil.optimize_model_proto(model_proto)
if new_model_proto:
model_proto = new_model_proto
return model_proto
def to_onnx(model,fname="./frozen_out.onnx",scope=""):
tf.reset_default_graph()
model.load_graphdef()
model.import_graph(scope=scope)
tf.import_graph_def(
model.graph_def, {}, name=scope)
graph = tf.get_default_graph()
onnx_graph = tf2onnx.tfonnx.process_tf_graph(graph)
inp_name = model.input_name+":0"
out_name = model.layers[-1]['name']+":0"
print(inp_name,out_name)
model_proto = onnx_graph.make_model("", [inp_name], [out_name])
with open(fname, "wb") as f:
f.write(model_proto.SerializeToString())
print("Done...")