How to use the medaka.models function in medaka

To help you get started, we’ve selected a few medaka examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github nanoporetech / medaka / medaka / features.py View on Github external
elif max_run != num_qstrat:
            raise ValueError(
                 'num_qstrat in feature_encoder_args must agree '
                 'with max_run in feature_encoder_args')

        # Create and serialise to file model ancilliaries
        feature_encoder = feature_encoders[args.feature_encoder](
            **args.feature_encoder_args)
        ds.set_meta(feature_encoder, 'feature_encoder')

        label_scheme = medaka.labels.label_schemes[args.label_scheme](
            **args.label_scheme_args)
        ds.set_meta(label_scheme, 'label_scheme')

        model_function = functools.partial(
            medaka.models.build_model,
            feature_encoder.feature_vector_length,
            len(label_scheme._decoding))
        ds.set_meta(model_function, 'model_function')

        # TODO: this parallelism would be better in
        # `SampleGenerator.bams_to_training_samples` since training
        # alignments are usually chunked.
        with concurrent.futures.ProcessPoolExecutor(max_workers=args.threads) \
                as executor:
            # break up overly long chunks
            MAX_SIZE = int(1e6)
            regions = itertools.chain(*(r.split(MAX_SIZE) for r in regions))
            futures = [executor.submit(
                _samples_worker, args, reg,
                feature_encoder, label_scheme) for reg in regions]
