Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
for input_tensor_info in scan_props.state_inputs:
scan_body_g.add_graph_input(input_tensor_info.id, input_tensor_info.dtype, input_tensor_info.shape)
for input_tensor_info in scan_props.scan_inputs:
scan_body_g.add_graph_input(input_tensor_info.id, input_tensor_info.dtype, input_tensor_info.shape)
scan_node = self._create_scan_node(context, scan_props,
state_inputs_initial_values + scan_inputs_initial_values)
if not scan_node:
log.error("failed to create scan node during rewrite")
return REWRITER_RESULT.FAIL
scan_node.set_body_graph_as_attr("body", scan_body_g)
self._connect_scan_with_output(context, scan_node)
return REWRITER_RESULT.OK
except Exception as ex:
tb = traceback.format_exc()
log.error("custom rnn rewrite failed, due to exception: %s, details:%s", ex, tb)
return REWRITER_RESULT.FAIL
if not rnn_scope_name:
log.debug("unable to find rnn scope name, skip")
return REWRITER_RESULT.SKIP
log.debug("rnn scope name is %s", rnn_scope_name)
self.print_step("get_weight_and_bias starts")
rnn_weights = self.get_weight_and_bias(match)
if not rnn_weights:
log.debug("rnn weights check failed, skip")
return REWRITER_RESULT.SKIP
rnn_props = RnnProperties()
res = self.get_var_initializers(match, rnn_props, rnn_scope_name)
if not res or not rnn_props.var_initializers.keys:
log.debug("no cell variable initializers found, skip")
return REWRITER_RESULT.SKIP
seq_len_input_node = self.find_sequence_length_node(rnn_scope_name)
input_filter = self.get_rnn_input_blacklist(rnn_weights, rnn_props)
if seq_len_input_node:
input_filter.append(seq_len_input_node)
self.find_inputs(rnn_scope_name, rnn_props, match, input_filter)
if not rnn_props.is_valid():
log.debug("rnn properties are not valid, skip")
return REWRITER_RESULT.SKIP
if not self.process_input_x(rnn_props, rnn_scope_name):
log.debug("rnn input x not found, skip")
return REWRITER_RESULT.SKIP
self.print_step("process the weights/bias/ft_bias, to fit onnx weights/bias requirements")
loop_body_g.replace_all_inputs(loop_body_g.get_nodes(), input_ta.consumer.id, data_node.output[0])
## create Loop node
loop_node = self._create_loop_node(context, loop_props, init_cond_output)
if not loop_node:
logger.error("failed to create loop node during rewrite")
return REWRITER_RESULT.FAIL
loop_node.set_body_graph_as_attr("body", loop_body_g)
logger.debug("rewrite successfully")
return REWRITER_RESULT.OK
except Exception as ex:
tb = traceback.format_exc()
logger.error("loop rewrite failed, due to exception: %s, details:%s", ex, tb)
return REWRITER_RESULT.FAIL
for input_tensor_info in scan_props.state_inputs:
scan_body_g.add_graph_input(input_tensor_info.id, input_tensor_info.dtype, input_tensor_info.shape)
for input_tensor_info in scan_props.scan_inputs:
scan_body_g.add_graph_input(input_tensor_info.id, input_tensor_info.dtype, input_tensor_info.shape)
scan_node = self._create_scan_node(context, scan_props,
state_inputs_initial_values + scan_inputs_initial_values)
if not scan_node:
logger.error("failed to create scan node during rewrite")
return REWRITER_RESULT.FAIL
scan_node.set_body_graph_as_attr("body", scan_body_g)
self._connect_scan_with_output(context, scan_node)
return REWRITER_RESULT.OK
except Exception as ex:
tb = traceback.format_exc()
logger.error("custom rnn rewrite failed, due to exception: %s, details:%s", ex, tb)
return REWRITER_RESULT.FAIL
def rewrite(self, context):
return REWRITER_RESULT.FAIL
most found info is stored in "rnn_props"
"""
log.debug("=========================")
self.print_step("start handling a new potential rnn cell")
self.all_nodes = self.g.get_nodes()
# FIXME:
# pylint: disable=assignment-from-none,assignment-from-no-return
# when bi-directional, node in while will be rnnxx/fw/fw/while/... >> scope name is rnnxx/fw/fw
# when single direction, node in while will be rnnxx/while/... >> scope name is rnnxx
# and rnnxx can be assigned by users but not "fw", though maybe "FW" in another tf version
rnn_scope_name = self.get_rnn_scope_name(match)
if not rnn_scope_name:
log.debug("unable to find rnn scope name, skip")
return REWRITER_RESULT.SKIP
log.debug("rnn scope name is %s", rnn_scope_name)
self.print_step("get_weight_and_bias starts")
rnn_weights = self.get_weight_and_bias(match)
if not rnn_weights:
log.debug("rnn weights check failed, skip")
return REWRITER_RESULT.SKIP
rnn_props = RnnProperties()
res = self.get_var_initializers(match, rnn_props, rnn_scope_name)
if not res or not rnn_props.var_initializers.keys:
log.debug("no cell variable initializers found, skip")
return REWRITER_RESULT.SKIP
seq_len_input_node = self.find_sequence_length_node(rnn_scope_name)
input_filter = self.get_rnn_input_blacklist(rnn_weights, rnn_props)
# self.g.get_nodes may change inside this loop so that we parse all LoopCond first
for op in loopcond_ops:
logger.debug("======================\n handling loop cond node called %s", op.name)
context = self.create_context()
context.loop_cond = op
self._check_in_read_only_mode(context)
if self.need_rewrite(context):
# cut off connection between cell/cond graphs and useless nodes like Merge, NextIteration.
self._cut_off_connection_for_cell(context)
context.cell_graph = self._crop_loop_body_sub_graph(context)
context.cond_graph = self._crop_loop_condition_sub_graph(context)
_result = self.rewrite(context)
if _result == REWRITER_RESULT.OK:
logger.debug("rewrite successfully")
elif _result == REWRITER_RESULT.SKIP:
logger.debug("rewrite skipped for LoopCond called %s", op.name)
continue
elif _result == REWRITER_RESULT.FAIL:
raise ValueError("rewrite failed, so just fast fail it")
if self.g.outputs:
# clean the graph based on output names.
self.g.delete_unused_nodes(self.g.outputs)
return self.g.get_nodes()
for input_ta in loop_props.tensor_array_inputs:
# Loop does not have scan inputs, so we use Gather to get data for each iteration.
index_node = loop_body_g.make_node("Unsqueeze", [input_ta.index_input_id], attr={"axes": [0]})
gather_node = loop_body_g.make_node("Gather", [input_ta.data_input_id, index_node.output[0]])
data_node = loop_body_g.make_node("Squeeze", [gather_node.output[0]], attr={"axes": [0]})
loop_body_g.replace_all_inputs(loop_body_g.get_nodes(), input_ta.consumer.id, data_node.output[0])
## create Loop node
loop_node = self._create_loop_node(context, loop_props)
if not loop_node:
logger.error("failed to create loop node during rewrite")
return REWRITER_RESULT.FAIL
loop_node.set_body_graph_as_attr("body", loop_body_g)
logger.debug("rewrite successfully")
return REWRITER_RESULT.OK
except Exception as ex:
tb = traceback.format_exc()
logger.error("loop rewrite failed, due to exception: %s, details:%s", ex, tb)
return REWRITER_RESULT.FAIL
def rewrite(self, context):
return REWRITER_RESULT.FAIL