Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def make_model_and_optimizer(conf):
""" Function to define the model and optimizer for a config dictionary.
Args:
conf: Dictionary containing the output of hierachical argparse.
Returns:
model, optimizer.
The main goal of this function is to make reloading for resuming
and evaluation very simple.
"""
# Define building blocks for local model
stft, istft = make_enc_dec('stft', **conf['filterbank'])
# Because we concatenate (re, im, mag) as input and compute a complex mask.
if conf['main_args']['is_complex']:
inp_size = int(stft.n_feats_out * 3 / 2)
output_size = stft.n_feats_out
else:
inp_size = output_size = int(stft.n_feats_out / 2)
# Add these fields to the mask model dict
conf['masknet'].update(dict(input_size=inp_size,
output_size=output_size))
masker = SimpleModel(**conf['masknet'])
# Make the complete model
model = Model(stft, masker, istft,
is_complex=conf['main_args']['is_complex'])
# Define optimizer of this model
optimizer = make_optimizer(model.parameters(), **conf['optim'])
return model, optimizer