github nanoporetech / medaka / medaka / inference.py View on Github external
inter_op_parallelism_threads=args.threads)
    ))

    # Split overly long regions to maximum size so as to not create
    #   massive feature matrices
    MAX_REGION_SIZE = int(1e6)  # 1Mb
    regions = []
    for region in args.regions:
        if region.size > MAX_REGION_SIZE:
            regs = region.split(MAX_REGION_SIZE, args.chunk_ovlp)
        else:
            regs = [region]
        regions.extend(regs)

    logger.info("Processing {} long region(s) with batching.".format(len(regions)))
    model = medaka.models.load_model(args.model, time_steps=args.chunk_len)
    # the returned regions are those where the pileup width is smaller than chunk_len
    remainder_regions = run_prediction(
        args.output, args.bam, regions, model, args.model, args.rle_ref, args.read_fraction,
        args.chunk_len, args.chunk_ovlp,
        batch_size=args.batch_size, save_features=args.save_features,
        tag_name=args.tag_name, tag_value=args.tag_value, tag_keep_missing=args.tag_keep_missing
    )

    # short/remainder regions: just do things without chunking. We can do this
    # here because we now have the size of all pileups (and know they are small).
    # TODO: can we avoid calculating pileups twice whilst controlling memory?
    if len(remainder_regions) > 0:
        logger.info("Processing {} short region(s).".format(len(remainder_regions)))
        model = medaka.models.load_model(args.model, time_steps=None)
        for region in remainder_regions:
            new_remainders = run_prediction(
github nanoporetech / medaka / medaka / prediction.py View on Github external
for region in args.regions:
        if region.size > MAX_REGION_SIZE:
            # chunk_ovlp is mostly used in overlapping pileups (which generally
            # end up being expanded compared to the draft coordinate system)
            regs = region.split(
                MAX_REGION_SIZE, overlap=args.chunk_ovlp, fixed_size=False)
        else:
            regs = [region]
        regions.extend(regs)

    logger.info("Processing {} long region(s) with batching.".format(
        len(regions)))

    logger.info("Using model: {}.".format(args.model))

    model = medaka.models.load_model(args.model, time_steps=args.chunk_len,
                                     allow_cudnn=args.allow_cudnn)

    # the returned regions are those where the pileup width is smaller than
    # chunk_len
    remainder_regions = run_prediction(
        args.output, args.bam, regions, model, feature_encoder,
        args.chunk_len, args.chunk_ovlp,
        batch_size=args.batch_size, save_features=args.save_features
    )

    # short/remainder regions: just do things without chunking. We can do this
    # here because we now have the size of all pileups (and know they are
    # small).
    # TODO: can we avoid calculating pileups twice whilst controlling memory?
    if len(remainder_regions) > 0:
        logger.info("Processing {} short region(s).".format(
github nanoporetech / medaka / scripts / update_model.py View on Github external
parser.add_argument('model', help='input model filepath')
    parser.add_argument('output', help='output model filepath')
    args = parser.parse_args()

    if os.path.exists(args.output):
        sys.stderr.write('{} exists already!\n'.format(args.output))
        sys.exit(1)

    shutil.copy(args.model, args.output)

    with h5py.File(args.output) as h:

        # check that model can be built
        model_name = yaml.unsafe_load(h['medaka_model_name'][()])
        try:
            build_model = medaka.models.model_builders[model_name]
        except KeyError('Can not convert; requires deprecated ' + \
                        '{} function.'.format(model_name)):
            sys.exit(1)

        features = yaml.unsafe_load(h['medaka_feature_decoding'][()])
        feat_len = len(features)
        classes = yaml.unsafe_load(h['medaka_label_decoding'][()])
        num_classes = len(classes)

        gru_size = 128
        classify_activation = 'softmax'
        # load specified model kwargs if they exist
        model_kwargs = yaml.unsafe_load(h['medaka_model_kwargs'][()])
        if 'gru_size' in model_kwargs:
            gru_size = model_kwargs['gru_size']
        if 'classify_activation' in model_kwargs:
github nanoporetech / medaka / medaka / inference.py View on Github external
def run_training(train_name, batcher, model_fp=None,
                 epochs=5000, class_weight=None, n_mini_epochs=1, threads_io=1, multi_label=False,
                 optimizer='rmsprop', optim_args=None):
    """Run training."""
    from tensorflow.keras.callbacks import CSVLogger, TensorBoard, EarlyStopping, ReduceLROnPlateau
    from tensorflow.keras import optimizers
    from medaka.keras_ext import ModelMetaCheckpoint, SequenceBatcher

    logger = medaka.common.get_named_logger('RunTraining')

    if model_fp is None:
        model_name = medaka.models.default_model
        model_kwargs = {
            k:v.default for (k,v) in
            inspect.signature(medaka.models.model_builders[model_name]).parameters.items()
            if v.default is not inspect.Parameter.empty
        }
    else:
        with medaka.datastore.DataStore(model_fp) as ds:
            model_name = ds.meta['medaka_model_name']
            model_kwargs = ds.meta['medaka_model_kwargs']

    opt_str = '\n'.join(['{}: {}'.format(k,v) for k, v in model_kwargs.items()])
    logger.info('Building {} model with: \n{}'.format(model_name, opt_str))
    num_classes = len(batcher.label_scheme.label_decoding)
    timesteps, feat_dim = batcher.feature_shape
    model = medaka.models.model_builders[model_name](timesteps, feat_dim, num_classes, **model_kwargs)
github nanoporetech / medaka / medaka / training.py View on Github external
if model_fp is not None:
        with medaka.datastore.DataStore(model_fp) as ds:
            partial_model_function = ds.get_meta('model_function')
            model = partial_model_function(time_steps=time_steps,
                                           allow_cudnn=allow_cudnn)
        try:
            model.load_weights(model_fp)
            logger.info("Loading weights from {}".format(model_fp))
        except Exception:
            logger.info("Could not load weights from {}".format(model_fp))

    else:
        num_classes = batcher.label_scheme.num_classes
        model_name = medaka.models.default_model
        model_function = medaka.models.model_builders[model_name]
        partial_model_function = functools.partial(
            model_function, feat_dim, num_classes)
        model = partial_model_function(time_steps=time_steps,
                                       allow_cudnn=allow_cudnn)

    model_metadata = {'model_function': partial_model_function,
                      'label_scheme': batcher.label_scheme,
                      'feature_encoder': batcher.feature_encoder}

    opts = dict(verbose=1, save_best_only=True, mode='max')

    if isinstance(batcher.label_scheme,
                  medaka.labels.DiploidZygosityLabelScheme):

        metrics = ['binary_accuracy']
        call_back_metrics = metrics
github nanoporetech / medaka / medaka / medaka.py View on Github external
if args.command == 'tools' and not hasattr(args, 'func'):
        # display help if given `medaka tools (--help)`
        toolparser.print_help()
    elif args.command == 'methylation' and not hasattr(args, 'func'):
        methparser.print_help()
    else:
        # do some common argument validation here
        if hasattr(args, 'bam') and args.bam is not None:
            RG = args.RG if hasattr(args, 'RG') else None
            CheckBam.check_read_groups(args.bam, RG)
            if RG is not None:
                msg = "Reads will be filtered to only those with RG tag: {}"
                logger.info(msg.format(RG))
        # if model is default, resolve to file, save mess in help text
        if hasattr(args, 'model') and args.model is not None:
            args.model = medaka.models.resolve_model(args.model)
        args.func(args)
github nanoporetech / medaka / medaka / inference.py View on Github external
model_name = medaka.models.default_model
        model_kwargs = {
            k:v.default for (k,v) in
            inspect.signature(medaka.models.model_builders[model_name]).parameters.items()
            if v.default is not inspect.Parameter.empty
        }
    else:
        with medaka.datastore.DataStore(model_fp) as ds:
            model_name = ds.meta['medaka_model_name']
            model_kwargs = ds.meta['medaka_model_kwargs']

    opt_str = '\n'.join(['{}: {}'.format(k,v) for k, v in model_kwargs.items()])
    logger.info('Building {} model with: \n{}'.format(model_name, opt_str))
    num_classes = len(batcher.label_scheme.label_decoding)
    timesteps, feat_dim = batcher.feature_shape
    model = medaka.models.model_builders[model_name](timesteps, feat_dim, num_classes, **model_kwargs)

    if model_fp is not None:
        try:
            model.load_weights(model_fp)
            logger.info("Loading weights from {}".format(model_fp))
        except:
            logger.info("Could not load weights from {}".format(model_fp))

    msg = "feat_dim: {}, timesteps: {}, num_classes: {}"
    logger.info(msg.format(feat_dim, timesteps, num_classes))
    model.summary()

    model_details = batcher.meta.copy()

    model_details['medaka_model_name'] = model_name
    model_details['medaka_model_kwargs'] = model_kwargs