Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def set_model_domain(model, domain):
"""
Sets the domain on the ONNX model.
:param model: instance of an ONNX model
:param domain: string containing the domain name of the model
Example:
::
from test_utils import set_model_domain
onnx_model = load_model("SqueezeNet.onnx")
set_model_domain(onnx_model, "com.acme")
"""
if model is None or not isinstance(model, onnx_proto.ModelProto):
raise ValueError("Parameter model is not an onnx model.")
if not convert_utils.is_string_type(domain):
raise ValueError("Parameter domain must be a string type.")
model.domain = domain
if container.target_opset < 9:
op_type = 'Ngram'
container.add_node(op_type, tokenized, output,
op_domain='com.microsoft', **attrs)
else:
op_type = 'TfIdfVectorizer'
container.add_node(op_type, tokenized, output, op_domain='',
op_version=9, **attrs)
if op.binary:
cast_result_name = scope.get_unique_variable_name('cast_result')
apply_cast(scope, output, cast_result_name, container,
to=onnx_proto.TensorProto.BOOL)
apply_cast(scope, cast_result_name, operator.output_full_names,
container, to=onnx_proto.TensorProto.FLOAT)
def predict(model, scope, operator, container, op_type, is_ensemble=False):
"""Predict target and calculate probability scores."""
indices_name = scope.get_unique_variable_name('indices')
dummy_proba_name = scope.get_unique_variable_name('dummy_proba')
values_name = scope.get_unique_variable_name('values')
out_values_name = scope.get_unique_variable_name('out_indices')
transposed_result_name = scope.get_unique_variable_name(
'transposed_result')
proba_output_name = scope.get_unique_variable_name('proba_output')
cast_result_name = scope.get_unique_variable_name('cast_result')
value = model.tree_.value.transpose(1, 2, 0)
attrs = populate_tree_attributes(
model, scope.get_unique_operator_name(op_type))
container.add_initializer(
values_name, onnx_proto.TensorProto.FLOAT,
value.shape, value.ravel())
container.add_node(
op_type, operator.input_full_names,
[indices_name, dummy_proba_name],
op_domain='ai.onnx.ml', **attrs)
container.add_node(
'ArrayFeatureExtractor',
[values_name, indices_name],
out_values_name, op_domain='ai.onnx.ml',
name=scope.get_unique_operator_name('ArrayFeatureExtractor'))
apply_transpose(scope, out_values_name, proba_output_name,
container, perm=(0, 2, 1))
apply_cast(scope, proba_output_name, cast_result_name,
container, to=onnx_proto.TensorProto.BOOL)
if is_ensemble:
# domains
domains = {}
version = target_opset
for n in container.nodes:
domains[n.domain] = max(domains.get(n.domain, version),
getattr(n, 'op_version', version))
for i, (k, v) in enumerate(domains.items()):
if i == 0 and len(onnx_model.opset_import) == 1:
op_set = onnx_model.opset_import[0]
else:
op_set = onnx_model.opset_import.add()
op_set.domain = k
op_set.version = domains.get(k, version)
# metadata
onnx_model.ir_version = onnx_proto.IR_VERSION
onnx_model.producer_name = utils.get_producer()
onnx_model.producer_version = utils.get_producer_version()
onnx_model.domain = utils.get_domain()
onnx_model.model_version = utils.get_model_version()
return onnx_model
def _guess_type_proto(data_type, dims):
# This could be moved to onnxconverter_common.
if data_type == onnx_proto.TensorProto.FLOAT:
return FloatTensorType(dims)
elif data_type == onnx_proto.TensorProto.DOUBLE:
return DoubleTensorType(dims)
elif data_type == onnx_proto.TensorProto.STRING:
return StringTensorType(dims)
elif data_type == onnx_proto.TensorProto.INT64:
return Int64TensorType(dims)
elif data_type == onnx_proto.TensorProto.INT32:
return Int32TensorType(dims)
elif data_type == onnx_proto.TensorProto.BOOL:
return BooleanTensorType(dims)
else:
raise NotImplementedError(
"Unsupported data_type '{}'. You may raise an issue "
"at https://github.com/onnx/sklearn-onnx/issues."
"".format(data_type))
ohe_op.drop_idx_[index]])
if isinstance(inp_type, (Int64TensorType, Int32TensorType)):
attrs['cats_int64s'] = categories.astype(np.int64)
else:
attrs['cats_strings'] = np.array(
[str(s).encode('utf-8') for s in categories])
ohe_output = scope.get_unique_variable_name(name + 'out')
result.append(ohe_output)
if 'cats_int64s' in attrs:
# Let's cast this input in int64.
cast_feature = scope.get_unique_variable_name(name + 'cast')
apply_cast(scope, name, cast_feature, container,
to=onnx_proto.TensorProto.INT64)
name = cast_feature
container.add_node('OneHotEncoder', name,
ohe_output, op_domain='ai.onnx.ml',
**attrs)
categories_len += len(categories)
concat_result_name = scope.get_unique_variable_name('concat_result')
apply_concat(scope, result, concat_result_name, container, axis=2)
reshape_input = concat_result_name
if np.issubdtype(ohe_op.dtype, np.signedinteger):
reshape_input = scope.get_unique_variable_name('cast')
apply_cast(scope, concat_result_name, reshape_input,
container, to=onnx_proto.TensorProto.INT64)