Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
if dataset == 'train':
self.data = batcher.train_samples
elif dataset == 'validation':
self.data = batcher.valid_samples
if mini_epochs != 1:
raise ValueError(
"'mini_epochs' must be equal to 1 for validation data.")
else:
raise ValueError("'dataset' should be 'train' or 'validation'.")
original_size = len(self.data)
self.n_batches = len(self.data) // self.batch_size
self.data = self.data[:self.n_batches*self.batch_size]
np.random.shuffle(self.data)
self.logger = medaka.common.get_named_logger(
'{}Batcher'.format(dataset.capitalize()))
self.logger.info(
'{} batches of {} samples ({}), from {} original.'.format(
self.n_batches, self.batch_size, len(self.data),
original_size))
def __init__(
self, features, validation=0.2, seed=0,
batch_size=500, threads=1):
"""Serve up batches of training or validation data.
:param features: iterable of str, training feature files.
:param validation: float, fraction of batches to use for validation, or
iterable of str, validation feature files.
:param seed: int, random seed for separation of batches into
training/validation.
:param batch_size: int, number of samples per batch.
:param threads: int, number of threads to use for preparing batches.
"""
self.logger = medaka.common.get_named_logger('TrainBatcher')
self.features = features
self.validation = validation
self.seed = seed
self.batch_size = batch_size
di = medaka.datastore.DataIndex(self.features, threads=threads)
self.samples = di.samples.copy()
self.label_scheme = di.metadata['label_scheme']
self.feature_encoder = di.metadata['feature_encoder']
# check sample size using first batch
test_sample, test_fname = self.samples[0]
with medaka.datastore.DataStore(test_fname) as ds:
# TODO: this should come from feature_encoder
def stitch_from_probs(h5_fp, regions=None):
"""Join overlapping label probabilities from HDF5 files.
Network outputs from multiple samples stored within a file are spliced
together into a logically contiguous array and decoded to generate
contiguous sequence(s).
:param h5_fp: iterable of HDF5 filepaths
:param regions: iterable of region (strings) to process
:returns: list of (region string, sequence)
"""
logger = common.get_named_logger('Stitch')
if isinstance(regions, medaka.common.Region):
regions = [regions]
logger.info("Stitching regions: {}".format([str(r) for r in regions]))
index = medaka.datastore.DataIndex(h5_fp)
label_scheme = index.metadata['label_scheme']
logger.debug("Label decoding is:\n{}".format(
'\n'.join('{}: {}'.format(k, v)
for k, v in label_scheme._decoding.items())))
def get_pos(sample, i):
return '{}.{}'.format(
sample.positions[i]['major'] + 1, sample.positions[i]['minor'])
ref_assemblies = []
def _samples_worker(args, region, feature_encoder, label_scheme):
logger = medaka.common.get_named_logger('PrepWork')
logger.info("Processing region {}.".format(region))
data_gen = SampleGenerator(
args.bam, region, feature_encoder, truth_bam=args.truth,
label_scheme=label_scheme, truth_haplotag=args.truth_haplotag,
chunk_len=args.chunk_len, chunk_overlap=args.chunk_ovlp)
return list(data_gen.samples), region
def run_prediction(output, bam, regions, model, model_file, rle_ref,
read_fraction, chunk_len, chunk_ovlp, batch_size=200,
save_features=False, tag_name=None, tag_value=None,
tag_keep_missing=False, enable_chunking=True):
"""Inference worker."""
logger = medaka.common.get_named_logger('PWorker')
remainder_regions = list()
def sample_gen():
# chain all samples whilst dispensing with generators when done
# (they hold the feature vector in memory until they die)
for region in regions:
data_gen = medaka.features.SampleGenerator(
bam, region, model_file, rle_ref, read_fraction,
chunk_len=chunk_len, chunk_overlap=chunk_ovlp,
tag_name=tag_name, tag_value=tag_value,
tag_keep_missing=tag_keep_missing,
enable_chunking=enable_chunking)
yield from data_gen.samples
remainder_regions.extend(data_gen._quarantined)
batches = medaka.common.background_generator(
medaka.common.grouper(sample_gen(), batch_size), 10
:returns: {haplotype: [`TruthAlignment`]}
"""
alignments = collections.defaultdict(list)
with pysam.AlignmentFile(truth_bam, 'rb') as bamfile:
aln_reads = bamfile.fetch(
reference=region.ref_name,
start=region.start, end=region.end)
for r in aln_reads:
if (r.is_unmapped or r.is_secondary):
continue
else:
hap = r.get_tag(haplotag) if haplotag is not None else None
alignments[hap].append(TruthAlignment(r))
logger = medaka.common.get_named_logger("TruthAlign")
for hap in alignments.keys():
alignments[hap].sort(key=attrgetter('start'))
logger.info("Retrieved {} alignments for haplotype {}.".format(
len(alignments[hap]), hap))
return alignments
def __init__(self, filename, mode='r', verify_on_close=True):
"""Initialize a datastore.
:param filename: file to open.
:param mode: file opening mode ('r', 'w', 'a').
:param verify_on_close: on file close, check that all samples logged
as being stored in file have a corresponding group within the
`.hdf`."
"""
self.filename = filename
self.mode = mode
self.verify_on_close = verify_on_close
self.logger = medaka.common.get_named_logger('DataStore')
self.write_executor = ThreadPoolExecutor(1)
self.write_futures = []