Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def main():
args = get_args()
logging.basicConfig(level=logging.get_verbosity_level(args.verbose))
if args.debug:
utils.set_debug_mode(True)
logger = logging.getLogger(constants.TF2ONNX_PACKAGE_NAME)
extra_opset = args.extra_opset or []
custom_ops = {}
if args.custom_ops:
# default custom ops for tensorflow-onnx are in the "tf" namespace
custom_ops = {op: (default_custom_op_handler, []) for op in args.custom_ops.split(",")}
extra_opset.append(constants.TENSORFLOW_OPSET)
# get the frozen tensorflow model from graphdef, checkpoint or saved_model.
if args.graphdef:
graph_def, inputs, outputs = loader.from_graphdef(args.graphdef, args.inputs, args.outputs)
model_path = args.graphdef
if args.checkpoint:
graph_def, inputs, outputs = loader.from_checkpoint(args.checkpoint, args.inputs, args.outputs)
model_path = args.checkpoint
if args.saved_model:
def version_9(cls, ctx, node, **kwargs):
# T output = OneHot(uint8/int32/int64 input, T depth, T on-value, T off-value, @int axis, @dtype)
# tf requires that dtype is same as on-value's and off-value's dtype
# in ONNX, op's schema is (input, depth, value, @int axis), meaning of "value" is [off-value, on-value]
# onnxruntime only supports int64
output_dtype = ctx.get_dtype(node.input[2])
if ctx.is_target(constants.TARGET_RS6) \
and output_dtype not in [onnx_pb.TensorProto.INT64, onnx_pb.TensorProto.INT32]:
logger.warning("unsupported dtype in onnxruntime, onehot-9 can't be used directly")
cls.version_1(ctx, node, **kwargs)
return
depth = node.input[1]
depth = ctx.make_node("Unsqueeze", [depth], attr={"axes": [0]}).output[0]
on_value = node.input[2]
off_value = node.input[3]
on_value = ctx.make_node("Unsqueeze", [on_value], attr={"axes": [0]}).output[0]
off_value = ctx.make_node("Unsqueeze", [off_value], attr={"axes": [0]}).output[0]
off_on_value = ctx.make_node("Concat", [off_value, on_value], attr={"axis": 0}).output[0]
indices = node.input[0]
if ctx.is_target(constants.TARGET_RS6) \
new_kernel_shape: reshape the kernel
"""
if input_indices is None:
input_indices = [0]
if output_indices is None:
output_indices = [0]
if node.is_nhwc():
# transpose input if needed, no need to record shapes on input
for idx in input_indices:
parent = node.inputs[idx]
if node.inputs[idx].is_const() and len(ctx.find_output_consumers(node.input[1])) == 1:
# if input is a constant, transpose that one if we are the only consumer
val = parent.get_tensor_value(as_list=False)
parent.set_tensor_value(val.transpose(constants.NHWC_TO_NCHW))
else:
# if input comes from a op, insert transpose op
input_name = node.input[idx]
transpose = ctx.insert_new_node_on_input(node, "Transpose", input_name)
transpose.set_attr("perm", constants.NHWC_TO_NCHW)
transpose.skip_conversion = True
shape = ctx.get_shape(input_name)
if shape is not None:
new_shape = spatial_map(shape, constants.NHWC_TO_NCHW)
ctx.set_shape(transpose.output[0], new_shape)
# kernel must to be transposed
if with_kernel:
parent = node.inputs[1]
need_transpose = True
if node.inputs[1].is_const():
"""Insert a transpose from NHWC to NCHW on model input on users request."""
ops = []
for node in ctx.get_nodes():
for idx, output_name in enumerate(node.output):
if output_name in inputs_as_nchw:
shape = ctx.get_shape(output_name)
if len(shape) != len(constants.NCHW_TO_NHWC):
logger.warning("transpose_input for %s: shape must be rank 4, ignored" % output_name)
ops.append(node)
continue
# insert transpose
op_name = utils.make_name(node.name)
transpose = ctx.insert_new_node_on_output("Transpose", output_name, name=op_name)
transpose.set_attr("perm", constants.NCHW_TO_NHWC)
ctx.copy_shape(output_name, transpose.output[0])
ctx.set_shape(output_name, np.array(shape)[constants.NHWC_TO_NCHW])
ops.append(transpose)
ops.append(node)
continue
ops.append(node)
ctx.reset_nodes(ops)
return helper.make_opsetid(domain, version)
def is_onnx_domain(domain):
if domain is None or domain == "":
return True
return False
def parse_bool(val):
if val is None:
return False
return val.lower() in ("yes", "true", "t", "y", "1")
_is_debug_mode = parse_bool(os.environ.get(constants.ENV_TF2ONNX_DEBUG_MODE))
def is_debug_mode():
return _is_debug_mode
def set_debug_mode(enabled):
global _is_debug_mode
_is_debug_mode = enabled
def get_max_value(np_dtype):
return np.iinfo(np_dtype).max
def get_min_value(np_dtype):
def set_level(level):
""" Set logging level for tf2onnx package. tf verbosity is updated accordingly. """
_logging.getLogger(constants.TF2ONNX_PACKAGE_NAME).setLevel(level)
set_tf_verbosity(level)
scales = ctx.make_node("Concat", [const_one_array.output[0], scales_hw.output[0]], {"axis": 0})
# because onnxruntime only supports to scale the last two dims so transpose is inserted
input_nchw = ctx.make_node("Transpose", [node.input[0]], {"perm": constants.NHWC_TO_NCHW})
roi = ctx.make_const(tf2onnx.utils.make_name("roi"), np.array([]).astype(np.float32))
attrs = {"mode": mode}
attrs['coordinate_transformation_mode'] = 'asymmetric'
if attrs['mode'] == 'nearest':
attrs['nearest_mode'] = 'floor'
upsample = ctx.make_node("Resize", [input_nchw.output[0], roi.output[0], scales.output[0]],
attr=attrs)
shapes = node.output_shapes
dtypes = node.output_dtypes
ctx.remove_node(node.name)
ctx.make_node("Transpose", upsample.output, {"perm": constants.NCHW_TO_NHWC},
name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
def convert_to_onnx(self, graph_def, inputs, outputs):
# FIXME: folding const = False
graph_def = tf2onnx.tfonnx.tf_optimize(
inputs, outputs, graph_def, False)
with tf.Graph().as_default() as tf_graph:
tf.import_graph_def(graph_def, name='')
with tf.Session(graph=tf_graph):
onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph,
continue_on_error=False,
verbose=False,
target=",".join(
constants.DEFAULT_TARGET),
opset=9,
input_names=inputs,
output_names=outputs,
inputs_as_nchw=None)
model_proto = onnx_graph.make_model(
"converted from {}".format(self._tf_file))
new_model_proto = GraphUtil.optimize_model_proto(model_proto)
if new_model_proto:
model_proto = new_model_proto
return model_proto
def find_opset(opset):
"""Find opset."""
if opset is None or opset == 0:
opset = defs.onnx_opset_version()
if opset > constants.PREFERRED_OPSET:
# if we use a newer onnx opset than most runtimes support, default to the one most supported
opset = constants.PREFERRED_OPSET
return opset
def version_6(cls, ctx, node, **kwargs):
"""Elementwise Ops with broadcast flag."""
if node.type == "AddV2":
node.type = "Add"
shape0 = ctx.get_shape(node.input[0])
shape1 = ctx.get_shape(node.input[1])
if shape0 != shape1:
# this works around shortcomings in the broadcasting code
# of caffe2 and winml/rs4.
if ctx.is_target(constants.TARGET_RS4):
# in rs4 mul and add do not support scalar correctly
if not shape0:
if node.inputs[0].is_const():
shape0 = node.inputs[0].scalar_to_dim1()
if not shape1:
if node.inputs[1].is_const():
shape1 = node.inputs[1].scalar_to_dim1()
if shape0 and shape1 and len(shape0) < len(shape1) and node.type in ["Mul", "Add", "AddV2"]:
tmp = node.input[0]
node.input[0] = node.input[1]
node.input[1] = tmp