Source code for mmedit.models.losses.pixelwise_loss

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..registry import LOSSES
from .utils import masked_loss

_reduction_modes = ['none', 'mean', 'sum']


@masked_loss
def l1_loss(pred, target):
    """L1 loss.

    Args:
        pred (Tensor): Prediction Tensor with shape (n, c, h, w).
        target ([type]): Target Tensor with shape (n, c, h, w).

    Returns:
        Tensor: Calculated L1 loss.
    """
    return F.l1_loss(pred, target, reduction='none')


@masked_loss
def mse_loss(pred, target):
    """MSE loss.

    Args:
        pred (Tensor): Prediction Tensor with shape (n, c, h, w).
        target ([type]): Target Tensor with shape (n, c, h, w).

    Returns:
        Tensor: Calculated MSE loss.
    """
    return F.mse_loss(pred, target, reduction='none')


@masked_loss
def charbonnier_loss(pred, target, eps=1e-12):
    """Charbonnier loss.

    Args:
        pred (Tensor): Prediction Tensor with shape (n, c, h, w).
        target ([type]): Target Tensor with shape (n, c, h, w).

    Returns:
        Tensor: Calculated Charbonnier loss.
    """
    return torch.sqrt((pred - target)**2 + eps)


[docs]@LOSSES.register_module() class L1Loss(nn.Module): """L1 (mean absolute error, MAE) loss. Args: loss_weight (float): Loss weight for L1 loss. Default: 1.0. reduction (str): Specifies the reduction to apply to the output. Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. sample_wise (bool): Whether calculate the loss sample-wise. This argument only takes effect when `reduction` is 'mean' and `weight` (argument of `forward()`) is not None. It will first reduce loss with 'mean' per-sample, and then it means over all the samples. Default: False. """ def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False): super(L1Loss, self).__init__() if reduction not in ['none', 'mean', 'sum']: raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') self.loss_weight = loss_weight self.reduction = reduction self.sample_wise = sample_wise
[docs] def forward(self, pred, target, weight=None, **kwargs): """Forward Function. Args: pred (Tensor): of shape (N, C, H, W). Predicted tensor. target (Tensor): of shape (N, C, H, W). Ground truth tensor. weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. """ return self.loss_weight * l1_loss( pred, target, weight, reduction=self.reduction, sample_wise=self.sample_wise)
[docs]@LOSSES.register_module() class MSELoss(nn.Module): """MSE (L2) loss. Args: loss_weight (float): Loss weight for MSE loss. Default: 1.0. reduction (str): Specifies the reduction to apply to the output. Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. sample_wise (bool): Whether calculate the loss sample-wise. This argument only takes effect when `reduction` is 'mean' and `weight` (argument of `forward()`) is not None. It will first reduces loss with 'mean' per-sample, and then it means over all the samples. Default: False. """ def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False): super(MSELoss, self).__init__() if reduction not in ['none', 'mean', 'sum']: raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') self.loss_weight = loss_weight self.reduction = reduction self.sample_wise = sample_wise
[docs] def forward(self, pred, target, weight=None, **kwargs): """Forward Function. Args: pred (Tensor): of shape (N, C, H, W). Predicted tensor. target (Tensor): of shape (N, C, H, W). Ground truth tensor. weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. """ return self.loss_weight * mse_loss( pred, target, weight, reduction=self.reduction, sample_wise=self.sample_wise)
[docs]@LOSSES.register_module() class CharbonnierLoss(nn.Module): """Charbonnier loss (one variant of Robust L1Loss, a differentiable variant of L1Loss). Described in "Deep Laplacian Pyramid Networks for Fast and Accurate Super-Resolution". Args: loss_weight (float): Loss weight for L1 loss. Default: 1.0. reduction (str): Specifies the reduction to apply to the output. Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. sample_wise (bool): Whether calculate the loss sample-wise. This argument only takes effect when `reduction` is 'mean' and `weight` (argument of `forward()`) is not None. It will first reduces loss with 'mean' per-sample, and then it means over all the samples. Default: False. eps (float): A value used to control the curvature near zero. Default: 1e-12. """ def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False, eps=1e-12): super(CharbonnierLoss, self).__init__() if reduction not in ['none', 'mean', 'sum']: raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') self.loss_weight = loss_weight self.reduction = reduction self.sample_wise = sample_wise self.eps = eps
[docs] def forward(self, pred, target, weight=None, **kwargs): """Forward Function. Args: pred (Tensor): of shape (N, C, H, W). Predicted tensor. target (Tensor): of shape (N, C, H, W). Ground truth tensor. weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. """ return self.loss_weight * charbonnier_loss( pred, target, weight, eps=self.eps, reduction=self.reduction, sample_wise=self.sample_wise)
[docs]@LOSSES.register_module() class MaskedTVLoss(L1Loss): """Masked TV loss. Args: loss_weight (float, optional): Loss weight. Defaults to 1.0. """ def __init__(self, loss_weight=1.0): super(MaskedTVLoss, self).__init__(loss_weight=loss_weight)
[docs] def forward(self, pred, mask=None): """Forward function. Args: pred (torch.Tensor): Tensor with shape of (n, c, h, w). mask (torch.Tensor, optional): Tensor with shape of (n, 1, h, w). Defaults to None. Returns: [type]: [description] """ y_diff = super(MaskedTVLoss, self).forward( pred[:, :, :-1, :], pred[:, :, 1:, :], weight=mask[:, :, :-1, :]) x_diff = super(MaskedTVLoss, self).forward( pred[:, :, :, :-1], pred[:, :, :, 1:], weight=mask[:, :, :, :-1]) loss = x_diff + y_diff return loss