import torch.nn as nn
import torch.nn.functional as F
from .sr_backbone_utils import default_init_weights
[docs]class PixelShufflePack(nn.Module):
""" Pixel Shuffle upsample layer.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
scale_factor (int): Upsample ratio.
upsample_kernel (int): Kernel size of Conv layer to expand channels.
Returns:
Upsampled feature map.
"""
def __init__(self, in_channels, out_channels, scale_factor,
upsample_kernel):
super(PixelShufflePack, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.scale_factor = scale_factor
self.upsample_kernel = upsample_kernel
self.upsample_conv = nn.Conv2d(
self.in_channels,
self.out_channels * scale_factor * scale_factor,
self.upsample_kernel,
padding=(self.upsample_kernel - 1) // 2)
self.init_weights()
[docs] def init_weights(self):
"""Initialize weights for PixelShufflePack.
"""
default_init_weights(self, 1)
[docs] def forward(self, x):
"""Forward function for PixelShufflePack.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
x = self.upsample_conv(x)
x = F.pixel_shuffle(x, self.scale_factor)
return x