Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# get real and predicted values by running the model with the input of this batch
predicted_target = self.forwardWrapper(batch)
print(predicted_target)
exit()
real_target = batch.getTarget(flatten=self.flatTarget)
# append to all targets and all real values
real_target_all += real_target.data.tolist()
predicted_target_all += predicted_target.data.tolist()
if len(permutation) == 0:
# append to all targets and all real values
real_target_all_ret += real_target.data.tolist()
predicted_target_all_ret += predicted_target.data.tolist()
if batch is None:
log.error('there is no data in test, we should not be here')
return
# caluclate the error for all values
predicted_targets = batch.deflatTarget(np.array(predicted_target_all))
real_targets = batch.deflatTarget(np.array(real_target_all))
# caluclate the error for all values
predicted_targets_ret = batch.deflatTarget(np.array(predicted_target_all_ret))
real_targets_ret = batch.deflatTarget(np.array(real_target_all_ret))
r_values = {}
# calculate r and other statistical properties of error
for target_key in real_targets_ret:
r_values[target_key] = explained_variance_score(real_targets_ret[target_key], predicted_targets_ret[target_key], multioutput='variance_weighted')
files = self.fs_file_ids
if type(files) != type([]):
files = [files]
for file in files:
filename = '{path}/{filename}.pt'.format(path=MINDSDB_STORAGE_PATH, filename=file)
try:
file_path = Path(filename)
if file_path.is_file():
os.remove(filename)
else:
log.warning('could not delete file {file} becasue it doesnt exist'.format(file=filename))
except OSError:
log.error('could not delete file {file}'.format(file=filename))
# save/update model loss, error, confusion_matrix
self.register_model_data(train_ret, test_ret, is_it_lowest_error_epoch)
log.info('Loading model from store for retrain on new learning rate {lr}'.format(lr=self.data_model_object.learning_rates[i][LEARNING_RATE_INDEX]))
# after its done with the first batch group, get the one with the lowest error and keep training
ml_model_info = self.ml_model_info.find_one({
'model_name': self.model_name,
'ml_model_name': self.ml_model_name,
'config_serialized': json.dumps(self.config)
})
if ml_model_info is None:
# TODO: Make sure we have a model for this
log.info('No model found in storage')
return self.data_model_object
fs_file_ids = ml_model_info.fs_file_ids
if fs_file_ids is not None:
self.data_model_object = self.ml_model_class.load_from_disk(file_ids=fs_file_ids)
# When out of training loop:
# - if stop or finished leave as is (TODO: Have the hability to stop model training, but not necessarily delete it)
# * save best lowest error into GridFS (we only save into GridFS at the end because it takes too long)
# * remove local model file
# self.saveToGridFs(local_files=local_files, throttle=False)
return self.data_model_object
def __init__(self, session, light_transaction_metadata, heavy_transaction_metadata, logger = log):
"""
A transaction is the interface to start some MindsDB operation within a session
:param session:
:type session: utils.controllers.session_controller.SessionController
:param transaction_type:
:param transaction_metadata:
:type transaction_metadata: dict
:type heavy_transaction_metadata: dict
"""
self.session = session
self.lmd = light_transaction_metadata
self.lmd['created_at'] = str(datetime.datetime.now())
self.hmd = heavy_transaction_metadata
uuid_str = str(uuid.uuid4())
try:
open(uuid_file, 'w').write(uuid_str)
except:
log.warning('Cannot store token, Please add write permissions to file:' + uuid_file)
uuid_str = uuid_str + '.NO_WRITE'
file_path = Path(mdb_file)
if file_path.is_file():
token = open(mdb_file).read()
else:
token = '{system}|{version}|{uid}'.format(system=platform.system(), version=__version__, uid=uuid_str)
try:
open(mdb_file,'w').write(token)
except:
log.warning('Cannot store token, Please add write permissions to file:'+mdb_file)
token = token+'.NO_WRITE'
extra = urllib.parse.quote_plus(token)
try:
r = requests.get('http://mindsdb.com/updates/check/{extra}'.format(extra=extra), headers={'referer': 'http://check.mindsdb.com/?token={token}'.format(token=token)})
except:
log.warning('Could not check for updates')
return
try:
# TODO: Extract version, compare with version in version.py
ret = r.json()
if 'version' in ret and ret['version']!= __version__:
pass
#log.warning("There is a new version of MindsDB {version}, please do:\n pip3 uninstall mindsdb\n pip3 install mindsdb --user".format(version=ret['version']))
else:
log.debug('MindsDB is up to date!')
"""
:return:
"""
# here we will also determine based on the query if we should do a moving window for the training
# TODO: if order by encode
ret = {}
total_groups = len(self.data)
for group in self.data:
group_pointer = 0
first_column = next(iter(self.data[group]))
total_length = len(self.data[group][first_column])
log.debug('Iterator on group {group}/{total_groups}, total rows: {total_rows}'.format(group=group, total_groups=total_groups, total_rows=total_length))
while group_pointer < total_length:
limit = group_pointer + self.batch_size
limit = limit if limit < total_length else total_length
allcols_time = time.time()
for column in self.model_columns:
# log.debug('Generating: pytorch variables, batch: {column}-[{group_pointer}:{limit}]-{column_type}'.format(column=column, group_pointer=group_pointer, limit=limit, column_type=self.stats[column]['data_type']))
# col_start_time = time.time()
#if self.stats[column]['data_type'] != DATA_TYPES.FULL_TEXT:
ret[column] = self.data[group][column][group_pointer:limit]
ext_col_name = EXTENSION_COLUMNS_TEMPLATE.format(column_name=column)
if ext_col_name in self.data[group]: