Source code for mmedit.models.backbones.sr_backbones.edsr

import math

import torch
import torch.nn as nn
from mmcv.runner import load_checkpoint

from mmedit.models.common import (PixelShufflePack, ResidualBlockNoBN,
                                  make_layer)
from mmedit.models.registry import BACKBONES
from mmedit.utils import get_root_logger


class UpsampleModule(nn.Sequential):
    """Upsample module used in EDSR.

    Args:
        scale (int): Scale factor. Supported scales: 2^n and 3.
        mid_channels (int): Channel number of intermediate features.
    """

    def __init__(self, scale, mid_channels):
        modules = []
        if (scale & (scale - 1)) == 0:  # scale = 2^n
            for _ in range(int(math.log(scale, 2))):
                modules.append(
                    PixelShufflePack(
                        mid_channels, mid_channels, 2, upsample_kernel=3))
        elif scale == 3:
            modules.append(
                PixelShufflePack(
                    mid_channels, mid_channels, scale, upsample_kernel=3))
        else:
            raise ValueError(f'scale {scale} is not supported. '
                             'Supported scales: 2^n and 3.')

        super(UpsampleModule, self).__init__(*modules)


[docs]@BACKBONES.register_module() class EDSR(nn.Module): """EDSR network structure. Paper: Enhanced Deep Residual Networks for Single Image Super-Resolution. Ref repo: https://github.com/thstkdgus35/EDSR-PyTorch Args: in_channels (int): Channel number of inputs. out_channels (int): Channel number of outputs. mid_channels (int): Channel number of intermediate features. Default: 64. num_blocks (int): Block number in the trunk network. Default: 16. upscale_factor (int): Upsampling factor. Support 2^n and 3. Default: 4. res_scale (float): Used to scale the residual in residual block. Default: 1. rgb_mean (tuple[float]): Image mean in RGB orders. Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset. rgb_std (tuple[float]): Image std in RGB orders. In EDSR, it uses (1.0, 1.0, 1.0). Default: (1.0, 1.0, 1.0). """ def __init__(self, in_channels, out_channels, mid_channels=64, num_blocks=16, upscale_factor=4, res_scale=1, rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0)): super(EDSR, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.mid_channels = mid_channels self.num_blocks = num_blocks self.upscale_factor = upscale_factor self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) self.std = torch.Tensor(rgb_std).view(1, 3, 1, 1) self.conv_first = nn.Conv2d(in_channels, mid_channels, 3, padding=1) self.body = make_layer( ResidualBlockNoBN, num_blocks, mid_channels=mid_channels, res_scale=res_scale) self.conv_after_body = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) self.upsample = UpsampleModule(upscale_factor, mid_channels) self.conv_last = nn.Conv2d( mid_channels, out_channels, 3, 1, 1, bias=True)
[docs] def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ self.mean = self.mean.to(x) self.std = self.std.to(x) x = (x - self.mean) / self.std x = self.conv_first(x) res = self.conv_after_body(self.body(x)) res += x x = self.conv_last(self.upsample(res)) x = x * self.std + self.mean return x
[docs] def init_weights(self, pretrained=None, strict=True): """Init weights for models. Args: pretrained (str, optional): Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None. strict (boo, optional): Whether strictly load the pretrained model. Defaults to True. """ if isinstance(pretrained, str): logger = get_root_logger() load_checkpoint(self, pretrained, strict=strict, logger=logger) elif pretrained is None: pass # use default initialization else: raise TypeError('"pretrained" must be a str or None. ' f'But received {type(pretrained)}.')