Skip to content

Commit

Permalink
SNOW-919476: Do not ignore stage location for using imports when is_p…
Browse files Browse the repository at this point in the history
…ermanent is False (#1053)

* SNOW-919476: Do not ignore stage location for using imports when is_permanent is False

* SNOW-919476: Test updates; changelog updates

* SNOW-919476: fix test

* separate stage location between upload and import stage

* fix _resolve_imports

* add test

* fix lint

* fix unit test for new file

* refactor and double unwrap tests

* refactor variable names

* remove redundant comment
  • Loading branch information
sfc-gh-aalam authored Sep 28, 2023
1 parent c6cfc66 commit 00ca832
Show file tree
Hide file tree
Showing 12 changed files with 198 additions and 100 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

- Added back the dependency of `typing-extensions`.

### Bug Fixes

- Fixed a bug where imports from permanent stage locations were ignored for temporary stored procedures, UDTFs, UDFs, and UDAFs.

## 1.8.0 (2023-09-14)

### New Features
Expand Down
30 changes: 17 additions & 13 deletions src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,6 @@ def check_register_args(
raise ValueError(
f"stage_location must be specified for permanent {get_error_message_abbr(object_type)}"
)
else:
if stage_location:
logger.warn(
"is_permanent is False therefore stage_location will be ignored"
)

if parallel < 1 or parallel > 99:
raise ValueError(
Expand Down Expand Up @@ -815,12 +810,16 @@ def resolve_imports_and_packages(
skip_upload_on_content_match: bool = False,
is_permanent: bool = False,
) -> Tuple[str, str, str, str, str, bool]:
upload_stage = (
import_only_stage = (
unwrap_stage_location_single_quote(stage_location)
if stage_location and is_permanent
if stage_location
else session.get_session_stage()
)

upload_and_import_stage = (
import_only_stage if is_permanent else session.get_session_stage()
)

# resolve packages
resolved_packages = (
session._resolve_packages(packages, include_pandas=is_pandas_udf)
Expand Down Expand Up @@ -850,11 +849,16 @@ def resolve_imports_and_packages(
)
udf_level_imports[resolved_import_tuple[0]] = resolved_import_tuple[1:]
all_urls = session._resolve_imports(
upload_stage, udf_level_imports, statement_params=statement_params
import_only_stage,
upload_and_import_stage,
udf_level_imports,
statement_params=statement_params,
)
elif imports is None:
all_urls = session._resolve_imports(
upload_stage, statement_params=statement_params
import_only_stage,
upload_and_import_stage,
statement_params=statement_params,
)
else:
all_urls = []
Expand Down Expand Up @@ -883,7 +887,7 @@ def resolve_imports_and_packages(
if len(code) > _MAX_INLINE_CLOSURE_SIZE_BYTES:
dest_prefix = get_udf_upload_prefix(udf_name)
upload_file_stage_location = normalize_remote_file_or_dir(
f"{upload_stage}/{dest_prefix}/{udf_file_name}"
f"{upload_and_import_stage}/{dest_prefix}/{udf_file_name}"
)
udf_file_name_base = os.path.splitext(udf_file_name)[0]
with io.BytesIO() as input_stream:
Expand All @@ -893,7 +897,7 @@ def resolve_imports_and_packages(
zf.writestr(f"{udf_file_name_base}.py", code)
session._conn.upload_stream(
input_stream=input_stream,
stage_location=upload_stage,
stage_location=upload_and_import_stage,
dest_filename=udf_file_name,
dest_prefix=dest_prefix,
parallel=parallel,
Expand Down Expand Up @@ -924,11 +928,11 @@ def resolve_imports_and_packages(
all_urls.append(func[0])
else:
upload_file_stage_location = normalize_remote_file_or_dir(
f"{upload_stage}/{dest_prefix}/{udf_file_name}"
f"{upload_and_import_stage}/{dest_prefix}/{udf_file_name}"
)
session._conn.upload_file(
path=func[0],
stage_location=upload_stage,
stage_location=upload_and_import_stage,
dest_prefix=dest_prefix,
parallel=parallel,
compress_data=False,
Expand Down
36 changes: 25 additions & 11 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,8 @@ def _resolve_import_path(

def _resolve_imports(
self,
stage_location: str,
import_only_stage: str,
upload_and_import_stage: str,
udf_level_import_paths: Optional[
Dict[str, Tuple[Optional[str], Optional[str]]]
] = None,
Expand All @@ -695,9 +696,15 @@ def _resolve_imports(
"""Resolve the imports and upload local files (if any) to the stage."""
resolved_stage_files = []
stage_file_list = self._list_files_in_stage(
stage_location, statement_params=statement_params
import_only_stage, statement_params=statement_params
)

normalized_import_only_location = unwrap_stage_location_single_quote(
import_only_stage
)
normalized_upload_and_import_location = unwrap_stage_location_single_quote(
upload_and_import_stage
)
normalized_stage_location = unwrap_stage_location_single_quote(stage_location)

import_paths = udf_level_import_paths or self._import_paths
for path, (prefix, leading_path) in import_paths.items():
Expand All @@ -713,7 +720,12 @@ def _resolve_imports(
filename_with_prefix = f"{prefix}/{filename}"
if filename_with_prefix in stage_file_list:
_logger.debug(
f"{filename} exists on {normalized_stage_location}, skipped"
f"{filename} exists on {normalized_import_only_location}, skipped"
)
resolved_stage_files.append(
normalize_remote_file_or_dir(
f"{normalized_import_only_location}/{filename_with_prefix}"
)
)
else:
# local directory or .py file
Expand All @@ -723,7 +735,7 @@ def _resolve_imports(
) as input_stream:
self._conn.upload_stream(
input_stream=input_stream,
stage_location=normalized_stage_location,
stage_location=normalized_upload_and_import_location,
dest_filename=filename,
dest_prefix=prefix,
source_compression="DEFLATE",
Expand All @@ -736,17 +748,17 @@ def _resolve_imports(
else:
self._conn.upload_file(
path=path,
stage_location=normalized_stage_location,
stage_location=normalized_upload_and_import_location,
dest_prefix=prefix,
compress_data=False,
overwrite=True,
skip_upload_on_content_match=True,
)
resolved_stage_files.append(
normalize_remote_file_or_dir(
f"{normalized_stage_location}/{filename_with_prefix}"
resolved_stage_files.append(
normalize_remote_file_or_dir(
f"{normalized_upload_and_import_location}/{filename_with_prefix}"
)
)
)

return resolved_stage_files

Expand Down Expand Up @@ -1177,7 +1189,9 @@ def get_req_identifiers_list(
if name in result_dict:
if version is not None:
added_package_has_version = "==" in result_dict[name]
if added_package_has_version and result_dict[name] != str(package):
if added_package_has_version and result_dict[name] != str(
package
):
raise ValueError(
f"Cannot add dependency package '{name}=={version}' "
f"because {result_dict[name]} is already added."
Expand Down
101 changes: 85 additions & 16 deletions tests/integ/test_stored_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,28 +592,24 @@ def test_permanent_sp(session, db_parameters):


@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP")
def test_permanent_sp_negative(session, db_parameters, caplog):
def test_permanent_sp_negative(session, db_parameters):
stage_name = Utils.random_stage_name()
sp_name = Utils.random_name_for_temp_object(TempObjectType.PROCEDURE)
with Session.builder.configs(db_parameters).create() as new_session:
new_session.sql_simplifier_enabled = session.sql_simplifier_enabled
new_session.add_packages("snowflake-snowpark-python")
try:
with caplog.at_level(logging.WARN):
sproc(
lambda session_, x, y: session_.sql(f"SELECT {x} + {y}").collect()[
0
][0],
return_type=IntegerType(),
input_types=[IntegerType(), IntegerType()],
name=sp_name,
is_permanent=False,
stage_location=stage_name,
session=new_session,
)
assert (
"is_permanent is False therefore stage_location will be ignored"
in caplog.text
Utils.create_stage(session, stage_name, is_temporary=False)
sproc(
lambda session_, x, y: session_.sql(f"SELECT {x} + {y}").collect()[0][
0
],
return_type=IntegerType(),
input_types=[IntegerType(), IntegerType()],
name=sp_name,
is_permanent=False,
stage_location=stage_name,
session=new_session,
)

with pytest.raises(
Expand All @@ -623,6 +619,7 @@ def test_permanent_sp_negative(session, db_parameters, caplog):
assert new_session.call(sp_name, 8, 9) == 17
finally:
new_session._run_query(f"drop function if exists {sp_name}(int, int)")
Utils.drop_stage(session, stage_name)


@pytest.mark.skipif(not is_pandas_available, reason="Requires pandas")
Expand Down Expand Up @@ -934,6 +931,78 @@ def hello_sp(session: Session, name: str, age: int) -> DataFrame:
Utils.drop_procedure(session, f"{temp_sp_name}(string, bigint)")


def test_temp_sp_with_import_and_upload_stage(session, resources_path):
"""We want temporary stored procs to be able to do the following:
- Do not upload packages to permanent stage locations
- Can import packages from permanent stage locations
- Can upload packages to temp stages for custom usage
- Import from permanent stage location and upload to temp stage + import from temp stage should
work
"""
stage_name = Utils.random_stage_name()
Utils.create_stage(session, stage_name, is_temporary=False)
test_files = TestFiles(resources_path)
# upload test_sp_dir.test_sp_file (mod5) to permanent stage and use mod3
# file for temporary stage import correctness
session._conn.upload_file(
path=test_files.test_sp_py_file,
stage_location=unwrap_stage_location_single_quote(stage_name),
compress_data=False,
overwrite=True,
skip_upload_on_content_match=True,
)
try:
# Can import packages from permanent stage locations
def mod5_(session_, x):
from test_sp_file import mod5

return mod5(session_, x)

mod5_sproc = sproc(
mod5_,
return_type=IntegerType(),
input_types=[IntegerType()],
imports=[f"@{stage_name}/test_sp_file.py"],
is_permanent=False,
)
assert mod5_sproc(5) == 0

# Can upload packages to temp stages for custom usage
def mod3_(session_, x):
from test_sp_mod3_file import mod3

return mod3(session_, x)

mod3_sproc = sproc(
mod3_,
return_type=IntegerType(),
input_types=[IntegerType()],
imports=[test_files.test_sp_mod3_py_file],
)

assert mod3_sproc(3) == 0

# Import from permanent stage location and upload to temp stage + import
# from temp stage should work
def mod3_of_mod5_(session_, x):
from test_sp_file import mod5
from test_sp_mod3_file import mod3

return mod3(session_, mod5(session_, x))

mod3_of_mod5_sproc = sproc(
mod3_of_mod5_,
return_type=IntegerType(),
input_types=[IntegerType()],
imports=[f"@{stage_name}/test_sp_file.py", test_files.test_sp_mod3_py_file],
)

assert mod3_of_mod5_sproc(4) == 1
finally:
Utils.drop_stage(session, stage_name)
pass


def test_add_import_negative(session, resources_path):
test_files = TestFiles(resources_path)

Expand Down
26 changes: 11 additions & 15 deletions tests/integ/test_udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import datetime
import decimal
import logging
from typing import Any, Dict, List

import pytest
Expand Down Expand Up @@ -414,7 +413,7 @@ def test_register_udaf_from_file_with_type_hints(session, resources_path):


@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP")
def test_permanent_udaf_negative(session, db_parameters, caplog):
def test_permanent_udaf_negative(session, db_parameters):
stage_name = Utils.random_stage_name()
udaf_name = Utils.random_name_for_temp_object(TempObjectType.AGGREGATE_FUNCTION)
df1 = session.create_dataframe([[1, 3], [1, 4], [2, 5], [2, 6]]).to_df("a", "b")
Expand Down Expand Up @@ -442,19 +441,15 @@ def finish(self):
"a", "b"
)
try:
with caplog.at_level(logging.WARN):
sum_udaf = udaf(
PythonSumUDAFHandler,
return_type=IntegerType(),
input_types=[IntegerType()],
name=udaf_name,
is_permanent=False,
stage_location=stage_name,
session=new_session,
)
assert (
"is_permanent is False therefore stage_location will be ignored"
in caplog.text
Utils.create_stage(session, stage_name, is_temporary=False)
sum_udaf = udaf(
PythonSumUDAFHandler,
return_type=IntegerType(),
input_types=[IntegerType()],
name=udaf_name,
is_permanent=False,
stage_location=stage_name,
session=new_session,
)

with pytest.raises(
Expand All @@ -465,6 +460,7 @@ def finish(self):
Utils.check_answer(df2.agg(sum_udaf("a")), [Row(6)])
finally:
new_session._run_query(f"drop function if exists {udaf_name}(int)")
Utils.drop_stage(session, stage_name)


def test_udaf_negative(session):
Expand Down
Loading

0 comments on commit 00ca832

Please sign in to comment.