diff --git a/mmeval/metrics/__init__.py b/mmeval/metrics/__init__.py index 0411f1ee..c7fe8c1e 100644 --- a/mmeval/metrics/__init__.py +++ b/mmeval/metrics/__init__.py @@ -20,6 +20,7 @@ from .mae import MeanAbsoluteError from .matting_mse import MattingMeanSquaredError from .mean_iou import MeanIoU +from .ms_ssim import MultiScaleStructureSimilarity from .mse import MeanSquaredError from .niqe import NaturalImageQualityEvaluator from .oid_map import OIDMeanAP @@ -34,6 +35,7 @@ from .sad import SumAbsoluteDifferences from .snr import SignalNoiseRatio from .ssim import StructuralSimilarity +from .swd import SlicedWassersteinDistance from .voc_map import VOCMeanAP from .word_accuracy import WordAccuracy @@ -48,7 +50,8 @@ 'KeypointAUC', 'KeypointNME', 'NaturalImageQualityEvaluator', 'WordAccuracy', 'PrecisionRecallF1score', 'SingleLabelPrecisionRecallF1score', 'MultiLabelPrecisionRecallF1score', - 'CharRecallPrecision' + 'CharRecallPrecision', 'MultiScaleStructureSimilarity', + 'SlicedWassersteinDistance' ] _deprecated_msg = ( diff --git a/mmeval/metrics/ms_ssim.py b/mmeval/metrics/ms_ssim.py new file mode 100644 index 00000000..4f63b3d8 --- /dev/null +++ b/mmeval/metrics/ms_ssim.py @@ -0,0 +1,295 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +from scipy import signal +from typing import Dict, List, Sequence, Tuple + +from mmeval.core import BaseMetric +from .utils.image_transforms import reorder_image + + +class MultiScaleStructureSimilarity(BaseMetric): + """MS-SSIM (Multi-Scale Structure Similarity) metric. + + Ref: + This class implements Multi-Scale Structural Similarity (MS-SSIM) Image + Quality Assessment according to Zhou Wang's paper, "Multi-scale structural + similarity for image quality assessment" (2003). + Link: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf + + Author's MATLAB implementation: + http://www.cns.nyu.edu/~lcv/ssim/msssim.zip + + PGGAN's implementation: + https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/ms_ssim.py + + Args: + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Defaults to 'HWC'. + max_val (int): the dynamic range of the images (i.e., the difference + between the maximum and the minimum allowed values). + Defaults to 255. + filter_size (int): Size of blur kernel to use (will be reduced for + small images). Defaults to 11. + filter_sigma (float): Standard deviation for Gaussian blur kernel (will + be reduced for small images). Defaults to 1.5. + k1 (float): Constant used to maintain stability in the SSIM calculation + (0.01 in the original paper). Defaults to 0.01. + k2 (float): Constant used to maintain stability in the SSIM calculation + (0.03 in the original paper). Defaults to 0.03. + weights (List[float]): List of weights for each level. Defaults to + [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]. Noted that the default + weights don't sum to 1.0 but do match the paper / matlab code. + **kwargs: Keyword parameters passed to :class:`BaseMetric`. + + Examples: + + >>> from mmeval import MultiScaleStructureSimilarity as MS_SSIM + >>> import numpy as np + >>> + >>> ms_ssim = MS_SSIM() + >>> preds = [np.random.randint(0, 255, size=(3, 32, 32)) for _ in range(4)] # noqa + >>> gts = [np.random.randint(0, 255, size=(3, 32, 32)) for _ in range(4)] # noqa + >>> ms_ssim(preds, gts) # doctest: +ELLIPSIS + {'ms_ssim': ...} + """ + + def __init__(self, + input_order: str = 'CHW', + max_val: int = 255, + filter_size: int = 11, + filter_sigma: float = 1.5, + k1: float = 0.01, + k2: float = 0.03, + weights: List[float] = [ + 0.0448, 0.2856, 0.3001, 0.2363, 0.1333 + ], + **kwargs) -> None: + super().__init__(**kwargs) + + assert input_order.upper() in [ + 'CHW', 'HWC' + ], (f'Wrong input_order {input_order}. Supported input_orders are ' + '"HWC" and "CHW"') + self.input_order = input_order + + self.max_val = max_val + self.filter_size = filter_size + self.filter_sigma = filter_sigma + self.k1 = k1 + self.k2 = k2 + self.weights = np.array(weights) + + def add(self, predictions: Sequence[np.ndarray], groundtruths: Sequence[np.ndarray]) -> None: # type: ignore # yapf: disable # noqa: E501 + """Add a bunch of images to calculate metric result. + + Args: + predictions (Sequence[np.ndarray]): Predictions of the model. + The length of `predictions` must be same as `groundtruths`. + The width and height of each element must be divisible by 2 ** + num_scale (`self.weights.size`). The channel order of each + element should align with `self.input_order` and the range + should be [0, 255]. + groundtruths (Sequence[np.ndarray], optional): Groundtruth of the + model. The number of elements in the Sequence must be same as + `predictions`, and the width and height of each element must + be divisible by 2 ** num_scale (`self.weights.size`). The + channel order of each element should align with + `self.input_order` and the range should be [0, 255]. + Defaults to None. + """ + assert len(predictions) == len(groundtruths), ( + 'The length of "predictions" and "groundtruths" must be ' + 'same.') + half1, half2 = predictions, groundtruths + + half1 = [reorder_image(samp, self.input_order) for samp in half1] + half2 = [reorder_image(samp, self.input_order) for samp in half2] + least_size = 2**self.weights.size + assert all([ + sample.shape[0] % least_size == 0 for sample in half1 + ]), ('The height and width of each sample must be divisible by ' + f'{least_size} (2 ** len(self.weights.size)).') + assert all([ + sample.shape[0] % least_size == 0 for sample in half2 + ]), ('The height and width of each sample must be divisible by ' + f'{least_size} (2 ** self.weights.size).') + + half1 = np.stack(half1, axis=0).astype(np.uint8) + half2 = np.stack(half2, axis=0).astype(np.uint8) + + self._results += self.compute_ms_ssim(half1, half2) + + def compute_metric(self, results: List[np.float64]) -> Dict[str, float]: + """Compute the MS-SSIM metric. + + This method would be invoked in ``BaseMetric.compute`` after + distributed synchronization. + + Args: + results (List[np.float64]): A list that consisting the PSNR score. + This list has already been synced across all ranks. + + Returns: + Dict[str, float]: The computed PSNR metric. + """ + return {'ms-ssim': float(np.array(results).mean())} + + def compute_ms_ssim(self, img1: np.array, img2: np.array) -> List[float]: + """Calculate MS-SSIM (multi-scale structural similarity). + + Args: + img1 (ndarray): Images with range [0, 255] and order "NHWC". + img2 (ndarray): Images with range [0, 255] and order "NHWC". + + Returns: + np.ndarray: MS-SSIM score between `img1` and `img2` of shape (N, ). + """ + if img1.shape != img2.shape: + raise RuntimeError( + 'Input images must have the same shape (%s vs. %s).' % + (img1.shape, img2.shape)) + if img1.ndim != 4: + raise RuntimeError( + 'Input images must have four dimensions, not %d' % img1.ndim) + + levels = self.weights.size + im1, im2 = (x.astype(np.float32) for x in [img1, img2]) + mssim = [] + mcs = [] + for _ in range(levels): + ssim, cs = self._ssim_for_multi_scale( + im1, + im2, + max_val=self.max_val, + filter_size=self.filter_size, + filter_sigma=self.filter_sigma, + k1=self.k1, + k2=self.k2) + mssim.append(ssim) + mcs.append(cs) + im1, im2 = (self._hox_downsample(x) for x in [im1, im2]) + + # Clip to zero. Otherwise we get NaNs. + mssim = np.clip(np.asarray(mssim), 0.0, np.inf) + mcs = np.clip(np.asarray(mcs), 0.0, np.inf) + + results = np.prod( + mcs[:-1, :]**self.weights[:-1, np.newaxis], axis=0) * ( + mssim[-1, :]**self.weights[-1]) + return results.tolist() + + @staticmethod + def _f_special_gauss(size: int, sigma: float) -> np.ndarray: + r"""Return a circular symmetric gaussian kernel. + + Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/2504c3f3cb98ca58751610ad61fa1097313152bd/metrics/ms_ssim.py#L25-L36 # noqa + + Args: + size (int): Size of Gaussian kernel. + sigma (float): Standard deviation for Gaussian blur kernel. + + Returns: + np.ndarray: Gaussian kernel. + """ + radius = size // 2 + offset = 0.0 + start, stop = -radius, radius + 1 + if size % 2 == 0: + offset = 0.5 + stop -= 1 + x, y = np.mgrid[offset + start:stop, # type: ignore # noqa + offset + start:stop] # type: ignore # noqa + assert len(x) == size + g = np.exp(-((x**2 + y**2) / (2.0 * sigma**2))) + return g / g.sum() + + @staticmethod + def _hox_downsample(img: np.ndarray) -> np.ndarray: + r"""Downsample images with factor equal to 0.5. + + Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/2504c3f3cb98ca58751610ad61fa1097313152bd/metrics/ms_ssim.py#L110-L111 # noqa + + Args: + img (np.ndarray): Images with order "NHWC". + + Returns: + np.ndarray: Downsampled images with order "NHWC". + """ + return (img[:, 0::2, 0::2, :] + img[:, 1::2, 0::2, :] + + img[:, 0::2, 1::2, :] + img[:, 1::2, 1::2, :]) * 0.25 + + def _ssim_for_multi_scale( + self, + img1: np.ndarray, + img2: np.ndarray, + max_val: int = 255, + filter_size: int = 11, + filter_sigma: float = 1.5, + k1: float = 0.01, + k2: float = 0.03) -> Tuple[np.ndarray, np.ndarray]: + """Calculate SSIM (structural similarity) and contrast sensitivity. + + Ref: + Our implementation is based on PGGAN: + https://github.com/tkarras/progressive_growing_of_gans/blob/2504c3f3cb98ca58751610ad61fa1097313152bd/metrics/ms_ssim.py#L38-L108 # noqa + + Args: + img1 (np.ndarray): Images with range [0, 255] and order "NHWC". + img2 (np.ndarray): Images with range [0, 255] and order "NHWC". + max_val (int): the dynamic range of the images (i.e., the + difference between the maximum the and minimum allowed + values). Defaults to 255. + filter_size (int): Size of blur kernel to use (will be reduced for + small images). Defaults to 11. + filter_sigma (float): Standard deviation for Gaussian blur kernel ( + will be reduced for small images). Defaults to 1.5. + k1 (float): Constant used to maintain stability in the SSIM + calculation (0.01 in the original paper). Defaults to 0.01. + k2 (float): Constant used to maintain stability in the SSIM + calculation (0.03 in the original paper). Defaults to 0.03. + + Returns: + tuple: Pair containing the mean SSIM and contrast sensitivity + between `img1` and `img2`. + """ + img1 = img1.astype(np.float32) + img2 = img2.astype(np.float32) + _, height, width, _ = img1.shape + + # Filter size can't be larger than height or width of images. + size = min(filter_size, height, width) + + # Scale down sigma if a smaller filter size is used. + sigma = size * filter_sigma / filter_size if filter_size else 0 + + if filter_size: + window = np.reshape( + self._f_special_gauss(size, sigma), (1, size, size, 1)) + mu1 = signal.fftconvolve(img1, window, mode='valid') + mu2 = signal.fftconvolve(img2, window, mode='valid') + sigma11 = signal.fftconvolve(img1 * img1, window, mode='valid') + sigma22 = signal.fftconvolve(img2 * img2, window, mode='valid') + sigma12 = signal.fftconvolve(img1 * img2, window, mode='valid') + else: + # Empty blur kernel so no need to convolve. + mu1, mu2 = img1, img2 + sigma11 = img1 * img1 + sigma22 = img2 * img2 + sigma12 = img1 * img2 + + mu11 = mu1 * mu1 + mu22 = mu2 * mu2 + mu12 = mu1 * mu2 + sigma11 -= mu11 + sigma22 -= mu22 + sigma12 -= mu12 + + # Calculate intermediate values used by both ssim and cs_map. + c1 = (k1 * max_val)**2 + c2 = (k2 * max_val)**2 + v1 = 2.0 * sigma12 + c2 + v2 = sigma11 + sigma22 + c2 + ssim = np.mean((((2.0 * mu12 + c1) * v1) / ((mu11 + mu22 + c1) * v2)), + axis=(1, 2, 3)) # Return for each image individually. + cs = np.mean(v1 / v2, axis=(1, 2, 3)) + return ssim, cs diff --git a/mmeval/metrics/swd.py b/mmeval/metrics/swd.py new file mode 100644 index 00000000..7cbc89fb --- /dev/null +++ b/mmeval/metrics/swd.py @@ -0,0 +1,305 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +from scipy import ndimage +from typing import Any, Dict, List, Sequence + +from mmeval.core import BaseMetric + + +class SlicedWassersteinDistance(BaseMetric): + """SWD (Sliced Wasserstein distance) metric. We calculate the SWD of two + sets of images in the following way. In every 'feed', we obtain the + Laplacian pyramids of every images and extract patches from the Laplacian + pyramids as descriptors. In 'summary', we normalize these descriptors along + channel, and reshape them so that we can use these descriptors to represent + the distribution of real/fake images. And we can calculate the sliced + Wasserstein distance of the real and fake descriptors as the SWD of the + real and fake images. Note that, as with the official implementation, we + multiply the result by 10 to prevent the value from being too small and to + facilitate comparison. + + Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/sliced_wasserstein.py # noqa + + Args: + resolution (int): Resolution of the input images. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Defaults to 'HWC'. + **kwargs: Keyword parameters passed to :class:`BaseMetric`. + + Examples: + + >>> from mmeval import SlicedWassersteinDistance as SWD + >>> import numpy as np + >>> + >>> swd = SWD(resolution=64) + >>> preds = [np.random.randint(0, 255, size=(3, 64, 64)) for _ in range(4)] # noqa + >>> gts = [np.random.randint(0, 255, size=(3, 64, 64)) for _ in range(4)] # noqa + >>> swd(preds, gts) # doctest: +ELLIPSIS + {'SWD/64': ..., 'SWD/32': ..., 'SWD/16': ..., 'SWD/8': ..., 'SWD/avg': ...} + """ + + def __init__(self, + resolution: int, + input_order: str = 'CHW', + **kwargs) -> None: + super().__init__(**kwargs) + + assert input_order.upper() in [ + 'CHW', 'HWC' + ], (f'Wrong input_order {input_order}. Supported input_orders are ' + '"HWC" and "CHW"') + self.input_order = input_order + + self.nhood_size = 7 # height and width of the extracted patches + self.nhoods_per_image = 128 # number of extracted patches per image + self.dir_repeats = 4 # times of sampling directions + self.dirs_per_repeat = 128 # number of directions per sampling + + self.resolution = resolution + self._resolutions: List[Any] = [] + while resolution >= 16 and len(self._resolutions) < 4: + self._resolutions.append(resolution) + resolution //= 2 + self.n_pyramids = len(self._resolutions) + self.gaussian_k = self.get_gaussian_kernel() + + @staticmethod + def get_gaussian_kernel() -> np.ndarray: + """Get Gaussian kernel. + + Returns: + np.ndarray: Gaussian kernel. + """ + kernel = np.array( + [[1, 4, 6, 4, 1], [4, 16, 24, 16, 4], [6, 24, 36, 24, 6], + [4, 16, 24, 16, 4], [1, 4, 6, 4, 1]], np.float32) / 256.0 + return kernel.reshape(1, 1, 5, 5) + + def add(self, predictions: Sequence[np.ndarray], groundtruths: Sequence[np.ndarray]) -> None: # type: ignore # yapf: disable # noqa: E501 + """Add processed feature of batch to ``self._results`` + + Args: + predictions (Sequence[np.ndarray]): Predictions of the model. + The channel order of each element in Sequence should align with + `self.input_order` and the range should be [0, 255]. + groundtruths (Sequence[np.ndarray]): The ground truth images. + The channel order of each element in Sequence should align with + `self.input_order` and the range should be [0, 255]. + """ + if self.input_order == 'HWC': + predictions = [pred.transpose(2, 0, 1) for pred in predictions] + groundtruths = [gt.transpose(2, 0, 1) for gt in groundtruths] + preds = np.stack(predictions, axis=0) + gts = np.stack(groundtruths, axis=0) + + # convert to [-1, 1] + preds, gts = (preds - 127.5) / 127.5, (gts - 127.5) / 127.5 + + # prepare feature pyramid + fake_pyramid = self.laplacian_pyramid(preds, self.n_pyramids - 1, + self.gaussian_k) + real_pyramid = self.laplacian_pyramid(gts, self.n_pyramids - 1, + self.gaussian_k) + + # init list to save fake desp and real desp + if self._results == []: + self._results.append([[] for _ in self._resolutions]) + self._results.append([[] for _ in self._resolutions]) + for lod, level in enumerate(fake_pyramid): + desc = self.get_descriptors_for_minibatch(level, self.nhood_size, + self.nhoods_per_image) + self._results[0][lod].append(desc) + + for lod, level in enumerate(real_pyramid): + desc = self.get_descriptors_for_minibatch(level, self.nhood_size, + self.nhoods_per_image) + self._results[1][lod].append(desc) + + def compute_metric(self, + results: List[List[np.ndarray]]) -> Dict[str, float]: + """Compute the SWD metric. + + This method would be invoked in ``BaseMetric.compute`` after + distributed synchronization. + + Args: + results (List[List[np.ndarray]]): A list of feature of different + resolution extracted from real and fake images. This list has + already been synced across all ranks. + + Returns: Dict[str, float]: The computed SWD metric. + """ + results_fake, results_real = results + fake_descs = [self.finalize_descriptors(d) for d in results_fake] + real_descs = [self.finalize_descriptors(d) for d in results_real] + + distance = [ + self.compute_swd(dreal, dfake, self.dir_repeats, + self.dirs_per_repeat) + for dreal, dfake in zip(real_descs, fake_descs) + ] + del real_descs + del fake_descs + # multiply by 10^3 refers to https://github.com/tkarras/progressive_growing_of_gans/blob/2504c3f3cb98ca58751610ad61fa1097313152bd/metrics/sliced_wasserstein.py#L132 # noqa + distance = [d * 1e3 for d in distance] + result = distance + [np.mean(distance)] + return { + f'SWD/{resolution}': d + for resolution, d in zip(self._resolutions + ['avg'], result) + } + + def compute_swd(self, distribution_a: np.array, distribution_b: np.array, + dir_repeats: int, dirs_per_repeat: int) -> np.ndarray: + """Calculate SWD (SlicedWassersteinDistance). + + Args: + distribution_a (np.ndarray): Distribution to compute sliced + wasserstein distance. + distribution_b (np.ndarray): Distribution to compute sliced + wasserstein distance. + dir_repeats (int): The number of projection times. + dirs_per_repeat (int): The number of directions per projection. + + Returns: + np.ndarray: SWD between `distribution_a` and `distribution_b`. + """ + results = [] + for _ in range(dir_repeats): + # (descriptor_component, direction) + dirs = np.random.randn(distribution_a.shape[1], dirs_per_repeat) + # normalize descriptor components for each direction + dirs /= np.sqrt(np.sum(np.square(dirs), axis=0, keepdims=True)) + dirs = dirs.astype(np.float32) + # (neighborhood, direction) + projA = np.matmul(distribution_a, dirs) + projB = np.matmul(distribution_b, dirs) + # sort neighborhood projections for each direction + projA = np.sort(projA, axis=0) + projB = np.sort(projB, axis=0) + # pointwise wasserstein distances + dists = np.abs(projA - projB) + # average over neighborhoods and directions + results.append(np.mean(dists)) + + return np.mean(results) + + def laplacian_pyramid(self, original: np.ndarray, n_pyramids: int, + gaussian_k: np.ndarray) -> List[np.ndarray]: + """Calculate Laplacian pyramid. + + Ref: https://github.com/koshian2/swd-pytorch/blob/master/swd.py + + Args: + original (np.ndarray): Batch of Images with range [0, 1] and order + "NCHW". + n_pyramids (int): Levels of pyramids minus one. + gaussian_k (np.ndarray): Gaussian kernel with shape (1, 1, 5, 5). + + Return: + list[np.ndarray]. Laplacian pyramids of original. + """ + # create gaussian pyramid + pyramids = self.gaussian_pyramid(original, n_pyramids, gaussian_k) + + # pyramid up - diff + laplacian = [] + for i in range(len(pyramids) - 1): + diff = pyramids[i] - self.get_pyramid_layer( + pyramids[i + 1], gaussian_k, 'up') + laplacian.append(diff) + # Add last gaussian pyramid + laplacian.append(pyramids[len(pyramids) - 1]) + return laplacian + + def gaussian_pyramid(self, original: np.ndarray, n_pyramids: int, + gaussian_k: np.ndarray) -> List[np.ndarray]: + """Get a group of gaussian pyramid. + + Args: + original (np.ndarray): The input image. + n_pyramids (int): The number of pyramids. + gaussian_k (np.ndarray): The gaussian kernel. + + Returns: + List[np.ndarray]: The list of output of gaussian pyramid. + """ + x = original + # pyramid down + pyramids = [original] + for _ in range(n_pyramids): + x = self.get_pyramid_layer(x, gaussian_k) + pyramids.append(x) + return pyramids + + @staticmethod + def get_pyramid_layer(image: np.ndarray, + gaussian_k: np.ndarray, + direction: str = 'down') -> np.ndarray: + """Get the pyramid layer. + + Args: + image (np.ndarray): Input image. + gaussian_k (np.ndarray): Gaussian kernel. + direction (str, optional): The direction of pyramid. Defaults to + 'down'. + + Returns: + np.ndarray: The output of the pyramid. + """ + shape = image.shape + if direction == 'up': + # nearest interpolation with scale_factor = 2 + res = np.zeros((shape[0], shape[1], shape[2] * 2, shape[3] * 2), + image.dtype) + res[:, :, ::2, ::2] = image + res[:, :, 1::2, 1::2] = image + res[:, :, 1::2, ::2] = image + res[:, :, ::2, 1::2] = image + return ndimage.convolve(res, gaussian_k, mode='constant') + else: + return ndimage.convolve( + image, gaussian_k, mode='constant')[:, :, ::2, ::2] + + def get_descriptors_for_minibatch(self, minibatch: np.ndarray, + nhood_size: int, + nhoods_per_image: int) -> np.ndarray: + r"""Get descriptors of one level of pyramids. + + Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/sliced_wasserstein.py # noqa + + Args: + minibatch (np.ndarray): Pyramids of one level with order "NCHW". + nhood_size (int): Pixel neighborhood size. + nhoods_per_image (int): The number of descriptors per image. + + Return: + np.ndarray: Descriptors of images from one level batch. + """ + shape = minibatch.shape # (minibatch, channel, height, width) + assert len(shape) == 4 and shape[1] == 3 + N = nhoods_per_image * shape[0] + H = nhood_size // 2 + nhood, chan, x, y = np.ogrid[0:N, 0:3, -H:H + 1, -H:H + 1] + img = nhood // nhoods_per_image + x = x + np.random.randint(H, shape[3] - H, size=(N, 1, 1, 1)) + y = y + np.random.randint(H, shape[2] - H, size=(N, 1, 1, 1)) + idx = ((img * shape[1] + chan) * shape[2] + y) * shape[3] + x + return minibatch.flat[idx] + + @staticmethod + def finalize_descriptors(desc: List[np.ndarray]) -> np.ndarray: + r"""Normalize and reshape descriptors. + + Ref: https://github.com/tkarras/progressive_growing_of_gans/blob/master/metrics/sliced_wasserstein.py # noqa + + Args: + desc (List[np.ndarray]): List of descriptors of one level. + + Return: + np.ndarray: Descriptors after normalized along channel and flattened. + """ + desc = np.concatenate(desc, axis=0) + desc -= np.mean(desc, axis=(0, 2, 3), keepdims=True) + desc /= np.std(desc, axis=(0, 2, 3), keepdims=True) + desc = desc.reshape(desc.shape[0], -1) + return desc diff --git a/setup.cfg b/setup.cfg index b35a1556..9ae6b0dd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,7 +18,7 @@ SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true [codespell] skip = *.ipynb quiet-level = 3 -ignore-words-list = dota, rouge +ignore-words-list = dota, rouge, lod [mypy] allow_redefinition = True diff --git a/tests/test_metrics/test_ms_ssim.py b/tests/test_metrics/test_ms_ssim.py new file mode 100644 index 00000000..b5758b8a --- /dev/null +++ b/tests/test_metrics/test_ms_ssim.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# yapf: disable + +import numpy as np +import pytest + +from mmeval.metrics import MultiScaleStructureSimilarity as MS_SSIM + + +def test_ms_ssim_init(): + ms_ssim = MS_SSIM() + assert (ms_ssim.weights == np.array( + [0.0448, 0.2856, 0.3001, 0.2363, 0.1333])).all() + + ms_ssim = MS_SSIM(weights=[0.22, 0.72]) + assert (ms_ssim.weights == np.array([0.22, 0.72])).all() + + +@pytest.mark.parametrize( + argnames=['init_kwargs', 'preds', 'gts', 'results'], + argvalues=[ + ({'input_order': 'CHW'}, [np.ones((3, 64, 64)) * 255.] * 2, + [np.ones((3, 64, 64)) * 255.] * 2, 1), + ({'input_order': 'HWC'}, [np.zeros((64, 64, 3)) * 255.] * 2, + [np.zeros((64, 64, 3)) * 255.] * 2, 1), + ({'input_order': 'HWC', 'filter_size': 0}, + [np.zeros((64, 64, 3)) * 255.] * 2, + [np.zeros((64, 64, 3)) * 255.] * 2, 1), + ({}, [np.ones((3, 64, 64)) * 255], [np.zeros((3, 64, 64))], + 0.2929473249689634), + ({'input_order': 'HWC'}, + [np.ones((64, 64, 3)) * 255], [np.zeros((64, 64, 3))], + 0.2929473249689634), + ({'input_order': 'HWC', 'filter_size': 0}, + [np.ones((64, 64, 3)) * 255], [np.zeros((64, 64, 3))], + 0.29295045137405396)] + ) +def test_ms_ssim(init_kwargs, preds, gts, results): + ms_ssim = MS_SSIM(**init_kwargs) + ms_ssim_results = ms_ssim(preds, gts) + np.testing.assert_allclose( + ms_ssim_results['ms-ssim'], results) + + +def test_raise_error(): + ms_ssim = MS_SSIM() + preds = [np.random.randint(0, 255, (3, 64, 64))] + gts = [np.random.randint(0, 255, (3, 64, 64))] * 2 + with pytest.raises(AssertionError): + ms_ssim(preds, gts) + + # shape checking + with pytest.raises(RuntimeError): + ms_ssim.compute_ms_ssim( + np.random.randint(0, 255, (64, 64, 3)), + np.random.randint(0, 255, (3, 64, 64)) + ) + + with pytest.raises(RuntimeError): + ms_ssim.compute_ms_ssim( + np.random.randint(0, 255, (64, 64, 3)), + np.random.randint(0, 255, (64, 64, 3)) + ) + + with pytest.raises(AssertionError): + preds = [np.random.randint(0, 255, (3, 16, 16))] * 3 + gts = [np.random.randint(0, 255, (3, 16, 16))] * 3 + ms_ssim(preds, gts) diff --git a/tests/test_metrics/test_swd.py b/tests/test_metrics/test_swd.py new file mode 100644 index 00000000..e543229c --- /dev/null +++ b/tests/test_metrics/test_swd.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# yapf: disable + +import numpy as np +import pytest + +from mmeval.metrics import SlicedWassersteinDistance as SWD + + +@pytest.mark.parametrize( + argnames=['init_kwargs', 'preds', 'gts', 'results'], + argvalues=[ + ({'resolution': 32}, + [np.ones((3, 32, 32)) * i for i in range(100)], + [np.ones((3, 32, 32)) * 2 * i for i in range(100)], + [198.67430960025712, 33.72058904027052, 116.19744932026381])] +) +def test_swd(init_kwargs, preds, gts, results): + swd = SWD(**init_kwargs) + swd_results = swd(preds, gts) + for out, res in zip(swd_results.values(), results): + np.testing.assert_almost_equal(out / 100, res / 100, decimal=1) + swd.reset() + assert swd._results == [] + + swd.add(preds[:50], gts[:50]) + swd.add(preds[50:], gts[50:]) + swd_results = swd.compute() + for out, res in zip(swd_results.values(), results): + np.testing.assert_almost_equal(out / 100, res / 100, decimal=1)