Skip to content

Commit

Permalink
Merge branch 'main' into chhwang/llama
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang authored Sep 27, 2023
2 parents b879fb5 + 2446b18 commit 0b569fe
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 39 deletions.
35 changes: 24 additions & 11 deletions ark/include/kernels/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ struct Arithmetic
const _DataType *b)
{
*c = *a + *b;
if (_In0Shape::W == 1) {
if (_In0Shape::W == 1 && _In1Shape::W == 1) {
// do nothing
} else if (_In0Shape::W == 1) {
#pragma unroll
for (int i = 1; i < NelemPerThread; ++i) {
c[i] = _ArithmeticType::compute(*a, b[i]);
Expand Down Expand Up @@ -128,18 +130,22 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, float, 2>

static DEVICE void compute(float *c, const float *a, const float *b)
{
float2 *pc = (float2 *)c;
if (_In0Shape::W == 1) {
if (_In0Shape::W == 1 && _In1Shape::W == 1) {
*c = _ArithmeticType::compute(*a, *b);
} else if (_In0Shape::W == 1) {
float2 *pb = (float2 *)b;
float2 *pc = (float2 *)c;
pc->x = _ArithmeticType::compute(*a, pb->x);
pc->y = _ArithmeticType::compute(*a, pb->y);
} else if (_In1Shape::W == 1) {
float2 *pa = (float2 *)a;
float2 *pc = (float2 *)c;
pc->x = _ArithmeticType::compute(pa->x, *b);
pc->y = _ArithmeticType::compute(pa->y, *b);
} else {
float2 *pa = (float2 *)a;
float2 *pb = (float2 *)b;
float2 *pc = (float2 *)c;
pc->x = _ArithmeticType::compute(pa->x, pb->x);
pc->y = _ArithmeticType::compute(pa->y, pb->y);
}
Expand All @@ -155,7 +161,9 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, float, 4>

static DEVICE void compute(float *c, const float *a, const float *b)
{
if (_In0Shape::W == 1) {
if (_In0Shape::W == 1 && _In1Shape::W == 1) {
*c = _ArithmeticType::compute(*a, *b);
} else if (_In0Shape::W == 1) {
longlong2 reg_b;
longlong2 reg_c;
asm volatile("ld.global.v2.u64 {%0,%1}, [%2];"
Expand Down Expand Up @@ -227,19 +235,20 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, half, 2>

static DEVICE void compute(half *c, const half *a, const half *b)
{
__half2 *pc = (__half2 *)c;
if (_In0Shape::W == 1) {
if (_In0Shape::W == 1 && _In1Shape::W == 1) {
*c = _ArithmeticType::compute(*a, *b);
} else if (_In0Shape::W == 1) {
__half2 *pb = (__half2 *)b;
*pc =
*(__half2 *)c =
_ArithmeticType::compute(__half2half2(*(const __half *)a), *pb);
} else if (_In1Shape::W == 1) {
__half2 *pa = (__half2 *)a;
*pc =
*(__half2 *)c =
_ArithmeticType::compute(*pa, __half2half2(*(const __half *)b));
} else {
__half2 *pa = (__half2 *)a;
__half2 *pb = (__half2 *)b;
*pc = _ArithmeticType::compute(*pa, *pb);
*(__half2 *)c = _ArithmeticType::compute(*pa, *pb);
}
}
};
Expand All @@ -253,7 +262,9 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, half, 4>

static DEVICE void compute(half *c, const half *a, const half *b)
{
if (_In0Shape::W == 1) {
if (_In0Shape::W == 1 && _In1Shape::W == 1) {
*c = _ArithmeticType::compute(*a, *b);
} else if (_In0Shape::W == 1) {
uint64_t reg_b = *(uint64_t *)b;
uint64_t reg_c;
__half2 *pb = (__half2 *)&reg_b;
Expand Down Expand Up @@ -294,7 +305,9 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, half, 8>

static DEVICE void compute(half *c, const half *a, const half *b)
{
if (_In0Shape::W == 1) {
if (_In0Shape::W == 1 && _In1Shape::W == 1) {
*c = _ArithmeticType::compute(*a, *b);
} else if (_In0Shape::W == 1) {
longlong2 reg_b;
longlong2 reg_c;
asm volatile("ld.global.v2.u64 {%0,%1}, [%2];"
Expand Down
2 changes: 1 addition & 1 deletion ark/ops/ops_matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ std::string MatmulOp::function_name(const OpConfig &cfg) const
leading_dims[2] = ldims_y[ldims_y.ndims() - 1];
leading_dims[3] = ldims_b[ndims_b - 1];

DimType in_ldim_a = ldims_a[ndims_a - 2];
DimType in_ldim_a = ldims_a[ndims_a - 1];
DimType in_ldim_b = ldims_b[ndims_b - 2];

// TODO: verify `leading_dims`
Expand Down
36 changes: 15 additions & 21 deletions ark/ops/ops_matmul_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void cublas_matmul_float_nn(int m, int n, int k, const float *a, int lda,
b, ldb, a, lda, &beta, c, ldc);
} else {
status = cublasSgemmStridedBatched(
cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, b, ldb, n * m,
cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, b, ldb, n * k,
a, lda, k * m, &beta, c, ldc, n * m, batch_size);
}
if (status != CUBLAS_STATUS_SUCCESS) {
Expand All @@ -54,7 +54,7 @@ void cublas_matmul_float_nt(int m, int n, int k, const float *a, int lda,
b, ldb, a, lda, &beta, c, ldc);
} else {
status = cublasSgemmStridedBatched(
cublasH, CUBLAS_OP_T, CUBLAS_OP_N, n, m, k, &alpha, b, ldb, n * m,
cublasH, CUBLAS_OP_T, CUBLAS_OP_N, n, m, k, &alpha, b, ldb, n * k,
a, lda, k * m, &beta, c, ldc, n * m, batch_size);
}
if (status != CUBLAS_STATUS_SUCCESS) {
Expand All @@ -75,7 +75,7 @@ void cublas_matmul_float_tn(int m, int n, int k, const float *a, int lda,
b, ldb, a, lda, &beta, c, ldc);
} else {
status = cublasSgemmStridedBatched(
cublasH, CUBLAS_OP_N, CUBLAS_OP_T, n, m, k, &alpha, b, ldb, n * m,
cublasH, CUBLAS_OP_N, CUBLAS_OP_T, n, m, k, &alpha, b, ldb, n * k,
a, lda, k * m, &beta, c, ldc, n * m, batch_size);
}
if (status != CUBLAS_STATUS_SUCCESS) {
Expand All @@ -96,7 +96,7 @@ void cublas_matmul_float_tt(int m, int n, int k, const float *a, int lda,
b, ldb, a, lda, &beta, c, ldc);
} else {
status = cublasSgemmStridedBatched(
cublasH, CUBLAS_OP_T, CUBLAS_OP_T, n, m, k, &alpha, b, ldb, n * m,
cublasH, CUBLAS_OP_T, CUBLAS_OP_T, n, m, k, &alpha, b, ldb, n * k,
a, lda, k * m, &beta, c, ldc, n * m, batch_size);
}
if (status != CUBLAS_STATUS_SUCCESS) {
Expand All @@ -117,7 +117,7 @@ void cublas_matmul_half_nn(int m, int n, int k, const half *a, int lda,
b, ldb, a, lda, &beta, c, ldc);
} else {
status = cublasHgemmStridedBatched(
cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, b, ldb, n * m,
cublasH, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, b, ldb, n * k,
a, lda, k * m, &beta, c, ldc, n * m, batch_size);
}
if (status != CUBLAS_STATUS_SUCCESS) {
Expand All @@ -138,7 +138,7 @@ void cublas_matmul_half_nt(int m, int n, int k, const half *a, int lda,
b, ldb, a, lda, &beta, c, ldc);
} else {
status = cublasHgemmStridedBatched(
cublasH, CUBLAS_OP_T, CUBLAS_OP_N, n, m, k, &alpha, b, ldb, n * m,
cublasH, CUBLAS_OP_T, CUBLAS_OP_N, n, m, k, &alpha, b, ldb, n * k,
a, lda, k * m, &beta, c, ldc, n * m, batch_size);
}
if (status != CUBLAS_STATUS_SUCCESS) {
Expand All @@ -159,7 +159,7 @@ void cublas_matmul_half_tn(int m, int n, int k, const half *a, int lda,
b, ldb, a, lda, &beta, c, ldc);
} else {
status = cublasHgemmStridedBatched(
cublasH, CUBLAS_OP_N, CUBLAS_OP_T, n, m, k, &alpha, b, ldb, n * m,
cublasH, CUBLAS_OP_N, CUBLAS_OP_T, n, m, k, &alpha, b, ldb, n * k,
a, lda, k * m, &beta, c, ldc, n * m, batch_size);
}
if (status != CUBLAS_STATUS_SUCCESS) {
Expand All @@ -180,7 +180,7 @@ void cublas_matmul_half_tt(int m, int n, int k, const half *a, int lda,
b, ldb, a, lda, &beta, c, ldc);
} else {
status = cublasHgemmStridedBatched(
cublasH, CUBLAS_OP_T, CUBLAS_OP_T, n, m, k, &alpha, b, ldb, n * m,
cublasH, CUBLAS_OP_T, CUBLAS_OP_T, n, m, k, &alpha, b, ldb, n * k,
a, lda, k * m, &beta, c, ldc, n * m, batch_size);
}
if (status != CUBLAS_STATUS_SUCCESS) {
Expand Down Expand Up @@ -552,31 +552,25 @@ ark::unittest::State test_matmul_tt()
ark::unittest::State test_matmul_batched()
{
ark::Model m;
ark::Tensor *a = m.tensor(ark::Dims(2, 64, 64), ark::FP16);
ark::Tensor *b = m.tensor(ark::Dims(2, 64, 64), ark::FP16);
ark::Tensor *a = m.tensor(ark::Dims(3, 7, 64, 128), ark::FP16);
ark::Tensor *b = m.tensor(ark::Dims(3, 7, 128, 256), ark::FP16);
ark::Tensor *c = m.matmul(a, b);

auto ones_a = ark::utils::ones<ark::half_t>(a->shape.size());
auto ones_b = ark::utils::ones<ark::half_t>(b->shape.size());
auto result =
ark::op_test("matmul_batched", m, {a, b}, {c}, baseline_matmul_nn<half>,
{ones_a.get(), ones_b.get()}, true);
auto result = ark::op_test("matmul_batched", m, {a, b}, {c},
baseline_matmul_nn<half>);
ark::op_test_log(result);
return ark::unittest::SUCCESS;
}

ark::unittest::State test_matmul_batched_padded()
{
ark::Model m;
ark::Tensor *a = m.tensor(ark::Dims(2, 1, 128), ark::FP16);
ark::Tensor *b = m.tensor(ark::Dims(2, 128, 1), ark::FP16);
ark::Tensor *a = m.tensor(ark::Dims(3, 7, 2, 9), ark::FP16);
ark::Tensor *b = m.tensor(ark::Dims(3, 7, 9, 2), ark::FP16);
ark::Tensor *c = m.matmul(a, b);

auto ones_a = ark::utils::ones<ark::half_t>(a->shape.size());
auto ones_b = ark::utils::ones<ark::half_t>(b->shape.size());
auto result = ark::op_test("matmul_batched_padded", m, {a, b}, {c},
baseline_matmul_nn<half>,
{ones_a.get(), ones_b.get()}, true);
baseline_matmul_nn<half>);
ark::op_test_log(result);
return ark::unittest::SUCCESS;
}
Expand Down
47 changes: 41 additions & 6 deletions ark/ops/ops_scale_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,56 @@ void baseline_scale(std::vector<void *> &outputs,
}
};

ark::unittest::State test_scale_fp32()
{
{
ark::Model m;
ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1), ark::FP32);
ark::Tensor *out = m.scale(t, SCALE_FACTOR);

auto result = ark::op_test("scale_fp32_small", m, {t}, {out},
baseline_scale<float>);
ark::op_test_log(result);
}
{
ark::Model m;
ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::FP32);
ark::Tensor *out = m.scale(t, SCALE_FACTOR);

auto result =
ark::op_test("scale_fp32", m, {t}, {out}, baseline_scale<float>);
ark::op_test_log(result);
}
return ark::unittest::SUCCESS;
}

ark::unittest::State test_scale_fp16()
{
ark::Model m;
ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::FP16);
ark::Tensor *out = m.scale(t, SCALE_FACTOR);
{
ark::Model m;
ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1), ark::FP16);
ark::Tensor *out = m.scale(t, SCALE_FACTOR);

auto result =
ark::op_test("scale_fp16", m, {t}, {out}, baseline_scale<ark::half_t>);
ark::op_test_log(result);
auto result = ark::op_test("scale_fp16_small", m, {t}, {out},
baseline_scale<ark::half_t>);
ark::op_test_log(result);
}
{
ark::Model m;
ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::FP16);
ark::Tensor *out = m.scale(t, SCALE_FACTOR);

auto result = ark::op_test("scale_fp16", m, {t}, {out},
baseline_scale<ark::half_t>);
ark::op_test_log(result);
}
return ark::unittest::SUCCESS;
}

int main()
{
ark::init();
UNITTEST(test_scale_fp32);
UNITTEST(test_scale_fp16);
return ark::unittest::SUCCESS;
}

0 comments on commit 0b569fe

Please sign in to comment.