Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
@tf_op("Identity")
class Identity:
@classmethod
def version_1(cls, ctx, node, **kwargs):
if node.inputs[0].is_const():
# should not remove the identity node if it is output of the graph
if node.output[0] in ctx.outputs:
return
# if identity has a const as input, remove it
input_name = node.input[0]
output_name = node.output[0]
ctx.replace_all_inputs(ctx.get_nodes(), output_name, input_name)
ctx.remove_node(node.name)
@tf_op("Reshape")
class Reshape:
@classmethod
def version_1(cls, ctx, node, **kwargs):
# T output = Reshape(T tensor, Tshape shape, @type Tshape)
# T reshaped = Reshape(T data, @INTS shape) - but takes a optional 2nd input for shape
shape_node = node.inputs[1]
shape = shape_node.get_tensor_value()
if shape is None:
logger.error("Reshape on node %s does not have a const shape", node.name)
return
ctx.remove_input(node, node.input[1])
node.set_attr("shape", shape)
ctx.set_shape(node.output[0], shape)
@classmethod
def version_5(cls, ctx, node, **kwargs):
new_node = ctx.insert_new_node_on_output("Min", new_node.output[0], name=utils.make_name(name))
new_node.input.append(max_node.output[0])
# copy shape and type
ctx.set_dtype(new_node.output[0], dtypes[0])
ctx.set_shape(new_node.output[0], shapes[0])
if dtypes[0] not in supported:
# cast output if needed
new_node = ctx.insert_new_node_on_output("Cast", new_node.output[0],
name=utils.make_name(name), to=dtypes[0])
# copy shape and type
ctx.set_dtype(new_node.output[0], dtypes[0])
ctx.set_shape(new_node.output[0], shapes[0])
@tf_op("Softmax")
class Softmax:
@classmethod
def version_1(cls, ctx, node, **kwargs):
# T output = Softmax(T logits). The axis softmax would be performed on is always on -1.
# T output = Softmax(T input, @int axis). Default axis is 1.
logits_rank = len(ctx.get_shape(node.input[0]))
node.set_attr("axis", logits_rank - 1)
@classmethod
def version_11(cls, ctx, node, **kwargs):
# opset 11 supports -ve axis
pass
@tf_op("Square")
class Square:
# cast each inputs to float
for i, inp in enumerate(node.inputs):
input_cast = ctx.insert_new_node_on_input(node, "Cast", node.input[i])
input_cast.set_attr("to", onnx_pb.TensorProto.FLOAT)
ctx.set_dtype(input_cast.output[0], onnx_pb.TensorProto.FLOAT)
next_nodes = ctx.find_output_consumers(node.output[0])
# cast output back to dtype unless the next op is a cast
if next_nodes[0].type != "Cast":
op_name = utils.make_name(node.name)
output_cast = ctx.insert_new_node_on_output("Cast", output_name, name=op_name)
output_cast.set_attr("to", dtype)
ctx.set_dtype(output_cast.output[0], dtype)
ctx.copy_shape(output_name, output_cast.output[0])
@tf_op("Size")
class Size:
@classmethod
def version_1(cls, ctx, node, **kwargs):
pass
@tf_op("Flatten")
class Flatten:
@classmethod
def version_1(cls, ctx, node, **kwargs):
pass
@classmethod
def version_9(cls, ctx, node, **kwargs):
# no change for us
cls.version_1(ctx, node, **kwargs)
axis_val = axis_node.get_tensor_value()
ctx.remove_input(node, node.input[-1])
if axis_val < 0: # onnxruntime does not support -1 axis, but TF supports.
input_shape = ctx.get_shape(node.input[0])
utils.make_sure(input_shape is not None, "shape of {} is None".format(node.input[0]))
axis_val = len(input_shape) + axis_val
node.set_attr("axis", axis_val)
if ctx.opset < 8:
# opset < 8: might need to wrap concat in casts since only float is supported
_wrap_concat_with_cast(ctx, node)
return
@tf_op("Slice")
class Slice:
@classmethod
def version_1(cls, ctx, node, **kwargs):
# T output = Slice(T input, Index begin, Index size)
# T output = Slice(T input, Tind starts, Tind ends, Tind axes, Tind steps)
# "ends" are exclusive, "axes" and "steps" are optional, their default val are [0, ...] and 1
input_tensor = node.input[0]
starts = node.input[1]
size = node.input[2]
# in tf, size can be -1 which means all elem are taken, so size can't be added starts directly.
# the way to make sure size are not less than 0: set "sizes"'s elem to be int_max if elem val is -1
size_dtype = ctx.get_dtype(size)
size_np_dtype = utils.map_onnx_to_numpy_type(size_dtype)
if ctx.get_node_by_output(size).is_const() and ctx.get_node_by_output(starts).is_const():
starts = ctx.get_node_by_output(starts).get_tensor_value()
sizes = ctx.get_node_by_output(size).get_tensor_value()
ctx.remove_node(node.name)
node = ctx.make_node(
"ReverseSequence",
node.input,
outputs=node.output,
attr={"batch_axis": batch_dim, "time_axis": seq_dim})
seq_len_dtype = ctx.get_dtype(node.input[1])
utils.make_sure(seq_len_dtype is not None, "dtype of {} is None".format(node.input[1]))
target_dtype = TensorProto.INT64
if seq_len_dtype != target_dtype:
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=target_dtype)
@tf_op("ReverseV2")
class ReverseV2:
@classmethod
def version_10(cls, ctx, node, **kwargs):
# T output = ReverseV2(T input, int32|int64 seq_lengths, @int seq_dim, @int batch_dim)
# Implement tensorflow ReverseV2 op using multiple ReverseSequence (for each axis)
# and Transpose ops. We sort the axis vector (if non-empty) at the start. Each axis can
# be reversed only once (in tf) and so we can compute the transpose for each axis
# (other than 0), feed the tensor to a ReverseSequence node and finally transpose again
# to get back the original shape.
axes_node = node.inputs[1]
axes = axes_node.get_tensor_value(as_list=False)
# Current support is for when axis is a 1D tensor.
utils.make_sure(len(axes.shape) == 1 \
, "Currently no support for reverseV2 tensor axis")
reorganize_node = ctx.make_node(node.type, trans1.output, attr={"blocksize": blocksize[0]})
ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 2, 3, 0]}, name=node.name, outputs=node.output,
shapes=shapes, dtypes=dtypes)
@tf_op("IsInf", onnx_op="IsInf")
class IsInf:
@classmethod
def version_10(cls, ctx, node, **kwargs):
node_dtype = ctx.get_dtype(node.input[0])
utils.make_sure(node_dtype, "Dtype of {} is None".format(node.name))
if node_dtype not in [onnx_pb.TensorProto.FLOAT, onnx_pb.TensorProto.DOUBLE]:
raise ValueError("dtype " + str(node_dtype) + " is not supported in onnx for now")
@tf_op(["NonMaxSuppressionV2", "NonMaxSuppressionV3"], onnx_op="NonMaxSuppression")
class NonMaxSuppression:
@classmethod
def version_10(cls, ctx, node, **kwargs):
# int32 = NonMaxSuppressionV2(T boxes, T scores, int32 max_output_size, T iou_threshold, T score_threshold)
# int64 = NonMaxSuppression(T boxes, T scores, int64 max_output_size, T iou_threshold, T score_threshold),
# T means float32 here, the last 3 params are optional
# tf boxes is 2D ([boxes_num, 4]) while onnx is 3D ([num_batches, boxes_num, 4])
# tf scores is 1D ([boxes_num])while onnx is 2D ([num_batches, num_classes, boxes_num])
# onnx output is [num_selected_boxes, 3], the meaning of last dim is [batch_index, class_index, box_index]
# while tf's output is [num_selected_boxes]
ctx.insert_new_node_on_input(node, "Unsqueeze", node.input[0], axes=[0])
ctx.insert_new_node_on_input(node, "Unsqueeze", node.input[1], axes=[0, 1])
ctx.insert_new_node_on_input(node, "Cast", node.input[2], to=onnx_pb.TensorProto.INT64)
# replace original node with nonmaxsurppress + slice + squeeze + cast
dtypes = [ctx.get_dtype(node.output[0])]
shapes = [ctx.get_shape(node.output[0])]
@classmethod
def version_1(cls, ctx, node, **kwargs):
pass
@classmethod
def version_9(cls, ctx, node, **kwargs):
# no change for us
cls.version_1(ctx, node, **kwargs)
@classmethod
def version_11(cls, ctx, node, **kwargs):
# no change
cls.version_1(ctx, node, **kwargs)
@tf_op("Dropout")
class Dropout:
@classmethod
def version_1(cls, ctx, node, **kwargs):
pass
@classmethod
def version_6(cls, ctx, node, **kwargs):
pass
@classmethod
def version_7(cls, ctx, node, **kwargs):
pass
@classmethod
def version_10(cls, ctx, node, **kwargs):
pass
utils.make_sure(shape is not None, "squeeze input shape cannot be None")
shape_len = len(shape)
axis = [a + shape_len if a < 0 else a for a in axis]
else:
shape = ctx.get_shape(node.input[0])
utils.make_sure(shape is not None, "squeeze input shape cannot be None")
axis = [i for i, j in enumerate(shape) if j == 1]
node.set_attr("axes", axis)
@classmethod
def version_11(cls, ctx, node, **kwargs):
# Opset 11 supports negative axis, but core logic is same
cls.version_1(ctx, node, **kwargs)
@tf_op("Transpose")
class Transpose:
@classmethod
def version_1(cls, ctx, node, **kwargs):
# T y = Transpose(T x, Tperm perm, @type Tperm)
# T transposed = Transpose(T data, @INTS perm)
if len(node.input) > 1:
perm = node.inputs[1]
if perm.is_const():
# perms is passed as const
dims = perm.get_tensor_value()
ctx.remove_input(node, node.input[1])
node.set_attr("perm", dims)
else:
utils.make_sure(False, "perm can't be dynamic in ONNX")
else:
# graph rewrite moved perm to attribute
has_correct_shape.append(input_name)
if needs_broadcast_op:
has_correct_shape = has_correct_shape[0]
for i in needs_broadcast_op:
input_node = node.inputs[i]
# get a tensor with zeros (since there is no Fill op as of opset8)
sub_node = ctx.make_node("Sub", [has_correct_shape, has_correct_shape],
op_name_scope=input_node.name)
# use add as 'broadcast' op
add_node = ctx.make_node("Add", [input_node.output[0], sub_node.output[0]],
op_name_scope=input_node.name)
node.input[i] = add_node.output[0]
@tf_op("Minimum", onnx_op="Min")
@tf_op("Maximum", onnx_op="Max")
class MinMaxOp:
@classmethod
def version_1(cls, ctx, node, **kwargs):
shapes = node.output_shapes
dtypes = node.output_dtypes
ctx.remove_node(node.name)
make_min_or_max_op(ctx, node.type, node.input, node.output, shapes, dtypes)
@tf_op("Softmax")
class Softmax:
@classmethod
def version_1(cls, ctx, node, **kwargs):
# T output = Softmax(T logits). The axis softmax would be performed on is always on -1.
# T output = Softmax(T input, @int axis). Default axis is 1.
logits_rank = len(ctx.get_shape(node.input[0]))
@classmethod
def version_7(cls, ctx, node, **kwargs):
# T2 output = Greater(T1 x, T1 y), T2=tensor(bool)
# T2 output = Less(T1 x, T1 y), T2=tensor(bool)
# Great/Less in opset7 only supports limited types, insert Cast if needed
supported_dtypes = [
TensorProto.FLOAT,
TensorProto.FLOAT16,
TensorProto.DOUBLE
]
target_dtype = TensorProto.FLOAT
_add_cast_to_inputs(ctx, node, supported_dtypes, target_dtype)
@tf_op("GreaterEqual", onnx_op="Less")
@tf_op("LessEqual", onnx_op="Greater")
class GreaterLessEqual:
@classmethod
def version_7(cls, ctx, node, **kwargs):
GreaterLess.version_7(ctx, node, **kwargs)
output_name = node.output[0]
new_node = ctx.insert_new_node_on_output("Not", output_name, name=utils.make_name(node.name))
ctx.copy_shape(output_name, new_node.output[0])
ctx.set_dtype(new_node.output[0], ctx.get_dtype(output_name))