Source code for mmedit.models.backbones.encoder_decoders.simple_encoder_decoder

import torch.nn as nn

from mmedit.models.builder import build_component
from mmedit.models.registry import BACKBONES


[docs]@BACKBONES.register_module() class SimpleEncoderDecoder(nn.Module): """Simple encoder-decoder model from matting. Args: encoder (dict): Config of the encoder. decoder (dict): Config of the decoder. """ def __init__(self, encoder, decoder): super(SimpleEncoderDecoder, self).__init__() self.encoder = build_component(encoder) decoder['in_channels'] = self.encoder.out_channels self.decoder = build_component(decoder) def init_weights(self, pretrained=None): self.encoder.init_weights(pretrained) self.decoder.init_weights()
[docs] def forward(self, *args, **kwargs): """Forward function. Returns: Tensor: The output tensor of the decoder. """ out = self.encoder(*args, **kwargs) out = self.decoder(out) return out