Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
log.info('stop_training_in_accuracy flag was passed and contition was met, thus, we are forcing training to stop now')
return False
model_data = self.persistent_model_metadata.find_one({'model_name': self.model_name}) #type: PersistentModelMetadata
if model_data is None:
return False
if model_data.stop_training == True:
log.info('[FORCED] Stopping model training....')
return False
elif model_data.kill_training == True:
log.info('[FORCED] Stopping model training....')
self.persistent_model_metadata.delete()
self.ml_model_info.delete()
return False
return True
:return:
"""
self.train_init_time = time.time()
self.data = data
self.train_sampler = Sampler(self.data.train_set, metadata_as_stored=self.persistent_model_metadata,
ignore_types=self.ml_model_class.ignore_types, sampler_mode=SAMPLER_MODES.LEARN)
self.test_sampler = Sampler(self.data.test_set, metadata_as_stored=self.persistent_model_metadata,
ignore_types=self.ml_model_class.ignore_types, sampler_mode=SAMPLER_MODES.LEARN)
self.train_sampler.variable_wrapper = self.ml_model_class.variable_wrapper
self.test_sampler.variable_wrapper = self.ml_model_class.variable_wrapper
self.sample_batch = self.train_sampler.getSampleBatch()
if self.model_name is None:
log.info('Starting model...')
self.data_model_object = self.ml_model_class(self.sample_batch)
else:
self.data_model_object = self.ml_model_class(self.sample_batch)
log.info('Training model...')
last_epoch = 0
lowest_error = None
highest_accuracy = 0
local_files = None
for i in range(len(self.data_model_object.learning_rates)):
self.data_model_object.setLearningRateIndex(i)
def save_to_disk(self, local_files):
"""
This method persists model into disk, and removes previous stored files of this model
:param local_files: any previous files
:return:
"""
if local_files is not None:
for file_response_object in local_files:
try:
os.remove(file_response_object.path)
except:
log.info('Could not delete file {path}'.format(path=file_response_object.path))
file_id = '{model_name}.{ml_model_name}.{config_hash}'.format(model_name=self.model_name, ml_model_name=self.ml_model_name, config_hash=self.config_hash)
return_objects = self.data_model_object.saveToDisk(file_id)
file_ids = [ret.file_id for ret in return_objects]
self.ml_model_info.fs_file_ids = file_ids
self.ml_model_info.update()
return return_objects
"""
Check if the training should continue
:return:
"""
model_name = self.model_name
# check if stop training is set in which case we should exit the training
if self.stop_training_in_x_seconds:
if time.time() - self.train_init_time > self.stop_training_in_x_seconds:
log.info('stop_training_in_x_seconds flag was passed and contition was met, thus, we are forcing training to stop now')
return False
if self.stop_training_in_accuracy is not None and highest_accuracy >= self.stop_training_in_accuracy:
log.info('stop_training_in_accuracy flag was passed and contition was met, thus, we are forcing training to stop now')
return False
model_data = self.persistent_model_metadata.find_one({'model_name': self.model_name}) #type: PersistentModelMetadata
if model_data is None:
return False
if model_data.stop_training == True:
log.info('[FORCED] Stopping model training....')
return False
elif model_data.kill_training == True:
log.info('[FORCED] Stopping model training....')
self.persistent_model_metadata.delete()
is_it_lowest_error_epoch = False
# if lowest error save model
if lowest_error in [None]:
lowest_error = test_ret.error
highest_accuracy = test_ret.accuracy
log.info(f'Got best accuracy so far: {highest_accuracy}')
if lowest_error > test_ret.error and test_ret.accuracy > 0:
is_it_lowest_error_epoch = True
lowest_error = test_ret.error
highest_accuracy = test_ret.accuracy
log.debug('[SAVING MODEL] Lowest ERROR so far! - Test Error: {error}, Accuracy: {accuracy}'.format(error=test_ret.error, accuracy=test_ret.accuracy))
log.debug('Lowest ERROR so far! Saving: model {model_name}, {data_model} config:{config}'.format(
model_name=self.model_name, data_model=self.ml_model_name, config=self.ml_model_info.config_serialized))
log.info(f'Got best accuracy so far: {highest_accuracy}')
# save model local file
local_files = self.save_to_disk(local_files)
# throttle model saving into GridFS to 10 minutes
# self.saveToGridFs(local_files, throttle=True)
# save model predicted - real vectors
log.debug('Saved: model {model_name}:{ml_model_name} state vars into db [OK]'.format(model_name=self.model_name, ml_model_name = self.ml_model_name))
# check if continue training
if self.should_continue(highest_accuracy) == False:
# save model local file
local_files = self.save_to_disk(local_files)
return self.data_model_object
# save/update model loss, error, confusion_matrix
self.register_model_data(train_ret, test_ret, is_it_lowest_error_epoch)
def should_continue(self, highest_accuracy=None):
"""
Check if the training should continue
:return:
"""
model_name = self.model_name
# check if stop training is set in which case we should exit the training
if self.stop_training_in_x_seconds:
if time.time() - self.train_init_time > self.stop_training_in_x_seconds:
log.info('stop_training_in_x_seconds flag was passed and contition was met, thus, we are forcing training to stop now')
return False
if self.stop_training_in_accuracy is not None and highest_accuracy >= self.stop_training_in_accuracy:
log.info('stop_training_in_accuracy flag was passed and contition was met, thus, we are forcing training to stop now')
return False
model_data = self.persistent_model_metadata.find_one({'model_name': self.model_name}) #type: PersistentModelMetadata
if model_data is None:
return False
if model_data.stop_training == True:
log.info('[FORCED] Stopping model training....')
return False
ignore_types=self.ml_model_class.ignore_types, sampler_mode=SAMPLER_MODES.LEARN)
self.test_sampler = Sampler(self.data.test_set, metadata_as_stored=self.persistent_model_metadata,
ignore_types=self.ml_model_class.ignore_types, sampler_mode=SAMPLER_MODES.LEARN)
self.train_sampler.variable_wrapper = self.ml_model_class.variable_wrapper
self.test_sampler.variable_wrapper = self.ml_model_class.variable_wrapper
self.sample_batch = self.train_sampler.getSampleBatch()
if self.model_name is None:
log.info('Starting model...')
self.data_model_object = self.ml_model_class(self.sample_batch)
else:
self.data_model_object = self.ml_model_class(self.sample_batch)
log.info('Training model...')
last_epoch = 0
lowest_error = None
highest_accuracy = 0
local_files = None
for i in range(len(self.data_model_object.learning_rates)):
self.data_model_object.setLearningRateIndex(i)
for train_ret in self.data_model_object.trainModel(self.train_sampler):
log.debug('Training State epoch:{epoch}, batch:{batch}, loss:{loss}'.format(epoch=train_ret.epoch,
batch=train_ret.batch,
loss=train_ret.loss))