Skip to content

Commit

Permalink
Fix invalid query pushdown (#11)
Browse files Browse the repository at this point in the history
* Initial testcase

* Correct test

* Attempt to create a new version of normal `.join`

* Delete mess

* Switch to new spark extensions framework

* First explicit pitJoin compiles

* In progress: create new entrypoint to PIT join

* Working test

* All scala tests working

* Create python side joinPIT to replace context

* Rename some paramaters

* Update parameter order in tests

* Orgnise imports

* Tidy
  • Loading branch information
Tom-Newton authored Jan 18, 2024
1 parent 67a454e commit 40ef962
Show file tree
Hide file tree
Showing 12 changed files with 235 additions and 395 deletions.
2 changes: 1 addition & 1 deletion python/ackuq/pit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@
# SOFTWARE.
#

from ackuq.pit.context import PitContext # noqa: F401
from ackuq.pit.joinPIT import joinPIT # noqa: F401
187 changes: 0 additions & 187 deletions python/ackuq/pit/context.py

This file was deleted.

46 changes: 46 additions & 0 deletions python/ackuq/pit/joinPIT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#
# MIT License
#
# Copyright (c) 2022 Axel Pettersson
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#

from pyspark.sql import Column, DataFrame


def joinPIT(
left: DataFrame,
right: DataFrame,
leftPitKey: Column,
rightPitKey: Column,
on: Column,
how: str = "inner",
tolerance: int = 0,
) -> DataFrame:
jdf = left.sparkSession.sparkContext._jvm.io.github.ackuq.pit.EarlyStopSortMerge.joinPIT(
left._jdf,
right._jdf,
leftPitKey._jc,
rightPitKey._jc,
on._jc,
how,
tolerance,
)
return DataFrame(jdf, left.sparkSession)
43 changes: 17 additions & 26 deletions python/tests/test_sort_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,22 @@
# SOFTWARE.
#

from ackuq.pit.joinPIT import joinPIT
from tests.data import SmallDataSortMerge
from tests.utils import SparkTests


class SortMergeUnionAsOfTest(SparkTests):
def setUp(self) -> None:
super().setUp()
self.small_data = SmallDataSortMerge(self.spark)
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
cls.small_data = SmallDataSortMerge(cls.spark)

def test_two_aligned(self):
fg1 = self.small_data.fg1
fg2 = self.small_data.fg2

pit_join = fg1.join(
fg2,
self.pit_context.pit_udf(fg1["ts"], fg2["ts"]) & (fg1["id"] == fg2["id"]),
)
pit_join = joinPIT(fg1, fg2, fg1["ts"], fg2["ts"], (fg1["id"] == fg2["id"]))

self.assertSchemaEqual(pit_join.schema, self.small_data.PIT_1_2.schema)
self.assertEqual(pit_join.collect(), self.small_data.PIT_1_2.collect())
Expand All @@ -47,10 +46,7 @@ def test_two_misaligned(self):
fg1 = self.small_data.fg1
fg2 = self.small_data.fg3

pit_join = fg1.join(
fg2,
self.pit_context.pit_udf(fg1["ts"], fg2["ts"]) & (fg1["id"] == fg2["id"]),
)
pit_join = joinPIT(fg1, fg2, fg1["ts"], fg2["ts"], (fg1["id"] == fg2["id"]))

self.assertSchemaEqual(pit_join.schema, self.small_data.PIT_1_3.schema)
self.assertEqual(pit_join.collect(), self.small_data.PIT_1_3.collect())
Expand All @@ -60,15 +56,15 @@ def test_three_misaligned(self):
fg2 = self.small_data.fg2
fg3 = self.small_data.fg3

left = fg1.join(
left = joinPIT(
fg1,
fg2,
self.pit_context.pit_udf(fg1["ts"], fg2["ts"]) & (fg1["id"] == fg2["id"]),
fg1["ts"],
fg2["ts"],
(fg1["id"] == fg2["id"]),
)

pit_join = left.join(
fg3,
self.pit_context.pit_udf(fg1["ts"], fg3["ts"]) & (fg1["id"] == fg3["id"]),
)
pit_join = joinPIT(left, fg3, fg1["ts"], fg3["ts"], (fg1["id"] == fg3["id"]))

self.assertSchemaEqual(pit_join.schema, self.small_data.PIT_1_2_3.schema)
self.assertEqual(pit_join.collect(), self.small_data.PIT_1_2_3.collect())
Expand All @@ -77,10 +73,8 @@ def test_two_tolerance(self):
fg1 = self.small_data.fg1
fg2 = self.small_data.fg3

pit_join = fg1.join(
fg2,
self.pit_context.pit_udf(fg1["ts"], fg2["ts"], 1)
& (fg1["id"] == fg2["id"]),
pit_join = joinPIT(
fg1, fg2, fg1["ts"], fg2["ts"], (fg1["id"] == fg2["id"]), tolerance=1
)
self.assertSchemaEqual(pit_join.schema, self.small_data.PIT_1_3_T1.schema)
self.assertEqual(pit_join.collect(), self.small_data.PIT_1_3_T1.collect())
Expand All @@ -89,11 +83,8 @@ def test_two_tolerance_outer(self):
fg1 = self.small_data.fg1
fg2 = self.small_data.fg3

pit_join = fg1.join(
fg2,
self.pit_context.pit_udf(fg1["ts"], fg2["ts"], 1)
& (fg1["id"] == fg2["id"]),
"left",
pit_join = joinPIT(
fg1, fg2, fg1["ts"], fg2["ts"], (fg1["id"] == fg2["id"]), "left", 1
)
self.assertSchemaEqual(pit_join.schema, self.small_data.PIT_1_3_T1_OUTER.schema)
self.assertEqual(pit_join.collect(), self.small_data.PIT_1_3_T1_OUTER.collect())
23 changes: 12 additions & 11 deletions python/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,30 @@
import os
import unittest

from pyspark import SQLContext
from pyspark.sql import SparkSession
from pyspark.sql.types import StructField, StructType

from ackuq.pit.context import PitContext


class SparkTests(unittest.TestCase):
def setUp(self) -> None:
self.jar_location = os.environ["SCALA_PIT_JAR"]
print("Loading jar from location: {}".format(self.jar_location))
self.spark = (
spark: SparkSession

@classmethod
def setUpClass(cls) -> None:
jar_location = os.environ["SCALA_PIT_JAR"]
print("Loading jar from location: {}".format(jar_location))
cls.spark = (
SparkSession.builder.appName("sparkTests")
.master("local")
.config("spark.ui.showConsoleProgress", False)
.config("spark.driver.extraClassPath", self.jar_location)
.config("spark.driver.extraClassPath", jar_location)
.config("spark.sql.shuffle.partitions", 1)
.config("spark.sql.extensions", "io.github.ackuq.pit.SparkPIT")
.getOrCreate()
)
self.pit_context = PitContext(self.spark)

def tearDown(self) -> None:
self.spark.stop()
@classmethod
def tearDownClass(cls) -> None:
cls.spark.stop()

def _assertFieldsEqual(self, a: StructField, b: StructField):
self.assertEqual(a.name.lower(), b.name.lower())
Expand Down
Loading

0 comments on commit 40ef962

Please sign in to comment.