Skip to content

Commit

Permalink
format + prep for merge
Browse files Browse the repository at this point in the history
(cherry picked from commit c444edd)
  • Loading branch information
plutonium-239 committed Aug 22, 2024
1 parent b36baa6 commit 3d5ec71
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 3 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
*.txt
*.csv
!requirements.txt
torchviz-output/
torchview-output/

# generated docs
docs_src/_build/
Expand Down
2 changes: 1 addition & 1 deletion experiments/util/estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def estimate_mem_savings(
is_vlm, vis_model, vis_model_arch, llm = args.model.split("!")
assert is_vlm == "vlm"
assert vis_model_arch in ["transformer", "conv"]
model_fn = lambda: models.VLM(vis_model, vis_model_arch, llm)
model_fn = lambda: models.VLM(vis_model, vis_model_arch, llm) # noqa: E731
config = models.get_transformers_config(llm)
vocab_dim = config.vocab_size
embed_dim = config.hidden_size
Expand Down
1 change: 1 addition & 0 deletions experiments/util/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Parameter,
)
from torchvision.models.convnext import LayerNorm2d
from transformers import Conv1D

Check failure on line 29 in experiments/util/measurements.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

experiments/util/measurements.py:29:26: F401 `transformers.Conv1D` imported but unused

from memsave_torch.nn.Conv2d import MemSaveConv2d
from memsave_torch.nn.Linear import MemSaveLinear
Expand Down
28 changes: 26 additions & 2 deletions experiments/util/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import itertools
import math
from typing import List, Tuple
import warnings
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple

import torchvision.models as tvm
from torch.nn import (
Expand Down Expand Up @@ -110,7 +112,18 @@ def get_transformers_config(model_name: str) -> AutoConfig:
return AutoConfig.from_pretrained(props.hf_name, **props.extra_kwargs)


def get_arch_models(arch: str):
def get_arch_models(arch: str) -> Tuple[Dict[str, Callable], Any]:
"""Get the dict of all defined functions for an architecture
Args:
arch (str): The architecture
Returns:
Tuple[Dict[str, Callable], Any]: Dict of all defined functions
Raises:
ValueError: Invalid architecture
"""
if arch == "conv":
return conv_model_fns, conv_input_shape
if arch == "transformer":
Expand Down Expand Up @@ -467,13 +480,16 @@ def forward(self, x):

# VLM
class VLM(Module):
"""Small wrapper for making a VLM model with transformer llm and conv/transformer vision model"""

def __init__(
self,
vision_model_name: str,
vision_model_arch: str,
llm_name: str,
nc: int = 1000,
) -> None:
"""Init"""
super().__init__()
self.vision_model_name = vision_model_name
self.vm_arch = vision_model_arch
Expand All @@ -488,6 +504,14 @@ def __init__(
self.patchify = Unfold(kernel_size=16, stride=16)

def forward(self, x):
"""Forward through vlm
Args:
x: x
Returns:
output: model output
"""
if self.vm_arch == "transformer" and self.vm.config.image_size != x.shape[-1]:
x = functional.interpolate(
x, size=self.vm.config.image_size, mode="bicubic"
Expand Down

0 comments on commit 3d5ec71

Please sign in to comment.