Source code for mmedit.models.restorers.basic_restorer

import numbers
import os.path as osp

import mmcv
from mmcv.runner import auto_fp16

from mmedit.core import psnr, ssim, tensor2img
from ..base import BaseModel
from ..builder import build_backbone, build_loss
from ..registry import MODELS


[docs]@MODELS.register_module() class BasicRestorer(BaseModel): """Basic model for image restoration. It must contain a generator that takes an image as inputs and outputs a restored image. It also has a pixel-wise loss for training. The subclasses should overwrite the function `forward_train`, `forward_test` and `train_step`. Args: generator (dict): Config for the generator structure. pixel_loss (dict): Config for pixel-wise loss. train_cfg (dict): Config for training. Default: None. test_cfg (dict): Config for testing. Default: None. pretrained (str): Path for pretrained model. Default: None. """ allowed_metrics = {'PSNR': psnr, 'SSIM': ssim} def __init__(self, generator, pixel_loss, train_cfg=None, test_cfg=None, pretrained=None): super(BasicRestorer, self).__init__() self.train_cfg = train_cfg self.test_cfg = test_cfg # support fp16 self.fp16_enabled = False # generator self.generator = build_backbone(generator) self.init_weights(pretrained) # loss self.pixel_loss = build_loss(pixel_loss)
[docs] def init_weights(self, pretrained=None): """Init weights for models. Args: pretrained (str, optional): Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None. """ self.generator.init_weights(pretrained)
[docs] @auto_fp16(apply_to=('lq', )) def forward(self, lq, gt=None, test_mode=False, **kwargs): """Forward function. Args: lq (Tensor): Input lq images. gt (Tensor): Ground-truth image. Default: None. test_mode (bool): Whether in test mode or not. Default: False. kwargs (dict): Other arguments. """ if not test_mode: return self.forward_train(lq, gt) else: return self.forward_test(lq, gt, **kwargs)
[docs] def forward_train(self, lq, gt): """Training forward function. Args: lq (Tensor): LQ Tensor with shape (n, c, h, w). gt (Tensor): GT Tensor with shape (n, c, h, w). Returns: Tensor: Output tensor. """ losses = dict() output = self.generator(lq) loss_pix = self.pixel_loss(output, gt) losses['loss_pix'] = loss_pix outputs = dict( losses=losses, num_samples=len(gt.data), results=dict(lq=lq.cpu(), gt=gt.cpu(), output=output.cpu())) return outputs
[docs] def evaluate(self, output, gt): """Evaluation function. Args: output (Tensor): Model output with shape (n, c, h, w). gt (Tensor): GT Tensor with shape (n, c, h, w). Returns: dict: Evaluation results. """ crop_border = self.test_cfg.crop_border output = tensor2img(output) gt = tensor2img(gt) eval_result = dict() for metric in self.test_cfg.metrics: eval_result[metric] = self.allowed_metrics[metric](output, gt, crop_border) return eval_result
[docs] def forward_test(self, lq, gt=None, meta=None, save_image=False, save_path=None, iteration=None): """Testing forward function. Args: lq (Tensor): LQ Tensor with shape (n, c, h, w). gt (Tensor): GT Tensor with shape (n, c, h, w). Default: None. save_image (bool): Whether to save image. Default: False. save_path (str): Path to save image. Default: None. iteration (int): Iteration for the saving image name. Default: None. Returns: dict: Output results. """ output = self.generator(lq) if self.test_cfg is not None and self.test_cfg.get('metrics', None): assert gt is not None, ( 'evaluation with metrics must have gt images.') results = dict(eval_result=self.evaluate(output, gt)) else: results = dict(lq=lq.cpu(), output=output.cpu()) if gt is not None: results['gt'] = gt.cpu() # save image if save_image: lq_path = meta[0]['lq_path'] folder_name = osp.splitext(osp.basename(lq_path))[0] if isinstance(iteration, numbers.Number): save_path = osp.join(save_path, folder_name, f'{folder_name}-{iteration + 1:06d}.png') elif iteration is None: save_path = osp.join(save_path, f'{folder_name}.png') else: raise ValueError('iteration should be number or None, ' f'but got {type(iteration)}') mmcv.imwrite(tensor2img(output), save_path) return results
[docs] def forward_dummy(self, img): """Used for computing network FLOPs. Args: img (Tensor): Input image. Returns: Tensor: Output image. """ out = self.generator(img) return out
[docs] def train_step(self, data_batch, optimizer): """Train step. Args: data_batch (dict): A batch of data. optimizer (obj): Optimizer. Returns: dict: Returned output. """ outputs = self(**data_batch, test_mode=False) loss, log_vars = self.parse_losses(outputs.pop('losses')) # optimize optimizer['generator'].zero_grad() loss.backward() optimizer['generator'].step() outputs.update({'log_vars': log_vars}) return outputs
[docs] def val_step(self, data_batch, **kwargs): """Validation step. Args: data_batch (dict): A batch of data. kwargs (dict): Other arguments for ``val_step``. Returns: dict: Returned output. """ output = self.forward_test(**data_batch, **kwargs) return output