Skip to content

Commit

Permalink
Support ep, column major moe kernel. (#2690)
Browse files Browse the repository at this point in the history
* support EP, optimize moe kernel

* support ep and col major moe kernel

* remove create_weight_ep
  • Loading branch information
grimoire authored Nov 11, 2024
1 parent 47b0d1a commit 4a8d745
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 83 deletions.
43 changes: 37 additions & 6 deletions lmdeploy/pytorch/backends/cuda/moe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.

from typing import List

import torch

from lmdeploy.pytorch.kernels.cuda import fused_moe
Expand All @@ -10,7 +12,11 @@
class TritonFusedMoEImpl(FusedMoEImpl):
"""triton fused moe implementation."""

def __init__(self, top_k: int, renormalize: bool = False):
def __init__(self,
top_k: int,
num_experts: int,
renormalize: bool = False):
self.num_experts = num_experts
self.top_k = top_k
self.renormalize = renormalize

Expand All @@ -23,23 +29,48 @@ def update_weights(self, gate_up_weights: torch.Tensor,
2).contiguous().transpose(1, 2)
return gate_up_weights, down_weights

def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.LongTensor, gate_up_weights: torch.Tensor,
down_weights: torch.Tensor):
def support_ep(self):
"""support expert parallelism."""
return True

def ep_expert_list(self, world_size: int, rank: int):
"""experts list of current rank."""
num_experts = self.num_experts
expert_per_rank = (num_experts + world_size - 1) // world_size
first_expert = rank * expert_per_rank
last_expert = min(first_expert + expert_per_rank, num_experts)
return list(range(first_expert, last_expert))

def forward(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.LongTensor,
gate_up_weights: torch.Tensor,
down_weights: torch.Tensor,
expert_list: List[int] = None):
"""forward."""
expert_offset = 0
num_experts = None
if expert_list is not None and len(expert_list) != self.num_experts:
expert_offset = expert_list[0]
num_experts = self.num_experts
return fused_moe(hidden_states,
gate_up_weights,
down_weights,
topk_weights=topk_weights,
topk_ids=topk_ids,
topk=self.top_k,
expert_offset=expert_offset,
num_experts=num_experts,
renormalize=self.renormalize)


class TritonFusedMoEBuilder(FusedMoEBuilder):
"""triton fused moe builder."""

@staticmethod
def build(top_k: int, renormalize: bool = False):
def build(top_k: int, num_experts: int, renormalize: bool = False):
"""build from mlp."""
return TritonFusedMoEImpl(top_k=top_k, renormalize=renormalize)
return TritonFusedMoEImpl(top_k=top_k,
num_experts=num_experts,
renormalize=renormalize)
14 changes: 10 additions & 4 deletions lmdeploy/pytorch/backends/dlinfer/moe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.

from typing import List

import torch

from lmdeploy.pytorch.kernels.dlinfer import fused_moe, moe_gating_topk_softmax
Expand Down Expand Up @@ -38,9 +40,13 @@ def __init__(self, top_k: int, renormalize: bool = False):
self.top_k = top_k
self.renormalize = renormalize

def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.LongTensor, gate_up_weights: torch.Tensor,
down_weights: torch.Tensor):
def forward(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.LongTensor,
gate_up_weights: torch.Tensor,
down_weights: torch.Tensor,
expert_list: List[int] = None):
"""forward."""
return fused_moe(hidden_states, self.top_k, topk_ids, topk_weights,
gate_up_weights, down_weights)
Expand All @@ -50,6 +56,6 @@ class DlinferFusedMoEBuilder(FusedMoEBuilder):
"""dlinfer fused moe builder."""

@staticmethod
def build(top_k: int, renormalize: bool = False):
def build(top_k: int, num_experts: int, renormalize: bool = False):
"""build from mlp."""
return DlinferFusedMoEImpl(top_k=top_k, renormalize=renormalize)
21 changes: 17 additions & 4 deletions lmdeploy/pytorch/backends/moe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from typing import List

import torch

