From cff56342e3dd8e7e8baef384eadefc08d05d6a00 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Thu, 16 May 2024 19:39:08 +0800 Subject: [PATCH] Optimize w8a8 kernel (#1353) * optimize autotune * add kernel meta --- .../pytorch/kernels/w8a8_triton_kernels.py | 147 ++++++++---------- 1 file changed, 66 insertions(+), 81 deletions(-) diff --git a/lmdeploy/pytorch/kernels/w8a8_triton_kernels.py b/lmdeploy/pytorch/kernels/w8a8_triton_kernels.py index 0e3115029e..da65599bef 100644 --- a/lmdeploy/pytorch/kernels/w8a8_triton_kernels.py +++ b/lmdeploy/pytorch/kernels/w8a8_triton_kernels.py @@ -3,6 +3,7 @@ import torch.nn.functional as F import triton import triton.language as tl +from triton.runtime.jit import get_cuda_stream def per_channel_quant(x, n_bits, dtype): @@ -33,42 +34,19 @@ def per_channel_quant(x, n_bits, dtype): @triton.autotune( configs=[ triton.Config({ - 'BLOCK_M': 16, - 'BLOCK_N': 128, - 'BLOCK_K': 256, - }, - num_stages=4, - num_warps=4), - triton.Config({ - 'BLOCK_M': 32, 'BLOCK_N': 64, 'BLOCK_K': 128, }, num_stages=4, num_warps=4), triton.Config({ - 'BLOCK_M': 64, - 'BLOCK_N': 64, - 'BLOCK_K': 128, - }, - num_stages=4, - num_warps=4), - triton.Config({ - 'BLOCK_M': 64, - 'BLOCK_N': 128, - 'BLOCK_K': 128, - }, - num_stages=4, - num_warps=4), - triton.Config({ - 'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, }, num_stages=4, num_warps=4) ], - key=['M', 'N', 'K'], + key=['N', 'K'], ) @triton.jit def _linear( @@ -138,42 +116,19 @@ def _linear( @triton.autotune( configs=[ triton.Config({ - 'BLOCK_M': 16, - 'BLOCK_N': 128, - 'BLOCK_K': 256, - }, - num_stages=4, - num_warps=4), - triton.Config({ - 'BLOCK_M': 32, 'BLOCK_N': 64, 'BLOCK_K': 128, }, num_stages=4, num_warps=4), triton.Config({ - 'BLOCK_M': 64, - 'BLOCK_N': 64, - 'BLOCK_K': 128, - }, - num_stages=4, - num_warps=4), - triton.Config({ - 'BLOCK_M': 64, - 'BLOCK_N': 128, - 'BLOCK_K': 128, - }, - num_stages=4, - num_warps=4), - triton.Config({ - 'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, }, num_stages=4, num_warps=4) ], - key=['M', 'N', 'K'], + key=['N', 'K'], ) @triton.jit def _linear_add( @@ -258,21 +213,33 @@ def matmul_kernel_dynamic_quant(a, `linear_scale`, and optionally adds a `residual` tensor and a `bias`. The output is returned in the specified `output_dtype`. """ + + def __kernel_meta(): + device = a.device + device_idx = device.index + device_type = device.type + stream = get_cuda_stream(device_idx) + return dict(device=device, device_type=device_type, stream=stream) + assert a.shape[-1] == b.shape[-1] assert b.ndim == 2 and b.is_contiguous() - b = b.t() # (K, N) M = a.numel() // a.shape[-1] - K, N = b.shape + N, K = b.shape c_shape = a.shape[:-1] + (N, ) if residual is not None: assert residual.shape == c_shape assert residual.is_contiguous() - c = torch.empty(c_shape, device=a.device, dtype=output_dtype) + c = a.new_empty(c_shape, dtype=output_dtype) + + BLOCK_M = 128 + if M < BLOCK_M: + BLOCK_M = triton.next_power_of_2(M) + BLOCK_M = max(BLOCK_M, 16) def grid(META): - return (triton.cdiv(M, META['BLOCK_M']) * - triton.cdiv(N, META['BLOCK_N']), ) + return (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, META['BLOCK_N']), ) + kernel_meta = __kernel_meta() if residual is not None: _linear_add[grid](a, b, @@ -283,13 +250,15 @@ def grid(META): K, a.stride(-2), a.stride(-1), - b.stride(0), b.stride(1), + b.stride(0), c.stride(-2), c.stride(-1), + BLOCK_M=BLOCK_M, GROUP_SIZE_M=8, rms_scale_ptr=rms_scale, - linear_scale_ptr=linear_scale) + linear_scale_ptr=linear_scale, + **kernel_meta) else: _linear[grid](a, b, @@ -299,13 +268,15 @@ def grid(META): K, a.stride(-2), a.stride(-1), - b.stride(0), b.stride(1), + b.stride(0), c.stride(-2), c.stride(-1), + BLOCK_M=BLOCK_M, GROUP_SIZE_M=8, rms_scale_ptr=rms_scale, - linear_scale_ptr=linear_scale) + linear_scale_ptr=linear_scale, + **kernel_meta) if bias is not None: c += bias @@ -340,7 +311,7 @@ def _per_token_quant_int8( # Quant _absmax = tl.maximum(tl.max(tl.abs(y)), eps) y_s = _absmax / 127 - y_q = tl.maximum(tl.minimum(tl.math.round(y / y_s), 127), -128).to(tl.int8) + y_q = tl.math.round(y / y_s).to(tl.int8) tl.store(y_q_ptr + cols, y_q, mask=mask) tl.store(y_s_ptr, y_s) @@ -352,6 +323,14 @@ def per_token_quant_int8(x, eps): It converts the tensor values into signed 8-bit integers and returns the quantized tensor along with the scaling factor used for quantization. """ + + def __kernel_meta(): + device = x.device + device_idx = device.index + device_type = device.type + stream = get_cuda_stream(device_idx) + return dict(device=device, device_type=device_type, stream=stream) + x_q = torch.empty_like(x, device=x.device, dtype=torch.int8) M = x.numel() // x.shape[-1] N = x.shape[-1] @@ -362,6 +341,7 @@ def per_token_quant_int8(x, eps): # heuristics for number of warps num_warps = min(max(BLOCK // 256, 1), 8) # enqueue kernel + kernel_meta = __kernel_meta() _per_token_quant_int8[(M, )](x, x_q, x_s, @@ -369,7 +349,9 @@ def per_token_quant_int8(x, eps): N, eps, BLOCK=BLOCK, - num_warps=num_warps) + num_warps=num_warps, + **kernel_meta) + return x_q, x_s @@ -389,18 +371,15 @@ def _rms_norm_fwd_fused_dynamic_symmetric( row = tl.program_id(0) Y += row * stride X += row * stride - _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) cols = tl.arange(0, BLOCK_SIZE) - x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) - _var += x * x + mask = cols < N + x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) + _var = x * x var = tl.sum(_var, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) + rstd = tl.math.rsqrt(var + eps) - cols = tl.arange(0, BLOCK_SIZE) - mask = cols < N w = tl.load(W + cols, mask=mask) - x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) x_hat = x * rstd y = x_hat * w @@ -420,7 +399,15 @@ def rms_norm_dynamic_quant(x, w, eps): with the same shape as `x`, and calculates RMS normalization on the reshaped `x` using a Triton kernel `_rms_norm_fwd_fused_dynamic_symmetric`. """ - x_arg = x.reshape(-1, x.shape[-1]) + + def __kernel_meta(): + device = x.device + device_idx = device.index + device_type = device.type + stream = get_cuda_stream(device_idx) + return dict(device=device, device_type=device_type, stream=stream) + + x_arg = x.flatten(0, -2) y = torch.empty_like(x, dtype=torch.int8) M, K = x_arg.shape MAX_FUSED_SIZE = 65536 // x.element_size() @@ -429,20 +416,18 @@ def rms_norm_dynamic_quant(x, w, eps): raise RuntimeError( "This rms norm doesn't support feature dim >= 64KB.") num_warps = min(max(BLOCK_SIZE // 256, 1), 8) - scale = torch.empty(x.shape[:-1] + (1, ), - dtype=torch.float32, - device=x.device) - _rms_norm_fwd_fused_dynamic_symmetric[(M, )]( - x_arg, - y, - w, - scale, - x_arg.stride(0), - K, - eps, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, - ) + scale = x.new_empty(x.shape[:-1] + (1, ), dtype=torch.float32) + kernel_meta = __kernel_meta() + _rms_norm_fwd_fused_dynamic_symmetric[(M, )](x_arg, + y, + w, + scale, + x_arg.stride(0), + K, + eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + **kernel_meta) return y, scale