Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*Major T/s improvement* Use the Metal qmatmul MM kernels #2615

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion candle-core/benches/benchmarks/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ impl BenchDevice for Device {
Device::Cpu => Ok(()),
Device::Cuda(device) => {
#[cfg(feature = "cuda")]
return Ok(device.synchronize()?);
{
use cuda::WrapErr;
return Ok(device.synchronize().w()?);
}
#[cfg(not(feature = "cuda"))]
panic!("Cuda device without cuda feature enabled: {:?}", device)
}
Expand Down
83 changes: 81 additions & 2 deletions candle-core/src/cpu/avx.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use super::{Cpu, CpuF16};
use super::{Cpu, CpuBF16, CpuF16};
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;

use half::f16;
use half::{bf16, f16};

pub struct CurrentCpu {}

Expand Down Expand Up @@ -146,3 +146,82 @@ impl CpuF16<ARR> for CurrentCpuF16 {
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
}
}

pub struct CurrentCpuBF16 {}
impl CpuBF16<ARR> for CurrentCpuBF16 {
type Unit = __m256;
type Array = [__m256; ARR];

const STEP: usize = STEP;
const EPR: usize = EPR;

fn n() -> usize {
ARR
}

unsafe fn zero() -> Self::Unit {
_mm256_setzero_ps()
}

unsafe fn zero_array() -> Self::Array {
[Self::zero(); ARR]
}

unsafe fn from_f32(v: f32) -> Self::Unit {
_mm256_set1_ps(v)
}

#[cfg(target_feature = "f16c")]
unsafe fn load(mem_addr: *const bf16) -> Self::Unit {
_mm256_cvtph_ps(_mm_loadu_si128(mem_addr as *const __m128i))
}

#[cfg(not(target_feature = "f16c"))]
unsafe fn load(mem_addr: *const bf16) -> Self::Unit {
let mut tmp = [0.0f32; 8];
for i in 0..8 {
tmp[i] = (*mem_addr.add(i)).to_f32();
}
_mm256_loadu_ps(tmp.as_ptr())
}

unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit {
_mm256_add_ps(a, b)
}

unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit {
_mm256_add_ps(_mm256_mul_ps(b, c), a)
}

#[cfg(target_feature = "f16c")]
unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) {
_mm_storeu_si128(mem_addr as *mut __m128i, _mm256_cvtps_ph(a, 0))
}

#[cfg(not(target_feature = "f16c"))]
unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit) {
let mut tmp = [0.0f32; 8];
_mm256_storeu_ps(tmp.as_mut_ptr(), a);
for i in 0..8 {
*mem_addr.add(i) = bf16::from_f32(tmp[i]);
}
}

unsafe fn vec_reduce(mut x: Self::Array, y: *mut f32) {
let mut offset = ARR >> 1;
for i in 0..offset {
x[i] = _mm256_add_ps(x[i], x[offset + i]);
}
offset >>= 1;
for i in 0..offset {
x[i] = _mm256_add_ps(x[i], x[offset + i]);
}
offset >>= 1;
for i in 0..offset {
x[i] = _mm256_add_ps(x[i], x[offset + i]);
}
let t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), _mm256_extractf128_ps(x[0], 1));
let t1 = _mm_hadd_ps(t0, t0);
*y = _mm_cvtss_f32(_mm_hadd_ps(t1, t1));
}
}
7 changes: 7 additions & 0 deletions candle-core/src/cpu/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ impl VecOps for half::bf16 {
fn max(self, other: Self) -> Self {
Self::max(self, other)
}

#[inline(always)]
unsafe fn vec_dot(lhs: *const Self, rhs: *const Self, res: *mut Self, len: usize) {
let mut res_f32 = 0f32;
super::vec_dot_bf16(lhs, rhs, &mut res_f32, len);
*res = half::bf16::from_f32(res_f32);
}
}
impl VecOps for u8 {
#[inline(always)]
Expand Down
62 changes: 60 additions & 2 deletions candle-core/src/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,33 @@ trait CpuF16<const ARR: usize> {
unsafe fn from_f32(v: f32) -> Self::Unit;
unsafe fn vec_store(mem_addr: *mut f16, a: Self::Unit);
}
use half::f16;

#[allow(unused)]
trait CpuBF16<const ARR: usize> {
type Unit;
type Array;
const STEP: usize;
const EPR: usize;

fn n() -> usize;
unsafe fn zero() -> Self::Unit;
unsafe fn zero_array() -> Self::Array;
unsafe fn load(mem_addr: *const bf16) -> Self::Unit;
unsafe fn vec_add(a: Self::Unit, b: Self::Unit) -> Self::Unit;
unsafe fn vec_fma(a: Self::Unit, b: Self::Unit, c: Self::Unit) -> Self::Unit;
unsafe fn vec_reduce(x: Self::Array, y: *mut f32);
unsafe fn from_f32(v: f32) -> Self::Unit;
unsafe fn vec_store(mem_addr: *mut bf16, a: Self::Unit);
}

use half::{bf16, f16};

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(target_feature = "avx")]
pub mod avx;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(target_feature = "avx")]
pub use avx::{CurrentCpu, CurrentCpuF16};
pub use avx::{CurrentCpu, CurrentCpuBF16, CurrentCpuF16};

