Skip to content

Commit

Permalink
Fix more lints.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Nov 28, 2024
1 parent 3c82d21 commit 0b29ef3
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
20 changes: 10 additions & 10 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ pub fn call_unary_contiguous_tiled(
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let tile_size = 2;
let tiles = (length + tile_size - 1) / tile_size;
let tiles = length.div_ceil(tile_size);

encoder.set_compute_pipeline_state(&pipeline);

Expand Down Expand Up @@ -594,7 +594,7 @@ pub fn call_reduce_contiguous(

let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
(elements_to_sum as u64 + 2 - 1) / 2,
(elements_to_sum as u64).div_ceil(2),
)
.next_power_of_two();

Expand Down Expand Up @@ -1735,7 +1735,7 @@ pub fn call_sdpa_full(
}
};

let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?;
let pipeline = kernels.load_pipeline(device, Source::Sdpa, name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
Expand All @@ -1759,16 +1759,16 @@ pub fn call_sdpa_full(
let ldo = dk;

let tn = 1;
let tm = (m + BM - 1) / BM;
let tm = m.div_ceil(BM);

let b_stride_q = dk * qseq;
let b_stride_k = dk * qseq;
let b_stride_v = dk * qseq;
let b_stride_o = dk * qseq;
let swizzle_log = 0;
let gemm_n_iterations_aligned = (n + BN - 1) / BN;
let gemm_k_iterations_aligned = (k + bk - 1) / bk;
let gemm_sv_m_block_iterations = (m + BM - 1) / BM;
let gemm_n_iterations_aligned = n.div_ceil(BN);
let gemm_k_iterations_aligned = k.div_ceil(*bk);
let gemm_sv_m_block_iterations = m.div_ceil(BM);
let batch_ndim = batch_shape.len();

let alpha = if softcapping != 1. {
Expand Down Expand Up @@ -1906,7 +1906,7 @@ pub fn call_sdpa_vector(
alpha
};

let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?;
let pipeline = kernels.load_pipeline(device, Source::Sdpa, name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
Expand All @@ -1933,7 +1933,7 @@ pub fn call_sdpa_vector(
let grid_dims = MTLSize {
width: 1,
height: b as u64,
depth: 1 as u64,
depth: 1_u64,
};
let group_dims = MTLSize {
width: 1024,
Expand Down Expand Up @@ -2320,7 +2320,7 @@ pub fn call_quantized_matmul_mv_t(
}

fn divide(m: usize, b: usize) -> NSUInteger {
((m + b - 1) / b) as NSUInteger
m.div_ceil(b) as NSUInteger
}

#[allow(clippy::too_many_arguments)]
Expand Down
8 changes: 4 additions & 4 deletions candle-metal-kernels/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::ffi::c_void;
pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
let size = length as u64;
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
let count = (size + width - 1) / width;
let count = size.div_ceil(width);
let thread_group_count = MTLSize {
width: count,
height: 1,
Expand Down Expand Up @@ -128,7 +128,7 @@ impl EncoderParam for (&Buffer, usize) {
}
}

impl<'a> EncoderParam for &BufferOffset<'a> {
impl EncoderParam for &BufferOffset<'_> {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_buffer(position, Some(data.buffer), data.offset_in_bytes as u64);
}
Expand Down Expand Up @@ -169,15 +169,15 @@ pub struct WrappedEncoder<'a> {
end_encoding_on_drop: bool,
}

impl<'a> Drop for WrappedEncoder<'a> {
impl Drop for WrappedEncoder<'_> {
fn drop(&mut self) {
if self.end_encoding_on_drop {
self.inner.end_encoding()
}
}
}

impl<'a> AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'a> {
impl AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'_> {
fn as_ref(&self) -> &metal::ComputeCommandEncoderRef {
self.inner
}
Expand Down

0 comments on commit 0b29ef3

Please sign in to comment.