From 3d5ec71771d76fb93708d80f8de5bd968d404bc2 Mon Sep 17 00:00:00 2001 From: Plutonium-239 Date: Sun, 2 Jun 2024 02:40:10 +0530 Subject: [PATCH] format + prep for merge (cherry picked from commit c444edd2c1532dd6b1cfb6120018a46e51666a25) --- .gitignore | 2 ++ experiments/util/estimate.py | 2 +- experiments/util/measurements.py | 1 + experiments/util/models.py | 28 ++++++++++++++++++++++++++-- 4 files changed, 30 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 6ad7868..7f71819 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,8 @@ *.txt *.csv !requirements.txt +torchviz-output/ +torchview-output/ # generated docs docs_src/_build/ diff --git a/experiments/util/estimate.py b/experiments/util/estimate.py index 4613c2c..51e78b2 100644 --- a/experiments/util/estimate.py +++ b/experiments/util/estimate.py @@ -305,7 +305,7 @@ def estimate_mem_savings( is_vlm, vis_model, vis_model_arch, llm = args.model.split("!") assert is_vlm == "vlm" assert vis_model_arch in ["transformer", "conv"] - model_fn = lambda: models.VLM(vis_model, vis_model_arch, llm) + model_fn = lambda: models.VLM(vis_model, vis_model_arch, llm) # noqa: E731 config = models.get_transformers_config(llm) vocab_dim = config.vocab_size embed_dim = config.hidden_size diff --git a/experiments/util/measurements.py b/experiments/util/measurements.py index 63cfa16..7ee16bd 100644 --- a/experiments/util/measurements.py +++ b/experiments/util/measurements.py @@ -26,6 +26,7 @@ Parameter, ) from torchvision.models.convnext import LayerNorm2d +from transformers import Conv1D from memsave_torch.nn.Conv2d import MemSaveConv2d from memsave_torch.nn.Linear import MemSaveLinear diff --git a/experiments/util/models.py b/experiments/util/models.py index ccda151..3ba0d37 100644 --- a/experiments/util/models.py +++ b/experiments/util/models.py @@ -2,7 +2,9 @@ import itertools import math -from typing import List, Tuple +import warnings +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple import torchvision.models as tvm from torch.nn import ( @@ -110,7 +112,18 @@ def get_transformers_config(model_name: str) -> AutoConfig: return AutoConfig.from_pretrained(props.hf_name, **props.extra_kwargs) -def get_arch_models(arch: str): +def get_arch_models(arch: str) -> Tuple[Dict[str, Callable], Any]: + """Get the dict of all defined functions for an architecture + + Args: + arch (str): The architecture + + Returns: + Tuple[Dict[str, Callable], Any]: Dict of all defined functions + + Raises: + ValueError: Invalid architecture + """ if arch == "conv": return conv_model_fns, conv_input_shape if arch == "transformer": @@ -467,6 +480,8 @@ def forward(self, x): # VLM class VLM(Module): + """Small wrapper for making a VLM model with transformer llm and conv/transformer vision model""" + def __init__( self, vision_model_name: str, @@ -474,6 +489,7 @@ def __init__( llm_name: str, nc: int = 1000, ) -> None: + """Init""" super().__init__() self.vision_model_name = vision_model_name self.vm_arch = vision_model_arch @@ -488,6 +504,14 @@ def __init__( self.patchify = Unfold(kernel_size=16, stride=16) def forward(self, x): + """Forward through vlm + + Args: + x: x + + Returns: + output: model output + """ if self.vm_arch == "transformer" and self.vm.config.image_size != x.shape[-1]: x = functional.interpolate( x, size=self.vm.config.image_size, mode="bicubic"