Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def build(self):
"""
build topo_sort of ONNX model
"""
for layer in self.model.node:
node = ONNXGraphNode(layer)
self.node_map[layer.name] = node
for layer in self.model.input:
if layer.name not in self.node_map:
is_place_holder = self.is_place_holder_nodes(layer.name)
self.node_map[layer.name] = ONNXGraphDataNode(
layer,
layer_name=layer.name,
is_global_input=is_place_holder)
#set data node's weight
for initializer in self.model.initializer:
name = initializer.name
weight = to_array(initializer)
if name in self.node_map:
if isinstance(self.node_map[name], ONNXGraphDataNode):
def __init__(self, layer, layer_name=None):
if layer_name is None:
super(ONNXGraphNode, self).__init__(layer, layer.name)
else:
super(ONNXGraphNode, self).__init__(layer, layer_name)
self.layer_type = layer.op_type
self.fluid_code = FluidCode()
self.attr_map = self.get_attr_map()
self.out_shapes = list()
self.dtype = None
self.which_child = {}
def _interpolate(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_scales = self.graph.get_input_node(node, idx=1, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
out_shape = val_y.out_shapes[0]
if out_shape is not None:
assert len(out_shape) == 4, 'only 4-D Tensor as X and Y supported'
out_shape = out_shape[2:]
scales = _const_weight_or_none(val_scales)
if isinstance(val_scales, ONNXGraphNode):
scales, _, _ = self.get_dynamic_shape(val_scales.layer_name)
attr = {'name': string(node.layer_name)}
use_scales = True
if scales is not None:
try:
assert len(scales) == 4, 'only 4-D Tensor as X and Y supported'
assert scales[0] == 1 and scales[
1] == 1, 'only scale on (NC)HW supported'
assert scales[2] == scales[
3], 'only aspect-ratio-invariant scale supported'
except:
use_scales = False
scale = scales[2] if scales else None
if scale is None:
assert out_shape, 'neither scales nor output shape is available'
def Reshape(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_shape = self.graph.get_input_node(node, idx=1, copy=True)
val_reshaped = self.graph.get_node(node.layer.output[0], copy=True)
shape = None
if isinstance(val_shape, ONNXGraphDataNode):
self.omit_nodes.append(val_shape.layer_name)
attr = {'name': string(node.layer_name)}
# catch dynamic graph shape
if isinstance(val_shape, ONNXGraphNode):
shape, _, _ = self.get_dynamic_shape(val_shape.layer_name)
if val_shape.dtype == 'int64':
val_shape_cast = val_shape.layer_name + '_cast'
node.fluid_code.add_layer('cast',
inputs=val_shape,
output=val_shape_cast,
param_attr={'dtype': string('int32')})
attr['actual_shape'] = val_shape_cast
else:
attr['actual_shape'] = val_shape
if shape is None:
shape = val_reshaped.out_shapes[0]
if shape is None:
shape = [1, -1]
name = initializer.name
weight = to_array(initializer)
if name in self.node_map:
if isinstance(self.node_map[name], ONNXGraphDataNode):
self.node_map[name].weight = weight
self.node_map[name].embeded_as = []
else:
self.node_map[name] = ONNXGraphDataNode(initializer,
layer_name=name,
is_global_input=False)
self.node_map[name].weight = weight
self.node_map[name].embeded_as = []
#generate connection between nodes for topo
for layer_name, node in self.node_map.items():
if isinstance(node, ONNXGraphNode):
self.build_connection(layer_name, node)
#generate topo
super(ONNXGraph, self).build()
self.input_nodes = self.place_holder_nodes