Skip to content

Commit

Permalink
updated typing with mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
danielfromearth committed Sep 19, 2023
1 parent 950c1b7 commit 50acd08
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 45 deletions.
55 changes: 30 additions & 25 deletions ncompare/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections import namedtuple
from collections.abc import Iterable
from pathlib import Path
from typing import Union
from typing import Optional, Union

import netCDF4
import numpy as np
Expand All @@ -25,14 +25,14 @@
def compare(
nc_a: Union[str, Path],
nc_b: Union[str, Path],
comparison_var_group: str = None,
comparison_var_name: str = None,
comparison_var_group: Optional[str] = None,
comparison_var_name: Optional[str] = None,
no_color: bool = False,
show_chunks: bool = False,
show_attributes: bool = False,
file_text: str = None,
file_csv: str = None,
file_xlsx: str = None,
file_text: Union[str, Path] = "",
file_csv: Union[str, Path] = "",
file_xlsx: Union[str, Path] = "",
) -> None:
"""Compare the variables contained within two different NetCDF datasets.
Expand Down Expand Up @@ -103,8 +103,8 @@ def run_through_comparisons(
out: Outputter,
nc_a: Union[str, Path],
nc_b: Union[str, Path],
comparison_var_group: str,
comparison_var_name: str,
comparison_var_group: Optional[str],
comparison_var_name: Optional[str],
show_chunks: bool,
show_attributes: bool,
) -> None:
Expand Down Expand Up @@ -164,8 +164,9 @@ def run_through_comparisons(
out.print(
Style.BRIGHT
+ Fore.RED
+ "\nError when comparing values for variable <%s> in group <%s>."
% (comparison_var_name, comparison_var_group)
+ "\nError when comparing values for variable <{}> in group <{}>.".format(
comparison_var_name, comparison_var_group
)
)
out.print(traceback.format_exc())
out.print("\n")
Expand All @@ -181,7 +182,11 @@ def run_through_comparisons(


def compare_multiple_random_values(
out: Outputter, nc_a: Path, nc_b: Path, groupname: str, num_comparisons: int = 100
out: Outputter,
nc_a: Union[str, Path],
nc_b: Union[str, Path],
groupname: str,
num_comparisons: int = 100,
):
"""Iterate through N random samples, and evaluate whether the differences exceed a threshold."""
# Open a variable from each NetCDF
Expand All @@ -206,7 +211,7 @@ def compare_multiple_random_values(
out.print("Done.", colors=False)


def walk_common_groups_tree(
def walk_common_groups_tree( # type:ignore[misc]
top_a_name: str,
top_a: Union[netCDF4.Dataset, netCDF4.Group],
top_b_name: str,
Expand Down Expand Up @@ -257,8 +262,8 @@ def walk_common_groups_tree(

def compare_two_nc_files(
out: Outputter,
nc_one: Path,
nc_two: Path,
nc_one: Union[str, Path],
nc_two: Union[str, Path],
show_chunks: bool = False,
show_attributes: bool = False,
) -> tuple[int, int, int]:
Expand Down Expand Up @@ -322,7 +327,7 @@ def _print_group_details_side_by_side(
)

# Count the number of variables in this group as long as this group exists.
vars_a_sorted, vars_b_sorted = "", ""
vars_a_sorted, vars_b_sorted = [""], [""]
if group_a:
vars_a_sorted = sorted(group_a.variables)
if group_b:
Expand Down Expand Up @@ -453,11 +458,11 @@ def _match_random_value(
rand_index = []
for dim_length in nc_var_a.shape:
rand_index.append(random.randint(0, dim_length - 1))
rand_index = tuple(rand_index)
rand_index_tuple = tuple(rand_index)

# Get the values from each variable
value_a = nc_var_a.values[rand_index]
value_b = nc_var_b.values[rand_index]
value_a = nc_var_a.values[rand_index_tuple]
value_b = nc_var_b.values[rand_index_tuple]

# Check whether null
if np.isnan(value_a) or np.isnan(value_b):
Expand All @@ -469,7 +474,7 @@ def _match_random_value(
out.print()
out.print(Fore.RED + f"Difference exceeded threshold (diff == {diff}")
out.print(f"var shape: {nc_var_a.shape}", colors=False)
out.print(f"indices: {rand_index}", colors=False)
out.print(f"indices: {rand_index_tuple}", colors=False)
out.print(f"value a: {value_a}", colors=False)
out.print(f"value b: {value_b}", colors=False, end="\n\n")
return False
Expand All @@ -479,7 +484,7 @@ def _match_random_value(

def _print_sample_values(out: Outputter, nc_filepath, groupname: str, varname: str) -> None:
comparison_variable = xr.open_dataset(nc_filepath, backend_kwargs={"group": groupname})[varname]
out.print(comparison_variable.values[0, :], colors=False)
out.print(str(comparison_variable.values[0, :]), colors=False)


def _get_attribute_value_as_str(varprops: VarProperties, attribute_key: str) -> str:
Expand All @@ -490,31 +495,31 @@ def _get_attribute_value_as_str(varprops: VarProperties, attribute_key: str) ->
# we are preventing any subsequent difference checker from detecting
# differences past the 5th element in the iterable.
# So, we need to figure out a way to still check for other differences past the 5th element.
return "[" + ", ".join([str(x) for x in attr[:5]]) + ", ..." + "]"
return "[" + ", ".join([str(x) for x in attr[:5]]) + ", ..." + "]" # type:ignore[index]

return str(attr)

return ""


def _get_vars(nc_filepath: Path, groupname: str) -> list:
def _get_vars(nc_filepath: Union[str, Path], groupname: str) -> list:
try:
grp = xr.open_dataset(nc_filepath, backend_kwargs={"group": groupname})
except OSError as err:
print("\nError occurred when attempting to open group within <%s>.\n" % nc_filepath)
raise err
grp_varlist = sorted(list(grp.variables.keys()))
grp_varlist = sorted(list(grp.variables.keys())) # type:ignore[type-var]

return grp_varlist


def _get_groups(nc_filepath: Path) -> list:
def _get_groups(nc_filepath: Union[str, Path]) -> list:
with netCDF4.Dataset(nc_filepath) as dataset:
groups_list = list(dataset.groups.keys())
return groups_list


def _get_dims(nc_filepath: Path) -> list:
def _get_dims(nc_filepath: Union[str, Path]) -> list:
def __get_dim_list(decode_times=True):
with xr.open_dataset(nc_filepath, decode_times=decode_times) as dataset:
return list(dataset.dims.items())
Expand Down
13 changes: 5 additions & 8 deletions ncompare/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
from collections.abc import Iterable
from pathlib import Path
from typing import Union
from typing import Optional, TextIO, Union

import colorama
import openpyxl
Expand Down Expand Up @@ -41,7 +41,7 @@ def __init__(
self,
keep_print_history: bool = False,
no_color: bool = False,
text_file: Union[str, Path] = None,
text_file: Optional[Union[str, Path]] = None,
):
"""Set up the handling of printing and saving destinations.
Expand All @@ -51,10 +51,7 @@ def __init__(
"""
# Parse the print history option.
self._keep_print_history = keep_print_history
if self._keep_print_history:
self._line_history = []
else:
self._line_history = None
self._line_history: list[list[str]] = []

if no_color:
# Replace colorized styles with blank strings.
Expand All @@ -71,7 +68,7 @@ def __init__(
if filepath.exists():
pass
# This will overwrite any existing file at this path, if one exists.
self._text_file_obj = open(
self._text_file_obj: Optional[TextIO] = open(
filepath, "w", encoding="utf-8"
) # pylint: disable=consider-using-with
else:
Expand Down Expand Up @@ -145,7 +142,7 @@ def _parse_single_str(s): # pylint: disable=invalid-name
else:
raise TypeError(f"Invalid type <{type(args)}>. Expected a `str` or `list`.")

if self._line_history is not None:
if self._keep_print_history:
self._line_history.append(parsed_strings)

@staticmethod
Expand Down
8 changes: 5 additions & 3 deletions ncompare/sequence_operations.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Helper functions for operating on iterables, such as lists or sets."""
from collections.abc import Iterable
from collections.abc import Generator, Iterable
from typing import Union

from ncompare.utils import coerce_to_str


def common_elements(sequence_a: Iterable, sequence_b: Iterable) -> tuple[int, str, str]:
def common_elements(
sequence_a: Iterable, sequence_b: Iterable
) -> Generator[tuple[int, str, str], None, None]:
"""Loop over combined items of two iterables, and yield aligned item pairs.
Note
Expand Down Expand Up @@ -43,7 +45,7 @@ def common_elements(sequence_a: Iterable, sequence_b: Iterable) -> tuple[int, st


def count_diffs(
list_a: list[Union[str, int]], list_b: list[Union[str, int]]
list_a: Union[list[str], list[int]], list_b: Union[list[str], list[int]]
) -> tuple[int, int, int]:
"""Count how many elements are either uniquely in one list or the other, or in both.
Expand Down
2 changes: 1 addition & 1 deletion ncompare/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def ensure_valid_path_exists(should_be_path: Union[str, Path]) -> Path:
raise TypeError(wrong_type_msg + str(type(should_be_path)))


def ensure_valid_path_with_suffix(should_be_path: Union[str, Path], suffix: str = None) -> Path:
def ensure_valid_path_with_suffix(should_be_path: Union[str, Path], suffix: str) -> Path:
"""Coerce input to a pathlib.Path with given suffix."""
wrong_type_msg = "Unexpected type for something that should be convertable to a Path: "

Expand Down
15 changes: 7 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,13 @@ mypy = "^1.5.1"
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tool.isort]
profile = "black"
line_length = 120
multi_line_output = 3
include_trailing_comma = "true"
default_section = "THIRDPARTY"
lines_after_imports = 2
combine_as_imports = "true"
[[tool.mypy.overrides]]
module = [
"colorama.*",
"netCDF4.*",
"openpyxl.*"
]
ignore_missing_imports = true

[tool.black]
line-length = 100
Expand Down

0 comments on commit 50acd08

Please sign in to comment.