Source code for mmedit.models.base

from abc import ABCMeta, abstractmethod
from collections import OrderedDict

import torch
import torch.nn as nn


[docs]class BaseModel(nn.Module, metaclass=ABCMeta): """Base model. All models should subclass it. All subclass should overwrite: ``init_weights``, supporting to initialize models. ``forward_train``, supporting to forward when training. ``forward_test``, supporting to forward when testing. ``train_step``, supporting to train one step when training. """ def __init__(self): super(BaseModel, self).__init__()
[docs] @abstractmethod def init_weights(self): """Abstract method for initializing weight. All subclass should overwrite it. """ pass
[docs] @abstractmethod def forward_train(self, imgs, labels): """Abstract method for training forward. All subclass should overwrite it. """ pass
[docs] @abstractmethod def forward_test(self, imgs): """Abstract method for testing forward. All subclass should overwrite it. """ pass
[docs] def forward(self, imgs, labels, test_mode, **kwargs): """Forward function for base model. Args: imgs (Tensor): Input image(s). labels (Tensor): Ground-truth label(s). test_mode (bool): Whether in test mode. kwargs (dict): Other arguments. Returns: Tensor: Forward results. """ if not test_mode: return self.forward_train(imgs, labels, **kwargs) else: return self.forward_test(imgs, **kwargs)
[docs] @abstractmethod def train_step(self, data_batch, optimizer): """Abstract method for one training step. All subclass should overwrite it. """ pass
[docs] def val_step(self, data_batch, **kwargs): """Abstract method for one validation step. All subclass should overwrite it. """ output = self.forward_test(**data_batch, **kwargs) return output
[docs] def parse_losses(self, losses): """Parse losses dict for different loss variants. Args: losses (dict): Loss dict. Returns: loss (float): Sum of the total loss. log_vars (dict): loss dict for different variants. """ log_vars = OrderedDict() for loss_name, loss_value in losses.items(): if isinstance(loss_value, torch.Tensor): log_vars[loss_name] = loss_value.mean() elif isinstance(loss_value, list): log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) else: raise TypeError( f'{loss_name} is not a tensor or list of tensors') loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key) log_vars['loss'] = loss for name in log_vars: log_vars[name] = log_vars[name].item() return loss, log_vars