Skip to content

Commit

Permalink
move utils to experiments, add conv1d
Browse files Browse the repository at this point in the history
  • Loading branch information
plutonium-239 committed Apr 11, 2024
1 parent c3b1c46 commit cc4c916
Show file tree
Hide file tree
Showing 20 changed files with 122 additions and 82 deletions.
1 change: 1 addition & 0 deletions experiments/exp01_llm_finetuning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Measure memory savings on fine-tuning an LLM."""
106 changes: 106 additions & 0 deletions experiments/exp01_llm_finetuning/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""Measure peak memory on of the forward pass."""

# pip install transformers peft

import sys
from os import path

import torch
from peft import LoraConfig, get_peft_model
from torch import manual_seed
from torch.nn import Conv1d, LayerNorm, Linear
from transformers import (
AutoModelForCausalLM,
)

HEREDIR = path.dirname(path.abspath(__file__))
LIBDIR = path.join(HEREDIR, "memsave_torch")
if LIBDIR not in sys.path:
sys.path.append(LIBDIR)

from memsave_torch.nn import (
MemSaveConv1d,
MemSaveLayerNorm,
MemSaveLinear,
recursive_setattr,
)


def print_trainable_parameters(model):
trainable_params = 0
all_param = 0
for name, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
)


def main():
manual_seed(0)

memsave = True

# config = GPT2Config.from_pretrained("gpt2")
# config.hidden_dropout_prob = 0
# config.attention_probs_dropout_prob = 0
# model = GPT2LMHeadModel.from_pretrained("gpt2", config=config)

model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125m")

lora_config = LoraConfig(
r=4,
lora_alpha=16,
# target_modules=["c_attn"], # LoRA on the attention weights, GPT2
target_modules=["q_proj", "v_proj"], # LoRA on the attention weight, GPT neo
lora_dropout=0.1,
bias="none",
)
lora_model = get_peft_model(model, lora_config)
# print_trainable_parameters(lora_model)

if memsave:
for name, layer in model.named_modules():
if isinstance(layer, Linear):
new_layer = MemSaveLinear.from_nn_Linear(layer)
for p1, p2 in zip(layer.parameters(), new_layer.parameters()):
p2.requires_grad = p1.requires_grad
recursive_setattr(model, name, new_layer)
elif isinstance(layer, Conv1d):
new_layer = MemSaveConv1d.from_nn_Conv1d(layer)
for p1, p2 in zip(layer.parameters(), new_layer.parameters()):
p2.requires_grad = p1.requires_grad
recursive_setattr(model, name, new_layer)
elif isinstance(layer, LayerNorm):
new_layer = MemSaveLayerNorm.from_nn_LayerNorm(layer)
for p1, p2 in zip(layer.parameters(), new_layer.parameters()):
p2.requires_grad = p1.requires_grad
recursive_setattr(model, name, new_layer)

batch_size = 8
seq_len = 512
input_ids = torch.randint(10, (batch_size, seq_len))

out = lora_model(input_ids)

# print(out)
print({type(layer) for layer in model.modules()})

# for name, layer in model.named_modules():
# print(name, type(layer))

# for name, param in model.named_parameters():
# if param.requires_grad:
# print(f"{name}")

# print(f"{name} requires_grad = {param.requires_grad}")
# print(out["logits"].flatten()[0:10])
return out


if __name__ == "__main__":
main()
# max_usage = memory_usage(main, interval=1e-3, max_usage=True)
# print(f"Peak mem: {max_usage}.")
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pandas as pd

from memsave_torch.util.collect_results import case_mapping
from experiments.util.collect_results import case_mapping


def main(base_dir: str):
Expand Down
6 changes: 3 additions & 3 deletions memsave_torch/paper_demo.py → experiments/paper_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

from tqdm import tqdm

from memsave_torch.util import collect_results
from memsave_torch.util.models import prefix_in_pairs
from experiments.util import collect_results
from experiments.util.models import prefix_in_pairs

estimators = ["time", "memory"]
estimators = ["memory"]
Expand Down Expand Up @@ -111,7 +111,7 @@
pbar.set_description(f"{model} {estimate} case {case}")
case_str = f"--case {' '.join(case)}" if case is not None else ""
cmd = (
f"python memsave_torch/util/estimate.py --architecture {architecture} --model {model} --estimate {estimate} {case_str} "
f"python experiments/util/estimate.py --architecture {architecture} --model {model} --estimate {estimate} {case_str} "
+ f"--device {device} -B {batch_size} -C_in {input_channels} -HW {input_HW} -n_class {num_classes}"
)
proc = subprocess.run(shlex.split(cmd), capture_output=True)
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from torch import Tensor, device, manual_seed, rand, randint
from torch.nn import CrossEntropyLoss, Module

