From 5dd246e1427f12959dd5520a2c4ec5edccbaddff Mon Sep 17 00:00:00 2001 From: Plutonium-239 Date: Mon, 3 Jun 2024 03:08:22 +0530 Subject: [PATCH 1/3] add surgical fine tuning cases, batchnorm and layernorm improvements (cherry picked from commit 262a4bbfb569476eaf3f8ba395c15881e9f16b01) --- experiments/util/collect_results.py | 36 ++- experiments/util/estimate.py | 4 + experiments/util/measurements.py | 298 +++++++++++++++-------- memsave_torch/nn/functional/BatchNorm.py | 4 +- memsave_torch/nn/functional/LayerNorm.py | 4 +- 5 files changed, 235 insertions(+), 111 deletions(-) diff --git a/experiments/util/collect_results.py b/experiments/util/collect_results.py index 393fe0a..c90301e 100644 --- a/experiments/util/collect_results.py +++ b/experiments/util/collect_results.py @@ -27,11 +27,37 @@ ], } -case_mapping = { - "None": "All", - "grad_input + no_grad_conv_weights + no_grad_conv_bias + no_grad_linear_weights + no_grad_linear_bias + no_grad_norm_weights + no_grad_norm_bias": "Input", - "no_grad_linear_weights + no_grad_linear_bias + no_grad_norm_weights + no_grad_norm_bias": "Conv", - "no_grad_conv_weights + no_grad_conv_bias + no_grad_linear_weights + no_grad_linear_bias": "Norm", +cases = { + "All": None, # ALL + "Input": [ # INPUT + "grad_input", + "no_grad_conv_weights", + "no_grad_conv_bias", + "no_grad_linear_weights", + "no_grad_linear_bias", + "no_grad_norm_weights", + "no_grad_norm_bias", + ], + "Conv": [ # CONV + "no_grad_linear_weights", + "no_grad_linear_bias", + "no_grad_norm_weights", + "no_grad_norm_bias", + ], + "Linear": [ # LINEAR + "no_grad_conv_weights", + "no_grad_conv_bias", + "no_grad_norm_weights", + "no_grad_norm_bias", + ], + "Norm": [ # NORM + "no_grad_conv_weights", + "no_grad_conv_bias", + "no_grad_linear_weights", + "no_grad_linear_bias", + ], + "SurgicalFirst": ["surgical_first"], + "SurgicalLast": ["surgical_last"], } diff --git a/experiments/util/estimate.py b/experiments/util/estimate.py index 8486931..1036d1e 100644 --- a/experiments/util/estimate.py +++ b/experiments/util/estimate.py @@ -38,6 +38,10 @@ "no_grad_norm_bias", "grad_input", "no_grad_input", + "grad_embed_weights", + "no_grad_embed_weights", + "surgical_first", + "surgical_last", ] diff --git a/experiments/util/measurements.py b/experiments/util/measurements.py index 63cfa16..5f4b0f9 100644 --- a/experiments/util/measurements.py +++ b/experiments/util/measurements.py @@ -20,15 +20,27 @@ ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, + Embedding, LayerNorm, Linear, Module, Parameter, ) from torchvision.models.convnext import LayerNorm2d - -from memsave_torch.nn.Conv2d import MemSaveConv2d -from memsave_torch.nn.Linear import MemSaveLinear +from transformers import Conv1D +from transformers.models.llama.modeling_llama import LlamaRMSNorm +from transformers.models.mistral.modeling_mistral import MistralRMSNorm +from transformers.models.phi3.modeling_phi3 import Phi3RMSNorm +from transformers.models.t5.modeling_t5 import T5LayerNorm + +from memsave_torch.nn import ( + MemSaveBatchNorm2d, + MemSaveConv2d, + MemSaveLayerNorm, + MemSaveLinear, + MemSaveRMSLayerNorm, + RMSLayerNorm, +) def maybe_synchronize(dev: device): @@ -41,8 +53,8 @@ def maybe_synchronize(dev: device): cuda.synchronize() -class _Measurement: - """Base class for measurements.""" +class Measurement: + """Base class for measurements. This is meant to be subclassed and extended.""" def __init__( self, @@ -71,16 +83,52 @@ def __init__( self.targets = targets def set_up( - self, synchronize: bool = True - ) -> Tuple[Module, Module, Tensor, Tensor, Optional[List[Dict[str, Tensor]]]]: + self, + synchronize: bool = True, + grad_linear_weights: bool = True, + grad_linear_bias: bool = True, + grad_conv_weights: bool = True, + grad_conv_bias: bool = True, + grad_norm_weights: bool = True, + grad_norm_bias: bool = True, + grad_input: bool = False, + grad_embed_weights: bool = False, + surgical_first: bool = False, + surgical_last: bool = False, + ) -> Tuple[ + Module, + Module, + Tensor, + Tensor, + Optional[List[Dict[str, Tensor]]], + List[Tensor], + List[Tensor], + ]: """Initialize model and loss function, load to device (including data). Syncs CUDA threads if the device is a GPU to avoid leaking run time of this function into the measurement. Args: - synchronize: Whether to synchronize CUDA threads after loading the + synchronize (bool, optional): Whether to synchronize CUDA threads after loading the model, loss function, and data to the device. Default: `True`. + grad_linear_weights (bool, optional): Whether to compute the gradient of the linear + layer weights. Default: `True`. + grad_linear_bias (bool, optional): Whether to compute the gradient of the linear + layer bias. Default: `True`. + grad_conv_weights (bool, optional): Whether to compute the gradient of the convolution + layer weights. Default: `True`. + grad_conv_bias (bool, optional): Whether to compute the gradient of the convolution + layer bias. Default: `True`. + grad_norm_weights (bool, optional): Whether to compute the gradient of the normalization + layer weights. Default: `True`. + grad_norm_bias (bool, optional): Whether to compute the gradient of the normalization + layer bias. Default: `True`. + grad_input (bool, optional): Whether to compute the gradient of the input. Default: `False`. + grad_embed_weights (bool, optional): Whether to compute the gradient of the embedding + layer weights. Default: `True`. + surgical_first (bool, optional): Corresponds to computing gradient only for the first quarter of layers + surgical_last (bool, optional): Corresponds to computing gradient only for the last quarter of layers Returns: The model, loss function, input tensor, and output tensor. All are loaded @@ -98,67 +146,52 @@ def set_up( else: targets = None + if surgical_first or surgical_last: + leafs, no_leafs = separate_surgical(model, surgical_first, surgical_last) + else: + leafs, no_leafs = separate_grad_arguments( + model, + grad_linear_weights, + grad_linear_bias, + grad_conv_weights, + grad_conv_bias, + grad_norm_weights, + grad_norm_bias, + grad_embed_weights, + ) + leafs = ([x] if grad_input else []) + leafs + no_leafs = ([y] if grad_input else [x, y]) + no_leafs + # targets will never require grad + + # make leaves differentiable, turn off non-leafs + for leaf in leafs: + leaf.requires_grad_(True) + for no_leaf in no_leafs: + no_leaf.requires_grad_(False) + if synchronize: maybe_synchronize(self.dev) - return model, loss_fn, x, y, targets + return model, loss_fn, x, y, targets, leafs, no_leafs -class RuntimeMeasurement(_Measurement): +class RuntimeMeasurement(Measurement): """A class to perform run time measurements of forward+backward pass.""" - def forward_backward( - self, - grad_linear_weights: bool = True, - grad_linear_bias: bool = True, - grad_conv_weights: bool = True, - grad_conv_bias: bool = True, - grad_norm_weights: bool = True, - grad_norm_bias: bool = True, - grad_input: bool = False, - ) -> float: + def forward_backward(self, **case_kwargs) -> float: """Perform a forward and backward pass and return the run time. Syncs CUDA threads if the device is a GPU. + Note: We directly pass input embeddings to transformers so embed weights are never used and their + grad will be None. Args: - grad_linear_weights: Whether to compute the gradient of the linear - layer weights. Default: `True`. - grad_linear_bias: Whether to compute the gradient of the linear - layer bias. Default: `True`. - grad_conv_weights: Whether to compute the gradient of the convolution - layer weights. Default: `True`. - grad_conv_bias: Whether to compute the gradient of the convolution - layer bias. Default: `True`. - grad_norm_weights: Whether to compute the gradient of the normalization - layer weights. Default: `True`. - grad_norm_bias: Whether to compute the gradient of the normalization - layer bias. Default: `True`. - grad_input: Whether to compute the gradient of the input. Default: `False`. + **case_kwargs: Strings denoting which grads to compute and which to not, check docs of `Measurement.set_up()` Returns: - The run time in seconds. + float: The run time in seconds. """ - model, loss_fn, x, y, targets = self.set_up() - - leafs, no_leafs = separate_grad_arguments( - model, - grad_linear_weights, - grad_linear_bias, - grad_conv_weights, - grad_conv_bias, - grad_norm_weights, - grad_norm_bias, - ) - leafs = ([x] if grad_input else []) + leafs - no_leafs = ([y] if grad_input else [x, y]) + no_leafs - # targets will never require grad - - # make leaves differentiable, turn off non-leafs - for leaf in leafs: - leaf.requires_grad_(True) - for no_leaf in no_leafs: - no_leaf.requires_grad_(False) + model, loss_fn, x, y, targets, leafs, no_leafs = self.set_up(**case_kwargs) # obtain run time maybe_synchronize(self.dev) @@ -179,58 +212,22 @@ def forward_backward( return timer.last -class MemoryMeasurement(_Measurement): +class MemoryMeasurement(Measurement): """A class to measure memory usage after a forward pass.""" - def after_forward( - self, - grad_linear_weights: bool = True, - grad_linear_bias: bool = True, - grad_conv_weights: bool = True, - grad_conv_bias: bool = True, - grad_norm_weights: bool = True, - grad_norm_bias: bool = True, - grad_input: bool = False, - ) -> float: + def after_forward(self, **case_kwargs) -> float: """Return memory usage after a forward pass. + Note: We directly pass input embeddings to transformers so embed weights are never used and their + grad will be None. + Args: - grad_linear_weights: Whether to compute the gradient of the linear - layer weights. Default: `True`. - grad_linear_bias: Whether to compute the gradient of the linear - layer bias. Default: `True`. - grad_conv_weights: Whether to compute the gradient of the convolution - layer weights. Default: `True`. - grad_conv_bias: Whether to compute the gradient of the convolution - layer bias. Default: `True`. - grad_norm_weights: Whether to compute the gradient of the normalization - layer weights. Default: `True`. - grad_norm_bias: Whether to compute the gradient of the normalization - layer bias. Default: `True`. - grad_input: Whether to compute the gradient of the input. Default: `False`. + **case_kwargs: Strings denoting which grads to compute and which to not, check docs of `Measurement.set_up()` Returns: - The memory usage in bytes. + float: The memory usage in bytes. """ - model, loss_fn, x, y, targets = self.set_up() - - leafs, no_leafs = separate_grad_arguments( - model, - grad_linear_weights, - grad_linear_bias, - grad_conv_weights, - grad_conv_bias, - grad_norm_weights, - grad_norm_bias, - ) - leafs = ([x] if grad_input else []) + leafs - no_leafs = ([y] if grad_input else [x, y]) + no_leafs - - # make leaves differentiable, turn off non-leafs - for leaf in leafs: - leaf.requires_grad_(True) - for no_leaf in no_leafs: - no_leaf.requires_grad_(False) + model, loss_fn, x, y, targets, leafs, no_leafs = self.set_up(**case_kwargs) if str(self.dev) == "cpu": @@ -289,6 +286,7 @@ def separate_grad_arguments( grad_conv_bias: bool, grad_norm_weights: bool, grad_norm_bias: bool, + grad_embed_weights: bool, ) -> Tuple[List[Parameter], List[Parameter]]: """Separate the parameters of a model into leafs and non-leafs. @@ -303,6 +301,7 @@ def separate_grad_arguments( grad_norm_weights: Whether to compute the gradient of the normalization layer weights. grad_norm_bias: Whether to compute the gradient of the normalization layer bias. + grad_embed_weights: Whether to compute the gradient of the embedding layer weights Returns: A tuple of lists of parameters. The first list contains the leafs, the second @@ -311,7 +310,7 @@ def separate_grad_arguments( Raises: NotImplementedError: If an unknown layer with parameters is encountered. """ - linear = (Linear, MemSaveLinear) + linear = (Linear, MemSaveLinear, Conv1D) conv = ( Conv1d, Conv2d, @@ -321,7 +320,22 @@ def separate_grad_arguments( ConvTranspose3d, MemSaveConv2d, ) - norm = (BatchNorm1d, BatchNorm2d, BatchNorm3d, LayerNorm, LayerNorm2d) + norm = ( + BatchNorm1d, + BatchNorm2d, + BatchNorm3d, + LayerNorm, + LayerNorm2d, + MemSaveBatchNorm2d, + MemSaveLayerNorm, + RMSLayerNorm, + MemSaveRMSLayerNorm, + T5LayerNorm, + MistralRMSNorm, + LlamaRMSNorm, + Phi3RMSNorm, + ) + embed = Embedding leafs, no_leafs = [], [] @@ -334,10 +348,29 @@ def separate_layer(layer: Module, grad_weight: bool, grad_bias: bool): grad_bias: Whether to compute the gradient of the layer bias. """ leafs.append(layer.weight) if grad_weight else no_leafs.append(layer.weight) - if layer.bias is not None: + if "bias" in layer._parameters and layer.bias is not None: leafs.append(layer.bias) if grad_bias else no_leafs.append(layer.bias) - layers = [m for m in model.modules() if len(list(m.modules())) == 1] + def check_lm_head(n) -> bool: + """Checks if the module name n corresponds to an LM Head + + LM Head for transformers is not trainable (i.e. it is a module but it's weight is not a parameter) + and the weights are tied to the embedding layer + + Args: + n (str): name of the module + + Returns: + bool: Whether n is a LM head + """ + lm_head_name = getattr(model, "lm_head_name", None) + return lm_head_name is not None and lm_head_name in n + + layers = [ + m + for n, m in model.named_modules() + if len(list(m.modules())) == 1 and not check_lm_head(n) + ] for layer in layers: if isinstance(layer, linear): @@ -346,7 +379,68 @@ def separate_layer(layer: Module, grad_weight: bool, grad_bias: bool): separate_layer(layer, grad_conv_weights, grad_conv_bias) elif isinstance(layer, norm): separate_layer(layer, grad_norm_weights, grad_norm_bias) + elif isinstance(layer, embed): + separate_layer(layer, grad_embed_weights, False) elif list(layer.parameters()): raise NotImplementedError(f"Unknown layer with parameters: {layer}.") return leafs, no_leafs + + +def separate_surgical( + model: Module, surgical_first: bool, surgical_last: bool +) -> Tuple[List[Parameter], List[Parameter]]: + """Separate the parameters of a model into leafs and no-leafs for surgical fine-tuning + + One and only one of `surgical_first` and `surgical_last` must be True + Args: + model (Module): The model to separate the parameters of. + surgical_first (bool): Whether to compute the gradients of the first quarter of layers with parameters. + surgical_last (bool): Whether to compute the gradients of the last quarter of layers with parameters. + """ + assert ( + surgical_first ^ surgical_last + ), "One and only one of surgical_first and surgical_last must be True" + leafs, no_leafs = [], [] + counted_modules, total_modules = 0, 0 + + layers = [ + m + for n, m in model.named_modules() + if len(list(m.modules())) == 1 and list(m.parameters()) + ] + + total_modules = len(layers) + + def separate_layer(layer: Module, leaf: bool): + """Add parameters of layer to leafs or non-leafs. + + Args: + layer: The layer whose parameters to add to (non-)leafs. + leaf: Whether the layer is a leaf or not. + """ + leafs.append(layer.weight) if leaf else no_leafs.append(layer.weight) + if "bias" in layer._parameters and layer.bias is not None: + leafs.append(layer.bias) if leaf else no_leafs.append(layer.bias) + + def check_condition(counted, total): + if surgical_first: + return counted <= total / 4 + return counted >= 3 * total / 4 + + if surgical_last: + layers = layers[::-1] + # import ipdb; ipdb.set_trace() + for c in layers: + if not list(c.parameters()): + continue + if counted_modules <= total_modules / 4: + # Leaf + separate_layer(c, True) + counted_modules += 1 + else: + # No Leaf + separate_layer(c, False) + # counted_modules += surgical_last + + return leafs, no_leafs diff --git a/memsave_torch/nn/functional/BatchNorm.py b/memsave_torch/nn/functional/BatchNorm.py index f5be409..33efcf0 100644 --- a/memsave_torch/nn/functional/BatchNorm.py +++ b/memsave_torch/nn/functional/BatchNorm.py @@ -38,7 +38,7 @@ def forward( need_grad = [] # save_mean and save_invstd if ctx.needs_input_grad[0]: need_grad.append(weight) - if ctx.needs_input_grad[3]: + if any(ctx.needs_input_grad): need_grad.append(x) # bias doesnt need anything for calc @@ -54,9 +54,9 @@ def backward(ctx, grad_output): if ctx.needs_input_grad[0]: weight = ctx.saved_tensors[current_idx] current_idx += 1 + x = ctx.saved_tensors[current_idx] if ctx.needs_input_grad[3]: x = ctx.saved_tensors[current_idx] - current_idx += 1 if x is None: x = torch.zeros(ctx.x_shape, device=ctx.device) diff --git a/memsave_torch/nn/functional/LayerNorm.py b/memsave_torch/nn/functional/LayerNorm.py index 1f92081..6cf72f8 100644 --- a/memsave_torch/nn/functional/LayerNorm.py +++ b/memsave_torch/nn/functional/LayerNorm.py @@ -26,7 +26,7 @@ def forward(ctx, x, normalized_shape, weight, bias, eps): need_grad = [] # save_mean and save_invstd if ctx.needs_input_grad[0]: need_grad.append(weight) - if ctx.needs_input_grad[2]: + if any(ctx.needs_input_grad): need_grad.append(x) # bias doesnt need anything for calc @@ -41,9 +41,9 @@ def backward(ctx, grad_output): if ctx.needs_input_grad[0]: weight = ctx.saved_tensors[current_idx] current_idx += 1 + x = ctx.saved_tensors[current_idx] if ctx.needs_input_grad[3]: x = ctx.saved_tensors[current_idx] - current_idx += 1 if x is None: x = torch.zeros(ctx.x_shape, device=ctx.device) From 669c2503d800563e9b0df0aa459aa0ad744669df Mon Sep 17 00:00:00 2001 From: Plutonium-239 Date: Mon, 3 Jun 2024 03:26:22 +0530 Subject: [PATCH 2/3] surgical case add (cherry picked from commit fad9f25b7d344243e12a51b6dcdcea6c9e8cd3a6) --- experiments/util/collect_results.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experiments/util/collect_results.py b/experiments/util/collect_results.py index c90301e..ca64a8f 100644 --- a/experiments/util/collect_results.py +++ b/experiments/util/collect_results.py @@ -57,7 +57,7 @@ "no_grad_linear_bias", ], "SurgicalFirst": ["surgical_first"], - "SurgicalLast": ["surgical_last"], + "SurgicalLast": ["surgical_last", "grad_input"], } From 45146a390d56e7304808c3bc4a77d14891e9e250 Mon Sep 17 00:00:00 2001 From: Plutonium-239 Date: Mon, 3 Jun 2024 06:44:27 +0530 Subject: [PATCH 3/3] case invalidation minor fix (only affects models_without_norm) (cherry picked from commit 0b77085ff31563a485b2ce8da80d99d26892e9e9) --- experiments/util/estimate.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/experiments/util/estimate.py b/experiments/util/estimate.py index 1036d1e..528f153 100644 --- a/experiments/util/estimate.py +++ b/experiments/util/estimate.py @@ -77,11 +77,13 @@ def skip_case_check(args: argparse.Namespace) -> bool: invalid = False if args.case is None: return invalid + is_surgical = 'surgical_last' in args.case or 'surgical_first' in args.case + # 1. for c in ["grad_norm_bias", "grad_norm_weights"]: if c in args.case and args.model in models.models_without_norm: invalid = True for c in ["no_grad_norm_bias", "no_grad_norm_weights"]: - if c not in args.case and args.model in models.models_without_norm: + if c not in args.case and args.model in models.models_without_norm and not is_surgical: invalid = True if invalid: if args.print: