Source code for mmedit.models.restorers.srgan

from mmcv.runner import auto_fp16

from ..builder import build_backbone, build_component, build_loss
from ..common import set_requires_grad
from ..registry import MODELS
from .basic_restorer import BasicRestorer


[docs]@MODELS.register_module() class SRGAN(BasicRestorer): """SRGAN model for single image super-resolution. Ref: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network. Args: generator (dict): Config for the generator. discriminator (dict): Config for the discriminator. Default: None. gan_loss (dict): Config for the gan loss. Note that the loss weight in gan loss is only for the generator. pixel_loss (dict): Config for the pixel loss. Default: None. perceptual_loss (dict): Config for the perceptual loss. Default: None. train_cfg (dict): Config for training. Default: None. You may change the training of gan by setting: `disc_steps`: how many discriminator updates after one generate update; `disc_init_steps`: how many discriminator updates at the start of the training. These two keys are useful when training with WGAN. test_cfg (dict): Config for testing. Default: None. pretrained (str): Path for pretrained model. Default: None. """ def __init__(self, generator, discriminator=None, gan_loss=None, pixel_loss=None, perceptual_loss=None, train_cfg=None, test_cfg=None, pretrained=None): super(BasicRestorer, self).__init__() self.train_cfg = train_cfg self.test_cfg = test_cfg # generator self.generator = build_backbone(generator) # discriminator self.discriminator = build_component( discriminator) if discriminator else None # support fp16 self.fp16_enabled = False # loss self.gan_loss = build_loss(gan_loss) if gan_loss else None self.pixel_loss = build_loss(pixel_loss) if pixel_loss else None self.perceptual_loss = build_loss( perceptual_loss) if perceptual_loss else None self.disc_steps = 1 if self.train_cfg is None else self.train_cfg.get( 'disc_steps', 1) self.disc_init_steps = (0 if self.train_cfg is None else self.train_cfg.get('disc_init_steps', 0)) self.step_counter = 0 # counting training steps self.init_weights(pretrained)
[docs] def init_weights(self, pretrained=None): """Init weights for models. Args: pretrained (str, optional): Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None. """ self.generator.init_weights(pretrained=pretrained) if self.discriminator: self.discriminator.init_weights(pretrained=pretrained)
[docs] @auto_fp16(apply_to=('lq', )) def forward(self, lq, gt=None, test_mode=False, **kwargs): """Forward function. Args: lq (Tensor): Input lq images. gt (Tensor): Ground-truth image. Default: None. test_mode (bool): Whether in test mode or not. Default: False. kwargs (dict): Other arguments. """ if not test_mode: raise ValueError( 'SRGAN model does not supprot `forward_train` function.') else: return self.forward_test(lq, gt, **kwargs)
[docs] def train_step(self, data_batch, optimizer): """Train step. Args: data_batch (dict): A batch of data. optimizer (obj): Optimizer. Returns: dict: Returned output. """ # data lq = data_batch['lq'] gt = data_batch['gt'] # generator fake_g_output = self.generator(lq) losses = dict() log_vars = dict() # no updates to discriminator parameters. set_requires_grad(self.discriminator, False) if (self.step_counter % self.disc_steps == 0 and self.step_counter >= self.disc_init_steps): if self.pixel_loss: losses['loss_pix'] = self.pixel_loss(fake_g_output, gt) if self.perceptual_loss: loss_percep, loss_style = self.perceptual_loss( fake_g_output, gt) if loss_percep is not None: losses['loss_perceptual'] = loss_percep if loss_style is not None: losses['loss_style'] = loss_style # gan loss for generator fake_g_pred = self.discriminator(fake_g_output) losses['loss_gan'] = self.gan_loss( fake_g_pred, target_is_real=True, is_disc=False) # parse loss loss_g, log_vars_g = self.parse_losses(losses) log_vars.update(log_vars_g) # optimize optimizer['generator'].zero_grad() loss_g.backward() optimizer['generator'].step() # discriminator set_requires_grad(self.discriminator, True) # real real_d_pred = self.discriminator(gt) loss_d_real = self.gan_loss( real_d_pred, target_is_real=True, is_disc=True) loss_d, log_vars_d = self.parse_losses(dict(loss_d_real=loss_d_real)) optimizer['discriminator'].zero_grad() loss_d.backward() log_vars.update(log_vars_d) # fake fake_d_pred = self.discriminator(fake_g_output.detach()) loss_d_fake = self.gan_loss( fake_d_pred, target_is_real=False, is_disc=True) loss_d, log_vars_d = self.parse_losses(dict(loss_d_fake=loss_d_fake)) loss_d.backward() log_vars.update(log_vars_d) optimizer['discriminator'].step() self.step_counter += 1 log_vars.pop('loss') # remove the unnecessary 'loss' outputs = dict( log_vars=log_vars, num_samples=len(gt.data), results=dict(lq=lq.cpu(), gt=gt.cpu(), output=fake_g_output.cpu())) return outputs