Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
"discriminator": {
"name": ACGANDiscriminator,
"args": {"num_classes": 10, "in_channels": 1, "step_channels": 4},
"optimizer": {
"name": Adam,
"args": {"lr": 0.0002, "betas": (0.5, 0.999)},
},
},
}
losses_list = [
MinimaxGeneratorLoss(),
MinimaxDiscriminatorLoss(),
AuxiliaryClassifierGeneratorLoss(),
AuxiliaryClassifierDiscriminatorLoss(),
]
trainer = Trainer(
network_params,
losses_list,
sample_size=1,
epochs=1,
device=torch.device("cpu"),
)
trainer(mnist_dataloader())
"optimizer": {
"name": Adam,
"args": {"lr": 0.0002, "betas": (0.5, 0.999)},
},
},
"discriminator": {
"name": ConditionalGANDiscriminator,
"args": {"num_classes": 10, "in_channels": 1, "step_channels": 4},
"optimizer": {
"name": Adam,
"args": {"lr": 0.0002, "betas": (0.5, 0.999)},
},
},
}
losses_list = [MinimaxGeneratorLoss(), MinimaxDiscriminatorLoss()]
trainer = Trainer(
network_params,
losses_list,
sample_size=1,
epochs=1,
device=torch.device("cpu"),
)
trainer(mnist_dataloader())
"optimizer": {
"name": Adam,
"args": {"lr": 0.0002, "betas": (0.5, 0.999)},
},
},
"discriminator": {
"name": DCGANDiscriminator,
"args": {"in_channels": 1, "step_channels": 4},
"optimizer": {
"name": Adam,
"args": {"lr": 0.0002, "betas": (0.5, 0.999)},
},
},
}
losses_list = [MinimaxGeneratorLoss(), MinimaxDiscriminatorLoss()]
trainer = Trainer(
network_params,
losses_list,
sample_size=1,
epochs=1,
device=torch.device("cpu"),
)
trainer(mnist_dataloader())
)
if not generator.label_type == "none":
raise Exception("EBGAN PT supports models which donot require labels")
if not discriminator.embeddings:
raise Exception("EBGAN PT requires the embeddings for loss computation")
noise = torch.randn(batch_size, generator.encoding_dims, device=device)
optimizer_generator.zero_grad()
fake = generator(noise)
d_hid, dgz = discriminator(fake)
loss = self.forward(dgz, d_hid)
loss.backward()
optimizer_generator.step()
return loss.item()
class EnergyBasedDiscriminatorLoss(DiscriminatorLoss):
r"""Energy Based GAN generator loss from `"Energy Based Generative Adversarial Network
by Zhao et. al." `_ paper
The loss can be described as:
.. math:: L(D) = D(x) + max(0, m - D(G(z)))
where
- :math:`G` : Generator
- :math:`D` : Discriminator
- :math:`m` : Margin Hyperparameter
- :math:`z` : A sample from the noise prior
.. note::
The convergence of EBGAN is highly sensitive to hyperparameters. The ``margin``
device, batch_size, labels=None):
if self.override_train_ops is not None:
return self.override_train_ops(self, generator, discriminator, optimizer_generator,
real_inputs, device, labels)
else:
if isinstance(generator, AdversarialAutoEncodingGenerator):
setattr(generator, "embeddings", False)
recon, encodings = generator(real_inputs)
optimizer_generator.zero_grad()
dgz = discriminator(encodings)
loss = self.forward(real_inputs, recon, dgz)
loss.backward()
optimizer_generator.step()
return loss.item()
class AdversarialAutoEncoderDiscriminatorLoss(DiscriminatorLoss):
def forward(self, dx, dgz):
return minimax_discriminator_loss(dx, dgz)
def train_ops(self, generator, discriminator, optimizer_discriminator, real_inputs,
device, batch_size, labels=None):
if self.override_train_ops is not None:
return self.override_train_ops(self, generator, discriminator, optimizer_discriminator,
real_inputs, device, labels)
else:
if isinstance(generator, AdversarialAutoEncodingGenerator):
setattr(generator, "embeddings", True)
encodings = generator(real_inputs).detach()
noise = torch.randn(real_inputs.size(0), generator.encoding_dims, device=device)
optimizer_discriminator.zero_grad()
dx = discriminator(noise)
dgz = discriminator(encodings)
def forward(self, dgz):
r"""Computes the loss for the given input.
Args:
dgz (torch.Tensor) : Output of the Discriminator with generated data. It must have the
dimensions (N, \*) where \* means any number of additional
dimensions.
Returns:
scalar if reduction is applied else Tensor with dimensions (N, \*).
"""
return minimax_generator_loss(dgz, self.nonsaturating, self.reduction)
class MinimaxDiscriminatorLoss(DiscriminatorLoss):
r"""Minimax game discriminator loss from the original GAN paper `"Generative Adversarial Networks
by Goodfellow et. al." `_
The loss can be described as:
.. math:: L(D) = -[log(D(x)) + log(1 - D(G(z)))]
where
- :math:`G` : Generator
- :math:`D` : Discriminator
- :math:`x` : A sample from the data distribution
- :math:`z` : A sample from the noise prior
Args:
label_smoothing (float, optional): The factor by which the labels (1 in this case) needs
``train_iter_custom``.
.. warning::
This function is needed in this exact state for the Trainer to work correctly. So it is
highly recommended that this function is not changed even if the ``Trainer`` is subclassed.
Returns:
An NTuple of the ``generator loss``, ``discriminator loss``, ``number of times the generator
was trained`` and the ``number of times the discriminator was trained``.
"""
self.train_iter_custom()
ldis, lgen, dis_iter, gen_iter = 0.0, 0.0, 0, 0
loss_logs = self.logger.get_loss_viz()
grad_logs = self.logger.get_grad_viz()
for name, loss in self.losses.items():
if isinstance(loss, GeneratorLoss) and isinstance(loss, DiscriminatorLoss):
# NOTE(avik-pal): In most cases this loss is meant to optimize the Discriminator
# but we might need to think of a better solution
if self.loss_information["generator_iters"] % self.ngen == 0:
cur_loss = loss.train_ops(
**self._get_arguments(self.loss_arg_maps[name])
)
loss_logs.logs[name].append(cur_loss)
if type(cur_loss) is tuple:
lgen, ldis, gen_iter, dis_iter = (
lgen + cur_loss[0],
ldis + cur_loss[1],
gen_iter + 1,
dis_iter + 1,
)
else:
# NOTE(avik-pal): We assume that it is a Discriminator Loss by default.
import torch
from .functional import mutual_information_penalty
from .loss import DiscriminatorLoss, GeneratorLoss
__all__ = ["MutualInformationPenalty"]
class MutualInformationPenalty(GeneratorLoss, DiscriminatorLoss):
r"""Mutual Information Penalty as defined in
`"InfoGAN : Interpretable Representation Learning by Information Maximising Generative Adversarial Nets
by Chen et. al." `_ paper
The loss is the variational lower bound of the mutual information between
the latent codes and the generator distribution and is defined as
.. math:: L(G,Q) = log(Q|x)
where
- :math:`x` is drawn from the generator distribution G(z,c)
- :math:`c` drawn from the latent code prior :math:`P(c)`
Args:
lambd (float, optional): The scaling factor for the loss.
label_gen = torch.randint(
0, generator.num_classes, (batch_size,), device=device
)
fake = generator(noise, label_gen)
cgz = discriminator(fake, mode="classifier")
if generator.label_type == "required":
loss = self.forward(cgz, labels)
else:
label_gen = label_gen.type(torch.LongTensor).to(device)
loss = self.forward(cgz, label_gen)
loss.backward()
optimizer_generator.step()
return loss.item()
class AuxiliaryClassifierDiscriminatorLoss(DiscriminatorLoss):
r"""Auxiliary Classifier GAN (ACGAN) loss based on a from
`"Conditional Image Synthesis With Auxiliary Classifier GANs
by Odena et. al. " `_ paper
Args:
reduction (str, optional): Specifies the reduction to apply to the output.
If ``none`` no reduction will be applied. If ``mean`` the outputs are averaged over batch size.
If ``sum`` the elements of the output are summed.
override_train_ops (function, optional): A function is passed to this argument,
if the default ``train_ops`` is not to be used.
"""
def forward(self, logits, labels):
return auxiliary_classification_loss(logits, labels, self.reduction)
def train_ops(
# NOTE(avik-pal): We assume that it is a Discriminator Loss by default.
ldis, dis_iter = ldis + cur_loss, dis_iter + 1
for model_name in self.model_names:
grad_logs.update_grads(model_name, getattr(self, model_name))
elif isinstance(loss, GeneratorLoss):
if self.loss_information["discriminator_iters"] % self.ncritic == 0:
cur_loss = loss.train_ops(
**self._get_arguments(self.loss_arg_maps[name])
)
loss_logs.logs[name].append(cur_loss)
lgen, gen_iter = lgen + cur_loss, gen_iter + 1
for model_name in self.model_names:
model = getattr(self, model_name)
if isinstance(model, Generator):
grad_logs.update_grads(model_name, model)
elif isinstance(loss, DiscriminatorLoss):
if self.loss_information["generator_iters"] % self.ngen == 0:
cur_loss = loss.train_ops(
**self._get_arguments(self.loss_arg_maps[name])
)
loss_logs.logs[name].append(cur_loss)
ldis, dis_iter = ldis + cur_loss, dis_iter + 1
for model_name in self.model_names:
model = getattr(self, model_name)
if isinstance(model, Discriminator):
grad_logs.update_grads(model_name, model)
return lgen, ldis, gen_iter, dis_iter