Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def make_value_info(self, node_or_name, data_type, shape=None, usage=None):
if usage == NodeFactory.ValueInfoType.input:
value_info = self.graph_.input.add()
elif usage == NodeFactory.ValueInfoType.output:
value_info = self.graph_.output.add()
elif not usage:
value_info = self.graph_.value_info.add()
else:
raise NotImplementedError("unknown usage")
if type(node_or_name) == str:
name = node_or_name
else:
assert len(node_or_name.output) == 1
name = node_or_name.output[0]
value_info.CopyFrom(helper.make_tensor_value_info(name, data_type, shape))
scan_body = onnx.GraphProto()
scan_body.name = name_prefix + '_subgraph'
nf_body = NodeFactory(out_main_graph, scan_body)
with nf_body.scoped_prefix(name_prefix) as body_scoped_prefix:
# subgraph inputs
X_proj_subgraph = X_proj.name + '_subgraph'
prev_h_subgraph = name_prefix + '_h_subgraph'
seq_len_subgraph = declare_seq_len_in_subgraph(seq_len, nf_body, X_proj.name, batch_size)
nf_body.make_value_info(prev_h_subgraph,
data_type=onnx.TensorProto.FLOAT,
shape=(batch_size, hidden_size),
usage=NodeFactory.ValueInfoType.input)
nf_body.make_value_info(X_proj_subgraph,
data_type=onnx.TensorProto.FLOAT,
shape=(batch_size, 3*hidden_size),
usage=NodeFactory.ValueInfoType.input)
# subgraph nodes
# zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)
# rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
# ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) # default, when linear_before_reset = 0
# ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # when linear_before_reset != 0
# Ht = (1 - zt) (.) ht + zt (.) Ht-1
split_X_outputs = ['split_Xzr', 'split_Xh']
nf_body.make_node('Split', X_proj_subgraph, {"axis":1, "split":[2*hidden_size, hidden_size]}, output_names=split_X_outputs)
nf_body.make_value_info('split_Xzr',
def declare_seq_len_in_subgraph(seq_len, nf_body, prefix, batch_size):
if seq_len:
seq_len_subgraph = prefix + '_seq_len_subgraph'
nf_body.make_value_info(seq_len_subgraph,
data_type=onnx.TensorProto.INT32,
shape=(batch_size,),
usage=NodeFactory.ValueInfoType.input)
else:
seq_len_subgraph = None
return seq_len_subgraph
def make_value_info(self, node_or_name, data_type, shape=None, usage=None):
if usage == NodeFactory.ValueInfoType.input:
value_info = self.graph_.input.add()
elif usage == NodeFactory.ValueInfoType.output:
value_info = self.graph_.output.add()
elif not usage:
value_info = self.graph_.value_info.add()
else:
raise NotImplementedError("unknown usage")
if type(node_or_name) == str:
name = node_or_name
else:
assert len(node_or_name.output) == 1
name = node_or_name.output[0]
value_info.CopyFrom(helper.make_tensor_value_info(name, data_type, shape))
with nf_body.scoped_prefix(name_prefix) as body_scoped_prefix:
# subgraph inputs
X_proj_subgraph = X_proj.name + '_subgraph'
prev_h_subgraph = name_prefix + '_h_subgraph'
seq_len_subgraph = declare_seq_len_in_subgraph(seq_len, nf_body, X_proj.name, batch_size)
nf_body.make_value_info(prev_h_subgraph,
data_type=onnx.TensorProto.FLOAT,
shape=(batch_size, hidden_size),
usage=NodeFactory.ValueInfoType.input)
nf_body.make_value_info(X_proj_subgraph,
data_type=onnx.TensorProto.FLOAT,
shape=(batch_size, 3*hidden_size),
usage=NodeFactory.ValueInfoType.input)
# subgraph nodes
# zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)
# rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
# ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) # default, when linear_before_reset = 0
# ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # when linear_before_reset != 0
# Ht = (1 - zt) (.) ht + zt (.) Ht-1
split_X_outputs = ['split_Xzr', 'split_Xh']
nf_body.make_node('Split', X_proj_subgraph, {"axis":1, "split":[2*hidden_size, hidden_size]}, output_names=split_X_outputs)
nf_body.make_value_info('split_Xzr',
data_type=onnx.TensorProto.FLOAT,
shape=(batch_size, 2*hidden_size))
nf_body.make_value_info('split_Xh',
data_type=onnx.TensorProto.FLOAT,
shape=(batch_size, hidden_size))
final_subgraph_output.append(seq_len_output)
# since seq_len is rank-1, need to unsqueeze for Where op on rank-2 states
condition = nf_body.make_node('Unsqueeze', nf_body.make_node('Greater', [seq_len_subgraph, np.zeros(shape=(), dtype=np.int32)]), {'axes':[1]})
for valid, default in subgraph_output_or_default:
final_subgraph_output.append(nf_body.make_node('Where', [condition, valid, default]))
else:
final_subgraph_output.append(None)
for valid, default in subgraph_output_or_default:
final_subgraph_output.append(nf_body.make_node('Identity', valid))
for subgraph_o in final_subgraph_output[1:]:
nf_body.make_value_info(subgraph_o,
data_type=onnx.TensorProto.FLOAT,
shape=(batch_size, hidden_size),
usage=NodeFactory.ValueInfoType.output)
return final_subgraph_output