diff --git a/python/perf-kernels/README.md b/python/perf-kernels/README.md index e1edcd81090f..b142030af087 100644 --- a/python/perf-kernels/README.md +++ b/python/perf-kernels/README.md @@ -82,4 +82,4 @@ Kernel that implements RMS Norm over a row of tensor. Kernel that implements Layer Normalization over a row on tensor ## `fused_moe/moe-gemm.py` -Kernel that implements moe gemm \ No newline at end of file +Kernel that implements moe gemm diff --git a/python/perf-kernels/fused_moe/moe-gemm.py b/python/perf-kernels/fused_moe/moe-gemm.py index 9faa81598e59..7a4b6940484b 100644 --- a/python/perf-kernels/fused_moe/moe-gemm.py +++ b/python/perf-kernels/fused_moe/moe-gemm.py @@ -93,9 +93,7 @@ def moe_gemm_kernel( b_ptrs += BLOCK_SIZE_K * stride_bk if MUL_ROUTED_WEIGHT: - moe_weight = tl.load(topk_weights_ptr + offs_token, - mask=token_mask, - other=0) + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) accumulator = accumulator * moe_weight[:, None] offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -105,8 +103,10 @@ def moe_gemm_kernel( # tl.store(out_ptrs, accumulator, mask=c_mask) tl.store(out_ptrs, accumulator, mask=c_mask) -def _moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, top_k: int, block_size: int, sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, num_tokens_post_pad: torch.Tensor) -> None: + +def _moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, top_k: int, block_size: int, + sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor) -> None: M, top_k = topk_ids.shape # 1) Build a list of tokens for each expert @@ -132,7 +132,7 @@ def _moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, top_k: int, # Reorder all actual tokens for expert e_id reordered_token_ids.extend(tokens_for_expert) # reordered_expert_ids.extend([e_id]*num_tokens) - reordered_expert_ids.extend([e_id]*n_blocks) + reordered_expert_ids.extend([e_id] * n_blocks) # Pad with dummy token_id = -1 (or any sentinel), if needed if padded_size > num_tokens: @@ -192,7 +192,7 @@ def moe_align_block_size(topk_ids: torch.Tensor, block_size: int, sorted_ids.fill_(topk_ids.numel()) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) # TODO do we need to predefine the vars? may be they can be defined in the function and reurned - _moe_align_block_size(topk_ids, num_experts, top_k ,block_size, sorted_ids, expert_ids, num_tokens_post_pad) + _moe_align_block_size(topk_ids, num_experts, top_k, block_size, sorted_ids, expert_ids, num_tokens_post_pad) return sorted_ids, expert_ids, num_tokens_post_pad @@ -279,6 +279,7 @@ def get_config_dtype_str(dtype: torch.dtype, use_int8_w8a16: Optional[bool] = Fa def moe_gemm(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor) -> None: + # TODO get config type once quantization is enabled config_dtype = None get_config_func = functools.partial( try_get_optimal_moe_config, @@ -295,11 +296,11 @@ def moe_gemm(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, topk_weights: to EM = num_tokens_post_padded.item() _, N, K = b.shape - grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) + grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) moe_gemm_kernel[grid](a, b, c, a.stride(0), a.stride(1), b.stride(0), b.stride(1), b.stride(2), c.stride(1), - c.stride(2), top_k, topk_weights, sorted_token_ids, expert_ids, EM, N, - K, MUL_ROUTED_WEIGHT=topk_weights is not None, **config) + c.stride(2), top_k, topk_weights, sorted_token_ids, expert_ids, EM, N, K, + MUL_ROUTED_WEIGHT=topk_weights is not None, **config) return c @@ -308,11 +309,6 @@ def input_helper(M: int, K: int, N: int, top_k: int, E: int, routed_weight: bool b = torch.randn((E, N, K), dtype=torch.float32, device='cuda') c = torch.zeros((M, top_k, N), dtype=torch.float32, device='cuda') - # config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, - # use_int8_w8a16=use_int8_w8a16, - # dtype=hidden_states.dtype) - config_dtype = None - values = torch.randn(M, E, device='cuda') softmax_vals = torch.softmax(values, dim=1) @@ -322,8 +318,11 @@ def input_helper(M: int, K: int, N: int, top_k: int, E: int, routed_weight: bool return a, b, c, None, topk_ids return a, b, c, topk_weights, topk_ids + + # TODO assert the input shape + @pytest.mark.parametrize("M, K, N, top_k, E", [ (16, 1, 14336, 2, 4), (1, 128, 14336, 2, 4), @@ -335,7 +334,7 @@ def input_helper(M: int, K: int, N: int, top_k: int, E: int, routed_weight: bool (64, 128, 64, 2, 8), ]) @pytest.mark.parametrize('routed_weight', [True, False]) -def test_correctness(M: int, K: int, N: int, top_k: int, E: int, routed_weight:bool): +def test_correctness(M: int, K: int, N: int, top_k: int, E: int, routed_weight: bool): torch.manual_seed(20) a, b, c, topk_weights, topk_ids = input_helper(M, K, N, top_k, E, routed_weight=routed_weight)