Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
reader = csv.reader(f)
next(reader) # skip header
lines = []
for line in reader:
smiles = line[0]
if smiles in skip_smiles:
continue
lines.append(line)
if len(lines) >= max_data_size:
break
data = MoleculeDataset([
MoleculeDatapoint(
line=line,
args=args,
features=features_data[i] if features_data is not None else None,
use_compound_names=use_compound_names
) for i, line in tqdm(enumerate(lines), total=len(lines))
])
# Filter out invalid SMILES
if skip_invalid_smiles:
original_data_len = len(data)
data = filter_invalid_smiles(data)
if len(data) < original_data_len:
debug(f'Warning: {original_data_len - len(data)} SMILES are invalid.')
test_scaffold_count += 1
if logger is not None:
logger.debug(f'Total scaffolds = {len(scaffold_to_indices):,} | '
f'train scaffolds = {train_scaffold_count:,} | '
f'val scaffolds = {val_scaffold_count:,} | '
f'test scaffolds = {test_scaffold_count:,}')
log_scaffold_stats(data, index_sets, logger=logger)
# Map from indices to data
train = [data[i] for i in train]
val = [data[i] for i in val]
test = [data[i] for i in test]
return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(test)
def filter_invalid_smiles(data: MoleculeDataset) -> MoleculeDataset:
"""
Filters out invalid SMILES.
:param data: A MoleculeDataset.
:param logger: Logger.
:return: A MoleculeDataset with only valid molecules.
"""
return MoleculeDataset([datapoint for datapoint in data
if datapoint.smiles != '' and datapoint.mol is not None
and datapoint.mol.GetNumHeavyAtoms() > 0])
elif split_type == 'scaffold_overlap':
assert scaffold_overlap is not None
return scaffold_split_overlap(data, overlap=scaffold_overlap, seed=seed, logger=logger)
elif split_type == 'random':
data.shuffle(seed=seed)
train_size = int(sizes[0] * len(data))
train_val_size = int((sizes[0] + sizes[1]) * len(data))
train = data[:train_size]
val = data[train_size:train_val_size]
test = data[train_val_size:]
return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(test)
else:
raise ValueError(f'split_type "{split_type}" not supported.')
def chunk(self, num_chunks: int, seed: int = None) -> List['MoleculeDataset']:
self.shuffle(seed)
datasets = []
chunk_len = math.ceil(len(self.data) / num_chunks)
for i in range(num_chunks):
datasets.append(MoleculeDataset(self.data[i * chunk_len:(i + 1) * chunk_len]))
return datasets
f'train scaffolds = {len(set(index_to_scaffold[index] for index in train)):,} | '
f'val scaffolds = {len(set(index_to_scaffold[index] for index in val)):,} | '
f'test scaffolds = {len(set(index_to_scaffold[index] for index in test)):,}')
# Map from indices to data
train = [data[i] for i in train]
val = [data[i] for i in val]
test = [data[i] for i in test]
# Shuffle since overlap and non-overlap are not shuffled
random.seed(seed)
random.shuffle(train)
random.shuffle(val)
random.shuffle(test)
return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(test)