Skip to content

Commit

Permalink
Turbomind prefix caching (#1450)
Browse files Browse the repository at this point in the history
* turbomind prefix caching

* update

* fix lint

* refine

* change block_ptrs_ allocate size

* refactor cache

* fix default

* move verify

* verify before cache

* fix format

* fix cache

* fix typo

* add api

* add lint

* format code
  • Loading branch information
ispobock authored May 15, 2024
1 parent 75d0e73 commit 14b6c02
Show file tree
Hide file tree
Showing 18 changed files with 291 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
uses: DoozyX/[email protected]
with:
source: src
extensions: h,c,cpp,hpp,cu,cuh
extensions: h,c,cpp,hpp,cu,cuh,cc
clangFormatVersion: 11
style: file
- name: Check markdown link
Expand Down
10 changes: 8 additions & 2 deletions benchmark/profile_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,13 +336,15 @@ def parse_args():
cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group)
cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group)
session_len_act = ArgumentHelper.session_len(pt_group, default=2048)
prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group)

# turbomind engine args
tb_group = parser.add_argument_group('TurboMind engine argument')
tb_group._group_actions.append(tp_act)
tb_group._group_actions.append(session_len_act)
tb_group._group_actions.append(cache_count_act)
tb_group._group_actions.append(cache_block_seq_len_act)
tb_group._group_actions.append(prefix_caching_act)
ArgumentHelper.model_format(tb_group, default='hf')
args = parser.parse_args()
return args
Expand Down Expand Up @@ -399,14 +401,18 @@ def main():
cache_block_seq_len=args.cache_block_seq_len,
model_format=args.model_format,
session_len=session_len,
tp=args.tp)
tp=args.tp,
enable_prefix_caching=args.enable_prefix_caching,
)
elif args.backend == 'pytorch':
engine_config = PytorchEngineConfig(
cache_max_entry_count=args.cache_max_entry_count,
block_size=args.cache_block_seq_len,
session_len=session_len,
tp=args.tp,
thread_safe=True)
thread_safe=True,
enable_prefix_caching=args.enable_prefix_caching,
)
gen_config = EngineGenerationConfig(
top_k=args.top_k,
top_p=args.top_p,
Expand Down
10 changes: 8 additions & 2 deletions benchmark/profile_pipeline_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,15 @@ def parse_args():
session_len_act = ArgumentHelper.session_len(pt_group, default=4096)
cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group)
cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group)
prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group)

# turbomind engine args
tb_group = parser.add_argument_group('TurboMind engine argument')
tb_group._group_actions.append(tp_act)
tb_group._group_actions.append(session_len_act)
tb_group._group_actions.append(cache_count_act)
tb_group._group_actions.append(cache_block_seq_len_act)
tb_group._group_actions.append(prefix_caching_act)
ArgumentHelper.model_format(tb_group, default='hf')
ArgumentHelper.quant_policy(tb_group, default=0)
ArgumentHelper.num_tokens_per_iter(tb_group)
Expand All @@ -228,15 +230,19 @@ def main():
model_format=args.model_format,
quant_policy=args.quant_policy,
num_tokens_per_iter=args.num_tokens_per_iter,
max_prefill_iters=args.max_prefill_iters)
max_prefill_iters=args.max_prefill_iters,
enable_prefix_caching=args.enable_prefix_caching,
)
elif args.backend == 'pytorch':
engine_config = PytorchEngineConfig(
session_len=args.session_len,
cache_max_entry_count=args.cache_max_entry_count,
block_size=args.cache_block_seq_len,
max_batch_size=args.concurrency,
tp=args.tp,
thread_safe=False)
thread_safe=False,
enable_prefix_caching=args.enable_prefix_caching,
)

engine = Engine(args.model_path, engine_config, csv=args.csv)

