Skip to content

Commit

Permalink
fix torch_dtype (#2933)
Browse files Browse the repository at this point in the history
  • Loading branch information
RunningLeon authored Dec 23, 2024
1 parent 182d1c8 commit 92475b0
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str):
return config

torch_dtype = getattr(config.hf_config, 'torch_dtype', None)
# deal with case when torch_dtype is not string but torch.dtype
if isinstance(torch_dtype, torch.dtype):
torch_dtype = str(torch_dtype).split('.')[1]

if torch_dtype is None:
_dtype = 'float16' if dtype == 'auto' else dtype
logger.warning('Model config does not have `torch_dtype`,'
Expand Down

0 comments on commit 92475b0

Please sign in to comment.