diff --git a/.gitignore b/.gitignore index 6ad7868..7f71819 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,8 @@ *.txt *.csv !requirements.txt +torchviz-output/ +torchview-output/ # generated docs docs_src/_build/ diff --git a/experiments/get_best_results.py b/experiments/get_best_results.py index 1321679..2146a4f 100644 --- a/experiments/get_best_results.py +++ b/experiments/get_best_results.py @@ -7,7 +7,7 @@ import pandas as pd -from experiments.util.collect_results import case_mapping +from experiments.util.collect_results import case_inv_mapping def main(base_dir: str): @@ -16,17 +16,25 @@ def main(base_dir: str): Args: base_dir (str): The base results dir """ - for device, arch in product(["cuda", "cpu"], ["linear", "conv"]): + # Don't recognize None as NaN + custom_na_values = pd._libs.parsers.STR_NA_VALUES - {"None"} + for device, arch in product(["cuda", "cpu"], ["linear", "conv", "transformer"]): # usage stats df = None idx_col = ["model", "case"] for fname in glob(os.path.join(base_dir, f"usage_stats-{arch}-{device}-*.csv")): with open(fname) as f: - f.readline() - temp_df = pd.read_csv(f, index_col=idx_col) + # f.readline() + temp_df = pd.read_csv( + f, + index_col=idx_col, + header=1, + na_values=custom_na_values, + keep_default_na=False, + ) df = temp_df if df is None else pd.concat([df, temp_df]) if df is not None: - df = df.rename(index=case_mapping, level=1) + df = df.rename(index=case_inv_mapping, level=1) df["Memory Usage (GB)"] = df["Memory Usage (MB)"] / 1024 df = df.drop(columns=["Memory Usage (MB)"]) best_results = df.groupby(idx_col).min() diff --git a/experiments/paper_demo.py b/experiments/paper_demo.py index 27b9b87..9c5d53e 100644 --- a/experiments/paper_demo.py +++ b/experiments/paper_demo.py @@ -14,7 +14,7 @@ from experiments.util.models import prefix_in_pairs estimators = ["time", "memory"] -estimators = ["memory"] +# estimators = ["memory"] # estimators = ["time"] # improvements can be either speedups or savings based on context @@ -27,70 +27,68 @@ # repeat the experiment multiple times (generates multiple files to be aggregated by `get_best_results`) n_repeat = 5 +batchnorm_eval = True # BatchNorm in eval mode -# CONV +# ============== CONV CONFIG ============== # Valid choices for models are in models.conv_model_fns +# models = [ +# "deepmodel", +# "resnet101", +# "resnet18", +# "vgg16", # "convnext_base", +# "fasterrcnn_resnet50_fpn_v2", +# "ssdlite320_mobilenet_v3_large", # "retinanet_resnet50_fpn_v2", +# "deeplabv3_resnet101", +# "fcn_resnet101", +# "efficientnet_v2_l", +# "mobilenet_v3_large", +# "resnext101_64x4d", +# ] +# models = prefix_in_pairs("memsave_", models) +# # models = ["resnet101", "memsave_resnet101_conv", "memsave_resnet101_conv+relu+bn", "memsave_resnet101_conv_full"] +# batch_size = 64 +# input_channels = 3 +# input_HW = 224 +# num_classes = 1000 +# device = "cuda" +# architecture = "conv" +# cases = collect_results.select_cases(['All', 'Input', 'Conv', 'Norm']) + +# ============== TRANSFORMER CONFIG ============== +# Valid choices for models are in models.transformer_model_fns models = [ - "deepmodel", - "resnet101", - "resnet18", - "vgg16", # "convnext_base", - "fasterrcnn_resnet50_fpn_v2", - "ssdlite320_mobilenet_v3_large", # "retinanet_resnet50_fpn_v2", - "deeplabv3_resnet101", - "fcn_resnet101", - "efficientnet_v2_l", - "mobilenet_v3_large", - "resnext101_64x4d", + "transformer", + "gpt2", + "bert", + "bart", + "roberta", + "t5", + "flan-t5", + # "xlm-roberta", + "mistral-7b", + "llama3-8b", + "phi3-4b", ] - -# models = ["resnet101", "memsave_resnet101_conv", "memsave_resnet101_conv+relu+bn", "memsave_resnet101_conv_full"] -# models = ["resnet101", "memsave_resnet101_conv_full"] - models = prefix_in_pairs("memsave_", models) -# models = ["memsave_resnet101"] batch_size = 64 -input_channels = 3 -input_HW = 224 -num_classes = 1000 +input_channels = 2048 +input_HW = 256 +num_classes = 5000 device = "cuda" -architecture = "conv" +architecture = "transformer" +cases = collect_results.select_cases(["All", "Input", "Norm"]) -# LINEAR +# ============== LINEAR CONFIG ============== # Valid choices for models are in models.linear_model_fns -# models = ['deeplinearmodel'] +# models = ["deeplinearmodel"] # models += [f"memsave_{m}" for m in models] # add memsave versions for each model -# batch_size = 32768 -# input_channels = 3 -# input_HW = 64 +# batch_size = 1024 +# input_channels = 3 # ignored +# input_HW = 64 # square of this is passed in estimate.py # num_classes = 1000 -# device = 'cuda' -# architecture = 'linear' # use high batch size - -cases = [ - None, # ALL - [ # INPUT - "grad_input", - "no_grad_conv_weights", - "no_grad_conv_bias", - "no_grad_linear_weights", - "no_grad_linear_bias", - "no_grad_norm_weights", - "no_grad_norm_bias", - ], - [ # CONV - "no_grad_linear_weights", - "no_grad_linear_bias", - "no_grad_norm_weights", - "no_grad_norm_bias", - ], - [ # NORM - "no_grad_conv_weights", - "no_grad_conv_bias", - "no_grad_linear_weights", - "no_grad_linear_bias", - ], -] +# device = "cuda" +# architecture = "linear" # use high batch size +# cases = collect_results.select_cases(["All", "Input", "Linear"]) if __name__ == "__main__": @@ -108,19 +106,30 @@ cases, "results", ) + bn_eval_str = "--bn_eval" if batchnorm_eval else "" for model in models: + B = batch_size + if model in prefix_in_pairs("memsave_", ["flan-t5"]): + B = 56 + if model in prefix_in_pairs("memsave_", ["mistral-7b", "phi3-4b"]): + B = 16 + if model in prefix_in_pairs("memsave_", ["llama3-8b"]): + B = 8 for estimate in estimators: outputs = [] collector.clear_file(estimate) for case in cases: pbar.update() - pbar.set_description(f"{model} {estimate} case {case}") + case_display = collect_results.case_inv_mapping[ + collect_results.make_case_str(case) + ] case_str = f"--case {' '.join(case)}" if case is not None else "" + pbar.set_description(f"{model} {estimate} case {case_display}") cmd = ( f"python experiments/util/estimate.py --architecture {architecture} --model {model} --estimate {estimate} {case_str} " - + f"--device {device} -B {batch_size} -C_in {input_channels} -HW {input_HW} -n_class {num_classes}" + + f"--device {device} -B {B} -C_in {input_channels} -HW {input_HW} -n_class {num_classes} {bn_eval_str}" ) proc = subprocess.run(shlex.split(cmd), capture_output=True) assert proc.stderr in [ diff --git a/experiments/util/collect_results.py b/experiments/util/collect_results.py index ca64a8f..6d4681f 100644 --- a/experiments/util/collect_results.py +++ b/experiments/util/collect_results.py @@ -61,6 +61,18 @@ } +def select_cases(selected: List[str]) -> List[Union[List[str], None]]: + """Helper function to return cases selected by their names + + Args: + selected (List[str]): Which cases to select, strings can be keys of the cases table + + Returns: + List[Union[List[str], None]]: Selected cases + """ + return [cases[s] for s in selected] + + def make_case_str(case: Union[None, List[str]]) -> str: """Format case into a string @@ -73,6 +85,9 @@ def make_case_str(case: Union[None, List[str]]) -> str: return "None" if case is None else " + ".join(case) +case_inv_mapping = {make_case_str(v): k for k, v in cases.items()} + + def hyperparam_str(args: SimpleNamespace) -> str: """Format hyperparams into a string @@ -198,12 +213,15 @@ def _display_run( """ # print(f"{model} input ({input_channels},{input_HW},{input_HW}) {device}") # print('='*78) - s = f"{model} input ({self.batch_size},{self.input_channels},{self.input_HW},{self.input_HW}) {self.device}" + if self.architecture == "conv": + s = f"{model} input ({self.batch_size},{self.input_channels},{self.input_HW},{self.input_HW}) {self.device}" + elif self.architecture == "transformer": + s = f"{model} input ({self.batch_size},{self.input_HW},{self.input_channels}(or model hidden size)) {self.device}" print(s.center(78, "=")) for out, case in zip(outputs, self.cases): print( - f"{strings[estimate][1]} ({case_mapping[make_case_str(case)]}): {out:.3f}{strings[estimate][0]}" + f"{strings[estimate][1]} ({case_inv_mapping[make_case_str(case)]}): {out:.3f}{strings[estimate][0]}" ) # CODE ONLY APPLIES WITH OLD RUNDEMO.PY diff --git a/experiments/util/estimate.py b/experiments/util/estimate.py index 528f153..3c0dabe 100644 --- a/experiments/util/estimate.py +++ b/experiments/util/estimate.py @@ -66,7 +66,10 @@ def parse_case(case: Optional[List[str]]) -> Dict[str, bool]: def skip_case_check(args: argparse.Namespace) -> bool: - """Decide whether to skip the case (when case has grad_norm_* but model does not have any normalization layers) + """Decide whether to skip the case: + + 1. when case has grad_norm_* but model does not have any normalization layers + 2. when case has no_grad_embed_weights but no grad_input: there is a backward error (no input requires_grad) Args: args (argparse.Namespace): args @@ -85,6 +88,9 @@ def skip_case_check(args: argparse.Namespace) -> bool: for c in ["no_grad_norm_bias", "no_grad_norm_weights"]: if c not in args.case and args.model in models.models_without_norm and not is_surgical: invalid = True + # 2. + if "no_grad_embed_weights" in args.case and "grad_input" not in args.case: + invalid = True if invalid: if args.print: print("-1") @@ -232,7 +238,7 @@ def estimate_mem_savings( type=str, required=True, help="Which architecture to run", - choices=["conv", "linear"], + choices=["conv", "linear", "transformer", "VLM"], ) parser.add_argument( "--estimate", @@ -250,6 +256,12 @@ def estimate_mem_savings( help=f"Which case to run, allowed values are {allowed_cases}", ) parser.add_argument("--device", type=str, default="cpu", help="torch device name") + parser.add_argument( + "--bn_eval", + action="store_true", + default=False, + help="Set all BN layers to eval mode", + ) parser.add_argument( "--print", action="store_true", @@ -275,6 +287,7 @@ def estimate_mem_savings( input_shape = (args.input_channels, args.input_hw, args.input_hw) models.conv_input_shape = input_shape model_fn = models.conv_model_fns.get(args.model) + y_args = {"size": (batch_size,), "low": 0, "high": num_classes} assert ( model_fn is not None ), f"Conv model name {args.model} not found, must be one of {list(models.conv_model_fns.keys())}" @@ -282,16 +295,51 @@ def estimate_mem_savings( input_shape = [args.input_hw**2] models.linear_input_shape = input_shape[0] model_fn = models.linear_model_fns.get(args.model) + y_args = {"size": (batch_size,), "low": 0, "high": num_classes} assert ( model_fn is not None ), f"Linear model name {args.model} not found, must be one of {list(models.linear_model_fns.keys())}" + elif args.architecture == "transformer": + vocab_dim = args.num_classes + embed_dim = args.input_channels + seq_len = args.input_hw + model_fn = models.transformer_model_fns.get(args.model) + if args.model in models.hf_transformers_models: + model_fn_orig = model_fn + model_fn = lambda: models.TransformersModelWrapper( # noqa: E731 + model_fn_orig, args.model + ) + config = models.get_transformers_config(args.model) + # as per transformers.PretrainedConfig these 2 should be present in all models: + vocab_dim = config.vocab_size + embed_dim = config.hidden_size + models.transformer_input_shape = (vocab_dim, embed_dim) + input_shape = [seq_len, embed_dim] + y_args = {"size": (batch_size, seq_len), "low": 0, "high": vocab_dim} + assert ( + model_fn is not None + ), f"Transformer model name {args.model} not found, must be one of {list(models.transformer_model_fns.keys())}" + elif args.architecture == "VLM": + # model format: `vlm!!!` + # eg: `vlm!vit!transformer!memsave_gpt2` + is_vlm, vis_model, vis_model_arch, llm = args.model.split("!") + assert is_vlm == "vlm" + assert vis_model_arch in ["transformer", "conv"] + model_fn = lambda: models.VLM(vis_model, vis_model_arch, llm) # noqa: E731 + config = models.get_transformers_config(llm) + vocab_dim = config.vocab_size + embed_dim = config.hidden_size + seq_len = (args.input_hw // 16) ** 2 + y_args = {"size": (batch_size, seq_len), "low": 0, "high": vocab_dim} + input_shape = (args.input_channels, args.input_hw, args.input_hw) + models.conv_input_shape = input_shape loss_fn = CrossEntropyLoss manual_seed(0) # make deterministic x = rand(batch_size, *input_shape, device=dev) - y = randint(size=(batch_size,), low=0, high=num_classes, device=dev) + y = randint(**y_args, device=dev) targets = None if args.model in models.detection_models: # pred is a dictionary of losses @@ -318,6 +366,10 @@ def estimate_mem_savings( loss_fn_orig = loss_fn loss_fn = lambda: models.SegmentationLossWrapper(loss_fn_orig) # noqa: E731 + if args.bn_eval: + model_fn_orig_bn = model_fn + model_fn = lambda: models.set_BN_to_eval(model_fn_orig_bn()) # noqa: E731 + # warm-up # with redirect_stdout(open(devnull, "w")): # estimate_speedup(model_fn, loss_fn, x, y, dev, vjp_speedups[:1]) diff --git a/experiments/util/measurements.py b/experiments/util/measurements.py index 5f4b0f9..62747c2 100644 --- a/experiments/util/measurements.py +++ b/experiments/util/measurements.py @@ -28,19 +28,9 @@ ) from torchvision.models.convnext import LayerNorm2d from transformers import Conv1D -from transformers.models.llama.modeling_llama import LlamaRMSNorm -from transformers.models.mistral.modeling_mistral import MistralRMSNorm -from transformers.models.phi3.modeling_phi3 import Phi3RMSNorm -from transformers.models.t5.modeling_t5 import T5LayerNorm - -from memsave_torch.nn import ( - MemSaveBatchNorm2d, - MemSaveConv2d, - MemSaveLayerNorm, - MemSaveLinear, - MemSaveRMSLayerNorm, - RMSLayerNorm, -) + +from memsave_torch.nn.Conv2d import MemSaveConv2d +from memsave_torch.nn.Linear import MemSaveLinear def maybe_synchronize(dev: device): @@ -327,14 +317,6 @@ def separate_grad_arguments( LayerNorm, LayerNorm2d, MemSaveBatchNorm2d, - MemSaveLayerNorm, - RMSLayerNorm, - MemSaveRMSLayerNorm, - T5LayerNorm, - MistralRMSNorm, - LlamaRMSNorm, - Phi3RMSNorm, - ) embed = Embedding leafs, no_leafs = [], [] diff --git a/experiments/util/models.py b/experiments/util/models.py index 222c8c4..8fcfe91 100644 --- a/experiments/util/models.py +++ b/experiments/util/models.py @@ -2,12 +2,38 @@ import itertools import math -from typing import List, Tuple +import warnings +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple +import torch import torchvision.models as tvm -from torch.nn import Conv2d, Flatten, Linear, MaxPool2d, Module, ReLU, Sequential +from torch.nn import ( + BatchNorm2d, + Conv2d, + Flatten, + Linear, + MaxPool2d, + Module, + ReLU, + Sequential, + Transformer, + Unfold, + functional, +) +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForMaskedLM, + AutoModelForPreTraining, + AutoModelForSeq2SeqLM, + BartForConditionalGeneration, +) +from transformers import logging as tf_logging +from transformers import utils as tf_utils from memsave_torch.nn import ( + MemSaveBatchNorm2d, MemSaveConv2d, MemSaveLinear, convert_to_memory_saving, @@ -26,7 +52,7 @@ def prefix_in_pairs(prefix: str, it: List[str]) -> List[str]: Args: prefix (str): Prefix to be added - it (List[str]): Description + it (List[str]): The list to be prefixed Returns: List[str]: The output iterable with items prefixed in pairs @@ -73,6 +99,58 @@ def convert_to_memory_saving_defaultsoff( ) +def get_transformers_config(model_name: str) -> AutoConfig: + """Get the config for the given `model_name` from huggingface transformers. Handles memsave_ as well. + + Args: + model_name (str): Model name + + Returns: + AutoConfig: Config for given model + """ + if model_name.startswith("memsave_"): + model_name = model_name.split("memsave_")[1] + props = hf_transformers_models_map[model_name] + return AutoConfig.from_pretrained(props.hf_name, **props.extra_kwargs) + + +def get_arch_models(arch: str) -> Tuple[Dict[str, Callable], Any]: + """Get the dict of all defined functions for an architecture + + Args: + arch (str): The architecture + + Returns: + Tuple[Dict[str, Callable], Any]: Dict of all defined functions + + Raises: + ValueError: Invalid architecture + """ + if arch == "conv": + return conv_model_fns, conv_input_shape + if arch == "transformer": + return transformer_model_fns, transformer_input_shape + if arch == "linear": + return linear_model_fns, linear_input_shape + raise ValueError(f"arch={arch} not in allowed architectures") + + +def set_BN_to_eval(model: Module) -> Module: + """Sets all BatchNorm layers in the input `model` to eval mode (i.e. bn.eval()). + + Args: + model (Module): Input model + + Returns: + Module: Model with BN layers in eval mode + """ + known_bn_layers = (BatchNorm2d, MemSaveBatchNorm2d) + for layer in model.modules(): + if isinstance(layer, known_bn_layers): + layer.eval() + return model + + # CONV conv_input_shape: Tuple[int, int, int] = (1, 1, 1) @@ -213,6 +291,7 @@ def _convrelupool_model1(num_blocks=5) -> Module: "memsave_resnext101_64x4d": lambda: convert_to_memory_saving( tvm.resnext101_64x4d() ), + # For paper "memsave_resnet101_conv": lambda: convert_to_memory_saving_defaultsoff( tvm.resnet101(), conv2d=True ), @@ -267,6 +346,211 @@ def forward(self, loss_dict): return sum(loss_dict.values()) +# TRANSFORMER +transformer_input_shape: Tuple[int, int] = (1, 1) # (vocab_dim, embed_dim) + + +class _HF_model: + def __init__( + self, + hf_name: str, + extra_kwargs: Dict[str, Any], + model_cls: Any = AutoModelForCausalLM, + lm_head_name: Optional[str] = None, + ) -> None: + self.hf_name = hf_name + self.extra_kwargs = extra_kwargs + if self.extra_kwargs is None: + self.extra_kwargs = {} + self.model_cls = model_cls + self.lm_head_name = lm_head_name + + +tf_logging.disable_progress_bar() +tf_logging.set_verbosity_error() +tf_utils.logging.captureWarnings(True) + + +hf_transformers_models_map = { + "gpt2": _HF_model("gpt2", {}, lm_head_name="lm_head"), + "vit": _HF_model("facebook/vit-mae-base", {}, AutoModelForPreTraining), + "bert": _HF_model( + "google-bert/bert-base-uncased", + {"is_decoder": True}, + lm_head_name="cls.predictions.decoder", + ), + "bart": _HF_model( + "facebook/bart-base", {}, BartForConditionalGeneration, "lm_head" + ), + "roberta": _HF_model( + "FacebookAI/roberta-base", {"is_decoder": True}, lm_head_name="lm_head.decoder" + ), + "t5": _HF_model("google-t5/t5-base", {}, AutoModelForSeq2SeqLM, "lm_head"), + "flan-t5": _HF_model("google/flan-t5-base", {}, AutoModelForSeq2SeqLM, "lm_head"), + "xlm-roberta": _HF_model( + "FacebookAI/xlm-roberta-base", {}, AutoModelForMaskedLM, "lm_head.decoder" + ), + "mistral-7b": _HF_model( + "mistralai/Mistral-7B-v0.1", + {"torch_dtype": torch.bfloat16}, + lm_head_name="lm_head", + ), + "llama3-8b": _HF_model( + "meta-llama/Meta-Llama-3-8B", + {"torch_dtype": torch.bfloat16}, + lm_head_name="lm_head", + ), + "phi3-4b": _HF_model( + "microsoft/Phi-3-mini-4k-instruct", + {"torch_dtype": torch.bfloat16}, + lm_head_name="lm_head", + ), +} +hf_transformers_models = list(hf_transformers_models_map.keys()) +hf_transformers_models = prefix_in_pairs("memsave_", hf_transformers_models) + +transformer_model_fns = { + "transformer": lambda: TorchTransformer(), + "memsave_transformer": lambda: convert_to_memory_saving(TorchTransformer()), +} + +fused = lambda fn, name, kwargs: convert_to_memory_saving( # noqa: E731 + fn(name, **kwargs) +) + +for m in hf_transformers_models: + if m in transformer_model_fns: + continue + # Can't use lambdas in loops :') + if not m.startswith("memsave_"): + props = hf_transformers_models_map[m] + transformer_model_fns[m] = partial( + props.model_cls.from_pretrained, props.hf_name, **props.extra_kwargs + ) + else: + props = hf_transformers_models_map[m.split("memsave_", 1)[1]] + transformer_model_fns[m] = partial( + fused, props.model_cls.from_pretrained, props.hf_name, props.extra_kwargs + ) + + +class TorchTransformer(Module): + """Small model to wrap `torch.nn.Transformer`""" + + def __init__(self) -> None: + """Init""" + super().__init__() + self.transformer = Transformer( + d_model=transformer_input_shape[1], batch_first=True + ) + self.pred = Linear(transformer_input_shape[1], transformer_input_shape[0]) + + def forward(self, x): + """Forward + + Args: + x: x + + Returns: + output: model output + """ + out = self.transformer.decoder(x, self.transformer.encoder(x)) + return self.pred(out).permute(0, 2, 1) + + +class TransformersModelWrapper(Module): + """Small wrapper around `transformers` models to support interop with existing measurement code""" + + def __init__(self, model_fn, model_name) -> None: + """Init""" + super().__init__() + with warnings.catch_warnings(): + # hf does not keep quiet sometimes even when transformers.logging is set to errors only + # https://github.com/huggingface/transformers/issues/30618 + warnings.simplefilter("ignore") + self.model = model_fn() + self.dec = self.model.config.is_encoder_decoder + self.model_name = model_name + model_name_pure = model_name + if model_name.startswith("memsave_"): + model_name_pure = model_name.split("memsave_")[1] + self.lm_head_name = hf_transformers_models_map[model_name_pure].lm_head_name + + self.cache_kw = {"use_cache": False} + if any("ForMaskedLM" in a for a in self.model.config.architectures): + self.cache_kw = {} + + def forward(self, x): + """Forward + + Args: + x: x + + Returns: + output: model output + """ + if self.model.dtype != torch.float32: + x = x.to(self.model.dtype) + # HF takes care of converting logits to float32 + if self.dec: + out = self.model(inputs_embeds=x, decoder_inputs_embeds=x, **self.cache_kw) + else: + out = self.model(inputs_embeds=x, **self.cache_kw) + return out.logits.permute(0, 2, 1) + + +# VLM +class VLM(Module): + """Small wrapper for making a VLM model with transformer llm and conv/transformer vision model""" + + def __init__( + self, + vision_model_name: str, + vision_model_arch: str, + llm_name: str, + nc: int = 1000, + ) -> None: + """Init""" + super().__init__() + self.vision_model_name = vision_model_name + self.vm_arch = vision_model_arch + self.llm_name = llm_name + model_fns, input_shape = get_arch_models(vision_model_arch) + if vision_model_arch == "conv": + assert vision_model_name in segmentation_models + self.vm = model_fns[vision_model_name]() + self.llm = TransformersModelWrapper(transformer_model_fns[llm_name], llm_name) + vision_final_dim = 3 * 16 * 16 if vision_model_arch == "transformer" else nc + self.proj = Linear(vision_final_dim, self.llm.model.config.hidden_size) + self.patchify = Unfold(kernel_size=16, stride=16) + + def forward(self, x): + """Forward through vlm + + Args: + x: x + + Returns: + output: model output + """ + if self.vm_arch == "transformer" and self.vm.config.image_size != x.shape[-1]: + x = functional.interpolate( + x, size=self.vm.config.image_size, mode="bicubic" + ) + x = self.vm(x) + if self.vm_arch == "conv": + import ipdb + + ipdb.set_trace() + x = self.patchify(x["out"]).permute(0, 2, 1) + # [B, nc*n_patches, patch_size**2] + else: + x = x.logits + x = self.proj(x) + # [B, patch_size**2, llm_hidden] + return self.llm(x) + + # LINEAR linear_input_shape: int = 1 diff --git a/memsave_torch/nn/BatchNorm.py b/memsave_torch/nn/BatchNorm.py index a8de50c..79a1507 100644 --- a/memsave_torch/nn/BatchNorm.py +++ b/memsave_torch/nn/BatchNorm.py @@ -79,4 +79,5 @@ def from_nn_BatchNorm2d(cls, bn2d: nn.BatchNorm2d): obj.bias = bn2d.bias obj.running_mean = bn2d.running_mean obj.running_var = bn2d.running_var + obj.training = bn2d.training return obj diff --git a/memsave_torch/nn/Dropout.py b/memsave_torch/nn/Dropout.py new file mode 100644 index 0000000..5d682e5 --- /dev/null +++ b/memsave_torch/nn/Dropout.py @@ -0,0 +1,46 @@ +"""Implementation of a memory saving Dropout (sort of). + +This is done by not saving the whole input/output `float32` tensor and instead just saving the `bool` mask (8bit). +""" + +import torch +import torch.nn as nn + +from memsave_torch.nn.functional import dropoutMemSave + + +class MemSaveDropout(nn.Dropout): + """MemSaveDropout.""" + + def __init__(self, p=0.5): + """Inits a MemSaveDropout layer with the given params. + + Args: + p: Probability of elements being zeroed + """ + super().__init__(p) + + def forward(self, x): + """Forward pass. + + Args: + x: Input to the network + + Returns: + torch.Tensor: Output + """ + return dropoutMemSave(x, self.p, self.train) + + @classmethod + def from_nn_dropout(cls, dropout: nn.Dropout): + """Converts a nn.Dropout layer to MemSaveDropout. + + Args: + dropout : The nn.Dropout layer + + Returns: + obj: The MemSaveDropout object + """ + obj = cls(dropout.p) + return obj + diff --git a/memsave_torch/nn/Linear.py b/memsave_torch/nn/Linear.py index 9deb3c1..1ba78c6 100644 --- a/memsave_torch/nn/Linear.py +++ b/memsave_torch/nn/Linear.py @@ -3,10 +3,18 @@ This is done by not saving the inputs/weights if weight/inputs dont require grad. """ +import sys + import torch.nn as nn from memsave_torch.nn.functional import linearMemSave +transformers_imported = False +if "transformers" in sys.modules: + import transformers + + transformers_imported = True + class MemSaveLinear(nn.Linear): """MemSaveLinear.""" @@ -36,14 +44,21 @@ def forward(self, x): @classmethod def from_nn_Linear(cls, linear: nn.Linear): - """Converts a nn.Linear layer to MemSaveLinear. + """Converts a nn.Linear/transformers.Conv1D layer to MemSaveLinear. Args: - linear : The nn.Linear layer + linear : The nn.Linear/transformers.Conv1D layer Returns: obj: The MemSaveLinear object """ + isTransformersConv1D = False + if transformers_imported: + isTransformersConv1D = isinstance(linear, transformers.Conv1D) + if isTransformersConv1D: + # it only saves output features in the model (linear.nf); need to take input features from weight anyway + # weight and bias are still defined + linear.in_features, linear.out_features = linear.weight.shape obj = cls( linear.in_features, linear.out_features, @@ -51,6 +66,9 @@ def from_nn_Linear(cls, linear: nn.Linear): device=getattr(linear, "device", None), dtype=getattr(linear, "dtype", None), ) - obj.weight = linear.weight + if isTransformersConv1D: + obj.weight = nn.Parameter(linear.weight.T) + else: + obj.weight = linear.weight obj.bias = linear.bias return obj diff --git a/memsave_torch/nn/__init__.py b/memsave_torch/nn/__init__.py index b9bba81..4fbc756 100644 --- a/memsave_torch/nn/__init__.py +++ b/memsave_torch/nn/__init__.py @@ -6,17 +6,26 @@ - BatchNorm2d """ +import sys + import torch.nn as nn from memsave_torch.nn import functional # noqa: F401 from memsave_torch.nn.BatchNorm import MemSaveBatchNorm2d from memsave_torch.nn.Conv1d import MemSaveConv1d from memsave_torch.nn.Conv2d import MemSaveConv2d +from memsave_torch.nn.Dropout import MemSaveDropout from memsave_torch.nn.LayerNorm import MemSaveLayerNorm from memsave_torch.nn.Linear import MemSaveLinear from memsave_torch.nn.MaxPool import MemSaveMaxPool2d from memsave_torch.nn.ReLU import MemSaveReLU +transformers_imported = False +if "transformers" in sys.modules: + import transformers + + transformers_imported = True + def convert_to_memory_saving( model: nn.Module, @@ -27,6 +36,7 @@ def convert_to_memory_saving( relu=True, maxpool2d=True, layernorm=True, + dropout=True, verbose=False, clone_params=False, ) -> nn.Module: @@ -45,16 +55,20 @@ def convert_to_memory_saving( relu (bool, optional): Whether to replace `nn.ReLU` layers maxpool2d (bool, optional): Whether to replace `nn.MaxPool2d` layers layernorm (bool, optional): Whether to replace `nn.LayerNorm` layers + dropout (bool, optional): Whether to replace `nn.Dropout` layers verbose (bool, optional): Whether to print which layers were replaced clone_params (bool, optional): Whether to clone the layer parameters or use directly Returns: memsavemodel (nn.Module): The converted memory saving model """ + linear_cls = nn.Linear + if transformers_imported: + linear_cls = (nn.Linear, transformers.Conv1D) layers = [ { "allowed": linear, - "cls": nn.Linear, + "cls": linear_cls, "convert_fn": MemSaveLinear.from_nn_Linear, }, {"allowed": relu, "cls": nn.ReLU, "convert_fn": MemSaveReLU.from_nn_ReLU}, @@ -83,6 +97,11 @@ def convert_to_memory_saving( "cls": nn.LayerNorm, "convert_fn": MemSaveLayerNorm.from_nn_LayerNorm, }, + { + "allowed": dropout, + "cls": nn.Dropout, + "convert_fn": MemSaveDropout.from_nn_dropout, + }, ] import copy diff --git a/memsave_torch/nn/functional/BatchNorm.py b/memsave_torch/nn/functional/BatchNorm.py index 33efcf0..934b6ef 100644 --- a/memsave_torch/nn/functional/BatchNorm.py +++ b/memsave_torch/nn/functional/BatchNorm.py @@ -38,8 +38,12 @@ def forward( need_grad = [] # save_mean and save_invstd if ctx.needs_input_grad[0]: need_grad.append(weight) - if any(ctx.needs_input_grad): - need_grad.append(x) + if ctx.training: + if any(ctx.needs_input_grad): + need_grad.append(x) + else: + if ctx.needs_input_grad[3]: + need_grad.append(x) # bias doesnt need anything for calc ctx.save_for_backward(*need_grad) @@ -54,7 +58,8 @@ def backward(ctx, grad_output): if ctx.needs_input_grad[0]: weight = ctx.saved_tensors[current_idx] current_idx += 1 - x = ctx.saved_tensors[current_idx] + if ctx.training: + x = ctx.saved_tensors[current_idx] if ctx.needs_input_grad[3]: x = ctx.saved_tensors[current_idx] diff --git a/memsave_torch/nn/functional/Dropout.py b/memsave_torch/nn/functional/Dropout.py index d20e725..614a31e 100644 --- a/memsave_torch/nn/functional/Dropout.py +++ b/memsave_torch/nn/functional/Dropout.py @@ -10,10 +10,13 @@ class _MemSaveDropout(torch.autograd.Function): @staticmethod def forward(ctx, x, p, train): - out, mask = torch.ops.aten.native_dropout(x, p, train) + rng = torch.get_rng_state() + # dont need mask here, so dont call torch.ops, torch.dropout is faster + out = torch.dropout(x, p, train) if ctx.needs_input_grad[0]: ctx.p = p - ctx.mask = mask + ctx.train = train + ctx.rng = rng return out @staticmethod @@ -21,11 +24,17 @@ def backward(ctx, grad_output): grad_x = None if ctx.needs_input_grad[0]: - grad_x = torch.ops.aten.native_dropout_backward( - grad_output, ctx.mask, scale=1 / (1 - ctx.p) - ) - - return grad_x + orig_rng = torch.get_rng_state() + torch.set_rng_state(ctx.rng) + mask = torch.empty_like(grad_output) + mask = mask.bernoulli_(0.5).bool() + torch.set_rng_state(orig_rng) + grad_x = grad_output * mask / (1 - ctx.p) + # grad_x = torch.ops.aten.native_dropout_backward( + # grad_output, mask, scale=1 / (1 - ctx.p) + # ) + + return grad_x, None, None def dropoutMemSave(x, p, training) -> torch.Tensor: diff --git a/pyproject.toml b/pyproject.toml index 762e45d..1183b68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,10 +34,12 @@ test = [ 'pytest', 'pytest-cov', 'pytest-optional-tests', + "transformers" ] exp = [ "codetiming", - "memory_profiler" + "memory_profiler", + "transformers>=4.41" ] [tool.pytest.ini_options] diff --git a/test/test_layers.py b/test/test_layers.py index 549e7cc..74fd46f 100644 --- a/test/test_layers.py +++ b/test/test_layers.py @@ -4,6 +4,7 @@ import pytest import torch +import transformers import memsave_torch @@ -14,44 +15,66 @@ torch.manual_seed(239) cases = [ - {"layer_fn": lambda: torch.nn.Linear(3, 5), "data_fn": lambda: torch.rand(7, 3)}, { + "name": "Linear1dims", + "layer_fn": lambda: torch.nn.Linear(3, 5), + "data_fn": lambda: torch.rand(7, 3), + }, + { + "name": "Linear2dims", "layer_fn": lambda: torch.nn.Linear(3, 5), "data_fn": lambda: torch.rand(7, 12, 3), # weight sharing }, { + "name": "Linear3dims", "layer_fn": lambda: torch.nn.Linear(3, 5), "data_fn": lambda: torch.rand(7, 12, 12, 3), # weight sharing }, { + "name": "Conv2d", "layer_fn": lambda: torch.nn.Conv2d(3, 5, 3), "data_fn": lambda: torch.rand(7, 3, 12, 12), }, { + "name": "Conv1d", "layer_fn": lambda: torch.nn.Conv1d(3, 5, 3), "data_fn": lambda: torch.rand(7, 3, 12), }, { + "name": "BatchNorm2d", "layer_fn": lambda: torch.nn.BatchNorm2d(3), "data_fn": lambda: torch.rand(7, 3, 12, 12), }, { + "name": "LayerNorm", "layer_fn": lambda: torch.nn.LayerNorm([3, 12, 12]), "data_fn": lambda: torch.rand(7, 3, 12, 12), }, { + # TODO: add testing for dropout (save and load rng state) + # { + # "name": "Dropout" + # "layer_fn": lambda: torch.nn.Dropout(), + # "data_fn": lambda: torch.rand(7, 3, 12, 12), + # }, + { + "name": "MaxPool2d", "layer_fn": lambda: torch.nn.MaxPool2d(3), "data_fn": lambda: torch.rand(7, 3, 12, 12), }, - {"layer_fn": lambda: torch.nn.ReLU(), "data_fn": lambda: torch.rand(7, 3, 12, 12)}, + { + "name": "ReLU", + "layer_fn": lambda: torch.nn.ReLU(), + "data_fn": lambda: torch.rand(7, 3, 12, 12), + }, ] @pytest.mark.quick -@pytest.mark.parametrize("case", cases) +@pytest.mark.parametrize("case", cases, ids=[case["name"] for case in cases]) @pytest.mark.parametrize("device", devices) def test_single_layer( - case: Dict[str, Callable[[], Union[torch.Tensor, torch.nn.Module]]], + case: Dict[str, Union[str, Callable[[], Union[torch.Tensor, torch.nn.Module]]]], device: str, ): """Runs tests for the layer_cls defined by `layer`