Source code for mmedit.models.inpaintors.two_stage

import os.path as osp
from pathlib import Path

import mmcv
import torch
from torchvision.utils import save_image

from mmedit.core import tensor2img
from ..common.model_utils import set_requires_grad
from ..registry import MODELS
from .one_stage import OneStageInpaintor


[docs]@MODELS.register_module() class TwoStageInpaintor(OneStageInpaintor): """Two-Stage Inpaintor. Currently, we support these loss types in each of two stage inpaintors: ['loss_gan', 'loss_l1_hole', 'loss_l1_valid', 'loss_composed_percep',\ 'loss_out_percep', 'loss_tv'] The `stage1_loss_type` and `stage2_loss_type` should be chosen from these loss types. Args: stage1_loss_type (tuple[str]): Contains the loss names used in the first stage model. stage2_loss_type (tuple[str]): Contains the loss names used in the second stage model. input_with_ones (bool): Whether to concatenate an extra ones tensor in input. Default: True. disc_input_with_mask (bool): Whether to add mask as input in discriminator. Default: False. """ def __init__(self, *args, stage1_loss_type=('loss_l1_hole', ), stage2_loss_type=('loss_l1_hole', 'loss_gan'), input_with_ones=True, disc_input_with_mask=False, **kwargs): super(TwoStageInpaintor, self).__init__(*args, **kwargs) self.stage1_loss_type = stage1_loss_type self.stage2_loss_type = stage2_loss_type self.input_with_ones = input_with_ones self.disc_input_with_mask = disc_input_with_mask self.eval_with_metrics = ('metrics' in self.test_cfg) and ( self.test_cfg['metrics'] is not None)
[docs] def forward_test(self, masked_img, mask, save_image=False, save_path=None, iteration=None, **kwargs): """Forward function for testing. Args: masked_img (torch.Tensor): Tensor with shape of (n, 3, h, w). mask (torch.Tensor): Tensor with shape of (n, 1, h, w). save_image (bool, optional): If True, results will be saved as image. Defaults to False. save_path (str, optional): If given a valid str, the reuslts will be saved in this path. Defaults to None. iteration (int, optional): Iteration number. Defaults to None. Returns: dict: Contain output results and eval metrics (if have). """ if self.input_with_ones: tmp_ones = torch.ones_like(mask) input_x = torch.cat([masked_img, tmp_ones, mask], dim=1) else: input_x = torch.cat([masked_img, mask], dim=1) stage1_fake_res, stage2_fake_res = self.generator(input_x) fake_img = stage2_fake_res * mask + masked_img * (1. - mask) output = dict() eval_results = {} if self.eval_with_metrics: gt_img = kwargs['gt_img'] data_dict = dict( gt_img=gt_img, fake_res=stage2_fake_res, mask=mask) for metric_name in self.test_cfg['metrics']: if metric_name in ['ssim', 'psnr']: eval_results[metric_name] = self._eval_metrics[ metric_name](tensor2img(fake_img, min_max=(-1, 1)), tensor2img(gt_img, min_max=(-1, 1))) else: eval_results[metric_name] = self._eval_metrics[ metric_name]()(data_dict).item() output['eval_results'] = eval_results else: output['stage1_fake_res'] = stage1_fake_res output['stage2_fake_res'] = stage2_fake_res output['fake_res'] = stage2_fake_res output['fake_img'] = fake_img output['meta'] = None if 'meta' not in kwargs else kwargs['meta'][0] if save_image: assert save_image and save_path is not None, ( 'Save path should be given') assert output['meta'] is not None, ( 'Meta information should be given to save image.') tmp_filename = output['meta']['gt_img_path'] filestem = Path(tmp_filename).stem if iteration is not None: filename = f'{filestem}_{iteration}.png' else: filename = f'{filestem}.png' mmcv.mkdir_or_exist(save_path) img_list = [kwargs['gt_img']] if 'gt_img' in kwargs else [] img_list.extend([ masked_img, mask.expand_as(masked_img), stage1_fake_res, stage2_fake_res, fake_img ]) img = torch.cat(img_list, dim=3).cpu() self.save_visualization(img, osp.join(save_path, filename)) output['save_img_path'] = osp.abspath( osp.join(save_path, filename)) return output
[docs] def save_visualization(self, img, filename): """Save visualization results. Args: img (torch.Tensor): Tensor with shape of (n, 3, h, w). filename (str): Path to save visualization. """ if self.test_cfg.get('img_rerange', True): img = (img + 1) / 2 if self.test_cfg.get('img_bgr2rgb', True): img = img[:, [2, 1, 0], ...] save_image(img, filename, nrow=1, padding=0)
[docs] def two_stage_loss(self, stage1_data, stage2_data, data_batch): """Calculate two-stage loss. Args: stage1_data (dict): Contain stage1 results. stage2_data (dict): Contain stage2 results. data_batch (dict): Contain data needed to calculate loss. Returns: dict: Contain losses with name. """ gt = data_batch['gt_img'] mask = data_batch['mask'] masked_img = data_batch['masked_img'] loss = dict() results = dict( gt_img=gt.cpu(), mask=mask.cpu(), masked_img=masked_img.cpu()) # calculate losses for stage1 if self.stage1_loss_type is not None: fake_res = stage1_data['fake_res'] fake_img = stage1_data['fake_img'] for type_key in self.stage1_loss_type: tmp_loss = self.calculate_loss_with_type( type_key, fake_res, fake_img, gt, mask, prefix='stage1_') loss.update(tmp_loss) results.update( dict( stage1_fake_res=stage1_data['fake_res'].cpu(), stage1_fake_img=stage1_data['fake_img'].cpu())) if self.stage2_loss_type is not None: fake_res = stage2_data['fake_res'] fake_img = stage2_data['fake_img'] for type_key in self.stage2_loss_type: tmp_loss = self.calculate_loss_with_type( type_key, fake_res, fake_img, gt, mask, prefix='stage2_') loss.update(tmp_loss) results.update( dict( stage2_fake_res=stage2_data['fake_res'].cpu(), stage2_fake_img=stage2_data['fake_img'].cpu())) return results, loss
[docs] def calculate_loss_with_type(self, loss_type, fake_res, fake_img, gt, mask, prefix='stage1_'): """Calculate multiple types of losses. Args: loss_type (str): Type of the loss. fake_res (torch.Tensor): Direct results from model. fake_img (torch.Tensor): Composited results from model. gt (torch.Tensor): Ground-truth tensor. mask (torch.Tensor): Mask tensor. prefix (str, optional): Prefix for loss name. Defaults to 'stage1_'. Returns: dict: Contain loss value with its name. """ loss_dict = dict() if loss_type == 'loss_gan': if self.disc_input_with_mask: disc_input_x = torch.cat([fake_img, mask], dim=1) else: disc_input_x = fake_img g_fake_pred = self.disc(disc_input_x) loss_g_fake = self.loss_gan(g_fake_pred, True, is_disc=False) loss_dict[prefix + 'loss_g_fake'] = loss_g_fake elif 'percep' in loss_type: loss_pecep, loss_style = self.loss_percep(fake_img, gt) if loss_pecep is not None: loss_dict[prefix + loss_type] = loss_pecep if loss_style is not None: loss_dict[prefix + loss_type[:-6] + 'style'] = loss_style elif 'tv' in loss_type: loss_tv = self.loss_tv(fake_img, mask=mask) loss_dict[prefix + loss_type] = loss_tv elif 'l1' in loss_type: weight = 1. - mask if 'valid' in loss_type else mask loss_l1 = getattr(self, loss_type)(fake_res, gt, weight=weight) loss_dict[prefix + loss_type] = loss_l1 else: raise NotImplementedError( f'Please check your loss type {loss_type}' f' and the config dict in init function. ' f'We cannot find the related loss function.') return loss_dict
[docs] def train_step(self, data_batch, optimizer): """Train step function. In this function, the inpaintor will finish the train step following the pipeline: 1. get fake res/image 2. optimize discriminator (if have) 3. optimize generator If `self.train_cfg.disc_step > 1`, the train step will contain multiple iterations for optimizing discriminator with different input data and only one iteration for optimizing gerator after `disc_step` iterations for discriminator. Args: data_batch (torch.Tensor): Batch of data as input. optimizer (dict[torch.optim.Optimizer]): Dict with optimizers for generator and discriminator (if have). Returns: dict: Dict with loss, information for logger, the number of \ samples and results for visualization. """ log_vars = {} gt_img = data_batch['gt_img'] mask = data_batch['mask'] masked_img = data_batch['masked_img'] # get common output from encdec if self.input_with_ones: tmp_ones = torch.ones_like(mask) input_x = torch.cat([masked_img, tmp_ones, mask], dim=1) else: input_x = torch.cat([masked_img, mask], dim=1) stage1_fake_res, stage2_fake_res = self.generator(input_x) stage1_fake_img = masked_img * (1. - mask) + stage1_fake_res * mask stage2_fake_img = masked_img * (1. - mask) + stage2_fake_res * mask # discriminator training step # In this version, we only use the results from the second stage to # train discriminators, which is a commonly used setting. This can be # easily modified to your custom training schedule. if self.train_cfg.disc_step > 0: set_requires_grad(self.disc, True) if self.disc_input_with_mask: disc_input_x = torch.cat([stage2_fake_img.detach(), mask], dim=1) else: disc_input_x = stage2_fake_img.detach() disc_losses = self.forward_train_d( disc_input_x, False, is_disc=True) loss_disc, log_vars_d = self.parse_losses(disc_losses) log_vars.update(log_vars_d) optimizer['disc'].zero_grad() loss_disc.backward() if self.disc_input_with_mask: disc_input_x = torch.cat([gt_img, mask], dim=1) else: disc_input_x = gt_img disc_losses = self.forward_train_d( disc_input_x, True, is_disc=True) loss_disc, log_vars_d = self.parse_losses(disc_losses) log_vars.update(log_vars_d) loss_disc.backward() if self.with_gp_loss: # gradient penalty loss should not be used with mask as input assert not self.disc_input_with_mask loss_d_gp = self.loss_gp( self.disc, gt_img, stage2_fake_img, mask=mask) loss_disc, log_vars_d = self.parse_losses( dict(loss_gp=loss_d_gp)) log_vars.update(log_vars_d) loss_disc.backward() optimizer['disc'].step() self.disc_step_count = (self.disc_step_count + 1) % self.train_cfg.disc_step if self.disc_step_count != 0: # results contain the data for visualization results = dict( gt_img=gt_img.cpu(), masked_img=masked_img.cpu(), fake_res=stage2_fake_res.cpu(), fake_img=stage2_fake_img.cpu()) outputs = dict( log_vars=log_vars, num_samples=len(data_batch['gt_img'].data), results=results) return outputs # prepare stage1 results and stage2 results dict for calculating losses stage1_results = dict( fake_res=stage1_fake_res, fake_img=stage1_fake_img) stage2_results = dict( fake_res=stage2_fake_res, fake_img=stage2_fake_img) # generator (encdec) and refiner training step, results contain the # data for visualization if self.with_gan: set_requires_grad(self.disc, False) results, two_stage_losses = self.two_stage_loss( stage1_results, stage2_results, data_batch) loss_two_stage, log_vars_two_stage = self.parse_losses( two_stage_losses) log_vars.update(log_vars_two_stage) optimizer['generator'].zero_grad() loss_two_stage.backward() optimizer['generator'].step() outputs = dict( log_vars=log_vars, num_samples=len(data_batch['gt_img'].data), results=results) return outputs