Config System for Restoration¶
An Example - EDSR¶
To help the users have a basic idea of a complete config, we make a brief comments on the config of the EDSR model we implemented as the following. For more detailed usage and the corresponding alternative for each modules, please refer to the API documentation.
exp_name = 'edsr_x2c64b16_1x16_300k_div2k' # The experiment name
scale = 2 # Scale factor for upsampling
# model settings
model = dict(
type='BasicRestorer', # Name of the model
generator=dict( # Config of the generator
type='EDSR', # Type of the generator
in_channels=3, # Channel number of inputs
out_channels=3, # Channel number of outputs
mid_channels=64, # Channel number of intermediate features
num_blocks=16, # Block number in the trunk network
upscale_factor=scale, # Upsampling factor
res_scale=1, # Used to scale the residual in residual block
rgb_mean=(0.4488, 0.4371, 0.4040), # Image mean in RGB orders
rgb_std=(1.0, 1.0, 1.0)), # Image std in RGB orders
pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean')) # Config for pixel loss
# model training and testing settings
train_cfg = None # Training config
test_cfg = dict( # Test config
metrics=['PSNR'], # Metrics used during testing
crop_border=scale) # Crop border during evaluation
# dataset settings
train_dataset_type = 'SRAnnotationDataset' # Dataset type for training
val_dataset_type = 'SRFolderDataset' # Dataset type for validation
train_pipeline = [ # Training data processing pipeline
dict(type='LoadImageFromFile', # Load images from files
io_backend='disk', # io backend
key='lq', # Keys in results to find corresponding path
flag='unchanged'), # flag for reading images
dict(type='LoadImageFromFile', # Load images from files
io_backend='disk', # io backend
key='gt', # Keys in results to find corresponding path
flag='unchanged'), # flag for reading images
dict(type='RescaleToZeroOne', keys=['lq', 'gt']), # Rescale images from [0, 255] to [0, 1]
dict(type='Normalize', # Augmentation pipeline that normalize the input images
keys=['lq', 'gt'], # Images to be normalized
mean=[0, 0, 0], # Mean values
std=[1, 1, 1], # Standard variance
to_rgb=True), # Change to RGB channel
dict(type='PairedRandomCrop', gt_patch_size=96), # Paired random crop
dict(type='Flip', # Flip images
keys=['lq', 'gt'], # Images to be flipped
flip_ratio=0.5, # Flip ratio
direction='horizontal'), # Flip direction
dict(type='Flip', # Flip images
keys=['lq', 'gt'], # Images to be flipped
flip_ratio=0.5, # Flip ratio
direction='vertical'), # Flip direction
dict(type='RandomTransposeHW', # Random transpose h and w for images
keys=['lq', 'gt'], # Images to be transposed
transpose_ratio=0.5 # Transpose ratio
),
dict(type='Collect', # Pipeline that decides which keys in the data should be passed to the model
keys=['lq', 'gt'], # Keys to pass to the model
meta_keys=['lq_path', 'gt_path']), # Meta information keys. In training, meta information is not needed
dict(type='ImageToTensor', # Convert images to tensor
keys=['lq', 'gt']) # Images to be converted to Tensor
]
test_pipeline = [ # Test pipeline
dict(
type='LoadImageFromFile', # Load images from files
io_backend='disk', # io backend
key='lq', # Keys in results to find corresponding path
flag='unchanged'), # flag for reading images
dict(
type='LoadImageFromFile', # Load images from files
io_backend='disk', # io backend
key='gt', # Keys in results to find corresponding path
flag='unchanged'), # flag for reading images
dict(type='RescaleToZeroOne', keys=['lq', 'gt']), # Rescale images from [0, 255] to [0, 1]
dict(
type='Normalize', # Augmentation pipeline that normalize the input images
keys=['lq', 'gt'], # Images to be normalized
mean=[0, 0, 0], # Mean values
std=[1, 1, 1], # Standard variance
to_rgb=True), # Change to RGB channel
dict(type='Collect', # Pipeline that decides which keys in the data should be passed to the model
keys=['lq', 'gt'], # Keys to pass to the model
meta_keys=['lq_path', 'lq_path']), # Meta information keys
dict(type='ImageToTensor', # Convert images to tensor
keys=['lq', 'gt']) # Images to be converted to Tensor
]
data = dict(
# train
samples_per_gpu=16, # Batch size of a single GPU
workers_per_gpu=6, # Worker to pre-fetch data for each single GPU
drop_last=True, # Use drop_last in data_loader
train=dict( # Train dataset config
type='RepeatDataset', # Repeated dataset for iter-based model
times=1000, # Repeated times for RepeatDataset
dataset=dict(
type=train_dataset_type, # Type of dataset
lq_folder='data/DIV2K/DIV2K_train_LR_bicubic/X2_sub', # Path for lq folder
gt_folder='data/DIV2K/DIV2K_train_HR_sub', # Path for gt folder
ann_file='data/DIV2K/meta_info_DIV2K800sub_GT.txt', # Path for annotation file
pipeline=train_pipeline, # See above for train_pipeline
scale=scale)), # Scale factor for upsampling
# val
val_samples_per_gpu=1, # Batch size of a single GPU for validation
val_workers_per_gpu=1, # Worker to pre-fetch data for each single GPU for validation
val=dict(
type=val_dataset_type, # Type of dataset
lq_folder='./data/val_set5/Set5_bicLRx2', # Path for lq folder
gt_folder='./data/val_set5/Set5_mod12', # Path for gt folder
pipeline=test_pipeline, # See above for test_pipeline
scale=scale, # Scale factor for upsampling
filename_tmpl='{}'), # filename template
# test
test=dict(
type=val_dataset_type, # Type of dataset
lq_folder='./data/val_set5/Set5_bicLRx2', # Path for lq folder
gt_folder='./data/val_set5/Set5_mod12', # Path for gt folder
pipeline=test_pipeline, # See above for test_pipeline
scale=scale, # Scale factor for upsampling
filename_tmpl='{}')) # filename template
# optimizer
optimizers = dict(generator=dict(type='Adam', lr=1e-4, betas=(0.9, 0.999))) # Config used to build optimizer, support all the optimizers in PyTorch whose arguments are also the same as those in PyTorch
# learning policy
total_iters = 300000 # Total training iters
lr_config = dict( # Learning rate scheduler config used to register LrUpdater hook
policy='Step', by_epoch=False, step=[200000], gamma=0.5) # The policy of scheduler, also support CosineAnnealing, Cyclic, etc. Refer to details of supported LrUpdater from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/lr_updater.py#L9.
checkpoint_config = dict( # Config to set the checkpoint hook, Refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/checkpoint.py for implementation.
interval=5000, # The save interval is 5000 iterations
save_optimizer=True, # Also save optimizers
by_epoch=False) # Count by iterations
evaluation = dict( # The config to build the evaluation hook
interval=5000, # Evaluation interval
save_image=True, # Save images during evaluation
gpu_collect=True) # Use gpu collect
log_config = dict( # Config to register logger hook
interval=100, # Interval to print the log
hooks=[
dict(type='TextLoggerHook', by_epoch=False), # The logger used to record the training process
dict(type='TensorboardLoggerHook'), # The Tensorboard logger is also supported
])
visual_config = None # Visual config, we do not use it.
# runtime settings
dist_params = dict(backend='nccl') # Parameters to setup distributed training, the port can also be set
log_level = 'INFO' # The level of logging
work_dir = f'./work_dirs/{exp_name}' # Directory to save the model checkpoints and logs for the current experiments
load_from = None # load models as a pre-trained model from a given path. This will not resume training
resume_from = None # Resume checkpoints from a given path, the training will be resumed from the iteration when the checkpoint's is saved
workflow = [('train', 1)] # Workflow for runner. [('train', 1)] means there is only one workflow and the workflow named 'train' is executed once. Keep this unchanged when training current matting models