Skip to content

Commit

Permalink
Fix arithmetic kernels (#131)
Browse files Browse the repository at this point in the history
Fix for cases where both inputs' W dimension length is 1.
  • Loading branch information
chhwang authored Sep 27, 2023
1 parent fa91fd8 commit 2446b18
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 17 deletions.
35 changes: 24 additions & 11 deletions ark/include/kernels/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ struct Arithmetic
const _DataType *b)
{
*c = *a + *b;
if (_In0Shape::W == 1) {
if (_In0Shape::W == 1 && _In1Shape::W == 1) {
// do nothing
} else if (_In0Shape::W == 1) {
#pragma unroll
for (int i = 1; i < NelemPerThread; ++i) {
c[i] = _ArithmeticType::compute(*a, b[i]);
Expand Down Expand Up @@ -128,18 +130,22 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, float, 2>

static DEVICE void compute(float *c, const float *a, const float *b)
{
float2 *pc = (float2 *)c;
if (_In0Shape::W == 1) {
if (_In0Shape::W == 1 && _In1Shape::W == 1) {
*c = _ArithmeticType::compute(*a, *b);
} else if (_In0Shape::W == 1) {
float2 *pb = (float2 *)b;
float2 *pc = (float2 *)c;
pc->x = _ArithmeticType::compute(*a, pb->x);
pc->y = _ArithmeticType::compute(*a, pb->y);
} else if (_In1Shape::W == 1) {
float2 *pa = (float2 *)a;
float2 *pc = (float2 *)c;
pc->x = _ArithmeticType::compute(pa->x, *b);
pc->y = _ArithmeticType::compute(pa->y, *b);
} else {
float2 *pa = (float2 *)a;
float2 *pb = (float2 *)b;
float2 *pc = (float2 *)c;
pc->x = _ArithmeticType::compute(pa->x, pb->x);
pc->y = _ArithmeticType::compute(pa->y, pb->y);
}
Expand All @@ -155,7 +161,9 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, float, 4>

static DEVICE void compute(float *c, const float *a, const float *b)
{
if (_In0Shape::W == 1) {
if (_In0Shape::W == 1 && _In1Shape::W == 1) {
*c = _ArithmeticType::compute(*a, *b);
} else if (_In0Shape::W == 1) {
longlong2 reg_b;
longlong2 reg_c;
asm volatile("ld.global.v2.u64 {%0,%1}, [%2];"
Expand Down Expand Up @@ -227,19 +235,20 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, half, 2>

static DEVICE void compute(half *c, const half *a, const half *b)
{
__half2 *pc = (__half2 *)c;
if (_In0Shape::W == 1) {
if (_In0Shape::W == 1 && _In1Shape::W == 1) {
*c = _ArithmeticType::compute(*a, *b);
} else if (_In0Shape::W == 1) {
__half2 *pb = (__half2 *)b;
*pc =
*(__half2 *)c =
_ArithmeticType::compute(__half2half2(*(const __half *)a), *pb);
} else if (_In1Shape::W == 1) {
__half2 *pa = (__half2 *)a;
*pc =
*(__half2 *)c =
_ArithmeticType::compute(*pa, __half2half2(*(const __half *)b));
} else {
__half2 *pa = (__half2 *)a;
__half2 *pb = (__half2 *)b;
*pc = _ArithmeticType::compute(*pa, *pb);
*(__half2 *)c = _ArithmeticType::compute(*pa, *pb);
}
}
};
Expand All @@ -253,7 +262,9 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, half, 4>

static DEVICE void compute(half *c, const half *a, const half *b)
{
if (_In0Shape::W == 1) {
if (_In0Shape::W == 1 && _In1Shape::W == 1) {
*c = _ArithmeticType::compute(*a, *b);
} else if (_In0Shape::W == 1) {
uint64_t reg_b = *(uint64_t *)b;
uint64_t reg_c;
__half2 *pb = (__half2 *)&reg_b;
Expand Down Expand Up @@ -294,7 +305,9 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, half, 8>

static DEVICE void compute(half *c, const half *a, const half *b)
{
if (_In0Shape::W == 1) {
if (_In0Shape::W == 1 && _In1Shape::W == 1) {
*c = _ArithmeticType::compute(*a, *b);
} else if (_In0Shape::W == 1) {
longlong2 reg_b;
longlong2 reg_c;
asm volatile("ld.global.v2.u64 {%0,%1}, [%2];"
Expand Down
47 changes: 41 additions & 6 deletions ark/ops/ops_scale_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,56 @@ void baseline_scale(std::vector<void *> &outputs,
}
};

ark::unittest::State test_scale_fp32()
{
{
ark::Model m;
ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1), ark::FP32);
ark::Tensor *out = m.scale(t, SCALE_FACTOR);

auto result = ark::op_test("scale_fp32_small", m, {t}, {out},
baseline_scale<float>);
ark::op_test_log(result);
}
{
ark::Model m;
ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::FP32);
ark::Tensor *out = m.scale(t, SCALE_FACTOR);

auto result =
ark::op_test("scale_fp32", m, {t}, {out}, baseline_scale<float>);
ark::op_test_log(result);
}
return ark::unittest::SUCCESS;
}

ark::unittest::State test_scale_fp16()
{
ark::Model m;
ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::FP16);
ark::Tensor *out = m.scale(t, SCALE_FACTOR);
{
ark::Model m;
ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1), ark::FP16);
ark::Tensor *out = m.scale(t, SCALE_FACTOR);

auto result =
ark::op_test("scale_fp16", m, {t}, {out}, baseline_scale<ark::half_t>);
ark::op_test_log(result);
auto result = ark::op_test("scale_fp16_small", m, {t}, {out},
baseline_scale<ark::half_t>);
ark::op_test_log(result);
}
{
ark::Model m;
ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::FP16);
ark::Tensor *out = m.scale(t, SCALE_FACTOR);

auto result = ark::op_test("scale_fp16", m, {t}, {out},
baseline_scale<ark::half_t>);
ark::op_test_log(result);
}
return ark::unittest::SUCCESS;
}

int main()
{
ark::init();
UNITTEST(test_scale_fp32);
UNITTEST(test_scale_fp16);
return ark::unittest::SUCCESS;
}

0 comments on commit 2446b18

Please sign in to comment.