From b5424c622bbea3fb949764d8a711dc6c44db8ada Mon Sep 17 00:00:00 2001 From: Plutonium-239 Date: Mon, 6 May 2024 00:30:09 +0530 Subject: [PATCH] fix --- experiments/util/models.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/experiments/util/models.py b/experiments/util/models.py index 3b20504..fb60062 100644 --- a/experiments/util/models.py +++ b/experiments/util/models.py @@ -3,7 +3,7 @@ import itertools import math from functools import partial -from typing import Any, List, Tuple, Optional +from typing import Any, List, Tuple, Optional, Dict import torchvision.models as tvm from torch.nn import ( @@ -309,7 +309,7 @@ class _HF_model: def __init__( self, hf_name: str, - extra_kwargs: Optional[dict] = None, + extra_kwargs: Dict[str, Any], model_cls: Any = AutoModelForCausalLM, ) -> None: self.hf_name = hf_name @@ -339,8 +339,8 @@ def __init__( "memsave_transformer": lambda: convert_to_memory_saving(TorchTransformer()), } -fused = lambda cls, name, kwargs: convert_to_memory_saving( # noqa: E731 - cls.from_pretrained(name, **kwargs) +fused = lambda fn, name, kwargs: convert_to_memory_saving( # noqa: E731 + fn(name, **kwargs) ) for m in hf_transformers_models: @@ -355,7 +355,7 @@ def __init__( else: props = hf_transformers_models_map[m.split("memsave_", 1)[1]] transformer_model_fns[m] = partial( - fused, props.model_cls, props.hf_name, **props.extra_kwargs + fused, props.model_cls.from_pretrained, props.hf_name, props.extra_kwargs )