Skip to content

Commit

Permalink
add llm trial code
Browse files Browse the repository at this point in the history
  • Loading branch information
plutonium-239 committed Apr 16, 2024
1 parent d47c05f commit 4df05a3
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 15 deletions.
4 changes: 2 additions & 2 deletions experiments/paper_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

# models = ["resnet101", "memsave_resnet101_conv", "memsave_resnet101_conv+relu+bn", "memsave_resnet101_conv_full"]
# models = ["resnet101", "memsave_resnet101_conv_full"]

models = ["gpt2"]
models = prefix_in_pairs("memsave_", models)
# models = ["memsave_resnet101"]
batch_size = 64
Expand Down Expand Up @@ -99,7 +99,7 @@
architecture,
vjp_improvements,
cases,
'results'
"results",
)

for model in models:
Expand Down
38 changes: 26 additions & 12 deletions experiments/util/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Linear,
Module,
Parameter,
Embedding
)
from torchvision.models.convnext import LayerNorm2d

Expand Down Expand Up @@ -116,28 +117,31 @@ def forward_backward(
grad_norm_weights: bool = True,
grad_norm_bias: bool = True,
grad_input: bool = False,
grad_embed_weights: bool = True
) -> float:
"""Perform a forward and backward pass and return the run time.
Syncs CUDA threads if the device is a GPU.
Args:
grad_linear_weights: Whether to compute the gradient of the linear
grad_linear_weights (bool, optional): Whether to compute the gradient of the linear
layer weights. Default: `True`.
grad_linear_bias: Whether to compute the gradient of the linear
grad_linear_bias (bool, optional): Whether to compute the gradient of the linear
layer bias. Default: `True`.
grad_conv_weights: Whether to compute the gradient of the convolution
grad_conv_weights (bool, optional): Whether to compute the gradient of the convolution
layer weights. Default: `True`.
grad_conv_bias: Whether to compute the gradient of the convolution
grad_conv_bias (bool, optional): Whether to compute the gradient of the convolution
layer bias. Default: `True`.
grad_norm_weights: Whether to compute the gradient of the normalization
grad_norm_weights (bool, optional): Whether to compute the gradient of the normalization
layer weights. Default: `True`.
grad_norm_bias: Whether to compute the gradient of the normalization
grad_norm_bias (bool, optional): Whether to compute the gradient of the normalization
layer bias. Default: `True`.
grad_input: Whether to compute the gradient of the input. Default: `False`.
grad_input (bool, optional): Whether to compute the gradient of the input. Default: `False`.
grad_embed_weights (bool, optional): Whether to compute the gradient of the embedding
layer weights. Default: `True`.
Returns:
The run time in seconds.
float: The run time in seconds.
"""
model, loss_fn, x, y, targets = self.set_up()

Expand All @@ -149,6 +153,7 @@ def forward_backward(
grad_conv_bias,
grad_norm_weights,
grad_norm_bias,
grad_embed_weights
)
leafs = ([x] if grad_input else []) + leafs
no_leafs = ([y] if grad_input else [x, y]) + no_leafs
Expand Down Expand Up @@ -191,6 +196,7 @@ def after_forward(
grad_norm_weights: bool = True,
grad_norm_bias: bool = True,
grad_input: bool = False,
grad_embed_weights: bool = True
) -> float:
"""Return memory usage after a forward pass.
Expand All @@ -208,9 +214,11 @@ def after_forward(
grad_norm_bias: Whether to compute the gradient of the normalization
layer bias. Default: `True`.
grad_input: Whether to compute the gradient of the input. Default: `False`.
grad_embed_weights (bool, optional): Whether to compute the gradient of the embedding
layer weights. Default: `True`.
Returns:
The memory usage in bytes.
float: The memory usage in bytes.
"""
model, loss_fn, x, y, targets = self.set_up()

