Source code for mmedit.models.backbones.sr_backbones.sr_resnet

import torch.nn as nn
from mmcv.runner import load_checkpoint

from mmedit.models.common import (PixelShufflePack, ResidualBlockNoBN,
                                  default_init_weights, make_layer)
from mmedit.models.registry import BACKBONES
from mmedit.utils import get_root_logger


[docs]@BACKBONES.register_module() class MSRResNet(nn.Module): """Modified SRResNet. A compacted version modified from SRResNet in "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network". It uses residual blocks without BN, similar to EDSR. Currently, it supports x2, x3 and x4 upsampling scale factor. Args: in_channels (int): Channel number of inputs. out_channels (int): Channel number of outputs. mid_channels (int): Channel number of intermediate features. Default: 64. num_blocks (int): Block number in the trunk network. Default: 16. upscale_factor (int): Upsampling factor. Support x2, x3 and x4. Default: 4. """ _supported_upscale_factors = [2, 3, 4] def __init__(self, in_channels, out_channels, mid_channels=64, num_blocks=16, upscale_factor=4): super(MSRResNet, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.mid_channels = mid_channels self.num_blocks = num_blocks self.upscale_factor = upscale_factor self.conv_first = nn.Conv2d( in_channels, mid_channels, 3, 1, 1, bias=True) self.trunk_net = make_layer( ResidualBlockNoBN, num_blocks, mid_channels=mid_channels) # upsampling if self.upscale_factor in [2, 3]: self.upsample1 = PixelShufflePack( mid_channels, mid_channels, self.upscale_factor, upsample_kernel=3) elif self.upscale_factor == 4: self.upsample1 = PixelShufflePack( mid_channels, mid_channels, 2, upsample_kernel=3) self.upsample2 = PixelShufflePack( mid_channels, mid_channels, 2, upsample_kernel=3) else: raise ValueError( f'Unsupported scale factor {self.upscale_factor}. ' f'Currently supported ones are ' f'{self._supported_upscale_factors}.') self.conv_hr = nn.Conv2d( mid_channels, mid_channels, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d( mid_channels, out_channels, 3, 1, 1, bias=True) self.img_upsampler = nn.Upsample( scale_factor=self.upscale_factor, mode='bilinear', align_corners=False) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
[docs] def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ feat = self.lrelu(self.conv_first(x)) out = self.trunk_net(feat) if self.upscale_factor in [2, 3]: out = self.upsample1(out) elif self.upscale_factor == 4: out = self.upsample1(out) out = self.upsample2(out) out = self.conv_last(self.lrelu(self.conv_hr(out))) upsampled_img = self.img_upsampler(x) out += upsampled_img return out
[docs] def init_weights(self, pretrained=None, strict=True): """Init weights for models. Args: pretrained (str, optional): Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None. strict (boo, optional): Whether strictly load the pretrained model. Defaults to True. """ if isinstance(pretrained, str): logger = get_root_logger() load_checkpoint(self, pretrained, strict=strict, logger=logger) elif pretrained is None: # Initialization methods like `kaiming_init` are for VGG-style # modules. For modules with residual paths, using smaller std is # better for stability and performance. There is a global residual # path in MSRResNet and empirically we use 0.1. See more details in # "ESRGAN: Enhanced Super-Resolution Generative Adversarial # Networks" for m in [self.conv_first, self.conv_hr, self.conv_last]: default_init_weights(m, 0.1) else: raise TypeError(f'"pretrained" must be a str or None. ' f'But received {type(pretrained)}.')