From 0362c46b96f33692bdabeda574d99ef66b61c4f6 Mon Sep 17 00:00:00 2001 From: Jamison Rose Date: Mon, 16 Dec 2024 15:26:22 -0800 Subject: [PATCH] SNOW-1803811: Allow mixed-case field names for struct type columns (#2640) --- CHANGELOG.md | 1 + .../snowpark/_internal/type_utils.py | 6 +- src/snowflake/snowpark/column.py | 6 + src/snowflake/snowpark/context.py | 4 + src/snowflake/snowpark/types.py | 81 ++++++-- tests/integ/scala/test_datatype_suite.py | 177 ++++++++++-------- tests/integ/test_stored_procedure.py | 4 +- 7 files changed, 179 insertions(+), 100 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3945c9e955d..51573adafa7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ - Added support for applying Snowpark Python function `snowflake_cortex_sentiment`. - Added support for `DataFrame.map`. - Added support for `DataFrame.from_dict` and `DataFrame.from_records`. +- Added support for mixed case field names in struct type columns. #### Improvements - Improve performance of `DataFrame.map`, `Series.apply` and `Series.map` methods by mapping numpy functions to snowpark functions if possible. diff --git a/src/snowflake/snowpark/_internal/type_utils.py b/src/snowflake/snowpark/_internal/type_utils.py index 55fe27c9f8f..3d1095132ab 100644 --- a/src/snowflake/snowpark/_internal/type_utils.py +++ b/src/snowflake/snowpark/_internal/type_utils.py @@ -30,6 +30,7 @@ get_origin, ) +import snowflake.snowpark.context as context import snowflake.snowpark.types # type: ignore from snowflake.connector.constants import FIELD_ID_TO_NAME from snowflake.connector.cursor import ResultMetadata @@ -157,9 +158,12 @@ def convert_metadata_to_sp_type( return StructType( [ StructField( - quote_name(field.name, keep_case=True), + field.name + if context._should_use_structured_type_semantics + else quote_name(field.name, keep_case=True), convert_metadata_to_sp_type(field, max_string_size), nullable=field.is_nullable, + _is_column=False, ) for field in metadata.fields ], diff --git a/src/snowflake/snowpark/column.py b/src/snowflake/snowpark/column.py index 3f32c981203..7e9588d9d5d 100644 --- a/src/snowflake/snowpark/column.py +++ b/src/snowflake/snowpark/column.py @@ -91,6 +91,9 @@ StringType, TimestampTimeZone, TimestampType, + ArrayType, + MapType, + StructType, ) from snowflake.snowpark.window import Window, WindowSpec @@ -917,6 +920,9 @@ def _cast( if isinstance(to, str): to = type_string_to_type_object(to) + if isinstance(to, (ArrayType, MapType, StructType)): + to = to._as_nested() + if self._ast is None: _emit_ast = False diff --git a/src/snowflake/snowpark/context.py b/src/snowflake/snowpark/context.py index c8f6888c5bd..8bc86f928a1 100644 --- a/src/snowflake/snowpark/context.py +++ b/src/snowflake/snowpark/context.py @@ -21,6 +21,10 @@ _should_continue_registration: Optional[Callable[..., bool]] = None +# Global flag that determines if structured type semantics should be used +_should_use_structured_type_semantics = False + + def get_active_session() -> "snowflake.snowpark.Session": """Returns the current active Snowpark session. diff --git a/src/snowflake/snowpark/types.py b/src/snowflake/snowpark/types.py index 06bcc8969b5..333fc580f60 100644 --- a/src/snowflake/snowpark/types.py +++ b/src/snowflake/snowpark/types.py @@ -16,6 +16,7 @@ # Use correct version from here: from snowflake.snowpark._internal.utils import installed_pandas, pandas, quote_name +import snowflake.snowpark.context as context # TODO: connector installed_pandas is broken. If pyarrow is not installed, but pandas is this function returns the wrong answer. # The core issue is that in the connector detection of both pandas/arrow are mixed, which is wrong. @@ -341,6 +342,14 @@ def __init__( def __repr__(self) -> str: return f"ArrayType({repr(self.element_type) if self.element_type else ''})" + def _as_nested(self) -> "ArrayType": + if not context._should_use_structured_type_semantics: + return self + element_type = self.element_type + if isinstance(element_type, (ArrayType, MapType, StructType)): + element_type = element_type._as_nested() + return ArrayType(element_type, self.structured) + def is_primitive(self): return False @@ -391,6 +400,14 @@ def __repr__(self) -> str: def is_primitive(self): return False + def _as_nested(self) -> "MapType": + if not context._should_use_structured_type_semantics: + return self + value_type = self.value_type + if isinstance(value_type, (ArrayType, MapType, StructType)): + value_type = value_type._as_nested() + return MapType(self.key_type, value_type, self.structured) + @classmethod def from_json(cls, json_dict: Dict[str, Any]) -> "MapType": return MapType( @@ -552,29 +569,46 @@ def __init__( column_identifier: Union[ColumnIdentifier, str], datatype: DataType, nullable: bool = True, + _is_column: bool = True, ) -> None: - self.column_identifier = ( - ColumnIdentifier(column_identifier) - if isinstance(column_identifier, str) - else column_identifier - ) + self.name = column_identifier + self._is_column = _is_column self.datatype = datatype self.nullable = nullable @property def name(self) -> str: - """Returns the column name.""" - return self.column_identifier.name + if self._is_column or not context._should_use_structured_type_semantics: + return self.column_identifier.name + else: + return self._name @name.setter - def name(self, n: str) -> None: - self.column_identifier = ColumnIdentifier(n) + def name(self, n: Union[ColumnIdentifier, str]) -> None: + if isinstance(n, ColumnIdentifier): + self._name = n.name + self.column_identifier = n + else: + self._name = n + self.column_identifier = ColumnIdentifier(n) + + def _as_nested(self) -> "StructField": + if not context._should_use_structured_type_semantics: + return self + datatype = self.datatype + if isinstance(datatype, (ArrayType, MapType, StructType)): + datatype = datatype._as_nested() + # Nested StructFields do not follow column naming conventions + return StructField(self._name, datatype, self.nullable, _is_column=False) def __repr__(self) -> str: return f"StructField({self.name!r}, {repr(self.datatype)}, nullable={self.nullable})" def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + return isinstance(other, self.__class__) and ( + (self.name, self._is_column, self.datatype, self.nullable) + == (other.name, other._is_column, other.datatype, other.nullable) + ) @classmethod def from_json(cls, json_dict: Dict[str, Any]) -> "StructField": @@ -620,9 +654,9 @@ def __init__( self, fields: Optional[List["StructField"]] = None, structured=False ) -> None: self.structured = structured - if fields is None: - fields = [] - self.fields = fields + self.fields = [] + for field in fields or []: + self.add(field) def add( self, @@ -630,20 +664,31 @@ def add( datatype: Optional[DataType] = None, nullable: Optional[bool] = True, ) -> "StructType": - if isinstance(field, StructField): - self.fields.append(field) - elif isinstance(field, (str, ColumnIdentifier)): + if isinstance(field, (str, ColumnIdentifier)): if datatype is None: raise ValueError( "When field argument is str or ColumnIdentifier, datatype must not be None." ) - self.fields.append(StructField(field, datatype, nullable)) - else: + field = StructField(field, datatype, nullable) + elif not isinstance(field, StructField): raise ValueError( f"field argument must be one of str, ColumnIdentifier or StructField. Got: '{type(field)}'" ) + + # Nested data does not follow the same schema conventions as top level fields. + if isinstance(field.datatype, (ArrayType, MapType, StructType)): + field.datatype = field.datatype._as_nested() + + self.fields.append(field) return self + def _as_nested(self) -> "StructType": + if not context._should_use_structured_type_semantics: + return self + return StructType( + [field._as_nested() for field in self.fields], self.structured + ) + @classmethod def _from_attributes(cls, attributes: list) -> "StructType": return cls([StructField(a.name, a.datatype, a.nullable) for a in attributes]) diff --git a/tests/integ/scala/test_datatype_suite.py b/tests/integ/scala/test_datatype_suite.py index 63c24b7e35c..a1bd1d48acd 100644 --- a/tests/integ/scala/test_datatype_suite.py +++ b/tests/integ/scala/test_datatype_suite.py @@ -9,6 +9,7 @@ import pytest +import snowflake.snowpark.context as context from snowflake.connector.options import installed_pandas from snowflake.snowpark import Row from snowflake.snowpark.exceptions import SnowparkSQLException @@ -57,24 +58,28 @@ # make sure dataframe creation is the same as _create_test_dataframe -_STRUCTURE_DATAFRAME_QUERY = """ +_STRUCTURED_DATAFRAME_QUERY = """ select object_construct('k1', 1) :: map(varchar, int) as map, - object_construct('A', 'foo', 'B', 0.05) :: object(A varchar, B float) as obj, + object_construct('A', 'foo', 'b', 0.05) :: object(A varchar, b float) as obj, [1.0, 3.1, 4.5] :: array(float) as arr """ -# make sure dataframe creation is the same as _STRUCTURE_DATAFRAME_QUERY -def _create_test_dataframe(s): +# make sure dataframe creation is the same as _STRUCTURED_DATAFRAME_QUERY +def _create_test_dataframe(s, structured_type_support): + nested_field_name = "b" if structured_type_support else "B" df = s.create_dataframe([1], schema=["a"]).select( object_construct(lit("k1"), lit(1)) .cast(MapType(StringType(), IntegerType(), structured=True)) .alias("map"), - object_construct(lit("A"), lit("foo"), lit("B"), lit(0.05)) + object_construct(lit("A"), lit("foo"), lit(nested_field_name), lit(0.05)) .cast( StructType( - [StructField("A", StringType()), StructField("B", DoubleType())], + [ + StructField("A", StringType()), + StructField(nested_field_name, DoubleType()), + ], structured=True, ) ) @@ -86,55 +91,6 @@ def _create_test_dataframe(s): return df -STRUCTURED_TYPES_EXAMPLES = { - True: ( - _STRUCTURE_DATAFRAME_QUERY, - [ - ("MAP", "map"), - ("OBJ", "struct"), - ("ARR", "array"), - ], - StructType( - [ - StructField( - "MAP", - MapType(StringType(16777216), LongType(), structured=True), - nullable=True, - ), - StructField( - "OBJ", - StructType( - [ - StructField("A", StringType(16777216), nullable=True), - StructField("B", DoubleType(), nullable=True), - ], - structured=True, - ), - nullable=True, - ), - StructField( - "ARR", ArrayType(DoubleType(), structured=True), nullable=True - ), - ] - ), - ), - False: ( - _STRUCTURE_DATAFRAME_QUERY, - [ - ("MAP", "map"), - ("OBJ", "map"), - ("ARR", "array"), - ], - StructType( - [ - StructField("MAP", MapType(StringType(), StringType()), nullable=True), - StructField("OBJ", MapType(StringType(), StringType()), nullable=True), - StructField("ARR", ArrayType(StringType()), nullable=True), - ] - ), - ), -} - ICEBERG_CONFIG = { "catalog": "SNOWFLAKE", "external_volume": "python_connector_iceberg_exvol", @@ -142,6 +98,61 @@ def _create_test_dataframe(s): } +def _create_example(structured_types_enabled): + if structured_types_enabled: + return ( + _STRUCTURED_DATAFRAME_QUERY, + [ + ("MAP", "map"), + ("OBJ", "struct"), + ("ARR", "array"), + ], + StructType( + [ + StructField( + "MAP", + MapType(StringType(16777216), LongType(), structured=True), + nullable=True, + ), + StructField( + "OBJ", + StructType( + [ + StructField("A", StringType(16777216), nullable=True), + StructField("b", DoubleType(), nullable=True), + ], + structured=True, + ), + nullable=True, + ), + StructField( + "ARR", ArrayType(DoubleType(), structured=True), nullable=True + ), + ] + ), + ) + else: + return ( + _STRUCTURED_DATAFRAME_QUERY, + [ + ("MAP", "map"), + ("OBJ", "map"), + ("ARR", "array"), + ], + StructType( + [ + StructField( + "MAP", MapType(StringType(), StringType()), nullable=True + ), + StructField( + "OBJ", MapType(StringType(), StringType()), nullable=True + ), + StructField("ARR", ArrayType(StringType()), nullable=True), + ] + ), + ) + + @pytest.fixture(scope="module") def structured_type_support(session, local_testing_mode): yield structured_types_supported(session, local_testing_mode) @@ -149,14 +160,17 @@ def structured_type_support(session, local_testing_mode): @pytest.fixture(scope="module") def examples(structured_type_support): - yield STRUCTURED_TYPES_EXAMPLES[structured_type_support] + yield _create_example(structured_type_support) @pytest.fixture(scope="module") def structured_type_session(session, structured_type_support): if structured_type_support: with structured_types_enabled_session(session) as sess: + semantics_enabled = context._should_use_structured_type_semantics + context._should_use_structured_type_semantics = True yield sess + context._should_use_structured_type_semantics = semantics_enabled else: yield session @@ -365,9 +379,9 @@ def test_dtypes(session): "config.getoption('local_testing_mode', default=False)", reason="FEAT: SNOW-1372813 Cast to StructType not supported", ) -def test_structured_dtypes(structured_type_session, examples): +def test_structured_dtypes(structured_type_session, examples, structured_type_support): query, expected_dtypes, expected_schema = examples - df = _create_test_dataframe(structured_type_session) + df = _create_test_dataframe(structured_type_session, structured_type_support) assert df.schema == expected_schema assert df.dtypes == expected_dtypes @@ -380,13 +394,16 @@ def test_structured_dtypes(structured_type_session, examples): "config.getoption('local_testing_mode', default=False)", reason="FEAT: SNOW-1372813 Cast to StructType not supported", ) -def test_structured_dtypes_select(structured_type_session, examples): +def test_structured_dtypes_select( + structured_type_session, examples, structured_type_support +): query, expected_dtypes, expected_schema = examples - df = _create_test_dataframe(structured_type_session) + df = _create_test_dataframe(structured_type_session, structured_type_support) + nested_field_name = "b" if context._should_use_structured_type_semantics else "B" flattened_df = df.select( df.map["k1"].alias("value1"), df.obj["A"].alias("a"), - col("obj")["B"].alias("b"), + col("obj")[nested_field_name].alias("b"), df.arr[0].alias("value2"), df.arr[1].alias("value3"), col("arr")[2].alias("value4"), @@ -395,7 +412,7 @@ def test_structured_dtypes_select(structured_type_session, examples): [ StructField("VALUE1", LongType(), nullable=True), StructField("A", StringType(16777216), nullable=True), - StructField("B", DoubleType(), nullable=True), + StructField(nested_field_name, DoubleType(), nullable=True), StructField("VALUE2", DoubleType(), nullable=True), StructField("VALUE3", DoubleType(), nullable=True), StructField("VALUE4", DoubleType(), nullable=True), @@ -420,11 +437,13 @@ def test_structured_dtypes_select(structured_type_session, examples): reason="FEAT: SNOW-1372813 Cast to StructType not supported", ) def test_structured_dtypes_pandas(structured_type_session, structured_type_support): - pdf = _create_test_dataframe(structured_type_session).to_pandas() + pdf = _create_test_dataframe( + structured_type_session, structured_type_support + ).to_pandas() if structured_type_support: assert ( pdf.to_json() - == '{"MAP":{"0":[["k1",1.0]]},"OBJ":{"0":{"A":"foo","B":0.05}},"ARR":{"0":[1.0,3.1,4.5]}}' + == '{"MAP":{"0":[["k1",1.0]]},"OBJ":{"0":{"A":"foo","b":0.05}},"ARR":{"0":[1.0,3.1,4.5]}}' ) else: assert ( @@ -445,7 +464,7 @@ def test_structured_dtypes_iceberg( and iceberg_supported(structured_type_session, local_testing_mode) ): pytest.skip("Test requires iceberg support and structured type support.") - query, expected_dtypes, expected_schema = STRUCTURED_TYPES_EXAMPLES[True] + query, expected_dtypes, expected_schema = _create_example(True) table_name = f"snowpark_structured_dtypes_{uuid.uuid4().hex[:5]}".upper() dynamic_table_name = f"snowpark_dynamic_iceberg_{uuid.uuid4().hex[:5]}".upper() @@ -467,7 +486,7 @@ def test_structured_dtypes_iceberg( ) assert save_ddl[0][0] == ( f"create or replace ICEBERG TABLE {table_name.upper()} (\n\t" - "MAP MAP(STRING, LONG),\n\tOBJ OBJECT(A STRING, B DOUBLE),\n\tARR ARRAY(DOUBLE)\n)\n " + "MAP MAP(STRING, LONG),\n\tOBJ OBJECT(A STRING, b DOUBLE),\n\tARR ARRAY(DOUBLE)\n)\n " "EXTERNAL_VOLUME = 'PYTHON_CONNECTOR_ICEBERG_EXVOL'\n CATALOG = 'SNOWFLAKE'\n " "BASE_LOCATION = 'python_connector_merge_gate/';" ) @@ -524,27 +543,27 @@ def test_iceberg_nested_fields( "NESTED_DATA", StructType( [ - StructField('"camelCase"', StringType(), nullable=True), - StructField('"snake_case"', StringType(), nullable=True), - StructField('"PascalCase"', StringType(), nullable=True), + StructField("camelCase", StringType(), nullable=True), + StructField("snake_case", StringType(), nullable=True), + StructField("PascalCase", StringType(), nullable=True), StructField( - '"nested_map"', + "nested_map", MapType( StringType(), StructType( [ StructField( - '"inner_camelCase"', + "inner_camelCase", StringType(), nullable=True, ), StructField( - '"inner_snake_case"', + "inner_snake_case", StringType(), nullable=True, ), StructField( - '"inner_PascalCase"', + "inner_PascalCase", StringType(), nullable=True, ), @@ -730,11 +749,11 @@ def test_structured_dtypes_iceberg_create_from_values( ): pytest.skip("Test requires iceberg support and structured type support.") - _, __, expected_schema = STRUCTURED_TYPES_EXAMPLES[True] + _, __, expected_schema = _create_example(True) table_name = f"snowpark_structured_dtypes_{uuid.uuid4().hex[:5]}" data = [ - ({"x": 1}, {"A": "a", "B": 1}, [1, 1, 1]), - ({"x": 2}, {"A": "b", "B": 2}, [2, 2, 2]), + ({"x": 1}, {"A": "a", "b": 1}, [1, 1, 1]), + ({"x": 2}, {"A": "b", "b": 2}, [2, 2, 2]), ] try: create_df = structured_type_session.create_dataframe( @@ -760,7 +779,7 @@ def test_structured_dtypes_iceberg_udf( and iceberg_supported(structured_type_session, local_testing_mode) ): pytest.skip("Test requires iceberg support and structured type support.") - query, expected_dtypes, expected_schema = STRUCTURED_TYPES_EXAMPLES[True] + query, expected_dtypes, expected_schema = _create_example(True) table_name = f"snowpark_structured_dtypes_udf_test{uuid.uuid4().hex[:5]}" @@ -945,8 +964,8 @@ def test_structured_type_print_schema( " | |-- key: StringType()\n" " | |-- value: ArrayType\n" " | | |-- element: StructType\n" - ' | | | |-- "FIELD1": StringType() (nullable = True)\n' - ' | | | |-- "FIELD2": LongType() (nullable = True)\n' + ' | | | |-- "Field1": StringType() (nullable = True)\n' + ' | | | |-- "Field2": LongType() (nullable = True)\n' ) # Test that depth works as expected diff --git a/tests/integ/test_stored_procedure.py b/tests/integ/test_stored_procedure.py index 20c63d78642..9345bca0bb8 100644 --- a/tests/integ/test_stored_procedure.py +++ b/tests/integ/test_stored_procedure.py @@ -388,8 +388,8 @@ def test_stored_procedure_with_structured_returns( "OBJ", StructType( [ - StructField('"a"', StringType(16777216), nullable=True), - StructField('"b"', DoubleType(), nullable=True), + StructField("a", StringType(16777216), nullable=True), + StructField("b", DoubleType(), nullable=True), ], structured=True, ),