import numbers
import os.path as osp
import mmcv
from mmcv.runner import auto_fp16
from mmedit.core import psnr, ssim, tensor2img
from ..base import BaseModel
from ..builder import build_backbone, build_loss
from ..registry import MODELS
[docs]@MODELS.register_module()
class BasicRestorer(BaseModel):
"""Basic model for image restoration.
It must contain a generator that takes an image as inputs and outputs a
restored image. It also has a pixel-wise loss for training.
The subclasses should overwrite the function `forward_train`,
`forward_test` and `train_step`.
Args:
generator (dict): Config for the generator structure.
pixel_loss (dict): Config for pixel-wise loss.
train_cfg (dict): Config for training. Default: None.
test_cfg (dict): Config for testing. Default: None.
pretrained (str): Path for pretrained model. Default: None.
"""
allowed_metrics = {'PSNR': psnr, 'SSIM': ssim}
def __init__(self,
generator,
pixel_loss,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(BasicRestorer, self).__init__()
self.train_cfg = train_cfg
self.test_cfg = test_cfg
# support fp16
self.fp16_enabled = False
# generator
self.generator = build_backbone(generator)
self.init_weights(pretrained)
# loss
self.pixel_loss = build_loss(pixel_loss)
[docs] def init_weights(self, pretrained=None):
"""Init weights for models.
Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Defaults to None.
"""
self.generator.init_weights(pretrained)
[docs] @auto_fp16(apply_to=('lq', ))
def forward(self, lq, gt=None, test_mode=False, **kwargs):
"""Forward function.
Args:
lq (Tensor): Input lq images.
gt (Tensor): Ground-truth image. Default: None.
test_mode (bool): Whether in test mode or not. Default: False.
kwargs (dict): Other arguments.
"""
if not test_mode:
return self.forward_train(lq, gt)
else:
return self.forward_test(lq, gt, **kwargs)
[docs] def forward_train(self, lq, gt):
"""Training forward function.
Args:
lq (Tensor): LQ Tensor with shape (n, c, h, w).
gt (Tensor): GT Tensor with shape (n, c, h, w).
Returns:
Tensor: Output tensor.
"""
losses = dict()
output = self.generator(lq)
loss_pix = self.pixel_loss(output, gt)
losses['loss_pix'] = loss_pix
outputs = dict(
losses=losses,
num_samples=len(gt.data),
results=dict(lq=lq.cpu(), gt=gt.cpu(), output=output.cpu()))
return outputs
[docs] def evaluate(self, output, gt):
"""Evaluation function.
Args:
output (Tensor): Model output with shape (n, c, h, w).
gt (Tensor): GT Tensor with shape (n, c, h, w).
Returns:
dict: Evaluation results.
"""
crop_border = self.test_cfg.crop_border
output = tensor2img(output)
gt = tensor2img(gt)
eval_result = dict()
for metric in self.test_cfg.metrics:
eval_result[metric] = self.allowed_metrics[metric](output, gt,
crop_border)
return eval_result
[docs] def forward_test(self,
lq,
gt=None,
meta=None,
save_image=False,
save_path=None,
iteration=None):
"""Testing forward function.
Args:
lq (Tensor): LQ Tensor with shape (n, c, h, w).
gt (Tensor): GT Tensor with shape (n, c, h, w). Default: None.
save_image (bool): Whether to save image. Default: False.
save_path (str): Path to save image. Default: None.
iteration (int): Iteration for the saving image name.
Default: None.
Returns:
dict: Output results.
"""
output = self.generator(lq)
if self.test_cfg is not None and self.test_cfg.get('metrics', None):
assert gt is not None, (
'evaluation with metrics must have gt images.')
results = dict(eval_result=self.evaluate(output, gt))
else:
results = dict(lq=lq.cpu(), output=output.cpu())
if gt is not None:
results['gt'] = gt.cpu()
# save image
if save_image:
lq_path = meta[0]['lq_path']
folder_name = osp.splitext(osp.basename(lq_path))[0]
if isinstance(iteration, numbers.Number):
save_path = osp.join(save_path, folder_name,
f'{folder_name}-{iteration + 1:06d}.png')
elif iteration is None:
save_path = osp.join(save_path, f'{folder_name}.png')
else:
raise ValueError('iteration should be number or None, '
f'but got {type(iteration)}')
mmcv.imwrite(tensor2img(output), save_path)
return results
[docs] def forward_dummy(self, img):
"""Used for computing network FLOPs.
Args:
img (Tensor): Input image.
Returns:
Tensor: Output image.
"""
out = self.generator(img)
return out
[docs] def train_step(self, data_batch, optimizer):
"""Train step.
Args:
data_batch (dict): A batch of data.
optimizer (obj): Optimizer.
Returns:
dict: Returned output.
"""
outputs = self(**data_batch, test_mode=False)
loss, log_vars = self.parse_losses(outputs.pop('losses'))
# optimize
optimizer['generator'].zero_grad()
loss.backward()
optimizer['generator'].step()
outputs.update({'log_vars': log_vars})
return outputs
[docs] def val_step(self, data_batch, **kwargs):
"""Validation step.
Args:
data_batch (dict): A batch of data.
kwargs (dict): Other arguments for ``val_step``.
Returns:
dict: Returned output.
"""
output = self.forward_test(**data_batch, **kwargs)
return output