Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
test_sample, test_fname = self.samples[0]
with medaka.datastore.DataStore(test_fname) as ds:
self.feature_shape = ds.load_sample(test_sample).features.shape
self.logger.info("Sample features have shape {}".format(self.feature_shape))
if isinstance(self.validation, float):
np.random.seed(self.seed)
np.random.shuffle(self.samples)
n_sample_train = int((1 - self.validation) * len(self.samples))
self.train_samples = self.samples[:n_sample_train]
self.valid_samples = self.samples[n_sample_train:]
msg = 'Randomly selected {} ({:3.2%}) of features for validation (seed {})'
self.logger.info(msg.format(len(self.valid_samples), self.validation, self.seed))
else:
self.train_samples = self.samples
self.valid_samples = medaka.datastore.DataIndex(self.validation).samples.copy()
msg = 'Found {} validation samples equivalent to {:3.2%} of all the data'
fraction = len(self.valid_samples) / len(self.valid_samples) + len(self.train_samples)
self.logger.info(msg.format(len(self.valid_samples), fraction))
msg = 'Got {} samples in {} batches ({} labels) for {}'
self.logger.info(msg.format(len(self.train_samples),
len(self.train_samples) // batch_size,
len(self.train_samples) * self.feature_shape[0],
'training'))
self.logger.info(msg.format(len(self.valid_samples),
len(self.valid_samples) // batch_size,
len(self.valid_samples) * self.feature_shape[0],
'validation'))
def stitch(args):
"""Entry point for stitching program."""
index = medaka.datastore.DataIndex(args.inputs)
if args.regions is None:
args.regions = sorted(index.index)
# batch size is a simple empirical heuristic
regions = medaka.common.grouper(
(common.Region.from_string(r) for r in args.regions),
batch_size=max(1, len(args.regions) // (2 * args.jobs)))
with open(args.output, 'w') as fasta:
Executor = concurrent.futures.ProcessPoolExecutor
with Executor(max_workers=args.jobs) as executor:
worker = functools.partial(stitch_from_probs, args.inputs)
for contigs in executor.map(worker, regions):
for name, info, seq in contigs:
fasta.write('>{} {}\n{}\n'.format(name, info, seq))
self.logger.info(
"Sample features have shape {}".format(self.feature_shape))
if isinstance(self.validation, float):
np.random.seed(self.seed)
np.random.shuffle(self.samples)
n_sample_train = int((1 - self.validation) * len(self.samples))
self.train_samples = self.samples[:n_sample_train]
self.valid_samples = self.samples[n_sample_train:]
self.logger.info(
'Randomly selected {} ({:3.2%}) of features for '
'validation (seed {})'.format(
len(self.valid_samples), self.validation, self.seed))
else:
self.train_samples = self.samples
self.valid_samples = medaka.datastore.DataIndex(
self.validation).samples.copy()
msg = 'Found {} validation samples, to {:3.2%} of all the data'
fraction = len(self.valid_samples) / \
(len(self.valid_samples) + len(self.train_samples))
self.logger.info(msg.format(len(self.valid_samples), fraction))
msg = 'Got {} samples in {} batches ({} labels) for {}'
self.logger.info(msg.format(
len(self.train_samples),
len(self.train_samples) // batch_size,
len(self.train_samples) * self.feature_shape[0],
'training'))
self.logger.info(msg.format(
len(self.valid_samples),
len(self.valid_samples) // batch_size,
len(self.valid_samples) * self.feature_shape[0],
def _extract_sample_registries(self):
"""."""
self.samples = []
with ProcessPoolExecutor(self.threads) as executor:
future_to_fn = {
executor.submit(DataIndex._load_sample_registry, fn): fn
for fn in self.filenames}
for i, future in enumerate(as_completed(future_to_fn), 1):
fn = future_to_fn[future]
try:
sample_registry = future.result()
self.samples.extend(
[(s, fn) for s in sample_registry])
except Exception:
self.logger.info('No sample_registry in {}'.format(fn))
else:
self.logger.info(
'Loaded {}/{} ({:.2f}%) sample files.'.format(
i, self.n_files,
i / self.n_files * 100))
# make order of samples independent of order in which tasks complete
:param validation: float, fraction of batches to use for validation, or
iterable of str, validation feature files.
:param seed: int, random seed for separation of batches into
training/validation.
:param batch_size: int, number of samples per batch.
:param threads: int, number of threads to use for preparing batches.
"""
self.logger = medaka.common.get_named_logger('TrainBatcher')
self.features = features
self.validation = validation
self.seed = seed
self.batch_size = batch_size
di = medaka.datastore.DataIndex(self.features, threads=threads)
self.samples = di.samples.copy()
self.label_scheme = di.metadata['label_scheme']
self.feature_encoder = di.metadata['feature_encoder']
# check sample size using first batch
test_sample, test_fname = self.samples[0]
with medaka.datastore.DataStore(test_fname) as ds:
# TODO: this should come from feature_encoder
self.feature_shape = ds.load_sample(test_sample).features.shape
self.logger.info(
"Sample features have shape {}".format(self.feature_shape))
if isinstance(self.validation, float):
np.random.seed(self.seed)
np.random.shuffle(self.samples)
def snps_from_hdf(args):
"""Entry point for SNP calling from HDF5 files.
A `LabelScheme` read from file is used to decode SNPs. All `LabelScheme` s
define a `decode_snps` public method. We do not need to use `join_samples`
to look for variants overlapping sample slice boundaries because we only
analyse a single locus at a time. This means that `LabelScheme` s that do
not define the `decode_consensus` method (called within `join_samples`) can
be used.
"""
logger = medaka.common.get_named_logger('SNPs')
index = medaka.datastore.DataIndex(args.inputs)
if args.regions is None:
args.regions = sorted(index.index)
regions = [
medaka.common.Region.from_string(r)
for r in args.regions]
# lookup LabelScheme stored in HDF5
try:
label_scheme = index.metadata['label_scheme']
except KeyError:
logger.debug(
"Could not find `label_scheme` metadata in input file, "
"assuming HaploidLabelScheme.")
label_scheme = medaka.labels.HaploidLabelScheme()
def variants_from_hdf(args):
"""Entry point for variant calling from HDF5 files.
A `LabelScheme` read from HDF must define both a `decode_variants`
and `decode_consnesus` method. The latter is used with `join_samples`
to detect multi-locus variants spanning `Sample` slice boundaries.
"""
logger = medaka.common.get_named_logger('Variants')
index = medaka.datastore.DataIndex(args.inputs)
if args.regions is None:
args.regions = sorted(index.index)
regions = [
medaka.common.Region.from_string(r)
for r in args.regions]
# lookup LabelScheme stored in HDF5
try:
label_scheme = index.metadata['label_scheme']
except KeyError:
logger.debug(
"Could not find `label_scheme` metadata in input file, "
"assuming HaploidLabelScheme.")
label_scheme = medaka.labels.HaploidLabelScheme()
:param validation: float, fraction of batches to use for validation, or
iterable of str, validation feature files.
:param seed: int, random seed for separation of batches into training/validation.
:param batch_size: int, number of samples per batch.
:param threads: int, number of threads to use for preparing batches.
"""
self.logger = medaka.common.get_named_logger('TrainBatcher')
self.features = features
self.max_label_len = max_label_len
self.validation = validation
self.seed = seed
self.sparse_labels = label_scheme_cls.sparse_labels
self.batch_size = batch_size
di = medaka.datastore.DataIndex(self.features, threads=threads)
self.samples = di.samples.copy()
self.meta = di.meta.copy()
self.label_counts = self.meta['medaka_label_counts']
self.label_scheme = label_scheme_cls(self.label_counts, max_label_len=self.max_label_len)
# check sample size using first batch
test_sample, test_fname = self.samples[0]
with medaka.datastore.DataStore(test_fname) as ds:
self.feature_shape = ds.load_sample(test_sample).features.shape
self.logger.info("Sample features have shape {}".format(self.feature_shape))
if isinstance(self.validation, float):
np.random.seed(self.seed)
np.random.shuffle(self.samples)
n_sample_train = int((1 - self.validation) * len(self.samples))
self.train_samples = self.samples[:n_sample_train]
Network outputs from multiple samples stored within a file are spliced
together into a logically contiguous array and decoded to generate
contiguous sequence(s).
:param h5_fp: iterable of HDF5 filepaths
:param regions: iterable of region (strings) to process
:returns: list of (region string, sequence)
"""
logger = common.get_named_logger('Stitch')
if isinstance(regions, medaka.common.Region):
regions = [regions]
logger.info("Stitching regions: {}".format([str(r) for r in regions]))
index = medaka.datastore.DataIndex(h5_fp)
label_scheme = index.metadata['label_scheme']
logger.debug("Label decoding is:\n{}".format(
'\n'.join('{}: {}'.format(k, v)
for k, v in label_scheme._decoding.items())))
def get_pos(sample, i):
return '{}.{}'.format(
sample.positions[i]['major'] + 1, sample.positions[i]['minor'])
ref_assemblies = []
for reg in regions:
logger.info("Processing {}.".format(reg))
data_gen = index.yield_from_feature_files(regions=[reg])
seq_parts = list()
cur_ref_name = ''