Source code for mmedit.datasets.sr_vid4_dataset

from .base_sr_dataset import BaseSRDataset
from .registry import DATASETS

[docs]@DATASETS.register_module() class SRVid4Dataset(BaseSRDataset): """Vid4 dataset for video super resolution. The dataset loads several LQ (Low-Quality) frames and a center GT (Ground-Truth) frame. Then it applies specified transforms and finally returns a dict containing paired data and other information. It reads Vid4 keys from the txt file. Each line contains: 1. folder name; 2. number of frames in this clip (in the same folder); 3. image shape, seperated by a white space. Examples: :: calendar 40 (320,480,3) city 34 (320,480,3) Args: lq_folder (str | :obj:`Path`): Path to a lq folder. gt_folder (str | :obj:`Path`): Path to a gt folder. ann_file (str | :obj:`Path`): Path to the annotation file. num_input_frames (int): Window size for input frames. pipeline (list[dict | callable]): A sequence of data transformations. scale (int): Upsampling scale ratio. filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. Default: '{:08d}'. test_mode (bool): Store `True` when building test dataset. Default: `False`. """ def __init__(self, lq_folder, gt_folder, ann_file, num_input_frames, pipeline, scale, filename_tmpl='{:08d}', test_mode=False): super(SRVid4Dataset, self).__init__(pipeline, scale, test_mode) assert num_input_frames % 2 == 1, ( f'num_input_frames should be odd numbers, ' f'but received {num_input_frames}.') self.lq_folder = str(lq_folder) self.gt_folder = str(gt_folder) self.ann_file = str(ann_file) self.num_input_frames = num_input_frames self.filename_tmpl = filename_tmpl self.data_infos = self.load_annotations()
[docs] def load_annotations(self): """Load annoations for Vid4 dataset. Returns: dict: Returned dict for LQ and GT pairs. """ data_infos = [] with open(self.ann_file, 'r') as fin: for line in fin: folder, frame_num, _ = line.strip().split(' ') for i in range(int(frame_num)): data_infos.append( dict( lq_path=self.lq_folder, gt_path=self.gt_folder, key=f'{folder}/{self.filename_tmpl.format(i)}', num_input_frames=self.num_input_frames, max_frame_num=int(frame_num))) return data_infos