Source code for mmedit.models.restorers.esrgan

import torch

from ..common import set_requires_grad
from ..registry import MODELS
from .srgan import SRGAN


[docs]@MODELS.register_module() class ESRGAN(SRGAN): """Enhanced SRGAN model for single image super-resolution. Ref: ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. It uses RaGAN for GAN updates: The relativistic discriminator: a key element missing from standard GAN. Args: generator (dict): Config for the generator. discriminator (dict): Config for the discriminator. Default: None. gan_loss (dict): Config for the gan loss. Note that the loss weight in gan loss is only for the generator. pixel_loss (dict): Config for the pixel loss. Default: None. perceptual_loss (dict): Config for the perceptual loss. Default: None. train_cfg (dict): Config for training. Default: None. You may change the training of gan by setting: `disc_steps`: how many discriminator updates after one generate update; `disc_init_steps`: how many discriminator updates at the start of the training. These two keys are useful when training with WGAN. test_cfg (dict): Config for testing. Default: None. pretrained (str): Path for pretrained model. Default: None. """
[docs] def train_step(self, data_batch, optimizer): """Train step. Args: data_batch (dict): A batch of data. optimizer (obj): Optimizer. Returns: dict: Returned output. """ # data lq = data_batch['lq'] gt = data_batch['gt'] # generator fake_g_output = self.generator(lq) losses = dict() log_vars = dict() # no updates to discriminator parameters. set_requires_grad(self.discriminator, False) if (self.step_counter % self.disc_steps == 0 and self.step_counter >= self.disc_init_steps): if self.pixel_loss: losses['loss_pix'] = self.pixel_loss(fake_g_output, gt) if self.perceptual_loss: loss_percep, loss_style = self.perceptual_loss( fake_g_output, gt) if loss_percep is not None: losses['loss_perceptual'] = loss_percep if loss_style is not None: losses['loss_style'] = loss_style # gan loss for generator real_d_pred = self.discriminator(gt).detach() fake_g_pred = self.discriminator(fake_g_output) loss_gan_fake = self.gan_loss( fake_g_pred - torch.mean(real_d_pred), target_is_real=True, is_disc=False) loss_gan_real = self.gan_loss( real_d_pred - torch.mean(fake_g_pred), target_is_real=False, is_disc=False) losses['loss_gan'] = (loss_gan_fake + loss_gan_real) / 2 # parse loss loss_g, log_vars_g = self.parse_losses(losses) log_vars.update(log_vars_g) # optimize optimizer['generator'].zero_grad() loss_g.backward() optimizer['generator'].step() # discriminator set_requires_grad(self.discriminator, True) # real fake_d_pred = self.discriminator(fake_g_output).detach() real_d_pred = self.discriminator(gt) loss_d_real = self.gan_loss( real_d_pred - torch.mean(fake_d_pred), target_is_real=True, is_disc=True ) * 0.5 # 0.5 for averaging loss_d_real and loss_d_fake loss_d, log_vars_d = self.parse_losses(dict(loss_d_real=loss_d_real)) optimizer['discriminator'].zero_grad() loss_d.backward() log_vars.update(log_vars_d) # fake fake_d_pred = self.discriminator(fake_g_output.detach()) loss_d_fake = self.gan_loss( fake_d_pred - torch.mean(real_d_pred.detach()), target_is_real=False, is_disc=True ) * 0.5 # 0.5 for averaging loss_d_real and loss_d_fake loss_d, log_vars_d = self.parse_losses(dict(loss_d_fake=loss_d_fake)) loss_d.backward() log_vars.update(log_vars_d) optimizer['discriminator'].step() self.step_counter += 1 log_vars.pop('loss') # remove the unnecessary 'loss' outputs = dict( log_vars=log_vars, num_samples=len(gt.data), results=dict(lq=lq.cpu(), gt=gt.cpu(), output=fake_g_output.cpu())) return outputs