Source code for mmedit.models.backbones.encoder_decoders.encoders.pconv_encoder

import torch.nn as nn
from mmcv.utils.parrots_wrapper import _BatchNorm

from mmedit.models.common import MaskConvModule
from mmedit.models.registry import COMPONENTS


[docs]@COMPONENTS.register_module() class PConvEncoder(nn.Module): """Encoder with partial conv. About the details for this architecture, pls see: Image Inpainting for Irregular Holes Using Partial Convolutions Args: in_channels (int): The number of input channels. Default: 3. num_layers (int): The number of convolutional layers. Default 7. conv_cfg (dict): Config for convolution module. Default: {'type': 'PConv', 'multi_channel': True}. norm_cfg (dict): Config for norm layer. Default: {'type': 'BN'}. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effective on Batch Norm and its variants only. """ def __init__(self, in_channels=3, num_layers=7, conv_cfg=dict(type='PConv', multi_channel=True), norm_cfg=dict(type='BN', requires_grad=True), norm_eval=False): super(PConvEncoder, self).__init__() self.num_layers = num_layers self.norm_eval = norm_eval self.enc1 = MaskConvModule( in_channels, 64, kernel_size=7, stride=2, padding=3, conv_cfg=conv_cfg, norm_cfg=None, act_cfg=dict(type='ReLU')) self.enc2 = MaskConvModule( 64, 128, kernel_size=5, stride=2, padding=2, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU')) self.enc3 = MaskConvModule( 128, 256, kernel_size=5, stride=2, padding=2, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU')) self.enc4 = MaskConvModule( 256, 512, kernel_size=3, stride=2, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU')) for i in range(4, num_layers): name = f'enc{i+1}' self.add_module( name, MaskConvModule( 512, 512, kernel_size=3, stride=2, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU')))
[docs] def train(self, mode=True): super(PConvEncoder, self).train(mode) if mode and self.norm_eval: for m in self.modules(): # trick: eval have effect on BatchNorm only if isinstance(m, _BatchNorm): m.eval()
[docs] def forward(self, x, mask): """Forward function for partial conv encoder. Args: x (torch.Tensor): Masked image with shape (n, c, h, w). mask (torch.Tensor): Mask tensor with shape (n, c, h, w). Returns: dict: Contains the results and middle level features in this \ module. `hidden_feats` contain the middle feature maps and \ `hidden_masks` store updated masks. """ # dict for hidden layers of main information flow hidden_feats = {} # dict for hidden layers of mask information flow hidden_masks = {} hidden_feats['h0'], hidden_masks['h0'] = x, mask h_key_prev = 'h0' for i in range(1, self.num_layers + 1): l_key = f'enc{i}' h_key = f'h{i}' hidden_feats[h_key], hidden_masks[h_key] = getattr(self, l_key)( hidden_feats[h_key_prev], hidden_masks[h_key_prev]) h_key_prev = h_key outputs = dict( out=hidden_feats[f'h{self.num_layers}'], hidden_feats=hidden_feats, hidden_masks=hidden_masks) return outputs