from mmcv.runner import auto_fp16
from ..builder import build_backbone, build_component, build_loss
from ..common import set_requires_grad
from ..registry import MODELS
from .basic_restorer import BasicRestorer
[docs]@MODELS.register_module()
class SRGAN(BasicRestorer):
"""SRGAN model for single image super-resolution.
Ref:
Photo-Realistic Single Image Super-Resolution Using a Generative
Adversarial Network.
Args:
generator (dict): Config for the generator.
discriminator (dict): Config for the discriminator. Default: None.
gan_loss (dict): Config for the gan loss.
Note that the loss weight in gan loss is only for the generator.
pixel_loss (dict): Config for the pixel loss. Default: None.
perceptual_loss (dict): Config for the perceptual loss. Default: None.
train_cfg (dict): Config for training. Default: None.
You may change the training of gan by setting:
`disc_steps`: how many discriminator updates after one generate
update;
`disc_init_steps`: how many discriminator updates at the start of
the training.
These two keys are useful when training with WGAN.
test_cfg (dict): Config for testing. Default: None.
pretrained (str): Path for pretrained model. Default: None.
"""
def __init__(self,
generator,
discriminator=None,
gan_loss=None,
pixel_loss=None,
perceptual_loss=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(BasicRestorer, self).__init__()
self.train_cfg = train_cfg
self.test_cfg = test_cfg
# generator
self.generator = build_backbone(generator)
# discriminator
self.discriminator = build_component(
discriminator) if discriminator else None
# support fp16
self.fp16_enabled = False
# loss
self.gan_loss = build_loss(gan_loss) if gan_loss else None
self.pixel_loss = build_loss(pixel_loss) if pixel_loss else None
self.perceptual_loss = build_loss(
perceptual_loss) if perceptual_loss else None
self.disc_steps = 1 if self.train_cfg is None else self.train_cfg.get(
'disc_steps', 1)
self.disc_init_steps = (0 if self.train_cfg is None else
self.train_cfg.get('disc_init_steps', 0))
self.step_counter = 0 # counting training steps
self.init_weights(pretrained)
[docs] def init_weights(self, pretrained=None):
"""Init weights for models.
Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Defaults to None.
"""
self.generator.init_weights(pretrained=pretrained)
if self.discriminator:
self.discriminator.init_weights(pretrained=pretrained)
[docs] @auto_fp16(apply_to=('lq', ))
def forward(self, lq, gt=None, test_mode=False, **kwargs):
"""Forward function.
Args:
lq (Tensor): Input lq images.
gt (Tensor): Ground-truth image. Default: None.
test_mode (bool): Whether in test mode or not. Default: False.
kwargs (dict): Other arguments.
"""
if not test_mode:
raise ValueError(
'SRGAN model does not supprot `forward_train` function.')
else:
return self.forward_test(lq, gt, **kwargs)
[docs] def train_step(self, data_batch, optimizer):
"""Train step.
Args:
data_batch (dict): A batch of data.
optimizer (obj): Optimizer.
Returns:
dict: Returned output.
"""
# data
lq = data_batch['lq']
gt = data_batch['gt']
# generator
fake_g_output = self.generator(lq)
losses = dict()
log_vars = dict()
# no updates to discriminator parameters.
set_requires_grad(self.discriminator, False)
if (self.step_counter % self.disc_steps == 0
and self.step_counter >= self.disc_init_steps):
if self.pixel_loss:
losses['loss_pix'] = self.pixel_loss(fake_g_output, gt)
if self.perceptual_loss:
loss_percep, loss_style = self.perceptual_loss(
fake_g_output, gt)
if loss_percep is not None:
losses['loss_perceptual'] = loss_percep
if loss_style is not None:
losses['loss_style'] = loss_style
# gan loss for generator
fake_g_pred = self.discriminator(fake_g_output)
losses['loss_gan'] = self.gan_loss(
fake_g_pred, target_is_real=True, is_disc=False)
# parse loss
loss_g, log_vars_g = self.parse_losses(losses)
log_vars.update(log_vars_g)
# optimize
optimizer['generator'].zero_grad()
loss_g.backward()
optimizer['generator'].step()
# discriminator
set_requires_grad(self.discriminator, True)
# real
real_d_pred = self.discriminator(gt)
loss_d_real = self.gan_loss(
real_d_pred, target_is_real=True, is_disc=True)
loss_d, log_vars_d = self.parse_losses(dict(loss_d_real=loss_d_real))
optimizer['discriminator'].zero_grad()
loss_d.backward()
log_vars.update(log_vars_d)
# fake
fake_d_pred = self.discriminator(fake_g_output.detach())
loss_d_fake = self.gan_loss(
fake_d_pred, target_is_real=False, is_disc=True)
loss_d, log_vars_d = self.parse_losses(dict(loss_d_fake=loss_d_fake))
loss_d.backward()
log_vars.update(log_vars_d)
optimizer['discriminator'].step()
self.step_counter += 1
log_vars.pop('loss') # remove the unnecessary 'loss'
outputs = dict(
log_vars=log_vars,
num_samples=len(gt.data),
results=dict(lq=lq.cpu(), gt=gt.cpu(), output=fake_g_output.cpu()))
return outputs