Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simd128 version of the q2k-q8k vecdot product. #1011

Merged
merged 5 commits into from
Sep 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions candle-core/src/quantized/k_quants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -710,18 +710,17 @@ impl GgmlType for BlockQ2K {

let mut isum = 0;
let mut is = 0;
let mut d;
for _ in 0..(QK_K / 128) {
let mut shift = 0;
for _ in 0..4 {
d = (sc[is] & 0xF) as i32;
let d = (sc[is] & 0xF) as i32;
is += 1;
let mut isuml = 0;
for l in 0..16 {
isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
}
isum += d * isuml;
d = (sc[is] & 0xF) as i32;
let d = (sc[is] & 0xF) as i32;
is += 1;
isuml = 0;
for l in 16..32 {
Expand Down Expand Up @@ -1086,7 +1085,6 @@ impl GgmlType for BlockQ3K {
let d_all = block.d.to_f32();
let mut m = 1;
let mut is = 0;
let mut dl;

// Dequantize both 128 long blocks
// 32 qs values per 128 long block
Expand All @@ -1097,7 +1095,7 @@ impl GgmlType for BlockQ3K {
for (scale_index, scale_scoped_y) in
shift_scoped_y.chunks_exact_mut(16).enumerate()
{
dl = d_all * (scales[is] as f32 - 32.0);
let dl = d_all * (scales[is] as f32 - 32.0);
for (i, inner_y) in scale_scoped_y.iter_mut().enumerate() {
let new_y = dl
* (((qs[i + 16 * scale_index] >> shift) & 3) as i8
Expand Down
112 changes: 72 additions & 40 deletions candle-core/src/quantized/simd128.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,53 +102,85 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
#[inline(always)]
pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result<f32> {
if n % QK_K != 0 {
crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
}
let mut sumf = 0.0;
for (x, y) in xs.iter().zip(ys.iter()) {
let mut q2: &[_] = &x.qs;
let mut q8: &[_] = &y.qs;
let sc = &x.scales;

let mut summs = 0;
for (bsum, scale) in y.bsums.iter().zip(sc) {
summs += *bsum as i32 * ((scale >> 4) as i32);
}
unsafe {
let mut sumf = f32x4_splat(0f32);
for (x, y) in xs.iter().zip(ys.iter()) {
let mut q2: &[_] = &x.qs;
let mut q8: &[_] = &y.qs;
let sc = &x.scales;

let dall = y.d * x.d.to_f32();
let dmin = y.d * x.dmin.to_f32();

let mut isum = 0;
let mut is = 0;
let mut d;
for _ in 0..(QK_K / 128) {
let mut shift = 0;
for _ in 0..4 {
d = (sc[is] & 0xF) as i32;
is += 1;
let mut isuml = 0;
for l in 0..16 {
isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
}
isum += d * isuml;
d = (sc[is] & 0xF) as i32;
is += 1;
isuml = 0;
for l in 16..32 {
isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
let mut summs = i32x4_splat(0);
for i in (0..(QK_K / 16)).step_by(4) {
let bsums = i32x4_load_extend_i16x4(y.bsums.as_ptr().add(i));
let scales = i32x4_shr(
i32x4(
sc[i] as i32,
sc[i + 1] as i32,
sc[i + 2] as i32,
sc[i + 3] as i32,
),
4,
);
summs = i32x4_add(summs, i32x4_mul(bsums, scales))
}
let summs = f32x4_convert_i32x4(summs);

let dall = y.d * x.d.to_f32();
let dmin = y.d * x.dmin.to_f32();

let mut isum = i32x4_splat(0);
let mut is = 0;
for _ in 0..(QK_K / 128) {
let mut shift = 0;
for _ in 0..4 {
let d = (sc[is] & 0xF) as i32;
is += 1;
let mut isuml = i16x8_splat(0);
for l in (0..16).step_by(8) {
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
}
let dd = i32x4_splat(d);
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
let d = (sc[is] & 0xF) as i32;
is += 1;
let mut isuml = i16x8_splat(0);
for l in (16..32).step_by(8) {
let q8 = i16x8_load_extend_i8x8(q8.as_ptr().add(l));
let q2 = i16x8_load_extend_u8x8(q2.as_ptr().add(l));
let q2 = v128_and(i16x8_shr(q2, shift), i16x8_splat(3));
isuml = i16x8_add(isuml, i16x8_mul(q2, q8))
}
let dd = i32x4_splat(d);
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_low_i16x8(isuml), dd));
isum = i32x4_add(isum, i32x4_mul(i32x4_extend_high_i16x8(isuml), dd));
shift += 2;
// adjust the indexing
q8 = &q8[32..];
}
isum += d * isuml;
shift += 2;
// adjust the indexing
q8 = &q8[32..];
q2 = &q2[32..];
}
// adjust the indexing
q2 = &q2[32..];
let isum = f32x4_convert_i32x4(isum);
sumf = f32x4_add(
sumf,
f32x4_sub(
f32x4_mul(isum, f32x4_splat(dall)),
f32x4_mul(summs, f32x4_splat(dmin)),
),
);
}
sumf += dall * isum as f32 - dmin * summs as f32;
let sumf = f32x4_extract_lane::<0>(sumf)
+ f32x4_extract_lane::<1>(sumf)
+ f32x4_extract_lane::<2>(sumf)
+ f32x4_extract_lane::<3>(sumf);
Ok(sumf)
}

Ok(sumf)
}

#[inline(always)]
Expand Down
2 changes: 1 addition & 1 deletion candle-core/tests/quantized_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {

let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;

if error > GGML_MAX_DOT_PRODUCT_ERROR {
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR {
candle_core::bail!(
"Dot product error {} exceeds max error {}",
error,
Expand Down
2 changes: 1 addition & 1 deletion candle-wasm-tests/tests/quantized_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {

let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;

if error > GGML_MAX_DOT_PRODUCT_ERROR {
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR {
candle::bail!(
"Dot product error {} exceeds max error {}",
error,
Expand Down