Skip to content

Commit

Permalink
Report first-token-latency and token-latency percentiles (#736)
Browse files Browse the repository at this point in the history
* update profile scripts

* add top_p, top_k and temperature as input arguments

* fix input_ids

* update profile_throughput

* update profile_restful_api

* update profile_serving

* update

* update

* add progress bar

* remove TODO comments

* update

* remove useless profile_* argument

* remove log level

* change concurrency default value to 64

* update restful_api.md

* update according to review comments

* fix docstring
  • Loading branch information
lvhan028 authored Nov 29, 2023
1 parent 8add942 commit 5c9e1e2
Show file tree
Hide file tree
Showing 12 changed files with 711 additions and 476 deletions.
2 changes: 1 addition & 1 deletion benchmark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pip install nvidia-ml-py

```bash
python profile_generation.py \
--model-path /path/to/your/model \
/path/to/your/model \
--concurrency 1 8 --prompt-tokens 1 512 --completion-tokens 2048 512
```

Expand Down
203 changes: 134 additions & 69 deletions benchmark/profile_generation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
# import multiprocessing as mp
import argparse
import csv
import logging
import os
import os.path as osp
import time
from dataclasses import dataclass
from queue import Queue
Expand All @@ -18,37 +16,47 @@
nvmlInit, nvmlShutdown, nvmlSystemGetDriverVersion)
from tqdm import tqdm

from lmdeploy.tokenizer import Tokenizer
from lmdeploy.turbomind import TurboMind


def infer(model, session_id: int, input_ids: str, output_seqlen: int,
def infer(model, session_id: int, input_ids: List, output_seqlen: int,
test_round: int, que: Queue):
chatbot = model.create_instance()
stats = []
for i in range(test_round):
start = time.perf_counter()
timestamps = []
tokens = []
for _ in range(test_round):
token_latency_stats = [0] * (output_seqlen + 1)
prev = time.perf_counter()
n_pre_token = 0
"""
The iterator provided by `stream_infer` denotes the number of generated tokens so far,
which is represented by the variable `n_token`.
Please note that `n_token` is not a continuous value. In other words, during the iteration,
its value might be 5, 7, 8, 16, and so on, rather than 1, 2, 3, 4, etc.
So, it is quite difficult to get the latency of each generated token.
As a work-around, we set the latency `new-prev` of each iteration to the first token of
the new generated tokens, and leave the latency of the rest tokens being 0.
For example, in the first iteration, 5 tokens are generated.
The time elapsing in this iteration `now-prev` is set to the latency of first token of
the 5 tokens, i.e. `token_latency_stats[0]`, and `token_latency_stats[1:4]` is set 0`
""" # noqa: E501
for outputs in chatbot.stream_infer(session_id,
input_ids,
request_output_len=output_seqlen,
sequence_start=True,
sequence_end=True,
ignore_eos=True):
res, token = outputs[0]
timestamps.append(time.perf_counter())
tokens.append(token)

# TODO: ignore first token
first_token_latency = np.round(timestamps[0] - start, 2)
if len(timestamps) == 1:
token_latency = np.round(timestamps[0] - start, 2)
token = tokens[0]
else:
token_latency = np.round(timestamps[-1] - timestamps[0], 2)
token = tokens[-1] - tokens[0]
stats.append([first_token_latency, token, token_latency])
ignore_eos=True,
stream_output=True):
_, n_token = outputs[0]
now = time.perf_counter()
if n_pre_token != n_token:
token_latency_stats[n_pre_token] = np.round(now - prev, 3)
n_pre_token = n_token
prev = now

assert output_seqlen <= n_token <= output_seqlen + 1, \
f'Error. session_id({session_id}) request {output_seqlen} ' \
f'tokens, but generate {n_token} tokens'
stats.append(token_latency_stats[:output_seqlen])
que.put((session_id, stats))


Expand Down Expand Up @@ -93,15 +101,19 @@ def profile_throughput(model_path: str,
input_seqlen: int = 1,
output_seqlen: int = 512,
test_round: int = 10,
tp: int = 1):
tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
tokenizer = Tokenizer(tokenizer_model_path)
tm_model = TurboMind(model_path=model_path, tp=tp)
tp: int = 1,
**kwargs):
# avoid turbomind checking chat template name by setting
# `model_name='llama'`
tm_model = TurboMind(model_path=model_path,
tp=tp,
model_name='llama',
**kwargs)
tokenizer = tm_model.tokenizer

# make up a prompt that can be tokenized into {input_seqlen} tokens
assert input_seqlen > 0, 'input_seqlen should > 0'
prompt = 'hi'
input_ids = tokenizer.encode(prompt)
input_ids = tokenizer('hi').input_ids
input_ids = input_ids * input_seqlen

