From eff930e7f2fa0330bbc00b04c9fcdfb3c41f7353 Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Sat, 26 Oct 2024 17:13:37 -0400 Subject: [PATCH 01/10] Add some fast Metal MLX SDPA kernels (#32) * Sketch the sdpa kernel * Add full sdpa kernel, * Add test * Add vectorized kernel for decoding * Update tests * Add some docs * Fix sdpa_vector names * Add softcapping for vectorized sdpa * Add softcapping for full sdpa * Add support for head dim 32, 96, 256 * Add support for head dim 32, 96, 256 * Update docs * Add update notice * Clippy and format --- candle-metal-kernels/src/lib.rs | 339 +++++ .../src/scaled_dot_product_attention.metal | 1251 +++++++++++++++++ candle-nn/src/ops.rs | 190 +++ candle-nn/tests/sdpa.rs | 206 +++ 4 files changed, 1986 insertions(+) create mode 100644 candle-metal-kernels/src/scaled_dot_product_attention.metal create mode 100644 candle-nn/tests/sdpa.rs diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 222ae8ad85..28cbc5499d 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -25,6 +25,7 @@ const REDUCE: &str = include_str!("reduce.metal"); const SORT: &str = include_str!("sort.metal"); const TERNARY: &str = include_str!("ternary.metal"); const UNARY: &str = include_str!("unary.metal"); +const SDPA: &str = include_str!("scaled_dot_product_attention.metal"); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Source { @@ -42,6 +43,7 @@ pub enum Source { Sort, Ternary, Unary, + Sdpa, } pub mod copy2d { @@ -159,6 +161,17 @@ pub enum MetalKernelError { rhs_stride: Vec, mnk: (usize, usize, usize), }, + #[error("Sdpa {variation} head size was {got}, expectd {expected:?}")] + SdpaHeadSizeMismatch { + variation: &'static str, + got: usize, + expected: Vec, + }, + #[error("Sdpa {variation} got dtype {got:?}")] + SdpaHeadDTypeMismatch { + variation: &'static str, + got: SdpaDType, + }, } impl From> for MetalKernelError { @@ -207,6 +220,7 @@ impl Kernels { Source::Sort => SORT, Source::Ternary => TERNARY, Source::Unary => UNARY, + Source::Sdpa => SDPA, Source::Mfa => panic!("Invalid lib"), } } @@ -1627,6 +1641,331 @@ pub fn call_gemm( Ok(()) } +#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] +pub enum SdpaDType { + BF16, + F16, + F32, +} + +/// SDPA full is supported when: +/// - q head dim == 64, 128 +/// - no mask +/// - q heads == kv heads +/// - final type != bf16 (TODO maybe just template this kernel too?) +/// - q,k,v are contiguous +#[allow(clippy::too_many_arguments)] +pub fn call_sdpa_full( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + q_offset: usize, + q_shape: &[usize], + q_buffer: &Buffer, + k_offset: usize, + k_shape: &[usize], + k_buffer: &Buffer, + v_offset: usize, + v_buffer: &Buffer, + output: &Buffer, + alpha: f32, + softcapping: f32, + itype: SdpaDType, +) -> Result<(), MetalKernelError> { + #[derive(Debug)] + #[repr(C)] + struct MLXFastAttentionParams { + m: i32, + n: i32, + k: i32, + + ldq: i32, // ldq == ldo + ldk: i32, + ldv: i32, + lds: i32, + ldo: i32, + + tiles_n: i32, + tiles_m: i32, + + batch_stride_q: i32, + batch_stride_k: i32, + batch_stride_v: i32, + batch_stride_o: i32, + + swizzle_log: i32, + gemm_n_iterations_aligned: i32, + gemm_k_iterations_aligned: i32, + gemm_sv_m_block_iterations: i32, + + batch_ndim: i32, + alpha: f32, + softcapping: f32, + } + + let bk = q_shape.last().unwrap(); + + const BN: usize = 16; + const BM: usize = 16; + const WM: usize = 2; + const WN: usize = 2; + + let name = match (bk, itype) { + (32, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_half", + (64, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_half", + (96, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_half", + (128, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_half", + (256, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_half", + (32, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_float", + (64, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_float", + (96, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_float", + (128, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_float", + (256, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_float", + (other, SdpaDType::F16 | SdpaDType::F32) => { + return Err(MetalKernelError::SdpaHeadSizeMismatch { + variation: "full", + got: *other, + expected: vec![32, 64, 96, 128, 256], + }) + } + (_, SdpaDType::BF16) => { + return Err(MetalKernelError::SdpaHeadDTypeMismatch { + variation: "full", + got: SdpaDType::BF16, + }) + } + }; + + let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + // q = (bs, qhead, seq, hidden) + // k/v = (bs, kv_head, seq, hidden) + + let _hidden = q_shape[q_shape.len() - 1]; + let qseq = q_shape[q_shape.len() - 2]; + let _qheads = q_shape[q_shape.len() - 3]; + + let _kvseq = k_shape[k_shape.len() - 2]; + let _nq_heads = q_shape[1]; + let _nkv_heads = k_shape[1]; + + let m = q_shape[q_shape.len() - 2]; + let n = m; + let k = q_shape[q_shape.len() - 1]; + let bs_out = q_shape[0] * q_shape[1]; + + let batch_shape = [q_shape[0] * q_shape[1]]; + let dk = q_shape[q_shape.len() - 1]; + let ldq = dk; + let ldk = dk; + let ldv = dk; + let lds = BN; + let ldo = dk; + + let tn = 1; + let tm = (m + BM - 1) / BM; + + let b_stride_q = dk * qseq; + let b_stride_k = dk * qseq; + let b_stride_v = dk * qseq; + let b_stride_o = dk * qseq; + let swizzle_log = 0; + let gemm_n_iterations_aligned = (n + BN - 1) / BN; + let gemm_k_iterations_aligned = (k + bk - 1) / bk; + let gemm_sv_m_block_iterations = (m + BM - 1) / BM; + let batch_ndim = batch_shape.len(); + + let alpha = if softcapping != 1. { + alpha / softcapping + } else { + alpha + }; + + let params = MLXFastAttentionParams { + m: m as i32, + n: n as i32, + k: k as i32, + ldq: ldq as i32, + ldk: ldk as i32, + ldv: ldv as i32, + lds: lds as i32, + ldo: ldo as i32, + tiles_n: tn, + tiles_m: tm as i32, + batch_stride_q: b_stride_q as i32, + batch_stride_k: b_stride_k as i32, + batch_stride_v: b_stride_v as i32, + batch_stride_o: b_stride_o as i32, + swizzle_log, + gemm_n_iterations_aligned: gemm_n_iterations_aligned as i32, + gemm_k_iterations_aligned: gemm_k_iterations_aligned as i32, + gemm_sv_m_block_iterations: gemm_sv_m_block_iterations as i32, + batch_ndim: batch_ndim as i32, + alpha, + softcapping, + }; + let batch_strides = [b_stride_q, b_stride_k, b_stride_v, b_stride_o]; + + encoder.set_buffer(0, Some(&q_buffer), q_offset as NSUInteger); + encoder.set_buffer(1, Some(&k_buffer), k_offset as NSUInteger); + encoder.set_buffer(2, Some(&v_buffer), v_offset as NSUInteger); + encoder.set_buffer(3, Some(&output), 0); + + encoder.set_bytes( + 4, + std::mem::size_of::() as u64, + ¶ms as *const MLXFastAttentionParams as *const c_void, + ); + encoder.set_bytes( + 6, + (std::mem::size_of::() * batch_shape.len()) as u64, + batch_shape.as_ptr() as *const i32 as *const c_void, + ); + encoder.set_bytes( + 7, + (std::mem::size_of::() * batch_strides.len()) as u64, + batch_strides.as_ptr() as *const c_void, + ); + + let grid_dims = MTLSize { + width: 1, + height: tm as u64, + depth: bs_out as u64, + }; + let group_dims = MTLSize { + width: 32, + height: WM as u64, + depth: WN as u64, + }; + encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(grid_dims, group_dims); + Ok(()) +} + +/// SDPA full is supported when: +/// - q head dim == 64, 96, 128 +/// - no mask +/// - q,k,v are contiguous +#[allow(clippy::too_many_arguments)] +pub fn call_sdpa_vector( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + q_offset: usize, + q_shape: &[usize], + q_buffer: &Buffer, + k_offset: usize, + k_shape: &[usize], + k_stride: &[usize], + k_buffer: &Buffer, + v_offset: usize, + v_buffer: &Buffer, + output: &Buffer, + alpha: f32, + softcapping: f32, + itype: SdpaDType, +) -> Result<(), MetalKernelError> { + let bk = q_shape.last().unwrap(); + + let gqa_factor = (q_shape[1] / k_shape[1]) as i32; + let n = k_shape[2] as i32; + let b = (q_shape[0] * q_shape[1]) as i32; + let stride = k_stride[1]; + + let name = match (bk, itype) { + (32, SdpaDType::F16) => "sdpa_vector_float16_t_32", + (64, SdpaDType::F16) => "sdpa_vector_float16_t_64", + (96, SdpaDType::F16) => "sdpa_vector_float16_t_96", + (128, SdpaDType::F16) => "sdpa_vector_float16_t_128", + (256, SdpaDType::F16) => "sdpa_vector_float16_t_256", + (32, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_32", + (64, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_64", + (96, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_96", + (128, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_128", + (256, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_256", + (32, SdpaDType::F32) => "sdpa_vector_float_32", + (64, SdpaDType::F32) => "sdpa_vector_float_64", + (96, SdpaDType::F32) => "sdpa_vector_float_96", + (128, SdpaDType::F32) => "sdpa_vector_float_128", + (256, SdpaDType::F32) => "sdpa_vector_float_256", + (other, _) => { + return Err(MetalKernelError::SdpaHeadSizeMismatch { + variation: "vector", + got: *other, + expected: vec![32, 64, 96, 128, 256], + }) + } + }; + + let alpha = if softcapping != 1. { + alpha / softcapping + } else { + alpha + }; + + let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?; + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + + // q = (bs, qhead, seq, hidden) + // k/v = (bs, kv_head, kv_seq, hidden) + + encoder.set_buffer(0, Some(&q_buffer), q_offset as NSUInteger); + encoder.set_buffer(1, Some(&k_buffer), k_offset as NSUInteger); + encoder.set_buffer(2, Some(&v_buffer), v_offset as NSUInteger); + encoder.set_buffer(3, Some(&output), 0); + + encoder.set_bytes( + 4, + std::mem::size_of::() as u64, + &gqa_factor as *const i32 as *const c_void, + ); + encoder.set_bytes( + 5, + std::mem::size_of::() as u64, + &n as *const i32 as *const c_void, + ); + encoder.set_bytes( + 6, + std::mem::size_of::() as u64, + &stride as *const usize as *const c_void, + ); + encoder.set_bytes( + 7, + std::mem::size_of::() as u64, + &alpha as *const f32 as *const c_void, + ); + encoder.set_bytes( + 8, + std::mem::size_of::() as u64, + &softcapping as *const f32 as *const c_void, + ); + + let grid_dims = MTLSize { + width: 1, + height: b as u64, + depth: 1 as u64, + }; + let group_dims = MTLSize { + width: 1024, + height: 1, + depth: 1, + }; + encoder.use_resource(q_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(k_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(grid_dims, group_dims); + Ok(()) +} + #[allow(clippy::too_many_arguments)] pub fn call_im2col1d_strided( device: &Device, diff --git a/candle-metal-kernels/src/scaled_dot_product_attention.metal b/candle-metal-kernels/src/scaled_dot_product_attention.metal new file mode 100644 index 0000000000..8fc6d5ddfc --- /dev/null +++ b/candle-metal-kernels/src/scaled_dot_product_attention.metal @@ -0,0 +1,1251 @@ +// Updated from MLX commit has f70764a + +#include +#include + +using namespace metal; + +// ============ "mlx/backend/metal/kernels/scaled_dot_product_attention_params.h" + +struct MLXFastAttentionParams { + const int M; + const int N; + const int K; + + const int ldq; // ldq == ldo + const int ldk; + const int ldv; + const int lds; + const int ldo; + + const int tiles_n; + const int tiles_m; + + const int batch_stride_q; + const int batch_stride_k; + const int batch_stride_v; + const int batch_stride_o; + + const int swizzle_log; + const int gemm_n_iterations_aligned; + const int gemm_k_iterations_aligned; + const int gemm_sv_m_block_iterations; + + const int batch_ndim; + const float alpha; + const float softcapping; +}; + +struct MLXScaledDotProductAttentionParams { + // Associated dimensions & transposition information + const uint QUERY_SEQUENCE_LENGTH = 1; + const uint N_Q_HEADS = 32; + const uint N_KV_HEADS = 32; + const uint KV_TILES = 1; + const float INV_ALPHA = 0.08838834764831843f; +}; + +// ============ "mlx/backend/metal/kernels/scaled_dot_product_attention_params.sdpa_vector" + +template +[[kernel]] void sdpa_vector( + const device T* queries [[buffer(0)]], + const device T* keys [[buffer(1)]], + const device T* values [[buffer(2)]], + device T* out [[buffer(3)]], + const constant int& gqa_factor, + const constant int& N, + const constant size_t& k_stride, + const constant float& scale, + const constant float& softcapping, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int BN = 32; + constexpr int BD = 32; + constexpr int elem_per_thread = D / BD; + + const int stride = BN * D; + + typedef float U; + + thread U q[elem_per_thread]; + thread U k[elem_per_thread]; + thread U o[elem_per_thread]; + + threadgroup U outputs[BN * BD]; + threadgroup U max_scores[BN]; + threadgroup U sum_exp_scores[BN]; + + // Adjust positions + const int head_idx = tid.y; + const int kv_head_idx = head_idx / gqa_factor; + queries += head_idx * D + simd_lid * elem_per_thread; + keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; + values += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; + out += head_idx * D + simd_gid * elem_per_thread; + + // Read the query and 0 the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + q[i] = static_cast(scale) * queries[i]; + } + for (int i = 0; i < elem_per_thread; i++) { + o[i] = 0; + } + + U max_score = -INFINITY; + U sum_exp_score = 0; + + // For each key + for (int i = simd_gid; i < N; i += BN) { + // Read the key + for (int i = 0; i < elem_per_thread; i++) { + k[i] = keys[i]; + } + + // Compute the i-th score + U score = 0; + for (int i = 0; i < elem_per_thread; i++) { + score += q[i] * k[i]; + } + score = simd_sum(score); + if (softcapping != 1.) { + score = precise::tanh(score); + score = score * softcapping; + } + + // Update the accumulators + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + // Update the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + o[i] = o[i] * factor + exp_score * values[i]; + } + + // Move the pointers to the next kv + keys += stride; + values += stride; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Each thread has a partial part of the output so we need to combine them. + + // First let's communicate the max and sum_exp + if (simd_lid == 0) { + max_scores[simd_gid] = max_score; + sum_exp_scores[simd_gid] = sum_exp_score; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + max_score = max_scores[simd_lid]; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor); + + // Now we need to aggregate all the outputs + for (int i = 0; i < elem_per_thread; i++) { + outputs[simd_lid * BD + simd_gid] = o[i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // And write the output + if (simd_lid == 0) { + for (int i = 0; i < elem_per_thread; i++) { + out[i] = static_cast(o[i]); + } + } +} + +// ============ "mlx/backend/metal/kernels/steel/defines.h" + +#define STEEL_CONST static constant constexpr const +#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") + +// ============ "mlx/backend/metal/kernels/steel/gemm/transforms.h" + +template +struct TransformNone { + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT) { + return static_cast(x); + } +}; + +template +struct TransformAdd { + TransformAdd(const float, const float) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT c) { + return static_cast(x) + c; + } +}; + +template +struct TransformAxpby { + const float alpha; + const float beta; + + TransformAxpby(const float alpha_, const float beta_) + : alpha(alpha_), beta(beta_) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + METAL_FUNC OutT apply(InT x, OutT c) const { + return static_cast(x * alpha + (beta * c)); + } +}; + +template +struct AccumHelper { + typedef float accum_type; +}; + +struct BlockSwizzle { + static METAL_FUNC int2 + swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { + const int tid_x = (tid.x) >> swizzle_log; + const int tid_y = + ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); + return int2(tid_x, tid_y); + } +}; + +// ============ "mlx/backend/metal/kernels/utils.h" + +typedef bfloat bfloat16_t; +typedef half float16_t; + +METAL_FUNC ulong2 elem_to_loc_broadcast( + uint elem, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + int ndim) { + ulong loc_a{0}; + ulong loc_b{0}; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + int pos_in_dim = (elem % shape[i]); + elem /= shape[i]; + loc_a += pos_in_dim * a_strides[i]; + loc_b += pos_in_dim * b_strides[i]; + } + return ulong2(loc_a, loc_b); +} + +METAL_FUNC ulong3 elem_to_loc_broadcast( + uint elem, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + constant const size_t* c_strides, + int ndim) { + ulong loc_a{0}; + ulong loc_b{0}; + ulong loc_c{0}; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + int pos_in_dim = (elem % shape[i]); + elem /= shape[i]; + loc_a += pos_in_dim * a_strides[i]; + loc_b += pos_in_dim * b_strides[i]; + loc_c += pos_in_dim * c_strides[i]; + } + return ulong3(loc_a, loc_b, loc_c); +} + +// ============ "mlx/backend/metal/kernels/scaled_dot_product_attention_params.metal" + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short alignment = 1, + short n_reads = (BCOLS * BROWS) / (tgp_size), + short TCOLS = BCOLS / n_reads, + short TROWS = tgp_size / TCOLS> +struct BlockLoaderFA { + STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; + STEEL_CONST short vec_size = n_reads; + + // Leading dimension for src + const int src_ld; + const int tile_stride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + struct alignas(alignment * sizeof(T)) ReadVector { + uint8_t v[sizeof(T) * vec_size]; + }; + + /* Constructor */ + METAL_FUNC BlockLoaderFA( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj) {} + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + *((threadgroup ReadVector*)(&dst[i * dst_ld])) = + *((const device ReadVector*)(&src[i * src_ld])); + } + } + + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = src_tile_dim - short2(bj, bi); + + // Skip loading if thread has no valid reads + if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + return; + } + + // Use fast thread memory for bound checks + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + // Make sure tmp_idx only contains valid indices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); + } + + // Read valid indices into tmp_val + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; + } + + // Zero out uneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = tmp_val[j]; + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tile_stride; + } + METAL_FUNC void next(short n) { + src += n * tile_stride; + } +}; + +template +struct LoopAlignment {}; + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + short lda_tgp, + short ldb_tgp, + typename AccumType = float, + typename Epilogue = TransformNone> +struct BlockMMAFA { + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TM_stride = 8 * WM; + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TN_stride = 8 * WN; + + // Warp tile size along M + STEEL_CONST short TM = BM / TM_stride; + // Warp tile size along N + STEEL_CONST short TN = BN / TN_stride; + + // Strides of A, B along reduction axis + STEEL_CONST short simd_stride_a = { + transpose_a ? TM_stride : TM_stride * lda_tgp}; + STEEL_CONST short simd_stride_b = { + transpose_b ? TN_stride * ldb_tgp : TN_stride}; + + // Jump between elements + STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1}; + STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1}; + + STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8}; + STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp}; + + // Simdgroup matrices + simdgroup_matrix Asimd[TM]; + simdgroup_matrix Bsimd[TN]; + simdgroup_matrix results[TM * TN] = { + simdgroup_matrix(0)}; + + // Offsets within threadgroup + const short tm; + const short tn; + + short sm; + short sn; + + ushort sid; + ushort slid; + + short As_offset; + short Bs_offset; + + /* Constructor */ + METAL_FUNC BlockMMAFA( + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { + // Determine thread position in simdgroup matrix + short qid = simd_lane_id / 4; + slid = simd_lane_id; + sid = simd_group_id; + + sm = (qid & 4) + (simd_lane_id / 2) % 4; + sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + + // Determine thread and simdgroup offset + As_offset = + transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp); + Bs_offset = + transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn)); + } + + /* (BM, BK) X (BK, BN) multiply accumulate function */ + METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { + // Adjust for simdgroup and thread location + As += As_offset; + Bs += Bs_offset; + + // Iterate over BK in blocks of 8 + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < BK; kk += 8) { + simdgroup_barrier(mem_flags::mem_none); + + // Load elements from threadgroup A as simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + Asimd[i].thread_elements()[0] = + static_cast(As[i * simd_stride_a + 0]); + Asimd[i].thread_elements()[1] = + static_cast(As[i * simd_stride_a + jump_a]); + } + + simdgroup_barrier(mem_flags::mem_none); + + // Load elements from threadgroup B as simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + Bsimd[j].thread_elements()[0] = + static_cast(Bs[j * simd_stride_b + 0]); + Bsimd[j].thread_elements()[1] = + static_cast(Bs[j * simd_stride_b + jump_b]); + } + + simdgroup_barrier(mem_flags::mem_none); + + // Multiply and accumulate into result simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + short j_serp = (i % 2) ? (TN - 1 - j) : j; + + simdgroup_multiply_accumulate( + results[i * TN + j_serp], + Asimd[i], + Bsimd[j_serp], + results[i * TN + j_serp]); + } + } + + // Progress to next simdgroup tile + As += tile_stride_a; + Bs += tile_stride_b; + } + } + + METAL_FUNC void rescale_output(const threadgroup float* Corrections) { + // Loop over all simdgroup tiles + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + short row = sm + tm + i * TM_stride; + float scale_value = Corrections[row]; + + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = results[i * TN + j].thread_elements(); + // int offset = (i * TM_stride) * ldc + (j * TN_stride); + accum[0] *= scale_value; + accum[1] *= scale_value; + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result(device U* C, const int ldc) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + tn + sn; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset = (i * TM_stride) * ldc + (j * TN_stride); + + // Apply epilogue + U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])}; + + // Write out C + C[offset] = outs[0]; + C[offset + 1] = outs[1]; + } + } + } + + METAL_FUNC void store_result_to_tgp_memory( + threadgroup U* C, + const int ldc, + short2 dst_tile_dims) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn); + dst_tile_dims -= short2(tn + sn, sm + tm); + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset = (i * TM_stride) * ldc + (j * TN_stride); + + // Apply epilogue and output C + if (j * TN_stride < dst_tile_dims.x) { + C[offset] = Epilogue::apply(accum[0]); + } + + if (j * TN_stride + 1 < dst_tile_dims.x) { + C[offset + 1] = Epilogue::apply(accum[1]); + } + } + } + } + } + + METAL_FUNC void + store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn); + dst_tile_dims -= short2(tn + sn, sm + tm); + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset = (i * TM_stride) * ldc + (j * TN_stride); + + // Apply epilogue and output C + if (j * TN_stride < dst_tile_dims.x) { + C[offset] = Epilogue::apply(accum[0]); + } + + if (j * TN_stride + 1 < dst_tile_dims.x) { + C[offset + 1] = Epilogue::apply(accum[1]); + } + } + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + D += (sm + tm) * ldd + tn + sn; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + U outs[2] = { + epilogue_op.apply(accum[0], C[offset_c]), + epilogue_op.apply(accum[1], C[offset_c + fdc])}; + + // Write out D + D[offset_d] = outs[0]; + D[offset_d + 1] = outs[1]; + } + } + } + + METAL_FUNC void store_result_safe( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + D += (sm + tm) * ldd + tn + sn; + dst_tile_dims -= short2(tn + sn, sm + tm); + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue and output C + if (j * TN_stride < dst_tile_dims.x) { + D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]); + } + + if (j * TN_stride + 1 < dst_tile_dims.x) { + D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); + } + } + } + } + } + + METAL_FUNC void clear_results() { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + results[i * TN + j] = simdgroup_matrix(0); + } + } + } +}; + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_q, + bool transpose_k, + bool transpose_v, + bool MN_aligned, + bool K_aligned, + typename AccumType = typename AccumHelper::accum_type, + typename Epilogue = TransformNone> +struct FastAttentionKernel { + STEEL_CONST short tgp_padding = 16 / sizeof(T); + STEEL_CONST short float_padding = 16 / sizeof(float); + STEEL_CONST short tgp_mem_size_q = + transpose_q ? BK * (BM + tgp_padding) : BM * (BK + tgp_padding); + STEEL_CONST short tgp_mem_size_k = + transpose_k ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding); + STEEL_CONST short tgp_mem_size_v = + transpose_v ? BK * (BN + tgp_padding) : BN * (BK + tgp_padding); + STEEL_CONST short tgp_mem_size_s = BM * (BN + tgp_padding); + + // maxes, rowsums, rescale + STEEL_CONST short tgp_mem_size_corrections = + 4 * (BM * sizeof(float) + float_padding); + + STEEL_CONST bool share_kv_smem = transpose_k != transpose_v; + + STEEL_CONST short tgp_mem_size = share_kv_smem + ? tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s + + tgp_mem_size_corrections + : tgp_mem_size_q + tgp_mem_size_k + tgp_mem_size_s + + tgp_mem_size_corrections + tgp_mem_size_v; + + STEEL_CONST short tgp_size = WM * WN * 32; + + static_assert(transpose_q == false, "Expected Q not transposed."); + static_assert(transpose_k == true, "Expected K transposed."); + static_assert(transpose_v == false, "Expected V not transposed."); + static_assert(tgp_mem_size <= 32768, "Excessive tgp memory requested."); + + using loader_q_t = BlockLoaderFA< + T, + transpose_q ? BK : BM, + transpose_q ? BM : BK, + transpose_q ? BM + tgp_padding : BK + tgp_padding, + !transpose_q, + tgp_size>; + + using loader_k_t = BlockLoaderFA< + T, + transpose_k ? BN : BK, + transpose_k ? BK : BN, + transpose_k ? BK + tgp_padding : BN + tgp_padding, + transpose_k, + tgp_size>; + + using loader_v_t = BlockLoaderFA< + T, + transpose_v ? BK : BN, + transpose_v ? BN : BK, + transpose_v ? BN + tgp_padding : BK + tgp_padding, + transpose_v, + tgp_size>; + + using mma_qk_t = BlockMMAFA< + T, + U, + BM, + BN, + BK, + WM, + WN, + transpose_q, + transpose_k, + transpose_q ? BM + tgp_padding : BK + tgp_padding, + transpose_k ? BK + tgp_padding : BN + tgp_padding, + AccumType, + Epilogue>; + + using mma_sv_t = BlockMMAFA< + T, + U, + BM, + BK, + BN, + WM, + WN, + false, + transpose_v, + BN + tgp_padding, + BK + tgp_padding, + AccumType, + Epilogue>; + + /* Main kernel function */ + template + static METAL_FUNC void gemm_loop( + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + const int gemm_k_iterations, + thread loader_k_t& loader_b, + thread mma_qk_t& mma_op, + thread const short& tgp_bm, + thread const short& tgp_bn, + LoopAlignment l = {}) { + // Appease the compiler + (void)l; + (void)tgp_bm; + + short2 tile_dims_B = transpose_k ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + + // not valid for gemm_k_iterations > 1 (so, BK == d_k) + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (N_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe(tile_dims_B); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + } + } + + static METAL_FUNC void initialize_corrections( + threadgroup float* C, + uint simd_lane_id, + uint simd_group_id) { + if (simd_group_id == 0) { + threadgroup float* maxes = C; + threadgroup float* sums = C + (BM + float_padding); + threadgroup float* o_rescale = sums + (BM + float_padding); + threadgroup float* output_rescale = o_rescale + (BM + float_padding); + + if (simd_lane_id < BM) { + maxes[simd_lane_id] = -INFINITY; // m_i + sums[simd_lane_id] = 0.f; // l_i + o_rescale[simd_lane_id] = 1.f; // li * exp(mi - mi_new) + output_rescale[simd_lane_id] = 1.f; // 1.0 / l_i + } + } + } + + static METAL_FUNC void rescale_ss( + threadgroup T* Ss, + threadgroup float* Corrections, + uint simd_group_id, + uint simd_lane_id, + short2 local_blocks, + float alpha, + float softcapping) { + if (simd_group_id == 0) { + short row_offset = BM + float_padding; + threadgroup float* maxes = Corrections; + threadgroup float* sums = Corrections + row_offset; + threadgroup float* o_rescale = sums + row_offset; + threadgroup float* output_scales = o_rescale + row_offset; + + if (simd_lane_id < uint(local_blocks.y)) { + float m_i_old = maxes[simd_lane_id]; + float l_i_old = sums[simd_lane_id]; + + float m_i_new = m_i_old; + float l_i_new = l_i_old; + + short offset = simd_lane_id * (BN + tgp_padding); + + float m_ij = -INFINITY; + + for (short j = 0; j < local_blocks.x; j++) { + float val = alpha * float(Ss[offset + j]); + if (softcapping != 1.) { + val = precise::tanh(val); + val = val * softcapping; + } + m_ij = max(m_ij, val); + } + + m_i_new = max(m_ij, m_i_new); + + float rowsum = 0.f; // lij + + for (short j = 0; j < local_blocks.x; j++) { + float val = alpha * float(Ss[offset + j]); + if (softcapping != 1.) { + val = precise::tanh(val); + val = val * softcapping; + } + float P_i_j = exp(val - m_ij); + rowsum += P_i_j; + P_i_j = P_i_j * exp(m_ij - m_i_new); + Ss[offset + j] = T(P_i_j); + } + + l_i_new = + exp(m_i_old - m_i_new) * l_i_old + exp(m_ij - m_i_new) * rowsum; + maxes[simd_lane_id] = m_i_new; + sums[simd_lane_id] = l_i_new; + float rescale = l_i_old * exp(m_i_old - m_i_new); + o_rescale[simd_lane_id] = rescale; + output_scales[simd_lane_id] = 1.0 / l_i_new; + } + } + } + + /* Main kernel function */ + static METAL_FUNC void run( + const device T* Q [[buffer(0)]], + const device T* K [[buffer(1)]], + const device T* V [[buffer(2)]], + device U* O [[buffer(3)]], + const constant MLXFastAttentionParams* params [[buffer(4)]], + threadgroup T* Qs [[threadgroup(0)]], + threadgroup T* Ks [[threadgroup(1)]], + threadgroup T* Ss [[threadgroup(2)]], + threadgroup T* Vs [[threadgroup(3)]], + threadgroup float* Corrections [[threadgroup(4)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Pacifying compiler + (void)lid; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in Q, O; and head in K, V. + const int c_row = tid_y * BM; + + Q += transpose_q ? c_row : c_row * params->ldq; + thread loader_q_t loader_q(Q, params->ldq, Qs, simd_group_id, simd_lane_id); + + short tgp_bm = min(BM, params->M - c_row); + short2 tile_dims_Q = transpose_q ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + + loader_q.load_safe(tile_dims_Q); + + initialize_corrections(Corrections, simd_lane_id, simd_group_id); + + O += c_row * params->ldo; + + // Prepare threadgroup mma operation + thread mma_qk_t mma_qk_op(simd_group_id, simd_lane_id); + thread mma_sv_t mma_softmax_sv_op(simd_group_id, simd_lane_id); + thread loader_k_t loader_k(K, params->ldk, Ks, simd_group_id, simd_lane_id); + thread loader_v_t loader_v(V, params->ldv, Vs, simd_group_id, simd_lane_id); + + for (short n_block = 0; n_block < params->gemm_n_iterations_aligned; + n_block++) { + short c_col = BN; + + // Prepare threadgroup loading operations + short gemm_k_iterations = params->gemm_k_iterations_aligned; + short tgp_bn_qk = min(BN, params->N - c_col * n_block); + threadgroup_barrier(mem_flags::mem_none); + + /////////////////////////////////////////////////////////////////////////////// + { // Loop over K - unaligned case + + if (tgp_bm == BM && tgp_bn_qk == BN) { + gemm_loop( + Qs, + Ks, + gemm_k_iterations, + loader_k, + mma_qk_op, + tgp_bm, + tgp_bn_qk); + } else if (tgp_bn_qk == BN) { + gemm_loop( + Qs, + Ks, + gemm_k_iterations, + loader_k, + mma_qk_op, + tgp_bm, + tgp_bn_qk); + + } else if (tgp_bm == BM) { + gemm_loop( + Qs, + Ks, + gemm_k_iterations, + loader_k, + mma_qk_op, + tgp_bm, + tgp_bn_qk); + + } else { + gemm_loop( + Qs, + Ks, + gemm_k_iterations, + loader_k, + mma_qk_op, + tgp_bm, + tgp_bn_qk); + } + } + + mma_qk_op.store_result_to_tgp_memory( + Ss, BN + tgp_padding, short2(BN, BM)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + rescale_ss( + Ss, + Corrections, + simd_group_id, + simd_lane_id, + short2(tgp_bn_qk, tgp_bm), + params->alpha, + params->softcapping); + + loader_v.load_safe(short2(BK, tgp_bn_qk)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + threadgroup float* o_scales = Corrections + 2 * (BM + float_padding); + mma_softmax_sv_op.rescale_output(o_scales); + + mma_softmax_sv_op.mma(Ss, Vs); + + threadgroup float* final_output_scales = + Corrections + 3 * (BM + float_padding); + + mma_softmax_sv_op.rescale_output(final_output_scales); + + loader_v.next(); + loader_k.next(BN); + + mma_qk_op.clear_results(); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_softmax_sv_op.store_result_safe(O, params->ldo, short2(BK, tgp_bm)); + } +}; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_q, + bool transpose_k, + bool transpose_v, + bool MN_aligned, + bool K_aligned> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void attention( + const device T* Q [[buffer(0)]], + const device T* K [[buffer(1)]], + const device T* V [[buffer(2)]], + device T* O [[buffer(3)]], + const constant MLXFastAttentionParams* params [[buffer(4)]], + const constant int* batch_shape [[buffer(6)]], + const constant size_t* batch_strides [[buffer(7)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using attention_kernel = FastAttentionKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_q, + transpose_k, + transpose_v, + MN_aligned, + K_aligned>; + + // Adjust for batch + if (params->batch_ndim > 1) { + const constant size_t* Q_bstrides = batch_strides; + const constant size_t* KV_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, Q_bstrides, KV_bstrides, params->batch_ndim); + + Q += batch_offsets.x; + K += batch_offsets.y; + V += batch_offsets.y; + + } else { + Q += params->batch_stride_q * tid.z; + K += params->batch_stride_k * tid.z; + V += params->batch_stride_v * tid.z; + } + + // same shape as input + O += params->batch_stride_o * tid.z; + threadgroup T Qs[attention_kernel::tgp_mem_size_q]; + threadgroup T Ss[attention_kernel::tgp_mem_size_s]; + threadgroup float Corrections[attention_kernel::tgp_mem_size_corrections]; + + if (attention_kernel::share_kv_smem) { + threadgroup T Ks[attention_kernel::tgp_mem_size_k]; + threadgroup T* Vs = Ks; //[attention_kernel::tgp_mem_size_v]; + attention_kernel::run( + Q, + K, + V, + O, + params, + Qs, + Ks, + Ss, + Vs, + Corrections, + simd_lane_id, + simd_group_id, + tid, + lid); + } else { + threadgroup T Ks[attention_kernel::tgp_mem_size_k]; + threadgroup T Vs[attention_kernel::tgp_mem_size_v]; + attention_kernel::run( + Q, + K, + V, + O, + params, + Qs, + Ks, + Ss, + Vs, + Corrections, + simd_lane_id, + simd_group_id, + tid, + lid); + } +} + +// clang-format off + +// SDPA full instantiations +#define instantiate_fast_inference_self_attention_kernel( \ + itype, otype, bm, bn, bk, wm, wn) \ + template [[host_name("steel_gemm_attention_bm_" #bm "_bn_" #bn "_bk_" #bk \ + "_itype_" #itype)]] [[kernel]] void \ + attention( \ + const device itype* Q [[buffer(0)]], \ + const device itype* K [[buffer(1)]], \ + const device itype* V [[buffer(2)]], \ + device otype* O [[buffer(3)]], \ + const constant MLXFastAttentionParams* params [[buffer(4)]], \ + const constant int* batch_shape [[buffer(6)]], \ + const constant size_t* batch_strides [[buffer(7)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); + +instantiate_fast_inference_self_attention_kernel( + float, + float, + 16, + 16, + 32, + 2, + 2); +instantiate_fast_inference_self_attention_kernel( + float, + float, + 16, + 16, + 64, + 2, + 2); +instantiate_fast_inference_self_attention_kernel( + float, + float, + 16, + 16, + 96, + 2, + 2); +instantiate_fast_inference_self_attention_kernel( + float, + float, + 16, + 16, + 128, + 2, + 2); +instantiate_fast_inference_self_attention_kernel( + float, + float, + 16, + 16, + 256, + 2, + 2); +instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 32, 2, 2); +instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 64, 2, 2); +instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 96, 2, 2); +instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 128, 2, 2); +instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2); + +// SDPA vector instantiations +#define instantiate_sdpa_vector(type, head_dim) \ + template [[host_name("sdpa_vector_" #type "_" #head_dim)]] \ + [[kernel]] void sdpa_vector( \ + const device type* queries [[buffer(0)]], \ + const device type* keys [[buffer(1)]], \ + const device type* values [[buffer(2)]], \ + device type* out [[buffer(3)]], \ + const constant int& gqa_factor, \ + const constant int& N, \ + const constant size_t& k_stride, \ + const constant float& scale, \ + const constant float& softcapping, \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_sdpa_vector_heads(type) \ + instantiate_sdpa_vector(type, 32) \ + instantiate_sdpa_vector(type, 64) \ + instantiate_sdpa_vector(type, 96) \ + instantiate_sdpa_vector(type, 128) \ + instantiate_sdpa_vector(type, 256) + +instantiate_sdpa_vector_heads(float) +instantiate_sdpa_vector_heads(bfloat16_t) +instantiate_sdpa_vector_heads(float16_t) + // clang-format on diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 9a360c472c..a57b988e03 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -947,3 +947,193 @@ impl Module for Identity { Ok(xs.clone()) } } + +#[allow(dead_code)] +struct Sdpa { + scale: f32, + softcapping: f32, +} + +impl candle::CustomOp3 for Sdpa { + fn name(&self) -> &'static str { + "sdpa" + } + + fn cpu_fwd( + &self, + _s1: &CpuStorage, + _l1: &Layout, + _s2: &CpuStorage, + _l2: &Layout, + _s3: &CpuStorage, + _l3: &Layout, + ) -> Result<(CpuStorage, Shape)> { + candle::bail!("SDPA has no cpu impl") + } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + q: &candle::MetalStorage, + q_l: &Layout, + k: &candle::MetalStorage, + k_l: &Layout, + v: &candle::MetalStorage, + v_l: &Layout, + ) -> Result<(candle::MetalStorage, Shape)> { + use candle::backend::BackendStorage; + use candle_metal_kernels::SdpaDType; + + let device = q.device(); + + let out_dims = vec![q_l.dims()[0], q_l.dims()[1], q_l.dims()[2], v_l.dims()[3]]; + let elem_count: usize = out_dims.iter().product(); + + let output = device.new_buffer(elem_count, q.dtype(), "sdpa_o")?; + + // q,k must have matching emb dim + if q_l.dims()[q_l.dims().len() - 1] != k_l.dims()[k_l.dims().len() - 1] { + candle::bail!("`q` and `k` last dims must match"); + } + + // k,v must have matching n kv heads + if v_l.dims()[v_l.dims().len() - 3] != k_l.dims()[k_l.dims().len() - 3] { + candle::bail!("`k` and `v` head dims must match"); + } + + // n_heads % n_kv_heads == 0; n_heads >= 1, n_kv_heads >= 1. + if q_l.dims()[q_l.dims().len() - 3] % k_l.dims()[k_l.dims().len() - 3] != 0 { + candle::bail!("query `n_heads` must be a multiple of `n_kv_heads`"); + } + + let k_head = k_l.dims()[k_l.dims().len() - 1]; + let q_head = q_l.dims()[q_l.dims().len() - 1]; + let q_seq = q_l.dims()[2]; + + let mut implementation_supports_use_case = q_head == k_head; + let supported_head_dim = + q_head == 32 || q_head == 64 || q_head == 96 || q_head == 128 || q_head == 256; + + const SDPA_FULL_THRESHOLD: usize = 2; + + let supports_sdpa_full = + q_seq >= SDPA_FULL_THRESHOLD && supported_head_dim && q_head == k_head; + let supports_sdpa_vector = q_seq == 1 && supported_head_dim; + + implementation_supports_use_case &= supports_sdpa_full || supports_sdpa_vector; + + if !supported_head_dim { + candle::bail!( + "Meta SDPA does not support q head dim {q_head}: q dims {:?}, k dims {:?}, v dims {:?}.", + q_l.dims(), + k_l.dims(), + v_l.dims() + ); + } + if !implementation_supports_use_case { + candle::bail!( + "Meta SDPA does not support q dims {:?}, k dims {:?}, v dims {:?}.", + q_l.dims(), + k_l.dims(), + v_l.dims() + ); + } + + for t in [k.dtype(), v.dtype()] { + if q.dtype() != t { + candle::bail!("all q, k, v dtypes must match."); + } + } + + let itype = match q.dtype() { + DType::BF16 => SdpaDType::BF16, + DType::F16 => SdpaDType::F16, + DType::F32 => SdpaDType::F32, + other => candle::bail!("unsupported sdpa type {other:?}"), + }; + + let command_buffer = q.device().command_buffer()?; + if supports_sdpa_vector { + command_buffer.set_label("vector_attention"); + candle_metal_kernels::call_sdpa_vector( + q.device().device(), + &command_buffer, + q.device().kernels(), + q_l.start_offset(), + q_l.dims(), + q.buffer(), + k_l.start_offset(), + k_l.dims(), + k_l.stride(), + k.buffer(), + v_l.start_offset(), + v.buffer(), + &output, + self.scale, + self.softcapping, + itype, + ) + .map_err(candle::Error::wrap)?; + } else if supports_sdpa_full { + if q_l.dims()[2] != k_l.dims()[2] { + candle::bail!( + "query and key sequence length must be equal if using full metal sdpa" + ) + } + + command_buffer.set_label("full_attention"); + candle_metal_kernels::call_sdpa_full( + q.device().device(), + &command_buffer, + q.device().kernels(), + q_l.start_offset(), + q_l.dims(), + q.buffer(), + k_l.start_offset(), + k_l.dims(), + k.buffer(), + v_l.start_offset(), + v.buffer(), + &output, + self.scale, + self.softcapping, + itype, + ) + .map_err(candle::Error::wrap)?; + } else { + candle::bail!("must be vector or full sdpa kernel"); + } + + let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, q.dtype()); + Ok((newstorage, Shape::from_dims(&out_dims))) + } +} + +/// Scaled dot product attention with a fused kernel. +/// +/// Computes softmax(qk^T*scale)v. +/// +/// **Inputs shapes:** +/// - `q`: (bs, qhead, seq, hidden) +/// - `k`: (bs, kv_head, kv_seq, hidden) +/// - `k`: (bs, kv_head, kv_seq, v_hidden) +/// - `scale` is applied before softmax. +/// - If `softcapping` != 1.0: +/// - Computation is: softmax(tanh(qk^T*scale/cap)*cap)v +/// +/// **Output shape:** (bs, qhead, seq, v_hidden) +/// +/// **Supported head dims:** 32, 64, 96, 128, 256. +/// +/// ## On Metal: +/// - If `seq` == 1: +/// - Use a vectorized kernel +/// - Supports `seq` != `kv_seq` (cross attn. support) +/// - Supports GQA when `qhead` is a multiple of `kv_head` +/// - Otherwise: +/// - Use an alternate kernel +/// - Requires `seq` == `kv_seq` +/// - GQA is not supported (requires `qhead` == `kv_head`) +pub fn sdpa(q: &Tensor, k: &Tensor, v: &Tensor, scale: f32, softcapping: f32) -> Result { + q.apply_op3_no_bwd(k, v, &Sdpa { scale, softcapping }) +} diff --git a/candle-nn/tests/sdpa.rs b/candle-nn/tests/sdpa.rs new file mode 100644 index 0000000000..67ad3816b4 --- /dev/null +++ b/candle-nn/tests/sdpa.rs @@ -0,0 +1,206 @@ +#[cfg(feature = "metal")] +mod metal_sdpa_tests { + #[test] + fn sdpa_full() -> candle::Result<()> { + use candle::{DType, Device, Tensor}; + + // Force seqlen = 100 + const BS: usize = 4; + const R: usize = 4; + const L: usize = 4; + const DK: usize = 64; + const H: usize = 3; + let scale: f64 = f64::from(DK as u32).sqrt().recip(); + + let device = Device::new_metal(0)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + + let ground_truth = { + let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; + let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)? + .to_dtype(q.dtype())?; + att.matmul(&v.clone())? + }; + + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; + + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error <= 0.0005, "{}", error); + + Ok(()) + } + + #[test] + fn sdpa_vector() -> candle::Result<()> { + use candle::{DType, Device, Tensor}; + + // Allow vectorized, seqlen = 1 + const BS: usize = 4; + const R: usize = 1; + const L: usize = 1; + const DK: usize = 64; + const H: usize = 3; + let scale: f64 = f64::from(DK as u32).sqrt().recip(); + + let device = Device::new_metal(0)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + + let ground_truth = { + let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; + let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)? + .to_dtype(q.dtype())?; + att.matmul(&v.clone())? + }; + + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; + + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error <= 0.0001, "{}", error); + + Ok(()) + } + + #[test] + fn sdpa_full_softcapping() -> candle::Result<()> { + use candle::{DType, Device, Tensor}; + use std::ops::{Div, Mul}; + + // Allow vectorized, seqlen = 1 + const BS: usize = 4; + const R: usize = 4; + const L: usize = 4; + const DK: usize = 64; + const H: usize = 3; + const SOFTCAP: f64 = 50.; + let scale: f64 = f64::from(DK as u32).sqrt().recip(); + + let device = Device::new_metal(0)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + + let ground_truth = { + let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; + let att = candle_nn::ops::softmax_last_dim( + &att.to_dtype(DType::F32)? + .div(SOFTCAP)? + .tanh()? + .mul(SOFTCAP)?, + )? + .to_dtype(q.dtype())?; + att.matmul(&v.clone())? + }; + + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?; + + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error <= 0.0004, "{}", error); + + Ok(()) + } + + #[test] + fn sdpa_vector_softcapping() -> candle::Result<()> { + use candle::{DType, Device, Tensor}; + use std::ops::{Div, Mul}; + + // Allow vectorized, seqlen = 1 + const BS: usize = 4; + const R: usize = 1; + const L: usize = 1; + const DK: usize = 64; + const H: usize = 3; + const SOFTCAP: f64 = 50.; + let scale: f64 = f64::from(DK as u32).sqrt().recip(); + + let device = Device::new_metal(0)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + + let ground_truth = { + let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; + let att = candle_nn::ops::softmax_last_dim( + &att.to_dtype(DType::F32)? + .div(SOFTCAP)? + .tanh()? + .mul(SOFTCAP)?, + )? + .to_dtype(q.dtype())?; + att.matmul(&v.clone())? + }; + + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, SOFTCAP as f32)?; + + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error <= 0.0001, "{}", error); + + Ok(()) + } + + #[test] + fn sdpa_vector_cross() -> candle::Result<()> { + use candle::{DType, Device, Tensor}; + + // Allow vectorized, seqlen = 1. Simulat cross attention case where R != L, R = 1 + const BS: usize = 4; + const R: usize = 1; + const L: usize = 24; + const DK: usize = 64; + const H: usize = 3; + let scale: f64 = f64::from(DK as u32).sqrt().recip(); + + let device = Device::new_metal(0)?; + + let q = Tensor::randn(0f32, 1f32, (BS, H, R, DK), &device)?; + let k = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + let v = Tensor::randn(0f32, 1f32, (BS, H, L, DK), &device)?; + + let ground_truth = { + let att = (q.clone() * scale)?.matmul(&k.clone().t()?)?; + let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)? + .to_dtype(q.dtype())?; + att.matmul(&v.clone())? + }; + + let sdpa_output = candle_nn::ops::sdpa(&q, &k, &v, scale as f32, 1.)?; + + assert_eq!(ground_truth.shape(), sdpa_output.shape()); + + let error: f32 = ((&ground_truth - &sdpa_output)?.abs()? / &ground_truth.abs()?)? + .sum_all()? + .to_scalar()?; + + assert!(error <= 0.0013, "{}", error); + + Ok(()) + } +} From 49c72556eb9888445af57e0e220f30c2640f09d3 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Tue, 29 Oct 2024 06:35:37 -0400 Subject: [PATCH 02/10] Conditional compilation for bf16 --- .vscode/settings.json | 3 ++- candle-metal-kernels/src/scaled_dot_product_attention.metal | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index b2dbd68012..383df783ce 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -7,5 +7,6 @@ "candle-pyo3" ], "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true + "python.testing.pytestEnabled": true, + "rust-analyzer.cargo.features": ["metal"] } \ No newline at end of file diff --git a/candle-metal-kernels/src/scaled_dot_product_attention.metal b/candle-metal-kernels/src/scaled_dot_product_attention.metal index 8fc6d5ddfc..d65cc621ac 100644 --- a/candle-metal-kernels/src/scaled_dot_product_attention.metal +++ b/candle-metal-kernels/src/scaled_dot_product_attention.metal @@ -227,7 +227,9 @@ struct BlockSwizzle { // ============ "mlx/backend/metal/kernels/utils.h" +#if defined(__HAVE_BFLOAT__) typedef bfloat bfloat16_t; +#endif typedef half float16_t; METAL_FUNC ulong2 elem_to_loc_broadcast( @@ -1246,6 +1248,8 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2); instantiate_sdpa_vector(type, 256) instantiate_sdpa_vector_heads(float) +#if defined(__HAVE_BFLOAT__) instantiate_sdpa_vector_heads(bfloat16_t) +#endif instantiate_sdpa_vector_heads(float16_t) // clang-format on From 10d357aba31151433e0dafc30b04029bfcd0536c Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Tue, 29 Oct 2024 08:14:27 -0400 Subject: [PATCH 03/10] Use it in quantized llama --- .../src/models/quantized_llama.rs | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 6b326fbe92..4c0610af3f 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -205,21 +205,27 @@ impl LayerWeights { }; self.kv_cache = Some((k.clone(), v.clone())); - // Support for MQA, useful for 70B models and mistral. - let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?; - let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?; - - let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; - let att = match mask { - None => att, - Some(mask) => { - let mask = mask.broadcast_as(att.shape())?; - masked_fill(&att, &mask, &self.neg_inf)? - } + let y = if q.device().is_metal() && seq_len == 1 { + // SDPA will do MQA for us + candle_nn::ops::sdpa(&q, &k, &v, 1. / (self.head_dim as f32).sqrt(), 1.)? + } else { + // Support for MQA, useful for 70B models and mistral. + let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?; + let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?; + + let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; + let att = match mask { + None => att, + Some(mask) => { + let mask = mask.broadcast_as(att.shape())?; + masked_fill(&att, &mask, &self.neg_inf)? + } + }; + let att = candle_nn::ops::softmax_last_dim(&att)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + att.matmul(&v.contiguous()?)? }; - let att = candle_nn::ops::softmax_last_dim(&att)?; - // Convert to contiguous as matmul doesn't support strided vs for now. - let y = att.matmul(&v.contiguous()?)?; + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; let y = self.attention_wo.forward(&y)?; Ok(y) From 018bda719f675b127bd3398ba684175ca425cc97 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 31 Oct 2024 08:57:26 -0400 Subject: [PATCH 04/10] Some review comments --- .vscode/settings.json | 1 - candle-nn/src/ops.rs | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 383df783ce..36c8eb1c90 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -8,5 +8,4 @@ ], "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, - "rust-analyzer.cargo.features": ["metal"] } \ No newline at end of file diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index a57b988e03..15a5fbec70 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -956,7 +956,7 @@ struct Sdpa { impl candle::CustomOp3 for Sdpa { fn name(&self) -> &'static str { - "sdpa" + "metal-sdpa" } fn cpu_fwd( From fa9cef3513275317d5f6bf46cf74f0cc1c2cc41b Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 31 Oct 2024 09:13:09 -0400 Subject: [PATCH 05/10] Use set_params! --- candle-metal-kernels/src/lib.rs | 82 ++++++++----------- .../src/scaled_dot_product_attention.metal | 4 +- 2 files changed, 36 insertions(+), 50 deletions(-) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 28cbc5499d..9ff86e6475 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -8,7 +8,7 @@ use std::sync::RwLock; pub mod utils; pub use utils::BufferOffset; -use utils::{get_block_dims, linear_split, EncoderProvider}; +use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider}; const AFFINE: &str = include_str!("affine.metal"); const BINARY: &str = include_str!("binary.metal"); @@ -1809,25 +1809,27 @@ pub fn call_sdpa_full( }; let batch_strides = [b_stride_q, b_stride_k, b_stride_v, b_stride_o]; - encoder.set_buffer(0, Some(&q_buffer), q_offset as NSUInteger); - encoder.set_buffer(1, Some(&k_buffer), k_offset as NSUInteger); - encoder.set_buffer(2, Some(&v_buffer), v_offset as NSUInteger); - encoder.set_buffer(3, Some(&output), 0); + impl EncoderParam for MLXFastAttentionParams { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_bytes( + position, + core::mem::size_of::() as u64, + &data as *const MLXFastAttentionParams as *const c_void, + ); + } + } - encoder.set_bytes( - 4, - std::mem::size_of::() as u64, - ¶ms as *const MLXFastAttentionParams as *const c_void, - ); - encoder.set_bytes( - 6, - (std::mem::size_of::() * batch_shape.len()) as u64, - batch_shape.as_ptr() as *const i32 as *const c_void, - ); - encoder.set_bytes( - 7, - (std::mem::size_of::() * batch_strides.len()) as u64, - batch_strides.as_ptr() as *const c_void, + set_params!( + encoder, + ( + (q_buffer, q_offset), + (k_buffer, k_offset), + (v_buffer, v_offset), + output, + params, + &batch_shape[..], + &batch_strides[..] + ) ); let grid_dims = MTLSize { @@ -1917,35 +1919,19 @@ pub fn call_sdpa_vector( // q = (bs, qhead, seq, hidden) // k/v = (bs, kv_head, kv_seq, hidden) - encoder.set_buffer(0, Some(&q_buffer), q_offset as NSUInteger); - encoder.set_buffer(1, Some(&k_buffer), k_offset as NSUInteger); - encoder.set_buffer(2, Some(&v_buffer), v_offset as NSUInteger); - encoder.set_buffer(3, Some(&output), 0); - - encoder.set_bytes( - 4, - std::mem::size_of::() as u64, - &gqa_factor as *const i32 as *const c_void, - ); - encoder.set_bytes( - 5, - std::mem::size_of::() as u64, - &n as *const i32 as *const c_void, - ); - encoder.set_bytes( - 6, - std::mem::size_of::() as u64, - &stride as *const usize as *const c_void, - ); - encoder.set_bytes( - 7, - std::mem::size_of::() as u64, - &alpha as *const f32 as *const c_void, - ); - encoder.set_bytes( - 8, - std::mem::size_of::() as u64, - &softcapping as *const f32 as *const c_void, + set_params!( + encoder, + ( + (q_buffer, q_offset), + (k_buffer, k_offset), + (v_buffer, v_offset), + output, + gqa_factor, + n, + stride, + alpha, + softcapping + ) ); let grid_dims = MTLSize { diff --git a/candle-metal-kernels/src/scaled_dot_product_attention.metal b/candle-metal-kernels/src/scaled_dot_product_attention.metal index d65cc621ac..69815abd0d 100644 --- a/candle-metal-kernels/src/scaled_dot_product_attention.metal +++ b/candle-metal-kernels/src/scaled_dot_product_attention.metal @@ -1170,8 +1170,8 @@ template < const device itype* V [[buffer(2)]], \ device otype* O [[buffer(3)]], \ const constant MLXFastAttentionParams* params [[buffer(4)]], \ - const constant int* batch_shape [[buffer(6)]], \ - const constant size_t* batch_strides [[buffer(7)]], \ + const constant int* batch_shape [[buffer(5)]], \ + const constant size_t* batch_strides [[buffer(6)]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]], \ uint3 tid [[threadgroup_position_in_grid]], \ From 56420cc5ecc6827af11d0a20c8eafa311f1b0f48 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 31 Oct 2024 09:14:56 -0400 Subject: [PATCH 06/10] Remove unused --- .vscode/settings.json | 1 + candle-metal-kernels/src/lib.rs | 7 ------- candle-nn/src/ops.rs | 1 - 3 files changed, 1 insertion(+), 8 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 36c8eb1c90..383df783ce 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -8,4 +8,5 @@ ], "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, + "rust-analyzer.cargo.features": ["metal"] } \ No newline at end of file diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 9ff86e6475..f190c185d4 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1663,7 +1663,6 @@ pub fn call_sdpa_full( q_shape: &[usize], q_buffer: &Buffer, k_offset: usize, - k_shape: &[usize], k_buffer: &Buffer, v_offset: usize, v_buffer: &Buffer, @@ -1744,13 +1743,7 @@ pub fn call_sdpa_full( // q = (bs, qhead, seq, hidden) // k/v = (bs, kv_head, seq, hidden) - let _hidden = q_shape[q_shape.len() - 1]; let qseq = q_shape[q_shape.len() - 2]; - let _qheads = q_shape[q_shape.len() - 3]; - - let _kvseq = k_shape[k_shape.len() - 2]; - let _nq_heads = q_shape[1]; - let _nkv_heads = k_shape[1]; let m = q_shape[q_shape.len() - 2]; let n = m; diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 15a5fbec70..efb1a83e23 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -1090,7 +1090,6 @@ impl candle::CustomOp3 for Sdpa { q_l.dims(), q.buffer(), k_l.start_offset(), - k_l.dims(), k.buffer(), v_l.start_offset(), v.buffer(), From 4124ae0866b6d52dd793f873af845f283be36ca1 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 4 Nov 2024 05:26:42 -0500 Subject: [PATCH 07/10] Remove feature --- .vscode/settings.json | 1 - 1 file changed, 1 deletion(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 383df783ce..36c8eb1c90 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -8,5 +8,4 @@ ], "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, - "rust-analyzer.cargo.features": ["metal"] } \ No newline at end of file From 7d85ffcecd2ca61c5d17777030a98131097d12a8 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 4 Nov 2024 11:35:53 -0500 Subject: [PATCH 08/10] Fix metal sdpa for v stride --- candle-metal-kernels/src/lib.rs | 7 +++++-- .../src/scaled_dot_product_attention.metal | 4 +++- candle-nn/src/ops.rs | 1 + 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index f190c185d4..0843cc1179 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1860,6 +1860,7 @@ pub fn call_sdpa_vector( k_stride: &[usize], k_buffer: &Buffer, v_offset: usize, + v_stride: &[usize], v_buffer: &Buffer, output: &Buffer, alpha: f32, @@ -1871,7 +1872,8 @@ pub fn call_sdpa_vector( let gqa_factor = (q_shape[1] / k_shape[1]) as i32; let n = k_shape[2] as i32; let b = (q_shape[0] * q_shape[1]) as i32; - let stride = k_stride[1]; + let kstride = k_stride[1]; + let vstride = v_stride[1]; let name = match (bk, itype) { (32, SdpaDType::F16) => "sdpa_vector_float16_t_32", @@ -1921,7 +1923,8 @@ pub fn call_sdpa_vector( output, gqa_factor, n, - stride, + kstride, + vstride, alpha, softcapping ) diff --git a/candle-metal-kernels/src/scaled_dot_product_attention.metal b/candle-metal-kernels/src/scaled_dot_product_attention.metal index 69815abd0d..1abb9f080a 100644 --- a/candle-metal-kernels/src/scaled_dot_product_attention.metal +++ b/candle-metal-kernels/src/scaled_dot_product_attention.metal @@ -56,6 +56,7 @@ template const constant int& gqa_factor, const constant int& N, const constant size_t& k_stride, + const constant size_t& v_stride, const constant float& scale, const constant float& softcapping, uint3 tid [[threadgroup_position_in_grid]], @@ -82,7 +83,7 @@ template const int kv_head_idx = head_idx / gqa_factor; queries += head_idx * D + simd_lid * elem_per_thread; keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; - values += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; + values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread; out += head_idx * D + simd_gid * elem_per_thread; // Read the query and 0 the output accumulator @@ -1234,6 +1235,7 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2); const constant int& gqa_factor, \ const constant int& N, \ const constant size_t& k_stride, \ + const constant size_t& v_stride, \ const constant float& scale, \ const constant float& softcapping, \ uint3 tid [[threadgroup_position_in_grid]], \ diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index efb1a83e23..95aec1149f 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -1067,6 +1067,7 @@ impl candle::CustomOp3 for Sdpa { k_l.stride(), k.buffer(), v_l.start_offset(), + v_l.stride(), v.buffer(), &output, self.scale, From a2471b197c81aa1ae4de37f96c7772dfecd1d581 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 4 Nov 2024 11:40:37 -0500 Subject: [PATCH 09/10] Remove comma --- .vscode/settings.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 36c8eb1c90..b2dbd68012 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -7,5 +7,5 @@ "candle-pyo3" ], "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true, + "python.testing.pytestEnabled": true } \ No newline at end of file From fcd2cb31764f8c1a5f5da015ba92e7f56aa16dcd Mon Sep 17 00:00:00 2001 From: Laurent Date: Tue, 5 Nov 2024 09:27:00 +0100 Subject: [PATCH 10/10] Add the dim method to layout and shape. --- candle-core/src/layout.rs | 6 ++++++ candle-core/src/shape.rs | 6 ++++++ candle-nn/src/ops.rs | 16 ++++++++-------- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs index e6824b29a6..7e3b7afbba 100644 --- a/candle-core/src/layout.rs +++ b/candle-core/src/layout.rs @@ -35,6 +35,12 @@ impl Layout { self.shape.dims() } + /// The dimension size for a specified dimension index. + pub fn dim(&self, dim: D) -> Result { + let dim = dim.to_index(&self.shape, "dim")?; + Ok(self.dims()[dim]) + } + pub fn shape(&self) -> &Shape { &self.shape } diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index 90a37be663..ca05d216a5 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -142,6 +142,12 @@ impl Shape { &self.0 } + /// The dimension size for a specified dimension index. + pub fn dim(&self, dim: D) -> Result { + let dim = dim.to_index(self, "dim")?; + Ok(self.dims()[dim]) + } + /// The total number of elements, this is the product of all dimension sizes. pub fn elem_count(&self) -> usize { self.0.iter().product() diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 95aec1149f..3e9a1029bd 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -986,29 +986,29 @@ impl candle::CustomOp3 for Sdpa { let device = q.device(); - let out_dims = vec![q_l.dims()[0], q_l.dims()[1], q_l.dims()[2], v_l.dims()[3]]; + let out_dims = vec![q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, v_l.dim(3)?]; let elem_count: usize = out_dims.iter().product(); let output = device.new_buffer(elem_count, q.dtype(), "sdpa_o")?; // q,k must have matching emb dim - if q_l.dims()[q_l.dims().len() - 1] != k_l.dims()[k_l.dims().len() - 1] { + if q_l.dim(D::Minus1)? != k_l.dim(D::Minus1)? { candle::bail!("`q` and `k` last dims must match"); } // k,v must have matching n kv heads - if v_l.dims()[v_l.dims().len() - 3] != k_l.dims()[k_l.dims().len() - 3] { + if v_l.dim(D::Minus(3))? != k_l.dim(D::Minus(3))? { candle::bail!("`k` and `v` head dims must match"); } // n_heads % n_kv_heads == 0; n_heads >= 1, n_kv_heads >= 1. - if q_l.dims()[q_l.dims().len() - 3] % k_l.dims()[k_l.dims().len() - 3] != 0 { + if q_l.dim(D::Minus(3))? % k_l.dim(D::Minus(3))? != 0 { candle::bail!("query `n_heads` must be a multiple of `n_kv_heads`"); } - let k_head = k_l.dims()[k_l.dims().len() - 1]; - let q_head = q_l.dims()[q_l.dims().len() - 1]; - let q_seq = q_l.dims()[2]; + let k_head = k_l.dim(D::Minus1)?; + let q_head = q_l.dim(D::Minus1)?; + let q_seq = q_l.dim(2)?; let mut implementation_supports_use_case = q_head == k_head; let supported_head_dim = @@ -1076,7 +1076,7 @@ impl candle::CustomOp3 for Sdpa { ) .map_err(candle::Error::wrap)?; } else if supports_sdpa_full { - if q_l.dims()[2] != k_l.dims()[2] { + if q_l.dim(2)? != k_l.dim(2)? { candle::bail!( "query and key sequence length must be equal if using full metal sdpa" )