diff --git a/.gitignore b/.gitignore index 6ad7868..7f71819 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,8 @@ *.txt *.csv !requirements.txt +torchviz-output/ +torchview-output/ # generated docs docs_src/_build/ diff --git a/experiments/get_best_results.py b/experiments/get_best_results.py index 1321679..2146a4f 100644 --- a/experiments/get_best_results.py +++ b/experiments/get_best_results.py @@ -7,7 +7,7 @@ import pandas as pd -from experiments.util.collect_results import case_mapping +from experiments.util.collect_results import case_inv_mapping def main(base_dir: str): @@ -16,17 +16,25 @@ def main(base_dir: str): Args: base_dir (str): The base results dir """ - for device, arch in product(["cuda", "cpu"], ["linear", "conv"]): + # Don't recognize None as NaN + custom_na_values = pd._libs.parsers.STR_NA_VALUES - {"None"} + for device, arch in product(["cuda", "cpu"], ["linear", "conv", "transformer"]): # usage stats df = None idx_col = ["model", "case"] for fname in glob(os.path.join(base_dir, f"usage_stats-{arch}-{device}-*.csv")): with open(fname) as f: - f.readline() - temp_df = pd.read_csv(f, index_col=idx_col) + # f.readline() + temp_df = pd.read_csv( + f, + index_col=idx_col, + header=1, + na_values=custom_na_values, + keep_default_na=False, + ) df = temp_df if df is None else pd.concat([df, temp_df]) if df is not None: - df = df.rename(index=case_mapping, level=1) + df = df.rename(index=case_inv_mapping, level=1) df["Memory Usage (GB)"] = df["Memory Usage (MB)"] / 1024 df = df.drop(columns=["Memory Usage (MB)"]) best_results = df.groupby(idx_col).min() diff --git a/experiments/paper_demo.py b/experiments/paper_demo.py index 27b9b87..ddc375e 100644 --- a/experiments/paper_demo.py +++ b/experiments/paper_demo.py @@ -14,7 +14,7 @@ from experiments.util.models import prefix_in_pairs estimators = ["time", "memory"] -estimators = ["memory"] +# estimators = ["memory"] # estimators = ["time"] # improvements can be either speedups or savings based on context @@ -28,35 +28,55 @@ # repeat the experiment multiple times (generates multiple files to be aggregated by `get_best_results`) n_repeat = 5 -# CONV +# ============== CONV CONFIG ============== # Valid choices for models are in models.conv_model_fns +# models = [ +# "deepmodel", +# "resnet101", +# "resnet18", +# "vgg16", # "convnext_base", +# "fasterrcnn_resnet50_fpn_v2", +# "ssdlite320_mobilenet_v3_large", # "retinanet_resnet50_fpn_v2", +# "deeplabv3_resnet101", +# "fcn_resnet101", +# "efficientnet_v2_l", +# "mobilenet_v3_large", +# "resnext101_64x4d", +# ] +# models = prefix_in_pairs("memsave_", models) +# batch_size = 64 +# input_channels = 3 +# input_HW = 224 +# num_classes = 1000 +# device = "cuda" +# architecture = "conv" +# cases = collect_results.select_cases(['All', 'Input', 'Conv', 'Norm']) + +# ============== TRANSFORMER CONFIG ============== +# Valid choices for models are in models.transformer_model_fns models = [ - "deepmodel", - "resnet101", - "resnet18", - "vgg16", # "convnext_base", - "fasterrcnn_resnet50_fpn_v2", - "ssdlite320_mobilenet_v3_large", # "retinanet_resnet50_fpn_v2", - "deeplabv3_resnet101", - "fcn_resnet101", - "efficientnet_v2_l", - "mobilenet_v3_large", - "resnext101_64x4d", + "transformer", + "gpt2", + "bert", + "bart", + "roberta", + "t5", + "flan-t5", + # "xlm-roberta", + "mistral-7b", + "llama3-8b", + "phi3-4b", ] - -# models = ["resnet101", "memsave_resnet101_conv", "memsave_resnet101_conv+relu+bn", "memsave_resnet101_conv_full"] -# models = ["resnet101", "memsave_resnet101_conv_full"] - models = prefix_in_pairs("memsave_", models) -# models = ["memsave_resnet101"] batch_size = 64 -input_channels = 3 -input_HW = 224 -num_classes = 1000 +input_channels = 2048 +input_HW = 256 +num_classes = 5000 device = "cuda" -architecture = "conv" +architecture = "transformer" +cases = collect_results.select_cases(["All", "Input", "Norm"]) -# LINEAR +# ============== LINEAR CONFIG ============== # Valid choices for models are in models.linear_model_fns # models = ['deeplinearmodel'] # models += [f"memsave_{m}" for m in models] # add memsave versions for each model @@ -66,31 +86,7 @@ # num_classes = 1000 # device = 'cuda' # architecture = 'linear' # use high batch size - -cases = [ - None, # ALL - [ # INPUT - "grad_input", - "no_grad_conv_weights", - "no_grad_conv_bias", - "no_grad_linear_weights", - "no_grad_linear_bias", - "no_grad_norm_weights", - "no_grad_norm_bias", - ], - [ # CONV - "no_grad_linear_weights", - "no_grad_linear_bias", - "no_grad_norm_weights", - "no_grad_norm_bias", - ], - [ # NORM - "no_grad_conv_weights", - "no_grad_conv_bias", - "no_grad_linear_weights", - "no_grad_linear_bias", - ], -] +# cases = collect_results.select_cases(['All', 'Input', 'Linear']) if __name__ == "__main__": @@ -110,17 +106,27 @@ ) for model in models: + B = batch_size + if model in prefix_in_pairs("memsave_", ["flan-t5"]): + B = 56 + if model in prefix_in_pairs("memsave_", ["mistral-7b", "phi3-4b"]): + B = 16 + if model in prefix_in_pairs("memsave_", ["llama3-8b"]): + B = 8 for estimate in estimators: outputs = [] collector.clear_file(estimate) for case in cases: pbar.update() - pbar.set_description(f"{model} {estimate} case {case}") + case_display = collect_results.case_inv_mapping[ + collect_results.make_case_str(case) + ] case_str = f"--case {' '.join(case)}" if case is not None else "" + pbar.set_description(f"{model} {estimate} case {case_display}") cmd = ( f"python experiments/util/estimate.py --architecture {architecture} --model {model} --estimate {estimate} {case_str} " - + f"--device {device} -B {batch_size} -C_in {input_channels} -HW {input_HW} -n_class {num_classes}" + + f"--device {device} -B {B} -C_in {input_channels} -HW {input_HW} -n_class {num_classes}" ) proc = subprocess.run(shlex.split(cmd), capture_output=True) assert proc.stderr in [ diff --git a/experiments/util/collect_results.py b/experiments/util/collect_results.py index 393fe0a..d6a918b 100644 --- a/experiments/util/collect_results.py +++ b/experiments/util/collect_results.py @@ -27,14 +27,50 @@ ], } -case_mapping = { - "None": "All", - "grad_input + no_grad_conv_weights + no_grad_conv_bias + no_grad_linear_weights + no_grad_linear_bias + no_grad_norm_weights + no_grad_norm_bias": "Input", - "no_grad_linear_weights + no_grad_linear_bias + no_grad_norm_weights + no_grad_norm_bias": "Conv", - "no_grad_conv_weights + no_grad_conv_bias + no_grad_linear_weights + no_grad_linear_bias": "Norm", +cases = { + "All": None, # ALL + "Input": [ # INPUT + "grad_input", + "no_grad_conv_weights", + "no_grad_conv_bias", + "no_grad_linear_weights", + "no_grad_linear_bias", + "no_grad_norm_weights", + "no_grad_norm_bias", + ], + "Conv": [ # CONV + "no_grad_linear_weights", + "no_grad_linear_bias", + "no_grad_norm_weights", + "no_grad_norm_bias", + ], + "Linear": [ # LINEAR + "no_grad_conv_weights", + "no_grad_conv_bias", + "no_grad_norm_weights", + "no_grad_norm_bias", + ], + "Norm": [ # NORM + "no_grad_conv_weights", + "no_grad_conv_bias", + "no_grad_linear_weights", + "no_grad_linear_bias", + ], } +def select_cases(selected: List[str]) -> List[Union[List[str], None]]: + """Helper function to return cases selected by their names + + Args: + selected (List[str]): Which cases to select, strings can be keys of the cases table + + Returns: + List[Union[List[str], None]]: Selected cases + """ + return [cases[s] for s in selected] + + def make_case_str(case: Union[None, List[str]]) -> str: """Format case into a string @@ -47,6 +83,9 @@ def make_case_str(case: Union[None, List[str]]) -> str: return "None" if case is None else " + ".join(case) +case_inv_mapping = {make_case_str(v): k for k, v in cases.items()} + + def hyperparam_str(args: SimpleNamespace) -> str: """Format hyperparams into a string @@ -172,12 +211,15 @@ def _display_run( """ # print(f"{model} input ({input_channels},{input_HW},{input_HW}) {device}") # print('='*78) - s = f"{model} input ({self.batch_size},{self.input_channels},{self.input_HW},{self.input_HW}) {self.device}" + if self.architecture == "conv": + s = f"{model} input ({self.batch_size},{self.input_channels},{self.input_HW},{self.input_HW}) {self.device}" + elif self.architecture == "transformer": + s = f"{model} input ({self.batch_size},{self.input_HW},{self.input_channels}(or model hidden size)) {self.device}" print(s.center(78, "=")) for out, case in zip(outputs, self.cases): print( - f"{strings[estimate][1]} ({case_mapping[make_case_str(case)]}): {out:.3f}{strings[estimate][0]}" + f"{strings[estimate][1]} ({case_inv_mapping[make_case_str(case)]}): {out:.3f}{strings[estimate][0]}" ) # CODE ONLY APPLIES WITH OLD RUNDEMO.PY diff --git a/experiments/util/estimate.py b/experiments/util/estimate.py index 8486931..8df7417 100644 --- a/experiments/util/estimate.py +++ b/experiments/util/estimate.py @@ -38,6 +38,8 @@ "no_grad_norm_bias", "grad_input", "no_grad_input", + "grad_embed_weights", + "no_grad_embed_weights", ] @@ -62,7 +64,10 @@ def parse_case(case: Optional[List[str]]) -> Dict[str, bool]: def skip_case_check(args: argparse.Namespace) -> bool: - """Decide whether to skip the case (when case has grad_norm_* but model does not have any normalization layers) + """Decide whether to skip the case: + + 1. when case has grad_norm_* but model does not have any normalization layers + 2. when case has no_grad_embed_weights but no grad_input: there is a backward error (no input requires_grad) Args: args (argparse.Namespace): args @@ -73,12 +78,16 @@ def skip_case_check(args: argparse.Namespace) -> bool: invalid = False if args.case is None: return invalid + # 1. for c in ["grad_norm_bias", "grad_norm_weights"]: if c in args.case and args.model in models.models_without_norm: invalid = True for c in ["no_grad_norm_bias", "no_grad_norm_weights"]: if c not in args.case and args.model in models.models_without_norm: invalid = True + # 2. + if "no_grad_embed_weights" in args.case and "grad_input" not in args.case: + invalid = True if invalid: if args.print: print("-1") @@ -226,7 +235,7 @@ def estimate_mem_savings( type=str, required=True, help="Which architecture to run", - choices=["conv", "linear"], + choices=["conv", "linear", "transformer", "VLM"], ) parser.add_argument( "--estimate", @@ -269,6 +278,7 @@ def estimate_mem_savings( input_shape = (args.input_channels, args.input_hw, args.input_hw) models.conv_input_shape = input_shape model_fn = models.conv_model_fns.get(args.model) + y_args = {"size": (batch_size,), "low": 0, "high": num_classes} assert ( model_fn is not None ), f"Conv model name {args.model} not found, must be one of {list(models.conv_model_fns.keys())}" @@ -276,16 +286,51 @@ def estimate_mem_savings( input_shape = [args.input_hw**2] models.linear_input_shape = input_shape[0] model_fn = models.linear_model_fns.get(args.model) + y_args = {"size": (batch_size,), "low": 0, "high": num_classes} assert ( model_fn is not None ), f"Linear model name {args.model} not found, must be one of {list(models.linear_model_fns.keys())}" + elif args.architecture == "transformer": + vocab_dim = args.num_classes + embed_dim = args.input_channels + seq_len = args.input_hw + model_fn = models.transformer_model_fns.get(args.model) + if args.model in models.hf_transformers_models: + model_fn_orig = model_fn + model_fn = lambda: models.TransformersModelWrapper( # noqa: E731 + model_fn_orig, args.model + ) + config = models.get_transformers_config(args.model) + # as per transformers.PretrainedConfig these 2 should be present in all models: + vocab_dim = config.vocab_size + embed_dim = config.hidden_size + models.transformer_input_shape = (vocab_dim, embed_dim) + input_shape = [seq_len, embed_dim] + y_args = {"size": (batch_size, seq_len), "low": 0, "high": vocab_dim} + assert ( + model_fn is not None + ), f"Transformer model name {args.model} not found, must be one of {list(models.transformer_model_fns.keys())}" + elif args.architecture == "VLM": + # model format: `vlm!!!` + # eg: `vlm!vit!transformer!memsave_gpt2` + is_vlm, vis_model, vis_model_arch, llm = args.model.split("!") + assert is_vlm == "vlm" + assert vis_model_arch in ["transformer", "conv"] + model_fn = lambda: models.VLM(vis_model, vis_model_arch, llm) # noqa: E731 + config = models.get_transformers_config(llm) + vocab_dim = config.vocab_size + embed_dim = config.hidden_size + seq_len = (args.input_hw // 16) ** 2 + y_args = {"size": (batch_size, seq_len), "low": 0, "high": vocab_dim} + input_shape = (args.input_channels, args.input_hw, args.input_hw) + models.conv_input_shape = input_shape loss_fn = CrossEntropyLoss manual_seed(0) # make deterministic x = rand(batch_size, *input_shape, device=dev) - y = randint(size=(batch_size,), low=0, high=num_classes, device=dev) + y = randint(**y_args, device=dev) targets = None if args.model in models.detection_models: # pred is a dictionary of losses diff --git a/experiments/util/measurements.py b/experiments/util/measurements.py index 63cfa16..321843d 100644 --- a/experiments/util/measurements.py +++ b/experiments/util/measurements.py @@ -20,12 +20,14 @@ ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, + Embedding, LayerNorm, Linear, Module, Parameter, ) from torchvision.models.convnext import LayerNorm2d +from transformers import Conv1D from memsave_torch.nn.Conv2d import MemSaveConv2d from memsave_torch.nn.Linear import MemSaveLinear @@ -116,28 +118,33 @@ def forward_backward( grad_norm_weights: bool = True, grad_norm_bias: bool = True, grad_input: bool = False, + grad_embed_weights: bool = False, ) -> float: """Perform a forward and backward pass and return the run time. Syncs CUDA threads if the device is a GPU. + Note: We directly pass input embeddings to transformers so embed weights are never used and their + grad will be None. Args: - grad_linear_weights: Whether to compute the gradient of the linear + grad_linear_weights (bool, optional): Whether to compute the gradient of the linear layer weights. Default: `True`. - grad_linear_bias: Whether to compute the gradient of the linear + grad_linear_bias (bool, optional): Whether to compute the gradient of the linear layer bias. Default: `True`. - grad_conv_weights: Whether to compute the gradient of the convolution + grad_conv_weights (bool, optional): Whether to compute the gradient of the convolution layer weights. Default: `True`. - grad_conv_bias: Whether to compute the gradient of the convolution + grad_conv_bias (bool, optional): Whether to compute the gradient of the convolution layer bias. Default: `True`. - grad_norm_weights: Whether to compute the gradient of the normalization + grad_norm_weights (bool, optional): Whether to compute the gradient of the normalization layer weights. Default: `True`. - grad_norm_bias: Whether to compute the gradient of the normalization + grad_norm_bias (bool, optional): Whether to compute the gradient of the normalization layer bias. Default: `True`. - grad_input: Whether to compute the gradient of the input. Default: `False`. + grad_input (bool, optional): Whether to compute the gradient of the input. Default: `False`. + grad_embed_weights (bool, optional): Whether to compute the gradient of the embedding + layer weights. Default: `True`. Returns: - The run time in seconds. + float: The run time in seconds. """ model, loss_fn, x, y, targets = self.set_up() @@ -149,6 +156,7 @@ def forward_backward( grad_conv_bias, grad_norm_weights, grad_norm_bias, + grad_embed_weights, ) leafs = ([x] if grad_input else []) + leafs no_leafs = ([y] if grad_input else [x, y]) + no_leafs @@ -191,9 +199,13 @@ def after_forward( grad_norm_weights: bool = True, grad_norm_bias: bool = True, grad_input: bool = False, + grad_embed_weights: bool = False, ) -> float: """Return memory usage after a forward pass. + Note: We directly pass input embeddings to transformers so embed weights are never used and their + grad will be None. + Args: grad_linear_weights: Whether to compute the gradient of the linear layer weights. Default: `True`. @@ -208,9 +220,11 @@ def after_forward( grad_norm_bias: Whether to compute the gradient of the normalization layer bias. Default: `True`. grad_input: Whether to compute the gradient of the input. Default: `False`. + grad_embed_weights (bool, optional): Whether to compute the gradient of the embedding + layer weights. Default: `True`. Returns: - The memory usage in bytes. + float: The memory usage in bytes. """ model, loss_fn, x, y, targets = self.set_up() @@ -222,6 +236,7 @@ def after_forward( grad_conv_bias, grad_norm_weights, grad_norm_bias, + grad_embed_weights, ) leafs = ([x] if grad_input else []) + leafs no_leafs = ([y] if grad_input else [x, y]) + no_leafs @@ -289,6 +304,7 @@ def separate_grad_arguments( grad_conv_bias: bool, grad_norm_weights: bool, grad_norm_bias: bool, + grad_embed_weights: bool, ) -> Tuple[List[Parameter], List[Parameter]]: """Separate the parameters of a model into leafs and non-leafs. @@ -303,6 +319,7 @@ def separate_grad_arguments( grad_norm_weights: Whether to compute the gradient of the normalization layer weights. grad_norm_bias: Whether to compute the gradient of the normalization layer bias. + grad_embed_weights: Whether to compute the gradient of the embedding layer weights Returns: A tuple of lists of parameters. The first list contains the leafs, the second @@ -311,7 +328,7 @@ def separate_grad_arguments( Raises: NotImplementedError: If an unknown layer with parameters is encountered. """ - linear = (Linear, MemSaveLinear) + linear = (Linear, MemSaveLinear, Conv1D) conv = ( Conv1d, Conv2d, @@ -322,6 +339,7 @@ def separate_grad_arguments( MemSaveConv2d, ) norm = (BatchNorm1d, BatchNorm2d, BatchNorm3d, LayerNorm, LayerNorm2d) + embed = Embedding leafs, no_leafs = [], [] @@ -334,10 +352,29 @@ def separate_layer(layer: Module, grad_weight: bool, grad_bias: bool): grad_bias: Whether to compute the gradient of the layer bias. """ leafs.append(layer.weight) if grad_weight else no_leafs.append(layer.weight) - if layer.bias is not None: + if "bias" in layer._parameters and layer.bias is not None: leafs.append(layer.bias) if grad_bias else no_leafs.append(layer.bias) - layers = [m for m in model.modules() if len(list(m.modules())) == 1] + def check_lm_head(n) -> bool: + """Checks if the module name n corresponds to an LM Head + + LM Head for transformers is not trainable (i.e. it is a module but it's weight is not a parameter) + and the weights are tied to the embedding layer + + Args: + n (str): name of the module + + Returns: + bool: Whether n is a LM head + """ + lm_head_name = getattr(model, "lm_head_name", None) + return lm_head_name is not None and lm_head_name in n + + layers = [ + m + for n, m in model.named_modules() + if len(list(m.modules())) == 1 and not check_lm_head(n) + ] for layer in layers: if isinstance(layer, linear): @@ -346,6 +383,8 @@ def separate_layer(layer: Module, grad_weight: bool, grad_bias: bool): separate_layer(layer, grad_conv_weights, grad_conv_bias) elif isinstance(layer, norm): separate_layer(layer, grad_norm_weights, grad_norm_bias) + elif isinstance(layer, embed): + separate_layer(layer, grad_embed_weights, False) elif list(layer.parameters()): raise NotImplementedError(f"Unknown layer with parameters: {layer}.") diff --git a/experiments/util/models.py b/experiments/util/models.py index 222c8c4..951ce03 100644 --- a/experiments/util/models.py +++ b/experiments/util/models.py @@ -2,10 +2,34 @@ import itertools import math -from typing import List, Tuple +import warnings +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple +import torch import torchvision.models as tvm -from torch.nn import Conv2d, Flatten, Linear, MaxPool2d, Module, ReLU, Sequential +from torch.nn import ( + Conv2d, + Flatten, + Linear, + MaxPool2d, + Module, + ReLU, + Sequential, + Transformer, + Unfold, + functional, +) +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForMaskedLM, + AutoModelForPreTraining, + AutoModelForSeq2SeqLM, + BartForConditionalGeneration, +) +from transformers import logging as tf_logging +from transformers import utils as tf_utils from memsave_torch.nn import ( MemSaveConv2d, @@ -26,7 +50,7 @@ def prefix_in_pairs(prefix: str, it: List[str]) -> List[str]: Args: prefix (str): Prefix to be added - it (List[str]): Description + it (List[str]): The list to be prefixed Returns: List[str]: The output iterable with items prefixed in pairs @@ -73,6 +97,42 @@ def convert_to_memory_saving_defaultsoff( ) +def get_transformers_config(model_name: str) -> AutoConfig: + """Get the config for the given `model_name` from huggingface transformers. Handles memsave_ as well. + + Args: + model_name (str): Model name + + Returns: + AutoConfig: Config for given model + """ + if model_name.startswith("memsave_"): + model_name = model_name.split("memsave_")[1] + props = hf_transformers_models_map[model_name] + return AutoConfig.from_pretrained(props.hf_name, **props.extra_kwargs) + + +def get_arch_models(arch: str) -> Tuple[Dict[str, Callable], Any]: + """Get the dict of all defined functions for an architecture + + Args: + arch (str): The architecture + + Returns: + Tuple[Dict[str, Callable], Any]: Dict of all defined functions + + Raises: + ValueError: Invalid architecture + """ + if arch == "conv": + return conv_model_fns, conv_input_shape + if arch == "transformer": + return transformer_model_fns, transformer_input_shape + if arch == "linear": + return linear_model_fns, linear_input_shape + raise ValueError(f"arch={arch} not in allowed architectures") + + # CONV conv_input_shape: Tuple[int, int, int] = (1, 1, 1) @@ -213,6 +273,7 @@ def _convrelupool_model1(num_blocks=5) -> Module: "memsave_resnext101_64x4d": lambda: convert_to_memory_saving( tvm.resnext101_64x4d() ), + # For paper "memsave_resnet101_conv": lambda: convert_to_memory_saving_defaultsoff( tvm.resnet101(), conv2d=True ), @@ -267,6 +328,210 @@ def forward(self, loss_dict): return sum(loss_dict.values()) +# TRANSFORMER +transformer_input_shape: Tuple[int, int] = (1, 1) # (vocab_dim, embed_dim) + + +class _HF_model: + def __init__( + self, + hf_name: str, + extra_kwargs: Dict[str, Any], + model_cls: Any = AutoModelForCausalLM, + lm_head_name: Optional[str] = None, + ) -> None: + self.hf_name = hf_name + self.extra_kwargs = extra_kwargs + if self.extra_kwargs is None: + self.extra_kwargs = {} + self.model_cls = model_cls + self.lm_head_name = lm_head_name + + +tf_logging.disable_progress_bar() +tf_logging.set_verbosity_error() +tf_utils.logging.captureWarnings(True) + + +hf_transformers_models_map = { + "gpt2": _HF_model("gpt2", {}, lm_head_name="lm_head"), + "vit": _HF_model("facebook/vit-mae-base", {}, AutoModelForPreTraining), + "bert": _HF_model( + "google-bert/bert-base-uncased", + {"is_decoder": True}, + lm_head_name="cls.predictions.decoder", + ), + "bart": _HF_model( + "facebook/bart-base", {}, BartForConditionalGeneration, "lm_head" + ), + "roberta": _HF_model( + "FacebookAI/roberta-base", {"is_decoder": True}, lm_head_name="lm_head.decoder" + ), + "t5": _HF_model("google-t5/t5-base", {}, AutoModelForSeq2SeqLM, "lm_head"), + "flan-t5": _HF_model("google/flan-t5-base", {}, AutoModelForSeq2SeqLM, "lm_head"), + "xlm-roberta": _HF_model( + "FacebookAI/xlm-roberta-base", {}, AutoModelForMaskedLM, "lm_head.decoder" + ), + "mistral-7b": _HF_model( + "mistralai/Mistral-7B-v0.1", + {"torch_dtype": torch.bfloat16}, + lm_head_name="lm_head", + ), + "llama3-8b": _HF_model( + "meta-llama/Meta-Llama-3-8B", + {"torch_dtype": torch.bfloat16}, + lm_head_name="lm_head", + ), + "phi3-4b": _HF_model( + "microsoft/Phi-3-mini-4k-instruct", + {"torch_dtype": torch.bfloat16}, + lm_head_name="lm_head", + ), +} +hf_transformers_models = list(hf_transformers_models_map.keys()) +hf_transformers_models = prefix_in_pairs("memsave_", hf_transformers_models) + +transformer_model_fns = { + "transformer": lambda: TorchTransformer(), + "memsave_transformer": lambda: convert_to_memory_saving(TorchTransformer()), +} + +fused = lambda fn, name, kwargs: convert_to_memory_saving( # noqa: E731 + fn(name, **kwargs) +) + +for m in hf_transformers_models: + if m in transformer_model_fns: + continue + # Can't use lambdas in loops :') + if not m.startswith("memsave_"): + props = hf_transformers_models_map[m] + transformer_model_fns[m] = partial( + props.model_cls.from_pretrained, props.hf_name, **props.extra_kwargs + ) + else: + props = hf_transformers_models_map[m.split("memsave_", 1)[1]] + transformer_model_fns[m] = partial( + fused, props.model_cls.from_pretrained, props.hf_name, props.extra_kwargs + ) + + +class TorchTransformer(Module): + """Small model to wrap `torch.nn.Transformer`""" + + def __init__(self) -> None: + """Init""" + super().__init__() + self.transformer = Transformer( + d_model=transformer_input_shape[1], batch_first=True + ) + self.pred = Linear(transformer_input_shape[1], transformer_input_shape[0]) + + def forward(self, x): + """Forward + + Args: + x: x + + Returns: + output: model output + """ + out = self.transformer.decoder(x, self.transformer.encoder(x)) + return self.pred(out).permute(0, 2, 1) + + +class TransformersModelWrapper(Module): + """Small wrapper around `transformers` models to support interop with existing measurement code""" + + def __init__(self, model_fn, model_name) -> None: + """Init""" + super().__init__() + with warnings.catch_warnings(): + # hf does not keep quiet sometimes even when transformers.logging is set to errors only + # https://github.com/huggingface/transformers/issues/30618 + warnings.simplefilter("ignore") + self.model = model_fn() + self.dec = self.model.config.is_encoder_decoder + self.model_name = model_name + model_name_pure = model_name + if model_name.startswith("memsave_"): + model_name_pure = model_name.split("memsave_")[1] + self.lm_head_name = hf_transformers_models_map[model_name_pure].lm_head_name + + self.cache_kw = {"use_cache": False} + if any("ForMaskedLM" in a for a in self.model.config.architectures): + self.cache_kw = {} + + def forward(self, x): + """Forward + + Args: + x: x + + Returns: + output: model output + """ + if self.model.dtype != torch.float32: + x = x.to(self.model.dtype) + # HF takes care of converting logits to float32 + if self.dec: + out = self.model(inputs_embeds=x, decoder_inputs_embeds=x, **self.cache_kw) + else: + out = self.model(inputs_embeds=x, **self.cache_kw) + return out.logits.permute(0, 2, 1) + + +# VLM +class VLM(Module): + """Small wrapper for making a VLM model with transformer llm and conv/transformer vision model""" + + def __init__( + self, + vision_model_name: str, + vision_model_arch: str, + llm_name: str, + nc: int = 1000, + ) -> None: + """Init""" + super().__init__() + self.vision_model_name = vision_model_name + self.vm_arch = vision_model_arch + self.llm_name = llm_name + model_fns, input_shape = get_arch_models(vision_model_arch) + if vision_model_arch == "conv": + assert vision_model_name in segmentation_models + self.vm = model_fns[vision_model_name]() + self.llm = TransformersModelWrapper(transformer_model_fns[llm_name], llm_name) + vision_final_dim = 3 * 16 * 16 if vision_model_arch == "transformer" else nc + self.proj = Linear(vision_final_dim, self.llm.model.config.hidden_size) + self.patchify = Unfold(kernel_size=16, stride=16) + + def forward(self, x): + """Forward through vlm + + Args: + x: x + + Returns: + output: model output + """ + if self.vm_arch == "transformer" and self.vm.config.image_size != x.shape[-1]: + x = functional.interpolate( + x, size=self.vm.config.image_size, mode="bicubic" + ) + x = self.vm(x) + if self.vm_arch == "conv": + import ipdb + + ipdb.set_trace() + x = self.patchify(x["out"]).permute(0, 2, 1) + # [B, nc*n_patches, patch_size**2] + else: + x = x.logits + x = self.proj(x) + # [B, patch_size**2, llm_hidden] + return self.llm(x) + # LINEAR linear_input_shape: int = 1 diff --git a/memsave_torch/nn/Linear.py b/memsave_torch/nn/Linear.py index 9deb3c1..1ba78c6 100644 --- a/memsave_torch/nn/Linear.py +++ b/memsave_torch/nn/Linear.py @@ -3,10 +3,18 @@ This is done by not saving the inputs/weights if weight/inputs dont require grad. """ +import sys + import torch.nn as nn from memsave_torch.nn.functional import linearMemSave +transformers_imported = False +if "transformers" in sys.modules: + import transformers + + transformers_imported = True + class MemSaveLinear(nn.Linear): """MemSaveLinear.""" @@ -36,14 +44,21 @@ def forward(self, x): @classmethod def from_nn_Linear(cls, linear: nn.Linear): - """Converts a nn.Linear layer to MemSaveLinear. + """Converts a nn.Linear/transformers.Conv1D layer to MemSaveLinear. Args: - linear : The nn.Linear layer + linear : The nn.Linear/transformers.Conv1D layer Returns: obj: The MemSaveLinear object """ + isTransformersConv1D = False + if transformers_imported: + isTransformersConv1D = isinstance(linear, transformers.Conv1D) + if isTransformersConv1D: + # it only saves output features in the model (linear.nf); need to take input features from weight anyway + # weight and bias are still defined + linear.in_features, linear.out_features = linear.weight.shape obj = cls( linear.in_features, linear.out_features, @@ -51,6 +66,9 @@ def from_nn_Linear(cls, linear: nn.Linear): device=getattr(linear, "device", None), dtype=getattr(linear, "dtype", None), ) - obj.weight = linear.weight + if isTransformersConv1D: + obj.weight = nn.Parameter(linear.weight.T) + else: + obj.weight = linear.weight obj.bias = linear.bias return obj diff --git a/memsave_torch/nn/__init__.py b/memsave_torch/nn/__init__.py index b9bba81..4fbc756 100644 --- a/memsave_torch/nn/__init__.py +++ b/memsave_torch/nn/__init__.py @@ -6,17 +6,26 @@ - BatchNorm2d """ +import sys + import torch.nn as nn from memsave_torch.nn import functional # noqa: F401 from memsave_torch.nn.BatchNorm import MemSaveBatchNorm2d from memsave_torch.nn.Conv1d import MemSaveConv1d from memsave_torch.nn.Conv2d import MemSaveConv2d +from memsave_torch.nn.Dropout import MemSaveDropout from memsave_torch.nn.LayerNorm import MemSaveLayerNorm from memsave_torch.nn.Linear import MemSaveLinear from memsave_torch.nn.MaxPool import MemSaveMaxPool2d from memsave_torch.nn.ReLU import MemSaveReLU +transformers_imported = False +if "transformers" in sys.modules: + import transformers + + transformers_imported = True + def convert_to_memory_saving( model: nn.Module, @@ -27,6 +36,7 @@ def convert_to_memory_saving( relu=True, maxpool2d=True, layernorm=True, + dropout=True, verbose=False, clone_params=False, ) -> nn.Module: @@ -45,16 +55,20 @@ def convert_to_memory_saving( relu (bool, optional): Whether to replace `nn.ReLU` layers maxpool2d (bool, optional): Whether to replace `nn.MaxPool2d` layers layernorm (bool, optional): Whether to replace `nn.LayerNorm` layers + dropout (bool, optional): Whether to replace `nn.Dropout` layers verbose (bool, optional): Whether to print which layers were replaced clone_params (bool, optional): Whether to clone the layer parameters or use directly Returns: memsavemodel (nn.Module): The converted memory saving model """ + linear_cls = nn.Linear + if transformers_imported: + linear_cls = (nn.Linear, transformers.Conv1D) layers = [ { "allowed": linear, - "cls": nn.Linear, + "cls": linear_cls, "convert_fn": MemSaveLinear.from_nn_Linear, }, {"allowed": relu, "cls": nn.ReLU, "convert_fn": MemSaveReLU.from_nn_ReLU}, @@ -83,6 +97,11 @@ def convert_to_memory_saving( "cls": nn.LayerNorm, "convert_fn": MemSaveLayerNorm.from_nn_LayerNorm, }, + { + "allowed": dropout, + "cls": nn.Dropout, + "convert_fn": MemSaveDropout.from_nn_dropout, + }, ] import copy diff --git a/memsave_torch/nn/functional/Dropout.py b/memsave_torch/nn/functional/Dropout.py index d20e725..f6fe992 100644 --- a/memsave_torch/nn/functional/Dropout.py +++ b/memsave_torch/nn/functional/Dropout.py @@ -10,7 +10,9 @@ class _MemSaveDropout(torch.autograd.Function): @staticmethod def forward(ctx, x, p, train): - out, mask = torch.ops.aten.native_dropout(x, p, train) + rng = torch.get_rng_state() + # dont need mask here, so dont call torch.ops, torch.dropout is faster + out = torch.dropout(x, p, train) if ctx.needs_input_grad[0]: ctx.p = p ctx.mask = mask @@ -21,11 +23,17 @@ def backward(ctx, grad_output): grad_x = None if ctx.needs_input_grad[0]: - grad_x = torch.ops.aten.native_dropout_backward( - grad_output, ctx.mask, scale=1 / (1 - ctx.p) - ) - - return grad_x + orig_rng = torch.get_rng_state() + torch.set_rng_state(ctx.rng) + mask = torch.empty_like(grad_output) + mask = mask.bernoulli_(0.5).bool() + torch.set_rng_state(orig_rng) + grad_x = grad_output * mask / (1 - ctx.p) + # grad_x = torch.ops.aten.native_dropout_backward( + # grad_output, mask, scale=1 / (1 - ctx.p) + # ) + + return grad_x, None, None def dropoutMemSave(x, p, training) -> torch.Tensor: diff --git a/pyproject.toml b/pyproject.toml index 762e45d..1183b68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,10 +34,12 @@ test = [ 'pytest', 'pytest-cov', 'pytest-optional-tests', + "transformers" ] exp = [ "codetiming", - "memory_profiler" + "memory_profiler", + "transformers>=4.41" ] [tool.pytest.ini_options] diff --git a/test/test_layers.py b/test/test_layers.py index 549e7cc..74fd46f 100644 --- a/test/test_layers.py +++ b/test/test_layers.py @@ -4,6 +4,7 @@ import pytest import torch +import transformers import memsave_torch @@ -14,44 +15,66 @@ torch.manual_seed(239) cases = [ - {"layer_fn": lambda: torch.nn.Linear(3, 5), "data_fn": lambda: torch.rand(7, 3)}, { + "name": "Linear1dims", + "layer_fn": lambda: torch.nn.Linear(3, 5), + "data_fn": lambda: torch.rand(7, 3), + }, + { + "name": "Linear2dims", "layer_fn": lambda: torch.nn.Linear(3, 5), "data_fn": lambda: torch.rand(7, 12, 3), # weight sharing }, { + "name": "Linear3dims", "layer_fn": lambda: torch.nn.Linear(3, 5), "data_fn": lambda: torch.rand(7, 12, 12, 3), # weight sharing }, { + "name": "Conv2d", "layer_fn": lambda: torch.nn.Conv2d(3, 5, 3), "data_fn": lambda: torch.rand(7, 3, 12, 12), }, { + "name": "Conv1d", "layer_fn": lambda: torch.nn.Conv1d(3, 5, 3), "data_fn": lambda: torch.rand(7, 3, 12), }, { + "name": "BatchNorm2d", "layer_fn": lambda: torch.nn.BatchNorm2d(3), "data_fn": lambda: torch.rand(7, 3, 12, 12), }, { + "name": "LayerNorm", "layer_fn": lambda: torch.nn.LayerNorm([3, 12, 12]), "data_fn": lambda: torch.rand(7, 3, 12, 12), }, { + # TODO: add testing for dropout (save and load rng state) + # { + # "name": "Dropout" + # "layer_fn": lambda: torch.nn.Dropout(), + # "data_fn": lambda: torch.rand(7, 3, 12, 12), + # }, + { + "name": "MaxPool2d", "layer_fn": lambda: torch.nn.MaxPool2d(3), "data_fn": lambda: torch.rand(7, 3, 12, 12), }, - {"layer_fn": lambda: torch.nn.ReLU(), "data_fn": lambda: torch.rand(7, 3, 12, 12)}, + { + "name": "ReLU", + "layer_fn": lambda: torch.nn.ReLU(), + "data_fn": lambda: torch.rand(7, 3, 12, 12), + }, ] @pytest.mark.quick -@pytest.mark.parametrize("case", cases) +@pytest.mark.parametrize("case", cases, ids=[case["name"] for case in cases]) @pytest.mark.parametrize("device", devices) def test_single_layer( - case: Dict[str, Callable[[], Union[torch.Tensor, torch.nn.Module]]], + case: Dict[str, Union[str, Callable[[], Union[torch.Tensor, torch.nn.Module]]]], device: str, ): """Runs tests for the layer_cls defined by `layer`