Expand Down
7 changes: 5 additions & 2 deletions benchmark/profile_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,14 +283,15 @@ def parse_args():
session_len_act = ArgumentHelper.session_len(pt_group, default=4096)
cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group)
cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group)
ArgumentHelper.enable_prefix_caching(pt_group)
prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group)

# turbomind engine args
tb_group = parser.add_argument_group('TurboMind engine argument')
tb_group._group_actions.append(tp_act)
tb_group._group_actions.append(session_len_act)
tb_group._group_actions.append(cache_count_act)
tb_group._group_actions.append(cache_block_seq_len_act)
tb_group._group_actions.append(prefix_caching_act)
ArgumentHelper.model_format(tb_group, default='hf')
ArgumentHelper.quant_policy(tb_group, default=0)
ArgumentHelper.num_tokens_per_iter(tb_group)
Expand All @@ -314,7 +315,9 @@ def main():
model_format=args.model_format,
quant_policy=args.quant_policy,
num_tokens_per_iter=args.num_tokens_per_iter,
max_prefill_iters=args.max_prefill_iters)
max_prefill_iters=args.max_prefill_iters,
enable_prefix_caching=args.enable_prefix_caching,
)
elif args.backend == 'pytorch':
engine_config = PytorchEngineConfig(
session_len=args.session_len,
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,14 @@ def add_parser_chat():
# pytorch engine args
pt_group = parser.add_argument_group('PyTorch engine arguments')
ArgumentHelper.adapters(pt_group)
ArgumentHelper.enable_prefix_caching(pt_group)

# common engine args
tp_act = ArgumentHelper.tp(pt_group)
model_name_act = ArgumentHelper.model_name(pt_group)
session_len_act = ArgumentHelper.session_len(pt_group)
max_batch_size_act = ArgumentHelper.max_batch_size(pt_group)
cache_max_entry_act = ArgumentHelper.cache_max_entry_count(pt_group)
prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group)

# turbomind args
tb_group = parser.add_argument_group('TurboMind engine arguments')
Expand All @@ -128,6 +128,7 @@ def add_parser_chat():
tb_group._group_actions.append(session_len_act)
tb_group._group_actions.append(max_batch_size_act)
tb_group._group_actions.append(cache_max_entry_act)
tb_group._group_actions.append(prefix_caching_act)
ArgumentHelper.model_format(tb_group)
ArgumentHelper.quant_policy(tb_group)
ArgumentHelper.rope_scaling_factor(tb_group)
Expand Down
14 changes: 10 additions & 4 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def add_parser_gradio():

# pytorch engine args
pt_group = parser.add_argument_group('PyTorch engine arguments')
ArgumentHelper.enable_prefix_caching(pt_group)

# common engine args
tp_act = ArgumentHelper.tp(pt_group)
Expand All @@ -59,6 +58,7 @@ def add_parser_gradio():
max_batch_size_act = ArgumentHelper.max_batch_size(pt_group)
cache_max_entry_act = ArgumentHelper.cache_max_entry_count(pt_group)
cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group)
prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group)

# turbomind args
tb_group = parser.add_argument_group('TurboMind engine arguments')
Expand All @@ -69,6 +69,7 @@ def add_parser_gradio():
tb_group._group_actions.append(max_batch_size_act)
tb_group._group_actions.append(cache_max_entry_act)
tb_group._group_actions.append(cache_block_seq_len_act)
tb_group._group_actions.append(prefix_caching_act)
ArgumentHelper.model_format(tb_group)
ArgumentHelper.quant_policy(tb_group)
ArgumentHelper.rope_scaling_factor(tb_group)
Expand Down Expand Up @@ -138,7 +139,6 @@ def add_parser_api_server():

# pytorch engine args
pt_group = parser.add_argument_group('PyTorch engine arguments')
ArgumentHelper.enable_prefix_caching(pt_group)

