diff --git a/.gitignore b/.gitignore index 6ad7868..7f71819 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,8 @@ *.txt *.csv !requirements.txt +torchviz-output/ +torchview-output/ # generated docs docs_src/_build/ diff --git a/experiments/get_best_results.py b/experiments/get_best_results.py index 1321679..2146a4f 100644 --- a/experiments/get_best_results.py +++ b/experiments/get_best_results.py @@ -7,7 +7,7 @@ import pandas as pd -from experiments.util.collect_results import case_mapping +from experiments.util.collect_results import case_inv_mapping def main(base_dir: str): @@ -16,17 +16,25 @@ def main(base_dir: str): Args: base_dir (str): The base results dir """ - for device, arch in product(["cuda", "cpu"], ["linear", "conv"]): + # Don't recognize None as NaN + custom_na_values = pd._libs.parsers.STR_NA_VALUES - {"None"} + for device, arch in product(["cuda", "cpu"], ["linear", "conv", "transformer"]): # usage stats df = None idx_col = ["model", "case"] for fname in glob(os.path.join(base_dir, f"usage_stats-{arch}-{device}-*.csv")): with open(fname) as f: - f.readline() - temp_df = pd.read_csv(f, index_col=idx_col) + # f.readline() + temp_df = pd.read_csv( + f, + index_col=idx_col, + header=1, + na_values=custom_na_values, + keep_default_na=False, + ) df = temp_df if df is None else pd.concat([df, temp_df]) if df is not None: - df = df.rename(index=case_mapping, level=1) + df = df.rename(index=case_inv_mapping, level=1) df["Memory Usage (GB)"] = df["Memory Usage (MB)"] / 1024 df = df.drop(columns=["Memory Usage (MB)"]) best_results = df.groupby(idx_col).min() diff --git a/experiments/paper_demo.py b/experiments/paper_demo.py index 0312850..9c5d53e 100644 --- a/experiments/paper_demo.py +++ b/experiments/paper_demo.py @@ -52,7 +52,7 @@ # num_classes = 1000 # device = "cuda" # architecture = "conv" -# cases = collect_results.select_cases(['All', 'Input', 'Conv', 'Norm', 'SurgicalFirst', 'SurgicalLast']) +# cases = collect_results.select_cases(['All', 'Input', 'Conv', 'Norm']) # ============== TRANSFORMER CONFIG ============== # Valid choices for models are in models.transformer_model_fns diff --git a/experiments/util/collect_results.py b/experiments/util/collect_results.py index 393fe0a..d6a918b 100644 --- a/experiments/util/collect_results.py +++ b/experiments/util/collect_results.py @@ -27,14 +27,50 @@ ], } -case_mapping = { - "None": "All", - "grad_input + no_grad_conv_weights + no_grad_conv_bias + no_grad_linear_weights + no_grad_linear_bias + no_grad_norm_weights + no_grad_norm_bias": "Input", - "no_grad_linear_weights + no_grad_linear_bias + no_grad_norm_weights + no_grad_norm_bias": "Conv", - "no_grad_conv_weights + no_grad_conv_bias + no_grad_linear_weights + no_grad_linear_bias": "Norm", +cases = { + "All": None, # ALL + "Input": [ # INPUT + "grad_input", + "no_grad_conv_weights", + "no_grad_conv_bias", + "no_grad_linear_weights", + "no_grad_linear_bias", + "no_grad_norm_weights", + "no_grad_norm_bias", + ], + "Conv": [ # CONV + "no_grad_linear_weights", + "no_grad_linear_bias", + "no_grad_norm_weights", + "no_grad_norm_bias", + ], + "Linear": [ # LINEAR + "no_grad_conv_weights", + "no_grad_conv_bias", + "no_grad_norm_weights", + "no_grad_norm_bias", + ], + "Norm": [ # NORM + "no_grad_conv_weights", + "no_grad_conv_bias", + "no_grad_linear_weights", + "no_grad_linear_bias", + ], } +def select_cases(selected: List[str]) -> List[Union[List[str], None]]: + """Helper function to return cases selected by their names + + Args: + selected (List[str]): Which cases to select, strings can be keys of the cases table + + Returns: + List[Union[List[str], None]]: Selected cases + """ + return [cases[s] for s in selected] + + def make_case_str(case: Union[None, List[str]]) -> str: """Format case into a string @@ -47,6 +83,9 @@ def make_case_str(case: Union[None, List[str]]) -> str: return "None" if case is None else " + ".join(case) +case_inv_mapping = {make_case_str(v): k for k, v in cases.items()} + + def hyperparam_str(args: SimpleNamespace) -> str: """Format hyperparams into a string @@ -172,12 +211,15 @@ def _display_run( """ # print(f"{model} input ({input_channels},{input_HW},{input_HW}) {device}") # print('='*78) - s = f"{model} input ({self.batch_size},{self.input_channels},{self.input_HW},{self.input_HW}) {self.device}" + if self.architecture == "conv": + s = f"{model} input ({self.batch_size},{self.input_channels},{self.input_HW},{self.input_HW}) {self.device}" + elif self.architecture == "transformer": + s = f"{model} input ({self.batch_size},{self.input_HW},{self.input_channels}(or model hidden size)) {self.device}" print(s.center(78, "=")) for out, case in zip(outputs, self.cases): print( - f"{strings[estimate][1]} ({case_mapping[make_case_str(case)]}): {out:.3f}{strings[estimate][0]}" + f"{strings[estimate][1]} ({case_inv_mapping[make_case_str(case)]}): {out:.3f}{strings[estimate][0]}" ) # CODE ONLY APPLIES WITH OLD RUNDEMO.PY diff --git a/experiments/util/estimate.py b/experiments/util/estimate.py index 4d997c5..d3b8375 100644 --- a/experiments/util/estimate.py +++ b/experiments/util/estimate.py @@ -38,6 +38,8 @@ "no_grad_norm_bias", "grad_input", "no_grad_input", + "grad_embed_weights", + "no_grad_embed_weights", ] @@ -62,7 +64,10 @@ def parse_case(case: Optional[List[str]]) -> Dict[str, bool]: def skip_case_check(args: argparse.Namespace) -> bool: - """Decide whether to skip the case (when case has grad_norm_* but model does not have any normalization layers) + """Decide whether to skip the case: + + 1. when case has grad_norm_* but model does not have any normalization layers + 2. when case has no_grad_embed_weights but no grad_input: there is a backward error (no input requires_grad) Args: args (argparse.Namespace): args @@ -73,12 +78,16 @@ def skip_case_check(args: argparse.Namespace) -> bool: invalid = False if args.case is None: return invalid + # 1. for c in ["grad_norm_bias", "grad_norm_weights"]: if c in args.case and args.model in models.models_without_norm: invalid = True for c in ["no_grad_norm_bias", "no_grad_norm_weights"]: if c not in args.case and args.model in models.models_without_norm: invalid = True + # 2. + if "no_grad_embed_weights" in args.case and "grad_input" not in args.case: + invalid = True if invalid: if args.print: print("-1") @@ -226,7 +235,7 @@ def estimate_mem_savings( type=str, required=True, help="Which architecture to run", - choices=["conv", "linear"], + choices=["conv", "linear", "transformer", "VLM"], ) parser.add_argument( "--estimate", @@ -275,6 +284,7 @@ def estimate_mem_savings( input_shape = (args.input_channels, args.input_hw, args.input_hw) models.conv_input_shape = input_shape model_fn = models.conv_model_fns.get(args.model) + y_args = {"size": (batch_size,), "low": 0, "high": num_classes} assert ( model_fn is not None ), f"Conv model name {args.model} not found, must be one of {list(models.conv_model_fns.keys())}" @@ -282,16 +292,51 @@ def estimate_mem_savings( input_shape = [args.input_hw**2] models.linear_input_shape = input_shape[0] model_fn = models.linear_model_fns.get(args.model) + y_args = {"size": (batch_size,), "low": 0, "high": num_classes} assert ( model_fn is not None ), f"Linear model name {args.model} not found, must be one of {list(models.linear_model_fns.keys())}" + elif args.architecture == "transformer": + vocab_dim = args.num_classes + embed_dim = args.input_channels + seq_len = args.input_hw + model_fn = models.transformer_model_fns.get(args.model) + if args.model in models.hf_transformers_models: + model_fn_orig = model_fn + model_fn = lambda: models.TransformersModelWrapper( # noqa: E731 + model_fn_orig, args.model + ) + config = models.get_transformers_config(args.model) + # as per transformers.PretrainedConfig these 2 should be present in all models: + vocab_dim = config.vocab_size + embed_dim = config.hidden_size + models.transformer_input_shape = (vocab_dim, embed_dim) + input_shape = [seq_len, embed_dim] + y_args = {"size": (batch_size, seq_len), "low": 0, "high": vocab_dim} + assert ( + model_fn is not None + ), f"Transformer model name {args.model} not found, must be one of {list(models.transformer_model_fns.keys())}" + elif args.architecture == "VLM": + # model format: `vlm!!!` + # eg: `vlm!vit!transformer!memsave_gpt2` + is_vlm, vis_model, vis_model_arch, llm = args.model.split("!") + assert is_vlm == "vlm" + assert vis_model_arch in ["transformer", "conv"] + model_fn = lambda: models.VLM(vis_model, vis_model_arch, llm) # noqa: E731 + config = models.get_transformers_config(llm) + vocab_dim = config.vocab_size + embed_dim = config.hidden_size + seq_len = (args.input_hw // 16) ** 2 + y_args = {"size": (batch_size, seq_len), "low": 0, "high": vocab_dim} + input_shape = (args.input_channels, args.input_hw, args.input_hw) + models.conv_input_shape = input_shape loss_fn = CrossEntropyLoss manual_seed(0) # make deterministic x = rand(batch_size, *input_shape, device=dev) - y = randint(size=(batch_size,), low=0, high=num_classes, device=dev) + y = randint(**y_args, device=dev) targets = None if args.model in models.detection_models: # pred is a dictionary of losses diff --git a/experiments/util/measurements.py b/experiments/util/measurements.py index 63cfa16..321843d 100644 --- a/experiments/util/measurements.py +++ b/experiments/util/measurements.py @@ -20,12 +20,14 @@ ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, + Embedding, LayerNorm, Linear, Module, Parameter, ) from torchvision.models.convnext import LayerNorm2d +from transformers import Conv1D from memsave_torch.nn.Conv2d import MemSaveConv2d from memsave_torch.nn.Linear import MemSaveLinear @@ -116,28 +118,33 @@ def forward_backward( grad_norm_weights: bool = True, grad_norm_bias: bool = True, grad_input: bool = False, + grad_embed_weights: bool = False, ) -> float: """Perform a forward and backward pass and return the run time. Syncs CUDA threads if the device is a GPU. + Note: We directly pass input embeddings to transformers so embed weights are never used and their + grad will be None. Args: - grad_linear_weights: Whether to compute the gradient of the linear + grad_linear_weights (bool, optional): Whether to compute the gradient of the linear layer weights. Default: `True`. - grad_linear_bias: Whether to compute the gradient of the linear + grad_linear_bias (bool, optional): Whether to compute the gradient of the linear layer bias. Default: `True`. - grad_conv_weights: Whether to compute the gradient of the convolution + grad_conv_weights (bool, optional): Whether to compute the gradient of the convolution layer weights. Default: `True`. - grad_conv_bias: Whether to compute the gradient of the convolution + grad_conv_bias (bool, optional): Whether to compute the gradient of the convolution layer bias. Default: `True`. - grad_norm_weights: Whether to compute the gradient of the normalization + grad_norm_weights (bool, optional): Whether to compute the gradient of the normalization layer weights. Default: `True`. - grad_norm_bias: Whether to compute the gradient of the normalization + grad_norm_bias (bool, optional): Whether to compute the gradient of the normalization layer bias. Default: `True`. - grad_input: Whether to compute the gradient of the input. Default: `False`. + grad_input (bool, optional): Whether to compute the gradient of the input. Default: `False`. + grad_embed_weights (bool, optional): Whether to compute the gradient of the embedding + layer weights. Default: `True`. Returns: - The run time in seconds. + float: The run time in seconds. """ model, loss_fn, x, y, targets = self.set_up() @@ -149,6 +156,7 @@ def forward_backward( grad_conv_bias, grad_norm_weights, grad_norm_bias, + grad_embed_weights, ) leafs = ([x] if grad_input else []) + leafs no_leafs = ([y] if grad_input else [x, y]) + no_leafs @@ -191,9 +199,13 @@ def after_forward( grad_norm_weights: bool = True, grad_norm_bias: bool = True, grad_input: bool = False, + grad_embed_weights: bool = False, ) -> float: """Return memory usage after a forward pass. + Note: We directly pass input embeddings to transformers so embed weights are never used and their + grad will be None. + Args: grad_linear_weights: Whether to compute the gradient of the linear layer weights. Default: `True`. @@ -208,9 +220,11 @@ def after_forward( grad_norm_bias: Whether to compute the gradient of the normalization layer bias. Default: `True`. grad_input: Whether to compute the gradient of the input. Default: `False`. + grad_embed_weights (bool, optional): Whether to compute the gradient of the embedding + layer weights. Default: `True`. Returns: - The memory usage in bytes. + float: The memory usage in bytes. """ model, loss_fn, x, y, targets = self.set_up() @@ -222,6 +236,7 @@ def after_forward( grad_conv_bias, grad_norm_weights, grad_norm_bias, + grad_embed_weights, ) leafs = ([x] if grad_input else []) + leafs no_leafs = ([y] if grad_input else [x, y]) + no_leafs @@ -289,6 +304,7 @@ def separate_grad_arguments( grad_conv_bias: bool, grad_norm_weights: bool, grad_norm_bias: bool, + grad_embed_weights: bool, ) -> Tuple[List[Parameter], List[Parameter]]: """Separate the parameters of a model into leafs and non-leafs. @@ -303,6 +319,7 @@ def separate_grad_arguments( grad_norm_weights: Whether to compute the gradient of the normalization layer weights. grad_norm_bias: Whether to compute the gradient of the normalization layer bias. + grad_embed_weights: Whether to compute the gradient of the embedding layer weights Returns: A tuple of lists of parameters. The first list contains the leafs, the second @@ -311,7 +328,7 @@ def separate_grad_arguments( Raises: NotImplementedError: If an unknown layer with parameters is encountered. """ - linear = (Linear, MemSaveLinear) + linear = (Linear, MemSaveLinear, Conv1D) conv = ( Conv1d, Conv2d, @@ -322,6 +339,7 @@ def separate_grad_arguments( MemSaveConv2d, ) norm = (BatchNorm1d, BatchNorm2d, BatchNorm3d, LayerNorm, LayerNorm2d) + embed = Embedding leafs, no_leafs = [], [] @@ -334,10 +352,29 @@ def separate_layer(layer: Module, grad_weight: bool, grad_bias: bool): grad_bias: Whether to compute the gradient of the layer bias. """ leafs.append(layer.weight) if grad_weight else no_leafs.append(layer.weight) - if layer.bias is not None: + if "bias" in layer._parameters and layer.bias is not None: leafs.append(layer.bias) if grad_bias else no_leafs.append(layer.bias) - layers = [m for m in model.modules() if len(list(m.modules())) == 1] + def check_lm_head(n) -> bool: + """Checks if the module name n corresponds to an LM Head + + LM Head for transformers is not trainable (i.e. it is a module but it's weight is not a parameter) + and the weights are tied to the embedding layer + + Args: + n (str): name of the module + + Returns: + bool: Whether n is a LM head + """ + lm_head_name = getattr(model, "lm_head_name", None) + return lm_head_name is not None and lm_head_name in n + + layers = [ + m + for n, m in model.named_modules() + if len(list(m.modules())) == 1 and not check_lm_head(n) + ] for layer in layers: if isinstance(layer, linear): @@ -346,6 +383,8 @@ def separate_layer(layer: Module, grad_weight: bool, grad_bias: bool): separate_layer(layer, grad_conv_weights, grad_conv_bias) elif isinstance(layer, norm): separate_layer(layer, grad_norm_weights, grad_norm_bias) + elif isinstance(layer, embed): + separate_layer(layer, grad_embed_weights, False) elif list(layer.parameters()): raise NotImplementedError(f"Unknown layer with parameters: {layer}.") diff --git a/memsave_torch/nn/Dropout.py b/memsave_torch/nn/Dropout.py new file mode 100644 index 0000000..5d682e5 --- /dev/null +++ b/memsave_torch/nn/Dropout.py @@ -0,0 +1,46 @@ +"""Implementation of a memory saving Dropout (sort of). + +This is done by not saving the whole input/output `float32` tensor and instead just saving the `bool` mask (8bit). +""" + +import torch +import torch.nn as nn + +from memsave_torch.nn.functional import dropoutMemSave + + +class MemSaveDropout(nn.Dropout): + """MemSaveDropout.""" + + def __init__(self, p=0.5): + """Inits a MemSaveDropout layer with the given params. + + Args: + p: Probability of elements being zeroed + """ + super().__init__(p) + + def forward(self, x): + """Forward pass. + + Args: + x: Input to the network + + Returns: + torch.Tensor: Output + """ + return dropoutMemSave(x, self.p, self.train) + + @classmethod + def from_nn_dropout(cls, dropout: nn.Dropout): + """Converts a nn.Dropout layer to MemSaveDropout. + + Args: + dropout : The nn.Dropout layer + + Returns: + obj: The MemSaveDropout object + """ + obj = cls(dropout.p) + return obj + diff --git a/memsave_torch/nn/Linear.py b/memsave_torch/nn/Linear.py index 9deb3c1..1ba78c6 100644 --- a/memsave_torch/nn/Linear.py +++ b/memsave_torch/nn/Linear.py @@ -3,10 +3,18 @@ This is done by not saving the inputs/weights if weight/inputs dont require grad. """ +import sys + import torch.nn as nn from memsave_torch.nn.functional import linearMemSave +transformers_imported = False +if "transformers" in sys.modules: + import transformers + + transformers_imported = True + class MemSaveLinear(nn.Linear): """MemSaveLinear.""" @@ -36,14 +44,21 @@ def forward(self, x): @classmethod def from_nn_Linear(cls, linear: nn.Linear): - """Converts a nn.Linear layer to MemSaveLinear. + """Converts a nn.Linear/transformers.Conv1D layer to MemSaveLinear. Args: - linear : The nn.Linear layer + linear : The nn.Linear/transformers.Conv1D layer Returns: obj: The MemSaveLinear object """ + isTransformersConv1D = False + if transformers_imported: + isTransformersConv1D = isinstance(linear, transformers.Conv1D) + if isTransformersConv1D: + # it only saves output features in the model (linear.nf); need to take input features from weight anyway + # weight and bias are still defined + linear.in_features, linear.out_features = linear.weight.shape obj = cls( linear.in_features, linear.out_features, @@ -51,6 +66,9 @@ def from_nn_Linear(cls, linear: nn.Linear): device=getattr(linear, "device", None), dtype=getattr(linear, "dtype", None), ) - obj.weight = linear.weight + if isTransformersConv1D: + obj.weight = nn.Parameter(linear.weight.T) + else: + obj.weight = linear.weight obj.bias = linear.bias return obj diff --git a/memsave_torch/nn/__init__.py b/memsave_torch/nn/__init__.py index b9bba81..4fbc756 100644 --- a/memsave_torch/nn/__init__.py +++ b/memsave_torch/nn/__init__.py @@ -6,17 +6,26 @@ - BatchNorm2d """ +import sys + import torch.nn as nn from memsave_torch.nn import functional # noqa: F401 from memsave_torch.nn.BatchNorm import MemSaveBatchNorm2d from memsave_torch.nn.Conv1d import MemSaveConv1d from memsave_torch.nn.Conv2d import MemSaveConv2d +from memsave_torch.nn.Dropout import MemSaveDropout from memsave_torch.nn.LayerNorm import MemSaveLayerNorm from memsave_torch.nn.Linear import MemSaveLinear from memsave_torch.nn.MaxPool import MemSaveMaxPool2d from memsave_torch.nn.ReLU import MemSaveReLU +transformers_imported = False +if "transformers" in sys.modules: + import transformers + + transformers_imported = True + def convert_to_memory_saving( model: nn.Module, @@ -27,6 +36,7 @@ def convert_to_memory_saving( relu=True, maxpool2d=True, layernorm=True, + dropout=True, verbose=False, clone_params=False, ) -> nn.Module: @@ -45,16 +55,20 @@ def convert_to_memory_saving( relu (bool, optional): Whether to replace `nn.ReLU` layers maxpool2d (bool, optional): Whether to replace `nn.MaxPool2d` layers layernorm (bool, optional): Whether to replace `nn.LayerNorm` layers + dropout (bool, optional): Whether to replace `nn.Dropout` layers verbose (bool, optional): Whether to print which layers were replaced clone_params (bool, optional): Whether to clone the layer parameters or use directly Returns: memsavemodel (nn.Module): The converted memory saving model """ + linear_cls = nn.Linear + if transformers_imported: + linear_cls = (nn.Linear, transformers.Conv1D) layers = [ { "allowed": linear, - "cls": nn.Linear, + "cls": linear_cls, "convert_fn": MemSaveLinear.from_nn_Linear, }, {"allowed": relu, "cls": nn.ReLU, "convert_fn": MemSaveReLU.from_nn_ReLU}, @@ -83,6 +97,11 @@ def convert_to_memory_saving( "cls": nn.LayerNorm, "convert_fn": MemSaveLayerNorm.from_nn_LayerNorm, }, + { + "allowed": dropout, + "cls": nn.Dropout, + "convert_fn": MemSaveDropout.from_nn_dropout, + }, ] import copy diff --git a/memsave_torch/nn/functional/Dropout.py b/memsave_torch/nn/functional/Dropout.py index d20e725..614a31e 100644 --- a/memsave_torch/nn/functional/Dropout.py +++ b/memsave_torch/nn/functional/Dropout.py @@ -10,10 +10,13 @@ class _MemSaveDropout(torch.autograd.Function): @staticmethod def forward(ctx, x, p, train): - out, mask = torch.ops.aten.native_dropout(x, p, train) + rng = torch.get_rng_state() + # dont need mask here, so dont call torch.ops, torch.dropout is faster + out = torch.dropout(x, p, train) if ctx.needs_input_grad[0]: ctx.p = p - ctx.mask = mask + ctx.train = train + ctx.rng = rng return out @staticmethod @@ -21,11 +24,17 @@ def backward(ctx, grad_output): grad_x = None if ctx.needs_input_grad[0]: - grad_x = torch.ops.aten.native_dropout_backward( - grad_output, ctx.mask, scale=1 / (1 - ctx.p) - ) - - return grad_x + orig_rng = torch.get_rng_state() + torch.set_rng_state(ctx.rng) + mask = torch.empty_like(grad_output) + mask = mask.bernoulli_(0.5).bool() + torch.set_rng_state(orig_rng) + grad_x = grad_output * mask / (1 - ctx.p) + # grad_x = torch.ops.aten.native_dropout_backward( + # grad_output, mask, scale=1 / (1 - ctx.p) + # ) + + return grad_x, None, None def dropoutMemSave(x, p, training) -> torch.Tensor: diff --git a/pyproject.toml b/pyproject.toml index 762e45d..1183b68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,10 +34,12 @@ test = [ 'pytest', 'pytest-cov', 'pytest-optional-tests', + "transformers" ] exp = [ "codetiming", - "memory_profiler" + "memory_profiler", + "transformers>=4.41" ] [tool.pytest.ini_options] diff --git a/test/test_layers.py b/test/test_layers.py index 549e7cc..74fd46f 100644 --- a/test/test_layers.py +++ b/test/test_layers.py @@ -4,6 +4,7 @@ import pytest import torch +import transformers import memsave_torch @@ -14,44 +15,66 @@ torch.manual_seed(239) cases = [ - {"layer_fn": lambda: torch.nn.Linear(3, 5), "data_fn": lambda: torch.rand(7, 3)}, { + "name": "Linear1dims", + "layer_fn": lambda: torch.nn.Linear(3, 5), + "data_fn": lambda: torch.rand(7, 3), + }, + { + "name": "Linear2dims", "layer_fn": lambda: torch.nn.Linear(3, 5), "data_fn": lambda: torch.rand(7, 12, 3), # weight sharing }, { + "name": "Linear3dims", "layer_fn": lambda: torch.nn.Linear(3, 5), "data_fn": lambda: torch.rand(7, 12, 12, 3), # weight sharing }, { + "name": "Conv2d", "layer_fn": lambda: torch.nn.Conv2d(3, 5, 3), "data_fn": lambda: torch.rand(7, 3, 12, 12), }, { + "name": "Conv1d", "layer_fn": lambda: torch.nn.Conv1d(3, 5, 3), "data_fn": lambda: torch.rand(7, 3, 12), }, { + "name": "BatchNorm2d", "layer_fn": lambda: torch.nn.BatchNorm2d(3), "data_fn": lambda: torch.rand(7, 3, 12, 12), }, { + "name": "LayerNorm", "layer_fn": lambda: torch.nn.LayerNorm([3, 12, 12]), "data_fn": lambda: torch.rand(7, 3, 12, 12), }, { + # TODO: add testing for dropout (save and load rng state) + # { + # "name": "Dropout" + # "layer_fn": lambda: torch.nn.Dropout(), + # "data_fn": lambda: torch.rand(7, 3, 12, 12), + # }, + { + "name": "MaxPool2d", "layer_fn": lambda: torch.nn.MaxPool2d(3), "data_fn": lambda: torch.rand(7, 3, 12, 12), }, - {"layer_fn": lambda: torch.nn.ReLU(), "data_fn": lambda: torch.rand(7, 3, 12, 12)}, + { + "name": "ReLU", + "layer_fn": lambda: torch.nn.ReLU(), + "data_fn": lambda: torch.rand(7, 3, 12, 12), + }, ] @pytest.mark.quick -@pytest.mark.parametrize("case", cases) +@pytest.mark.parametrize("case", cases, ids=[case["name"] for case in cases]) @pytest.mark.parametrize("device", devices) def test_single_layer( - case: Dict[str, Callable[[], Union[torch.Tensor, torch.nn.Module]]], + case: Dict[str, Union[str, Callable[[], Union[torch.Tensor, torch.nn.Module]]]], device: str, ): """Runs tests for the layer_cls defined by `layer`