How to use the torchgan.losses.functional.energy_based_pulling_away_term function in torchgan

To help you get started, we’ve selected a few torchgan examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github torchgan / torchgan / torchgan / losses / energybased.py View on Github external
def forward(self, dgz, d_hid):
        r"""Computes the loss for the given input.

        Args:
            dgz (torch.Tensor) : Output of the Discriminator with generated data. It must have the
                                 dimensions (N, \*) where \* means any number of additional
                                 dimensions.
            d_hid (torch.Tensor): The embeddings generated by the discriminator.

        Returns:
            scalar.
        """
        return self.pt_ratio * energy_based_pulling_away_term(d_hid)