Source code for mmedit.models.components.discriminators.deepfill_disc

import torch.nn as nn
from mmcv.cnn import normal_init
from mmcv.runner import load_checkpoint

from mmedit.models import build_component
from mmedit.models.registry import COMPONENTS
from mmedit.utils import get_root_logger


[docs]@COMPONENTS.register_module() class DeepFillv1Discriminators(nn.Module): """Discriminators used in DeepFillv1 model. In DeepFillv1 model, the discriminators are independent without any concatenation like Global&Local model. Thus, we call this model `DeepFillv1Discriminators`. There exist a global discriminator and a local discriminator with global and local input respectively. The details can be found in: Generative Image Inpainting with Contextual Attention. Args: global_disc_cfg (dict): Config dict for global discriminator. local_disc_cfg (dict): Config dict for local discriminator. """ def __init__(self, global_disc_cfg, local_disc_cfg): super(DeepFillv1Discriminators, self).__init__() self.global_disc = build_component(global_disc_cfg) self.local_disc = build_component(local_disc_cfg)
[docs] def forward(self, x): """Forward function. Args: x (tuple[torch.Tensor]): Contains global image and the local image patch. Returns: tuple[torch.Tensor]: Contains the prediction from discriminators \ in global image and local image patch. """ global_img, local_img = x global_pred = self.global_disc(global_img) local_pred = self.local_disc(local_img) return global_pred, local_pred
[docs] def init_weights(self, pretrained=None): """Init weights for models. Args: pretrained (str, optional): Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None. """ if isinstance(pretrained, str): logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: for m in self.modules(): if isinstance(m, nn.Linear): normal_init(m, 0, std=0.02) elif isinstance(m, nn.Conv2d): normal_init(m, 0.0, std=0.02) else: raise TypeError('pretrained must be a str or None but got' f'{type(pretrained)} instead.')