Skip to content

Commit

Permalink
Fix vl session-len (#1860)
Browse files Browse the repository at this point in the history
  • Loading branch information
AllentDan authored Jun 26, 2024
1 parent b23ba4b commit 9a20a21
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 2 deletions.
3 changes: 2 additions & 1 deletion lmdeploy/archs.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ def get_model_arch(model_path: str):
trust_remote_code=True)
except Exception as e: # noqa
from transformers import PretrainedConfig
cfg = PretrainedConfig.from_pretrained(model_path)
cfg = PretrainedConfig.from_pretrained(model_path,
trust_remote_code=True)

_cfg = cfg.to_dict()
if _cfg.get('architectures', None):
Expand Down
7 changes: 6 additions & 1 deletion lmdeploy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ async def __tmp():
return __inner


# copy from https://github.com/vllm-project/vllm/blob/0650e5935b0f6af35fb2acf71769982c47b804d7/vllm/config.py#L1082-L1150 # noqa
# modified from https://github.com/vllm-project/vllm/blob/0650e5935b0f6af35fb2acf71769982c47b804d7/vllm/config.py#L1082-L1150 # noqa
def _get_and_verify_max_len(
hf_tm_config: Union[PretrainedConfig,
TypeVar('TurbomindModelConfig')],
Expand All @@ -276,6 +276,11 @@ def _get_and_verify_max_len(
session_len = getattr(hf_tm_config, 'session_len')
return max_model_len if max_model_len else session_len

# vl configs hide session-len inside llm configs
llm_keys = ['language_config', 'llm_config']
for key in llm_keys:
hf_tm_config = getattr(hf_tm_config, key, hf_tm_config)

logger = get_logger('lmdeploy')
derived_max_model_len = float('inf')
possible_keys = [
Expand Down
7 changes: 7 additions & 0 deletions tests/test_lmdeploy/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@


def test_get_and_verify_max_len():
# with PretrainedConfig
config = AutoConfig.from_pretrained('OpenGVLab/InternVL-Chat-V1-5-AWQ',
trust_remote_code=True)
assert (_get_and_verify_max_len(config, None) == 98304)
assert (_get_and_verify_max_len(config, 1024) == 1024)
assert (_get_and_verify_max_len(config, 102400) == 102400)

# with PretrainedConfig
config = AutoConfig.from_pretrained('internlm/internlm2-chat-7b',
trust_remote_code=True)
Expand Down

0 comments on commit 9a20a21

Please sign in to comment.