Skip to content

Commit

Permalink
Fix pix's SSIM's gradient to not nan-out on a flat image, and add a u…
Browse files Browse the repository at this point in the history
…nit test that catches it.

PiperOrigin-RevId: 444187129
  • Loading branch information
jonbarron authored and PIXDev committed Apr 25, 2022
1 parent 6accc96 commit 6f2cb67
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
5 changes: 3 additions & 2 deletions dm_pix/_src/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,9 @@ def filt_fn_x(z):

# Clip the variances and covariances to valid values.
# Variance must be non-negative:
sigma00 = jnp.maximum(0., sigma00)
sigma11 = jnp.maximum(0., sigma11)
epsilon = jnp.finfo(jnp.float32).eps**2
sigma00 = jnp.maximum(epsilon, sigma00)
sigma11 = jnp.maximum(epsilon, sigma11)
sigma01 = jnp.sign(sigma01) * jnp.minimum(
jnp.sqrt(sigma00 * sigma11), jnp.abs(sigma01))

Expand Down
7 changes: 7 additions & 0 deletions dm_pix/_src/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@ def test_ssim_lowerbound(self):
ssim = ssim_fn(img, -img)
np.testing.assert_allclose(ssim, -np.ones_like(ssim), atol=1E-5, rtol=1E-5)

@chex.all_variants
def test_ssim_finite_grad(self):
"""Test that SSIM produces a finite gradient on large flat regions."""
img = np.zeros((64, 64, 3))
grad = self.variant(jax.grad(metrics.ssim))(img, img)
np.testing.assert_equal(grad, np.zeros_like(grad))


if __name__ == "__main__":
absltest.main()

0 comments on commit 6f2cb67

Please sign in to comment.