Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_constant_fill(self):
if not legacy_opset_pre_ver(9):
raise unittest.SkipTest(
"ONNX version {} doesn't support ConstantFill.".format(
defs.onnx_opset_version()))
shape = [1, 2, 3, 4]
extra_shape = [5, 6]
value = 3.
node_def = helper.make_node(
"ConstantFill",
["X"],
["Y"],
value=value,
extra_shape=extra_shape,
dtype=1,
)
x = self._get_rnd_float32(shape=shape)
y = np.zeros(shape + extra_shape)
y.fill(value)
output = run_node(node_def, [x])
np.testing.assert_equal(output["Y"].dtype, tf.float32)
def test_compress(self):
if legacy_opset_pre_ver(9):
raise unittest.SkipTest(
"ONNX version {} doesn't support Compress.".format(
defs.onnx_opset_version()))
axis = 1
node_def = helper.make_node("Compress",
inputs=['X', 'condition'],
outputs=['Y'],
axis=axis)
x = self._get_rnd_float32(shape=[5, 5, 5])
cond = np.array([1, 0, 1])
output = run_node(node_def, inputs=[x, cond])
np.testing.assert_almost_equal(output['Y'], np.compress(cond, x, axis=axis))
def test_mul(self):
from skl2onnx.algebra.onnx_ops import OnnxMul
assert OnnxMul.operator_name == 'Mul'
assert isinstance(
OnnxMul(
'a', 'b', op_version=onnx.defs.onnx_opset_version()),
OnnxOperator)
def test_max_pool_2d_dilations_same_lower(self):
if legacy_opset_pre_ver(10):
raise unittest.SkipTest(
"ONNX version {} doesn't support dilations.".format(
defs.onnx_opset_version()))
kernel_shape = [3, 3]
strides = [2, 2]
dilations = [3, 3]
auto_pad = "same_lower"
input_shape = [10, 3, 24, 24]
self._test_pooling(input_shape=input_shape, kernel_shape=kernel_shape,
strides=strides, dilations=dilations,
auto_pad=auto_pad)
def test_kernel_constant1(self):
ker = C(5.)
onx = convert_kernel(ker, 'X', output_names=['Y'], dtype=np.float32,
op_version=onnx_opset_version())
model_onnx = onx.to_onnx(
inputs=[('X', FloatTensorType([None, None]))], dtype=np.float32)
sess = InferenceSession(model_onnx.SerializeToString())
res = sess.run(None, {'X': Xtest_.astype(np.float32)})[0]
m1 = res
m2 = ker(Xtest_)
assert_almost_equal(m1, m2, decimal=5)
def test_scan_v8(self):
if legacy_opset_pre_ver(8) or not legacy_opset_pre_ver(9):
raise unittest.SkipTest(
"ONNX version {} not supported.".format(
defs.onnx_opset_version()))
initial = self._get_rnd_int(0, 100, shape=[5, 1]).astype(np.float32)
x1 = self._get_rnd_float32(0, 1000, shape=[5, 20, 6, 2])
x2 = self._get_rnd_float32(0, 1000, shape=[5, 20, 6, 2])
directions = [0, 1]
sequence_lens = np.array([15, 20, 14, 18, 20]).astype(np.int32)
Y = initial + (np.shape(x1)[1] if sequence_lens is str else \
np.reshape(sequence_lens,[-1, 1]))
x1_out = x1 + 1
# left-right flip x2 (reverse direction)
x2_out = x2[:,::-1] + 1
Z = np.concatenate([x1_out, x2_out], 2)
if sequence_lens is not str:
def generate_onnx_graph(dim, nbnode, input_name='X1'):
matrices = []
scale = list(numpy.ones((1, dim)).ravel())
i1 = input_name
for i in range(nbnode - 1):
i2 = list(rand(1, dim).ravel())
matrices.append(i2)
node = OnnxScaler(
i1, offset=i2, scale=scale,
op_version=onnx.defs.onnx_opset_version())
i1 = node
i2 = list(rand(1, dim).ravel())
matrices.append(i2)
node = OnnxScaler(
i1, offset=i2, scale=scale, output_names=['Y'],
op_version=onnx.defs.onnx_opset_version())
onx = node.to_onnx([(input_name, FloatTensorType((None, dim)))],
outputs=[('Y', FloatTensorType((None, dim)))])
return onx, matrices
def find_opset(opset):
if opset is None or opset == 0:
opset = defs.onnx_opset_version()
if opset > PREFERRED_OPSET:
opset = PREFERRED_OPSET
return opset
def check_node_args(graph_def, supported):
""" Check for required node arguments in graph
:param graph_def: the graph of operations
:param supported: the supported operators in graph
:return: whether all required parameters are provided
"""
logger.info('Checking for required node arguments...')
opset_dict = {}
opset_dict[defs.ONNX_DOMAIN] = defs.onnx_opset_version()
handlers = get_all_frontend_handlers(opset_dict)
total_nodes = 0
failed_nodes = 0
for node in graph_def.node:
if node.op in supported:
total_nodes += 1
tf_node = TensorflowNode(node)
kwargs = {}
for inp in node.input:
for attr_node in graph_def.node:
if inp == attr_node.name:
kwargs[inp] = attr_node.attr['value']
break
handler = handlers.get(defs.ONNX_DOMAIN, {}).get(node.op, None)
try: