From e2b1513b5f3d65bb1799c310714104fddde44749 Mon Sep 17 00:00:00 2001 From: Alexey Kuts Date: Sat, 19 Mar 2022 03:12:32 +0300 Subject: [PATCH 1/3] multihead attention version with loop by batch dimension to reduce memory usage --- lean_transformer/__init__.py | 2 +- lean_transformer/attn.py | 26 ++- lean_transformer/batch_step_attn_core_func.py | 163 ++++++++++++++++++ tests/test_attn.py | 48 ++++++ 4 files changed, 236 insertions(+), 3 deletions(-) create mode 100644 lean_transformer/batch_step_attn_core_func.py create mode 100644 tests/test_attn.py diff --git a/lean_transformer/__init__.py b/lean_transformer/__init__.py index d3d8d80..3b538ef 100644 --- a/lean_transformer/__init__.py +++ b/lean_transformer/__init__.py @@ -1,5 +1,5 @@ from .ffn import LeanFFN -from .attn import LeanSelfAttention, SimpleAttentionCore, RotaryAttentionCore +from .attn import LeanSelfAttention, SimpleAttentionCore, RotaryAttentionCore, BatchStepAttentionCore from .rotary import RotaryEmbeddings, rotate from .sequence import SequentialWithKwargs, ReversibleWithKwargs, ActiveKwargs from .config import LeanTransformerConfig diff --git a/lean_transformer/attn.py b/lean_transformer/attn.py index 0a35cc8..0d74eed 100644 --- a/lean_transformer/attn.py +++ b/lean_transformer/attn.py @@ -7,6 +7,8 @@ from lean_transformer.rotary import RotaryEmbeddings +from . import batch_step_attn_core_func + class LeanSelfAttention(nn.Module): def __init__( @@ -16,6 +18,7 @@ def __init__( dropout: float = 0, layer_norm_eps: float = 1e-12, sandwich_norm: bool = False, + layer_norm: bool = True, dense_qkv: Optional[nn.Linear] = None, dense_out: Optional[nn.Linear] = None, residual: bool = True, @@ -34,6 +37,7 @@ def __init__( :param hidden_size: base hidden size of the transformer, before q/k/v projections :param num_attention_heads: number of heads, as defined in the original transformer :param dropout: hidden dropout probability, applied to the output projection (before adding residual) + :param layer_norm: if set, applies layer norm to input tensor :param layer_norm_eps: see torch.nn.functional.layer_norm :param sandwich_norm: if set, applies an additional layer norm to projected attention outputs before residuals, as proposed in the CogView paper ( arXiv:2105.13290 ). This is meant to make fp16 training @@ -58,13 +62,13 @@ def __init__( assert self.dense_qkv.in_features == self.dense_out.in_features == self.dense_out.out_features == hidden_size assert self.dense_qkv.out_features == hidden_size * 3 - self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) if layer_norm else None self.sandwich_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) if sandwich_norm else None self.output_dropout = nn.Dropout(dropout, inplace=False) self.residual, self.checkpoint_attention_core = residual, checkpoint_attention_core def forward(self, hidden_states, attention_mask=None, output_attentions=False): - hidden_states_ln = self.layer_norm(hidden_states) + hidden_states_ln = self.layer_norm(hidden_states) if self.layer_norm else hidden_states qkv_output = self.dense_qkv(hidden_states_ln) query, key, value = qkv_output.split(self.hidden_size, dim=qkv_output.ndim - 1) attention_output, attention_probs = self._maybe_checkpoint( @@ -169,3 +173,21 @@ def forward(self, query, key, value, attention_mask): return self._attention_core_forward( self.rotate(query), self.rotate(key), value, attention_mask, self.num_attention_heads, self.attention_dropout.p, self.training, scale_inplace=True) + + +class BatchStepAttentionCore(SimpleAttentionCore): + def __init__(self, hidden_size: int, num_attention_heads: int, batch_step: int = 8, **kwargs): + super().__init__(hidden_size, num_attention_heads, **kwargs) + assert hidden_size % num_attention_heads == 0 + self.hidden_size, self.num_attention_heads = hidden_size, num_attention_heads + self.attention_head_size = hidden_size // num_attention_heads + self.scaling = self.attention_head_size ** -0.5 + self.batch_step = batch_step + + def forward(self, query, key, value, attention_mask): + if attention_mask is not None: + raise NotImplementedError("not implemented yet") + + ret = batch_step_attn_core_func.batch_step_attn_core_func(self.num_attention_heads, self.scaling, self.batch_step, query, key, value) + + return ret, None diff --git a/lean_transformer/batch_step_attn_core_func.py b/lean_transformer/batch_step_attn_core_func.py new file mode 100644 index 0000000..6aac328 --- /dev/null +++ b/lean_transformer/batch_step_attn_core_func.py @@ -0,0 +1,163 @@ + +import torch +import torch.nn.functional as F + + +class BatchStepAttnCoreFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + heads, + scale, + loop_batch_step, + queries, + keys, + values + ): + num_seqs = keys.size(0) + seq_len = keys.size(1) + hidden_dim = keys.size(2) + head_dim = hidden_dim // heads + + heads_t = torch.tensor([heads]) + scale_t = torch.tensor([scale]) + loop_batch_step_t = torch.tensor([loop_batch_step]) + num_seqs_t = torch.tensor([num_seqs]) + seq_len_t = torch.tensor([seq_len]) + hidden_dim_t = torch.tensor([hidden_dim]) + + queries = queries.view(num_seqs, seq_len, heads, head_dim).transpose(1, 2).contiguous().view(num_seqs * heads, seq_len, head_dim) + keys = keys.view(num_seqs, seq_len, heads, head_dim).transpose(1, 2).contiguous().view(num_seqs * heads, seq_len, head_dim) + values = values.view(num_seqs, seq_len, heads, head_dim).transpose(1, 2).contiguous().view(num_seqs * heads, seq_len, head_dim) + + matmul2_results = torch.empty( + (num_seqs * heads, seq_len, head_dim), dtype=keys.dtype, device=keys.device + ) + + iter_step = int(loop_batch_step_t.item()) + iter_count = num_seqs * heads + for iter_idx in range(0, iter_count, iter_step): + ibatch_range = [iter_idx, min(iter_idx + iter_step, iter_count)] + + # output: [batch, seql_q, seql_k] + matmul1_results = torch.bmm( + queries[ibatch_range[0]:ibatch_range[1], :, :], + keys[ibatch_range[0]:ibatch_range[1], :, :].transpose(1, 2) + ) * scale_t + + # output: [batch, seql_q, seql_k] + softmax_results = F.softmax(matmul1_results, dim=-1) + + matmul2_results[ibatch_range[0]:ibatch_range[1], :, :] = torch.bmm( + softmax_results, + values[ibatch_range[0]:ibatch_range[1], :, :]) + + outputs = matmul2_results.reshape(num_seqs, heads, seq_len, head_dim).transpose(1, 2).reshape(num_seqs, seq_len, hidden_dim) + + ctx.save_for_backward( + heads_t, + scale_t, + loop_batch_step_t, + num_seqs_t, + seq_len_t, + hidden_dim_t, + queries, + keys, + values + ) + + return outputs.detach() + + @staticmethod + def backward(ctx, output_grads): + ( + heads_t, + scale_t, + loop_batch_step_t, + num_seqs_t, + seq_len_t, + hidden_dim_t, + queries, + keys, + values + ) = ctx.saved_tensors + + heads = heads_t[0].item() + num_seqs = int(num_seqs_t.item()) + seq_len = int(seq_len_t.item()) + hidden_dim = int(hidden_dim_t.item()) + head_dim = hidden_dim // heads + + # [seqs * heads, seql, emb_dim] + queries_grads = torch.empty((num_seqs * heads, seq_len, head_dim), dtype=queries.dtype, device=queries.device) + keys_grads = torch.empty((num_seqs * heads, seq_len, head_dim), dtype=keys.dtype, device=keys.device) + values_grads = torch.empty((num_seqs * heads, seq_len, head_dim), dtype=values.dtype, device=values.device) + + output_grads = output_grads.view(num_seqs, seq_len, heads, head_dim).transpose(1, 2).contiguous().view(num_seqs * heads, seq_len, head_dim) + + # output_grads [seqs, seql, emb_dim] + iter_step = int(loop_batch_step_t.item()) + iter_count = num_seqs * heads + for iter_idx in range(0, iter_count, iter_step): + ibatch_range = [iter_idx, min(iter_idx + iter_step, iter_count)] + ibatch_sz = ibatch_range[1] - ibatch_range[0] + + # reconstruct softmax_results + # output: [seqs*heads, seql_q, seql_k] + matmul1_results = torch.bmm( + queries[ibatch_range[0]:ibatch_range[1], :, :], + keys[ibatch_range[0]:ibatch_range[1], :, :].transpose(1, 2) + ) * scale_t + + # output: [seqs*heads, seql_q, seql_k] + softmax_results = F.softmax(matmul1_results, dim=-1) + + # output_grads [ seqs * heads, seql, head_dim ] + # values [ seqs * heads, seql, head_dim ] + # output: [ seqs * heads, seql, seql ] + matmul2_dgrad1 = torch.bmm(output_grads[ibatch_range[0]:ibatch_range[1], :, :], + values[ibatch_range[0]:ibatch_range[1], :, :].transpose(1, 2)) + + # softmax_results [ seqs * heads, seql, seql ] + # output_grads [ seqs * heads, seql, head_dim ] + # output: [ seqs * heads, seql, head_dim ] + values_grads[ibatch_range[0]:ibatch_range[1], :, :] = torch.bmm( + softmax_results.transpose(1, 2), + output_grads[ibatch_range[0]:ibatch_range[1], :, :]) + # output: [ seqs * heads, seql, seql ] + softmax_grads = torch._softmax_backward_data(matmul2_dgrad1, softmax_results, -1, softmax_results.dtype) + + softmax_grads = softmax_grads.view(ibatch_sz, seq_len, seq_len) + + queries_grads[ibatch_range[0]:ibatch_range[1], :, :] = torch.baddbmm( + queries_grads[ibatch_range[0]:ibatch_range[1], :, :], + softmax_grads, + keys[ibatch_range[0]:ibatch_range[1], :, :], + beta=0.0, + alpha=scale_t[0], + ) + + keys_grads[ibatch_range[0]:ibatch_range[1], :, :] = torch.baddbmm( + keys_grads[ibatch_range[0]:ibatch_range[1], :, :], + softmax_grads.transpose(1, 2), + queries[ibatch_range[0]:ibatch_range[1], :, :], + beta=0.0, + alpha=scale_t[0], + ) + + queries_grads = queries_grads.reshape(num_seqs, heads, seq_len, head_dim).transpose(1, 2).reshape(num_seqs, seq_len, hidden_dim) + keys_grads = keys_grads.reshape(num_seqs, heads, seq_len, head_dim).transpose(1, 2).reshape(num_seqs, seq_len, hidden_dim) + values_grads = values_grads.reshape(num_seqs, heads, seq_len, head_dim).transpose(1, 2).reshape(num_seqs, seq_len, hidden_dim) + + # [ seqs * heads, seql, head_dim ] + return ( + None, # heads + None, # scale + None, # loop_batch_step + queries_grads, # queries + keys_grads, # keys + values_grads, # values + ) + + +batch_step_attn_core_func = BatchStepAttnCoreFunc.apply diff --git a/tests/test_attn.py b/tests/test_attn.py new file mode 100644 index 0000000..fb85f74 --- /dev/null +++ b/tests/test_attn.py @@ -0,0 +1,48 @@ +import pytest + +import torch +import torch.nn as nn + +import numpy as np + +from lean_transformer.attn import LeanSelfAttention, BatchStepAttentionCore + + +@pytest.mark.forked +def test_lean_attn(): + torch.use_deterministic_algorithms(True) + + seq_length = 64 + num_seqs = 8 + hidden_dim = 128 + heads = 16 + + gtruth_mha = nn.MultiheadAttention(hidden_dim, heads, bias=True, + dropout=0, batch_first=True) + + for batch_step in [1, 2, 8, num_seqs * heads]: + attention_core = BatchStepAttentionCore(hidden_dim, + heads, batch_step=batch_step) + test_mha = LeanSelfAttention(hidden_dim, heads, dropout=0, + layer_norm=False, residual=False, + attention_core=attention_core, + checkpoint_attention_core=False) + + test_mha.dense_qkv.weight = gtruth_mha.in_proj_weight + test_mha.dense_qkv.bias = gtruth_mha.in_proj_bias + test_mha.dense_out.weight = gtruth_mha.out_proj.weight + test_mha.dense_out.bias = gtruth_mha.out_proj.bias + + device = torch.device('cpu') + + atol = 1e-6 + + for _ in range(10): + a = torch.randn((num_seqs, seq_length, hidden_dim), device=device) + out0 = gtruth_mha(a, a, a)[0] + out1 = test_mha(a)[0] + out0.mean().backward() + out1.mean().backward() + out0 = out0.cpu().detach().numpy() + out1 = out1.cpu().detach().numpy() + assert np.allclose(out0, out1, atol=atol), f"{out0} {out1}" From 504bfa47f4b564cac5c0ef3c4b3f8461fc251813 Mon Sep 17 00:00:00 2001 From: Alexey Kuts Date: Fri, 25 Mar 2022 21:49:54 +0300 Subject: [PATCH 2/3] fixes for review --- lean_transformer/__init__.py | 2 +- lean_transformer/attn.py | 41 ++++++++----------- lean_transformer/batch_step_attn_core_func.py | 16 ++++++-- lean_transformer/transformer.py | 8 +++- tests/test_attn.py | 8 ++-- 5 files changed, 41 insertions(+), 34 deletions(-) diff --git a/lean_transformer/__init__.py b/lean_transformer/__init__.py index 3b538ef..d3d8d80 100644 --- a/lean_transformer/__init__.py +++ b/lean_transformer/__init__.py @@ -1,5 +1,5 @@ from .ffn import LeanFFN -from .attn import LeanSelfAttention, SimpleAttentionCore, RotaryAttentionCore, BatchStepAttentionCore +from .attn import LeanSelfAttention, SimpleAttentionCore, RotaryAttentionCore from .rotary import RotaryEmbeddings, rotate from .sequence import SequentialWithKwargs, ReversibleWithKwargs, ActiveKwargs from .config import LeanTransformerConfig diff --git a/lean_transformer/attn.py b/lean_transformer/attn.py index 1700721..b166e2e 100644 --- a/lean_transformer/attn.py +++ b/lean_transformer/attn.py @@ -17,7 +17,7 @@ def __init__( num_attention_heads: int, dropout: float = 0, layer_norm_eps: float = 1e-12, - pre_layer_norm: bool = False, + pre_layer_norm: bool = True, post_layer_norm: bool = False, qkv_proj: Optional[nn.Linear] = None, out_proj: Optional[nn.Linear] = None, @@ -87,12 +87,14 @@ def _maybe_checkpoint(self, func, *args): class SimpleAttentionCore(nn.Module): - def __init__(self, hidden_size: int, num_attention_heads: int, attention_probs_dropout: float = 0.0): + def __init__(self, hidden_size: int, num_attention_heads: int, attention_probs_dropout: float = 0.0, + batched_attention_size: int = -1): super().__init__() assert hidden_size % num_attention_heads == 0 self.attention_dropout = nn.Dropout(attention_probs_dropout) self.hidden_size, self.num_attention_heads = hidden_size, num_attention_heads self.attention_head_size = hidden_size // num_attention_heads + self.batched_attention_size = batched_attention_size def forward(self, query, key, value, attention_mask): """ @@ -109,7 +111,7 @@ def forward(self, query, key, value, attention_mask): assert torch.is_floating_point(attention_mask), "expected float mask with negative values for masked items" return self._attention_core_forward( query, key, value, attention_mask, self.num_attention_heads, self.attention_dropout.p, - self.training, scale_inplace=False, + self.training, scale_inplace=False, batched_attention_size=self.batched_attention_size ) @staticmethod @@ -118,8 +120,18 @@ def _attention_core_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - num_attention_heads: int, attention_dropout: float, training: bool, scale_inplace: bool + num_attention_heads: int, attention_dropout: float, training: bool, scale_inplace: bool, + batched_attention_size: int = -1 ) -> Tuple[torch.Tensor, torch.Tensor]: + + if batched_attention_size != -1: + hidden_size = query.shape[-1] + attention_head_size = hidden_size // num_attention_heads + scaling = attention_head_size ** -0.5 + ret = batch_step_attn_core_func.batch_step_attn_core_func(num_attention_heads, scaling, + batched_attention_size, query, key, value, attention_mask) + return ret, None + # transpose from [batch, seq_length, full_hid_size] to [batch, num_heads, seq_length, head_size] new_query_shape = query.shape[:-1] + (num_attention_heads, -1) new_kv_shape = key.shape[:-1] + (num_attention_heads, -1) @@ -172,22 +184,5 @@ def rotate(self, tensor: torch.Tensor): def forward(self, query, key, value, attention_mask): return self._attention_core_forward( self.rotate(query), self.rotate(key), value, attention_mask, self.num_attention_heads, - self.attention_dropout.p, self.training, scale_inplace=True) - - -class BatchStepAttentionCore(SimpleAttentionCore): - def __init__(self, hidden_size: int, num_attention_heads: int, batch_step: int = 8, **kwargs): - super().__init__(hidden_size, num_attention_heads, **kwargs) - assert hidden_size % num_attention_heads == 0 - self.hidden_size, self.num_attention_heads = hidden_size, num_attention_heads - self.attention_head_size = hidden_size // num_attention_heads - self.scaling = self.attention_head_size ** -0.5 - self.batch_step = batch_step - - def forward(self, query, key, value, attention_mask): - if attention_mask is not None: - raise NotImplementedError("not implemented yet") - - ret = batch_step_attn_core_func.batch_step_attn_core_func(self.num_attention_heads, self.scaling, self.batch_step, query, key, value) - - return ret, None + self.attention_dropout.p, self.training, scale_inplace=True, + batched_attention_size=self.batched_attention_size) diff --git a/lean_transformer/batch_step_attn_core_func.py b/lean_transformer/batch_step_attn_core_func.py index 6aac328..cda88c4 100644 --- a/lean_transformer/batch_step_attn_core_func.py +++ b/lean_transformer/batch_step_attn_core_func.py @@ -12,7 +12,8 @@ def forward( loop_batch_step, queries, keys, - values + values, + attention_mask ): num_seqs = keys.size(0) seq_len = keys.size(1) @@ -45,6 +46,9 @@ def forward( keys[ibatch_range[0]:ibatch_range[1], :, :].transpose(1, 2) ) * scale_t + if attention_mask is not None: + matmul1_results += attention_mask[:, 0, :, :] + # output: [batch, seql_q, seql_k] softmax_results = F.softmax(matmul1_results, dim=-1) @@ -63,7 +67,8 @@ def forward( hidden_dim_t, queries, keys, - values + values, + attention_mask ) return outputs.detach() @@ -79,7 +84,8 @@ def backward(ctx, output_grads): hidden_dim_t, queries, keys, - values + values, + attention_mask ) = ctx.saved_tensors heads = heads_t[0].item() @@ -109,6 +115,9 @@ def backward(ctx, output_grads): keys[ibatch_range[0]:ibatch_range[1], :, :].transpose(1, 2) ) * scale_t + if attention_mask is not None: + matmul1_results += attention_mask[:, 0, :, :] + # output: [seqs*heads, seql_q, seql_k] softmax_results = F.softmax(matmul1_results, dim=-1) @@ -157,6 +166,7 @@ def backward(ctx, output_grads): queries_grads, # queries keys_grads, # keys values_grads, # values + None, # attention_mask ) diff --git a/lean_transformer/transformer.py b/lean_transformer/transformer.py index fd66302..353521e 100644 --- a/lean_transformer/transformer.py +++ b/lean_transformer/transformer.py @@ -95,6 +95,7 @@ def set_optimizations( checkpoint_attention_core: Optional[bool] = None, ffn_custom_grad: Optional[bool] = None, update_triton_blocksparse_ops: bool = False, + batched_attention_size: Optional[int] = None ): """ Set one or more memory saving options for all compatible sub-modules. Options set to None remain unchanged. @@ -137,8 +138,11 @@ def set_optimizations( sequential.preserve_rng_state = preserve_rng_state for module in sequential.modules(): - if checkpoint_attention_core is not None and isinstance(module, LeanSelfAttention): - module.checkpoint_attention_core = checkpoint_attention_core + if isinstance(module, LeanSelfAttention): + if checkpoint_attention_core is not None: + module.checkpoint_attention_core = checkpoint_attention_core + if batched_attention_size is not None: + module.attention_core.batched_attention_size = batched_attention_size elif ffn_custom_grad is not None and isinstance(module, LeanFFN): module.ffn_custom_grad = ffn_custom_grad else: diff --git a/tests/test_attn.py b/tests/test_attn.py index edf4803..ca7ad45 100644 --- a/tests/test_attn.py +++ b/tests/test_attn.py @@ -5,7 +5,7 @@ import numpy as np -from lean_transformer.attn import LeanSelfAttention, BatchStepAttentionCore +from lean_transformer.attn import LeanSelfAttention @pytest.mark.forked @@ -21,12 +21,10 @@ def test_lean_attn(): dropout=0, batch_first=True) for batch_step in [1, 2, 8, num_seqs * heads]: - attention_core = BatchStepAttentionCore(hidden_dim, - heads, batch_step=batch_step) test_mha = LeanSelfAttention(hidden_dim, heads, dropout=0, pre_layer_norm=False, residual=False, - attention_core=attention_core, - checkpoint_attention_core=False) + checkpoint_attention_core=False, + batched_attention_size=batch_step) test_mha.qkv_proj.weight = gtruth_mha.in_proj_weight test_mha.qkv_proj.bias = gtruth_mha.in_proj_bias From 90abdb87bb08566eaba0a45bc29ec6a3220333ac Mon Sep 17 00:00:00 2001 From: justheuristic Date: Fri, 1 Apr 2022 19:43:44 +0300 Subject: [PATCH 3/3] Update tests/test_attn.py --- tests/test_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_attn.py b/tests/test_attn.py index ca7ad45..bd3e5f2 100644 --- a/tests/test_attn.py +++ b/tests/test_attn.py @@ -9,7 +9,7 @@ @pytest.mark.forked -def test_lean_attn(): +def test_lean_attn(rotary: bool = False): torch.use_deterministic_algorithms(True) seq_length = 64