How to use the tf2onnx.utils.create_vague_shape_like function in tf2onnx

To help you get started, we’ve selected a few tf2onnx examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / rewriter / loop_rewriter.py View on Github external
loop_body_g.add_graph_input(utils.make_name("i"), TensorProto.INT64, ())
            loop_body_g.add_graph_input(utils.make_name("cond"), TensorProto.BOOL, ())
            for i, tensor_value_info in enumerate(loop_props.state_inputs):
                input_name = tensor_value_info.id
                if input_name is None:
                    # if the variable is not used in the body graph, then we created a fake one,
                    # the same type and shape as its corresponding output.
                    out_tensor_value_info = loop_props.state_outputs[i]
                    dtype = out_tensor_value_info.dtype
                    shape = out_tensor_value_info.shape
                    input_name = utils.make_name("unused_state_input_")
                else:
                    dtype = tensor_value_info.dtype
                    shape = tensor_value_info.shape

                loop_body_g.add_graph_input(input_name, dtype, utils.create_vague_shape_like(shape))

            for input_ta in loop_props.tensor_array_inputs:
                # Loop does not have scan inputs, so we use Gather to get data for each iteration.
                index_node = loop_body_g.make_node("Unsqueeze", [input_ta.index_input_id], attr={"axes": [0]})
                gather_node = loop_body_g.make_node("Gather", [input_ta.data_input_id, index_node.output[0]])
                data_node = loop_body_g.make_node("Squeeze", [gather_node.output[0]], attr={"axes": [0]})
                loop_body_g.replace_all_inputs(loop_body_g.get_nodes(), input_ta.consumer.id, data_node.output[0])

            ## create Loop node
            loop_node = self._create_loop_node(context, loop_props)
            if not loop_node:
                logger.error("failed to create loop node during rewrite")
                return REWRITER_RESULT.FAIL
            loop_node.set_body_graph_as_attr("body", loop_body_g)

            logger.debug("rewrite successfully")
github onnx / tensorflow-onnx / tf2onnx / function / select.py View on Github external
)

    fake_var_identity_name = utils.make_name("fake_var_output")
    fake_var_output_id = utils.port_name(fake_var_identity_name)
    g.make_node(
        'Identity',
        [fake_var_name],
        outputs=[fake_var_output_id],
        name=fake_var_identity_name
    )

    g.add_graph_output(cond_output_id, TensorProto.BOOL, ())
    g.add_graph_output(fake_var_output_id, TensorProto.FLOAT, ())

    # use None for all dims, just keep original rank. Because it is observed, dims might be changed in loop.
    g.add_graph_output(loop_output_id, output_data_type, utils.create_vague_shape_like(output_shape[1:]))

    return g
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / onnx_opset / controlflow.py View on Github external
def create_body_graph_for_if_branch(parent_g, data_type, output_shape, chosen_cur_cond_val_out_name, op_name):
    g = parent_g.create_new_graph_with_same_config()
    g.parent_graph = parent_g
    name = utils.make_name("Identity")
    g.make_node(
        'Identity',
        inputs=[chosen_cur_cond_val_out_name],
        outputs=['y'],
        name=name
    )
    g.add_graph_output("y", data_type, utils.create_vague_shape_like(output_shape))
    return g
github onnx / tensorflow-onnx / tf2onnx / onnx_opset / controlflow.py View on Github external
def create_body_graph_for_if_branch(parent_g, data_type, output_shape, chosen_cur_cond_val_out_name, op_name):
    g = parent_g.create_new_graph_with_same_config()
    g.parent_graph = parent_g
    name = utils.make_name("Identity")
    g.make_node(
        'Identity',
        inputs=[chosen_cur_cond_val_out_name],
        outputs=['y'],
        name=name
    )
    g.add_graph_output("y", data_type, utils.create_vague_shape_like(output_shape))
    return g
github onnx / tensorflow-onnx / tf2onnx / function / select.py View on Github external
def create_body_graph_for_if_branch(parent_g, data_type, output_shape, chosen_cur_cond_val_out_name, op_name):
    g = parent_g.create_new_graph_with_same_config()
    name = utils.make_name("Identity")
    g.make_node(
        'Identity',
        inputs=[chosen_cur_cond_val_out_name],
        outputs=['y'],
        name=name
    )
    g.add_graph_output("y", data_type, utils.create_vague_shape_like(output_shape))
    return g
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / onnx_opset / controlflow.py View on Github external
)

    fake_var_identity_name = utils.make_name("fake_var_output")
    fake_var_output_id = utils.port_name(fake_var_identity_name)
    g.make_node(
        'Identity',
        [fake_var_name],
        outputs=[fake_var_output_id],
        name=fake_var_identity_name
    )

    g.add_graph_output(cond_output_id, TensorProto.BOOL, ())
    g.add_graph_output(fake_var_output_id, TensorProto.FLOAT, ())

    # use None for all dims, just keep original rank. Because it is observed, dims might be changed in loop.
    g.add_graph_output(loop_output_id, output_data_type, utils.create_vague_shape_like(output_shape[1:]))

    return g