Source code for mmedit.models.inpaintors.gl_inpaintor

import torch

from ..common import extract_around_bbox, extract_bbox_patch, set_requires_grad
from ..registry import MODELS
from .one_stage import OneStageInpaintor


[docs]@MODELS.register_module() class GLInpaintor(OneStageInpaintor): """Inpaintor for global&local method. This inpaintor is implemented according to the paper: Globally and Locally Consistent Image Completion Importantly, this inpaintor is an example for using custom training schedule based on `OneStageInpaintor`. The training pipeline of global&local is as following: .. code-block:: python if cur_iter < iter_tc: update generator with only l1 loss else: update discriminator if cur_iter > iter_td: update generator with l1 loss and adversarial loss The new attribute `cur_iter` is added for recording current number of iteration. The `train_cfg` contains the setting of the training schedule: .. code-block:: python train_cfg = dict( start_iter=0, disc_step=1, iter_tc=90000, iter_td=100000 ) `iter_tc` and `iter_td` correspond to the noation :math:`T_C` and :math:`T_D` of theoriginal paper. Args: generator (dict): Config for encoder-decoder style generator. disc (dict): Config for discriminator. loss_gan (dict): Config for adversarial loss. loss_gp (dict): Config for gradient penalty loss. loss_disc_shift (dict): Config for discriminator shift loss. loss_composed_percep (dict): Config for perceptural and style loss with composed image as input. loss_out_percep (dict): Config for perceptural and style loss with direct output as input. loss_l1_hole (dict): Config for l1 loss in the hole. loss_l1_valid (dict): Config for l1 loss in the valid region. loss_tv (dict): Config for total variation loss. train_cfg (dict): Configs for training scheduler. `disc_step` must be contained for indicates the discriminator updating steps in each training step. test_cfg (dict): Configs for testing scheduler. pretrained (str): Path for pretrained model. Default None. """ def __init__(self, encdec, disc=None, loss_gan=None, loss_gp=None, loss_disc_shift=None, loss_composed_percep=None, loss_out_percep=False, loss_l1_hole=None, loss_l1_valid=None, loss_tv=None, train_cfg=None, test_cfg=None, pretrained=None): super(GLInpaintor, self).__init__( encdec, disc=disc, loss_gan=loss_gan, loss_gp=loss_gp, loss_disc_shift=loss_disc_shift, loss_composed_percep=loss_composed_percep, loss_out_percep=loss_out_percep, loss_l1_hole=loss_l1_hole, loss_l1_valid=loss_l1_valid, loss_tv=loss_tv, train_cfg=train_cfg, test_cfg=test_cfg, pretrained=pretrained) if self.train_cfg is not None: self.cur_iter = self.train_cfg.start_iter
[docs] def generator_loss(self, fake_res, fake_img, fake_local, data_batch): """Forward function in generator training step. In this function, we mainly compute the loss items for generator with the given (fake_res, fake_img). In general, the `fake_res` is the direct output of the generator and the `fake_img` is the composition of direct output and ground-truth image. Args: fake_res (torch.Tensor): Direct output of the generator. fake_img (torch.Tensor): Composition of `fake_res` and ground-truth image. data_batch (dict): Contain other elements for computing losses. Returns: tuple[dict]: A tuple containing two dictionaries. The first one \ is the result dict, which contains the results computed \ within this function for visualization. The second one is the \ loss dict, containing loss items computed in this function. """ gt = data_batch['gt_img'] mask = data_batch['mask'] masked_img = data_batch['masked_img'] loss = dict() # if cur_iter <= iter_td, do not calculate adversarial loss if self.with_gan and self.cur_iter > self.train_cfg.iter_td: g_fake_pred = self.disc((fake_img, fake_local)) loss_g_fake = self.loss_gan(g_fake_pred, True, False) loss['loss_g_fake'] = loss_g_fake if self.with_l1_hole_loss: loss_l1_hole = self.loss_l1_hole(fake_res, gt, weight=mask) loss['loss_l1_hole'] = loss_l1_hole if self.with_l1_valid_loss: loss_l1_valid = self.loss_l1_valid(fake_res, gt, weight=1. - mask) loss['loss_l1_valid'] = loss_l1_valid res = dict( gt_img=gt.cpu(), masked_img=masked_img.cpu(), fake_res=fake_res.cpu(), fake_img=fake_img.cpu()) return res, loss
[docs] def train_step(self, data_batch, optimizer): """Train step function. In this function, the inpaintor will finish the train step following the pipeline: 1. get fake res/image 2. optimize discriminator (if in current schedule) 3. optimzie generator (if in current schedule) If ``self.train_cfg.disc_step > 1``, the train step will contain multiple iterations for optimizing discriminator with different input data and sonly one iteration for optimizing generator after `disc_step` iterations for discriminator. Args: data_batch (torch.Tensor): Batch of data as input. optimizer (dict[torch.optim.Optimizer]): Dict with optimizers for generator and discriminator (if have). Returns: dict: Dict with loss, information for logger, the number of \ samples and results for visualization. """ log_vars = {} gt_img = data_batch['gt_img'] mask = data_batch['mask'] masked_img = data_batch['masked_img'] bbox_tensor = data_batch['mask_bbox'] input_x = torch.cat([masked_img, mask], dim=1) fake_res = self.generator(input_x) fake_img = gt_img * (1. - mask) + fake_res * mask fake_local, bbox_new = extract_around_bbox(fake_img, bbox_tensor, self.train_cfg.local_size) gt_local = extract_bbox_patch(bbox_new, gt_img) fake_gt_local = torch.cat([fake_local, gt_local], dim=2) # if cur_iter > iter_tc, update discriminator if (self.train_cfg.disc_step > 0 and self.cur_iter > self.train_cfg.iter_tc): # set discriminator requires_grad as True set_requires_grad(self.disc, True) fake_data = (fake_img.detach(), fake_local.detach()) real_data = (gt_img, gt_local) disc_losses = self.forward_train_d(fake_data, False, True) loss_disc, log_vars_d = self.parse_losses(disc_losses) log_vars.update(log_vars_d) optimizer['disc'].zero_grad() loss_disc.backward() disc_losses = self.forward_train_d(real_data, True, True) loss_disc, log_vars_d = self.parse_losses(disc_losses) log_vars.update(log_vars_d) loss_disc.backward() optimizer['disc'].step() self.disc_step_count = (self.disc_step_count + 1) % self.train_cfg.disc_step # if cur_iter <= iter_td, do not update generator if (self.disc_step_count != 0 or self.cur_iter <= self.train_cfg.iter_td): results = dict( gt_img=gt_img.cpu(), masked_img=masked_img.cpu(), fake_res=fake_res.cpu(), fake_img=fake_img.cpu(), fake_gt_local=fake_gt_local.cpu()) outputs = dict( log_vars=log_vars, num_samples=len(data_batch['gt_img'].data), results=results) self.cur_iter += 1 return outputs # set discriminators requires_grad as False to avoid extra computation. set_requires_grad(self.disc, False) # update generator if (self.cur_iter <= self.train_cfg.iter_tc or self.cur_iter > self.train_cfg.iter_td): results, g_losses = self.generator_loss(fake_res, fake_img, fake_local, data_batch) loss_g, log_vars_g = self.parse_losses(g_losses) log_vars.update(log_vars_g) optimizer['generator'].zero_grad() loss_g.backward() optimizer['generator'].step() results.update(fake_gt_local=fake_gt_local.cpu()) outputs = dict( log_vars=log_vars, num_samples=len(data_batch['gt_img'].data), results=results) self.cur_iter += 1 return outputs