Skip to content

Commit

Permalink
Merge pull request #11 from plutonium-239/surgical
Browse files Browse the repository at this point in the history
merge surgical into main
  • Loading branch information
plutonium-239 authored Aug 22, 2024
2 parents 795d46e + 41480ba commit 1eecbfa
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 105 deletions.
2 changes: 2 additions & 0 deletions experiments/util/collect_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
"no_grad_linear_weights",
"no_grad_linear_bias",
],
"SurgicalFirst": ["surgical_first"],
"SurgicalLast": ["surgical_last", "grad_input"],
}


Expand Down
5 changes: 4 additions & 1 deletion experiments/util/estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
"no_grad_input",
"grad_embed_weights",
"no_grad_embed_weights",
"surgical_first",
"surgical_last",
]


Expand Down Expand Up @@ -78,12 +80,13 @@ def skip_case_check(args: argparse.Namespace) -> bool:
invalid = False
if args.case is None:
return invalid
is_surgical = 'surgical_last' in args.case or 'surgical_first' in args.case
# 1.
for c in ["grad_norm_bias", "grad_norm_weights"]:
if c in args.case and args.model in models.models_without_norm:
invalid = True
for c in ["no_grad_norm_bias", "no_grad_norm_weights"]:
if c not in args.case and args.model in models.models_without_norm:
if c not in args.case and args.model in models.models_without_norm and not is_surgical:
invalid = True
# 2.
if "no_grad_embed_weights" in args.case and "grad_input" not in args.case:
Expand Down
241 changes: 139 additions & 102 deletions experiments/util/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def maybe_synchronize(dev: device):
cuda.synchronize()


class _Measurement:
"""Base class for measurements."""
class Measurement:
"""Base class for measurements. This is meant to be subclassed and extended."""

