diff --git a/splink/exceptions.py b/splink/exceptions.py index 66fb327859..b5c1257992 100644 --- a/splink/exceptions.py +++ b/splink/exceptions.py @@ -36,6 +36,10 @@ class SplinkDeprecated(DeprecationWarning): pass +class InvalidSplinkInput(SplinkException): + pass + + class InvalidDialect(SplinkException): pass diff --git a/splink/linker.py b/splink/linker.py index b4163b67f8..da65680b78 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -10,6 +10,7 @@ from copy import copy, deepcopy from pathlib import Path from statistics import median +from typing import Any, Dict import sqlglot @@ -19,8 +20,9 @@ SettingsColumnCleaner, ) from splink.settings_validation.valid_types import ( + _check_input_dataframes_for_single_comparison_column, + _log_comparison_errors, _validate_dialect, - log_comparison_errors, ) from .accuracy import ( @@ -497,7 +499,26 @@ def _check_for_valid_settings(self): else: return True - def _validate_settings(self, validate_settings): + def _validate_settings_dictionary( + self, validate_settings: bool, settings_dict: Dict[Any] + ): + if settings_dict is None: + return + + if validate_settings: + _check_input_dataframes_for_single_comparison_column( + self._input_tables_dict, + source_dataset_column_name=settings_dict.get( + "source_dataset_column_name" + ), + unique_id_column_name=settings_dict.get("unique_id_column_name"), + ) + # Check the user's comparisons (if they exist) + _log_comparison_errors( + settings_dict.get("comparisons"), settings_dict.get("sql_dialect") + ) + + def _validate_settings_object(self, validate_settings: bool): # Vaidate our settings after plugging them through # `Settings()` if not self._check_for_valid_settings(): @@ -515,7 +536,7 @@ def _validate_settings(self, validate_settings): # Constructs output logs for our various settings inputs cleaned_settings = SettingsColumnCleaner( settings_object=self._settings_obj, - input_columns=self._input_tables_dict, + splink_input_table_dfs=self._input_tables_dict, ) InvalidColumnsLogger(cleaned_settings).construct_output_logs(validate_settings) @@ -1133,11 +1154,10 @@ def load_settings( settings_dict["sql_dialect"] = sql_dialect settings_dict["linker_uid"] = settings_dict.get("linker_uid", cache_uid) - # Check the user's comparisons (if they exist) - log_comparison_errors(settings_dict.get("comparisons"), sql_dialect) + self._validate_settings_dictionary(validate_settings, settings_dict) self._settings_obj_ = Settings(settings_dict) # Check the final settings object - self._validate_settings(validate_settings) + self._validate_settings_object(validate_settings) def load_model(self, model_path: Path): """ diff --git a/splink/predict.py b/splink/predict.py index 3a7f7555a1..efc9e18def 100644 --- a/splink/predict.py +++ b/splink/predict.py @@ -55,7 +55,7 @@ def predict_from_comparison_vectors_sqls( thres_prob_as_weight = prob_to_match_weight(threshold_match_probability) else: thres_prob_as_weight = None - if threshold_match_probability or threshold_match_weight: + if threshold_match_probability is not None or threshold_match_weight is not None: thresholds = [ thres_prob_as_weight, threshold_match_weight, diff --git a/splink/settings_validation/settings_column_cleaner.py b/splink/settings_validation/settings_column_cleaner.py index 4c99e1c60b..054ab5a6bf 100644 --- a/splink/settings_validation/settings_column_cleaner.py +++ b/splink/settings_validation/settings_column_cleaner.py @@ -5,7 +5,7 @@ from copy import deepcopy from functools import reduce from operator import and_ -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, Dict, List import sqlglot @@ -15,6 +15,7 @@ if TYPE_CHECKING: from ..settings import Settings + from ..splink_dataframe import SplinkDataFrame def remove_suffix(c): @@ -28,7 +29,7 @@ def find_columns_not_in_input_dfs( does not apply any cleaning to the input column(s). """ # the key to use when producing our warning logs - if type(columns_to_check) == str: + if isinstance(columns_to_check, str): columns_to_check = [columns_to_check] return {col for col in columns_to_check if col not in valid_input_dataframe_columns} @@ -81,7 +82,9 @@ def clean_list_of_column_names(col_list: List[InputColumn]): return set((c.unquote().name for c in col_list)) -def clean_user_input_columns(input_columns: dict, return_as_single_column: bool = True): +def clean_user_input_columns( + input_columns: Dict[str, "SplinkDataFrame"], return_as_single_column: bool = True +): """A dictionary containing all input dataframes and the columns located within. @@ -104,11 +107,11 @@ class SettingsColumnCleaner: cleaned up settings columns and SQL strings. """ - def __init__(self, settings_object: Settings, input_columns: dict): + def __init__(self, settings_object: Settings, splink_input_table_dfs: dict): self.sql_dialect = settings_object._sql_dialect self._settings_obj = settings_object self.input_columns = clean_user_input_columns( - input_columns.items(), return_as_single_column=True + splink_input_table_dfs.items(), return_as_single_column=True ) @property diff --git a/splink/settings_validation/settings_validation_log_strings.py b/splink/settings_validation/settings_validation_log_strings.py index 5a051b3d66..f236d02abf 100644 --- a/splink/settings_validation/settings_validation_log_strings.py +++ b/splink/settings_validation/settings_validation_log_strings.py @@ -1,5 +1,5 @@ from functools import partial -from typing import List, NamedTuple, Tuple +from typing import Dict, List, NamedTuple, Tuple def indent_error_message(message): @@ -200,3 +200,21 @@ def create_incorrect_dialect_import_log_string( "for your specified linker.\n" ) return indent_error_message(log_message) + + +def construct_single_dataframe_log_str(input_columns: Dict[str, str]) -> str: + if len(input_columns) == 1: + df_txt = "dataframe is" + else: + df_txt = "dataframes are" + + log_message = ( + f"\nThe provided {df_txt} unsuitable for linkage with Splink as\n" + "it contains only a single column for matching.\n" + "Splink is not designed for linking based on a single 'bag of words'\n" + "column, such as a table with only a 'company name' column and\n" + "no other details.\n\nFor more information see: \n" + "https://github.com/moj-analytical-services/splink/issues/1362" + ) + + return log_message diff --git a/splink/settings_validation/valid_types.py b/splink/settings_validation/valid_types.py index 56ffa12d5d..f151c3058d 100644 --- a/splink/settings_validation/valid_types.py +++ b/splink/settings_validation/valid_types.py @@ -1,18 +1,30 @@ from __future__ import annotations import logging -from typing import Dict, Union +from typing import TYPE_CHECKING, Dict, List, Union from ..comparison import Comparison from ..comparison_level import ComparisonLevel -from ..exceptions import ComparisonSettingsException, ErrorLogger, InvalidDialect +from ..default_from_jsonschema import default_value_from_schema +from ..exceptions import ( + ComparisonSettingsException, + ErrorLogger, + InvalidDialect, + InvalidSplinkInput, +) +from .settings_column_cleaner import clean_user_input_columns from .settings_validation_log_strings import ( + construct_single_dataframe_log_str, create_incorrect_dialect_import_log_string, create_invalid_comparison_level_log_string, create_invalid_comparison_log_string, create_no_comparison_levels_error_log_string, ) +if TYPE_CHECKING: + from ..splink_dataframe import SplinkDataFrame + + logger = logging.getLogger(__name__) @@ -24,9 +36,6 @@ def extract_sql_dialect_from_cll(cll): def _validate_dialect(settings_dialect: str, linker_dialect: str, linker_type: str): - # settings_dialect = self.linker._settings_obj._sql_dialect - # linker_dialect = self.linker._sql_dialect - # linker_type = self.linker.__class__.__name__ if settings_dialect != linker_dialect: raise ValueError( f"Incompatible SQL dialect! `settings` dictionary uses " @@ -35,6 +44,34 @@ def _validate_dialect(settings_dialect: str, linker_dialect: str, linker_type: s ) +def _check_input_dataframes_for_single_comparison_column( + input_columns: Dict[str, "SplinkDataFrame"], + source_dataset_column_name: str = None, + unique_id_column_name: str = None, +): + if source_dataset_column_name is None: + source_dataset_column_name = default_value_from_schema( + "source_dataset_column_name", "root" + ) + if unique_id_column_name is None: + unique_id_column_name = default_value_from_schema( + "unique_id_column_name", "root" + ) + + input_columns = clean_user_input_columns( + input_columns.items(), return_as_single_column=False + ) + + required_cols = (source_dataset_column_name, unique_id_column_name) + + # Loop and exit if any dataframe has only possible comparison column + for columns in input_columns.values(): + unique_columns = set(columns) - set(required_cols) + + if len(unique_columns) == 1: + raise InvalidSplinkInput(construct_single_dataframe_log_str(input_columns)) + + def validate_comparison_levels( error_logger: ErrorLogger, comparisons: list, linker_dialect: str ): @@ -53,40 +90,12 @@ def validate_comparison_levels( # If no error is found, append won't do anything error_logger.log_error(evaluate_comparison_dtype_and_contents(c_dict)) error_logger.log_error( - evaluate_comparisons_for_imports_from_incorrect_dialects( - c_dict, linker_dialect - ) + check_comparison_imported_for_correct_dialect(c_dict, linker_dialect) ) return error_logger -def log_comparison_errors(comparisons, linker_dialect): - """ - Log any errors arising from `validate_comparison_levels`. - """ - - # Check for empty inputs - Expecting None or [] - if not comparisons: - return - - error_logger = ErrorLogger() - - error_logger = validate_comparison_levels(error_logger, comparisons, linker_dialect) - - # Raise and log any errors identified - plural_this = "this" if len(error_logger.raw_errors) == 1 else "these" - comp_hyperlink_txt = ( - f"\nFor more info on how to construct comparisons and avoid {plural_this} " - "error, please visit:\n" - "https://moj-analytical-services.github.io/splink/topic_guides/comparisons/customising_comparisons.html" - ) - - error_logger.raise_and_log_all_errors( - exception=ComparisonSettingsException, additional_txt=comp_hyperlink_txt - ) - - def check_comparison_level_types( comparison_levels: Union[Comparison, Dict], comparison_str: str ): @@ -146,9 +155,7 @@ def evaluate_comparison_dtype_and_contents(comparison_dict): return check_comparison_level_types(comp_levels, comp_str) -def evaluate_comparisons_for_imports_from_incorrect_dialects( - comparison_dict, sql_dialect -): +def check_comparison_imported_for_correct_dialect(comparison_dict, sql_dialect): """ Given a comparison_dict, assess whether the sql dialect is valid for your selected linker. @@ -198,3 +205,29 @@ def evaluate_comparisons_for_imports_from_incorrect_dialects( comp_str, sorted(invalid_dialects) ) return InvalidDialect(error_message) + + +def _log_comparison_errors(comparisons: List[Comparison], linker_dialect: str): + """ + Log any errors arising from various comparison validation checks. + """ + + # Check for empty inputs - Expecting None or [] + if not comparisons: + return + + error_logger = ErrorLogger() + + error_logger = validate_comparison_levels(error_logger, comparisons, linker_dialect) + + # Raise and log any errors identified + plural_this = "this" if len(error_logger.raw_errors) == 1 else "these" + comp_hyperlink_txt = ( + f"\nFor more info on how to construct comparisons and avoid {plural_this} " + "error, please visit:\n" + "https://moj-analytical-services.github.io/splink/topic_guides/comparisons/customising_comparisons.html" + ) + + error_logger.raise_and_log_all_errors( + exception=ComparisonSettingsException, additional_txt=comp_hyperlink_txt + ) diff --git a/tests/helpers.py b/tests/helpers.py index dc0b63750c..d6f990a8e9 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -38,7 +38,7 @@ def Linker(self): pass def extra_linker_args(self): - return {} + return {"validate_settings": False} @property def date_format(self): @@ -113,7 +113,8 @@ def Linker(self): return SparkLinker def extra_linker_args(self): - return {"spark": self.spark, "num_partitions_on_repartition": 1} + core_args = super().extra_linker_args() + return {"spark": self.spark, "num_partitions_on_repartition": 1, **core_args} def convert_frame(self, df): spark_frame = self.spark.createDataFrame(df) @@ -159,7 +160,8 @@ def Linker(self): return SQLiteLinker def extra_linker_args(self): - return {"connection": self.con} + core_args = super().extra_linker_args() + return {"connection": self.con, **core_args} @classmethod def _get_input_name(cls): @@ -208,7 +210,8 @@ def Linker(self): return PostgresLinker def extra_linker_args(self): - return {"engine": self.engine} + core_args = super().extra_linker_args() + return {"engine": self.engine, **core_args} @classmethod def _get_input_name(cls): diff --git a/tests/test_caching_tables.py b/tests/test_caching_tables.py index 8ef3ea8526..9a3a3787c3 100644 --- a/tests/test_caching_tables.py +++ b/tests/test_caching_tables.py @@ -30,7 +30,7 @@ def test_cache_tracking_works(): "blocking_rules_to_generate_predictions": ["l.name = r.name"], } - linker = DuckDBLinker(df, settings) + linker = DuckDBLinker(df, settings, validate_settings=False) cache = linker._intermediate_table_cache assert cache.is_in_executed_queries("__splink__df_concat_with_tf") is False @@ -88,7 +88,7 @@ def test_cache_used_when_registering_nodes_table(): "blocking_rules_to_generate_predictions": ["l.name = r.name"], } - linker = DuckDBLinker(df, settings) + linker = DuckDBLinker(df, settings, validate_settings=False) cache = linker._intermediate_table_cache linker.register_table_input_nodes_concat_with_tf(splink__df_concat_with_tf) linker.estimate_u_using_random_sampling(target_rows=1e4) @@ -137,7 +137,7 @@ def test_cache_used_when_registering_tf_tables(): } # First test do not register any tf tables - linker = DuckDBLinker(df, settings) + linker = DuckDBLinker(df, settings, validate_settings=False) cache = linker._intermediate_table_cache linker.estimate_u_using_random_sampling(target_rows=1e4) @@ -210,7 +210,7 @@ def test_cache_invalidation(): "blocking_rules_to_generate_predictions": ["l.name = r.name"], } - linker = DuckDBLinker(df, settings) + linker = DuckDBLinker(df, settings, validate_settings=False) cache = linker._intermediate_table_cache linker.compute_tf_table("name") @@ -222,7 +222,7 @@ def test_cache_invalidation(): assert len_before == len_after assert cache.is_in_queries_retrieved_from_cache("__splink__df_tf_name") - linker = DuckDBLinker(df, settings) + linker = DuckDBLinker(df, settings, validate_settings=False) cache = linker._intermediate_table_cache linker.compute_tf_table("name") @@ -253,7 +253,7 @@ def test_table_deletions(): "blocking_rules_to_generate_predictions": ["l.name = r.name"], } - linker = DuckDBLinker("my_table", settings, connection=con) + linker = DuckDBLinker("my_table", settings, connection=con, validate_settings=False) table_names_before = set(get_duckdb_table_names_as_list(linker._con)) @@ -299,7 +299,9 @@ def test_table_deletions_with_preregistered(): "blocking_rules_to_generate_predictions": ["l.name = r.name"], } - linker = DuckDBLinker("my_data_table", settings, connection=con) + linker = DuckDBLinker( + "my_data_table", settings, connection=con, validate_settings=False + ) linker.register_table_input_nodes_concat_with_tf("my_nodes_with_tf_table") table_names_before = set(get_duckdb_table_names_as_list(linker._con)) @@ -332,7 +334,7 @@ def test_single_deletion(): "blocking_rules_to_generate_predictions": ["l.name = r.name"], } - linker = DuckDBLinker(df, settings) + linker = DuckDBLinker(df, settings, validate_settings=False) cache = linker._intermediate_table_cache tf_table = linker.compute_tf_table("name") diff --git a/tests/test_comparison_lib.py b/tests/test_comparison_lib.py index da899376c2..82ecab128d 100644 --- a/tests/test_comparison_lib.py +++ b/tests/test_comparison_lib.py @@ -78,7 +78,7 @@ def test_set_to_lowercase_parameter(): df = pd.DataFrame(data) - linker = DuckDBLinker(df, settings) + linker = DuckDBLinker(df, settings, validate_settings=False) df_e = linker.predict().as_pandas_dataframe() row = dict(df_e.query("id_l == 1 and id_r == 2").iloc[0]) diff --git a/tests/test_comparison_template_lib.py b/tests/test_comparison_template_lib.py index ad5224a500..ef926f91d7 100644 --- a/tests/test_comparison_template_lib.py +++ b/tests/test_comparison_template_lib.py @@ -5,11 +5,10 @@ import splink.spark.comparison_template_library as ctls from splink.duckdb.linker import DuckDBLinker from splink.spark.linker import SparkLinker +from tests.decorator import mark_with_dialects_excluding ## date_comparison - - @pytest.mark.parametrize( ("ctl"), [ @@ -486,14 +485,11 @@ def test_postcode_comparison_levels(spark, ctl, Linker, test_gamma_assert): test_gamma_assert(linker_output, size_gamma_lookup, col_name) -@pytest.mark.parametrize( - ("ctl", "Linker"), - [ - pytest.param(ctld, DuckDBLinker, id="DuckDB Email Comparison Template Test"), - pytest.param(ctls, SparkLinker, id="Spark Email Comparison Template Test"), - ], -) -def test_email_comparison_levels(spark, ctl, Linker, test_gamma_assert): +@mark_with_dialects_excluding("sqlite", "postgres") +def test_email_comparison_levels(test_helpers, dialect, test_gamma_assert): + helper = test_helpers[dialect] + ctl = helper.ctl + col_name = "email" df = pd.DataFrame( @@ -527,11 +523,8 @@ def test_email_comparison_levels(spark, ctl, Linker, test_gamma_assert): ], } - if Linker == SparkLinker: - df = spark.createDataFrame(df) - df.persist() - - linker = Linker(df, settings) + df = helper.convert_frame(df) + linker = helper.Linker(df, settings, **helper.extra_linker_args()) linker_output = linker.predict().as_pandas_dataframe() # Check individual IDs are assigned to the correct gamma values diff --git a/tests/test_disable_tf_exact_match_detection.py b/tests/test_disable_tf_exact_match_detection.py index 521ed1dd34..e01eeae4a7 100644 --- a/tests/test_disable_tf_exact_match_detection.py +++ b/tests/test_disable_tf_exact_match_detection.py @@ -4,7 +4,6 @@ def test_disable_tf_exact_match_detection(): - settings = Settings({"link_type": "dedupe_only"}) comparison_normal_dict = { diff --git a/tests/test_settings_validation.py b/tests/test_settings_validation.py index 676e30c1c8..db266ade3a 100644 --- a/tests/test_settings_validation.py +++ b/tests/test_settings_validation.py @@ -1,5 +1,6 @@ import logging import re +from typing import Dict, List import pandas as pd import pytest @@ -9,7 +10,7 @@ from splink.duckdb.blocking_rule_library import block_on from splink.duckdb.comparison_library import levenshtein_at_thresholds from splink.duckdb.linker import DuckDBLinker -from splink.exceptions import ErrorLogger +from splink.exceptions import ErrorLogger, InvalidSplinkInput from splink.settings_validation.log_invalid_columns import ( InvalidColumnSuffixesLogGenerator, InvalidTableNamesLogGenerator, @@ -20,7 +21,8 @@ validate_table_names, ) from splink.settings_validation.valid_types import ( - log_comparison_errors, + _check_input_dataframes_for_single_comparison_column, + _log_comparison_errors, validate_comparison_levels, ) @@ -142,6 +144,30 @@ ) +class MockSplinkDataFrame: + def __init__(self, columns: List["MockInputColumn"]): + self.columns = columns + + +class MockInputColumn: + def __init__(self, name: str): + self.name = name + + def unquote(self): + return self + + +@pytest.fixture +def mock_input_columns(): + def _mock(columns_dict: Dict[str, List[str]]): + return { + key: MockSplinkDataFrame([MockInputColumn(col) for col in columns]) + for key, columns in columns_dict.items() + } + + return _mock + + @pytest.mark.parametrize( "input_name, expected_output", missing_settings_column_test_cases ) @@ -379,7 +405,7 @@ def test_comparison_validation(): } ) - log_comparison_errors(None, "duckdb") # confirm it works with None as an input... + _log_comparison_errors(None, "duckdb") # confirm it works with None as an input... # Init the error logger. This is normally handled in # `log_comparison_errors`, but here we want to capture the @@ -413,3 +439,23 @@ def test_comparison_validation(): else: with pytest.raises(e, match=txt): raise errors[n] + + +@pytest.mark.parametrize( + "input_columns_dict", + [ + { + "df1": ["first_name", "surname", "source_dataset", "unique_id"], + "df2": ["first_name", "surname", "source_dataset", "unique_id"], + "df3": ["first_name", "source_dataset", "unique_id"], + }, + {"df1": ["abcde", "source_dataset", "unique_id"]}, + ], +) +def test_input_datasetest_input_datasets_with_insufficient_columnsts_validation( + mock_input_columns, input_columns_dict +): + input_columns = mock_input_columns(input_columns_dict) + + with pytest.raises(InvalidSplinkInput): + _check_input_dataframes_for_single_comparison_column(input_columns) diff --git a/tests/test_term_frequencies.py b/tests/test_term_frequencies.py index 13cb274999..7d4d1d534e 100644 --- a/tests/test_term_frequencies.py +++ b/tests/test_term_frequencies.py @@ -79,7 +79,9 @@ def test_tf_basic(): "retain_intermediate_calculation_columns": True, } - linker = DuckDBLinker(data, settings, connection=":memory:") + linker = DuckDBLinker( + data, settings, connection=":memory:", validate_settings=False + ) df_predict = linker.predict() results = filter_results(df_predict) @@ -115,7 +117,9 @@ def test_tf_clamp(): "retain_intermediate_calculation_columns": True, } - linker = DuckDBLinker(data, settings, connection=":memory:") + linker = DuckDBLinker( + data, settings, connection=":memory:", validate_settings=False + ) df_predict = linker.predict() results = filter_results(df_predict) @@ -151,7 +155,9 @@ def test_weight(): "retain_intermediate_calculation_columns": True, } - linker = DuckDBLinker(data, settings, connection=":memory:") + linker = DuckDBLinker( + data, settings, connection=":memory:", validate_settings=False + ) df_predict = linker.predict() results = filter_results(df_predict) @@ -200,7 +206,9 @@ def test_weightand_clamp(): "retain_intermediate_calculation_columns": True, } - linker = DuckDBLinker(data, settings, connection=":memory:") + linker = DuckDBLinker( + data, settings, connection=":memory:", validate_settings=False + ) df_predict = linker.predict() results = filter_results(df_predict)