Skip to content

Commit

Permalink
test ColumnComparer with StructType
Browse files Browse the repository at this point in the history
  • Loading branch information
zeotuan committed Oct 13, 2024
1 parent af6299a commit 29bbceb
Showing 1 changed file with 75 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,81 @@ class ColumnComparerTest extends AnyFreeSpec with ColumnComparer with SparkSessi
)
assertColumnEquality(sourceDF, "l1", "l2")
}

"throws an error for unequal nested StructType columns with same schema" in {
val sourceData = Seq(
Row(Row("John", 30), Row("John", 31)),
Row(Row("Jane", 25), Row("Jane", 25)),
Row(Row("Jake", 40), Row("Jake", 40)),
Row(null, null)
)
val nestedSchema = StructType(List(
StructField("name", StringType, true),
StructField("age", IntegerType, true)
))
val sourceSchema = List(
StructField("struct1", nestedSchema, true),
StructField("struct2", nestedSchema, true)
)
val sourceDF = spark.createDataFrame(
spark.sparkContext.parallelize(sourceData),
StructType(sourceSchema)
)
intercept[ColumnMismatch] {
assertColumnEquality(sourceDF, "struct1", "struct2")
}
}

"throws an error for unequal nested StructType columns with different schema" in {
val sourceData = Seq(
Row(Row("John", 30), Row("John")),
Row(Row("Jane", 25), Row("Jane")),
Row(Row("Jake", 40), Row("Jake")),
Row(null, null)
)
val nestedSchema1 = StructType(List(
StructField("name", StringType, true),
StructField("age", IntegerType, true)
))

val nestedSchema2 = StructType(List(
StructField("name", StringType, true)
))
val sourceSchema = List(
StructField("struct1", nestedSchema1, true),
StructField("struct2", nestedSchema2, true)
)
val sourceDF = spark.createDataFrame(
spark.sparkContext.parallelize(sourceData),
StructType(sourceSchema)
)
intercept[ColumnMismatch] {
assertColumnEquality(sourceDF, "struct1", "struct2")
}
}

"work with StructType columns" in {
val sourceData = Seq(
Row(Row("John", 30), Row("John", 30)),
Row(Row("Jane", 25), Row("Jane", 25)),
Row(Row("Jake", 40), Row("Jake", 40)),
Row(null, null)
)
val nestedSchema = StructType(List(
StructField("name", StringType, true),
StructField("age", IntegerType, true)
))
val sourceSchema = List(
StructField("struct1", nestedSchema, true),
StructField("struct2", nestedSchema, true)
)
val sourceDF = spark.createDataFrame(
spark.sparkContext.parallelize(sourceData),
StructType(sourceSchema)
)

assertColumnEquality(sourceDF, "struct1", "struct2")
}
}

}

0 comments on commit 29bbceb

Please sign in to comment.