How to use the tensorflowjs.converters.graph_rewrite_util function in tensorflowjs

To help you get started, we’ve selected a few tensorflowjs examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github tensorflow / tfjs / tfjs-converter / python / tensorflowjs / converters / fuse_prelu.py View on Github external
alpha_tensor_name = neg_alpha_op.name
    _create_alpha_node(neg_alpha_op, updated_alpha)

    relu_neg_input_op = None
    for name in mul_op.input:
      op = graph_rewrite_util.node_from_map(input_node_map, name)
      if op.op == 'Relu':
        relu_neg_input_op = op
        break

    if (not relu_neg_input_op or len(relu_neg_input_op.input) != 1 or
        relu_neg_input_op.op != 'Relu'):
      continue

    # This detects a Neg op followed by a separated Relu op.
    neg_input_op = graph_rewrite_util.node_from_map(
        input_node_map, relu_neg_input_op.input[0])
    if (not neg_input_op or len(neg_input_op.input) != 1 or
        neg_input_op.op != 'Neg'):
      continue
    final_input_op = neg_input_op

    if relu_input_op.input[0] != final_input_op.input[0]:
      continue

    relu_input_op.op = 'Prelu'
    relu_input_op.input.extend([alpha_tensor_name])
    # Remove the T attr that is defined in Relu op, since our custom Prelu op
    # definition does not have that.
    del relu_input_op.attr['T']

    node.op = 'Identity'
