Source code for mmedit.core.hooks.visualization

import os.path as osp

import mmcv
import torch
from mmcv.runner import HOOKS, Hook
from mmcv.runner.dist_utils import master_only
from torchvision.utils import save_image


[docs]@HOOKS.register_module() class VisualizationHook(Hook): """Visualization hook. In this hook, we use the official api `save_image` in torchvision to save the visualization results. Args: output_dir (str): The file path to store visualizations. res_name_list (str): The list contains the name of results in outputs dict. The results in outputs dict must be a torch.Tensor with shape (n, c, h, w). interval (int): The interval of calling this hook. If set to -1, the visualization hook will not be called. Default: -1. filename_tmpl (str): Format string used to save images. The output file name will be formatted as this args. Default: 'iter_{}.png'. rerange (bool): Whether to rerange the output value from [-1, 1] to [0, 1]. We highly recommend users should preprocess the visualization results on their own. Here, we just provide a simple interface. Default: True. bgr2rgb (bool): Whether to reformat the channel dimension from BGR to RGB. The final image we will save is following RGB style. Default: True. nrow (int): The number of samples in a row. Default: 1. padding (int): The number of padding pixels between each samples. Default: 4. """ def __init__(self, output_dir, res_name_list, interval=-1, filename_tmpl='iter_{}.png', rerange=True, bgr2rgb=True, nrow=1, padding=4): assert mmcv.is_list_of(res_name_list, str) self.output_dir = output_dir self.res_name_list = res_name_list self.interval = interval self.filename_tmpl = filename_tmpl self.bgr2rgb = bgr2rgb self.rerange = rerange self.nrow = nrow self.padding = padding mmcv.mkdir_or_exist(self.output_dir)
[docs] @master_only def after_train_iter(self, runner): """The behavior after each train iteration. Args: runner (object): The runner. """ if not self.every_n_iters(runner, self.interval): return results = runner.outputs['results'] filename = self.filename_tmpl.format(runner.iter + 1) img_list = [x for k, x in results.items() if k in self.res_name_list] img_cat = torch.cat(img_list, dim=3).detach() if self.rerange: img_cat = ((img_cat + 1) / 2) if self.bgr2rgb: img_cat = img_cat[:, [2, 1, 0], ...] img_cat = img_cat.clamp_(0, 1) save_image( img_cat, osp.join(self.output_dir, filename), nrow=self.nrow, padding=self.padding)