Source code for mmedit.models.common.upsample

import torch.nn as nn
import torch.nn.functional as F

from .sr_backbone_utils import default_init_weights


[docs]class PixelShufflePack(nn.Module): """ Pixel Shuffle upsample layer. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. scale_factor (int): Upsample ratio. upsample_kernel (int): Kernel size of Conv layer to expand channels. Returns: Upsampled feature map. """ def __init__(self, in_channels, out_channels, scale_factor, upsample_kernel): super(PixelShufflePack, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.scale_factor = scale_factor self.upsample_kernel = upsample_kernel self.upsample_conv = nn.Conv2d( self.in_channels, self.out_channels * scale_factor * scale_factor, self.upsample_kernel, padding=(self.upsample_kernel - 1) // 2) self.init_weights()
[docs] def init_weights(self): """Initialize weights for PixelShufflePack. """ default_init_weights(self, 1)
[docs] def forward(self, x): """Forward function for PixelShufflePack. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ x = self.upsample_conv(x) x = F.pixel_shuffle(x, self.scale_factor) return x