diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-aot-amx.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-aot-amx.hpp index 4f703b7d92e54..21583b6a3e2f6 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-aot-amx.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-aot-amx.hpp @@ -1,4 +1,4 @@ -//===-------------- matrix-amx.hpp - SYCL matrix --------------*- C++ -*---===// +//===------------ matrix-aot-amx.hpp - SYCL matrix ------------*- C++ -*---===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp index 9e092325206d8..14466fd5fafb4 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp @@ -1,4 +1,4 @@ -//==------------------ matrix.hpp - SYCL matrix ----------------*- C++ -*---==// +//==---------------- matrix-jit.hpp - SYCL matrix --------------*- C++ -*---==// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp index 194dd075aeea0..5c6df9114b161 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp @@ -1,7 +1,12 @@ -#pragma once +//===---- matrix-tensorcore.hpp - SYCL tensor cores matrix ----*- C++ -*---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// ===--------------------------------------------------------------------=== // -#include -#include +#pragma once __SYCL_INLINE_NAMESPACE(cl) { namespace sycl { @@ -13,52 +18,81 @@ enum class matrix_use { a, b, accumulator }; enum class matrix_layout { row_major, col_major, packed_a, packed_b }; -template -struct joint_matrix { - joint_matrix(Group g) {} -}; - -// The enable_if_t usage in this file is used to disable the -// matrix_layout::packed case which is not compatible with the Nvidia cuda -// backend. -template -struct joint_matrix< - double, matrix_use::a, 8, 4, Layout, sycl::sub_group, - typename std::enable_if_t> { - double data[1]; -}; - -template -struct joint_matrix< - double, matrix_use::b, 4, 8, Layout, sycl::sub_group, - typename std::enable_if_t<(Layout == matrix_layout::row_major || - Layout == matrix_layout::col_major)>> { - double data[1]; -}; - -template -struct joint_matrix< - double, matrix_use::accumulator, 8, 8, Layout, sycl::sub_group, - typename std::enable_if_t> { - double data[2]; -}; - +struct joint_matrix; + +#define __SYCL_JOINT_MATRIX_OVERLOAD(type, use, M, N, frag_type, frag_size) \ + template \ + struct joint_matrix< \ + type, matrix_use::use, M, N, Layout, sycl::sub_group, \ + typename std::enable_if_t> { \ + frag_type data[frag_size]; \ + }; + +// m8n8k4 double only +__SYCL_JOINT_MATRIX_OVERLOAD(double, a, 8, 4, double, 1) +__SYCL_JOINT_MATRIX_OVERLOAD(double, b, 4, 8, double, 1) +__SYCL_JOINT_MATRIX_OVERLOAD(double, accumulator, 8, 8, double, 2) + +// m8n32k16 +// bf16 data format uses uint16_t data type +__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 8, 16, int32_t, 2) +__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 32, int32_t, 8) +__SYCL_JOINT_MATRIX_OVERLOAD(half, a, 8, 16, int32_t, 8) +__SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 32, int32_t, 8) +__SYCL_JOINT_MATRIX_OVERLOAD(float, accumulator, 8, 32, float, 8) +__SYCL_JOINT_MATRIX_OVERLOAD(half, accumulator, 8, 32, int32_t, 4) + +__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, a, 8, 16, int32_t, 1) +__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, b, 16, 32, int32_t, 4) +__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 8, 16, int32_t, 1) +__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 32, int32_t, 4) +__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 8, 32, int32_t, 8) + +// m32n8k16 +__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 32, 16, int32_t, 8) +__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 8, int32_t, 2) +__SYCL_JOINT_MATRIX_OVERLOAD(half, a, 32, 16, int32_t, 8) +__SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 8, int32_t, 8) +__SYCL_JOINT_MATRIX_OVERLOAD(float, accumulator, 32, 8, float, 8) +__SYCL_JOINT_MATRIX_OVERLOAD(half, accumulator, 32, 8, int32_t, 4) + +__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, a, 32, 16, int32_t, 4) +__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, b, 16, 8, int32_t, 1) +__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 32, 16, int32_t, 4) +__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 8, int32_t, 1) +__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 32, 8, int32_t, 8) + +// m16n16k16 +__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, a, 16, 16, int32_t, 4) +__SYCL_JOINT_MATRIX_OVERLOAD(uint16_t, b, 16, 16, int32_t, 4) +__SYCL_JOINT_MATRIX_OVERLOAD(half, a, 16, 16, int32_t, 8) +__SYCL_JOINT_MATRIX_OVERLOAD(half, b, 16, 16, int32_t, 8) +__SYCL_JOINT_MATRIX_OVERLOAD(float, accumulator, 16, 16, float, 8) +__SYCL_JOINT_MATRIX_OVERLOAD(half, accumulator, 16, 16, int32_t, 4) + +__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, a, 16, 16, int32_t, 2) +__SYCL_JOINT_MATRIX_OVERLOAD(int8_t, b, 16, 16, int32_t, 2) +__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, a, 16, 16, int32_t, 2) +__SYCL_JOINT_MATRIX_OVERLOAD(uint8_t, b, 16, 16, int32_t, 2) +__SYCL_JOINT_MATRIX_OVERLOAD(int32_t, accumulator, 16, 16, int32_t, 8) + +#undef __SYCL_JOINT_MATRIX_OVERLOAD } // namespace experimental::matrix namespace detail { -template struct joint_matrix_load_impl { void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< - T, MT, NumRows, NumCols, Layout> &res, + T, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, multi_ptr src, size_t stride); }; @@ -77,70 +111,167 @@ constexpr int get_layout_id< return 1; } -template -struct joint_matrix_load_impl< - double, sycl::ext::oneapi::experimental::matrix::matrix_use::a, 8, 4, - Layout, Space, - typename std::enable_if_t> { - void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< - double, sycl::ext::oneapi::experimental::matrix::matrix_use::a, - 8, 4, Layout> &res, - multi_ptr src, size_t stride) { - -#ifdef __NVPTX__ -#ifdef __SYCL_DEVICE_ONLY__ - __dmma_m8n8k4_ld_a(res.data, src.get(), stride, get_layout_id()); -#endif -#endif - } -}; - -template struct joint_matrix_load_impl< - double, sycl::ext::oneapi::experimental::matrix::matrix_use::b, 4, 8, - Layout, Space, + T, Use, NumRows, NumCols, Layout, Space, typename std::enable_if_t> { void load(sycl::ext::oneapi::experimental::matrix::joint_matrix< - double, sycl::ext::oneapi::experimental::matrix::matrix_use::b, - 4, 8, Layout> &res, - multi_ptr src, size_t stride) { -#ifdef __NVPTX__ -#ifdef __SYCL_DEVICE_ONLY__ - __dmma_m8n8k4_ld_b(res.data, src.get(), stride, get_layout_id()); -#endif -#endif - } -}; - -template -struct joint_matrix_load_impl< - double, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8, - 8, Layout, Space, - typename std::enable_if_t> { - void - load(sycl::ext::oneapi::experimental::matrix::joint_matrix< - double, - sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8, - 8, Layout> &res, - multi_ptr src, size_t stride) { - -#ifdef __NVPTX__ -#ifdef __SYCL_DEVICE_ONLY__ - __dmma_m8n8k4_ld_c(res.data, src.get(), stride, get_layout_id()); -#endif -#endif + T, Use, NumRows, NumCols, Layout, sycl::sub_group> &res, + multi_ptr src, size_t stride) { + if constexpr (std::is_same::value) { + int32_t *tileptr = reinterpret_cast(src.get()); + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (Use == + sycl::ext::oneapi::experimental::matrix::matrix_use::a) { + __mma_bf16_m16n16k16_ld_a(res.data, tileptr, stride, + get_layout_id()); + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::b) { + __mma_bf16_m16n16k16_ld_b(res.data, tileptr, stride, + get_layout_id()); + } + } else if constexpr (NumRows == 8 && NumCols == 16) { + __mma_bf16_m8n32k16_ld_a(res.data, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 32) { + __mma_bf16_m8n32k16_ld_b(res.data, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 16) { + __mma_bf16_m32n8k16_ld_a(res.data, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 8) { + __mma_bf16_m32n8k16_ld_b(res.data, tileptr, stride, + get_layout_id()); + } + } else if constexpr (std::is_same::value) { + int32_t *tileptr = reinterpret_cast(src.get()); + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (Use == + sycl::ext::oneapi::experimental::matrix::matrix_use::a) { + __imma_m16n16k16_ld_a_u8(res.data, tileptr, stride, + get_layout_id()); + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::b) { + __imma_m16n16k16_ld_b_u8(res.data, tileptr, stride, + get_layout_id()); + } + } else if constexpr (NumRows == 8 && NumCols == 16) { + __imma_m8n32k16_ld_a_u8(res.data, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 32) { + __imma_m8n32k16_ld_b_u8(res.data, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 16) { + __imma_m32n8k16_ld_a_u8(res.data, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 8) { + __imma_m32n8k16_ld_b_u8(res.data, tileptr, stride, + get_layout_id()); + } + } else if constexpr (std::is_same::value) { + int32_t *tileptr = reinterpret_cast(src.get()); + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (Use == + sycl::ext::oneapi::experimental::matrix::matrix_use::a) { + __imma_m16n16k16_ld_a_s8(res.data, tileptr, stride, + get_layout_id()); + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::b) { + __imma_m16n16k16_ld_b_s8(res.data, tileptr, stride, + get_layout_id()); + } + } else if constexpr (NumRows == 8 && NumCols == 16) { + __imma_m8n32k16_ld_a_s8(res.data, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 32) { + __imma_m8n32k16_ld_b_s8(res.data, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 16) { + __imma_m32n8k16_ld_a_s8(res.data, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 8) { + __imma_m32n8k16_ld_b_s8(res.data, tileptr, stride, + get_layout_id()); + } + } else if constexpr (std::is_same::value) { + int32_t *tileptr = reinterpret_cast(src.get()); + if constexpr (NumRows == 16 && NumCols == 16) { + if constexpr (Use == + sycl::ext::oneapi::experimental::matrix::matrix_use::a) { + __hmma_m16n16k16_ld_a(res.data, tileptr, stride, + get_layout_id()); + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::b) { + __hmma_m16n16k16_ld_b(res.data, tileptr, stride, + get_layout_id()); + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::accumulator) { + __hmma_m16n16k16_ld_c_f16(res.data, tileptr, stride, + get_layout_id()); + } + } else if constexpr (NumRows == 8 && NumCols == 16) { + __hmma_m8n32k16_ld_a(res.data, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 32) { + __hmma_m8n32k16_ld_b(res.data, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 16) { + __hmma_m32n8k16_ld_a(res.data, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 16 && NumCols == 8) { + __hmma_m32n8k16_ld_b(res.data, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 8) { + __hmma_m32n8k16_ld_c_f16(res.data, tileptr, stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 32) { + __hmma_m8n32k16_ld_c_f16(res.data, tileptr, stride, + get_layout_id()); + } + + } else if constexpr (std::is_same::value) { + if constexpr (NumRows == 16 && NumCols == 16) { + __imma_m16n16k16_ld_c(res.data, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 32) { + __imma_m8n32k16_ld_c(res.data, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 8) { + __imma_m32n8k16_ld_c(res.data, src.get(), stride, + get_layout_id()); + } + } else if constexpr (std::is_same::value) { + if constexpr (NumRows == 16 && NumCols == 16) { + __hmma_m16n16k16_ld_c_f32(res.data, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 8 && NumCols == 32) { + __hmma_m8n32k16_ld_c_f32(res.data, src.get(), stride, + get_layout_id()); + } else if constexpr (NumRows == 32 && NumCols == 8) { + __hmma_m32n8k16_ld_c_f32(res.data, src.get(), stride, + get_layout_id()); + } + } else if constexpr (std::is_same::value) { + if constexpr (Use == + sycl::ext::oneapi::experimental::matrix::matrix_use::a) { + __dmma_m8n8k4_ld_a(res.data, src.get(), stride, + get_layout_id()); + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::b) { + __dmma_m8n8k4_ld_b(res.data, src.get(), stride, + get_layout_id()); + } else if constexpr (Use == sycl::ext::oneapi::experimental::matrix:: + matrix_use::accumulator) { + __dmma_m8n8k4_ld_c(res.data, src.get(), stride, + get_layout_id()); + } + } } }; @@ -151,31 +282,64 @@ struct joint_matrix_store_impl { void store(sycl::ext::oneapi::experimental::matrix::joint_matrix< T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, - NumRows, NumCols, Layout> &src, + NumRows, NumCols, Layout, sycl::sub_group> &src, multi_ptr dst, size_t stride); }; -template struct joint_matrix_store_impl< - double, 8, 8, Layout, Space, + T, NumRows, NumCols, Layout, Space, typename std::enable_if_t> { void store(sycl::ext::oneapi::experimental::matrix::joint_matrix< - double, - sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8, - 8, Layout> &src, - multi_ptr dst, size_t stride) { - -#ifdef __NVPTX__ -#ifdef __SYCL_DEVICE_ONLY__ - __dmma_m8n8k4_st_c_f64(dst.get(), src.data, stride, - get_layout_id()); -#endif -#endif + T, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + NumRows, NumCols, Layout, sycl::sub_group> &src, + multi_ptr dst, size_t stride) { + if (NumRows == 16 && NumCols == 16) { + if constexpr (std::is_same::value) { + __hmma_m16n16k16_st_c_f32(dst.get(), src.data, stride, + get_layout_id()); + } else if constexpr (std::is_same::value) { + __imma_m16n16k16_st_c_i32(dst.get(), src.data, stride, + get_layout_id()); + } else if constexpr (std::is_same::value) { + int32_t *tileptr = reinterpret_cast(dst.get()); + __hmma_m16n16k16_st_c_f16(tileptr, src.data, stride, + get_layout_id()); + } + } else if (NumRows == 8 && NumCols == 32) { + if constexpr (std::is_same::value) { + __hmma_m8n32k16_st_c_f32(dst.get(), src.data, stride, + get_layout_id()); + } else if constexpr (std::is_same::value) { + __imma_m8n32k16_st_c_i32(dst.get(), src.data, stride, + get_layout_id()); + } else if constexpr (std::is_same::value) { + int32_t *tileptr = reinterpret_cast(dst.get()); + __hmma_m8n32k16_st_c_f16(tileptr, src.data, stride, + get_layout_id()); + } + } else if (NumRows == 32 && NumCols == 8) { + if constexpr (std::is_same::value) { + __hmma_m32n8k16_st_c_f32(dst.get(), src.data, stride, + get_layout_id()); + } else if constexpr (std::is_same::value) { + __imma_m32n8k16_st_c_i32(dst.get(), src.data, stride, + get_layout_id()); + } else if constexpr (std::is_same::value) { + int32_t *tileptr = reinterpret_cast(dst.get()); + __hmma_m32n8k16_st_c_f16(tileptr, src.data, stride, + get_layout_id()); + } + } else if constexpr (std::is_same::value) { + __dmma_m8n8k4_st_c_f64(dst.get(), src.data, stride, + get_layout_id()); + } } }; @@ -187,18 +351,18 @@ template + N, LayoutC, sycl::sub_group> mad(sycl::ext::oneapi::experimental::matrix::joint_matrix< T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K, - LayoutA> + LayoutA, sycl::sub_group> A, sycl::ext::oneapi::experimental::matrix::joint_matrix< T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N, - LayoutB> + LayoutB, sycl::sub_group> B, sycl::ext::oneapi::experimental::matrix::joint_matrix< T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, - M, N, LayoutC> + M, N, LayoutC, sycl::sub_group> C); }; @@ -234,11 +398,12 @@ constexpr int get_layout_pair_id< return 3; } -template struct joint_matrix_mad_impl< - double, double, 8, 4, 8, LayoutA, LayoutB, LayoutC, + T1, T2, M, K, N, LayoutA, LayoutB, LayoutC, typename std::enable_if_t< (LayoutA == sycl::ext::oneapi::experimental::matrix::matrix_layout:: row_major || @@ -253,34 +418,87 @@ struct joint_matrix_mad_impl< LayoutC == sycl::ext::oneapi::experimental::matrix::matrix_layout:: col_major)>> { sycl::ext::oneapi::experimental::matrix::joint_matrix< - double, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, - 8, 8, LayoutC> + T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, + N, LayoutC, sycl::sub_group> mad(sycl::ext::oneapi::experimental::matrix::joint_matrix< - double, sycl::ext::oneapi::experimental::matrix::matrix_use::a, 8, 4, - LayoutA> + T1, sycl::ext::oneapi::experimental::matrix::matrix_use::a, M, K, + LayoutA, sycl::sub_group> A, sycl::ext::oneapi::experimental::matrix::joint_matrix< - double, sycl::ext::oneapi::experimental::matrix::matrix_use::b, 4, 8, - LayoutB> + T1, sycl::ext::oneapi::experimental::matrix::matrix_use::b, K, N, + LayoutB, sycl::sub_group> B, sycl::ext::oneapi::experimental::matrix::joint_matrix< - double, - sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8, - 8, LayoutC> + T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, + M, N, LayoutC, sycl::sub_group> C) { sycl::ext::oneapi::experimental::matrix::joint_matrix< - double, - sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, 8, 8, - LayoutC> + T2, sycl::ext::oneapi::experimental::matrix::matrix_use::accumulator, M, + N, LayoutC, sycl::sub_group> D; - -#ifdef __NVPTX__ -#ifdef __SYCL_DEVICE_ONLY__ - __dmma_m8n8k4_mma_f64(D.data, A.data, B.data, C.data, - get_layout_pair_id(), 0); -#endif -#endif - + if constexpr (M == 16 && N == 16 && K == 16) { + if constexpr (std::is_same::value) { + __imma_m16n16k16_mma_s8(D.data, A.data, B.data, C.data, + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __imma_m16n16k16_mma_u8(D.data, A.data, B.data, C.data, + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { + __hmma_m16n16k16_mma_f32f32(D.data, A.data, B.data, C.data, + get_layout_pair_id(), + 0); + } else if constexpr (std::is_same::value) { + __hmma_m16n16k16_mma_f16f16(D.data, A.data, B.data, C.data, + get_layout_pair_id(), + 0); + } + } else if constexpr (std::is_same::value) { + __mma_bf16_m16n16k16_mma_f32(D.data, A.data, B.data, C.data, + get_layout_pair_id(), 0); + } + } else if constexpr (M == 8 && N == 32 && K == 16) { + if constexpr (std::is_same::value) { + __imma_m8n32k16_mma_s8(D.data, A.data, B.data, C.data, + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __imma_m8n32k16_mma_u8(D.data, A.data, B.data, C.data, + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { + __hmma_m8n32k16_mma_f32f32(D.data, A.data, B.data, C.data, + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __hmma_m8n32k16_mma_f16f16(D.data, A.data, B.data, C.data, + get_layout_pair_id(), 0); + } + } else if constexpr (std::is_same::value) { + __mma_bf16_m8n32k16_mma_f32(D.data, A.data, B.data, C.data, + get_layout_pair_id(), 0); + } + } else if constexpr (M == 32 && N == 8 && K == 16) { + if constexpr (std::is_same::value) { + __imma_m32n8k16_mma_s8(D.data, A.data, B.data, C.data, + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __imma_m32n8k16_mma_u8(D.data, A.data, B.data, C.data, + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __mma_bf16_m32n8k16_mma_f32(D.data, A.data, B.data, C.data, + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { + __hmma_m32n8k16_mma_f32f32(D.data, A.data, B.data, C.data, + get_layout_pair_id(), 0); + } else if constexpr (std::is_same::value) { + __hmma_m32n8k16_mma_f16f16(D.data, A.data, B.data, C.data, + get_layout_pair_id(), 0); + } + } + } else if constexpr (std::is_same::value) { + __dmma_m8n8k4_mma_f64(D.data, A.data, B.data, C.data, + get_layout_pair_id(), 0); + } return D; } }; @@ -289,14 +507,25 @@ struct joint_matrix_mad_impl< namespace experimental::matrix { -template void joint_matrix_load( - Group sg, joint_matrix &res, + Group sg, joint_matrix &res, multi_ptr src, size_t stride) { - sycl::ext::oneapi::detail::joint_matrix_load_impl{} .load(res, src, stride); +#else + (void)sg; + (void)res; + (void)src; + (void)stride; + throw runtime_error( + "When using SYCL_EXT_ONEAPI_MATRIX=3 joint_matrix_load is " + "only supported by CUDA devices", + PI_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) } template &src, multi_ptr dst, size_t stride) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) sycl::ext::oneapi::detail::joint_matrix_store_impl{} .store(src, dst, stride); +#else + (void)sg; + (void)src; + (void)dst; + (void)stride; + throw runtime_error( + "When using SYCL_EXT_ONEAPI_MATRIX=3 joint_matrix_store is " + "only supported by CUDA devices", + PI_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) } template A, joint_matrix B, joint_matrix C) { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) return sycl::ext::oneapi::detail::joint_matrix_mad_impl< T1, T2, M, K, N, LayoutA, LayoutB, LayoutC>{} .mad(A, B, C); +#else + (void)sg; + (void)A; + (void)B; + (void)C; + throw runtime_error("When using SYCL_EXT_ONEAPI_MATRIX=3 joint_matrix_mad is " + "only supported by CUDA devices", + PI_INVALID_DEVICE); +#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) } } // namespace experimental::matrix diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-bf16-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-bf16-test.cpp new file mode 100644 index 0000000000000..b79d964c951c4 --- /dev/null +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-bf16-test.cpp @@ -0,0 +1,199 @@ +// REQUIRES: cuda + +// RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s + +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +constexpr int stride = 16; + +int main() { + + buffer bufA(nullptr, range<1>(1)); + buffer bufB(nullptr, range<1>(1)); + buffer bufC(nullptr, range<1>(1)); + buffer bufD(nullptr, range<1>(1)); + + queue q; + + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + auto accD = bufD.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.bf16.p0i32(i32* %call.ascast.i.i49.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.bf16.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.bf16(i32 %11, i32 %12, i32 %13, i32 %14, i32 %17, i32 %18, i32 %19, i32 %20, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_14, float %22, float %23, float %24, float %25, float %26, float %27, float %28, float %29, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.bf16.p0i32(i32* %call.ascast.i.i49.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.col.stride.bf16.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.bf16(i32 %11, i32 %12, i32 %13, i32 %14, i32 %17, i32 %18, i32 %19, i32 %20, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_14, float %22, float %23, float %24, float %25, float %26, float %27, float %28, float %29, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.bf16.p0i32(i32* %call.ascast.i.i50.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.b.row.stride.bf16.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.bf16(i32 %11, i32 %12, i32 %13, i32 %14, i32 %15, i32 %16, i32 %17, i32 %18, i32 %21, i32 %22, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_14, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.bf16.p0i32(i32* %call.ascast.i.i50.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.b.col.stride.bf16.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.bf16(i32 %11, i32 %12, i32 %13, i32 %14, i32 %15, i32 %16, i32 %17, i32 %18, i32 %21, i32 %22, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_14, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.bf16.p0i32(i32* %call.ascast.i.i50.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.row.stride.bf16.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.bf16(i32 %11, i32 %12, i32 %15, i32 %16, i32 %17, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_14, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.bf16.p0i32(i32* %call.ascast.i.i50.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.col.stride.bf16.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.bf16(i32 %11, i32 %12, i32 %15, i32 %16, i32 %17, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_14, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + }); + + return 0; +}; diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-double-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-double-test.cpp index 408899e0897ea..aac0356ea9a6f 100644 --- a/sycl/test/check_device_code/matrix/matrix-nvptx-double-test.cpp +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-double-test.cpp @@ -1,4 +1,4 @@ -// REQUIRES: gpu, cuda +// REQUIRES: cuda // RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s @@ -36,8 +36,8 @@ int main() { auto accD = bufD.get_access(cgh); cgh.parallel_for( - nd_range<2>({1, 32}, {1, 32}), [= - ](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); joint_matrix(cgh); cgh.parallel_for( - nd_range<2>({1, 32}, {1, 32}), [= - ](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { sycl::sub_group sg = item.get_sub_group(); joint_matrix + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +constexpr int stride = 16; + +int main() { + + buffer bufA(nullptr, range<1>(1)); + buffer bufB(nullptr, range<1>(1)); + buffer bufC(nullptr, range<1>(1)); + buffer bufD(nullptr, range<1>(1)); + + queue q; + + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + auto accD = bufD.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16.p0i32(i32* %call.ascast.i.i48.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.f16.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f32.f32(<2 x half> %11, <2 x half> %12, <2 x half> %13, <2 x half> %14, <2 x half> %15, <2 x half> %16, <2 x half> %17, <2 x half> %18, <2 x half> %21, <2 x half> %22, <2 x half> %23, <2 x half> %24, <2 x half> %25, <2 x half> %26, <2 x half> %27, <2 x half> %28, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_14, float %30, float %31, float %32, float %33, float %34, float %35, float %36, float %37, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.f16.p0i32(i32* %call.ascast.i.i48.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.b.col.stride.f16.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f32.f32(<2 x half> %11, <2 x half> %12, <2 x half> %13, <2 x half> %14, <2 x half> %15, <2 x half> %16, <2 x half> %17, <2 x half> %18, <2 x half> %21, <2 x half> %22, <2 x half> %23, <2 x half> %24, <2 x half> %25, <2 x half> %26, <2 x half> %27, <2 x half> %28, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_14, float %30, float %31, float %32, float %33, float %34, float %35, float %36, float %37, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.f16.p0i32(i32* %call.ascast.i.i48.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.b.row.stride.f16.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f32.f32(<2 x half> %11, <2 x half> %12, <2 x half> %13, <2 x half> %14, <2 x half> %15, <2 x half> %16, <2 x half> %17, <2 x half> %18, <2 x half> %21, <2 x half> %22, <2 x half> %23, <2 x half> %24, <2 x half> %25, <2 x half> %26, <2 x half> %27, <2 x half> %28, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_14, float %30, float %31, float %32, float %33, float %34, float %35, float %36, float %37, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.f16.p0i32(i32* %call.ascast.i.i48.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.b.col.stride.f16.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f32.f32(<2 x half> %11, <2 x half> %12, <2 x half> %13, <2 x half> %14, <2 x half> %15, <2 x half> %16, <2 x half> %17, <2 x half> %18, <2 x half> %21, <2 x half> %22, <2 x half> %23, <2 x half> %24, <2 x half> %25, <2 x half> %26, <2 x half> %27, <2 x half> %28, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_14, float %30, float %31, float %32, float %33, float %34, float %35, float %36, float %37, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.f16.p0i32(i32* %call.ascast.i.i48.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.b.row.stride.f16.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f32.f32(<2 x half> %11, <2 x half> %12, <2 x half> %13, <2 x half> %14, <2 x half> %15, <2 x half> %16, <2 x half> %17, <2 x half> %18, <2 x half> %21, <2 x half> %22, <2 x half> %23, <2 x half> %24, <2 x half> %25, <2 x half> %26, <2 x half> %27, <2 x half> %28, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_14, float %30, float %31, float %32, float %33, float %34, float %35, float %36, float %37, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.f16.p0i32(i32* %call.ascast.i.i48.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.b.col.stride.f16.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f32.f32(<2 x half> %11, <2 x half> %12, <2 x half> %13, <2 x half> %14, <2 x half> %15, <2 x half> %16, <2 x half> %17, <2 x half> %18, <2 x half> %21, <2 x half> %22, <2 x half> %23, <2 x half> %24, <2 x half> %25, <2 x half> %26, <2 x half> %27, <2 x half> %28, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_14, float %30, float %31, float %32, float %33, float %34, float %35, float %36, float %37, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + }); + + return 0; +}; diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-half-half-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-half-half-test.cpp new file mode 100644 index 0000000000000..a622cd0b62243 --- /dev/null +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-half-half-test.cpp @@ -0,0 +1,191 @@ +// REQUIRES: cuda + +// RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_70 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s + +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +constexpr int stride = 16; + +int main() { + + buffer bufA(nullptr, range<1>(1)); + buffer bufB(nullptr, range<1>(1)); + buffer bufC(nullptr, range<1>(1)); + buffer bufD(nullptr, range<1>(1)); + + queue q; + + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + auto accD = bufD.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f16.p0i32(i32* %call.ascast.i.i49.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16.p0i32(i32* %call.ascast.i.i41.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.f16.p0i32(i32* %call.ascast.i.i33.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half> %8, <2 x half> %9, <2 x half> %10, <2 x half> %11, <2 x half> %12, <2 x half> %13, <2 x half> %14, <2 x half> %15, <2 x half> %18, <2 x half> %19, <2 x half> %20, <2 x half> %21, <2 x half> %22, <2 x half> %23, <2 x half> %24, <2 x half> %25, <2 x half> %2, <2 x half> %3, <2 x half> %4, <2 x half> %5) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p0i32(i32* %call.ascast.i.i.i, <2 x half> %27, <2 x half> %28, <2 x half> %29, <2 x half> %30, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f16.p0i32(i32* %call.ascast.i.i49.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.f16.p0i32(i32* %call.ascast.i.i41.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.b.col.stride.f16.p0i32(i32* %call.ascast.i.i33.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.col.col.f16.f16(<2 x half> %8, <2 x half> %9, <2 x half> %10, <2 x half> %11, <2 x half> %12, <2 x half> %13, <2 x half> %14, <2 x half> %15, <2 x half> %18, <2 x half> %19, <2 x half> %20, <2 x half> %21, <2 x half> %22, <2 x half> %23, <2 x half> %24, <2 x half> %25, <2 x half> %2, <2 x half> %3, <2 x half> %4, <2 x half> %5) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f16.p0i32(i32* %call.ascast.i.i.i, <2 x half> %27, <2 x half> %28, <2 x half> %29, <2 x half> %30, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.f16.p0i32(i32* %call.ascast.i.i49.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.f16.p0i32(i32* %call.ascast.i.i41.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.b.row.stride.f16.p0i32(i32* %call.ascast.i.i33.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.row.row.f16.f16(<2 x half> %8, <2 x half> %9, <2 x half> %10, <2 x half> %11, <2 x half> %12, <2 x half> %13, <2 x half> %14, <2 x half> %15, <2 x half> %18, <2 x half> %19, <2 x half> %20, <2 x half> %21, <2 x half> %22, <2 x half> %23, <2 x half> %24, <2 x half> %25, <2 x half> %2, <2 x half> %3, <2 x half> %4, <2 x half> %5) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f16.p0i32(i32* %call.ascast.i.i.i, <2 x half> %27, <2 x half> %28, <2 x half> %29, <2 x half> %30, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.f16.p0i32(i32* %call.ascast.i.i49.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.f16.p0i32(i32* %call.ascast.i.i41.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.load.b.col.stride.f16.p0i32(i32* %call.ascast.i.i33.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m32n8k16.mma.col.col.f16.f16(<2 x half> %8, <2 x half> %9, <2 x half> %10, <2 x half> %11, <2 x half> %12, <2 x half> %13, <2 x half> %14, <2 x half> %15, <2 x half> %18, <2 x half> %19, <2 x half> %20, <2 x half> %21, <2 x half> %22, <2 x half> %23, <2 x half> %24, <2 x half> %25, <2 x half> %2, <2 x half> %3, <2 x half> %4, <2 x half> %5) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f16.p0i32(i32* %call.ascast.i.i.i, <2 x half> %27, <2 x half> %28, <2 x half> %29, <2 x half> %30, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.f16.p0i32(i32* %call.ascast.i.i49.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.f16.p0i32(i32* %call.ascast.i.i41.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.b.row.stride.f16.p0i32(i32* %call.ascast.i.i33.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.row.row.f16.f16(<2 x half> %8, <2 x half> %9, <2 x half> %10, <2 x half> %11, <2 x half> %12, <2 x half> %13, <2 x half> %14, <2 x half> %15, <2 x half> %18, <2 x half> %19, <2 x half> %20, <2 x half> %21, <2 x half> %22, <2 x half> %23, <2 x half> %24, <2 x half> %25, <2 x half> %2, <2 x half> %3, <2 x half> %4, <2 x half> %5) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f16.p0i32(i32* %call.ascast.i.i.i, <2 x half> %27, <2 x half> %28, <2 x half> %29, <2 x half> %30, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.f16.p0i32(i32* %call.ascast.i.i49.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.f16.p0i32(i32* %call.ascast.i.i41.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.load.b.col.stride.f16.p0i32(i32* %call.ascast.i.i33.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m8n32k16.mma.col.col.f16.f16(<2 x half> %8, <2 x half> %9, <2 x half> %10, <2 x half> %11, <2 x half> %12, <2 x half> %13, <2 x half> %14, <2 x half> %15, <2 x half> %18, <2 x half> %19, <2 x half> %20, <2 x half> %21, <2 x half> %22, <2 x half> %23, <2 x half> %24, <2 x half> %25, <2 x half> %2, <2 x half> %3, <2 x half> %4, <2 x half> %5) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f16.p0i32(i32* %call.ascast.i.i.i, <2 x half> %27, <2 x half> %28, <2 x half> %29, <2 x half> %30, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + }); + + return 0; +}; diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-int8-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-int8-test.cpp new file mode 100644 index 0000000000000..6a62502ff62de --- /dev/null +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-int8-test.cpp @@ -0,0 +1,191 @@ +// REQUIRES: cuda + +// RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_72 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s + +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +constexpr int stride = 16; + +int main() { + + buffer bufA(nullptr, range<1>(1)); + buffer bufB(nullptr, range<1>(1)); + buffer bufC(nullptr, range<1>(1)); + buffer bufD(nullptr, range<1>(1)); + + queue q; + + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + auto accD = bufD.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.s8.p0i32(i32* %call.ascast.i.i52.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.s8.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.s8(i32 %11, i32 %12, i32 %15, i32 %16, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, i32 %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_14, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, i32 %25, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.s32.p1i32(i32 addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.s8.p0i32(i32* %call.ascast.i.i52.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.col.stride.s8.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.s8(i32 %11, i32 %12, i32 %15, i32 %16, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, i32 %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1i32(i32 addrspace(1)* %_arg_14, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, i32 %25, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.s8.p0i32(i32* %call.ascast.i.i49.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call i32 @llvm.nvvm.wmma.m32n8k16.load.b.row.stride.s8.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.s8(i32 %11, i32 %12, i32 %13, i32 %14, i32 %16, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, i32 %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_14, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, i32 %25, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.s32.p1i32(i32 addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.s8.p0i32(i32* %call.ascast.i.i49.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call i32 @llvm.nvvm.wmma.m32n8k16.load.b.col.stride.s8.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.s8(i32 %11, i32 %12, i32 %13, i32 %14, i32 %16, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, i32 %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1i32(i32 addrspace(1)* %_arg_14, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, i32 %25, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call i32 @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.s8.p0i32(i32* %call.ascast.i.i49.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.row.stride.s8.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.s8(i32 %10, i32 %13, i32 %14, i32 %15, i32 %16, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, i32 %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_14, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, i32 %25, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.s32.p1i32(i32 addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call i32 @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.s8.p0i32(i32* %call.ascast.i.i49.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.col.stride.s8.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.s8(i32 %10, i32 %13, i32 %14, i32 %15, i32 %16, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, i32 %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1i32(i32 addrspace(1)* %_arg_14, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, i32 %25, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + }); + + return 0; +}; \ No newline at end of file diff --git a/sycl/test/check_device_code/matrix/matrix-nvptx-uint8-test.cpp b/sycl/test/check_device_code/matrix/matrix-nvptx-uint8-test.cpp new file mode 100644 index 0000000000000..637a9e41bd2f7 --- /dev/null +++ b/sycl/test/check_device_code/matrix/matrix-nvptx-uint8-test.cpp @@ -0,0 +1,191 @@ +// REQUIRES: cuda + +// RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_72 -DSYCL_EXT_ONEAPI_MATRIX=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s + +#include + +using namespace sycl; +using namespace sycl::ext::oneapi::experimental::matrix; + +constexpr int stride = 16; + +int main() { + + buffer bufA(nullptr, range<1>(1)); + buffer bufB(nullptr, range<1>(1)); + buffer bufC(nullptr, range<1>(1)); + buffer bufD(nullptr, range<1>(1)); + + queue q; + + q.submit([&](handler &cgh) { + auto accC = bufC.get_access(cgh); + auto accA = bufA.get_access(cgh); + auto accB = bufB.get_access(cgh); + auto accD = bufD.get_access(cgh); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.u8.p0i32(i32* %call.ascast.i.i52.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.u8.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.row.row.u8(i32 %11, i32 %12, i32 %15, i32 %16, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, i32 %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_14, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, i32 %25, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.s32.p1i32(i32 addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.u8.p0i32(i32* %call.ascast.i.i52.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.col.stride.u8.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.mma.col.col.u8(i32 %11, i32 %12, i32 %15, i32 %16, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, i32 %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.s32.p1i32(i32 addrspace(1)* %_arg_14, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, i32 %25, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.u8.p0i32(i32* %call.ascast.i.i49.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call i32 @llvm.nvvm.wmma.m32n8k16.load.b.row.stride.u8.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.row.row.u8(i32 %11, i32 %12, i32 %13, i32 %14, i32 %16, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, i32 %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_14, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, i32 %25, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.s32.p1i32(i32 addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.u8.p0i32(i32* %call.ascast.i.i49.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call i32 @llvm.nvvm.wmma.m32n8k16.load.b.col.stride.u8.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.mma.col.col.u8(i32 %11, i32 %12, i32 %13, i32 %14, i32 %16, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, i32 %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.s32.p1i32(i32 addrspace(1)* %_arg_14, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, i32 %25, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call i32 @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.u8.p0i32(i32* %call.ascast.i.i49.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.row.stride.u8.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.row.row.u8(i32 %10, i32 %13, i32 %14, i32 %15, i32 %16, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, i32 %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.s32.p1i32(i32 addrspace(1)* %_arg_14, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, i32 %25, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + + cgh.parallel_for( + nd_range<2>({1, 32}, {1, 32}), + [=](nd_item<2> item) [[sycl::reqd_work_group_size(1, 1, 32)]] { + sycl::sub_group sg = item.get_sub_group(); + + joint_matrix + sub_c; + + joint_matrix + sub_a; + + joint_matrix + sub_b; + + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.s32.p1i32(i32 addrspace(1)* %_arg_, i32 16) #{{.*}} + joint_matrix_load(sg, sub_c, accC.get_pointer(), stride); + // CHECK: tail call i32 @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.u8.p0i32(i32* %call.ascast.i.i49.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_a, accA.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.col.stride.u8.p0i32(i32* %call.ascast.i.i.i, i32 16) #{{.*}} + joint_matrix_load(sg, sub_b, accB.get_pointer(), stride); + // CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.mma.col.col.u8(i32 %10, i32 %13, i32 %14, i32 %15, i32 %16, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, i32 %8) #{{.*}} + sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c); + // CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.s32.p1i32(i32 addrspace(1)* %_arg_14, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, i32 %23, i32 %24, i32 %25, i32 16) #{{.*}} + joint_matrix_store(sg, sub_c, accD.get_pointer(), stride); + }); + }); + + return 0; +};