from memsave_torch.util import models
from memsave_torch.util.measurements import (
from experiments.util import models
from experiments.util.measurements import (
MemoryMeasurement,
RuntimeMeasurement,
)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
1 change: 0 additions & 1 deletion memsave_torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
"""memsave_torch package"""

import memsave_torch.nn as nn # noqa: F401
import memsave_torch.util as util # noqa: F401
70 changes: 2 additions & 68 deletions memsave_torch/nn/Conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
import torch.nn as nn
from Conv2d import _MemSaveConv


class MemSaveConv1d(nn.Conv1d):
Expand Down Expand Up @@ -100,73 +101,6 @@ def from_nn_Conv1d(cls, conv1d: nn.Conv1d):
return obj


class _MemSaveConv1d(torch.autograd.Function):
@staticmethod
def forward(x, weight, bias, stride, padding, dilation, groups):
return nn.functional.conv1d(x, weight, bias, stride, padding, dilation, groups)

@staticmethod
def setup_context(ctx, inputs, output):
x, weight, bias, stride, padding, dilation, groups = inputs
# print('setting up context', ctx.needs_input_grad)
need_grad = []
if ctx.needs_input_grad[0]:
# print('weight saved')
need_grad.append(weight)
if ctx.needs_input_grad[1]:
# print('x saved')
need_grad.append(x)
# bias doesnt need anything for calc
ctx.bias_exists = bias is not None
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.groups = groups
ctx.x_shape = x.shape
ctx.weight_shape = weight.shape

ctx.save_for_backward(*need_grad)

@staticmethod
def backward(ctx, grad_output):
x = weight = None

current_idx = 0
if ctx.needs_input_grad[0]:
# print('0 needs weight')
weight = ctx.saved_tensors[current_idx]
current_idx += 1
elif ctx.needs_input_grad[1]:
# print('1 needs x')
x = ctx.saved_tensors[current_idx]
current_idx += 1

if weight is not None:
x = torch.zeros(ctx.x_shape, device=weight.device)
if x is not None:
weight = torch.zeros(ctx.weight_shape, device=x.device)

# print(current_idx)

grad_x, grad_weight, grad_bias = torch.ops.aten.convolution_backward(
grad_output,
x,
weight,
weight.shape[0] if ctx.bias_exists else None,
ctx.stride,
ctx.padding,
ctx.dilation,
False,
[0],
ctx.groups,
ctx.needs_input_grad[:3],
)

# print('grads are ', (grad_x is not None), (grad_weight is not None), (grad_bias is not None))

return grad_x, grad_weight, grad_bias, None, None, None, None, None


def conv1dMemSave(
input, weight, bias, stride, padding, dilation, groups
) -> torch.Tensor:
Expand All @@ -184,4 +118,4 @@ def conv1dMemSave(
Returns:
torch.Tensor: Output of the conv operation [B, C_out, H_out, W_out]
"""
return _MemSaveConv1d.apply(input, weight, bias, stride, padding, dilation, groups)
return _MemSaveConv.apply(input, weight, bias, stride, padding, dilation, groups)
4 changes: 2 additions & 2 deletions memsave_torch/nn/Conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def from_nn_Conv2d(cls, conv2d: nn.Conv2d):
return obj


class _MemSaveConv2d(torch.autograd.Function):
class _MemSaveConv(torch.autograd.Function):
@staticmethod
def forward(x, weight, bias, stride, padding, dilation, groups):
return nn.functional.conv2d(x, weight, bias, stride, padding, dilation, groups)
Expand Down Expand Up @@ -185,4 +185,4 @@ def conv2dMemSave(
Returns:
torch.Tensor: Output of the conv operation [B, C_out, H_out, W_out]
"""
return _MemSaveConv2d.apply(input, weight, bias, stride, padding, dilation, groups)
return _MemSaveConv.apply(input, weight, bias, stride, padding, dilation, groups)
10 changes: 5 additions & 5 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def test_all():
import subprocess
from time import sleep

from memsave_torch.util import collect_results
from memsave_torch.util.models import prefix_in_pairs
from experiments.util import collect_results
from experiments.util.models import prefix_in_pairs

estimators = ["time", "memory"]

Expand Down Expand Up @@ -66,7 +66,7 @@ def test_all():
for case in cases:
case_str = f"--case {' '.join(case)}" if case is not None else ""
cmd = (
f"python memsave_torch/util/estimate.py --architecture {architecture} --model {model} --estimate {estimate} {case_str} "
f"python experiments/util/estimate.py --architecture {architecture} --model {model} --estimate {estimate} {case_str} "
+ f"--device {device} -B {batch_size} -C_in {input_channels} -HW {input_HW} -n_class {num_classes} "
+ f"--results_dir {results_dir}"
)
Expand All @@ -81,6 +81,6 @@ def test_all():

collector.finish()

import memsave_torch.get_best_results
import experiments.get_best_results

memsave_torch.get_best_results.main(results_dir)
experiments.get_best_results.main(results_dir)

0 comments on commit cc4c916

Please sign in to comment.