Skip to content

Commit

Permalink
Merge pull request #669 from ROCm/tianxing/FA-int8
Browse files Browse the repository at this point in the history
Tianxing/fa int8
  • Loading branch information
Chi-Chu319 authored Dec 18, 2024
2 parents 9cdcf1d + 296209a commit cd6f51b
Show file tree
Hide file tree
Showing 2 changed files with 343 additions and 40 deletions.
18 changes: 18 additions & 0 deletions python/perf-kernels/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,27 @@ This script contains the Flash Attention kernel with the following support
- Multi and Grouped Query attention
- ALiBi bias
- Matrix bias
- Int8 quantization

These are currently supported for the forward kernel only.

INT8 Quantization Support

1. <em>q_descale</em>, <em>k_descale</em>, and <em>v_descale</em> provided:
- The first QK GEMM runs in INT8, then the output is dequantized to the specified <em>dtype</em>.
- The second PV GEMM runs in the specified <em>dtype</em>.

2. <em>q_descale</em>, <em>k_descale</em>, <em>p_descale</em>, and <em>v_descale</em> provided:
- Both the first and second GEMM operations run in INT8.
- The results are dequantized to the specified <em>dtype</em> after both GEMMs.

3. Only <em>k_descale</em> and <em>v_descale</em> provided:
- K and V are dequantized before the first and second GEMM operations, respectively.
- Both GEMMs run in the specified <em>dtype</em>.

Note: The softmax operation is always performed in <em>fp32</em>.


## `06-attention-decode.py`

This contains the Flash Decoding kernel.
Expand Down
Loading

0 comments on commit cd6f51b

Please sign in to comment.