#[cfg(target_arch = "wasm32")]
#[cfg(target_feature = "simd128")]
Expand Down Expand Up @@ -170,6 +189,34 @@ pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f
*c = sumf;
}

#[cfg(target_feature = "avx")]
#[inline(always)]
pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) {
let mut sumf = 0.0f32;
let np = k & !(CurrentCpuBF16::STEP - 1);

let mut sum = CurrentCpuBF16::zero_array();
let mut ax = CurrentCpuBF16::zero_array();
let mut ay = CurrentCpuBF16::zero_array();

for i in (0..np).step_by(CurrentCpuBF16::STEP) {
for j in 0..CurrentCpuBF16::n() {
ax[j] = CurrentCpuBF16::load(a_row.add(i + j * CurrentCpuBF16::EPR));
ay[j] = CurrentCpuBF16::load(b_row.add(i + j * CurrentCpuBF16::EPR));

sum[j] = CurrentCpuBF16::vec_fma(sum[j], ax[j], ay[j]);
}
}

CurrentCpuBF16::vec_reduce(sum, &mut sumf);

// leftovers
for i in np..k {
sumf += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
}
*c = sumf;
}

#[cfg(not(target_feature = "avx"))]
#[inline(always)]
pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f32, k: usize) {
Expand All @@ -180,3 +227,14 @@ pub(crate) unsafe fn vec_dot_f16(a_row: *const f16, b_row: *const f16, c: *mut f
}
*c = sum;
}

#[cfg(not(target_feature = "avx"))]
#[inline(always)]
pub(crate) unsafe fn vec_dot_bf16(a_row: *const bf16, b_row: *const bf16, c: *mut f32, k: usize) {
// leftovers
let mut sum = 0.0;
for i in 0..k {
sum += (*a_row.add(i)).to_f32() * (*b_row.add(i)).to_f32();
}
*c = sum;
}
1 change: 1 addition & 0 deletions candle-core/src/quantized/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ impl QCudaStorage {
match self.dtype {
GgmlDType::F32 => deq::<f32>(&buffer, block_len, &mut out)?,
GgmlDType::F16 => deq::<half::f16>(&buffer, block_len, &mut out)?,
GgmlDType::BF16 => deq::<half::bf16>(&buffer, block_len, &mut out)?,
GgmlDType::Q4_0 => deq::<crate::quantized::BlockQ4_0>(&buffer, block_len, &mut out)?,
GgmlDType::Q4_1 => deq::<crate::quantized::BlockQ4_1>(&buffer, block_len, &mut out)?,
GgmlDType::Q5_0 => deq::<crate::quantized::BlockQ5_0>(&buffer, block_len, &mut out)?,
Expand Down
1 change: 1 addition & 0 deletions candle-core/src/quantized/ggml_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ pub fn qtensor_from_ggml(
match ggml_dtype {
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
GgmlDType::BF16 => from_raw_data::<half::bf16>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q4_0 => {
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
}
Expand Down
46 changes: 45 additions & 1 deletion candle-core/src/quantized/k_quants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use super::utils::{
use super::GgmlDType;
use crate::Result;
use byteorder::{ByteOrder, LittleEndian};
use half::f16;
use half::{bf16, f16};
use rayon::prelude::*;

// Default to QK_K 256 rather than 64.
Expand Down Expand Up @@ -1963,3 +1963,47 @@ impl GgmlType for f16 {
Ok(())
}
}

impl GgmlType for bf16 {
const DTYPE: GgmlDType = GgmlDType::BF16;
const BLCK_SIZE: usize = 1;
type VecDotType = bf16;

fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
Self::vec_dot_unopt(n, xs, ys)
}

fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
if xs.len() < n {
crate::bail!("size mismatch {} < {n}", xs.len())
}
if ys.len() < n {
crate::bail!("size mismatch {} < {n}", ys.len())
}
let mut res = 0f32;
unsafe { crate::cpu::vec_dot_bf16(xs.as_ptr(), ys.as_ptr(), &mut res, n) };
Ok(res)
}

fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
if xs.len() != ys.len() {
crate::bail!("size mismatch {} {}", xs.len(), ys.len());
}
// TODO: vectorize
for (x, y) in xs.iter().zip(ys.iter_mut()) {
*y = bf16::from_f32(*x)
}
Ok(())
}

fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
if xs.len() != ys.len() {
crate::bail!("size mismatch {} {}", xs.len(), ys.len());
}
// TODO: vectorize
for (x, y) in xs.iter().zip(ys.iter_mut()) {
*y = x.to_f32()
}
Ok(())
}
}
Loading
Loading