Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Chi-Chu319 committed Dec 18, 2024
1 parent fe44173 commit 6f737c6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
2 changes: 1 addition & 1 deletion python/perf-kernels/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,4 @@ Kernel that implements RMS Norm over a row of tensor.
Kernel that implements Layer Normalization over a row on tensor

## `fused_moe/moe-gemm.py`
Kernel that implements moe gemm
Kernel that implements moe gemm
31 changes: 15 additions & 16 deletions python/perf-kernels/fused_moe/moe-gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,7 @@ def moe_gemm_kernel(
b_ptrs += BLOCK_SIZE_K * stride_bk

if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token,
mask=token_mask,
other=0)
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]

offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
Expand All @@ -105,8 +103,10 @@ def moe_gemm_kernel(
# tl.store(out_ptrs, accumulator, mask=c_mask)
tl.store(out_ptrs, accumulator, mask=c_mask)

def _moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, top_k: int, block_size: int, sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor, num_tokens_post_pad: torch.Tensor) -> None:

def _moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, top_k: int, block_size: int,
sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor) -> None:
M, top_k = topk_ids.shape

# 1) Build a list of tokens for each expert
Expand All @@ -132,7 +132,7 @@ def _moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, top_k: int,
# Reorder all actual tokens for expert e_id
reordered_token_ids.extend(tokens_for_expert)
# reordered_expert_ids.extend([e_id]*num_tokens)
reordered_expert_ids.extend([e_id]*n_blocks)
reordered_expert_ids.extend([e_id] * n_blocks)

# Pad with dummy token_id = -1 (or any sentinel), if needed
if padded_size > num_tokens:
Expand Down Expand Up @@ -192,7 +192,7 @@ def moe_align_block_size(topk_ids: torch.Tensor, block_size: int,
sorted_ids.fill_(topk_ids.numel())
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
# TODO do we need to predefine the vars? may be they can be defined in the function and reurned
_moe_align_block_size(topk_ids, num_experts, top_k ,block_size, sorted_ids, expert_ids, num_tokens_post_pad)
_moe_align_block_size(topk_ids, num_experts, top_k, block_size, sorted_ids, expert_ids, num_tokens_post_pad)

return sorted_ids, expert_ids, num_tokens_post_pad

Expand Down Expand Up @@ -279,6 +279,7 @@ def get_config_dtype_str(dtype: torch.dtype, use_int8_w8a16: Optional[bool] = Fa

def moe_gemm(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor) -> None:
# TODO get config type once quantization is enabled
config_dtype = None
get_config_func = functools.partial(
try_get_optimal_moe_config,
Expand All @@ -295,11 +296,11 @@ def moe_gemm(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, topk_weights: to

EM = num_tokens_post_padded.item()
_, N, K = b.shape
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )

moe_gemm_kernel[grid](a, b, c, a.stride(0), a.stride(1), b.stride(0), b.stride(1), b.stride(2), c.stride(1),
c.stride(2), top_k, topk_weights, sorted_token_ids, expert_ids, EM, N,
K, MUL_ROUTED_WEIGHT=topk_weights is not None, **config)
c.stride(2), top_k, topk_weights, sorted_token_ids, expert_ids, EM, N, K,
MUL_ROUTED_WEIGHT=topk_weights is not None, **config)
return c


Expand All @@ -308,11 +309,6 @@ def input_helper(M: int, K: int, N: int, top_k: int, E: int, routed_weight: bool
b = torch.randn((E, N, K), dtype=torch.float32, device='cuda')
c = torch.zeros((M, top_k, N), dtype=torch.float32, device='cuda')

# config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
# use_int8_w8a16=use_int8_w8a16,
# dtype=hidden_states.dtype)
config_dtype = None

values = torch.randn(M, E, device='cuda')

softmax_vals = torch.softmax(values, dim=1)
Expand All @@ -322,8 +318,11 @@ def input_helper(M: int, K: int, N: int, top_k: int, E: int, routed_weight: bool
return a, b, c, None, topk_ids

return a, b, c, topk_weights, topk_ids


# TODO assert the input shape


@pytest.mark.parametrize("M, K, N, top_k, E", [
(16, 1, 14336, 2, 4),
(1, 128, 14336, 2, 4),
Expand All @@ -335,7 +334,7 @@ def input_helper(M: int, K: int, N: int, top_k: int, E: int, routed_weight: bool
(64, 128, 64, 2, 8),
])
@pytest.mark.parametrize('routed_weight', [True, False])
def test_correctness(M: int, K: int, N: int, top_k: int, E: int, routed_weight:bool):
def test_correctness(M: int, K: int, N: int, top_k: int, E: int, routed_weight: bool):
torch.manual_seed(20)
a, b, c, topk_weights, topk_ids = input_helper(M, K, N, top_k, E, routed_weight=routed_weight)

Expand Down

0 comments on commit 6f737c6

Please sign in to comment.