Expand All @@ -222,6 +230,7 @@ def after_forward(
grad_conv_bias,
grad_norm_weights,
grad_norm_bias,
grad_embed_weights
)
leafs = ([x] if grad_input else []) + leafs
no_leafs = ([y] if grad_input else [x, y]) + no_leafs
Expand Down Expand Up @@ -289,6 +298,7 @@ def separate_grad_arguments(
grad_conv_bias: bool,
grad_norm_weights: bool,
grad_norm_bias: bool,
grad_embed_weights: bool
) -> Tuple[List[Parameter], List[Parameter]]:
"""Separate the parameters of a model into leafs and non-leafs.
Expand All @@ -303,6 +313,7 @@ def separate_grad_arguments(
grad_norm_weights: Whether to compute the gradient of the normalization layer
weights.
grad_norm_bias: Whether to compute the gradient of the normalization layer bias.
grad_embed_weights: Whether to compute the gradient of the embedding layer weights
Returns:
A tuple of lists of parameters. The first list contains the leafs, the second
Expand All @@ -322,6 +333,7 @@ def separate_grad_arguments(
MemSaveConv2d,
)
norm = (BatchNorm1d, BatchNorm2d, BatchNorm3d, LayerNorm, LayerNorm2d)
embed = (Embedding)

leafs, no_leafs = [], []

Expand All @@ -346,6 +358,8 @@ def separate_layer(layer: Module, grad_weight: bool, grad_bias: bool):
separate_layer(layer, grad_conv_weights, grad_conv_bias)
elif isinstance(layer, norm):
separate_layer(layer, grad_norm_weights, grad_norm_bias)
elif isinstance(layer, embed):
separate_layer(layer, grad_embed_weights, False)
elif list(layer.parameters()):
raise NotImplementedError(f"Unknown layer with parameters: {layer}.")

Expand Down
36 changes: 35 additions & 1 deletion experiments/util/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torchvision.models as tvm
from torch.nn import Conv2d, Flatten, Linear, MaxPool2d, Module, ReLU, Sequential
from transformers import AutoConfig, AutoModelForCausalLM

from memsave_torch.nn import (
MemSaveConv2d,
Expand All @@ -26,7 +27,7 @@ def prefix_in_pairs(prefix: str, it: List[str]) -> List[str]:
Args:
prefix (str): Prefix to be added
it (List[str]): Description
it (List[str]): The list to be prefixed
Returns:
List[str]: The output iterable with items prefixed in pairs
Expand Down Expand Up @@ -73,6 +74,12 @@ def convert_to_memory_saving_defaultsoff(
)


def get_transformers_config(model_name: str) -> AutoConfig:
if model_name.startswith("memsave_"):
model_name = model_name.split("memsave_")[1]
return AutoConfig.from_pretrained(model_name)


# CONV
conv_input_shape: Tuple[int, int, int] = (1, 1, 1)

Expand Down Expand Up @@ -161,6 +168,8 @@ def _convrelupool_model1(num_blocks=5) -> Module:
segmentation_models = prefix_in_pairs("memsave_", segmentation_models)
models_without_norm = ["deepmodel", "vgg16"]
models_without_norm = prefix_in_pairs("memsave_", models_without_norm)
transformers_models = ["gpt2"]
transformers_models = prefix_in_pairs("memsave_", transformers_models)

conv_model_fns = {
"deepmodel": _conv_model1,
Expand Down Expand Up @@ -213,6 +222,11 @@ def _convrelupool_model1(num_blocks=5) -> Module:
"memsave_resnext101_64x4d": lambda: convert_to_memory_saving(
tvm.resnext101_64x4d()
),
"gpt2": lambda: AutoModelForCausalLM.from_pretrained("gpt2"),
"memsave_gpt2": lambda: convert_to_memory_saving(
AutoModelForCausalLM.from_pretrained("gpt2")
),
# For paper
"memsave_resnet101_conv": lambda: convert_to_memory_saving_defaultsoff(
tvm.resnet101(), conv2d=True
),
Expand Down Expand Up @@ -267,6 +281,26 @@ def forward(self, loss_dict):
return sum(loss_dict.values())


class TransformersModelWrapper(Module):
"""Small wrapper around `transformers` models to support interop with existing measurement code"""

def __init__(self, model_fn) -> None:
"""Init"""
super().__init__()
self.model = model_fn()

def forward(self, x):
"""Forward
Args:
x: x
Returns:
output: model output
"""
return self.model(input_ids=x, use_cache=False)["logits"]


# LINEAR
linear_input_shape: int = 1

Expand Down
8 changes: 8 additions & 0 deletions memsave_torch/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from memsave_torch.nn.BatchNorm import MemSaveBatchNorm2d
from memsave_torch.nn.Conv1d import MemSaveConv1d
from memsave_torch.nn.Conv2d import MemSaveConv2d
from memsave_torch.nn.Dropout import MemSaveDropout
from memsave_torch.nn.LayerNorm import MemSaveLayerNorm
from memsave_torch.nn.Linear import MemSaveLinear
from memsave_torch.nn.MaxPool import MemSaveMaxPool2d
Expand All @@ -26,6 +27,7 @@ def convert_to_memory_saving(
relu=True,
maxpool2d=True,
layernorm=True,
dropout=True,
verbose=False,
clone_params=False,
) -> nn.Module:
Expand All @@ -44,6 +46,7 @@ def convert_to_memory_saving(
relu (bool, optional): Whether to replace `nn.ReLU` layers
maxpool2d (bool, optional): Whether to replace `nn.MaxPool2d` layers
layernorm (bool, optional): Whether to replace `nn.LayerNorm` layers
dropout (bool, optional): Whether to replace `nn.Dropout` layers
verbose (bool, optional): Whether to print which layers were replaced
clone_params (bool, optional): Whether to clone the layer parameters or use directly
Expand Down Expand Up @@ -82,6 +85,11 @@ def convert_to_memory_saving(
"cls": nn.LayerNorm,
"convert_fn": MemSaveLayerNorm.from_nn_LayerNorm,
},
{
"allowed": dropout,
"cls": nn.Dropout,
"convert_fn": MemSaveDropout.from_nn_dropout,
},
]

import copy
Expand Down

0 comments on commit 4df05a3

Please sign in to comment.