Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def main(conf):
model_path = os.path.join(conf['exp_dir'], 'best_model.pth')
model = ConvTasNet.from_pretrained(model_path)
# Handle device placement
if conf['use_gpu']:
model.cuda()
model_device = next(model.parameters()).device
test_set = LibriMix(csv_dir=conf['test_dir'],
task=conf['task'],
sample_rate=conf['sample_rate'],
n_src=conf['train_conf']['data']['n_src'],
segment=None) # Uses all segment length
# Used to reorder sources only
loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')
# Randomly choose the indexes of sentences to save.
eval_save_dir = os.path.join(conf['exp_dir'], conf['out_dir'])
ex_save_dir = os.path.join(eval_save_dir, 'examples/')
if conf['n_save_ex'] == -1:
conf['n_save_ex'] = len(test_set)
save_idx = random.sample(range(len(test_set)), conf['n_save_ex'])
series_list = []
torch.no_grad().__enter__()
for idx in tqdm(range(len(test_set))):
checkpoint = torch.load(best_model_path, map_location='cpu')
state = checkpoint['state_dict']
state_copy = state.copy()
# Remove unwanted keys
for keys, values in state.items():
if keys.startswith('loss'):
del state_copy[keys]
print(keys)
model = load_state_dict_in(state_copy, model)
# Handle device placement
if conf['use_gpu']:
model.cuda()
model_device = next(model.parameters()).device
test_set = LibriMix(conf['test_dir'], None,
conf['sample_rate'],
conf['train_conf']['data']['n_src'])
loss_func = PITLossWrapper(pairwise_neg_sisdr, mode='pairwise')
# Randomly choose the indexes of sentences to save.
ex_save_dir = os.path.join(conf['exp_dir'], 'examples_mss_8K/')
if conf['n_save_ex'] == -1:
conf['n_save_ex'] = len(test_set)
save_idx = random.sample(range(len(test_set)), conf['n_save_ex'])
series_list = []
torch.no_grad().__enter__()
for idx in tqdm(range(len(test_set))):
# Forward the network on the mixture.
mix, sources = tensors_to_device(test_set[idx], device=model_device)
def main(conf):
train_set = LibriMix(conf['data']['metadata_train_path'],
conf['data']['desired_length'],
conf['data']['sample_rate'],
conf['data']['n_src'])
val_set = LibriMix(conf['data']['metadata_val_path'],
conf['data']['desired_length'],
conf['data']['sample_rate'],
conf['data']['n_src'])
train_loader = DataLoader(train_set, shuffle=True,
batch_size=conf['training']['batch_size'],
num_workers=conf['training']['num_workers'],
drop_last=True)
val_loader = DataLoader(val_set, shuffle=True,
batch_size=conf['training']['batch_size'],
num_workers=conf['training']['num_workers'],
def main(conf):
train_set = LibriMix(conf['data']['metadata_train_path'],
conf['data']['desired_length'],
conf['data']['sample_rate'],
conf['data']['n_src'])
val_set = LibriMix(conf['data']['metadata_val_path'],
conf['data']['desired_length'],
conf['data']['sample_rate'],
conf['data']['n_src'])
train_loader = DataLoader(train_set, shuffle=True,
batch_size=conf['training']['batch_size'],
num_workers=conf['training']['num_workers'],
drop_last=True)
val_loader = DataLoader(val_set, shuffle=True,
batch_size=conf['training']['batch_size'],
num_workers=conf['training']['num_workers'],
drop_last=True)
conf['masknet'].update({'n_src': 2})
# Define model and optimizer in a local function (defined in the recipe).