Source code for mmedit.models.common.partial_conv

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import CONV_LAYERS


[docs]@CONV_LAYERS.register_module(name='PConv') class PartialConv2d(nn.Conv2d): """Implementation for partial convolution. Image Inpainting for Irregular Holes Using Partial Convolutions [https://arxiv.org/abs/1804.07723] Args: multi_channel (bool): If True, the mask is multi-channle. Otherwise, the mask is single-channel. eps (float): Need to be changed for mixed precision training. For mixed precision training, you need change 1e-8 to 1e-6. """ def __init__(self, *args, multi_channel=False, eps=1e-8, **kwargs): super(PartialConv2d, self).__init__(*args, **kwargs) # whether the mask is multi-channel or not self.multi_channel = multi_channel self.eps = eps if self.multi_channel: out_channels, in_channels = self.out_channels, self.in_channels else: out_channels, in_channels = 1, 1 self.register_buffer( 'weight_mask_updater', torch.ones(out_channels, in_channels, self.kernel_size[0], self.kernel_size[1])) self.mask_kernel_numel = np.prod(self.weight_mask_updater.shape[1:4]) self.mask_kernel_numel = np.asscalar(self.mask_kernel_numel)
[docs] def forward(self, input, mask=None, return_mask=True): """Forward function for partial conv2d. Args: input (torch.Tensor): Tensor with shape of (n, c, h, w). mask (torch.Tensor): Tensor with shape of (n, c, h, w) or (n, 1, h, w). If mask is not given, the function will work as standard conv2d. Default: None. return_mask (bool): If True and mask is not None, the updated mask will be returned. Default: True. Returns: torch.Tensor : Results after partial conv.\ torch.Tensor : Updated mask will be returned if mask is given and \ ``return_mask`` is True. """ assert input.dim() == 4 if mask is not None: assert mask.dim() == 4 if self.multi_channel: assert mask.shape[1] == input.shape[1] else: assert mask.shape[1] == 1 # update mask and compute mask ratio if mask is not None: with torch.no_grad(): updated_mask = F.conv2d( mask, self.weight_mask_updater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation) mask_ratio = self.mask_kernel_numel / (updated_mask + self.eps) updated_mask = torch.clamp(updated_mask, 0, 1) mask_ratio = mask_ratio * updated_mask # standard conv2d if mask is not None: input = input * mask raw_out = super(PartialConv2d, self).forward(input) if mask is not None: if self.bias is None: output = raw_out * mask_ratio else: # compute new bias when mask is given bias_view = self.bias.view(1, self.out_channels, 1, 1) output = (raw_out - bias_view) * mask_ratio + bias_view output = output * updated_mask else: output = raw_out if return_mask and mask is not None: return output, updated_mask else: return output