From fa91fd84fce58ba75c0a3ef5712cf275f6f51aa6 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Wed, 27 Sep 2023 22:01:53 +0800 Subject: [PATCH 1/2] Fix #129 (#130) --- ark/ops/ops_matmul.cc | 2 +- ark/ops/ops_matmul_test.cu | 36 +++++++++++++++--------------------- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/ark/ops/ops_matmul.cc b/ark/ops/ops_matmul.cc index 37a15422a..100429355 100644 --- a/ark/ops/ops_matmul.cc +++ b/ark/ops/ops_matmul.cc @@ -67,7 +67,7 @@ std::string MatmulOp::function_name(const OpConfig &cfg) const leading_dims[2] = ldims_y[ldims_y.ndims() - 1]; leading_dims[3] = ldims_b[ndims_b - 1]; - DimType in_ldim_a = ldims_a[ndims_a - 2]; + DimType in_ldim_a = ldims_a[ndims_a - 1]; DimType in_ldim_b = ldims_b[ndims_b - 2]; // TODO: verify `leading_dims` diff --git a/ark/ops/ops_matmul_test.cu b/ark/ops/ops_matmul_test.cu index cd3738d1b..6e751e29a 100644 --- a/ark/ops/ops_matmul_test.cu +++ b/ark/ops/ops_matmul_test.cu @@ -33,7 +33,7 @@ void cublas_matmul_float_nn(int m, int n, int k, const float *a, int lda, b, ldb, a, lda, &beta, c, ldc); } else { status = cublasSgemmStridedBatched( - cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, b, ldb, n * m, + cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, b, ldb, n * k, a, lda, k * m, &beta, c, ldc, n * m, batch_size); } if (status != CUBLAS_STATUS_SUCCESS) { @@ -54,7 +54,7 @@ void cublas_matmul_float_nt(int m, int n, int k, const float *a, int lda, b, ldb, a, lda, &beta, c, ldc); } else { status = cublasSgemmStridedBatched( - cublasH, CUBLAS_OP_T, CUBLAS_OP_N, n, m, k, &alpha, b, ldb, n * m, + cublasH, CUBLAS_OP_T, CUBLAS_OP_N, n, m, k, &alpha, b, ldb, n * k, a, lda, k * m, &beta, c, ldc, n * m, batch_size); } if (status != CUBLAS_STATUS_SUCCESS) { @@ -75,7 +75,7 @@ void cublas_matmul_float_tn(int m, int n, int k, const float *a, int lda, b, ldb, a, lda, &beta, c, ldc); } else { status = cublasSgemmStridedBatched( - cublasH, CUBLAS_OP_N, CUBLAS_OP_T, n, m, k, &alpha, b, ldb, n * m, + cublasH, CUBLAS_OP_N, CUBLAS_OP_T, n, m, k, &alpha, b, ldb, n * k, a, lda, k * m, &beta, c, ldc, n * m, batch_size); } if (status != CUBLAS_STATUS_SUCCESS) { @@ -96,7 +96,7 @@ void cublas_matmul_float_tt(int m, int n, int k, const float *a, int lda, b, ldb, a, lda, &beta, c, ldc); } else { status = cublasSgemmStridedBatched( - cublasH, CUBLAS_OP_T, CUBLAS_OP_T, n, m, k, &alpha, b, ldb, n * m, + cublasH, CUBLAS_OP_T, CUBLAS_OP_T, n, m, k, &alpha, b, ldb, n * k, a, lda, k * m, &beta, c, ldc, n * m, batch_size); } if (status != CUBLAS_STATUS_SUCCESS) { @@ -117,7 +117,7 @@ void cublas_matmul_half_nn(int m, int n, int k, const half *a, int lda, b, ldb, a, lda, &beta, c, ldc); } else { status = cublasHgemmStridedBatched( - cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, b, ldb, n * m, + cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, b, ldb, n * k, a, lda, k * m, &beta, c, ldc, n * m, batch_size); } if (status != CUBLAS_STATUS_SUCCESS) { @@ -138,7 +138,7 @@ void cublas_matmul_half_nt(int m, int n, int k, const half *a, int lda, b, ldb, a, lda, &beta, c, ldc); } else { status = cublasHgemmStridedBatched( - cublasH, CUBLAS_OP_T, CUBLAS_OP_N, n, m, k, &alpha, b, ldb, n * m, + cublasH, CUBLAS_OP_T, CUBLAS_OP_N, n, m, k, &alpha, b, ldb, n * k, a, lda, k * m, &beta, c, ldc, n * m, batch_size); } if (status != CUBLAS_STATUS_SUCCESS) { @@ -159,7 +159,7 @@ void cublas_matmul_half_tn(int m, int n, int k, const half *a, int lda, b, ldb, a, lda, &beta, c, ldc); } else { status = cublasHgemmStridedBatched( - cublasH, CUBLAS_OP_N, CUBLAS_OP_T, n, m, k, &alpha, b, ldb, n * m, + cublasH, CUBLAS_OP_N, CUBLAS_OP_T, n, m, k, &alpha, b, ldb, n * k, a, lda, k * m, &beta, c, ldc, n * m, batch_size); } if (status != CUBLAS_STATUS_SUCCESS) { @@ -180,7 +180,7 @@ void cublas_matmul_half_tt(int m, int n, int k, const half *a, int lda, b, ldb, a, lda, &beta, c, ldc); } else { status = cublasHgemmStridedBatched( - cublasH, CUBLAS_OP_T, CUBLAS_OP_T, n, m, k, &alpha, b, ldb, n * m, + cublasH, CUBLAS_OP_T, CUBLAS_OP_T, n, m, k, &alpha, b, ldb, n * k, a, lda, k * m, &beta, c, ldc, n * m, batch_size); } if (status != CUBLAS_STATUS_SUCCESS) { @@ -552,15 +552,12 @@ ark::unittest::State test_matmul_tt() ark::unittest::State test_matmul_batched() { ark::Model m; - ark::Tensor *a = m.tensor(ark::Dims(2, 64, 64), ark::FP16); - ark::Tensor *b = m.tensor(ark::Dims(2, 64, 64), ark::FP16); + ark::Tensor *a = m.tensor(ark::Dims(3, 7, 64, 128), ark::FP16); + ark::Tensor *b = m.tensor(ark::Dims(3, 7, 128, 256), ark::FP16); ark::Tensor *c = m.matmul(a, b); - auto ones_a = ark::utils::ones(a->shape.size()); - auto ones_b = ark::utils::ones(b->shape.size()); - auto result = - ark::op_test("matmul_batched", m, {a, b}, {c}, baseline_matmul_nn, - {ones_a.get(), ones_b.get()}, true); + auto result = ark::op_test("matmul_batched", m, {a, b}, {c}, + baseline_matmul_nn); ark::op_test_log(result); return ark::unittest::SUCCESS; } @@ -568,15 +565,12 @@ ark::unittest::State test_matmul_batched() ark::unittest::State test_matmul_batched_padded() { ark::Model m; - ark::Tensor *a = m.tensor(ark::Dims(2, 1, 128), ark::FP16); - ark::Tensor *b = m.tensor(ark::Dims(2, 128, 1), ark::FP16); + ark::Tensor *a = m.tensor(ark::Dims(3, 7, 2, 9), ark::FP16); + ark::Tensor *b = m.tensor(ark::Dims(3, 7, 9, 2), ark::FP16); ark::Tensor *c = m.matmul(a, b); - auto ones_a = ark::utils::ones(a->shape.size()); - auto ones_b = ark::utils::ones(b->shape.size()); auto result = ark::op_test("matmul_batched_padded", m, {a, b}, {c}, - baseline_matmul_nn, - {ones_a.get(), ones_b.get()}, true); + baseline_matmul_nn); ark::op_test_log(result); return ark::unittest::SUCCESS; } From 2446b183ef70ede9bc3da3beaf455a88c47f7434 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Wed, 27 Sep 2023 23:47:43 +0800 Subject: [PATCH 2/2] Fix arithmetic kernels (#131) Fix for cases where both inputs' W dimension length is 1. --- ark/include/kernels/arithmetic.h | 35 ++++++++++++++++-------- ark/ops/ops_scale_test.cc | 47 ++++++++++++++++++++++++++++---- 2 files changed, 65 insertions(+), 17 deletions(-) diff --git a/ark/include/kernels/arithmetic.h b/ark/include/kernels/arithmetic.h index 6b35e8002..c09ccf023 100644 --- a/ark/include/kernels/arithmetic.h +++ b/ark/include/kernels/arithmetic.h @@ -100,7 +100,9 @@ struct Arithmetic const _DataType *b) { *c = *a + *b; - if (_In0Shape::W == 1) { + if (_In0Shape::W == 1 && _In1Shape::W == 1) { + // do nothing + } else if (_In0Shape::W == 1) { #pragma unroll for (int i = 1; i < NelemPerThread; ++i) { c[i] = _ArithmeticType::compute(*a, b[i]); @@ -128,18 +130,22 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, float, 2> static DEVICE void compute(float *c, const float *a, const float *b) { - float2 *pc = (float2 *)c; - if (_In0Shape::W == 1) { + if (_In0Shape::W == 1 && _In1Shape::W == 1) { + *c = _ArithmeticType::compute(*a, *b); + } else if (_In0Shape::W == 1) { float2 *pb = (float2 *)b; + float2 *pc = (float2 *)c; pc->x = _ArithmeticType::compute(*a, pb->x); pc->y = _ArithmeticType::compute(*a, pb->y); } else if (_In1Shape::W == 1) { float2 *pa = (float2 *)a; + float2 *pc = (float2 *)c; pc->x = _ArithmeticType::compute(pa->x, *b); pc->y = _ArithmeticType::compute(pa->y, *b); } else { float2 *pa = (float2 *)a; float2 *pb = (float2 *)b; + float2 *pc = (float2 *)c; pc->x = _ArithmeticType::compute(pa->x, pb->x); pc->y = _ArithmeticType::compute(pa->y, pb->y); } @@ -155,7 +161,9 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, float, 4> static DEVICE void compute(float *c, const float *a, const float *b) { - if (_In0Shape::W == 1) { + if (_In0Shape::W == 1 && _In1Shape::W == 1) { + *c = _ArithmeticType::compute(*a, *b); + } else if (_In0Shape::W == 1) { longlong2 reg_b; longlong2 reg_c; asm volatile("ld.global.v2.u64 {%0,%1}, [%2];" @@ -227,19 +235,20 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, half, 2> static DEVICE void compute(half *c, const half *a, const half *b) { - __half2 *pc = (__half2 *)c; - if (_In0Shape::W == 1) { + if (_In0Shape::W == 1 && _In1Shape::W == 1) { + *c = _ArithmeticType::compute(*a, *b); + } else if (_In0Shape::W == 1) { __half2 *pb = (__half2 *)b; - *pc = + *(__half2 *)c = _ArithmeticType::compute(__half2half2(*(const __half *)a), *pb); } else if (_In1Shape::W == 1) { __half2 *pa = (__half2 *)a; - *pc = + *(__half2 *)c = _ArithmeticType::compute(*pa, __half2half2(*(const __half *)b)); } else { __half2 *pa = (__half2 *)a; __half2 *pb = (__half2 *)b; - *pc = _ArithmeticType::compute(*pa, *pb); + *(__half2 *)c = _ArithmeticType::compute(*pa, *pb); } } }; @@ -253,7 +262,9 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, half, 4> static DEVICE void compute(half *c, const half *a, const half *b) { - if (_In0Shape::W == 1) { + if (_In0Shape::W == 1 && _In1Shape::W == 1) { + *c = _ArithmeticType::compute(*a, *b); + } else if (_In0Shape::W == 1) { uint64_t reg_b = *(uint64_t *)b; uint64_t reg_c; __half2 *pb = (__half2 *)®_b; @@ -294,7 +305,9 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, half, 8> static DEVICE void compute(half *c, const half *a, const half *b) { - if (_In0Shape::W == 1) { + if (_In0Shape::W == 1 && _In1Shape::W == 1) { + *c = _ArithmeticType::compute(*a, *b); + } else if (_In0Shape::W == 1) { longlong2 reg_b; longlong2 reg_c; asm volatile("ld.global.v2.u64 {%0,%1}, [%2];" diff --git a/ark/ops/ops_scale_test.cc b/ark/ops/ops_scale_test.cc index 48f096685..36c791dc1 100644 --- a/ark/ops/ops_scale_test.cc +++ b/ark/ops/ops_scale_test.cc @@ -22,21 +22,56 @@ void baseline_scale(std::vector &outputs, } }; +ark::unittest::State test_scale_fp32() +{ + { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1), ark::FP32); + ark::Tensor *out = m.scale(t, SCALE_FACTOR); + + auto result = ark::op_test("scale_fp32_small", m, {t}, {out}, + baseline_scale); + ark::op_test_log(result); + } + { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::FP32); + ark::Tensor *out = m.scale(t, SCALE_FACTOR); + + auto result = + ark::op_test("scale_fp32", m, {t}, {out}, baseline_scale); + ark::op_test_log(result); + } + return ark::unittest::SUCCESS; +} + ark::unittest::State test_scale_fp16() { - ark::Model m; - ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::FP16); - ark::Tensor *out = m.scale(t, SCALE_FACTOR); + { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1), ark::FP16); + ark::Tensor *out = m.scale(t, SCALE_FACTOR); - auto result = - ark::op_test("scale_fp16", m, {t}, {out}, baseline_scale); - ark::op_test_log(result); + auto result = ark::op_test("scale_fp16_small", m, {t}, {out}, + baseline_scale); + ark::op_test_log(result); + } + { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::FP16); + ark::Tensor *out = m.scale(t, SCALE_FACTOR); + + auto result = ark::op_test("scale_fp16", m, {t}, {out}, + baseline_scale); + ark::op_test_log(result); + } return ark::unittest::SUCCESS; } int main() { ark::init(); + UNITTEST(test_scale_fp32); UNITTEST(test_scale_fp16); return ark::unittest::SUCCESS; }