Source code for mmedit.models.backbones.encoder_decoders.encoders.resnet_enc

import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_activation_layer, constant_init
from mmcv.runner import load_checkpoint

from mmedit.models.common import GCAModule
from mmedit.models.registry import COMPONENTS
from mmedit.utils.logger import get_root_logger


class BasicBlock(nn.Module):
    """Basic residual block.

    Args:
        in_channels (int): Input channels of the block.
        out_channels (int): Output channels of the block.
        kernel_size (int): Kernel size of the convolution layers.
        stride (int): Stride of the first conv of the block.
        interpolation (nn.Module, optional): Interpolation module for skip
            connection.
        conv_cfg (dict): dictionary to construct convolution layer. If it is
            None, 2d convolution will be applied. Default: None.
        norm_cfg (dict): Config dict for normalization layer. "BN" by default.
        act_cfg (dict): Config dict for activation layer, "ReLU" by default.
        with_spectral_norm (bool): Whether use spectral norm after conv.
            Default: False.
    """
    expansion = 1

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=3,
                 stride=1,
                 interpolation=None,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 act_cfg=dict(type='ReLU'),
                 with_spectral_norm=False):
        super(BasicBlock, self).__init__()
        assert stride == 1 or stride == 2, (
            f'stride other than 1 and 2 is not implemented, got {stride}')

        assert stride != 2 or interpolation is not None, (
            'if stride is 2, interpolation should be specified')

        self.conv1 = self.build_conv1(in_channels, out_channels, kernel_size,
                                      stride, conv_cfg, norm_cfg, act_cfg,
                                      with_spectral_norm)
        self.conv2 = self.build_conv2(in_channels, out_channels, kernel_size,
                                      conv_cfg, norm_cfg, with_spectral_norm)

        self.interpolation = interpolation
        self.activation = build_activation_layer(act_cfg)
        self.stride = stride

    def build_conv1(self, in_channels, out_channels, kernel_size, stride,
                    conv_cfg, norm_cfg, act_cfg, with_spectral_norm):
        return ConvModule(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=kernel_size // 2,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            act_cfg=act_cfg,
            with_spectral_norm=with_spectral_norm)

    def build_conv2(self, in_channels, out_channels, kernel_size, conv_cfg,
                    norm_cfg, with_spectral_norm):
        return ConvModule(
            out_channels,
            out_channels,
            kernel_size,
            stride=1,
            padding=kernel_size // 2,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            act_cfg=None,
            with_spectral_norm=with_spectral_norm)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.conv2(out)

        if self.interpolation is not None:
            identity = self.interpolation(x)

        out += identity
        out = self.activation(out)

        return out


