Skip to content

Commit

Permalink
dump forward dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Dec 2, 2024
1 parent d008a3c commit 118e705
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ csrc/flash_attn_ck
core.*
*.csv
*.png
*.html
*.html
*.json
6 changes: 4 additions & 2 deletions flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import triton
import triton.language as tl
from .utils import get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, DEBUG, AUTOTUNE
from .utils import get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, DEBUG, AUTOTUNE, write_tensor

# Convenience function to load with optional boundary checks.
# "First" is the major dim, "second" is the minor dim.
Expand Down Expand Up @@ -137,6 +137,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
# NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that
exp_score_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)
tl.store(exp_scores_ptrs, tl.where(dropout_mask, p, -p), mask=exp_score_mask)
# tl.store(exp_scores_ptrs, dropout_mask, mask=exp_score_mask)
p = tl.where(dropout_mask, p, 0.0)
elif RETURN_SCORES:
# NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that
Expand Down Expand Up @@ -604,6 +605,7 @@ def attention_prefill_forward_triton_impl(
print("attention_prefill_forward_triton_impl outputs")
print("o:", o, o.shape)
print("softmax_lse:", softmax_lse, softmax_lse.shape)
print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None)
print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None, ",", "dropout fraction:", 1.0 - (sd_mask.sum()/ sd_mask.numel()).item())
# write_tensor(sd_mask)

return o, softmax_lse, sd_mask.to(o.dtype) if return_scores else None
11 changes: 11 additions & 0 deletions flash_attn/flash_attn_triton_amd/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@

import csv
import json
import torch
import os
import triton
Expand Down Expand Up @@ -257,6 +259,15 @@ def get_padded_headsize(size):
padded_d_model = max(padded_d_model, 16)
return padded_d_model

def write_tensor(x, tensor_name = "tensor"):
x = x.tolist()

with open(f'{tensor_name}.csv', 'w') as f:
writer = csv.writer(f)
writer.writerows(x)

with open(f'{tensor_name}.json', 'w') as f:
json.dump(x, f, indent=2)

def _strides(x: torch.Tensor, *stride_names: str):
if x is None:
Expand Down

0 comments on commit 118e705

Please sign in to comment.