Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge surgical into main #11

Merged
merged 4 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions experiments/util/collect_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
"no_grad_linear_weights",
"no_grad_linear_bias",
],
"SurgicalFirst": ["surgical_first"],
"SurgicalLast": ["surgical_last", "grad_input"],
}


Expand Down
5 changes: 4 additions & 1 deletion experiments/util/estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
"no_grad_input",
"grad_embed_weights",
"no_grad_embed_weights",
"surgical_first",
"surgical_last",
]


Expand Down Expand Up @@ -78,12 +80,13 @@ def skip_case_check(args: argparse.Namespace) -> bool:
invalid = False
if args.case is None:
return invalid
is_surgical = 'surgical_last' in args.case or 'surgical_first' in args.case
# 1.
for c in ["grad_norm_bias", "grad_norm_weights"]:
if c in args.case and args.model in models.models_without_norm:
invalid = True
for c in ["no_grad_norm_bias", "no_grad_norm_weights"]:
if c not in args.case and args.model in models.models_without_norm:
if c not in args.case and args.model in models.models_without_norm and not is_surgical:
invalid = True
# 2.
if "no_grad_embed_weights" in args.case and "grad_input" not in args.case:
Expand Down
241 changes: 139 additions & 102 deletions experiments/util/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def maybe_synchronize(dev: device):
cuda.synchronize()


class _Measurement:
"""Base class for measurements."""
class Measurement:
"""Base class for measurements. This is meant to be subclassed and extended."""

