Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pytorch engine kv int4/int8 quantization #2438

Merged
merged 18 commits into from
Oct 14, 2024

Conversation

AllentDan
Copy link
Collaborator

Only update internlm and llama model. After #2104, all the models should be updated.

Conflicts:
	lmdeploy/pytorch/config.py
	lmdeploy/pytorch/engine/cache_engine.py
	lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py
	lmdeploy/pytorch/kernels/cuda/pagedattention.py
	lmdeploy/pytorch/models/internlm2.py
	lmdeploy/pytorch/models/llama.py
@AllentDan
Copy link
Collaborator Author

AllentDan commented Sep 13, 2024

Benchmark benchmark/profile_throughput.py on Meta-Llama-3.1-8B-Instruct for pytorch engine with triton 2.3.0.

quant-policy fp16 int8 int4
TTFT(s) 2.354 2.035 2.569
RPS(req/s) 16.767 19.179 16.810

Tested gsm8k accuracy:

dataset version metric mode llama3-chat-8b llama3-chat-8b-kv8 llama3-chat-8b-kv4
gsm8k 1d7fe4 accuracy gen 77.41 77.56 73.24

@lvhan028
Copy link
Collaborator

lvhan028 commented Oct 7, 2024

May resolve the conflicts

@lvhan028 lvhan028 requested review from lvhan028 and grimoire October 7, 2024 14:00
@lvhan028 lvhan028 added the enhancement New feature or request label Oct 7, 2024
Conflicts:
	lmdeploy/pytorch/config.py
	lmdeploy/pytorch/kernels/cuda/pagedattention.py
@grimoire
Copy link
Collaborator

grimoire commented Oct 8, 2024

Since kv int4 requires triton>=2.3.0, It would be cool if we add a check in engine. https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/pytorch/check_env/__init__.py

@grimoire
Copy link
Collaborator

grimoire commented Oct 8, 2024

quant_policy might not be a good argument name since user might misunderstand this as online quant(gemm)

@AllentDan
Copy link
Collaborator Author

quant_policy might not be a good argument name since user might misunderstand this as online quant(gemm)

Yes, but it is the same name as turbomind.

lmdeploy/messages.py Outdated Show resolved Hide resolved
@zhulinJulia24
Copy link
Collaborator

@AllentDan Can you update support models? I will add testcase according to this. https://github.com/InternLM/lmdeploy/blob/main/docs/en/supported_models/supported_models.md

@AllentDan
Copy link
Collaborator Author

I did not test all the models since some models may fail when quant_policy=4. In my tested models, InternLM/internlm2-chat-1_8b, baichuan2/Baichuan2-13B-Chat,Meta-Llama-3.1-8B-Instruct worked while Qwen/Qwen2-1.5B-Instruct failed.

@zhulinJulia24
Copy link
Collaborator

zhulinJulia24 commented Oct 10, 2024

All models supported by pytorch backend and 4bits are tested. Find following errors.

  1. deepseek-ai/DeepSeek-V2-Lite-Chat is not support both on kvint4 and kvint8
    config and error is
