Skip to content

Commit

Permalink
fix TorchEngine stuck when benchmarking with tp>1 (#942)
Browse files Browse the repository at this point in the history
* fix benchmark tp

* slient warning

* fix profile batch

* support torch2.0

* fix dtensor

* fix get type dtype
  • Loading branch information
grimoire authored Jan 22, 2024
1 parent e96e2b4 commit 9ff13ba
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 12 deletions.
34 changes: 29 additions & 5 deletions benchmark/profile_torch_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,11 @@ def profile_throughput(model_path: str, concurrency: int, input_seqlen: int,
from lmdeploy.messages import PytorchEngineConfig
from lmdeploy.pytorch.engine import Engine

tm_model = Engine(model_path, PytorchEngineConfig(model_name='llama',
tp=tp))
tm_model = Engine(
model_path,
PytorchEngineConfig(model_name='llama',
tp=tp,
max_batch_size=concurrency))

# make up a dummy `input_ids` with the length of `input_seqlen` exactly
assert input_seqlen > 0, 'input_seqlen should > 0'
Expand Down Expand Up @@ -342,6 +345,28 @@ def parse_args():
return args


def _process_map(target, iterable):
from multiprocessing import Pipe, Process

def __proc_cb(*args, ret_pipe: Pipe):
try:
ret = target(*args)
ret_pipe[1].send(ret)
except Exception as e:
ret_pipe[1].send(e)

pipe = Pipe(False)
proc = Process(target=__proc_cb, args=iterable, kwargs=dict(ret_pipe=pipe))
proc.start()
proc.join()

ret = pipe[0].recv()
if isinstance(ret, Exception):
raise ret

return ret


def main():
args = parse_args()
assert len(args.prompt_tokens) == len(args.completion_tokens), \
Expand All @@ -355,7 +380,6 @@ def main():
args.completion_tokens):
MemoryMonitor.start()
from functools import partial
from multiprocessing import Pool
profile_target = partial(profile_throughput,
concurrency=batch,
input_seqlen=prompt_tokens,
Expand All @@ -366,9 +390,9 @@ def main():
temperature=args.temperature,
test_round=args.test_round,
warmup_round=args.warmup_round)
output = Pool(1).map(profile_target, (args.model_path, ))
output = _process_map(profile_target, (args.model_path, ))
model_name, first_token_latency, percentiles, \
throughput_per_proc, tp = output[0]
throughput_per_proc, tp = output
time.sleep(5) # wait a while for releasing GPU mem
memory = MemoryMonitor.terminate()
device_count = MemoryMonitor.device_count.value
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def _get_torch_dtype(config: Any, default: str = 'float16'):
default (str): default device type.
"""
torch_dtype = getattr(config, 'torch_dtype', default)
# torch_dtype in config could be none
torch_dtype = torch_dtype or default
return eval(f'torch.{torch_dtype}')


Expand Down
28 changes: 24 additions & 4 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,24 @@ def _tp_build_model(
patched_model = None
cache_engine = None

def __load_params_and_buffers(param_mod, mod):
"""load param and buffer."""
for name, param in param_mod.named_parameters(recurse=False):
mod.register_parameter(name, param)
for name, buffer in param_mod.named_buffers(recurse=False):
mod.register_buffer(name, buffer)

def __load_state_dict_assign(param_model, model):
"""load state dict assign."""
try:
model.load_state_dict(param_model.state_dict(), assign=True)
except Exception:
__load_params_and_buffers(param_model, model)
mods = dict(model.named_modules())
for mod_name, param_mod in param_model.named_modules():
mod = mods[mod_name]
__load_params_and_buffers(param_mod, mod)

def _broadcast_config(cache_config):
"""broadcast cache config, use minimum cache."""
if rank == 0:
Expand Down Expand Up @@ -631,7 +649,7 @@ def _broadcast_config(cache_config):
device_map=device_map,
trust_remote_code=trust_remote_code)
_load_adapters(param_model, adapters, device_map=device_map)
model.load_state_dict(param_model.state_dict(), assign=True)
__load_state_dict_assign(param_model, model)
del param_model

patched_model = patch(
Expand All @@ -654,6 +672,7 @@ def _broadcast_config(cache_config):
rank=rank,
world_size=world_size)
except Exception as e:
logger.error(f'rank[{rank}] failed with error: {e}')
error_code = 1
error_type = e

Expand Down Expand Up @@ -701,6 +720,7 @@ def _tp_get_input(rank: int, in_que: mp.Queue, world_size: int):
device_mesh=device_mesh,
placements=[Replicate()
]).to_local()
torch.cuda.synchronize()

inputs = updated_inputs
inputs.update(other_metas)
Expand Down Expand Up @@ -894,11 +914,11 @@ def __init__(self,
world_size: int,
adapters: Dict[str, str] = None,
trust_remote_code: bool = True) -> None:
mp.set_start_method('spawn')
self.mp_ctx = mp.get_context('spawn')
super().__init__(model_config=model_config, cache_config=cache_config)
self.world_size = world_size
self.tp_model_in_que = mp.Queue(10)
self.tp_model_out_que = mp.Queue(10)
self.tp_model_in_que = self.mp_ctx.Queue(10)
self.tp_model_out_que = self.mp_ctx.Queue(10)

self.patch_model_tp(model_path,
model_config=model_config,
Expand Down
8 changes: 5 additions & 3 deletions lmdeploy/pytorch/engine/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import enum
from dataclasses import dataclass, field
from queue import Empty, Queue
from threading import Lock, Thread, ThreadError
from threading import Lock, Thread
from typing import Any, Callable, ClassVar, Dict, List

from lmdeploy.messages import ResponseType
Expand Down Expand Up @@ -75,7 +75,8 @@ def _resp_que_get(self, block: bool = True, timeout: float = None):
except Empty:
timeout_counter -= self.THREAD_ALIVE_INTERVAL
if self._thread and not self._thread.is_alive():
raise ThreadError('Engine main loop stopped.')
logger.error('Engine main loop stopped.')
exit(1)

return self.resp_que.get(timeout=timeout_counter)

Expand Down Expand Up @@ -110,7 +111,8 @@ def batched_send_async(self, req_types: List[RequestType],
data: List[Any]) -> List[int]:
"""Batched send request asynchronize."""
if self._thread and not self._thread.is_alive():
raise ThreadError('Engine main loop stopped.')
logger.error('Engine main loop stopped.')
exit(1)
assert len(req_types) == len(data)
batch_size = len(req_types)

Expand Down

0 comments on commit 9ff13ba

Please sign in to comment.