Source code for mmedit.models.backbones.encoder_decoders.pconv_encoder_decoder

import torch.nn as nn
from mmcv.runner import auto_fp16, load_checkpoint

from mmedit.models.builder import build_component
from mmedit.models.registry import BACKBONES
from mmedit.utils import get_root_logger


[docs]@BACKBONES.register_module() class PConvEncoderDecoder(nn.Module): """Encoder-Decoder with partial conv module. Args: encoder (dict): Config of the encoder. decoder (dict): Config of the decoder. """ def __init__(self, encoder, decoder): super(PConvEncoderDecoder, self).__init__() self.encoder = build_component(encoder) self.decoder = build_component(decoder) # support fp16 self.fp16_enabled = False
[docs] @auto_fp16() def forward(self, x, mask_in): """Forward Function. Args: x (torch.Tensor): Input tensor with shape of (n, c, h, w). mask_in (torch.Tensor): Input tensor with shape of (n, c, h, w). Returns: torch.Tensor: Output tensor with shape of (n, c, h', w'). """ enc_outputs = self.encoder(x, mask_in) x, final_mask = self.decoder(enc_outputs) return x, final_mask
[docs] def init_weights(self, pretrained=None): """Init weights for models. Args: pretrained (str, optional): Path for pretrained weights. If given None, pretrained weights will not be loaded. Defaults to None. """ if isinstance(pretrained, str): logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: # Here, we just use the default initialization in `ConvModule`. pass else: raise TypeError('pretrained must be a str or None')