import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner import load_checkpoint
from mmedit.models.common import flow_warp
from mmedit.models.registry import BACKBONES
from mmedit.utils import get_root_logger
class BasicModule(nn.Module):
"""Basic module of SPyNet.
Note that unlike the common spynet architecture, the basic module
here contains batch normalization.
"""
def __init__(self):
super(BasicModule, self).__init__()
self.basic_module = nn.Sequential(
ConvModule(
in_channels=8,
out_channels=32,
kernel_size=7,
stride=1,
padding=3,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU')),
ConvModule(
in_channels=32,
out_channels=64,
kernel_size=7,
stride=1,
padding=3,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU')),
ConvModule(
in_channels=64,
out_channels=32,
kernel_size=7,
stride=1,
padding=3,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU')),
ConvModule(
in_channels=32,
out_channels=16,
kernel_size=7,
stride=1,
padding=3,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU')),
ConvModule(
in_channels=16,
out_channels=2,
kernel_size=7,
stride=1,
padding=3,
norm_cfg=None,
act_cfg=None))
def forward(self, tensor_input):
"""
Args:
tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
8 channels contain:
[reference image (3), neighbor image (3), initial flow (2)].
Returns:
Tensor: Estimated flow with shape (b, 2, h, w)
"""
return self.basic_module(tensor_input)
class SPyNet(nn.Module):
"""SPyNet architecture.
Note that this implementation is specifically for TOFlow. It differs from
the common SPyNet in the following aspects:
1. The basic modules here contain BatchNorm.
2. Normalization and denormalization are not done here, as
they are done in TOFlow.
Paper:
Optical Flow Estimation using a Spatial Pyramid Network
Code reference:
https://github.com/Coldog2333/pytoflow
"""
def __init__(self, load_path=None):
super(SPyNet, self).__init__()
self.basic_module = nn.ModuleList([BasicModule() for _ in range(4)])
def forward(self, ref, supp):
"""
Args:
ref (Tensor): Reference image with shape of (b, 3, h, w).
supp: The supporting image to be warped: (b, 3, h, w).
Returns:
Tensor: Estimated optical flow: (b, 2, h, w).
"""
num_batches, _, h, w = ref.size()
ref = [ref]
supp = [supp]
# generate downsampled frames
for _ in range(3):
ref.insert(
0,
F.avg_pool2d(
input=ref[0],
kernel_size=2,
stride=2,
count_include_pad=False))
supp.insert(
0,
F.avg_pool2d(
input=supp[0],
kernel_size=2,
stride=2,
count_include_pad=False))
# flow computation
flow = ref[0].new_zeros(num_batches, 2, h // 16, w // 16)
for i in range(4):
flow_up = F.interpolate(
input=flow,
scale_factor=2,
mode='bilinear',
align_corners=True) * 2.0
flow = flow_up + self.basic_module[i](
torch.cat([
ref[i],
flow_warp(supp[i], flow_up.permute(0, 2, 3, 1)), flow_up
], 1))
return flow
[docs]@BACKBONES.register_module()
class TOFlow(nn.Module):
"""PyTorch implementation of TOFlow.
In TOFlow, the LR frames are pre-upsampled and have the same size with
the GT frames.
Paper: Xue et al., Video Enhancement with Task-Oriented Flow, IJCV 2018
Code reference:
1. https://github.com/anchen1011/toflow
2. https://github.com/Coldog2333/pytoflow
Args:
adapt_official_weights (bool): Whether to adapt the weights translated
from the official implementation. Set to false if you want to
train from scratch. Default: False
"""
def __init__(self, adapt_official_weights=False):
super(TOFlow, self).__init__()
self.adapt_official_weights = adapt_official_weights
self.ref_idx = 0 if adapt_official_weights else 3
# The mean and std are for img with range (0, 1)
self.register_buffer(
'mean',
torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer(
'std',
torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
# flow estimation module
self.spynet = SPyNet()
# reconstruction module
self.conv1 = nn.Conv2d(3 * 7, 64, 9, 1, 4)
self.conv2 = nn.Conv2d(64, 64, 9, 1, 4)
self.conv3 = nn.Conv2d(64, 64, 1)
self.conv4 = nn.Conv2d(64, 3, 1)
# activation function
self.relu = nn.ReLU(inplace=True)
[docs] def normalize(self, img):
"""Normalize the input image.
Args:
img (Tensor): Input image.
Returns:
Tensor: Normalized image.
"""
return (img - self.mean) / self.std
[docs] def denormalize(self, img):
"""Denormalize the output image.
Args:
img (Tensor): Output image.
Returns:
Tensor: Denormalized image.
"""
return img * self.std + self.mean
[docs] def forward(self, lrs):
"""
Args:
lrs: Input lr frames: (b, 7, 3, h, w).
Returns:
Tensor: SR frame: (b, 3, h, w).
"""
# In the official implementation, the 0-th frame is the reference frame
if self.adapt_official_weights:
lrs = lrs[:, [3, 0, 1, 2, 4, 5, 6], :, :, :]
num_batches, num_lrs, _, h, w = lrs.size()
lrs = self.normalize(lrs.view(-1, 3, h, w))
lrs = lrs.view(num_batches, num_lrs, 3, h, w)
lr_ref = lrs[:, self.ref_idx, :, :, :]
lr_aligned = []
for i in range(7): # 7 frames
if i == self.ref_idx:
lr_aligned.append(lr_ref)
else:
lr_supp = lrs[:, i, :, :, :]
flow = self.spynet(lr_ref, lr_supp)
lr_aligned.append(flow_warp(lr_supp, flow.permute(0, 2, 3, 1)))
# reconstruction
hr = torch.stack(lr_aligned, dim=1)
hr = hr.view(num_batches, -1, h, w)
hr = self.relu(self.conv1(hr))
hr = self.relu(self.conv2(hr))
hr = self.relu(self.conv3(hr))
hr = self.conv4(hr) + lr_ref
return self.denormalize(hr)
[docs] def init_weights(self, pretrained=None, strict=True):
"""Init weights for models.
Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Defaults to None.
strict (boo, optional): Whether strictly load the pretrained model.
Defaults to True.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=strict, logger=logger)
elif pretrained is None:
pass # use default initialization
else:
raise TypeError('"pretrained" must be a str or None. '
f'But received {type(pretrained)}.')