Source code for mmedit.models.backbones.encoder_decoders.decoders.pconv_decoder

import torch
import torch.nn as nn
import torch.nn.functional as F

from mmedit.models.common import MaskConvModule
from mmedit.models.registry import COMPONENTS


[docs]@COMPONENTS.register_module() class PConvDecoder(nn.Module): """Decoder with partial conv. About the details for this architecture, pls see: Image Inpainting for Irregular Holes Using Partial Convolutions Args: num_layers (int): The number of convolutional layers. Default: 7. interpolation (str): The upsample mode. Default: 'nearest'. conv_cfg (dict): Config for convolution module. Default: {'type': 'PConv', 'multi_channel': True}. norm_cfg (dict): Config for norm layer. Default: {'type': 'BN'}. """ def __init__(self, num_layers=7, interpolation='nearest', conv_cfg=dict(type='PConv', multi_channel=True), norm_cfg=dict(type='BN')): super(PConvDecoder, self).__init__() self.num_layers = num_layers self.interpolation = interpolation for i in range(4, num_layers): name = f'dec{i+1}' self.add_module( name, MaskConvModule( 512 + 512, 512, kernel_size=3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='LeakyReLU', negative_slope=0.2))) self.dec4 = MaskConvModule( 512 + 256, 256, kernel_size=3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='LeakyReLU', negative_slope=0.2)) self.dec3 = MaskConvModule( 256 + 128, 128, kernel_size=3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='LeakyReLU', negative_slope=0.2)) self.dec2 = MaskConvModule( 128 + 64, 64, kernel_size=3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=dict(type='LeakyReLU', negative_slope=0.2)) self.dec1 = MaskConvModule( 64 + 3, 3, kernel_size=3, stride=1, padding=1, conv_cfg=conv_cfg, norm_cfg=None, act_cfg=None)
[docs] def forward(self, input_dict): """Forward Function. Args: input_dict (dict | torch.Tensor): Input dict with middle features or torch.Tensor. Returns: torch.Tensor: Output tensor with shape of (n, c, h, w). """ hidden_feats = input_dict['hidden_feats'] hidden_masks = input_dict['hidden_masks'] h_key = 'h{:d}'.format(self.num_layers) h, h_mask = hidden_feats[h_key], hidden_masks[h_key] for i in range(self.num_layers, 0, -1): enc_h_key = f'h{i-1}' dec_l_key = f'dec{i}' h = F.interpolate(h, scale_factor=2, mode=self.interpolation) h_mask = F.interpolate( h_mask, scale_factor=2, mode=self.interpolation) h = torch.cat([h, hidden_feats[enc_h_key]], dim=1) h_mask = torch.cat([h_mask, hidden_masks[enc_h_key]], dim=1) h, h_mask = getattr(self, dec_l_key)(h, h_mask) return h, h_mask