Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Improve Schema Diff Error Message" #165

Merged
merged 10 commits into from
Oct 12, 2024
Original file line number Diff line number Diff line change
@@ -1,70 +1,59 @@
package com.github.mrpowers.spark.fast.tests

import com.github.mrpowers.spark.fast.tests.ufansi.Color.{DarkGray, Green, Red}
import com.github.mrpowers.spark.fast.tests.ufansi.FansiExtensions.StrOps
import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.Row
import com.github.mrpowers.spark.fast.tests.ufansi.FansiExtensions.StrOps
object DataframeUtil {

import scala.reflect.ClassTag

object ProductUtil {
private[mrpowers] def productOrRowToSeq(product: Any): Seq[Any] = {
product match {
case null => Seq.empty
case a: Array[_] => a
case i: Iterable[_] => i.toSeq
case r: Row => r.toSeq
case p: Product => p.productIterator.toSeq
case s => Seq(s)
}
}
private[mrpowers] def showProductDiff[T: ClassTag](
private[mrpowers] def showDataframeDiff(
header: (String, String),
actual: Seq[T],
expected: Seq[T],
actual: Seq[Row],
expected: Seq[Row],
truncate: Int = 20,
minColWidth: Int = 3
): String = {

val runTimeClass = implicitly[ClassTag[T]].runtimeClass
val (className, lBracket, rBracket) = if (runTimeClass == classOf[Row]) ("", "[", "]") else (runTimeClass.getSimpleName, "(", ")")
val prodToString: Seq[Any] => String = s => s.mkString(s"$className$lBracket", ",", rBracket)
val emptyProd = "MISSING"

val sb = new StringBuilder

val fullJoin = actual.zipAll(expected, null, null)

val fullJoin = actual.zipAll(expected, Row(), Row())
val diff = fullJoin.map { case (actualRow, expectedRow) =>
if (actualRow == expectedRow) {
if (equals(actualRow, expectedRow)) {
List(DarkGray(actualRow.toString), DarkGray(expectedRow.toString))
} else {
val actualSeq = productOrRowToSeq(actualRow)
val expectedSeq = productOrRowToSeq(expectedRow)
val actualSeq = actualRow.toSeq
val expectedSeq = expectedRow.toSeq
if (actualSeq.isEmpty)
List(Red(emptyProd), Green(prodToString(expectedSeq)))
List(
Red("[]"),
Green(expectedSeq.mkString("[", ",", "]"))
)
else if (expectedSeq.isEmpty)
List(Red(prodToString(actualSeq)), Green(emptyProd))
List(Red(actualSeq.mkString("[", ",", "]")), Green("[]"))
else {
val withEquals = actualSeq
.zipAll(expectedSeq, "MISSING", "MISSING")
.zip(expectedSeq)
.map { case (actualRowField, expectedRowField) =>
(actualRowField, expectedRowField, actualRowField == expectedRowField)
}
val allFieldsAreNotEqual = !withEquals.exists(_._3)
if (allFieldsAreNotEqual) {
List(Red(prodToString(actualSeq)), Green(prodToString(expectedSeq)))
List(
Red(actualSeq.mkString("[", ",", "]")),
Green(expectedSeq.mkString("[", ",", "]"))
)
} else {

val coloredDiff = withEquals
.map {
case (actualRowField, expectedRowField, true) =>
(DarkGray(actualRowField.toString), DarkGray(expectedRowField.toString))
case (actualRowField, expectedRowField, false) =>
(Red(actualRowField.toString), Green(expectedRowField.toString))
}
val start = DarkGray(s"$className$lBracket")
val start = DarkGray("[")
val sep = DarkGray(",")
val end = DarkGray(rBracket)
val end = DarkGray("]")
List(
coloredDiff.map(_._1).mkStr(start, sep, end),
coloredDiff.map(_._2).mkStr(start, sep, end)
Expand All @@ -80,12 +69,11 @@ object ProductUtil {
val colWidths = Array.fill(numCols)(minColWidth)

// Compute the width of each column
headerSeq.zipWithIndex.foreach({ case (cell, i) =>
for ((cell, i) <- headerSeq.zipWithIndex) {
colWidths(i) = math.max(colWidths(i), cell.length)
})

diff.foreach { row =>
row.zipWithIndex.foreach { case (cell, i) =>
}
for (row <- diff) {
for ((cell, i) <- row.zipWithIndex) {
colWidths(i) = math.max(colWidths(i), cell.length)
}
}
Expand Down Expand Up @@ -129,4 +117,5 @@ object ProductUtil {

sb.toString
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Expected DataFrame Row Count: '$expectedCount'
/**
* Raises an error unless `actualDS` and `expectedDS` are equal
*/
def assertSmallDatasetEquality[T: ClassTag](
def assertSmallDatasetEquality[T](
actualDS: Dataset[T],
expectedDS: Dataset[T],
ignoreNullable: Boolean = false,
Expand All @@ -53,7 +53,7 @@ Expected DataFrame Row Count: '$expectedCount'
assertSmallDatasetContentEquality(actual, expectedDS, orderedComparison, truncate, equals)
}

def assertSmallDatasetContentEquality[T: ClassTag](
def assertSmallDatasetContentEquality[T](
actualDS: Dataset[T],
expectedDS: Dataset[T],
orderedComparison: Boolean,
Expand All @@ -66,12 +66,12 @@ Expected DataFrame Row Count: '$expectedCount'
assertSmallDatasetContentEquality(defaultSortDataset(actualDS), defaultSortDataset(expectedDS), truncate, equals)
}

def assertSmallDatasetContentEquality[T: ClassTag](actualDS: Dataset[T], expectedDS: Dataset[T], truncate: Int, equals: (T, T) => Boolean): Unit = {
def assertSmallDatasetContentEquality[T](actualDS: Dataset[T], expectedDS: Dataset[T], truncate: Int, equals: (T, T) => Boolean): Unit = {
val a = actualDS.collect().toSeq
val e = expectedDS.collect().toSeq
if (!a.approximateSameElements(e, equals)) {
val arr = ("Actual Content", "Expected Content")
val msg = "Diffs\n" ++ ProductUtil.showProductDiff(arr, a, e, truncate)
val msg = "Diffs\n" ++ DataframeUtil.showDataframeDiff(arr, a.asRows, e.asRows, truncate)
throw DatasetContentMismatch(msg)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
package com.github.mrpowers.spark.fast.tests

import com.github.mrpowers.spark.fast.tests.ProductUtil.showProductDiff
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, NullType, StructField, StructType}
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}

object SchemaComparer {

case class DatasetSchemaMismatch(smth: String) extends Exception(smth)
private def betterSchemaMismatchMessage[T](actualDS: Dataset[T], expectedDS: Dataset[T]): String = {
showProductDiff(
("Actual Schema", "Expected Schema"),
actualDS.schema.fields,
expectedDS.schema.fields,
truncate = 200
)
"\nActual Schema Field | Expected Schema Field\n" + actualDS.schema
.zipAll(
expectedDS.schema,
"",
""
)
.map {
case (sf1, sf2) if sf1 == sf2 =>
ufansi.Color.Blue(s"$sf1 | $sf2")
case ("", sf2) =>
ufansi.Color.Red(s"MISSING | $sf2")
case (sf1, "") =>
ufansi.Color.Red(s"$sf1 | MISSING")
case (sf1, sf2) =>
ufansi.Color.Red(s"$sf1 | $sf2")
}
.mkString("\n")
}

def assertSchemaEqual[T](
Expand All @@ -25,7 +36,7 @@ object SchemaComparer {
require((ignoreColumnNames, ignoreColumnOrder) != (true, true), "Cannot set both ignoreColumnNames and ignoreColumnOrder to true.")
if (!SchemaComparer.equals(actualDS.schema, expectedDS.schema, ignoreNullable, ignoreColumnNames, ignoreColumnOrder)) {
throw DatasetSchemaMismatch(
"Diffs\n" + betterSchemaMismatchMessage(actualDS, expectedDS)
betterSchemaMismatchMessage(actualDS, expectedDS)
)
}
}
Expand Down Expand Up @@ -65,4 +76,5 @@ object SchemaComparer {
case _ => dt1 == dt2
}
}

}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.github.mrpowers.spark.fast.tests

import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType}
import org.apache.spark.sql.types.{DoubleType, IntegerType, StringType}
import SparkSessionExt._
import com.github.mrpowers.spark.fast.tests.SchemaComparer.DatasetSchemaMismatch
import com.github.mrpowers.spark.fast.tests.StringExt.StringOps
Expand Down Expand Up @@ -310,41 +310,6 @@ class DataFrameComparerTest extends AnyFreeSpec with DataFrameComparer with Spar
)
assertLargeDataFrameEquality(sourceDF, expectedDF, ignoreColumnOrder = true)
}

"correctly mark unequal schema field" in {
val sourceDF = spark.createDF(
List(
(1, 2.0),
(5, 3.0)
),
List(
("number", IntegerType, true),
("float", DoubleType, true)
)
)

val expectedDF = spark.createDF(
List(
(1, "word", 1L),
(5, "word", 2L)
),
List(
("number", IntegerType, true),
("word", StringType, true),
("long", LongType, true)
)
)

val e = intercept[DatasetSchemaMismatch] {
assertSmallDataFrameEquality(sourceDF, expectedDF)
}

val colourGroup = e.getMessage.extractColorGroup
val expectedColourGroup = colourGroup.get(Console.GREEN)
val actualColourGroup = colourGroup.get(Console.RED)
assert(expectedColourGroup.contains(Seq("word", "StringType", "StructField(long,LongType,true,{})")))
assert(actualColourGroup.contains(Seq("float", "DoubleType", "MISSING")))
}
}

"assertApproximateDataFrameEquality" - {
Expand Down
Loading