Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
alpha_tensor_name = neg_alpha_op.name
_create_alpha_node(neg_alpha_op, updated_alpha)
relu_neg_input_op = None
for name in mul_op.input:
op = graph_rewrite_util.node_from_map(input_node_map, name)
if op.op == 'Relu':
relu_neg_input_op = op
break
if (not relu_neg_input_op or len(relu_neg_input_op.input) != 1 or
relu_neg_input_op.op != 'Relu'):
continue
# This detects a Neg op followed by a separated Relu op.
neg_input_op = graph_rewrite_util.node_from_map(
input_node_map, relu_neg_input_op.input[0])
if (not neg_input_op or len(neg_input_op.input) != 1 or
neg_input_op.op != 'Neg'):
continue
final_input_op = neg_input_op
if relu_input_op.input[0] != final_input_op.input[0]:
continue
relu_input_op.op = 'Prelu'
relu_input_op.input.extend([alpha_tensor_name])
# Remove the T attr that is defined in Relu op, since our custom Prelu op
# definition does not have that.
del relu_input_op.attr['T']
node.op = 'Identity'
"""
input_node_map = {}
for node in input_graph_def.node:
if node.name not in input_node_map:
input_node_map[node.name] = node
else:
raise ValueError('Duplicate node names detected for ', node.name)
nodes_to_skip = {}
inputs_to_remove = []
updated_alpha = []
for node in input_graph_def.node:
if (node.op not in ('Add', 'AddV2') or len(node.input) != 2):
continue
relu_input_op = graph_rewrite_util.node_from_map(
input_node_map, node.input[0])
if (not relu_input_op or relu_input_op.op != 'Relu'):
continue
mul_op = graph_rewrite_util.node_from_map(input_node_map, node.input[1])
if (not mul_op or mul_op.op != 'Mul'):
continue
neg_alpha_op = None
for name in mul_op.input:
op = graph_rewrite_util.node_from_map(input_node_map, name)
if op.op == 'Const':
neg_alpha_op = op
break
if not neg_alpha_op:
def _find_contraction_with_bias(node, node_map):
if node.op != 'BiasAdd':
return False
# Input to the BiasAdd must be a DepthwiseConv2dNative.
if not node.input:
return False
conv2d_node = graph_rewrite_util.node_from_map(node_map, node.input[0])
if conv2d_node.op != 'DepthwiseConv2dNative':
return False
return {'contraction': conv2d_node, 'bias': node, 'activation': None}
new_ops = []
for node in input_graph_def.node:
if (node.op not in ("BatchNormWithGlobalNormalization",
"FusedBatchNorm", "FusedBatchNormV3")):
continue
bias = None
conv_op = graph_rewrite_util.node_from_map(
input_node_map,
node.input[INPUT_ORDER[node.op].index("conv_op")])
# There might be an Add/BiasAdd op between the conv and the batchnorm,
# which we can fold into the mean param of the batchnorm.
if conv_op.op in ['BiasAdd', 'Add', 'AddV2']:
add_op = conv_op
# Follow the first input of the add to get to the conv.
conv_op = graph_rewrite_util.node_from_map(
input_node_map, add_op.input[0])
bias = graph_rewrite_util.node_from_map(input_node_map, add_op.input[1])
if conv_op.op not in ["Conv2D", "DepthwiseConv2dNative"]:
# Follow the second input of the add to get to the conv.
conv_op = graph_rewrite_util.node_from_map(
input_node_map, add_op.input[1])
bias = graph_rewrite_util.node_from_map(input_node_map, add_op.input[0])
if bias and bias.op != 'Const':
tf_logging.warning("The bias %s after the conv %s was not a constant. "
"Maybe because freeze_graph wasn't "
"run first?" % (bias.name, conv_op.name))
continue
if conv_op.op not in ["Conv2D", "DepthwiseConv2dNative"]:
tf_logging.warning("Didn't find expected Conv2D or DepthwiseConv2dNative"
" input to '%s'" % node.name)
continue
def _add_fused_contraction_node(contraction, bias_add, activation,
inputs_to_remove, nodes_to_skip):
fused_op = contraction
fused_op.input.extend([bias_add.input[1]])
fused_op.op = graph_rewrite_util.FUSED_DEPTHWISE_CONV2D
fused_op.attr['fused_ops'].list.s.extend([b'BiasAdd'])
fused_op.attr['num_args'].i = fused_op.attr['num_args'].i + 1
bias_add.input[:] = [contraction.name]
if activation:
fused_op.attr['fused_ops'].list.s.extend([activation.op.encode('ascii')])
nodes_to_skip[activation.name] = True
activation.input[:] = [contraction.name]
inputs_to_remove.append(activation)
inputs_to_remove.append(bias_add)
nodes_to_skip[bias_add.name] = True
if node.name not in input_node_map:
input_node_map[node.name] = node
else:
raise ValueError('Duplicate node names detected for ', node.name)
nodes_to_skip = {}
inputs_to_remove = []
for node in input_graph_def.node:
nodes = match_function(node, input_node_map)
if nodes:
_add_fused_contraction_node(nodes['contraction'], nodes['bias'],
nodes['activation'], inputs_to_remove,
nodes_to_skip)
if nodes_to_skip or inputs_to_remove:
return graph_rewrite_util.cleanup_graph_def(
input_graph_def, nodes_to_skip, inputs_to_remove)
# No pattern detected
return input_graph_def
input_node_map = {}
for node in input_graph_def.node:
if node.name not in input_node_map:
input_node_map[node.name] = node
else:
raise ValueError("Duplicate node names detected for ", node.name)
nodes_to_skip = {}
new_ops = []
for node in input_graph_def.node:
if (node.op not in ("BatchNormWithGlobalNormalization",
"FusedBatchNorm", "FusedBatchNormV3")):
continue
bias = None
conv_op = graph_rewrite_util.node_from_map(
input_node_map,
node.input[INPUT_ORDER[node.op].index("conv_op")])
# There might be an Add/BiasAdd op between the conv and the batchnorm,
# which we can fold into the mean param of the batchnorm.
if conv_op.op in ['BiasAdd', 'Add', 'AddV2']:
add_op = conv_op
# Follow the first input of the add to get to the conv.
conv_op = graph_rewrite_util.node_from_map(
input_node_map, add_op.input[0])
bias = graph_rewrite_util.node_from_map(input_node_map, add_op.input[1])
if conv_op.op not in ["Conv2D", "DepthwiseConv2dNative"]:
# Follow the second input of the add to get to the conv.
conv_op = graph_rewrite_util.node_from_map(
input_node_map, add_op.input[1])
bias = graph_rewrite_util.node_from_map(input_node_map, add_op.input[0])
if bias and bias.op != 'Const':