[docs]@COMPONENTS.register_module() class ResNetEnc(nn.Module): """ResNet encoder for image matting. This class is adopted from https://github.com/Yaoyi-Li/GCA-Matting. Implement and pre-train on ImageNet with the tricks from https://arxiv.org/abs/1812.01187 without the mix-up part. Args: block (str): Type of residual block. Currently only `BasicBlock` is implemented. layers (list[int]): Number of layers in each block. in_channels (int): Number of input channels. conv_cfg (dict): dictionary to construct convolution layer. If it is None, 2d convolution will be applied. Default: None. norm_cfg (dict): Config dict for normalization layer. "BN" by default. act_cfg (dict): Config dict for activation layer, "ReLU" by default. with_spectral_norm (bool): Whether use spectral norm after conv. Default: False. late_downsample (bool): Whether to adopt late downsample strategy, Default: False. """ def __init__(self, block, layers, in_channels, conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), with_spectral_norm=False, late_downsample=False): super(ResNetEnc, self).__init__() if block == 'BasicBlock': block = BasicBlock else: raise NotImplementedError(f'{block} is not implemented.') self.inplanes = 64 self.midplanes = 64 if late_downsample else 32 start_stride = [1, 2, 1, 2] if late_downsample else [2, 1, 2, 1] self.conv1 = ConvModule( in_channels, 32, 3, stride=start_stride[0], padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, with_spectral_norm=with_spectral_norm) self.conv2 = ConvModule( 32, self.midplanes, 3, stride=start_stride[1], padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, with_spectral_norm=with_spectral_norm) self.conv3 = ConvModule( self.midplanes, self.inplanes, 3, stride=start_stride[2], padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, with_spectral_norm=with_spectral_norm) self.layer1 = self._make_layer(block, 64, layers[0], start_stride[3], conv_cfg, norm_cfg, act_cfg, with_spectral_norm) self.layer2 = self._make_layer(block, 128, layers[1], 2, conv_cfg, norm_cfg, act_cfg, with_spectral_norm) self.layer3 = self._make_layer(block, 256, layers[2], 2, conv_cfg, norm_cfg, act_cfg, with_spectral_norm) self.layer4 = self._make_layer(block, 512, layers[3], 2, conv_cfg, norm_cfg, act_cfg, with_spectral_norm) self.out_channels = 512 def init_weights(self, pretrained=None): if isinstance(pretrained, str): # if pretrained weight is trained on 3-channel images, # initialize other channels with zeros self.conv1.conv.weight.data[:, 3:, :, :] = 0 logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: for m in self.modules(): if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): constant_init(m.weight, 1) constant_init(m.bias, 0) # Zero-initialize the last BN in each residual branch, so that the # residual branch starts with zeros, and each residual block # behaves like an identity. This improves the model by 0.2~0.3% # according to https://arxiv.org/abs/1706.02677 for m in self.modules(): if isinstance(m, BasicBlock): constant_init(m.conv2.bn.weight, 0) else: raise TypeError(f'"pretrained" must be a str or None. ' f'But received {type(pretrained)}.') def _make_layer(self, block, planes, num_blocks, stride, conv_cfg, norm_cfg, act_cfg, with_spectral_norm): downsample = None if stride != 1: downsample = nn.Sequential( nn.AvgPool2d(2, stride), ConvModule( self.inplanes, planes * block.expansion, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None, with_spectral_norm=with_spectral_norm)) layers = [ block( self.inplanes, planes, stride=stride, interpolation=downsample, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, with_spectral_norm=with_spectral_norm) ] self.inplanes = planes * block.expansion for _ in range(1, num_blocks): layers.append( block( self.inplanes, planes, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, with_spectral_norm=with_spectral_norm)) return nn.Sequential(*layers)
[docs] def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (N, C, H, W). Returns: Tensor: Output tensor. """ x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) return x
[docs]@COMPONENTS.register_module() class ResShortcutEnc(ResNetEnc): """ResNet backbone for image matting with shortcut connection. :: image ---------------- shortcut[0] --- feat1 | conv1-conv2 ---------- shortcut[1] --- feat2 | conv3-layer1 --- shortcut[2] --- feat3 | layer2 -- shortcut[4] --- feat4 | layer3 - shortcut[5] --- feat5 | layer4 ---------------- out Baseline model of Natural Image Matting via Guided Contextual Attention https://arxiv.org/pdf/2001.04069.pdf. Args: block (str): Type of residual block. Currently only `BasicBlock` is implemented. layers (list[int]): Number of layers in each block. in_channels (int): Number of input channels. conv_cfg (dict): Dictionary to construct convolution layer. If it is None, 2d convolution will be applied. Default: None. norm_cfg (dict): Config dict for normalization layer. "BN" by default. act_cfg (dict): Config dict for activation layer, "ReLU" by default. with_spectral_norm (bool): Whether use spectral norm after conv. Default: False. late_downsample (bool): Whether to adopt late downsample strategy. Default: False. order (tuple[str]): Order of `conv`, `norm` and `act` layer in shortcut convolution module. Default: ('conv', 'act', 'norm'). """ def __init__(self, block, layers, in_channels, conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), with_spectral_norm=False, late_downsample=False, order=('conv', 'act', 'norm')): super(ResShortcutEnc, self).__init__(block, layers, in_channels, conv_cfg, norm_cfg, act_cfg, with_spectral_norm, late_downsample) # TODO: rename self.midplanes to self.mid_channels in ResNetEnc self.shortcut_in_channels = [in_channels, self.midplanes, 64, 128, 256] self.shortcut_out_channels = [32, self.midplanes, 64, 128, 256] self.shortcut = nn.ModuleList() for in_channels, out_channels in zip(self.shortcut_in_channels, self.shortcut_out_channels): self.shortcut.append( self._make_shortcut(in_channels, out_channels, conv_cfg, norm_cfg, act_cfg, order, with_spectral_norm)) def _make_shortcut(self, in_channels, out_channels, conv_cfg, norm_cfg, act_cfg, order, with_spectral_norm): return nn.Sequential( ConvModule( in_channels, out_channels, 3, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, with_spectral_norm=with_spectral_norm, order=order), ConvModule( out_channels, out_channels, 3, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, with_spectral_norm=with_spectral_norm, order=order))
[docs] def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (N, C, H, W). Returns: dict: Contains the output tensor and shortcut feature. """ out = self.conv1(x) x1 = self.conv2(out) out = self.conv3(x1) x2 = self.layer1(out) x3 = self.layer2(x2) x4 = self.layer3(x3) out = self.layer4(x4) feat1 = self.shortcut[0](x) feat2 = self.shortcut[1](x1) feat3 = self.shortcut[2](x2) feat4 = self.shortcut[3](x3) feat5 = self.shortcut[4](x4) return { 'out': out, 'feat1': feat1, 'feat2': feat2, 'feat3': feat3, 'feat4': feat4, 'feat5': feat5, }
[docs]@COMPONENTS.register_module() class ResGCAEncoder(ResShortcutEnc): """ResNet backbone with shortcut connection and gca module. :: image ---------------- shortcut[0] -------------- feat1 | conv1-conv2 ---------- shortcut[1] -------------- feat2 | conv3-layer1 ---- shortcut[2] -------------- feat3 | | image - guidance_conv ------------ img_feat | | layer2 --- gca_module - shortcut[4] - feat4 | layer3 -- shortcut[5] - feat5 | layer4 --------------- out * gca module also requires unknown tensor generated by trimap which is \ ignored in the above graph. Implementation of Natural Image Matting via Guided Contextual Attention https://arxiv.org/pdf/2001.04069.pdf. Args: block (str): Type of residual block. Currently only `BasicBlock` is implemented. layers (list[int]): Number of layers in each block. in_channels (int): Number of input channels. conv_cfg (dict): Dictionary to construct convolution layer. If it is None, 2d convolution will be applied. Default: None. norm_cfg (dict): Config dict for normalization layer. "BN" by default. act_cfg (dict): Config dict for activation layer, "ReLU" by default. late_downsample (bool): Whether to adopt late downsample strategy. Default: False. order (tuple[str]): Order of `conv`, `norm` and `act` layer in shortcut convolution module. Default: ('conv', 'act', 'norm'). """ def __init__(self, block, layers, in_channels, conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), with_spectral_norm=False, late_downsample=False, order=('conv', 'act', 'norm')): super(ResGCAEncoder, self).__init__(block, layers, in_channels, conv_cfg, norm_cfg, act_cfg, with_spectral_norm, late_downsample, order) assert in_channels == 4 or in_channels == 6, ( f'in_channels must be 4 or 6, but got {in_channels}') self.trimap_channels = in_channels - 3 guidance_in_channels = [3, 16, 32] guidance_out_channels = [16, 32, 128] guidance_head = [] for in_channels, out_channels in zip(guidance_in_channels, guidance_out_channels): guidance_head += [ ConvModule( in_channels, out_channels, 3, stride=2, padding=1, norm_cfg=norm_cfg, act_cfg=act_cfg, with_spectral_norm=with_spectral_norm, padding_mode='reflect', order=order) ] self.guidance_head = nn.Sequential(*guidance_head) self.gca = GCAModule(128, 128) def init_weights(self, pretrained=None): if isinstance(pretrained, str): logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: super(ResGCAEncoder, self).init_weights() else: raise TypeError('"pretrained" must be a str or None. ' f'But received {type(pretrained)}.')
[docs] def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (N, C, H, W). Returns: dict: Contains the output tensor, shortcut feature and \ intermediate feature. """ out = self.conv1(x) x1 = self.conv2(out) out = self.conv3(x1) img_feat = self.guidance_head(x[:, :3, ...]) if self.trimap_channels == 3: unknown = x[:, 4:5, ...] else: unknown = x[:, 3:, ...].eq(1).float() # same as img_feat, downsample to 1/8 unknown = F.interpolate(unknown, scale_factor=1 / 8, mode='nearest') x2 = self.layer1(out) x3 = self.layer2(x2) x3 = self.gca(img_feat, x3, unknown) x4 = self.layer3(x3) out = self.layer4(x4) # shortcut block feat1 = self.shortcut[0](x) feat2 = self.shortcut[1](x1) feat3 = self.shortcut[2](x2) feat4 = self.shortcut[3](x3) feat5 = self.shortcut[4](x4) return { 'out': out, 'feat1': feat1, 'feat2': feat2, 'feat3': feat3, 'feat4': feat4, 'feat5': feat5, 'img_feat': img_feat, 'unknown': unknown }