Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __init__(self, layer, layer_name=None, is_global_input=False):
if layer_name is None:
super(ONNXGraphDataNode, self).__init__(layer, layer.name)
else:
super(ONNXGraphDataNode, self).__init__(layer, layer_name)
if is_global_input:
self.layer_type = 'place_holder'
else:
self.layer_type = 'create_parameter'
self.layer_name = layer_name
self.fluid_code = FluidCode()
self.weight = None
self.embeded_as = None
self.which_child = {}
is_place_holder = self.is_place_holder_nodes(layer.name)
self.node_map[layer.name] = ONNXGraphDataNode(
layer,
layer_name=layer.name,
is_global_input=is_place_holder)
#set data node's weight
for initializer in self.model.initializer:
name = initializer.name
weight = to_array(initializer)
if name in self.node_map:
if isinstance(self.node_map[name], ONNXGraphDataNode):
self.node_map[name].weight = weight
self.node_map[name].embeded_as = []
else:
self.node_map[name] = ONNXGraphDataNode(initializer,
layer_name=name,
is_global_input=False)
self.node_map[name].weight = weight
self.node_map[name].embeded_as = []
#generate connection between nodes for topo
for layer_name, node in self.node_map.items():
if isinstance(node, ONNXGraphNode):
self.build_connection(layer_name, node)
#generate topo
super(ONNXGraph, self).build()
self.input_nodes = self.place_holder_nodes
def build(self):
"""
build topo_sort of ONNX model
"""
for layer in self.model.node:
node = ONNXGraphNode(layer)
self.node_map[layer.name] = node
for layer in self.model.input:
if layer.name not in self.node_map:
is_place_holder = self.is_place_holder_nodes(layer.name)
self.node_map[layer.name] = ONNXGraphDataNode(
layer,
layer_name=layer.name,
is_global_input=is_place_holder)
#set data node's weight
for initializer in self.model.initializer:
name = initializer.name
weight = to_array(initializer)
if name in self.node_map:
if isinstance(self.node_map[name], ONNXGraphDataNode):
self.node_map[name].weight = weight
self.node_map[name].embeded_as = []
else:
self.node_map[name] = ONNXGraphDataNode(initializer,
layer_name=name,
is_global_input=False)
def Reshape(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_shape = self.graph.get_input_node(node, idx=1, copy=True)
val_reshaped = self.graph.get_node(node.layer.output[0], copy=True)
shape = None
if isinstance(val_shape, ONNXGraphDataNode):
self.omit_nodes.append(val_shape.layer_name)
attr = {'name': string(node.layer_name)}
# catch dynamic graph shape
if isinstance(val_shape, ONNXGraphNode):
shape, _, _ = self.get_dynamic_shape(val_shape.layer_name)
if val_shape.dtype == 'int64':
val_shape_cast = val_shape.layer_name + '_cast'
node.fluid_code.add_layer('cast',
inputs=val_shape,
output=val_shape_cast,
param_attr={'dtype': string('int32')})
attr['actual_shape'] = val_shape_cast
else:
attr['actual_shape'] = val_shape
self.node_map[layer.name] = node
for layer in self.model.input:
if layer.name not in self.node_map:
is_place_holder = self.is_place_holder_nodes(layer.name)
self.node_map[layer.name] = ONNXGraphDataNode(
layer,
layer_name=layer.name,
is_global_input=is_place_holder)
#set data node's weight
for initializer in self.model.initializer:
name = initializer.name
weight = to_array(initializer)
if name in self.node_map:
if isinstance(self.node_map[name], ONNXGraphDataNode):
self.node_map[name].weight = weight
self.node_map[name].embeded_as = []
else:
self.node_map[name] = ONNXGraphDataNode(initializer,
layer_name=name,
is_global_input=False)
self.node_map[name].weight = weight
self.node_map[name].embeded_as = []
#generate connection between nodes for topo
for layer_name, node in self.node_map.items():
if isinstance(node, ONNXGraphNode):
self.build_connection(layer_name, node)
#generate topo
super(ONNXGraph, self).build()