from collections.abc import Sequence
import mmcv
import numpy as np
import torch
from mmcv.parallel import DataContainer as DC
from torch.nn import functional as F
from ..registry import PIPELINES
def to_tensor(data):
"""Convert objects of various python types to :obj:`torch.Tensor`.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
"""
if isinstance(data, torch.Tensor):
return data
elif isinstance(data, np.ndarray):
return torch.from_numpy(data)
elif isinstance(data, Sequence) and not mmcv.is_str(data):
return torch.tensor(data)
elif isinstance(data, int):
return torch.LongTensor([data])
elif isinstance(data, float):
return torch.FloatTensor([data])
else:
raise TypeError(f'type {type(data)} cannot be converted to tensor.')
[docs]@PIPELINES.register_module()
class ToTensor(object):
"""Convert some values in results dict to `torch.Tensor` type
in data loader pipeline.
Args:
keys (Sequence[str]): Required keys to be converted.
"""
def __init__(self, keys):
self.keys = keys
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for key in self.keys:
results[key] = to_tensor(results[key])
return results
def __repr__(self):
return self.__class__.__name__ + f'(keys={self.keys})'
[docs]@PIPELINES.register_module()
class ImageToTensor(object):
"""Convert image type to `torch.Tensor` type.
Args:
keys (Sequence[str]): Required keys to be converted.
to_float32 (bool): Whether convert numpy image array to np.float32
before converted to tensor. Default: True.
"""
def __init__(self, keys, to_float32=True):
self.keys = keys
self.to_float32 = to_float32
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for key in self.keys:
# deal with gray scale img: expand a color channel
if len(results[key].shape) == 2:
results[key] = results[key][..., None]
if self.to_float32 and not isinstance(results[key], np.float32):
results[key] = results[key].astype(np.float32)
results[key] = to_tensor(results[key].transpose(2, 0, 1))
return results
def __repr__(self):
return self.__class__.__name__ + (
f'(keys={self.keys}, to_float32={self.to_float32})')
@PIPELINES.register_module()
class FramesToTensor(ImageToTensor):
"""Convert frames type to `torch.Tensor` type.
It accpets a list of frames, converts each to `torch.Tensor` type and then
concatenates in a new dimension (dim=0).
Args:
keys (Sequence[str]): Required keys to be converted.
to_float32 (bool): Whether convert numpy image array to np.float32
before converted to tensor. Default: True.
"""
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for key in self.keys:
if not isinstance(results[key], list):
raise TypeError(f'results["{key}"] should be a list, '
f'but got {type(results[key])}')
for idx, v in enumerate(results[key]):
# deal with gray scale img: expand a color channel
if len(v.shape) == 2:
v = v[..., None]
if self.to_float32 and not isinstance(v, np.float32):
v = v.astype(np.float32)
results[key][idx] = to_tensor(v.transpose(2, 0, 1))
results[key] = torch.stack(results[key], dim=0)
if results[key].size(0) == 1:
results[key].squeeze_()
return results
[docs]@PIPELINES.register_module()
class GetMaskedImage(object):
"""Get masked image.
Args:
img_name (str): Key for clean image.
mask_name (str): Key for mask image. The mask shape should be
(h, w, 1) while '1' indicate holes and '0' indicate valid
regions.
"""
def __init__(self, img_name='gt_img', mask_name='mask'):
self.img_name = img_name
self.mask_name = mask_name
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
clean_img = results[self.img_name]
mask = results[self.mask_name]
masked_img = clean_img * (1. - mask)
results['masked_img'] = masked_img
return results
def __repr__(self):
return self.__class__.__name__ + (
f"(img_name='{self.img_name}', mask_name='{self.mask_name}')")
[docs]@PIPELINES.register_module()
class Collect(object):
"""Collect data from the loader relevant to the specific task.
This is usually the last stage of the data loader pipeline. Typically keys
is set to some subset of "img", "gt_labels".
The "img_meta" item is always populated. The contents of the "meta"
dictionary depends on "meta_keys".
Args:
keys (Sequence[str]): Required keys to be collected.
meta_keys (Sequence[str]): Required keys to be collected to "meta".
Default: None.
"""
def __init__(self, keys, meta_keys=None):
self.keys = keys
self.meta_keys = meta_keys
def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
data = {}
img_meta = {}
for key in self.meta_keys:
img_meta[key] = results[key]
data['meta'] = DC(img_meta, cpu_only=True)
for key in self.keys:
data[key] = results[key]
return data
def __repr__(self):
return self.__class__.__name__ + (
f'(keys={self.keys}, meta_keys={self.meta_keys})')