From 31a25d3c629cb46efdf5077a6bc634897f2b448d Mon Sep 17 00:00:00 2001 From: PE39806 <185931318+PE39806@users.noreply.github.com> Date: Tue, 17 Dec 2024 14:47:09 +0000 Subject: [PATCH 1/7] BAI-1540 minimal check for errors in modelscan results value --- backend/src/clients/modelScan.ts | 25 ++++++++++--------- .../src/connectors/fileScanning/modelScan.ts | 16 ++++++++++++ 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/backend/src/clients/modelScan.ts b/backend/src/clients/modelScan.ts index a71d48b8d..bd4a3b19f 100644 --- a/backend/src/clients/modelScan.ts +++ b/backend/src/clients/modelScan.ts @@ -33,18 +33,19 @@ interface ModelScanResponse { skipped_files: string[] } } - issues: [ - { - description: string - operator: string - module: string - source: string - scanner: string - severity: string - }, - ] - // TODO: currently unknown what this might look like - errors: object[] + issues: { + description: string + operator: string + module: string + source: string + scanner: string + severity: string + }[] + errors: { + category: string + description: string + source: string + }[] } export async function getModelScanInfo() { diff --git a/backend/src/connectors/fileScanning/modelScan.ts b/backend/src/connectors/fileScanning/modelScan.ts index a9f01fa86..2a6437934 100644 --- a/backend/src/connectors/fileScanning/modelScan.ts +++ b/backend/src/connectors/fileScanning/modelScan.ts @@ -63,6 +63,21 @@ export class ModelScanFileScanningConnector extends BaseFileScanningConnector { try { const scanResults = await scanStream(s3Stream, file.name, file.size) + if (scanResults.errors.length !== 0) { + log.error( + { errors: scanResults.errors, modelId: file.modelId, fileId: file._id, name: file.name }, + 'Scan errored.', + ) + return [ + { + toolName: modelScanToolName, + state: ScanState.Error, + scannerVersion: modelscanVersion, + lastRunAt: new Date(), + }, + ] + } + const issues = scanResults.summary.total_issues const isInfected = issues > 0 const viruses: string[] = [] @@ -91,6 +106,7 @@ export class ModelScanFileScanningConnector extends BaseFileScanningConnector { { toolName: modelScanToolName, state: ScanState.Error, + scannerVersion: modelscanVersion, lastRunAt: new Date(), }, ] From f549fe3c9ddc01a3e560a5f9137600d401572ec0 Mon Sep 17 00:00:00 2001 From: PE39806 <185931318+PE39806@users.noreply.github.com> Date: Tue, 17 Dec 2024 16:28:13 +0000 Subject: [PATCH 2/7] BAI-1540 rework basic modelscan tests --- lib/modelscan_api/pytest.ini | 5 ++ .../test_dependencies.py | 2 +- .../test_main.py | 55 +++++++++++-------- 3 files changed, 39 insertions(+), 23 deletions(-) create mode 100644 lib/modelscan_api/pytest.ini rename lib/modelscan_api/{bailo_modelscan_api => tests}/test_dependencies.py (98%) rename lib/modelscan_api/{bailo_modelscan_api => tests}/test_main.py (55%) diff --git a/lib/modelscan_api/pytest.ini b/lib/modelscan_api/pytest.ini new file mode 100644 index 000000000..1c0f70191 --- /dev/null +++ b/lib/modelscan_api/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +filterwarnings = ignore::DeprecationWarning +pythonpath = "." +testpaths = "tests" +junit_family = "xunit2" diff --git a/lib/modelscan_api/bailo_modelscan_api/test_dependencies.py b/lib/modelscan_api/tests/test_dependencies.py similarity index 98% rename from lib/modelscan_api/bailo_modelscan_api/test_dependencies.py rename to lib/modelscan_api/tests/test_dependencies.py index fe27ce221..2fc17eeff 100644 --- a/lib/modelscan_api/bailo_modelscan_api/test_dependencies.py +++ b/lib/modelscan_api/tests/test_dependencies.py @@ -9,7 +9,7 @@ import pytest -from .dependencies import safe_join +from bailo_modelscan_api.dependencies import safe_join # Helpers diff --git a/lib/modelscan_api/bailo_modelscan_api/test_main.py b/lib/modelscan_api/tests/test_main.py similarity index 55% rename from lib/modelscan_api/bailo_modelscan_api/test_main.py rename to lib/modelscan_api/tests/test_main.py index 456b163e6..82c013d9e 100644 --- a/lib/modelscan_api/bailo_modelscan_api/test_main.py +++ b/lib/modelscan_api/tests/test_main.py @@ -5,14 +5,16 @@ from functools import lru_cache from pathlib import Path +from typing import Any from unittest.mock import Mock, patch import modelscan from fastapi.testclient import TestClient +import pytest -from .config import Settings -from .dependencies import parse_path -from .main import app, get_settings +from bailo_modelscan_api.config import Settings +from bailo_modelscan_api.dependencies import parse_path +from bailo_modelscan_api.main import app, get_settings client = TestClient(app) @@ -25,6 +27,10 @@ def get_settings_override(): app.dependency_overrides[get_settings] = get_settings_override +EMPTY_CONTENTS = rb"" +H5_MIME_TYPE = "application/x-hdf5" + + def test_info(): response = client.get("/info") @@ -38,20 +44,13 @@ def test_info(): @patch("modelscan.modelscan.ModelScan.scan") -def test_scan_file(mock_scan: Mock): - mock_scan.return_value = {} - files = {"in_file": ("foo.h5", rb"", "application/x-hdf5")} - - response = client.post("/scan/file", files=files) - - assert response.status_code == 200 - mock_scan.assert_called_once() - - -@patch("modelscan.modelscan.ModelScan.scan") -def test_scan_file_escape_path(mock_scan: Mock): +@pytest.mark.parametrize( + ("file_name", "file_content", "file_mime_type"), + [("foo.h5", EMPTY_CONTENTS, H5_MIME_TYPE), ("../foo.h5", EMPTY_CONTENTS, H5_MIME_TYPE)], +) +def test_scan_file(mock_scan: Mock, file_name: str, file_content: Any, file_mime_type: str): mock_scan.return_value = {} - files = {"in_file": ("../foo.bar", rb"", "application/x-hdf5")} + files = {"in_file": (file_name, file_content, file_mime_type)} response = client.post("/scan/file", files=files) @@ -59,8 +58,12 @@ def test_scan_file_escape_path(mock_scan: Mock): mock_scan.assert_called_once() -def test_scan_file_escape_path_error(): - files = {"in_file": ("..", rb"", "text/plain")} +@pytest.mark.parametrize( + ("file_name", "file_content", "file_mime_type"), + [("..", EMPTY_CONTENTS, H5_MIME_TYPE), ("../", EMPTY_CONTENTS, H5_MIME_TYPE)], +) +def test_scan_file_escape_path_error(file_name: str, file_content: Any, file_mime_type: str): + files = {"in_file": (file_name, file_content, file_mime_type)} response = client.post("/scan/file", files=files) @@ -69,9 +72,13 @@ def test_scan_file_escape_path_error(): @patch("modelscan.modelscan.ModelScan.scan") -def test_scan_file_exception(mock_scan: Mock): +@pytest.mark.parametrize( + ("file_name", "file_content", "file_mime_type"), + [("foo.h5", EMPTY_CONTENTS, H5_MIME_TYPE)], +) +def test_scan_file_exception(mock_scan: Mock, file_name: str, file_content: Any, file_mime_type: str): mock_scan.side_effect = Exception("Mocked error!") - files = {"in_file": ("foo.h5", rb"", "application/x-hdf5")} + files = {"in_file": (file_name, file_content, file_mime_type)} response = client.post("/scan/file", files=files) @@ -83,8 +90,12 @@ def test_scan_file_exception(mock_scan: Mock): Path.unlink(Path.joinpath(parse_path(get_settings().download_dir), "foo.h5"), missing_ok=True) -def test_scan_file_filename_missing(): - files = {"in_file": (" ", rb"", "application/x-hdf5")} +@pytest.mark.parametrize( + ("file_name", "file_content", "file_mime_type"), + [(" ", EMPTY_CONTENTS, H5_MIME_TYPE)], +) +def test_scan_file_filename_missing(file_name: str, file_content: Any, file_mime_type: str): + files = {"in_file": (file_name, file_content, file_mime_type)} response = client.post("/scan/file", files=files) From 0aaa2c4db3eec67ef9b6599d7c61ef974189883f Mon Sep 17 00:00:00 2001 From: PE39806 <185931318+PE39806@users.noreply.github.com> Date: Wed, 18 Dec 2024 09:09:46 +0000 Subject: [PATCH 3/7] BAI-1540 continue modelscan api pytest rework --- .../bailo_modelscan_api/dependencies.py | 19 +- lib/modelscan_api/tests/test_dependencies.py | 169 +++++++++--------- lib/modelscan_api/tests/test_main.py | 4 +- 3 files changed, 98 insertions(+), 94 deletions(-) diff --git a/lib/modelscan_api/bailo_modelscan_api/dependencies.py b/lib/modelscan_api/bailo_modelscan_api/dependencies.py index 2fbc9a1c4..1fe9480c1 100644 --- a/lib/modelscan_api/bailo_modelscan_api/dependencies.py +++ b/lib/modelscan_api/bailo_modelscan_api/dependencies.py @@ -4,6 +4,7 @@ from __future__ import annotations import logging +import re from pathlib import Path logger = logging.getLogger(__name__) @@ -21,7 +22,17 @@ def parse_path(path: str | Path | None) -> Path: return Path().cwd() if path == "." else Path(path).absolute() -def safe_join(root_dir: str | Path | None, filename: str | Path) -> Path: +def sanitise_unix_filename(filename: str) -> str: + """Safely convert an arbitrary string to a valid unix filename by only preserving explicitly allowed characters as per https://en.wikipedia.org/wiki/Filename#Reserved_characters_and_words + Note that this is not safe for Windows users as it doesn't check for reserved words e.g. CON and AUX. + + :param filename: the untrusted filename to be sanitised + :return: a valid filename with trusted characters + """ + return re.sub(r"[/\\?%*:|\"<>\x7F\x00-\x1F]", "-", filename) + + +def safe_join(root_dir: str | Path | None, filename: str) -> Path: """Combine a trusted directory path with an untrusted filename to get a full path. :param root_dir: Trusted path/directory. @@ -33,13 +44,13 @@ def safe_join(root_dir: str | Path | None, filename: str | Path) -> Path: if not filename or not str(filename).strip(): raise ValueError("filename must not be empty") - stripped_filename = Path(str(filename)).name.strip() + safe_filename = sanitise_unix_filename(filename).strip() - if not stripped_filename: + if not safe_filename: raise ValueError("filename must not be empty") parent_dir = parse_path(root_dir).resolve() - full_path = parent_dir.joinpath(stripped_filename).resolve() + full_path = parent_dir.joinpath(safe_filename).resolve() if not full_path.is_relative_to(parent_dir): raise ValueError("Could not safely join paths.") diff --git a/lib/modelscan_api/tests/test_dependencies.py b/lib/modelscan_api/tests/test_dependencies.py index 2fc17eeff..a6db717b8 100644 --- a/lib/modelscan_api/tests/test_dependencies.py +++ b/lib/modelscan_api/tests/test_dependencies.py @@ -9,7 +9,7 @@ import pytest -from bailo_modelscan_api.dependencies import safe_join +from bailo_modelscan_api.dependencies import parse_path, safe_join, sanitise_unix_filename # Helpers @@ -26,106 +26,99 @@ def type_matrix(data: Iterable[Any], types: Iterable[type]) -> itertools.product return itertools.product(*[[t(d) for t in types] for d in data]) -def string_path_matrix(path1: str | Path, path2: str | Path) -> itertools.product[tuple[str, Path]]: - """Wrap type_matrix for convenience with str and Path types. - - :param path1: A path to process. - :param path2: Another path to process. - :return: The matrix of both paths with types str and Path. - """ - return type_matrix([path1, path2], [str, Path]) +# Tests -def helper_test_safe_join(path1: str | Path, path2: str | Path, output: Path) -> None: - """Helper method for testing that all str and Path representations of the two paths will match the given output when joined. +@pytest.mark.parametrize( + ("path", "output"), + [ + ("foo.bar", "foo.bar"), + (".foo.bar", ".foo.bar"), + ("/foo.bar", "-foo.bar"), + ("foo/./bar", "foo-.-bar"), + ("foo.-/bar", "foo.--bar"), + (".", "."), + ("..", ".."), + ("/", "-"), + ("/.", "-."), + ("./", ".-"), + ("\n", "-"), + ("\r", "-"), + ("~", "~"), + ("".join(['\\[/\\?%*:|"<>0x7F0x00-0x1F]', chr(0x1F) * 15]), "-[----------0x7F0x00-0x1F]---------------"), + ("ad\nbla'{-+\\)(ç?", "ad-bla'{-+-)(ç-"), # type: ignore + ], +) +def test_sanitise_unix_filename(path: str, output: str) -> None: + assert sanitise_unix_filename(path) == output + + +@pytest.mark.parametrize( + ("path", "output"), + [ + (None, Path().cwd()), + ("", Path().cwd()), + (".", Path().cwd()), + ("/tmp", Path("/tmp")), + ("/foo/bar", Path("/foo/bar")), + ("/foo/../bar", Path("/foo/../bar")), + ("/foo/bar space/baz", Path("/foo/bar space/baz")), + ("/C:\\Program Files\\HAL 9000", Path("/C:\\Program Files\\HAL 9000")), + ("/ISO&Emulator", Path("/ISO&Emulator")), + ("/$HOME", Path("/$HOME")), + ("~", Path().cwd().joinpath("~")), + ], +) +def test_parse_path(path: str | Path | None, output: Path) -> None: + if path is None: + assert parse_path(path) == output + else: + for (test_path,) in type_matrix((path,), (str, Path)): + assert parse_path(test_path) == output + + +@pytest.mark.parametrize( + ("path1", "path2", "output"), + [ + ("", "foo.bar", Path.cwd().joinpath("foo.bar")), + (".", "foo.bar", Path.cwd().joinpath("foo.bar")), + ("/tmp", "foo.bar", Path("/tmp/foo.bar")), + ("/tmp/", "foo.bar", Path("/tmp/foo.bar")), + ("/tmp/", "/foo.bar", Path("/tmp/-foo.bar")), + ("/tmp", ".foo.bar", Path("/tmp/.foo.bar")), + ("/tmp", "/foo.bar", Path("/tmp/-foo.bar")), + ("/tmp", "//foo.bar", Path("/tmp/--foo.bar")), + ("/tmp", "./foo.bar", Path("/tmp/.-foo.bar")), + ("/tmp", "./.foo.bar", Path("/tmp/.-.foo.bar")), + ("/tmp", "..foo.bar", Path("/tmp/..foo.bar")), + ("/tmp", "../foo.bar", Path("/tmp/..-foo.bar")), + ("/tmp", "../.foo.bar", Path("/tmp/..-.foo.bar")), + ("/tmp", ".", Path("/tmp/.")), + ("/tmp", "/", Path("/tmp/-")), + ("/tmp", "//", Path("/tmp/--")), + ("/tmp", "~", Path("/tmp/~")), + ], +) +def test_safe_join(path1: str | Path, path2: str, output: Path) -> None: + """Test that all str and Path representations of the two paths will match the given output when joined. :param path1: Directory part of the final path. :param path2: Filename part of the final path. :param output: Expected final path value. """ - for test_dir, test_file in string_path_matrix(path1, path2): - res = safe_join(test_dir, test_file) + for (test_dir,) in type_matrix((path1,), (str, Path)): + res = safe_join(test_dir, path2) assert res == output -def helper_test_safe_join_catch(path1: str | Path, path2: str | Path) -> None: - """Helper method for testing that all str and Path representation of the two paths will throw an error when joined. +@pytest.mark.parametrize(("path1", "path2"), [("/tmp", ""), ("/tmp", "..")]) +def test_safe_join_catch(path1: str | Path, path2: str) -> None: + """Test that all str and Path representation of the two paths will throw an error when joined. :param path1: Directory part of the final path. :param path2: Filename part of the final path. """ # check error thrown given two inputs - for test_dir, test_file in string_path_matrix(path1, path2): + for (test_dir,) in type_matrix((path1,), (str, Path)): with pytest.raises(ValueError): - safe_join(test_dir, test_file) - - -# Tests - - -def test_safe_join_blank(): - helper_test_safe_join("", "foo.bar", Path.cwd().joinpath("foo.bar")) - - -def test_safe_join_local(): - helper_test_safe_join(".", "foo.bar", Path.cwd().joinpath("foo.bar")) - - -def test_safe_join_abs(): - helper_test_safe_join("/tmp", "foo.bar", Path("/tmp").joinpath("foo.bar")) - - -def test_safe_join_abs_trailing(): - helper_test_safe_join("/tmp/", "foo.bar", Path("/tmp").joinpath("foo.bar")) - - -def test_safe_join_abs_dot(): - helper_test_safe_join("/tmp", ".foo.bar", Path("/tmp").joinpath(".foo.bar")) - - -def test_safe_join_abs_slash(): - helper_test_safe_join("/tmp", "/foo.bar", Path("/tmp").joinpath("foo.bar")) - - -def test_safe_join_abs_double_slash(): - helper_test_safe_join("/tmp", "//foo.bar", Path("/tmp").joinpath("foo.bar")) - - -def test_safe_join_abs_dot_slash(): - helper_test_safe_join("/tmp", "./foo.bar", Path("/tmp").joinpath("foo.bar")) - - -def test_safe_join_abs_dot_slash_dot(): - helper_test_safe_join("/tmp", "./.foo.bar", Path("/tmp").joinpath(".foo.bar")) - - -def test_safe_join_abs_double_dot(): - helper_test_safe_join("/tmp", "..foo.bar", Path("/tmp").joinpath("..foo.bar")) - - -def test_safe_join_abs_double_dot_slash(): - helper_test_safe_join("/tmp", "../foo.bar", Path("/tmp").joinpath("foo.bar")) - - -def test_safe_join_abs_double_dot_slash_dot(): - helper_test_safe_join("/tmp", "../.foo.bar", Path("/tmp").joinpath(".foo.bar")) - - -def test_safe_join_fail_blank(): - helper_test_safe_join_catch("/tmp", "") - - -def test_safe_join_fail_dot(): - helper_test_safe_join_catch("/tmp", ".") - - -def test_safe_join_fail_double_dot(): - helper_test_safe_join_catch("/tmp", "..") - - -def test_safe_join_fail_slash(): - helper_test_safe_join_catch("/tmp", "/") - - -def test_safe_join_fail_double_slash(): - helper_test_safe_join_catch("/tmp", "//") + safe_join(test_dir, path2) diff --git a/lib/modelscan_api/tests/test_main.py b/lib/modelscan_api/tests/test_main.py index 82c013d9e..418fe5c54 100644 --- a/lib/modelscan_api/tests/test_main.py +++ b/lib/modelscan_api/tests/test_main.py @@ -9,8 +9,8 @@ from unittest.mock import Mock, patch import modelscan -from fastapi.testclient import TestClient import pytest +from fastapi.testclient import TestClient from bailo_modelscan_api.config import Settings from bailo_modelscan_api.dependencies import parse_path @@ -60,7 +60,7 @@ def test_scan_file(mock_scan: Mock, file_name: str, file_content: Any, file_mime @pytest.mark.parametrize( ("file_name", "file_content", "file_mime_type"), - [("..", EMPTY_CONTENTS, H5_MIME_TYPE), ("../", EMPTY_CONTENTS, H5_MIME_TYPE)], + [("..", EMPTY_CONTENTS, H5_MIME_TYPE)], ) def test_scan_file_escape_path_error(file_name: str, file_content: Any, file_mime_type: str): files = {"in_file": (file_name, file_content, file_mime_type)} From 18dc6197514f60ec553f86fb1e3939237fc79258 Mon Sep 17 00:00:00 2001 From: PE39806 <185931318+PE39806@users.noreply.github.com> Date: Wed, 18 Dec 2024 12:50:04 +0000 Subject: [PATCH 4/7] BAI-1540 add modelscan integration tests --- .github/workflows/modelscan.yml | 6 + lib/modelscan_api/README.md | 11 +- lib/modelscan_api/pytest.ini | 1 + lib/modelscan_api/tests/test_integration.py | 165 ++++++++++++++++++ .../tests/test_integration/README.md | 15 ++ .../test_integration/generate_test_data.py | 67 +++++++ .../tests/test_integration/safe.pkl | Bin 0 -> 27 bytes .../tests/test_integration/unsafe.pkl | Bin 0 -> 71 bytes lib/modelscan_api/tests/test_main.py | 6 +- 9 files changed, 269 insertions(+), 2 deletions(-) create mode 100644 lib/modelscan_api/tests/test_integration.py create mode 100644 lib/modelscan_api/tests/test_integration/README.md create mode 100644 lib/modelscan_api/tests/test_integration/generate_test_data.py create mode 100644 lib/modelscan_api/tests/test_integration/safe.pkl create mode 100644 lib/modelscan_api/tests/test_integration/unsafe.pkl diff --git a/.github/workflows/modelscan.yml b/.github/workflows/modelscan.yml index aab632dd0..85e4af68d 100644 --- a/.github/workflows/modelscan.yml +++ b/.github/workflows/modelscan.yml @@ -31,3 +31,9 @@ jobs: run: | cd lib/modelscan_api python -m pytest + + # Pytest -m integration + - name: Run integration testing + run: | + cd lib/modelscan_api + python -m pytest -m integration diff --git a/lib/modelscan_api/README.md b/lib/modelscan_api/README.md index 8e654040f..043039362 100644 --- a/lib/modelscan_api/README.md +++ b/lib/modelscan_api/README.md @@ -53,12 +53,21 @@ pre-commit install ### Tests -To run the tests: +To run the unit tests: ```bash pytest ``` +To run the integration tests (does not require any externally running services): + +```bash +pytest -m integration +``` + +Note that the integration tests use safe but technically "malicious" file(s) to check ModelScan's performance. Please +refer to [test_integration](./tests/test_integration/README.md) for details. + ### Running To run in [dev mode](https://fastapi.tiangolo.com/fastapi-cli/#fastapi-dev): diff --git a/lib/modelscan_api/pytest.ini b/lib/modelscan_api/pytest.ini index 1c0f70191..256a0f80b 100644 --- a/lib/modelscan_api/pytest.ini +++ b/lib/modelscan_api/pytest.ini @@ -3,3 +3,4 @@ filterwarnings = ignore::DeprecationWarning pythonpath = "." testpaths = "tests" junit_family = "xunit2" +markers = integration diff --git a/lib/modelscan_api/tests/test_integration.py b/lib/modelscan_api/tests/test_integration.py new file mode 100644 index 000000000..afd5be3e6 --- /dev/null +++ b/lib/modelscan_api/tests/test_integration.py @@ -0,0 +1,165 @@ +"""Integration tests for working with ModelScan. +""" + +from __future__ import annotations + +from functools import lru_cache +from pathlib import Path +from typing import Any +from unittest.mock import ANY + +from fastapi.testclient import TestClient +import modelscan +import pytest + +from bailo_modelscan_api.config import Settings +from bailo_modelscan_api.main import app, get_settings + +client = TestClient(app) + + +@lru_cache +def get_settings_override(): + return Settings(download_dir=".") + + +app.dependency_overrides[get_settings] = get_settings_override + + +H5_MIME_TYPE = "application/x-hdf5" +OCTET_STREAM_TYPE = "application/octet-stream" + + +@pytest.mark.integration +@pytest.mark.parametrize( + ("file_name", "file_content", "file_mime_type", "expected_response"), + [ + ( + "empty.txt", + rb"", + "text/plain", + { + "errors": [], + "issues": [], + "summary": { + "absolute_path": str(Path().cwd().absolute()), + "input_path": str(Path().cwd().absolute().joinpath("empty.txt")), + "modelscan_version": modelscan.__version__, + "scanned": {"total_scanned": 0}, + "skipped": { + "skipped_files": [ + { + "category": "SCAN_NOT_SUPPORTED", + "description": "Model Scan did not scan file", + "source": "empty.txt", + } + ], + "total_skipped": 1, + }, + "timestamp": ANY, + "total_issues": 0, + "total_issues_by_severity": {"CRITICAL": 0, "HIGH": 0, "LOW": 0, "MEDIUM": 0}, + }, + }, + ), + ( + "null.h5", + rb"", + H5_MIME_TYPE, + { + "summary": { + "total_issues_by_severity": {"LOW": 0, "MEDIUM": 0, "HIGH": 0, "CRITICAL": 0}, + "total_issues": 0, + "input_path": str(Path().cwd().absolute().joinpath("null.h5")), + "absolute_path": str(Path().cwd().absolute()), + "modelscan_version": modelscan.__version__, + "timestamp": ANY, + "scanned": {"total_scanned": 0}, + "skipped": { + "total_skipped": 1, + "skipped_files": [ + { + "category": "SCAN_NOT_SUPPORTED", + "description": "Model Scan did not scan file", + "source": "null.h5", + } + ], + }, + }, + "issues": [], + "errors": [ + { + "category": "MODEL_SCAN", + "description": "Unable to synchronously open file (file signature not found)", + "source": "null.h5", + } + ], + }, + ), + ( + "safe.pkl", + Path().cwd().joinpath("tests/test_integration/safe.pkl"), + OCTET_STREAM_TYPE, + { + "summary": { + "total_issues_by_severity": {"LOW": 0, "MEDIUM": 0, "HIGH": 0, "CRITICAL": 0}, + "total_issues": 0, + "input_path": str(Path().cwd().absolute().joinpath("safe.pkl")), + "absolute_path": str(Path().cwd().absolute()), + "modelscan_version": modelscan.__version__, + "timestamp": ANY, + "scanned": {"total_scanned": 1, "scanned_files": ["safe.pkl"]}, + "skipped": { + "total_skipped": 0, + "skipped_files": [], + }, + }, + "issues": [], + "errors": [], + }, + ), + ( + "unsafe.pkl", + Path().cwd().joinpath("unsafe.pkl"), + OCTET_STREAM_TYPE, + { + "summary": { + "total_issues_by_severity": {"LOW": 0, "MEDIUM": 0, "HIGH": 0, "CRITICAL": 1}, + "total_issues": 1, + "input_path": str(Path().cwd().absolute().joinpath("unsafe.pkl")), + "absolute_path": str(Path().cwd().absolute()), + "modelscan_version": modelscan.__version__, + "timestamp": ANY, + "scanned": {"total_scanned": 1, "scanned_files": ["unsafe.pkl"]}, + "skipped": { + "total_skipped": 0, + "skipped_files": [], + }, + }, + "issues": [ + { + "description": "Use of unsafe operator 'system' from module 'posix'", + "module": "posix", + "operator": "system", + "scanner": "modelscan.scanners.PickleUnsafeOpScan", + "severity": "CRITICAL", + "source": "unsafe.pkl", + }, + ], + "errors": [], + }, + ), + ], +) +def test_scan_file(file_name: str, file_content: Path | bytes, file_mime_type: str, expected_response: dict) -> None: + # allow passing in a Path to read the file's contents for specific data types + if isinstance(file_content, Path): + with open(file_content, "rb") as f: + file_content = f.read() + + files = {"in_file": (file_name, file_content, file_mime_type)} + + response = client.post("/scan/file", files=files) + + assert response.status_code == 200 + assert response.json() == expected_response diff --git a/lib/modelscan_api/tests/test_integration/README.md b/lib/modelscan_api/tests/test_integration/README.md new file mode 100644 index 000000000..fd43b5b61 --- /dev/null +++ b/lib/modelscan_api/tests/test_integration/README.md @@ -0,0 +1,15 @@ +# Integration test files + +Simple, minimal test files for use with `pytest -m integration`. These files may contain "malicious" code which is +specifically chosen to actually be safe to run while having the footprint of dangerous code. + +Test files can be regenerated by running `python3 generate_test_data.py`. + +## safe.pkl + +Simply stores `{"foo": "bar"}`. + +## unsafe.pkl + +A "malicious" pickle file which executes the system command `echo hello world` when loaded, as well as storing +`{"foo": "bar"}`. diff --git a/lib/modelscan_api/tests/test_integration/generate_test_data.py b/lib/modelscan_api/tests/test_integration/generate_test_data.py new file mode 100644 index 000000000..cf2123d49 --- /dev/null +++ b/lib/modelscan_api/tests/test_integration/generate_test_data.py @@ -0,0 +1,67 @@ +import os +import pickle +import struct + + +class _Pickler(pickle._Pickler): + """A minimal reproduction of https://github.com/protectai/modelscan/blob/main/notebooks/utils/pickle_codeinjection.py""" + + def __init__(self, file, protocol, inj_objs): + super().__init__(file, protocol) + self.inj_objs = inj_objs + + def dump(self, obj): + "Pickle data, inject object before or after" + if self.proto >= 2: # type: ignore + self.write(pickle.PROTO + struct.pack("= 4: # type: ignore + self.framer.start_framing() # type: ignore + for inj_obj in self.inj_objs: + self.save(inj_obj) # type: ignore + self.save(obj) # type: ignore + self.write(pickle.STOP) # type: ignore + self.framer.end_framing() # type: ignore + + +class _PickleInject: + def __init__(self, args, command=None): + self.command = command + self.args = args + + def __reduce__(self): + return self.command, (self.args,) + + +def _generate_unsafe_file(data, malicious_code, unsafe_model_path): + """Generate a malicious pickle file with real data as well as a malicious call to the system. + + :param data: normal data to store in the pickle file + :param malicious_code: malicious code to run on the host device + :param unsafe_model_path: where to write the pickle file to + """ + payload = _PickleInject(malicious_code, command=os.system) + with open(unsafe_model_path, "wb") as f: + mypickler = _Pickler(f, 4, [payload]) + mypickler.dump(data) + + +def safe_pickle(): + """Creates a simple, safe pickle file containing the data `{"foo": "bar"}`""" + with open("safe.pkl", "wb") as f: + pickle.dump({"foo": "bar"}, f) + + +def unsafe_pickle(): + """Creates a minimal malicious pickle file that would run the system command `echo hello world` as well as containing the data `{"foo": "bar"}`""" + _generate_unsafe_file( + {"foo", "bar"}, + """echo hello world +""", + "unsafe.pkl", + ) + + +if __name__ == "__main__": + # only generate the files if the file is explicitly run + safe_pickle() + unsafe_pickle() diff --git a/lib/modelscan_api/tests/test_integration/safe.pkl b/lib/modelscan_api/tests/test_integration/safe.pkl new file mode 100644 index 0000000000000000000000000000000000000000..d99f54500b444e4950269cf7a6d7e8629c86091e GIT binary patch literal 27 ccmZo*nJT~l0kuR9iJ Date: Wed, 18 Dec 2024 12:52:30 +0000 Subject: [PATCH 5/7] BAI-1540 correct bad test filepath --- lib/modelscan_api/tests/test_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/modelscan_api/tests/test_integration.py b/lib/modelscan_api/tests/test_integration.py index afd5be3e6..58e95d0c4 100644 --- a/lib/modelscan_api/tests/test_integration.py +++ b/lib/modelscan_api/tests/test_integration.py @@ -120,7 +120,7 @@ def get_settings_override(): ), ( "unsafe.pkl", - Path().cwd().joinpath("unsafe.pkl"), + Path().cwd().joinpath("tests/test_integration/unsafe.pkl"), OCTET_STREAM_TYPE, { "summary": { From 0729767b79d5d0d6c2f2f5837db403782c587683 Mon Sep 17 00:00:00 2001 From: PE39806 <185931318+PE39806@users.noreply.github.com> Date: Wed, 18 Dec 2024 14:20:42 +0000 Subject: [PATCH 6/7] BAI-1540 expand modelscan api /scan/file docs examples --- lib/modelscan_api/bailo_modelscan_api/main.py | 119 +++++++++++++++--- 1 file changed, 99 insertions(+), 20 deletions(-) diff --git a/lib/modelscan_api/bailo_modelscan_api/main.py b/lib/modelscan_api/bailo_modelscan_api/main.py index 6bf107dca..ae4e3fec4 100644 --- a/lib/modelscan_api/bailo_modelscan_api/main.py +++ b/lib/modelscan_api/bailo_modelscan_api/main.py @@ -87,28 +87,107 @@ async def info(settings: Annotated[Settings, Depends(get_settings)]) -> ApiInfor "description": "modelscan returned results", "content": { "application/json": { - "example": { - "summary": { - "total_issues_by_severity": {"LOW": 0, "MEDIUM": 1, "HIGH": 0, "CRITICAL": 0}, - "total_issues": 1, - "input_path": "/foo/bar/unsafe_model.h5", - "absolute_path": "/foo/bar", - "modelscan_version": "0.8.1", - "timestamp": "2024-11-19T12:00:00.000000", - "scanned": {"total_scanned": 1, "scanned_files": ["unsafe_model.h5"]}, - "skipped": {"total_skipped": 0, "skipped_files": []}, + "examples": { + "Normal": { + "value": { + "summary": { + "total_issues_by_severity": {"LOW": 0, "MEDIUM": 0, "HIGH": 0, "CRITICAL": 0}, + "total_issues": 0, + "input_path": "/foo/bar/safe_model.pkl", + "absolute_path": "/foo/bar", + "modelscan_version": "0.8.1", + "timestamp": "2024-11-19T12:00:00.000000", + "scanned": {"total_scanned": 1, "scanned_files": ["safe_model.pkl"]}, + "skipped": { + "total_skipped": 0, + "skipped_files": [], + }, + }, + "issues": [], + "errors": [], + } + }, + "Issue": { + "value": { + "summary": { + "total_issues_by_severity": {"LOW": 0, "MEDIUM": 1, "HIGH": 0, "CRITICAL": 0}, + "total_issues": 1, + "input_path": "/foo/bar/unsafe_model.h5", + "absolute_path": "/foo/bar", + "modelscan_version": "0.8.1", + "timestamp": "2024-11-19T12:00:00.000000", + "scanned": {"total_scanned": 1, "scanned_files": ["unsafe_model.h5"]}, + "skipped": {"total_skipped": 0, "skipped_files": []}, + }, + "issues": [ + { + "description": "Use of unsafe operator 'Lambda' from module 'Keras'", + "operator": "Lambda", + "module": "Keras", + "source": "unsafe_model.h5", + "scanner": "modelscan.scanners.H5LambdaDetectScan", + "severity": "MEDIUM", + } + ], + "errors": [], + } }, - "issues": [ - { - "description": "Use of unsafe operator 'Lambda' from module 'Keras'", - "operator": "Lambda", - "module": "Keras", - "source": "unsafe_model.h5", - "scanner": "modelscan.scanners.H5LambdaDetectScan", - "severity": "MEDIUM", + "Skipped": { + "value": { + "errors": [], + "issues": [], + "summary": { + "input_path": "/foo/bar/empty.txt", + "absolute_path": "/foo/bar", + "modelscan_version": "0.8.1", + "scanned": {"total_scanned": 0}, + "skipped": { + "skipped_files": [ + { + "category": "SCAN_NOT_SUPPORTED", + "description": "Model Scan did not scan file", + "source": "empty.txt", + } + ], + "total_skipped": 1, + }, + "timestamp": "2024-11-19T12:00:00.000000", + "total_issues": 0, + "total_issues_by_severity": {"CRITICAL": 0, "HIGH": 0, "LOW": 0, "MEDIUM": 0}, + }, } - ], - "errors": [], + }, + "Error": { + "value": { + "summary": { + "total_issues_by_severity": {"LOW": 0, "MEDIUM": 0, "HIGH": 0, "CRITICAL": 0}, + "total_issues": 0, + "input_path": "/foo/bar/null.h5", + "absolute_path": "/foo/bar", + "modelscan_version": "0.8.1", + "timestamp": "2024-11-19T12:00:00.000000", + "scanned": {"total_scanned": 0}, + "skipped": { + "total_skipped": 1, + "skipped_files": [ + { + "category": "SCAN_NOT_SUPPORTED", + "description": "Model Scan did not scan file", + "source": "null.h5", + } + ], + }, + }, + "issues": [], + "errors": [ + { + "category": "MODEL_SCAN", + "description": "Unable to synchronously open file (file signature not found)", + "source": "null.h5", + } + ], + } + }, } } }, From 26269d8e07341e8407f9457292bcca73bfa87648 Mon Sep 17 00:00:00 2001 From: PE39806 <185931318+PE39806@users.noreply.github.com> Date: Wed, 18 Dec 2024 15:53:19 +0000 Subject: [PATCH 7/7] BAI-1540 minor linting --- lib/modelscan_api/tests/test_integration.py | 3 +-- lib/modelscan_api/tests/test_integration/generate_test_data.py | 2 ++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/modelscan_api/tests/test_integration.py b/lib/modelscan_api/tests/test_integration.py index 58e95d0c4..d8e37184a 100644 --- a/lib/modelscan_api/tests/test_integration.py +++ b/lib/modelscan_api/tests/test_integration.py @@ -5,12 +5,11 @@ from functools import lru_cache from pathlib import Path -from typing import Any from unittest.mock import ANY -from fastapi.testclient import TestClient import modelscan import pytest +from fastapi.testclient import TestClient from bailo_modelscan_api.config import Settings from bailo_modelscan_api.main import app, get_settings diff --git a/lib/modelscan_api/tests/test_integration/generate_test_data.py b/lib/modelscan_api/tests/test_integration/generate_test_data.py index cf2123d49..69459656d 100644 --- a/lib/modelscan_api/tests/test_integration/generate_test_data.py +++ b/lib/modelscan_api/tests/test_integration/generate_test_data.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import pickle import struct