def __init__(
self,
Expand Down Expand Up @@ -73,16 +73,52 @@ def __init__(
self.targets = targets

def set_up(
self, synchronize: bool = True
) -> Tuple[Module, Module, Tensor, Tensor, Optional[List[Dict[str, Tensor]]]]:
self,
synchronize: bool = True,
grad_linear_weights: bool = True,
grad_linear_bias: bool = True,
grad_conv_weights: bool = True,
grad_conv_bias: bool = True,
grad_norm_weights: bool = True,
grad_norm_bias: bool = True,
grad_input: bool = False,
grad_embed_weights: bool = False,
surgical_first: bool = False,
surgical_last: bool = False,
) -> Tuple[
Module,
Module,
Tensor,
Tensor,
Optional[List[Dict[str, Tensor]]],
List[Tensor],
List[Tensor],
]:
"""Initialize model and loss function, load to device (including data).

Syncs CUDA threads if the device is a GPU to avoid leaking run time
of this function into the measurement.

Args:
synchronize: Whether to synchronize CUDA threads after loading the
synchronize (bool, optional): Whether to synchronize CUDA threads after loading the
model, loss function, and data to the device. Default: `True`.
grad_linear_weights (bool, optional): Whether to compute the gradient of the linear
layer weights. Default: `True`.
grad_linear_bias (bool, optional): Whether to compute the gradient of the linear
layer bias. Default: `True`.
grad_conv_weights (bool, optional): Whether to compute the gradient of the convolution
layer weights. Default: `True`.
grad_conv_bias (bool, optional): Whether to compute the gradient of the convolution
layer bias. Default: `True`.
grad_norm_weights (bool, optional): Whether to compute the gradient of the normalization
layer weights. Default: `True`.
grad_norm_bias (bool, optional): Whether to compute the gradient of the normalization
layer bias. Default: `True`.
grad_input (bool, optional): Whether to compute the gradient of the input. Default: `False`.
grad_embed_weights (bool, optional): Whether to compute the gradient of the embedding
layer weights. Default: `True`.
surgical_first (bool, optional): Corresponds to computing gradient only for the first quarter of layers
surgical_last (bool, optional): Corresponds to computing gradient only for the last quarter of layers

Returns:
The model, loss function, input tensor, and output tensor. All are loaded
Expand All @@ -100,73 +136,52 @@ def set_up(
else:
targets = None

if surgical_first or surgical_last:
leafs, no_leafs = separate_surgical(model, surgical_first, surgical_last)
else:
leafs, no_leafs = separate_grad_arguments(
model,
grad_linear_weights,
grad_linear_bias,
grad_conv_weights,
grad_conv_bias,
grad_norm_weights,
grad_norm_bias,
grad_embed_weights,
)
leafs = ([x] if grad_input else []) + leafs
no_leafs = ([y] if grad_input else [x, y]) + no_leafs
# targets will never require grad

# make leaves differentiable, turn off non-leafs
for leaf in leafs:
leaf.requires_grad_(True)
for no_leaf in no_leafs:
no_leaf.requires_grad_(False)

if synchronize:
maybe_synchronize(self.dev)

return model, loss_fn, x, y, targets
return model, loss_fn, x, y, targets, leafs, no_leafs


class RuntimeMeasurement(_Measurement):
class RuntimeMeasurement(Measurement):
"""A class to perform run time measurements of forward+backward pass."""

def forward_backward(
self,
grad_linear_weights: bool = True,
grad_linear_bias: bool = True,
grad_conv_weights: bool = True,
grad_conv_bias: bool = True,
grad_norm_weights: bool = True,
grad_norm_bias: bool = True,
grad_input: bool = False,
grad_embed_weights: bool = False,
) -> float:
def forward_backward(self, **case_kwargs) -> float:
"""Perform a forward and backward pass and return the run time.

Syncs CUDA threads if the device is a GPU.
Note: We directly pass input embeddings to transformers so embed weights are never used and their
grad will be None.

Args:
grad_linear_weights (bool, optional): Whether to compute the gradient of the linear
layer weights. Default: `True`.
grad_linear_bias (bool, optional): Whether to compute the gradient of the linear
layer bias. Default: `True`.
grad_conv_weights (bool, optional): Whether to compute the gradient of the convolution
layer weights. Default: `True`.
grad_conv_bias (bool, optional): Whether to compute the gradient of the convolution
layer bias. Default: `True`.
grad_norm_weights (bool, optional): Whether to compute the gradient of the normalization
layer weights. Default: `True`.
grad_norm_bias (bool, optional): Whether to compute the gradient of the normalization
layer bias. Default: `True`.
grad_input (bool, optional): Whether to compute the gradient of the input. Default: `False`.
grad_embed_weights (bool, optional): Whether to compute the gradient of the embedding
layer weights. Default: `True`.
**case_kwargs: Strings denoting which grads to compute and which to not, check docs of `Measurement.set_up()`

Returns:
float: The run time in seconds.
"""
model, loss_fn, x, y, targets = self.set_up()

leafs, no_leafs = separate_grad_arguments(
model,
grad_linear_weights,
grad_linear_bias,
grad_conv_weights,
grad_conv_bias,
grad_norm_weights,
grad_norm_bias,
grad_embed_weights,
)
leafs = ([x] if grad_input else []) + leafs
no_leafs = ([y] if grad_input else [x, y]) + no_leafs
# targets will never require grad

# make leaves differentiable, turn off non-leafs
for leaf in leafs:
leaf.requires_grad_(True)
for no_leaf in no_leafs:
no_leaf.requires_grad_(False)
model, loss_fn, x, y, targets, leafs, no_leafs = self.set_up(**case_kwargs)

# obtain run time
maybe_synchronize(self.dev)
Expand All @@ -187,65 +202,22 @@ def forward_backward(
return timer.last


class MemoryMeasurement(_Measurement):
class MemoryMeasurement(Measurement):
"""A class to measure memory usage after a forward pass."""

def after_forward(
self,
grad_linear_weights: bool = True,
grad_linear_bias: bool = True,
grad_conv_weights: bool = True,
grad_conv_bias: bool = True,
grad_norm_weights: bool = True,
grad_norm_bias: bool = True,
grad_input: bool = False,
grad_embed_weights: bool = False,
) -> float:
def after_forward(self, **case_kwargs) -> float:
"""Return memory usage after a forward pass.

Note: We directly pass input embeddings to transformers so embed weights are never used and their
grad will be None.

Args:
grad_linear_weights: Whether to compute the gradient of the linear
layer weights. Default: `True`.
grad_linear_bias: Whether to compute the gradient of the linear
layer bias. Default: `True`.
grad_conv_weights: Whether to compute the gradient of the convolution
layer weights. Default: `True`.
grad_conv_bias: Whether to compute the gradient of the convolution
layer bias. Default: `True`.
grad_norm_weights: Whether to compute the gradient of the normalization
layer weights. Default: `True`.
grad_norm_bias: Whether to compute the gradient of the normalization
layer bias. Default: `True`.
grad_input: Whether to compute the gradient of the input. Default: `False`.
grad_embed_weights (bool, optional): Whether to compute the gradient of the embedding
layer weights. Default: `True`.
**case_kwargs: Strings denoting which grads to compute and which to not, check docs of `Measurement.set_up()`

Returns:
float: The memory usage in bytes.
"""
model, loss_fn, x, y, targets = self.set_up()

leafs, no_leafs = separate_grad_arguments(
model,
grad_linear_weights,
grad_linear_bias,
grad_conv_weights,
grad_conv_bias,
grad_norm_weights,
grad_norm_bias,
grad_embed_weights,
)
leafs = ([x] if grad_input else []) + leafs
no_leafs = ([y] if grad_input else [x, y]) + no_leafs

# make leaves differentiable, turn off non-leafs
for leaf in leafs:
leaf.requires_grad_(True)
for no_leaf in no_leafs:
no_leaf.requires_grad_(False)
model, loss_fn, x, y, targets, leafs, no_leafs = self.set_up(**case_kwargs)

if str(self.dev) == "cpu":

Expand Down Expand Up @@ -338,7 +310,13 @@ def separate_grad_arguments(
ConvTranspose3d,
MemSaveConv2d,
)
norm = (BatchNorm1d, BatchNorm2d, BatchNorm3d, LayerNorm, LayerNorm2d)
norm = (
BatchNorm1d,
BatchNorm2d,
BatchNorm3d,
LayerNorm,
LayerNorm2d,
MemSaveBatchNorm2d,
embed = Embedding

leafs, no_leafs = [], []
Expand Down Expand Up @@ -389,3 +367,62 @@ def check_lm_head(n) -> bool:
raise NotImplementedError(f"Unknown layer with parameters: {layer}.")

return leafs, no_leafs


def separate_surgical(
model: Module, surgical_first: bool, surgical_last: bool
) -> Tuple[List[Parameter], List[Parameter]]:
"""Separate the parameters of a model into leafs and no-leafs for surgical fine-tuning

One and only one of `surgical_first` and `surgical_last` must be True
Args:
model (Module): The model to separate the parameters of.
surgical_first (bool): Whether to compute the gradients of the first quarter of layers with parameters.
surgical_last (bool): Whether to compute the gradients of the last quarter of layers with parameters.
"""
assert (
surgical_first ^ surgical_last
), "One and only one of surgical_first and surgical_last must be True"
leafs, no_leafs = [], []
counted_modules, total_modules = 0, 0

layers = [
m
for n, m in model.named_modules()
if len(list(m.modules())) == 1 and list(m.parameters())
]

total_modules = len(layers)

def separate_layer(layer: Module, leaf: bool):
"""Add parameters of layer to leafs or non-leafs.

Args:
layer: The layer whose parameters to add to (non-)leafs.
leaf: Whether the layer is a leaf or not.
"""
leafs.append(layer.weight) if leaf else no_leafs.append(layer.weight)
if "bias" in layer._parameters and layer.bias is not None:
leafs.append(layer.bias) if leaf else no_leafs.append(layer.bias)

def check_condition(counted, total):
if surgical_first:
return counted <= total / 4
return counted >= 3 * total / 4

if surgical_last:
layers = layers[::-1]
# import ipdb; ipdb.set_trace()
for c in layers:
if not list(c.parameters()):
continue
if counted_modules <= total_modules / 4:
# Leaf
separate_layer(c, True)
counted_modules += 1
else:
# No Leaf
separate_layer(c, False)
# counted_modules += surgical_last

return leafs, no_leafs
4 changes: 2 additions & 2 deletions memsave_torch/nn/functional/LayerNorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def forward(ctx, x, normalized_shape, weight, bias, eps):
need_grad = [] # save_mean and save_invstd
if ctx.needs_input_grad[0]:
need_grad.append(weight)
if ctx.needs_input_grad[2]:
if any(ctx.needs_input_grad):
need_grad.append(x)
# bias doesnt need anything for calc

Expand All @@ -41,9 +41,9 @@ def backward(ctx, grad_output):
if ctx.needs_input_grad[0]:
weight = ctx.saved_tensors[current_idx]
current_idx += 1
x = ctx.saved_tensors[current_idx]
if ctx.needs_input_grad[3]:
x = ctx.saved_tensors[current_idx]
current_idx += 1

if x is None:
x = torch.zeros(ctx.x_shape, device=ctx.device)
Expand Down
Loading