Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau committed Mar 15, 2024
1 parent 07e9459 commit 9b42f0a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
11 changes: 7 additions & 4 deletions dask_snowflake/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from functools import partial
from typing import Sequence
from typing import Optional, Sequence

import pandas as pd
import pyarrow as pa
Expand All @@ -23,7 +23,10 @@

@delayed
def write_snowflake(
df: pd.DataFrame, name: str, connection_kwargs: dict, write_pandas_kwargs: dict = {}
df: pd.DataFrame,
name: str,
connection_kwargs: dict,
write_pandas_kwargs: Optional[dict] = None,
):
connection_kwargs = {
**{"application": dask.config.get("snowflake.partner", "dask")},
Expand All @@ -37,7 +40,7 @@ def write_snowflake(
# NOTE: since ensure_db_exists uses uppercase for the table name
table_name=name.upper(),
quote_identifiers=False,
**write_pandas_kwargs,
**(write_pandas_kwargs or {}),
)


Expand Down Expand Up @@ -72,7 +75,7 @@ def to_snowflake(
df: dd.DataFrame,
name: str,
connection_kwargs: dict,
write_pandas_kwargs: dict = {},
write_pandas_kwargs: Optional[dict] = None,
compute: bool = True,
):
"""Write a Dask DataFrame to a Snowflake table.
Expand Down
23 changes: 23 additions & 0 deletions dask_snowflake/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,29 @@ def test_arrow_options(table, connection_kwargs, client):
)


def test_write_pandas_kwargs(table, connection_kwargs, client):
to_snowflake(
ddf.repartition(npartitions=1), name=table, connection_kwargs=connection_kwargs
)
# Overwrite existing table
to_snowflake(
ddf.repartition(npartitions=1),
name=table,
connection_kwargs=connection_kwargs,
write_pandas_kwargs={"overwrite": True},
)

query = f"SELECT * FROM {table}"
df_out = read_snowflake(query, connection_kwargs=connection_kwargs, npartitions=2)
# FIXME: Why does read_snowflake return lower-case columns names?
df_out.columns = df_out.columns.str.upper()
# FIXME: We need to sort the DataFrame because paritions are written
# in a non-sequential order.
dd.utils.assert_eq(
df, df_out.sort_values(by="A").reset_index(drop=True), check_dtype=False
)


def test_application_id_default(table, connection_kwargs, monkeypatch):
# Patch Snowflake's normal connection mechanism with checks that
# the expected application ID is set
Expand Down

0 comments on commit 9b42f0a

Please sign in to comment.