Skip to content

Commit

Permalink
Merge pull request #55 from nasa/feature/fix-variable-value-matching
Browse files Browse the repository at this point in the history
Fix variable value matching
  • Loading branch information
danielfromearth authored Oct 20, 2023
2 parents 6577862 + 14c8ffc commit ec2cfe3
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Deprecated
### Removed
### Fixed
- [pull-request/55](https://github.com/nasa/ncompare/pull/55): Fix variable value matching
### Security

## [1.2.0]
Expand Down
20 changes: 13 additions & 7 deletions ncompare/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def run_through_comparisons(
_, _, _ = out.lists_diff(list_a, list_b)

# Show the groups in each NetCDF file and evaluate differences.
out.print(Fore.LIGHTBLUE_EX + "\nGroups:", add_to_history=True)
out.print(Fore.LIGHTBLUE_EX + "\nRoot-level Groups:", add_to_history=True)
list_a = _get_groups(nc_a)
list_b = _get_groups(nc_b)
_, _, _ = out.lists_diff(list_a, list_b)
Expand Down Expand Up @@ -158,7 +158,9 @@ def run_through_comparisons(
+ "\nChecking multiple random values within specified variable <%s>:"
% comparison_var_name
)
compare_multiple_random_values(out, nc_a, nc_b, groupname=comparison_var_group)
compare_multiple_random_values(out, nc_a, nc_b,
groupname=comparison_var_group,
varname=comparison_var_name)

except KeyError:
out.print(
Expand Down Expand Up @@ -186,12 +188,13 @@ def compare_multiple_random_values(
nc_a: Union[str, Path],
nc_b: Union[str, Path],
groupname: str,
varname: str,
num_comparisons: int = 100,
):
"""Iterate through N random samples, and evaluate whether the differences exceed a threshold."""
# Open a variable from each NetCDF
nc_var_a = xr.open_dataset(nc_a, backend_kwargs={"group": groupname}).varname
nc_var_b = xr.open_dataset(nc_b, backend_kwargs={"group": groupname}).varname
nc_var_a = xr.open_dataset(nc_a, backend_kwargs={"group": groupname}).variables[varname]
nc_var_b = xr.open_dataset(nc_b, backend_kwargs={"group": groupname}).variables[varname]

num_mismatches = 0
for _ in range(num_comparisons):
Expand All @@ -200,6 +203,7 @@ def compare_multiple_random_values(
out.print(".", colors=False, end="")
elif match_result is None:
out.print("n", colors=False, end="")
num_mismatches += 1
else:
out.print("x", colors=False, end="")
num_mismatches += 1
Expand Down Expand Up @@ -444,14 +448,14 @@ def _var_properties(group: Union[netCDF4.Dataset, netCDF4.Group], varname: str)


def _match_random_value(
out: Outputter, nc_var_a: netCDF4.Variable, nc_var_b: netCDF4.Variable, thresh: float = 1e-6
out: Outputter, nc_var_a: xr.Variable, nc_var_b: xr.Variable, thresh: float = 1e-6
) -> Union[bool, None]:
"""Check whether a randomly selected data point matches between two variables.
Returns
-------
None or bool
None if data point is null for either variable
None if data point is null for one and only one of the variables
True if values match
False if the difference exceeds the given threshold
"""
Expand All @@ -466,7 +470,9 @@ def _match_random_value(
value_b = nc_var_b.values[rand_index_tuple]

# Check whether null
if np.isnan(value_a) or np.isnan(value_b):
if np.isnan(value_a) and np.isnan(value_b):
return True
elif np.isnan(value_a) or np.isnan(value_b):
return None

# Evaluate difference between values
Expand Down
17 changes: 12 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from xarray import Dataset

from . import data_for_tests_dir
from ncompare.printing import Outputter

@pytest.fixture(scope="session")
def outputter_obj():
return Outputter()

@pytest.fixture(scope="session")
def temp_data_dir(tmpdir_factory) -> Path:
Expand All @@ -14,11 +19,12 @@ def temp_data_dir(tmpdir_factory) -> Path:
@pytest.fixture(scope="session")
def ds_3dims_2vars_4coords(temp_data_dir) -> Path:
ds = Dataset(
dict(
data_vars=dict(
# "normal" (Gaussian) distribution of mean 0 and variance 1
z1=(["y", "x"], np.random.randn(2, 8)),
z2=(["time", "y"], np.random.randn(10, 2)),
),
dict(
coords=dict(
x=("x", np.linspace(0, 1.0, 8)),
time=("time", np.linspace(0, 1.0, 10)),
c=("y", ["a", "b"]),
Expand All @@ -33,12 +39,13 @@ def ds_3dims_2vars_4coords(temp_data_dir) -> Path:
@pytest.fixture(scope="session")
def ds_4dims_3vars_5coords(temp_data_dir):
ds = Dataset(
dict(
z1=(["y", "x"], np.random.randn(2, 8)),
data_vars=dict(
# "normal" (Gaussian) distribution of mean 10 and standard deviation 2.5
z1=(["y", "x"], 10 + 2.5 * np.random.randn(2, 8)),
z2=(["time", "y"], np.random.randn(10, 2)),
z3=(["y", "z"], np.random.randn(2, 9)),
),
dict(
coords=dict(
x=("x", np.linspace(0, 1.0, 8)),
time=("time", np.linspace(0, 1.0, 10)),
c=("y", ["a", "b"]),
Expand Down
12 changes: 11 additions & 1 deletion tests/test_netcdf_compare.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
from ncompare.core import compare
import xarray as xr

from ncompare.core import compare, _match_random_value


def test_dataset_compare_does_not_raise_exception(ds_3dims_2vars_4coords, ds_4dims_3vars_5coords):
compare(ds_3dims_2vars_4coords, ds_4dims_3vars_5coords)

def test_dataset_compare_does_not_raise_exception_2(ds_3dims_2vars_4coords, ds_3dims_3vars_4coords_1group):
compare(ds_3dims_2vars_4coords, ds_3dims_3vars_4coords_1group)

def test_matching_random_values(ds_3dims_2vars_4coords, ds_4dims_3vars_5coords,
ds_3dims_3vars_4coords_1group, outputter_obj):
variable_array_1 = xr.open_dataset(ds_3dims_2vars_4coords).variables['z1']
variable_array_2 = xr.open_dataset(ds_4dims_3vars_5coords).variables['z1']

assert _match_random_value(outputter_obj, variable_array_1, variable_array_1, ) is True
assert _match_random_value(outputter_obj, variable_array_1, variable_array_2, ) is False
6 changes: 0 additions & 6 deletions tests/test_printing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
import pytest

from ncompare.printing import Outputter


@pytest.fixture
def outputter_obj():
return Outputter()

def test_list_of_strings_diff(outputter_obj):
left, right, both = outputter_obj.lists_diff(['hey', 'yo', 'beebop'],
Expand Down

0 comments on commit ec2cfe3

Please sign in to comment.