From c086d0846158cb023b22f4459237a1907d86b646 Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Wed, 18 Dec 2024 11:31:02 -0600 Subject: [PATCH] Load scales instead of constexpr (#684) Load scales from global memory. Scales are typically not provided as constexprs by users. --- python/perf-kernels/gemm.py | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/python/perf-kernels/gemm.py b/python/perf-kernels/gemm.py index bf8b7f657c36..21fbbcf8cc0a 100644 --- a/python/perf-kernels/gemm.py +++ b/python/perf-kernels/gemm.py @@ -55,7 +55,8 @@ def matmul_kernel( stride_bn, stride_cm, stride_cn, - scale, + a_scale_ptr, + b_scale_ptr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -92,6 +93,9 @@ def matmul_kernel( offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + if APPLY_SCALE: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr) acc_dtype = tl.float32 if c_ptr.type.element_ty != tl.int8 else tl.int32 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) @@ -110,12 +114,13 @@ def matmul_kernel( # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk + # Apply scale to recover dynamic range reduced due to lower precision inputs. + if APPLY_SCALE: + accumulator = accumulator * a_scale * b_scale # Apply activation function, if specified. + # TODO(vgokhale): Add different types of activations. if ACTIVATION == "leaky_relu": accumulator = leaky_relu(accumulator) - # Apply scale to recover dynamic range reduced due to lower precision inputs. - if APPLY_SCALE: - accumulator = accumulator * scale c = accumulator.to(c_ptr.type.element_ty) # Write back the block of the output matrix C with masks. @@ -134,15 +139,13 @@ def leaky_relu(x): # Wrapper for gemm kernel. -def matmul(a, b, c, a_scale, b_scale, activation=""): +def matmul(a, b, c, a_scale, b_scale, scale_a8_b8=False, activation=""): # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions!!!" assert a.dtype == b.dtype, "Mixed dtype GEMMs are not supported!!!" M, K = a.shape K, N = b.shape grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) - apply_scale = a_scale is not None and b_scale is not None - scale = a_scale * b_scale if apply_scale else None matmul_kernel[grid]( a, b, @@ -156,8 +159,9 @@ def matmul(a, b, c, a_scale, b_scale, activation=""): b.stride(1), c.stride(0), c.stride(1), - scale, - APPLY_SCALE=apply_scale, + a_scale, + b_scale, + APPLY_SCALE=scale_a8_b8, ACTIVATION=activation, ) @@ -173,9 +177,12 @@ def matmul(a, b, c, a_scale, b_scale, activation=""): } dtype_max = { - torch.float8_e5m2fnuz: 57344, - torch.float8_e4m3fnuz: 240, - torch.int8: 127, + dtype: (torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype)).max + for dtype in [ + torch.float8_e5m2fnuz, + torch.float8_e4m3fnuz, + torch.int8, + ] } @@ -213,6 +220,7 @@ def get_x_vals(): # Unit tests +#TODO(vgokhale): Test activation. @pytest.mark.parametrize( "M, N, K, in_dtype, out_dtype, col_a, col_b", [(*shape, in_dtype, out_dtype, col_a, col_b) @@ -232,12 +240,12 @@ def test_correctness(M, N, K, col_a, col_b, in_dtype, out_dtype): # This requires us to compute in fp32 because for e5m2, the range is same as fp16 (e5m10). # If we use fp16 it is possible to return infs from the torch.matmul call. if dtype_is_8_bit(torch_in_dtype): - matmul(a, b, c, a_scale.item(), b_scale.item(), activation="") + matmul(a, b, c, a_scale, b_scale, scale_a8_b8=True, activation="") torch_output = torch.matmul(a_fp32, b_fp32) torch_output = torch_output * a_scale * b_scale # For other dtypes, use the same torch matmul as the dtype. else: - matmul(a, b, c, a_scale=None, b_scale=None, activation="") + matmul(a, b, c, a_scale=None, b_scale=None, scale_a8_b8=False, activation="") torch_output = torch.matmul(a.to(torch_in_dtype), b.to(torch_in_dtype)) if out_dtype == 'int8': torch.testing.assert_close(c.to(torch.float32),