Skip to content

Commit

Permalink
SNOW-1803811: Allow mixed-case field names for struct type columns (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jrose authored Dec 16, 2024
1 parent 665bef1 commit 0362c46
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 100 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
- Added support for applying Snowpark Python function `snowflake_cortex_sentiment`.
- Added support for `DataFrame.map`.
- Added support for `DataFrame.from_dict` and `DataFrame.from_records`.
- Added support for mixed case field names in struct type columns.

#### Improvements
- Improve performance of `DataFrame.map`, `Series.apply` and `Series.map` methods by mapping numpy functions to snowpark functions if possible.
Expand Down
6 changes: 5 additions & 1 deletion src/snowflake/snowpark/_internal/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
get_origin,
)

import snowflake.snowpark.context as context
import snowflake.snowpark.types # type: ignore
from snowflake.connector.constants import FIELD_ID_TO_NAME
from snowflake.connector.cursor import ResultMetadata
Expand Down Expand Up @@ -157,9 +158,12 @@ def convert_metadata_to_sp_type(
return StructType(
[
StructField(
quote_name(field.name, keep_case=True),
field.name
if context._should_use_structured_type_semantics
else quote_name(field.name, keep_case=True),
convert_metadata_to_sp_type(field, max_string_size),
nullable=field.is_nullable,
_is_column=False,
)
for field in metadata.fields
],
Expand Down
6 changes: 6 additions & 0 deletions src/snowflake/snowpark/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@
StringType,
TimestampTimeZone,
TimestampType,
ArrayType,
MapType,
StructType,
)
from snowflake.snowpark.window import Window, WindowSpec

Expand Down Expand Up @@ -917,6 +920,9 @@ def _cast(
if isinstance(to, str):
to = type_string_to_type_object(to)

if isinstance(to, (ArrayType, MapType, StructType)):
to = to._as_nested()

if self._ast is None:
_emit_ast = False

Expand Down
4 changes: 4 additions & 0 deletions src/snowflake/snowpark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
_should_continue_registration: Optional[Callable[..., bool]] = None


# Global flag that determines if structured type semantics should be used
_should_use_structured_type_semantics = False


def get_active_session() -> "snowflake.snowpark.Session":
"""Returns the current active Snowpark session.
Expand Down
81 changes: 63 additions & 18 deletions src/snowflake/snowpark/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

# Use correct version from here:
from snowflake.snowpark._internal.utils import installed_pandas, pandas, quote_name
import snowflake.snowpark.context as context

# TODO: connector installed_pandas is broken. If pyarrow is not installed, but pandas is this function returns the wrong answer.
# The core issue is that in the connector detection of both pandas/arrow are mixed, which is wrong.
Expand Down Expand Up @@ -341,6 +342,14 @@ def __init__(
def __repr__(self) -> str:
return f"ArrayType({repr(self.element_type) if self.element_type else ''})"

def _as_nested(self) -> "ArrayType":
if not context._should_use_structured_type_semantics:
return self
element_type = self.element_type
if isinstance(element_type, (ArrayType, MapType, StructType)):
element_type = element_type._as_nested()
return ArrayType(element_type, self.structured)

def is_primitive(self):
return False

Expand Down Expand Up @@ -391,6 +400,14 @@ def __repr__(self) -> str:
def is_primitive(self):
return False

def _as_nested(self) -> "MapType":
if not context._should_use_structured_type_semantics:
return self
value_type = self.value_type
if isinstance(value_type, (ArrayType, MapType, StructType)):
value_type = value_type._as_nested()
return MapType(self.key_type, value_type, self.structured)

@classmethod
def from_json(cls, json_dict: Dict[str, Any]) -> "MapType":
return MapType(
Expand Down Expand Up @@ -552,29 +569,46 @@ def __init__(
column_identifier: Union[ColumnIdentifier, str],
datatype: DataType,
nullable: bool = True,
_is_column: bool = True,
) -> None:
self.column_identifier = (
ColumnIdentifier(column_identifier)
if isinstance(column_identifier, str)
else column_identifier
)
self.name = column_identifier
self._is_column = _is_column
self.datatype = datatype
self.nullable = nullable

@property
def name(self) -> str:
"""Returns the column name."""
return self.column_identifier.name
if self._is_column or not context._should_use_structured_type_semantics:
return self.column_identifier.name
else:
return self._name

@name.setter
def name(self, n: str) -> None:
self.column_identifier = ColumnIdentifier(n)
def name(self, n: Union[ColumnIdentifier, str]) -> None:
if isinstance(n, ColumnIdentifier):
self._name = n.name
self.column_identifier = n
else:
self._name = n
self.column_identifier = ColumnIdentifier(n)

def _as_nested(self) -> "StructField":
if not context._should_use_structured_type_semantics:
return self
datatype = self.datatype
if isinstance(datatype, (ArrayType, MapType, StructType)):
datatype = datatype._as_nested()
# Nested StructFields do not follow column naming conventions
return StructField(self._name, datatype, self.nullable, _is_column=False)

def __repr__(self) -> str:
return f"StructField({self.name!r}, {repr(self.datatype)}, nullable={self.nullable})"

def __eq__(self, other):
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
return isinstance(other, self.__class__) and (
(self.name, self._is_column, self.datatype, self.nullable)
== (other.name, other._is_column, other.datatype, other.nullable)
)

@classmethod
def from_json(cls, json_dict: Dict[str, Any]) -> "StructField":
Expand Down Expand Up @@ -620,30 +654,41 @@ def __init__(
self, fields: Optional[List["StructField"]] = None, structured=False
) -> None:
self.structured = structured
if fields is None:
fields = []
self.fields = fields
self.fields = []
for field in fields or []:
self.add(field)

def add(
self,
field: Union[str, ColumnIdentifier, "StructField"],
datatype: Optional[DataType] = None,
nullable: Optional[bool] = True,
) -> "StructType":
if isinstance(field, StructField):
self.fields.append(field)
elif isinstance(field, (str, ColumnIdentifier)):
if isinstance(field, (str, ColumnIdentifier)):
if datatype is None:
raise ValueError(
"When field argument is str or ColumnIdentifier, datatype must not be None."
)
self.fields.append(StructField(field, datatype, nullable))
else:
field = StructField(field, datatype, nullable)
elif not isinstance(field, StructField):
raise ValueError(
f"field argument must be one of str, ColumnIdentifier or StructField. Got: '{type(field)}'"
)

# Nested data does not follow the same schema conventions as top level fields.
if isinstance(field.datatype, (ArrayType, MapType, StructType)):
field.datatype = field.datatype._as_nested()

self.fields.append(field)
return self

def _as_nested(self) -> "StructType":
if not context._should_use_structured_type_semantics:
return self
return StructType(
[field._as_nested() for field in self.fields], self.structured
)

@classmethod
def _from_attributes(cls, attributes: list) -> "StructType":
return cls([StructField(a.name, a.datatype, a.nullable) for a in attributes])
Expand Down
Loading

0 comments on commit 0362c46

Please sign in to comment.