diff --git a/alibi_detect/saving/loading.py b/alibi_detect/saving/loading.py index 03caa4bd8..1f797be79 100644 --- a/alibi_detect/saving/loading.py +++ b/alibi_detect/saving/loading.py @@ -74,7 +74,8 @@ ] -def load_detector(filepath: Union[str, os.PathLike], **kwargs) -> Union[Detector, ConfigurableDetector]: +def load_detector(filepath: Union[str, os.PathLike], enable_unsafe_loading: bool = False, + **kwargs) -> Union[Detector, ConfigurableDetector]: """ Load outlier, drift or adversarial detector. @@ -82,6 +83,11 @@ def load_detector(filepath: Union[str, os.PathLike], **kwargs) -> Union[Detector ---------- filepath Load directory. + enable_unsafe_loading + Sets allow_pickle=True when a np.ndarray is loaded from a .npy file referenced in the detector config. Needed + if you have to load objects. + Only applied if the filepath is config.toml or a directory containing a config.toml. + It has security implications: https://nvd.nist.gov/vuln/detail/cve-2019-6446. Returns ------- @@ -90,13 +96,13 @@ def load_detector(filepath: Union[str, os.PathLike], **kwargs) -> Union[Detector filepath = Path(filepath) # If reference is a 'config.toml' itself, pass to new load function if filepath.name == 'config.toml': - return _load_detector_config(filepath) + return _load_detector_config(filepath, enable_unsafe_loading=enable_unsafe_loading) # Otherwise, if a directory, look for meta.dill, meta.pickle or config.toml inside it elif filepath.is_dir(): files = [str(f.name) for f in filepath.iterdir() if f.is_file()] if 'config.toml' in files: - return _load_detector_config(filepath.joinpath('config.toml')) + return _load_detector_config(filepath.joinpath('config.toml'), enable_unsafe_loading=enable_unsafe_loading) elif 'meta.dill' in files: return load_detector_legacy(filepath, '.dill', **kwargs) elif 'meta.pickle' in files: @@ -110,7 +116,8 @@ def load_detector(filepath: Union[str, os.PathLike], **kwargs) -> Union[Detector # TODO - will eventually become load_detector -def _load_detector_config(filepath: Union[str, os.PathLike]) -> ConfigurableDetector: +def _load_detector_config(filepath: Union[str, os.PathLike], enable_unsafe_loading: bool = False) \ + -> ConfigurableDetector: """ Loads a drift detector specified in a detector config dict. Validation is performed with pydantic. @@ -118,6 +125,10 @@ def _load_detector_config(filepath: Union[str, os.PathLike]) -> ConfigurableDete ---------- filepath Filepath to the `config.toml` file. + enable_unsafe_loading + Sets allow_pickle=True when a np.ndarray is loaded from a .npy file (happens if the .toml references one). + Needed if you have to load objects. + It has security implications: https://nvd.nist.gov/vuln/detail/cve-2019-6446 Returns ------- @@ -134,7 +145,7 @@ def _load_detector_config(filepath: Union[str, os.PathLike]) -> ConfigurableDete # Resolve and validate config cfg = validate_config(cfg) logger.info('Validated unresolved config.') - cfg = resolve_config(cfg, config_dir=config_dir) + cfg = resolve_config(cfg, config_dir=config_dir, enable_unsafe_loading=enable_unsafe_loading) cfg = validate_config(cfg, resolved=True) logger.info('Validated resolved config.') @@ -453,7 +464,7 @@ def read_config(filepath: Union[os.PathLike, str]) -> dict: return cfg -def resolve_config(cfg: dict, config_dir: Optional[Path]) -> dict: +def resolve_config(cfg: dict, config_dir: Optional[Path], enable_unsafe_loading: bool = False) -> dict: """ Resolves artefacts in a config dict. For example x_ref='x_ref.npy' is resolved by loading the np.ndarray from the .npy file. For a list of fields that are resolved, see @@ -466,6 +477,9 @@ def resolve_config(cfg: dict, config_dir: Optional[Path]) -> dict: config_dir Filepath to directory the `config.toml` is located in. Only required if different from the runtime directory, and artefacts are specified with filepaths relative to the config.toml file. + enable_unsafe_loading + If set to true, allow_pickle=True is set in np.load(). Needed if you have to load objects. + It has security implications: https://nvd.nist.gov/vuln/detail/cve-2019-6446 Returns ------- @@ -506,7 +520,7 @@ def resolve_config(cfg: dict, config_dir: Optional[Path]) -> dict: if Path(src).suffix == '.dill': obj = dill.load(open(src, 'rb')) if Path(src).suffix == '.npy': - obj = np.load(src) + obj = np.load(src, allow_pickle=enable_unsafe_loading) # Resolve artefact dicts elif isinstance(src, dict): diff --git a/alibi_detect/saving/tests/test_loading.py b/alibi_detect/saving/tests/test_loading.py new file mode 100644 index 000000000..ae5043081 --- /dev/null +++ b/alibi_detect/saving/tests/test_loading.py @@ -0,0 +1,57 @@ +import numpy as np +import pytest + +from alibi_detect.cd import TabularDrift +from alibi_detect.saving import save_detector, load_detector + + +def test_loading_detector_with_plain_data(tmp_path): + data = np.array([[42, 1, 1234077000], [42, 2, 1234088000], [42, 3, 1234099000]]) + p_val = 0.42 + + create_and_save_detector(data, p_val, tmp_path) + + loaded_detector = load_detector(tmp_path) + + assert loaded_detector.get_config()["name"] == "TabularDrift" + assert loaded_detector.get_config()["p_val"] == p_val + np.testing.assert_array_equal(loaded_detector.get_config()["x_ref"], data) + + +def test_loading_detector_with_data_containing_objects_throws_exception(tmp_path): + data = np.array([['42', 1, 1234077000], ['42', 2, 1234088000], ['42', 3, 1234099000]], dtype=object) + p_val = 0.42 + + create_and_save_detector(data, p_val, tmp_path) + + with pytest.raises(Exception) as ex_info: + _ = load_detector(tmp_path) + + assert ex_info.typename == "ValueError" + assert ex_info.value.args[0] == "Object arrays cannot be loaded when allow_pickle=False" + + +def test_loading_detector_with_data_containing_objects(tmp_path): + data = np.array([['42', 1, 1234077000], ['42', 2, 1234088000], ['42', 3, 1234099000]], dtype=object) + p_val = 0.42 + + create_and_save_detector(data, p_val, tmp_path) + + loaded_detector = load_detector(tmp_path, enable_unsafe_loading=True) + + assert loaded_detector.get_config()["name"] == "TabularDrift" + assert loaded_detector.get_config()["p_val"] == p_val + np.testing.assert_array_equal(loaded_detector.get_config()["x_ref"], data) + + +def create_and_save_detector(data: np.ndarray, p_val: float, path): + detector = TabularDrift( + x_ref=data, + p_val=p_val, + x_ref_preprocessed=True + ) + + save_detector( + detector, + path + )