diff --git a/experiments/util/collect_results.py b/experiments/util/collect_results.py index d6a918b..6d4681f 100644 --- a/experiments/util/collect_results.py +++ b/experiments/util/collect_results.py @@ -56,6 +56,8 @@ "no_grad_linear_weights", "no_grad_linear_bias", ], + "SurgicalFirst": ["surgical_first"], + "SurgicalLast": ["surgical_last", "grad_input"], } diff --git a/experiments/util/estimate.py b/experiments/util/estimate.py index d3b8375..3c0dabe 100644 --- a/experiments/util/estimate.py +++ b/experiments/util/estimate.py @@ -40,6 +40,8 @@ "no_grad_input", "grad_embed_weights", "no_grad_embed_weights", + "surgical_first", + "surgical_last", ] @@ -78,12 +80,13 @@ def skip_case_check(args: argparse.Namespace) -> bool: invalid = False if args.case is None: return invalid + is_surgical = 'surgical_last' in args.case or 'surgical_first' in args.case # 1. for c in ["grad_norm_bias", "grad_norm_weights"]: if c in args.case and args.model in models.models_without_norm: invalid = True for c in ["no_grad_norm_bias", "no_grad_norm_weights"]: - if c not in args.case and args.model in models.models_without_norm: + if c not in args.case and args.model in models.models_without_norm and not is_surgical: invalid = True # 2. if "no_grad_embed_weights" in args.case and "grad_input" not in args.case: diff --git a/experiments/util/measurements.py b/experiments/util/measurements.py index 321843d..62747c2 100644 --- a/experiments/util/measurements.py +++ b/experiments/util/measurements.py @@ -43,8 +43,8 @@ def maybe_synchronize(dev: device): cuda.synchronize() -class _Measurement: - """Base class for measurements.""" +class Measurement: + """Base class for measurements. This is meant to be subclassed and extended.""" def __init__( self, @@ -73,16 +73,52 @@ def __init__( self.targets = targets def set_up( - self, synchronize: bool = True - ) -> Tuple[Module, Module, Tensor, Tensor, Optional[List[Dict[str, Tensor]]]]: + self, + synchronize: bool = True, + grad_linear_weights: bool = True, + grad_linear_bias: bool = True, + grad_conv_weights: bool = True, + grad_conv_bias: bool = True, + grad_norm_weights: bool = True, + grad_norm_bias: bool = True, + grad_input: bool = False, + grad_embed_weights: bool = False, + surgical_first: bool = False, + surgical_last: bool = False, + ) -> Tuple[ + Module, + Module, + Tensor, + Tensor, + Optional[List[Dict[str, Tensor]]], + List[Tensor], + List[Tensor], + ]: """Initialize model and loss function, load to device (including data). Syncs CUDA threads if the device is a GPU to avoid leaking run time of this function into the measurement. Args: - synchronize: Whether to synchronize CUDA threads after loading the + synchronize (bool, optional): Whether to synchronize CUDA threads after loading the model, loss function, and data to the device. Default: `True`. + grad_linear_weights (bool, optional): Whether to compute the gradient of the linear + layer weights. Default: `True`. + grad_linear_bias (bool, optional): Whether to compute the gradient of the linear + layer bias. Default: `True`. + grad_conv_weights (bool, optional): Whether to compute the gradient of the convolution + layer weights. Default: `True`. + grad_conv_bias (bool, optional): Whether to compute the gradient of the convolution + layer bias. Default: `True`. + grad_norm_weights (bool, optional): Whether to compute the gradient of the normalization + layer weights. Default: `True`. + grad_norm_bias (bool, optional): Whether to compute the gradient of the normalization + layer bias. Default: `True`. + grad_input (bool, optional): Whether to compute the gradient of the input. Default: `False`. + grad_embed_weights (bool, optional): Whether to compute the gradient of the embedding + layer weights. Default: `True`. + surgical_first (bool, optional): Corresponds to computing gradient only for the first quarter of layers + surgical_last (bool, optional): Corresponds to computing gradient only for the last quarter of layers Returns: The model, loss function, input tensor, and output tensor. All are loaded @@ -100,26 +136,39 @@ def set_up( else: targets = None + if surgical_first or surgical_last: + leafs, no_leafs = separate_surgical(model, surgical_first, surgical_last) + else: + leafs, no_leafs = separate_grad_arguments( + model, + grad_linear_weights, + grad_linear_bias, + grad_conv_weights, + grad_conv_bias, + grad_norm_weights, + grad_norm_bias, + grad_embed_weights, + ) + leafs = ([x] if grad_input else []) + leafs + no_leafs = ([y] if grad_input else [x, y]) + no_leafs + # targets will never require grad + + # make leaves differentiable, turn off non-leafs + for leaf in leafs: + leaf.requires_grad_(True) + for no_leaf in no_leafs: + no_leaf.requires_grad_(False) + if synchronize: maybe_synchronize(self.dev) - return model, loss_fn, x, y, targets + return model, loss_fn, x, y, targets, leafs, no_leafs -class RuntimeMeasurement(_Measurement): +class RuntimeMeasurement(Measurement): """A class to perform run time measurements of forward+backward pass.""" - def forward_backward( - self, - grad_linear_weights: bool = True, - grad_linear_bias: bool = True, - grad_conv_weights: bool = True, - grad_conv_bias: bool = True, - grad_norm_weights: bool = True, - grad_norm_bias: bool = True, - grad_input: bool = False, - grad_embed_weights: bool = False, - ) -> float: + def forward_backward(self, **case_kwargs) -> float: """Perform a forward and backward pass and return the run time. Syncs CUDA threads if the device is a GPU. @@ -127,46 +176,12 @@ def forward_backward( grad will be None. Args: - grad_linear_weights (bool, optional): Whether to compute the gradient of the linear - layer weights. Default: `True`. - grad_linear_bias (bool, optional): Whether to compute the gradient of the linear - layer bias. Default: `True`. - grad_conv_weights (bool, optional): Whether to compute the gradient of the convolution - layer weights. Default: `True`. - grad_conv_bias (bool, optional): Whether to compute the gradient of the convolution - layer bias. Default: `True`. - grad_norm_weights (bool, optional): Whether to compute the gradient of the normalization - layer weights. Default: `True`. - grad_norm_bias (bool, optional): Whether to compute the gradient of the normalization - layer bias. Default: `True`. - grad_input (bool, optional): Whether to compute the gradient of the input. Default: `False`. - grad_embed_weights (bool, optional): Whether to compute the gradient of the embedding - layer weights. Default: `True`. + **case_kwargs: Strings denoting which grads to compute and which to not, check docs of `Measurement.set_up()` Returns: float: The run time in seconds. """ - model, loss_fn, x, y, targets = self.set_up() - - leafs, no_leafs = separate_grad_arguments( - model, - grad_linear_weights, - grad_linear_bias, - grad_conv_weights, - grad_conv_bias, - grad_norm_weights, - grad_norm_bias, - grad_embed_weights, - ) - leafs = ([x] if grad_input else []) + leafs - no_leafs = ([y] if grad_input else [x, y]) + no_leafs - # targets will never require grad - - # make leaves differentiable, turn off non-leafs - for leaf in leafs: - leaf.requires_grad_(True) - for no_leaf in no_leafs: - no_leaf.requires_grad_(False) + model, loss_fn, x, y, targets, leafs, no_leafs = self.set_up(**case_kwargs) # obtain run time maybe_synchronize(self.dev) @@ -187,65 +202,22 @@ def forward_backward( return timer.last -class MemoryMeasurement(_Measurement): +class MemoryMeasurement(Measurement): """A class to measure memory usage after a forward pass.""" - def after_forward( - self, - grad_linear_weights: bool = True, - grad_linear_bias: bool = True, - grad_conv_weights: bool = True, - grad_conv_bias: bool = True, - grad_norm_weights: bool = True, - grad_norm_bias: bool = True, - grad_input: bool = False, - grad_embed_weights: bool = False, - ) -> float: + def after_forward(self, **case_kwargs) -> float: """Return memory usage after a forward pass. Note: We directly pass input embeddings to transformers so embed weights are never used and their grad will be None. Args: - grad_linear_weights: Whether to compute the gradient of the linear - layer weights. Default: `True`. - grad_linear_bias: Whether to compute the gradient of the linear - layer bias. Default: `True`. - grad_conv_weights: Whether to compute the gradient of the convolution - layer weights. Default: `True`. - grad_conv_bias: Whether to compute the gradient of the convolution - layer bias. Default: `True`. - grad_norm_weights: Whether to compute the gradient of the normalization - layer weights. Default: `True`. - grad_norm_bias: Whether to compute the gradient of the normalization - layer bias. Default: `True`. - grad_input: Whether to compute the gradient of the input. Default: `False`. - grad_embed_weights (bool, optional): Whether to compute the gradient of the embedding - layer weights. Default: `True`. + **case_kwargs: Strings denoting which grads to compute and which to not, check docs of `Measurement.set_up()` Returns: float: The memory usage in bytes. """ - model, loss_fn, x, y, targets = self.set_up() - - leafs, no_leafs = separate_grad_arguments( - model, - grad_linear_weights, - grad_linear_bias, - grad_conv_weights, - grad_conv_bias, - grad_norm_weights, - grad_norm_bias, - grad_embed_weights, - ) - leafs = ([x] if grad_input else []) + leafs - no_leafs = ([y] if grad_input else [x, y]) + no_leafs - - # make leaves differentiable, turn off non-leafs - for leaf in leafs: - leaf.requires_grad_(True) - for no_leaf in no_leafs: - no_leaf.requires_grad_(False) + model, loss_fn, x, y, targets, leafs, no_leafs = self.set_up(**case_kwargs) if str(self.dev) == "cpu": @@ -338,7 +310,13 @@ def separate_grad_arguments( ConvTranspose3d, MemSaveConv2d, ) - norm = (BatchNorm1d, BatchNorm2d, BatchNorm3d, LayerNorm, LayerNorm2d) + norm = ( + BatchNorm1d, + BatchNorm2d, + BatchNorm3d, + LayerNorm, + LayerNorm2d, + MemSaveBatchNorm2d, embed = Embedding leafs, no_leafs = [], [] @@ -389,3 +367,62 @@ def check_lm_head(n) -> bool: raise NotImplementedError(f"Unknown layer with parameters: {layer}.") return leafs, no_leafs + + +def separate_surgical( + model: Module, surgical_first: bool, surgical_last: bool +) -> Tuple[List[Parameter], List[Parameter]]: + """Separate the parameters of a model into leafs and no-leafs for surgical fine-tuning + + One and only one of `surgical_first` and `surgical_last` must be True + Args: + model (Module): The model to separate the parameters of. + surgical_first (bool): Whether to compute the gradients of the first quarter of layers with parameters. + surgical_last (bool): Whether to compute the gradients of the last quarter of layers with parameters. + """ + assert ( + surgical_first ^ surgical_last + ), "One and only one of surgical_first and surgical_last must be True" + leafs, no_leafs = [], [] + counted_modules, total_modules = 0, 0 + + layers = [ + m + for n, m in model.named_modules() + if len(list(m.modules())) == 1 and list(m.parameters()) + ] + + total_modules = len(layers) + + def separate_layer(layer: Module, leaf: bool): + """Add parameters of layer to leafs or non-leafs. + + Args: + layer: The layer whose parameters to add to (non-)leafs. + leaf: Whether the layer is a leaf or not. + """ + leafs.append(layer.weight) if leaf else no_leafs.append(layer.weight) + if "bias" in layer._parameters and layer.bias is not None: + leafs.append(layer.bias) if leaf else no_leafs.append(layer.bias) + + def check_condition(counted, total): + if surgical_first: + return counted <= total / 4 + return counted >= 3 * total / 4 + + if surgical_last: + layers = layers[::-1] + # import ipdb; ipdb.set_trace() + for c in layers: + if not list(c.parameters()): + continue + if counted_modules <= total_modules / 4: + # Leaf + separate_layer(c, True) + counted_modules += 1 + else: + # No Leaf + separate_layer(c, False) + # counted_modules += surgical_last + + return leafs, no_leafs diff --git a/memsave_torch/nn/functional/LayerNorm.py b/memsave_torch/nn/functional/LayerNorm.py index 1f92081..6cf72f8 100644 --- a/memsave_torch/nn/functional/LayerNorm.py +++ b/memsave_torch/nn/functional/LayerNorm.py @@ -26,7 +26,7 @@ def forward(ctx, x, normalized_shape, weight, bias, eps): need_grad = [] # save_mean and save_invstd if ctx.needs_input_grad[0]: need_grad.append(weight) - if ctx.needs_input_grad[2]: + if any(ctx.needs_input_grad): need_grad.append(x) # bias doesnt need anything for calc @@ -41,9 +41,9 @@ def backward(ctx, grad_output): if ctx.needs_input_grad[0]: weight = ctx.saved_tensors[current_idx] current_idx += 1 + x = ctx.saved_tensors[current_idx] if ctx.needs_input_grad[3]: x = ctx.saved_tensors[current_idx] - current_idx += 1 if x is None: x = torch.zeros(ctx.x_shape, device=ctx.device)