Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_rewrite_subgraph(self):
graph_proto = self.sample_net()
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
pattern = \
OpTypePattern('Abs', name='output', inputs=[
OpTypePattern('Add', name='input')
])
ops = g.get_nodes()
matcher = GraphMatcher(pattern)
match_results = list(matcher.match_ops(ops))
for match in match_results:
input_node = match.get_op('input')
output_node = match.get_op('output')
op_name = utils.make_name("ReplacedOp")
out_name = utils.port_name(op_name)
new_node = g.make_node("Sub", inputs=input_node.input, outputs=[out_name], name=op_name)
g.replace_all_inputs(ops, output_node.output[0], new_node.output[0])
for n in set(match.get_nodes()):
g.remove_node(n.name)
g.topological_sort(ops)
result = onnx_to_graphviz(g)
expected = 'digraph { Placeholder__4 [op_type=Placeholder] n1 [op_type=Abs] ' \
'n3 [op_type=Abs] n2 [op_type=Abs] ReplacedOp__5 [op_type=Sub] ' \
'n6 [op_type=Identity] n5_graph_outputs_Identity__3 [op_type=Identity] ' \
'input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> ReplacedOp__5 n3:0 -> ReplacedOp__5 ' \
OpTypePattern('Reshape', name='reshape', inputs=[
OpTypePattern("*", name="input"),
OpTypePattern('Pack', name="pack", inputs=[
OpTypePattern('StridedSlice', name="slice", inputs=[
OpTypePattern('Shape', inputs=[
OpTypePattern("*", name="input2")
]),
"*", "*", "*",
]),
"*",
]),
])
matcher = GraphMatcher(pattern_fixed_shape_input)
match_results_1 = list(matcher.match_ops(ops))
matcher = GraphMatcher(pattern_non_fixed_shape_input)
match_results_2 = list(matcher.match_ops(ops))
match_results = [(match_results_1, True), (match_results_2, False)]
for match_results, check_fixed_input_shape in match_results:
for match in match_results:
input_node = match.get_op('input')
reshape_node = match.get_op('reshape')
pack_node = match.get_op('pack')
slice_node = match.get_op('slice')
need_rewrite = pack_node.inputs[1].is_const() and pack_node.inputs[1].get_tensor_value() == -1
if not need_rewrite:
continue
input_shape = g.get_shape(reshape_node.input[0])
need_rewrite = input_shape is not None
if not need_rewrite:
def find_sequence_length_node(self, context):
# get any state variable
state_variable = list(context.state_variables.values())[0]
next_iter_input_node = self.g.get_node_by_output(state_variable.next_iteration_input.id)
if not is_tf_select_op(next_iter_input_node):
logger.debug("no sequence length node is given")
return None
matcher = GraphMatcher(seq_len_pattern)
match_result = matcher.match_op(next_iter_input_node)
if not match_result:
raise RuntimeError("failed to find sequence length.")
return match_result.get_op("seq_len_node")
def rewrite_random_uniform(g, ops):
pattern = \
OpTypePattern('Add', name='output', inputs=[
OpTypePattern('Mul', inputs=[
OpTypePattern('RandomUniform', name='input1', inputs=["*"]),
OpTypePattern('Sub', name='input2', inputs=["*", "*"]),
]), None
])
matcher = GraphMatcher(pattern)
match_results = list(matcher.match_ops(ops))
for match in match_results:
input2 = match.get_op('input2')
output = match.get_op('output')
ru_op = match.get_op('input1')
# max is on input 0
tmax = input2.inputs[0].get_tensor_value()
tmin = input2.inputs[1].get_tensor_value()
to_delete = list(set(match.get_nodes()))
new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output, to_delete)
g.replace_all_inputs(ops, output.output[0], new_node.output[0])
g.safe_remove_nodes(to_delete)
return ops
def _parse_input_ta(self, context):
graph_inputs = [v.switch_true_identity_output.id for v in context.loop_properties.all_variables.values()
if v.switch_true_identity_output.id]
matcher = GraphMatcher(self.ta_read_input_pattern, allow_reorder=False)
match_results = matcher.match_ops(self.g.get_nodes())
match_results = [r for r in match_results if r.get_op("ta_index").output[0] in graph_inputs]
for match in match_results:
ta_input_scatter = match.get_op("ta_input_scatter")
# the 3rd input of scatter is the value
data_input_id = ta_input_scatter.input[2]
ta_read_node = match.get_op("ta_read")
# todo: need check ta's index variable is a scalar starting from 1, and increase by 1 each iteration.
# then we can be sure this is equivalent to scan input behavior.
index_input_id = ta_read_node.input[1]
unstacked_ta_consumer = match.get_op("ta_read").output[0]
ta = InputTensorArray(data_input_id, index_input_id, unstacked_ta_consumer, self.g)
context.loop_properties.add_scan_input(ta)
2 input_x
3 weight
4 sequence node
5 initializer
6 state output & hidden output
3 process found info according to ONNX requirement
remember: op pattern and scope name are useful
they are used to get needed info from tensorflow graph
raw found info need to be formatted according to ONNX requirement
"""
# allow_reorder must be true. because LSTMCell and BasicLSTMCell's call function
# are defining the calculation with different orders. Then we can share the same
# pattern.
cell_pattern = get_pattern(unit_type)
matcher = GraphMatcher(cell_pattern, allow_reorder=True)
match_results = list(matcher.match_ops(self.g.get_nodes()))
if match_results:
for match in match_results:
self.run_single_match(match)
self.g.delete_unused_nodes(self.g.outputs)
self.print_step("finish handling")
return self.g.get_nodes()
def rewrite_random_uniform(g, ops):
pattern = \
OpTypePattern('Add', name='output', inputs=[
OpTypePattern('Mul', inputs=[
OpTypePattern('RandomUniform', name='input1', inputs=["*"]),
OpTypePattern('Sub', name='input2', inputs=["*", "*"]),
]), None
])
matcher = GraphMatcher(pattern)
match_results = list(matcher.match_ops(ops))
for match in match_results:
input2 = match.get_op('input2')
output = match.get_op('output')
ru_op = match.get_op('input1')
# max is on input 0
tmax = input2.inputs[0].get_tensor_value()
tmin = input2.inputs[1].get_tensor_value()
to_delete = list(set(match.get_nodes()))
new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output, to_delete)
g.replace_all_inputs(ops, output.output[0], new_node.output[0])
for n in to_delete:
g.remove_node(n.name)
return ops
def _parse_input_ta(self, context):
graph_inputs = [v.switch_true_identity_output.id for v in context.loop_properties.all_variables.values()
if v.switch_true_identity_output.id]
matcher = GraphMatcher(self.ta_read_input_pattern, allow_reorder=False)
match_results = matcher.match_ops(self.g.get_nodes())
match_results = [r for r in match_results if r.get_op("ta_index").output[0] in graph_inputs]
for match in match_results:
ta_input_scatter = match.get_op("ta_input_scatter")
# the 3rd input of scatter is the value
data_input_id = ta_input_scatter.input[2]
ta_read_node = match.get_op("ta_read")
# todo: need check ta's index variable is a scalar starting from 1, and increase by 1 each iteration.
# then we can be sure this is equivalent to scan input behavior.
index_input_id = ta_read_node.input[1]
unstacked_ta_consumer = match.get_op("ta_read").output[0]
ta = InputTensorArray(data_input_id, index_input_id, unstacked_ta_consumer, self.g)
context.loop_properties.add_scan_input(ta)
OpTypePattern('Const', name='beta'),
OpTypePattern('*', name='C')
])
])
# pattern3: A*B + C
pattern3 = \
OpTypePattern('Add|AddV2', name='add', inputs=[
OpTypePattern('MatMul', name='matmul'),
OpTypePattern('*', name='C'),
])
pattern_list = [pattern0, pattern1, pattern2, pattern3]
for pattern in pattern_list:
matcher = GraphMatcher(pattern, allow_reorder=True)
match_results = list(matcher.match_ops(ops))
if match_results:
for match in match_results:
matmul_node = match.get_op("matmul")
if g.get_dtype(matmul_node.input[0]) != onnx_pb.TensorProto.FLOAT:
logging.warning(u"For now, onnxruntime only support float32 type for Gemm rewriter")
continue
attr, is_valid = get_gemm_attr(match)
if not is_valid:
continue
add_node = match.get_op('add')
input_c_node = match.get_op("C")
a_edge_name = matmul_node.input[0]
def find_sequence_length_node(self, context):
# get any state variable
state_variable = list(context.state_variables.values())[0]
next_iter_input_node = self.g.get_node_by_output(state_variable.next_iteration_input.id)
if not is_tf_select_op(next_iter_input_node):
logger.debug("no sequence length node is given")
return None
matcher = GraphMatcher(seq_len_pattern)
match_result = matcher.match_op(next_iter_input_node)
if not match_result:
raise RuntimeError("failed to find sequence length.")
return match_result.get_op("seq_len_node")