From 091441dbaca906df012db22a6bb6ef1bc1a6feca Mon Sep 17 00:00:00 2001 From: Scott McMurray Date: Fri, 29 Nov 2024 21:32:03 -0800 Subject: [PATCH 1/2] Move `{widening, carrying}_mul` to an intrinsic with fallback MIR Including implementing it for `u128`, so it can be defined in `uint_impl!`. This way it works for all backends, including CTFE. --- .../rustc_hir_analysis/src/check/intrinsic.rs | 5 + compiler/rustc_span/src/symbol.rs | 1 + library/core/src/intrinsics/fallback.rs | 111 ++++++++++++++ library/core/src/intrinsics/mod.rs | 29 ++++ library/core/src/lib.rs | 1 + library/core/src/num/mod.rs | 135 ------------------ library/core/src/num/uint_macros.rs | 116 +++++++++++++++ library/core/tests/intrinsics.rs | 58 ++++++++ library/core/tests/lib.rs | 1 + 9 files changed, 322 insertions(+), 135 deletions(-) create mode 100644 library/core/src/intrinsics/fallback.rs diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs index 394794019109d..427ef141c72a1 100644 --- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs +++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs @@ -94,6 +94,7 @@ pub fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) - | sym::add_with_overflow | sym::sub_with_overflow | sym::mul_with_overflow + | sym::carrying_mul_add | sym::wrapping_add | sym::wrapping_sub | sym::wrapping_mul @@ -436,6 +437,10 @@ pub fn check_intrinsic_type( (1, 0, vec![param(0), param(0)], Ty::new_tup(tcx, &[param(0), tcx.types.bool])) } + sym::carrying_mul_add => { + (2, 0, vec![param(0); 4], Ty::new_tup(tcx, &[param(1), param(0)])) + } + sym::ptr_guaranteed_cmp => ( 1, 0, diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index a7ff0576f92d6..ad4463864f3a0 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -555,6 +555,7 @@ symbols! { call_ref_future, caller_location, capture_disjoint_fields, + carrying_mul_add, catch_unwind, cause, cdylib, diff --git a/library/core/src/intrinsics/fallback.rs b/library/core/src/intrinsics/fallback.rs new file mode 100644 index 0000000000000..d87331f726208 --- /dev/null +++ b/library/core/src/intrinsics/fallback.rs @@ -0,0 +1,111 @@ +#![unstable( + feature = "core_intrinsics_fallbacks", + reason = "The fallbacks will never be stable, as they exist only to be called \ + by the fallback MIR, but they're exported so they can be tested on \ + platforms where the fallback MIR isn't actually used", + issue = "none" +)] +#![allow(missing_docs)] + +#[const_trait] +pub trait CarryingMulAdd: Copy + 'static { + type Unsigned: Copy + 'static; + fn carrying_mul_add( + self, + multiplicand: Self, + addend: Self, + carry: Self, + ) -> (Self::Unsigned, Self); +} + +macro_rules! impl_carrying_mul_add_by_widening { + ($($t:ident $u:ident $w:ident,)+) => {$( + #[rustc_const_unstable(feature = "core_intrinsics_fallbacks", issue = "none")] + impl const CarryingMulAdd for $t { + type Unsigned = $u; + #[inline] + fn carrying_mul_add(self, a: Self, b: Self, c: Self) -> ($u, $t) { + let wide = (self as $w) * (a as $w) + (b as $w) + (c as $w); + (wide as _, (wide >> Self::BITS) as _) + } + } + )+}; +} +impl_carrying_mul_add_by_widening! { + u8 u8 u16, + u16 u16 u32, + u32 u32 u64, + u64 u64 u128, + usize usize UDoubleSize, + i8 u8 i16, + i16 u16 i32, + i32 u32 i64, + i64 u64 i128, + isize usize UDoubleSize, +} + +#[cfg(target_pointer_width = "16")] +type UDoubleSize = u32; +#[cfg(target_pointer_width = "32")] +type UDoubleSize = u64; +#[cfg(target_pointer_width = "64")] +type UDoubleSize = u128; + +#[inline] +const fn wide_mul_u128(a: u128, b: u128) -> (u128, u128) { + #[inline] + const fn to_low_high(x: u128) -> [u128; 2] { + const MASK: u128 = u64::MAX as _; + [x & MASK, x >> 64] + } + #[inline] + const fn from_low_high(x: [u128; 2]) -> u128 { + x[0] | (x[1] << 64) + } + #[inline] + const fn scalar_mul(low_high: [u128; 2], k: u128) -> [u128; 3] { + let [x, c] = to_low_high(k * low_high[0]); + let [y, z] = to_low_high(k * low_high[1] + c); + [x, y, z] + } + let a = to_low_high(a); + let b = to_low_high(b); + let low = scalar_mul(a, b[0]); + let high = scalar_mul(a, b[1]); + let r0 = low[0]; + let [r1, c] = to_low_high(low[1] + high[0]); + let [r2, c] = to_low_high(low[2] + high[1] + c); + let r3 = high[2] + c; + (from_low_high([r0, r1]), from_low_high([r2, r3])) +} + +#[rustc_const_unstable(feature = "core_intrinsics_fallbacks", issue = "none")] +impl const CarryingMulAdd for u128 { + type Unsigned = u128; + #[inline] + fn carrying_mul_add(self, b: u128, c: u128, d: u128) -> (u128, u128) { + let (low, mut high) = wide_mul_u128(self, b); + let (low, carry) = u128::overflowing_add(low, c); + high += carry as u128; + let (low, carry) = u128::overflowing_add(low, d); + high += carry as u128; + (low, high) + } +} + +#[rustc_const_unstable(feature = "core_intrinsics_fallbacks", issue = "none")] +impl const CarryingMulAdd for i128 { + type Unsigned = u128; + #[inline] + fn carrying_mul_add(self, b: i128, c: i128, d: i128) -> (u128, i128) { + let (low, high) = wide_mul_u128(self as u128, b as u128); + let mut high = high as i128; + high = high.wrapping_add((self >> 127) * b); + high = high.wrapping_add(self * (b >> 127)); + let (low, carry) = u128::overflowing_add(low, c as u128); + high = high.wrapping_add((carry as i128) + (c >> 127)); + let (low, carry) = u128::overflowing_add(low, d as u128); + high = high.wrapping_add((carry as i128) + (d >> 127)); + (low, high) + } +} diff --git a/library/core/src/intrinsics/mod.rs b/library/core/src/intrinsics/mod.rs index 3e53c0497cc78..53b3293175aeb 100644 --- a/library/core/src/intrinsics/mod.rs +++ b/library/core/src/intrinsics/mod.rs @@ -68,6 +68,7 @@ use crate::marker::{DiscriminantKind, Tuple}; use crate::mem::SizedTypeProperties; use crate::{ptr, ub_checks}; +pub mod fallback; pub mod mir; pub mod simd; @@ -3305,6 +3306,34 @@ pub const fn mul_with_overflow(_x: T, _y: T) -> (T, bool) { unimplemented!() } +/// Performs full-width multiplication and addition with a carry: +/// `multiplier * multiplicand + addend + carry`. +/// +/// This is possible without any overflow. For `uN`: +/// MAX * MAX + MAX + MAX +/// => (2ⁿ-1) × (2ⁿ-1) + (2ⁿ-1) + (2ⁿ-1) +/// => (2²ⁿ - 2ⁿ⁺¹ + 1) + (2ⁿ⁺¹ - 2) +/// => 2²ⁿ - 1 +/// +/// For `iN`, the upper bound is MIN * MIN + MAX + MAX => 2²ⁿ⁻² + 2ⁿ - 2, +/// and the lower bound is MAX * MIN + MIN + MIN => -2²ⁿ⁻² - 2ⁿ + 2ⁿ⁺¹. +/// +/// This currently supports unsigned integers *only*, no signed ones. +/// The stabilized versions of this intrinsic are available on integers. +#[unstable(feature = "core_intrinsics", issue = "none")] +#[rustc_const_unstable(feature = "const_carrying_mul_add", issue = "85532")] +#[rustc_nounwind] +#[cfg_attr(not(bootstrap), rustc_intrinsic)] +#[cfg_attr(not(bootstrap), miri::intrinsic_fallback_is_spec)] +pub const fn carrying_mul_add, U>( + multiplier: T, + multiplicand: T, + addend: T, + carry: T, +) -> (U, T) { + multiplier.carrying_mul_add(multiplicand, addend, carry) +} + /// Performs an exact division, resulting in undefined behavior where /// `x % y != 0` or `y == 0` or `x == T::MIN && y == -1` /// diff --git a/library/core/src/lib.rs b/library/core/src/lib.rs index 18bd9bb811836..a23bed92e9063 100644 --- a/library/core/src/lib.rs +++ b/library/core/src/lib.rs @@ -111,6 +111,7 @@ #![cfg_attr(bootstrap, feature(do_not_recommend))] #![feature(array_ptr_get)] #![feature(asm_experimental_arch)] +#![feature(const_carrying_mul_add)] #![feature(const_eval_select)] #![feature(const_typed_swap)] #![feature(core_intrinsics)] diff --git a/library/core/src/num/mod.rs b/library/core/src/num/mod.rs index 357a85ae843cd..1f5b8ce6c6e22 100644 --- a/library/core/src/num/mod.rs +++ b/library/core/src/num/mod.rs @@ -228,134 +228,6 @@ macro_rules! midpoint_impl { }; } -macro_rules! widening_impl { - ($SelfT:ty, $WideT:ty, $BITS:literal, unsigned) => { - /// Calculates the complete product `self * rhs` without the possibility to overflow. - /// - /// This returns the low-order (wrapping) bits and the high-order (overflow) bits - /// of the result as two separate values, in that order. - /// - /// If you also need to add a carry to the wide result, then you want - /// [`Self::carrying_mul`] instead. - /// - /// # Examples - /// - /// Basic usage: - /// - /// Please note that this example is shared between integer types. - /// Which explains why `u32` is used here. - /// - /// ``` - /// #![feature(bigint_helper_methods)] - /// assert_eq!(5u32.widening_mul(2), (10, 0)); - /// assert_eq!(1_000_000_000u32.widening_mul(10), (1410065408, 2)); - /// ``` - #[unstable(feature = "bigint_helper_methods", issue = "85532")] - #[must_use = "this returns the result of the operation, \ - without modifying the original"] - #[inline] - pub const fn widening_mul(self, rhs: Self) -> (Self, Self) { - // note: longer-term this should be done via an intrinsic, - // but for now we can deal without an impl for u128/i128 - // SAFETY: overflow will be contained within the wider types - let wide = unsafe { (self as $WideT).unchecked_mul(rhs as $WideT) }; - (wide as $SelfT, (wide >> $BITS) as $SelfT) - } - - /// Calculates the "full multiplication" `self * rhs + carry` - /// without the possibility to overflow. - /// - /// This returns the low-order (wrapping) bits and the high-order (overflow) bits - /// of the result as two separate values, in that order. - /// - /// Performs "long multiplication" which takes in an extra amount to add, and may return an - /// additional amount of overflow. This allows for chaining together multiple - /// multiplications to create "big integers" which represent larger values. - /// - /// If you don't need the `carry`, then you can use [`Self::widening_mul`] instead. - /// - /// # Examples - /// - /// Basic usage: - /// - /// Please note that this example is shared between integer types. - /// Which explains why `u32` is used here. - /// - /// ``` - /// #![feature(bigint_helper_methods)] - /// assert_eq!(5u32.carrying_mul(2, 0), (10, 0)); - /// assert_eq!(5u32.carrying_mul(2, 10), (20, 0)); - /// assert_eq!(1_000_000_000u32.carrying_mul(10, 0), (1410065408, 2)); - /// assert_eq!(1_000_000_000u32.carrying_mul(10, 10), (1410065418, 2)); - #[doc = concat!("assert_eq!(", - stringify!($SelfT), "::MAX.carrying_mul(", stringify!($SelfT), "::MAX, ", stringify!($SelfT), "::MAX), ", - "(0, ", stringify!($SelfT), "::MAX));" - )] - /// ``` - /// - /// This is the core operation needed for scalar multiplication when - /// implementing it for wider-than-native types. - /// - /// ``` - /// #![feature(bigint_helper_methods)] - /// fn scalar_mul_eq(little_endian_digits: &mut Vec, multiplicand: u16) { - /// let mut carry = 0; - /// for d in little_endian_digits.iter_mut() { - /// (*d, carry) = d.carrying_mul(multiplicand, carry); - /// } - /// if carry != 0 { - /// little_endian_digits.push(carry); - /// } - /// } - /// - /// let mut v = vec![10, 20]; - /// scalar_mul_eq(&mut v, 3); - /// assert_eq!(v, [30, 60]); - /// - /// assert_eq!(0x87654321_u64 * 0xFEED, 0x86D3D159E38D); - /// let mut v = vec![0x4321, 0x8765]; - /// scalar_mul_eq(&mut v, 0xFEED); - /// assert_eq!(v, [0xE38D, 0xD159, 0x86D3]); - /// ``` - /// - /// If `carry` is zero, this is similar to [`overflowing_mul`](Self::overflowing_mul), - /// except that it gives the value of the overflow instead of just whether one happened: - /// - /// ``` - /// #![feature(bigint_helper_methods)] - /// let r = u8::carrying_mul(7, 13, 0); - /// assert_eq!((r.0, r.1 != 0), u8::overflowing_mul(7, 13)); - /// let r = u8::carrying_mul(13, 42, 0); - /// assert_eq!((r.0, r.1 != 0), u8::overflowing_mul(13, 42)); - /// ``` - /// - /// The value of the first field in the returned tuple matches what you'd get - /// by combining the [`wrapping_mul`](Self::wrapping_mul) and - /// [`wrapping_add`](Self::wrapping_add) methods: - /// - /// ``` - /// #![feature(bigint_helper_methods)] - /// assert_eq!( - /// 789_u16.carrying_mul(456, 123).0, - /// 789_u16.wrapping_mul(456).wrapping_add(123), - /// ); - /// ``` - #[unstable(feature = "bigint_helper_methods", issue = "85532")] - #[must_use = "this returns the result of the operation, \ - without modifying the original"] - #[inline] - pub const fn carrying_mul(self, rhs: Self, carry: Self) -> (Self, Self) { - // note: longer-term this should be done via an intrinsic, - // but for now we can deal without an impl for u128/i128 - // SAFETY: overflow will be contained within the wider types - let wide = unsafe { - (self as $WideT).unchecked_mul(rhs as $WideT).unchecked_add(carry as $WideT) - }; - (wide as $SelfT, (wide >> $BITS) as $SelfT) - } - }; -} - impl i8 { int_impl! { Self = i8, @@ -576,7 +448,6 @@ impl u8 { from_xe_bytes_doc = u8_xe_bytes_doc!(), bound_condition = "", } - widening_impl! { u8, u16, 8, unsigned } midpoint_impl! { u8, u16, unsigned } /// Checks if the value is within the ASCII range. @@ -1192,7 +1063,6 @@ impl u16 { from_xe_bytes_doc = "", bound_condition = "", } - widening_impl! { u16, u32, 16, unsigned } midpoint_impl! { u16, u32, unsigned } /// Checks if the value is a Unicode surrogate code point, which are disallowed values for [`char`]. @@ -1240,7 +1110,6 @@ impl u32 { from_xe_bytes_doc = "", bound_condition = "", } - widening_impl! { u32, u64, 32, unsigned } midpoint_impl! { u32, u64, unsigned } } @@ -1264,7 +1133,6 @@ impl u64 { from_xe_bytes_doc = "", bound_condition = "", } - widening_impl! { u64, u128, 64, unsigned } midpoint_impl! { u64, u128, unsigned } } @@ -1314,7 +1182,6 @@ impl usize { from_xe_bytes_doc = usize_isize_from_xe_bytes_doc!(), bound_condition = " on 16-bit targets", } - widening_impl! { usize, u32, 16, unsigned } midpoint_impl! { usize, u32, unsigned } } @@ -1339,7 +1206,6 @@ impl usize { from_xe_bytes_doc = usize_isize_from_xe_bytes_doc!(), bound_condition = " on 32-bit targets", } - widening_impl! { usize, u64, 32, unsigned } midpoint_impl! { usize, u64, unsigned } } @@ -1364,7 +1230,6 @@ impl usize { from_xe_bytes_doc = usize_isize_from_xe_bytes_doc!(), bound_condition = " on 64-bit targets", } - widening_impl! { usize, u128, 64, unsigned } midpoint_impl! { usize, u128, unsigned } } diff --git a/library/core/src/num/uint_macros.rs b/library/core/src/num/uint_macros.rs index 151d879fabeed..4a5fdbfb0ea2c 100644 --- a/library/core/src/num/uint_macros.rs +++ b/library/core/src/num/uint_macros.rs @@ -3347,6 +3347,122 @@ macro_rules! uint_impl { unsafe { mem::transmute(bytes) } } + /// Calculates the complete product `self * rhs` without the possibility to overflow. + /// + /// This returns the low-order (wrapping) bits and the high-order (overflow) bits + /// of the result as two separate values, in that order. + /// + /// If you also need to add a carry to the wide result, then you want + /// [`Self::carrying_mul`] instead. + /// + /// # Examples + /// + /// Basic usage: + /// + /// Please note that this example is shared between integer types. + /// Which explains why `u32` is used here. + /// + /// ``` + /// #![feature(bigint_helper_methods)] + /// assert_eq!(5u32.widening_mul(2), (10, 0)); + /// assert_eq!(1_000_000_000u32.widening_mul(10), (1410065408, 2)); + /// ``` + #[unstable(feature = "bigint_helper_methods", issue = "85532")] + #[rustc_const_unstable(feature = "bigint_helper_methods", issue = "85532")] + #[must_use = "this returns the result of the operation, \ + without modifying the original"] + #[inline] + pub const fn widening_mul(self, rhs: Self) -> (Self, Self) { + Self::carrying_mul(self, rhs, 0) + } + + /// Calculates the "full multiplication" `self * rhs + carry` + /// without the possibility to overflow. + /// + /// This returns the low-order (wrapping) bits and the high-order (overflow) bits + /// of the result as two separate values, in that order. + /// + /// Performs "long multiplication" which takes in an extra amount to add, and may return an + /// additional amount of overflow. This allows for chaining together multiple + /// multiplications to create "big integers" which represent larger values. + /// + /// If you don't need the `carry`, then you can use [`Self::widening_mul`] instead. + /// + /// # Examples + /// + /// Basic usage: + /// + /// Please note that this example is shared between integer types. + /// Which explains why `u32` is used here. + /// + /// ``` + /// #![feature(bigint_helper_methods)] + /// assert_eq!(5u32.carrying_mul(2, 0), (10, 0)); + /// assert_eq!(5u32.carrying_mul(2, 10), (20, 0)); + /// assert_eq!(1_000_000_000u32.carrying_mul(10, 0), (1410065408, 2)); + /// assert_eq!(1_000_000_000u32.carrying_mul(10, 10), (1410065418, 2)); + #[doc = concat!("assert_eq!(", + stringify!($SelfT), "::MAX.carrying_mul(", stringify!($SelfT), "::MAX, ", stringify!($SelfT), "::MAX), ", + "(0, ", stringify!($SelfT), "::MAX));" + )] + /// ``` + /// + /// This is the core operation needed for scalar multiplication when + /// implementing it for wider-than-native types. + /// + /// ``` + /// #![feature(bigint_helper_methods)] + /// fn scalar_mul_eq(little_endian_digits: &mut Vec, multiplicand: u16) { + /// let mut carry = 0; + /// for d in little_endian_digits.iter_mut() { + /// (*d, carry) = d.carrying_mul(multiplicand, carry); + /// } + /// if carry != 0 { + /// little_endian_digits.push(carry); + /// } + /// } + /// + /// let mut v = vec![10, 20]; + /// scalar_mul_eq(&mut v, 3); + /// assert_eq!(v, [30, 60]); + /// + /// assert_eq!(0x87654321_u64 * 0xFEED, 0x86D3D159E38D); + /// let mut v = vec![0x4321, 0x8765]; + /// scalar_mul_eq(&mut v, 0xFEED); + /// assert_eq!(v, [0xE38D, 0xD159, 0x86D3]); + /// ``` + /// + /// If `carry` is zero, this is similar to [`overflowing_mul`](Self::overflowing_mul), + /// except that it gives the value of the overflow instead of just whether one happened: + /// + /// ``` + /// #![feature(bigint_helper_methods)] + /// let r = u8::carrying_mul(7, 13, 0); + /// assert_eq!((r.0, r.1 != 0), u8::overflowing_mul(7, 13)); + /// let r = u8::carrying_mul(13, 42, 0); + /// assert_eq!((r.0, r.1 != 0), u8::overflowing_mul(13, 42)); + /// ``` + /// + /// The value of the first field in the returned tuple matches what you'd get + /// by combining the [`wrapping_mul`](Self::wrapping_mul) and + /// [`wrapping_add`](Self::wrapping_add) methods: + /// + /// ``` + /// #![feature(bigint_helper_methods)] + /// assert_eq!( + /// 789_u16.carrying_mul(456, 123).0, + /// 789_u16.wrapping_mul(456).wrapping_add(123), + /// ); + /// ``` + #[unstable(feature = "bigint_helper_methods", issue = "85532")] + #[rustc_const_unstable(feature = "bigint_helper_methods", issue = "85532")] + #[must_use = "this returns the result of the operation, \ + without modifying the original"] + #[inline] + pub const fn carrying_mul(self, rhs: Self, carry: Self) -> (Self, Self) { + intrinsics::carrying_mul_add(self, rhs, 0, carry) + } + /// New code should prefer to use #[doc = concat!("[`", stringify!($SelfT), "::MIN", "`] instead.")] /// diff --git a/library/core/tests/intrinsics.rs b/library/core/tests/intrinsics.rs index 8b731cf5b25d1..76f4259409120 100644 --- a/library/core/tests/intrinsics.rs +++ b/library/core/tests/intrinsics.rs @@ -125,3 +125,61 @@ fn test_three_way_compare_in_const_contexts() { assert_eq!(SIGNED_EQUAL, Equal); assert_eq!(SIGNED_GREATER, Greater); } + +fn fallback_cma( + a: T, + b: T, + c: T, + d: T, +) -> (T::Unsigned, T) { + a.carrying_mul_add(b, c, d) +} + +#[test] +fn carrying_mul_add_fallback_u32() { + let r = fallback_cma::(0x9e37_79b9, 0x7f4a_7c15, 0xf39c_c060, 0x5ced_c834); + assert_eq!(r, (0x2087_20c1, 0x4eab_8e1d)); + let r = fallback_cma::(0x1082_276b, 0xf3a2_7251, 0xf86c_6a11, 0xd0c1_8e95); + assert_eq!(r, (0x7aa0_1781, 0x0fb6_0528)); +} + +#[test] +fn carrying_mul_add_fallback_i32() { + let r = fallback_cma::(-1, -1, -1, -1); + assert_eq!(r, (u32::MAX, -1)); + let r = fallback_cma::(1, -1, 1, 1); + assert_eq!(r, (1, 0)); +} + +#[test] +fn carrying_mul_add_fallback_u128() { + assert_eq!(fallback_cma::(1, 1, 1, 1), (3, 0)); + assert_eq!(fallback_cma::(0, 0, u128::MAX, u128::MAX), (u128::MAX - 1, 1)); + assert_eq!( + fallback_cma::(u128::MAX, u128::MAX, u128::MAX, u128::MAX), + (u128::MAX, u128::MAX), + ); + + let r = fallback_cma::( + 0x243f6a8885a308d313198a2e03707344, + 0xa4093822299f31d0082efa98ec4e6c89, + 0x452821e638d01377be5466cf34e90c6c, + 0xc0ac29b7c97c50dd3f84d5b5b5470917, + ); + assert_eq!(r, (0x8050ec20ed554e40338d277e00b674e7, 0x1739ee6cea07da409182d003859b59d8)); + let r = fallback_cma::( + 0x9216d5d98979fb1bd1310ba698dfb5ac, + 0x2ffd72dbd01adfb7b8e1afed6a267e96, + 0xba7c9045f12c7f9924a19947b3916cf7, + 0x0801f2e2858efc16636920d871574e69, + ); + assert_eq!(r, (0x185525545fdb2fefb502a3a602efd628, 0x1b62d35fe3bff6b566f99667ef7ebfd6)); +} + +#[test] +fn carrying_mul_add_fallback_i128() { + let r = fallback_cma::(-1, -1, -1, -1); + assert_eq!(r, (u128::MAX, -1)); + let r = fallback_cma::(1, -1, 1, 1); + assert_eq!(r, (1, 0)); +} diff --git a/library/core/tests/lib.rs b/library/core/tests/lib.rs index 89b65eefd027e..668d5c19c2378 100644 --- a/library/core/tests/lib.rs +++ b/library/core/tests/lib.rs @@ -18,6 +18,7 @@ #![feature(const_swap)] #![feature(const_trait_impl)] #![feature(core_intrinsics)] +#![feature(core_intrinsics_fallbacks)] #![feature(core_io_borrowed_buf)] #![feature(core_private_bignum)] #![feature(core_private_diy_float)] From 0c1cc9eab5bcf53cc3c6dda5348fa811023b170c Mon Sep 17 00:00:00 2001 From: Scott McMurray Date: Fri, 29 Nov 2024 22:51:46 -0800 Subject: [PATCH 2/2] Override `carrying_mul_add` in cg_llvm --- compiler/rustc_codegen_llvm/src/intrinsic.rs | 31 +++++ compiler/rustc_codegen_llvm/src/lib.rs | 1 + library/core/src/intrinsics/fallback.rs | 4 +- library/core/tests/intrinsics.rs | 10 ++ tests/codegen/intrinsics/carrying_mul_add.rs | 137 +++++++++++++++++++ 5 files changed, 181 insertions(+), 2 deletions(-) create mode 100644 tests/codegen/intrinsics/carrying_mul_add.rs diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index c38c5d4c6442c..cabcfc9b42b4e 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -340,6 +340,37 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { self.const_i32(cache_type), ]) } + sym::carrying_mul_add => { + let (size, signed) = fn_args.type_at(0).int_size_and_signed(self.tcx); + + let wide_llty = self.type_ix(size.bits() * 2); + let args = args.as_array().unwrap(); + let [a, b, c, d] = args.map(|a| self.intcast(a.immediate(), wide_llty, signed)); + + let wide = if signed { + let prod = self.unchecked_smul(a, b); + let acc = self.unchecked_sadd(prod, c); + self.unchecked_sadd(acc, d) + } else { + let prod = self.unchecked_umul(a, b); + let acc = self.unchecked_uadd(prod, c); + self.unchecked_uadd(acc, d) + }; + + let narrow_llty = self.type_ix(size.bits()); + let low = self.trunc(wide, narrow_llty); + let bits_const = self.const_uint(wide_llty, size.bits()); + // No need for ashr when signed; LLVM changes it to lshr anyway. + let high = self.lshr(wide, bits_const); + // FIXME: could be `trunc nuw`, even for signed. + let high = self.trunc(high, narrow_llty); + + let pair_llty = self.type_struct(&[narrow_llty, narrow_llty], false); + let pair = self.const_poison(pair_llty); + let pair = self.insert_value(pair, low, 0); + let pair = self.insert_value(pair, high, 1); + pair + } sym::ctlz | sym::ctlz_nonzero | sym::cttz diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index 0de0c6a7a89ec..dca7738daf774 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -17,6 +17,7 @@ #![feature(iter_intersperse)] #![feature(let_chains)] #![feature(rustdoc_internals)] +#![feature(slice_as_array)] #![feature(try_blocks)] #![warn(unreachable_pub)] // tidy-alphabetical-end diff --git a/library/core/src/intrinsics/fallback.rs b/library/core/src/intrinsics/fallback.rs index d87331f726208..1779126b180ea 100644 --- a/library/core/src/intrinsics/fallback.rs +++ b/library/core/src/intrinsics/fallback.rs @@ -100,8 +100,8 @@ impl const CarryingMulAdd for i128 { fn carrying_mul_add(self, b: i128, c: i128, d: i128) -> (u128, i128) { let (low, high) = wide_mul_u128(self as u128, b as u128); let mut high = high as i128; - high = high.wrapping_add((self >> 127) * b); - high = high.wrapping_add(self * (b >> 127)); + high = high.wrapping_add(i128::wrapping_mul(self >> 127, b)); + high = high.wrapping_add(i128::wrapping_mul(self, b >> 127)); let (low, carry) = u128::overflowing_add(low, c as u128); high = high.wrapping_add((carry as i128) + (c >> 127)); let (low, carry) = u128::overflowing_add(low, d as u128); diff --git a/library/core/tests/intrinsics.rs b/library/core/tests/intrinsics.rs index 76f4259409120..744a6a0d2dd8f 100644 --- a/library/core/tests/intrinsics.rs +++ b/library/core/tests/intrinsics.rs @@ -153,6 +153,7 @@ fn carrying_mul_add_fallback_i32() { #[test] fn carrying_mul_add_fallback_u128() { + assert_eq!(fallback_cma::(u128::MAX, u128::MAX, 0, 0), (1, u128::MAX - 1)); assert_eq!(fallback_cma::(1, 1, 1, 1), (3, 0)); assert_eq!(fallback_cma::(0, 0, u128::MAX, u128::MAX), (u128::MAX - 1, 1)); assert_eq!( @@ -178,8 +179,17 @@ fn carrying_mul_add_fallback_u128() { #[test] fn carrying_mul_add_fallback_i128() { + assert_eq!(fallback_cma::(-1, -1, 0, 0), (1, 0)); let r = fallback_cma::(-1, -1, -1, -1); assert_eq!(r, (u128::MAX, -1)); let r = fallback_cma::(1, -1, 1, 1); assert_eq!(r, (1, 0)); + assert_eq!( + fallback_cma::(i128::MAX, i128::MAX, i128::MAX, i128::MAX), + (u128::MAX, i128::MAX / 2), + ); + assert_eq!( + fallback_cma::(i128::MIN, i128::MIN, i128::MAX, i128::MAX), + (u128::MAX - 1, -(i128::MIN / 2)), + ); } diff --git a/tests/codegen/intrinsics/carrying_mul_add.rs b/tests/codegen/intrinsics/carrying_mul_add.rs new file mode 100644 index 0000000000000..17cb2a49ebea8 --- /dev/null +++ b/tests/codegen/intrinsics/carrying_mul_add.rs @@ -0,0 +1,137 @@ +//@ revisions: RAW OPT +//@ compile-flags: -C opt-level=1 +//@[RAW] compile-flags: -C no-prepopulate-passes +//@[OPT] min-llvm-version: 19 + +#![crate_type = "lib"] +#![feature(core_intrinsics)] +#![feature(core_intrinsics_fallbacks)] + +// Note that LLVM seems to sometimes permute the order of arguments to mul and add, +// so these tests don't check the arguments in the optimized revision. + +use std::intrinsics::{carrying_mul_add, fallback}; + +// The fallbacks are emitted even when they're never used, but optimize out. + +// RAW: wide_mul_u128 +// OPT-NOT: wide_mul_u128 + +// CHECK-LABEL: @cma_u8 +#[no_mangle] +pub unsafe fn cma_u8(a: u8, b: u8, c: u8, d: u8) -> (u8, u8) { + // CHECK: [[A:%.+]] = zext i8 %a to i16 + // CHECK: [[B:%.+]] = zext i8 %b to i16 + // CHECK: [[C:%.+]] = zext i8 %c to i16 + // CHECK: [[D:%.+]] = zext i8 %d to i16 + // CHECK: [[AB:%.+]] = mul nuw i16 + // RAW-SAME: [[A]], [[B]] + // CHECK: [[ABC:%.+]] = add nuw i16 + // RAW-SAME: [[AB]], [[C]] + // CHECK: [[ABCD:%.+]] = add nuw i16 + // RAW-SAME: [[ABC]], [[D]] + // CHECK: [[LOW:%.+]] = trunc i16 [[ABCD]] to i8 + // CHECK: [[HIGHW:%.+]] = lshr i16 [[ABCD]], 8 + // RAW: [[HIGH:%.+]] = trunc i16 [[HIGHW]] to i8 + // OPT: [[HIGH:%.+]] = trunc nuw i16 [[HIGHW]] to i8 + // CHECK: [[PAIR0:%.+]] = insertvalue { i8, i8 } poison, i8 [[LOW]], 0 + // CHECK: [[PAIR1:%.+]] = insertvalue { i8, i8 } [[PAIR0]], i8 [[HIGH]], 1 + // OPT: ret { i8, i8 } [[PAIR1]] + carrying_mul_add(a, b, c, d) +} + +// CHECK-LABEL: @cma_u32 +#[no_mangle] +pub unsafe fn cma_u32(a: u32, b: u32, c: u32, d: u32) -> (u32, u32) { + // CHECK: [[A:%.+]] = zext i32 %a to i64 + // CHECK: [[B:%.+]] = zext i32 %b to i64 + // CHECK: [[C:%.+]] = zext i32 %c to i64 + // CHECK: [[D:%.+]] = zext i32 %d to i64 + // CHECK: [[AB:%.+]] = mul nuw i64 + // RAW-SAME: [[A]], [[B]] + // CHECK: [[ABC:%.+]] = add nuw i64 + // RAW-SAME: [[AB]], [[C]] + // CHECK: [[ABCD:%.+]] = add nuw i64 + // RAW-SAME: [[ABC]], [[D]] + // CHECK: [[LOW:%.+]] = trunc i64 [[ABCD]] to i32 + // CHECK: [[HIGHW:%.+]] = lshr i64 [[ABCD]], 32 + // RAW: [[HIGH:%.+]] = trunc i64 [[HIGHW]] to i32 + // OPT: [[HIGH:%.+]] = trunc nuw i64 [[HIGHW]] to i32 + // CHECK: [[PAIR0:%.+]] = insertvalue { i32, i32 } poison, i32 [[LOW]], 0 + // CHECK: [[PAIR1:%.+]] = insertvalue { i32, i32 } [[PAIR0]], i32 [[HIGH]], 1 + // OPT: ret { i32, i32 } [[PAIR1]] + carrying_mul_add(a, b, c, d) +} + +// CHECK-LABEL: @cma_u128 +// CHECK-SAME: sret{{.+}}dereferenceable(32){{.+}}%_0,{{.+}}%a,{{.+}}%b,{{.+}}%c,{{.+}}%d +#[no_mangle] +pub unsafe fn cma_u128(a: u128, b: u128, c: u128, d: u128) -> (u128, u128) { + // CHECK: [[A:%.+]] = zext i128 %a to i256 + // CHECK: [[B:%.+]] = zext i128 %b to i256 + // CHECK: [[C:%.+]] = zext i128 %c to i256 + // CHECK: [[D:%.+]] = zext i128 %d to i256 + // CHECK: [[AB:%.+]] = mul nuw i256 + // RAW-SAME: [[A]], [[B]] + // CHECK: [[ABC:%.+]] = add nuw i256 + // RAW-SAME: [[AB]], [[C]] + // CHECK: [[ABCD:%.+]] = add nuw i256 + // RAW-SAME: [[ABC]], [[D]] + // CHECK: [[LOW:%.+]] = trunc i256 [[ABCD]] to i128 + // CHECK: [[HIGHW:%.+]] = lshr i256 [[ABCD]], 128 + // RAW: [[HIGH:%.+]] = trunc i256 [[HIGHW]] to i128 + // OPT: [[HIGH:%.+]] = trunc nuw i256 [[HIGHW]] to i128 + // RAW: [[PAIR0:%.+]] = insertvalue { i128, i128 } poison, i128 [[LOW]], 0 + // RAW: [[PAIR1:%.+]] = insertvalue { i128, i128 } [[PAIR0]], i128 [[HIGH]], 1 + // OPT: store i128 [[LOW]], ptr %_0 + // OPT: [[P1:%.+]] = getelementptr inbounds i8, ptr %_0, i64 16 + // OPT: store i128 [[HIGH]], ptr [[P1]] + // CHECK: ret void + carrying_mul_add(a, b, c, d) +} + +// CHECK-LABEL: @cma_i128 +// CHECK-SAME: sret{{.+}}dereferenceable(32){{.+}}%_0,{{.+}}%a,{{.+}}%b,{{.+}}%c,{{.+}}%d +#[no_mangle] +pub unsafe fn cma_i128(a: i128, b: i128, c: i128, d: i128) -> (u128, i128) { + // CHECK: [[A:%.+]] = sext i128 %a to i256 + // CHECK: [[B:%.+]] = sext i128 %b to i256 + // CHECK: [[C:%.+]] = sext i128 %c to i256 + // CHECK: [[D:%.+]] = sext i128 %d to i256 + // CHECK: [[AB:%.+]] = mul nsw i256 + // RAW-SAME: [[A]], [[B]] + // CHECK: [[ABC:%.+]] = add nsw i256 + // RAW-SAME: [[AB]], [[C]] + // CHECK: [[ABCD:%.+]] = add nsw i256 + // RAW-SAME: [[ABC]], [[D]] + // CHECK: [[LOW:%.+]] = trunc i256 [[ABCD]] to i128 + // CHECK: [[HIGHW:%.+]] = lshr i256 [[ABCD]], 128 + // RAW: [[HIGH:%.+]] = trunc i256 [[HIGHW]] to i128 + // OPT: [[HIGH:%.+]] = trunc nuw i256 [[HIGHW]] to i128 + // RAW: [[PAIR0:%.+]] = insertvalue { i128, i128 } poison, i128 [[LOW]], 0 + // RAW: [[PAIR1:%.+]] = insertvalue { i128, i128 } [[PAIR0]], i128 [[HIGH]], 1 + // OPT: store i128 [[LOW]], ptr %_0 + // OPT: [[P1:%.+]] = getelementptr inbounds i8, ptr %_0, i64 16 + // OPT: store i128 [[HIGH]], ptr [[P1]] + // CHECK: ret void + carrying_mul_add(a, b, c, d) +} + +// CHECK-LABEL: @fallback_cma_u32 +#[no_mangle] +pub unsafe fn fallback_cma_u32(a: u32, b: u32, c: u32, d: u32) -> (u32, u32) { + // OPT-DAG: [[A:%.+]] = zext i32 %a to i64 + // OPT-DAG: [[B:%.+]] = zext i32 %b to i64 + // OPT-DAG: [[AB:%.+]] = mul nuw i64 + // OPT-DAG: [[C:%.+]] = zext i32 %c to i64 + // OPT-DAG: [[ABC:%.+]] = add nuw i64{{.+}}[[C]] + // OPT-DAG: [[D:%.+]] = zext i32 %d to i64 + // OPT-DAG: [[ABCD:%.+]] = add nuw i64{{.+}}[[D]] + // OPT-DAG: [[LOW:%.+]] = trunc i64 [[ABCD]] to i32 + // OPT-DAG: [[HIGHW:%.+]] = lshr i64 [[ABCD]], 32 + // OPT-DAG: [[HIGH:%.+]] = trunc nuw i64 [[HIGHW]] to i32 + // OPT-DAG: [[PAIR0:%.+]] = insertvalue { i32, i32 } poison, i32 [[LOW]], 0 + // OPT-DAG: [[PAIR1:%.+]] = insertvalue { i32, i32 } [[PAIR0]], i32 [[HIGH]], 1 + // OPT-DAG: ret { i32, i32 } [[PAIR1]] + fallback::CarryingMulAdd::carrying_mul_add(a, b, c, d) +}