Source code for mmedit.models.backbones.encoder_decoders.decoders.plain_decoder

import torch.nn as nn
from mmcv.cnn.utils.weight_init import xavier_init

from mmedit.models.registry import COMPONENTS

[docs]@COMPONENTS.register_module() class PlainDecoder(nn.Module): """Simple decoder from Deep Image Matting. Args: in_channels (int): Channel num of input features. """ def __init__(self, in_channels): super(PlainDecoder, self).__init__() self.deconv6_1 = nn.Conv2d(in_channels, 512, kernel_size=1) self.deconv5_1 = nn.Conv2d(512, 512, kernel_size=5, padding=2) self.deconv4_1 = nn.Conv2d(512, 256, kernel_size=5, padding=2) self.deconv3_1 = nn.Conv2d(256, 128, kernel_size=5, padding=2) self.deconv2_1 = nn.Conv2d(128, 64, kernel_size=5, padding=2) self.deconv1_1 = nn.Conv2d(64, 64, kernel_size=5, padding=2) self.deconv1 = nn.Conv2d(64, 1, kernel_size=5, padding=2) self.relu = nn.ReLU(inplace=True) self.max_unpool2d = nn.MaxUnpool2d(kernel_size=2, stride=2)
[docs] def init_weights(self): """Init weights for the module. """ for m in self.modules(): if isinstance(m, nn.Conv2d): xavier_init(m)
[docs] def forward(self, inputs): """Forward function of PlainDecoder. Args: inputs (dict): Output dictionary of the VGG encoder containing: - out (Tensor): Output of the VGG encoder. - max_idx_1 (Tensor): Index of the first maxpooling layer in the VGG encoder. - max_idx_2 (Tensor): Index of the second maxpooling layer in the VGG encoder. - max_idx_3 (Tensor): Index of the third maxpooling layer in the VGG encoder. - max_idx_4 (Tensor): Index of the fourth maxpooling layer in the VGG encoder. - max_idx_5 (Tensor): Index of the fifth maxpooling layer in the VGG encoder. Returns: Tensor: Output tensor. """ max_idx_1 = inputs['max_idx_1'] max_idx_2 = inputs['max_idx_2'] max_idx_3 = inputs['max_idx_3'] max_idx_4 = inputs['max_idx_4'] max_idx_5 = inputs['max_idx_5'] x = inputs['out'] out = self.relu(self.deconv6_1(x)) out = self.max_unpool2d(out, max_idx_5) out = self.relu(self.deconv5_1(out)) out = self.max_unpool2d(out, max_idx_4) out = self.relu(self.deconv4_1(out)) out = self.max_unpool2d(out, max_idx_3) out = self.relu(self.deconv3_1(out)) out = self.max_unpool2d(out, max_idx_2) out = self.relu(self.deconv2_1(out)) out = self.max_unpool2d(out, max_idx_1) out = self.relu(self.deconv1_1(out)) raw_alpha = self.deconv1(out) return raw_alpha