Skip to content

Commit

Permalink
fix(tests): add missing kagglehub.model_download mocks
Browse files Browse the repository at this point in the history
  • Loading branch information
KeijiBranshi committed Nov 19, 2024
1 parent f91607b commit 87d34bb
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions tests/torchtune/_cli/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,17 +171,20 @@ def test_download_from_kaggle_warn_when_ignore_patterns_provided(

# tests when --kaggle-username and --kaggle-api-key are provided as CLI args
def test_download_from_kaggle_when_credentials_provided(
self, capsys, monkeypatch, mocker
self, capsys, monkeypatch, mocker, tmpdir
):
expected_username = "username"
expected_api_key = "api_key"
model = "metaresearch/llama-3.2/pytorch/1b"
testargs = (
f"tune download {model} "
f"--source kaggle --kaggle-username {expected_username} "
f"--source kaggle "
f"--kaggle-username {expected_username} "
f"--kaggle-api-key {expected_api_key}"
).split()
monkeypatch.setattr(sys, "argv", testargs)
# mock out kagglehub.model_download to get around key storage
mocker.patch("torchtune._cli.download.model_download", return_value=tmpdir)
set_kaggle_credentials_spy = mocker.patch(
"torchtune._cli.download.set_kaggle_credentials"
)
Expand All @@ -204,13 +207,15 @@ def test_download_from_kaggle_when_credentials_provided(
# passes partial credentials with just --kaggle-username (expect fallback to environment variables)
@mock.patch.dict(os.environ, {"KAGGLE_KEY": "env_api_key"})
def test_download_from_kaggle_when_partial_credentials_provided(
self, capsys, monkeypatch, mocker
self, capsys, monkeypatch, mocker, tmpdir
):
expected_username = "username"
expected_api_key = "env_api_key"
model = "metaresearch/llama-3.2/pytorch/1b"
testargs = f"tune download {model} --source kaggle --kaggle-username {expected_username}".split()
monkeypatch.setattr(sys, "argv", testargs)
# mock out kagglehub.model_download to get around key storage
mocker.patch("torchtune._cli.download.model_download", return_value=tmpdir)
set_kaggle_credentials_spy = mocker.patch(
"torchtune._cli.download.set_kaggle_credentials"
)
Expand All @@ -231,11 +236,13 @@ def test_download_from_kaggle_when_partial_credentials_provided(
)

def test_download_from_kaggle_when_set_kaggle_credentials_throws(
self, monkeypatch, mocker
self, monkeypatch, mocker, tmpdir
):
model = "metaresearch/llama-3.2/pytorch/1b"
testargs = f"tune download {model} --source kaggle --kaggle-username u --kaggle-api-key k".split()
monkeypatch.setattr(sys, "argv", testargs)
# mock out kagglehub.model_download to get around key storage
mocker.patch("torchtune._cli.download.model_download", return_value=tmpdir)
mocker.patch(
"torchtune._cli.download.set_kaggle_credentials",
side_effect=Exception("some error"),
Expand Down

0 comments on commit 87d34bb

Please sign in to comment.