From 6f2cb67fe68b69f2084d833a59ab397c23f3f92f Mon Sep 17 00:00:00 2001 From: Jon Barron Date: Mon, 25 Apr 2022 00:43:46 -0700 Subject: [PATCH] Fix pix's SSIM's gradient to not nan-out on a flat image, and add a unit test that catches it. PiperOrigin-RevId: 444187129 --- dm_pix/_src/metrics.py | 5 +++-- dm_pix/_src/metrics_test.py | 7 +++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/dm_pix/_src/metrics.py b/dm_pix/_src/metrics.py index 89c050b..cb14d9f 100644 --- a/dm_pix/_src/metrics.py +++ b/dm_pix/_src/metrics.py @@ -205,8 +205,9 @@ def filt_fn_x(z): # Clip the variances and covariances to valid values. # Variance must be non-negative: - sigma00 = jnp.maximum(0., sigma00) - sigma11 = jnp.maximum(0., sigma11) + epsilon = jnp.finfo(jnp.float32).eps**2 + sigma00 = jnp.maximum(epsilon, sigma00) + sigma11 = jnp.maximum(epsilon, sigma11) sigma01 = jnp.sign(sigma01) * jnp.minimum( jnp.sqrt(sigma00 * sigma11), jnp.abs(sigma01)) diff --git a/dm_pix/_src/metrics_test.py b/dm_pix/_src/metrics_test.py index 2b04983..1449fd5 100644 --- a/dm_pix/_src/metrics_test.py +++ b/dm_pix/_src/metrics_test.py @@ -127,6 +127,13 @@ def test_ssim_lowerbound(self): ssim = ssim_fn(img, -img) np.testing.assert_allclose(ssim, -np.ones_like(ssim), atol=1E-5, rtol=1E-5) + @chex.all_variants + def test_ssim_finite_grad(self): + """Test that SSIM produces a finite gradient on large flat regions.""" + img = np.zeros((64, 64, 3)) + grad = self.variant(jax.grad(metrics.ssim))(img, img) + np.testing.assert_equal(grad, np.zeros_like(grad)) + if __name__ == "__main__": absltest.main()