Source code for mmedit.datasets.sr_reds_dataset

from .base_sr_dataset import BaseSRDataset
from .registry import DATASETS


[docs]@DATASETS.register_module() class SRREDSDataset(BaseSRDataset): """REDS dataset for video super resolution. The dataset loads several LQ (Low-Quality) frames and a center GT (Ground-Truth) frame. Then it applies specified transforms and finally returns a dict containing paired data and other information. It reads REDS keys from the txt file. Each line contains: 1. image name; 2, image shape, seperated by a white space. Examples: :: 000/00000000.png (720, 1280, 3) 000/00000001.png (720, 1280, 3) Args: lq_folder (str | :obj:`Path`): Path to a lq folder. gt_folder (str | :obj:`Path`): Path to a gt folder. ann_file (str | :obj:`Path`): Path to the annotation file. num_input_frames (int): Window size for input frames. pipeline (list[dict | callable]): A sequence of data transformations. scale (int): Upsampling scale ratio. val_partition (str): Validation partition mode. Choices ['official' or 'REDS4']. Default: 'official'. test_mode (bool): Store `True` when building test dataset. Default: `False`. """ def __init__(self, lq_folder, gt_folder, ann_file, num_input_frames, pipeline, scale, val_partition='official', test_mode=False): super(SRREDSDataset, self).__init__(pipeline, scale, test_mode) assert num_input_frames % 2 == 1, ( f'num_input_frames should be odd numbers, ' f'but received {num_input_frames }.') self.lq_folder = str(lq_folder) self.gt_folder = str(gt_folder) self.ann_file = str(ann_file) self.num_input_frames = num_input_frames self.val_partition = val_partition self.data_infos = self.load_annotations()
[docs] def load_annotations(self): """Load annoations for REDS dataset. Returns: dict: Returned dict for LQ and GT pairs. """ # get keys with open(self.ann_file, 'r') as fin: keys = [v.strip().split('.')[0] for v in fin] if self.val_partition == 'REDS4': val_partition = ['000', '011', '015', '020'] elif self.val_partition == 'official': val_partition = [f'{v:03d}' for v in range(240, 270)] else: raise ValueError( f'Wrong validation partition {self.val_partition}.' f'Supported ones are ["official", "REDS4"]') if self.test_mode: keys = [v for v in keys if v.split('/')[0] in val_partition] else: keys = [v for v in keys if v.split('/')[0] not in val_partition] data_infos = [] for key in keys: data_infos.append( dict( lq_path=self.lq_folder, gt_path=self.gt_folder, key=key, max_frame_num=100, # REDS has 100 frames for each clip num_input_frames=self.num_input_frames)) return data_infos