ArgumentHelper.adapters(pt_group)
# common engine args
Expand All @@ -148,6 +148,7 @@ def add_parser_api_server():
max_batch_size_act = ArgumentHelper.max_batch_size(pt_group)
cache_max_entry_act = ArgumentHelper.cache_max_entry_count(pt_group)
cache_block_seq_len_act = ArgumentHelper.cache_block_seq_len(pt_group)
prefix_caching_act = ArgumentHelper.enable_prefix_caching(pt_group)

# turbomind args
tb_group = parser.add_argument_group('TurboMind engine arguments')
Expand All @@ -158,6 +159,7 @@ def add_parser_api_server():
tb_group._group_actions.append(max_batch_size_act)
tb_group._group_actions.append(cache_max_entry_act)
tb_group._group_actions.append(cache_block_seq_len_act)
tb_group._group_actions.append(prefix_caching_act)
ArgumentHelper.model_format(tb_group)
ArgumentHelper.quant_policy(tb_group)
ArgumentHelper.rope_scaling_factor(tb_group)
Expand Down Expand Up @@ -233,7 +235,9 @@ def gradio(args):
quant_policy=args.quant_policy,
rope_scaling_factor=args.rope_scaling_factor,
cache_max_entry_count=args.cache_max_entry_count,
cache_block_seq_len=args.cache_block_seq_len)
cache_block_seq_len=args.cache_block_seq_len,
enable_prefix_caching=args.enable_prefix_caching,
)
chat_template_config = ChatTemplateConfig(
model_name=args.model_name,
meta_instruction=args.meta_instruction,
Expand Down Expand Up @@ -283,7 +287,9 @@ def api_server(args):
quant_policy=args.quant_policy,
rope_scaling_factor=args.rope_scaling_factor,
cache_max_entry_count=args.cache_max_entry_count,
cache_block_seq_len=args.cache_block_seq_len)
cache_block_seq_len=args.cache_block_seq_len,
enable_prefix_caching=args.enable_prefix_caching,
)
chat_template_config = None
if args.chat_template:
chat_template_config = ChatTemplateConfig.from_json(
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ class TurbomindEngineConfig:
For versions of lmdeploy between `v0.2.0` and `v0.2.1`, it defaults to 0.5, depicting the percentage of TOTAL GPU memory to be allocated to the k/v cache.
For lmdeploy versions greater than `v0.2.1`, it defaults to 0.8, signifying the percentage of FREE GPU memory to be reserved for the k/v cache
cache_block_seq_len (int): the length of the token sequence in a k/v block, default to 64
enable_prefix_caching (bool): enable cache prompts for block reuse, default to False
quant_policy (int): default to 0. When k/v is quantized into 8 bit, set it to 4
rope_scaling_factor (int): scaling factor used for dynamic ntk, default to 0. TurboMind follows the implementation of transformer LlamaAttention
use_logn_attn (bool): whether or not to use log attn: default to False
Expand All @@ -141,6 +142,7 @@ class TurbomindEngineConfig:
max_batch_size: int = 128
cache_max_entry_count: float = 0.8
cache_block_seq_len: int = 64
enable_prefix_caching: bool = False
quant_policy: int = 0
rope_scaling_factor: float = 0.0
use_logn_attn: bool = False
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/turbomind/deploy/target_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class TurbomindModelConfig:
cache_max_entry_count: float = 0.8
cache_block_seq_len: int = 64
cache_chunk_size: int = -1
enable_prefix_caching: bool = False
num_tokens_per_iter: int = 0
max_prefill_iters: int = 1
extra_tokens_per_iter: int = 0
Expand Down
1 change: 0 additions & 1 deletion src/turbomind/models/llama/BlockManager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,6 @@ int BlockManager::Lock(const BlockIds& ids)

for (const auto& i : ids) {
auto& b = blocks_[i];
FT_CHECK_WITH_INFO(is_cached(b), to_string(b));
if (++b.use_count == 1) {
lock.push_back(i);
FT_CHECK(is_active(b));
Expand Down
128 changes: 128 additions & 0 deletions src/turbomind/models/llama/BlockTrie.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/models/llama/BlockTrie.h"
#include "src/turbomind/models/llama/SequenceManager.h"

namespace turbomind {

size_t hash(const std::vector<int>& vec)
{
size_t seed = vec.size();
for (const auto& i : vec) {
seed ^= std::hash<int>{}(i) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}

BlockTrie::BlockTrie(size_t block_seq_len, std::shared_ptr<BlockManager> block_manager, bool enable_prefix_caching):
block_seq_len_(block_seq_len), block_manager_(block_manager), enable_prefix_caching_(enable_prefix_caching)
{
root_ = std::make_shared<TrieNode>();
}

void BlockTrie::match(Sequence& seq)
{
BlockIds matched_blocks;
UniqueIds matched_unique_ids;

std::shared_ptr<TrieNode> curr_node = root_;
int num_matched = 0;

while (num_matched + block_seq_len_ < seq.prompt.size()) {
std::vector<int> curr_tokens(seq.prompt.begin() + num_matched,
seq.prompt.begin() + num_matched + block_seq_len_);
size_t hash_key = hash(curr_tokens);

auto it = curr_node->children.find(hash_key);

if (it == curr_node->children.end()) {
break;
}

if (curr_tokens != it->second->tokens) {
break;
}

matched_blocks.push_back(it->second->block_id);
matched_unique_ids.push_back(it->second->block_unique_id);
curr_node = it->second;
num_matched += block_seq_len_;
}

if (matched_blocks.size() > 0) {
// add use count
block_manager_->Lock(matched_blocks);
block_manager_->Touch(matched_blocks);
// only consider no history blocks
seq.blocks.insert(seq.blocks.end(), matched_blocks.begin(), matched_blocks.end());
seq.block_unique_ids.insert(seq.block_unique_ids.end(), matched_unique_ids.begin(), matched_unique_ids.end());
}
}

void BlockTrie::cache(const Sequence& seq)
{
std::shared_ptr<TrieNode> curr_node = root_;
int num_matched = 0;
int idx = 0;
BlockIds cached_blocks;

while (num_matched + block_seq_len_ <= seq.prompt.size()) {
std::vector<int> curr_tokens(seq.prompt.begin() + num_matched,
seq.prompt.begin() + num_matched + block_seq_len_);
size_t hash_key = hash(curr_tokens);

auto it = curr_node->children.find(hash_key);

int block_id = seq.blocks[idx];
uint64_t block_unique_id = seq.block_unique_ids[idx];

if (it != curr_node->children.end()) {
if (curr_tokens != it->second->tokens) {
break;
}
curr_node = it->second;
curr_node->block_id = block_id;
curr_node->block_unique_id = block_unique_id;
}
else {
// insert new node
std::shared_ptr<TrieNode> node = std::make_shared<TrieNode>();
node->hash_key = hash_key;
node->tokens = curr_tokens;
node->block_id = block_id;
node->block_unique_id = block_unique_id;
node->num_matched = num_matched + block_seq_len_;
curr_node->children[hash_key] = node;
curr_node = node;
}

cached_blocks.push_back(curr_node->block_id);
num_matched += block_seq_len_;
idx++;
}

block_manager_->Touch(cached_blocks);
}

int BlockTrie::verify()
{
return verify_traverse(root_);
}

int BlockTrie::verify_traverse(std::shared_ptr<TrieNode>& node)
{
int valid_count = 1;
for (auto it = node->children.begin(); it != node->children.end();) {
if (block_manager_->unique_id(it->second->block_id) != it->second->block_unique_id) {
// child invalid
it = node->children.erase(it);
}
else {
valid_count += verify_traverse(it->second);
it++;
}
}
return valid_count;
}

} // namespace turbomind
Loading

0 comments on commit 14b6c02

Please sign in to comment.