From 4a8d745604246ae6fdd1978688c28cf7ffda2a03 Mon Sep 17 00:00:00 2001 From: q yao Date: Mon, 11 Nov 2024 21:11:30 +0800 Subject: [PATCH] Support ep, column major moe kernel. (#2690) * support EP, optimize moe kernel * support ep and col major moe kernel * remove create_weight_ep --- lmdeploy/pytorch/backends/cuda/moe.py | 43 ++++++++++-- lmdeploy/pytorch/backends/dlinfer/moe.py | 14 ++-- lmdeploy/pytorch/backends/moe.py | 21 ++++-- lmdeploy/pytorch/kernels/cuda/fused_moe.py | 79 ++++++++-------------- lmdeploy/pytorch/nn/moe.py | 79 +++++++++++++++++----- 5 files changed, 153 insertions(+), 83 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/moe.py b/lmdeploy/pytorch/backends/cuda/moe.py index e5ae92d8bd..eb38401211 100644 --- a/lmdeploy/pytorch/backends/cuda/moe.py +++ b/lmdeploy/pytorch/backends/cuda/moe.py @@ -1,5 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import List + import torch from lmdeploy.pytorch.kernels.cuda import fused_moe @@ -10,7 +12,11 @@ class TritonFusedMoEImpl(FusedMoEImpl): """triton fused moe implementation.""" - def __init__(self, top_k: int, renormalize: bool = False): + def __init__(self, + top_k: int, + num_experts: int, + renormalize: bool = False): + self.num_experts = num_experts self.top_k = top_k self.renormalize = renormalize @@ -23,16 +29,39 @@ def update_weights(self, gate_up_weights: torch.Tensor, 2).contiguous().transpose(1, 2) return gate_up_weights, down_weights - def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.LongTensor, gate_up_weights: torch.Tensor, - down_weights: torch.Tensor): + def support_ep(self): + """support expert parallelism.""" + return True + + def ep_expert_list(self, world_size: int, rank: int): + """experts list of current rank.""" + num_experts = self.num_experts + expert_per_rank = (num_experts + world_size - 1) // world_size + first_expert = rank * expert_per_rank + last_expert = min(first_expert + expert_per_rank, num_experts) + return list(range(first_expert, last_expert)) + + def forward(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.LongTensor, + gate_up_weights: torch.Tensor, + down_weights: torch.Tensor, + expert_list: List[int] = None): """forward.""" + expert_offset = 0 + num_experts = None + if expert_list is not None and len(expert_list) != self.num_experts: + expert_offset = expert_list[0] + num_experts = self.num_experts return fused_moe(hidden_states, gate_up_weights, down_weights, topk_weights=topk_weights, topk_ids=topk_ids, topk=self.top_k, + expert_offset=expert_offset, + num_experts=num_experts, renormalize=self.renormalize) @@ -40,6 +69,8 @@ class TritonFusedMoEBuilder(FusedMoEBuilder): """triton fused moe builder.""" @staticmethod - def build(top_k: int, renormalize: bool = False): + def build(top_k: int, num_experts: int, renormalize: bool = False): """build from mlp.""" - return TritonFusedMoEImpl(top_k=top_k, renormalize=renormalize) + return TritonFusedMoEImpl(top_k=top_k, + num_experts=num_experts, + renormalize=renormalize) diff --git a/lmdeploy/pytorch/backends/dlinfer/moe.py b/lmdeploy/pytorch/backends/dlinfer/moe.py index eb8b1e591e..90f6335ecb 100644 --- a/lmdeploy/pytorch/backends/dlinfer/moe.py +++ b/lmdeploy/pytorch/backends/dlinfer/moe.py @@ -1,5 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import List + import torch from lmdeploy.pytorch.kernels.dlinfer import fused_moe, moe_gating_topk_softmax @@ -38,9 +40,13 @@ def __init__(self, top_k: int, renormalize: bool = False): self.top_k = top_k self.renormalize = renormalize - def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.LongTensor, gate_up_weights: torch.Tensor, - down_weights: torch.Tensor): + def forward(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.LongTensor, + gate_up_weights: torch.Tensor, + down_weights: torch.Tensor, + expert_list: List[int] = None): """forward.""" return fused_moe(hidden_states, self.top_k, topk_ids, topk_weights, gate_up_weights, down_weights) @@ -50,6 +56,6 @@ class DlinferFusedMoEBuilder(FusedMoEBuilder): """dlinfer fused moe builder.""" @staticmethod - def build(top_k: int, renormalize: bool = False): + def build(top_k: int, num_experts: int, renormalize: bool = False): """build from mlp.""" return DlinferFusedMoEImpl(top_k=top_k, renormalize=renormalize) diff --git a/lmdeploy/pytorch/backends/moe.py b/lmdeploy/pytorch/backends/moe.py index 4a1d5b73da..8e7977625e 100644 --- a/lmdeploy/pytorch/backends/moe.py +++ b/lmdeploy/pytorch/backends/moe.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABC, abstractmethod +from typing import List import torch @@ -31,10 +32,22 @@ def update_weights(self, gate_up_weights: torch.Tensor, """update weights.""" return gate_up_weights, down_weights + def support_ep(self): + """support expert parallelism.""" + return False + + def ep_expert_list(self, world_size: int, rank: int): + """experts list of current rank.""" + raise NotImplementedError('Not Implemented.') + @abstractmethod - def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.LongTensor, gate_up_weights: torch.Tensor, - down_weights: torch.Tensor): + def forward(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.LongTensor, + gate_up_weights: torch.Tensor, + down_weights: torch.Tensor, + expert_list: List[int] = None): """forward.""" raise NotImplementedError @@ -44,6 +57,6 @@ class FusedMoEBuilder(ABC): @staticmethod @abstractmethod - def build(top_k: int, renormalize: bool = False): + def build(top_k: int, num_experts: int, renormalize: bool = False): """build from mlp.""" raise NotImplementedError diff --git a/lmdeploy/pytorch/kernels/cuda/fused_moe.py b/lmdeploy/pytorch/kernels/cuda/fused_moe.py index e9ac7087cd..9f9771368e 100644 --- a/lmdeploy/pytorch/kernels/cuda/fused_moe.py +++ b/lmdeploy/pytorch/kernels/cuda/fused_moe.py @@ -5,7 +5,7 @@ import triton.language as tl from .activation import silu_and_mul -from .triton_utils import get_kernel_meta, wrap_jit_func +from .triton_utils import get_kernel_meta def get_cuda_autotune_config(): @@ -13,16 +13,16 @@ def get_cuda_autotune_config(): triton.Config( { 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, }, - num_stages=3, - num_warps=8), + num_stages=4, + num_warps=4), triton.Config( { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, }, @@ -43,34 +43,9 @@ def get_cuda_autotune_config(): @triton.autotune( configs=get_cuda_autotune_config(), key=['N', 'K', 'M_NP2'], + warmup=10, + rep=25, ) -@wrap_jit_func(type_hint=dict( - A=torch.Tensor, - B=torch.Tensor, - C=torch.Tensor, - SortedIdx=torch.Tensor, - ExpStart=torch.Tensor, - ExpEnd=torch.Tensor, - Weights=torch.Tensor, - N=int, - K=int, - stride_am=int, - stride_ak=int, - stride_be=int, - stride_bn=int, - stride_bk=int, - stride_cm=int, - stride_cn=int, - BLOCK_SIZE_M=torch.int32, - BLOCK_SIZE_N=torch.int32, - BLOCK_SIZE_K=torch.int32, - GROUP_SIZE_M=torch.int32, - ENABLE_WEIGHTS=bool, - top_k=torch.int32, - expert_offset=torch.int32, - reindex_a=bool, - reindex_c=bool, -)) @triton.jit def fused_moe_kernel( A, @@ -110,16 +85,23 @@ def fused_moe_kernel( if M <= 0: return - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_m = tl.cdiv(M_NP2, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - if pid_m * BLOCK_SIZE_M >= M: + + if GROUP_SIZE_M == 1: + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + # pid_m = pid // num_pid_n + # pid_n = pid % num_pid_n + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + if pid_m * BLOCK_SIZE_M >= M or pid_n * BLOCK_SIZE_N >= N: return offs_sid = exp_start + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) @@ -189,11 +171,11 @@ def fused_moe_kernel_launcher( if num_tokens is None: num_tokens = A.size(0) M_NP2 = triton.next_power_of_2(num_tokens) - M_NP2 = max(32, M_NP2) + M_NP2 = max(64, M_NP2) E, N, K = B.shape def _grid_fn(META): - grid = (triton.cdiv(num_tokens, META['BLOCK_SIZE_M']) * + grid = (triton.cdiv(M_NP2, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), E) return grid @@ -229,13 +211,6 @@ def _grid_fn(META): ) -@wrap_jit_func(type_hint=dict(TopkIdx=torch.Tensor, - SortedIdx=torch.Tensor, - ExpStart=torch.Tensor, - ExpEnd=torch.Tensor, - len_sorted_idx=int, - num_experts=torch.int32, - BLOCK=torch.int32)) @triton.jit def _start_end_kernel(TopkIdx, SortedIdx, ExpStart, ExpEnd, len_sorted_idx: int, num_experts: tl.constexpr, diff --git a/lmdeploy/pytorch/nn/moe.py b/lmdeploy/pytorch/nn/moe.py index 6467a6de08..47176335c4 100644 --- a/lmdeploy/pytorch/nn/moe.py +++ b/lmdeploy/pytorch/nn/moe.py @@ -35,32 +35,54 @@ def __init__(self, renormalize: bool = False, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, - all_reduce: bool = True): + all_reduce: bool = True, + enable_ep: bool = False): super().__init__() if device is None: device = torch.device('cpu') if dtype is None: dtype = torch.float16 - hidden_dim, ffn_dim = self._update_args(hidden_dim, ffn_dim) impl_builder = get_backend().get_layer_impl_builder(OpType.FusedMoE) - self.impl = impl_builder.build(top_k, renormalize) - - gate_up_weights, down_weights = self.create_weights(hidden_dim, - ffn_dim, - num_experts, - dtype=dtype, - device=device) + self.impl = impl_builder.build(top_k, num_experts, renormalize) + + self.expert_list = None + self.expert_map = None + enable_ep = enable_ep and self.impl.support_ep() + if enable_ep: + world_size, rank = get_world_rank() + expert_list = self.impl.ep_expert_list(world_size, rank) + self.expert_list = expert_list + self.expert_map = dict( + (eid, idx) for idx, eid in enumerate(expert_list)) + num_experts = len(expert_list) + gate_up_weights, down_weights = self.create_weights(hidden_dim, + ffn_dim, + num_experts, + dtype=dtype, + device=device) + else: + hidden_dim, ffn_dim = self._update_args(hidden_dim, ffn_dim) + gate_up_weights, down_weights = self.create_weights(hidden_dim, + ffn_dim, + num_experts, + dtype=dtype, + device=device) gate_up_weights = torch.nn.Parameter(gate_up_weights, requires_grad=False) down_weights = torch.nn.Parameter(down_weights, requires_grad=False) - gate_up_weights.weight_loader = self.weight_loader - down_weights.weight_loader = self.weight_loader gate_up_weights._weight_type = 'gate_up_weights' down_weights._weight_type = 'down_weights' self.register_parameter('gate_up_weights', gate_up_weights) self.register_parameter('down_weights', down_weights) + if enable_ep: + gate_up_weights.weight_loader = self.weight_loader_ep + down_weights.weight_loader = self.weight_loader_ep + else: + gate_up_weights.weight_loader = self.weight_loader_tp + down_weights.weight_loader = self.weight_loader_tp + self.hidden_dim = hidden_dim self.ffn_dim = ffn_dim self.num_experts = num_experts @@ -91,21 +113,23 @@ def create_weights(self, hidden_dim: int, ffn_dim: int, num_experts: int, def update_weights(self): """update weights.""" + gateup_loader = self.gate_up_weights.weight_loader + down_loader = self.down_weights.weight_loader gate_up_weights, down_weights = self.impl.update_weights( self.gate_up_weights, self.down_weights) gate_up_weights = torch.nn.Parameter(gate_up_weights, requires_grad=False) down_weights = torch.nn.Parameter(down_weights, requires_grad=False) - gate_up_weights.weight_loader = self.weight_loader - down_weights.weight_loader = self.weight_loader + gate_up_weights.weight_loader = gateup_loader + down_weights.weight_loader = down_loader gate_up_weights._weight_type = 'gate_up_weights' down_weights._weight_type = 'down_weights' self.register_parameter('gate_up_weights', gate_up_weights) self.register_parameter('down_weights', down_weights) - def weight_loader(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor, expert_id: int, - shard_id: str): + def weight_loader_tp(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, expert_id: int, + shard_id: str): """weight loader.""" world_size, rank = get_world_rank() if shard_id == 'gate': @@ -121,10 +145,31 @@ def weight_loader(self, param: torch.nn.Parameter, raise RuntimeError(f'Unknown shard_id: {shard_id}') param_data.copy_(weight) + def weight_loader_ep(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, expert_id: int, + shard_id: str): + """weight loader.""" + expert_list = self.expert_list + if expert_id not in expert_list: + return + + expert_map = self.expert_map + param_id = expert_map[expert_id] + if shard_id == 'gate': + param_data = param.data[param_id, :self.ffn_dim] + elif shard_id == 'up': + param_data = param.data[param_id, self.ffn_dim:] + elif shard_id == 'down': + param_data = param.data[param_id] + else: + raise RuntimeError(f'Unknown shard_id: {shard_id}') + param_data.copy_(loaded_weight) + def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.LongTensor): ret = self.impl.forward(hidden_states, topk_weights, topk_ids, - self.gate_up_weights, self.down_weights) + self.gate_up_weights, self.down_weights, + self.expert_list) if self.all_reduce: dist.all_reduce(ret) return ret