Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multihead attention version with loop by batch dimension to reduce memory usage #10

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions lean_transformer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from lean_transformer.rotary import RotaryEmbeddings

from . import batch_step_attn_core_func


class LeanSelfAttention(nn.Module):
def __init__(
Expand All @@ -15,6 +17,7 @@ def __init__(
num_attention_heads: int,
dropout: float = 0,
layer_norm_eps: float = 1e-12,
pre_layer_norm: bool = True,
post_layer_norm: bool = False,
qkv_proj: Optional[nn.Linear] = None,
out_proj: Optional[nn.Linear] = None,
Expand All @@ -34,6 +37,7 @@ def __init__(
:param hidden_size: base hidden size of the transformer, before q/k/v projections
:param num_attention_heads: number of heads, as defined in the original transformer
:param dropout: hidden dropout probability, applied to the output projection (before adding residual)
:param pre_layer_norm: if set, applies layer norm to input tensor
:param layer_norm_eps: see torch.nn.functional.layer_norm
:param post_layer_norm: if set, applies an additional layer norm to projected attention outputs before residuals,
as proposed in the CogView paper ( arXiv:2105.13290 ). This is meant to make fp16 training
Expand All @@ -58,13 +62,13 @@ def __init__(
assert self.qkv_proj.in_features == self.out_proj.in_features == self.out_proj.out_features == hidden_size
assert self.qkv_proj.out_features == hidden_size * 3

self.pre_layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.pre_layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) if pre_layer_norm else None
self.post_layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) if post_layer_norm else None
self.output_dropout = nn.Dropout(dropout, inplace=False)
self.residual, self.checkpoint_attention_core = residual, checkpoint_attention_core

def forward(self, hidden_states, attention_mask=None, output_attentions=False):
hidden_states_ln = self.pre_layer_norm(hidden_states)
hidden_states_ln = self.pre_layer_norm(hidden_states) if self.pre_layer_norm else hidden_states
qkv_output = self.qkv_proj(hidden_states_ln)
query, key, value = qkv_output.split(self.hidden_size, dim=qkv_output.ndim - 1)
attention_output, attention_probs = self._maybe_checkpoint(
Expand All @@ -83,12 +87,14 @@ def _maybe_checkpoint(self, func, *args):


class SimpleAttentionCore(nn.Module):
def __init__(self, hidden_size: int, num_attention_heads: int, attention_probs_dropout: float = 0.0):
def __init__(self, hidden_size: int, num_attention_heads: int, attention_probs_dropout: float = 0.0,
batched_attention_size: int = -1):
super().__init__()
assert hidden_size % num_attention_heads == 0
self.attention_dropout = nn.Dropout(attention_probs_dropout)
self.hidden_size, self.num_attention_heads = hidden_size, num_attention_heads
self.attention_head_size = hidden_size // num_attention_heads
self.batched_attention_size = batched_attention_size

def forward(self, query, key, value, attention_mask):
"""
Expand All @@ -105,7 +111,7 @@ def forward(self, query, key, value, attention_mask):
assert torch.is_floating_point(attention_mask), "expected float mask with negative values for masked items"
return self._attention_core_forward(
query, key, value, attention_mask, self.num_attention_heads, self.attention_dropout.p,
self.training, scale_inplace=False,
self.training, scale_inplace=False, batched_attention_size=self.batched_attention_size
)

@staticmethod
Expand All @@ -114,8 +120,18 @@ def _attention_core_forward(
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
num_attention_heads: int, attention_dropout: float, training: bool, scale_inplace: bool
num_attention_heads: int, attention_dropout: float, training: bool, scale_inplace: bool,
batched_attention_size: int = -1
) -> Tuple[torch.Tensor, torch.Tensor]:

if batched_attention_size != -1:
hidden_size = query.shape[-1]
attention_head_size = hidden_size // num_attention_heads
scaling = attention_head_size ** -0.5
ret = batch_step_attn_core_func.batch_step_attn_core_func(num_attention_heads, scaling,
batched_attention_size, query, key, value, attention_mask)
return ret, None

# transpose from [batch, seq_length, full_hid_size] to [batch, num_heads, seq_length, head_size]
new_query_shape = query.shape[:-1] + (num_attention_heads, -1)
new_kv_shape = key.shape[:-1] + (num_attention_heads, -1)
Expand Down Expand Up @@ -168,4 +184,5 @@ def rotate(self, tensor: torch.Tensor):
def forward(self, query, key, value, attention_mask):
return self._attention_core_forward(
self.rotate(query), self.rotate(key), value, attention_mask, self.num_attention_heads,
self.attention_dropout.p, self.training, scale_inplace=True)
self.attention_dropout.p, self.training, scale_inplace=True,
batched_attention_size=self.batched_attention_size)
173 changes: 173 additions & 0 deletions lean_transformer/batch_step_attn_core_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@

import torch
import torch.nn.functional as F


class BatchStepAttnCoreFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
heads,
scale,
loop_batch_step,
queries,
keys,
values,
attention_mask
):
num_seqs = keys.size(0)
seq_len = keys.size(1)
hidden_dim = keys.size(2)
head_dim = hidden_dim // heads

heads_t = torch.tensor([heads])
scale_t = torch.tensor([scale])
loop_batch_step_t = torch.tensor([loop_batch_step])
num_seqs_t = torch.tensor([num_seqs])
seq_len_t = torch.tensor([seq_len])
hidden_dim_t = torch.tensor([hidden_dim])

queries = queries.view(num_seqs, seq_len, heads, head_dim).transpose(1, 2).contiguous().view(num_seqs * heads, seq_len, head_dim)
keys = keys.view(num_seqs, seq_len, heads, head_dim).transpose(1, 2).contiguous().view(num_seqs * heads, seq_len, head_dim)
values = values.view(num_seqs, seq_len, heads, head_dim).transpose(1, 2).contiguous().view(num_seqs * heads, seq_len, head_dim)

matmul2_results = torch.empty(
(num_seqs * heads, seq_len, head_dim), dtype=keys.dtype, device=keys.device
)

iter_step = int(loop_batch_step_t.item())
iter_count = num_seqs * heads
for iter_idx in range(0, iter_count, iter_step):
ibatch_range = [iter_idx, min(iter_idx + iter_step, iter_count)]

# output: [batch, seql_q, seql_k]
matmul1_results = torch.bmm(
queries[ibatch_range[0]:ibatch_range[1], :, :],
keys[ibatch_range[0]:ibatch_range[1], :, :].transpose(1, 2)
) * scale_t

if attention_mask is not None:
matmul1_results += attention_mask[:, 0, :, :]

# output: [batch, seql_q, seql_k]
softmax_results = F.softmax(matmul1_results, dim=-1)

matmul2_results[ibatch_range[0]:ibatch_range[1], :, :] = torch.bmm(
softmax_results,
values[ibatch_range[0]:ibatch_range[1], :, :])

outputs = matmul2_results.reshape(num_seqs, heads, seq_len, head_dim).transpose(1, 2).reshape(num_seqs, seq_len, hidden_dim)

ctx.save_for_backward(
heads_t,
scale_t,
loop_batch_step_t,
num_seqs_t,
seq_len_t,
hidden_dim_t,
queries,
keys,
values,
attention_mask
)

return outputs.detach()

@staticmethod
def backward(ctx, output_grads):
(
heads_t,
scale_t,
loop_batch_step_t,
num_seqs_t,
seq_len_t,
hidden_dim_t,
queries,
keys,
values,
attention_mask
) = ctx.saved_tensors

heads = heads_t[0].item()
num_seqs = int(num_seqs_t.item())
seq_len = int(seq_len_t.item())
hidden_dim = int(hidden_dim_t.item())
head_dim = hidden_dim // heads

# [seqs * heads, seql, emb_dim]
queries_grads = torch.empty((num_seqs * heads, seq_len, head_dim), dtype=queries.dtype, device=queries.device)
keys_grads = torch.empty((num_seqs * heads, seq_len, head_dim), dtype=keys.dtype, device=keys.device)
values_grads = torch.empty((num_seqs * heads, seq_len, head_dim), dtype=values.dtype, device=values.device)

output_grads = output_grads.view(num_seqs, seq_len, heads, head_dim).transpose(1, 2).contiguous().view(num_seqs * heads, seq_len, head_dim)

# output_grads [seqs, seql, emb_dim]
iter_step = int(loop_batch_step_t.item())
iter_count = num_seqs * heads
for iter_idx in range(0, iter_count, iter_step):
ibatch_range = [iter_idx, min(iter_idx + iter_step, iter_count)]
ibatch_sz = ibatch_range[1] - ibatch_range[0]

# reconstruct softmax_results
# output: [seqs*heads, seql_q, seql_k]
matmul1_results = torch.bmm(
queries[ibatch_range[0]:ibatch_range[1], :, :],
keys[ibatch_range[0]:ibatch_range[1], :, :].transpose(1, 2)
) * scale_t

if attention_mask is not None:
matmul1_results += attention_mask[:, 0, :, :]

# output: [seqs*heads, seql_q, seql_k]
softmax_results = F.softmax(matmul1_results, dim=-1)

# output_grads [ seqs * heads, seql, head_dim ]
# values [ seqs * heads, seql, head_dim ]
# output: [ seqs * heads, seql, seql ]
matmul2_dgrad1 = torch.bmm(output_grads[ibatch_range[0]:ibatch_range[1], :, :],
values[ibatch_range[0]:ibatch_range[1], :, :].transpose(1, 2))

# softmax_results [ seqs * heads, seql, seql ]
# output_grads [ seqs * heads, seql, head_dim ]
# output: [ seqs * heads, seql, head_dim ]
values_grads[ibatch_range[0]:ibatch_range[1], :, :] = torch.bmm(
softmax_results.transpose(1, 2),
output_grads[ibatch_range[0]:ibatch_range[1], :, :])
# output: [ seqs * heads, seql, seql ]
softmax_grads = torch._softmax_backward_data(matmul2_dgrad1, softmax_results, -1, softmax_results.dtype)

softmax_grads = softmax_grads.view(ibatch_sz, seq_len, seq_len)

queries_grads[ibatch_range[0]:ibatch_range[1], :, :] = torch.baddbmm(
queries_grads[ibatch_range[0]:ibatch_range[1], :, :],
softmax_grads,
keys[ibatch_range[0]:ibatch_range[1], :, :],
beta=0.0,
alpha=scale_t[0],
)

keys_grads[ibatch_range[0]:ibatch_range[1], :, :] = torch.baddbmm(
keys_grads[ibatch_range[0]:ibatch_range[1], :, :],
softmax_grads.transpose(1, 2),
queries[ibatch_range[0]:ibatch_range[1], :, :],
beta=0.0,
alpha=scale_t[0],
)

queries_grads = queries_grads.reshape(num_seqs, heads, seq_len, head_dim).transpose(1, 2).reshape(num_seqs, seq_len, hidden_dim)
keys_grads = keys_grads.reshape(num_seqs, heads, seq_len, head_dim).transpose(1, 2).reshape(num_seqs, seq_len, hidden_dim)
values_grads = values_grads.reshape(num_seqs, heads, seq_len, head_dim).transpose(1, 2).reshape(num_seqs, seq_len, hidden_dim)

# [ seqs * heads, seql, head_dim ]
return (
None, # heads
None, # scale
None, # loop_batch_step
queries_grads, # queries
keys_grads, # keys
values_grads, # values
None, # attention_mask
)


batch_step_attn_core_func = BatchStepAttnCoreFunc.apply
8 changes: 6 additions & 2 deletions lean_transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def set_optimizations(
checkpoint_attention_core: Optional[bool] = None,
ffn_custom_grad: Optional[bool] = None,
update_triton_blocksparse_ops: bool = False,
batched_attention_size: Optional[int] = None
):
"""
Set one or more memory saving options for all compatible sub-modules. Options set to None remain unchanged.
Expand Down Expand Up @@ -137,8 +138,11 @@ def set_optimizations(
sequential.preserve_rng_state = preserve_rng_state

for module in sequential.modules():
if checkpoint_attention_core is not None and isinstance(module, LeanSelfAttention):
module.checkpoint_attention_core = checkpoint_attention_core
if isinstance(module, LeanSelfAttention):
if checkpoint_attention_core is not None:
module.checkpoint_attention_core = checkpoint_attention_core
if batched_attention_size is not None:
module.attention_core.batched_attention_size = batched_attention_size
elif ffn_custom_grad is not None and isinstance(module, LeanFFN):
module.ffn_custom_grad = ffn_custom_grad
else:
Expand Down
46 changes: 46 additions & 0 deletions tests/test_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import pytest

import torch
import torch.nn as nn

import numpy as np

from lean_transformer.attn import LeanSelfAttention


@pytest.mark.forked
def test_lean_attn():
justheuristic marked this conversation as resolved.
Show resolved Hide resolved
torch.use_deterministic_algorithms(True)

seq_length = 64
num_seqs = 8
hidden_dim = 128
heads = 16

gtruth_mha = nn.MultiheadAttention(hidden_dim, heads, bias=True,
dropout=0, batch_first=True)

for batch_step in [1, 2, 8, num_seqs * heads]:
test_mha = LeanSelfAttention(hidden_dim, heads, dropout=0,
pre_layer_norm=False, residual=False,
checkpoint_attention_core=False,
batched_attention_size=batch_step)

test_mha.qkv_proj.weight = gtruth_mha.in_proj_weight
test_mha.qkv_proj.bias = gtruth_mha.in_proj_bias
test_mha.out_proj.weight = gtruth_mha.out_proj.weight
test_mha.out_proj.bias = gtruth_mha.out_proj.bias

device = torch.device('cpu')

atol = 1e-6

for _ in range(10):
a = torch.randn((num_seqs, seq_length, hidden_dim), device=device)
out0 = gtruth_mha(a, a, a)[0]
out1 = test_mha(a)[0]
out0.mean().backward()
out1.mean().backward()
out0 = out0.cpu().detach().numpy()
out1 = out1.cpu().detach().numpy()
assert np.allclose(out0, out1, atol=atol), f"{out0} {out1}"