Expand Down Expand Up @@ -31,10 +32,22 @@ def update_weights(self, gate_up_weights: torch.Tensor,
"""update weights."""
return gate_up_weights, down_weights

def support_ep(self):
"""support expert parallelism."""
return False

def ep_expert_list(self, world_size: int, rank: int):
"""experts list of current rank."""
raise NotImplementedError('Not Implemented.')

@abstractmethod
def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.LongTensor, gate_up_weights: torch.Tensor,
down_weights: torch.Tensor):
def forward(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.LongTensor,
gate_up_weights: torch.Tensor,
down_weights: torch.Tensor,
expert_list: List[int] = None):
"""forward."""
raise NotImplementedError

Expand All @@ -44,6 +57,6 @@ class FusedMoEBuilder(ABC):

@staticmethod
@abstractmethod
def build(top_k: int, renormalize: bool = False):
def build(top_k: int, num_experts: int, renormalize: bool = False):
"""build from mlp."""
raise NotImplementedError
79 changes: 27 additions & 52 deletions lmdeploy/pytorch/kernels/cuda/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,24 @@
import triton.language as tl

from .activation import silu_and_mul
from .triton_utils import get_kernel_meta, wrap_jit_func
from .triton_utils import get_kernel_meta


def get_cuda_autotune_config():
return [
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 1,
},
num_stages=3,
num_warps=8),
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 1,
},
Expand All @@ -43,34 +43,9 @@ def get_cuda_autotune_config():
@triton.autotune(
configs=get_cuda_autotune_config(),
key=['N', 'K', 'M_NP2'],
warmup=10,
rep=25,
)
@wrap_jit_func(type_hint=dict(
A=torch.Tensor,
B=torch.Tensor,
C=torch.Tensor,
SortedIdx=torch.Tensor,
ExpStart=torch.Tensor,
ExpEnd=torch.Tensor,
Weights=torch.Tensor,
N=int,
K=int,
stride_am=int,
stride_ak=int,
stride_be=int,
stride_bn=int,
stride_bk=int,
stride_cm=int,
stride_cn=int,
BLOCK_SIZE_M=torch.int32,
BLOCK_SIZE_N=torch.int32,
BLOCK_SIZE_K=torch.int32,
GROUP_SIZE_M=torch.int32,
ENABLE_WEIGHTS=bool,
top_k=torch.int32,
expert_offset=torch.int32,
reindex_a=bool,
reindex_c=bool,
))
@triton.jit
def fused_moe_kernel(
A,
Expand Down Expand Up @@ -110,16 +85,23 @@ def fused_moe_kernel(
if M <= 0:
return

num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_m = tl.cdiv(M_NP2, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

if pid_m * BLOCK_SIZE_M >= M:

if GROUP_SIZE_M == 1:
pid_m = pid % num_pid_m
pid_n = pid // num_pid_m
# pid_m = pid // num_pid_n
# pid_n = pid % num_pid_n
else:
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

if pid_m * BLOCK_SIZE_M >= M or pid_n * BLOCK_SIZE_N >= N:
return

offs_sid = exp_start + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
Expand Down Expand Up @@ -189,11 +171,11 @@ def fused_moe_kernel_launcher(
if num_tokens is None:
num_tokens = A.size(0)
M_NP2 = triton.next_power_of_2(num_tokens)
M_NP2 = max(32, M_NP2)
M_NP2 = max(64, M_NP2)
E, N, K = B.shape

def _grid_fn(META):
grid = (triton.cdiv(num_tokens, META['BLOCK_SIZE_M']) *
grid = (triton.cdiv(M_NP2, META['BLOCK_SIZE_M']) *
triton.cdiv(N, META['BLOCK_SIZE_N']), E)
return grid

Expand Down Expand Up @@ -229,13 +211,6 @@ def _grid_fn(META):
)


@wrap_jit_func(type_hint=dict(TopkIdx=torch.Tensor,
SortedIdx=torch.Tensor,
ExpStart=torch.Tensor,
ExpEnd=torch.Tensor,
len_sorted_idx=int,
num_experts=torch.int32,
BLOCK=torch.int32))
@triton.jit
def _start_end_kernel(TopkIdx, SortedIdx, ExpStart, ExpEnd,
len_sorted_idx: int, num_experts: tl.constexpr,
Expand Down
Loading

0 comments on commit 4a8d745

Please sign in to comment.