Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
exp_dir(str): Experiment directory. Expects to find
`'best_k_models.json'` there.
Returns:
nn.Module the best pretrained model according to the val_loss.
"""
# Create the model from recipe-local function
model, _ = make_model_and_optimizer(train_conf)
# Last best model summary
with open(os.path.join(exp_dir, 'best_k_models.json'), "r") as f:
best_k = json.load(f)
best_model_path = min(best_k, key=best_k.get)
# Load checkpoint
checkpoint = torch.load(best_model_path, map_location='cpu')
# Load state_dict into model.
model = torch_utils.load_state_dict_in(checkpoint['state_dict'],
model)
model.eval()
return model
# Make the model
model, _ = make_model_and_optimizer(conf['train_conf'])
# Load best model
with open(os.path.join(conf['exp_dir'], 'best_k_models.json'), "r") as f:
best_k = json.load(f)
best_model_path = min(best_k, key=best_k.get)
# Load checkpoint
checkpoint = torch.load(best_model_path, map_location='cpu')
state = checkpoint['state_dict']
state_copy = state.copy()
# Remove unwanted keys
for keys, values in state.items():
if keys.startswith('loss'):
del state_copy[keys]
print(keys)
model = load_state_dict_in(state_copy, model)
# Handle device placement
if conf['use_gpu']:
model.cuda()
model_device = next(model.parameters()).device
test_set = LibriMix(conf['test_dir'], None,
conf['sample_rate'],
conf['train_conf']['data']['n_src'])
loss_func = PITLossWrapper(pairwise_neg_sisdr, mode='pairwise')
# Randomly choose the indexes of sentences to save.
ex_save_dir = os.path.join(conf['exp_dir'], 'examples_mss_8K/')
if conf['n_save_ex'] == -1:
conf['n_save_ex'] = len(test_set)
if os.path.exists(checkpoint_dir):
available_models = [p for p in os.listdir(checkpoint_dir)
if '.ckpt' in p]
if available_models:
model_available = True
if not model_available:
raise FileNotFoundError('There is no available separator model at: {}'
''.format(checkpoint_dir))
model_path = os.path.join(checkpoint_dir, available_models[0])
print('Going to load from: {}'.format(model_path))
checkpoint = torch.load(model_path, map_location='cpu')
model_c, _ = make_model_and_optimizer(conf, model_part='separator',
pretrained_filterbank=filterbank)
model = torch_utils.load_state_dict_in(checkpoint['state_dict'], model_c)
print('Successfully loaded separator from: {}'.format(model_path))
return model