Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def generate_submodel(self, submodel):
"""
Generates multiple PMML object for the regression and classification submodel of RetinaNet for each connected pyramid layers
Parameters
----------
submodel :
The Regression or the Classification submodel
Returns
-------
List of Nyoka's NetworkLayer object for all the submodels
"""
net_layers_group=list()
for idx, name in enumerate(self._pyramid_layers):
nyoka_pmml_reg_mod = kerasAPI.KerasToPmml(submodel)
del nyoka_pmml_reg_mod.DeepNetwork[0].NetworkLayer[0]
nyoka_pmml_reg_mod.DeepNetwork[0].NetworkLayer[0].connectionLayerId = name
for idx_, lay in enumerate(nyoka_pmml_reg_mod.DeepNetwork[0].NetworkLayer):
lay.layerId = lay.layerId+"_"+name
if idx_ != 0:
lay.connectionLayerId = lay.connectionLayerId+"_"+name
net_layers_group.extend(nyoka_pmml_reg_mod.DeepNetwork[0].NetworkLayer)
return net_layers_group
def __init__(self, keras_model, model_name=None, description=None,copyright=None,\
dataSet=None, predictedClasses=None, script_args=None):
if not dataSet:
dataSet = 'input'
data_dict = KerasDataDictionary(dataSet, predictedClasses, script_args)
trans_dict = None
if script_args:
self.validate_script_args(script_args)
trans_dict = KerasTransformationDictionary(dataSet,script_args)
super(KerasToPmml, self).__init__(
version="4.4", Header=KerasHeader(description=description, copyright=copyright),
DataDictionary=data_dict, TransformationDictionary= trans_dict, DeepNetwork=[
KerasNetwork(keras_model=keras_model,
model_name=model_name,
dataSet=dataSet,
predictedClasses=predictedClasses,
script_args=script_args)])
Returns
-------
Nyoka's PMML object
"""
from keras.models import Sequential
mod = Sequential()
for l in model.layers[1:]:
if l.__class__.__name__ == "Model":
break
mod.add(l)
if trained_classes == None:
warnings.warn(f"trained_classes are not provided. Maximum 80 classes will be considered.")
trained_classes = ["Category_"+str(i+1).zfill(2) for i in range(80)]
group1_pmml = kerasAPI.KerasToPmml(mod,model_name=self.model_name,dataSet=input_format, description=self.description,
predictedClasses=trained_classes, script_args=self.script_args)
return group1_pmml