From 2446b183ef70ede9bc3da3beaf455a88c47f7434 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Wed, 27 Sep 2023 23:47:43 +0800 Subject: [PATCH] Fix arithmetic kernels (#131) Fix for cases where both inputs' W dimension length is 1. --- ark/include/kernels/arithmetic.h | 35 ++++++++++++++++-------- ark/ops/ops_scale_test.cc | 47 ++++++++++++++++++++++++++++---- 2 files changed, 65 insertions(+), 17 deletions(-) diff --git a/ark/include/kernels/arithmetic.h b/ark/include/kernels/arithmetic.h index 6b35e8002..c09ccf023 100644 --- a/ark/include/kernels/arithmetic.h +++ b/ark/include/kernels/arithmetic.h @@ -100,7 +100,9 @@ struct Arithmetic const _DataType *b) { *c = *a + *b; - if (_In0Shape::W == 1) { + if (_In0Shape::W == 1 && _In1Shape::W == 1) { + // do nothing + } else if (_In0Shape::W == 1) { #pragma unroll for (int i = 1; i < NelemPerThread; ++i) { c[i] = _ArithmeticType::compute(*a, b[i]); @@ -128,18 +130,22 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, float, 2> static DEVICE void compute(float *c, const float *a, const float *b) { - float2 *pc = (float2 *)c; - if (_In0Shape::W == 1) { + if (_In0Shape::W == 1 && _In1Shape::W == 1) { + *c = _ArithmeticType::compute(*a, *b); + } else if (_In0Shape::W == 1) { float2 *pb = (float2 *)b; + float2 *pc = (float2 *)c; pc->x = _ArithmeticType::compute(*a, pb->x); pc->y = _ArithmeticType::compute(*a, pb->y); } else if (_In1Shape::W == 1) { float2 *pa = (float2 *)a; + float2 *pc = (float2 *)c; pc->x = _ArithmeticType::compute(pa->x, *b); pc->y = _ArithmeticType::compute(pa->y, *b); } else { float2 *pa = (float2 *)a; float2 *pb = (float2 *)b; + float2 *pc = (float2 *)c; pc->x = _ArithmeticType::compute(pa->x, pb->x); pc->y = _ArithmeticType::compute(pa->y, pb->y); } @@ -155,7 +161,9 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, float, 4> static DEVICE void compute(float *c, const float *a, const float *b) { - if (_In0Shape::W == 1) { + if (_In0Shape::W == 1 && _In1Shape::W == 1) { + *c = _ArithmeticType::compute(*a, *b); + } else if (_In0Shape::W == 1) { longlong2 reg_b; longlong2 reg_c; asm volatile("ld.global.v2.u64 {%0,%1}, [%2];" @@ -227,19 +235,20 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, half, 2> static DEVICE void compute(half *c, const half *a, const half *b) { - __half2 *pc = (__half2 *)c; - if (_In0Shape::W == 1) { + if (_In0Shape::W == 1 && _In1Shape::W == 1) { + *c = _ArithmeticType::compute(*a, *b); + } else if (_In0Shape::W == 1) { __half2 *pb = (__half2 *)b; - *pc = + *(__half2 *)c = _ArithmeticType::compute(__half2half2(*(const __half *)a), *pb); } else if (_In1Shape::W == 1) { __half2 *pa = (__half2 *)a; - *pc = + *(__half2 *)c = _ArithmeticType::compute(*pa, __half2half2(*(const __half *)b)); } else { __half2 *pa = (__half2 *)a; __half2 *pb = (__half2 *)b; - *pc = _ArithmeticType::compute(*pa, *pb); + *(__half2 *)c = _ArithmeticType::compute(*pa, *pb); } } }; @@ -253,7 +262,9 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, half, 4> static DEVICE void compute(half *c, const half *a, const half *b) { - if (_In0Shape::W == 1) { + if (_In0Shape::W == 1 && _In1Shape::W == 1) { + *c = _ArithmeticType::compute(*a, *b); + } else if (_In0Shape::W == 1) { uint64_t reg_b = *(uint64_t *)b; uint64_t reg_c; __half2 *pb = (__half2 *)®_b; @@ -294,7 +305,9 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, half, 8> static DEVICE void compute(half *c, const half *a, const half *b) { - if (_In0Shape::W == 1) { + if (_In0Shape::W == 1 && _In1Shape::W == 1) { + *c = _ArithmeticType::compute(*a, *b); + } else if (_In0Shape::W == 1) { longlong2 reg_b; longlong2 reg_c; asm volatile("ld.global.v2.u64 {%0,%1}, [%2];" diff --git a/ark/ops/ops_scale_test.cc b/ark/ops/ops_scale_test.cc index 48f096685..36c791dc1 100644 --- a/ark/ops/ops_scale_test.cc +++ b/ark/ops/ops_scale_test.cc @@ -22,21 +22,56 @@ void baseline_scale(std::vector &outputs, } }; +ark::unittest::State test_scale_fp32() +{ + { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1), ark::FP32); + ark::Tensor *out = m.scale(t, SCALE_FACTOR); + + auto result = ark::op_test("scale_fp32_small", m, {t}, {out}, + baseline_scale); + ark::op_test_log(result); + } + { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::FP32); + ark::Tensor *out = m.scale(t, SCALE_FACTOR); + + auto result = + ark::op_test("scale_fp32", m, {t}, {out}, baseline_scale); + ark::op_test_log(result); + } + return ark::unittest::SUCCESS; +} + ark::unittest::State test_scale_fp16() { - ark::Model m; - ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::FP16); - ark::Tensor *out = m.scale(t, SCALE_FACTOR); + { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1), ark::FP16); + ark::Tensor *out = m.scale(t, SCALE_FACTOR); - auto result = - ark::op_test("scale_fp16", m, {t}, {out}, baseline_scale); - ark::op_test_log(result); + auto result = ark::op_test("scale_fp16_small", m, {t}, {out}, + baseline_scale); + ark::op_test_log(result); + } + { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::FP16); + ark::Tensor *out = m.scale(t, SCALE_FACTOR); + + auto result = ark::op_test("scale_fp16", m, {t}, {out}, + baseline_scale); + ark::op_test_log(result); + } return ark::unittest::SUCCESS; } int main() { ark::init(); + UNITTEST(test_scale_fp32); UNITTEST(test_scale_fp16); return ark::unittest::SUCCESS; }