Skip to content

Commit

Permalink
conv1d + minor
Browse files Browse the repository at this point in the history
  • Loading branch information
plutonium-239 committed Apr 11, 2024
1 parent cc4c916 commit 574b9ea
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 36 deletions.
47 changes: 13 additions & 34 deletions experiments/exp01_llm_finetuning/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,26 @@

# pip install transformers peft

import sys
from os import path

import torch
from peft import LoraConfig, get_peft_model
from torch import manual_seed
from torch.nn import Conv1d, LayerNorm, Linear
from torch.nn import Module
from transformers import (
AutoModelForCausalLM,
)

HEREDIR = path.dirname(path.abspath(__file__))
LIBDIR = path.join(HEREDIR, "memsave_torch")
if LIBDIR not in sys.path:
sys.path.append(LIBDIR)
from memsave_torch.nn import convert_to_memory_saving

from memsave_torch.nn import (
MemSaveConv1d,
MemSaveLayerNorm,
MemSaveLinear,
recursive_setattr,
)

def print_trainable_parameters(model: Module):
"""Function that prints how many parameters are trainable in the given model
def print_trainable_parameters(model):
Args:
model (Module): The model
"""
trainable_params = 0
all_param = 0
for name, param in model.named_parameters():
for param in model.parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
Expand All @@ -39,6 +31,7 @@ def print_trainable_parameters(model):


def main():
"""Runs the LLM experiment after replacing layers"""
manual_seed(0)

memsave = True
Expand All @@ -58,26 +51,12 @@ def main():
lora_dropout=0.1,
bias="none",
)
lora_model = get_peft_model(model, lora_config)
# print_trainable_parameters(lora_model)

if memsave:
for name, layer in model.named_modules():
if isinstance(layer, Linear):
new_layer = MemSaveLinear.from_nn_Linear(layer)
for p1, p2 in zip(layer.parameters(), new_layer.parameters()):
p2.requires_grad = p1.requires_grad
recursive_setattr(model, name, new_layer)
elif isinstance(layer, Conv1d):
new_layer = MemSaveConv1d.from_nn_Conv1d(layer)
for p1, p2 in zip(layer.parameters(), new_layer.parameters()):
p2.requires_grad = p1.requires_grad
recursive_setattr(model, name, new_layer)
elif isinstance(layer, LayerNorm):
new_layer = MemSaveLayerNorm.from_nn_LayerNorm(layer)
for p1, p2 in zip(layer.parameters(), new_layer.parameters()):
p2.requires_grad = p1.requires_grad
recursive_setattr(model, name, new_layer)
model = convert_to_memory_saving(model)

lora_model = get_peft_model(model, lora_config)
# print_trainable_parameters(lora_model)

batch_size = 8
seq_len = 512
Expand Down
15 changes: 14 additions & 1 deletion experiments/util/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,37 @@ def convert_to_memory_saving_defaultsoff(
model: Module,
linear=False,
conv2d=False,
conv1d=False,
batchnorm2d=False,
relu=False,
maxpool2d=False,
layernorm=False,
) -> Module:
"""Extension of the `convert_to_memory_saving` function with all defaults as off
Args:
model (Module): Input model
linear (bool, optional): Whether to replace linear layers
conv2d (bool, optional): Whether to replace conv2d layers
conv1d (bool, optional): Whether to replace conv1d layers
batchnorm2d (bool, optional): Whether to replace batchnorm2d layers
relu (bool, optional): Whether to replace relu layers
maxpool2d (bool, optional): Whether to replace maxpool2d layers
layernorm (bool, optional): Whether to replace layernorm layers
Returns:
Module: The converted memory saving model
"""
return convert_to_memory_saving(model, linear, conv2d, batchnorm2d, relu, maxpool2d)
return convert_to_memory_saving(
model,
linear=linear,
conv2d=conv2d,
conv1d=conv1d,
batchnorm2d=batchnorm2d,
relu=relu,
maxpool2d=maxpool2d,
layernorm=layernorm,
)


# CONV
Expand Down
3 changes: 2 additions & 1 deletion experiments/visual_abstract/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
args = parser.parse_args()


def main():
def main(): # noqa: C901
"""Runs exps for generating the data of the visual abstract"""
manual_seed(0)

# create the input
Expand Down
8 changes: 8 additions & 0 deletions memsave_torch/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn

from memsave_torch.nn.BatchNorm import MemSaveBatchNorm2d
from memsave_torch.nn.Conv1d import MemSaveConv1d
from memsave_torch.nn.Conv2d import MemSaveConv2d
from memsave_torch.nn.LayerNorm import MemSaveLayerNorm
from memsave_torch.nn.Linear import MemSaveLinear
Expand All @@ -20,6 +21,7 @@ def convert_to_memory_saving(
model: nn.Module,
linear=True,
conv2d=True,
conv1d=False,
batchnorm2d=True,
relu=True,
maxpool2d=True,
Expand All @@ -37,6 +39,7 @@ def convert_to_memory_saving(
model (nn.Module): The input model
linear (bool, optional): Whether to replace `nn.Linear` layers
conv2d (bool, optional): Whether to replace `nn.Conv2d` layers
conv1d (bool, optional): Whether to replace `nn.Conv1d` layers
batchnorm2d (bool, optional): Whether to replace `nn.BatchNorm2d` layers
relu (bool, optional): Whether to replace `nn.ReLU` layers
maxpool2d (bool, optional): Whether to replace `nn.MaxPool2d` layers
Expand Down Expand Up @@ -64,6 +67,11 @@ def convert_to_memory_saving(
"cls": nn.Conv2d,
"convert_fn": MemSaveConv2d.from_nn_Conv2d,
},
{
"allowed": conv1d,
"cls": nn.Conv1d,
"convert_fn": MemSaveConv1d.from_nn_Conv1d,
},
{
"allowed": batchnorm2d,
"cls": nn.BatchNorm2d,
Expand Down

0 comments on commit 574b9ea

Please sign in to comment.