warmup(tm_model, concurrency, input_ids, output_seqlen)
Expand All @@ -110,7 +122,6 @@ def profile_throughput(model_path: str,
procs = []
_start = time.perf_counter()

# TODO: update to the multithread version
for i in range(concurrency):
proc = Thread(target=infer,
args=(tm_model, i + 1, input_ids, output_seqlen,
Expand All @@ -128,33 +139,49 @@ def profile_throughput(model_path: str,
_end = time.perf_counter()
elapsed_time = _end - _start

stats = []
token_latency_stats = []
while not que.empty():
session_id, _stats = que.get()
print(f'\n{"-" * 50}\n'
f'session {session_id} stats: \n{_stats}\n{"-" * 50}\n')
stats.append(_stats)

stats = np.array(stats).reshape(-1, 3)

first_token_latency_min = np.min(stats[:, 0], axis=0)
first_token_latency_max = np.max(stats[:, 0], axis=0)
first_token_latency_ave = np.mean(stats[:, 0], axis=0)
token_latency_min = np.min(stats[:, 2], axis=0)
token_latency_max = np.max(stats[:, 2], axis=0)
token_latency_ave = np.mean(stats[:, 2], axis=0)
throughput = np.sum(stats[:, 1], axis=0) / np.sum(stats[:, 2],
axis=0) * concurrency
print(f'\n{"-" * 50}\nconcurrency: {concurrency}, input_tokens: '
f'{input_seqlen}, output_tokens: {output_seqlen}\n'
f'elapsed_time: {elapsed_time:.2f}s\n'
_, _stats = que.get()
token_latency_stats += _stats

# The shape is [concurrency*test_round, output_seqlen]
token_latency_stats = np.stack(token_latency_stats, axis=0)

first_token_latency_min = np.round(
np.min(token_latency_stats[:, 0], axis=0), 3)
first_token_latency_max = np.round(
np.max(token_latency_stats[:, 0], axis=0), 3)
first_token_latency_ave = np.round(
np.mean(token_latency_stats[:, 0], axis=0), 3)
token_latency_max = np.round(np.max(np.sum(token_latency_stats, axis=1)),
3)
token_latency_min = np.round(np.min(np.sum(token_latency_stats, axis=1)),
3)
token_latency_ave = np.round(np.mean(np.sum(token_latency_stats, axis=1)),
3)
# sort token_latency without the first token's latency
sorted_token_latency = np.sort(token_latency_stats[:, 1:].flatten())
percentiles = [
np.round(
sorted_token_latency[int(percent * len(sorted_token_latency))], 3)
for percent in [0.5, 0.75, 0.95, 0.99]
]

throughput = np.round(token_latency_stats.size / elapsed_time, 2)
print(f'\n{"-" * 50}\ntotal time: {elapsed_time:.2f}s\n'
f'concurrency: {concurrency}, test_round: {test_round}\n'
f'input_tokens: {input_seqlen}, output_tokens: {output_seqlen}\n'
f'first_token latency(min, max, ave): '
f'{first_token_latency_min:.2f}s, {first_token_latency_max:.2f}s, '
f'{first_token_latency_ave:.2f}s\ntoken latency(min, max, ave): '
f'{token_latency_min:.2f}s, {token_latency_max:.2f}s, '
f'{token_latency_ave:.2f}s\n'
f'throughput: {throughput:.2f} token/s\n{"-" * 50}')
return tm_model.model_name, throughput, tm_model.gpu_count
f'{first_token_latency_min}s, {first_token_latency_max}s, '
f'{first_token_latency_ave}s\ntotal_token latency(min, max, ave): '
f'{token_latency_min}s, {token_latency_max}s, '
f'{token_latency_ave}s\n'
f'token_latency percentiles(50%,75%,95%,99%)(s): {percentiles}\n'
f'throughput: {throughput} token/s\n{"-" * 50}')
return tm_model.model_name, \
[first_token_latency_min, first_token_latency_max,
first_token_latency_ave], \
percentiles, throughput, tm_model.gpu_count


class MemoryMonitor:
Expand Down Expand Up @@ -235,6 +262,8 @@ class ProfileResult:
batch: int
prompt_tokens: int
completion_tokens: int
first_token_latency: List
percentiles: List
throughput_per_proc: float
throughput_per_node: float
mem_per_proc: float
Expand All @@ -244,42 +273,67 @@ class ProfileResult:

def parse_args():
parser = argparse.ArgumentParser(description='Regression Test')
parser.add_argument('--model-path',
parser.add_argument('model_path',
type=str,
help='benchmark test model path')
help='the path of the model in localhost or '
'the repo_id of the model in huggingface.co')
parser.add_argument('--concurrency',
nargs='+',
type=int,
help='how many requests launched concurrently',
default=[1, 8, 16, 32])
default=[1, 16, 32, 64])
parser.add_argument(
'--prompt-tokens',
nargs='+',
type=int,
help='how many requests launched concurrently. One-to-one'
'correspondence with completion-tokens',
default=[64, 512, 512, 1024])
default=[1, 128, 128, 2048, 2048])
parser.add_argument('--completion-tokens',
nargs='+',
type=int,
help='how many tokens to be generated. One-to-one'
'correspondence with prompt-tokens',
default=[512, 512, 1024, 1024])
default=[128, 128, 2048, 128, 2048])
parser.add_argument('--tp', type=int, help='Tensor parallel', default=1)
parser.add_argument('--dst-csv',
parser.add_argument('--top_k',
type=int,
help='The number of highest probability vocabulary '
'tokens to keep for top-k-filtering',
default=1)
parser.add_argument('--top_p',
type=float,
help='the set of most probable tokens with '
'probabilities that add up to top_p or higher '
'are kept for generation',
default=1.0)
parser.add_argument('--temperature',
type=float,
help='The value used to modulate the next token '
'probabilities',
default=1.0)
parser.add_argument('--csv',
type=str,
help='Where to save the result.',
default='profile_generation.csv')
parser.add_argument('--log-level',
help='set log level',
default='INFO',
default='ERROR',
choices=list(logging._nameToLevel.keys()))
parser.add_argument('--test-round',
type=int,
help='number of test rounds',
default=10)
args = parser.parse_args()
return args


def main():
args = parse_args()
assert len(args.prompt_tokens) == len(args.completion_tokens), \
f'mismatched size between `prompt-tokens` and `completion-tokenes`' \
f', {len(args.prompt_tokens)} vs {len(args.completion_tokens)}'

os.environ['TM_LOG_LEVEL'] = args.log_level
results: List[ProfileResult] = []
for batch in tqdm(args.concurrency):
Expand All @@ -292,9 +346,14 @@ def main():
concurrency=batch,
input_seqlen=prompt_tokens,
output_seqlen=completion_tokens,
tp=args.tp)
tp=args.tp,
top_k=args.top_k,
top_p=args.top_p,
temperature=args.temperature,
test_round=args.test_round)
output = Pool(1).map(profile_target, (args.model_path, ))
model_name, throughput_per_proc, tp = output[0]
model_name, first_token_latency, percentiles, \
throughput_per_proc, tp = output[0]
time.sleep(5) # wait a while for releasing GPU mem
memory = MemoryMonitor.terminate()
device_count = MemoryMonitor.device_count.value
Expand All @@ -303,25 +362,31 @@ def main():
batch=batch,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
first_token_latency=first_token_latency,
percentiles=percentiles,
throughput_per_proc=throughput_per_proc,
throughput_per_node=throughput_per_proc / tp *
device_count,
mem_per_proc=memory,
mem_per_gpu=memory / tp,
mem_per_node=memory / tp * device_count))
with open(args.dst_csv, 'w') as csvfile:
with open(args.csv, 'w') as csvfile:
writer = csv.writer(csvfile)
writer.writerow([
'batch', 'prompt_tokens', 'completion_tokens',
'throughput_per_proc(token/s)', 'throughput_per_node(token/s)',
'mem_per_proc(GB)', 'mem_per_gpu(GB)', 'mem_per_node(GB)'
'1st_token_latency(min)(s)', '1st_token_latency(max)(s)',
'1st_token_latency(ave)(s)', 'percentile50(s)', 'percentile75(s)',
'percentile95(s)', 'percentile99(s)', 'throughput(token/s)',
'mem_per_proc(GB)', 'mem_per_gpu(GB)'
])
for re in results:
writer.writerow([
re.batch, re.prompt_tokens, re.completion_tokens,
f'{re.throughput_per_proc:.2f}',
f'{re.throughput_per_node:.2f}', f'{re.mem_per_proc:.2f}',
f'{re.mem_per_gpu:.2f}', f'{re.mem_per_node:.2f}'
re.first_token_latency[0], re.first_token_latency[1],
re.first_token_latency[2], re.percentiles[0],
re.percentiles[1], re.percentiles[2], re.percentiles[3],
f'{re.throughput_per_proc:.2f}', f'{re.mem_per_proc:.2f}',
f'{re.mem_per_gpu:.2f}'
])


Expand Down
Loading

0 comments on commit 5c9e1e2

Please sign in to comment.