Skip to content

Commit

Permalink
fix(mixtral): setup hack atm to load weights from pt specifically
Browse files Browse the repository at this point in the history
instead of safetensors

Signed-off-by: Aaron <[email protected]>
  • Loading branch information
aarnphm committed Dec 13, 2023
1 parent 2dbcfa8 commit 22d7144
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions openllm-python/src/openllm/serialisation/transformers/weights.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
import attr, traceback, pathlib, typing as t
import attr, traceback, functools, pathlib, typing as t
from huggingface_hub import HfApi
from openllm_core.exceptions import Error
from openllm_core.utils import resolve_filepath, validate_is_path
Expand Down Expand Up @@ -27,10 +27,13 @@ def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo:
traceback.print_exc()
raise Error(f'Failed to fetch {model_id} from huggingface.co') from err

def has_safetensors_weights(model_id: str, revision: str | None = None) -> bool:
def has_weights(model_id: str, revision: str | None = None, *, extensions: str) -> bool:
if validate_is_path(model_id):
return next((True for _ in pathlib.Path(resolve_filepath(model_id)).glob('*.safetensors')), False)
return any(s.rfilename.endswith('.safetensors') for s in ModelInfo(model_id, revision=revision).siblings)
return next((True for _ in pathlib.Path(resolve_filepath(model_id)).glob(f'*.{extensions}')), False)
return any(s.rfilename.endswith(f'.{extensions}') for s in ModelInfo(model_id, revision=revision).siblings)

has_safetensors_weights = functools.partial(has_weights, extensions='safetensors')
has_pt_weights = functools.partial(has_weights, extensions='pt')

@attr.define(slots=True)
class HfIgnore:
Expand All @@ -43,8 +46,10 @@ class HfIgnore:
def ignore_patterns(cls, llm: openllm.LLM[t.Any, t.Any]) -> list[str]:
if llm.__llm_backend__ in {'vllm', 'pt'}:
base = [cls.tf, cls.flax, cls.gguf]
if has_safetensors_weights(llm.model_id):
base.append(cls.pt)
if llm.config['architecture'] == 'MixtralForCausalLM': # XXX: Hack for Mixtral as safetensors is yet to be working atm
base.append(cls.safetensors)
elif has_safetensors_weights(llm.model_id):
base.extend([cls.pt, '*.pt'])
else:
base.append(cls.safetensors)
elif llm.__llm_backend__ == 'ggml':
Expand Down

0 comments on commit 22d7144

Please sign in to comment.