Source code for mmedit.models.losses.composition_loss

import torch.nn as nn

from ..registry import LOSSES
from .pixelwise_loss import charbonnier_loss, l1_loss, mse_loss

_reduction_modes = ['none', 'mean', 'sum']


[docs]@LOSSES.register_module() class L1CompositionLoss(nn.Module): """L1 composition loss. Args: loss_weight (float): Loss weight for L1 loss. Default: 1.0. reduction (str): Specifies the reduction to apply to the output. Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. sample_wise (bool): Whether calculate the loss sample-wise. This argument only takes effect when `reduction` is 'mean' and `weight` (argument of `forward()`) is not None. It will first reduces loss with 'mean' per-sample, and then it means over all the samples. Default: False. """ def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False): super(L1CompositionLoss, self).__init__() if reduction not in ['none', 'mean', 'sum']: raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') self.loss_weight = loss_weight self.reduction = reduction self.sample_wise = sample_wise
[docs] def forward(self, pred_alpha, fg, bg, ori_merged, weight=None, **kwargs): """ Args: pred_alpha (Tensor): of shape (N, 1, H, W). Predicted alpha matte. fg (Tensor): of shape (N, 3, H, W). Tensor of foreground object. bg (Tensor): of shape (N, 3, H, W). Tensor of background object. ori_merged (Tensor): of shape (N, 3, H, W). Tensor of origin merged image before normalized by ImageNet mean and std. weight (Tensor, optional): of shape (N, 1, H, W). It is an indicating matrix: weight[trimap == 128] = 1. Default: None. """ pred_merged = pred_alpha * fg + (1. - pred_alpha) * bg if weight is not None: weight = weight.expand(-1, 3, -1, -1) return self.loss_weight * l1_loss( pred_merged, ori_merged, weight, reduction=self.reduction, sample_wise=self.sample_wise)
[docs]@LOSSES.register_module() class MSECompositionLoss(nn.Module): """MSE (L2) composition loss. Args: loss_weight (float): Loss weight for MSE loss. Default: 1.0. reduction (str): Specifies the reduction to apply to the output. Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. sample_wise (bool): Whether calculate the loss sample-wise. This argument only takes effect when `reduction` is 'mean' and `weight` (argument of `forward()`) is not None. It will first reduces loss with 'mean' per-sample, and then it means over all the samples. Default: False. """ def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False): super(MSECompositionLoss, self).__init__() if reduction not in ['none', 'mean', 'sum']: raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') self.loss_weight = loss_weight self.reduction = reduction self.sample_wise = sample_wise
[docs] def forward(self, pred_alpha, fg, bg, ori_merged, weight=None, **kwargs): """ Args: pred_alpha (Tensor): of shape (N, 1, H, W). Predicted alpha matte. fg (Tensor): of shape (N, 3, H, W). Tensor of foreground object. bg (Tensor): of shape (N, 3, H, W). Tensor of background object. ori_merged (Tensor): of shape (N, 3, H, W). Tensor of origin merged image before normalized by ImageNet mean and std. weight (Tensor, optional): of shape (N, 1, H, W). It is an indicating matrix: weight[trimap == 128] = 1. Default: None. """ pred_merged = pred_alpha * fg + (1. - pred_alpha) * bg if weight is not None: weight = weight.expand(-1, 3, -1, -1) return self.loss_weight * mse_loss( pred_merged, ori_merged, weight, reduction=self.reduction, sample_wise=self.sample_wise)
[docs]@LOSSES.register_module() class CharbonnierCompLoss(nn.Module): """Charbonnier composition loss. Args: loss_weight (float): Loss weight for L1 loss. Default: 1.0. reduction (str): Specifies the reduction to apply to the output. Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. sample_wise (bool): Whether calculate the loss sample-wise. This argument only takes effect when `reduction` is 'mean' and `weight` (argument of `forward()`) is not None. It will first reduces loss with 'mean' per-sample, and then it means over all the samples. Default: False. eps (float): A value used to control the curvature near zero. Default: 1e-12. """ def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False, eps=1e-12): super(CharbonnierCompLoss, self).__init__() if reduction not in ['none', 'mean', 'sum']: raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') self.loss_weight = loss_weight self.reduction = reduction self.sample_wise = sample_wise self.eps = eps
[docs] def forward(self, pred_alpha, fg, bg, ori_merged, weight=None, **kwargs): """ Args: pred_alpha (Tensor): of shape (N, 1, H, W). Predicted alpha matte. fg (Tensor): of shape (N, 3, H, W). Tensor of foreground object. bg (Tensor): of shape (N, 3, H, W). Tensor of background object. ori_merged (Tensor): of shape (N, 3, H, W). Tensor of origin merged image before normalized by ImageNet mean and std. weight (Tensor, optional): of shape (N, 1, H, W). It is an indicating matrix: weight[trimap == 128] = 1. Default: None. """ pred_merged = pred_alpha * fg + (1. - pred_alpha) * bg if weight is not None: weight = weight.expand(-1, 3, -1, -1) return self.loss_weight * charbonnier_loss( pred_merged, ori_merged, weight, eps=self.eps, reduction=self.reduction, sample_wise=self.sample_wise)