Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_match_flipped(self):
n1 = helper.make_node("Sub", ["i1", "i1"], ["n1:0"], name="n1")
n2 = helper.make_node("Add", ["i2", "i2"], ["n2:0"], name="n2")
n3 = helper.make_node("Mul", ["n1:0", "n2:0"], ["n3:0"], name="n3")
graph_proto = helper.make_graph(
nodes=[n1, n2, n3],
name="test",
inputs=[helper.make_tensor_value_info("i1", TensorProto.FLOAT, [2, 2]),
helper.make_tensor_value_info("i2", TensorProto.FLOAT, [2, 2])],
outputs=[helper.make_tensor_value_info("n2:0", TensorProto.FLOAT, [2, 2])],
initializer=[]
)
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
pattern = OpTypePattern('Mul', inputs=[
OpTypePattern('Add'),
OpTypePattern('Sub')
])
ops = g.get_nodes()
matcher = GraphMatcher(pattern, allow_reorder=True)
match_results = list(matcher.match_ops(ops))
self.assertEqual(1, len(match_results))
def test_rewrite_subgraph(self):
graph_proto = self.sample_net()
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
pattern = \
OpTypePattern('Abs', name='output', inputs=[
OpTypePattern('Add', name='input')
])
ops = g.get_nodes()
matcher = GraphMatcher(pattern)
match_results = list(matcher.match_ops(ops))
for match in match_results:
input_node = match.get_op('input')
output_node = match.get_op('output')
op_name = utils.make_name("ReplacedOp")
out_name = utils.port_name(op_name)
new_node = g.make_node("Sub", inputs=input_node.input, outputs=[out_name], name=op_name)
g.replace_all_inputs(ops, output_node.output[0], new_node.output[0])
for n in set(match.get_nodes()):
g.remove_node(n.name)
g.topological_sort(ops)
result = onnx_to_graphviz(g)
expected = 'digraph { Placeholder__4 [op_type=Placeholder] n1 [op_type=Abs] ' \
OpTypePattern("Enter", inputs=[
OpTypePattern("*", name="hidden_state_bias")
]),
OpTypePattern("MatMul", inputs=[
OpTypePattern("Enter", inputs=[
OpTypePattern("*", name="hidden_state_kernel"),
]),
OpTypePattern("Identity")
])
])
]),
OpTypePattern("BiasAdd", inputs=[
OpTypePattern("Enter", inputs=[
OpTypePattern("*", name="hidden_input_bias")
]),
OpTypePattern("MatMul", inputs=[
OpTypePattern("Enter", inputs=[
OpTypePattern("*", name="hidden_input_kernel"),
]),
OpTypePattern("*")
])
])
])
])
]),
OpTypePattern("Mul", inputs=[
gru_split_pattern,
OpTypePattern("Identity")
])
])
def rewrite_random_uniform_fold_const(g, ops):
pattern = \
OpTypePattern('Add', name='output', inputs=[
OpTypePattern('Mul', name='mul', inputs=[
OpTypePattern('RandomUniform', name='input1', inputs=["*"]),
None,
]),
None,
])
matcher = GraphMatcher(pattern)
match_results = list(matcher.match_ops(ops))
for match in match_results:
output = match.get_op('output')
mul = match.get_op('mul')
ru_op = match.get_op('input1')
tmax_minus_tmin = mul.inputs[1].get_tensor_value()
tmin = output.inputs[1].get_tensor_value()
logger = logging.getLogger(__name__)
class REWRITER_RESULT(Enum):
SKIP = 1
OK = 2
FAIL = 3
# TensorFlow LSTMCell/BasicLSTMCell computation graph matching
xc_pattern = OpTypePattern('Split', inputs=[
OpTypePattern("Const"), # axis for split
OpTypePattern("BiasAdd", name="bias_add", inputs=[
OpTypePattern("MatMul", inputs=[
OpTypePattern("ConcatV2|Concat", name="xh"),
OpTypePattern("Enter", inputs=[
OpTypePattern("*", name="cell_kernel"),
]),
]),
OpTypePattern("Enter", inputs=[
OpTypePattern("*", name="cell_bias"),
]),
]),
])
lstmcell_pattern = \
OpTypePattern('Mul', name='ht', inputs=[
OpTypePattern("Sigmoid", name="ot", inputs=[xc_pattern]),
OpTypePattern('Tanh', inputs=[
OpTypePattern("Enter", inputs=[
OpTypePattern("*", name="gate_bias")
]),
OpTypePattern("MatMul", name="update_reset_gate", inputs=[
OpTypePattern("Enter", inputs=[
OpTypePattern("*", name="gate_kernel")
]),
OpTypePattern("ConcatV2|Concat", name="cell_inputs")
])
])
])
])
grucell_pattern = \
OpTypePattern("Add", name="cell_output", inputs=[
OpTypePattern("Mul", inputs=[
gru_split_pattern,
OpTypePattern("Identity")
]),
OpTypePattern("Mul", inputs=[
OpTypePattern("Sub", inputs=[
OpTypePattern("Const"), # 1-u
gru_split_pattern
]),
OpTypePattern("*", name="optional_activation", inputs=[
OpTypePattern("BiasAdd", inputs=[
OpTypePattern("Enter", inputs=[
OpTypePattern("*", name="hidden_bias")
]),
OpTypePattern("MatMul", inputs=[
OpTypePattern("Enter", inputs=[
])
lstmcell_pattern = \
OpTypePattern('Mul', name='ht', inputs=[
OpTypePattern("Sigmoid", name="ot", inputs=[xc_pattern]),
OpTypePattern('Tanh', inputs=[
OpTypePattern("Add", name="ct", inputs=[
OpTypePattern("Mul", name="ct_identity_consumer", inputs=[
OpTypePattern("Sigmoid", name="ft", inputs=[
OpTypePattern("Add", inputs=[
xc_pattern,
OpTypePattern("*", name="ft_bias"),
]),
]),
OpTypePattern("*"),
]),
OpTypePattern("Mul", inputs=[
OpTypePattern("Sigmoid", name="it", inputs=[xc_pattern]),
OpTypePattern("Tanh", name="gt", inputs=[xc_pattern]),
]),
]),
]),
])
# input sequence: top to down, left to right
# split into update gate and reset gate
gru_split_pattern = \
OpTypePattern("Split", inputs=[
OpTypePattern("Const"), # split dim, a constant
OpTypePattern("Sigmoid", inputs=[
OpTypePattern("BiasAdd", inputs=[
OpTypePattern("*"),
]),
OpTypePattern("Mul", inputs=[
OpTypePattern("Sigmoid", name="it", inputs=[xc_pattern]),
OpTypePattern("Tanh", name="gt", inputs=[xc_pattern]),
]),
]),
]),
])
# input sequence: top to down, left to right
# split into update gate and reset gate
gru_split_pattern = \
OpTypePattern("Split", inputs=[
OpTypePattern("Const"), # split dim, a constant
OpTypePattern("Sigmoid", inputs=[
OpTypePattern("BiasAdd", inputs=[
OpTypePattern("Enter", inputs=[
OpTypePattern("*", name="gate_bias")
]),
OpTypePattern("MatMul", name="update_reset_gate", inputs=[
OpTypePattern("Enter", inputs=[
OpTypePattern("*", name="gate_kernel")
]),
OpTypePattern("ConcatV2|Concat", name="cell_inputs")
])
])
])
])
grucell_pattern = \
def __init__(self, g):
self.g = g
self.ta_read_input_pattern = \
OpTypePattern("TensorArrayReadV3", name="ta_read", inputs=[
OpTypePattern("Enter", name="ta_enter", inputs=[
OpTypePattern("TensorArrayV3")
]),
OpTypePattern("Identity", name="ta_index"),
OpTypePattern("Enter", name="ta_scatter_enter", inputs=[
OpTypePattern("TensorArrayScatterV3", name="ta_input_scatter")
]),
def rewrite_transpose(g, ops):
pattern = \
OpTypePattern('Transpose', name='output', inputs=[
OpTypePattern(None),
OpTypePattern('Sub', inputs=[
OpTypePattern('Sub', inputs=["*", "*"]),
OpTypePattern('Range', inputs=["*", "*", "*"]),
]),
])
matcher = GraphMatcher(pattern)
match_results = list(matcher.match_ops(ops))
for match in match_results:
output = match.get_op('output')
shape = g.get_shape(output.input[0])
dims = range(len(shape) - 1, -1, -1)
output.set_attr("perm", dims)
g.remove_input(output, output.input[1])
to_delete = [n for n in match.get_nodes() if n != output]
g.safe_remove_nodes(to_delete)
return ops