diff --git a/tests/test_ops_precision_error.py b/tests/test_ops_precision_error.py index 94ac6ed53..dae2bbb5b 100644 --- a/tests/test_ops_precision_error.py +++ b/tests/test_ops_precision_error.py @@ -3,6 +3,7 @@ import triton.language as tl import pytest import pdb +from flash_attn.flash_attn_triton_amd.utils import check_is_fp8 @triton.jit def many_ops_triton(x_ptr, @@ -43,7 +44,7 @@ def many_ops_triton(x_ptr, } """ # Set input dtype (we will cast back to this for the output) - input_dtype = tl.float16 if DTYPE==0 else tl.float32 if DTYPE==1 else None + input_dtype = tl.float8e4b8 if DTYPE == 2 else tl.float16 if DTYPE==0 else tl.float32 if DTYPE==1 else None x_block_range = tl.arange(0, M)[:, None]*K + tl.arange(0, K)[None, :] y_block_range = tl.arange(0, K)[:, None]*N + tl.arange(0, N)[None, :] @@ -94,6 +95,7 @@ def many_ops_triton(x_ptr, # Matmul o_block_range = tl.arange(0, M)[:, None]*N + tl.arange(0, N)[None, :] o = tl.dot(x, y) # tl.dot always outputs input dtype. ALSO REQUIRES INPUT SHAPES M >= 16, N >= 16 and K >= 16 + if IMITATE_PYTORCH: x = x.to(input_dtype) y = y.to(input_dtype) @@ -150,13 +152,13 @@ def many_ops_torch(x: torch.Tensor, @pytest.mark.parametrize("M", [16, 32]) @pytest.mark.parametrize("K", [16, 32, 64]) # 64 seems to cause some issues @pytest.mark.parametrize("N", [16, 32]) -@pytest.mark.parametrize("mult", [0.001, 0.74, 1.5251]) # mult = [0, 2.99] -@pytest.mark.parametrize("dtype", [torch.float16]) # torch.float32 -@pytest.mark.parametrize("IMITATE_PYTORCH", [1]) # 0 = no casting (not imitating pytorch), 1 = cast after every op (imitating pytorch) -@pytest.mark.parametrize("DO_MULTIPLY", [1]) # Include multiplication +@pytest.mark.parametrize("mult", [0.7972]) # mult = [0, 2.99] +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float8_e4m3fnuz]) # torch.float32 +@pytest.mark.parametrize("IMITATE_PYTORCH", [0]) # 0 = no casting (not imitating pytorch), 1 = cast after every op (imitating pytorch) +@pytest.mark.parametrize("DO_MULTIPLY", [0]) # Include multiplication @pytest.mark.parametrize("DO_SIGMOID", [0]) # Include sigmoid @pytest.mark.parametrize("DO_COS", [0]) # Include cosine -@pytest.mark.parametrize("DO_EXPONENT", [0]) # Include exponentiation +@pytest.mark.parametrize("DO_EXPONENT", [1]) # Include exponentiation @pytest.mark.parametrize("DO_SQRT", [0]) # Include square root def test_many_ops(seed, M, K, N, mult, dtype, IMITATE_PYTORCH, DO_MULTIPLY, DO_SIGMOID, DO_COS, DO_EXPONENT, DO_SQRT): """ @@ -188,23 +190,34 @@ def test_many_ops(seed, M, K, N, mult, dtype, IMITATE_PYTORCH, DO_MULTIPLY, DO_S torch_type_to_id = { torch.float16: 0, - torch.float32: 1 + torch.float32: 1, + torch.float8_e4m3fnuz: 2 } DTYPE = torch_type_to_id[dtype] - x = torch.rand(M, K, dtype=dtype, device=device) - y = torch.rand(K, N, dtype=dtype, device=device) + x = torch.rand(M, K, dtype=torch.float32, device=device) + y = torch.rand(K, N, dtype=torch.float32, device=device) + x, y = x.to(dtype), y.to(dtype) grid = (1,) out = torch.zeros(M, N, dtype=dtype, device=device) - out_torch = torch.zeros(M, N, dtype=dtype, device=device) + + if check_is_fp8(x): + out_torch = torch.zeros(M, N, dtype=torch.float16, device=device) + else: + out_torch = torch.zeros(M, N, dtype=dtype, device=device) with torch.cuda.device(x.device): many_ops_triton[grid](x, y, out, M, K, N, mult, IMITATE_PYTORCH, DTYPE, DO_MULTIPLY, DO_SIGMOID, DO_COS, DO_EXPONENT, DO_SQRT) - many_ops_torch(x, y, out_torch, M, K, N, mult, DO_MULTIPLY, DO_SIGMOID, DO_COS, DO_EXPONENT, DO_SQRT) + if check_is_fp8(x): + many_ops_torch(x.to(torch.float16), y.to(torch.float16), out_torch, M, K, N, mult, DO_MULTIPLY, DO_SIGMOID, DO_COS, DO_EXPONENT, DO_SQRT) + else: + many_ops_torch(x, y, out_torch, M, K, N, mult, DO_MULTIPLY, DO_SIGMOID, DO_COS, DO_EXPONENT, DO_SQRT) # print("torch - triton", (out_torch-out)) + if check_is_fp8(x): + out = out.to(torch.float16) print(f'absolute error: {(out-out_torch).abs().max().item()}, relative error: {((out-out_torch)/out).abs().max().item()}')