Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tree-string color diff #176

Merged
merged 5 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package com.github.mrpowers.spark.fast.tests
import com.github.mrpowers.spark.fast.tests.DatasetComparer.maxUnequalRowsToShow
import com.github.mrpowers.spark.fast.tests.SeqLikesExtensions.SeqExtensions
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Dataset, Row}

import scala.reflect.ClassTag

Expand Down Expand Up @@ -49,7 +49,7 @@ Expected DataFrame Row Count: '$expectedCount'
truncate: Int = 500,
equals: (T, T) => Boolean = (o1: T, o2: T) => o1.equals(o2)
): Unit = {
SchemaComparer.assertSchemaEqual(actualDS, expectedDS, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata)
SchemaComparer.assertDatasetSchemaEqual(actualDS, expectedDS, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata)
val actual = if (ignoreColumnOrder) orderColumns(actualDS, expectedDS) else actualDS
assertSmallDatasetContentEquality(actual, expectedDS, orderedComparison, truncate, equals)
}
Expand Down Expand Up @@ -103,7 +103,7 @@ Expected DataFrame Row Count: '$expectedCount'
ignoreMetadata: Boolean = true
): Unit = {
// first check if the schemas are equal
SchemaComparer.assertSchemaEqual(actualDS, expectedDS, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata)
SchemaComparer.assertDatasetSchemaEqual(actualDS, expectedDS, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata)
val actual = if (ignoreColumnOrder) orderColumns(actualDS, expectedDS) else actualDS
assertLargeDatasetContentEquality(actual, expectedDS, equals, orderedComparison)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,33 +1,133 @@
package com.github.mrpowers.spark.fast.tests

import com.github.mrpowers.spark.fast.tests.ProductUtil.showProductDiff
import com.github.mrpowers.spark.fast.tests.SchemaDiffOutputFormat.SchemaDiffOutputFormat
import com.github.mrpowers.spark.fast.tests.ufansi.Color.{DarkGray, Green, Red}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, NullType, StructField, StructType}
import org.apache.spark.sql.types._

