Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
elif max_run != num_qstrat:
raise ValueError(
'num_qstrat in feature_encoder_args must agree '
'with max_run in feature_encoder_args')
# Create and serialise to file model ancilliaries
feature_encoder = feature_encoders[args.feature_encoder](
**args.feature_encoder_args)
ds.set_meta(feature_encoder, 'feature_encoder')
label_scheme = medaka.labels.label_schemes[args.label_scheme](
**args.label_scheme_args)
ds.set_meta(label_scheme, 'label_scheme')
model_function = functools.partial(
medaka.models.build_model,
feature_encoder.feature_vector_length,
len(label_scheme._decoding))
ds.set_meta(model_function, 'model_function')
# TODO: this parallelism would be better in
# `SampleGenerator.bams_to_training_samples` since training
# alignments are usually chunked.
with concurrent.futures.ProcessPoolExecutor(max_workers=args.threads) \
as executor:
# break up overly long chunks
MAX_SIZE = int(1e6)
regions = itertools.chain(*(r.split(MAX_SIZE) for r in regions))
futures = [executor.submit(
_samples_worker, args, reg,
feature_encoder, label_scheme) for reg in regions]
inter_op_parallelism_threads=args.threads)
))
# Split overly long regions to maximum size so as to not create
# massive feature matrices
MAX_REGION_SIZE = int(1e6) # 1Mb
regions = []
for region in args.regions:
if region.size > MAX_REGION_SIZE:
regs = region.split(MAX_REGION_SIZE, args.chunk_ovlp)
else:
regs = [region]
regions.extend(regs)
logger.info("Processing {} long region(s) with batching.".format(len(regions)))
model = medaka.models.load_model(args.model, time_steps=args.chunk_len)
# the returned regions are those where the pileup width is smaller than chunk_len
remainder_regions = run_prediction(
args.output, args.bam, regions, model, args.model, args.rle_ref, args.read_fraction,
args.chunk_len, args.chunk_ovlp,
batch_size=args.batch_size, save_features=args.save_features,
tag_name=args.tag_name, tag_value=args.tag_value, tag_keep_missing=args.tag_keep_missing
)
# short/remainder regions: just do things without chunking. We can do this
# here because we now have the size of all pileups (and know they are small).
# TODO: can we avoid calculating pileups twice whilst controlling memory?
if len(remainder_regions) > 0:
logger.info("Processing {} short region(s).".format(len(remainder_regions)))
model = medaka.models.load_model(args.model, time_steps=None)
for region in remainder_regions:
new_remainders = run_prediction(
for region in args.regions:
if region.size > MAX_REGION_SIZE:
# chunk_ovlp is mostly used in overlapping pileups (which generally
# end up being expanded compared to the draft coordinate system)
regs = region.split(
MAX_REGION_SIZE, overlap=args.chunk_ovlp, fixed_size=False)
else:
regs = [region]
regions.extend(regs)
logger.info("Processing {} long region(s) with batching.".format(
len(regions)))
logger.info("Using model: {}.".format(args.model))
model = medaka.models.load_model(args.model, time_steps=args.chunk_len,
allow_cudnn=args.allow_cudnn)
# the returned regions are those where the pileup width is smaller than
# chunk_len
remainder_regions = run_prediction(
args.output, args.bam, regions, model, feature_encoder,
args.chunk_len, args.chunk_ovlp,
batch_size=args.batch_size, save_features=args.save_features
)
# short/remainder regions: just do things without chunking. We can do this
# here because we now have the size of all pileups (and know they are
# small).
# TODO: can we avoid calculating pileups twice whilst controlling memory?
if len(remainder_regions) > 0:
logger.info("Processing {} short region(s).".format(
parser.add_argument('model', help='input model filepath')
parser.add_argument('output', help='output model filepath')
args = parser.parse_args()
if os.path.exists(args.output):
sys.stderr.write('{} exists already!\n'.format(args.output))
sys.exit(1)
shutil.copy(args.model, args.output)
with h5py.File(args.output) as h:
# check that model can be built
model_name = yaml.unsafe_load(h['medaka_model_name'][()])
try:
build_model = medaka.models.model_builders[model_name]
except KeyError('Can not convert; requires deprecated ' + \
'{} function.'.format(model_name)):
sys.exit(1)
features = yaml.unsafe_load(h['medaka_feature_decoding'][()])
feat_len = len(features)
classes = yaml.unsafe_load(h['medaka_label_decoding'][()])
num_classes = len(classes)
gru_size = 128
classify_activation = 'softmax'
# load specified model kwargs if they exist
model_kwargs = yaml.unsafe_load(h['medaka_model_kwargs'][()])
if 'gru_size' in model_kwargs:
gru_size = model_kwargs['gru_size']
if 'classify_activation' in model_kwargs:
def run_training(train_name, batcher, model_fp=None,
epochs=5000, class_weight=None, n_mini_epochs=1, threads_io=1, multi_label=False,
optimizer='rmsprop', optim_args=None):
"""Run training."""
from tensorflow.keras.callbacks import CSVLogger, TensorBoard, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras import optimizers
from medaka.keras_ext import ModelMetaCheckpoint, SequenceBatcher
logger = medaka.common.get_named_logger('RunTraining')
if model_fp is None:
model_name = medaka.models.default_model
model_kwargs = {
k:v.default for (k,v) in
inspect.signature(medaka.models.model_builders[model_name]).parameters.items()
if v.default is not inspect.Parameter.empty
}
else:
with medaka.datastore.DataStore(model_fp) as ds:
model_name = ds.meta['medaka_model_name']
model_kwargs = ds.meta['medaka_model_kwargs']
opt_str = '\n'.join(['{}: {}'.format(k,v) for k, v in model_kwargs.items()])
logger.info('Building {} model with: \n{}'.format(model_name, opt_str))
num_classes = len(batcher.label_scheme.label_decoding)
timesteps, feat_dim = batcher.feature_shape
model = medaka.models.model_builders[model_name](timesteps, feat_dim, num_classes, **model_kwargs)
if model_fp is not None:
with medaka.datastore.DataStore(model_fp) as ds:
partial_model_function = ds.get_meta('model_function')
model = partial_model_function(time_steps=time_steps,
allow_cudnn=allow_cudnn)
try:
model.load_weights(model_fp)
logger.info("Loading weights from {}".format(model_fp))
except Exception:
logger.info("Could not load weights from {}".format(model_fp))
else:
num_classes = batcher.label_scheme.num_classes
model_name = medaka.models.default_model
model_function = medaka.models.model_builders[model_name]
partial_model_function = functools.partial(
model_function, feat_dim, num_classes)
model = partial_model_function(time_steps=time_steps,
allow_cudnn=allow_cudnn)
model_metadata = {'model_function': partial_model_function,
'label_scheme': batcher.label_scheme,
'feature_encoder': batcher.feature_encoder}
opts = dict(verbose=1, save_best_only=True, mode='max')
if isinstance(batcher.label_scheme,
medaka.labels.DiploidZygosityLabelScheme):
metrics = ['binary_accuracy']
call_back_metrics = metrics
if args.command == 'tools' and not hasattr(args, 'func'):
# display help if given `medaka tools (--help)`
toolparser.print_help()
elif args.command == 'methylation' and not hasattr(args, 'func'):
methparser.print_help()
else:
# do some common argument validation here
if hasattr(args, 'bam') and args.bam is not None:
RG = args.RG if hasattr(args, 'RG') else None
CheckBam.check_read_groups(args.bam, RG)
if RG is not None:
msg = "Reads will be filtered to only those with RG tag: {}"
logger.info(msg.format(RG))
# if model is default, resolve to file, save mess in help text
if hasattr(args, 'model') and args.model is not None:
args.model = medaka.models.resolve_model(args.model)
args.func(args)
model_name = medaka.models.default_model
model_kwargs = {
k:v.default for (k,v) in
inspect.signature(medaka.models.model_builders[model_name]).parameters.items()
if v.default is not inspect.Parameter.empty
}
else:
with medaka.datastore.DataStore(model_fp) as ds:
model_name = ds.meta['medaka_model_name']
model_kwargs = ds.meta['medaka_model_kwargs']
opt_str = '\n'.join(['{}: {}'.format(k,v) for k, v in model_kwargs.items()])
logger.info('Building {} model with: \n{}'.format(model_name, opt_str))
num_classes = len(batcher.label_scheme.label_decoding)
timesteps, feat_dim = batcher.feature_shape
model = medaka.models.model_builders[model_name](timesteps, feat_dim, num_classes, **model_kwargs)
if model_fp is not None:
try:
model.load_weights(model_fp)
logger.info("Loading weights from {}".format(model_fp))
except:
logger.info("Could not load weights from {}".format(model_fp))
msg = "feat_dim: {}, timesteps: {}, num_classes: {}"
logger.info(msg.format(feat_dim, timesteps, num_classes))
model.summary()
model_details = batcher.meta.copy()
model_details['medaka_model_name'] = model_name
model_details['medaka_model_kwargs'] = model_kwargs