github tensorflow / tfjs / tfjs-converter / python / tensorflowjs / converters / fuse_prelu.py View on Github external
"""
  input_node_map = {}
  for node in input_graph_def.node:
    if node.name not in input_node_map:
      input_node_map[node.name] = node
    else:
      raise ValueError('Duplicate node names detected for ', node.name)

  nodes_to_skip = {}
  inputs_to_remove = []
  updated_alpha = []
  for node in input_graph_def.node:
    if (node.op not in ('Add', 'AddV2') or len(node.input) != 2):
      continue

    relu_input_op = graph_rewrite_util.node_from_map(
        input_node_map, node.input[0])
    if (not relu_input_op or relu_input_op.op != 'Relu'):
      continue

    mul_op = graph_rewrite_util.node_from_map(input_node_map, node.input[1])
    if (not mul_op or mul_op.op != 'Mul'):
      continue

    neg_alpha_op = None
    for name in mul_op.input:
      op = graph_rewrite_util.node_from_map(input_node_map, name)
      if op.op == 'Const':
        neg_alpha_op = op
        break

    if not neg_alpha_op:
github tensorflow / tfjs / tfjs-converter / python / tensorflowjs / converters / fuse_depthwise_conv2d.py View on Github external
def _find_contraction_with_bias(node, node_map):
  if node.op != 'BiasAdd':
    return False

  # Input to the BiasAdd must be a DepthwiseConv2dNative.
  if not node.input:
    return False

  conv2d_node = graph_rewrite_util.node_from_map(node_map, node.input[0])
  if conv2d_node.op != 'DepthwiseConv2dNative':
    return False

  return {'contraction': conv2d_node, 'bias': node, 'activation': None}
github tensorflow / tfjs / tfjs-converter / python / tensorflowjs / converters / fold_batch_norms.py View on Github external
new_ops = []
  for node in input_graph_def.node:
    if (node.op not in ("BatchNormWithGlobalNormalization",
                        "FusedBatchNorm", "FusedBatchNormV3")):
      continue

    bias = None
    conv_op = graph_rewrite_util.node_from_map(
        input_node_map,
        node.input[INPUT_ORDER[node.op].index("conv_op")])
    # There might be an Add/BiasAdd op between the conv and the batchnorm,
    # which we can fold into the mean param of the batchnorm.
    if conv_op.op in ['BiasAdd', 'Add', 'AddV2']:
      add_op = conv_op
      # Follow the first input of the add to get to the conv.
      conv_op = graph_rewrite_util.node_from_map(
          input_node_map, add_op.input[0])
      bias = graph_rewrite_util.node_from_map(input_node_map, add_op.input[1])
      if conv_op.op not in ["Conv2D", "DepthwiseConv2dNative"]:
        # Follow the second input of the add to get to the conv.
        conv_op = graph_rewrite_util.node_from_map(
            input_node_map, add_op.input[1])
        bias = graph_rewrite_util.node_from_map(input_node_map, add_op.input[0])
    if bias and bias.op != 'Const':
      tf_logging.warning("The bias %s after the conv %s was not a constant. "
                         "Maybe because freeze_graph wasn't "
                         "run first?" % (bias.name, conv_op.name))
      continue
    if conv_op.op not in ["Conv2D", "DepthwiseConv2dNative"]:
      tf_logging.warning("Didn't find expected Conv2D or DepthwiseConv2dNative"
                         " input to '%s'" % node.name)
      continue
github tensorflow / tfjs / tfjs-converter / python / tensorflowjs / converters / fuse_depthwise_conv2d.py View on Github external
def _add_fused_contraction_node(contraction, bias_add, activation,
                                inputs_to_remove, nodes_to_skip):
  fused_op = contraction
  fused_op.input.extend([bias_add.input[1]])

  fused_op.op = graph_rewrite_util.FUSED_DEPTHWISE_CONV2D
  fused_op.attr['fused_ops'].list.s.extend([b'BiasAdd'])
  fused_op.attr['num_args'].i = fused_op.attr['num_args'].i + 1
  bias_add.input[:] = [contraction.name]

  if activation:
    fused_op.attr['fused_ops'].list.s.extend([activation.op.encode('ascii')])
    nodes_to_skip[activation.name] = True
    activation.input[:] = [contraction.name]
    inputs_to_remove.append(activation)

  inputs_to_remove.append(bias_add)
  nodes_to_skip[bias_add.name] = True
github tensorflow / tfjs / tfjs-converter / python / tensorflowjs / converters / fuse_depthwise_conv2d.py View on Github external
if node.name not in input_node_map:
      input_node_map[node.name] = node
    else:
      raise ValueError('Duplicate node names detected for ', node.name)

  nodes_to_skip = {}
  inputs_to_remove = []
  for node in input_graph_def.node:
    nodes = match_function(node, input_node_map)
    if nodes:
      _add_fused_contraction_node(nodes['contraction'], nodes['bias'],
                                  nodes['activation'], inputs_to_remove,
                                  nodes_to_skip)

  if nodes_to_skip or inputs_to_remove:
    return graph_rewrite_util.cleanup_graph_def(
        input_graph_def, nodes_to_skip, inputs_to_remove)

  # No pattern detected
  return input_graph_def
github tensorflow / tfjs / tfjs-converter / python / tensorflowjs / converters / fold_batch_norms.py View on Github external
input_node_map = {}
  for node in input_graph_def.node:
    if node.name not in input_node_map:
      input_node_map[node.name] = node
    else:
      raise ValueError("Duplicate node names detected for ", node.name)

  nodes_to_skip = {}
  new_ops = []
  for node in input_graph_def.node:
    if (node.op not in ("BatchNormWithGlobalNormalization",
                        "FusedBatchNorm", "FusedBatchNormV3")):
      continue

    bias = None
    conv_op = graph_rewrite_util.node_from_map(
        input_node_map,
        node.input[INPUT_ORDER[node.op].index("conv_op")])
    # There might be an Add/BiasAdd op between the conv and the batchnorm,
    # which we can fold into the mean param of the batchnorm.
    if conv_op.op in ['BiasAdd', 'Add', 'AddV2']:
      add_op = conv_op
      # Follow the first input of the add to get to the conv.
      conv_op = graph_rewrite_util.node_from_map(
          input_node_map, add_op.input[0])
      bias = graph_rewrite_util.node_from_map(input_node_map, add_op.input[1])
      if conv_op.op not in ["Conv2D", "DepthwiseConv2dNative"]:
        # Follow the second input of the add to get to the conv.
        conv_op = graph_rewrite_util.node_from_map(
            input_node_map, add_op.input[1])
        bias = graph_rewrite_util.node_from_map(input_node_map, add_op.input[0])
    if bias and bias.op != 'Const':