Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def obtain_df_splits(data_csv):
"""Split input data csv file in to train, validation and test dataframes.
:param data_csv: Input data CSV file.
:return test_df, train_df, val_df: Train, validation and test dataframe
splits
"""
data_df = read_csv(data_csv)
# Obtain data split array mapping data rows to split type
# 0-train, 1-validation, 2-test
data_split = get_split(data_df)
train_split, test_split, val_split = split_dataset_tvt(data_df, data_split)
# Splits are python dictionaries not dataframes- they need to be converted.
test_df = pd.DataFrame(test_split)
train_df = pd.DataFrame(train_split)
val_df = pd.DataFrame(val_split)
return test_df, train_df, val_df
else:
dataset[output_feature['name']] = hdf5_data[
output_feature['name']][()]
if 'limit' in output_feature:
dataset[output_feature['name']] = collapse_rare_labels(
dataset[output_feature['name']],
output_feature['limit']
)
if not split_data:
hdf5_data.close()
return dataset
split = hdf5_data['split'][()]
hdf5_data.close()
training_set, test_set, validation_set = split_dataset_tvt(dataset, split)
# shuffle up
if shuffle_training:
training_set = data_utils.shuffle_dict_unison_inplace(training_set)
return training_set, test_set, validation_set
)
if not skip_save_processed_input:
logger.info('Writing dataset')
data_hdf5_fp = replace_file_extension(data_csv, 'hdf5')
data_utils.save_hdf5(data_hdf5_fp, data, train_set_metadata)
train_set_metadata[DATA_TRAIN_HDF5_FP] = data_hdf5_fp
logger.info('Writing train set metadata with vocabulary')
train_set_metadata_json_fp = replace_file_extension(
data_csv,
'json'
)
data_utils.save_json(
train_set_metadata_json_fp, train_set_metadata)
training_set, test_set, validation_set = split_dataset_tvt(
data,
data['split']
)
elif data_train_csv is not None:
# use data_train (including _validation and _test if they are present)
# and ignore data and train set metadata
# needs preprocessing
logger.info(
'Using training raw csv, no hdf5 and json '
'file with the same name have been found'
)
logger.info('Building dataset (it may take a while)')
concatenated_df = concatenate_csv(
data_train_csv,
data_validation_csv,
)
logger.info('Building dataset (it may take a while)')
concatenated_df = concatenate_csv(
data_train_csv,
data_validation_csv,
data_test_csv
)
concatenated_df.csv = data_train_csv
data, train_set_metadata = build_dataset_df(
concatenated_df,
features,
preprocessing_params,
train_set_metadata=train_set_metadata,
random_seed=random_seed
)
training_set, test_set, validation_set = split_dataset_tvt(
data,
data['split']
)
if not skip_save_processed_input:
logger.info('Writing dataset')
data_train_hdf5_fp = replace_file_extension(data_train_csv, 'hdf5')
data_utils.save_hdf5(
data_train_hdf5_fp,
training_set,
train_set_metadata
)
train_set_metadata[DATA_TRAIN_HDF5_FP] = data_train_hdf5_fp
if validation_set is not None:
data_validation_hdf5_fp = replace_file_extension(
data_validation_csv,
'hdf5'