Source code for mmedit.models.common.contextual_attention

from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs]class ContextualAttentionModule(nn.Module): """Contexture attention module. The details of this module can be found in: Generative Image Inpainting with Contextual Attention Args: unfold_raw_kernel_size (int): Kernel size used in unfolding raw feature. Default: 4. unfold_raw_stride (int): Stride used in unfolding raw feature. Default: 2. unfold_raw_padding (int): Padding used in unfolding raw feature. Default: 1. unfold_corr_kernel_size (int): Kernel size used in unfolding context for computing correlation maps. Default: 3. unfold_corr_stride (int): Stride used in unfolding context for computing correlation maps. Default: 1. unfold_corr_dilation (int): Dilation used in unfolding context for computing correlation maps. Default: 1. unfold_corr_padding (int): Padding used in unfolding context for computing correlation maps. Default: 1. scale (float): The resale factor used in resize input features. Default: 0.5. fuse_kernel_size (int): The kernel size used in fusion module. Default: 3. softmax_scale (float): The scale factor for softmax function. Default: 10. return_attenion_score (bool): If True, the attention score will be returned. Default: True. """ def __init__(self, unfold_raw_kernel_size=4, unfold_raw_stride=2, unfold_raw_padding=1, unfold_corr_kernel_size=3, unfold_corr_stride=1, unfold_corr_dilation=1, unfold_corr_padding=1, scale=0.5, fuse_kernel_size=3, softmax_scale=10, return_attenion_score=True): super(ContextualAttentionModule, self).__init__() self.unfold_raw_kernel_size = unfold_raw_kernel_size self.unfold_raw_stride = unfold_raw_stride self.unfold_raw_padding = unfold_raw_padding self.unfold_corr_kernel_size = unfold_corr_kernel_size self.unfold_corr_stride = unfold_corr_stride self.unfold_corr_dilation = unfold_corr_dilation self.unfold_corr_padding = unfold_corr_padding self.scale = scale self.fuse_kernel_size = fuse_kernel_size self.with_fuse_correlation = fuse_kernel_size > 1 self.softmax_scale = softmax_scale self.return_attention_score = return_attenion_score if self.with_fuse_correlation: assert fuse_kernel_size % 2 == 1 fuse_kernel = torch.eye(fuse_kernel_size).view( 1, 1, fuse_kernel_size, fuse_kernel_size) self.register_buffer('fuse_kernel', fuse_kernel) padding = int((fuse_kernel_size - 1) // 2) self.fuse_conv = partial(F.conv2d, padding=padding, stride=1) self.softmax = nn.Softmax(dim=1)
[docs] def forward(self, x, context, mask=None): """Forward Function. Args: x (torch.Tensor): Tensor with shape (n, c, h, w). context (torch.Tensor): Tensor with shape (n, c, h, w). mask (torch.Tensor): Tensor with shape (n, 1, h, w). Default: None. Returns: tuple(torch.Tensor): Features after contextural attention. """ # raw features to be used in copy (deconv) raw_context = context raw_context_cols = self.im2col( raw_context, kernel_size=self.unfold_raw_kernel_size, stride=self.unfold_raw_stride, padding=self.unfold_raw_padding, normalize=False, return_cols=True) # resize the feature to reduce computational cost x = F.interpolate(x, scale_factor=self.scale) context = F.interpolate(context, scale_factor=self.scale) context_cols = self.im2col( context, kernel_size=self.unfold_corr_kernel_size, stride=self.unfold_corr_stride, padding=self.unfold_corr_padding, dilation=self.unfold_corr_dilation, normalize=True, return_cols=True) h_unfold, w_unfold = self.calculate_unfold_hw( context.size()[-2:], kernel_size=self.unfold_corr_kernel_size, stride=self.unfold_corr_stride, padding=self.unfold_corr_padding, dilation=self.unfold_corr_dilation, ) # reshape context_cols to # (n*h_unfold*w_unfold, c, unfold_mks, unfold_mks) # 'mks' is short for 'mask_kernel_size' context_cols = context_cols.reshape(-1, *context_cols.shape[2:]) # the shape of correlation map should be: # (n, h_unfold*w_unfold, h', w') correlation_map = self.patch_correlation(x, context_cols) # fuse correlation map to enlarge consistent attention region. if self.with_fuse_correlation: correlation_map = self.fuse_correlation_map( correlation_map, h_unfold, w_unfold) correlation_map = self.mask_correlation_map(correlation_map, mask=mask) attention_score = self.softmax(correlation_map * self.softmax_scale) raw_context_filter = raw_context_cols.reshape( -1, *raw_context_cols.shape[2:]) output = self.patch_copy_deconv(attention_score, raw_context_filter) # deconv will cause overlap and we need to remove the effects of that overlap_factor = self.calculate_overlap_factor(attention_score) output /= overlap_factor if self.return_attention_score: n, _, h_s, w_s = attention_score.size() attention_score = attention_score.view(n, h_unfold, w_unfold, h_s, w_s) return output, attention_score return output
[docs] def patch_correlation(self, x, kernel): """Calculate patch correlation. Args: x (torch.Tensor): Input tensor. kernel (torch.Tensor): Kernel tensor. Returns: torch.Tensor: Tensor with shape of (n, l, h, w). """ n, _, h_in, w_in = x.size() patch_corr = F.conv2d( x.view(1, -1, h_in, w_in), kernel, stride=self.unfold_corr_stride, padding=self.unfold_corr_padding, dilation=self.unfold_corr_dilation, groups=n) h_out, w_out = patch_corr.size()[-2:] return patch_corr.view(n, -1, h_out, w_out)
[docs] def patch_copy_deconv(self, attention_score, context_filter): """Copy patches using deconv. Args: attention_score (torch.Tensor): Tensor with shape of (n, l , h, w). context_filter (torch.Tensor): Filter kernel. Returns: torch.Tensor: Tensor with shape of (n, c, h, w). """ n, num_context, h, w = attention_score.size() attention_score = attention_score.view(1, -1, h, w) output = F.conv_transpose2d( attention_score, context_filter, stride=self.unfold_raw_stride, padding=self.unfold_raw_padding, groups=n) h_out, w_out = output.size()[-2:] return output.view(n, -1, h_out, w_out)
[docs] def fuse_correlation_map(self, correlation_map, h_unfold, w_unfold): """Fuse correlation map. This operation is to fuse correlation map for increasing large consistent correlation regions. The mechanism behind this op is simple and easy to understand. A standard 'Eye' matrix will be applied as a filter on the correlation map in horizontal and vertical direction. The shape of input correlation map is (n, h_unfold*w_unfold, h, w). When adopting fusing, we will apply convolutional filter in the reshaped feature map with shape of (n, 1, h_unfold*w_fold, h*w). A simple specification for horizontal direction is shown below: .. code-block:: python (h, (h, (h, (h, 0) 1) 2) 3) ... (h, 0) (h, 1) 1 (h, 2) 1 (h, 3) 1 ... """ # horizontal direction n, _, h_map, w_map = correlation_map.size() map_ = correlation_map.permute(0, 2, 3, 1) map_ = map_.reshape(n, h_map * w_map, h_unfold * w_unfold, 1) map_ = map_.permute(0, 3, 1, 2).contiguous() map_ = self.fuse_conv(map_, self.fuse_kernel) correlation_map = map_.view(n, h_unfold, w_unfold, h_map, w_map) # vertical direction map_ = correlation_map.permute(0, 2, 1, 4, 3).reshape(n, 1, h_unfold * w_unfold, h_map * w_map) map_ = self.fuse_conv(map_, self.fuse_kernel) # Note that the dimension should be transposed since the convolution of # eye matrix will put the normed scores into the last several dimension correlation_map = map_.view(n, w_unfold, h_unfold, w_map, h_map).permute(0, 4, 3, 2, 1) correlation_map = correlation_map.reshape(n, -1, h_unfold, w_unfold) return correlation_map
[docs] def calculate_unfold_hw(self, input_size, kernel_size=3, stride=1, dilation=1, padding=0): """Calculate (h, w) after unfolding The official implementation of `unfold` in pytorch will put the dimension (h, w) into `L`. Thus, this function is just to calculate the (h, w) according to the equation in: https://pytorch.org/docs/stable/nn.html#torch.nn.Unfold """ h_in, w_in = input_size h_unfold = int((h_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1) w_unfold = int((w_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1) return h_unfold, w_unfold
[docs] def calculate_overlap_factor(self, attention_score): """Calculte the overlap factor after applying deconv. Args: attention_score (torch.Tensor): The attention score with shape of (n, c, h, w). Returns: torch.Tensor: The overlap factor will be returned. """ h, w = attention_score.shape[-2:] kernel_size = self.unfold_raw_kernel_size ones_input = torch.ones(1, 1, h, w).to(attention_score) ones_filter = torch.ones(1, 1, kernel_size, kernel_size).to(attention_score) overlap = F.conv_transpose2d( ones_input, ones_filter, stride=self.unfold_raw_stride, padding=self.unfold_raw_padding) # avoid division by zero overlap[overlap == 0] = 1. return overlap
[docs] def mask_correlation_map(self, correlation_map, mask): """Add mask weight for correlation map. Add a negative infinity number to the masked regions so that softmax function will result in 'zero' in those regions. Args: correlation_map (torch.Tensor): Correlation map with shape of (n, h_unfold*w_unfold, h_map, w_map). mask (torch.Tensor): Mask tensor with shape of (n, c, h, w). '1' in the mask indicates masked region while '0' indicates valid region. Returns: torch.Tensor: Updated correlation map with mask. """ if mask is not None: mask = F.interpolate(mask, scale_factor=self.scale) # if any pixel is masked in patch, the patch is considered to be # masked mask_cols = self.im2col( mask, kernel_size=self.unfold_corr_kernel_size, stride=self.unfold_corr_stride, padding=self.unfold_corr_padding, dilation=self.unfold_corr_dilation) mask_cols = (mask_cols.sum(dim=1, keepdim=True) > 0).float() mask_cols = mask_cols.permute(0, 2, 1).reshape(mask.size(0), -1, 1, 1) # add negative inf will bring zero in softmax mask_cols[mask_cols == 1] = -float('inf') correlation_map += mask_cols return correlation_map
[docs] def im2col(self, img, kernel_size, stride=1, padding=0, dilation=1, normalize=False, return_cols=False): """Reshape image-style feature to columns. This function is used for unfold feature maps to columns. The details of this function can be found in: https://pytorch.org/docs/1.1.0/nn.html?highlight=unfold#torch.nn.Unfold Args: img (torch.Tensor): Features to be unfolded. The shape of this feature should be (n, c, h, w). kernel_size (int): In this function, we only support square kernel with same height and width. stride (int): Stride number in unfolding. Default: 1. padding (int): Padding number in unfolding. Default: 0. dilation (int): Dilation number in unfolding. Default: 1. normalize (bool): If True, the unfolded feature will be normalized. Default: False. return_cols (bool): The official implementation in PyTorch of unfolding will return features with shape of (n, c*$prod{kernel_size}$, L). If True, the features will be reshaped to (n, L, c, kernel_size, kernel_size). Otherwise, the results will maintain the shape as the official implementation. Returns: torch.Tensor: Unfolded columns. If `return_cols` is True, the \ shape of output tensor is \ `(n, L, c, kernel_size, kernel_size)`. Otherwise, the shape \ will be `(n, c*$prod{kernel_size}$, L)`. """ # unfold img to columns with shape (n, c*kernel_size**2, num_cols) img_unfold = F.unfold( img, kernel_size, stride=stride, padding=padding, dilation=dilation) # normalize the feature map if normalize: norm = torch.sqrt((img_unfold**2).sum(dim=1, keepdim=True)) eps = torch.tensor([1e-4]).to(img) img_unfold = img_unfold / torch.max(norm, eps) if return_cols: img_unfold_ = img_unfold.permute(0, 2, 1) n, num_cols = img_unfold_.size()[:2] img_cols = img_unfold_.view(n, num_cols, img.size(1), kernel_size, kernel_size) return img_cols return img_unfold