Source code for mmedit.core.evaluation.eval_hooks

import os.path as osp

from mmcv.runner import Hook
from torch.utils.data import DataLoader


[docs]class EvalIterHook(Hook): """Non-Distributed evaluation hook for iteration-based runner. This hook will regularly perform evaluation in a given interval when performing in non-distributed environment. Args: dataloader (DataLoader): A PyTorch dataloader. interval (int): Evaluation interval. Default: 1. eval_kwargs (dict): Other eval kwargs. It contains: save_image (bool): Whether to save image. save_path (str): The path to save image. """ def __init__(self, dataloader, interval=1, **eval_kwargs): if not isinstance(dataloader, DataLoader): raise TypeError('dataloader must be a pytorch DataLoader, ' f'but got { type(dataloader)}') self.dataloader = dataloader self.interval = interval self.eval_kwargs = eval_kwargs self.save_image = self.eval_kwargs.pop('save_image', False) self.save_path = self.eval_kwargs.pop('save_path', None)
[docs] def after_train_iter(self, runner): """The behavior after each train iteration. Args: runner (``mmcv.runner.BaseRunner``): The runner. """ if not self.every_n_iters(runner, self.interval): return runner.log_buffer.clear() from mmedit.apis import single_gpu_test results = single_gpu_test( runner.model, self.dataloader, save_image=self.save_image, save_path=self.save_path, iteration=runner.iter) self.evaluate(runner, results)
[docs] def evaluate(self, runner, results): """Evaluation function. Args: runner (``mmcv.runner.BaseRunner``): The runner. results (dict): Model forward results. """ eval_res = self.dataloader.dataset.evaluate( results, logger=runner.logger, **self.eval_kwargs) for name, val in eval_res.items(): runner.log_buffer.output[name] = val runner.log_buffer.ready = True
[docs]class DistEvalIterHook(EvalIterHook): """Distributed evaluation hook. Args: dataloader (DataLoader): A PyTorch dataloader. interval (int): Evaluation interval. Default: 1. tmpdir (str | None): Temporary directory to save the results of all processes. Default: None. gpu_collect (bool): Whether to use gpu or cpu to collect results. Default: False. eval_kwargs (dict): Other eval kwargs. It may contain: save_image (bool): Whether save image. save_path (str): The path to save image. """ def __init__(self, dataloader, interval=1, gpu_collect=False, **eval_kwargs): super(DistEvalIterHook, self).__init__(dataloader, interval, **eval_kwargs) self.gpu_collect = gpu_collect
[docs] def after_train_iter(self, runner): """The behavior after each train iteration. Args: runner (``mmcv.runner.BaseRunner``): The runner. """ if not self.every_n_iters(runner, self.interval): return runner.log_buffer.clear() from mmedit.apis import multi_gpu_test results = multi_gpu_test( runner.model, self.dataloader, tmpdir=osp.join(runner.work_dir, '.eval_hook'), gpu_collect=self.gpu_collect, save_image=self.save_image, save_path=self.save_path, iteration=runner.iter) if runner.rank == 0: print('\n') self.evaluate(runner, results)