Skip to content

Commit

Permalink
SNOW-872425: fix adding package when the specifier contains name with…
Browse files Browse the repository at this point in the history
… underscore and version (#1098)
  • Loading branch information
sfc-gh-aling authored Oct 20, 2023
1 parent 4f6d745 commit 9dcaa92
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 3 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
- The `format` argument changed from optional to required.
- The returned result changed from a date object to a date-formatted string.

### Bug Fixes

- Fixed a bug that `session.add_packages` can not handle requirement specifier that contains project name with underscore and version.

## 1.9.0 (2023-10-13)

### New Features
Expand Down
15 changes: 12 additions & 3 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import json
import logging
import os
import re
import sys
import tempfile
from array import array
Expand Down Expand Up @@ -1034,10 +1035,18 @@ def _resolve_packages(
# get the standard package name if there is no underscore
# underscores are discouraged in package names, but are still used in Anaconda channel
# pkg_resources.Requirement.parse will convert all underscores to dashes
# the regexp is to deal with case that "_" is in the package requirement as well as version restrictions
# we only extract the valid package name from the string by following:
# https://packaging.python.org/en/latest/specifications/name-normalization/
# A valid name consists only of ASCII letters and numbers, period, underscore and hyphen.
# It must start and end with a letter or number.
# however, we don't validate the pkg name as this is done by pkg_resources.Requirement.parse
# find the index of the first char which is not an valid package name character
package_name = package_req.key
if not use_local_version and "_" in package:
reg_match = re.search(r"[^0-9a-zA-Z\-_.]", package)
package_name = package[: reg_match.start()] if reg_match else package

package_name = (
package if not use_local_version and "_" in package else package_req.key
)
package_dict[package] = (package_name, use_local_version, package_req)

package_table = "information_schema.packages"
Expand Down
21 changes: 21 additions & 0 deletions tests/integ/test_packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,27 @@ def check_if_package_installed() -> bool:
Utils.check_answer(session.sql(f"select {udf_name}()").collect(), [Row(True)])


@pytest.mark.udf
def test_add_packages_with_underscore_and_versions(session):
session.add_packages(["huggingface_hub==0.15.1"])
assert session.get_packages() == {
"huggingface_hub": "huggingface_hub==0.15.1",
}
session.clear_packages()

session.add_packages(["huggingface_hub>0.14.1"])
assert session.get_packages() == {
"huggingface_hub": "huggingface_hub>0.14.1",
}
session.clear_packages()

session.add_packages(["huggingface_hub<=0.15.1"])
assert session.get_packages() == {
"huggingface_hub": "huggingface_hub<=0.15.1",
}
session.clear_packages()


@pytest.mark.skipif(
IS_IN_STORED_PROC, reason="Need certain version of datautil/pandas/numpy"
)
Expand Down

0 comments on commit 9dcaa92

Please sign in to comment.