import torch.nn as nn
from mmcv.runner import load_checkpoint
from mmedit.models.common import (PixelShufflePack, ResidualBlockNoBN,
default_init_weights, make_layer)
from mmedit.models.registry import BACKBONES
from mmedit.utils import get_root_logger
[docs]@BACKBONES.register_module()
class MSRResNet(nn.Module):
"""Modified SRResNet.
A compacted version modified from SRResNet in "Photo-Realistic Single
Image Super-Resolution Using a Generative Adversarial Network".
It uses residual blocks without BN, similar to EDSR.
Currently, it supports x2, x3 and x4 upsampling scale factor.
Args:
in_channels (int): Channel number of inputs.
out_channels (int): Channel number of outputs.
mid_channels (int): Channel number of intermediate features.
Default: 64.
num_blocks (int): Block number in the trunk network. Default: 16.
upscale_factor (int): Upsampling factor. Support x2, x3 and x4.
Default: 4.
"""
_supported_upscale_factors = [2, 3, 4]
def __init__(self,
in_channels,
out_channels,
mid_channels=64,
num_blocks=16,
upscale_factor=4):
super(MSRResNet, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.mid_channels = mid_channels
self.num_blocks = num_blocks
self.upscale_factor = upscale_factor
self.conv_first = nn.Conv2d(
in_channels, mid_channels, 3, 1, 1, bias=True)
self.trunk_net = make_layer(
ResidualBlockNoBN, num_blocks, mid_channels=mid_channels)
# upsampling
if self.upscale_factor in [2, 3]:
self.upsample1 = PixelShufflePack(
mid_channels,
mid_channels,
self.upscale_factor,
upsample_kernel=3)
elif self.upscale_factor == 4:
self.upsample1 = PixelShufflePack(
mid_channels, mid_channels, 2, upsample_kernel=3)
self.upsample2 = PixelShufflePack(
mid_channels, mid_channels, 2, upsample_kernel=3)
else:
raise ValueError(
f'Unsupported scale factor {self.upscale_factor}. '
f'Currently supported ones are '
f'{self._supported_upscale_factors}.')
self.conv_hr = nn.Conv2d(
mid_channels, mid_channels, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(
mid_channels, out_channels, 3, 1, 1, bias=True)
self.img_upsampler = nn.Upsample(
scale_factor=self.upscale_factor,
mode='bilinear',
align_corners=False)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
[docs] def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
feat = self.lrelu(self.conv_first(x))
out = self.trunk_net(feat)
if self.upscale_factor in [2, 3]:
out = self.upsample1(out)
elif self.upscale_factor == 4:
out = self.upsample1(out)
out = self.upsample2(out)
out = self.conv_last(self.lrelu(self.conv_hr(out)))
upsampled_img = self.img_upsampler(x)
out += upsampled_img
return out
[docs] def init_weights(self, pretrained=None, strict=True):
"""Init weights for models.
Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Defaults to None.
strict (boo, optional): Whether strictly load the pretrained model.
Defaults to True.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=strict, logger=logger)
elif pretrained is None:
# Initialization methods like `kaiming_init` are for VGG-style
# modules. For modules with residual paths, using smaller std is
# better for stability and performance. There is a global residual
# path in MSRResNet and empirically we use 0.1. See more details in
# "ESRGAN: Enhanced Super-Resolution Generative Adversarial
# Networks"
for m in [self.conv_first, self.conv_hr, self.conv_last]:
default_init_weights(m, 0.1)
else:
raise TypeError(f'"pretrained" must be a str or None. '
f'But received {type(pretrained)}.')