Source code for mmedit.datasets.generation_unpaired_dataset

import os.path as osp

import numpy as np

from .base_generation_dataset import BaseGenerationDataset
from .registry import DATASETS

[docs]@DATASETS.register_module() class GenerationUnpairedDataset(BaseGenerationDataset): """General unpaired image folder dataset for image generation. It assumes that the training directory of images from domain A is '/path/to/data/trainA', and that from domain B is '/path/to/data/trainB', respectively. '/path/to/data' can be initialized by args 'dataroot'. During test time, the directory is '/path/to/data/testA' and '/path/to/data/testB', respectively. Args: dataroot (str | :obj:`Path`): Path to the folder root of unpaired images. pipeline (List[dict | callable]): A sequence of data transformations. test_mode (bool): Store `True` when building test dataset. Default: `False`. """ def __init__(self, dataroot, pipeline, test_mode=False): super(GenerationUnpairedDataset, self).__init__(pipeline, test_mode) phase = 'test' if test_mode else 'train' self.dataroot_a = osp.join(str(dataroot), phase + 'A') self.dataroot_b = osp.join(str(dataroot), phase + 'B') self.data_infos_a = self.load_annotations(self.dataroot_a) self.data_infos_b = self.load_annotations(self.dataroot_b) self.len_a = len(self.data_infos_a) self.len_b = len(self.data_infos_b)
[docs] def load_annotations(self, dataroot): """Load unpaired image paths of one domain. Args: dataroot (str): Path to the folder root for unpaired images of one domain. Returns: list[dict]: List that contains unpaired image paths of one domain. """ data_infos = [] paths = sorted(self.scan_folder(dataroot)) for path in paths: data_infos.append(dict(path=path)) return data_infos
[docs] def prepare_train_data(self, idx): """Prepare unpaired training data. Args: idx (int): Index of current batch. Returns: dict: Prepared training data batch. """ img_a_path = self.data_infos_a[idx % self.len_a]['path'] idx_b = np.random.randint(0, self.len_b) img_b_path = self.data_infos_b[idx_b]['path'] results = dict(img_a_path=img_a_path, img_b_path=img_b_path) return self.pipeline(results)
[docs] def prepare_test_data(self, idx): """Prepare unpaired test data. Args: idx (int): Index of current batch. Returns: list[dict]: Prepared test data batch. """ img_a_path = self.data_infos_a[idx % self.len_a]['path'] img_b_path = self.data_infos_b[idx % self.len_b]['path'] results = dict(img_a_path=img_a_path, img_b_path=img_b_path) return self.pipeline(results)
def __len__(self): return max(self.len_a, self.len_b)