Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
plutonium-239 committed May 5, 2024
1 parent 52993fb commit b5424c6
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions experiments/util/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import itertools
import math
from functools import partial
from typing import Any, List, Tuple, Optional
from typing import Any, List, Tuple, Optional, Dict

Check failure on line 6 in experiments/util/models.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

experiments/util/models.py:6:38: F401 `typing.Optional` imported but unused

import torchvision.models as tvm
from torch.nn import (
Expand Down Expand Up @@ -309,7 +309,7 @@ class _HF_model:
def __init__(
self,
hf_name: str,
extra_kwargs: Optional[dict] = None,
extra_kwargs: Dict[str, Any],
model_cls: Any = AutoModelForCausalLM,
) -> None:
self.hf_name = hf_name
Expand Down Expand Up @@ -339,8 +339,8 @@ def __init__(
"memsave_transformer": lambda: convert_to_memory_saving(TorchTransformer()),
}

fused = lambda cls, name, kwargs: convert_to_memory_saving( # noqa: E731
cls.from_pretrained(name, **kwargs)
fused = lambda fn, name, kwargs: convert_to_memory_saving( # noqa: E731
fn(name, **kwargs)
)

for m in hf_transformers_models:
Expand All @@ -355,7 +355,7 @@ def __init__(
else:
props = hf_transformers_models_map[m.split("memsave_", 1)[1]]
transformer_model_fns[m] = partial(
fused, props.model_cls, props.hf_name, **props.extra_kwargs
fused, props.model_cls.from_pretrained, props.hf_name, props.extra_kwargs
)


Expand Down

0 comments on commit b5424c6

Please sign in to comment.