Source code for mmedit.models.common.model_utils
import numpy as np
import torch
[docs]def set_requires_grad(nets, requires_grad=False):
"""Set requies_grad for all the networks.
Args:
nets (nn.Module | list[nn.Module]): A list of networks or a single
network.
requires_grad (bool): Whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
[docs]def extract_bbox_patch(bbox, img, channel_first=True):
"""Extract patch from a given bbox
Args:
bbox (torch.Tensor | numpy.array): Bbox with (top, left, h, w). If
`img` has batch dimension, the `bbox` must be stacked at first
dimension. The shape should be (4,) or (n, 4).
img (torch.Tensor | numpy.array): Image data to be extracted. If
organized in batch dimension, the batch dimension must be the first
order like (n, h, w, c) or (n, c, h, w).
channel_first (bool): If True, the channel dimension of img is before
height and width, e.g. (c, h, w). Otherwise, the img shape (samples
in the batch) is like (h, w, c).
Returns:
(torch.Tensor | numpy.array): Extracted patches. The dimension of the \
output should be the same as `img`.
"""
def _extract(bbox, img):
assert len(bbox) == 4
t, l, h, w = bbox
if channel_first:
img_patch = img[..., t:t + h, l:l + w]
else:
img_patch = img[t:t + h, l:l + w, ...]
return img_patch
input_size = img.shape
assert len(input_size) == 3 or len(input_size) == 4
bbox_size = bbox.shape
assert bbox_size == (4, ) or (len(bbox_size) == 2
and bbox_size[0] == input_size[0])
# images with batch dimension
if len(input_size) == 4:
output_list = []
for i in range(input_size[0]):
img_patch_ = _extract(bbox[i], img[i:i + 1, ...])
output_list.append(img_patch_)
if isinstance(img, torch.Tensor):
img_patch = torch.cat(output_list, dim=0)
else:
img_patch = np.concatenate(output_list, axis=0)
# standardize image
else:
img_patch = _extract(bbox, img)
return img_patch
[docs]def scale_bbox(bbox, target_size):
"""Modify bbox to target size.
The original bbox will be enlarged to the target size with the original
bbox in the center of the new bbox.
Args:
bbox (np.ndarray | torch.Tensor): Bboxes to be modified. Bbox can
be in batch or not. The shape should be (4,) or (n, 4).
target_size (tuple[int]): Target size of final bbox.
Returns:
(np.ndarray | torch.Tensor): Modified bboxes.
"""
def _mod(bbox, target_size):
top_ori, left_ori, h_ori, w_ori = bbox
h, w = target_size
assert h >= h_ori and w >= w_ori
top = int(max(0, top_ori - (h - h_ori) // 2))
left = int(max(0, left_ori - (w - w_ori) // 2))
if isinstance(bbox, torch.Tensor):
bbox_new = torch.Tensor([top, left, h, w]).type_as(bbox)
else:
bbox_new = np.asarray([top, left, h, w])
return bbox_new
if isinstance(bbox, torch.Tensor):
bbox_new = torch.zeros_like(bbox)
elif isinstance(bbox, np.ndarray):
bbox_new = np.zeros_like(bbox)
else:
raise TypeError('bbox mush be torch.Tensor or numpy.ndarray'
f'but got type {type(bbox)}')
bbox_shape = list(bbox.shape)
if len(bbox_shape) == 2:
for i in range(bbox_shape[0]):
bbox_new[i, :] = _mod(bbox[i], target_size)
else:
bbox_new = _mod(bbox, target_size)
return bbox_new
[docs]def extract_around_bbox(img, bbox, target_size, channel_first=True):
"""Extract patches around the given bbox.
Args:
bbox (np.ndarray | torch.Tensor): Bboxes to be modified. Bbox can
be in batch or not.
target_size (List(int)): Target size of final bbox.
Returns:
(torch.Tensor | numpy.array): Extracted patches. The dimension of the \
output should be the same as `img`.
"""
bbox_new = scale_bbox(bbox, target_size)
img_patch = extract_bbox_patch(bbox_new, img, channel_first=channel_first)
return img_patch, bbox_new