Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
args.chunk_len, args.chunk_ovlp, # these won't be used
batch_size=args.batch_size, save_features=args.save_features,
enable_chunking=False
)
if len(new_remainders) > 0:
# shouldn't get here
ignored = [x[0] for x in new_remainders]
n_ignored = len(ignored)
logger.warning("{} regions were not processed: {}.".format(
n_ignored, ignored))
logger.info("Finished processing all regions.")
if args.check_output:
logger.info("Validating and finalising output data.")
with medaka.datastore.DataStore(args.output, 'a') as ds:
pass
logger_level = logging.getLogger(__package__).level
if logger_level > logging.DEBUG:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras import backend as K
args.regions = medaka.common.get_regions(args.bam, region_strs=args.regions)
logger = medaka.common.get_named_logger('Predict')
logger.info('Processing region(s): {}'.format(' '.join(str(r) for r in args.regions)))
# write class names to output
with medaka.datastore.DataStore(args.model) as ds:
meta = ds.meta
with medaka.datastore.DataStore(args.output, 'w', verify_on_close=False) as ds:
ds.update_meta(meta)
logger.info("Setting tensorflow threads to {}.".format(args.threads))
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
K.set_session(tf.Session(
config=tf.ConfigProto(
intra_op_parallelism_threads=args.threads,
inter_op_parallelism_threads=args.threads)
))
# Split overly long regions to maximum size so as to not create
# massive feature matrices
MAX_REGION_SIZE = int(1e6) # 1Mb
regions = []
for region in args.regions:
if region.size > MAX_REGION_SIZE:
"""Run training."""
from tensorflow.keras.callbacks import CSVLogger, TensorBoard, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras import optimizers
from medaka.keras_ext import ModelMetaCheckpoint, SequenceBatcher
logger = medaka.common.get_named_logger('RunTraining')
if model_fp is None:
model_name = medaka.models.default_model
model_kwargs = {
k:v.default for (k,v) in
inspect.signature(medaka.models.model_builders[model_name]).parameters.items()
if v.default is not inspect.Parameter.empty
}
else:
with medaka.datastore.DataStore(model_fp) as ds:
model_name = ds.meta['medaka_model_name']
model_kwargs = ds.meta['medaka_model_kwargs']
opt_str = '\n'.join(['{}: {}'.format(k,v) for k, v in model_kwargs.items()])
logger.info('Building {} model with: \n{}'.format(model_name, opt_str))
num_classes = len(batcher.label_scheme.label_decoding)
timesteps, feat_dim = batcher.feature_shape
model = medaka.models.model_builders[model_name](timesteps, feat_dim, num_classes, **model_kwargs)
if model_fp is not None:
try:
model.load_weights(model_fp)
logger.info("Loading weights from {}".format(model_fp))
except:
logger.info("Could not load weights from {}".format(model_fp))
self.logger = medaka.common.get_named_logger('TrainBatcher')
self.features = features
self.validation = validation
self.seed = seed
self.batch_size = batch_size
di = medaka.datastore.DataIndex(self.features, threads=threads)
self.samples = di.samples.copy()
self.label_scheme = di.metadata['label_scheme']
self.feature_encoder = di.metadata['feature_encoder']
# check sample size using first batch
test_sample, test_fname = self.samples[0]
with medaka.datastore.DataStore(test_fname) as ds:
# TODO: this should come from feature_encoder
self.feature_shape = ds.load_sample(test_sample).features.shape
self.logger.info(
"Sample features have shape {}".format(self.feature_shape))
if isinstance(self.validation, float):
np.random.seed(self.seed)
np.random.shuffle(self.samples)
n_sample_train = int((1 - self.validation) * len(self.samples))
self.train_samples = self.samples[:n_sample_train]
self.valid_samples = self.samples[n_sample_train:]
self.logger.info(
'Randomly selected {} ({:3.2%}) of features for '
'validation (seed {})'.format(
len(self.valid_samples), self.validation, self.seed))
else:
all_samples = self.index
if regions is None:
regions = [
medaka.common.Region.from_string(x)
for x in sorted(all_samples)]
for reg in regions:
if reg.ref_name not in self.index:
continue
for sample in self.index[reg.ref_name]:
# samples can have major.minor coords, round to end excl.
sam_reg = medaka.common.Region(
sample['ref_name'],
int(float(sample['start'])),
int(float(sample['end'])) + 1)
if sam_reg.overlaps(reg):
with DataStore(sample['filename']) as store:
yield store.load_sample(sample['sample_key'])