def __init__(
self,
Expand Down Expand Up @@ -73,16 +73,52 @@ def __init__(
self.targets = targets

def set_up(
self, synchronize: bool = True
) -> Tuple[Module, Module, Tensor, Tensor, Optional[List[Dict[str, Tensor]]]]:
self,
synchronize: bool = True,
grad_linear_weights: bool = True,
grad_linear_bias: bool = True,
grad_conv_weights: bool = True,
grad_conv_bias: bool = True,
grad_norm_weights: bool = True,
grad_norm_bias: bool = True,
grad_input: bool = False,
grad_embed_weights: bool = False,
surgical_first: bool = False,
surgical_last: bool = False,
) -> Tuple[
Module,
Module,
Tensor,
Tensor,
Optional[List[Dict[str, Tensor]]],
List[Tensor],
List[Tensor],
]:
"""Initialize model and loss function, load to device (including data).
Syncs CUDA threads if the device is a GPU to avoid leaking run time
of this function into the measurement.
Args:
synchronize: Whether to synchronize CUDA threads after loading the
synchronize (bool, optional): Whether to synchronize CUDA threads after loading the
model, loss function, and data to the device. Default: `True`.
grad_linear_weights (bool, optional): Whether to compute the gradient of the linear
layer weights. Default: `True`.
grad_linear_bias (bool, optional): Whether to compute the gradient of the linear
layer bias. Default: `True`.
grad_conv_weights (bool, optional): Whether to compute the gradient of the convolution
layer weights. Default: `True`.
grad_conv_bias (bool, optional): Whether to compute the gradient of the convolution
layer bias. Default: `True`.
grad_norm_weights (bool, optional): Whether to compute the gradient of the normalization
layer weights. Default: `True`.
grad_norm_bias (bool, optional): Whether to compute the gradient of the normalization
layer bias. Default: `True`.
grad_input (bool, optional): Whether to compute the gradient of the input. Default: `False`.
grad_embed_weights (bool, optional): Whether to compute the gradient of the embedding
layer weights. Default: `True`.
surgical_first (bool, optional): Corresponds to computing gradient only for the first quarter of layers
surgical_last (bool, optional): Corresponds to computing gradient only for the last quarter of layers
Returns:
The model, loss function, input tensor, and output tensor. All are loaded
Expand All @@ -100,73 +136,52 @@ def set_up(
else:
targets = None

if surgical_first or surgical_last:
leafs, no_leafs = separate_surgical(model, surgical_first, surgical_last)
else:
leafs, no_leafs = separate_grad_arguments(
model,
grad_linear_weights,
grad_linear_bias,
grad_conv_weights,
grad_conv_bias,
grad_norm_weights,
grad_norm_bias,
grad_embed_weights,
)
leafs = ([x] if grad_input else []) + leafs
no_leafs = ([y] if grad_input else [x, y]) + no_leafs
# targets will never require grad

# make leaves differentiable, turn off non-leafs
for leaf in leafs:
leaf.requires_grad_(True)
for no_leaf in no_leafs:
no_leaf.requires_grad_(False)

if synchronize:
maybe_synchronize(self.dev)

return model, loss_fn, x, y, targets
return model, loss_fn, x, y, targets, leafs, no_leafs


class RuntimeMeasurement(_Measurement):
class RuntimeMeasurement(Measurement):
"""A class to perform run time measurements of forward+backward pass."""

def forward_backward(
self,
grad_linear_weights: bool = True,
grad_linear_bias: bool = True,
grad_conv_weights: bool = True,
grad_conv_bias: bool = True,
grad_norm_weights: bool = True,
grad_norm_bias: bool = True,
grad_input: bool = False,
grad_embed_weights: bool = False,
) -> float:
def forward_backward(self, **case_kwargs) -> float:
"""Perform a forward and backward pass and return the run time.
Syncs CUDA threads if the device is a GPU.
Note: We directly pass input embeddings to transformers so embed weights are never used and their
grad will be None.
Args:
grad_linear_weights (bool, optional): Whether to compute the gradient of the linear
layer weights. Default: `True`.
grad_linear_bias (bool, optional): Whether to compute the gradient of the linear
layer bias. Default: `True`.
grad_conv_weights (bool, optional): Whether to compute the gradient of the convolution
layer weights. Default: `True`.
grad_conv_bias (bool, optional): Whether to compute the gradient of the convolution
layer bias. Default: `True`.
grad_norm_weights (bool, optional): Whether to compute the gradient of the normalization
layer weights. Default: `True`.
grad_norm_bias (bool, optional): Whether to compute the gradient of the normalization
layer bias. Default: `True`.
grad_input (bool, optional): Whether to compute the gradient of the input. Default: `False`.
grad_embed_weights (bool, optional): Whether to compute the gradient of the embedding
layer weights. Default: `True`.
**case_kwargs: Strings denoting which grads to compute and which to not, check docs of `Measurement.set_up()`
Returns:
float: The run time in seconds.
"""
model, loss_fn, x, y, targets = self.set_up()

leafs, no_leafs = separate_grad_arguments(
model,
grad_linear_weights,
grad_linear_bias,
grad_conv_weights,
grad_conv_bias,
grad_norm_weights,
grad_norm_bias,
grad_embed_weights,
)
leafs = ([x] if grad_input else []) + leafs
no_leafs = ([y] if grad_input else [x, y]) + no_leafs
# targets will never require grad

# make leaves differentiable, turn off non-leafs
for leaf in leafs:
leaf.requires_grad_(True)
for no_leaf in no_leafs:
no_leaf.requires_grad_(False)
model, loss_fn, x, y, targets, leafs, no_leafs = self.set_up(**case_kwargs)

# obtain run time
maybe_synchronize(self.dev)
Expand All @@ -187,65 +202,22 @@ def forward_backward(
return timer.last


class MemoryMeasurement(_Measurement):
class MemoryMeasurement(Measurement):
"""A class to measure memory usage after a forward pass."""

def after_forward(
self,
grad_linear_weights: bool = True,
grad_linear_bias: bool = True,
grad_conv_weights: bool = True,
grad_conv_bias: bool = True,
grad_norm_weights: bool = True,
grad_norm_bias: bool = True,
grad_input: bool = False,
grad_embed_weights: bool = False,
) -> float:
def after_forward(self, **case_kwargs) -> float:
"""Return memory usage after a forward pass.
Note: We directly pass input embeddings to transformers so embed weights are never used and their
grad will be None.
Args:
grad_linear_weights: Whether to compute the gradient of the linear
layer weights. Default: `True`.
grad_linear_bias: Whether to compute the gradient of the linear
layer bias. Default: `True`.
grad_conv_weights: Whether to compute the gradient of the convolution
layer weights. Default: `True`.
grad_conv_bias: Whether to compute the gradient of the convolution
layer bias. Default: `True`.
grad_norm_weights: Whether to compute the gradient of the normalization
layer weights. Default: `True`.
grad_norm_bias: Whether to compute the gradient of the normalization
layer bias. Default: `True`.
grad_input: Whether to compute the gradient of the input. Default: `False`.
grad_embed_weights (bool, optional): Whether to compute the gradient of the embedding
layer weights. Default: `True`.
**case_kwargs: Strings denoting which grads to compute and which to not, check docs of `Measurement.set_up()`
Returns:
float: The memory usage in bytes.
"""
model, loss_fn, x, y, targets = self.set_up()

leafs, no_leafs = separate_grad_arguments(
model,
grad_linear_weights,
grad_linear_bias,
grad_conv_weights,
grad_conv_bias,
grad_norm_weights,
grad_norm_bias,
grad_embed_weights,
)
leafs = ([x] if grad_input else []) + leafs
no_leafs = ([y] if grad_input else [x, y]) + no_leafs

# make leaves differentiable, turn off non-leafs
for leaf in leafs:
leaf.requires_grad_(True)
for no_leaf in no_leafs:
no_leaf.requires_grad_(False)
model, loss_fn, x, y, targets, leafs, no_leafs = self.set_up(**case_kwargs)

if str(self.dev) == "cpu":

Expand Down Expand Up @@ -338,7 +310,13 @@ def separate_grad_arguments(
ConvTranspose3d,
MemSaveConv2d,
)
norm = (BatchNorm1d, BatchNorm2d, BatchNorm3d, LayerNorm, LayerNorm2d)
norm = (
BatchNorm1d,
BatchNorm2d,
BatchNorm3d,
LayerNorm,
LayerNorm2d,
MemSaveBatchNorm2d,
embed = Embedding

leafs, no_leafs = [], []
Expand Down Expand Up @@ -389,3 +367,62 @@ def check_lm_head(n) -> bool:
raise NotImplementedError(f"Unknown layer with parameters: {layer}.")

return leafs, no_leafs


def separate_surgical(
model: Module, surgical_first: bool, surgical_last: bool
) -> Tuple[List[Parameter], List[Parameter]]:
"""Separate the parameters of a model into leafs and no-leafs for surgical fine-tuning
One and only one of `surgical_first` and `surgical_last` must be True
Args:
model (Module): The model to separate the parameters of.
surgical_first (bool): Whether to compute the gradients of the first quarter of layers with parameters.
surgical_last (bool): Whether to compute the gradients of the last quarter of layers with parameters.
"""
assert (
surgical_first ^ surgical_last
), "One and only one of surgical_first and surgical_last must be True"
leafs, no_leafs = [], []
counted_modules, total_modules = 0, 0

layers = [
m
for n, m in model.named_modules()
if len(list(m.modules())) == 1 and list(m.parameters())
]

total_modules = len(layers)

def separate_layer(layer: Module, leaf: bool):
"""Add parameters of layer to leafs or non-leafs.
Args:
layer: The layer whose parameters to add to (non-)leafs.
leaf: Whether the layer is a leaf or not.
"""
leafs.append(layer.weight) if leaf else no_leafs.append(layer.weight)
if "bias" in layer._parameters and layer.bias is not None:
leafs.append(layer.bias) if leaf else no_leafs.append(layer.bias)

def check_condition(counted, total):
if surgical_first:
return counted <= total / 4
return counted >= 3 * total / 4

if surgical_last:
layers = layers[::-1]
# import ipdb; ipdb.set_trace()
for c in layers:
if not list(c.parameters()):
continue
if counted_modules <= total_modules / 4:
# Leaf
separate_layer(c, True)
counted_modules += 1
else:
# No Leaf
separate_layer(c, False)
# counted_modules += surgical_last

return leafs, no_leafs
4 changes: 2 additions & 2 deletions memsave_torch/nn/functional/LayerNorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def forward(ctx, x, normalized_shape, weight, bias, eps):
need_grad = [] # save_mean and save_invstd
if ctx.needs_input_grad[0]:
need_grad.append(weight)
if ctx.needs_input_grad[2]:
if any(ctx.needs_input_grad):
need_grad.append(x)
# bias doesnt need anything for calc

Expand All @@ -41,9 +41,9 @@ def backward(ctx, grad_output):
if ctx.needs_input_grad[0]:
weight = ctx.saved_tensors[current_idx]
current_idx += 1
x = ctx.saved_tensors[current_idx]
if ctx.needs_input_grad[3]:
x = ctx.saved_tensors[current_idx]
current_idx += 1

if x is None:
x = torch.zeros(ctx.x_shape, device=ctx.device)
Expand Down

0 comments on commit 1eecbfa

Please sign in to comment.