Skip to content

Commit

Permalink
feat: support multi cards in ascend graph mode (#2755)
Browse files Browse the repository at this point in the history
* support multi cards in ascend graph mode

* update warning info

* update warning info
  • Loading branch information
tangzhiyi11 authored Nov 14, 2024
1 parent 59c1c63 commit 8e0076a
Showing 1 changed file with 12 additions and 16 deletions.
28 changes: 12 additions & 16 deletions lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig,
super().__init__(model, model_config, cache_config, backend_config,
device)

self.supported_model = ['Llama3-8B', 'Llama2-7B', 'Qwen2-7B']
self.enable_graph = self.check_enable_graph()
if self.enable_graph:
import dlinfer.graph
Expand All @@ -39,26 +38,23 @@ def check_enable_graph(self):
# eager_mode
if self.backend_config.eager_mode:
return False
# tp
if torch.distributed.is_initialized():
warnings.warn(
"Graph mode of device_type 'ascend' only supports tp=1 "
'for now, fallback to eager mode', RuntimeWarning)
return False

warnings.warn(
'\n\n'
'**********************************************************\n'
' The following models were tested in graph mode of\n'
" device_type 'ascend' when tp=1:\n"
f" {', '.join(self.supported_model)}\n"
' Other LLaMa-like models may work in graph mode, please\n'
' check the result yourself!\n'
' If graph mode does not work correctly with your model,\n'
' please use eager mode instead.\n'
'**********************************************************\n\n',
'************************************************************\n'
' Graph mode is an experimental feature. We currently\n'
' support both dense and Mixture of Experts (MoE) models\n'
' with bf16 and fp16 data types.\n'
' If graph mode does not function correctly with your model,\n'
' please consider using eager mode as an alternative.\n'
'************************************************************\n\n',
RuntimeWarning)

# tp
if torch.distributed.is_initialized():
torch._inductor.config.compile_threads = 1
return True

return True

def patch_kernels_custom_op(self):
Expand Down

0 comments on commit 8e0076a

Please sign in to comment.