Config System for Inpainting

Config Name Style

Same as MMDetection, we incorporate modular and inheritance design into our config system, which is convenient to conduct various experiments.

Config Field Description

To help the users have a basic idea of a complete config and the modules in a inpainting system, we make brief comments on the config of Global&Local as the following. For more detailed usage and the corresponding alternative for each modules, please refer to the API documentation.

model = dict(
    type='GLInpaintor', # The name of inpaintor
    encdec=dict(
        type='GLEncoderDecoder', # The name of encoder-decoder
        encoder=dict(type='GLEncoder', norm_cfg=dict(type='SyncBN')), # The config of encoder
        decoder=dict(type='GLDecoder', norm_cfg=dict(type='SyncBN')), # The config of decoder
        dilation_neck=dict(
            type='GLDilationNeck', norm_cfg=dict(type='SyncBN'))), # The config of dilation neck
    disc=dict(
        type='GLDiscs', # The name of discriminator
        global_disc_cfg=dict(
            in_channels=3, # The input channel of discriminator
            max_channels=512, # The maximum middle channel in discriminator
            fc_in_channels=512 * 4 * 4, # The input channel of last fc layer
            fc_out_channels=1024, # The output channel of last fc channel
            num_convs=6, # The number of convs used in discriminator
            norm_cfg=dict(type='SyncBN') # The config of norm layer
        ),
        local_disc_cfg=dict(
            in_channels=3, # The input channel of discriminator
            max_channels=512, # The maximum middle channel in discriminator
            fc_in_channels=512 * 4 * 4, # The input channel of last fc layer
            fc_out_channels=1024, # The output channel of last fc channel
            num_convs=5, # The number of convs used in discriminator
            norm_cfg=dict(type='SyncBN') # The config of norm layer
        ),
    ),
    loss_gan=dict(
        type='GANLoss', # The name of GAN loss
        gan_type='vanilla', # The type of GAN loss
        loss_weight=0.001 # The weight of GAN loss
    ),
    loss_l1_hole=dict(
        type='L1Loss', # The type of l1 loss
        loss_weight=1.0 # The weight of l1 loss
    ),
    pretrained=None) # The path of pretrained weight

train_cfg = dict(
    disc_step=1, # The steps of training discriminator before training generator
    iter_tc=90000, # Iterations of warming up generator
    iter_td=100000, # Iterations of warming up discriminator
    start_iter=0, # Starting iteration
    local_size=(128, 128)) # The size of local patches
test_cfg = dict(metrics=['l1']) # The config of testing scheme

dataset_type = 'ImgInpaintingDataset' # The type of dataset
input_shape = (256, 256) # The shape of input image

train_pipeline = [
    dict(type='LoadImageFromFile', key='gt_img'), # The config of loading image
    dict(
        type='LoadMask', # The type of loading mask pipeline
        mask_mode='bbox', # The type of mask
        mask_config=dict(
            max_bbox_shape=(128, 128), # The shape of bbox
            max_bbox_delta=40, # The changing delta of bbox height and width
            min_margin=20,  # The minimum margin from bbox to the image border
            img_shape=input_shape)),  # The input image shape
    dict(
        type='Crop', # The type of crop pipeline
        keys=['gt_img'],  # The keys of images to be cropped
        crop_size=(384, 384),  # The size of cropped patch
        random_crop=True,  # Whether to use random crop
    ),
    dict(
        type='Resize',  # The type of resizing pipeline
        keys=['gt_img'],  # They keys of images to be resized
        scale=input_shape,  # The scale of resizing function
        keep_ratio=False,  # Whether to keep ratio during resizing
    ),
    dict(
        type='Normalize',  # The type of normalizing pipeline
        keys=['gt_img'],  # The keys of images to be normed
        mean=[127.5] * 3,  # Mean value used in normalization
        std=[127.5] * 3,  # Std value used in normalization
        to_rgb=False),  # Whether to transfer image channels to rgb
    dict(type='GetMaskedImage'),  # The config of getting masked image pipeline
    dict(
        type='Collect',  # The type of collecting data from current pipeline
        keys=['gt_img', 'masked_img', 'mask', 'mask_bbox'],  # The keys of data to be collected
        meta_keys=['gt_img_path']),  # The meta keys of data to be collected
    dict(type='ImageToTensor', keys=['gt_img', 'masked_img', 'mask']),  # The config dict of image to tensor pipeline
    dict(type='ToTensor', keys=['mask_bbox'])  # The config dict of ToTensor pipeline
]

test_pipeline = train_pipeline  # Constructing testing/validation pipeline

data_root = './data/places365/'  # Set data root

data = dict(
    samples_per_gpu=12,  # Batch size of a single GPU
    workers_per_gpu=8,  # Worker to pre-fetch data for each single GPU
    val_samples_per_gpu=1,  # Batch size of a single GPU in validation
    val_workers_per_gpu=8,  # Worker to pre-fetch data for each single GPU in validation
    drop_last=True,  # Whether to drop out the last batch of data
    train=dict(  # Train dataset config
        type=dataset_type,
        ann_file=data_root + 'train_places_img_list_total.txt',
        data_prefix=data_root,
        pipeline=train_pipeline,
        test_mode=False),
    val=dict(  # Validation dataset config
        type=dataset_type,
        ann_file=data_root + 'val_places_img_list.txt',
        data_prefix=data_root,
        pipeline=test_pipeline,
        test_mode=True))

optimizers = dict(  # Config used to build optimizer, support all the optimizers in PyTorch whose arguments are also the same as those in PyTorch
    generator=dict(type='Adam', lr=0.0004), disc=dict(type='Adam', lr=0.0004))

lr_config = dict(policy='Fixed', by_epoch=False)  # Learning rate scheduler config used to register LrUpdater hook

checkpoint_config = dict(by_epoch=False, interval=50000)  # Config to set the checkpoint hook, Refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/checkpoint.py for implementation.
log_config = dict(  # config to register logger hook
    interval=100,  # Interval to print the log
    hooks=[
        dict(type='TextLoggerHook', by_epoch=False),
        # dict(type='TensorboardLoggerHook'),  # The Tensorboard logger is also supported
        dict(type='PaviLoggerHook', init_kwargs=dict(project='mmedit'))
    ])  # The logger used to record the training process.

visual_config = dict(  # config to register logger hook
    type='VisualizationHook',
    output_dir='visual',
    interval=1000,
    res_name_list=[
        'gt_img', 'masked_img', 'fake_res', 'fake_img', 'fake_gt_local'
    ],
)  # The logger used to visualize the training process.

evaluation = dict(interval=50000)  # The config to build the evaluation hook

total_iters = 500002
dist_params = dict(backend='nccl')  # Parameters to setup distributed training, the port can also be set.
log_level = 'INFO'  # The level of logging.
work_dir = None  # Directory to save the model checkpoints and logs for the current experiments.
load_from = None  # load models as a pre-trained model from a given path. This will not resume training.
resume_from = None  # Resume checkpoints from a given path, the training will be resumed from the epoch when the checkpoint's is saved.
workflow = [('train', 10000)]  # Workflow for runner. [('train', 1)] means there is only one workflow and the workflow named 'train' is executed once. The workflow trains the model by 12 epochs according to the total_epochs.
exp_name = 'gl_places'  # The experiment name
find_unused_parameters = False  # Whether to set find unused parameters in ddp