object SchemaComparer {
case class DatasetSchemaMismatch(smth: String) extends Exception(smth)
private def betterSchemaMismatchMessage[T](actualDS: Dataset[T], expectedDS: Dataset[T]): String = {
private def betterSchemaMismatchMessage(actualSchema: StructType, expectedSchema: StructType): String = {
showProductDiff(
("Actual Schema", "Expected Schema"),
actualDS.schema.fields,
expectedDS.schema.fields,
actualSchema.fields,
expectedSchema.fields,
truncate = 200
)
}

def assertSchemaEqual[T](
private def treeSchemaMismatchMessage[T](actualSchema: StructType, expectedSchema: StructType): String = {
def flattenStrucType(s: StructType, indent: Int): (Seq[(Int, StructField)], Int) = s
.foldLeft((Seq.empty[(Int, StructField)], Int.MinValue)) { case ((fieldPair, maxWidth), f) =>
// 5 char for each level of indentation, 21 char for gap, and description words
val gap = indent * 5 + 21 + f.name.length + f.dataType.typeName.length + f.nullable.toString.length
val pair = fieldPair :+ (indent, f)
val newMaxWidth = scala.math.max(maxWidth, gap)
f.dataType match {
case st: StructType =>
val (flattenPair, width) = flattenStrucType(st, indent + 1)
(pair ++ flattenPair, scala.math.max(newMaxWidth, width))
case _ => (pair, newMaxWidth)
}
}

def depthToIndentStr(depth: Int): String = Range(0, depth).map(_ => "| ").mkString + "|--"
val treeSpaces = 6
val (treeFieldPair1, tree1MaxWidth) = flattenStrucType(actualSchema, 0)
val (treeFieldPair2, _) = flattenStrucType(expectedSchema, 0)
val (treePair, maxWidth) = treeFieldPair1
.zipAll(treeFieldPair2, (0, null), (0, null))
.foldLeft((Seq.empty[(String, String)], 0)) { case ((acc, maxWidth), ((indent1, field1), (indent2, field2))) =>
val prefix1 = depthToIndentStr(indent1)
val prefix2 = depthToIndentStr(indent2)
val (sprefix1, sprefix2) = if (indent1 != indent2) {
(Red(prefix1), Green(prefix2))
} else {
(DarkGray(prefix1), DarkGray(prefix2))
}

val pair = if (field1 != null && field2 != null) {
val (name1, name2) =
if (field1.name != field2.name)
(Red(field1.name), Green(field2.name))
else
(DarkGray(field1.name), DarkGray(field2.name))

val (dtype1, dtype2) =
if (field1.dataType != field2.dataType)
(Red(field1.dataType.typeName), Green(field2.dataType.typeName))
else
(DarkGray(field1.dataType.typeName), DarkGray(field2.dataType.typeName))

val (nullable1, nullable2) =
if (field1.nullable != field2.nullable)
(Red(field1.nullable.toString), Green(field2.nullable.toString))
else
(DarkGray(field1.nullable.toString), DarkGray(field2.nullable.toString))

val structString1 = s"$sprefix1 $name1 : $dtype1 (nullable = $nullable1)"
val structString2 = s"$sprefix2 $name2 : $dtype2 (nullable = $nullable2)"
(structString1, structString2)
} else {
val structString1 = if (field1 != null) {
val name = Red(field1.name)
val dtype = Red(field1.dataType.typeName)
val nullable = Red(field1.nullable.toString)
s"$sprefix1 $name : $dtype (nullable = $nullable)"
} else ""

val structString2 = if (field2 != null) {
val name = Green(field2.name)
val dtype = Green(field2.dataType.typeName)
val nullable = Green(field2.nullable.toString)
s"$sprefix2 $name : $dtype (nullable = $nullable)"
} else ""
(structString1, structString2)
}
(acc :+ pair, math.max(maxWidth, pair._1.length))
}

val schemaGap = maxWidth + treeSpaces
val headerGap = tree1MaxWidth + treeSpaces
treePair
.foldLeft(new StringBuilder("\nActual Schema".padTo(headerGap, ' ') + "Expected Schema\n")) { case (sb, (s1, s2)) =>
val gap = if (s1.isEmpty) headerGap else schemaGap
val s = if (s2.isEmpty) s1 else s1.padTo(gap, ' ')
sb.append(s + s2 + "\n")
}
.toString()
}

def assertDatasetSchemaEqual[T](
actualDS: Dataset[T],
expectedDS: Dataset[T],
ignoreNullable: Boolean = false,
ignoreColumnNames: Boolean = false,
ignoreColumnOrder: Boolean = true,
ignoreMetadata: Boolean = true
ignoreMetadata: Boolean = true,
outputFormat: SchemaDiffOutputFormat = SchemaDiffOutputFormat.Table
): Unit = {
assertSchemaEqual(actualDS.schema, expectedDS.schema, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata, outputFormat)
}

def assertSchemaEqual(
actualSchema: StructType,
expectedSchema: StructType,
ignoreNullable: Boolean = false,
ignoreColumnNames: Boolean = false,
ignoreColumnOrder: Boolean = true,
ignoreMetadata: Boolean = true,
outputFormat: SchemaDiffOutputFormat = SchemaDiffOutputFormat.Table
): Unit = {
require((ignoreColumnNames, ignoreColumnOrder) != (true, true), "Cannot set both ignoreColumnNames and ignoreColumnOrder to true.")
if (!SchemaComparer.equals(actualDS.schema, expectedDS.schema, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata)) {
throw DatasetSchemaMismatch(
"Diffs\n" + betterSchemaMismatchMessage(actualDS, expectedDS)
)
if (!SchemaComparer.equals(actualSchema, expectedSchema, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata)) {
val diffString = outputFormat match {
case SchemaDiffOutputFormat.Tree => treeSchemaMismatchMessage(actualSchema, expectedSchema)
case SchemaDiffOutputFormat.Table => betterSchemaMismatchMessage(actualSchema, expectedSchema)
}

throw DatasetSchemaMismatch(s"Diffs\n$diffString")
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.github.mrpowers.spark.fast.tests

object SchemaDiffOutputFormat extends Enumeration {
type SchemaDiffOutputFormat = Value

val Tree, Table = Value
}
Loading