Source code for mmedit.models.common.linear_module

import torch.nn as nn
from mmcv.cnn import build_activation_layer, kaiming_init


[docs]class LinearModule(nn.Module): """A linear block that contains linear/norm/activation layers. For low level visioin, we add spectral norm and padding layer. Args: in_features (int): Same as nn.Linear. out_features (int): Same as nn.Linear. bias (bool): Same as nn.Linear. act_cfg (dict): Config dict for activation layer, "relu" by default. inplace (bool): Whether to use inplace mode for activation. with_spectral_norm (bool): Whether use spectral norm in linear module. order (tuple[str]): The order of linear/activation layers. It is a sequence of "linear", "norm" and "act". Examples are ("linear", "act") and ("act", "linear"). """ def __init__(self, in_features, out_features, bias=True, act_cfg=dict(type='ReLU'), inplace=True, with_spectral_norm=False, order=('linear', 'act')): super(LinearModule, self).__init__() assert act_cfg is None or isinstance(act_cfg, dict) self.act_cfg = act_cfg self.inplace = inplace self.with_spectral_norm = with_spectral_norm self.order = order assert isinstance(self.order, tuple) and len(self.order) == 2 assert set(order) == set(['linear', 'act']) self.with_activation = act_cfg is not None self.with_bias = bias # build linear layer self.linear = nn.Linear(in_features, out_features, bias=bias) # export the attributes of self.linear to a higher level for # convenience self.in_features = self.linear.in_features self.out_features = self.linear.out_features if self.with_spectral_norm: self.linear = nn.utils.spectral_norm(self.linear) # build activation layer if self.with_activation: act_cfg_ = act_cfg.copy() act_cfg_.setdefault('inplace', inplace) self.activate = build_activation_layer(act_cfg_) # Use msra init by default self.init_weights() def init_weights(self): if self.with_activation and self.act_cfg['type'] == 'LeakyReLU': nonlinearity = 'leaky_relu' a = self.act_cfg.get('negative_slope', 0.01) else: nonlinearity = 'relu' a = 0 kaiming_init(self.linear, a=a, nonlinearity=nonlinearity)
[docs] def forward(self, x, activate=True): """Foward Function. Args: x (torch.Tensor): Input tensor with shape of (n, \*, # noqa: W605 c). Same as ``torch.nn.Linear``. activate (bool, optional): Whether to use activation layer. Defaults to True. Returns: torch.Tensor: Same as ``torch.nn.Linear``. """ for layer in self.order: if layer == 'linear': x = self.linear(x) elif layer == 'act' and activate and self.with_activation: x = self.activate(x) return x