diff --git a/CHANGELOG.md b/CHANGELOG.md index 7794f3757cf..280458671c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,8 @@ - Added support for following functions in `functions.py`: - `size` to get size of array, object, or map columns. - `collect_list` an alias of `array_agg`. + - `concat_ws_ignore_nulls` to concatenate strings with a separator, ignoring null values. + - `substring` makes `len` argument optional. - Added parameter `ast_enabled` to session for internal usage (default: `False`). #### Improvements diff --git a/docs/source/snowpark/functions.rst b/docs/source/snowpark/functions.rst index 71e83093565..ba97c62a93a 100644 --- a/docs/source/snowpark/functions.rst +++ b/docs/source/snowpark/functions.rst @@ -97,6 +97,7 @@ Functions column concat concat_ws + concat_ws_ignore_nulls contains convert_timezone corr diff --git a/src/snowflake/snowpark/functions.py b/src/snowflake/snowpark/functions.py index 8c08810efcb..38b3d0e1359 100644 --- a/src/snowflake/snowpark/functions.py +++ b/src/snowflake/snowpark/functions.py @@ -3017,7 +3017,7 @@ def split(str: ColumnOrName, pattern: ColumnOrName, _emit_ast: bool = True) -> C def substring( str: ColumnOrName, pos: Union[Column, int], - len: Union[Column, int], + len: Optional[Union[Column, int]] = None, _emit_ast: bool = True, ) -> Column: """Returns the portion of the string or binary value str, starting from the @@ -3030,16 +3030,26 @@ def substring( :func:`substr` is an alias of :func:`substring`. - Example:: + Example 1:: >>> df = session.create_dataframe( ... ["abc", "def"], ... schema=["S"], - ... ).select(substring(col("S"), 1, 1)) - >>> df.collect() + ... ) + >>> df.select(substring(col("S"), 1, 1)).collect() [Row(SUBSTRING("S", 1, 1)='a'), Row(SUBSTRING("S", 1, 1)='d')] + + Example 2:: + >>> df = session.create_dataframe( + ... ["abc", "def"], + ... schema=["S"], + ... ) + >>> df.select(substring(col("S"), 2)).collect() + [Row(SUBSTRING("S", 2)='bc'), Row(SUBSTRING("S", 2)='ef')] """ s = _to_col_if_str(str, "substring") p = pos if isinstance(pos, Column) else lit(pos, _emit_ast=_emit_ast) + if len is None: + return builtin("substring", _emit_ast=_emit_ast)(s, p) length = len if isinstance(len, Column) else lit(len, _emit_ast=_emit_ast) return builtin("substring", _emit_ast=_emit_ast)(s, p, length) @@ -3392,6 +3402,46 @@ def concat_ws(*cols: ColumnOrName, _emit_ast: bool = True) -> Column: return builtin("concat_ws", _emit_ast=_emit_ast)(*columns) +@publicapi +def concat_ws_ignore_nulls( + sep: str, *cols: ColumnOrName, _emit_ast: bool = True +) -> Column: + """Concatenates two or more strings, or concatenates two or more binary values. Null values are ignored. + + Args: + sep: The separator to use between the strings. + + Examples:: + >>> df = session.create_dataframe([ + ... ['Hello', 'World', None], + ... [None, None, None], + ... ['Hello', None, None], + ... ], schema=['a', 'b', 'c']) + >>> df.select(concat_ws_ignore_nulls(',', df.a, df.b, df.c)).show() + ---------------------------------------------------- + |"CONCAT_WS_IGNORE_NULLS(',', ""A"",""B"",""C"")" | + ---------------------------------------------------- + |Hello,World | + | | + |Hello | + ---------------------------------------------------- + + """ + # TODO: SNOW-1831917 create ast + columns = [_to_col_if_str(c, "concat_ws_ignore_nulls") for c in cols] + names = ",".join([c.get_name() for c in columns]) + + input_column_array = array_construct_compact(*columns, _emit_ast=False) + reduced_result = builtin("reduce", _emit_ast=False)( + input_column_array, + lit("", _emit_ast=False), + sql_expr(f"(l, r) -> l || '{sep}' || r", _emit_ast=False), + ) + return substring(reduced_result, 2, _emit_ast=False).alias( + f"CONCAT_WS_IGNORE_NULLS('{sep}', {names})", _emit_ast=False + ) + + @publicapi def translate( src: ColumnOrName,