Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def k_fold_split(self, dataset, k, directories=None, **kwargs):
"""Needs custom implementation due to ragged splits for stratification."""
log("Computing K-fold split", self.verbose)
if directories is None:
directories = [tempfile.mkdtemp() for _ in range(k)]
else:
assert len(directories) == k
fold_datasets = []
# rem_dataset is remaining portion of dataset
rem_dataset = dataset
for fold in range(k):
# Note starts as 1/k since fold starts at 0. Ends at 1 since fold goes up
# to k-1.
frac_fold = 1. / (k - fold)
fold_dir = directories[fold]
rem_dir = tempfile.mkdtemp()
fold_dataset, rem_dataset = self.split(rem_dataset, frac_fold,
[fold_dir, rem_dir])
fold_datasets.append(fold_dataset)
def split(self,
dataset,
frac_train=.8,
frac_valid=.1,
frac_test=.1,
log_every_n=1000):
"""
Splits internal compounds into train/validation/test by scaffold.
"""
np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.)
scaffolds = {}
log("About to generate scaffolds", self.verbose)
data_len = len(dataset)
for ind, smiles in enumerate(dataset.ids):
if ind % log_every_n == 0:
log("Generating scaffold %d/%d" % (ind, data_len), self.verbose)
scaffold = generate_scaffold(smiles)
if scaffold not in scaffolds:
scaffolds[scaffold] = [ind]
else:
scaffolds[scaffold].append(ind)
# Sort from largest to smallest scaffold sets
scaffolds = {key: sorted(value) for key, value in scaffolds.items()}
scaffold_sets = [
scaffold_set
for (scaffold, scaffold_set) in sorted(
scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True)
]
train_cutoff = frac_train * len(dataset)
valid_cutoff = (frac_train + frac_valid) * len(dataset)
train_inds, valid_inds, test_inds = [], [], []
log("About to sort in scaffold sets", self.verbose)
train_dir=None,
valid_dir=None,
test_dir=None,
frac_train=.8,
frac_valid=.1,
frac_test=.1,
seed=None,
log_every_n=1000,
verbose=True,
**kwargs):
"""
Splits self into train/validation/test sets.
Returns Dataset objects.
"""
log("Computing train/valid/test indices", self.verbose)
train_inds, valid_inds, test_inds = self.split(
dataset,
seed=seed,
frac_train=frac_train,
frac_test=frac_test,
frac_valid=frac_valid,
log_every_n=log_every_n,
**kwargs)
if train_dir is None:
train_dir = tempfile.mkdtemp()
if valid_dir is None:
valid_dir = tempfile.mkdtemp()
if test_dir is None:
test_dir = tempfile.mkdtemp()
train_dataset = dataset.select(train_inds, train_dir)
if frac_valid != 0:
def _to_singletask(dataset, task_dirs):
"""Transforms a multitask dataset to a collection of singletask datasets."""
tasks = dataset.get_task_names()
assert len(tasks) == len(task_dirs)
log("Splitting multitask dataset into singletask datasets", dataset.verbose)
task_datasets = [
DiskDataset.create_dataset([], task_dirs[task_num], [task])
for (task_num, task) in enumerate(tasks)
]
#task_metadata_rows = {task: [] for task in tasks}
for shard_num, (X, y, w, ids) in enumerate(dataset.itershards()):
log("Processing shard %d" % shard_num, dataset.verbose)
basename = "dataset-%d" % shard_num
for task_num, task in enumerate(tasks):
log("\tTask %s" % task, dataset.verbose)
if len(w.shape) == 1:
w_task = w
elif w.shape[1] == 1:
w_task = w[:, 0]
else:
w_task = w[:, task_num]
y_task = y[:, task_num]
# Extract those datapoints which are present for this task
X_nonzero = X[w_task != 0]
num_datapoints = X_nonzero.shape[0]
y_nonzero = np.reshape(y_task[w_task != 0], (num_datapoints, 1))
w_nonzero = np.reshape(w_task[w_task != 0], (num_datapoints, 1))
metadata_rows.append(
Dataset.write_data_to_disk(
self.data_dir, "data", tasks, X, y, w, ids,
compute_feature_statistics=compute_feature_statistics))
self.metadata_df = Dataset.construct_metadata(metadata_rows)
self.save_to_disk()
else:
# Create an empty metadata dataframe to be filled at a later time
basename = "metadata"
metadata_rows = [Dataset.write_data_to_disk(
self.data_dir, basename, tasks)]
self.metadata_df = Dataset.construct_metadata(metadata_rows)
self.save_to_disk()
else:
log("Loading pre-existing metadata file.", self.verbosity)
if os.path.exists(self._get_metadata_filename()):
self.metadata_df = load_from_disk(self._get_metadata_filename())
else:
raise ValueError("No metadata found.")
def featurize(self, input_files, data_dir=None, shard_size=8192):
"""Featurize provided files and write to specified location.
For large datasets, automatically shards into smaller chunks
for convenience.
Parameters
----------
input_files: list
List of input filenames.
data_dir: str
(Optional) Directory to store featurized dataset.
shard_size: int
(Optional) Number of examples stored in each shard.
"""
log("Loading raw samples now.", self.verbose)
log("shard_size: %d" % shard_size, self.verbose)
if not isinstance(input_files, list):
input_files = [input_files]
def shard_generator():
for shard_num, shard in enumerate(
self.get_shards(input_files, shard_size)):
time1 = time.time()
X, valid_inds = self.featurize_shard(shard)
ids = shard[self.id_field].values
ids = ids[valid_inds]
if len(self.tasks) > 0:
# Featurize task results iff they exist.
y, w = convert_df_to_numpy(shard, self.tasks, self.id_field)
# Filter out examples where featurization failed.
self.train_graph.loss)
with self._get_shared_session(train=True) as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(max_to_keep=max_checkpoints_to_keep)
# Save an initial checkpoint.
saver.save(sess, self._save_path, global_step=0)
for epoch in range(nb_epoch):
avg_loss, n_batches = 0., 0
for ind, (X_b, y_b, w_b, ids_b) in enumerate(
# Turns out there are valid cases where we don't want pad-batches
# on by default.
#dataset.iterbatches(batch_size, pad_batches=True)):
dataset.iterbatches(
self.batch_size, pad_batches=self.pad_batches)):
if ind % log_every_N_batches == 0:
log("On batch %d" % ind, self.verbose)
# Run training op.
feed_dict = self.construct_feed_dict(X_b, y_b, w_b, ids_b)
fetches = self.train_graph.output + [
train_op, self.train_graph.loss
]
fetched_values = sess.run(fetches, feed_dict=feed_dict)
output = fetched_values[:len(self.train_graph.output)]
loss = fetched_values[-1]
avg_loss += loss
y_pred = np.squeeze(np.array(output))
y_b = y_b.flatten()
n_batches += 1
saver.save(sess, self._save_path, global_step=epoch)
avg_loss = float(avg_loss) / n_batches
log('Ending epoch %d: Average loss %g' % (epoch, avg_loss),
self.verbose)
def _add_user_specified_features(self, df, ori_df):
"""Merge user specified features.
Merge features included in dataset provided by user
into final features dataframe
"""
if self.user_specified_features is not None:
log("Aggregating User-Specified Features", self.verbosity)
features_data = []
for ind, row in ori_df.iterrows():
# pandas rows are tuples (row_num, row_data)
feature_list = []
for feature_name in self.user_specified_features:
feature_list.append(row[feature_name])
features_data.append(np.array(feature_list))
df["user-specified-features"] = features_data
dataset,
seed=None,
frac_train=.8,
frac_valid=.1,
frac_test=.1,
log_every_n=1000):
"""
Splits internal compounds into train/validation/test by scaffold.
"""
np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.)
scaffolds = {}
log("About to generate scaffolds", self.verbose)
data_len = len(dataset)
for ind, smiles in enumerate(dataset.ids):
if ind % log_every_n == 0:
log("Generating scaffold %d/%d" % (ind, data_len), self.verbose)
scaffold = generate_scaffold(smiles)
if scaffold not in scaffolds:
scaffolds[scaffold] = [ind]
else:
scaffolds[scaffold].append(ind)
# Sort from largest to smallest scaffold sets
scaffolds = {key: sorted(value) for key, value in scaffolds.items()}
scaffold_sets = [
scaffold_set for (scaffold, scaffold_set) in sorted(
scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True)
]
train_cutoff = frac_train * len(dataset)
valid_cutoff = (frac_train + frac_valid) * len(dataset)
train_inds, valid_inds, test_inds = [], [], []
log("About to sort in scaffold sets", self.verbose)
for scaffold_set in scaffold_sets:
"Bad PDB! Improperly formatted CONECT line (too short)")
continue
atom_index = int(line[6:11].strip())
if atom_index not in self.all_atoms:
log(
"Bad PDB! Improper CONECT line: (atom index not loaded)")
continue
bonded_atoms = []
ranges = [(11, 16), (16, 21), (21, 26), (26, 31)]
misformatted = False
for (lower, upper) in ranges:
# Check that the range is nonempty.
if line[lower:upper].strip():
index = int(line[lower:upper])
if index not in self.all_atoms:
log(
"Bad PDB! Improper CONECT line: (bonded atom not loaded)")
misformatted = True
break
bonded_atoms.append(index)
if misformatted:
continue
atom = self.all_atoms[atom_index]
atom.add_neighbor_atom_indices(bonded_atoms)