Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def unittest_main():
config = get_test_config()
logging.basicConfig(level=config.log_level)
with logging.set_scope_level(logging.INFO) as logger:
logger.info(config)
unittest.main()
def main():
args = get_args()
logging.basicConfig(level=logging.get_verbosity_level(args.verbose))
if args.debug:
utils.set_debug_mode(True)
logger = logging.getLogger(constants.TF2ONNX_PACKAGE_NAME)
extra_opset = args.extra_opset or []
custom_ops = {}
if args.custom_ops:
# default custom ops for tensorflow-onnx are in the "tf" namespace
custom_ops = {op: (default_custom_op_handler, []) for op in args.custom_ops.split(",")}
extra_opset.append(constants.TENSORFLOW_OPSET)
# get the frozen tensorflow model from graphdef, checkpoint or saved_model.
if args.graphdef:
graph_def, inputs, outputs = loader.from_graphdef(args.graphdef, args.inputs, args.outputs)
model_path = args.graphdef
def main():
args = get_args()
logging.basicConfig(level=logging.get_verbosity_level(args.verbose))
if args.debug:
utils.set_debug_mode(True)
logger = logging.getLogger(constants.TF2ONNX_PACKAGE_NAME)
extra_opset = args.extra_opset or []
custom_ops = {}
if args.custom_ops:
# default custom ops for tensorflow-onnx are in the "tf" namespace
custom_ops = {op: (default_custom_op_handler, []) for op in args.custom_ops.split(",")}
extra_opset.append(constants.TENSORFLOW_OPSET)
# get the frozen tensorflow model from graphdef, checkpoint or saved_model.
if args.graphdef:
graph_def, inputs, outputs = loader.from_graphdef(args.graphdef, args.inputs, args.outputs)
model_path = args.graphdef
import numpy as np
import tensorflow as tf
TF2 = tf.__version__.startswith("2.")
from tensorflow.core.framework import types_pb2, tensor_pb2
from tensorflow.python.framework import tensor_util
if not TF2:
from tensorflow.tools.graph_transforms import TransformGraph
from onnx import helper, onnx_pb, numpy_helper
from tf2onnx.utils import make_sure, is_tf_const_op, port_name, node_name
from . import logging
logger = logging.getLogger(__name__)
#
# mapping dtypes from tensorflow to onnx
#
TF_TO_ONNX_DTYPE = {
types_pb2.DT_FLOAT: onnx_pb.TensorProto.FLOAT,
types_pb2.DT_HALF: onnx_pb.TensorProto.FLOAT16,
types_pb2.DT_DOUBLE: onnx_pb.TensorProto.DOUBLE,
types_pb2.DT_INT32: onnx_pb.TensorProto.INT32,
types_pb2.DT_INT16: onnx_pb.TensorProto.INT16,
types_pb2.DT_INT8: onnx_pb.TensorProto.INT8,
types_pb2.DT_UINT8: onnx_pb.TensorProto.UINT8,
types_pb2.DT_UINT16: onnx_pb.TensorProto.UINT16,
types_pb2.DT_INT64: onnx_pb.TensorProto.INT64,
types_pb2.DT_STRING: onnx_pb.TensorProto.STRING,
types_pb2.DT_COMPLEX64: onnx_pb.TensorProto.COMPLEX64,
try:
ops = func(g, g.get_nodes())
g.reset_nodes(ops)
except Exception as ex:
type_, value_, traceback_ = sys.exc_info()
logger.error("rewriter %s: exception %s", func, ex)
ex_ext = traceback.format_exception(type_, value_, traceback_)
if continue_on_error:
logger.info(ex_ext)
else:
raise ex
if utils.is_debug_mode():
broken_outputs = g.check_integrity()
if broken_outputs:
logging.error(
"After rewriter %s, graph breaks at outputs %s",
func.__name__, broken_outputs
)
if g.contained_graphs:
for dict_val in g.contained_graphs.values():
for attr_name, b_g in dict_val.items():
run_rewriters(b_g, funcs, attr_name)
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
"""
tf2onnx.graph_helper - class to help building graph, such as helping to make complex node
"""
import numpy as np
from tf2onnx import utils, logging
# pylint: disable=missing-docstring
logger = logging.getLogger(__name__)
class GraphBuilder(object):
"""help to build graph"""
def __init__(self, graph):
self._g = graph
@property
def graph(self):
return self._g
def make_slice(self, kwargs, name=None, shapes=None, dtypes=None):
"""
slice changes its schema at opset 10: it treats some attributes as dynamic input
so this function has to process inputs according to graph's opset version
to get "inputs" and "attr" to feed "make_node"
import sys
import traceback
import numpy as np
from onnx import onnx_pb
import tf2onnx
import tf2onnx.onnx_opset # pylint: disable=unused-import
import tf2onnx.custom_opsets # pylint: disable=unused-import
from tf2onnx.graph import Graph
from tf2onnx.rewriter import * # pylint: disable=wildcard-import
from tf2onnx.shape_inference import infer_shape
from tf2onnx.tf_utils import tensorflow_to_onnx, get_tf_version
from . import constants, logging, schemas, utils, handler
logger = logging.getLogger(__name__)
# pylint: disable=useless-return,broad-except,logging-not-lazy,unused-argument,missing-docstring
# FIXME:
# pylint: disable=unused-variable
def rewrite_constant_fold(g, ops):
"""
We call tensorflow transform with constant folding but in some cases tensorflow does
fold all constants. Since there are a bunch of ops in onnx that use attributes where
tensorflow has dynamic inputs, we badly want constant folding to work. For cases where
tensorflow missed something, make another pass over the graph and fix want we care about.
"""
func_map = {
"Add": np.add,