diff --git a/ark/gpu/gpu_compile.cc b/ark/gpu/gpu_compile.cc index da033cfa4..327386d4c 100644 --- a/ark/gpu/gpu_compile.cc +++ b/ark/gpu/gpu_compile.cc @@ -14,13 +14,13 @@ #include #include #include +#include #include "cpu_timer.h" #include "env.h" #include "gpu/gpu_logging.h" #include "include/ark.h" #include "random.h" -#include "threading.h" #define ARK_USE_NVRTC 0 #define ARK_DEBUG_KERNEL 0 @@ -140,6 +140,35 @@ const string link(const vector &ptxs) { #endif // (ARK_USE_NVRTC) +template +static void para_exec(std::vector &items, int max_num_threads, + const std::function &func) { + size_t nthread = (size_t)max_num_threads; + if (nthread > items.size()) { + nthread = items.size(); + } + std::vector threads; + threads.reserve(nthread); + std::mutex mtx; + size_t idx = 0; + for (size_t i = 0; i < nthread; ++i) { + threads.emplace_back([&items, &mtx, &idx, &func] { + size_t local_idx = -1; + for (;;) { + { + const std::lock_guard lock(mtx); + local_idx = idx++; + } + if (local_idx >= items.size()) break; + func(items[local_idx]); + } + }); + } + for (auto &t : threads) { + t.join(); + } +} + const string gpu_compile(const vector &codes, const GpuArchType &arch_type, unsigned int max_reg_cnt) { diff --git a/ark/include/ark.h b/ark/include/ark.h index db53979d6..b187e82f9 100644 --- a/ark/include/ark.h +++ b/ark/include/ark.h @@ -82,73 +82,64 @@ class CodeGenerator; class BaseScheduler; class SchedOp; -// TensorBuf refers to a data array that can be shared by multiple tensors. -class TensorBuf { - public: - TensorBuf(const DimType &bytes = 0, int id = -1); - TensorBuf(const TensorBuf &) = default; - - size_t get_buf_offset() const; - - DimType bytes; - int id; - bool immutable = false; - - protected: - void *buf = nullptr; - - friend class Tensor; - friend class BaseScheduler; -}; - /// Type of tensor data. class TensorType { private: - const int id_; - const int bytes_; const std::string name_; - const std::string pointer_name_; + const int bytes_; + const std::string type_str_; public: - TensorType(int id = -1, int bytes = 0, const std::string &name = "none", - const std::string &pointer_name = "void *"); + TensorType(const std::string &name = "none", int bytes = 0, + const std::string &type_str = "void *"); bool operator==(const TensorType &other) const; bool operator!=(const TensorType &other) const; - int id() const; int bytes() const; const std::string &name() const; - const std::string &pointer_name() const; + const std::string &type_str() const; }; -class Fp16 : public TensorType { - public: - Fp16(); -}; +const TensorType NONE; -class Fp32 : public TensorType { - public: - Fp32(); -}; +std::ostream &operator<<(std::ostream &os, const TensorType &type); -class Int32 : public TensorType { - public: - Int32(); -}; +#define REGISTER_TENSOR_TYPE(_type_name, _bytes, _type_str) \ + class TensorType_##_type_name : public TensorType { \ + public: \ + TensorType_##_type_name() \ + : TensorType{#_type_name, _bytes, _type_str} {} \ + }; \ + const TensorType_##_type_name _type_name; + +REGISTER_TENSOR_TYPE(FP32, 4, "float") +REGISTER_TENSOR_TYPE(FP16, 2, "ark::half") +REGISTER_TENSOR_TYPE(BF16, 2, "ark::bfloat16") +REGISTER_TENSOR_TYPE(INT32, 4, "int32_t") +REGISTER_TENSOR_TYPE(UINT32, 4, "uint32_t") +REGISTER_TENSOR_TYPE(INT8, 1, "int8_t") +REGISTER_TENSOR_TYPE(UINT8, 1, "uint8_t") +REGISTER_TENSOR_TYPE(BYTE, 1, "unsigned char") -class Byte : public TensorType { +// TensorBuf refers to a data array that can be shared by multiple tensors. +class TensorBuf { public: - Byte(); -}; + TensorBuf(const DimType &bytes = 0, int id = -1); + TensorBuf(const TensorBuf &) = default; -const TensorType NONE; -const Fp16 FP16; -const Fp32 FP32; -const Int32 INT32; -const Byte BYTE; + size_t get_buf_offset() const; -std::ostream &operator<<(std::ostream &os, const TensorType &type); + DimType bytes; + int id; + bool immutable = false; + + protected: + void *buf = nullptr; + + friend class Tensor; + friend class BaseScheduler; +}; /// Tensor is a view of a TensorBuf. /// diff --git a/ark/include/ark_utils.h b/ark/include/ark_utils.h index d9fd0a30a..d874929a0 100644 --- a/ark/include/ark_utils.h +++ b/ark/include/ark_utils.h @@ -22,6 +22,16 @@ struct alignas(2) half_t { operator float() const; }; +// bfloat16 type. +struct alignas(2) bfloat16_t { + uint16_t storage; + bfloat16_t() = default; + // Constructor with float parameter + bfloat16_t(float f); + // Conversion operator from bfloat16 to float + operator float() const; +}; + } // namespace ark ark::half_t operator+(ark::half_t const &lhs, ark::half_t const &rhs); @@ -32,6 +42,15 @@ ark::half_t &operator-=(ark::half_t &lhs, ark::half_t const &rhs); ark::half_t abs(ark::half_t const &val); +ark::bfloat16_t operator+(ark::bfloat16_t const &lhs, + ark::bfloat16_t const &rhs); +ark::bfloat16_t operator-(ark::bfloat16_t const &lhs, + ark::bfloat16_t const &rhs); +ark::bfloat16_t operator*(ark::bfloat16_t const &lhs, + ark::bfloat16_t const &rhs); +ark::bfloat16_t &operator+=(ark::bfloat16_t &lhs, ark::bfloat16_t const &rhs); +ark::bfloat16_t &operator-=(ark::bfloat16_t &lhs, ark::bfloat16_t const &rhs); + // A set of utility functions namespace ark { namespace utils { diff --git a/ark/include/kernels/activation.h b/ark/include/kernels/activation.h index e036d775a..1bdf1bbcd 100644 --- a/ark/include/kernels/activation.h +++ b/ark/include/kernels/activation.h @@ -9,9 +9,9 @@ namespace ark { struct Relu { - static DEVICE float compute(float input) { return max(input, 0.0f); } - static DEVICE __half2 compute(__half2 input) { - return __hmax2(input, (__half2_raw){0, 0}); + template + static DEVICE DataType compute(DataType input) { + return type::Max::compute(input, type::Constant::zero()); } }; @@ -22,6 +22,10 @@ struct Gelu { (input + 0.044715f * input * input * input))); } + static DEVICE bfloat16 compute(bfloat16 input) { + return bfloat16(Gelu::compute(float(input))); + } + static DEVICE __half2 compute(__half2 input) { __half2 half_pi = __float2half2_rn(0.7978845608f); // sqrt(2 / pi) = 0.7978845608 @@ -48,8 +52,11 @@ struct Gelu { }; struct Sigmoid { - static DEVICE float compute(float input) { - return 1.0f / (1.0f + expf(-input)); + template + static DEVICE DataType compute(DataType input) { + return type::Div::compute( + DataType(1.0f), + (type::Add::compute(DataType(1.0f), type::Exp::compute(-input)))); } static DEVICE __half2 compute(__half2 input) { __half2 one = __float2half2_rn(1.0f); @@ -59,91 +66,49 @@ struct Sigmoid { } }; -template -struct Activation; - -template -struct Activation<_ActivationType, _InShape, half, 2> { - using InputType = half; - using OutputType = half; - static const int NelemPerThread = 2; - - static DEVICE void compute(half *output, const half *input) { - __half2 *pout = (__half2 *)output; - if (_InShape::W == 1) { - *pout = - _ActivationType::compute(__half2half2(*(const __half *)input)); - } else { - __half2 *pin = (__half2 *)input; - *pout = _ActivationType::compute(*pin); - } - } -}; - -template -struct Activation<_ActivationType, _InShape, float, 1> { - using InputType = float; - using OutputType = float; - static const int NelemPerThread = 1; - - static DEVICE void compute(float *output, const float *input) { - *output = _ActivationType::compute(*input); - } -}; - -template -DEVICE void relu(float *out, float *in, int uop_idx, int) { - Broadcast1>::run(out, in, - uop_idx); -} - -template -DEVICE void relu(half *out, half *in, int uop_idx, int) { - Broadcast1>::run(out, in, - uop_idx); -} - -template -DEVICE void gelu(float *out, float *in, int uop_idx, int) { - Broadcast1>::run(out, in, - uop_idx); -} - template -DEVICE void gelu(half *out, half *in, int uop_idx, int) { + int SmemBytes, typename InDataType, typename OutDataType> +DEVICE void relu(OutDataType *out, InDataType *in, int uop_idx, + int smem_per_warp) { + constexpr int NelemPerThread = + (sizeof(OutDataType) <= 2 && UnitOutDims::W % 8 == 0) + ? 8 + : (UnitOutDims::W % 4 == 0) ? 4 : (UnitOutDims::W % 2 == 0) ? 2 : 1; Broadcast1>::run(out, in, - uop_idx); + SmemBytes, + Broadcast1Intrinsic>::run(out, in, uop_idx); } template -DEVICE void sigmoid(float *out, float *in, int uop_idx, int) { + int SmemBytes, typename InDataType, typename OutDataType> +DEVICE void gelu(OutDataType *out, InDataType *in, int uop_idx, + int smem_per_warp) { + constexpr int NelemPerThread = + (sizeof(OutDataType) <= 2 && UnitOutDims::W % 8 == 0) + ? 8 + : (UnitOutDims::W % 4 == 0) ? 4 : (UnitOutDims::W % 2 == 0) ? 2 : 1; Broadcast1>::run(out, in, - uop_idx); + SmemBytes, + Broadcast1Intrinsic>::run(out, in, uop_idx); } template -DEVICE void sigmoid(half *out, half *in, int uop_idx, int) { + int SmemBytes, typename InDataType, typename OutDataType> +DEVICE void sigmoid(OutDataType *out, InDataType *in, int uop_idx, + int smem_per_warp) { + constexpr int NelemPerThread = + (sizeof(OutDataType) <= 2 && UnitOutDims::W % 8 == 0) + ? 8 + : (UnitOutDims::W % 4 == 0) ? 4 : (UnitOutDims::W % 2 == 0) ? 2 : 1; Broadcast1>::run(out, in, - uop_idx); + SmemBytes, + Broadcast1Intrinsic>::run(out, in, uop_idx); } } // namespace ark diff --git a/ark/include/kernels/arithmetic.h b/ark/include/kernels/arithmetic.h index 0058a8327..62c230dd1 100644 --- a/ark/include/kernels/arithmetic.h +++ b/ark/include/kernels/arithmetic.h @@ -8,447 +8,61 @@ namespace ark { -struct Add { - static DEVICE float compute(float a, float b) { return a + b; } - static DEVICE half compute(half a, half b) { return a + b; } - static DEVICE __half compute(__half a, __half b) { return __hadd(a, b); } - static DEVICE __half2 compute(__half2 a, __half2 b) { - return __hadd2(a, b); - } -}; - -struct Sub { - static DEVICE float compute(float a, float b) { return a - b; } - static DEVICE half compute(half a, half b) { return a - b; } - static DEVICE __half compute(__half a, __half b) { return __hsub(a, b); } - static DEVICE __half2 compute(__half2 a, __half2 b) { - return __hsub2(a, b); - } -}; - -struct Mul { - static DEVICE float compute(float a, float b) { return a * b; } - static DEVICE half compute(half a, half b) { return a * b; } - static DEVICE __half compute(__half a, __half b) { return __hmul(a, b); } - static DEVICE __half2 compute(__half2 a, __half2 b) { - return __hmul2(a, b); - } -}; - -struct Div { - static DEVICE float compute(float a, float b) { return a / b; } - static DEVICE half compute(half a, half b) { return a / b; } - static DEVICE __half compute(__half a, __half b) { return __hdiv(a, b); } - static DEVICE __half2 compute(__half2 a, __half2 b) { - return __h2div(a, b); - } -}; - -template -struct Arithmetic { - using InputType = _DataType; - using OutputType = _DataType; - static const int NelemPerThread = _NelemPerThread; - - static DEVICE void compute(_DataType *c, const _DataType *a, - const _DataType *b) { - *c = *a + *b; - if (_In0Shape::W == 1 && _In1Shape::W == 1) { - // do nothing - } else if (_In0Shape::W == 1) { -#pragma unroll - for (int i = 1; i < NelemPerThread; ++i) { - c[i] = _ArithmeticType::compute(*a, b[i]); - } - } else if (_In1Shape::W == 1) { -#pragma unroll - for (int i = 1; i < NelemPerThread; ++i) { - c[i] = _ArithmeticType::compute(a[i], *b); - } - } else { -#pragma unroll - for (int i = 1; i < NelemPerThread; ++i) { - c[i] = _ArithmeticType::compute(a[i], b[i]); - } - } - } -}; - -template -struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, float, 2> { - using InputType = float; - using OutputType = float; - static const int NelemPerThread = 2; - - static DEVICE void compute(float *c, const float *a, const float *b) { - if (_In0Shape::W == 1 && _In1Shape::W == 1) { - *c = _ArithmeticType::compute(*a, *b); - } else if (_In0Shape::W == 1) { - float2 *pb = (float2 *)b; - float2 *pc = (float2 *)c; - pc->x = _ArithmeticType::compute(*a, pb->x); - pc->y = _ArithmeticType::compute(*a, pb->y); - } else if (_In1Shape::W == 1) { - float2 *pa = (float2 *)a; - float2 *pc = (float2 *)c; - pc->x = _ArithmeticType::compute(pa->x, *b); - pc->y = _ArithmeticType::compute(pa->y, *b); - } else { - float2 *pa = (float2 *)a; - float2 *pb = (float2 *)b; - float2 *pc = (float2 *)c; - pc->x = _ArithmeticType::compute(pa->x, pb->x); - pc->y = _ArithmeticType::compute(pa->y, pb->y); - } - } -}; - -template -struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, float, 4> { - using InputType = float; - using OutputType = float; - static const int NelemPerThread = 4; - - static DEVICE void compute(float *c, const float *a, const float *b) { - if (_In0Shape::W == 1 && _In1Shape::W == 1) { - *c = _ArithmeticType::compute(*a, *b); - } else if (_In0Shape::W == 1) { - longlong2 reg_b; - longlong2 reg_c; - asm volatile("ld.global.v2.u64 {%0,%1}, [%2];" - : "=l"(reg_b.x), "=l"(reg_b.y) - : "l"(b) - : "memory"); - float4 *pb = (float4 *)®_b; - float4 *pc = (float4 *)®_c; - float v = *a; - pc->w = _ArithmeticType::compute(v, pb->w); - pc->x = _ArithmeticType::compute(v, pb->x); - pc->y = _ArithmeticType::compute(v, pb->y); - pc->z = _ArithmeticType::compute(v, pb->z); - asm volatile("st.global.v2.u64 [%0], {%1,%2};" - : - : "l"(c), "l"(reg_c.x), "l"(reg_c.y) - : "memory"); - } else if (_In1Shape::W == 1) { - longlong2 reg_a; - longlong2 reg_c; - asm volatile("ld.global.v2.u64 {%0,%1}, [%2];" - : "=l"(reg_a.x), "=l"(reg_a.y) - : "l"(a) - : "memory"); - float4 *pa = (float4 *)®_a; - float4 *pc = (float4 *)®_c; - float v = *b; - pc->w = _ArithmeticType::compute(pa->w, v); - pc->x = _ArithmeticType::compute(pa->x, v); - pc->y = _ArithmeticType::compute(pa->y, v); - pc->z = _ArithmeticType::compute(pa->z, v); - asm volatile("st.global.v2.u64 [%0], {%1,%2};" - : - : "l"(c), "l"(reg_c.x), "l"(reg_c.y) - : "memory"); - } else { - longlong2 reg_a; - longlong2 reg_b; - longlong2 reg_c; - asm volatile("ld.global.v2.u64 {%0,%1}, [%2];" - : "=l"(reg_a.x), "=l"(reg_a.y) - : "l"(a) - : "memory"); - asm volatile("ld.global.v2.u64 {%0,%1}, [%2];" - : "=l"(reg_b.x), "=l"(reg_b.y) - : "l"(b) - : "memory"); - float4 *pa = (float4 *)®_a; - float4 *pb = (float4 *)®_b; - float4 *pc = (float4 *)®_c; - pc->w = _ArithmeticType::compute(pa->w, pb->w); - pc->x = _ArithmeticType::compute(pa->x, pb->x); - pc->y = _ArithmeticType::compute(pa->y, pb->y); - pc->z = _ArithmeticType::compute(pa->z, pb->z); - asm volatile("st.global.v2.u64 [%0], {%1,%2};" - : - : "l"(c), "l"(reg_c.x), "l"(reg_c.y) - : "memory"); - } - } -}; - -template -struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, half, 2> { - using InputType = half; - using OutputType = half; - static const int NelemPerThread = 2; - - static DEVICE void compute(half *c, const half *a, const half *b) { - if (_In0Shape::W == 1 && _In1Shape::W == 1) { - *c = _ArithmeticType::compute(*a, *b); - } else if (_In0Shape::W == 1) { - __half2 *pb = (__half2 *)b; - *(__half2 *)c = - _ArithmeticType::compute(__half2half2(*(const __half *)a), *pb); - } else if (_In1Shape::W == 1) { - __half2 *pa = (__half2 *)a; - *(__half2 *)c = - _ArithmeticType::compute(*pa, __half2half2(*(const __half *)b)); - } else { - __half2 *pa = (__half2 *)a; - __half2 *pb = (__half2 *)b; - *(__half2 *)c = _ArithmeticType::compute(*pa, *pb); - } - } -}; - -template -struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, half, 4> { - using InputType = half; - using OutputType = half; - static const int NelemPerThread = 4; - - static DEVICE void compute(half *c, const half *a, const half *b) { - if (_In0Shape::W == 1 && _In1Shape::W == 1) { - *c = _ArithmeticType::compute(*a, *b); - } else if (_In0Shape::W == 1) { - uint64_t reg_b = *(uint64_t *)b; - uint64_t reg_c; - __half2 *pb = (__half2 *)®_b; - __half2 *pc = (__half2 *)®_c; - __half2 v = __half2half2(*(const __half *)a); - pc[0] = _ArithmeticType::compute(v, pb[0]); - pc[1] = _ArithmeticType::compute(v, pb[1]); - *(uint64_t *)c = reg_c; - } else if (_In1Shape::W == 1) { - uint64_t reg_a = *(uint64_t *)a; - uint64_t reg_c; - __half2 *pa = (__half2 *)®_a; - __half2 *pc = (__half2 *)®_c; - __half2 v = __half2half2(*(const __half *)b); - pc[0] = _ArithmeticType::compute(pa[0], v); - pc[1] = _ArithmeticType::compute(pa[1], v); - *(uint64_t *)c = reg_c; - } else { - uint64_t reg_a = *(uint64_t *)a; - uint64_t reg_b = *(uint64_t *)b; - uint64_t reg_c; - __half2 *pa = (__half2 *)®_a; - __half2 *pb = (__half2 *)®_b; - __half2 *pc = (__half2 *)®_c; - pc[0] = _ArithmeticType::compute(pa[0], pb[0]); - pc[1] = _ArithmeticType::compute(pa[1], pb[1]); - *(uint64_t *)c = reg_c; - } - } -}; - -template -struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, half, 8> { - using InputType = half; - using OutputType = half; - static const int NelemPerThread = 8; - - static DEVICE void compute(half *c, const half *a, const half *b) { - if (_In0Shape::W == 1 && _In1Shape::W == 1) { - *c = _ArithmeticType::compute(*a, *b); - } else if (_In0Shape::W == 1) { - longlong2 reg_b; - longlong2 reg_c; - asm volatile("ld.global.v2.u64 {%0,%1}, [%2];" - : "=l"(reg_b.x), "=l"(reg_b.y) - : "l"(b) - : "memory"); - __half2 *pb = (__half2 *)®_b; - __half2 *pc = (__half2 *)®_c; - __half2 v = __half2half2(*(const __half *)a); - pc[0] = _ArithmeticType::compute(v, pb[0]); - pc[1] = _ArithmeticType::compute(v, pb[1]); - pc[2] = _ArithmeticType::compute(v, pb[2]); - pc[3] = _ArithmeticType::compute(v, pb[3]); - asm volatile("st.global.v2.u64 [%0], {%1,%2};" - : - : "l"(c), "l"(reg_c.x), "l"(reg_c.y) - : "memory"); - } else if (_In1Shape::W == 1) { - longlong2 reg_a; - longlong2 reg_c; - asm volatile("ld.global.v2.u64 {%0,%1}, [%2];" - : "=l"(reg_a.x), "=l"(reg_a.y) - : "l"(a) - : "memory"); - __half2 *pa = (__half2 *)®_a; - __half2 *pc = (__half2 *)®_c; - __half2 v = __half2half2(*(const __half *)b); - pc[0] = _ArithmeticType::compute(pa[0], v); - pc[1] = _ArithmeticType::compute(pa[1], v); - pc[2] = _ArithmeticType::compute(pa[2], v); - pc[3] = _ArithmeticType::compute(pa[3], v); - asm volatile("st.global.v2.u64 [%0], {%1,%2};" - : - : "l"(c), "l"(reg_c.x), "l"(reg_c.y) - : "memory"); - } else { - longlong2 reg_a; - longlong2 reg_b; - longlong2 reg_c; - asm volatile("ld.global.v2.u64 {%0,%1}, [%2];" - : "=l"(reg_a.x), "=l"(reg_a.y) - : "l"(a) - : "memory"); - asm volatile("ld.global.v2.u64 {%0,%1}, [%2];" - : "=l"(reg_b.x), "=l"(reg_b.y) - : "l"(b) - : "memory"); - __half2 *pa = (__half2 *)®_a; - __half2 *pb = (__half2 *)®_b; - __half2 *pc = (__half2 *)®_c; - pc[0] = _ArithmeticType::compute(pa[0], pb[0]); - pc[1] = _ArithmeticType::compute(pa[1], pb[1]); - pc[2] = _ArithmeticType::compute(pa[2], pb[2]); - pc[3] = _ArithmeticType::compute(pa[3], pb[3]); - asm volatile("st.global.v2.u64 [%0], {%1,%2};" - : - : "l"(c), "l"(reg_c.x), "l"(reg_c.y) - : "memory"); - } - } -}; - -template -DEVICE void add(float *c, const float *a, const float *b, int uop_idx, int) { - constexpr int NelemPerThread = - (UnitOutDims::W % 4 == 0) ? 4 : (UnitOutDims::W % 2 == 0) ? 2 : 1; - Broadcast2>::run(c, a, b, uop_idx); -} - -template -DEVICE void add(half *c, const half *a, const half *b, int uop_idx, int) { - constexpr int NelemPerThread = - (UnitOutDims::W % 8 == 0) - ? 8 - : (UnitOutDims::W % 4 == 0) ? 4 : (UnitOutDims::W % 2 == 0) ? 2 : 1; - Broadcast2>::run(c, a, b, uop_idx); -} - template -DEVICE void sub(float *c, const float *a, const float *b, int uop_idx, int) { - constexpr int NelemPerThread = - (UnitOutDims::W % 4 == 0) ? 4 : (UnitOutDims::W % 2 == 0) ? 2 : 1; - Broadcast2>::run(c, a, b, uop_idx); + typename UnitOutDims, int NumThreads, int SmemBytes, + typename In0DataType, typename In1DataType, typename OutDataType> +DEVICE void add(OutDataType *c, const In0DataType *a, const In1DataType *b, + int uop_idx, int smem_per_warp) { + broadcast2(c, a, b, uop_idx, smem_per_warp); } template -DEVICE void sub(half *c, const half *a, const half *b, int uop_idx, int) { - constexpr int NelemPerThread = - (UnitOutDims::W % 8 == 0) - ? 8 - : (UnitOutDims::W % 4 == 0) ? 4 : (UnitOutDims::W % 2 == 0) ? 2 : 1; - Broadcast2>::run(c, a, b, uop_idx); + typename UnitOutDims, int NumThreads, int SmemBytes, + typename In0DataType, typename In1DataType, typename OutDataType> +DEVICE void sub(OutDataType *c, const In0DataType *a, const In1DataType *b, + int uop_idx, int smem_per_warp) { + broadcast2(c, a, b, uop_idx, smem_per_warp); } template -DEVICE void mul(float *c, float *a, float *b, int uop_idx, int) { - constexpr int NelemPerThread = - (UnitOutDims::W % 4 == 0) ? 4 : (UnitOutDims::W % 2 == 0) ? 2 : 1; - Broadcast2>::run(c, a, b, uop_idx); + typename UnitOutDims, int NumThreads, int SmemBytes, + typename In0DataType, typename In1DataType, typename OutDataType> +DEVICE void mul(OutDataType *c, const In0DataType *a, const In1DataType *b, + int uop_idx, int smem_per_warp) { + broadcast2(c, a, b, uop_idx, smem_per_warp); } template -DEVICE void mul(half *c, half *a, half *b, int uop_idx, int) { - constexpr int NelemPerThread = - (UnitOutDims::W % 8 == 0) - ? 8 - : (UnitOutDims::W % 4 == 0) ? 4 : (UnitOutDims::W % 2 == 0) ? 2 : 1; - Broadcast2>::run(c, a, b, uop_idx); -} - -template -DEVICE void div(float *c, float *a, float *b, int uop_idx, int) { - constexpr int NelemPerThread = - (UnitOutDims::W % 4 == 0) ? 4 : (UnitOutDims::W % 2 == 0) ? 2 : 1; - Broadcast2>::run(c, a, b, uop_idx); -} - -template -DEVICE void div(half *c, half *a, half *b, int uop_idx, int) { - constexpr int NelemPerThread = - (UnitOutDims::W % 8 == 0) - ? 8 - : (UnitOutDims::W % 4 == 0) ? 4 : (UnitOutDims::W % 2 == 0) ? 2 : 1; - Broadcast2>::run(c, a, b, uop_idx); -} - -template -DEVICE void scale(half *y, half *x, float val, int uop_idx, int) { - constexpr int NelemPerThread = - (UnitOutDims::W % 8 == 0) - ? 8 - : (UnitOutDims::W % 4 == 0) ? 4 : (UnitOutDims::W % 2 == 0) ? 2 : 1; - half val_h(val); - using ValDims = Vec<1, 1, 1, 1>; - using ValShape = Vec<1, 1, 1, 1>; - Broadcast2< - InDims, InShape, ValDims, ValShape, OutDims, OutShape, UnitOutDims, - NumThreads, SmemBytes, - Arithmetic>::run(y, x, - &val_h, - uop_idx); + typename UnitOutDims, int NumThreads, int SmemBytes, + typename In0DataType, typename In1DataType, typename OutDataType> +DEVICE void div(OutDataType *c, const In0DataType *a, const In1DataType *b, + int uop_idx, int smem_per_warp) { + broadcast2(c, a, b, uop_idx, smem_per_warp); } template -DEVICE void scale(float *y, float *x, float val, int uop_idx, int) { - constexpr int NelemPerThread = - (UnitOutDims::W % 4 == 0) ? 4 : (UnitOutDims::W % 2 == 0) ? 2 : 1; + int SmemBytes, typename InDataType, typename OutDataType> +DEVICE void scale(OutDataType *y, InDataType *x, float val, int uop_idx, + int smem_per_warp) { + InDataType val_cast(val); using ValDims = Vec<1, 1, 1, 1>; using ValShape = Vec<1, 1, 1, 1>; - Broadcast2>::run(y, x, &val, uop_idx); + broadcast2(y, x, &val_cast, uop_idx, smem_per_warp); } } // namespace ark diff --git a/ark/include/kernels/bfloat16.h b/ark/include/kernels/bfloat16.h new file mode 100644 index 000000000..3a2ba4893 --- /dev/null +++ b/ark/include/kernels/bfloat16.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef ARK_KERNELS_BFLOAT16_H_ +#define ARK_KERNELS_BFLOAT16_H_ + +// clang-format off +#include "cutlass/numeric_types.h" +#include "cutlass/bfloat16.h" +// clang-format on + +namespace ark { +using bfloat16 = cutlass::bfloat16_t; +} // namespace ark + +#endif // ARK_KERNELS_BFLOAT16_H_ diff --git a/ark/include/kernels/broadcast.h b/ark/include/kernels/broadcast.h index 8d720ab60..16585b7fe 100644 --- a/ark/include/kernels/broadcast.h +++ b/ark/include/kernels/broadcast.h @@ -8,6 +8,418 @@ namespace ark { +template +struct Broadcast1Intrinsic { + using InputType = _InputType; + using OutputType = _OutputType; + static const int NelemPerThread = _NelemPerThread; + + static DEVICE void compute(OutputType *out, const InputType *in) { + if (_InShape::W == 1) { + *out = _IntrinsicType::compute(*in); + } else { +#pragma unroll + for (int i = 0; i < NelemPerThread; ++i) { + out[i] = _IntrinsicType::compute(in[i]); + } + } + } +}; + +template +struct Broadcast1Intrinsic<_IntrinsicType, _InShape, float, float, 2> { + using InputType = float; + using OutputType = float; + static const int NelemPerThread = 2; + + static DEVICE void compute(float *out, const float *in) { + if (_InShape::W == 1) { + *out = _IntrinsicType::compute(*in); + } else { + float2 *pout = (float2 *)out; + float2 *pin = (float2 *)in; + pout->x = _IntrinsicType::compute(pin->x); + pout->y = _IntrinsicType::compute(pin->y); + } + } +}; + +template +struct Broadcast1Intrinsic<_IntrinsicType, _InShape, float, float, 4> { + using InputType = float; + using OutputType = float; + static const int NelemPerThread = 4; + + static DEVICE void compute(float *out, const float *in) { + if (_InShape::W == 1) { + *out = _IntrinsicType::compute(*in); + } else { + longlong2 reg_out; + longlong2 reg_in; + asm volatile("ld.global.v2.u64 {%0,%1}, [%2];" + : "=l"(reg_in.x), "=l"(reg_in.y) + : "l"(in) + : "memory"); + float4 *pout = (float4 *)®_out; + float4 *pin = (float4 *)®_in; + pout->w = _IntrinsicType::compute(pin->w); + pout->x = _IntrinsicType::compute(pin->x); + pout->y = _IntrinsicType::compute(pin->y); + pout->z = _IntrinsicType::compute(pin->z); + asm volatile("st.global.v2.u64 [%0], {%1,%2};" + : + : "l"(out), "l"(reg_out.x), "l"(reg_out.y) + : "memory"); + } + } +}; + +template +struct Broadcast1Intrinsic<_IntrinsicType, _InShape, half, half, 2> { + using InputType = half; + using OutputType = half; + static const int NelemPerThread = 2; + + static DEVICE void compute(half *out, const half *in) { + if (_InShape::W == 1) { + *out = _IntrinsicType::compute(*in); + } else { + *(__half2 *)out = _IntrinsicType::compute(*(__half2 *)in); + } + } +}; + +template +struct Broadcast1Intrinsic<_IntrinsicType, _InShape, half, half, 4> { + using InputType = half; + using OutputType = half; + static const int NelemPerThread = 4; + + static DEVICE void compute(half *out, const half *in) { + if (_InShape::W == 1) { + *out = _IntrinsicType::compute(*in); + } else { + uint64_t reg_in = *(uint64_t *)in; + uint64_t reg_out; + __half2 *pin = (__half2 *)®_in; + __half2 *pout = (__half2 *)®_out; + pout[0] = _IntrinsicType::compute(pin[0]); + pout[1] = _IntrinsicType::compute(pin[1]); + *(uint64_t *)out = reg_out; + } + } +}; + +template +struct Broadcast1Intrinsic<_IntrinsicType, _InShape, half, half, 8> { + using InputType = half; + using OutputType = half; + static const int NelemPerThread = 8; + + static DEVICE void compute(half *out, const half *in) { + if (_InShape::W == 1) { + *out = _IntrinsicType::compute(*in); + } else { + longlong2 reg_in; + longlong2 reg_out; + asm volatile("ld.global.v2.u64 {%0,%1}, [%2];" + : "=l"(reg_in.x), "=l"(reg_in.y) + : "l"(in) + : "memory"); + __half2 *pin = (__half2 *)®_in; + __half2 *pout = (__half2 *)®_out; + pout[0] = _IntrinsicType::compute(pin[0]); + pout[1] = _IntrinsicType::compute(pin[1]); + pout[2] = _IntrinsicType::compute(pin[2]); + pout[3] = _IntrinsicType::compute(pin[3]); + asm volatile("st.global.v2.u64 [%0], {%1,%2};" + : + : "l"(out), "l"(reg_out.x), "l"(reg_out.y) + : "memory"); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +template +struct Broadcast2Intrinsic { + using InputType = _InputType; + using OutputType = _OutputType; + static const int NelemPerThread = _NelemPerThread; + + static DEVICE void compute(OutputType *c, const InputType *a, + const InputType *b) { + *c = _IntrinsicType::compute(*a, *b); + if (_In0Shape::W == 1 && _In1Shape::W == 1) { + // do nothing + } else if (_In0Shape::W == 1) { +#pragma unroll + for (int i = 1; i < NelemPerThread; ++i) { + c[i] = _IntrinsicType::compute(*a, b[i]); + } + } else if (_In1Shape::W == 1) { +#pragma unroll + for (int i = 1; i < NelemPerThread; ++i) { + c[i] = _IntrinsicType::compute(a[i], *b); + } + } else { +#pragma unroll + for (int i = 1; i < NelemPerThread; ++i) { + c[i] = _IntrinsicType::compute(a[i], b[i]); + } + } + } +}; + +template +struct Broadcast2Intrinsic<_IntrinsicType, _In0Shape, _In1Shape, float, float, + 2> { + using InputType = float; + using OutputType = float; + static const int NelemPerThread = 2; + + static DEVICE void compute(float *c, const float *a, const float *b) { + if (_In0Shape::W == 1 && _In1Shape::W == 1) { + *c = _IntrinsicType::compute(*a, *b); + } else if (_In0Shape::W == 1) { + float2 *pb = (float2 *)b; + float2 *pc = (float2 *)c; + pc->x = _IntrinsicType::compute(*a, pb->x); + pc->y = _IntrinsicType::compute(*a, pb->y); + } else if (_In1Shape::W == 1) { + float2 *pa = (float2 *)a; + float2 *pc = (float2 *)c; + pc->x = _IntrinsicType::compute(pa->x, *b); + pc->y = _IntrinsicType::compute(pa->y, *b); + } else { + float2 *pa = (float2 *)a; + float2 *pb = (float2 *)b; + float2 *pc = (float2 *)c; + pc->x = _IntrinsicType::compute(pa->x, pb->x); + pc->y = _IntrinsicType::compute(pa->y, pb->y); + } + } +}; + +template +struct Broadcast2Intrinsic<_IntrinsicType, _In0Shape, _In1Shape, float, float, + 4> { + using InputType = float; + using OutputType = float; + static const int NelemPerThread = 4; + + static DEVICE void compute(float *c, const float *a, const float *b) { + if (_In0Shape::W == 1 && _In1Shape::W == 1) { + *c = _IntrinsicType::compute(*a, *b); + } else if (_In0Shape::W == 1) { + longlong2 reg_b; + longlong2 reg_c; + asm volatile("ld.global.v2.u64 {%0,%1}, [%2];" + : "=l"(reg_b.x), "=l"(reg_b.y) + : "l"(b) + : "memory"); + float4 *pb = (float4 *)®_b; + float4 *pc = (float4 *)®_c; + float v = *a; + pc->w = _IntrinsicType::compute(v, pb->w); + pc->x = _IntrinsicType::compute(v, pb->x); + pc->y = _IntrinsicType::compute(v, pb->y); + pc->z = _IntrinsicType::compute(v, pb->z); + asm volatile("st.global.v2.u64 [%0], {%1,%2};" + : + : "l"(c), "l"(reg_c.x), "l"(reg_c.y) + : "memory"); + } else if (_In1Shape::W == 1) { + longlong2 reg_a; + longlong2 reg_c; + asm volatile("ld.global.v2.u64 {%0,%1}, [%2];" + : "=l"(reg_a.x), "=l"(reg_a.y) + : "l"(a) + : "memory"); + float4 *pa = (float4 *)®_a; + float4 *pc = (float4 *)®_c; + float v = *b; + pc->w = _IntrinsicType::compute(pa->w, v); + pc->x = _IntrinsicType::compute(pa->x, v); + pc->y = _IntrinsicType::compute(pa->y, v); + pc->z = _IntrinsicType::compute(pa->z, v); + asm volatile("st.global.v2.u64 [%0], {%1,%2};" + : + : "l"(c), "l"(reg_c.x), "l"(reg_c.y) + : "memory"); + } else { + longlong2 reg_a; + longlong2 reg_b; + longlong2 reg_c; + asm volatile("ld.global.v2.u64 {%0,%1}, [%2];" + : "=l"(reg_a.x), "=l"(reg_a.y) + : "l"(a) + : "memory"); + asm volatile("ld.global.v2.u64 {%0,%1}, [%2];" + : "=l"(reg_b.x), "=l"(reg_b.y) + : "l"(b) + : "memory"); + float4 *pa = (float4 *)®_a; + float4 *pb = (float4 *)®_b; + float4 *pc = (float4 *)®_c; + pc->w = _IntrinsicType::compute(pa->w, pb->w); + pc->x = _IntrinsicType::compute(pa->x, pb->x); + pc->y = _IntrinsicType::compute(pa->y, pb->y); + pc->z = _IntrinsicType::compute(pa->z, pb->z); + asm volatile("st.global.v2.u64 [%0], {%1,%2};" + : + : "l"(c), "l"(reg_c.x), "l"(reg_c.y) + : "memory"); + } + } +}; + +template +struct Broadcast2Intrinsic<_IntrinsicType, _In0Shape, _In1Shape, half, half, + 2> { + using InputType = half; + using OutputType = half; + static const int NelemPerThread = 2; + + static DEVICE void compute(half *c, const half *a, const half *b) { + if (_In0Shape::W == 1 && _In1Shape::W == 1) { + *c = _IntrinsicType::compute(*a, *b); + } else if (_In0Shape::W == 1) { + __half2 *pb = (__half2 *)b; + *(__half2 *)c = + _IntrinsicType::compute(__half2half2(*(const __half *)a), *pb); + } else if (_In1Shape::W == 1) { + __half2 *pa = (__half2 *)a; + *(__half2 *)c = + _IntrinsicType::compute(*pa, __half2half2(*(const __half *)b)); + } else { + __half2 *pa = (__half2 *)a; + __half2 *pb = (__half2 *)b; + *(__half2 *)c = _IntrinsicType::compute(*pa, *pb); + } + } +}; + +template +struct Broadcast2Intrinsic<_IntrinsicType, _In0Shape, _In1Shape, half, half, + 4> { + using InputType = half; + using OutputType = half; + static const int NelemPerThread = 4; + + static DEVICE void compute(half *c, const half *a, const half *b) { + if (_In0Shape::W == 1 && _In1Shape::W == 1) { + *c = _IntrinsicType::compute(*a, *b); + } else if (_In0Shape::W == 1) { + uint64_t reg_b = *(uint64_t *)b; + uint64_t reg_c; + __half2 *pb = (__half2 *)®_b; + __half2 *pc = (__half2 *)®_c; + __half2 v = __half2half2(*(const __half *)a); + pc[0] = _IntrinsicType::compute(v, pb[0]); + pc[1] = _IntrinsicType::compute(v, pb[1]); + *(uint64_t *)c = reg_c; + } else if (_In1Shape::W == 1) { + uint64_t reg_a = *(uint64_t *)a; + uint64_t reg_c; + __half2 *pa = (__half2 *)®_a; + __half2 *pc = (__half2 *)®_c; + __half2 v = __half2half2(*(const __half *)b); + pc[0] = _IntrinsicType::compute(pa[0], v); + pc[1] = _IntrinsicType::compute(pa[1], v); + *(uint64_t *)c = reg_c; + } else { + uint64_t reg_a = *(uint64_t *)a; + uint64_t reg_b = *(uint64_t *)b; + uint64_t reg_c; + __half2 *pa = (__half2 *)®_a; + __half2 *pb = (__half2 *)®_b; + __half2 *pc = (__half2 *)®_c; + pc[0] = _IntrinsicType::compute(pa[0], pb[0]); + pc[1] = _IntrinsicType::compute(pa[1], pb[1]); + *(uint64_t *)c = reg_c; + } + } +}; + +template +struct Broadcast2Intrinsic<_IntrinsicType, _In0Shape, _In1Shape, half, half, + 8> { + using InputType = half; + using OutputType = half; + static const int NelemPerThread = 8; + + static DEVICE void compute(half *c, const half *a, const half *b) { + if (_In0Shape::W == 1 && _In1Shape::W == 1) { + *c = _IntrinsicType::compute(*a, *b); + } else if (_In0Shape::W == 1) { + longlong2 reg_b; + longlong2 reg_c; + asm volatile("ld.global.v2.u64 {%0,%1}, [%2];" + : "=l"(reg_b.x), "=l"(reg_b.y) + : "l"(b) + : "memory"); + __half2 *pb = (__half2 *)®_b; + __half2 *pc = (__half2 *)®_c; + __half2 v = __half2half2(*(const __half *)a); + pc[0] = _IntrinsicType::compute(v, pb[0]); + pc[1] = _IntrinsicType::compute(v, pb[1]); + pc[2] = _IntrinsicType::compute(v, pb[2]); + pc[3] = _IntrinsicType::compute(v, pb[3]); + asm volatile("st.global.v2.u64 [%0], {%1,%2};" + : + : "l"(c), "l"(reg_c.x), "l"(reg_c.y) + : "memory"); + } else if (_In1Shape::W == 1) { + longlong2 reg_a; + longlong2 reg_c; + asm volatile("ld.global.v2.u64 {%0,%1}, [%2];" + : "=l"(reg_a.x), "=l"(reg_a.y) + : "l"(a) + : "memory"); + __half2 *pa = (__half2 *)®_a; + __half2 *pc = (__half2 *)®_c; + __half2 v = __half2half2(*(const __half *)b); + pc[0] = _IntrinsicType::compute(pa[0], v); + pc[1] = _IntrinsicType::compute(pa[1], v); + pc[2] = _IntrinsicType::compute(pa[2], v); + pc[3] = _IntrinsicType::compute(pa[3], v); + asm volatile("st.global.v2.u64 [%0], {%1,%2};" + : + : "l"(c), "l"(reg_c.x), "l"(reg_c.y) + : "memory"); + } else { + longlong2 reg_a; + longlong2 reg_b; + longlong2 reg_c; + asm volatile("ld.global.v2.u64 {%0,%1}, [%2];" + : "=l"(reg_a.x), "=l"(reg_a.y) + : "l"(a) + : "memory"); + asm volatile("ld.global.v2.u64 {%0,%1}, [%2];" + : "=l"(reg_b.x), "=l"(reg_b.y) + : "l"(b) + : "memory"); + __half2 *pa = (__half2 *)®_a; + __half2 *pb = (__half2 *)®_b; + __half2 *pc = (__half2 *)®_c; + pc[0] = _IntrinsicType::compute(pa[0], pb[0]); + pc[1] = _IntrinsicType::compute(pa[1], pb[1]); + pc[2] = _IntrinsicType::compute(pa[2], pb[2]); + pc[3] = _IntrinsicType::compute(pa[3], pb[3]); + asm volatile("st.global.v2.u64 [%0], {%1,%2};" + : + : "l"(c), "l"(reg_c.x), "l"(reg_c.y) + : "memory"); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + // Static checker if InShape can be broadcasted into OutShape. template struct BroadcastShapeChecker1 { @@ -67,13 +479,13 @@ struct BroadcastShapeChecker2 { // https://numpy.org/doc/stable/user/basics.broadcasting.html template + int SmemBytes, typename Intrinsic> struct Broadcast1 { using UnitOp = UnitOp; - using InputType = typename CompType::InputType; - using OutputType = typename CompType::OutputType; - static const int NelemPerThread = CompType::NelemPerThread; + using InputType = typename Intrinsic::InputType; + using OutputType = typename Intrinsic::OutputType; + static const int NelemPerThread = Intrinsic::NelemPerThread; static_assert(NelemPerThread > 0, "NelemPerThread must be positive"); static_assert(UnitOutDims::W % NelemPerThread == 0, @@ -124,7 +536,7 @@ struct Broadcast1 { InDims::CHW; } - CompType::compute(&out[idx_out], &in[idx_in]); + Intrinsic::compute(&out[idx_out], &in[idx_in]); } } }; @@ -134,13 +546,13 @@ struct Broadcast1 { template + typename Intrinsic> struct Broadcast2 { using UnitOp = UnitOp; - using InputType = typename CompType::InputType; - using OutputType = typename CompType::OutputType; - static const int NelemPerThread = CompType::NelemPerThread; + using InputType = typename Intrinsic::InputType; + using OutputType = typename Intrinsic::OutputType; + static const int NelemPerThread = Intrinsic::NelemPerThread; static_assert(NelemPerThread > 0, "NelemPerThread must be positive"); static_assert(UnitOutDims::W % NelemPerThread == 0, @@ -211,11 +623,33 @@ struct Broadcast2 { In1Dims::CHW; } - CompType::compute(&out[idx_out], &in0[idx_in0], &in1[idx_in1]); + Intrinsic::compute(&out[idx_out], &in0[idx_in0], &in1[idx_in1]); } } }; +//////////////////////////////////////////////////////////////////////////////// + +// Broadcast2 with a default `NelemPerThread` and the intrinsic template. +template +DEVICE void broadcast2(OutDataType *c, const In0DataType *a, + const In1DataType *b, int uop_idx, int) { + constexpr int NelemPerThread = + (sizeof(OutDataType) <= 2 && UnitOutDims::W % 8 == 0) + ? 8 + : (UnitOutDims::W % 4 == 0) ? 4 : (UnitOutDims::W % 2 == 0) ? 2 : 1; + Broadcast2< + In0Dims, In0Shape, In1Dims, In1Shape, OutDims, OutShape, UnitOutDims, + NumThreads, SmemBytes, + Broadcast2Intrinsic>::run(c, a, b, + uop_idx); +} + } // namespace ark #endif // ARK_KERNELS_BROADCAST_H_ diff --git a/ark/include/kernels/cast.h b/ark/include/kernels/cast.h index 560454ee8..4b0fa0062 100644 --- a/ark/include/kernels/cast.h +++ b/ark/include/kernels/cast.h @@ -12,6 +12,22 @@ template struct Cast; +template +struct Cast<_InShape, _FromType, _ToType, 2> { + using InputType = _FromType; + using OutputType = _ToType; + static const int NelemPerThread = 2; + + static DEVICE void compute(_ToType *output, const _FromType *input) { + if constexpr (_InShape::W == 1) { + *output = _ToType(*input); + } else { + output[0] = _ToType(input[0]); + output[1] = _ToType(input[1]); + } + } +}; + template struct Cast<_InShape, half, float, 2> { using InputType = half; @@ -118,46 +134,14 @@ struct Cast<_InShape, half, int, 2> { } }; -template -DEVICE void cast(float *out, half *in, int uop_idx, int) { - Broadcast1>::run(out, in, uop_idx); -} - -template -DEVICE void cast(float *out, int *in, int uop_idx, int) { - Broadcast1>::run(out, in, uop_idx); -} - -template -DEVICE void cast(half *out, float *in, int uop_idx, int) { - Broadcast1>::run(out, in, uop_idx); -} - -template -DEVICE void cast(half *out, int *in, int uop_idx, int) { - Broadcast1>::run(out, in, uop_idx); -} - -template -DEVICE void cast(int *out, float *in, int uop_idx, int) { - Broadcast1>::run(out, in, uop_idx); -} +// TODO: specialization for bfloat16 template -DEVICE void cast(int *out, half *in, int uop_idx, int) { + typename OutShape, typename UnitOutDims, int NumThreads, + typename FromType, typename ToType> +DEVICE void cast(ToType *out, FromType *in, int uop_idx, int) { Broadcast1>::run(out, in, uop_idx); + Cast>::run(out, in, uop_idx); } } // namespace ark diff --git a/ark/include/kernels/common.h b/ark/include/kernels/common.h index 739a4a97f..7b0192f2b 100644 --- a/ark/include/kernels/common.h +++ b/ark/include/kernels/common.h @@ -7,11 +7,10 @@ #include "arch.h" #include "checker.h" #include "device.h" -#include "half.h" -#include "platform.h" #include "smem.h" #include "static_math.h" #include "sync.h" +#include "type_intrinsics.h" #include "unit_op.h" #include "vec.h" diff --git a/ark/include/kernels/embedding.h b/ark/include/kernels/embedding.h index 59460d263..3c8d9a1fe 100644 --- a/ark/include/kernels/embedding.h +++ b/ark/include/kernels/embedding.h @@ -4,6 +4,8 @@ #ifndef ARK_KERNELS_EMBEDDING_H_ #define ARK_KERNELS_EMBEDDING_H_ +#include + #include "common.h" namespace ark { @@ -41,24 +43,105 @@ struct RoPE { } }; -template -DEVICE void rope(float *c, float *a, float *b, int uop_idx, int) { - Broadcast2>::run(c, a, b, - uop_idx); -} +template <> +struct RoPE { + using InputType = bfloat16; + using OutputType = bfloat16; + static const int NelemPerThread = 2; + static DEVICE void compute(bfloat16 *c, const bfloat16 *a, + const bfloat16 *b) { + float2 pa; + float2 pb; + float2 pc; + pa.x = float(a[0]); + pa.y = float(a[1]); + pb.x = float(b[0]); + pb.y = float(b[1]); + RoPE::compute((float *)&pc, (const float *)&pa, + (const float *)&pb); + c[0] = bfloat16(pc.x); + c[1] = bfloat16(pc.y); + } +}; template -DEVICE void rope(half *c, half *a, half *b, int uop_idx, int) { + typename UnitOutDims, int NumThreads, int SmemBytes, + typename DataType> +DEVICE void rope(DataType *c, DataType *a, DataType *b, int uop_idx, int) { Broadcast2>::run(c, a, b, - uop_idx); + UnitOutDims, NumThreads, SmemBytes, + RoPE>::run(c, a, b, uop_idx); } +// TODO: figure out why below doesn't pass the accuracy test for half + +// struct Rope { +// static DEVICE float2 compute(float2 a, float2 b) { +// float2 out; +// out.x = a.x * b.x - a.y * b.y; +// out.y = a.x * b.y + a.y * b.x; +// return out; +// } +// static DEVICE __half2 compute(__half2 a, __half2 b) { +// __half2 out; +// out.x = __hmul(a.x, b.x) - __hmul(a.y, b.y); +// out.y = __hmul(a.x, b.y) + __hmul(a.y, b.x); +// return out; +// } +// // struct DEVICE __nv_bfloat162 compute(__nv_bfloat162 a, __nv_bfloat162 +// b) +// // { +// // __nv_bfloat162 out; +// // out.x = __hmul(a.x, b.x) - __hmul(a.y, b.y); +// // out.y = __hmul(a.x, b.y) + __hmul(a.y, b.x); +// // return out; +// // } +// }; + +// template +// DEVICE void rope(DataType *c, DataType *a, DataType *b, int uop_idx, int) { +// static_assert(In0Dims::W % 2 == 0, ""); +// static_assert(In1Dims::W % 2 == 0, ""); +// static_assert(OutDims::W % 2 == 0, ""); +// static_assert(In0Shape::W % 2 == 0, ""); +// static_assert(In1Shape::W % 2 == 0, ""); +// static_assert(OutShape::W % 2 == 0, ""); + +// using VecType = typename std::conditional< +// std::is_same::value, float2, +// typename std::conditional< +// std::is_same::value, __half2, +// typename std::conditional::value, +// __nv_bfloat162, +// void>::type>::type>::type; + +// using In0VecDims = Vec; using In1VecDims = Vec; using OutVecDims = Vec; + +// using In0VecShape = Vec; +// using In1VecShape = Vec; +// using OutVecShape = Vec; + +// Broadcast2>::run((VecType *)c, (VecType +// *)a, +// (VecType *)b, uop_idx); +// } + // Embedding template @@ -66,40 +149,14 @@ struct Assign { using InputType = _DataType; using OutputType = _DataType; static const int NelemPerThread = 1; - static DEVICE void compute(_DataType *c, const _DataType *a) { *c = *a; } + static DEVICE void compute(_DataType *y, const _DataType *x) { *y = *x; } }; template -DEVICE void embedding(float *output, int *input, float *weight, int uop_idx, - int) { - // InShape: Vec - // WeightShape: Vec< 1, 1, ?, EmbeddingDim> (?: # of embeddings) - // OutShape: Vec - - static_assert(InShape::W == 1, ""); - - using UnitOutDims = Vec<1, 1, 1, OutDims::W>; - using UnitOp = UnitOp; - int un = UnitOp::uop_idx_n(uop_idx); - int uc = UnitOp::uop_idx_c(uop_idx); - int uh = UnitOp::uop_idx_h(uop_idx); - - // pWeight: Vec<1, 1, 1, EmbeddingDim> - int emb_idx = input[un * InDims::CH + uc * InDims::H + uh]; - float *pWeight = &weight[emb_idx * WeightDims::W]; - - Broadcast1, Vec<1, 1, 1, EmbeddingDim>, OutDims, - OutShape, UnitOutDims, NumThreads, 0, - Assign>::run(output, pWeight, uop_idx); -} - -template -DEVICE void embedding(half *output, int *input, half *weight, int uop_idx, - int) { + int EmbeddingDim, int NumThreads, typename DataType> +DEVICE void embedding(DataType *output, int *input, DataType *weight, + int uop_idx, int) { // InShape: Vec // WeightShape: Vec< 1, 1, ?, EmbeddingDim> (?: # of embeddings) // OutShape: Vec @@ -114,11 +171,11 @@ DEVICE void embedding(half *output, int *input, half *weight, int uop_idx, // pWeight: Vec<1, 1, 1, EmbeddingDim> int emb_idx = input[un * InDims::CH + uc * InDims::H + uh]; - half *pWeight = &weight[emb_idx * WeightDims::W]; + DataType *pWeight = &weight[emb_idx * WeightDims::W]; Broadcast1, Vec<1, 1, 1, EmbeddingDim>, OutDims, OutShape, UnitOutDims, NumThreads, 0, - Assign>::run(output, pWeight, uop_idx); + Assign>::run(output, pWeight, uop_idx); } } // namespace ark diff --git a/ark/include/kernels/gemm.h b/ark/include/kernels/gemm.h index 1c8ba25d6..c01c9cf32 100644 --- a/ark/include/kernels/gemm.h +++ b/ark/include/kernels/gemm.h @@ -396,6 +396,345 @@ struct GemmConfiguration, 10>; }; +//////////////////////////////////////////////////////////////////////////////// +/// SM80 BF16 +//////////////////////////////////////////////////////////////////////////////// + +template +struct GemmConfiguration> { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, + ElementOutput, LayoutC, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + ark::GemmThreadblockSwizzle, 3>; +}; + +template +struct GemmConfiguration> { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, + ElementOutput, LayoutC, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + ark::GemmThreadblockSwizzle, 3>; +}; + +template +struct GemmConfiguration> { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, + ElementOutput, LayoutC, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + ark::GemmThreadblockSwizzle, 3>; +}; + +template +struct GemmConfiguration> { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, + ElementOutput, LayoutC, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + ark::GemmThreadblockSwizzle, 3>; +}; + +template +struct GemmConfiguration> { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, + ElementOutput, LayoutC, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + ark::GemmThreadblockSwizzle, 3>; +}; + +template +struct GemmConfiguration> { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, + ElementOutput, LayoutC, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + ark::GemmThreadblockSwizzle, 4>; +}; + +template +struct GemmConfiguration> { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, + ElementOutput, LayoutC, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + ark::GemmThreadblockSwizzle, 4>; +}; + +template +struct GemmConfiguration> { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, + ElementOutput, LayoutC, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + ark::GemmThreadblockSwizzle, 6>; +}; + +template +struct GemmConfiguration> { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, + ElementOutput, LayoutC, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + ark::GemmThreadblockSwizzle, 3>; +}; + +template +struct GemmConfiguration> { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, + ElementOutput, LayoutC, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + ark::GemmThreadblockSwizzle, 3>; +}; + +template +struct GemmConfiguration> { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, + ElementOutput, LayoutC, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + ark::GemmThreadblockSwizzle, 4>; +}; + +template +struct GemmConfiguration> { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, + ElementOutput, LayoutC, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + ark::GemmThreadblockSwizzle, 4>; +}; + +template +struct GemmConfiguration> { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, + ElementOutput, LayoutC, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + ark::GemmThreadblockSwizzle, 4>; +}; + +template +struct GemmConfiguration> { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, + ElementOutput, LayoutC, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + ark::GemmThreadblockSwizzle, 6>; +}; + +template +struct GemmConfiguration> { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, LayoutA, cutlass::bfloat16_t, LayoutB, + ElementOutput, LayoutC, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + ark::GemmThreadblockSwizzle, 6>; +}; + +template +struct GemmConfiguration> { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::half_t, LayoutA, cutlass::half_t, LayoutB, ElementOutput, + LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + ark::GemmThreadblockSwizzle, 10>; +}; + //////////////////////////////////////////////////////////////////////////////// /// SM80 FP32 //////////////////////////////////////////////////////////////////////////////// diff --git a/ark/include/kernels/layernorm.h b/ark/include/kernels/layernorm.h index 00aad8961..feee4a791 100644 --- a/ark/include/kernels/layernorm.h +++ b/ark/include/kernels/layernorm.h @@ -33,7 +33,6 @@ struct LayerNorm { static DEVICE void run(DataType *out, const DataType *in, int uop_idx, int smem_per_warp) { using InOutChk = LayerNormShapeChecker; - using ReduceTypeMean = ReduceTypeMean; constexpr int NonReduceDimLength = UnitOutDims::NCH; // The reduction dimension of the final stage. @@ -65,19 +64,19 @@ struct LayerNorm { (tid_n + un * UnitOutDims::N) * InDims::CHW; DataType reduced; - ReduceTypeMean::singleIdentity(&reduced); + ReduceTypeMean::identity<1>(&reduced); for (int idx_in_w = tid_w; idx_in_w < InShape::W; idx_in_w += ThreadsPerRow) { int idx_in = idx_in_base + idx_in_w; - ReduceTypeMean::singleReduce(&reduced, &reduced, &in[idx_in]); + ReduceTypeMean::reduce<1>(&reduced, &reduced, &in[idx_in]); } // final reduction on shared memory using warp shuffle. reduced = warpsReduce( reduced, tid, smem_per_warp); // get the average result. - ReduceTypeMean::singlePostReduce(&reduced, &reduced, UnitOutDims::W); + ReduceTypeMean::postReduce<1>(&reduced, &reduced, UnitOutDims::W); DataType variance; - ReduceTypeMean::singleIdentity(&variance); + ReduceTypeMean::identity<1>(&variance); // get the variance // TODO: Kahan sum for (int idx_in_w = tid_w; idx_in_w < InShape::W; @@ -87,7 +86,7 @@ struct LayerNorm { } variance = warpsReduce( variance, tid, smem_per_warp); - ReduceTypeMean::singlePostReduce(&variance, &variance, UnitOutDims::W); + ReduceTypeMean::postReduce<1>(&variance, &variance, UnitOutDims::W); // the output is (input - mean) / sqrt(variance) for (int idx_w = tid_w; idx_w < InShape::W; idx_w += ThreadsPerRow) { int idx_in = idx_in_base + idx_w; @@ -99,24 +98,13 @@ struct LayerNorm { template -DEVICE void layernorm(float *out, const float *in, int uop_idx, + int SmemBytes, typename DataType> +DEVICE void layernorm(DataType *out, const DataType *in, int uop_idx, int smem_per_warp) { constexpr int NelemPerThread = 1; LayerNorm::run(out, in, uop_idx, - smem_per_warp); -} - -template -DEVICE void layernorm(ark::half *out, const ark::half *in, int uop_idx, - int smem_per_warp) { - constexpr int NelemPerThread = 1; - LayerNorm::run(out, in, uop_idx, - smem_per_warp); + SmemBytes, DataType, NelemPerThread>::run(out, in, uop_idx, + smem_per_warp); } // Perform RMS normalization on input and write the result on output. @@ -132,7 +120,6 @@ struct RMSNorm { static DEVICE void run(DataType *out, const DataType *in, int uop_idx, int smem_per_warp) { using InOutChk = LayerNormShapeChecker; - using ReduceTypeMean = ReduceTypeMean; constexpr int NonReduceDimLength = UnitOutDims::NCH; // The reduction dimension of the final stage. @@ -166,8 +153,8 @@ struct RMSNorm { // calculate mean square DataType mean_square; DataType cmp; - ReduceTypeMean::singleIdentity(&mean_square); - ReduceTypeMean::singleIdentity(&cmp); + ReduceTypeMean::identity<1>(&mean_square); + ReduceTypeMean::identity<1>(&cmp); for (int idx_in_w = tid_w; idx_in_w < InShape::W; idx_in_w += ThreadsPerRow) { int idx_in = idx_in_base + idx_in_w; @@ -179,8 +166,8 @@ struct RMSNorm { } mean_square = warpsReduce( mean_square, tid, smem_per_warp); - ReduceTypeMean::singlePostReduce(&mean_square, &mean_square, - UnitOutDims::W); + ReduceTypeMean::postReduce<1>(&mean_square, &mean_square, + UnitOutDims::W); // the output is (input - mean) / sqrt(mean_square) DataType rrms(rsqrtf(mean_square + 1e-5f)); for (int idx_w = tid_w; idx_w < InShape::W; idx_w += ThreadsPerRow) { @@ -193,24 +180,13 @@ struct RMSNorm { template -DEVICE void rmsnorm(float *out, const float *in, int uop_idx, - int smem_per_warp) { - constexpr int NelemPerThread = 1; - RMSNorm::run(out, in, uop_idx, - smem_per_warp); -} - -template -DEVICE void rmsnorm(ark::half *out, const ark::half *in, int uop_idx, + int SmemBytes, typename DataType> +DEVICE void rmsnorm(DataType *out, const DataType *in, int uop_idx, int smem_per_warp) { constexpr int NelemPerThread = 1; RMSNorm::run(out, in, uop_idx, - smem_per_warp); + SmemBytes, DataType, NelemPerThread>::run(out, in, uop_idx, + smem_per_warp); } } // namespace ark diff --git a/ark/include/kernels/math_functions.h b/ark/include/kernels/math_functions.h index 511ce37c2..fcb304a48 100644 --- a/ark/include/kernels/math_functions.h +++ b/ark/include/kernels/math_functions.h @@ -4,82 +4,36 @@ #ifndef ARK_KERNELS_MATH_FUNCTIONS_H_ #define ARK_KERNELS_MATH_FUNCTIONS_H_ -#include "broadcast.h" +#include "common.h" namespace ark { -struct Exp { - static DEVICE float compute(float input) { return expf(input); } - static DEVICE __half2 compute(__half2 input) { return h2exp(input); } -}; - -struct Sqrt { - static DEVICE float compute(float input) { return sqrtf(input); } - static DEVICE __half2 compute(__half2 input) { return h2sqrt(input); } -}; - -template -struct Math; - -template -struct Math<_MathType, _InShape, half, 2> { - using InputType = half; - using OutputType = half; - static const int NelemPerThread = 2; - - static DEVICE void compute(half *output, const half *input) { - __half2 *pout = (__half2 *)output; - if (_InShape::W == 1) { - *pout = _MathType::compute(__half2half2(*(const __half *)input)); - } else { - __half2 *pin = (__half2 *)input; - *pout = _MathType::compute(*pin); - } - } -}; - -template -struct Math<_MathType, _InShape, float, 1> { - using InputType = float; - using OutputType = float; - static const int NelemPerThread = 1; - - static DEVICE void compute(float *output, const float *input) { - *output = _MathType::compute(*input); - } -}; - -template -DEVICE void exp(half *out, half *in, int uop_idx, int) { - Broadcast1>::run(out, in, uop_idx); -} - -template -DEVICE void exp(float *out, float *in, int uop_idx, int) { - Broadcast1>::run(out, in, uop_idx); -} - template -DEVICE void sqrt(half *out, half *in, int uop_idx, int) { + int SmemBytes, typename InDataType, typename OutDataType> +DEVICE void exp(OutDataType *out, const InDataType *in, int uop_idx, int) { + constexpr int NelemPerThread = + (sizeof(OutDataType) <= 2 && UnitOutDims::W % 8 == 0) + ? 8 + : (UnitOutDims::W % 4 == 0) ? 4 : (UnitOutDims::W % 2 == 0) ? 2 : 1; Broadcast1>::run(out, in, uop_idx); + SmemBytes, + Broadcast1Intrinsic>::run(out, in, uop_idx); } template -DEVICE void sqrt(float *out, float *in, int uop_idx, int) { + int SmemBytes, typename InDataType, typename OutDataType> +DEVICE void sqrt(OutDataType *out, const InDataType *in, int uop_idx, int) { + constexpr int NelemPerThread = + (sizeof(OutDataType) <= 2 && UnitOutDims::W % 8 == 0) + ? 8 + : (UnitOutDims::W % 4 == 0) ? 4 : (UnitOutDims::W % 2 == 0) ? 2 : 1; Broadcast1>::run(out, in, uop_idx); + SmemBytes, + Broadcast1Intrinsic>::run(out, in, uop_idx); } } // namespace ark diff --git a/ark/include/kernels/matmul.h b/ark/include/kernels/matmul.h index 000665ecf..83d30d399 100644 --- a/ark/include/kernels/matmul.h +++ b/ark/include/kernels/matmul.h @@ -32,43 +32,24 @@ namespace ark { template -DEVICE void matmul(float *C, float *A, float *B, int uop_idx, + int SmemBytes, typename DataTypeA, typename DataTypeB, + typename DataTypeC> +DEVICE void matmul(DataTypeC *C, DataTypeA *A, DataTypeB *B, int uop_idx, int smem_per_warp) { gemm(C, A, B, uop_idx, smem_per_warp); + InnerLdimB, IsColumnA, IsColumnB, NumThreads, SmemBytes, DataTypeA, + DataTypeB, DataTypeC, DataTypeC>(C, A, B, uop_idx, smem_per_warp); } -/// Matrix multiplication. -/// -/// Reuse GEMM kernels. The output is row-major, and the input matrices are -/// row-major by default. If the input matrices are column-major, the -/// corresponding @p IsColumnA or @p IsColumnB should be set to true. -/// -/// @tparam OutDims (ark::Vec) Output tensor leading dimensions. -/// @tparam NCA (ark::Vec) A 2D vector with N and C dimensions of matrix A. -/// @tparam NCB (ark::Vec) A 2D vector with N and C dimensions of matrix B. -/// @tparam Shape (ark::Vec) The tile shape of matmul computation (m, n, k). -/// @tparam ProblemSize (ark::Vec) The problem size of matmul computation -/// (m, n, k). -/// @tparam LeadingDims (ark::Vec) The leading dimensions of matrix inputs -/// and outputs. (lda, ldc, ldc, ldb). -/// @tparam InnerLdimA (int) The leading dimension of the inner dimension of A. -/// @tparam InnerLdimB (int) The leading dimension of the inner dimension of B. -/// @tparam IsColumnA (bool) Whether matrix A is column-major. -/// @tparam IsColumnB (bool) Whether matrix B is column-major. -/// @tparam NumThreads (int) The number of threads per uop. -/// @tparam SmemBytes (int) The size of shared memory per uop. -/// template -DEVICE void matmul(half *C, half *A, half *B, int uop_idx, int smem_per_warp) { +DEVICE void matmul(bfloat16 *C, bfloat16 *A, bfloat16 *B, int uop_idx, + int smem_per_warp) { gemm(C, A, B, uop_idx, smem_per_warp); + InnerLdimB, IsColumnA, IsColumnB, NumThreads, SmemBytes, bfloat16, + bfloat16, bfloat16, float>(C, A, B, uop_idx, smem_per_warp); } } // namespace ark diff --git a/ark/include/kernels/reduce.h b/ark/include/kernels/reduce.h index 301e5f58c..662be7334 100644 --- a/ark/include/kernels/reduce.h +++ b/ark/include/kernels/reduce.h @@ -23,51 +23,57 @@ struct ReduceSharedStorage { DataType storage[32]; }; -/* Reduce single-precision `val` within a single warp. */ -template +// Reduce single-precision `val` within a single warp. +template DEVICE DataType warpReduce(DataType val) { DataType res = val; DataType tmp; if (LanesNum >= 32) { tmp = __shfl_xor_sync(0xffffffff, res, 16, 32); - ReduceType::singleReduce(&res, &res, &tmp); + ReduceType::reduce<1>(&res, &res, &tmp); tmp = __shfl_xor_sync(0xffffffff, res, 8, 16); - ReduceType::singleReduce(&res, &res, &tmp); + ReduceType::reduce<1>(&res, &res, &tmp); tmp = __shfl_xor_sync(0xffffffff, res, 4, 8); - ReduceType::singleReduce(&res, &res, &tmp); + ReduceType::reduce<1>(&res, &res, &tmp); tmp = __shfl_xor_sync(0xffffffff, res, 2, 4); - ReduceType::singleReduce(&res, &res, &tmp); + ReduceType::reduce<1>(&res, &res, &tmp); tmp = __shfl_xor_sync(0xffffffff, res, 1, 2); - ReduceType::singleReduce(&res, &res, &tmp); + ReduceType::reduce<1>(&res, &res, &tmp); } else { if (LanesNum > 16) { tmp = __shfl_xor_sync(0xffffffff, res, 16, 32); - ReduceType::singleReduce(&res, &res, &tmp); + ReduceType::reduce<1>(&res, &res, &tmp); } if (LanesNum > 8) { tmp = __shfl_xor_sync(0xffffffff, res, 8, 16); - ReduceType::singleReduce(&res, &res, &tmp); + ReduceType::reduce<1>(&res, &res, &tmp); } if (LanesNum > 4) { tmp = __shfl_xor_sync(0xffffffff, res, 4, 8); - ReduceType::singleReduce(&res, &res, &tmp); + ReduceType::reduce<1>(&res, &res, &tmp); } if (LanesNum > 2) { tmp = __shfl_xor_sync(0xffffffff, res, 2, 4); - ReduceType::singleReduce(&res, &res, &tmp); + ReduceType::reduce<1>(&res, &res, &tmp); } if (LanesNum > 1) { tmp = __shfl_xor_sync(0xffffffff, res, 1, 2); - ReduceType::singleReduce(&res, &res, &tmp); + ReduceType::reduce<1>(&res, &res, &tmp); } } return res; } +// Reduce bfloat16 `val` within a single warp. +template +DEVICE bfloat16 warpReduce(bfloat16 val) { + float tmp(val); + tmp = warpReduce(tmp); + return bfloat16(tmp); +} + // Reduce single-precision `val` within multiple warps. -template +template DEVICE DataType warpsReduce(DataType val, int tid, int smem_per_warp) { val = warpReduce(val); if (LanesNum > 32) { @@ -83,7 +89,7 @@ DEVICE DataType warpsReduce(DataType val, int tid, int smem_per_warp) { if (laneId < (LanesNum >> 5)) { val = shared->storage[laneId]; } else { - ReduceType::singleIdentity(&val); + ReduceType::identity<1>(&val); } val = warpReduce(val); } @@ -115,24 +121,25 @@ struct ReduceShapeChecker { "Invalid UnitOutDims::W"); }; -template struct ReduceTypeSum { - using DataType = _DataType; - static const int NelemPerThread = _NelemPerThread; - + template static DEVICE void identity(DataType *v) { #pragma unroll for (int elem = 0; elem < NelemPerThread; ++elem) { - v[elem] = 0; + v[elem] = type::Constant::zero(); } } + + template static DEVICE void reduce(DataType *out, const DataType *in0, const DataType *in1) { #pragma unroll for (int elem = 0; elem < NelemPerThread; ++elem) { - out[elem] = in0[elem] + in1[elem]; + out[elem] = type::Add::compute(in0[elem], in1[elem]); } } + + template static DEVICE void postReduce(DataType *out, const DataType *in, int nelem = 1) { #pragma unroll @@ -140,65 +147,71 @@ struct ReduceTypeSum { out[elem] = in[elem]; } } - static DEVICE void singleIdentity(DataType *v) { *v = 0; } - static DEVICE void singleReduce(DataType *out, const DataType *in0, - const DataType *in1) { - *out = *in0 + *in1; - } - static DEVICE void singlePostReduce(DataType *out, const DataType *in, - int nelem = 1) { - *out = *in; - } + + // template <> + // static DEVICE void identity<2, half>(half *v) { + // *reinterpret_cast<__half2 *>(v) = type::Constant<__half2>::zero(); + // } + + // template <> + // static DEVICE void reduce<2, half>(half *out, const half *in0, + // const half *in1) { + // __half2 *out2 = reinterpret_cast<__half2 *>(out); + // const __half2 *in02 = reinterpret_cast(in0); + // const __half2 *in12 = reinterpret_cast(in1); + // *out2 = type::Add::compute(*in02, *in12); + // } + + // template <> + // static DEVICE void postReduce<2, half>(half *out, const half *in, + // int nelem = 1) { + // __half2 *out2 = reinterpret_cast<__half2 *>(out); + // const __half2 *in2 = reinterpret_cast(in); + // *out2 = *in2; + // } }; template <> -struct ReduceTypeSum { - using DataType = half; - static const int NelemPerThread = 2; +DEVICE void ReduceTypeSum::identity<2, half>(half *v) { + *reinterpret_cast<__half2 *>(v) = type::Constant<__half2>::zero(); +} - static DEVICE void identity(half *v) { - *reinterpret_cast<__half2 *>(v) = (__half2_raw){0, 0}; - } - static DEVICE void reduce(half *out, const half *in0, const half *in1) { - __half2 *out2 = reinterpret_cast<__half2 *>(out); - const __half2 *in02 = reinterpret_cast(in0); - const __half2 *in12 = reinterpret_cast(in1); - *out2 = __hadd2(*in02, *in12); - } - static DEVICE void postReduce(half *out, const half *in, int nelem = 1) { - __half2 *out2 = reinterpret_cast<__half2 *>(out); - const __half2 *in2 = reinterpret_cast(in); - *out2 = *in2; - } - static DEVICE void singleIdentity(half *v) { *v = 0; } - static DEVICE void singleReduce(half *out, const half *in0, - const half *in1) { - *out = *in0 + *in1; - } - static DEVICE void singlePostReduce(half *out, const half *in, - int nelem = 1) { - *out = *in; - } -}; +template <> +DEVICE void ReduceTypeSum::reduce<2, half>(half *out, const half *in0, + const half *in1) { + __half2 *out2 = reinterpret_cast<__half2 *>(out); + const __half2 *in02 = reinterpret_cast(in0); + const __half2 *in12 = reinterpret_cast(in1); + *out2 = type::Add::compute(*in02, *in12); +} -template -struct ReduceTypeMax { - using DataType = _DataType; - static const int NelemPerThread = _NelemPerThread; +template <> +DEVICE void ReduceTypeSum::postReduce<2, half>(half *out, const half *in, + int nelem) { + __half2 *out2 = reinterpret_cast<__half2 *>(out); + const __half2 *in2 = reinterpret_cast(in); + *out2 = *in2; +} +struct ReduceTypeMax { + template static DEVICE void identity(DataType *v) { #pragma unroll for (int elem = 0; elem < NelemPerThread; ++elem) { - v[elem] = platform::numeric_limits::lowest(); + v[elem] = type::Constant::lowest(); } } + + template static DEVICE void reduce(DataType *out, const DataType *in0, const DataType *in1) { #pragma unroll for (int elem = 0; elem < NelemPerThread; ++elem) { - out[elem] = (in0[elem] > in1[elem]) ? in0[elem] : in1[elem]; + out[elem] = type::Max::compute(in0[elem], in1[elem]); } } + + template static DEVICE void postReduce(DataType *out, const DataType *in, int nelem = 1) { #pragma unroll @@ -206,135 +219,136 @@ struct ReduceTypeMax { out[elem] = in[elem]; } } - static DEVICE void singleIdentity(DataType *v) { - *v = platform::numeric_limits::lowest(); - } - static DEVICE void singleReduce(DataType *out, const DataType *in0, - const DataType *in1) { - *out = (*in0 > *in1) ? *in0 : *in1; - } - static DEVICE void singlePostReduce(DataType *out, const DataType *in, - int nelem = 1) { - *out = *in; - } + // template <> + // static DEVICE void identity<2, half>(half *v) { + // *reinterpret_cast<__half2 *>(v) = type::Constant<__half2>::lowest(); + // } + + // template <> + // static DEVICE void reduce<2, half>(half *out, const half *in0, const half + // *in1) { + // __half2 *out2 = reinterpret_cast<__half2 *>(out); + // const __half2 *in02 = reinterpret_cast(in0); + // const __half2 *in12 = reinterpret_cast(in1); + // *out2 = type::Max::compute(*in02, *in12); + // } + + // template <> + // static DEVICE void postReduce<2, half>(half *out, const half *in, + // int nelem = 1) { + // __half2 *out2 = reinterpret_cast<__half2 *>(out); + // const __half2 *in2 = reinterpret_cast(in); + // *out2 = *in2; + // } }; template <> -struct ReduceTypeMax { - using DataType = half; - static const int NelemPerThread = 2; +DEVICE void ReduceTypeMax::identity<2, half>(half *v) { + *reinterpret_cast<__half2 *>(v) = type::Constant<__half2>::lowest(); +} - static DEVICE void identity(half *v) { - *reinterpret_cast<__half2 *>(v) = (__half2_raw){0xfbff, 0xfbff}; - } - static DEVICE void reduce(half *out, const half *in0, const half *in1) { -#if (__CUDA_ARCH__ >= 800) - __half2 *out2 = reinterpret_cast<__half2 *>(out); - const __half2 *in02 = reinterpret_cast(in0); - const __half2 *in12 = reinterpret_cast(in1); - *out2 = __hmax2(*in02, *in12); -#else -#pragma unroll - for (int elem = 0; elem < NelemPerThread; ++elem) { - out[elem] = (in0[elem] > in1[elem]) ? in0[elem] : in1[elem]; - } -#endif // (__CUDA_ARCH__ >= 800) - } - static DEVICE void postReduce(half *out, const half *in, int nelem = 1) { - __half2 *out2 = reinterpret_cast<__half2 *>(out); - const __half2 *in2 = reinterpret_cast(in); - *out2 = *in2; - } - static DEVICE void singleIdentity(half *v) { - *v = platform::numeric_limits::lowest(); - } - static DEVICE void singleReduce(half *out, const half *in0, - const half *in1) { - *out = (*in0 > *in1) ? *in0 : *in1; - } - static DEVICE void singlePostReduce(half *out, const half *in, - int nelem = 1) { - *out = *in; - } -}; +template <> +DEVICE void ReduceTypeMax::reduce<2, half>(half *out, const half *in0, + const half *in1) { + __half2 *out2 = reinterpret_cast<__half2 *>(out); + const __half2 *in02 = reinterpret_cast(in0); + const __half2 *in12 = reinterpret_cast(in1); + *out2 = type::Max::compute(*in02, *in12); +} -template -struct ReduceTypeMean { - using DataType = _DataType; - static const int NelemPerThread = _NelemPerThread; +template <> +DEVICE void ReduceTypeMax::postReduce<2, half>(half *out, const half *in, + int nelem) { + __half2 *out2 = reinterpret_cast<__half2 *>(out); + const __half2 *in2 = reinterpret_cast(in); + *out2 = *in2; +} +struct ReduceTypeMean { + template static DEVICE void identity(DataType *v) { #pragma unroll for (int elem = 0; elem < NelemPerThread; ++elem) { - v[elem] = 0; + v[elem] = type::Constant::zero(); } } + + template static DEVICE void reduce(DataType *out, const DataType *in0, const DataType *in1) { #pragma unroll for (int elem = 0; elem < NelemPerThread; ++elem) { - out[elem] = in0[elem] + in1[elem]; + out[elem] = type::Add::compute(in0[elem], in1[elem]); } } + + template static DEVICE void postReduce(DataType *out, const DataType *in, int nelem = 1) { #pragma unroll for (int elem = 0; elem < NelemPerThread; ++elem) { - out[elem] = in[elem] / nelem; + out[elem] = type::Div::compute(in[elem], DataType(nelem)); } } - static DEVICE void singleIdentity(DataType *v) { *v = 0; } - static DEVICE void singleReduce(DataType *out, const DataType *in0, - const DataType *in1) { - *out = *in0 + *in1; - } - static DEVICE void singlePostReduce(DataType *out, const DataType *in, - int nelem = 1) { - *out = *in / nelem; - } + + // template <> + // static DEVICE void identity<2, half>(half *v) { + // *reinterpret_cast<__half2 *>(v) = type::Constant<__half2>::zero(); + // } + + // template <> + // static DEVICE void reduce<2, half>(half *out, const half *in0, + // const half *in1) { + // __half2 *out2 = reinterpret_cast<__half2 *>(out); + // const __half2 *in02 = reinterpret_cast(in0); + // const __half2 *in12 = reinterpret_cast(in1); + // *out2 = type::Add::compute(*in02, *in12); + // } + + // template <> + // static DEVICE void postReduce<2, half>(half *out, const half *in, + // int nelem = 1) { + // __half2 *out2 = reinterpret_cast<__half2 *>(out); + // const __half2 *in2 = reinterpret_cast(in); + // *out2 = type::Div::compute(*in2, __float2half2_rn((float)nelem)); + // } }; template <> -struct ReduceTypeMean { - using DataType = half; - static const int NelemPerThread = 2; +DEVICE void ReduceTypeMean::identity<2, half>(half *v) { + *reinterpret_cast<__half2 *>(v) = type::Constant<__half2>::zero(); +} - static DEVICE void identity(half *v) { - *reinterpret_cast<__half2 *>(v) = (__half2_raw){0, 0}; - } - static DEVICE void reduce(half *out, const half *in0, const half *in1) { - __half2 *out2 = reinterpret_cast<__half2 *>(out); - const __half2 *in02 = reinterpret_cast(in0); - const __half2 *in12 = reinterpret_cast(in1); - *out2 = __hadd2(*in02, *in12); - } - static DEVICE void postReduce(half *out, const half *in, int nelem = 1) { - __half2 *out2 = reinterpret_cast<__half2 *>(out); - const __half2 *in2 = reinterpret_cast(in); - *out2 = __h2div(*in2, __float2half2_rn((float)nelem)); - } - static DEVICE void singleIdentity(half *v) { *v = 0; } - static DEVICE void singleReduce(half *out, const half *in0, - const half *in1) { - *out = *in0 + *in1; - } - static DEVICE void singlePostReduce(half *out, const half *in, - int nelem = 1) { - *out = *in / nelem; - } -}; +template <> +DEVICE void ReduceTypeMean::reduce<2, half>(half *out, const half *in0, + const half *in1) { + __half2 *out2 = reinterpret_cast<__half2 *>(out); + const __half2 *in02 = reinterpret_cast(in0); + const __half2 *in12 = reinterpret_cast(in1); + *out2 = type::Add::compute(*in02, *in12); +} + +template <> +DEVICE void ReduceTypeMean::postReduce<2, half>(half *out, const half *in, + int nelem) { + __half2 *out2 = reinterpret_cast<__half2 *>(out); + const __half2 *in2 = reinterpret_cast(in); + *out2 = type::Div::compute(*in2, __float2half2_rn((float)nelem)); +} template + typename ReduceType, typename _DataType, int _NelemPerThread, + int Axis> struct EwiseReduceCompType; // Conduct reduction on N dimension of the input. template -struct EwiseReduceCompType { - using DataType = typename ReduceType::DataType; - static const int NelemPerThread = ReduceType::NelemPerThread; + typename ReduceType, typename _DataType, int _NelemPerThread> +struct EwiseReduceCompType { + using DataType = _DataType; + static const int NelemPerThread = _NelemPerThread; static DEVICE void compute(DataType *out, DataType *in, int idx_n, int idx_c, int idx_h, int idx_w) { @@ -342,21 +356,24 @@ struct EwiseReduceCompType { int idx_in = idx_c * InDims::HW + idx_h * InDims::W + idx_w; DataType reduced[NelemPerThread]; - ReduceType::identity(reduced); + ReduceType::identity(reduced); #pragma unroll for (int i = 0; i < InShape::N; ++i) { - ReduceType::reduce(reduced, reduced, &in[idx_in + i * InDims::CHW]); + ReduceType::reduce(reduced, reduced, + &in[idx_in + i * InDims::CHW]); } - ReduceType::postReduce(&out[idx_out], reduced, InShape::N); + ReduceType::postReduce(&out[idx_out], reduced, + InShape::N); } }; // Conduct reduction on C dimension of the input. template -struct EwiseReduceCompType { - using DataType = typename ReduceType::DataType; - static const int NelemPerThread = ReduceType::NelemPerThread; + typename ReduceType, typename _DataType, int _NelemPerThread> +struct EwiseReduceCompType { + using DataType = _DataType; + static const int NelemPerThread = _NelemPerThread; static DEVICE void compute(DataType *out, DataType *in, int idx_n, int idx_c, int idx_h, int idx_w) { @@ -364,21 +381,24 @@ struct EwiseReduceCompType { int idx_in = idx_n * InDims::CHW + idx_h * InDims::W + idx_w; DataType reduced[NelemPerThread]; - ReduceType::identity(reduced); + ReduceType::identity(reduced); #pragma unroll for (int i = 0; i < InShape::C; ++i) { - ReduceType::reduce(reduced, reduced, &in[idx_in + i * InDims::HW]); + ReduceType::reduce(reduced, reduced, + &in[idx_in + i * InDims::HW]); } - ReduceType::postReduce(&out[idx_out], reduced, InShape::C); + ReduceType::postReduce(&out[idx_out], reduced, + InShape::C); } }; // Conduct reduction on H dimension of the input. template -struct EwiseReduceCompType { - using DataType = typename ReduceType::DataType; - static const int NelemPerThread = ReduceType::NelemPerThread; + typename ReduceType, typename _DataType, int _NelemPerThread> +struct EwiseReduceCompType { + using DataType = _DataType; + static const int NelemPerThread = _NelemPerThread; static DEVICE void compute(DataType *out, DataType *in, int idx_n, int idx_c, int idx_h, int idx_w) { @@ -386,21 +406,24 @@ struct EwiseReduceCompType { int idx_in = idx_n * InDims::CHW + idx_c * InDims::HW + idx_w; DataType reduced[NelemPerThread]; - ReduceType::identity(reduced); + ReduceType::identity(reduced); #pragma unroll for (int i = 0; i < InShape::H; ++i) { - ReduceType::reduce(reduced, reduced, &in[idx_in + i * InDims::W]); + ReduceType::reduce(reduced, reduced, + &in[idx_in + i * InDims::W]); } - ReduceType::postReduce(&out[idx_out], reduced, InShape::H); + ReduceType::postReduce(&out[idx_out], reduced, + InShape::H); } }; // Conduct reduction on W dimension of the input. template -struct EwiseReduceCompType { - using DataType = typename ReduceType::DataType; - static const int NelemPerThread = ReduceType::NelemPerThread; + typename ReduceType, typename _DataType, int _NelemPerThread> +struct EwiseReduceCompType { + using DataType = _DataType; + static const int NelemPerThread = _NelemPerThread; static DEVICE void compute(DataType *out, DataType *in, int idx_n, int idx_c, int idx_h, int idx_w) { @@ -410,32 +433,31 @@ struct EwiseReduceCompType { idx_n * InDims::CHW + idx_c * InDims::HW + idx_h * InDims::W; DataType reduced[NelemPerThread]; - ReduceType::identity(reduced); + ReduceType::identity(reduced); #pragma unroll for (int i = 0; i < InShape::W; ++i) { - ReduceType::reduce(reduced, reduced, &in[idx_in + i]); + ReduceType::reduce(reduced, reduced, + &in[idx_in + i]); } DataType finalSum; - ReduceType::singleIdentity(&finalSum); + ReduceType::identity<1>(&finalSum); #pragma unroll for (int i = 0; i < NelemPerThread; ++i) { - ReduceType::singleReduce(&finalSum, &finalSum, &reduced[i]); + ReduceType::reduce<1>(&finalSum, &finalSum, &reduced[i]); } - ReduceType::singlePostReduce(&out[idx_out], &finalSum, InShape::W); + ReduceType::postReduce<1>(&out[idx_out], &finalSum, InShape::W); } }; // Reduce one dimension of input into output. template + int SmemBytes, typename ReduceType, int NelemPerThread, int Axis> struct EwiseReduce { using UnitOp = UnitOp; - using DataType = typename ReduceType::DataType; - static const int NelemPerThread = ReduceType::NelemPerThread; static_assert(NelemPerThread > 0, "NelemPerThread must be positive"); static_assert(UnitOutDims::W % NelemPerThread == 0, "UnitOutDims::W must be divisible by NelemPerThread"); @@ -444,6 +466,7 @@ struct EwiseReduce { /// @param out Output tensor. /// @param in Input tensor. /// @param uop_idx Index of the unit operator. + template static DEVICE void run(DataType *out, DataType *in, int uop_idx) { static_assert(Axis == AxisType::N || Axis == AxisType::C || Axis == AxisType::H || Axis == AxisType::W, @@ -452,21 +475,20 @@ struct EwiseReduce { using ShapeChecker = ReduceShapeChecker; - Ewise1>::run(out, in, uop_idx); + Ewise1< + OutDims, OutShape, UnitOutDims, NumThreads, SmemBytes, + EwiseReduceCompType>::run(out, in, uop_idx); } }; // Warp-wise reduction. Only support reduction along the W dimension. template + int SmemBytes, typename ReduceType, int NelemPerThread, int Axis> struct WwiseReduce { using UnitOp = UnitOp; - using DataType = typename ReduceType::DataType; - static const int NelemPerThread = ReduceType::NelemPerThread; static_assert(NelemPerThread > 0, "NelemPerThread must be positive"); static_assert(UnitOutDims::W % NelemPerThread == 0, @@ -480,6 +502,7 @@ struct WwiseReduce { /// @param out Output tensor. /// @param in Input tensor. /// @param uop_idx Index of the unit operator. + template static DEVICE void runW(DataType *out, DataType *in, int uop_idx, int smem_per_warp) { using ShapeChecker = @@ -516,19 +539,19 @@ struct WwiseReduce { DataType reduced[NelemPerThread]; - ReduceType::identity(reduced); + ReduceType::identity(reduced); #pragma unroll for (int idx_in_w = tid_w; idx_in_w < InShape::W; idx_in_w += ThreadsPerRow) { int idx_in = idx_in_base + idx_in_w; - ReduceType::reduce(reduced, reduced, &in[idx_in]); + ReduceType::reduce(reduced, reduced, &in[idx_in]); } DataType finalSum; - ReduceType::singleIdentity(&finalSum); + ReduceType::identity<1>(&finalSum); #pragma unroll for (int i = 0; i < NelemPerThread; ++i) { - ReduceType::singleReduce(&finalSum, &finalSum, &reduced[i]); + ReduceType::reduce<1>(&finalSum, &finalSum, &reduced[i]); } UnitOp::sync_threads(); @@ -539,122 +562,63 @@ struct WwiseReduce { // write the result to output. if (tid % ThreadsPerRow == 0) { - ReduceType::singlePostReduce(&out[idx_out], &finalSum, InShape::W); + ReduceType::postReduce<1>(&out[idx_out], &finalSum, InShape::W); } } }; template -DEVICE void reduce_e_sum(half *out, half *in, int uop_idx, int) { - EwiseReduce, Axis>::run(out, in, uop_idx); -} - -template -DEVICE void reduce_e_sum(float *out, float *in, int uop_idx, int) { + int SmemBytes, int Axis, typename DataType> +DEVICE void reduce_e_sum(DataType *out, DataType *in, int uop_idx, int) { EwiseReduce, Axis>::run(out, in, - uop_idx); + SmemBytes, ReduceTypeSum, 1, Axis>::run(out, in, uop_idx); } template -DEVICE void reduce_e_mean(half *out, half *in, int uop_idx, int) { + int SmemBytes, int Axis, typename DataType> +DEVICE void reduce_e_mean(DataType *out, DataType *in, int uop_idx, int) { EwiseReduce, Axis>::run(out, in, - uop_idx); + SmemBytes, ReduceTypeMean, 1, Axis>::run(out, in, uop_idx); } template -DEVICE void reduce_e_mean(float *out, float *in, int uop_idx, int) { + int SmemBytes, int Axis, typename DataType> +DEVICE void reduce_e_max(DataType *out, DataType *in, int uop_idx, int) { EwiseReduce, Axis>::run(out, in, - uop_idx); -} - -template -DEVICE void reduce_e_max(half *out, half *in, int uop_idx, int) { - EwiseReduce, Axis>::run(out, in, uop_idx); -} - -template -DEVICE void reduce_e_max(float *out, float *in, int uop_idx, int) { - EwiseReduce, Axis>::run(out, in, - uop_idx); -} - -template -DEVICE void reduce_w_sum(half *out, half *in, int uop_idx, int smem_per_warp) { - WwiseReduce, Axis>::runW(out, in, uop_idx, - smem_per_warp); + SmemBytes, ReduceTypeMax, 1, Axis>::run(out, in, uop_idx); } template -DEVICE void reduce_w_sum(float *out, float *in, int uop_idx, + int SmemBytes, int Axis, typename DataType> +DEVICE void reduce_w_sum(DataType *out, DataType *in, int uop_idx, int smem_per_warp) { WwiseReduce, Axis>::runW(out, in, - uop_idx, - smem_per_warp); + SmemBytes, ReduceTypeSum, 1, Axis>::runW(out, in, uop_idx, + smem_per_warp); } template -DEVICE void reduce_w_mean(half *out, half *in, int uop_idx, int smem_per_warp) { - WwiseReduce, Axis>::runW(out, in, - uop_idx, - smem_per_warp); -} - -template -DEVICE void reduce_w_mean(float *out, float *in, int uop_idx, + int SmemBytes, int Axis, typename DataType> +DEVICE void reduce_w_mean(DataType *out, DataType *in, int uop_idx, int smem_per_warp) { WwiseReduce, Axis>::runW(out, in, - uop_idx, - smem_per_warp); -} - -template -DEVICE void reduce_w_max(half *out, half *in, int uop_idx, int smem_per_warp) { - WwiseReduce, Axis>::runW(out, in, uop_idx, - smem_per_warp); + SmemBytes, ReduceTypeMean, 1, Axis>::runW(out, in, uop_idx, + smem_per_warp); } template -DEVICE void reduce_w_max(float *out, float *in, int uop_idx, + int SmemBytes, int Axis, typename DataType> +DEVICE void reduce_w_max(DataType *out, DataType *in, int uop_idx, int smem_per_warp) { WwiseReduce, Axis>::runW(out, in, - uop_idx, - smem_per_warp); + SmemBytes, ReduceTypeMax, 1, Axis>::runW(out, in, uop_idx, + smem_per_warp); } } // namespace ark diff --git a/ark/include/kernels/softmax.h b/ark/include/kernels/softmax.h index e0329d836..860b5adf7 100644 --- a/ark/include/kernels/softmax.h +++ b/ark/include/kernels/softmax.h @@ -37,8 +37,6 @@ struct Softmax { static DEVICE void run(DataType *out, const DataType *in, int uop_idx, int smem_per_warp) { using InOutChk = SoftmaxShapeChecker; - using ReduceTypeMax = ReduceTypeMax; - using ReduceTypeSum = ReduceTypeSum; constexpr int NonReduceDimLength = UnitOutDims::NCH; // The reduction dimension of the final stage. @@ -71,24 +69,24 @@ struct Softmax { // get the max input. DataType max_input; - ReduceTypeMax::singleIdentity(&max_input); + ReduceTypeMax::identity<1>(&max_input); for (int idx_in_w = tid_w; idx_in_w < InShape::W; idx_in_w += ThreadsPerRow) { int idx_in = idx_in_base + idx_in_w; - ReduceTypeMax::singleReduce(&max_input, &max_input, &in[idx_in]); + ReduceTypeMax::reduce<1>(&max_input, &max_input, &in[idx_in]); } // final reduction on shared memory using warp shuffle. max_input = warpsReduce( max_input, tid, smem_per_warp); // get the max input. - ReduceTypeMax::singlePostReduce(&max_input, &max_input, UnitOutDims::W); + ReduceTypeMax::postReduce<1>(&max_input, &max_input, UnitOutDims::W); // get the exp input sum, use float to avoid overflow. DataType exp_sum_input; DataType cmp; - ReduceTypeSum::singleIdentity(&exp_sum_input); - ReduceTypeSum::singleIdentity(&cmp); + ReduceTypeSum::identity<1>(&exp_sum_input); + ReduceTypeSum::identity<1>(&cmp); for (int idx_in_w = tid_w; idx_in_w < InShape::W; idx_in_w += ThreadsPerRow) { int idx_in = idx_in_base + idx_in_w; @@ -99,36 +97,27 @@ struct Softmax { } exp_sum_input = warpsReduce( exp_sum_input, tid, smem_per_warp); - ReduceTypeSum::singlePostReduce(&exp_sum_input, &exp_sum_input); + ReduceTypeSum::postReduce<1>(&exp_sum_input, &exp_sum_input); // the output is for (int idx_w = tid_w; idx_w < InShape::W; idx_w += ThreadsPerRow) { int idx_in = idx_in_base + idx_w; int idx_out = idx_out_base + idx_w; - out[idx_out] = expf(in[idx_in] - max_input) / exp_sum_input; + out[idx_out] = + DataType(expf(in[idx_in] - max_input) / exp_sum_input); } } }; template -DEVICE void softmax(float *out, float *in, int uop_idx, int smem_per_warp) { - constexpr int NelemPerThread = 1; - Softmax::run(out, in, uop_idx, - smem_per_warp); -} - -template -DEVICE void softmax(ark::half *out, ark::half *in, int uop_idx, + int SmemBytes, typename DataType> +DEVICE void softmax(DataType *out, DataType *in, int uop_idx, int smem_per_warp) { constexpr int NelemPerThread = 1; Softmax::run(out, in, uop_idx, - smem_per_warp); + SmemBytes, DataType, NelemPerThread>::run(out, in, uop_idx, + smem_per_warp); } } // namespace ark diff --git a/ark/include/kernels/transpose.h b/ark/include/kernels/transpose.h index 09e12e83e..dc14df492 100644 --- a/ark/include/kernels/transpose.h +++ b/ark/include/kernels/transpose.h @@ -539,21 +539,11 @@ struct Transpose3210 { //////////////////////////////////////////////////////////////////////////////// +// TODO: support NelemPerThread > 1 template -DEVICE void _transpose(float *out, float *in, int uop_idx) { - Ewise1::run(out, in, uop_idx); -} - -// TODO: we need to use NelemPerThread=2 for half in the future, if out is a -// __half pointer, this can cause a memory bug, because GPU DRAM access should -// be always 4-byte aligned -template -DEVICE void _transpose(ark::half *out, ark::half *in, int uop_idx) { + typename Transpose, typename DataType> +DEVICE void _transpose(DataType *out, DataType *in, int uop_idx) { Ewise1::run(out, in, uop_idx); } diff --git a/ark/include/kernels/type_intrinsics.h b/ark/include/kernels/type_intrinsics.h new file mode 100644 index 000000000..c103ebd7b --- /dev/null +++ b/ark/include/kernels/type_intrinsics.h @@ -0,0 +1,134 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef ARK_KERNELS_TYPE_INTRINSICS_H_ +#define ARK_KERNELS_TYPE_INTRINSICS_H_ + +#include "bfloat16.h" +#include "half.h" +#include "platform.h" + +namespace ark { +namespace type { + +// TODO: add __nv_bfloat162 support + +template +struct Constant { + static DEVICE DataType zero() { return DataType(0); } + static DEVICE DataType lowest() { + return platform::numeric_limits::lowest(); + } +}; + +template <> +struct Constant<__half2> { + static DEVICE __half2 zero() { return __half2_raw{0, 0}; } + static DEVICE __half2 lowest() { return __half2_raw{0xfbff, 0xfbff}; } +}; + +template <> +struct Constant { + static DEVICE bfloat16 zero() { return bfloat16(0); } + static DEVICE bfloat16 lowest() { return bfloat16::bitcast(0xff7f); } +}; + +struct Add { + template + static DEVICE DataType compute(DataType a, DataType b) { + return a + b; + } + static DEVICE __half2 compute(__half2 a, __half2 b) { + return __hadd2(a, b); + } + // struct DEVICE __nv_bfloat162 compute(__nv_bfloat162 a, __nv_bfloat162 b) + // { + // return __hadd2(a, b); + // } +}; + +struct Sub { + template + static DEVICE DataType compute(DataType a, DataType b) { + return a - b; + } + static DEVICE __half2 compute(__half2 a, __half2 b) { + return __hsub2(a, b); + } + // struct DEVICE __nv_bfloat162 compute(__nv_bfloat162 a, __nv_bfloat162 b) + // { + // return __hsub2(a, b); + // } +}; + +struct Mul { + template + static DEVICE DataType compute(DataType a, DataType b) { + return a * b; + } + static DEVICE __half2 compute(__half2 a, __half2 b) { + return __hmul2(a, b); + } + // struct DEVICE __nv_bfloat162 compute(__nv_bfloat162 a, __nv_bfloat162 b) + // { + // return __hmul2(a, b); + // } +}; + +struct Div { + template + static DEVICE DataType compute(DataType a, DataType b) { + return a / b; + } + static DEVICE __half2 compute(__half2 a, __half2 b) { + return __h2div(a, b); + } + // struct DEVICE __nv_bfloat162 compute(__nv_bfloat162 a, __nv_bfloat162 b) + // { + // return __h2div(a, b); + // } +}; + +struct Exp { + static DEVICE float compute(float input) { return expf(input); } + static DEVICE __half compute(__half input) { return hexp(input); } + static DEVICE bfloat16 compute(bfloat16 input) { + return bfloat16(expf(float(input))); + } + static DEVICE __half2 compute(__half2 input) { return h2exp(input); } + // struct DEVICE __nv_bfloat162 compute(__nv_bfloat162 input) { + // return h2exp(input); + // } +}; + +struct Sqrt { + static DEVICE float compute(float input) { return sqrtf(input); } + static DEVICE __half compute(__half input) { return hsqrt(input); } + static DEVICE bfloat16 compute(bfloat16 input) { + return bfloat16(sqrtf(float(input))); + } + static DEVICE __half2 compute(__half2 input) { return h2sqrt(input); } + // struct DEVICE __nv_bfloat162 compute(__nv_bfloat162 input) { + // return h2sqrt(input); + // } +}; + +struct Max { + template + static DEVICE DataType compute(DataType a, DataType b) { + return (a > b) ? a : b; + } + static DEVICE float compute(float a, float b) { return max(a, b); } + static DEVICE __half2 compute(__half2 a, __half2 b) { +#if (__CUDA_ARCH__ >= 800) + return __hmax2(a, b); +#else + return __halves2half2((a.x > b.x) ? a.x : b.x, (a.y > b.y) ? a.y : b.y); +#endif // (__CUDA_ARCH__ >= 800) + } +}; + +} // namespace type +} // namespace ark + +#endif // ARK_KERNELS_TYPE_INTRINSICS_H_ diff --git a/ark/model.cc b/ark/model.cc index 2d60bfa5e..bb6b35c15 100644 --- a/ark/model.cc +++ b/ark/model.cc @@ -42,7 +42,7 @@ void Model::Impl::destroy_tensor_buf(const TensorBuf *buf) { } std::vector Model::Impl::add_op( - const OpType type, const OpPrecType prec_type, + const OpType type, const std::string &prec_type, const vector &inputs, const vector &outputs, const OpArgs &args, const string &name, const OpConfigMap *cfg_map, int gran_lev) { diff --git a/ark/model.h b/ark/model.h index e50bf21c0..96e255527 100644 --- a/ark/model.h +++ b/ark/model.h @@ -44,7 +44,8 @@ class Model::Impl { /// should indicate finer-grained Ops. If it is -1, the granularity level /// will be automatically determined by the scheduler. /// @return the output tensors of the @ref Op. - std::vector add_op(const OpType type, const OpPrecType prec_type, + std::vector add_op(const OpType type, + const std::string &prec_type, const std::vector &inputs, const std::vector &output_refs, const OpArgs &args, const std::string &name, diff --git a/ark/ops/ops_add.cc b/ark/ops/ops_add.cc index f85e1d332..38d684d1f 100644 --- a/ark/ops/ops_add.cc +++ b/ark/ops/ops_add.cc @@ -4,14 +4,12 @@ #include "logging.h" #include "model.h" -using namespace std; - namespace ark { extern const OpConfigMap ArithmeticConfigMap; -AddOp::AddOp(OpPrecType prec_type, Tensor *input, Tensor *other, Tensor *output, - const string &name) +AddOp::AddOp(const std::string &prec_type, Tensor *input, Tensor *other, + Tensor *output, const std::string &name) : Op{OP_ADD, prec_type, {input, other}, {output}, {}, name, &ArithmeticConfigMap, -1, true} {} @@ -45,17 +43,9 @@ std::string AddOp::function_name(const OpConfig &cfg) const { } Tensor *Model::add(Tensor *input, Tensor *other, Tensor *output, - const string &name) { + const std::string &name) { CHECK(input != nullptr); CHECK(other != nullptr); - OpPrecType pt = OP_PREC_NONE; - if (input->type == FP16) { - pt = OP_PREC_FP16; - } else if (input->type == FP32) { - pt = OP_PREC_FP32; - } else { - LOG(ERROR, "unsupported input data type: ", input->type); - } if (input->type != other->type) { LOG(ERROR, "input data types mismatch: ", input->type, ", ", other->type); @@ -71,12 +61,12 @@ Tensor *Model::add(Tensor *input, Tensor *other, Tensor *output, } else if (output == input) { output = this->identity(output); } - AddOp op{pt, input, other, output, name}; + AddOp op{output->type.name(), input, other, output, name}; return this->impl->add_op(op)[0]; } const OpConfigMap ArithmeticConfigMap = { - {{OP_ARCH_CUDA_ANY, OP_PREC_FP32}, + {{OP_ARCH_CUDA_ANY, "fp32"}, { // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost {8, 0, {{128, 256}, {128, 256}}, {{128, 256}}, false, false}, @@ -93,7 +83,24 @@ const OpConfigMap ArithmeticConfigMap = { {1, 0, {{1, 64}, {1, 64}}, {{1, 64}}, false, false}, {1, 0, {{1, 32}, {1, 32}}, {{1, 32}}, false, false}, }}, - {{OP_ARCH_CUDA_ANY, OP_PREC_FP16}, + {{OP_ARCH_CUDA_ANY, "fp16"}, + { + // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost + {8, 0, {{128, 256}, {128, 256}}, {{128, 256}}, false, false}, + {8, 0, {{256, 128}, {256, 128}}, {{256, 128}}, false, false}, + {8, 0, {{128, 128}, {128, 128}}, {{128, 128}}, false, false}, + {4, 0, {{64, 64}, {64, 64}}, {{64, 64}}, false, false}, + {2, 0, {{32, 64}, {32, 64}}, {{32, 64}}, false, false}, + {1, 0, {{16, 64}, {16, 64}}, {{16, 64}}, false, false}, + {1, 0, {{8, 64}, {8, 64}}, {{8, 64}}, false, false}, + {1, 0, {{2, 128}, {2, 128}}, {{2, 128}}, false, false}, + {1, 0, {{4, 64}, {4, 64}}, {{4, 64}}, false, false}, + {1, 0, {{2, 64}, {2, 64}}, {{2, 64}}, false, false}, + {1, 0, {{1, 256}, {1, 256}}, {{1, 256}}, false, false}, + {1, 0, {{1, 128}, {1, 128}}, {{1, 128}}, false, false}, + {1, 0, {{1, 64}, {1, 64}}, {{1, 64}}, false, false}, + }}, + {{OP_ARCH_CUDA_ANY, "bf16"}, { // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost {8, 0, {{128, 256}, {128, 256}}, {{128, 256}}, false, false}, diff --git a/ark/ops/ops_add_test.cc b/ark/ops/ops_add_test.cc index fb20d5330..d478b1b0f 100644 --- a/ark/ops/ops_add_test.cc +++ b/ark/ops/ops_add_test.cc @@ -63,6 +63,19 @@ ark::unittest::State test_add_fp16() { return ark::unittest::SUCCESS; } +ark::unittest::State test_add_bf16() { + ark::Model m; + ark::Tensor *t0 = m.tensor(ark::Dims(8192), ark::BF16); + ark::Tensor *t1 = m.tensor(ark::Dims(8192), ark::BF16); + ark::Tensor *out = m.add(t0, t1); + + auto result = ark::op_test("add_bf16", m, {t0, t1}, {out}, + baseline_add); + UNITTEST_LOG(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + return ark::unittest::SUCCESS; +} + ark::unittest::State test_add_overwrite() { ark::Model m; ark::Tensor *t0 = m.tensor(ark::Dims(8192), ark::FP16); @@ -117,6 +130,7 @@ int main() { ark::init(); UNITTEST(test_add_fp32); UNITTEST(test_add_fp16); + UNITTEST(test_add_bf16); UNITTEST(test_add_overwrite); UNITTEST(test_add_broadcast); return ark::unittest::SUCCESS; diff --git a/ark/ops/ops_cast.cc b/ark/ops/ops_cast.cc index aceeadf4e..386fad726 100644 --- a/ark/ops/ops_cast.cc +++ b/ark/ops/ops_cast.cc @@ -11,7 +11,7 @@ namespace ark { extern const OpConfigMap CastConfigMap; CastOp::CastOp(Tensor *input, Tensor *output, const std::string &name) - : Op{OP_CAST, OP_PREC_NONE, {input}, {output}, {}, + : Op{OP_CAST, "none", {input}, {output}, {}, name, &CastConfigMap, -1, true} {} std::string CastOp::function_name(const OpConfig &cfg) const { @@ -153,7 +153,7 @@ Tensor *Model::cast(Tensor *input, const TensorType &ttype, Tensor *output, } const OpConfigMap CastConfigMap = { - {{OP_ARCH_CUDA_ANY, OP_PREC_NONE}, + {{OP_ARCH_CUDA_ANY, "none"}, { // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost {8, 0, {{128, 256}}, {{128, 256}}, false, false}, diff --git a/ark/ops/ops_cast_test.cc b/ark/ops/ops_cast_test.cc index 480f7c2d5..e3b31b9e3 100644 --- a/ark/ops/ops_cast_test.cc +++ b/ark/ops/ops_cast_test.cc @@ -231,6 +231,30 @@ ark::unittest::State test_cast_int32_to_byte() { return ark::unittest::SUCCESS; } +ark::unittest::State test_cast_bf16_to_float() { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::BF16); + ark::Tensor *out = m.cast(t, ark::FP32); + + auto result = ark::op_test("cast_bf16_to_float", m, {t}, {out}, + baseline_cast); + UNITTEST_LOG(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_cast_float_to_bf16() { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::FP32); + ark::Tensor *out = m.cast(t, ark::BF16); + + auto result = ark::op_test("cast_float_to_bf16", m, {t}, {out}, + baseline_cast); + UNITTEST_LOG(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + return ark::unittest::SUCCESS; +} + int main() { ark::init(); UNITTEST(test_cast_fp16_to_fp32); @@ -245,5 +269,7 @@ int main() { UNITTEST(test_cast_fp32_to_byte); UNITTEST(test_cast_fp16_to_byte); UNITTEST(test_cast_int32_to_byte); + UNITTEST(test_cast_bf16_to_float); + UNITTEST(test_cast_float_to_bf16); return ark::unittest::SUCCESS; } diff --git a/ark/ops/ops_common.cc b/ark/ops/ops_common.cc index 27c7dc9fb..4682c8e6b 100644 --- a/ark/ops/ops_common.cc +++ b/ark/ops/ops_common.cc @@ -61,7 +61,7 @@ const std::vector &OpConfigMap::get(const OpConfigKey &key) const { if (search != this->cfg_map.end()) { return search->second; } - search = this->cfg_map.find({key.arch_type, OP_PREC_ANY}); + search = this->cfg_map.find({key.arch_type, "any"}); if (search != this->cfg_map.end()) { return search->second; } @@ -69,7 +69,7 @@ const std::vector &OpConfigMap::get(const OpConfigKey &key) const { if (search != this->cfg_map.end()) { return search->second; } - search = this->cfg_map.find({OP_ARCH_CUDA_ANY, OP_PREC_ANY}); + search = this->cfg_map.find({OP_ARCH_CUDA_ANY, "any"}); if (search == this->cfg_map.end()) { return NoneConfigs; } @@ -384,7 +384,7 @@ bool operator!=(const OpArgs &opargs1, const OpArgs &opargs2) { return !(opargs1 == opargs2); } -Op::Op(const OpType &type_, const OpPrecType &prec_type_, +Op::Op(const OpType &type_, const std::string &prec_type_, const vector &inputs_, const vector &output_refs_, const OpArgs &args_, const string &name_, const OpConfigMap *cfg_map_, int gran_lev_, bool force_inline_) diff --git a/ark/ops/ops_common.h b/ark/ops/ops_common.h index d6d4e963a..e224cbe0e 100644 --- a/ark/ops/ops_common.h +++ b/ark/ops/ops_common.h @@ -129,21 +129,13 @@ typedef enum { OP_CAST, } OpType; -/// Type of precision of @ref Op. -typedef enum { - OP_PREC_NONE, - OP_PREC_ANY, - OP_PREC_FP16, - OP_PREC_FP32, -} OpPrecType; - /// Type of hardware architecture support. typedef enum { - OP_ARCH_CUDA_ANY, - OP_ARCH_CUDA_60, - OP_ARCH_CUDA_70, - OP_ARCH_CUDA_80, - OP_ARCH_CUDA_90, + OP_ARCH_CUDA_60 = 0x1, + OP_ARCH_CUDA_70 = 0x2, + OP_ARCH_CUDA_80 = 0x4, + OP_ARCH_CUDA_90 = 0x8, + OP_ARCH_CUDA_ANY = -1, } OpArchType; struct Tensor; @@ -167,7 +159,7 @@ struct OpConfig { /// Key to find a list of OpConfigs from OpConfigMap. struct OpConfigKey { OpArchType arch_type; - OpPrecType prec_type; + std::string prec_type; }; bool operator<(const OpConfigKey &ops1, const OpConfigKey &ops2); @@ -208,7 +200,7 @@ class Op { /// should indicate finer-grained Ops. If it is -1, the granularity level /// will be automatically determined by the scheduler. /// @param force_inline whether to force inline the kernel of @ref Op. - Op(const OpType &type, const OpPrecType &prec_type, + Op(const OpType &type, const std::string &prec_type, const std::vector &inputs, const std::vector &output_refs, const OpArgs &args, const std::string &name, const OpConfigMap *cfg_map = nullptr, @@ -240,7 +232,7 @@ class Op { /// Type of the operator. OpType type; /// Precision type of the operator. - OpPrecType prec_type; + std::string prec_type; /// The input tensors of the operator. std::vector inputs; /// The output tensors of the operator. @@ -272,63 +264,63 @@ std::ostream &operator<<(std::ostream &os, const OpType &s); class AddOp : public Op { public: - AddOp(OpPrecType prec_type, Tensor *input, Tensor *other, Tensor *output, - const std::string &name); + AddOp(const std::string &prec_type, Tensor *input, Tensor *other, + Tensor *output, const std::string &name); std::string function_name(const OpConfig &cfg) const; }; class SubOp : public Op { public: - SubOp(OpPrecType prec_type, Tensor *input, Tensor *other, Tensor *output, - const std::string &name); + SubOp(const std::string &prec_type, Tensor *input, Tensor *other, + Tensor *output, const std::string &name); std::string function_name(const OpConfig &cfg) const; }; class MulOp : public Op { public: - MulOp(OpPrecType prec_type, Tensor *input, Tensor *other, Tensor *output, - const std::string &name); + MulOp(const std::string &prec_type, Tensor *input, Tensor *other, + Tensor *output, const std::string &name); std::string function_name(const OpConfig &cfg) const; }; class DivOp : public Op { public: - DivOp(OpPrecType prec_type, Tensor *input, Tensor *other, Tensor *output, - const std::string &name); + DivOp(const std::string &prec_type, Tensor *input, Tensor *other, + Tensor *output, const std::string &name); std::string function_name(const OpConfig &cfg) const; }; class GeluOp : public Op { public: - GeluOp(OpPrecType prec_type, Tensor *input, Tensor *output, + GeluOp(const std::string &prec_type, Tensor *input, Tensor *output, const std::string &name); std::string function_name(const OpConfig &cfg) const; }; class ExpOp : public Op { public: - ExpOp(OpPrecType prec_type, Tensor *input, Tensor *output, + ExpOp(const std::string &prec_type, Tensor *input, Tensor *output, const std::string &name); std::string function_name(const OpConfig &cfg) const; }; class SqrtOp : public Op { public: - SqrtOp(OpPrecType prec_type, Tensor *input, Tensor *output, + SqrtOp(const std::string &prec_type, Tensor *input, Tensor *output, const std::string &name); std::string function_name(const OpConfig &cfg) const; }; class RopeOp : public Op { public: - RopeOp(OpPrecType prec_type, Tensor *input, Tensor *other, Tensor *output, - const std::string &name); + RopeOp(const std::string &prec_type, Tensor *input, Tensor *other, + Tensor *output, const std::string &name); std::string function_name(const OpConfig &cfg) const; }; class Im2colOp : public Op { public: - Im2colOp(OpPrecType prec_type, Tensor *input, Tensor *output, + Im2colOp(const std::string &prec_type, Tensor *input, Tensor *output, int kernel_height, int kernel_width, int stride_height, int stride_width, int pad_height, int pad_width, int dilation_height, int dilation_width, const std::string &name); @@ -337,36 +329,36 @@ class Im2colOp : public Op { class LayernormOp : public Op { public: - LayernormOp(OpPrecType prec_type, Tensor *input, Tensor *output, + LayernormOp(const std::string &prec_type, Tensor *input, Tensor *output, const std::string &name); std::string function_name(const OpConfig &cfg) const; }; class RMSnormOp : public Op { public: - RMSnormOp(OpPrecType prec_type, Tensor *input, Tensor *output, + RMSnormOp(const std::string &prec_type, Tensor *input, Tensor *output, const std::string &name); std::string function_name(const OpConfig &cfg) const; }; class MatmulOp : public Op { public: - MatmulOp(OpPrecType prec_type, Tensor *mat_a, Tensor *mat_b, Tensor *mat_y, - Dims nca, Dims ncb, Dims problem_size, Dims leading_dims, - bool is_column_a, bool is_column_b, const std::string &name, - int gran_lev); + MatmulOp(const std::string &prec_type, Tensor *mat_a, Tensor *mat_b, + Tensor *mat_y, Dims nca, Dims ncb, Dims problem_size, + Dims leading_dims, bool is_column_a, bool is_column_b, + const std::string &name, int gran_lev); std::string function_name(const OpConfig &cfg) const; }; class MaxPoolOp : public Op { public: - MaxPoolOp(OpPrecType prec_type, Tensor *input, Tensor *output, + MaxPoolOp(const std::string &prec_type, Tensor *input, Tensor *output, DimType kernel_size, DimType stride, const std::string &name); }; class ReduceOp : public Op { public: - ReduceOp(const OpType &type, const OpPrecType &prec_type, + ReduceOp(const OpType &type, const std::string &prec_type, const std::vector &inputs, const std::vector &outputs, const OpArgs &args, const std::string &name, const OpConfigMap *cfg_map, int gran_lev); @@ -378,77 +370,77 @@ class ReduceOp : public Op { class ReduceWSumOp : public ReduceOp { public: - ReduceWSumOp(OpPrecType prec_type, Tensor *input, Tensor *output, int axis, - const std::string &name); + ReduceWSumOp(const std::string &prec_type, Tensor *input, Tensor *output, + int axis, const std::string &name); std::string function_name(const OpConfig &cfg) const; }; class ReduceESumOp : public ReduceOp { public: - ReduceESumOp(OpPrecType prec_type, Tensor *input, Tensor *output, int axis, - const std::string &name); + ReduceESumOp(const std::string &prec_type, Tensor *input, Tensor *output, + int axis, const std::string &name); std::string function_name(const OpConfig &cfg) const; }; class ReduceWMaxOp : public ReduceOp { public: - ReduceWMaxOp(OpPrecType prec_type, Tensor *input, Tensor *output, int axis, - const std::string &name); + ReduceWMaxOp(const std::string &prec_type, Tensor *input, Tensor *output, + int axis, const std::string &name); std::string function_name(const OpConfig &cfg) const; }; class ReduceEMaxOp : public ReduceOp { public: - ReduceEMaxOp(OpPrecType prec_type, Tensor *input, Tensor *output, int axis, - const std::string &name); + ReduceEMaxOp(const std::string &prec_type, Tensor *input, Tensor *output, + int axis, const std::string &name); std::string function_name(const OpConfig &cfg) const; }; class ReduceWMeanOp : public ReduceOp { public: - ReduceWMeanOp(OpPrecType prec_type, Tensor *input, Tensor *output, int axis, - const std::string &name); + ReduceWMeanOp(const std::string &prec_type, Tensor *input, Tensor *output, + int axis, const std::string &name); std::string function_name(const OpConfig &cfg) const; }; class ReduceEMeanOp : public ReduceOp { public: - ReduceEMeanOp(OpPrecType prec_type, Tensor *input, Tensor *output, int axis, - const std::string &name); + ReduceEMeanOp(const std::string &prec_type, Tensor *input, Tensor *output, + int axis, const std::string &name); std::string function_name(const OpConfig &cfg) const; }; class ReluOp : public Op { public: - ReluOp(OpPrecType prec_type, Tensor *input, Tensor *output, + ReluOp(const std::string &prec_type, Tensor *input, Tensor *output, const std::string &name); std::string function_name(const OpConfig &cfg) const; }; class SigmoidOp : public Op { public: - SigmoidOp(OpPrecType prec_type, Tensor *input, Tensor *output, + SigmoidOp(const std::string &prec_type, Tensor *input, Tensor *output, const std::string &name); std::string function_name(const OpConfig &cfg) const; }; class ReshapeOp : public Op { public: - ReshapeOp(OpPrecType prec_type, Tensor *input, Tensor *output, + ReshapeOp(const std::string &prec_type, Tensor *input, Tensor *output, const std::string &name); }; class ScaleOp : public Op { public: - ScaleOp(OpPrecType prec_type, Tensor *input, Tensor *output, float val, - const std::string &name); + ScaleOp(const std::string &prec_type, Tensor *input, Tensor *output, + float val, const std::string &name); std::string function_name(const OpConfig &cfg) const; OpArgs function_call_args(const OpConfig &) const; }; class SendMMOp : public Op { public: - SendMMOp(OpPrecType prec_type, Tensor *input, Tensor *recvbuf, + SendMMOp(const std::string &prec_type, Tensor *input, Tensor *recvbuf, Tensor *send_ready_flag, Tensor *output, int id, int gpu_dst, size_t bytes, const std::string &name); std::string function_name(const OpConfig &cfg) const; @@ -457,7 +449,7 @@ class SendMMOp : public Op { class RecvMMOp : public Op { public: - RecvMMOp(OpPrecType prec_type, Tensor *input, Tensor *recvbuf, + RecvMMOp(const std::string &prec_type, Tensor *input, Tensor *recvbuf, Tensor *send_ready_flag, Tensor *output, int id, int gpu_src, size_t bytes, const std::string &name); std::string function_name(const OpConfig &cfg) const; @@ -466,7 +458,7 @@ class RecvMMOp : public Op { class SendOp : public Op { public: - SendOp(OpPrecType prec_type, Tensor *input, int sid, int rank, int dst_rank, + SendOp(const std::string &prec_type, Tensor *input, int sid, int rank, int dst_rank, size_t bytes, const std::string &name); std::string function_name(const OpConfig &cfg) const; OpArgs function_call_args(const OpConfig &cfg) const; @@ -474,7 +466,7 @@ class SendOp : public Op { class SendDoneOp : public Op { public: - SendDoneOp(OpPrecType prec_type, Tensor *input, int sid, int rank, + SendDoneOp(const std::string &prec_type, Tensor *input, int sid, int rank, int dst_rank, const std::string &name); std::string function_name(const OpConfig &cfg) const; OpArgs function_call_args(const OpConfig &cfg) const; @@ -482,7 +474,7 @@ class SendDoneOp : public Op { class RecvOp : public Op { public: - RecvOp(OpPrecType prec_type, Tensor *output, int sid, int rank, + RecvOp(const std::string &prec_type, Tensor *output, int sid, int rank, int src_rank, size_t bytes, const std::string &name); std::string function_name(const OpConfig &cfg) const; OpArgs function_call_args(const OpConfig &cfg) const; @@ -490,7 +482,7 @@ class RecvOp : public Op { class SoftmaxOp : public Op { public: - SoftmaxOp(OpPrecType prec_type, Tensor *input, Tensor *output, + SoftmaxOp(const std::string &prec_type, Tensor *input, Tensor *output, const std::string &name); std::string function_name(const OpConfig &cfg) const; }; @@ -503,14 +495,14 @@ class TensorOp : public Op { class TransposeOp : public Op { public: - TransposeOp(OpPrecType prec_type, Tensor *input, Tensor *output, + TransposeOp(const std::string &prec_type, Tensor *input, Tensor *output, int tp_type, const std::string &name); std::string function_name(const OpConfig &cfg) const; }; class EmbeddingOp : public Op { public: - EmbeddingOp(OpPrecType prec_type, Tensor *input, Tensor *weight, + EmbeddingOp(const std::string &prec_type, Tensor *input, Tensor *weight, Tensor *output, const std::string &name); std::string function_name(const OpConfig &cfg) const; }; diff --git a/ark/ops/ops_div.cc b/ark/ops/ops_div.cc index 915219767..ee11d86d0 100644 --- a/ark/ops/ops_div.cc +++ b/ark/ops/ops_div.cc @@ -12,8 +12,8 @@ namespace ark { extern const OpConfigMap ArithmeticConfigMap; -DivOp::DivOp(OpPrecType prec_type, Tensor *input, Tensor *other, Tensor *output, - const string &name) +DivOp::DivOp(const std::string &prec_type, Tensor *input, Tensor *other, + Tensor *output, const string &name) : Op{OP_DIV, prec_type, {input, other}, {output}, {}, name, &ArithmeticConfigMap, -1, true} {} @@ -50,14 +50,6 @@ Tensor *Model::div(Tensor *input, Tensor *other, Tensor *output, const string &name) { assert(input != nullptr); assert(other != nullptr); - OpPrecType pt = OP_PREC_NONE; - if (input->type == FP16) { - pt = OP_PREC_FP16; - } else if (input->type == FP32) { - pt = OP_PREC_FP32; - } else { - LOG(ERROR, "unsupported input data type: ", input->type); - } if (input->type != other->type) { LOG(ERROR, "input data types mismatch: ", input->type, ", ", other->type); @@ -73,7 +65,7 @@ Tensor *Model::div(Tensor *input, Tensor *other, Tensor *output, } else if (output == input) { output = this->identity(output); } - DivOp op{pt, input, other, output, name}; + DivOp op{output->type.name(), input, other, output, name}; return this->impl->add_op(op)[0]; } diff --git a/ark/ops/ops_embedding.cc b/ark/ops/ops_embedding.cc index cd833a580..addadca9e 100644 --- a/ark/ops/ops_embedding.cc +++ b/ark/ops/ops_embedding.cc @@ -10,8 +10,9 @@ namespace ark { extern const OpConfigMap EmbeddingConfigMap; -EmbeddingOp::EmbeddingOp(OpPrecType prec_type, Tensor *input, Tensor *weight, - Tensor *output, const std::string &name) +EmbeddingOp::EmbeddingOp(const std::string &prec_type, Tensor *input, + Tensor *weight, Tensor *output, + const std::string &name) : Op{OP_EMBEDDING, prec_type, {input, weight}, {output}, {}, name, &EmbeddingConfigMap, -1, true} {} @@ -54,14 +55,6 @@ Tensor *Model::embedding(Tensor *input, Tensor *weight, Tensor *output, if (weight->shape.ndims() != 2) { LOG(ERROR, "weight shape ndims != 2: ", weight->shape); } - OpPrecType pt = OP_PREC_NONE; - if (weight->type == FP16) { - pt = OP_PREC_FP16; - } else if (weight->type == FP32) { - pt = OP_PREC_FP32; - } else { - LOG(ERROR, "unsupported weight data type: ", weight->type); - } auto emb_dim = weight->shape[-1]; std::vector output_dims; @@ -73,12 +66,12 @@ Tensor *Model::embedding(Tensor *input, Tensor *weight, Tensor *output, if (output == nullptr) { output = this->tensor(out_shape, weight->type); } - EmbeddingOp op{pt, input, weight, output, name}; + EmbeddingOp op{output->type.name(), input, weight, output, name}; return this->impl->add_op(op)[0]; } const OpConfigMap EmbeddingConfigMap = { - {{OP_ARCH_CUDA_ANY, OP_PREC_ANY}, + {{OP_ARCH_CUDA_ANY, "any"}, { // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost {1, 0, {{1, 1}, {1, -1}}, {{1, -1}}, true, false}, diff --git a/ark/ops/ops_embedding_test.cc b/ark/ops/ops_embedding_test.cc index 9a8305303..453ba0102 100644 --- a/ark/ops/ops_embedding_test.cc +++ b/ark/ops/ops_embedding_test.cc @@ -77,9 +77,14 @@ ark::unittest::State test_embedding_fp16() { return test_embedding(); } +ark::unittest::State test_embedding_bf16() { + return test_embedding(); +} + int main() { ark::init(); UNITTEST(test_embedding_fp32); UNITTEST(test_embedding_fp16); + UNITTEST(test_embedding_bf16); return ark::unittest::SUCCESS; } diff --git a/ark/ops/ops_exp.cc b/ark/ops/ops_exp.cc index 85906ca82..96f5fee30 100644 --- a/ark/ops/ops_exp.cc +++ b/ark/ops/ops_exp.cc @@ -12,7 +12,7 @@ namespace ark { extern const OpConfigMap MathConfigMap; -ExpOp::ExpOp(OpPrecType prec_type, Tensor *input, Tensor *output, +ExpOp::ExpOp(const std::string &prec_type, Tensor *input, Tensor *output, const string &name) : Op{OP_EXP, prec_type, {input}, {output}, {}, name, &MathConfigMap, -1, true} {} @@ -44,14 +44,6 @@ std::string ExpOp::function_name(const OpConfig &cfg) const { Tensor *Model::exp(Tensor *input, Tensor *output, const string &name) { assert(input != nullptr); - OpPrecType pt = OP_PREC_NONE; - if (input->type == FP16) { - pt = OP_PREC_FP16; - } else if (input->type == FP32) { - pt = OP_PREC_FP32; - } else { - LOG(ERROR, "unsupported input data type: ", input->type); - } if (output != nullptr && input->type != output->type) { LOG(ERROR, "invalid output data type: ", output->type); } @@ -60,7 +52,7 @@ Tensor *Model::exp(Tensor *input, Tensor *output, const string &name) { } else if (output->shape != input->shape) { LOG(ERROR, "invalid output shape: ", output->shape); } - ExpOp op{pt, input, output, name}; + ExpOp op{output->type.name(), input, output, name}; return this->impl->add_op(op)[0]; } diff --git a/ark/ops/ops_exp_test.cc b/ark/ops/ops_exp_test.cc index a1208fcf3..343f9ca3f 100644 --- a/ark/ops/ops_exp_test.cc +++ b/ark/ops/ops_exp_test.cc @@ -32,8 +32,34 @@ ark::unittest::State test_exp_fp32() { return ark::unittest::SUCCESS; } +ark::unittest::State test_exp_fp16() { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::FP16); + ark::Tensor *out = m.exp(t); + + auto result = + ark::op_test("exp_fp16", m, {t}, {out}, baseline_exp); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-2f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_exp_bf16() { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::BF16); + ark::Tensor *out = m.exp(t); + + auto result = + ark::op_test("exp_bf16", m, {t}, {out}, baseline_exp); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-2f); + return ark::unittest::SUCCESS; +} + int main() { ark::init(); UNITTEST(test_exp_fp32); + UNITTEST(test_exp_fp16); + UNITTEST(test_exp_bf16); return ark::unittest::SUCCESS; } diff --git a/ark/ops/ops_gelu.cc b/ark/ops/ops_gelu.cc index 7f3d392dd..62aabf7d0 100644 --- a/ark/ops/ops_gelu.cc +++ b/ark/ops/ops_gelu.cc @@ -10,7 +10,7 @@ namespace ark { extern const OpConfigMap ActivationConfigMap; -GeluOp::GeluOp(OpPrecType prec_type, Tensor *input, Tensor *output, +GeluOp::GeluOp(const std::string &prec_type, Tensor *input, Tensor *output, const std::string &name) : Op{OP_GELU, prec_type, {input}, {output}, {}, name, &ActivationConfigMap, -1, true} {} @@ -43,14 +43,6 @@ std::string GeluOp::function_name(const OpConfig &cfg) const { Tensor *Model::gelu(Tensor *input, Tensor *output, const std::string &name) { assert(input != nullptr); - OpPrecType pt = OP_PREC_NONE; - if (input->type == FP16) { - pt = OP_PREC_FP16; - } else if (input->type == FP32) { - pt = OP_PREC_FP32; - } else { - LOG(ERROR, "unsupported input data type: ", input->type); - } if (output != nullptr && input->type != output->type) { LOG(ERROR, "invalid output data type: ", output->type); } @@ -59,7 +51,7 @@ Tensor *Model::gelu(Tensor *input, Tensor *output, const std::string &name) { } else if (output->shape != input->shape) { LOG(ERROR, "invalid output shape: ", output->shape); } - GeluOp op{pt, input, output, name}; + GeluOp op{output->type.name(), input, output, name}; return this->impl->add_op(op)[0]; } diff --git a/ark/ops/ops_gelu_test.cc b/ark/ops/ops_gelu_test.cc index 2b27d81c0..8ee2248c3 100644 --- a/ark/ops/ops_gelu_test.cc +++ b/ark/ops/ops_gelu_test.cc @@ -37,8 +37,21 @@ ark::unittest::State test_gelu_fp32() { return ark::unittest::SUCCESS; } +ark::unittest::State test_gelu_bf16() { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::BF16); + ark::Tensor *out = m.gelu(t); + + auto result = ark::op_test("gelu_bf16", m, {t}, {out}, + baseline_gelu); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-6f); + return ark::unittest::SUCCESS; +} + int main() { ark::init(); UNITTEST(test_gelu_fp32); + UNITTEST(test_gelu_bf16); return ark::unittest::SUCCESS; } diff --git a/ark/ops/ops_im2col.cc b/ark/ops/ops_im2col.cc index eda7aae8d..ddc3edcf9 100644 --- a/ark/ops/ops_im2col.cc +++ b/ark/ops/ops_im2col.cc @@ -10,7 +10,7 @@ namespace ark { extern const OpConfigMap Im2colConfigMap; -Im2colOp::Im2colOp(OpPrecType prec_type, Tensor *input, Tensor *output, +Im2colOp::Im2colOp(const std::string &prec_type, Tensor *input, Tensor *output, int kernel_height, int kernel_width, int stride_height, int stride_width, int pad_height, int pad_width, int dilation_height, int dilation_width, @@ -104,14 +104,6 @@ Tensor *Model::im2col(Tensor *input, int kernel_height, int kernel_width, "invalid # of input dimensions. Expected 2, 3, or 4, but given ", input_ndims); } - OpPrecType pt = OP_PREC_NONE; - if (input->type == FP16) { - pt = OP_PREC_FP16; - } else if (input->type == FP32) { - pt = OP_PREC_FP32; - } else { - LOG(ERROR, "unsupported input data type: ", input->type); - } DimType out_h = (h + 2 * pad_height - kernel_height) / stride_height + 1; DimType out_w = (w + 2 * pad_width - kernel_width) / stride_width + 1; assert((out_h > 0) && (out_w > 0)); @@ -128,14 +120,15 @@ Tensor *Model::im2col(Tensor *input, int kernel_height, int kernel_width, } else { assert(output->shape == out_shape); } - Im2colOp op{pt, input, output, kernel_height, - kernel_width, stride_height, stride_width, pad_height, - pad_width, dilation_height, dilation_width, name}; + Im2colOp op{output->type.name(), input, output, + kernel_height, kernel_width, stride_height, + stride_width, pad_height, pad_width, + dilation_height, dilation_width, name}; return this->impl->add_op(op)[0]; } const OpConfigMap Im2colConfigMap = { - {{OP_ARCH_CUDA_ANY, OP_PREC_FP16}, + {{OP_ARCH_CUDA_ANY, "fp16"}, { // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost {8, 0, {{1, 1}}, {{128, 128}}, true, false}, diff --git a/ark/ops/ops_layernorm.cc b/ark/ops/ops_layernorm.cc index 9fb89c066..9fc738a4e 100644 --- a/ark/ops/ops_layernorm.cc +++ b/ark/ops/ops_layernorm.cc @@ -10,8 +10,8 @@ namespace ark { extern const OpConfigMap LayernormConfigMap; -LayernormOp::LayernormOp(OpPrecType prec_type, Tensor *input, Tensor *output, - const std::string &name) +LayernormOp::LayernormOp(const std::string &prec_type, Tensor *input, + Tensor *output, const std::string &name) : Op{OP_LAYERNORM, prec_type, {input}, {output}, {}, name, &LayernormConfigMap, -1, true} {} @@ -44,14 +44,6 @@ std::string LayernormOp::function_name(const OpConfig &cfg) const { Tensor *Model::layernorm(Tensor *input, Tensor *output, const std::string &name) { assert(input != nullptr); - OpPrecType pt = OP_PREC_NONE; - if (input->type == FP16) { - pt = OP_PREC_FP16; - } else if (input->type == FP32) { - pt = OP_PREC_FP32; - } else { - LOG(ERROR, "unsupported input data type: ", input->type); - } if (output != nullptr && input->type != output->type) { LOG(ERROR, "invalid output data type: ", output->type); } @@ -60,12 +52,12 @@ Tensor *Model::layernorm(Tensor *input, Tensor *output, } else if (output == input) { output = this->identity(output); } - LayernormOp op{pt, input, output, name}; + LayernormOp op{output->type.name(), input, output, name}; return this->impl->add_op(op)[0]; } const OpConfigMap LayernormConfigMap = { - {{OP_ARCH_CUDA_ANY, OP_PREC_ANY}, + {{OP_ARCH_CUDA_ANY, "any"}, { // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost {1, 128, {{32, -1}}, {{32, -1}}, true, false}, diff --git a/ark/ops/ops_matmul.cc b/ark/ops/ops_matmul.cc index 522d6827d..6a280e67d 100644 --- a/ark/ops/ops_matmul.cc +++ b/ark/ops/ops_matmul.cc @@ -11,7 +11,7 @@ namespace ark { extern const OpConfigMap MatmulConfigMap; -MatmulOp::MatmulOp(OpPrecType prec_type, Tensor *mat_a, Tensor *mat_b, +MatmulOp::MatmulOp(const std::string &prec_type, Tensor *mat_a, Tensor *mat_b, Tensor *mat_y, Dims nca, Dims ncb, Dims problem_size, Dims leading_dims, bool is_column_a, bool is_column_b, const string &name, int gran_lev) @@ -137,14 +137,6 @@ Tensor *Model::matmul(Tensor *mat_a, Tensor *mat_b, Tensor *mat_y, LOG(ERROR, "inner dimensions mismatch: ", k, " and ", k2); } - OpPrecType pt = OP_PREC_NONE; - if (mat_a->type == FP16) { - pt = OP_PREC_FP16; - } else if (mat_a->type == FP32) { - pt = OP_PREC_FP32; - } else { - LOG(ERROR, "unsupported input data type: ", mat_a->type); - } if (mat_a->type != mat_b->type) { LOG(ERROR, "input data types mismatch: ", mat_a->type, ", ", mat_b->type); @@ -221,9 +213,9 @@ Tensor *Model::matmul(Tensor *mat_a, Tensor *mat_b, Tensor *mat_y, ldims_y[ldims_y.ndims() - 1], ldims_y[ldims_y.ndims() - 1], trans_b ? ldims_b[ndims_b - 2] : ldims_b[ndims_b - 1]}; Dims problem_size{m, n, k}; - MatmulOp op{pt, mat_a, mat_b, mat_y, - nca, ncb, problem_size, leading_dims, - trans_a, trans_b, name, gran_lev}; + MatmulOp op{ + mat_y->type.name(), mat_a, mat_b, mat_y, nca, ncb, + problem_size, leading_dims, trans_a, trans_b, name, gran_lev}; return this->impl->add_op(op)[0]; } else if (split_k > k) { LOG(ERROR, "Split-K given larger than the K dimension size."); @@ -280,19 +272,26 @@ Tensor *Model::matmul(Tensor *mat_a, Tensor *mat_b, Tensor *mat_y, } const OpConfigMap MatmulConfigMap = { - {{OP_ARCH_CUDA_70, OP_PREC_FP16}, + {{OP_ARCH_CUDA_70, "fp16"}, { // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost {8, 49152, {{128, 32}, {32, 256}}, {{128, 256}}, true, false}, }}, - {{OP_ARCH_CUDA_80, OP_PREC_FP16}, + {{OP_ARCH_CUDA_80, "bf16"}, + { + // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost + {8, 147456, {{128, 64}, {64, 256}}, {{128, 256}}, true, false}, + {4, 98304, {{128, 64}, {64, 128}}, {{128, 128}}, true, false}, + {4, 98304, {{64, 64}, {64, 64}}, {{64, 64}}, true, false}, + }}, + {{OP_ARCH_CUDA_80, "fp16"}, { // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost {8, 147456, {{128, 64}, {64, 256}}, {{128, 256}}, true, false}, {4, 98304, {{128, 64}, {64, 128}}, {{128, 128}}, true, false}, {4, 98304, {{64, 64}, {64, 64}}, {{64, 64}}, true, false}, }}, - {{OP_ARCH_CUDA_80, OP_PREC_FP32}, + {{OP_ARCH_CUDA_80, "fp32"}, { // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost {8, 147456, {{128, 32}, {32, 256}}, {{128, 256}}, true, false}, diff --git a/ark/ops/ops_matmul_test.cu b/ark/ops/ops_matmul_test.cu index 3796c1de7..1c936e2c8 100644 --- a/ark/ops/ops_matmul_test.cu +++ b/ark/ops/ops_matmul_test.cu @@ -21,163 +21,32 @@ cublasHandle_t get_cublas_handle() { return globalCublasHandle; } -void cublas_matmul_float_nn(int m, int n, int k, const float *a, int lda, - const float *b, int ldb, float *c, int ldc, - int batch_size = 1) { +template +void cublas_matmul(int m, int n, int k, void *alpha, const void *a, int lda, + const void *b, int ldb, void *beta, void *c, int ldc, + int batch_size = 1) { auto cublasH = get_cublas_handle(); - float alpha = 1; - float beta = 0; cublasStatus_t status; + cublasOperation_t optypeA = (cublasOperation_t)CubOpTypeA; + cublasOperation_t optypeB = (cublasOperation_t)CubOpTypeB; + cudaDataType dtype = (cudaDataType)CudaDataType; + cublasComputeType_t ctype = (cublasComputeType_t)CubComputeType; if (batch_size == 1) { - status = cublasSgemm(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, - b, ldb, a, lda, &beta, c, ldc); - } else { - status = cublasSgemmStridedBatched( - cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, b, ldb, n * k, - a, lda, k * m, &beta, c, ldc, n * m, batch_size); - } - if (status != CUBLAS_STATUS_SUCCESS) { - throw std::runtime_error("Failed to call cublasSgemm"); - } -} - -void cublas_matmul_float_nt(int m, int n, int k, const float *a, int lda, - const float *b, int ldb, float *c, int ldc, - int batch_size = 1) { - auto cublasH = get_cublas_handle(); - float alpha = 1; - float beta = 0; - cublasStatus_t status; - if (batch_size == 1) { - status = cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, n, m, k, &alpha, - b, ldb, a, lda, &beta, c, ldc); - } else { - status = cublasSgemmStridedBatched( - cublasH, CUBLAS_OP_T, CUBLAS_OP_N, n, m, k, &alpha, b, ldb, n * k, - a, lda, k * m, &beta, c, ldc, n * m, batch_size); - } - if (status != CUBLAS_STATUS_SUCCESS) { - throw std::runtime_error("Failed to call cublasSgemm"); - } -} - -void cublas_matmul_float_tn(int m, int n, int k, const float *a, int lda, - const float *b, int ldb, float *c, int ldc, - int batch_size = 1) { - auto cublasH = get_cublas_handle(); - float alpha = 1; - float beta = 0; - cublasStatus_t status; - if (batch_size == 1) { - status = cublasSgemm(cublasH, CUBLAS_OP_N, CUBLAS_OP_T, n, m, k, &alpha, - b, ldb, a, lda, &beta, c, ldc); - } else { - status = cublasSgemmStridedBatched( - cublasH, CUBLAS_OP_N, CUBLAS_OP_T, n, m, k, &alpha, b, ldb, n * k, - a, lda, k * m, &beta, c, ldc, n * m, batch_size); - } - if (status != CUBLAS_STATUS_SUCCESS) { - throw std::runtime_error("Failed to call cublasSgemm"); - } -} - -void cublas_matmul_float_tt(int m, int n, int k, const float *a, int lda, - const float *b, int ldb, float *c, int ldc, - int batch_size = 1) { - auto cublasH = get_cublas_handle(); - float alpha = 1; - float beta = 0; - cublasStatus_t status; - if (batch_size == 1) { - status = cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_T, n, m, k, &alpha, - b, ldb, a, lda, &beta, c, ldc); - } else { - status = cublasSgemmStridedBatched( - cublasH, CUBLAS_OP_T, CUBLAS_OP_T, n, m, k, &alpha, b, ldb, n * k, - a, lda, k * m, &beta, c, ldc, n * m, batch_size); - } - if (status != CUBLAS_STATUS_SUCCESS) { - throw std::runtime_error("Failed to call cublasSgemm"); - } -} - -void cublas_matmul_half_nn(int m, int n, int k, const half *a, int lda, - const half *b, int ldb, half *c, int ldc, - int batch_size = 1) { - auto cublasH = get_cublas_handle(); - half alpha = half(ark::half_t(1)); - half beta = half(ark::half_t(0)); - cublasStatus_t status; - if (batch_size == 1) { - status = cublasHgemm(cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, - b, ldb, a, lda, &beta, c, ldc); - } else { - status = cublasHgemmStridedBatched( - cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, b, ldb, n * k, - a, lda, k * m, &beta, c, ldc, n * m, batch_size); - } - if (status != CUBLAS_STATUS_SUCCESS) { - throw std::runtime_error("Failed to call cublasHgemm"); - } -} - -void cublas_matmul_half_nt(int m, int n, int k, const half *a, int lda, - const half *b, int ldb, half *c, int ldc, - int batch_size = 1) { - auto cublasH = get_cublas_handle(); - half alpha = half(ark::half_t(1)); - half beta = half(ark::half_t(0)); - cublasStatus_t status; - if (batch_size == 1) { - status = cublasHgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N, n, m, k, &alpha, - b, ldb, a, lda, &beta, c, ldc); - } else { - status = cublasHgemmStridedBatched( - cublasH, CUBLAS_OP_T, CUBLAS_OP_N, n, m, k, &alpha, b, ldb, n * k, - a, lda, k * m, &beta, c, ldc, n * m, batch_size); - } - if (status != CUBLAS_STATUS_SUCCESS) { - throw std::runtime_error("Failed to call cublasHgemm"); - } -} - -void cublas_matmul_half_tn(int m, int n, int k, const half *a, int lda, - const half *b, int ldb, half *c, int ldc, - int batch_size = 1) { - auto cublasH = get_cublas_handle(); - half alpha = half(ark::half_t(1)); - half beta = half(ark::half_t(0)); - cublasStatus_t status; - if (batch_size == 1) { - status = cublasHgemm(cublasH, CUBLAS_OP_N, CUBLAS_OP_T, n, m, k, &alpha, - b, ldb, a, lda, &beta, c, ldc); - } else { - status = cublasHgemmStridedBatched( - cublasH, CUBLAS_OP_N, CUBLAS_OP_T, n, m, k, &alpha, b, ldb, n * k, - a, lda, k * m, &beta, c, ldc, n * m, batch_size); - } - if (status != CUBLAS_STATUS_SUCCESS) { - throw std::runtime_error("Failed to call cublasHgemm"); - } -} - -void cublas_matmul_half_tt(int m, int n, int k, const half *a, int lda, - const half *b, int ldb, half *c, int ldc, - int batch_size = 1) { - auto cublasH = get_cublas_handle(); - half alpha = half(ark::half_t(1)); - half beta = half(ark::half_t(0)); - cublasStatus_t status; - if (batch_size == 1) { - status = cublasHgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_T, n, m, k, &alpha, - b, ldb, a, lda, &beta, c, ldc); + status = cublasGemmEx(cublasH, optypeB, optypeA, n, m, k, alpha, b, + dtype, ldb, a, dtype, lda, beta, c, dtype, ldc, + ctype, CUBLAS_GEMM_DEFAULT); + if (status != CUBLAS_STATUS_SUCCESS) { + throw std::runtime_error("Failed to call cublasGemmEx"); + } } else { - status = cublasHgemmStridedBatched( - cublasH, CUBLAS_OP_T, CUBLAS_OP_T, n, m, k, &alpha, b, ldb, n * k, - a, lda, k * m, &beta, c, ldc, n * m, batch_size); - } - if (status != CUBLAS_STATUS_SUCCESS) { - throw std::runtime_error("Failed to call cublasHgemm"); + status = cublasGemmStridedBatchedEx( + cublasH, optypeB, optypeA, n, m, k, alpha, b, dtype, ldb, n * k, a, + dtype, lda, k * m, beta, c, dtype, ldc, n * m, batch_size, ctype, + CUBLAS_GEMM_DEFAULT); + if (status != CUBLAS_STATUS_SUCCESS) { + throw std::runtime_error( + "Failed to call cublasGemmStridedBatchedEx"); + } } } @@ -208,11 +77,24 @@ void baseline_matmul_nn(std::vector &outputs, // matmul using cublas if constexpr (std::is_same_v) { - cublas_matmul_float_nn(m, n, k, devA, lda, devB, ldb, devC, ldc, - batch_size); + float alpha = 1; + float beta = 0; + cublas_matmul(m, n, k, &alpha, devA, lda, + devB, ldb, &beta, devC, ldc, + batch_size); } else if constexpr (std::is_same_v) { - cublas_matmul_half_nn(m, n, k, devA, lda, devB, ldb, devC, ldc, - batch_size); + half alpha = 1; + half beta = 0; + cublas_matmul( + m, n, k, &alpha, devA, lda, devB, ldb, &beta, devC, ldc, + batch_size); + } else if constexpr (std::is_same_v) { + float alpha = 1; + float beta = 0; + cublas_matmul(m, n, k, &alpha, devA, lda, devB, ldb, + &beta, devC, ldc, batch_size); } else { throw std::runtime_error("Unsupported data type"); } @@ -249,11 +131,24 @@ void baseline_matmul_nt(std::vector &outputs, // matmul using cublas if constexpr (std::is_same_v) { - cublas_matmul_float_nt(m, n, k, devA, lda, devB, ldb, devC, ldc, - batch_size); + float alpha = 1; + float beta = 0; + cublas_matmul(m, n, k, &alpha, devA, lda, + devB, ldb, &beta, devC, ldc, + batch_size); } else if constexpr (std::is_same_v) { - cublas_matmul_half_nt(m, n, k, devA, lda, devB, ldb, devC, ldc, - batch_size); + half alpha = 1; + half beta = 0; + cublas_matmul( + m, n, k, &alpha, devA, lda, devB, ldb, &beta, devC, ldc, + batch_size); + } else if constexpr (std::is_same_v) { + float alpha = 1; + float beta = 0; + cublas_matmul(m, n, k, &alpha, devA, lda, devB, ldb, + &beta, devC, ldc, batch_size); } else { throw std::runtime_error("Unsupported data type"); } @@ -290,11 +185,24 @@ void baseline_matmul_tn(std::vector &outputs, // matmul using cublas if constexpr (std::is_same_v) { - cublas_matmul_float_tn(m, n, k, devA, lda, devB, ldb, devC, ldc, - batch_size); + float alpha = 1; + float beta = 0; + cublas_matmul(m, n, k, &alpha, devA, lda, + devB, ldb, &beta, devC, ldc, + batch_size); } else if constexpr (std::is_same_v) { - cublas_matmul_half_tn(m, n, k, devA, lda, devB, ldb, devC, ldc, - batch_size); + half alpha = 1; + half beta = 0; + cublas_matmul( + m, n, k, &alpha, devA, lda, devB, ldb, &beta, devC, ldc, + batch_size); + } else if constexpr (std::is_same_v) { + float alpha = 1; + float beta = 0; + cublas_matmul(m, n, k, &alpha, devA, lda, devB, ldb, + &beta, devC, ldc, batch_size); } else { throw std::runtime_error("Unsupported data type"); } @@ -331,11 +239,24 @@ void baseline_matmul_tt(std::vector &outputs, // matmul using cublas if constexpr (std::is_same_v) { - cublas_matmul_float_tt(m, n, k, devA, lda, devB, ldb, devC, ldc, - batch_size); + float alpha = 1; + float beta = 0; + cublas_matmul(m, n, k, &alpha, devA, lda, + devB, ldb, &beta, devC, ldc, + batch_size); } else if constexpr (std::is_same_v) { - cublas_matmul_half_tt(m, n, k, devA, lda, devB, ldb, devC, ldc, - batch_size); + half alpha = 1; + half beta = 0; + cublas_matmul( + m, n, k, &alpha, devA, lda, devB, ldb, &beta, devC, ldc, + batch_size); + } else if constexpr (std::is_same_v) { + float alpha = 1; + float beta = 0; + cublas_matmul(m, n, k, &alpha, devA, lda, devB, ldb, + &beta, devC, ldc, batch_size); } else { throw std::runtime_error("Unsupported data type"); } @@ -345,14 +266,14 @@ void baseline_matmul_tt(std::vector &outputs, ark::from_gpu(memC, outputs[0]); } -ark::unittest::State test_matmul_gran0() { +ark::unittest::State test_matmul_fp16_gran0() { { ark::Model m; ark::Tensor *a = m.tensor(ark::Dims(128, 64), ark::FP16); ark::Tensor *b = m.tensor(ark::Dims(64, 256), ark::FP16); ark::Tensor *c = m.matmul(a, b, nullptr, 1, false, false, "matmul", 0); - auto result = ark::op_test("matmul_gran0", m, {a, b}, {c}, + auto result = ark::op_test("matmul_fp16_gran0", m, {a, b}, {c}, baseline_matmul_nn); UNITTEST_LOG(result); UNITTEST_EQ(result.max_diff[0], 0.0f); @@ -363,7 +284,7 @@ ark::unittest::State test_matmul_gran0() { ark::Tensor *b = m.tensor(ark::Dims(8192, 16384), ark::FP16); ark::Tensor *c = m.matmul(a, b, nullptr, 1, false, false, "matmul", 0); - auto result = ark::op_test("matmul_gran0", m, {a, b}, {c}, + auto result = ark::op_test("matmul_fp16_gran0", m, {a, b}, {c}, baseline_matmul_nn); UNITTEST_LOG(result); UNITTEST_EQ(result.max_diff[0], 0.0f); @@ -371,14 +292,14 @@ ark::unittest::State test_matmul_gran0() { return ark::unittest::SUCCESS; } -ark::unittest::State test_matmul_gran1() { +ark::unittest::State test_matmul_fp16_gran1() { { ark::Model m; ark::Tensor *a = m.tensor(ark::Dims(128, 64), ark::FP16); ark::Tensor *b = m.tensor(ark::Dims(64, 256), ark::FP16); ark::Tensor *c = m.matmul(a, b, nullptr, 1, false, false, "matmul", 1); - auto result = ark::op_test("matmul_gran1", m, {a, b}, {c}, + auto result = ark::op_test("matmul_fp16_gran1", m, {a, b}, {c}, baseline_matmul_nn); UNITTEST_LOG(result); UNITTEST_EQ(result.max_diff[0], 0.0f); @@ -389,7 +310,7 @@ ark::unittest::State test_matmul_gran1() { ark::Tensor *b = m.tensor(ark::Dims(8192, 16384), ark::FP16); ark::Tensor *c = m.matmul(a, b, nullptr, 1, false, false, "matmul", 1); - auto result = ark::op_test("matmul_gran1", m, {a, b}, {c}, + auto result = ark::op_test("matmul_fp16_gran1", m, {a, b}, {c}, baseline_matmul_nn); UNITTEST_LOG(result); UNITTEST_EQ(result.max_diff[0], 0.0f); @@ -397,14 +318,14 @@ ark::unittest::State test_matmul_gran1() { return ark::unittest::SUCCESS; } -ark::unittest::State test_matmul_gran2() { +ark::unittest::State test_matmul_fp16_gran2() { { ark::Model m; ark::Tensor *a = m.tensor(ark::Dims(128, 64), ark::FP16); ark::Tensor *b = m.tensor(ark::Dims(64, 256), ark::FP16); ark::Tensor *c = m.matmul(a, b, nullptr, 1, false, false, "matmul", 2); - auto result = ark::op_test("matmul_gran2", m, {a, b}, {c}, + auto result = ark::op_test("matmul_fp16_gran2", m, {a, b}, {c}, baseline_matmul_nn); UNITTEST_LOG(result); UNITTEST_EQ(result.max_diff[0], 0.0f); @@ -415,7 +336,7 @@ ark::unittest::State test_matmul_gran2() { ark::Tensor *b = m.tensor(ark::Dims(8192, 16384), ark::FP16); ark::Tensor *c = m.matmul(a, b, nullptr, 1, false, false, "matmul", 2); - auto result = ark::op_test("matmul_gran2", m, {a, b}, {c}, + auto result = ark::op_test("matmul_fp16_gran2", m, {a, b}, {c}, baseline_matmul_nn); UNITTEST_LOG(result); UNITTEST_EQ(result.max_diff[0], 0.0f); @@ -433,6 +354,7 @@ ark::unittest::State test_matmul_split() { auto result = ark::op_test("matmul_split", m, {a, b}, {c}, baseline_matmul_nn); UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-4f); } return ark::unittest::SUCCESS; } @@ -447,6 +369,7 @@ ark::unittest::State test_matmul_fp32() { auto result = ark::op_test("matmul_fp32", m, {a, b}, {c}, baseline_matmul_nn); UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-4f); } { ark::Model m; @@ -457,11 +380,38 @@ ark::unittest::State test_matmul_fp32() { auto result = ark::op_test("matmul_fp32", m, {a, b}, {c}, baseline_matmul_nn); UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-4f); + } + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_matmul_bf16() { + { + ark::Model m; + ark::Tensor *a = m.tensor(ark::Dims(128, 64), ark::BF16); + ark::Tensor *b = m.tensor(ark::Dims(64, 256), ark::BF16); + ark::Tensor *c = m.matmul(a, b, nullptr, 1, false, false, "matmul", 0); + + auto result = ark::op_test("matmul_bf16", m, {a, b}, {c}, + baseline_matmul_nn); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-5f); + } + { + ark::Model m; + ark::Tensor *a = m.tensor(ark::Dims(4096, 8192), ark::BF16); + ark::Tensor *b = m.tensor(ark::Dims(8192, 16384), ark::BF16); + ark::Tensor *c = m.matmul(a, b, nullptr, 1, false, false, "matmul", 0); + + auto result = ark::op_test("matmul_bf16", m, {a, b}, {c}, + baseline_matmul_nn); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-5f); } return ark::unittest::SUCCESS; } -ark::unittest::State test_matmul_nt() { +ark::unittest::State test_matmul_fp16_nt() { { ark::Model m; ark::Tensor *a = m.tensor(ark::Dims(128, 64), ark::FP16); @@ -487,7 +437,7 @@ ark::unittest::State test_matmul_nt() { return ark::unittest::SUCCESS; } -ark::unittest::State test_matmul_tn() { +ark::unittest::State test_matmul_fp16_tn() { { ark::Model m; ark::Tensor *a = m.tensor(ark::Dims(64, 128), ark::FP16); @@ -516,7 +466,7 @@ ark::unittest::State test_matmul_tn() { return ark::unittest::SUCCESS; } -ark::unittest::State test_matmul_tt() { +ark::unittest::State test_matmul_fp16_tt() { { ark::Model m; ark::Tensor *a = m.tensor(ark::Dims(64, 128), ark::FP16); @@ -563,19 +513,21 @@ ark::unittest::State test_matmul_batched_padded() { auto result = ark::op_test("matmul_batched_padded", m, {a, b}, {c}, baseline_matmul_nn); UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-3f); return ark::unittest::SUCCESS; } int main() { ark::init(); - UNITTEST(test_matmul_gran0); - UNITTEST(test_matmul_gran1); - UNITTEST(test_matmul_gran2); + UNITTEST(test_matmul_fp16_gran0); + UNITTEST(test_matmul_fp16_gran1); + UNITTEST(test_matmul_fp16_gran2); UNITTEST(test_matmul_split); UNITTEST(test_matmul_fp32); - UNITTEST(test_matmul_nt); - UNITTEST(test_matmul_tn); - UNITTEST(test_matmul_tt); + UNITTEST(test_matmul_bf16); + UNITTEST(test_matmul_fp16_nt); + UNITTEST(test_matmul_fp16_tn); + UNITTEST(test_matmul_fp16_tt); UNITTEST(test_matmul_batched); UNITTEST(test_matmul_batched_padded); diff --git a/ark/ops/ops_max_pool.cc b/ark/ops/ops_max_pool.cc index 83d60ac0f..54183c4d6 100644 --- a/ark/ops/ops_max_pool.cc +++ b/ark/ops/ops_max_pool.cc @@ -8,8 +8,8 @@ namespace ark { -MaxPoolOp::MaxPoolOp(OpPrecType prec_type, Tensor *input, Tensor *output, - DimType kernel_size, DimType stride, +MaxPoolOp::MaxPoolOp(const std::string &prec_type, Tensor *input, + Tensor *output, DimType kernel_size, DimType stride, const std::string &name) : Op{OP_MAX_POOL, prec_type, {input}, {output}, {{kernel_size, stride}}, name, nullptr, -1} {} @@ -18,14 +18,6 @@ MaxPoolOp::MaxPoolOp(OpPrecType prec_type, Tensor *input, Tensor *output, Tensor *Model::max_pool(Tensor *input, DimType kernel_size, DimType stride, Tensor *output, const std::string &name) { assert(input != nullptr); - OpPrecType pt = OP_PREC_NONE; - if (input->type == FP16) { - pt = OP_PREC_FP16; - } else if (input->type == FP32) { - pt = OP_PREC_FP32; - } else { - LOG(ERROR, "unsupported input data type: ", input->type); - } if (output != nullptr && input->type != output->type) { LOG(ERROR, "invalid output data type: ", output->type); } @@ -35,7 +27,7 @@ Tensor *Model::max_pool(Tensor *input, DimType kernel_size, DimType stride, if (output == nullptr) { output = this->tensor(os, input->type); } - MaxPoolOp op{pt, input, output, kernel_size, stride, name}; + MaxPoolOp op{output->type.name(), input, output, kernel_size, stride, name}; return this->impl->add_op(op)[0]; } diff --git a/ark/ops/ops_mul.cc b/ark/ops/ops_mul.cc index b22fb77e2..211aa6a32 100644 --- a/ark/ops/ops_mul.cc +++ b/ark/ops/ops_mul.cc @@ -10,8 +10,8 @@ namespace ark { extern const OpConfigMap ArithmeticConfigMap; -MulOp::MulOp(OpPrecType prec_type, Tensor *input, Tensor *other, Tensor *output, - const std::string &name) +MulOp::MulOp(const std::string &prec_type, Tensor *input, Tensor *other, + Tensor *output, const std::string &name) : Op{OP_MUL, prec_type, {input, other}, {output}, {}, name, &ArithmeticConfigMap, -1, true} {} @@ -48,14 +48,6 @@ Tensor *Model::mul(Tensor *input, Tensor *other, Tensor *output, const std::string &name) { assert(input != nullptr); assert(other != nullptr); - OpPrecType pt = OP_PREC_NONE; - if (input->type == FP16) { - pt = OP_PREC_FP16; - } else if (input->type == FP32) { - pt = OP_PREC_FP32; - } else { - LOG(ERROR, "unsupported input data type: ", input->type); - } if (input->type != other->type) { LOG(ERROR, "input data types mismatch: ", input->type, ", ", other->type); @@ -71,7 +63,7 @@ Tensor *Model::mul(Tensor *input, Tensor *other, Tensor *output, } else if (output == input) { output = this->identity(output); } - MulOp op{pt, input, other, output, name}; + MulOp op{output->type.name(), input, other, output, name}; return this->impl->add_op(op)[0]; } diff --git a/ark/ops/ops_reduce.cc b/ark/ops/ops_reduce.cc index 5d812bde8..6c1856928 100644 --- a/ark/ops/ops_reduce.cc +++ b/ark/ops/ops_reduce.cc @@ -8,7 +8,7 @@ namespace ark { -ReduceOp::ReduceOp(const OpType &type, const OpPrecType &prec_type, +ReduceOp::ReduceOp(const OpType &type, const std::string &prec_type, const std::vector &inputs, const std::vector &outputs, const OpArgs &args, const std::string &name, const OpConfigMap *cfg_map, @@ -63,8 +63,8 @@ std::string ReduceOp::function_name(const OpConfig &cfg, extern const OpConfigMap ReduceWConfigMap; extern const OpConfigMap ReduceEConfigMap; -ReduceWSumOp::ReduceWSumOp(OpPrecType prec_type, Tensor *input, Tensor *output, - int axis, const std::string &name) +ReduceWSumOp::ReduceWSumOp(const std::string &prec_type, Tensor *input, + Tensor *output, int axis, const std::string &name) : ReduceOp{OP_REDUCE_W_SUM, prec_type, {input}, {output}, {{axis}}, name, &ReduceWConfigMap, -1} {} @@ -72,8 +72,8 @@ std::string ReduceWSumOp::function_name(const OpConfig &cfg) const { return ReduceOp::function_name(cfg, "w_sum"); } -ReduceESumOp::ReduceESumOp(OpPrecType prec_type, Tensor *input, Tensor *output, - int axis, const std::string &name) +ReduceESumOp::ReduceESumOp(const std::string &prec_type, Tensor *input, + Tensor *output, int axis, const std::string &name) : ReduceOp{OP_REDUCE_E_SUM, prec_type, {input}, {output}, {{axis}}, name, &ReduceEConfigMap, -1} {} @@ -81,8 +81,8 @@ std::string ReduceESumOp::function_name(const OpConfig &cfg) const { return ReduceOp::function_name(cfg, "e_sum"); } -ReduceWMaxOp::ReduceWMaxOp(OpPrecType prec_type, Tensor *input, Tensor *output, - int axis, const std::string &name) +ReduceWMaxOp::ReduceWMaxOp(const std::string &prec_type, Tensor *input, + Tensor *output, int axis, const std::string &name) : ReduceOp{OP_REDUCE_W_MAX, prec_type, {input}, {output}, {{axis}}, name, &ReduceWConfigMap, -1} {} @@ -90,8 +90,8 @@ std::string ReduceWMaxOp::function_name(const OpConfig &cfg) const { return ReduceOp::function_name(cfg, "w_max"); } -ReduceEMaxOp::ReduceEMaxOp(OpPrecType prec_type, Tensor *input, Tensor *output, - int axis, const std::string &name) +ReduceEMaxOp::ReduceEMaxOp(const std::string &prec_type, Tensor *input, + Tensor *output, int axis, const std::string &name) : ReduceOp{OP_REDUCE_E_MAX, prec_type, {input}, {output}, {{axis}}, name, &ReduceEConfigMap, -1} {} @@ -99,7 +99,7 @@ std::string ReduceEMaxOp::function_name(const OpConfig &cfg) const { return ReduceOp::function_name(cfg, "e_max"); } -ReduceWMeanOp::ReduceWMeanOp(OpPrecType prec_type, Tensor *input, +ReduceWMeanOp::ReduceWMeanOp(const std::string &prec_type, Tensor *input, Tensor *output, int axis, const std::string &name) : ReduceOp{OP_REDUCE_W_MEAN, prec_type, {input}, {output}, {{axis}}, name, &ReduceWConfigMap, -1} {} @@ -108,7 +108,7 @@ std::string ReduceWMeanOp::function_name(const OpConfig &cfg) const { return ReduceOp::function_name(cfg, "w_mean"); } -ReduceEMeanOp::ReduceEMeanOp(OpPrecType prec_type, Tensor *input, +ReduceEMeanOp::ReduceEMeanOp(const std::string &prec_type, Tensor *input, Tensor *output, int axis, const std::string &name) : ReduceOp{OP_REDUCE_E_MEAN, prec_type, {input}, {output}, {{axis}}, name, &ReduceEConfigMap, -1} {} @@ -121,14 +121,6 @@ template Tensor *Model::reduce(Tensor *input, int axis, Tensor *output, const std::string &name) { assert(input != nullptr); - OpPrecType pt = OP_PREC_NONE; - if (input->type == FP16) { - pt = OP_PREC_FP16; - } else if (input->type == FP32) { - pt = OP_PREC_FP32; - } else { - LOG(ERROR, "unsupported input data type: ", input->type); - } if (output != nullptr && input->type != output->type) { LOG(ERROR, "invalid output data type: ", output->type); } @@ -148,7 +140,7 @@ Tensor *Model::reduce(Tensor *input, int axis, Tensor *output, "reduce_sum op"); } } - ReduceOpType op{pt, input, output, axis, name}; + ReduceOpType op{output->type.name(), input, output, axis, name}; return this->impl->add_op(op)[0]; } @@ -180,7 +172,7 @@ Tensor *Model::reduce_max(Tensor *input, int axis, Tensor *output, } const OpConfigMap ReduceEConfigMap = { - {{OP_ARCH_CUDA_ANY, OP_PREC_ANY}, + {{OP_ARCH_CUDA_ANY, "any"}, { // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost {8, 0, {{128, 256}, {128, 256}}, {{128, 256}}, true, false}, @@ -199,7 +191,7 @@ const OpConfigMap ReduceEConfigMap = { }; const OpConfigMap ReduceWConfigMap = { - {{OP_ARCH_CUDA_ANY, OP_PREC_ANY}, + {{OP_ARCH_CUDA_ANY, "any"}, { // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost {1, 128, {{32, 1}}, {{32, 1}}, true, false}, diff --git a/ark/ops/ops_reduce_test.cc b/ark/ops/ops_reduce_test.cc index 583e7dd39..43f8177ea 100644 --- a/ark/ops/ops_reduce_test.cc +++ b/ark/ops/ops_reduce_test.cc @@ -24,12 +24,12 @@ void baseline_reduce_sum_axis0(std::vector &outputs, for (ark::DimType c = 0; c < ish[1]; ++c) { for (ark::DimType h = 0; h < ish[2]; ++h) { for (ark::DimType w = 0; w < ish[3]; ++w) { - T sum = 0; + float sum = 0; for (ark::DimType n = 0; n < ish[0]; ++n) { - sum += input[n * ish[1] * ish[2] * ish[3] + - c * ish[2] * ish[3] + h * ish[3] + w]; + sum += float(input[n * ish[1] * ish[2] * ish[3] + + c * ish[2] * ish[3] + h * ish[3] + w]); } - out[c * osh[2] * osh[3] + h * osh[3] + w] = sum; + out[c * osh[2] * osh[3] + h * osh[3] + w] = T(sum); } } } @@ -51,12 +51,12 @@ void baseline_reduce_sum_axis1(std::vector &outputs, for (ark::DimType n = 0; n < ish[0]; ++n) { for (ark::DimType h = 0; h < ish[2]; ++h) { for (ark::DimType w = 0; w < ish[3]; ++w) { - T sum = 0; + float sum = 0; for (ark::DimType c = 0; c < ish[1]; ++c) { - sum += input[n * ish[1] * ish[2] * ish[3] + - c * ish[2] * ish[3] + h * ish[3] + w]; + sum += float(input[n * ish[1] * ish[2] * ish[3] + + c * ish[2] * ish[3] + h * ish[3] + w]); } - out[n * osh[1] * osh[2] * osh[3] + h * osh[3] + w] = sum; + out[n * osh[1] * osh[2] * osh[3] + h * osh[3] + w] = T(sum); } } } @@ -78,13 +78,13 @@ void baseline_reduce_sum_axis2(std::vector &outputs, for (ark::DimType n = 0; n < ish[0]; ++n) { for (ark::DimType c = 0; c < ish[1]; ++c) { for (ark::DimType w = 0; w < ish[3]; ++w) { - T sum = 0; + float sum = 0; for (ark::DimType h = 0; h < ish[2]; ++h) { - sum += input[n * ish[1] * ish[2] * ish[3] + - c * ish[2] * ish[3] + h * ish[3] + w]; + sum += float(input[n * ish[1] * ish[2] * ish[3] + + c * ish[2] * ish[3] + h * ish[3] + w]); } out[n * osh[1] * osh[2] * osh[3] + c * osh[2] * osh[3] + w] = - sum; + T(sum); } } } @@ -106,13 +106,13 @@ void baseline_reduce_sum_axis3(std::vector &outputs, for (ark::DimType n = 0; n < ish[0]; ++n) { for (ark::DimType c = 0; c < ish[1]; ++c) { for (ark::DimType h = 0; h < ish[2]; ++h) { - T sum = 0; + float sum = 0; for (ark::DimType w = 0; w < ish[3]; ++w) { - sum += input[n * ish[1] * ish[2] * ish[3] + - c * ish[2] * ish[3] + h * ish[3] + w]; + sum += float(input[n * ish[1] * ish[2] * ish[3] + + c * ish[2] * ish[3] + h * ish[3] + w]); } out[n * osh[1] * osh[2] * osh[3] + c * osh[2] * osh[3] + - h * osh[3]] = sum; + h * osh[3]] = T(sum); } } } @@ -189,7 +189,7 @@ ark::unittest::State test_reduce_fp16() { auto result = ark::op_test("reduce_fp16_axis0", m, {t}, {out}, baseline_reduce_sum_axis0); UNITTEST_LOG(result); - UNITTEST_EQ(result.max_diff[0], 0.0f); + UNITTEST_TRUE(result.max_diff[0] < 1e-2f); } { ark::Model m; @@ -203,6 +203,29 @@ ark::unittest::State test_reduce_fp16() { return ark::unittest::SUCCESS; } +ark::unittest::State test_reduce_bf16() { + { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(7, 2, 4, 1024), ark::BF16); + ark::Tensor *out = m.reduce_sum(t, /*axis=*/0); + + auto result = ark::op_test("reduce_bf16_axis0", m, {t}, {out}, + baseline_reduce_sum_axis0); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-2f); + } + { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(7, 2, 4, 1024), ark::BF16); + ark::Tensor *out = m.reduce_sum(t, /*axis=*/3); + + auto result = ark::op_test("reduce_bf16_axis3", m, {t}, {out}, + baseline_reduce_sum_axis3); + UNITTEST_LOG(result); + } + return ark::unittest::SUCCESS; +} + int main() { ark::init(); UNITTEST(test_reduce_axis0); @@ -211,5 +234,6 @@ int main() { UNITTEST(test_reduce_axis3); UNITTEST(test_reduce_axis3_padded); UNITTEST(test_reduce_fp16); + UNITTEST(test_reduce_bf16); return ark::unittest::SUCCESS; } diff --git a/ark/ops/ops_relu.cc b/ark/ops/ops_relu.cc index 575a628ab..ab71471bc 100644 --- a/ark/ops/ops_relu.cc +++ b/ark/ops/ops_relu.cc @@ -10,7 +10,7 @@ namespace ark { extern const OpConfigMap ActivationConfigMap; -ReluOp::ReluOp(OpPrecType prec_type, Tensor *input, Tensor *output, +ReluOp::ReluOp(const std::string &prec_type, Tensor *input, Tensor *output, const std::string &name) : Op{OP_RELU, prec_type, {input}, {output}, {}, name, &ActivationConfigMap, -1, true} {} @@ -43,14 +43,6 @@ std::string ReluOp::function_name(const OpConfig &cfg) const { Tensor *Model::relu(Tensor *input, Tensor *output, const std::string &name) { assert(input != nullptr); - OpPrecType pt = OP_PREC_NONE; - if (input->type == FP16) { - pt = OP_PREC_FP16; - } else if (input->type == FP32) { - pt = OP_PREC_FP32; - } else { - LOG(ERROR, "unsupported input data type: ", input->type); - } if (output != nullptr && input->type != output->type) { LOG(ERROR, "invalid output data type: ", output->type); } @@ -59,12 +51,12 @@ Tensor *Model::relu(Tensor *input, Tensor *output, const std::string &name) { } else if (output->shape != input->shape) { LOG(ERROR, "invalid output shape: ", output->shape); } - ReluOp op{pt, input, output, name}; + ReluOp op{output->type.name(), input, output, name}; return this->impl->add_op(op)[0]; } const OpConfigMap ActivationConfigMap = { - {{OP_ARCH_CUDA_ANY, OP_PREC_FP32}, + {{OP_ARCH_CUDA_ANY, "fp32"}, { // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost {8, 0, {{128, 256}}, {{128, 256}}, false, false}, @@ -81,7 +73,22 @@ const OpConfigMap ActivationConfigMap = { {1, 0, {{1, 64}}, {{1, 64}}, false, false}, {1, 0, {{1, 32}}, {{1, 32}}, false, false}, }}, - {{OP_ARCH_CUDA_ANY, OP_PREC_FP16}, + {{OP_ARCH_CUDA_ANY, "fp16"}, + { + // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost + {8, 0, {{128, 256}}, {{128, 256}}, false, false}, + {8, 0, {{256, 128}}, {{256, 128}}, false, false}, + {8, 0, {{128, 128}}, {{128, 128}}, false, false}, + {4, 0, {{64, 64}}, {{64, 64}}, false, false}, + {2, 0, {{32, 64}}, {{32, 64}}, false, false}, + {1, 0, {{16, 64}}, {{16, 64}}, false, false}, + {1, 0, {{8, 64}}, {{8, 64}}, false, false}, + {1, 0, {{2, 128}}, {{2, 128}}, false, false}, + {1, 0, {{4, 64}}, {{4, 64}}, false, false}, + {1, 0, {{2, 64}}, {{2, 64}}, false, false}, + {1, 0, {{1, 64}}, {{1, 64}}, false, false}, + }}, + {{OP_ARCH_CUDA_ANY, "bf16"}, { // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost {8, 0, {{128, 256}}, {{128, 256}}, false, false}, diff --git a/ark/ops/ops_relu_test.cc b/ark/ops/ops_relu_test.cc index 362b2e5ef..1b9b708a6 100644 --- a/ark/ops/ops_relu_test.cc +++ b/ark/ops/ops_relu_test.cc @@ -35,8 +35,21 @@ ark::unittest::State test_relu_fp32() { return ark::unittest::SUCCESS; } +ark::unittest::State test_relu_bf16() { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::BF16); + ark::Tensor *out = m.relu(t); + + auto result = ark::op_test("relu_bf16", m, {t}, {out}, + baseline_relu); + UNITTEST_LOG(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + return ark::unittest::SUCCESS; +} + int main() { ark::init(); UNITTEST(test_relu_fp32); + UNITTEST(test_relu_bf16); return ark::unittest::SUCCESS; } diff --git a/ark/ops/ops_reshape.cc b/ark/ops/ops_reshape.cc index 96dce95e6..0aa16a946 100644 --- a/ark/ops/ops_reshape.cc +++ b/ark/ops/ops_reshape.cc @@ -8,8 +8,8 @@ namespace ark { -ReshapeOp::ReshapeOp(OpPrecType prec_type, Tensor *input, Tensor *output, - const std::string &name) +ReshapeOp::ReshapeOp(const std::string &prec_type, Tensor *input, + Tensor *output, const std::string &name) : Op{OP_RESHAPE, prec_type, {input}, {output}, {}, name, nullptr, -1, true} {} @@ -74,7 +74,7 @@ static Tensor *_reshape(Model *model, Tensor *input, const Dims &shape, Tensor *Model::reshape(Tensor *input, const Dims &shape, bool allowzero, Tensor *output, const std::string &name) { output = _reshape(this, input, shape, allowzero, output, name); - ReshapeOp op{OP_PREC_NONE, input, output, name}; + ReshapeOp op{"none", input, output, name}; return this->impl->add_op(op)[0]; } @@ -147,7 +147,7 @@ Tensor *Model::reshape(Tensor *input, const std::vector &shape, } output = _reshape(this, input, Dims{inferred_shape}, allowzero, output, name); - ReshapeOp op{OP_PREC_NONE, input, output, name}; + ReshapeOp op{"none", input, output, name}; return this->impl->add_op(op)[0]; } diff --git a/ark/ops/ops_rmsnorm.cc b/ark/ops/ops_rmsnorm.cc index e1208d09f..1ef702b22 100644 --- a/ark/ops/ops_rmsnorm.cc +++ b/ark/ops/ops_rmsnorm.cc @@ -10,8 +10,8 @@ namespace ark { extern const OpConfigMap LayernormConfigMap; -RMSnormOp::RMSnormOp(OpPrecType prec_type, Tensor *input, Tensor *output, - const std::string &name) +RMSnormOp::RMSnormOp(const std::string &prec_type, Tensor *input, + Tensor *output, const std::string &name) : Op{OP_RMSNORM, prec_type, {input}, {output}, {}, name, &LayernormConfigMap, -1, true} {} @@ -43,14 +43,6 @@ std::string RMSnormOp::function_name(const OpConfig &cfg) const { Tensor *Model::rmsnorm(Tensor *input, Tensor *output, const std::string &name) { assert(input != nullptr); - OpPrecType pt = OP_PREC_NONE; - if (input->type == FP16) { - pt = OP_PREC_FP16; - } else if (input->type == FP32) { - pt = OP_PREC_FP32; - } else { - LOG(ERROR, "unsupported input data type: ", input->type); - } if (output != nullptr && input->type != output->type) { LOG(ERROR, "invalid output data type: ", output->type); } @@ -59,7 +51,7 @@ Tensor *Model::rmsnorm(Tensor *input, Tensor *output, const std::string &name) { } else if (output == input) { output = this->identity(output); } - RMSnormOp op{pt, input, output, name}; + RMSnormOp op{output->type.name(), input, output, name}; return this->impl->add_op(op)[0]; }; diff --git a/ark/ops/ops_rmsnorm_test.cc b/ark/ops/ops_rmsnorm_test.cc index 9bd350586..0721a24ac 100644 --- a/ark/ops/ops_rmsnorm_test.cc +++ b/ark/ops/ops_rmsnorm_test.cc @@ -65,9 +65,20 @@ ark::unittest::State test_rmsnorm_fp16() { return ark::unittest::SUCCESS; } +ark::unittest::State test_rmsnorm_bf16() { + ark::Model model; + ark::Tensor *input = model.tensor(ark::Dims(1, 32, 32, 8192), ark::BF16); + ark::Tensor *output = model.rmsnorm(input); + auto result = ark::op_test("rmsnorm_bf16", model, {input}, {output}, + baseline_rmsnorm); + UNITTEST_LOG(result); + return ark::unittest::SUCCESS; +} + int main() { ark::init(); UNITTEST(test_rmsnorm_fp32); UNITTEST(test_rmsnorm_fp16); + UNITTEST(test_rmsnorm_bf16); return ark::unittest::SUCCESS; } diff --git a/ark/ops/ops_rope.cc b/ark/ops/ops_rope.cc index 26bdebeea..d693087ab 100644 --- a/ark/ops/ops_rope.cc +++ b/ark/ops/ops_rope.cc @@ -10,7 +10,7 @@ namespace ark { extern const OpConfigMap ArithmeticConfigMap; -RopeOp::RopeOp(OpPrecType prec_type, Tensor *input, Tensor *other, +RopeOp::RopeOp(const std::string &prec_type, Tensor *input, Tensor *other, Tensor *output, const std::string &name) : Op{OP_ROPE, prec_type, {input, other}, {output}, {}, name, &ArithmeticConfigMap, -1, @@ -49,14 +49,6 @@ Tensor *Model::rope(Tensor *input, Tensor *other, Tensor *output, const std::string &name) { assert(input != nullptr); assert(other != nullptr); - OpPrecType pt = OP_PREC_NONE; - if (input->type == FP16) { - pt = OP_PREC_FP16; - } else if (input->type == FP32) { - pt = OP_PREC_FP32; - } else { - LOG(ERROR, "unsupported input data type: ", input->type); - } if (input->type != other->type) { LOG(ERROR, "input data types mismatch: ", input->type, ", ", other->type); @@ -72,7 +64,7 @@ Tensor *Model::rope(Tensor *input, Tensor *other, Tensor *output, } else if (output == input) { output = this->identity(output); } - RopeOp op{pt, input, other, output, name}; + RopeOp op{output->type.name(), input, other, output, name}; return this->impl->add_op(op)[0]; } diff --git a/ark/ops/ops_rope_test.cc b/ark/ops/ops_rope_test.cc index 2b1788ff1..9d3f7e1a2 100644 --- a/ark/ops/ops_rope_test.cc +++ b/ark/ops/ops_rope_test.cc @@ -62,9 +62,22 @@ ark::unittest::State test_rope_fp16() { return ark::unittest::SUCCESS; } +ark::unittest::State test_rope_bf16() { + ark::Model model; + ark::Tensor *input = model.tensor(ark::Dims(1, 32, 32, 256), ark::BF16); + ark::Tensor *other = model.tensor(ark::Dims(1, 32, 32, 256), ark::BF16); + ark::Tensor *out = model.rope(input, other); + auto result = ark::op_test("rope", model, {input, other}, {out}, + baseline_rope); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-3f); + return ark::unittest::SUCCESS; +} + int main() { ark::init(); UNITTEST(test_rope_fp32); UNITTEST(test_rope_fp16); + UNITTEST(test_rope_bf16); return ark::unittest::SUCCESS; } diff --git a/ark/ops/ops_scale.cc b/ark/ops/ops_scale.cc index 574b341fd..793ddcf3b 100644 --- a/ark/ops/ops_scale.cc +++ b/ark/ops/ops_scale.cc @@ -10,8 +10,8 @@ namespace ark { extern const OpConfigMap ScaleConfigMap; -ScaleOp::ScaleOp(OpPrecType prec_type, Tensor *input, Tensor *output, float val, - const std::string &name) +ScaleOp::ScaleOp(const std::string &prec_type, Tensor *input, Tensor *output, + float val, const std::string &name) : Op{OP_SCALE, prec_type, {input}, {output}, {{val}}, name, &ScaleConfigMap, -1, true} {} @@ -58,14 +58,6 @@ OpArgs ScaleOp::function_call_args(const OpConfig &) const { Tensor *Model::scale(Tensor *input, float val, Tensor *output, const std::string &name) { assert(input != nullptr); - OpPrecType pt = OP_PREC_NONE; - if (input->type == FP16) { - pt = OP_PREC_FP16; - } else if (input->type == FP32) { - pt = OP_PREC_FP32; - } else { - LOG(ERROR, "unsupported input data type: ", input->type); - } if (output != nullptr && input->type != output->type) { LOG(ERROR, "invalid output data type: ", output->type); } @@ -74,12 +66,12 @@ Tensor *Model::scale(Tensor *input, float val, Tensor *output, } else if (output->shape != input->shape) { LOG(ERROR, "invalid output shape: ", output->shape); } - ScaleOp op{pt, input, output, val, name}; + ScaleOp op{output->type.name(), input, output, val, name}; return this->impl->add_op(op)[0]; } const OpConfigMap ScaleConfigMap = { - {{OP_ARCH_CUDA_ANY, OP_PREC_FP32}, + {{OP_ARCH_CUDA_ANY, "fp32"}, { // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost {8, 0, {{128, 256}}, {{128, 256}}, false, false}, @@ -96,7 +88,22 @@ const OpConfigMap ScaleConfigMap = { {1, 0, {{1, 64}}, {{1, 64}}, false, false}, {1, 0, {{1, 32}}, {{1, 32}}, false, false}, }}, - {{OP_ARCH_CUDA_ANY, OP_PREC_FP16}, + {{OP_ARCH_CUDA_ANY, "fp16"}, + { + // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost + {8, 0, {{128, 256}}, {{128, 256}}, false, false}, + {8, 0, {{256, 128}}, {{256, 128}}, false, false}, + {8, 0, {{128, 128}}, {{128, 128}}, false, false}, + {4, 0, {{64, 64}}, {{64, 64}}, false, false}, + {2, 0, {{32, 64}}, {{32, 64}}, false, false}, + {1, 0, {{16, 64}}, {{16, 64}}, false, false}, + {1, 0, {{8, 64}}, {{8, 64}}, false, false}, + {1, 0, {{2, 128}}, {{2, 128}}, false, false}, + {1, 0, {{4, 64}}, {{4, 64}}, false, false}, + {1, 0, {{2, 64}}, {{2, 64}}, false, false}, + {1, 0, {{1, 64}}, {{1, 64}}, false, false}, + }}, + {{OP_ARCH_CUDA_ANY, "bf16"}, { // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost {8, 0, {{128, 256}}, {{128, 256}}, false, false}, diff --git a/ark/ops/ops_scale_test.cc b/ark/ops/ops_scale_test.cc index 50b785016..0b9f0750d 100644 --- a/ark/ops/ops_scale_test.cc +++ b/ark/ops/ops_scale_test.cc @@ -68,9 +68,34 @@ ark::unittest::State test_scale_fp16() { return ark::unittest::SUCCESS; } +ark::unittest::State test_scale_bf16() { + { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1), ark::BF16); + ark::Tensor *out = m.scale(t, SCALE_FACTOR); + + auto result = ark::op_test("scale_bf16_small", m, {t}, {out}, + baseline_scale); + UNITTEST_LOG(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + } + { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::BF16); + ark::Tensor *out = m.scale(t, SCALE_FACTOR); + + auto result = ark::op_test("scale_bf16", m, {t}, {out}, + baseline_scale); + UNITTEST_LOG(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + } + return ark::unittest::SUCCESS; +} + int main() { ark::init(); UNITTEST(test_scale_fp32); UNITTEST(test_scale_fp16); + UNITTEST(test_scale_bf16); return ark::unittest::SUCCESS; } diff --git a/ark/ops/ops_sendrecv.cc b/ark/ops/ops_sendrecv.cc index 3af3f3ece..913b77097 100644 --- a/ark/ops/ops_sendrecv.cc +++ b/ark/ops/ops_sendrecv.cc @@ -10,7 +10,7 @@ namespace ark { extern const OpConfigMap CommConfigMap; -SendOp::SendOp(OpPrecType prec_type, Tensor *input, int sid, int rank, +SendOp::SendOp(const std::string &prec_type, Tensor *input, int sid, int rank, int dst_rank, size_t bytes, const std::string &name) : Op{OP_SEND, prec_type, @@ -45,7 +45,7 @@ std::string SendOp::function_name(const OpConfig &) const { OpArgs SendOp::function_call_args(const OpConfig &) const { return {}; } -SendDoneOp::SendDoneOp(OpPrecType prec_type, Tensor *input, int sid, int rank, +SendDoneOp::SendDoneOp(const std::string &prec_type, Tensor *input, int sid, int rank, int dst_rank, const std::string &name) : Op{OP_SEND_DONE, prec_type, @@ -74,7 +74,7 @@ std::string SendDoneOp::function_name(const OpConfig &) const { OpArgs SendDoneOp::function_call_args(const OpConfig &) const { return {}; } -RecvOp::RecvOp(OpPrecType prec_type, Tensor *output, int sid, int rank, +RecvOp::RecvOp(const std::string &prec_type, Tensor *output, int sid, int rank, int src_rank, size_t bytes, const std::string &name) : Op{OP_RECV, prec_type, @@ -117,14 +117,14 @@ Tensor *Model::send(Tensor *input, int id, int dst_rank, size_t bytes, bytes = max_bytes; } input->exported = true; - SendOp op{OP_PREC_NONE, input, id, this->impl->rank, dst_rank, bytes, name}; + SendOp op{"none", input, id, this->impl->rank, dst_rank, bytes, name}; return this->impl->add_op(op)[0]; } // Tensor *Model::send_done(Tensor *input, int id, int dst_rank, const std::string &name) { - SendDoneOp op{OP_PREC_NONE, input, id, this->impl->rank, dst_rank, name}; + SendDoneOp op{"none", input, id, this->impl->rank, dst_rank, name}; return this->impl->add_op(op)[0]; } @@ -145,13 +145,13 @@ Tensor *Model::recv(int id, int src_rank, size_t bytes, Tensor *output, if (bytes == 0) { bytes = max_bytes; } - RecvOp op{OP_PREC_NONE, output, id, this->impl->rank, + RecvOp op{"none", output, id, this->impl->rank, src_rank, bytes, name}; return this->impl->add_op(op)[0]; } const OpConfigMap CommConfigMap = { - {{OP_ARCH_CUDA_ANY, OP_PREC_NONE}, + {{OP_ARCH_CUDA_ANY, "none"}, { // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost {1, 0, {{-1, -1}}, {{-1, -1}}, true, true}, diff --git a/ark/ops/ops_sendrecv_mm.cc b/ark/ops/ops_sendrecv_mm.cc index adaee8701..c8ef85ddf 100644 --- a/ark/ops/ops_sendrecv_mm.cc +++ b/ark/ops/ops_sendrecv_mm.cc @@ -11,7 +11,7 @@ namespace ark { extern const OpConfigMap SendRecvMMConfigMap; -SendMMOp::SendMMOp(OpPrecType prec_type, Tensor *input, Tensor *recvbuf, +SendMMOp::SendMMOp(const std::string &prec_type, Tensor *input, Tensor *recvbuf, Tensor *send_ready_flag, Tensor *output, int id, int gpu_dst, size_t bytes, const std::string &name) : Op{OP_SEND_MM, @@ -69,7 +69,7 @@ OpArgs SendMMOp::function_call_args(const OpConfig &) const { return opargs; } -RecvMMOp::RecvMMOp(OpPrecType prec_type, Tensor *input, Tensor *recvbuf, +RecvMMOp::RecvMMOp(const std::string &prec_type, Tensor *input, Tensor *recvbuf, Tensor *send_ready_flag, Tensor *output, int id, int gpu_src, size_t bytes, const std::string &name) : Op{OP_RECV_MM, @@ -161,8 +161,8 @@ Tensor *Model::send_mm(Tensor *input, int id, int gpu_dst, size_t bytes, }, INT32); send_ready_flag->exported = true; - SendMMOp op{OP_PREC_NONE, input, recvbuf, send_ready_flag, output, id, - gpu_dst, bytes, name}; + SendMMOp op{"none", input, recvbuf, send_ready_flag, output, id, + gpu_dst, bytes, name}; return this->impl->add_op(op)[0]; } @@ -201,13 +201,13 @@ Tensor *Model::recv_mm(Tensor *input, int id, int gpu_src, size_t bytes, }, INT32); send_ready_flag->imported_rank = gpu_src; - RecvMMOp op{OP_PREC_NONE, input, recvbuf, send_ready_flag, output, id, - gpu_src, bytes, name}; + RecvMMOp op{"none", input, recvbuf, send_ready_flag, output, id, + gpu_src, bytes, name}; return this->impl->add_op(op)[0]; } const OpConfigMap SendRecvMMConfigMap = { - {{OP_ARCH_CUDA_ANY, OP_PREC_NONE}, + {{OP_ARCH_CUDA_ANY, "none"}, { // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost {4, 0, {{64, 64}, {64, 64}, {1, 1}}, {{64, 64}}, false, false}, diff --git a/ark/ops/ops_sigmoid.cc b/ark/ops/ops_sigmoid.cc index a2ba7789e..fe5ea8d70 100644 --- a/ark/ops/ops_sigmoid.cc +++ b/ark/ops/ops_sigmoid.cc @@ -10,8 +10,8 @@ namespace ark { extern const OpConfigMap ActivationConfigMap; -SigmoidOp::SigmoidOp(OpPrecType prec_type, Tensor *input, Tensor *output, - const std::string &name) +SigmoidOp::SigmoidOp(const std::string &prec_type, Tensor *input, + Tensor *output, const std::string &name) : Op{OP_SIGMOID, prec_type, {input}, {output}, {}, name, &ActivationConfigMap, -1, true} {} @@ -43,14 +43,6 @@ std::string SigmoidOp::function_name(const OpConfig &cfg) const { Tensor *Model::sigmoid(Tensor *input, Tensor *output, const std::string &name) { assert(input != nullptr); - OpPrecType pt = OP_PREC_NONE; - if (input->type == FP16) { - pt = OP_PREC_FP16; - } else if (input->type == FP32) { - pt = OP_PREC_FP32; - } else { - LOG(ERROR, "unsupported input data type: ", input->type); - } if (output != nullptr && input->type != output->type) { LOG(ERROR, "invalid output data type: ", output->type); } @@ -59,7 +51,7 @@ Tensor *Model::sigmoid(Tensor *input, Tensor *output, const std::string &name) { } else if (output->shape != input->shape) { LOG(ERROR, "invalid output shape: ", output->shape); } - SigmoidOp op{pt, input, output, name}; + SigmoidOp op{output->type.name(), input, output, name}; return this->impl->add_op(op)[0]; } diff --git a/ark/ops/ops_sigmoid_test.cc b/ark/ops/ops_sigmoid_test.cc index e3410beb9..07b5fa7d7 100644 --- a/ark/ops/ops_sigmoid_test.cc +++ b/ark/ops/ops_sigmoid_test.cc @@ -35,8 +35,21 @@ ark::unittest::State test_sigmoid_fp32() { return ark::unittest::SUCCESS; } +ark::unittest::State test_sigmoid_bf16() { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::BF16); + ark::Tensor *out = m.sigmoid(t); + + auto result = ark::op_test("sigmoid_bf16", m, {t}, {out}, + baseline_sigmoid); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-2f); + return ark::unittest::SUCCESS; +} + int main() { ark::init(); UNITTEST(test_sigmoid_fp32); + UNITTEST(test_sigmoid_bf16); return ark::unittest::SUCCESS; } diff --git a/ark/ops/ops_softmax.cc b/ark/ops/ops_softmax.cc index a4f406e47..4905ecea3 100644 --- a/ark/ops/ops_softmax.cc +++ b/ark/ops/ops_softmax.cc @@ -10,8 +10,8 @@ namespace ark { extern const OpConfigMap SoftmaxConfigMap; -SoftmaxOp::SoftmaxOp(OpPrecType prec_type, Tensor *input, Tensor *output, - const std::string &name) +SoftmaxOp::SoftmaxOp(const std::string &prec_type, Tensor *input, + Tensor *output, const std::string &name) : Op{OP_SOFTMAX, prec_type, {input}, {output}, {}, name, &SoftmaxConfigMap, -1, true} {} @@ -36,14 +36,6 @@ std::string SoftmaxOp::function_name(const OpConfig &cfg) const { Tensor *Model::softmax(Tensor *input, Tensor *output, const std::string &name) { assert(input != nullptr); - OpPrecType pt = OP_PREC_NONE; - if (input->type == FP16) { - pt = OP_PREC_FP16; - } else if (input->type == FP32) { - pt = OP_PREC_FP32; - } else { - LOG(ERROR, "unsupported input data type: ", input->type); - } if (output != nullptr && input->type != output->type) { LOG(ERROR, "invalid output data type: ", output->type); } @@ -52,12 +44,12 @@ Tensor *Model::softmax(Tensor *input, Tensor *output, const std::string &name) { } else if (output == input) { output = this->identity(output); } - SoftmaxOp op{pt, input, output, name}; + SoftmaxOp op{output->type.name(), input, output, name}; return this->impl->add_op(op)[0]; } const OpConfigMap SoftmaxConfigMap = { - {{OP_ARCH_CUDA_ANY, OP_PREC_ANY}, + {{OP_ARCH_CUDA_ANY, "any"}, { // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost {1, 128, {{32, -1}}, {{32, -1}}, true, false}, diff --git a/ark/ops/ops_softmax_test.cc b/ark/ops/ops_softmax_test.cc index d98f2b642..2f711fd24 100644 --- a/ark/ops/ops_softmax_test.cc +++ b/ark/ops/ops_softmax_test.cc @@ -136,6 +136,30 @@ ark::unittest::State test_softmax_fp16_small_magnitude() { return ark::unittest::SUCCESS; } +ark::unittest::State test_softmax_bf16() { + { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(64, 8192), ark::BF16); + ark::Tensor *out = m.softmax(t); + + auto result = ark::op_test("softmax_bf16", m, {t}, {out}, + baseline_softmax); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-3f); + } + { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(1, 32, 64, 64), ark::BF16); + ark::Tensor *out = m.softmax(t); + + auto result = ark::op_test("softmax_bf16", m, {t}, {out}, + baseline_softmax); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-3f); + } + return ark::unittest::SUCCESS; +} + int main() { ark::init(); UNITTEST(test_softmax_fp32); @@ -143,5 +167,6 @@ int main() { UNITTEST(test_softmax_fp16_padded); UNITTEST(test_softmax_fp16_big_magnitude); UNITTEST(test_softmax_fp16_small_magnitude); + UNITTEST(test_softmax_bf16); return ark::unittest::SUCCESS; } diff --git a/ark/ops/ops_sqrt.cc b/ark/ops/ops_sqrt.cc index 52adf166d..962116620 100644 --- a/ark/ops/ops_sqrt.cc +++ b/ark/ops/ops_sqrt.cc @@ -12,7 +12,7 @@ namespace ark { extern const OpConfigMap MathConfigMap; -SqrtOp::SqrtOp(OpPrecType prec_type, Tensor *input, Tensor *output, +SqrtOp::SqrtOp(const std::string &prec_type, Tensor *input, Tensor *output, const string &name) : Op{OP_SQRT, prec_type, {input}, {output}, {}, name, &MathConfigMap, -1, true} {} @@ -45,14 +45,6 @@ std::string SqrtOp::function_name(const OpConfig &cfg) const { Tensor *Model::sqrt(Tensor *input, Tensor *output, const string &name) { assert(input != nullptr); - OpPrecType pt = OP_PREC_NONE; - if (input->type == FP16) { - pt = OP_PREC_FP16; - } else if (input->type == FP32) { - pt = OP_PREC_FP32; - } else { - LOG(ERROR, "unsupported input data type: ", input->type); - } if (output != nullptr && input->type != output->type) { LOG(ERROR, "invalid output data type: ", output->type); } @@ -61,12 +53,12 @@ Tensor *Model::sqrt(Tensor *input, Tensor *output, const string &name) { } else if (output->shape != input->shape) { LOG(ERROR, "invalid output shape: ", output->shape); } - SqrtOp op{pt, input, output, name}; + SqrtOp op{output->type.name(), input, output, name}; return this->impl->add_op(op)[0]; } const OpConfigMap MathConfigMap = { - {{OP_ARCH_CUDA_ANY, OP_PREC_FP32}, + {{OP_ARCH_CUDA_ANY, "fp32"}, { // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost {8, 0, {{128, 256}}, {{128, 256}}, false, false}, @@ -83,7 +75,22 @@ const OpConfigMap MathConfigMap = { {1, 0, {{1, 64}}, {{1, 64}}, false, false}, {1, 0, {{1, 32}}, {{1, 32}}, false, false}, }}, - {{OP_ARCH_CUDA_ANY, OP_PREC_FP16}, + {{OP_ARCH_CUDA_ANY, "fp16"}, + { + // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost + {8, 0, {{128, 256}}, {{128, 256}}, false, false}, + {8, 0, {{256, 128}}, {{256, 128}}, false, false}, + {8, 0, {{128, 128}}, {{128, 128}}, false, false}, + {4, 0, {{64, 64}}, {{64, 64}}, false, false}, + {2, 0, {{32, 64}}, {{32, 64}}, false, false}, + {1, 0, {{16, 64}}, {{16, 64}}, false, false}, + {1, 0, {{8, 64}}, {{8, 64}}, false, false}, + {1, 0, {{2, 128}}, {{2, 128}}, false, false}, + {1, 0, {{4, 64}}, {{4, 64}}, false, false}, + {1, 0, {{2, 64}}, {{2, 64}}, false, false}, + {1, 0, {{1, 64}}, {{1, 64}}, false, false}, + }}, + {{OP_ARCH_CUDA_ANY, "bf16"}, { // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost {8, 0, {{128, 256}}, {{128, 256}}, false, false}, diff --git a/ark/ops/ops_sub.cc b/ark/ops/ops_sub.cc index c958ff7e7..dc046793b 100644 --- a/ark/ops/ops_sub.cc +++ b/ark/ops/ops_sub.cc @@ -12,8 +12,8 @@ namespace ark { extern const OpConfigMap ArithmeticConfigMap; -SubOp::SubOp(OpPrecType prec_type, Tensor *input, Tensor *other, Tensor *output, - const string &name) +SubOp::SubOp(const std::string &prec_type, Tensor *input, Tensor *other, + Tensor *output, const string &name) : Op{OP_SUB, prec_type, {input, other}, {output}, {}, name, &ArithmeticConfigMap, -1, true} {} @@ -50,14 +50,6 @@ Tensor *Model::sub(Tensor *input, Tensor *other, Tensor *output, const string &name) { assert(input != nullptr); assert(other != nullptr); - OpPrecType pt = OP_PREC_NONE; - if (input->type == FP16) { - pt = OP_PREC_FP16; - } else if (input->type == FP32) { - pt = OP_PREC_FP32; - } else { - LOG(ERROR, "unsupported input data type: ", input->type); - } if (input->type != other->type) { LOG(ERROR, "input data types mismatch: ", input->type, ", ", other->type); @@ -73,7 +65,7 @@ Tensor *Model::sub(Tensor *input, Tensor *other, Tensor *output, } else if (output == input) { output = this->identity(output); } - SubOp op{pt, input, other, output, name}; + SubOp op{output->type.name(), input, other, output, name}; return this->impl->add_op(op)[0]; } diff --git a/ark/ops/ops_tensor.cc b/ark/ops/ops_tensor.cc index 57e5bb822..6902b9489 100644 --- a/ark/ops/ops_tensor.cc +++ b/ark/ops/ops_tensor.cc @@ -10,7 +10,7 @@ namespace ark { TensorOp::TensorOp(const std::vector &deps, Tensor *output, const std::string &name) - : Op{OP_TENSOR, OP_PREC_NONE, deps, {output}, {}, name, nullptr, -1} {} + : Op{OP_TENSOR, "none", deps, {output}, {}, name, nullptr, -1} {} Tensor *Model::tensor(const Dims &shape, const TensorType &ttype, TensorBuf *buf, const Dims &ldims, const Dims &offs, diff --git a/ark/ops/ops_test_common.cc b/ark/ops/ops_test_common.cc index c136b7927..562d18b02 100644 --- a/ark/ops/ops_test_common.cc +++ b/ark/ops/ops_test_common.cc @@ -185,6 +185,11 @@ OpsTestResult op_test(const std::string &test_name_prefix, Model &model, } else if (t->type == FP16) { ::memcpy(buf, utils::rand_halfs(t->shape.size(), 0.1).get(), t->shape_bytes()); + } else if (t->type == BF16) { + ::memcpy( + buf, + utils::rand_array(t->shape.size(), 0.1).get(), + t->shape_bytes()); } else if (t->type == INT32) { ::memcpy(buf, utils::rand_array(t->shape.size(), 10000).get(), @@ -268,6 +273,10 @@ OpsTestResult op_test(const std::string &test_name_prefix, Model &model, comp = tensor_compare(static_cast(gt[i]), static_cast(res[i]), outputs[i]->shape.dims4(), print_on_error); + } else if (outputs[i]->type == BF16) { + comp = tensor_compare(static_cast(gt[i]), + static_cast(res[i]), + outputs[i]->shape.dims4(), print_on_error); } else if (outputs[i]->type == INT32) { comp = tensor_compare(static_cast(gt[i]), static_cast(res[i]), diff --git a/ark/ops/ops_transpose.cc b/ark/ops/ops_transpose.cc index e024c6792..380a2b20f 100644 --- a/ark/ops/ops_transpose.cc +++ b/ark/ops/ops_transpose.cc @@ -10,8 +10,8 @@ namespace ark { extern const OpConfigMap TransposeConfigMap; -TransposeOp::TransposeOp(OpPrecType prec_type, Tensor *input, Tensor *output, - int tp_type, const std::string &name) +TransposeOp::TransposeOp(const std::string &prec_type, Tensor *input, + Tensor *output, int tp_type, const std::string &name) : Op{OP_TRANSPOSE, prec_type, {input}, {output}, {{tp_type}}, name, &TransposeConfigMap, -1, true} {} @@ -45,14 +45,6 @@ std::string TransposeOp::function_name(const OpConfig &cfg) const { Tensor *Model::transpose(Tensor *input, Dims perm, Tensor *output, const std::string &name) { - OpPrecType pt = OP_PREC_NONE; - if (input->type == FP16) { - pt = OP_PREC_FP16; - } else if (input->type == FP32) { - pt = OP_PREC_FP32; - } else { - LOG(ERROR, "unsupported input data type: ", input->type); - } int input_ndims = input->ndims(); Dims in_shape{1, 1, 1, 1}; if (input_ndims < 2 || input_ndims > 4) { @@ -96,12 +88,12 @@ Tensor *Model::transpose(Tensor *input, Dims perm, Tensor *output, } else { assert(output->shape == out_shape); } - TransposeOp op{pt, input, output, tp_type, name}; + TransposeOp op{output->type.name(), input, output, tp_type, name}; return this->impl->add_op(op)[0]; } const OpConfigMap TransposeConfigMap = { - {{OP_ARCH_CUDA_ANY, OP_PREC_FP32}, + {{OP_ARCH_CUDA_ANY, "fp32"}, { // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost {8, 0, {{1, 1}}, {{128, 128}}, true, false}, @@ -110,7 +102,19 @@ const OpConfigMap TransposeConfigMap = { {4, 0, {{1, 1}}, {{64, 64}}, true, false}, {2, 0, {{1, 1}}, {{32, 32}}, true, false}, }}, - {{OP_ARCH_CUDA_ANY, OP_PREC_FP16}, + {{OP_ARCH_CUDA_ANY, "fp16"}, + { + // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost + {8, 0, {{1, 1}}, {{128, 128}}, true, false}, + {4, 0, {{1, 1}}, {{64, 128}}, true, false}, + {4, 0, {{1, 1}}, {{128, 64}}, true, false}, + {4, 0, {{1, 1}}, {{64, 64}}, true, false}, + {2, 0, {{1, 1}}, {{32, 32}}, true, false}, + {1, 0, {{1, 1}}, {{16, 16}}, true, false}, + {1, 0, {{1, 1}}, {{8, 16}}, true, false}, + {1, 0, {{1, 1}}, {{4, 8}}, true, false}, + }}, + {{OP_ARCH_CUDA_ANY, "bf16"}, { // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost {8, 0, {{1, 1}}, {{128, 128}}, true, false}, diff --git a/ark/ops/ops_transpose_test.cc b/ark/ops/ops_transpose_test.cc index 7ed14cacd..fcb6585be 100644 --- a/ark/ops/ops_transpose_test.cc +++ b/ark/ops/ops_transpose_test.cc @@ -78,6 +78,18 @@ ark::unittest::State test_transpose_0132_fp16() { return ark::unittest::SUCCESS; } +ark::unittest::State test_transpose_0132_bf16() { + ark::Model m; + ark::Tensor *t = m.tensor({5, 3, 32, 128}, ark::BF16); + ark::Tensor *out = m.transpose(t, {0, 1, 3, 2}); + + auto result = ark::op_test("transpose_0132_bf32", m, {t}, {out}, + baseline_transpose_0132); + UNITTEST_LOG(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + return ark::unittest::SUCCESS; +} + ark::unittest::State test_transpose_0231_fp32() { ark::Model m; ark::Tensor *t = m.tensor({5, 3, 32, 128}, ark::FP32); @@ -102,11 +114,25 @@ ark::unittest::State test_transpose_0231_fp16() { return ark::unittest::SUCCESS; } +ark::unittest::State test_transpose_0231_bf16() { + ark::Model m; + ark::Tensor *t = m.tensor({5, 3, 32, 128}, ark::BF16); + ark::Tensor *out = m.transpose(t, {0, 2, 3, 1}); + + auto result = ark::op_test("transpose_0231_bf16", m, {t}, {out}, + baseline_transpose_0231); + UNITTEST_LOG(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + return ark::unittest::SUCCESS; +} + int main() { ark::init(); UNITTEST(test_transpose_0132_fp32); UNITTEST(test_transpose_0132_fp16); + UNITTEST(test_transpose_0132_bf16); UNITTEST(test_transpose_0231_fp32); UNITTEST(test_transpose_0231_fp16); + UNITTEST(test_transpose_0231_bf16); return ark::unittest::SUCCESS; } diff --git a/ark/sched/sched_codegen.cc b/ark/sched/sched_codegen.cc index bd2023679..f3663a0dd 100644 --- a/ark/sched/sched_codegen.cc +++ b/ark/sched/sched_codegen.cc @@ -71,7 +71,7 @@ std::ostream &CodeGenerator::sync_stream(std::ostream &os, int stream_id, ostream &CodeGenerator::tensor(ostream &os, const Tensor *tensor) const { size_t off = this->get_tensor_offset(tensor); - os << "(" << tensor->type.pointer_name() << ")"; + os << "(" << tensor->type.type_str() << " *)"; std::string buf_name = ARK_BUF_NAME; if (tensor->imported_rank >= 0) { buf_name += std::to_string(tensor->imported_rank); @@ -85,7 +85,7 @@ std::ostream &CodeGenerator::def_oparg(std::ostream &os, const OpArg &arg, if (arg.type == OP_ARG_TENSOR) { Tensor *tns; arg.get(&tns); - os << tns->type.pointer_name() << name; + os << tns->type.type_str() << " *" << name; } else if (arg.type == OP_ARG_FLOAT) { os << "float " << name; } else if (arg.type == OP_ARG_INT) { diff --git a/ark/sched/sched_op.cc b/ark/sched/sched_op.cc index 8eec85f70..171a14a69 100644 --- a/ark/sched/sched_op.cc +++ b/ark/sched/sched_op.cc @@ -119,7 +119,7 @@ const string SchedOp::serialize() const { if (arg.type == OP_ARG_TENSOR) { Tensor *tns; arg.get(&tns); - ss << tns->type.pointer_name(); + ss << tns->type.type_str() << " *"; } else if (arg.type == OP_ARG_FLOAT) { ss << "float"; } else if (arg.type == OP_ARG_INT) { diff --git a/ark/tensor.cc b/ark/tensor.cc index 89f636d90..dac9e3bae 100644 --- a/ark/tensor.cc +++ b/ark/tensor.cc @@ -20,39 +20,6 @@ size_t TensorBuf::get_buf_offset() const { return static_cast(this->buf)->get_offset(); } -TensorType::TensorType(int id, int bytes, const std::string &name, - const std::string &pointer_name) - : id_{id}, bytes_{bytes}, name_{name}, pointer_name_{pointer_name} {} - -bool TensorType::operator==(const TensorType &other) const { - return id_ == other.id(); -} - -bool TensorType::operator!=(const TensorType &other) const { - return id_ != other.id(); -} - -int TensorType::id() const { return id_; } - -int TensorType::bytes() const { return bytes_; } - -const std::string &TensorType::name() const { return name_; } - -const std::string &TensorType::pointer_name() const { return pointer_name_; } - -std::ostream &operator<<(std::ostream &os, const TensorType &type) { - os << type.name(); - return os; -} - -Fp16::Fp16() : TensorType{0, 2, "fp16", "ark::half *"} {} - -Fp32::Fp32() : TensorType{1, 4, "fp32", "float *"} {} - -Int32::Int32() : TensorType{2, 4, "int32", "int *"} {} - -Byte::Byte() : TensorType{3, 1, "byte", "char *"} {} - // Tensor constructor Tensor::Tensor(const Dims &shape_, const TensorType &type_, TensorBuf *buf_, const Dims &ldims_, const Dims &offs_, const Dims &pads_, diff --git a/ark/tensor_type.cc b/ark/tensor_type.cc new file mode 100644 index 000000000..e162e005b --- /dev/null +++ b/ark/tensor_type.cc @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include + +#include "ark.h" + +static std::string to_lowercase(const std::string &str) { + std::string result = str; + std::transform(result.begin(), result.end(), result.begin(), + [](unsigned char c) { return std::tolower(c); }); + return result; +} + +namespace ark { + +TensorType::TensorType(const std::string &name, int bytes, + const std::string &type_str) + : name_{to_lowercase(name)}, bytes_{bytes}, type_str_{type_str} {} + +bool TensorType::operator==(const TensorType &other) const { + return name_ == other.name(); +} + +bool TensorType::operator!=(const TensorType &other) const { + return name_ != other.name(); +} + +int TensorType::bytes() const { return bytes_; } + +const std::string &TensorType::name() const { return name_; } + +const std::string &TensorType::type_str() const { return type_str_; } + +std::ostream &operator<<(std::ostream &os, const TensorType &type) { + os << type.name(); + return os; +} + +} // namespace ark diff --git a/ark/threading.h b/ark/threading.h deleted file mode 100644 index d245f1d08..000000000 --- a/ark/threading.h +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#ifndef ARK_THREADING_H_ -#define ARK_THREADING_H_ - -#include -#include -#include -#include - -namespace ark { - -template -void para_exec(std::vector &items, int max_num_threads, - const std::function &func) { - size_t nthread = (size_t)max_num_threads; - if (nthread > items.size()) { - nthread = items.size(); - } - std::vector threads; - threads.reserve(nthread); - std::mutex mtx; - size_t idx = 0; - for (size_t i = 0; i < nthread; ++i) { - threads.emplace_back([&items, &mtx, &idx, &func] { - size_t local_idx = -1; - for (;;) { - { - const std::lock_guard lock(mtx); - local_idx = idx++; - } - if (local_idx >= items.size()) break; - func(items[local_idx]); - } - }); - } - for (auto &t : threads) { - t.join(); - } -} - -} // namespace ark - -#endif // ARK_THREADING_H_ diff --git a/ark/utils.cc b/ark/utils.cc deleted file mode 100644 index 47ee37111..000000000 --- a/ark/utils.cc +++ /dev/null @@ -1,207 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#include -#include -#include - -#include -#include -#include - -#include "include/ark.h" -#include "include/ark_utils.h" - -// clang-format off -#include "vector_types.h" -#include "cutlass/half.h" -// clang-format on - -using namespace std; - -/// Convert cutlass::half_t to @ref ark::half_t -/// @param cuh cutlass::half_t -/// @return @ref ark::half_t -inline static const ark::half_t convert(const cutlass::half_t &cuh) { - ark::half_t ret; - ret.storage = cuh.raw(); - return ret; -} - -/// Numeric limits of @ref ark::half_t -template <> -struct std::numeric_limits { - static ark::half_t max() { - return convert(std::numeric_limits::max()); - } - static ark::half_t min() { - return convert(std::numeric_limits::min()); - } - static ark::half_t epsilon() { - return convert(std::numeric_limits::epsilon()); - } -}; - -ark::half_t operator+(ark::half_t const &lhs, ark::half_t const &rhs) { - return convert(cutlass::half_t::bitcast(lhs.storage) + - cutlass::half_t::bitcast(rhs.storage)); -} - -ark::half_t operator-(ark::half_t const &lhs, ark::half_t const &rhs) { - return convert(cutlass::half_t::bitcast(lhs.storage) - - cutlass::half_t::bitcast(rhs.storage)); -} - -ark::half_t operator*(ark::half_t const &lhs, ark::half_t const &rhs) { - return convert(cutlass::half_t::bitcast(lhs.storage) * - cutlass::half_t::bitcast(rhs.storage)); -} - -ark::half_t &operator+=(ark::half_t &lhs, ark::half_t const &rhs) { - cutlass::half_t v = cutlass::half_t::bitcast(lhs.storage) + - cutlass::half_t::bitcast(rhs.storage); - lhs.storage = v.raw(); - return lhs; -} - -ark::half_t &operator-=(ark::half_t &lhs, ark::half_t const &rhs) { - cutlass::half_t v = cutlass::half_t::bitcast(lhs.storage) - - cutlass::half_t::bitcast(rhs.storage); - lhs.storage = v.raw(); - return lhs; -} - -/// Return the absolute value of a @ref ark::half_t -/// @param val Input value -/// @return @ref Absolute value of `val` -ark::half_t abs(ark::half_t const &val) { - return convert(cutlass::abs(cutlass::half_t::bitcast(val.storage))); -} - -namespace ark { - -/// Construct a @ref half_t from a float -/// @param f Input value -half_t::half_t(float f) { this->storage = cutlass::half_t(f).raw(); } - -/// Convert a @ref half_t to a float -/// @return float -half_t::operator float() const { - return float(cutlass::half_t::bitcast(this->storage)); -} - -namespace utils { - -/// Return a random @ref half_t array. -/// @param num Number of elements -/// @param max_val Maximum value -/// @return std::unique_ptr -unique_ptr rand_halfs(size_t num, float max_val) { - return rand_array(num, max_val); -} - -/// Return a random float array. -/// @param num Number of elements -/// @param max_val Maximum value -/// @return std::unique_ptr -unique_ptr rand_floats(size_t num, float max_val) { - return rand_array(num, max_val); -} - -/// Return a random bytes array. -/// @param num Number of elements -/// @return std::unique_ptr -unique_ptr rand_bytes(size_t num) { - return rand_array(num, 255); -} - -/// Return an array of values starting from `begin` with difference `diff`. -/// @tparam T Type of the array -/// @param num Number of elements -/// @param begin First value -/// @param diff Difference between two values -/// @return std::unique_ptr -template -unique_ptr range_array(size_t num, float begin, float diff) { - T *ret = new T[num]; - for (size_t i = 0; i < num; ++i) { - ret[i] = T(begin); - begin += diff; - } - return unique_ptr(ret); -} - -/// Return a @ref half_t range array. -/// @param num Number of elements -/// @param begin First value -/// @param diff Difference between two values -/// @return std::unique_ptr -unique_ptr range_halfs(size_t num, float begin, float diff) { - return range_array(num, begin, diff); -} - -/// Return a float range array. -/// @param num Number of elements -/// @param begin First value -/// @param diff Difference between two values -/// @return std::unique_ptr -unique_ptr range_floats(size_t num, float begin, float diff) { - return range_array(num, begin, diff); -} - -/// Spawn a process that runs `func`. -/// @param func function to run in the spawned process. -/// @return PID of the spawned process. -int proc_spawn(const function &func) { - pid_t pid = fork(); - if (pid < 0) { - return -1; - } else if (pid == 0) { - int ret = func(); - std::exit(ret); - } - return (int)pid; -} - -/// Wait for a spawned process with PID `pid`. -/// @param pid PID of the spawned process. -/// @return -1 on any unexpected failure, otherwise return the exit status. -int proc_wait(int pid) { - int status; - if (waitpid(pid, &status, 0) == -1) { - return -1; - } - if (WIFEXITED(status)) { - return WEXITSTATUS(status); - } - return -1; -} - -/// Wait for multiple child processes. -/// @param pids PIDs of the spawned processes. -/// @return 0 on success, -1 on any unexpected failure, otherwise the first seen -/// non-zero exit status. -int proc_wait(const vector &pids) { - int ret = 0; - for (auto &pid : pids) { - int status; - if (waitpid(pid, &status, 0) == -1) { - return -1; - } - int r; - if (WIFEXITED(status)) { - r = WEXITSTATUS(status); - } else if (WIFSIGNALED(status)) { - r = -1; - } else { - r = -1; - } - if ((ret == 0) && (r != 0)) { - ret = r; - } - } - return ret; -} - -} // namespace utils -} // namespace ark diff --git a/ark/utils/utils_bfloat16.cc b/ark/utils/utils_bfloat16.cc new file mode 100644 index 000000000..37982c88b --- /dev/null +++ b/ark/utils/utils_bfloat16.cc @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "include/ark.h" +#include "include/ark_utils.h" + +// clang-format off +#include "vector_types.h" +#include "cutlass/bfloat16.h" +// clang-format on + +/// Convert cutlass::bfloat16_t to @ref ark::bfloat16_t +/// @param cub cutlass::bfloat16_t +/// @return @ref ark::bfloat16_t +inline static const ark::bfloat16_t convert(const cutlass::bfloat16_t &cub) { + ark::bfloat16_t ret; + ret.storage = cub.raw(); + return ret; +} + +/// Numeric limits of @ref ark::bfloat16_t +template <> +struct std::numeric_limits { + static ark::bfloat16_t max() { + return convert(std::numeric_limits::max()); + } + static ark::bfloat16_t min() { + return convert(std::numeric_limits::min()); + } + static ark::bfloat16_t epsilon() { + return convert(std::numeric_limits::epsilon()); + } +}; + +ark::bfloat16_t operator+(ark::bfloat16_t const &lhs, + ark::bfloat16_t const &rhs) { + return convert(cutlass::bfloat16_t::bitcast(lhs.storage) + + cutlass::bfloat16_t::bitcast(rhs.storage)); +} + +ark::bfloat16_t operator-(ark::bfloat16_t const &lhs, + ark::bfloat16_t const &rhs) { + return convert(cutlass::bfloat16_t::bitcast(lhs.storage) - + cutlass::bfloat16_t::bitcast(rhs.storage)); +} + +ark::bfloat16_t operator*(ark::bfloat16_t const &lhs, + ark::bfloat16_t const &rhs) { + return convert(cutlass::bfloat16_t::bitcast(lhs.storage) * + cutlass::bfloat16_t::bitcast(rhs.storage)); +} + +ark::bfloat16_t &operator+=(ark::bfloat16_t &lhs, ark::bfloat16_t const &rhs) { + cutlass::bfloat16_t v = cutlass::bfloat16_t::bitcast(lhs.storage) + + cutlass::bfloat16_t::bitcast(rhs.storage); + lhs.storage = v.raw(); + return lhs; +} + +ark::bfloat16_t &operator-=(ark::bfloat16_t &lhs, ark::bfloat16_t const &rhs) { + cutlass::bfloat16_t v = cutlass::bfloat16_t::bitcast(lhs.storage) - + cutlass::bfloat16_t::bitcast(rhs.storage); + lhs.storage = v.raw(); + return lhs; +} + +/// Return the absolute value of a @ref ark::bfloat16_t +/// @param val Input value +/// @return @ref Absolute value of `val` +ark::bfloat16_t abs(ark::bfloat16_t const &val) { + return convert(cutlass::abs(cutlass::bfloat16_t::bitcast(val.storage))); +} + +namespace ark { + +/// Construct a @ref bfloat16_t from a float +/// @param f Input value +bfloat16_t::bfloat16_t(float f) { + this->storage = cutlass::bfloat16_t(f).raw(); +} + +/// Convert a @ref bfloat16_t to a float +/// @return float +bfloat16_t::operator float() const { + return float(cutlass::bfloat16_t::bitcast(this->storage)); +} + +} // namespace ark diff --git a/ark/utils/utils_half.cc b/ark/utils/utils_half.cc new file mode 100644 index 000000000..d6e8092a3 --- /dev/null +++ b/ark/utils/utils_half.cc @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "include/ark.h" +#include "include/ark_utils.h" + +// clang-format off +#include "vector_types.h" +#include "cutlass/half.h" +// clang-format on + +/// Convert cutlass::half_t to @ref ark::half_t +/// @param cuh cutlass::half_t +/// @return @ref ark::half_t +inline static const ark::half_t convert(const cutlass::half_t &cuh) { + ark::half_t ret; + ret.storage = cuh.raw(); + return ret; +} + +/// Numeric limits of @ref ark::half_t +template <> +struct std::numeric_limits { + static ark::half_t max() { + return convert(std::numeric_limits::max()); + } + static ark::half_t min() { + return convert(std::numeric_limits::min()); + } + static ark::half_t epsilon() { + return convert(std::numeric_limits::epsilon()); + } +}; + +ark::half_t operator+(ark::half_t const &lhs, ark::half_t const &rhs) { + return convert(cutlass::half_t::bitcast(lhs.storage) + + cutlass::half_t::bitcast(rhs.storage)); +} + +ark::half_t operator-(ark::half_t const &lhs, ark::half_t const &rhs) { + return convert(cutlass::half_t::bitcast(lhs.storage) - + cutlass::half_t::bitcast(rhs.storage)); +} + +ark::half_t operator*(ark::half_t const &lhs, ark::half_t const &rhs) { + return convert(cutlass::half_t::bitcast(lhs.storage) * + cutlass::half_t::bitcast(rhs.storage)); +} + +ark::half_t &operator+=(ark::half_t &lhs, ark::half_t const &rhs) { + cutlass::half_t v = cutlass::half_t::bitcast(lhs.storage) + + cutlass::half_t::bitcast(rhs.storage); + lhs.storage = v.raw(); + return lhs; +} + +ark::half_t &operator-=(ark::half_t &lhs, ark::half_t const &rhs) { + cutlass::half_t v = cutlass::half_t::bitcast(lhs.storage) - + cutlass::half_t::bitcast(rhs.storage); + lhs.storage = v.raw(); + return lhs; +} + +/// Return the absolute value of a @ref ark::half_t +/// @param val Input value +/// @return @ref Absolute value of `val` +ark::half_t abs(ark::half_t const &val) { + return convert(cutlass::abs(cutlass::half_t::bitcast(val.storage))); +} + +namespace ark { + +/// Construct a @ref half_t from a float +/// @param f Input value +half_t::half_t(float f) { this->storage = cutlass::half_t(f).raw(); } + +/// Convert a @ref half_t to a float +/// @return float +half_t::operator float() const { + return float(cutlass::half_t::bitcast(this->storage)); +} + +} // namespace ark diff --git a/ark/utils/utils_proc.cc b/ark/utils/utils_proc.cc new file mode 100644 index 000000000..c69d4b603 --- /dev/null +++ b/ark/utils/utils_proc.cc @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include +#include + +#include + +#include "include/ark.h" +#include "include/ark_utils.h" + +using namespace std; + +namespace ark { +namespace utils { + +/// Spawn a process that runs `func`. +/// @param func function to run in the spawned process. +/// @return PID of the spawned process. +int proc_spawn(const function &func) { + pid_t pid = fork(); + if (pid < 0) { + return -1; + } else if (pid == 0) { + int ret = func(); + std::exit(ret); + } + return (int)pid; +} + +/// Wait for a spawned process with PID `pid`. +/// @param pid PID of the spawned process. +/// @return -1 on any unexpected failure, otherwise return the exit status. +int proc_wait(int pid) { + int status; + if (waitpid(pid, &status, 0) == -1) { + return -1; + } + if (WIFEXITED(status)) { + return WEXITSTATUS(status); + } + return -1; +} + +/// Wait for multiple child processes. +/// @param pids PIDs of the spawned processes. +/// @return 0 on success, -1 on any unexpected failure, otherwise the first seen +/// non-zero exit status. +int proc_wait(const vector &pids) { + int ret = 0; + for (auto &pid : pids) { + int status; + if (waitpid(pid, &status, 0) == -1) { + return -1; + } + int r; + if (WIFEXITED(status)) { + r = WEXITSTATUS(status); + } else if (WIFSIGNALED(status)) { + r = -1; + } else { + r = -1; + } + if ((ret == 0) && (r != 0)) { + ret = r; + } + } + return ret; +} + +} // namespace utils +} // namespace ark diff --git a/ark/utils/utils_random.cc b/ark/utils/utils_random.cc new file mode 100644 index 000000000..5ad58410b --- /dev/null +++ b/ark/utils/utils_random.cc @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "include/ark.h" +#include "include/ark_utils.h" + +using namespace std; + +namespace ark { +namespace utils { + +/// Return a random @ref half_t array. +/// @param num Number of elements +/// @param max_val Maximum value +/// @return std::unique_ptr +unique_ptr rand_halfs(size_t num, float max_val) { + return rand_array(num, max_val); +} + +/// Return a random float array. +/// @param num Number of elements +/// @param max_val Maximum value +/// @return std::unique_ptr +unique_ptr rand_floats(size_t num, float max_val) { + return rand_array(num, max_val); +} + +/// Return a random bytes array. +/// @param num Number of elements +/// @return std::unique_ptr +unique_ptr rand_bytes(size_t num) { + return rand_array(num, 255); +} + +/// Return an array of values starting from `begin` with difference `diff`. +/// @tparam T Type of the array +/// @param num Number of elements +/// @param begin First value +/// @param diff Difference between two values +/// @return std::unique_ptr +template +unique_ptr range_array(size_t num, float begin, float diff) { + T *ret = new T[num]; + for (size_t i = 0; i < num; ++i) { + ret[i] = T(begin); + begin += diff; + } + return unique_ptr(ret); +} + +/// Return a @ref half_t range array. +/// @param num Number of elements +/// @param begin First value +/// @param diff Difference between two values +/// @return std::unique_ptr +unique_ptr range_halfs(size_t num, float begin, float diff) { + return range_array(num, begin, diff); +} + +/// Return a float range array. +/// @param num Number of elements +/// @param begin First value +/// @param diff Difference between two values +/// @return std::unique_ptr +unique_ptr range_floats(size_t num, float begin, float diff) { + return range_array(num, begin, diff); +} + +} // namespace utils +} // namespace ark diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 3b97fc713..efb9aea3e 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -16,6 +16,7 @@ FetchContent_Declare( ) FetchContent_MakeAvailable(pybind11) -pybind11_add_module(ark_py ${CMAKE_CURRENT_SOURCE_DIR}/bindings.cpp) +file(GLOB_RECURSE BIND_SOURCES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) +pybind11_add_module(ark_py ${BIND_SOURCES}) set_target_properties(ark_py PROPERTIES OUTPUT_NAME _ark_core) target_link_libraries(ark_py PRIVATE ark_static) diff --git a/python/ark/__init__.py b/python/ark/__init__.py index 4d0687980..74050f177 100644 --- a/python/ark/__init__.py +++ b/python/ark/__init__.py @@ -7,59 +7,31 @@ os.environ["ARK_ROOT"] = os.path.abspath(os.path.dirname(__file__)) from . import _ark_core +from .data_type import DataType, _REGISTRY_DATA_TYPE +from .tensor import Dims, Tensor, TensorBuf, Parameter +from .model import Model, _REGISTRY_OPERATOR +from .module import Module +from .runtime import Runtime +from .serialize import save, load + +# Read the version. __version__ = _ark_core.version() +# Import data types. +for type_name in _REGISTRY_DATA_TYPE.keys(): + globals()[type_name] = DataType.from_name(type_name) + +# Import operators. +for op_name, op_func in _REGISTRY_OPERATOR.items(): + globals()[op_name] = op_func + def version(): """Returns the version of ARK.""" return __version__ -from .runtime import Runtime -from .data_type import DataType, fp32, fp16, int32, byte -from .tensor import Dims, Tensor, TensorBuf, Parameter -from .module import Module -from .serialize import save, load -from .model import ( - Model, - tensor, - parameter, - reshape, - identity, - sharding, - reduce_sum, - reduce_mean, - reduce_max, - layernorm, - rmsnorm, - softmax, - transpose, - matmul, - im2col, - scale, - relu, - gelu, - sigmoid, - exp, - sqrt, - rope, - add, - sub, - mul, - div, - send, - send_done, - recv, - send_mm, - recv_mm, - all_gather, - all_reduce, - embedding, - cast, -) - - def init(): """Initializes ARK.""" _ark_core.init() diff --git a/python/ark/data_type.py b/python/ark/data_type.py index 82102af48..e6691ecac 100644 --- a/python/ark/data_type.py +++ b/python/ark/data_type.py @@ -1,110 +1,104 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import numpy as np -from ._ark_core import _TensorType, _FP32, _FP16, _INT32, _BYTE +import numpy +import string +from . import _ark_core + + +_REGISTRY_DATA_TYPE = { + "fp32": { + "np": numpy.float32, + "doc": """32-bit floating point.""", + }, + "fp16": { + "np": numpy.float16, + "doc": """16-bit floating point.""", + }, + "bf16": { + "np": None, + "doc": """bfloat16 floating point.""", + }, + "int32": { + "np": numpy.int32, + "doc": """32-bit signed integer.""", + }, + "uint32": { + "np": numpy.uint32, + "doc": """32-bit unsigned integer.""", + }, + "int8": { + "np": numpy.int8, + "doc": """8-bit signed integer.""", + }, + "uint8": { + "np": numpy.uint8, + "doc": """8-bit unsigned integer.""", + }, + "byte": { + "np": numpy.ubyte, + "doc": """ +Represent the data as bytes, supposed to be untyped binary. + +Unlike other data types, casting to/from `byte` from/to another data type +is considered as reinterpretation of the data, instead of conversion. +""", + }, +} class DataType: @staticmethod - def from_numpy(np_type: np.dtype) -> "DataType": - if np_type == np.float32: - return fp32 - elif np_type == np.float16: - return fp16 - elif np_type == np.int32: - return int32 - elif np_type == np.uint8: - return byte - else: - raise NotImplementedError + def from_numpy(np_type: numpy.dtype) -> "DataType": + for type_name, reg in _REGISTRY_DATA_TYPE.items(): + if reg["np"] == np_type: + return DataType.from_name(type_name) + raise ValueError( + f"Undefined conversion from numpy data type {np_type}" + f" to ark data type." + ) @staticmethod - def from_ttype(ttype: _TensorType) -> "DataType": - if ttype == _FP32: - return fp32 - elif ttype == _FP16: - return fp16 - elif ttype == _INT32: - return int32 - elif ttype == _BYTE: - return byte - else: - raise NotImplementedError + def from_name(type_name: str) -> "DataType": + return globals()[type_name] @staticmethod - def from_str(type_str: str) -> "DataType": - if type_str == "fp32": - return fp32 - elif type_str == "fp16": - return fp16 - elif type_str == "int32": - return int32 - elif type_str == "byte": - return byte - else: - raise NotImplementedError + def from_ttype(ttype: _ark_core._TensorType) -> "DataType": + return DataType.from_name(ttype.name()) @staticmethod - def to_numpy() -> np.dtype: + def to_numpy() -> numpy.dtype: + """Return the corresponding numpy data type.""" ... @staticmethod - def ttype() -> _TensorType: + def ttype() -> _ark_core._TensorType: + """Return the corresponding tensor type.""" ... - -class fp32(DataType): - @staticmethod - def to_numpy() -> np.float32: - return np.float32 - - @staticmethod - def ttype() -> _TensorType: - return _FP32 - @staticmethod def element_size() -> int: - return 4 - - -class fp16(DataType): - @staticmethod - def to_numpy() -> np.float16: - return np.float16 - - @staticmethod - def ttype() -> _TensorType: - return _FP16 - - @staticmethod - def element_size() -> int: - return 2 + """Return the size of the data type in bytes.""" + ... -class int32(DataType): +_DATA_TYPE_TEMPLATE = string.Template( + """ +class $type_name(DataType): @staticmethod - def to_numpy() -> np.int32: - return np.int32 + def to_numpy() -> numpy.dtype: + return _REGISTRY_DATA_TYPE[__class__.__name__]["np"] @staticmethod - def ttype() -> _TensorType: - return _INT32 + def ttype() -> _ark_core._TensorType: + return getattr(_ark_core, "_" + __class__.__name__.upper()) @staticmethod def element_size() -> int: - return 4 - - -class byte(DataType): - @staticmethod - def to_numpy() -> np.uint8: - return np.uint8 - - @staticmethod - def ttype() -> _TensorType: - return _BYTE + return __class__.ttype().bytes() +""" +) - @staticmethod - def element_size() -> int: - return 1 +for type_name, reg in _REGISTRY_DATA_TYPE.items(): + exec(_DATA_TYPE_TEMPLATE.substitute(type_name=type_name)) + globals()[type_name].__doc__ = reg["doc"] diff --git a/python/ark/model.py b/python/ark/model.py index f696888b4..9c8e4c848 100644 --- a/python/ark/model.py +++ b/python/ark/model.py @@ -9,6 +9,15 @@ from .data_type import DataType, fp32 +_REGISTRY_OPERATOR = {} + + +# Decorator for registering operators. +def register_op(func): + _REGISTRY_OPERATOR[func.__name__] = func + return func + + class _ModelState: """ The _ModelState class is used to store the state of the model. @@ -75,6 +84,7 @@ def _is_list_or_tuple(obj): return isinstance(obj, list) or isinstance(obj, tuple) +@register_op def tensor( shape: Iterable[int], dtype: DataType = fp32, @@ -113,6 +123,7 @@ def tensor( return Tensor(_tensor) +@register_op def parameter( shape: Iterable[int], dtype: DataType = fp32, @@ -148,6 +159,7 @@ def parameter( return Parameter(_tensor) +@register_op def reshape( input: Tensor, shape: Iterable[int], @@ -180,6 +192,7 @@ def reshape( return Tensor(_tensor) +@register_op def identity( input: Tensor, deps: List[Tensor] = [], @@ -199,6 +212,7 @@ def identity( return Tensor(_tensor) +@register_op def sharding( input: Tensor, axis: int, @@ -223,6 +237,7 @@ def sharding( return tensor_list +@register_op def reduce_sum( input: Tensor, axis: int, @@ -243,6 +258,7 @@ def reduce_sum( return Tensor(_tensor) +@register_op def reduce_mean( input: Tensor, axis: int, @@ -261,6 +277,7 @@ def reduce_mean( return Tensor(_tensor) +@register_op def reduce_max( input: Tensor, axis: int, @@ -279,6 +296,7 @@ def reduce_max( return Tensor(_tensor) +@register_op def layernorm( input: Tensor, output: Tensor = None, @@ -296,6 +314,7 @@ def layernorm( return Tensor(_tensor) +@register_op def rmsnorm( input: Tensor, output: Tensor = None, @@ -313,6 +332,7 @@ def rmsnorm( return Tensor(_tensor) +@register_op def softmax( input: Tensor, output: Tensor = None, @@ -329,6 +349,7 @@ def softmax( return Tensor(_tensor) +@register_op def transpose( input: Tensor, perm: list, @@ -357,6 +378,7 @@ def transpose( return Tensor(_tensor) +@register_op def matmul( input: Tensor, other: Tensor, @@ -391,6 +413,7 @@ def matmul( return Tensor(_tensor) +@register_op def im2col( input: Tensor, kernel_height: int, @@ -428,6 +451,7 @@ def im2col( return Tensor(_tensor) +@register_op def scale( input: Tensor, val: float, @@ -450,6 +474,7 @@ def scale( return Tensor(_tensor) +@register_op def exp( input: Tensor, output: Tensor = None, @@ -466,6 +491,7 @@ def exp( return Tensor(_tensor) +@register_op def sqrt( input: Tensor, output: Tensor = None, @@ -482,6 +508,7 @@ def sqrt( return Tensor(_tensor) +@register_op def rope( input: Tensor, other: Tensor, @@ -499,6 +526,7 @@ def rope( return Tensor(_tensor) +@register_op def relu( input: Tensor, output: Tensor = None, @@ -516,6 +544,7 @@ def relu( return Tensor(_tensor) +@register_op def gelu( input: Tensor, output: Tensor = None, @@ -535,6 +564,7 @@ def gelu( return Tensor(_tensor) +@register_op def sigmoid( input: Tensor, output: Tensor = None, @@ -552,6 +582,7 @@ def sigmoid( return Tensor(_tensor) +@register_op def add( input: Tensor, other: Tensor, @@ -570,6 +601,7 @@ def add( return Tensor(_tensor) +@register_op def sub( input: Tensor, other: Tensor, @@ -588,6 +620,7 @@ def sub( return Tensor(_tensor) +@register_op def mul( input: Tensor, other: Tensor, @@ -606,6 +639,7 @@ def mul( return Tensor(_tensor) +@register_op def div( input: Tensor, other: Tensor, @@ -624,6 +658,7 @@ def div( return Tensor(_tensor) +@register_op def send( input: Tensor, id: int, @@ -654,6 +689,7 @@ def send( return Tensor(_tensor) +@register_op def send_done( input: Tensor, id: int, @@ -673,6 +709,7 @@ def send_done( return Tensor(_tensor) +@register_op def recv( id: int, src_rank: int, @@ -697,6 +734,7 @@ def recv( return Tensor(_tensor) +@register_op def send_mm( input: Tensor, id: int, @@ -728,6 +766,7 @@ def send_mm( return Tensor(_tensor) +@register_op def recv_mm( input: Tensor, id: int, @@ -753,6 +792,7 @@ def recv_mm( return Tensor(_tensor) +@register_op def all_gather( input: Tensor, rank: int, @@ -794,6 +834,7 @@ def all_gather( return [Tensor(_tensor) for _tensor in tensor_shards] +@register_op def all_reduce( input: Tensor, rank: int, @@ -822,6 +863,7 @@ def all_reduce( return Tensor(_tensor) +@register_op def embedding( input: Tensor, weight: Tensor, @@ -840,6 +882,7 @@ def embedding( return Tensor(_tensor) +@register_op def cast( input: Tensor, dtype: DataType, diff --git a/python/ark_py.cpp b/python/ark_py.cpp new file mode 100644 index 000000000..71af96d33 --- /dev/null +++ b/python/ark_py.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include +#include + +#include + +#include "ark.h" + +namespace py = pybind11; + +extern void register_dims(py::module &m); +extern void register_tensor_type(py::module &m); +extern void register_tensor(py::module &m); +extern void register_model(py::module &m); +extern void register_executor(py::module &m); + +PYBIND11_MODULE(_ark_core, m) { + m.doc() = "Bind ARK C++ APIs to Python"; + + m.def("version", &ark::version); + m.def("init", &ark::init); + m.def("srand", &ark::srand, py::arg("seed") = -1); + m.def("rand", &ark::rand); + + register_dims(m); + register_tensor_type(m); + register_tensor(m); + register_model(m); + register_executor(m); +} diff --git a/python/dims_py.cpp b/python/dims_py.cpp new file mode 100644 index 000000000..78461a5f4 --- /dev/null +++ b/python/dims_py.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include +#include + +#include + +#include "ark.h" + +namespace py = pybind11; + +void register_dims(py::module &m) { + m.attr("DIMS_LEN") = py::int_(static_cast(ark::DIMS_LEN)); + m.attr("NO_DIM") = py::int_(static_cast(ark::NO_DIM)); + + py::class_(m, "_Dims") + .def(py::init([](ark::DimType d0, ark::DimType d1, ark::DimType d2, + ark::DimType d3) { + return std::make_unique(d0, d1, d2, d3); + }), + py::arg_v("d0", static_cast(ark::NO_DIM)), + py::arg_v("d1", static_cast(ark::NO_DIM)), + py::arg_v("d2", static_cast(ark::NO_DIM)), + py::arg_v("d3", static_cast(ark::NO_DIM))) + .def(py::init()) + .def(py::init &>()) + .def("size", &ark::Dims::size) + .def("ndims", &ark::Dims::ndims) + .def("__getitem__", + [](const ark::Dims &d, ark::DimType idx) { return d[idx]; }) + .def("__setitem__", [](ark::Dims &d, ark::DimType idx, + ark::DimType value) { d[idx] = value; }) + .def("__repr__", [](const ark::Dims &d) { + std::ostringstream os; + os << d; + return os.str(); + }); +} diff --git a/python/executor_py.cpp b/python/executor_py.cpp new file mode 100644 index 000000000..7db7a9eb9 --- /dev/null +++ b/python/executor_py.cpp @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include +#include + +#include "ark.h" + +namespace py = pybind11; + +void register_executor(py::module &m) { + py::class_(m, "_Executor") + .def(py::init(), + py::arg("rank"), py::arg("world_size"), py::arg("model"), + py::arg("name"), py::arg("num_warps_per_sm") = 16) + .def("compile", &ark::Executor::compile) + .def("launch", &ark::Executor::launch) + .def("run", &ark::Executor::run, py::arg("iter")) + .def("wait", &ark::Executor::wait) + .def("stop", &ark::Executor::stop); +} diff --git a/python/bindings.cpp b/python/model_py.cpp similarity index 62% rename from python/bindings.cpp rename to python/model_py.cpp index aa07c0642..3a2a1b4ee 100644 --- a/python/bindings.cpp +++ b/python/model_py.cpp @@ -5,157 +5,11 @@ #include #include -#include -#include - #include "ark.h" namespace py = pybind11; -void tensor_write(ark::Tensor *tns, py::buffer host_buffer) { - py::buffer_info info = host_buffer.request(); - tns->write(info.ptr); -} - -void tensor_read(ark::Tensor *tns, py::buffer host_buffer) { - py::buffer_info info = host_buffer.request(); - tns->read(info.ptr); -} - -PYBIND11_MODULE(_ark_core, m) { - m.doc() = "ARK python module interface"; - - m.def("version", &ark::version, "Return a version string."); - - m.def("init", &ark::init, - "Init an ark program. Call this function to clean up the shared " - "memory directory. This is useful when the previous run crashed, as " - "this forces to remove locks generated by previous runs. This may " - "crash other ARK processes running on the same machine, if there are " - "any."); - - m.def("srand", &ark::srand, py::arg("seed") = -1, - "Sets the seed for the random number generator"); - m.def("rand", &ark::rand, "Generates a random integer"); - m.attr("DIMS_LEN") = py::int_(static_cast(ark::DIMS_LEN)); - m.attr("NO_DIM") = py::int_(static_cast(ark::NO_DIM)); - - py::class_(m, "_Dims", "Up-to-`DIMS_LEN`-dimensional vector.") - .def(py::init([](ark::DimType d0, ark::DimType d1, ark::DimType d2, - ark::DimType d3) { - return std::make_unique(d0, d1, d2, d3); - }), - py::arg_v("d0", static_cast(ark::NO_DIM), - "default value: NO_DIM"), - py::arg_v("d1", static_cast(ark::NO_DIM), - "default value: NO_DIM"), - py::arg_v("d2", static_cast(ark::NO_DIM), - "default value: NO_DIM"), - py::arg_v("d3", static_cast(ark::NO_DIM), - "default value: NO_DIM")) - .def(py::init(), "Copy another Dims object.") - .def(py::init &>(), - "Construct from a vector. If the vector is shorter than DIMS_LEN, " - "put following NO_DIMs. Raise an error if the vector is longer " - "than DIMS_LEN.") - .def("size", &ark::Dims::size, - "Return the volume of dimensions. If the dimensions are invalid, " - "return -1") - .def("ndims", &ark::Dims::ndims, - "Return the number of valid dimensions.") - .def("__getitem__", - [](const ark::Dims &d, ark::DimType idx) { return d[idx]; }) - .def("__setitem__", [](ark::Dims &d, ark::DimType idx, - ark::DimType value) { d[idx] = value; }) - .def("__repr__", [](const ark::Dims &d) { - std::ostringstream os; - os << d; - return os.str(); - }); - - py::class_(m, "_TensorType", "Type of tensor data.") - .def(pybind11::self == pybind11::self) - .def(pybind11::self != pybind11::self) - .def("bytes", &ark::TensorType::bytes, "Number of bytes of this type.") - .def("name", &ark::TensorType::name, "Name of this type."); - - py::class_(m, "_Fp16", "16-bit floating point.") - .def(py::init<>()); - - py::class_(m, "_Fp32", "32-bit floating point.") - .def(py::init<>()); - - py::class_(m, "_Int32", "32-bit integer.") - .def(py::init<>()); - - py::class_(m, "_Byte", "8-bit integer.") - .def(py::init<>()); - - m.attr("_FP16") = py::cast(&ark::FP16, py::return_value_policy::reference); - m.attr("_FP32") = py::cast(&ark::FP32, py::return_value_policy::reference); - m.attr("_INT32") = - py::cast(&ark::INT32, py::return_value_policy::reference); - m.attr("_BYTE") = py::cast(&ark::BYTE, py::return_value_policy::reference); - - py::class_(m, "_TensorBuf", - "TensorBuf refers to a data array that can be " - "shared by multiple tensors.") - .def(py::init(), py::arg("bytes") = 0, - py::arg("id") = -1) - .def_readwrite("bytes", &ark::TensorBuf::bytes) - .def_readwrite("id", &ark::TensorBuf::id) - .def_readwrite("immutable", &ark::TensorBuf::immutable); - - py::class_(m, "_Tensor") - .def(py::init(), - py::arg("shape"), py::arg("type"), py::arg("buf"), - py::arg("ldims"), py::arg("offs"), py::arg("pads"), - py::arg("exported"), py::arg("imported_rank"), py::arg("id"), - py::arg("name")) - .def_property_readonly("shape", - [](const ark::Tensor &t) { - py::list shape_list; - for (int i = 0; i < t.ndims(); ++i) { - shape_list.append((int)t.shape[i]); - } - return shape_list; - }) - .def_property_readonly("ldims", - [](const ark::Tensor &t) { - py::list ldims_list; - for (int i = 0; i < t.ndims(); ++i) { - ldims_list.append((int)t.ldims[i]); - } - return ldims_list; - }) - .def_property_readonly("type", - [](const ark::Tensor &t) { return t.type; }) - .def("write", &tensor_write, py::arg("buf"), - "Copy contiguous data from a host buffer to the given tensor's " - "(possibly non-contiguous) data range.") - .def("read", &tensor_read, py::arg("buf"), - "Copy (possibly non-contiguous) data from a tensor on GPU to a " - "contiguous host buffer.") - .def("clear", &ark::Tensor::clear) - .def("offset", &ark::Tensor::offset, py::arg("i0") = 0, - py::arg("i1") = 0, py::arg("i2") = 0, py::arg("i3") = 0) - .def("size", &ark::Tensor::size, - "Number of elements in the tensor excluding padding.") - .def("ndims", &ark::Tensor::ndims, - "Number of dimensions in the tensor.") - .def("type_bytes", &ark::Tensor::type_bytes, - "Number of bytes of each element in the tensor.") - .def("shape_bytes", &ark::Tensor::shape_bytes, - "Number of bytes of the tensor.") - .def("ldims_bytes", &ark::Tensor::ldims_bytes, - "Should be the same as the number of bytes of the TensorBuf.") - .def("offset_bytes", &ark::Tensor::offset_bytes, py::arg("i0") = 0, - py::arg("i1") = 0, py::arg("i2") = 0, py::arg("i3") = 0) - .def("is_alloced", &ark::Tensor::is_alloced) - .def("is_sequential", &ark::Tensor::is_sequential); - +void register_model(py::module &m) { py::class_(m, "_Model") .def(py::init(), py::arg("rank") = 0) .def("tensor", &ark::Model::tensor, @@ -377,23 +231,4 @@ PYBIND11_MODULE(_ark_core, m) { py::return_value_policy::reference_internal, py::arg("input"), py::arg("ttype"), py::arg("output") = nullptr, py::arg("name") = "cast"); - - py::class_(m, "_Executor", - "Convenience class for executing a model.") - .def(py::init(), - py::arg("rank"), py::arg("world_size"), py::arg("model"), - py::arg("name"), py::arg("num_warps_per_sm") = 16) - .def("compile", &ark::Executor::compile, - "Compile the model. This must be called before `launch()`.") - .def("launch", &ark::Executor::launch, - "Launch the model (not running yet). This must be called after " - "`compile()`.") - .def("run", &ark::Executor::run, py::arg("iter"), - "Run the model for `iter` iterations.") - .def("wait", &ark::Executor::wait, - "Wait for the previous run to finish.") - .def("stop", &ark::Executor::stop, - "Stop the model and return the elapsed time in milliseconds. Once " - "this is called, we need to call `launch()` again to run the " - "model again."); } diff --git a/python/tensor_py.cpp b/python/tensor_py.cpp new file mode 100644 index 000000000..5783544b5 --- /dev/null +++ b/python/tensor_py.cpp @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include +#include + +#include "ark.h" + +namespace py = pybind11; + +void tensor_write(ark::Tensor *tns, py::buffer host_buffer) { + py::buffer_info info = host_buffer.request(); + tns->write(info.ptr); +} + +void tensor_read(ark::Tensor *tns, py::buffer host_buffer) { + py::buffer_info info = host_buffer.request(); + tns->read(info.ptr); +} + +void register_tensor(py::module &m) { + py::class_(m, "_TensorBuf") + .def(py::init(), py::arg("bytes") = 0, + py::arg("id") = -1) + .def_readwrite("bytes", &ark::TensorBuf::bytes) + .def_readwrite("id", &ark::TensorBuf::id) + .def_readwrite("immutable", &ark::TensorBuf::immutable); + + py::class_(m, "_Tensor") + .def(py::init(), + py::arg("shape"), py::arg("type"), py::arg("buf"), + py::arg("ldims"), py::arg("offs"), py::arg("pads"), + py::arg("exported"), py::arg("imported_rank"), py::arg("id"), + py::arg("name")) + .def_property_readonly("shape", + [](const ark::Tensor &t) { + py::list shape_list; + for (int i = 0; i < t.ndims(); ++i) { + shape_list.append((int)t.shape[i]); + } + return shape_list; + }) + .def_property_readonly("ldims", + [](const ark::Tensor &t) { + py::list ldims_list; + for (int i = 0; i < t.ndims(); ++i) { + ldims_list.append((int)t.ldims[i]); + } + return ldims_list; + }) + .def_property_readonly("type", + [](const ark::Tensor &t) { return t.type; }) + .def("write", &tensor_write, py::arg("buf")) + .def("read", &tensor_read, py::arg("buf")) + .def("clear", &ark::Tensor::clear) + .def("offset", &ark::Tensor::offset, py::arg("i0") = 0, + py::arg("i1") = 0, py::arg("i2") = 0, py::arg("i3") = 0) + .def("size", &ark::Tensor::size) + .def("ndims", &ark::Tensor::ndims) + .def("type_bytes", &ark::Tensor::type_bytes) + .def("shape_bytes", &ark::Tensor::shape_bytes) + .def("ldims_bytes", &ark::Tensor::ldims_bytes) + .def("offset_bytes", &ark::Tensor::offset_bytes, py::arg("i0") = 0, + py::arg("i1") = 0, py::arg("i2") = 0, py::arg("i3") = 0) + .def("is_alloced", &ark::Tensor::is_alloced) + .def("is_sequential", &ark::Tensor::is_sequential); +} diff --git a/python/tensor_type_py.cpp b/python/tensor_type_py.cpp new file mode 100644 index 000000000..c6a2b13dc --- /dev/null +++ b/python/tensor_type_py.cpp @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include +#include + +#include "ark.h" + +namespace py = pybind11; + +#define PY_REGISTER_TENSOR_TYPE(_type_name) \ + py::class_( \ + m, "_TensorType_" #_type_name) \ + .def(py::init<>()); \ + m.attr("_" #_type_name) = \ + py::cast(&ark::_type_name, py::return_value_policy::reference); + +void register_tensor_type(py::module &m) { + py::class_(m, "_TensorType") + .def(pybind11::self == pybind11::self) + .def(pybind11::self != pybind11::self) + .def("bytes", &ark::TensorType::bytes) + .def("name", &ark::TensorType::name); + + PY_REGISTER_TENSOR_TYPE(FP32) + PY_REGISTER_TENSOR_TYPE(FP16) + PY_REGISTER_TENSOR_TYPE(BF16) + PY_REGISTER_TENSOR_TYPE(INT32) + PY_REGISTER_TENSOR_TYPE(UINT32) + PY_REGISTER_TENSOR_TYPE(INT8) + PY_REGISTER_TENSOR_TYPE(UINT8) + PY_REGISTER_TENSOR_TYPE(BYTE) +}