From dc77aa4776e163fb9c118f9b5953439361d4e551 Mon Sep 17 00:00:00 2001 From: Daniel Grittner Date: Sun, 26 Nov 2023 23:56:50 +0100 Subject: [PATCH 1/3] [Feature] Add expert loss function --- pipegoose/nn/expert_parallel/__init__.py | 1 + .../nn/expert_parallel/expert_context.py | 23 +++++++++++++ .../nn/expert_parallel/expert_parallel.py | 4 +++ pipegoose/nn/expert_parallel/layers.py | 9 ++++-- pipegoose/nn/expert_parallel/loss.py | 20 +++++++++--- pipegoose/nn/expert_parallel/routers.py | 32 ++++++++++++++----- .../nn/expert_parallel/test_expert_context.py | 17 ++++++++++ tests/nn/expert_parallel/test_expert_loss.py | 30 +++++++++++------ .../expert_parallel/test_expert_parallel.py | 24 +++++++++++--- tests/nn/expert_parallel/test_layers.py | 13 +++++++- tests/nn/expert_parallel/test_routers.py | 21 ++++++------ 11 files changed, 155 insertions(+), 39 deletions(-) create mode 100644 pipegoose/nn/expert_parallel/expert_context.py create mode 100644 tests/nn/expert_parallel/test_expert_context.py diff --git a/pipegoose/nn/expert_parallel/__init__.py b/pipegoose/nn/expert_parallel/__init__.py index a6f4ed9..1c55c66 100644 --- a/pipegoose/nn/expert_parallel/__init__.py +++ b/pipegoose/nn/expert_parallel/__init__.py @@ -1,3 +1,4 @@ +from pipegoose.nn.expert_parallel.expert_context import ExpertContext from pipegoose.nn.expert_parallel.expert_parallel import ExpertParallel from pipegoose.nn.expert_parallel.loss import ExpertLoss from pipegoose.nn.expert_parallel.routers import Top1Router, Top2Router, SwitchNoisePolicy diff --git a/pipegoose/nn/expert_parallel/expert_context.py b/pipegoose/nn/expert_parallel/expert_context.py new file mode 100644 index 0000000..23acb7d --- /dev/null +++ b/pipegoose/nn/expert_parallel/expert_context.py @@ -0,0 +1,23 @@ +from typing import List + +from torchtyping import TensorType + + +class ExpertContext: + def __init__(self): + self.aux_loss = [] + self.z_loss = [] + + def push_aux_loss(self, aux_loss: TensorType): + self.aux_loss.append(aux_loss) + + def pop_all_aux_loss(self) -> List[TensorType]: + aux_loss, self.aux_loss = self.aux_loss, [] + return aux_loss + + def push_z_loss(self, z_loss: TensorType): + self.z_loss.append(z_loss) + + def pop_all_z_loss(self) -> List[TensorType]: + z_loss, self.z_loss = self.z_loss, [] + return z_loss diff --git a/pipegoose/nn/expert_parallel/expert_parallel.py b/pipegoose/nn/expert_parallel/expert_parallel.py index faa961c..8191033 100644 --- a/pipegoose/nn/expert_parallel/expert_parallel.py +++ b/pipegoose/nn/expert_parallel/expert_parallel.py @@ -8,6 +8,7 @@ from pipegoose.distributed.parallel_mode import ParallelMode from pipegoose.nn.expert_parallel.layers import ExpertLayer from pipegoose.nn.parallel import Parallel +from pipegoose.nn.expert_parallel.expert_context import ExpertContext class ExpertParallel(Parallel): @@ -28,6 +29,7 @@ def __init__( # noise_poligy: Union[str, Callable], enable_tensor_parallelism: bool = False, parallel_context: ParallelContext = None, + expert_context: ExpertContext = None ): tensor_parallel_size = parallel_context.get_world_size(ParallelMode.TENSOR) assert parallel_context is not None, "parallel_context must be provided" @@ -49,6 +51,7 @@ def __init__( # self.noise_policy = noise_poligy self.enable_tensor_parallelism = enable_tensor_parallelism self.parallel_context = parallel_context + self.expert_context = expert_context @torch.no_grad() def parallelize(self) -> nn.Module: @@ -65,6 +68,7 @@ def parallelize(self) -> nn.Module: self.router, self.enable_tensor_parallelism, self.parallel_context, + self.expert_context ) getattr(self.module, "transformer").h[layer_idx].mlp = expert_layer diff --git a/pipegoose/nn/expert_parallel/layers.py b/pipegoose/nn/expert_parallel/layers.py index 9a31804..e9c4970 100644 --- a/pipegoose/nn/expert_parallel/layers.py +++ b/pipegoose/nn/expert_parallel/layers.py @@ -5,6 +5,7 @@ from pipegoose.nn.expert_parallel.experts import Experts from pipegoose.nn.expert_parallel.routers import Router from pipegoose.nn.expert_parallel.utils import get_num_local_experts +from pipegoose.nn.expert_parallel.expert_context import ExpertContext class ExpertLayer(nn.Module): @@ -21,6 +22,7 @@ def __init__( router: Router, enable_tensor_parallel: bool, parallel_context: ParallelContext, + expert_context: ExpertContext ): super().__init__() self.router = router @@ -31,6 +33,7 @@ def __init__( self._experts = Experts(self.num_local_experts, expert, enable_tensor_parallel, parallel_context) self.parallel_context = parallel_context + self.expert_context = expert_context @property def experts(self) -> nn.ModuleList: @@ -39,6 +42,8 @@ def experts(self) -> nn.ModuleList: def forward(self, *args, **kwargs) -> TensorType["batch_size", "seq_len", "d_model"]: # TODO: use torch.fx to extract the inputs from args, and kwargs inputs = args[0] - dispatching_order, _, _ = self.router(inputs) - outputs = self._experts(inputs, dispatching_order, *args, **kwargs) + router_output = self.router(inputs) + self.expert_context.push_aux_loss(router_output.aux_loss) + self.expert_context.push_z_loss(router_output.z_loss) + outputs = self._experts(inputs, router_output.dispatching_order, *args, **kwargs) return outputs diff --git a/pipegoose/nn/expert_parallel/loss.py b/pipegoose/nn/expert_parallel/loss.py index 93bcb85..7b3541e 100644 --- a/pipegoose/nn/expert_parallel/loss.py +++ b/pipegoose/nn/expert_parallel/loss.py @@ -1,12 +1,22 @@ from typing import Callable +from torchtyping import TensorType -import torch +from pipegoose.nn.expert_parallel.expert_context import ExpertContext class ExpertLoss: - def __init__(self, loss: Callable, aux_weight: float): - self.loss = loss + def __init__(self, loss_func: Callable, aux_weight: float, z_weight: float): + self.loss_func = loss_func self.aux_weight = aux_weight + self.z_weight = z_weight + self._expert_context = ExpertContext() - def __call__(self) -> torch.Tensor: - pass + @property + def expert_context(self) -> ExpertContext: + return self._expert_context + + def __call__(self, *args, **kwargs) -> TensorType: + loss = self.loss_func(*args, **kwargs) + loss += self.aux_weight * sum(self._expert_context.pop_all_aux_loss()) + loss += self.z_weight * sum(self._expert_context.pop_all_z_loss()) + return loss diff --git a/pipegoose/nn/expert_parallel/routers.py b/pipegoose/nn/expert_parallel/routers.py index 1997948..b7adec0 100644 --- a/pipegoose/nn/expert_parallel/routers.py +++ b/pipegoose/nn/expert_parallel/routers.py @@ -6,6 +6,7 @@ import torch.nn.functional as F from torch import nn from torchtyping import TensorType +from dataclasses import dataclass class RouterExplorationNoisePolicy(ABC): @@ -32,6 +33,14 @@ def sample_like(self, input: TensorType) -> TensorType: return noise +@dataclass +class RouterOutput: + dispatching_order: TensorType["batch_size * seq_len", "num_experts"] + weight: TensorType["batch_size * seq_len", "num_experts"] + aux_loss: TensorType["1"] + z_loss: TensorType["1"] + + class Router(ABC, nn.Module): pass @@ -93,9 +102,7 @@ def _expert_capacity(self, total_tokens: int) -> int: def forward( self, inputs: TensorType["batch_size", "seq_len", "d_model"] - ) -> Tuple[ - TensorType["batch_size*seq_len", "num_experts"], TensorType["batch_size*seq_len", "num_experts"], TensorType["1"] - ]: + ) -> RouterOutput: orig_dtype = inputs.dtype total_tokens = inputs.shape[0] * inputs.shape[1] @@ -115,15 +122,19 @@ def forward( topk_expert_mask = topk_expert_mask.scatter_(1, topk_idxs, True) # calculate router loss - loss = self.aux_loss_weight * self._aux_loss(router_prob, topk_expert_mask) + self.z_loss_weight * self._z_loss( - router_logits - ) + aux_loss = self._aux_loss(router_prob, topk_expert_mask) + z_loss = self._z_loss(router_logits) if not self.expert_capacity: # we don't limit the capacity of the experts topk_weight = router_prob * topk_expert_mask topk_weight = topk_weight.to(orig_dtype) - return topk_expert_mask, topk_weight, loss + return RouterOutput( + dispatching_order=topk_expert_mask, + weight=topk_weight, + aux_loss=aux_loss, + z_loss=z_loss + ) # limit the number of tokens per expert position_in_expert = torch.cumsum(topk_expert_mask, dim=0) * topk_expert_mask @@ -137,7 +148,12 @@ def forward( topk_weight = router_prob * capacity_limited_topk_expert_mask topk_weight = topk_weight.to(orig_dtype) - return capacity_limited_topk_expert_mask, topk_weight, loss + return RouterOutput( + dispatching_order=capacity_limited_topk_expert_mask, + weight=topk_weight, + aux_loss=aux_loss, + z_loss=z_loss + ) class Top1Router(_TopKRouter): diff --git a/tests/nn/expert_parallel/test_expert_context.py b/tests/nn/expert_parallel/test_expert_context.py new file mode 100644 index 0000000..9973d89 --- /dev/null +++ b/tests/nn/expert_parallel/test_expert_context.py @@ -0,0 +1,17 @@ +from pipegoose.nn.expert_parallel import ExpertContext + + +def test_expert_context(): + expert_context = ExpertContext() + + expert_context.push_aux_loss(1.01) + expert_context.push_z_loss(2.01) + + expert_context.push_aux_loss(1.02) + expert_context.push_z_loss(2.02) + + assert expert_context.pop_all_aux_loss() == [1.01, 1.02] + assert expert_context.pop_all_aux_loss() == [] + + assert expert_context.pop_all_z_loss() == [2.01, 2.02] + assert expert_context.pop_all_z_loss() == [] diff --git a/tests/nn/expert_parallel/test_expert_loss.py b/tests/nn/expert_parallel/test_expert_loss.py index 66f25f3..bde4dab 100644 --- a/tests/nn/expert_parallel/test_expert_loss.py +++ b/tests/nn/expert_parallel/test_expert_loss.py @@ -1,24 +1,34 @@ +import torch from torch import nn +import torch.nn.functional as F from pipegoose.nn.expert_parallel import ExpertLoss def test_expert_loss(): - loss_func = nn.CrossEntropyLoss() + torch.manual_seed(42) + logits = torch.randn((10, 5)) + gt = torch.randn((10, 5)) - expert_loss = ExpertLoss(loss_func, aux_weight=0.1) + loss_func = nn.MSELoss() + + expert_loss = ExpertLoss(loss_func, aux_weight=0.1, z_weight=0.2) + expert_context = expert_loss.expert_context assert expert_loss.aux_weight == 0.1 + assert expert_loss.z_weight == 0.2 assert expert_loss.loss_func == loss_func - ExpertLoss.add_aux_loss(1.01) - ExpertLoss.add_z_loss(2.01) + expert_context.push_aux_loss(1.01) + expert_context.push_z_loss(2.01) + + expert_context.push_aux_loss(1.02) + expert_context.push_z_loss(2.02) - assert expert_loss.get_aux_loss() == [1.01] - assert expert_loss.get_z_loss() == [2.01] + expected_loss = F.mse_loss(logits, gt) + 0.1 * (1.01 + 1.02) + 0.2 * (2.01 + 2.02) + loss = expert_loss(logits, gt) - ExpertLoss.add_aux_loss(1.02) - ExpertLoss.add_z_loss(2.02) + assert torch.allclose(loss, expected_loss) - assert expert_loss.get_aux_loss() == [1.01, 1.02] - assert expert_loss.get_z_loss() == [2.01, 2.02] + assert expert_context.aux_loss == [] + assert expert_context.z_loss == [] diff --git a/tests/nn/expert_parallel/test_expert_parallel.py b/tests/nn/expert_parallel/test_expert_parallel.py index 5d5bcd4..795f079 100644 --- a/tests/nn/expert_parallel/test_expert_parallel.py +++ b/tests/nn/expert_parallel/test_expert_parallel.py @@ -4,6 +4,7 @@ import numpy as np import pytest import torch +import torch.nn as nn from torch.optim import Adam from transformers import AutoTokenizer, BloomConfig, BloomForCausalLM @@ -11,6 +12,8 @@ from pipegoose.nn.expert_parallel.layers import ExpertLayer from pipegoose.nn.expert_parallel.utils import get_num_local_experts from pipegoose.testing.utils import init_parallel_context, spawn +from pipegoose.nn.expert_parallel.routers import RouterOutput +from pipegoose.nn.expert_parallel.loss import ExpertLoss MODEL_NAME = "bigscience/bloom-560m" @@ -21,7 +24,12 @@ def __init__(self, num_experts): def __call__(self, inputs): n_tokens = inputs.shape[0] * inputs.shape[1] - return torch.randint(0, self.num_experts, (n_tokens,)), None, None + return RouterOutput( + torch.randint(0, self.num_experts, (n_tokens,)), + None, + torch.tensor(0.0), + torch.tensor(0.0), + ) @pytest.fixture @@ -77,6 +85,8 @@ def log_routed_expert(module, grad_input, grad_output, key): key = (layer_idx, expert_idx) expert.register_backward_hook(partial(log_routed_expert, key=key)) + loss_func = ExpertLoss(nn.CrossEntropyLoss(), aux_weight=0.1, z_weight=0.1) + parallel_context = init_parallel_context( rank, world_size, @@ -91,6 +101,7 @@ def log_routed_expert(module, grad_input, grad_output, key): mapping=mapping, router=router, parallel_context=parallel_context, + expert_context=loss_func.expert_context ).parallelize() optim = Adam(model.parameters(), lr=1e-3) @@ -117,14 +128,19 @@ def log_routed_expert(module, grad_input, grad_output, key): # NOTE: we haven't go through any weight update yet # so the logits should be the same - outputs = model(**kwargs["input"], labels=kwargs["labels"]) + outputs = model(**kwargs["input"]) + + # compute the loss + logits = outputs.logits[..., :-1, :].view(-1, outputs.logits.shape[-1]) + labels = kwargs["labels"][..., 1:].view(-1).to(logits.device) + loss = loss_func(logits, labels) # assert torch.allclose(outputs.logits, REF_LOGITS) assert outputs.logits.shape == REF_LOGITS.shape - assert torch.allclose(outputs.loss, REF_LOSS) + assert torch.allclose(loss, REF_LOSS) optim.zero_grad() - outputs.loss.backward() + loss.backward() optim.step() # NOTE: After the backward pass, check if the gradients flowing to the routed experts diff --git a/tests/nn/expert_parallel/test_layers.py b/tests/nn/expert_parallel/test_layers.py index bb64a44..eee97a3 100644 --- a/tests/nn/expert_parallel/test_layers.py +++ b/tests/nn/expert_parallel/test_layers.py @@ -6,6 +6,8 @@ from pipegoose.distributed.parallel_mode import ParallelMode from pipegoose.nn.expert_parallel.layers import ExpertLayer from pipegoose.testing.utils import count_model_parameters, init_parallel_context, spawn +from pipegoose.nn.expert_parallel.expert_context import ExpertContext +from pipegoose.nn.expert_parallel.routers import RouterOutput class DummyRouter: @@ -14,7 +16,12 @@ def __init__(self, num_experts): def __call__(self, inputs): n_tokens = inputs.shape[0] * inputs.shape[1] - return torch.randint(0, self.num_experts, (n_tokens,)), None, None + return RouterOutput( + torch.randint(0, self.num_experts, (n_tokens,)), + None, + None, + None + ) def run_expert_layer( @@ -28,6 +35,7 @@ def run_expert_layer( num_experts, expert, router, + expert_context, enable_tensor_parallel, ): parallel_context = init_parallel_context( @@ -48,6 +56,7 @@ def run_expert_layer( router, enable_tensor_parallel, parallel_context, + expert_context ) local_param_count = count_model_parameters(expert_layer) @@ -78,6 +87,7 @@ def test_expert_layer(tensor_parallel_size, num_experts, enable_tensor_parallel) nn.Linear(HIDDEN_SIZE * 4, HIDDEN_SIZE), ) router = DummyRouter(num_experts) + expert_context = ExpertContext() spawn( run_expert_layer, @@ -89,5 +99,6 @@ def test_expert_layer(tensor_parallel_size, num_experts, enable_tensor_parallel) num_experts=num_experts, expert=expert, router=router, + expert_context=expert_context, enable_tensor_parallel=enable_tensor_parallel, ) diff --git a/tests/nn/expert_parallel/test_routers.py b/tests/nn/expert_parallel/test_routers.py index b715686..7aa9f0f 100644 --- a/tests/nn/expert_parallel/test_routers.py +++ b/tests/nn/expert_parallel/test_routers.py @@ -16,11 +16,12 @@ def run_topk_router( input = torch.randn(batch_size, seq_len, d_model, requires_grad=True) - dispatch_order, gate_values, loss = router(input) + router_output = router(input) - assert dispatch_order.shape == (batch_size*seq_len, num_experts) - assert gate_values.shape == (batch_size*seq_len, num_experts) - assert loss.shape == () + assert router_output.dispatching_order.shape == (batch_size*seq_len, num_experts) + assert router_output.weight.shape == (batch_size*seq_len, num_experts) + assert router_output.aux_loss.shape == () + assert router_output.z_loss.shape == () total_tokens = batch_size * seq_len @@ -28,19 +29,21 @@ def run_topk_router( expert_capacity = router._expert_capacity(total_tokens) for expert_id in range(num_experts): - assert dispatch_order[..., expert_id].sum().item() < expert_capacity + assert router_output.dispatching_order[..., expert_id].sum().item() < expert_capacity for token_id in range(total_tokens): - assert dispatch_order[token_id, ...].sum().item() <= top_k + assert router_output.dispatching_order[token_id, ...].sum().item() <= top_k else: for token_id in range(total_tokens): - assert dispatch_order[token_id, ...].sum().item() == top_k + assert router_output.dispatching_order[token_id, ...].sum().item() == top_k # test backwardpass - target_gate_values = torch.randn_like(gate_values) # Random target for testing - loss += F.mse_loss(gate_values, target_gate_values) + target_weight = torch.randn_like(router_output.weight) # Random target for testing + + loss = router_output.aux_loss + router_output.z_loss + loss += F.mse_loss(router_output.weight, target_weight) loss.backward() From 23716881f07f8ddc5cc1636da5706eb6dc8d3591 Mon Sep 17 00:00:00 2001 From: Daniel Grittner Date: Wed, 29 Nov 2023 00:04:49 +0100 Subject: [PATCH 2/3] [Refactor] Make expert context a singleton --- pipegoose/nn/expert_parallel/__init__.py | 1 - pipegoose/nn/expert_parallel/expert_context.py | 9 +++++++++ pipegoose/nn/expert_parallel/expert_parallel.py | 8 ++------ pipegoose/nn/expert_parallel/layers.py | 9 ++++----- pipegoose/nn/expert_parallel/loss.py | 6 +++--- tests/nn/expert_parallel/test_expert_context.py | 7 +++++-- tests/nn/expert_parallel/test_expert_loss.py | 3 ++- tests/nn/expert_parallel/test_expert_parallel.py | 3 +-- tests/nn/expert_parallel/test_layers.py | 7 +------ 9 files changed, 27 insertions(+), 26 deletions(-) diff --git a/pipegoose/nn/expert_parallel/__init__.py b/pipegoose/nn/expert_parallel/__init__.py index 1c55c66..a6f4ed9 100644 --- a/pipegoose/nn/expert_parallel/__init__.py +++ b/pipegoose/nn/expert_parallel/__init__.py @@ -1,4 +1,3 @@ -from pipegoose.nn.expert_parallel.expert_context import ExpertContext from pipegoose.nn.expert_parallel.expert_parallel import ExpertParallel from pipegoose.nn.expert_parallel.loss import ExpertLoss from pipegoose.nn.expert_parallel.routers import Top1Router, Top2Router, SwitchNoisePolicy diff --git a/pipegoose/nn/expert_parallel/expert_context.py b/pipegoose/nn/expert_parallel/expert_context.py index 23acb7d..ad760fb 100644 --- a/pipegoose/nn/expert_parallel/expert_context.py +++ b/pipegoose/nn/expert_parallel/expert_context.py @@ -1,9 +1,12 @@ +from __future__ import annotations from typing import List from torchtyping import TensorType class ExpertContext: + _instance = None + def __init__(self): self.aux_loss = [] self.z_loss = [] @@ -21,3 +24,9 @@ def push_z_loss(self, z_loss: TensorType): def pop_all_z_loss(self) -> List[TensorType]: z_loss, self.z_loss = self.z_loss, [] return z_loss + + @classmethod + def get_instance(cls) -> ExpertContext: + if not cls._instance: + cls._instance = ExpertContext() + return cls._instance diff --git a/pipegoose/nn/expert_parallel/expert_parallel.py b/pipegoose/nn/expert_parallel/expert_parallel.py index 8191033..263ceaa 100644 --- a/pipegoose/nn/expert_parallel/expert_parallel.py +++ b/pipegoose/nn/expert_parallel/expert_parallel.py @@ -8,7 +8,6 @@ from pipegoose.distributed.parallel_mode import ParallelMode from pipegoose.nn.expert_parallel.layers import ExpertLayer from pipegoose.nn.parallel import Parallel -from pipegoose.nn.expert_parallel.expert_context import ExpertContext class ExpertParallel(Parallel): @@ -28,8 +27,7 @@ def __init__( router: Union[int, Callable] = 1, # noise_poligy: Union[str, Callable], enable_tensor_parallelism: bool = False, - parallel_context: ParallelContext = None, - expert_context: ExpertContext = None + parallel_context: ParallelContext = None ): tensor_parallel_size = parallel_context.get_world_size(ParallelMode.TENSOR) assert parallel_context is not None, "parallel_context must be provided" @@ -51,7 +49,6 @@ def __init__( # self.noise_policy = noise_poligy self.enable_tensor_parallelism = enable_tensor_parallelism self.parallel_context = parallel_context - self.expert_context = expert_context @torch.no_grad() def parallelize(self) -> nn.Module: @@ -67,8 +64,7 @@ def parallelize(self) -> nn.Module: module if self.expert is None else self.expert, self.router, self.enable_tensor_parallelism, - self.parallel_context, - self.expert_context + self.parallel_context ) getattr(self.module, "transformer").h[layer_idx].mlp = expert_layer diff --git a/pipegoose/nn/expert_parallel/layers.py b/pipegoose/nn/expert_parallel/layers.py index e9c4970..8bfb98c 100644 --- a/pipegoose/nn/expert_parallel/layers.py +++ b/pipegoose/nn/expert_parallel/layers.py @@ -21,8 +21,7 @@ def __init__( expert: nn.Module, router: Router, enable_tensor_parallel: bool, - parallel_context: ParallelContext, - expert_context: ExpertContext + parallel_context: ParallelContext ): super().__init__() self.router = router @@ -33,7 +32,6 @@ def __init__( self._experts = Experts(self.num_local_experts, expert, enable_tensor_parallel, parallel_context) self.parallel_context = parallel_context - self.expert_context = expert_context @property def experts(self) -> nn.ModuleList: @@ -43,7 +41,8 @@ def forward(self, *args, **kwargs) -> TensorType["batch_size", "seq_len", "d_mod # TODO: use torch.fx to extract the inputs from args, and kwargs inputs = args[0] router_output = self.router(inputs) - self.expert_context.push_aux_loss(router_output.aux_loss) - self.expert_context.push_z_loss(router_output.z_loss) + expert_context = ExpertContext.get_instance() + expert_context.push_aux_loss(router_output.aux_loss) + expert_context.push_z_loss(router_output.z_loss) outputs = self._experts(inputs, router_output.dispatching_order, *args, **kwargs) return outputs diff --git a/pipegoose/nn/expert_parallel/loss.py b/pipegoose/nn/expert_parallel/loss.py index 7b3541e..ad99d64 100644 --- a/pipegoose/nn/expert_parallel/loss.py +++ b/pipegoose/nn/expert_parallel/loss.py @@ -9,7 +9,6 @@ def __init__(self, loss_func: Callable, aux_weight: float, z_weight: float): self.loss_func = loss_func self.aux_weight = aux_weight self.z_weight = z_weight - self._expert_context = ExpertContext() @property def expert_context(self) -> ExpertContext: @@ -17,6 +16,7 @@ def expert_context(self) -> ExpertContext: def __call__(self, *args, **kwargs) -> TensorType: loss = self.loss_func(*args, **kwargs) - loss += self.aux_weight * sum(self._expert_context.pop_all_aux_loss()) - loss += self.z_weight * sum(self._expert_context.pop_all_z_loss()) + expert_context = ExpertContext.get_instance() + loss += self.aux_weight * sum(expert_context.pop_all_aux_loss()) + loss += self.z_weight * sum(expert_context.pop_all_z_loss()) return loss diff --git a/tests/nn/expert_parallel/test_expert_context.py b/tests/nn/expert_parallel/test_expert_context.py index 9973d89..843baab 100644 --- a/tests/nn/expert_parallel/test_expert_context.py +++ b/tests/nn/expert_parallel/test_expert_context.py @@ -1,8 +1,8 @@ -from pipegoose.nn.expert_parallel import ExpertContext +from pipegoose.nn.expert_parallel.expert_context import ExpertContext def test_expert_context(): - expert_context = ExpertContext() + expert_context = ExpertContext.get_instance() expert_context.push_aux_loss(1.01) expert_context.push_z_loss(2.01) @@ -10,6 +10,9 @@ def test_expert_context(): expert_context.push_aux_loss(1.02) expert_context.push_z_loss(2.02) + # make sure that we have a singleton! + expert_context = ExpertContext.get_instance() + assert expert_context.pop_all_aux_loss() == [1.01, 1.02] assert expert_context.pop_all_aux_loss() == [] diff --git a/tests/nn/expert_parallel/test_expert_loss.py b/tests/nn/expert_parallel/test_expert_loss.py index bde4dab..6f4128f 100644 --- a/tests/nn/expert_parallel/test_expert_loss.py +++ b/tests/nn/expert_parallel/test_expert_loss.py @@ -3,6 +3,7 @@ import torch.nn.functional as F from pipegoose.nn.expert_parallel import ExpertLoss +from pipegoose.nn.expert_parallel.expert_context import ExpertContext def test_expert_loss(): @@ -13,7 +14,7 @@ def test_expert_loss(): loss_func = nn.MSELoss() expert_loss = ExpertLoss(loss_func, aux_weight=0.1, z_weight=0.2) - expert_context = expert_loss.expert_context + expert_context = ExpertContext.get_instance() assert expert_loss.aux_weight == 0.1 assert expert_loss.z_weight == 0.2 diff --git a/tests/nn/expert_parallel/test_expert_parallel.py b/tests/nn/expert_parallel/test_expert_parallel.py index 795f079..0d5b3ef 100644 --- a/tests/nn/expert_parallel/test_expert_parallel.py +++ b/tests/nn/expert_parallel/test_expert_parallel.py @@ -100,8 +100,7 @@ def log_routed_expert(module, grad_input, grad_output, key): NUM_EXPERTS, mapping=mapping, router=router, - parallel_context=parallel_context, - expert_context=loss_func.expert_context + parallel_context=parallel_context ).parallelize() optim = Adam(model.parameters(), lr=1e-3) diff --git a/tests/nn/expert_parallel/test_layers.py b/tests/nn/expert_parallel/test_layers.py index eee97a3..8329a43 100644 --- a/tests/nn/expert_parallel/test_layers.py +++ b/tests/nn/expert_parallel/test_layers.py @@ -6,7 +6,6 @@ from pipegoose.distributed.parallel_mode import ParallelMode from pipegoose.nn.expert_parallel.layers import ExpertLayer from pipegoose.testing.utils import count_model_parameters, init_parallel_context, spawn -from pipegoose.nn.expert_parallel.expert_context import ExpertContext from pipegoose.nn.expert_parallel.routers import RouterOutput @@ -35,7 +34,6 @@ def run_expert_layer( num_experts, expert, router, - expert_context, enable_tensor_parallel, ): parallel_context = init_parallel_context( @@ -55,8 +53,7 @@ def run_expert_layer( expert, router, enable_tensor_parallel, - parallel_context, - expert_context + parallel_context ) local_param_count = count_model_parameters(expert_layer) @@ -87,7 +84,6 @@ def test_expert_layer(tensor_parallel_size, num_experts, enable_tensor_parallel) nn.Linear(HIDDEN_SIZE * 4, HIDDEN_SIZE), ) router = DummyRouter(num_experts) - expert_context = ExpertContext() spawn( run_expert_layer, @@ -99,6 +95,5 @@ def test_expert_layer(tensor_parallel_size, num_experts, enable_tensor_parallel) num_experts=num_experts, expert=expert, router=router, - expert_context=expert_context, enable_tensor_parallel=enable_tensor_parallel, ) From c28fc89865054c45dcd61fbeaccb98f8b64b155f Mon Sep 17 00:00:00 2001 From: xrsrke Date: Wed, 29 Nov 2023 12:00:04 +0700 Subject: [PATCH 3/3] [Refactor] Add testing for ExpertParallel with top1 routing --- pipegoose/nn/expert_parallel/loss.py | 5 +- .../expert_parallel/test_expert_parallel.py | 105 ++++++++++++++++-- tests/nn/expert_parallel/test_routers.py | 55 ++------- 3 files changed, 105 insertions(+), 60 deletions(-) diff --git a/pipegoose/nn/expert_parallel/loss.py b/pipegoose/nn/expert_parallel/loss.py index ad99d64..27ecc17 100644 --- a/pipegoose/nn/expert_parallel/loss.py +++ b/pipegoose/nn/expert_parallel/loss.py @@ -1,4 +1,5 @@ from typing import Callable + from torchtyping import TensorType from pipegoose.nn.expert_parallel.expert_context import ExpertContext @@ -10,10 +11,6 @@ def __init__(self, loss_func: Callable, aux_weight: float, z_weight: float): self.aux_weight = aux_weight self.z_weight = z_weight - @property - def expert_context(self) -> ExpertContext: - return self._expert_context - def __call__(self, *args, **kwargs) -> TensorType: loss = self.loss_func(*args, **kwargs) expert_context = ExpertContext.get_instance() diff --git a/tests/nn/expert_parallel/test_expert_parallel.py b/tests/nn/expert_parallel/test_expert_parallel.py index 0d5b3ef..c4f37e9 100644 --- a/tests/nn/expert_parallel/test_expert_parallel.py +++ b/tests/nn/expert_parallel/test_expert_parallel.py @@ -10,10 +10,14 @@ from pipegoose.nn import ExpertParallel from pipegoose.nn.expert_parallel.layers import ExpertLayer +from pipegoose.nn.expert_parallel.loss import ExpertLoss +from pipegoose.nn.expert_parallel.routers import ( + RouterOutput, + SwitchNoisePolicy, + Top1Router, +) from pipegoose.nn.expert_parallel.utils import get_num_local_experts from pipegoose.testing.utils import init_parallel_context, spawn -from pipegoose.nn.expert_parallel.routers import RouterOutput -from pipegoose.nn.expert_parallel.loss import ExpertLoss MODEL_NAME = "bigscience/bloom-560m" @@ -95,13 +99,7 @@ def log_routed_expert(module, grad_input, grad_output, key): pipeline_parallel_size, data_parallel_size, ) - model = ExpertParallel( - model, - NUM_EXPERTS, - mapping=mapping, - router=router, - parallel_context=parallel_context - ).parallelize() + model = ExpertParallel(model, NUM_EXPERTS, mapping=mapping, router=router, parallel_context=parallel_context).parallelize() optim = Adam(model.parameters(), lr=1e-3) # NOTE: check the specified layers are replaced with expert layers @@ -129,13 +127,15 @@ def log_routed_expert(module, grad_input, grad_output, key): # so the logits should be the same outputs = model(**kwargs["input"]) + assert all(key in outputs for key in ["logits", "past_key_values"]) + # NOTE: why so high tolerance? + assert torch.allclose(outputs.logits, REF_LOGITS, rtol=1e-1) + # compute the loss logits = outputs.logits[..., :-1, :].view(-1, outputs.logits.shape[-1]) labels = kwargs["labels"][..., 1:].view(-1).to(logits.device) loss = loss_func(logits, labels) - # assert torch.allclose(outputs.logits, REF_LOGITS) - assert outputs.logits.shape == REF_LOGITS.shape assert torch.allclose(loss, REF_LOSS) optim.zero_grad() @@ -189,3 +189,86 @@ def test_expert_parallel(model, tokenizer, tensor_parallel_size, num_experts): data_parallel_size=DATA_PARALLEL_SIZE, kwargs=kwargs, ) + + +def run_expert_parallel_with_top1_router( + rank, + world_size, + port, + tensor_parallel_size, + pipeline_parallel_size, + data_parallel_size, + kwargs, +): + model = kwargs["model"] + mapping = kwargs["mapping"] + router = kwargs["router"] + NUM_EXPERTS = kwargs["num_experts"] + + # TODO: remove after adding seed to parallel_context + random.seed(42) + np.random.seed(42) + torch.manual_seed(42) + + parallel_context = init_parallel_context( + rank, + world_size, + port, + tensor_parallel_size, + pipeline_parallel_size, + data_parallel_size, + ) + model = ExpertParallel(model, NUM_EXPERTS, mapping=mapping, router=router, parallel_context=parallel_context).parallelize() + loss_func = ExpertLoss(nn.CrossEntropyLoss(), aux_weight=0.1, z_weight=0.1) + optim = Adam(model.parameters(), lr=1e-3) + + outputs = model(**kwargs["input"]) + + assert all(key in outputs for key in ["logits", "past_key_values"]) + + logits = outputs.logits[..., :-1, :].view(-1, outputs.logits.shape[-1]) + labels = kwargs["labels"][..., 1:].view(-1).to(logits.device) + loss = loss_func(logits, labels) + + assert isinstance(loss, torch.Tensor) + + optim.zero_grad() + loss.backward() + optim.step() + + +def test_expert_parallel_with_top1_router(model, tokenizer): + TENSOR_PARALLEL_SIZE = 2 + PIPELINE_PARALLEL_SIZE = 1 + DATA_PARALLEL_SIZE = 1 + WORLD_SIZE = TENSOR_PARALLEL_SIZE * PIPELINE_PARALLEL_SIZE * DATA_PARALLEL_SIZE + + NUM_EXPERTS = 2 + NUM_EXPERT_LAYERS = 2 + NUM_LAYERS = model.config.num_hidden_layers + D_MODEL = model.config.hidden_size + + mapping = [layer_idx for layer_idx in random.sample(range(NUM_LAYERS - 1), NUM_EXPERT_LAYERS)] + noise_policy = SwitchNoisePolicy() + router = Top1Router(noise_policy, NUM_EXPERTS, D_MODEL) + + text = "Persistence is all you need." + input = tokenizer(text, return_tensors="pt") + + kwargs = { + "input": input, + "labels": input["input_ids"], + "model": model, + "mapping": mapping, + "num_experts": NUM_EXPERTS, + "router": router, + } + + spawn( + run_expert_parallel_with_top1_router, + world_size=WORLD_SIZE, + tensor_parallel_size=TENSOR_PARALLEL_SIZE, + pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, + data_parallel_size=DATA_PARALLEL_SIZE, + kwargs=kwargs, + ) diff --git a/tests/nn/expert_parallel/test_routers.py b/tests/nn/expert_parallel/test_routers.py index 7aa9f0f..9f3dbd7 100644 --- a/tests/nn/expert_parallel/test_routers.py +++ b/tests/nn/expert_parallel/test_routers.py @@ -1,25 +1,18 @@ import torch import torch.nn.functional as F -from pipegoose.nn.expert_parallel import Top1Router, Top2Router, SwitchNoisePolicy +from pipegoose.nn.expert_parallel import SwitchNoisePolicy, Top1Router, Top2Router -def run_topk_router( - router, - batch_size, - seq_len, - d_model, - num_experts, - top_k -): +def run_topk_router(router, batch_size, seq_len, d_model, num_experts, top_k): router.train() input = torch.randn(batch_size, seq_len, d_model, requires_grad=True) router_output = router(input) - assert router_output.dispatching_order.shape == (batch_size*seq_len, num_experts) - assert router_output.weight.shape == (batch_size*seq_len, num_experts) + assert router_output.dispatching_order.shape == (batch_size * seq_len, num_experts) + assert router_output.weight.shape == (batch_size * seq_len, num_experts) assert router_output.aux_loss.shape == () assert router_output.z_loss.shape == () @@ -38,9 +31,9 @@ def run_topk_router( for token_id in range(total_tokens): assert router_output.dispatching_order[token_id, ...].sum().item() == top_k - # test backwardpass + # test backward pass - target_weight = torch.randn_like(router_output.weight) # Random target for testing + target_weight = torch.randn_like(router_output.weight) # Random target for testing loss = router_output.aux_loss + router_output.z_loss loss += F.mse_loss(router_output.weight, target_weight) @@ -62,14 +55,7 @@ def test_top1_router(): noise_policy = SwitchNoisePolicy() top1_router = Top1Router(noise_policy, NUM_EXPERTS, D_MODEL) - run_topk_router( - top1_router, - BATCH_SIZE, - SEQ_LEN, - D_MODEL, - NUM_EXPERTS, - top_k=1 - ) + run_topk_router(top1_router, BATCH_SIZE, SEQ_LEN, D_MODEL, NUM_EXPERTS, top_k=1) def test_top1_router_with_expert_capacity(): @@ -79,14 +65,7 @@ def test_top1_router_with_expert_capacity(): noise_policy = SwitchNoisePolicy() top1_router = Top1Router(noise_policy, NUM_EXPERTS, D_MODEL, expert_capacity=(1.0, 2.0)) - run_topk_router( - top1_router, - BATCH_SIZE, - SEQ_LEN, - D_MODEL, - NUM_EXPERTS, - top_k=1 - ) + run_topk_router(top1_router, BATCH_SIZE, SEQ_LEN, D_MODEL, NUM_EXPERTS, top_k=1) def test_top2_router(): @@ -96,14 +75,7 @@ def test_top2_router(): noise_policy = SwitchNoisePolicy() top2_router = Top2Router(noise_policy, NUM_EXPERTS, D_MODEL) - run_topk_router( - top2_router, - BATCH_SIZE, - SEQ_LEN, - D_MODEL, - NUM_EXPERTS, - top_k=2 - ) + run_topk_router(top2_router, BATCH_SIZE, SEQ_LEN, D_MODEL, NUM_EXPERTS, top_k=2) def test_top2_router_with_expert_capacity(): @@ -113,11 +85,4 @@ def test_top2_router_with_expert_capacity(): noise_policy = SwitchNoisePolicy() top2_router = Top2Router(noise_policy, NUM_EXPERTS, D_MODEL, expert_capacity=(1.0, 2.0)) - run_topk_router( - top2_router, - BATCH_SIZE, - SEQ_LEN, - D_MODEL, - NUM_EXPERTS, - top_k=2 - ) + run_topk_router(top2_router, BATCH_SIZE, SEQ_LEN, D_MODEL, NUM_EXPERTS, top_k=2)