Skip to content

Commit

Permalink
Optimize w8a8 kernel (#1353)
Browse files Browse the repository at this point in the history
* optimize autotune

* add kernel meta
  • Loading branch information
grimoire authored May 16, 2024
1 parent 8f79144 commit cff5634
Showing 1 changed file with 66 additions and 81 deletions.
147 changes: 66 additions & 81 deletions lmdeploy/pytorch/kernels/w8a8_triton_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.nn.functional as F
import triton
import triton.language as tl
from triton.runtime.jit import get_cuda_stream


def per_channel_quant(x, n_bits, dtype):
Expand Down Expand Up @@ -33,42 +34,19 @@ def per_channel_quant(x, n_bits, dtype):
@triton.autotune(
configs=[
triton.Config({
'BLOCK_M': 16,
'BLOCK_N': 128,
'BLOCK_K': 256,
},
num_stages=4,
num_warps=4),
triton.Config({
'BLOCK_M': 32,
'BLOCK_N': 64,
'BLOCK_K': 128,
},
num_stages=4,
num_warps=4),
triton.Config({
'BLOCK_M': 64,
'BLOCK_N': 64,
'BLOCK_K': 128,
},
num_stages=4,
num_warps=4),
triton.Config({
'BLOCK_M': 64,
'BLOCK_N': 128,
'BLOCK_K': 128,
},
num_stages=4,
num_warps=4),
triton.Config({
'BLOCK_M': 128,
'BLOCK_N': 128,
'BLOCK_K': 128,
},
num_stages=4,
num_warps=4)
],
key=['M', 'N', 'K'],
key=['N', 'K'],
)
@triton.jit
def _linear(
Expand Down Expand Up @@ -138,42 +116,19 @@ def _linear(
@triton.autotune(
configs=[
triton.Config({
'BLOCK_M': 16,
'BLOCK_N': 128,
'BLOCK_K': 256,
},
num_stages=4,
num_warps=4),
triton.Config({
'BLOCK_M': 32,
'BLOCK_N': 64,
'BLOCK_K': 128,
},
num_stages=4,
num_warps=4),
triton.Config({
'BLOCK_M': 64,
'BLOCK_N': 64,
'BLOCK_K': 128,
},
num_stages=4,
num_warps=4),
triton.Config({
'BLOCK_M': 64,
'BLOCK_N': 128,
'BLOCK_K': 128,
},
num_stages=4,
num_warps=4),
triton.Config({
'BLOCK_M': 128,
'BLOCK_N': 128,
'BLOCK_K': 128,
},
num_stages=4,
num_warps=4)
],
key=['M', 'N', 'K'],
key=['N', 'K'],
)
@triton.jit
def _linear_add(
Expand Down Expand Up @@ -258,21 +213,33 @@ def matmul_kernel_dynamic_quant(a,
`linear_scale`, and optionally adds a `residual` tensor and a `bias`. The
output is returned in the specified `output_dtype`.
"""

def __kernel_meta():
device = a.device
device_idx = device.index
device_type = device.type
stream = get_cuda_stream(device_idx)
return dict(device=device, device_type=device_type, stream=stream)

assert a.shape[-1] == b.shape[-1]
assert b.ndim == 2 and b.is_contiguous()
b = b.t() # (K, N)
M = a.numel() // a.shape[-1]
K, N = b.shape
N, K = b.shape
c_shape = a.shape[:-1] + (N, )
if residual is not None:
assert residual.shape == c_shape
assert residual.is_contiguous()
c = torch.empty(c_shape, device=a.device, dtype=output_dtype)
c = a.new_empty(c_shape, dtype=output_dtype)

BLOCK_M = 128
if M < BLOCK_M:
BLOCK_M = triton.next_power_of_2(M)
BLOCK_M = max(BLOCK_M, 16)

def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) *
triton.cdiv(N, META['BLOCK_N']), )
return (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, META['BLOCK_N']), )

kernel_meta = __kernel_meta()
if residual is not None:
_linear_add[grid](a,
b,
Expand All @@ -283,13 +250,15 @@ def grid(META):
K,
a.stride(-2),
a.stride(-1),
b.stride(0),
b.stride(1),
b.stride(0),
c.stride(-2),
c.stride(-1),
BLOCK_M=BLOCK_M,
GROUP_SIZE_M=8,
rms_scale_ptr=rms_scale,
linear_scale_ptr=linear_scale)
linear_scale_ptr=linear_scale,
**kernel_meta)
else:
_linear[grid](a,
b,
Expand All @@ -299,13 +268,15 @@ def grid(META):
K,
a.stride(-2),
a.stride(-1),
b.stride(0),
b.stride(1),
b.stride(0),
c.stride(-2),
c.stride(-1),
BLOCK_M=BLOCK_M,
GROUP_SIZE_M=8,
rms_scale_ptr=rms_scale,
linear_scale_ptr=linear_scale)
linear_scale_ptr=linear_scale,
**kernel_meta)
if bias is not None:
c += bias

Expand Down Expand Up @@ -340,7 +311,7 @@ def _per_token_quant_int8(
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / 127
y_q = tl.maximum(tl.minimum(tl.math.round(y / y_s), 127), -128).to(tl.int8)
y_q = tl.math.round(y / y_s).to(tl.int8)

tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
Expand All @@ -352,6 +323,14 @@ def per_token_quant_int8(x, eps):
It converts the tensor values into signed 8-bit integers and returns the
quantized tensor along with the scaling factor used for quantization.
"""

def __kernel_meta():
device = x.device
device_idx = device.index
device_type = device.type
stream = get_cuda_stream(device_idx)
return dict(device=device, device_type=device_type, stream=stream)

x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
M = x.numel() // x.shape[-1]
N = x.shape[-1]
Expand All @@ -362,14 +341,17 @@ def per_token_quant_int8(x, eps):
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
# enqueue kernel
kernel_meta = __kernel_meta()
_per_token_quant_int8[(M, )](x,
x_q,
x_s,
x.stride(-2),
N,
eps,
BLOCK=BLOCK,
num_warps=num_warps)
num_warps=num_warps,
**kernel_meta)

return x_q, x_s


Expand All @@ -389,18 +371,15 @@ def _rms_norm_fwd_fused_dynamic_symmetric(
row = tl.program_id(0)
Y += row * stride
X += row * stride
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)

cols = tl.arange(0, BLOCK_SIZE)
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
_var += x * x
mask = cols < N
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
_var = x * x
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
rstd = tl.math.rsqrt(var + eps)

cols = tl.arange(0, BLOCK_SIZE)
mask = cols < N
w = tl.load(W + cols, mask=mask)
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
x_hat = x * rstd
y = x_hat * w

Expand All @@ -420,7 +399,15 @@ def rms_norm_dynamic_quant(x, w, eps):
with the same shape as `x`, and calculates RMS normalization on the
reshaped `x` using a Triton kernel `_rms_norm_fwd_fused_dynamic_symmetric`.
"""
x_arg = x.reshape(-1, x.shape[-1])

def __kernel_meta():
device = x.device
device_idx = device.index
device_type = device.type
stream = get_cuda_stream(device_idx)
return dict(device=device, device_type=device_type, stream=stream)

x_arg = x.flatten(0, -2)
y = torch.empty_like(x, dtype=torch.int8)
M, K = x_arg.shape
MAX_FUSED_SIZE = 65536 // x.element_size()
Expand All @@ -429,20 +416,18 @@ def rms_norm_dynamic_quant(x, w, eps):
raise RuntimeError(
"This rms norm doesn't support feature dim >= 64KB.")
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
scale = torch.empty(x.shape[:-1] + (1, ),
dtype=torch.float32,
device=x.device)
_rms_norm_fwd_fused_dynamic_symmetric[(M, )](
x_arg,
y,
w,
scale,
x_arg.stride(0),
K,
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
scale = x.new_empty(x.shape[:-1] + (1, ), dtype=torch.float32)
kernel_meta = __kernel_meta()
_rms_norm_fwd_fused_dynamic_symmetric[(M, )](x_arg,
y,
w,
scale,
x_arg.stride(0),
K,
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
**kernel_meta)
return y, scale


Expand Down

0 comments on commit cff5634

Please sign in to comment.