Source code for mmedit.datasets.base_sr_dataset

import copy
import os.path as osp
from collections import defaultdict
from pathlib import Path

from mmcv import scandir

from .base_dataset import BaseDataset

IMG_EXTENSIONS = ('.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm',
                  '.PPM', '.bmp', '.BMP')


[docs]class BaseSRDataset(BaseDataset): """Base class for super resolution datasets. """ def __init__(self, pipeline, scale, test_mode=False): super(BaseSRDataset, self).__init__(pipeline, test_mode) self.scale = scale
[docs] @staticmethod def scan_folder(path): """Obtain image path list (including sub-folders) from a given folder. Args: path (str | :obj:`Path`): Folder path. Returns: list[str]: image list obtained form given folder. """ if isinstance(path, (str, Path)): path = str(path) else: raise TypeError("'path' must be a str or a Path object, " f'but received {type(path)}.') images = list(scandir(path, suffix=IMG_EXTENSIONS, recursive=True)) images = [osp.join(path, v) for v in images] assert images, f'{path} has no valid image file.' return images
def __getitem__(self, idx): """Get item at each call. Args: idx (int): Index for getting each item. """ results = copy.deepcopy(self.data_infos[idx]) results['scale'] = self.scale return self.pipeline(results)
[docs] def evaluate(self, results, logger=None): """Evaluate with different metrics. Args: results (list[tuple]): The output of forward_test() of the model. Return: dict: Evaluation results dict. """ if not isinstance(results, list): raise TypeError(f'results must be a list, but got {type(results)}') assert len(results) == len(self), ( 'The length of results is not equal to the dataset len: ' f'{len(results)} != {len(self)}') results = [res['eval_result'] for res in results] # a list of dict eval_results = defaultdict(list) # a dict of list for res in results: for metric, val in res.items(): eval_results[metric].append(val) for metric, val_list in eval_results.items(): assert len(val_list) == len(self), ( f'Length of evaluation result of {metric} is {len(val_list)}, ' f'should be {len(self)}') # average the results eval_results = { metric: sum(values) / len(self) for metric, values in eval_results.items() } return eval_results