Skip to content

Commit

Permalink
Merge pull request #168 from zeotuan/test-columnComparer-with-struct
Browse files Browse the repository at this point in the history
Test ColumnComparer with StructType
  • Loading branch information
zeotuan authored Oct 14, 2024
2 parents af6299a + e86046d commit 716be99
Showing 1 changed file with 83 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,89 @@ class ColumnComparerTest extends AnyFreeSpec with ColumnComparer with SparkSessi
)
assertColumnEquality(sourceDF, "l1", "l2")
}

"throws an error for unequal nested StructType columns with same schema" in {
val sourceData = Seq(
Row(Row("John", 30), Row("John", 31)),
Row(Row("Jane", 25), Row("Jane", 25)),
Row(Row("Jake", 40), Row("Jake", 40)),
Row(null, null)
)
val nestedSchema = StructType(
List(
StructField("name", StringType, true),
StructField("age", IntegerType, true)
)
)
val sourceSchema = List(
StructField("struct1", nestedSchema, true),
StructField("struct2", nestedSchema, true)
)
val sourceDF = spark.createDataFrame(
spark.sparkContext.parallelize(sourceData),
StructType(sourceSchema)
)
intercept[ColumnMismatch] {
assertColumnEquality(sourceDF, "struct1", "struct2")
}
}

"throws an error for unequal nested StructType columns with different schema" in {
val sourceData = Seq(
Row(Row("John", 30), Row("John")),
Row(Row("Jane", 25), Row("Jane")),
Row(Row("Jake", 40), Row("Jake")),
Row(null, null)
)
val nestedSchema1 = StructType(
List(
StructField("name", StringType, true),
StructField("age", IntegerType, true)
)
)

val nestedSchema2 = StructType(
List(
StructField("name", StringType, true)
)
)
val sourceSchema = List(
StructField("struct1", nestedSchema1, true),
StructField("struct2", nestedSchema2, true)
)
val sourceDF = spark.createDataFrame(
spark.sparkContext.parallelize(sourceData),
StructType(sourceSchema)
)
intercept[ColumnMismatch] {
assertColumnEquality(sourceDF, "struct1", "struct2")
}
}

"work with StructType columns" in {
val sourceData = Seq(
Row(Row("John", 30), Row("John", 30)),
Row(Row("Jane", 25), Row("Jane", 25)),
Row(Row("Jake", 40), Row("Jake", 40)),
Row(null, null)
)
val nestedSchema = StructType(
List(
StructField("name", StringType, true),
StructField("age", IntegerType, true)
)
)
val sourceSchema = List(
StructField("struct1", nestedSchema, true),
StructField("struct2", nestedSchema, true)
)
val sourceDF = spark.createDataFrame(
spark.sparkContext.parallelize(sourceData),
StructType(sourceSchema)
)

assertColumnEquality(sourceDF, "struct1", "struct2")
}
}

}

0 comments on commit 716be99

Please sign in to comment.