Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
import torch
import torch.nn.functional as F
from .loss import GeneratorLoss, DiscriminatorLoss
from ..models import AdversarialAutoEncodingGenerator
from .functional import minimax_generator_loss, minimax_discriminator_loss
__all__ = ['AdversarialAutoEncoderGeneratorLoss', 'AdversarialAutoEncoderDiscriminatorLoss']
class AdversarialAutoEncoderGeneratorLoss(GeneratorLoss):
def __init__(self, recon_weight=0.999, gen_weight=0.001, reduction='mean', override_train_ops=None):
super(AdversarialAutoEncoderGeneratorLoss, self).__init__(reduction, override_train_ops)
self.gen_weight = gen_weight
self.recon_weight = recon_weight
def forward(self, real_inputs, gen_outputs, dgz):
return self.recon_weight * F.mse_loss(real_inputs, gen_outputs) +\
self.gen_weight + minimax_generator_loss(dgz, reduction=self.reduction)
def train_ops(self, generator, discriminator, optimizer_generator, real_inputs,
device, batch_size, labels=None):
if self.override_train_ops is not None:
return self.override_train_ops(self, generator, discriminator, optimizer_generator,
real_inputs, device, labels)
else:
if isinstance(generator, AdversarialAutoEncodingGenerator):
import torch
from .functional import least_squares_discriminator_loss, least_squares_generator_loss
from .loss import DiscriminatorLoss, GeneratorLoss
__all__ = ["LeastSquaresGeneratorLoss", "LeastSquaresDiscriminatorLoss"]
class LeastSquaresGeneratorLoss(GeneratorLoss):
r"""Least Squares GAN generator loss from `"Least Squares Generative Adversarial Networks
by Mao et. al." `_ paper
The loss can be described as
.. math:: L(G) = \frac{(D(G(z)) - c)^2}{2}
where
- :math:`G` : Generator
- :math:`D` : Disrciminator
- :math:`c` : target generator label
- :math:`z` : A sample from the noise prior
Args:
reduction (str, optional): Specifies the reduction to apply to the output.
**self._get_arguments(self.loss_arg_maps[name])
)
loss_logs.logs[name].append(cur_loss)
if type(cur_loss) is tuple:
lgen, ldis, gen_iter, dis_iter = (
lgen + cur_loss[0],
ldis + cur_loss[1],
gen_iter + 1,
dis_iter + 1,
)
else:
# NOTE(avik-pal): We assume that it is a Discriminator Loss by default.
ldis, dis_iter = ldis + cur_loss, dis_iter + 1
for model_name in self.model_names:
grad_logs.update_grads(model_name, getattr(self, model_name))
elif isinstance(loss, GeneratorLoss):
if self.loss_information["discriminator_iters"] % self.ncritic == 0:
cur_loss = loss.train_ops(
**self._get_arguments(self.loss_arg_maps[name])
)
loss_logs.logs[name].append(cur_loss)
lgen, gen_iter = lgen + cur_loss, gen_iter + 1
for model_name in self.model_names:
model = getattr(self, model_name)
if isinstance(model, Generator):
grad_logs.update_grads(model_name, model)
elif isinstance(loss, DiscriminatorLoss):
if self.loss_information["generator_iters"] % self.ngen == 0:
cur_loss = loss.train_ops(
**self._get_arguments(self.loss_arg_maps[name])
)
loss_logs.logs[name].append(cur_loss)
from .functional import (
wasserstein_discriminator_loss,
wasserstein_generator_loss,
wasserstein_gradient_penalty,
)
from .loss import DiscriminatorLoss, GeneratorLoss
__all__ = [
"WassersteinGeneratorLoss",
"WassersteinDiscriminatorLoss",
"WassersteinGradientPenalty",
]
class WassersteinGeneratorLoss(GeneratorLoss):
r"""Wasserstein GAN generator loss from
`"Wasserstein GAN by Arjovsky et. al." `_ paper
The loss can be described as:
.. math:: L(G) = -f(G(z))
where
- :math:`G` : Generator
- :math:`f` : Critic/Discriminator
- :math:`z` : A sample from the noise prior
Args:
reduction (str, optional): Specifies the reduction to apply to the output.
If ``none`` no reduction will be applied. If ``mean`` the mean of the output.
import torch
from ..utils import reduce
from .loss import DiscriminatorLoss, GeneratorLoss
__all__ = ["HistoricalAverageGeneratorLoss", "HistoricalAverageDiscriminatorLoss"]
class HistoricalAverageGeneratorLoss(GeneratorLoss):
r"""Historical Average Generator Loss from
`"Improved Techniques for Training GANs
by Salimans et. al." `_ paper
The loss can be described as
.. math:: || \vtheta - \frac{1}{t} \sum_{i=1}^t \vtheta[i] ||^2
where
- :math:`G` : Generator
- :math: `\vtheta[i]` : Generator Parameters at Past Timestep :math: `i`
Args:
reduction (str, optional): Specifies the reduction to apply to the output.
If ``none`` no reduction will be applied. If ``mean`` the outputs are averaged over batch size.
from ..models import AutoEncodingDiscriminator
from .functional import (
energy_based_discriminator_loss,
energy_based_generator_loss,
energy_based_pulling_away_term,
)
from .loss import DiscriminatorLoss, GeneratorLoss
__all__ = [
"EnergyBasedGeneratorLoss",
"EnergyBasedDiscriminatorLoss",
"EnergyBasedPullingAwayTerm",
]
class EnergyBasedGeneratorLoss(GeneratorLoss):
r"""Energy Based GAN generator loss from `"Energy Based Generative Adversarial Network
by Zhao et. al." `_ paper.
The loss can be described as:
.. math:: L(G) = D(G(z))
where
- :math:`G` : Generator
- :math:`D` : Discriminator
- :math:`z` : A sample from the noise prior
Args:
reduction (str, optional): Specifies the reduction to apply to the output.
If ``none`` no reduction will be applied. If ``mean`` the outputs are averaged over batch size.
import torch
from .functional import minimax_discriminator_loss, minimax_generator_loss
from .loss import DiscriminatorLoss, GeneratorLoss
__all__ = ["MinimaxGeneratorLoss", "MinimaxDiscriminatorLoss"]
class MinimaxGeneratorLoss(GeneratorLoss):
r"""Minimax game generator loss from the original GAN paper `"Generative Adversarial Networks
by Goodfellow et. al." `_
The loss can be described as:
.. math:: L(G) = log(1 - D(G(z)))
The nonsaturating heuristic is also supported:
.. math:: L(G) = -log(D(G(z)))
where
- :math:`G` : Generator
- :math:`D` : Discriminator
- :math:`z` : A sample from the noise prior
if isinstance(discriminator, AutoEncodingDiscriminator):
setattr(discriminator, "embeddings", False)
loss = super(EnergyBasedGeneratorLoss, self).train_ops(
generator,
discriminator,
optimizer_generator,
device,
batch_size,
labels,
)
if isinstance(discriminator, AutoEncodingDiscriminator):
setattr(discriminator, "embeddings", True)
return loss
class EnergyBasedPullingAwayTerm(GeneratorLoss):
r"""Energy Based Pulling Away Term from `"Energy Based Generative Adversarial Network
by Zhao et. al." `_ paper.
The loss can be described as:
.. math:: f_{PT}(S) = \frac{1}{N(N-1)}\sum_i\sum_{j \neq i}\bigg(\frac{S_i^T S_j}{||S_i||\ ||S_j||}\bigg)^2
where
- :math:`S` : The feature output from the encoder for generated images
- :math:`N` : Batch Size of the Input
Args:
pt_ratio (float, optional): The weight given to the pulling away term.
override_train_ops (function, optional): A function is passed to this argument,
if the default ``train_ops`` is not to be used.
import torch
import torch.nn.functional as F
from ..utils import reduce
from .loss import DiscriminatorLoss, GeneratorLoss
__all__ = ["FeatureMatchingGeneratorLoss"]
class FeatureMatchingGeneratorLoss(GeneratorLoss):
r"""Feature Matching Generator loss from
`"Improved Training of GANs by Salimans et. al." `_ paper
The loss can be described as:
.. math:: L(G) = ||f(x)-f(G(z))||_2
where
- :math:`G` : Generator
- :math:`f` : An intermediate activation from the discriminator
- :math:`z` : A sample from the noise prior
Args:
reduction (str, optional): Specifies the reduction to apply to the output.
If ``none`` no reduction will be applied. If ``mean`` the outputs are averaged over batch size.
import torch
from .functional import (
boundary_equilibrium_discriminator_loss,
boundary_equilibrium_generator_loss,
)
from .loss import DiscriminatorLoss, GeneratorLoss
__all__ = ["BoundaryEquilibriumGeneratorLoss", "BoundaryEquilibriumDiscriminatorLoss"]
class BoundaryEquilibriumGeneratorLoss(GeneratorLoss):
r"""Boundary Equilibrium GAN generator loss from
`"BEGAN : Boundary Equilibrium Generative Adversarial Networks
by Berthelot et. al." `_ paper
The loss can be described as
.. math:: L(G) = D(G(z))
where
- :math:`G` : Generator
- :math:`D` : Discriminator
Args:
reduction (str, optional): Specifies the reduction to apply to the output.
If ``none`` no reduction will be applied. If ``mean`` the outputs are averaged over batch size.