How to use the tf2onnx.handler.tf_op function in tf2onnx

To help you get started, we’ve selected a few tf2onnx examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github onnx / tensorflow-onnx / tf2onnx / onnx_opset / View on Github external
class Identity:
    def version_1(cls, ctx, node, **kwargs):
        if node.inputs[0].is_const():
            # should not remove the identity node if it is output of the graph
            if node.output[0] in ctx.outputs:
            # if identity has a const as input, remove it
            input_name = node.input[0]
            output_name = node.output[0]
            ctx.replace_all_inputs(ctx.get_nodes(), output_name, input_name)

class Reshape:
    def version_1(cls, ctx, node, **kwargs):
        # T output = Reshape(T tensor, Tshape shape, @type Tshape)
        # T reshaped = Reshape(T data, @INTS shape) - but takes a optional 2nd input for shape
        shape_node = node.inputs[1]
        shape = shape_node.get_tensor_value()
        if shape is None:
            logger.error("Reshape on node %s does not have a const shape",
        ctx.remove_input(node, node.input[1])
        node.set_attr("shape", shape)
        ctx.set_shape(node.output[0], shape)

    def version_5(cls, ctx, node, **kwargs):
github onnx / tensorflow-onnx / tf2onnx / onnx_opset / View on Github external
new_node = ctx.insert_new_node_on_output("Min", new_node.output[0], name=utils.make_name(name))
        # copy shape and type
        ctx.set_dtype(new_node.output[0], dtypes[0])
        ctx.set_shape(new_node.output[0], shapes[0])
        if dtypes[0] not in supported:
            # cast output if needed
            new_node = ctx.insert_new_node_on_output("Cast", new_node.output[0],
                                                     name=utils.make_name(name), to=dtypes[0])
            # copy shape and type
            ctx.set_dtype(new_node.output[0], dtypes[0])
            ctx.set_shape(new_node.output[0], shapes[0])

class Softmax:
    def version_1(cls, ctx, node, **kwargs):
        # T output = Softmax(T logits). The axis softmax would be performed on is always on -1.
        # T output = Softmax(T input, @int axis). Default axis is 1.
        logits_rank = len(ctx.get_shape(node.input[0]))
        node.set_attr("axis", logits_rank - 1)

    def version_11(cls, ctx, node, **kwargs):
        # opset 11 supports -ve axis

class Square:
github onnx / tensorflow-onnx / tf2onnx / onnx_opset / View on Github external
# cast each inputs to float
        for i, inp in enumerate(node.inputs):
            input_cast = ctx.insert_new_node_on_input(node, "Cast", node.input[i])
            input_cast.set_attr("to", onnx_pb.TensorProto.FLOAT)
            ctx.set_dtype(input_cast.output[0], onnx_pb.TensorProto.FLOAT)
        next_nodes = ctx.find_output_consumers(node.output[0])
        # cast output back to dtype unless the next op is a cast
        if next_nodes[0].type != "Cast":
            op_name = utils.make_name(
            output_cast = ctx.insert_new_node_on_output("Cast", output_name, name=op_name)
            output_cast.set_attr("to", dtype)
            ctx.set_dtype(output_cast.output[0], dtype)
            ctx.copy_shape(output_name, output_cast.output[0])

class Size:
    def version_1(cls, ctx, node, **kwargs):

class Flatten:
    def version_1(cls, ctx, node, **kwargs):

    def version_9(cls, ctx, node, **kwargs):
        # no change for us
        cls.version_1(ctx, node, **kwargs)
github onnx / tensorflow-onnx / tf2onnx / onnx_opset / View on Github external
axis_val = axis_node.get_tensor_value()
        ctx.remove_input(node, node.input[-1])

        if axis_val < 0:  # onnxruntime does not support -1 axis, but TF supports.
            input_shape = ctx.get_shape(node.input[0])
            utils.make_sure(input_shape is not None, "shape of {} is None".format(node.input[0]))
            axis_val = len(input_shape) + axis_val
        node.set_attr("axis", axis_val)

        if ctx.opset < 8:
            # opset < 8: might need to wrap concat in casts since only float is supported
            _wrap_concat_with_cast(ctx, node)

class Slice:
    def version_1(cls, ctx, node, **kwargs):
        # T output = Slice(T input, Index begin, Index size)
        # T output = Slice(T input, Tind starts, Tind ends, Tind axes, Tind steps)
        # "ends" are exclusive, "axes" and "steps" are optional, their default val are [0, ...] and 1
        input_tensor = node.input[0]
        starts = node.input[1]
        size = node.input[2]
        # in tf, size can be -1 which means all elem are taken, so size can't be added starts directly.
        # the way to make sure size are not less than 0: set "sizes"'s elem to be int_max if elem val is -1
        size_dtype = ctx.get_dtype(size)
        size_np_dtype = utils.map_onnx_to_numpy_type(size_dtype)
        if ctx.get_node_by_output(size).is_const() and ctx.get_node_by_output(starts).is_const():
            starts = ctx.get_node_by_output(starts).get_tensor_value()
            sizes = ctx.get_node_by_output(size).get_tensor_value()
github onnx / tensorflow-onnx / tf2onnx / onnx_opset / View on Github external
        node = ctx.make_node(
            attr={"batch_axis": batch_dim, "time_axis": seq_dim})

        seq_len_dtype = ctx.get_dtype(node.input[1])
        utils.make_sure(seq_len_dtype is not None, "dtype of {} is None".format(node.input[1]))
        target_dtype = TensorProto.INT64
        if seq_len_dtype != target_dtype:
            ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=target_dtype)

class ReverseV2:
    def version_10(cls, ctx, node, **kwargs):
        # T output = ReverseV2(T input, int32|int64 seq_lengths, @int seq_dim, @int batch_dim)
        # Implement tensorflow ReverseV2 op using multiple ReverseSequence (for each axis)
        # and Transpose ops. We sort the axis vector (if non-empty) at the start. Each axis can
        # be reversed only once (in tf) and so we can compute the transpose for each axis
        # (other than 0), feed the tensor to a ReverseSequence node and finally transpose again
        # to get back the original shape.

        axes_node = node.inputs[1]
        axes = axes_node.get_tensor_value(as_list=False)
        # Current support is for when axis is a 1D tensor.
        utils.make_sure(len(axes.shape) == 1 \
                        , "Currently no support for reverseV2 tensor axis")
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / onnx_opset / View on Github external
reorganize_node = ctx.make_node(node.type, trans1.output, attr={"blocksize": blocksize[0]})
        ctx.make_node("Transpose", reorganize_node.output, {"perm": [1, 2, 3, 0]},, outputs=node.output,
                      shapes=shapes, dtypes=dtypes)

@tf_op("IsInf", onnx_op="IsInf")
class IsInf:
    def version_10(cls, ctx, node, **kwargs):
        node_dtype = ctx.get_dtype(node.input[0])
        utils.make_sure(node_dtype, "Dtype of {} is None".format(
        if node_dtype not in [onnx_pb.TensorProto.FLOAT, onnx_pb.TensorProto.DOUBLE]:
            raise ValueError("dtype " + str(node_dtype) + " is not supported in onnx for now")

@tf_op(["NonMaxSuppressionV2", "NonMaxSuppressionV3"], onnx_op="NonMaxSuppression")
class NonMaxSuppression:
    def version_10(cls, ctx, node, **kwargs):
        # int32 = NonMaxSuppressionV2(T boxes, T scores, int32 max_output_size, T iou_threshold, T score_threshold)
        # int64 = NonMaxSuppression(T boxes, T scores, int64 max_output_size, T iou_threshold, T score_threshold),
        # T means float32 here, the last 3 params are optional
        # tf boxes is 2D ([boxes_num, 4]) while onnx is 3D ([num_batches, boxes_num, 4])
        # tf scores is 1D ([boxes_num])while onnx is 2D ([num_batches, num_classes, boxes_num])
        # onnx output is [num_selected_boxes, 3], the meaning of last dim is [batch_index, class_index, box_index]
        # while tf's output is [num_selected_boxes]
        ctx.insert_new_node_on_input(node, "Unsqueeze", node.input[0], axes=[0])
        ctx.insert_new_node_on_input(node, "Unsqueeze", node.input[1], axes=[0, 1])
        ctx.insert_new_node_on_input(node, "Cast", node.input[2], to=onnx_pb.TensorProto.INT64)
        # replace original node with nonmaxsurppress + slice + squeeze + cast
        dtypes = [ctx.get_dtype(node.output[0])]
        shapes = [ctx.get_shape(node.output[0])]
github onnx / tensorflow-onnx / tf2onnx / onnx_opset / View on Github external
    def version_1(cls, ctx, node, **kwargs):

    def version_9(cls, ctx, node, **kwargs):
        # no change for us
        cls.version_1(ctx, node, **kwargs)

    def version_11(cls, ctx, node, **kwargs):
        # no change
        cls.version_1(ctx, node, **kwargs)

class Dropout:
    def version_1(cls, ctx, node, **kwargs):

    def version_6(cls, ctx, node, **kwargs):

    def version_7(cls, ctx, node, **kwargs):

    def version_10(cls, ctx, node, **kwargs):
github onnx / tensorflow-onnx / tf2onnx / onnx_opset / View on Github external
utils.make_sure(shape is not None, "squeeze input shape cannot be None")
                shape_len = len(shape)
                axis = [a + shape_len if a < 0 else a for a in axis]
            shape = ctx.get_shape(node.input[0])
            utils.make_sure(shape is not None, "squeeze input shape cannot be None")
            axis = [i for i, j in enumerate(shape) if j == 1]
        node.set_attr("axes", axis)

    def version_11(cls, ctx, node, **kwargs):
        # Opset 11 supports negative axis, but core logic is same
        cls.version_1(ctx, node, **kwargs)

class Transpose:
    def version_1(cls, ctx, node, **kwargs):
        # T y = Transpose(T x, Tperm perm, @type Tperm)
        # T transposed = Transpose(T data, @INTS perm)
        if len(node.input) > 1:
            perm = node.inputs[1]
            if perm.is_const():
                # perms is passed as const
                dims = perm.get_tensor_value()
                ctx.remove_input(node, node.input[1])
                node.set_attr("perm", dims)
                utils.make_sure(False, "perm can't be dynamic in ONNX")
            # graph rewrite moved perm to attribute
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / onnx_opset / View on Github external
    if needs_broadcast_op:
        has_correct_shape = has_correct_shape[0]
        for i in needs_broadcast_op:
            input_node = node.inputs[i]
            # get a tensor with zeros (since there is no Fill op as of opset8)
            sub_node = ctx.make_node("Sub", [has_correct_shape, has_correct_shape],
            # use add as 'broadcast' op
            add_node = ctx.make_node("Add", [input_node.output[0], sub_node.output[0]],
            node.input[i] = add_node.output[0]

@tf_op("Minimum", onnx_op="Min")
@tf_op("Maximum", onnx_op="Max")
class MinMaxOp:
    def version_1(cls, ctx, node, **kwargs):
        shapes = node.output_shapes
        dtypes = node.output_dtypes
        make_min_or_max_op(ctx, node.type, node.input, node.output, shapes, dtypes)

class Softmax:
    def version_1(cls, ctx, node, **kwargs):
        # T output = Softmax(T logits). The axis softmax would be performed on is always on -1.
        # T output = Softmax(T input, @int axis). Default axis is 1.
        logits_rank = len(ctx.get_shape(node.input[0]))
github onnx / tensorflow-onnx / tf2onnx / onnx_opset / View on Github external
    def version_7(cls, ctx, node, **kwargs):
        # T2 output = Greater(T1 x, T1 y), T2=tensor(bool)
        # T2 output = Less(T1 x, T1 y), T2=tensor(bool)
        # Great/Less in opset7 only supports limited types, insert Cast if needed
        supported_dtypes = [
        target_dtype = TensorProto.FLOAT
        _add_cast_to_inputs(ctx, node, supported_dtypes, target_dtype)

@tf_op("GreaterEqual", onnx_op="Less")
@tf_op("LessEqual", onnx_op="Greater")
class GreaterLessEqual:
    def version_7(cls, ctx, node, **kwargs):
        GreaterLess.version_7(ctx, node, **kwargs)
        output_name = node.output[0]
        new_node = ctx.insert_new_node_on_output("Not", output_name, name=utils.make_name(
        ctx.copy_shape(output_name, new_node.output[0])
        ctx.set_dtype(new_node.output[0], ctx.get_dtype(output_name))