From 7ba973a748b1187c6e6c01779d842fe51d32f5bf Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Wed, 10 Apr 2024 19:48:01 -0700 Subject: [PATCH] GH-40695 [C++] Expand Substrait type support (#40696) ### Rationale for this change See #40695 ### What changes are included in this PR? This PR does a few things: * Substrait is upgraded to the latest version * Support is added for the parameterized timestamp type (but not literals due to https://github.com/substrait-io/substrait/issues/611). * Support is added for the following arrow-specific types: * fp16 * date_millis * time_seconds * time_millis * time_nanos * large_string * large_binary When adding support for the new timestamp types I also relaxed the restrictions on the time zone column. Substrait puts time zone information in the function and not the type. In other words, to print the "America/New York" value of a column of instants one would do something like `to_char(my_timestamp, "America/New York")` instead of `to_char(cast(my_timestamp, timestamp("nanos", "America/New York")`. However, the current implementation makes it impossible to produce or consume a plan with `to_char(my_timestamp, "America/New York")` because it would reject the type because it has a non-UTC time zone. With this latest change, we treat any non-empty timezone as a timezone_tz type. In addition, I have enabled conversions from "encoded types" to their unencoded representation. E.g. a type of `DICTIONARY` will convert to `INT32`. At a logical expression / plan perspective these encodings are irrelevant. If anything, they may belong in a more physical plan representation. Should a need for them arise we can dig into it more later. However, I believe it is better to err on the side of generating "something" rather than failing in these cases. I don't consider this last change critical and can back it out if need be. ### Are these changes tested? Yes, I added new unit tests ### Are there any user-facing changes? Yes, via the Substrait conversion. These changes should be backwards compatible in that they only add functionality in places that previously reported "Not Supported". * GitHub Issue: #40695 Lead-authored-by: Weston Pace Co-authored-by: Benjamin Kietzman Signed-off-by: Weston Pace --- .../engine/substrait/expression_internal.cc | 262 +++++++++++++++--- .../arrow/engine/substrait/extension_set.cc | 8 +- .../arrow/engine/substrait/extension_set.h | 4 + .../arrow/engine/substrait/extension_types.cc | 31 +++ .../arrow/engine/substrait/extension_types.h | 10 + cpp/src/arrow/engine/substrait/serde_test.cc | 125 +++++++-- .../arrow/engine/substrait/type_internal.cc | 110 ++++++-- cpp/src/arrow/engine/substrait/util.h | 2 +- cpp/thirdparty/versions.txt | 4 +- format/substrait/extension_types.yaml | 142 ++++++++-- python/pyarrow/tests/test_substrait.py | 48 ++++ 11 files changed, 649 insertions(+), 97 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index 5d892af9a394e..480cf30d3033f 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -33,6 +33,7 @@ #include #include +#include #include "arrow/array/array_base.h" #include "arrow/array/array_nested.h" @@ -58,8 +59,10 @@ #include "arrow/util/decimal.h" #include "arrow/util/endian.h" #include "arrow/util/logging.h" +#include "arrow/util/macros.h" #include "arrow/util/small_vector.h" #include "arrow/util/string.h" +#include "arrow/util/unreachable.h" #include "arrow/visit_scalar_inline.h" namespace arrow { @@ -71,6 +74,9 @@ namespace engine { namespace { +constexpr int64_t kMicrosPerSecond = 1000000; +constexpr int64_t kMicrosPerMilli = 1000; + Id NormalizeFunctionName(Id id) { // Substrait plans encode the types into the function name so it might look like // add:opt_i32_i32. We don't care about the :opt_i32_i32 so we just trim it @@ -421,6 +427,86 @@ Result FromProto(const substrait::Expression& expr, expr.DebugString()); } +namespace { +struct UserDefinedLiteralToArrow { + Status Visit(const DataType& type) { + return Status::NotImplemented("User defined literals of type ", type); + } + Status Visit(const IntegerType& type) { + google::protobuf::UInt64Value value; + if (!user_defined_->value().UnpackTo(&value)) { + return Status::Invalid( + "Failed to unpack user defined integer literal to UInt64Value"); + } + ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), value.value())); + return Status::OK(); + } + Status Visit(const Time32Type& type) { + google::protobuf::Int32Value value; + if (!user_defined_->value().UnpackTo(&value)) { + return Status::Invalid( + "Failed to unpack user defined time32 literal to Int32Value"); + } + ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), value.value())); + return Status::OK(); + } + Status Visit(const Time64Type& type) { + google::protobuf::Int64Value value; + if (!user_defined_->value().UnpackTo(&value)) { + return Status::Invalid( + "Failed to unpack user defined time64 literal to Int64Value"); + } + ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), value.value())); + return Status::OK(); + } + Status Visit(const Date64Type& type) { + google::protobuf::Int64Value value; + if (!user_defined_->value().UnpackTo(&value)) { + return Status::Invalid( + "Failed to unpack user defined date64 literal to Int64Value"); + } + ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), value.value())); + return Status::OK(); + } + Status Visit(const HalfFloatType& type) { + google::protobuf::UInt32Value value; + if (!user_defined_->value().UnpackTo(&value)) { + return Status::Invalid( + "Failed to unpack user defined half_float literal to UInt32Value"); + } + uint16_t half_float_value = value.value(); + ARROW_ASSIGN_OR_RAISE(scalar_, MakeScalar(type.GetSharedPtr(), half_float_value)); + return Status::OK(); + } + Status Visit(const LargeStringType& type) { + google::protobuf::StringValue value; + if (!user_defined_->value().UnpackTo(&value)) { + return Status::Invalid( + "Failed to unpack user defined large_string literal to StringValue"); + } + ARROW_ASSIGN_OR_RAISE(scalar_, + MakeScalar(type.GetSharedPtr(), std::string(value.value()))); + return Status::OK(); + } + Status Visit(const LargeBinaryType& type) { + google::protobuf::BytesValue value; + if (!user_defined_->value().UnpackTo(&value)) { + return Status::Invalid( + "Failed to unpack user defined large_binary literal to BytesValue"); + } + ARROW_ASSIGN_OR_RAISE(scalar_, + MakeScalar(type.GetSharedPtr(), std::string(value.value()))); + return Status::OK(); + } + Status operator()(const DataType& type) { return VisitTypeInline(type, this); } + + std::shared_ptr scalar_; + const substrait::Expression::Literal::UserDefined* user_defined_; + const ExtensionSet* ext_set_; + const ConversionOptions& conversion_options_; +}; +} // namespace + Result FromProto(const substrait::Expression::Literal& lit, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { @@ -455,6 +541,7 @@ Result FromProto(const substrait::Expression::Literal& lit, case substrait::Expression::Literal::kBinary: return Datum(BinaryScalar(lit.binary())); + ARROW_SUPPRESS_DEPRECATION_WARNING case substrait::Expression::Literal::kTimestamp: return Datum( TimestampScalar(static_cast(lit.timestamp()), TimeUnit::MICRO)); @@ -462,7 +549,17 @@ Result FromProto(const substrait::Expression::Literal& lit, case substrait::Expression::Literal::kTimestampTz: return Datum(TimestampScalar(static_cast(lit.timestamp_tz()), TimeUnit::MICRO, TimestampTzTimezoneString())); - + ARROW_UNSUPPRESS_DEPRECATION_WARNING + case substrait::Expression::Literal::kPrecisionTimestamp: { + // https://github.com/substrait-io/substrait/issues/611 + // TODO(GH-40741) don't break, return precision timestamp + break; + } + case substrait::Expression::Literal::kPrecisionTimestampTz: { + // https://github.com/substrait-io/substrait/issues/611 + // TODO(GH-40741) don't break, return precision timestamp + break; + } case substrait::Expression::Literal::kDate: return Datum(Date32Scalar(lit.date())); case substrait::Expression::Literal::kTime: @@ -674,18 +771,30 @@ Result FromProto(const substrait::Expression::Literal& lit, return Datum(MakeNullScalar(std::move(type_nullable.first))); } + case substrait::Expression::Literal::kUserDefined: { + const auto& user_defined = lit.user_defined(); + ARROW_ASSIGN_OR_RAISE(auto type_record, + ext_set.DecodeType(user_defined.type_reference())); + UserDefinedLiteralToArrow visitor{nullptr, &user_defined, &ext_set, + conversion_options}; + ARROW_RETURN_NOT_OK((visitor)(*type_record.type)); + return Datum(std::move(visitor.scalar_)); + } + + case substrait::Expression::Literal::LITERAL_TYPE_NOT_SET: + return Status::Invalid("substrait literal did not have any literal type set"); + default: break; } - return Status::NotImplemented("conversion to arrow::Datum from Substrait literal ", - lit.DebugString()); + return Status::NotImplemented("conversion to arrow::Datum from Substrait literal `", + lit.DebugString(), "`"); } namespace { -struct ScalarToProtoImpl { - Status Visit(const NullScalar& s) { return NotImplemented(s); } +struct ScalarToProtoImpl { using Lit = substrait::Expression::Literal; template @@ -702,6 +811,25 @@ struct ScalarToProtoImpl { return Status::OK(); } + Status EncodeUserDefined(const DataType& data_type, + const google::protobuf::Message& value) { + ARROW_ASSIGN_OR_RAISE(auto anchor, ext_set_->EncodeType(data_type)); + auto user_defined = std::make_unique(); + user_defined->set_type_reference(anchor); + auto value_any = std::make_unique(); + value_any->PackFrom(value); + user_defined->set_allocated_value(value_any.release()); + lit_->set_allocated_user_defined(user_defined.release()); + return Status::OK(); + } + + Status Visit(const NullScalar& s) { + ARROW_ASSIGN_OR_RAISE(auto anchor, ext_set_->EncodeType(*s.type)); + auto user_defined = std::make_unique(); + user_defined->set_type_reference(anchor); + lit_->set_allocated_user_defined(user_defined.release()); + return Status::OK(); + } Status Visit(const BooleanScalar& s) { return Primitive(&Lit::set_boolean, s); } Status Visit(const Int8Scalar& s) { return Primitive(&Lit::set_i8, s); } @@ -709,12 +837,31 @@ struct ScalarToProtoImpl { Status Visit(const Int32Scalar& s) { return Primitive(&Lit::set_i32, s); } Status Visit(const Int64Scalar& s) { return Primitive(&Lit::set_i64, s); } - Status Visit(const UInt8Scalar& s) { return NotImplemented(s); } - Status Visit(const UInt16Scalar& s) { return NotImplemented(s); } - Status Visit(const UInt32Scalar& s) { return NotImplemented(s); } - Status Visit(const UInt64Scalar& s) { return NotImplemented(s); } - - Status Visit(const HalfFloatScalar& s) { return NotImplemented(s); } + Status Visit(const UInt8Scalar& s) { + google::protobuf::UInt64Value value; + value.set_value(s.value); + return EncodeUserDefined(*s.type, value); + } + Status Visit(const UInt16Scalar& s) { + google::protobuf::UInt64Value value; + value.set_value(s.value); + return EncodeUserDefined(*s.type, value); + } + Status Visit(const UInt32Scalar& s) { + google::protobuf::UInt64Value value; + value.set_value(s.value); + return EncodeUserDefined(*s.type, value); + } + Status Visit(const UInt64Scalar& s) { + google::protobuf::UInt64Value value; + value.set_value(s.value); + return EncodeUserDefined(*s.type, value); + } + Status Visit(const HalfFloatScalar& s) { + google::protobuf::UInt32Value value; + value.set_value(s.value); + return EncodeUserDefined(*s.type, value); + } Status Visit(const FloatScalar& s) { return Primitive(&Lit::set_fp32, s); } Status Visit(const DoubleScalar& s) { return Primitive(&Lit::set_fp64, s); } @@ -722,12 +869,18 @@ struct ScalarToProtoImpl { return FromBuffer([](Lit* lit, std::string&& s) { lit->set_string(std::move(s)); }, s); } + Status Visit(const StringViewScalar& s) { + return FromBuffer([](Lit* lit, std::string&& s) { lit->set_string(std::move(s)); }, + s); + } Status Visit(const BinaryScalar& s) { return FromBuffer([](Lit* lit, std::string&& s) { lit->set_binary(std::move(s)); }, s); } - - Status Visit(const BinaryViewScalar& s) { return NotImplemented(s); } + Status Visit(const BinaryViewScalar& s) { + return FromBuffer([](Lit* lit, std::string&& s) { lit->set_binary(std::move(s)); }, + s); + } Status Visit(const FixedSizeBinaryScalar& s) { return FromBuffer( @@ -735,28 +888,64 @@ struct ScalarToProtoImpl { } Status Visit(const Date32Scalar& s) { return Primitive(&Lit::set_date, s); } - Status Visit(const Date64Scalar& s) { return NotImplemented(s); } + Status Visit(const Date64Scalar& s) { + google::protobuf::Int64Value value; + value.set_value(s.value); + return EncodeUserDefined(*s.type, value); + } Status Visit(const TimestampScalar& s) { const auto& t = checked_cast(*s.type); - if (t.unit() != TimeUnit::MICRO) return NotImplemented(s); + uint64_t micros; + switch (t.unit()) { + case TimeUnit::SECOND: + micros = s.value * kMicrosPerSecond; + break; + case TimeUnit::MILLI: + micros = s.value * kMicrosPerMilli; + break; + case TimeUnit::MICRO: + micros = s.value; + break; + case TimeUnit::NANO: + // TODO(GH-40741): can support nanos when + // https://github.com/substrait-io/substrait/issues/611 is resolved + return NotImplemented(s); + default: + return NotImplemented(s); + } - if (t.timezone() == "") return Primitive(&Lit::set_timestamp, s); + // Remove these and use precision timestamp once + // https://github.com/substrait-io/substrait/issues/611 is resolved + ARROW_SUPPRESS_DEPRECATION_WARNING - if (t.timezone() == TimestampTzTimezoneString()) { - return Primitive(&Lit::set_timestamp_tz, s); + if (t.timezone() == "") { + lit_->set_timestamp(micros); + } else { + // Some loss of info here, Substrait doesn't store timezone + // in field data + lit_->set_timestamp_tz(micros); } + ARROW_UNSUPPRESS_DEPRECATION_WARNING - return NotImplemented(s); + return Status::OK(); } - Status Visit(const Time32Scalar& s) { return NotImplemented(s); } + // Need to support parameterized UDTs + Status Visit(const Time32Scalar& s) { + google::protobuf::Int32Value value; + value.set_value(s.value); + return EncodeUserDefined(*s.type, value); + } Status Visit(const Time64Scalar& s) { - if (checked_cast(*s.type).unit() != TimeUnit::MICRO) { - return NotImplemented(s); + if (checked_cast(*s.type).unit() == TimeUnit::MICRO) { + return Primitive(&Lit::set_time, s); + } else { + google::protobuf::Int64Value value; + value.set_value(s.value); + return EncodeUserDefined(*s.type, value); } - return Primitive(&Lit::set_time, s); } Status Visit(const MonthIntervalScalar& s) { return NotImplemented(s); } @@ -778,9 +967,10 @@ struct ScalarToProtoImpl { return Status::OK(); } + // Need support for parameterized UDTs Status Visit(const Decimal256Scalar& s) { return NotImplemented(s); } - Status Visit(const ListScalar& s) { + Status Visit(const BaseListScalar& s) { if (s.value->length() == 0) { ARROW_ASSIGN_OR_RAISE(auto list_type, ToProto(*s.type, /*nullable=*/true, ext_set_, conversion_options_)); @@ -807,10 +997,6 @@ struct ScalarToProtoImpl { return Status::OK(); } - Status Visit(const ListViewScalar& s) { - return Status::NotImplemented("list-view to proto"); - } - Status Visit(const LargeListViewScalar& s) { return Status::NotImplemented("list-view to proto"); } @@ -830,7 +1016,10 @@ struct ScalarToProtoImpl { Status Visit(const SparseUnionScalar& s) { return NotImplemented(s); } Status Visit(const DenseUnionScalar& s) { return NotImplemented(s); } - Status Visit(const DictionaryScalar& s) { return NotImplemented(s); } + Status Visit(const DictionaryScalar& s) { + ARROW_ASSIGN_OR_RAISE(auto encoded, s.GetEncodedValue()); + return (*this)(*encoded); + } Status Visit(const MapScalar& s) { if (s.value->length() == 0) { @@ -914,10 +1103,21 @@ struct ScalarToProtoImpl { return NotImplemented(s); } + // Need support for parameterized UDTs Status Visit(const FixedSizeListScalar& s) { return NotImplemented(s); } Status Visit(const DurationScalar& s) { return NotImplemented(s); } - Status Visit(const LargeStringScalar& s) { return NotImplemented(s); } - Status Visit(const LargeBinaryScalar& s) { return NotImplemented(s); } + + Status Visit(const LargeStringScalar& s) { + google::protobuf::StringValue value; + value.set_value(s.view().data(), s.view().size()); + return EncodeUserDefined(*s.type, value); + } + Status Visit(const LargeBinaryScalar& s) { + google::protobuf::BytesValue value; + value.set_value(s.view().data(), s.view().size()); + return EncodeUserDefined(*s.type, value); + } + // Need support for parameterized UDTs Status Visit(const LargeListScalar& s) { return NotImplemented(s); } Status Visit(const MonthDayNanoIntervalScalar& s) { return NotImplemented(s); } diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index b0dd6aeffbcfa..e955084dcdfbb 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -1035,7 +1035,7 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { }; // The type (variation) mappings listed below need to be kept in sync - // with the YAML at substrait/format/extension_types.yaml manually; + // with the YAML at format/substrait/extension_types.yaml manually; // see ARROW-15535. for (TypeName e : { TypeName{uint8(), "u8"}, @@ -1043,6 +1043,12 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { TypeName{uint32(), "u32"}, TypeName{uint64(), "u64"}, TypeName{float16(), "fp16"}, + TypeName{large_utf8(), "large_string"}, + TypeName{large_binary(), "large_binary"}, + TypeName{time32(TimeUnit::SECOND), "time_seconds"}, + TypeName{time32(TimeUnit::MILLI), "time_millis"}, + TypeName{date64(), "date_millis"}, + TypeName{time64(TimeUnit::NANO), "time_nanos"}, }) { DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type))); } diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h index 0a502960447e6..c18e0cf77aae5 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.h +++ b/cpp/src/arrow/engine/substrait/extension_set.h @@ -295,6 +295,10 @@ class ARROW_ENGINE_EXPORT ExtensionIdRegistry { constexpr std::string_view kArrowExtTypesUri = "https://github.com/apache/arrow/blob/main/format/substrait/" "extension_types.yaml"; +// Extension types that don't match 1:1 with a data type (or the data type is +// parameterized) +constexpr std::string_view kTimeNanosTypeName = "time_nanos"; +constexpr Id kTimeNanosId = {kArrowExtTypesUri, kTimeNanosTypeName}; /// A default registry with all supported functions and data types registered /// diff --git a/cpp/src/arrow/engine/substrait/extension_types.cc b/cpp/src/arrow/engine/substrait/extension_types.cc index fcc722e9d9410..f71b5f7185d00 100644 --- a/cpp/src/arrow/engine/substrait/extension_types.cc +++ b/cpp/src/arrow/engine/substrait/extension_types.cc @@ -22,6 +22,7 @@ #include #include "arrow/engine/simple_extension_type_internal.h" +#include "arrow/engine/substrait/type_internal.h" #include "arrow/result.h" #include "arrow/type_fwd.h" #include "arrow/util/reflection_internal.h" @@ -113,6 +114,36 @@ std::shared_ptr interval_year() { return IntervalYearType::Make({}); } std::shared_ptr interval_day() { return IntervalDayType::Make({}); } +Result> precision_timestamp(int precision) { + switch (precision) { + case 0: + return timestamp(TimeUnit::SECOND); + case 3: + return timestamp(TimeUnit::MILLI); + case 6: + return timestamp(TimeUnit::MICRO); + case 9: + return timestamp(TimeUnit::NANO); + default: + return Status::NotImplemented("Unrecognized timestamp precision (", precision, ")"); + } +} + +Result> precision_timestamp_tz(int precision) { + switch (precision) { + case 0: + return timestamp(TimeUnit::SECOND, TimestampTzTimezoneString()); + case 3: + return timestamp(TimeUnit::MILLI, TimestampTzTimezoneString()); + case 6: + return timestamp(TimeUnit::MICRO, TimestampTzTimezoneString()); + case 9: + return timestamp(TimeUnit::NANO, TimestampTzTimezoneString()); + default: + return Status::NotImplemented("Unrecognized timestamp precision (", precision, ")"); + } +} + bool UnwrapUuid(const DataType& t) { if (UuidType::GetIf(t)) { return true; diff --git a/cpp/src/arrow/engine/substrait/extension_types.h b/cpp/src/arrow/engine/substrait/extension_types.h index 28a4898a878d7..ae71ad83f7e54 100644 --- a/cpp/src/arrow/engine/substrait/extension_types.h +++ b/cpp/src/arrow/engine/substrait/extension_types.h @@ -56,6 +56,16 @@ std::shared_ptr interval_year(); ARROW_ENGINE_EXPORT std::shared_ptr interval_day(); +/// constructs the appropriate timestamp type given the precision +/// no time zone +ARROW_ENGINE_EXPORT +Result> precision_timestamp(int precision); + +/// constructs the appropriate timestamp type given the precision +/// and the UTC time zone +ARROW_ENGINE_EXPORT +Result> precision_timestamp_tz(int precision); + /// Return true if t is Uuid, otherwise false ARROW_ENGINE_EXPORT bool UnwrapUuid(const DataType&); diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 1e771ccdd25c2..3e80192377937 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -54,6 +54,7 @@ #include "arrow/engine/substrait/options.h" #include "arrow/engine/substrait/serde.h" #include "arrow/engine/substrait/test_util.h" +#include "arrow/engine/substrait/type_internal.h" #include "arrow/engine/substrait/util.h" #include "arrow/filesystem/filesystem.h" #include "arrow/filesystem/localfs.h" @@ -299,6 +300,8 @@ TEST(Substrait, SupportedTypes) { map(utf8(), field("", utf8()), false)); } +// These types don't exist in Substrait. However, we have user defined types +// defined for them and they should be able to round-trip TEST(Substrait, SupportedExtensionTypes) { ExtensionSet ext_set; @@ -308,6 +311,12 @@ TEST(Substrait, SupportedExtensionTypes) { uint16(), uint32(), uint64(), + large_utf8(), + large_binary(), + date64(), + time32(TimeUnit::SECOND), + time32(TimeUnit::MILLI), + time64(TimeUnit::NANO), }) { auto anchor = ext_set.num_types(); @@ -332,6 +341,53 @@ TEST(Substrait, SupportedExtensionTypes) { } } +// Encodings are not considered distinct types in Substrait. The encoding information +// is lost during a round-trip. +TEST(Substrait, OneWayTypes) { + ExtensionSet ext_set; + + for (auto [source_type, return_type] : + {std::pair{binary_view(), binary()}, + {utf8_view(), utf8()}, + {dictionary(int32(), utf8()), utf8()}, + {run_end_encoded(int32(), int32()), int32()}, + {dictionary(int32(), dictionary(int32(), utf8())), utf8()}}) { + ASSERT_OK_AND_ASSIGN(auto substrait_type, SerializeType(*source_type, &ext_set, {})); + + ASSERT_OK_AND_ASSIGN(auto actual_return, + DeserializeType(*substrait_type, ext_set, {})); + + EXPECT_EQ(*actual_return, *return_type); + } +} + +// Substrait does not store the time zone as part of the type. That information is stored +// on the function instead. As a result, that information is lost on a round-trip. +TEST(Substrait, TimestampTypes) { + ExtensionSet ext_set; + + for (auto time_unit : + {TimeUnit::NANO, TimeUnit::MICRO, TimeUnit::MILLI, TimeUnit::SECOND}) { + for (auto time_zone : {"UTC", "America/New_York"}) { + auto input_type = timestamp(time_unit, time_zone); + ASSERT_OK_AND_ASSIGN(auto substrait_type, SerializeType(*input_type, &ext_set, {})); + + auto expected_return = timestamp(time_unit, TimestampTzTimezoneString()); + ASSERT_OK_AND_ASSIGN(auto actual_return, + DeserializeType(*substrait_type, ext_set, {})); + + EXPECT_EQ(*actual_return, *expected_return); + } + auto input_type = timestamp(time_unit); + ASSERT_OK_AND_ASSIGN(auto substrait_type, SerializeType(*input_type, &ext_set, {})); + + ASSERT_OK_AND_ASSIGN(auto actual_return, + DeserializeType(*substrait_type, ext_set, {})); + + EXPECT_EQ(*actual_return, *input_type); + } +} + TEST(Substrait, NamedStruct) { ExtensionSet ext_set; @@ -415,26 +471,11 @@ TEST(Substrait, NoEquivalentArrowType) { TEST(Substrait, NoEquivalentSubstraitType) { for (auto type : { - date64(), - timestamp(TimeUnit::SECOND), - timestamp(TimeUnit::NANO), - timestamp(TimeUnit::MICRO, "New York"), - time32(TimeUnit::SECOND), - time32(TimeUnit::MILLI), - time64(TimeUnit::NANO), - decimal256(76, 67), - sparse_union({field("i8", int8()), field("f32", float32())}), dense_union({field("i8", int8()), field("f32", float32())}), - dictionary(int32(), utf8()), - fixed_size_list(float16(), 3), - duration(TimeUnit::MICRO), - - large_utf8(), - large_binary(), large_list(utf8()), }) { ARROW_SCOPED_TRACE(type->ToString()); @@ -563,6 +604,56 @@ TEST(Substrait, SupportedLiterals) { } } +template +void CheckArrowSpecificLiteral(ScalarType scalar) { + compute::Expression lit = compute::literal(scalar); + ExtensionSet ext_set; + ASSERT_OK_AND_ASSIGN(auto serialized, SerializeExpression(lit, &ext_set)); + ASSERT_OK_AND_ASSIGN(auto roundtripped, DeserializeExpression(*serialized, ext_set)); + ASSERT_EQ(lit, roundtripped); +} + +TEST(Substrait, ArrowSpecificLiterals) { + CheckArrowSpecificLiteral(UInt8Scalar(7)); + CheckArrowSpecificLiteral(UInt16Scalar(7)); + CheckArrowSpecificLiteral(UInt32Scalar(7)); + CheckArrowSpecificLiteral(UInt64Scalar(7)); + CheckArrowSpecificLiteral(Date64Scalar(86400000)); + CheckArrowSpecificLiteral(Time64Scalar(7, TimeUnit::NANO)); + CheckArrowSpecificLiteral(Time32Scalar(7, TimeUnit::SECOND)); + CheckArrowSpecificLiteral(Time32Scalar(7, TimeUnit::MILLI)); + CheckArrowSpecificLiteral(Time32Scalar(7, TimeUnit::SECOND)); + // We serialize as a signed integer, which doesn't make sense for Time scalars but + // Arrow supports it so we might as well round-trip it. + CheckArrowSpecificLiteral(Time32Scalar(-7, TimeUnit::MILLI)); + CheckArrowSpecificLiteral(Time32Scalar(-7, TimeUnit::SECOND)); + CheckArrowSpecificLiteral(Time64Scalar(-7, TimeUnit::NANO)); + // Negative date scalars DO make sense and we should make sure they work + CheckArrowSpecificLiteral(Date64Scalar(-86400000)); + CheckArrowSpecificLiteral(HalfFloatScalar(0)); + CheckArrowSpecificLiteral(LargeStringScalar("hello")); + CheckArrowSpecificLiteral(LargeBinaryScalar("hello")); + CheckArrowSpecificLiteral(MakeNullScalar(null())); +} + +template +void CheckOneWayLiteral(SourceScalarType source, DestScalarType expected) { + compute::Expression lit = compute::literal(source); + ExtensionSet ext_set; + ASSERT_OK_AND_ASSIGN(auto serialized, SerializeExpression(lit, &ext_set)); + ASSERT_OK_AND_ASSIGN(auto roundtripped, DeserializeExpression(*serialized, ext_set)); + compute::Expression expected_lit = compute::literal(expected); + ASSERT_EQ(expected_lit, roundtripped); +} + +TEST(Substrait, OneWayLiterals) { + CheckOneWayLiteral(StringViewScalar("test"), StringScalar("test")); + CheckOneWayLiteral(BinaryViewScalar("test"), BinaryScalar("test")); + CheckOneWayLiteral(RunEndEncodedScalar(std::make_shared(7), + run_end_encoded(int16(), uint32())), + UInt32Scalar(7)); +} + TEST(Substrait, CannotDeserializeLiteral) { ExtensionSet ext_set; @@ -823,8 +914,8 @@ TEST(Substrait, Cast) { std::shared_ptr cast_opts = std::dynamic_pointer_cast(call_opts); ASSERT_TRUE(!!cast_opts); - // It is unclear whether a Substrait cast should be safe or not. In the meantime we are - // assuming it is unsafe based on the behavior of many SQL engines. + // It is unclear whether a Substrait cast should be safe or not. In the meantime we + // are assuming it is unsafe based on the behavior of many SQL engines. ASSERT_TRUE(cast_opts->allow_int_overflow); ASSERT_TRUE(cast_opts->allow_float_truncate); ASSERT_TRUE(cast_opts->allow_decimal_truncate); diff --git a/cpp/src/arrow/engine/substrait/type_internal.cc b/cpp/src/arrow/engine/substrait/type_internal.cc index f4a2e6800eb49..5e7e364fe00c5 100644 --- a/cpp/src/arrow/engine/substrait/type_internal.cc +++ b/cpp/src/arrow/engine/substrait/type_internal.cc @@ -34,8 +34,13 @@ #include "arrow/status.h" #include "arrow/type.h" #include "arrow/type_fwd.h" +#include "arrow/type_traits.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/unreachable.h" #include "arrow/visit_type_inline.h" +using arrow::internal::checked_cast; + namespace arrow { namespace engine { @@ -121,11 +126,24 @@ Result, bool>> FromProto( case substrait::Type::kBinary: return FromProtoImpl(type.binary()); + ARROW_SUPPRESS_DEPRECATION_WARNING case substrait::Type::kTimestamp: return FromProtoImpl(type.timestamp(), TimeUnit::MICRO); case substrait::Type::kTimestampTz: return FromProtoImpl(type.timestamp_tz(), TimeUnit::MICRO, TimestampTzTimezoneString()); + ARROW_UNSUPPRESS_DEPRECATION_WARNING + case substrait::Type::kPrecisionTimestamp: { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr ts_type, + precision_timestamp(type.precision_timestamp().precision())); + return std::make_pair(ts_type, IsNullable(type.precision_timestamp())); + } + case substrait::Type::kPrecisionTimestampTz: { + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr ts_type, + precision_timestamp_tz(type.precision_timestamp_tz().precision())); + return std::make_pair(ts_type, IsNullable(type.precision_timestamp_tz())); + } case substrait::Type::kDate: return FromProtoImpl(type.date()); @@ -263,7 +281,14 @@ struct DataTypeToProtoImpl { return SetWith(&substrait::Type::set_allocated_binary); } - Status Visit(const BinaryViewType& t) { return NotImplemented(t); } + // From Substrait's point of view the view types are encodings, and an execution detail, + // and not distinct from the non-view type. + Status Visit(const BinaryViewType& t) { + return SetWith(&substrait::Type::set_allocated_binary); + } + Status Visit(const StringViewType& t) { + return SetWith(&substrait::Type::set_allocated_string); + } Status Visit(const FixedSizeBinaryType& t) { SetWithThen(&substrait::Type::set_allocated_fixed_binary)->set_length(t.byte_width()); @@ -273,25 +298,50 @@ struct DataTypeToProtoImpl { Status Visit(const Date32Type& t) { return SetWith(&substrait::Type::set_allocated_date); } - Status Visit(const Date64Type& t) { return NotImplemented(t); } + Status Visit(const Date64Type& t) { return EncodeUserDefined(t); } - Status Visit(const TimestampType& t) { - if (t.unit() != TimeUnit::MICRO) return NotImplemented(t); + template + Status VisitTimestamp(const TimestampType& t, + void (substrait::Type::*set_allocated_sub)(Sub*)) { + auto ts = SetWithThen(set_allocated_sub); + switch (t.unit()) { + case TimeUnit::SECOND: + ts->set_precision(0); + break; + case TimeUnit::MILLI: + ts->set_precision(3); + break; + case TimeUnit::MICRO: + ts->set_precision(6); + break; + case TimeUnit::NANO: + ts->set_precision(9); + break; + default: + return NotImplemented(t); + } + return Status::OK(); + } + Status Visit(const TimestampType& t) { if (t.timezone() == "") { - return SetWith(&substrait::Type::set_allocated_timestamp); - } - if (t.timezone() == TimestampTzTimezoneString()) { - return SetWith(&substrait::Type::set_allocated_timestamp_tz); + return VisitTimestamp(t, &substrait::Type::set_allocated_precision_timestamp); + } else { + // Note: The timezone information is discarded here. In Substrait the time zone + // information is part of the function and not part of the type. For example, to + // convert a timestamp to a string, the time zone is passed as an argument to the + // function. + return VisitTimestamp(t, &substrait::Type::set_allocated_precision_timestamp_tz); } - - return NotImplemented(t); } - Status Visit(const Time32Type& t) { return NotImplemented(t); } + Status Visit(const Time32Type& t) { return EncodeUserDefined(t); } Status Visit(const Time64Type& t) { - if (t.unit() != TimeUnit::MICRO) return NotImplemented(t); - return SetWith(&substrait::Type::set_allocated_time); + if (t.unit() == TimeUnit::MICRO) { + return SetWith(&substrait::Type::set_allocated_time); + } else { + return EncodeUserDefined(t); + } } Status Visit(const MonthIntervalType& t) { return EncodeUserDefined(t); } @@ -303,6 +353,7 @@ struct DataTypeToProtoImpl { dec->set_scale(t.scale()); return Status::OK(); } + // TODO(GH-40740) support parameterized UDT Status Visit(const Decimal256Type& t) { return NotImplemented(t); } Status Visit(const ListType& t) { @@ -313,8 +364,16 @@ struct DataTypeToProtoImpl { return Status::OK(); } - Status Visit(const ListViewType& t) { return NotImplemented(t); } + // From Substrait's point of view this is an encoding, and an implementation detail, + // and not distinct from the list type. + Status Visit(const ListViewType& t) { + ARROW_ASSIGN_OR_RAISE(auto type, ToProto(*t.value_type(), t.value_field()->nullable(), + ext_set_, conversion_options_)); + SetWithThen(&substrait::Type::set_allocated_list)->set_allocated_type(type.release()); + return Status::OK(); + } + // TODO(GH-40740) support parameterized UDT Status Visit(const LargeListViewType& t) { return NotImplemented(t); } Status Visit(const StructType& t) { @@ -335,8 +394,9 @@ struct DataTypeToProtoImpl { Status Visit(const SparseUnionType& t) { return NotImplemented(t); } Status Visit(const DenseUnionType& t) { return NotImplemented(t); } - Status Visit(const DictionaryType& t) { return NotImplemented(t); } - Status Visit(const RunEndEncodedType& t) { return NotImplemented(t); } + // The caller should have unwrapped the dictionary / RLE type + Status Visit(const DictionaryType& t) { Unreachable(); } + Status Visit(const RunEndEncodedType& t) { Unreachable(); } Status Visit(const MapType& t) { // FIXME assert default field names; custom ones won't roundtrip @@ -379,10 +439,13 @@ struct DataTypeToProtoImpl { return NotImplemented(t); } + // TODO(GH-40740) support parameterized UDT Status Visit(const FixedSizeListType& t) { return NotImplemented(t); } + // TODO(GH-40740) support parameterized UDT Status Visit(const DurationType& t) { return NotImplemented(t); } - Status Visit(const LargeStringType& t) { return NotImplemented(t); } - Status Visit(const LargeBinaryType& t) { return NotImplemented(t); } + Status Visit(const LargeStringType& t) { return EncodeUserDefined(t); } + Status Visit(const LargeBinaryType& t) { return EncodeUserDefined(t); } + // TODO(GH-40740) support parameterized UDT Status Visit(const LargeListType& t) { return NotImplemented(t); } Status Visit(const MonthDayNanoIntervalType& t) { return EncodeUserDefined(t); } @@ -429,6 +492,17 @@ struct DataTypeToProtoImpl { Result> ToProto( const DataType& type, bool nullable, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { + // From Substrait's perspective the "dictionary type" is just an encoding. As a result, + // we lose that information on conversion and just convert the value type. + if (type.id() == Type::DICTIONARY) { + const auto& dict_type = checked_cast(type); + return ToProto(*dict_type.value_type(), nullable, ext_set, conversion_options); + } + // Ditto for REE + if (type.id() == Type::RUN_END_ENCODED) { + const auto& ree_type = checked_cast(type); + return ToProto(*ree_type.value_type(), nullable, ext_set, conversion_options); + } auto out = std::make_unique(); RETURN_NOT_OK( (DataTypeToProtoImpl{out.get(), nullable, ext_set, conversion_options})(type)); diff --git a/cpp/src/arrow/engine/substrait/util.h b/cpp/src/arrow/engine/substrait/util.h index 5128ec44bff77..bef2a6c7e1823 100644 --- a/cpp/src/arrow/engine/substrait/util.h +++ b/cpp/src/arrow/engine/substrait/util.h @@ -70,7 +70,7 @@ ARROW_ENGINE_EXPORT const std::string& default_extension_types_uri(); // TODO(ARROW-18145) Populate these from cmake files constexpr uint32_t kSubstraitMajorVersion = 0; -constexpr uint32_t kSubstraitMinorVersion = 27; +constexpr uint32_t kSubstraitMinorVersion = 44; constexpr uint32_t kSubstraitPatchVersion = 0; constexpr uint32_t kSubstraitMinimumMajorVersion = 0; diff --git a/cpp/thirdparty/versions.txt b/cpp/thirdparty/versions.txt index 4093b0ec43efd..4983f3cee2c2d 100644 --- a/cpp/thirdparty/versions.txt +++ b/cpp/thirdparty/versions.txt @@ -103,8 +103,8 @@ ARROW_RE2_BUILD_VERSION=2022-06-01 ARROW_RE2_BUILD_SHA256_CHECKSUM=f89c61410a072e5cbcf8c27e3a778da7d6fd2f2b5b1445cd4f4508bee946ab0f ARROW_SNAPPY_BUILD_VERSION=1.1.10 ARROW_SNAPPY_BUILD_SHA256_CHECKSUM=49d831bffcc5f3d01482340fe5af59852ca2fe76c3e05df0e67203ebbe0f1d90 -ARROW_SUBSTRAIT_BUILD_VERSION=v0.27.0 -ARROW_SUBSTRAIT_BUILD_SHA256_CHECKSUM=4ed375f69d972a57fdc5ec406c17003a111831d8640d3f1733eccd4b3ff45628 +ARROW_SUBSTRAIT_BUILD_VERSION=v0.44.0 +ARROW_SUBSTRAIT_BUILD_SHA256_CHECKSUM=f989a862f694e7dbb695925ddb7c4ce06aa6c51aca945105c075139aed7e55a2 ARROW_S2N_TLS_BUILD_VERSION=v1.3.35 ARROW_S2N_TLS_BUILD_SHA256_CHECKSUM=9d32b26e6bfcc058d98248bf8fc231537e347395dd89cf62bb432b55c5da990d ARROW_THRIFT_BUILD_VERSION=0.16.0 diff --git a/format/substrait/extension_types.yaml b/format/substrait/extension_types.yaml index 888d6c94c8182..0073da1acc1ed 100644 --- a/format/substrait/extension_types.yaml +++ b/format/substrait/extension_types.yaml @@ -35,36 +35,52 @@ # - interval # - arrow::ExtensionTypes # -# Note that not all of these are currently implemented. In particular, these -# extension types are currently not parameterizable in Substrait, which means -# among other things that we can't declare dictionary type here at all since -# we'd have to declare a different dictionary type for all encoded types -# (but that is an infinite space). Similarly, we would have to declare a -# timestamp variation for all possible timezone strings. +# These types fall into several categories of behavior: -type_variations: - - parent: i8 - name: u8 - description: an unsigned 8 bit integer - functions: SEPARATE - - parent: i16 - name: u16 - description: an unsigned 16 bit integer - functions: SEPARATE - - parent: i32 - name: u32 - description: an unsigned 32 bit integer - functions: SEPARATE - - parent: i64 - name: u64 - description: an unsigned 64 bit integer - functions: SEPARATE +# Certain Arrow data types are, from Substrait's point of view, encodings. +# These include dictionary, the view types (e.g. binary view, list view), +# and REE. +# +# These types are not logically distinct from the type they are encoding. +# Specifically, the types meet the following criteria: +# * There is no value in the decoded type that cannot be represented +# as a value in the encoded type and vice versa. +# * Functions have the same meaning when applied to the encoded type +# +# Note: if two types have a different range (e.g. string and large_string) then +# they do not satisfy the above criteria and are not encodings. +# +# These types will never have a Substrait equivalent. In the Substrait point +# of view these are execution details. + +# The following types are encodings: + +# binary_view +# list_view +# dictionary +# ree + +# Arrow-cpp's Substrait serde does not yet handle parameterized UDTs. This means +# the following types are not yet supported but may be supported in the future. +# We define them below in case other implementations support them in the meantime. - - parent: i16 - name: fp16 - description: a 16 bit floating point number - functions: SEPARATE +# decimal256 +# large_list +# fixed_size_list +# duration +# Other types are not encodings, but are not first-class in Substrait. These +# types are often similar to existing Substrait types but define a different range +# of values. For example, unsigned integer types are very similar to their integer +# counterparts, but have a different range of values. These types are defined here +# as extension types. +# +# A full description of the types, along with their specified range, can be found +# in Schema.fbs +# +# Consumers should take care when supporting the below types. Should Substrait decide +# later to support these types, the consumer will need to make sure to continue supporting +# the extension type names as aliases for proper backwards compatibility. types: - name: "null" structure: {} @@ -80,3 +96,75 @@ types: months: i32 days: i32 nanos: i64 + # All unsigned integer literals are encoded as user defined literals with + # a google.protobuf.UInt64Value message. + - name: u8 + structure: {} + - name: u16 + structure: {} + - name: u32 + structure: {} + - name: u64 + structure: {} + # fp16 literals are encoded as user defined literals with + # a google.protobuf.UInt32Value message where the lower 16 bits are + # the fp16 value. + - name: fp16 + structure: {} + # 64-bit integers are big. Even though date64 stores ms and not days it + # can still represent about 50x more dates than date32. Since it has a + # different range of values, it is an extension type. + # + # date64 literals are encoded as user defined literals with + # a google.protobuf.Int64Value message. + - name: date_millis + structure: {} + # time literals are encoded as user defined literals with + # a google.protobuf.Int32Value message (for time_seconds/time_millis) + # or a google.protobuf.Int64Value message (for time_nanos). + - name: time_seconds + structure: {} + - name: time_millis + structure: {} + - name: time_nanos + structure: {} + # Large string literals are encoded using a + # google.protobuf.StringValue message. + - name: large_string + structure: {} + # Large binary literals are encoded using a + # google.protobuf.BytesValue message. + - name: large_binary + structure: {} + # We cannot generate these today because they are parameterized UDTs and + # substrait-cpp does not yet support parameterized UDTs. + - name: decimal256 + structure: {} + parameters: + - name: precision + type: integer + min: 0 + max: 76 + - name: scale + type: integer + min: 0 + max: 76 + - name: large_list + structure: {} + parameters: + - name: value_type + type: dataType + - name: fixed_size_list + structure: {} + parameters: + - name: value_type + type: dataType + - name: dimension + type: integer + min: 0 + - name: duration + structure: {} + parameters: + - name: unit + type: string + diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index d4fbfb7406838..40700e4741321 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -944,6 +944,54 @@ def test_serializing_expressions(expr): assert "test_expr" in returned.expressions +def test_arrow_specific_types(): + fields = { + "time_seconds": (pa.time32("s"), 0), + "time_millis": (pa.time32("ms"), 0), + "time_nanos": (pa.time64("ns"), 0), + "date_millis": (pa.date64(), 0), + "large_string": (pa.large_string(), "test_string"), + "large_binary": (pa.large_binary(), b"test_string"), + } + schema = pa.schema([pa.field(name, typ) for name, (typ, _) in fields.items()]) + + def check_round_trip(expr): + buf = pa.substrait.serialize_expressions([expr], ["test_expr"], schema) + returned = pa.substrait.deserialize_expressions(buf) + assert schema == returned.schema + + for name, (typ, val) in fields.items(): + check_round_trip(pc.field(name) == pa.scalar(val, type=typ)) + + +def test_arrow_one_way_types(): + schema = pa.schema( + [ + pa.field("binary_view", pa.binary_view()), + pa.field("string_view", pa.string_view()), + pa.field("dictionary", pa.dictionary(pa.int32(), pa.string())), + pa.field("ree", pa.run_end_encoded(pa.int32(), pa.string())), + ] + ) + alt_schema = pa.schema( + [ + pa.field("binary_view", pa.binary()), + pa.field("string_view", pa.string()), + pa.field("dictionary", pa.string()), + pa.field("ree", pa.string()) + ] + ) + + def check_one_way(field): + expr = pc.is_null(pc.field(field.name)) + buf = pa.substrait.serialize_expressions([expr], ["test_expr"], schema) + returned = pa.substrait.deserialize_expressions(buf) + assert alt_schema == returned.schema + + for field in schema: + check_one_way(field) + + def test_invalid_expression_ser_des(): schema = pa.schema([ pa.field("x", pa.int32()),