Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def stitch(args):
"""Entry point for stitching program."""
index = medaka.datastore.DataIndex(args.inputs)
if args.regions is None:
args.regions = sorted(index.index)
# batch size is a simple empirical heuristic
regions = medaka.common.grouper(
(common.Region.from_string(r) for r in args.regions),
batch_size=max(1, len(args.regions) // (2 * args.jobs)))
with open(args.output, 'w') as fasta:
Executor = concurrent.futures.ProcessPoolExecutor
with Executor(max_workers=args.jobs) as executor:
worker = functools.partial(stitch_from_probs, args.inputs)
for contigs in executor.map(worker, regions):
for name, info, seq in contigs:
fasta.write('>{} {}\n{}\n'.format(name, info, seq))
remainder_regions = list()
def sample_gen():
# chain all samples whilst dispensing with generators when done
# (they hold the feature vector in memory until they die)
for region in regions:
data_gen = medaka.features.SampleGenerator(
bam, region, model_file, rle_ref, read_fraction,
chunk_len=chunk_len, chunk_overlap=chunk_ovlp,
tag_name=tag_name, tag_value=tag_value,
tag_keep_missing=tag_keep_missing,
enable_chunking=enable_chunking)
yield from data_gen.samples
remainder_regions.extend(data_gen._quarantined)
batches = medaka.common.background_generator(
medaka.common.grouper(sample_gen(), batch_size), 10
)
total_region_mbases = sum(r.size for r in regions) / 1e6
logger.info("Running inference for {:.1f}M draft bases.".format(total_region_mbases))
with medaka.datastore.DataStore(output, 'a', verify_on_close=False) as ds:
mbases_done = 0
t0 = now()
tlast = t0
for data in batches:
x_data = np.stack([x.features for x in data])
class_probs = model.predict_on_batch(x_data)
# calculate bases done taking into account overlap
new_bases = 0
for x in data:
tmp_output, 'wb', header=alignments_bam.header) as output:
for region in regions:
bam_current = alignments_bam.fetch(
reference=region.ref_name,
start=region.start,
end=region.end)
ref_sequence = ref_fasta.fetch(region.ref_name)
ref_rle = RLEConverter(ref_sequence)
func = functools.partial(
_compress_alignment,
ref_rle=ref_rle,
fast5_dir=fast5_dir,
file_index=file_index)
with concurrent.futures.ThreadPoolExecutor(
max_workers=threads) as executor:
for chunk in medaka.common.grouper(bam_current, 100):
for new_alignment in executor.map(func, chunk):
if new_alignment is not None:
output.write(new_alignment)
pysam.sort("-o", bam_output, tmp_output)
os.remove(tmp_output)
pysam.index(bam_output)
def run_prediction(
output, bam, regions, model, feature_encoder,
chunk_len, chunk_ovlp, batch_size=200,
save_features=False, enable_chunking=True):
"""Inference worker."""
logger = medaka.common.get_named_logger('PWorker')
remainder_regions = list()
loader = DataLoader(
4 * batch_size, bam, regions, feature_encoder,
chunk_len=chunk_len, chunk_overlap=chunk_ovlp,
enable_chunking=enable_chunking)
batches = medaka.common.grouper(loader, batch_size)
total_region_mbases = sum(r.size for r in regions) / 1e6
logger.info(
"Running inference for {:.1f}M draft bases.".format(
total_region_mbases))
with medaka.datastore.DataStore(output, 'a') as ds:
mbases_done = 0
cache_size_log_interval = 5
t0 = now()
tlast = t0
tcache = t0
for data in batches:
if now() - tcache > cache_size_log_interval:
logger.info("Samples in cache: {}.".format(