diff --git a/tensorflow_transform/info_theory.py b/tensorflow_transform/info_theory.py index 751963d7..f536d105 100644 --- a/tensorflow_transform/info_theory.py +++ b/tensorflow_transform/info_theory.py @@ -84,7 +84,7 @@ def calculate_partial_mutual_information(n_ij, x_i, y_j, n): Returns: Mutual information for the cell x=i, y=j. """ - if n_ij == 0: + if n_ij == 0 or x_i == 0 or y_j == 0: return 0 return n_ij * ((log2(n_ij) + log2(n)) - (log2(x_i) + log2(y_j))) diff --git a/tensorflow_transform/info_theory_test.py b/tensorflow_transform/info_theory_test.py index 05bf4138..b8136b4f 100644 --- a/tensorflow_transform/info_theory_test.py +++ b/tensorflow_transform/info_theory_test.py @@ -137,6 +137,27 @@ def test_calculate_partial_expected_mutual_information( col_count=8, total_count=16, expected_mi=0), + dict( + testcase_name='invalid_input_zero_cell_count', + cell_count=4, + row_count=0, + col_count=8, + total_count=8, + expected_mi=0), + dict( + testcase_name='invalid_input_zero_row_count', + cell_count=4, + row_count=0, + col_count=8, + total_count=8, + expected_mi=0), + dict( + testcase_name='invalid_input_zero_col_count', + cell_count=4, + row_count=8, + col_count=0, + total_count=8, + expected_mi=0), ) def test_mutual_information(self, cell_count, row_count, col_count, total_count, expected_mi):