Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
neg_alpha_op = None
for name in mul_op.input:
op = graph_rewrite_util.node_from_map(input_node_map, name)
if op.op == 'Const':
neg_alpha_op = op
break
if not neg_alpha_op:
continue
alpha_tensor_name = neg_alpha_op.name
_create_alpha_node(neg_alpha_op, updated_alpha)
relu_neg_input_op = None
for name in mul_op.input:
op = graph_rewrite_util.node_from_map(input_node_map, name)
if op.op == 'Relu':
relu_neg_input_op = op
break
if (not relu_neg_input_op or len(relu_neg_input_op.input) != 1 or
relu_neg_input_op.op != 'Relu'):
continue
# This detects a Neg op followed by a separated Relu op.
neg_input_op = graph_rewrite_util.node_from_map(
input_node_map, relu_neg_input_op.input[0])
if (not neg_input_op or len(neg_input_op.input) != 1 or
neg_input_op.op != 'Neg'):
continue
final_input_op = neg_input_op
bias = graph_rewrite_util.node_from_map(input_node_map, add_op.input[1])
if conv_op.op not in ["Conv2D", "DepthwiseConv2dNative"]:
# Follow the second input of the add to get to the conv.
conv_op = graph_rewrite_util.node_from_map(
input_node_map, add_op.input[1])
bias = graph_rewrite_util.node_from_map(input_node_map, add_op.input[0])
if bias and bias.op != 'Const':
tf_logging.warning("The bias %s after the conv %s was not a constant. "
"Maybe because freeze_graph wasn't "
"run first?" % (bias.name, conv_op.name))
continue
if conv_op.op not in ["Conv2D", "DepthwiseConv2dNative"]:
tf_logging.warning("Didn't find expected Conv2D or DepthwiseConv2dNative"
" input to '%s'" % node.name)
continue
weights_op = graph_rewrite_util.node_from_map(
input_node_map, conv_op.input[1])
if weights_op.op != "Const":
tf_logging.warning("Didn't find expected conv Constant input to '%s',"
" found %s instead. Maybe because freeze_graph wasn't"
" run first?" % (conv_op.name, weights_op))
continue
weights = graph_rewrite_util.values_from_const(weights_op)
if conv_op.op == "Conv2D":
channel_count = weights.shape[3]
elif conv_op.op == "DepthwiseConv2dNative":
channel_count = weights.shape[2] * weights.shape[3]
mean_op = graph_rewrite_util.node_from_map(
input_node_map,
node.input[INPUT_ORDER[node.op].index("mean_op")])
if mean_op.op != "Const":
for node in input_graph_def.node:
if (node.op not in ('Add', 'AddV2') or len(node.input) != 2):
continue
relu_input_op = graph_rewrite_util.node_from_map(
input_node_map, node.input[0])
if (not relu_input_op or relu_input_op.op != 'Relu'):
continue
mul_op = graph_rewrite_util.node_from_map(input_node_map, node.input[1])
if (not mul_op or mul_op.op != 'Mul'):
continue
neg_alpha_op = None
for name in mul_op.input:
op = graph_rewrite_util.node_from_map(input_node_map, name)
if op.op == 'Const':
neg_alpha_op = op
break
if not neg_alpha_op:
continue
alpha_tensor_name = neg_alpha_op.name
_create_alpha_node(neg_alpha_op, updated_alpha)
relu_neg_input_op = None
for name in mul_op.input:
op = graph_rewrite_util.node_from_map(input_node_map, name)
if op.op == 'Relu':
relu_neg_input_op = op
break
tf_logging.warning("Didn't find expected mean Constant input to '%s',"
" found %s instead. Maybe because freeze_graph wasn't"
" run first?" % (node.name, mean_op))
continue
mean_value = graph_rewrite_util.values_from_const(mean_op)
if bias is not None:
# Adjust the mean of the batchnorm based on the add op in-between the conv
# and the batchnorm.
mean_value = mean_value - graph_rewrite_util.values_from_const(bias)
if mean_value.shape != (channel_count,):
tf_logging.warning("Incorrect shape for mean, found %s, expected %s,"
" for node %s" % (str(mean_value.shape), str(
(channel_count,)), node.name))
continue
var_op = graph_rewrite_util.node_from_map(
input_node_map,
node.input[INPUT_ORDER[node.op].index("var_op")])
if var_op.op != "Const":
tf_logging.warning("Didn't find expected var Constant input to '%s',"
" found %s instead. Maybe because freeze_graph wasn't"
" run first?" % (node.name, var_op))
continue
var_value = graph_rewrite_util.values_from_const(var_op)
if var_value.shape != (channel_count,):
tf_logging.warning("Incorrect shape for var, found %s, expected %s,"
" for node %s" % (str(var_value.shape), str(
(channel_count,)), node.name))
continue
beta_op = graph_rewrite_util.node_from_map(
input_node_map,
conv_op = graph_rewrite_util.node_from_map(
input_node_map,
node.input[INPUT_ORDER[node.op].index("conv_op")])
# There might be an Add/BiasAdd op between the conv and the batchnorm,
# which we can fold into the mean param of the batchnorm.
if conv_op.op in ['BiasAdd', 'Add', 'AddV2']:
add_op = conv_op
# Follow the first input of the add to get to the conv.
conv_op = graph_rewrite_util.node_from_map(
input_node_map, add_op.input[0])
bias = graph_rewrite_util.node_from_map(input_node_map, add_op.input[1])
if conv_op.op not in ["Conv2D", "DepthwiseConv2dNative"]:
# Follow the second input of the add to get to the conv.
conv_op = graph_rewrite_util.node_from_map(
input_node_map, add_op.input[1])
bias = graph_rewrite_util.node_from_map(input_node_map, add_op.input[0])
if bias and bias.op != 'Const':
tf_logging.warning("The bias %s after the conv %s was not a constant. "
"Maybe because freeze_graph wasn't "
"run first?" % (bias.name, conv_op.name))
continue
if conv_op.op not in ["Conv2D", "DepthwiseConv2dNative"]:
tf_logging.warning("Didn't find expected Conv2D or DepthwiseConv2dNative"
" input to '%s'" % node.name)
continue
weights_op = graph_rewrite_util.node_from_map(
input_node_map, conv_op.input[1])
if weights_op.op != "Const":
tf_logging.warning("Didn't find expected conv Constant input to '%s',"
" found %s instead. Maybe because freeze_graph wasn't"
" run first?" % (conv_op.name, weights_op))
continue
bias = None
conv_op = graph_rewrite_util.node_from_map(
input_node_map,
node.input[INPUT_ORDER[node.op].index("conv_op")])
# There might be an Add/BiasAdd op between the conv and the batchnorm,
# which we can fold into the mean param of the batchnorm.
if conv_op.op in ['BiasAdd', 'Add', 'AddV2']:
add_op = conv_op
# Follow the first input of the add to get to the conv.
conv_op = graph_rewrite_util.node_from_map(
input_node_map, add_op.input[0])
bias = graph_rewrite_util.node_from_map(input_node_map, add_op.input[1])
if conv_op.op not in ["Conv2D", "DepthwiseConv2dNative"]:
# Follow the second input of the add to get to the conv.
conv_op = graph_rewrite_util.node_from_map(
input_node_map, add_op.input[1])
bias = graph_rewrite_util.node_from_map(input_node_map, add_op.input[0])
if bias and bias.op != 'Const':
tf_logging.warning("The bias %s after the conv %s was not a constant. "
"Maybe because freeze_graph wasn't "
"run first?" % (bias.name, conv_op.name))
continue
if conv_op.op not in ["Conv2D", "DepthwiseConv2dNative"]:
tf_logging.warning("Didn't find expected Conv2D or DepthwiseConv2dNative"
" input to '%s'" % node.name)
continue
weights_op = graph_rewrite_util.node_from_map(
input_node_map, conv_op.input[1])
if weights_op.op != "Const":
tf_logging.warning("Didn't find expected conv Constant input to '%s',"
" found %s instead. Maybe because freeze_graph wasn't"
ValueError: If the graph is badly formed with duplicate node names.
"""
input_node_map = {}
nodes_to_skip = {}
inputs_to_remove = []
for node in input_graph_def.node:
if node.name not in input_node_map:
input_node_map[node.name] = node
else:
raise ValueError('Duplicate node names detected for ', node.name)
for node in input_graph_def.node:
if node.op != 'Prelu':
continue
fused_conv_op = graph_rewrite_util.node_from_map(
input_node_map, node.input[0])
if (not fused_conv_op or
(fused_conv_op.op != '_FusedConv2D'
and fused_conv_op.op != 'FusedDepthwiseConv2dNative') or
len(fused_conv_op.attr['fused_ops'].list.s) > 1):
continue
alpha_tensor_name = node.input[1]
fused_conv_op.input.extend([alpha_tensor_name])
fused_conv_op.attr['fused_ops'].list.s.extend([b'Prelu'])
fused_conv_op.attr['num_args'].i = fused_conv_op.attr['num_args'].i + 1
node.op = 'Identity'
node.input[:] = [node.input[0]]
nodes_to_skip[node.name] = True
inputs_to_remove.append(node)