Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
out_n.CopyFrom(in_n)
continue
scan_input_directions = NodeFactory.get_attribute(in_n, 'scan_input_directions')
scan_output_directions = NodeFactory.get_attribute(in_n, 'scan_output_directions')
out_sg = onnx.GraphProto()
out_sg.CopyFrom(in_sg)
out_sg.ClearField('node')
nf_subgraph = NodeFactory(out_mp.graph, out_sg, prefix='opt_inproj_sg_' + in_n.name + '_')
new_inputs = list(in_n.input)
in_sg_inputs = [i.name for i in in_sg.input]
replaced_matmul = None
for in_sn in in_sg.node:
if in_sn.op_type == 'Concat' and len(in_sn.input) == 2 and all([i in in_sg_inputs for i in in_sn.input]):
# make sure the concat's inputs are scan input and scan state
if NodeFactory.get_attribute(in_sn, 'axis') != len(in_sg.input[-1].type.tensor_type.shape.dim) - 1:
continue # must concat last dim
matmul_node = [nn for nn in in_sg.node if nn.op_type == 'MatMul' and in_sn.output[0] in nn.input]
if not matmul_node:
continue
replaced_matmul = matmul_node[0]
assert replaced_matmul.input[1] in initializers
aa = nf.get_initializer(replaced_matmul.input[1])
input_size = in_sg.input[-1].type.tensor_type.shape.dim[-1].dim_value
if in_sg_inputs[-1] == in_sn.input[0]:
hidden_idx = 1
input_proj_weights, hidden_proj_weights = np.vsplit(aa, [input_size])
else:
hidden_idx = 0
hidden_proj_weights, input_proj_weights = np.vsplit(aa, [aa.shape[-1] - input_size])
# add matmul for input_proj outside of Scan
input_proj = nf.make_node('MatMul', [new_inputs[-1], input_proj_weights])
def convert_gru_to_scan(node, out_main_graph):
assert node.op_type == 'GRU'
nf = NodeFactory(out_main_graph)
with nf.scoped_prefix(node.output[0]) as scoped_prefix:
X = node.input[0]
Wa = nf.get_initializer(node.input[1])
Ra = nf.get_initializer(node.input[2])
num_inputs = len(node.input)
Ba = nf.get_initializer(node.input[3]) if num_inputs > 3 else None
seq_len = node.input[4] if num_inputs > 4 else None
InitHa = node.input[5] if num_inputs > 5 else None
direction, num_directions, activations = handle_common_attributes(node, ['Sigmoid', 'Tanh'])
hidden_size = NodeFactory.get_attribute(node, 'hidden_size')
linear_before_reset = NodeFactory.get_attribute(node, 'linear_before_reset')
InitHa = handle_init_state(InitHa, nf, num_directions)
batch_size, batch_node = handle_batch_size(X, nf, InitHa is None)
if InitHa is None:
zero_init_state = default_init_state(X, batch_size, batch_node, hidden_size, nf)
scan_outputs = []
scan_h_outputs = []
for direction_index in range(num_directions):
# for each direction
# X [seq_len, batch_size, input_size]
# W [3*hidden_size, input_size]
# R [3*hidden_size, hidden_size]
# B [6*hidden_size]
# seq_len [batch_size]
direction = NodeFactory.get_attribute(node, 'direction')
if direction:
direction = str(direction, 'utf-8')
else:
direction = 'forward'
num_directions = 2 if direction == 'bidirectional' else 1
activations = NodeFactory.get_attribute(node, 'activations')
if activations:
activations = [str(x, 'utf-8').lower().capitalize() for x in activations]
else:
activations = default_activations * num_directions
activation_alpha = NodeFactory.get_attribute(node, 'activation_alpha')
activation_beta = NodeFactory.get_attribute(node, 'activation_beta')
clip_threshold = NodeFactory.get_attribute(node, 'clip')
# TODO: support these activation attributes
assert not activation_alpha
assert not activation_beta
assert not clip_threshold
return direction, num_directions, activations
def convert_gru_to_scan(node, out_main_graph):
assert node.op_type == 'GRU'
nf = NodeFactory(out_main_graph)
with nf.scoped_prefix(node.output[0]) as scoped_prefix:
X = node.input[0]
Wa = nf.get_initializer(node.input[1])
Ra = nf.get_initializer(node.input[2])
num_inputs = len(node.input)
Ba = nf.get_initializer(node.input[3]) if num_inputs > 3 else None
seq_len = node.input[4] if num_inputs > 4 else None
InitHa = node.input[5] if num_inputs > 5 else None
direction, num_directions, activations = handle_common_attributes(node, ['Sigmoid', 'Tanh'])
hidden_size = NodeFactory.get_attribute(node, 'hidden_size')
linear_before_reset = NodeFactory.get_attribute(node, 'linear_before_reset')
InitHa = handle_init_state(InitHa, nf, num_directions)
batch_size, batch_node = handle_batch_size(X, nf, InitHa is None)
if InitHa is None:
zero_init_state = default_init_state(X, batch_size, batch_node, hidden_size, nf)
scan_outputs = []
scan_h_outputs = []
for direction_index in range(num_directions):
# for each direction
# X [seq_len, batch_size, input_size]
# W [3*hidden_size, input_size]
# R [3*hidden_size, hidden_size]
# B [6*hidden_size]
# seq_len [batch_size]
# init_h [batch_size, hidden_size]
def handle_common_attributes(node, default_activations):
direction = NodeFactory.get_attribute(node, 'direction')
if direction:
direction = str(direction, 'utf-8')
else:
direction = 'forward'
num_directions = 2 if direction == 'bidirectional' else 1
activations = NodeFactory.get_attribute(node, 'activations')
if activations:
activations = [str(x, 'utf-8').lower().capitalize() for x in activations]
else:
activations = default_activations * num_directions
activation_alpha = NodeFactory.get_attribute(node, 'activation_alpha')
activation_beta = NodeFactory.get_attribute(node, 'activation_beta')
clip_threshold = NodeFactory.get_attribute(node, 'clip')
# TODO: support these activation attributes
assert not activation_alpha
assert not activation_beta
assert not clip_threshold
return direction, num_directions, activations
def _append_initializer_from_graph(graph):
initializers = [i.name for i in graph.initializer]
for node in graph.node:
if node.op_type == 'Scan': # currently only handle Scan
subgraph = NodeFactory.get_attribute(node, 'body')
initializers += _append_initializer_from_graph(subgraph)
return initializers