Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from ddsp.synth import SynthModule
class Effects(SynthModule):
"""
Generic class for effects
"""
def __init__(self):
super(Effects, self).__init__()
self.apply(self.init_parameters)
def n_parameters(self):
""" Return number of parameters in the module """
return 0
def forward(self, z):
z, conditions = z
return z
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from modules import ResConv1d
from ddsp.synth import SynthModule
class Filter(nn.Module, SynthModule):
"""
Generic class for trainable signal filters.
"""
def __init__(self):
super(Filter, self).__init__()
self.apply(self.init_parameters)
def init_parameters(self, m):
pass
def forward(self, z):
z, conditions = z
return z
"""
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import numpy as np
from ddsp.synth import SynthModule
class Oscillator(SynthModule):
def __init__(self):
super(Oscillator, self).__init__()
self.apply(self.init_parameters)
def init_parameters(self, m):
pass
def forward(self, z):
pass
class HarmonicOscillators(Oscillator):
def __init__(self, n_partial, sample_rate, block_size):
super(Oscillator, self).__init__()
self.apply(self.init_parameters)
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from ddsp.synth import SynthModule
class Generator(SynthModule):
"""
Generic class for trainable signal generators.
"""
def __init__(self):
super(Generator, self).__init__()
self.apply(self.init_parameters)
def init_parameters(self, m):
pass
def forward(self, z):
z, conditions = z
return z
class FilteredNoise(Generator):
def construct_extractors(self, args):
self.extractors = {}
self.extractors['f0'] = FundamentalFrequency(args.sr, args.block_size, args.sequence_size).float()
self.extractors['loudness'] = Loudness(args.block_size, args.kernel_size).float()
def construct_extractors(self, args):
self.extractors = {}
self.extractors['f0'] = FundamentalFrequency(args.sr, args.block_size, args.sequence_size).float()
self.extractors['loudness'] = Loudness(args.block_size, args.kernel_size).float()
def __init__(self, datadir, args, transform=None, splits=[.8, .1, .1], shuffle_files=True, train='train'):
self.args = args
# Metadata and raw
self.data_files = []
# Spectral transforms
self.features_files = []
# Construct set of extractors
self.construct_extractors(args)
# Construct the FFT extractor
self.multi_fft = MultiscaleFFT(args.scales)
# Retrieve list of files
tmp_files = sorted(glob.glob(datadir + '/raw/*.wav'))
self.data_files.extend(tmp_files)
if (not os.path.exists(datadir + '/data') or len(glob.glob(datadir + '/data/*.npy')) == 0):
os.makedirs(datadir + '/data')
self.preprocess_dataset(datadir)
feat_files = sorted(glob.glob(datadir + '/data/*.npy'))
self.features_files.extend(feat_files)
# Analyze dataset
self.analyze_dataset()
# Create splits
self.create_splits(splits, shuffle_files)
# Compute mean and std of dataset
self.compute_normalization()
# Now we can create the normalization / augmentation transform
self.transform = transform
plot_batch_detailed(fixed_batch)
# Set latent dims to output dims
if (args.latent_dims == 0):
args.latent_dims = args.output_size
"""
###################
Model definition section
###################
"""
print('[Creating model]')
if (args.model in ['ae', 'vae', 'wae', 'flow']):
# Construct encoding and decoding architectures
encoder, decoder = construct_architecture(args)
# Construct synthesizer
synth = construct_synth(args)
# Finally construct the full model (first only AE)
model = DDSSynth(encoder, decoder, synth, args)
else:
raise Exception('Unknown model ' + args.model)
# Send model to device
model = model.to(args.device)
"""
###################
Optimizer section
###################
"""
# Optimizer model
optimizer = optim.Adam(model.parameters(), lr=args.lr)
# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20, verbose=True, threshold=1e-7)
raise Exception('Unknown model ' + args.model)
# Send model to device
model = model.to(args.device)
"""
###################
Optimizer section
###################
"""
# Optimizer model
optimizer = optim.Adam(model.parameters(), lr=args.lr)
# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20, verbose=True, threshold=1e-7)
# Loss
if (args.loss == 'msstft'):
loss = MSSTFTLoss(args.scales)
else:
raise Exception('Unknown loss ' + args.loss)
"""
###################
Training section
###################
"""
#% Monitoring quantities
losses = torch.zeros(args.epochs, 3)
best_loss = np.inf
early = 0
print('[Starting training]')
for i in range(args.epochs):
# Set warm-up values
args.beta = args.beta_factor * (float(i) / float(max(args.warm_latent, i)))