Source code for mmedit.models.mattors.gca

import torch
from mmcv.runner import auto_fp16

from ..builder import build_loss
from ..registry import MODELS
from .base_mattor import BaseMattor
from .utils import get_unknown_tensor


[docs]@MODELS.register_module() class GCA(BaseMattor): """Guided Contextual Attention image matting model. https://arxiv.org/abs/2001.04069 Args: backbone (dict): Config of backbone. train_cfg (dict): Config of training. In ``train_cfg``, ``train_backbone`` should be specified. If the model has a refiner, ``train_refiner`` should be specified. test_cfg (dict): Config of testing. In ``test_cfg``, If the model has a refiner, ``train_refiner`` should be specified. pretrained (str): Path of the pretrained model. loss_alpha (dict): Config of the alpha prediction loss. Default: None. """ def __init__(self, backbone, train_cfg=None, test_cfg=None, pretrained=None, loss_alpha=None): super(GCA, self).__init__(backbone, None, train_cfg, test_cfg, pretrained) self.loss_alpha = build_loss(loss_alpha) # support fp16 self.fp16_enabled = False @auto_fp16(apply_to=('x', )) def _forward(self, x): raw_alpha = self.backbone(x) pred_alpha = (raw_alpha.tanh() + 1.0) / 2.0 return pred_alpha def forward_dummy(self, inputs): return self._forward(inputs)
[docs] def forward_train(self, merged, trimap, meta, alpha): """Forward function for training GCA model. Args: merged (Tensor): with shape (N, C, H, W) encoding input images. Typically these should be mean centered and std scaled. trimap (Tensor): with shape (N, C', H, W). Tensor of trimap. C' might be 1 or 3. meta (list[dict]): Meta data about the current data batch. alpha (Tensor): with shape (N, 1, H, W). Tensor of alpha. Returns: dict: Contains the loss items and batch infomation. """ pred_alpha = self._forward(torch.cat((merged, trimap), 1)) weight = get_unknown_tensor(trimap, meta) losses = {'loss': self.loss_alpha(pred_alpha, alpha, weight)} return {'losses': losses, 'num_samples': merged.size(0)}
[docs] def forward_test(self, merged, trimap, meta, save_image=False, save_path=None, iteration=None): """Defines the computation performed at every test call. Args: merged (Tensor): Image to predict alpha matte. trimap (Tensor): Trimap of the input image. meta (list[dict]): Meta data about the current data batch. Currently only batch_size 1 is supported. It may contain information needed to calculate metrics (``ori_alpha`` and ``ori_trimap``) or save predicted alpha matte (``merged_path``). save_image (bool, optional): Whether save predicted alpha matte. Defaults to False. save_path (str, optional): The directory to save predicted alpha matte. Defaults to None. iteration (int, optional): If given as None, the saved alpha matte will have the same file name with ``merged_path`` in meta dict. If given as an int, the saved alpha matte would named with postfix ``_{iteration}.png``. Defaults to None. Returns: dict: Contains the predicted alpha and evaluation result. """ pred_alpha = self._forward(torch.cat((merged, trimap), 1)) pred_alpha = pred_alpha.detach().cpu().numpy().squeeze() pred_alpha = self.restore_shape(pred_alpha, meta) eval_result = self.evaluate(pred_alpha, meta) if save_image: self.save_image(pred_alpha, meta, save_path, iteration) return {'pred_alpha': pred_alpha, 'eval_result': eval_result}