Source code for mmedit.datasets.pipelines.loading

from pathlib import Path

import mmcv
import numpy as np
from mmcv.fileio import FileClient

from mmedit.core.mask import (bbox2mask, brush_stroke_mask, get_irregular_mask,
                              random_bbox)
from ..registry import PIPELINES


[docs]@PIPELINES.register_module() class LoadImageFromFile(object): """Load image from file. Args: io_backend (str): io backend where images are store. Default: 'disk'. key (str): Keys in results to find corresponding path. Default: 'gt'. flag (str): Loading flag for images. Default: 'color'. channel_order (str): Order of channel, candidates are 'bgr' and 'rgb'. Default: 'bgr'. save_original_img (bool): If True, maintain a copy of the image in `results` dict with name of `f'ori_{key}'`. Default: False. kwargs (dict): Args for file client. """ def __init__(self, io_backend='disk', key='gt', flag='color', channel_order='bgr', save_original_img=False, **kwargs): self.io_backend = io_backend self.key = key self.flag = flag self.save_original_img = save_original_img self.channel_order = channel_order self.kwargs = kwargs self.file_client = None def __call__(self, results): """Call function. Args: results (dict): A dict containing the necessary information and data for augmentation. Returns: dict: A dict containing the processed data and information. """ if self.file_client is None: self.file_client = FileClient(self.io_backend, **self.kwargs) filepath = str(results[f'{self.key}_path']) img_bytes = self.file_client.get(filepath) img = mmcv.imfrombytes( img_bytes, flag=self.flag, channel_order=self.channel_order) # HWC results[self.key] = img results[f'{self.key}_path'] = filepath results[f'{self.key}_ori_shape'] = img.shape if self.save_original_img: results[f'ori_{self.key}'] = img.copy() return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += ( f'(io_backend={self.io_backend}, key={self.key}, ' f'flag={self.flag}, save_original_img={self.save_original_img})') return repr_str
[docs]@PIPELINES.register_module() class LoadImageFromFileList(LoadImageFromFile): """Load image from file list. It accepts a list of path and read each frame from each path. A list of frames will be returned. Args: io_backend (str): io backend where images are store. Default: 'disk'. key (str): Keys in results to find corresponding path. Default: 'gt'. flag (str): Loading flag for images. Default: 'color'. save_original_img (bool): If True, maintain a copy of the image in `results` dict with name of `f'ori_{key}'`. Default: False. kwargs (dict): Args for file client. """ def __call__(self, results): """Call function. Args: results (dict): A dict containing the necessary information and data for augmentation. Returns: dict: A dict containing the processed data and information. """ if self.file_client is None: self.file_client = FileClient(self.io_backend, **self.kwargs) filepaths = results[f'{self.key}_path'] if not isinstance(filepaths, list): raise TypeError( f'filepath should be list, but got {type(filepaths)}') filepaths = [str(v) for v in filepaths] imgs = [] shapes = [] if self.save_original_img: ori_imgs = [] for filepath in filepaths: img_bytes = self.file_client.get(filepath) img = mmcv.imfrombytes(img_bytes, flag=self.flag) # HWC, BGR if img.ndim == 2: img = np.expand_dims(img, axis=2) imgs.append(img) shapes.append(img.shape) if self.save_original_img: ori_imgs.append(img.copy()) results[self.key] = imgs results[f'{self.key}_path'] = filepaths results[f'{self.key}_ori_shape'] = shapes if self.save_original_img: results[f'ori_{self.key}'] = ori_imgs return results
[docs]@PIPELINES.register_module() class RandomLoadResizeBg(object): """Randomly load a background image and resize it. Required key is "fg", added key is "bg". Args: bg_dir (str): Path of directory to load background images from. io_backend (str): io backend where images are store. Default: 'disk'. flag (str): Loading flag for images. Default: 'color'. kwargs (dict): Args for file client. """ def __init__(self, bg_dir, io_backend='disk', flag='color', **kwargs): self.bg_dir = bg_dir self.bg_list = list(mmcv.scandir(bg_dir)) self.io_backend = io_backend self.flag = flag self.kwargs = kwargs self.file_client = None def __call__(self, results): """Call function. Args: results (dict): A dict containing the necessary information and data for augmentation. Returns: dict: A dict containing the processed data and information. """ if self.file_client is None: self.file_client = FileClient(self.io_backend, **self.kwargs) h, w = results['fg'].shape[:2] idx = np.random.randint(len(self.bg_list)) filepath = Path(self.bg_dir).joinpath(self.bg_list[idx]) img_bytes = self.file_client.get(filepath) img = mmcv.imfrombytes(img_bytes, flag=self.flag) # HWC, BGR bg = mmcv.imresize(img, (w, h), interpolation='bicubic') results['bg'] = bg return results def __repr__(self): return self.__class__.__name__ + f"(bg_dir='{self.bg_dir}')"
[docs]@PIPELINES.register_module() class LoadMask(object): """Load Mask for multiple types. For different types of mask, users need to provide the corresponding config dict. Example config for bbox: .. code-block:: python config = dict(img_shape=(256, 256), max_bbox_shape=128) Example config for irregular: .. code-block:: python config = dict( img_shape=(256, 256), num_vertexes=(4, 12), max_angle=4., length_range=(10, 100), brush_width=(10, 40), area_ratio_range=(0.15, 0.5)) Example config for ff: .. code-block:: python config = dict( img_shape=(256, 256), num_vertexes=(4, 12), mean_angle=1.2, angle_range=0.4, brush_width=(12, 40)) Example config for set: .. code-block:: python config = dict( mask_list_file='xxx/xxx/ooxx.txt', prefix='/xxx/xxx/ooxx/', io_backend='disk', flag='unchanged', file_client_kwargs=dict() ) The mask_list_file contains the list of mask file name like this: test1.jpeg test2.jpeg ... ... The prefix gives the data path. Args: mask_mode (str): Mask mode in ['bbox', 'irregular', 'ff', 'set', 'file']. * bbox: square bounding box masks. * irregular: irregular holes. * ff: free-form holes from DeepFillv2. * set: randomly get a mask from a mask set. * file: get mask from 'mask_path' in results. mask_config (dict): Params for creating masks. Each type of mask needs different configs. """ def __init__(self, mask_mode='bbox', mask_config=None): self.mask_mode = mask_mode self.mask_config = dict() if mask_config is None else mask_config assert isinstance(self.mask_config, dict) # set init info if needed in some modes self._init_info() def _init_info(self): if self.mask_mode == 'set': # get mask list information self.mask_list = [] mask_list_file = self.mask_config['mask_list_file'] with open(mask_list_file, 'r') as f: for line in f: line_split = line.strip().split(' ') mask_name = line_split[0] self.mask_list.append( Path(self.mask_config['prefix']).joinpath(mask_name)) self.mask_set_size = len(self.mask_list) self.io_backend = self.mask_config['io_backend'] self.flag = self.mask_config['flag'] self.file_client_kwargs = self.mask_config['file_client_kwargs'] self.file_client = None elif self.mask_mode == 'file': self.io_backend = 'disk' self.flag = 'unchanged' self.file_client_kwargs = dict() self.file_client = None def _get_random_mask_from_set(self): if self.file_client is None: self.file_client = FileClient(self.io_backend, **self.file_client_kwargs) # minus 1 to avoid out of range error mask_idx = np.random.randint(0, self.mask_set_size) mask_bytes = self.file_client.get(self.mask_list[mask_idx]) mask = mmcv.imfrombytes(mask_bytes, flag=self.flag) # HWC, BGR if mask.ndim == 2: mask = np.expand_dims(mask, axis=2) else: mask = mask[:, :, 0:1] mask[mask > 0] = 1. return mask def _get_mask_from_file(self, path): if self.file_client is None: self.file_client = FileClient(self.io_backend, **self.file_client_kwargs) mask_bytes = self.file_client.get(path) mask = mmcv.imfrombytes(mask_bytes, flag=self.flag) # HWC, BGR if mask.ndim == 2: mask = np.expand_dims(mask, axis=2) else: mask = mask[:, :, 0:1] mask[mask > 0] = 1. return mask def __call__(self, results): """Call function. Args: results (dict): A dict containing the necessary information and data for augmentation. Returns: dict: A dict containing the processed data and information. """ if self.mask_mode == 'bbox': mask_bbox = random_bbox(**self.mask_config) mask = bbox2mask(self.mask_config['img_shape'], mask_bbox) results['mask_bbox'] = mask_bbox elif self.mask_mode == 'irregular': mask = get_irregular_mask(**self.mask_config) elif self.mask_mode == 'set': mask = self._get_random_mask_from_set() elif self.mask_mode == 'ff': mask = brush_stroke_mask(**self.mask_config) elif self.mask_mode == 'file': mask = self._get_mask_from_file(results['mask_path']) else: raise NotImplementedError( f'Mask mode {self.mask_mode} has not been implemented.') results['mask'] = mask return results def __repr__(self): return self.__class__.__name__ + f"(mask_mode='{self.mask_mode}')"
[docs]@PIPELINES.register_module() class GetSpatialDiscountMask(object): """Get spatial discounting mask constant. Spatial discounting mask is first introduced in: Generative Image Inpainting with Contextual Attention. Args: gamma (float, optional): Gamma for computing spatial discounting. Defaults to 0.99. beta (float, optional): Beta for computing spatial discounting. Defaults to 1.5. """ def __init__(self, gamma=0.99, beta=1.5): self.gamma = gamma self.beta = beta
[docs] def spatial_discount_mask(self, mask_width, mask_height): """Generate spatial discounting mask constant. Args: mask_width (int): The width of bbox hole. mask_height (int): The height of bbox height. Returns: np.ndarray: Spatial discounting mask. """ w, h = np.meshgrid(np.arange(mask_width), np.arange(mask_height)) grid_stack = np.stack([h, w], axis=2) mask_values = (self.gamma**(np.minimum( grid_stack, [mask_height - 1, mask_width - 1] - grid_stack) * self.beta)).max( axis=2, keepdims=True) return mask_values
def __call__(self, results): """Call function. Args: results (dict): A dict containing the necessary information and data for augmentation. Returns: dict: A dict containing the processed data and information. """ mask_bbox = results['mask_bbox'] mask = results['mask'] mask_height, mask_width = mask_bbox[-2:] discount_hole = self.spatial_discount_mask(mask_width, mask_height) discount_mask = np.zeros_like(mask) discount_mask[mask_bbox[0]:mask_bbox[0] + mask_height, mask_bbox[1]:mask_bbox[1] + mask_width, ...] = discount_hole results['discount_mask'] = discount_mask return results def __repr__(self): return self.__class__.__name__ + (f'(gamma={self.gamma}, ' f'beta={self.beta})')
[docs]@PIPELINES.register_module() class LoadPairedImageFromFile(LoadImageFromFile): """Load a pair of images from file. Each sample contains a pair of images, which are concatenated in the w dimension (a|b). This is a special loading class for generation paired dataset. It loads a pair of images as the common loader does and crops it into two images with the same shape in different domains. Required key is "pair_path". Added or modified keys are "pair", "pair_ori_shape", "ori_pair", "img_a", "img_b", "img_a_path", "img_b_path", "img_a_ori_shape", "img_b_ori_shape", "ori_img_a" and "ori_img_b". Args: io_backend (str): io backend where images are store. Default: 'disk'. key (str): Keys in results to find corresponding path. Default: 'gt'. flag (str): Loading flag for images. Default: 'color'. channel_order (str): Order of channel, candidates are 'bgr' and 'rgb'. Default: 'bgr'. save_original_img (bool): If True, maintain a copy of the image in `results` dict with name of `f'ori_{key}'`. Default: False. kwargs (dict): Args for file client. """ def __call__(self, results): """Call function. Args: results (dict): A dict containing the necessary information and data for augmentation. Returns: dict: A dict containing the processed data and information. """ if self.file_client is None: self.file_client = FileClient(self.io_backend, **self.kwargs) filepath = str(results[f'{self.key}_path']) img_bytes = self.file_client.get(filepath) img = mmcv.imfrombytes(img_bytes, flag=self.flag) # HWC, BGR if img.ndim == 2: img = np.expand_dims(img, axis=2) results[self.key] = img results[f'{self.key}_path'] = filepath results[f'{self.key}_ori_shape'] = img.shape if self.save_original_img: results[f'ori_{self.key}'] = img.copy() # crop pair into a and b w = img.shape[1] if w % 2 != 0: raise ValueError( f'The width of image pair must be even number, but got {w}.') new_w = w // 2 img_a = img[:, :new_w, :] img_b = img[:, new_w:, :] results['img_a'] = img_a results['img_b'] = img_b results['img_a_path'] = filepath results['img_b_path'] = filepath results['img_a_ori_shape'] = img_a.shape results['img_b_ori_shape'] = img_b.shape if self.save_original_img: results['ori_img_a'] = img_a.copy() results['ori_img_b'] = img_b.copy() return results