pipe = pipeline("/nvme/qa_test_models/deepseek-ai/DeepSeek-V2-Lite-Chat",  backend_config=engine_config)
res = pipe("Hi, pls introduce shanghai")
2024-10-10 17:15:27,701 - lmdeploy - �[31mERROR�[0m - request.py:21 - Engine loop failed with error: 
Traceback (most recent call last):
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/request.py", line 17, in _raise_exception_on_finish
    task.result()
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 947, in async_loop
    await self._async_loop()
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 941, in _async_loop
    await __step()
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 929, in __step
    raise e
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 923, in __step
    raise out
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 857, in _async_loop_background
    await self._async_step_background(
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 739, in _async_step_background
    output = await self._async_model_forward(
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/utils.py", line 239, in __tmp
    return (await func(*args, **kwargs))
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 630, in _async_model_forward
    ret = await __forward(inputs)
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 608, in __forward
    return await self.model_agent.async_forward(
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 303, in async_forward
    output = self._forward_impl(inputs,
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 270, in _forward_impl
    output = model_forward(
  File "/opt/py3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 153, in model_forward
    output = model(**input_dict)
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/backends/cuda/graph_runner.py", line 160, in __call__
    runner.capture(**kwargs)
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/backends/cuda/graph_runner.py", line 77, in capture
    self.model(**padded_kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/models/deepseek_v2.py", line 636, in forward
    hidden_states = self.model(
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/models/deepseek_v2.py", line 591, in forward
    hidden_states, residual = decoder_layer(
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/models/deepseek_v2.py", line 491, in forward
    hidden_states = self.self_attn(
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/models/deepseek_v2.py", line 265, in forward
    attn_output = self.attn_fwd(
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/nn/attention.py", line 67, in forward
    return self.impl.forward(
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/backends/cuda/attention.py", line 109, in forward
    self.paged_attention_fwd(
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/kernels/cuda/pagedattention.py", line 1064, in paged_attention_fwd
    assert Lq == Lk * 2 and Lv * 2 == o.shape[-1]
AssertionError
  1. some models' response is no sense on kvint4 such as
    image
    model list is:
  • microsoft/Phi-3-mini-4k-instruct-inner-4bits
  • microsoft/Phi-3-mini-4k-instruct
  • microsoft/Phi-3-vision-128k-instruct
  • OpenGVLab/InternVL2-4B
  • Qwen/Qwen2-VL-2B-Instruct
  • Qwen/Qwen2-VL-7B-Instruct
  • openbmb/MiniCPM-V-2_6
  1. 【already fixed】Qwen/Qwen2-VL-2B-Instruct and Qwen/Qwen2-VL-7B-Instruct is not support, config and error is
engine_config = PytorchEngineConfig(dtype='auto', tp=1, session_len=None, max_batch_size=None, cache_max_entry_count=0.8, prefill_interval=16, block_size=64, num_cpu_blocks=0, num_gpu_blocks=0, adapters=None, max_prefill_token_num=4096, thread_safe=False, enable_prefix_caching=False, device_type='cuda', eager_mode=False, custom_module_map=None, download_dir=None, revision=None, quant_policy=8)
pipe = pipeline("/nvme/qa_test_models/Qwen/Qwen2-VL-2B-Instruct",  backend_config=engine_config)
res = pipe("Hi, pls introduce shanghai")
2024-10-10 17:17:41,974 - lmdeploy - �[31mERROR�[0m - request.py:21 - Engine loop failed with error: 'NoneType' object has no attribute 'stride'
Traceback (most recent call last):
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/request.py", line 17, in _raise_exception_on_finish
    task.result()
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 947, in async_loop
    await self._async_loop()
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 941, in _async_loop
    await __step()
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 929, in __step
    raise e
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 923, in __step
    raise out
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 857, in _async_loop_background
    await self._async_step_background(
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 739, in _async_step_background
    output = await self._async_model_forward(
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/utils.py", line 239, in __tmp
    return (await func(*args, **kwargs))
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 630, in _async_model_forward
    ret = await __forward(inputs)
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 608, in __forward
    return await self.model_agent.async_forward(
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 303, in async_forward
    output = self._forward_impl(inputs,
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 270, in _forward_impl
    output = model_forward(
  File "/opt/py3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 153, in model_forward
    output = model(**input_dict)
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/backends/cuda/graph_runner.py", line 160, in __call__
    runner.capture(**kwargs)
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/backends/cuda/graph_runner.py", line 77, in capture
    self.model(**padded_kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/models/qwen2_vl.py", line 379, in forward
    hidden_states = self.model(
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/models/qwen2_vl.py", line 318, in forward
    hidden_states, residual = decoder_layer(
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/models/qwen2_vl.py", line 226, in forward
    hidden_states = self.self_attn(
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/models/qwen2_vl.py", line 121, in forward
    attn_output = self.attn_fwd(
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/nn/attention.py", line 67, in forward
    return self.impl.forward(
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/backends/cuda/attention.py", line 86, in forward
    self.fill_kv_cache(
  File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py", line 461, in fill_kv_cache
    stride_kszn=k_scales_zeros.stride(0),
AttributeError: 'NoneType' object has no attribute 'stride'

@zhulinJulia24
Copy link
Collaborator

@AllentDan qwen2-vl-2b and 7b is passed on kvint8.

@lvhan028 lvhan028 merged commit 4126067 into InternLM:main Oct 14, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants