Skip to content

Commit

Permalink
feat: added fp8 to precision test
Browse files Browse the repository at this point in the history
  • Loading branch information
alexkranias-amd committed Dec 11, 2024
1 parent 937e814 commit fd342f7
Showing 1 changed file with 24 additions and 11 deletions.
35 changes: 24 additions & 11 deletions tests/test_ops_precision_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import triton.language as tl
import pytest
import pdb
from flash_attn.flash_attn_triton_amd.utils import check_is_fp8

@triton.jit
def many_ops_triton(x_ptr,
Expand Down Expand Up @@ -43,7 +44,7 @@ def many_ops_triton(x_ptr,
}
"""
# Set input dtype (we will cast back to this for the output)
input_dtype = tl.float16 if DTYPE==0 else tl.float32 if DTYPE==1 else None
input_dtype = tl.float8e4b8 if DTYPE == 2 else tl.float16 if DTYPE==0 else tl.float32 if DTYPE==1 else None

x_block_range = tl.arange(0, M)[:, None]*K + tl.arange(0, K)[None, :]
y_block_range = tl.arange(0, K)[:, None]*N + tl.arange(0, N)[None, :]
Expand Down Expand Up @@ -94,6 +95,7 @@ def many_ops_triton(x_ptr,
# Matmul
o_block_range = tl.arange(0, M)[:, None]*N + tl.arange(0, N)[None, :]
o = tl.dot(x, y) # tl.dot always outputs input dtype. ALSO REQUIRES INPUT SHAPES M >= 16, N >= 16 and K >= 16

if IMITATE_PYTORCH:
x = x.to(input_dtype)
y = y.to(input_dtype)
Expand Down Expand Up @@ -150,13 +152,13 @@ def many_ops_torch(x: torch.Tensor,
@pytest.mark.parametrize("M", [16, 32])
@pytest.mark.parametrize("K", [16, 32, 64]) # 64 seems to cause some issues
@pytest.mark.parametrize("N", [16, 32])
@pytest.mark.parametrize("mult", [0.001, 0.74, 1.5251]) # mult = [0, 2.99]
@pytest.mark.parametrize("dtype", [torch.float16]) # torch.float32
@pytest.mark.parametrize("IMITATE_PYTORCH", [1]) # 0 = no casting (not imitating pytorch), 1 = cast after every op (imitating pytorch)
@pytest.mark.parametrize("DO_MULTIPLY", [1]) # Include multiplication
@pytest.mark.parametrize("mult", [0.7972]) # mult = [0, 2.99]
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float8_e4m3fnuz]) # torch.float32
@pytest.mark.parametrize("IMITATE_PYTORCH", [0]) # 0 = no casting (not imitating pytorch), 1 = cast after every op (imitating pytorch)
@pytest.mark.parametrize("DO_MULTIPLY", [0]) # Include multiplication
@pytest.mark.parametrize("DO_SIGMOID", [0]) # Include sigmoid
@pytest.mark.parametrize("DO_COS", [0]) # Include cosine
@pytest.mark.parametrize("DO_EXPONENT", [0]) # Include exponentiation
@pytest.mark.parametrize("DO_EXPONENT", [1]) # Include exponentiation
@pytest.mark.parametrize("DO_SQRT", [0]) # Include square root
def test_many_ops(seed, M, K, N, mult, dtype, IMITATE_PYTORCH, DO_MULTIPLY, DO_SIGMOID, DO_COS, DO_EXPONENT, DO_SQRT):
"""
Expand Down Expand Up @@ -188,23 +190,34 @@ def test_many_ops(seed, M, K, N, mult, dtype, IMITATE_PYTORCH, DO_MULTIPLY, DO_S

torch_type_to_id = {
torch.float16: 0,
torch.float32: 1
torch.float32: 1,
torch.float8_e4m3fnuz: 2
}

DTYPE = torch_type_to_id[dtype]

x = torch.rand(M, K, dtype=dtype, device=device)
y = torch.rand(K, N, dtype=dtype, device=device)
x = torch.rand(M, K, dtype=torch.float32, device=device)
y = torch.rand(K, N, dtype=torch.float32, device=device)
x, y = x.to(dtype), y.to(dtype)

grid = (1,)
out = torch.zeros(M, N, dtype=dtype, device=device)
out_torch = torch.zeros(M, N, dtype=dtype, device=device)

if check_is_fp8(x):
out_torch = torch.zeros(M, N, dtype=torch.float16, device=device)
else:
out_torch = torch.zeros(M, N, dtype=dtype, device=device)

with torch.cuda.device(x.device):
many_ops_triton[grid](x, y, out, M, K, N, mult, IMITATE_PYTORCH, DTYPE, DO_MULTIPLY, DO_SIGMOID, DO_COS, DO_EXPONENT, DO_SQRT)
many_ops_torch(x, y, out_torch, M, K, N, mult, DO_MULTIPLY, DO_SIGMOID, DO_COS, DO_EXPONENT, DO_SQRT)
if check_is_fp8(x):
many_ops_torch(x.to(torch.float16), y.to(torch.float16), out_torch, M, K, N, mult, DO_MULTIPLY, DO_SIGMOID, DO_COS, DO_EXPONENT, DO_SQRT)
else:
many_ops_torch(x, y, out_torch, M, K, N, mult, DO_MULTIPLY, DO_SIGMOID, DO_COS, DO_EXPONENT, DO_SQRT)

# print("torch - triton", (out_torch-out))
if check_is_fp8(x):
out = out.to(torch.float16)

print(f'absolute error: {(out-out_torch).abs().max().item()}, relative error: {((out-out_torch)/out).abs().max().item()}')

Expand Down

0 comments on commit fd342f7

Please sign in to comment.