Skip to content

Commit

Permalink
Allow pickle when loading numpy array file (#836)
Browse files Browse the repository at this point in the history
* Allow pickle when loading numpy array file

Co-authored-by: tokaessm

* Add Opt-In for unsafe loading of np.ndarray

* Add tests for load_detector()

---------

Co-authored-by: tomglk <>
  • Loading branch information
tomglk authored Oct 4, 2023
1 parent 4a1b4f7 commit d188e02
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 7 deletions.
28 changes: 21 additions & 7 deletions alibi_detect/saving/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,20 @@
]


def load_detector(filepath: Union[str, os.PathLike], **kwargs) -> Union[Detector, ConfigurableDetector]:
def load_detector(filepath: Union[str, os.PathLike], enable_unsafe_loading: bool = False,
**kwargs) -> Union[Detector, ConfigurableDetector]:
"""
Load outlier, drift or adversarial detector.
Parameters
----------
filepath
Load directory.
enable_unsafe_loading
Sets allow_pickle=True when a np.ndarray is loaded from a .npy file referenced in the detector config. Needed
if you have to load objects.
Only applied if the filepath is config.toml or a directory containing a config.toml.
It has security implications: https://nvd.nist.gov/vuln/detail/cve-2019-6446.
Returns
-------
Expand All @@ -90,13 +96,13 @@ def load_detector(filepath: Union[str, os.PathLike], **kwargs) -> Union[Detector
filepath = Path(filepath)
# If reference is a 'config.toml' itself, pass to new load function
if filepath.name == 'config.toml':
return _load_detector_config(filepath)
return _load_detector_config(filepath, enable_unsafe_loading=enable_unsafe_loading)

# Otherwise, if a directory, look for meta.dill, meta.pickle or config.toml inside it
elif filepath.is_dir():
files = [str(f.name) for f in filepath.iterdir() if f.is_file()]
if 'config.toml' in files:
return _load_detector_config(filepath.joinpath('config.toml'))
return _load_detector_config(filepath.joinpath('config.toml'), enable_unsafe_loading=enable_unsafe_loading)
elif 'meta.dill' in files:
return load_detector_legacy(filepath, '.dill', **kwargs)
elif 'meta.pickle' in files:
Expand All @@ -110,14 +116,19 @@ def load_detector(filepath: Union[str, os.PathLike], **kwargs) -> Union[Detector


# TODO - will eventually become load_detector
def _load_detector_config(filepath: Union[str, os.PathLike]) -> ConfigurableDetector:
def _load_detector_config(filepath: Union[str, os.PathLike], enable_unsafe_loading: bool = False) \
-> ConfigurableDetector:
"""
Loads a drift detector specified in a detector config dict. Validation is performed with pydantic.
Parameters
----------
filepath
Filepath to the `config.toml` file.
enable_unsafe_loading
Sets allow_pickle=True when a np.ndarray is loaded from a .npy file (happens if the .toml references one).
Needed if you have to load objects.
It has security implications: https://nvd.nist.gov/vuln/detail/cve-2019-6446
Returns
-------
Expand All @@ -134,7 +145,7 @@ def _load_detector_config(filepath: Union[str, os.PathLike]) -> ConfigurableDete
# Resolve and validate config
cfg = validate_config(cfg)
logger.info('Validated unresolved config.')
cfg = resolve_config(cfg, config_dir=config_dir)
cfg = resolve_config(cfg, config_dir=config_dir, enable_unsafe_loading=enable_unsafe_loading)
cfg = validate_config(cfg, resolved=True)
logger.info('Validated resolved config.')

Expand Down Expand Up @@ -453,7 +464,7 @@ def read_config(filepath: Union[os.PathLike, str]) -> dict:
return cfg


def resolve_config(cfg: dict, config_dir: Optional[Path]) -> dict:
def resolve_config(cfg: dict, config_dir: Optional[Path], enable_unsafe_loading: bool = False) -> dict:
"""
Resolves artefacts in a config dict. For example x_ref='x_ref.npy' is resolved by loading the np.ndarray from
the .npy file. For a list of fields that are resolved, see
Expand All @@ -466,6 +477,9 @@ def resolve_config(cfg: dict, config_dir: Optional[Path]) -> dict:
config_dir
Filepath to directory the `config.toml` is located in. Only required if different from the
runtime directory, and artefacts are specified with filepaths relative to the config.toml file.
enable_unsafe_loading
If set to true, allow_pickle=True is set in np.load(). Needed if you have to load objects.
It has security implications: https://nvd.nist.gov/vuln/detail/cve-2019-6446
Returns
-------
Expand Down Expand Up @@ -506,7 +520,7 @@ def resolve_config(cfg: dict, config_dir: Optional[Path]) -> dict:
if Path(src).suffix == '.dill':
obj = dill.load(open(src, 'rb'))
if Path(src).suffix == '.npy':
obj = np.load(src)
obj = np.load(src, allow_pickle=enable_unsafe_loading)

# Resolve artefact dicts
elif isinstance(src, dict):
Expand Down
57 changes: 57 additions & 0 deletions alibi_detect/saving/tests/test_loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import numpy as np
import pytest

from alibi_detect.cd import TabularDrift
from alibi_detect.saving import save_detector, load_detector


def test_loading_detector_with_plain_data(tmp_path):
data = np.array([[42, 1, 1234077000], [42, 2, 1234088000], [42, 3, 1234099000]])
p_val = 0.42

create_and_save_detector(data, p_val, tmp_path)

loaded_detector = load_detector(tmp_path)

assert loaded_detector.get_config()["name"] == "TabularDrift"
assert loaded_detector.get_config()["p_val"] == p_val
np.testing.assert_array_equal(loaded_detector.get_config()["x_ref"], data)


def test_loading_detector_with_data_containing_objects_throws_exception(tmp_path):
data = np.array([['42', 1, 1234077000], ['42', 2, 1234088000], ['42', 3, 1234099000]], dtype=object)
p_val = 0.42

create_and_save_detector(data, p_val, tmp_path)

with pytest.raises(Exception) as ex_info:
_ = load_detector(tmp_path)

assert ex_info.typename == "ValueError"
assert ex_info.value.args[0] == "Object arrays cannot be loaded when allow_pickle=False"


def test_loading_detector_with_data_containing_objects(tmp_path):
data = np.array([['42', 1, 1234077000], ['42', 2, 1234088000], ['42', 3, 1234099000]], dtype=object)
p_val = 0.42

create_and_save_detector(data, p_val, tmp_path)

loaded_detector = load_detector(tmp_path, enable_unsafe_loading=True)

assert loaded_detector.get_config()["name"] == "TabularDrift"
assert loaded_detector.get_config()["p_val"] == p_val
np.testing.assert_array_equal(loaded_detector.get_config()["x_ref"], data)


def create_and_save_detector(data: np.ndarray, p_val: float, path):
detector = TabularDrift(
x_ref=data,
p_val=p_val,
x_ref_preprocessed=True
)

save_detector(
detector,
path
)

0 comments on commit d188e02

Please sign in to comment.