Source code for mmedit.models.common.generation_model_utils

import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, kaiming_init, normal_init, xavier_init
from torch.nn import init


[docs]def generation_init_weights(module, init_type='normal', init_gain=0.02): """Default initialization of network weights for image generation. By default, we use normal init, but xavier and kaiming might work better for some applications. Args: module (nn.Module): Module to be initialized. init_type (str): The name of an initialization method: normal | xavier | kaiming | orthogonal. init_gain (float): Scaling factor for normal, xavier and orthogonal. """ def init_func(m): """Initialization function. Args: m (nn.Module): Module to be initialized. """ classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': normal_init(m, 0.0, init_gain) elif init_type == 'xavier': xavier_init(m, gain=init_gain, distribution='normal') elif init_type == 'kaiming': kaiming_init( m, a=0, mode='fan_in', nonlinearity='leaky_relu', distribution='normal') elif init_type == 'orthogonal': init.orthogonal_(m.weight, gain=init_gain) init.constant_(m.bias.data, 0.0) else: raise NotImplementedError( f"Initialization method '{init_type}' is not implemented") elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; # only normal distribution applies. normal_init(m, 1.0, init_gain) module.apply(init_func)
[docs]class GANImageBuffer(object): """This class implements an image buffer that stores previously generated images. This buffer allows us to update the discriminator using a history of generated images rather than the ones produced by the latest generator to reduce model oscillation. Args: buffer_size (int): The size of image buffer. If buffer_size = 0, no buffer will be created. buffer_ratio (float): The chance / possibility to use the images previously stored in the buffer. """ def __init__(self, buffer_size, buffer_ratio=0.5): self.buffer_size = buffer_size # create an empty buffer if self.buffer_size > 0: self.img_num = 0 self.image_buffer = [] self.buffer_ratio = buffer_ratio
[docs] def query(self, images): """Query current image batch using a history of generated images. Args: images (Tensor): Current image batch without history information. """ if self.buffer_size == 0: # if the buffer size is 0, do nothing return images return_images = [] for image in images: image = torch.unsqueeze(image.data, 0) # if the buffer is not full, keep inserting current images if self.img_num < self.buffer_size: self.img_num = self.img_num + 1 self.image_buffer.append(image) return_images.append(image) else: use_buffer = np.random.random() < self.buffer_ratio # by self.buffer_ratio, the buffer will return a previously # stored image, and insert the current image into the buffer if use_buffer: random_id = np.random.randint(0, self.buffer_size) image_tmp = self.image_buffer[random_id].clone() self.image_buffer[random_id] = image return_images.append(image_tmp) # by (1 - self.buffer_ratio), the buffer will return the # current image else: return_images.append(image) # collect all the images and return return_images = torch.cat(return_images, 0) return return_images
[docs]class UnetSkipConnectionBlock(nn.Module): """Construct a Unet submodule with skip connections, with the following structure: downsampling - `submodule` - upsampling. Args: outer_channels (int): Number of channels at the outer conv layer. inner_channels (int): Number of channels at the inner conv layer. in_channels (int): Number of channels in input images/features. If is None, equals to `outer_channels`. Default: None. submodule (UnetSkipConnectionBlock): Previously constructed submodule. Default: None. is_outermost (bool): Whether this module is the outermost module. Default: False. is_innermost (bool): Whether this module is the innermost module. Default: False. norm_cfg (dict): Config dict to build norm layer. Default: `dict(type='BN')`. use_dropout (bool): Whether to use dropout layers. Default: False. """ def __init__(self, outer_channels, inner_channels, in_channels=None, submodule=None, is_outermost=False, is_innermost=False, norm_cfg=dict(type='BN'), use_dropout=False): super(UnetSkipConnectionBlock, self).__init__() # cannot be both outermost and innermost assert not (is_outermost and is_innermost), ( "'is_outermost' and 'is_innermost' cannot be True" 'at the same time.') self.is_outermost = is_outermost assert isinstance(norm_cfg, dict), ("'norm_cfg' should be dict, but" f'got {type(norm_cfg)}') assert 'type' in norm_cfg, "'norm_cfg' must have key 'type'" # We use norm layers in the unet skip connection block. # Only for IN, use bias since it does not have affine parameters. use_bias = norm_cfg['type'] == 'IN' kernel_size = 4 stride = 2 padding = 1 if in_channels is None: in_channels = outer_channels down_conv_cfg = dict(type='Conv2d') down_norm_cfg = norm_cfg down_act_cfg = dict(type='LeakyReLU', negative_slope=0.2) up_conv_cfg = dict(type='Deconv') up_norm_cfg = norm_cfg up_act_cfg = dict(type='ReLU') up_in_channels = inner_channels * 2 up_bias = use_bias middle = [submodule] upper = [] if is_outermost: down_act_cfg = None down_norm_cfg = None up_bias = True up_norm_cfg = None upper = [nn.Tanh()] elif is_innermost: down_norm_cfg = None up_in_channels = inner_channels middle = [] else: upper = [nn.Dropout(0.5)] if use_dropout else [] down = [ ConvModule( in_channels=in_channels, out_channels=inner_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=use_bias, conv_cfg=down_conv_cfg, norm_cfg=down_norm_cfg, act_cfg=down_act_cfg, order=('act', 'conv', 'norm')) ] up = [ ConvModule( in_channels=up_in_channels, out_channels=outer_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=up_bias, conv_cfg=up_conv_cfg, norm_cfg=up_norm_cfg, act_cfg=up_act_cfg, order=('act', 'conv', 'norm')) ] model = down + middle + up + upper self.model = nn.Sequential(*model)
[docs] def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ if self.is_outermost: return self.model(x) else: # add skip connections return torch.cat([x, self.model(x)], 1)
[docs]class ResidualBlockWithDropout(nn.Module): """Define a Residual Block with dropout layers. Ref: Deep Residual Learning for Image Recognition A residual block is a conv block with skip connections. A dropout layer is added between two common conv modules. Args: channels (int): Number of channels in the conv layer. padding_mode (str): The name of padding layer: 'reflect' | 'replicate' | 'zeros'. norm_cfg (dict): Config dict to build norm layer. Default: `dict(type='IN')`. use_dropout (bool): Whether to use dropout layers. Default: True. """ def __init__(self, channels, padding_mode, norm_cfg=dict(type='BN'), use_dropout=True): super(ResidualBlockWithDropout, self).__init__() assert isinstance(norm_cfg, dict), ("'norm_cfg' should be dict, but" f'got {type(norm_cfg)}') assert 'type' in norm_cfg, "'norm_cfg' must have key 'type'" # We use norm layers in the residual block with dropout layers. # Only for IN, use bias since it does not have affine parameters. use_bias = norm_cfg['type'] == 'IN' block = [ ConvModule( in_channels=channels, out_channels=channels, kernel_size=3, padding=1, bias=use_bias, norm_cfg=norm_cfg, padding_mode=padding_mode) ] if use_dropout: block += [nn.Dropout(0.5)] block += [ ConvModule( in_channels=channels, out_channels=channels, kernel_size=3, padding=1, bias=use_bias, norm_cfg=norm_cfg, act_cfg=None, padding_mode=padding_mode) ] self.block = nn.Sequential(*block)
[docs] def forward(self, x): """Forward function. Add skip connections without final ReLU. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ out = x + self.block(x) return out