From 50acd085eb3d1226716497eaf770bb677f75f1c0 Mon Sep 17 00:00:00 2001 From: danielfromearth Date: Tue, 19 Sep 2023 14:38:48 -0400 Subject: [PATCH] updated typing with mypy --- ncompare/core.py | 55 ++++++++++++++++++--------------- ncompare/printing.py | 13 +++----- ncompare/sequence_operations.py | 8 +++-- ncompare/utils.py | 2 +- pyproject.toml | 15 +++++---- 5 files changed, 48 insertions(+), 45 deletions(-) diff --git a/ncompare/core.py b/ncompare/core.py index 9ae7ddc..bfe46f8 100644 --- a/ncompare/core.py +++ b/ncompare/core.py @@ -8,7 +8,7 @@ from collections import namedtuple from collections.abc import Iterable from pathlib import Path -from typing import Union +from typing import Optional, Union import netCDF4 import numpy as np @@ -25,14 +25,14 @@ def compare( nc_a: Union[str, Path], nc_b: Union[str, Path], - comparison_var_group: str = None, - comparison_var_name: str = None, + comparison_var_group: Optional[str] = None, + comparison_var_name: Optional[str] = None, no_color: bool = False, show_chunks: bool = False, show_attributes: bool = False, - file_text: str = None, - file_csv: str = None, - file_xlsx: str = None, + file_text: Union[str, Path] = "", + file_csv: Union[str, Path] = "", + file_xlsx: Union[str, Path] = "", ) -> None: """Compare the variables contained within two different NetCDF datasets. @@ -103,8 +103,8 @@ def run_through_comparisons( out: Outputter, nc_a: Union[str, Path], nc_b: Union[str, Path], - comparison_var_group: str, - comparison_var_name: str, + comparison_var_group: Optional[str], + comparison_var_name: Optional[str], show_chunks: bool, show_attributes: bool, ) -> None: @@ -164,8 +164,9 @@ def run_through_comparisons( out.print( Style.BRIGHT + Fore.RED - + "\nError when comparing values for variable <%s> in group <%s>." - % (comparison_var_name, comparison_var_group) + + "\nError when comparing values for variable <{}> in group <{}>.".format( + comparison_var_name, comparison_var_group + ) ) out.print(traceback.format_exc()) out.print("\n") @@ -181,7 +182,11 @@ def run_through_comparisons( def compare_multiple_random_values( - out: Outputter, nc_a: Path, nc_b: Path, groupname: str, num_comparisons: int = 100 + out: Outputter, + nc_a: Union[str, Path], + nc_b: Union[str, Path], + groupname: str, + num_comparisons: int = 100, ): """Iterate through N random samples, and evaluate whether the differences exceed a threshold.""" # Open a variable from each NetCDF @@ -206,7 +211,7 @@ def compare_multiple_random_values( out.print("Done.", colors=False) -def walk_common_groups_tree( +def walk_common_groups_tree( # type:ignore[misc] top_a_name: str, top_a: Union[netCDF4.Dataset, netCDF4.Group], top_b_name: str, @@ -257,8 +262,8 @@ def walk_common_groups_tree( def compare_two_nc_files( out: Outputter, - nc_one: Path, - nc_two: Path, + nc_one: Union[str, Path], + nc_two: Union[str, Path], show_chunks: bool = False, show_attributes: bool = False, ) -> tuple[int, int, int]: @@ -322,7 +327,7 @@ def _print_group_details_side_by_side( ) # Count the number of variables in this group as long as this group exists. - vars_a_sorted, vars_b_sorted = "", "" + vars_a_sorted, vars_b_sorted = [""], [""] if group_a: vars_a_sorted = sorted(group_a.variables) if group_b: @@ -453,11 +458,11 @@ def _match_random_value( rand_index = [] for dim_length in nc_var_a.shape: rand_index.append(random.randint(0, dim_length - 1)) - rand_index = tuple(rand_index) + rand_index_tuple = tuple(rand_index) # Get the values from each variable - value_a = nc_var_a.values[rand_index] - value_b = nc_var_b.values[rand_index] + value_a = nc_var_a.values[rand_index_tuple] + value_b = nc_var_b.values[rand_index_tuple] # Check whether null if np.isnan(value_a) or np.isnan(value_b): @@ -469,7 +474,7 @@ def _match_random_value( out.print() out.print(Fore.RED + f"Difference exceeded threshold (diff == {diff}") out.print(f"var shape: {nc_var_a.shape}", colors=False) - out.print(f"indices: {rand_index}", colors=False) + out.print(f"indices: {rand_index_tuple}", colors=False) out.print(f"value a: {value_a}", colors=False) out.print(f"value b: {value_b}", colors=False, end="\n\n") return False @@ -479,7 +484,7 @@ def _match_random_value( def _print_sample_values(out: Outputter, nc_filepath, groupname: str, varname: str) -> None: comparison_variable = xr.open_dataset(nc_filepath, backend_kwargs={"group": groupname})[varname] - out.print(comparison_variable.values[0, :], colors=False) + out.print(str(comparison_variable.values[0, :]), colors=False) def _get_attribute_value_as_str(varprops: VarProperties, attribute_key: str) -> str: @@ -490,31 +495,31 @@ def _get_attribute_value_as_str(varprops: VarProperties, attribute_key: str) -> # we are preventing any subsequent difference checker from detecting # differences past the 5th element in the iterable. # So, we need to figure out a way to still check for other differences past the 5th element. - return "[" + ", ".join([str(x) for x in attr[:5]]) + ", ..." + "]" + return "[" + ", ".join([str(x) for x in attr[:5]]) + ", ..." + "]" # type:ignore[index] return str(attr) return "" -def _get_vars(nc_filepath: Path, groupname: str) -> list: +def _get_vars(nc_filepath: Union[str, Path], groupname: str) -> list: try: grp = xr.open_dataset(nc_filepath, backend_kwargs={"group": groupname}) except OSError as err: print("\nError occurred when attempting to open group within <%s>.\n" % nc_filepath) raise err - grp_varlist = sorted(list(grp.variables.keys())) + grp_varlist = sorted(list(grp.variables.keys())) # type:ignore[type-var] return grp_varlist -def _get_groups(nc_filepath: Path) -> list: +def _get_groups(nc_filepath: Union[str, Path]) -> list: with netCDF4.Dataset(nc_filepath) as dataset: groups_list = list(dataset.groups.keys()) return groups_list -def _get_dims(nc_filepath: Path) -> list: +def _get_dims(nc_filepath: Union[str, Path]) -> list: def __get_dim_list(decode_times=True): with xr.open_dataset(nc_filepath, decode_times=decode_times) as dataset: return list(dataset.dims.items()) diff --git a/ncompare/printing.py b/ncompare/printing.py index 3c49411..179ef08 100644 --- a/ncompare/printing.py +++ b/ncompare/printing.py @@ -4,7 +4,7 @@ import re from collections.abc import Iterable from pathlib import Path -from typing import Union +from typing import Optional, TextIO, Union import colorama import openpyxl @@ -41,7 +41,7 @@ def __init__( self, keep_print_history: bool = False, no_color: bool = False, - text_file: Union[str, Path] = None, + text_file: Optional[Union[str, Path]] = None, ): """Set up the handling of printing and saving destinations. @@ -51,10 +51,7 @@ def __init__( """ # Parse the print history option. self._keep_print_history = keep_print_history - if self._keep_print_history: - self._line_history = [] - else: - self._line_history = None + self._line_history: list[list[str]] = [] if no_color: # Replace colorized styles with blank strings. @@ -71,7 +68,7 @@ def __init__( if filepath.exists(): pass # This will overwrite any existing file at this path, if one exists. - self._text_file_obj = open( + self._text_file_obj: Optional[TextIO] = open( filepath, "w", encoding="utf-8" ) # pylint: disable=consider-using-with else: @@ -145,7 +142,7 @@ def _parse_single_str(s): # pylint: disable=invalid-name else: raise TypeError(f"Invalid type <{type(args)}>. Expected a `str` or `list`.") - if self._line_history is not None: + if self._keep_print_history: self._line_history.append(parsed_strings) @staticmethod diff --git a/ncompare/sequence_operations.py b/ncompare/sequence_operations.py index 8b6cc79..2f109a2 100644 --- a/ncompare/sequence_operations.py +++ b/ncompare/sequence_operations.py @@ -1,11 +1,13 @@ """Helper functions for operating on iterables, such as lists or sets.""" -from collections.abc import Iterable +from collections.abc import Generator, Iterable from typing import Union from ncompare.utils import coerce_to_str -def common_elements(sequence_a: Iterable, sequence_b: Iterable) -> tuple[int, str, str]: +def common_elements( + sequence_a: Iterable, sequence_b: Iterable +) -> Generator[tuple[int, str, str], None, None]: """Loop over combined items of two iterables, and yield aligned item pairs. Note @@ -43,7 +45,7 @@ def common_elements(sequence_a: Iterable, sequence_b: Iterable) -> tuple[int, st def count_diffs( - list_a: list[Union[str, int]], list_b: list[Union[str, int]] + list_a: Union[list[str], list[int]], list_b: Union[list[str], list[int]] ) -> tuple[int, int, int]: """Count how many elements are either uniquely in one list or the other, or in both. diff --git a/ncompare/utils.py b/ncompare/utils.py index 75b7b15..716d71f 100644 --- a/ncompare/utils.py +++ b/ncompare/utils.py @@ -23,7 +23,7 @@ def ensure_valid_path_exists(should_be_path: Union[str, Path]) -> Path: raise TypeError(wrong_type_msg + str(type(should_be_path))) -def ensure_valid_path_with_suffix(should_be_path: Union[str, Path], suffix: str = None) -> Path: +def ensure_valid_path_with_suffix(should_be_path: Union[str, Path], suffix: str) -> Path: """Coerce input to a pathlib.Path with given suffix.""" wrong_type_msg = "Unexpected type for something that should be convertable to a Path: " diff --git a/pyproject.toml b/pyproject.toml index eaaba0d..318db0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,14 +31,13 @@ mypy = "^1.5.1" requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" -[tool.isort] -profile = "black" -line_length = 120 -multi_line_output = 3 -include_trailing_comma = "true" -default_section = "THIRDPARTY" -lines_after_imports = 2 -combine_as_imports = "true" +[[tool.mypy.overrides]] +module = [ + "colorama.*", + "netCDF4.*", + "openpyxl.*" +] +ignore_missing_imports = true [tool.black] line-length = 100