diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index a7721f339..f68732fb1 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -1010,12 +1010,6 @@ fn cast_string_to_int_with_range_check( } } -#[derive(PartialEq)] -enum State { - ParseSignAndDigits, - ParseFractionalDigits, -} - /// Equivalent to /// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper, boolean allowDecimal) /// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper, boolean allowDecimal) @@ -1036,15 +1030,13 @@ fn do_cast_string_to_int< let mut negative = false; let radix = T::from(10); let stop_value = min_value / radix; - let mut state = State::ParseSignAndDigits; - let mut parsed_sign = false; + let mut parse_sign_and_digits = true; for (i, ch) in trimmed_str.char_indices() { - if state == State::ParseSignAndDigits { - if !parsed_sign { + if parse_sign_and_digits { + if i == 0 { negative = ch == '-'; let positive = ch == '+'; - parsed_sign = true; if negative || positive { if i + 1 == len { // input string is just "+" or "-" @@ -1058,7 +1050,7 @@ fn do_cast_string_to_int< if ch == '.' { if eval_mode == EvalMode::Legacy { // truncate decimal in legacy mode - state = State::ParseFractionalDigits; + parse_sign_and_digits = false; continue; } else { return none_or_err(eval_mode, type_name, str); @@ -1090,9 +1082,8 @@ fn do_cast_string_to_int< return none_or_err(eval_mode, type_name, str); } } - } - - if state == State::ParseFractionalDigits { + } else { + // make sure fractional digits are valid digits but ignore them if !ch.is_ascii_digit() { return none_or_err(eval_mode, type_name, str); }