Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def save(self):
save_pkl.save(path=self.path + self.trainer_file_name, object=self)
directory (str): if unspecified, use self.path as directory
return_name (bool): return the file-names corresponding to this save as tuple (model_obj_file, net_params_file)
"""
if directory is not None:
path = directory + file_prefix
else:
path = self.path + file_prefix
params_filepath = path + self.params_file_name
modelobj_filepath = path + self.model_file_name
if self.model is not None:
self.model.save_parameters(params_filepath)
temp_model = self.model
temp_sw = self.summary_writer
self.model = None
self.summary_writer = None
save_pkl.save(path=modelobj_filepath, object=self, verbose=verbose)
self.model = temp_model
self.summary_writer = temp_sw
if return_name:
return (modelobj_filepath, params_filepath)
def _callback(env):
if ((env.iteration - offset) % interval == 0) & (env.iteration != 0):
save_pkl.save(path=path, object=env.model)
save_pointer.save(path=latest_model_checkpoint, content_path=path)
_callback.before_iteration = True
def save(self, file_prefix=""):
""" Additional naming changes will be appended to end of file_prefix (must contain full absolute path) """
dataobj_file = file_prefix + self.DATAOBJ_SUFFIX
datalist_file = file_prefix + self.DATAVALUES_SUFFIX
data_list = self.dataset._data
self.dataset = None # Avoid pickling these
self.dataloader = None
save_pkl.save(path=dataobj_file, object=self)
mx.nd.save(datalist_file, data_list)
logger.debug("TabularNN Dataset saved to files: \n %s \n %s" % (dataobj_file, datalist_file))
def save(self, file_prefix ="", directory = None, return_filename=False, verbose=True):
if directory is None:
directory = self.path
file_name = directory + file_prefix + self.model_file_name
save_pkl.save(path=file_name, object=self, verbose=verbose)
if return_filename:
return file_name
if directory is None:
directory = self.path
directory = directory + file_prefix
if save_children:
model_names = []
for child in self.models:
child = self.load_child(child)
child.path = self.create_contexts(self.path + child.name + '/')
child.save(verbose=False)
model_names.append(child.name)
self.models = model_names
file_name = directory + self.model_file_name
save_pkl.save(path=file_name, object=self, verbose=verbose)
if return_filename:
return file_name
def save_artifacts(predictor, leaderboard, config):
artifacts = config.framework_params.get('_save_artifacts', ['leaderboard'])
try:
models_dir = output_subdir("models", config)
shutil.rmtree(os.path.join(models_dir, "utils"), ignore_errors=True)
if 'leaderboard' in artifacts:
save_pd.save(path=os.path.join(models_dir, "leaderboard.csv"), df=leaderboard)
if 'info' in artifacts:
ag_info = predictor.info()
info_dir = output_subdir("info", config)
save_pkl.save(path=os.path.join(info_dir, "info.pkl"), object=ag_info)
if 'models' in artifacts:
utils.zip_path(models_dir,
os.path.join(models_dir, "models.zip"))
def delete(path, isdir):
if isdir:
shutil.rmtree(path, ignore_errors=True)
elif os.path.splitext(path)[1] == '.pkl':
os.remove(path)
utils.walk_apply(models_dir, delete, max_depth=0)
except Exception:
log.warning("Error when saving artifacts.", exc_info=True)