import torch
import torch.autograd as autograd
import torch.nn as nn
from ..registry import LOSSES
[docs]@LOSSES.register_module()
class GANLoss(nn.Module):
"""Define GAN loss.
Args:
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
real_label_val (float): The value for real label. Default: 1.0.
fake_label_val (float): The value for fake label. Default: 0.0.
loss_weight (float): Loss weight. Default: 1.0.
Note that loss_weight is only for generators; and it is always 1.0
for discriminators.
"""
def __init__(self,
gan_type,
real_label_val=1.0,
fake_label_val=0.0,
loss_weight=1.0):
super(GANLoss, self).__init__()
self.gan_type = gan_type
self.loss_weight = loss_weight
self.real_label_val = real_label_val
self.fake_label_val = fake_label_val
if self.gan_type == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif self.gan_type == 'lsgan':
self.loss = nn.MSELoss()
elif self.gan_type == 'wgan':
self.loss = self._wgan_loss
elif self.gan_type == 'hinge':
self.loss = nn.ReLU()
else:
raise NotImplementedError(
f'GAN type {self.gan_type} is not implemented.')
def _wgan_loss(self, input, target):
"""wgan loss.
Args:
input (Tensor): Input tensor.
target (bool): Target label.
Returns:
Tensor: wgan loss.
"""
return -input.mean() if target else input.mean()
[docs] def get_target_label(self, input, target_is_real):
"""Get target label.
Args:
input (Tensor): Input tensor.
target_is_real (bool): Whether the target is real or fake.
Returns:
(bool | Tensor): Target tensor. Return bool for wgan, otherwise,
return Tensor.
"""
if self.gan_type == 'wgan':
return target_is_real
target_val = (
self.real_label_val if target_is_real else self.fake_label_val)
return input.new_ones(input.size()) * target_val
[docs] def forward(self, input, target_is_real, is_disc=False):
"""
Args:
input (Tensor): The input for the loss module, i.e., the network
prediction.
target_is_real (bool): Whether the targe is real or fake.
is_disc (bool): Whether the loss for discriminators or not.
Default: False.
Returns:
Tensor: GAN loss value.
"""
target_label = self.get_target_label(input, target_is_real)
if self.gan_type == 'hinge':
if is_disc: # for discriminators in hinge-gan
input = -input if target_is_real else input
loss = self.loss(1 + input).mean()
else: # for generators in hinge-gan
loss = -input.mean()
else: # other gan types
loss = self.loss(input, target_label)
# loss_weight is always 1.0 for discriminators
return loss if is_disc else loss * self.loss_weight
def gradient_penalty_loss(discriminator, real_data, fake_data, mask=None):
"""Calculate gradient penalty for wgan-gp.
Args:
discriminator (nn.Module): Network for the discriminator.
real_data (Tensor): Real input data.
fake_data (Tensor): Fake input data.
mask (Tensor): Masks for inpaitting. Default: None.
Returns:
Tensor: A tensor for gradient penalty.
"""
batch_size = real_data.size(0)
alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
# interpolate between real_data and fake_data
interpolates = alpha * real_data + (1. - alpha) * fake_data
interpolates = autograd.Variable(interpolates, requires_grad=True)
disc_interpolates = discriminator(interpolates)
gradients = autograd.grad(
outputs=disc_interpolates,
inputs=interpolates,
grad_outputs=torch.ones_like(disc_interpolates),
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
if mask is not None:
gradients = gradients * mask
gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
if mask is not None:
gradients_penalty /= torch.mean(mask)
return gradients_penalty
[docs]@LOSSES.register_module()
class GradientPenaltyLoss(nn.Module):
"""Gradient penalty loss for wgan-gp.
Args:
loss_weight (float): Loss weight. Default: 1.0.
"""
def __init__(self, loss_weight=1.):
super(GradientPenaltyLoss, self).__init__()
self.loss_weight = loss_weight
[docs] def forward(self, discriminator, real_data, fake_data, mask=None):
"""Forward function.
Args:
discriminator (nn.Module): Network for the discriminator.
real_data (Tensor): Real input data.
fake_data (Tensor): Fake input data.
mask (Tensor): Masks for inpaitting. Default: None.
Returns:
Tensor: Loss.
"""
loss = gradient_penalty_loss(
discriminator, real_data, fake_data, mask=mask)
return loss * self.loss_weight
[docs]@LOSSES.register_module()
class DiscShiftLoss(nn.Module):
"""Disc shift loss.
Args:
loss_weight (float, optional): Loss weight. Defaults to 1.0.
"""
def __init__(self, loss_weight=0.1):
super(DiscShiftLoss, self).__init__()
self.loss_weight = loss_weight
[docs] def forward(self, x):
"""Forward function.
Args:
x (Tensor): Tensor with shape (n, c, h, w)
Returns:
Tensor: Loss.
"""
loss = torch.mean(x**2)
return loss * self.loss_weight