Skip to content

Commit

Permalink
Add the embedding layer
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Sep 18, 2023
1 parent be491a8 commit 1cc9656
Show file tree
Hide file tree
Showing 16 changed files with 435 additions and 67 deletions.
3 changes: 3 additions & 0 deletions ark/include/ark.h
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,9 @@ class Model
std::vector<Tensor *> all_gather(Tensor *input, int gpu_id, int gpu_num,
const std::vector<Tensor *> &output = {},
const std::string &name = "all_gather");
/// Embedding layer.
Tensor *embedding(Tensor *input, Tensor *weight, Tensor *output = nullptr,
const std::string &name = "embedding");

/// Verify if this model is valid.
/// @return true if the model is valid, false otherwise.
Expand Down
66 changes: 66 additions & 0 deletions ark/include/kernels/embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,72 @@ DEVICE void rope(half *c, half *a, half *b, int uop_idx, int)
uop_idx);
}

// Embedding

template <typename _DataType> struct Assign
{
using DataType = _DataType;
static const int NelemPerThread = 1;
static DEVICE void compute(DataType *c, const DataType *a)
{
*c = *a;
}
};

template <typename InDims, typename InShape, typename WeightDims,
typename WeightShape, typename OutDims, typename OutShape,
int EmbeddingDim, int NumThreads>
DEVICE void embedding(float *output, int *input, float *weight,
int uop_idx, int)
{
// InShape: Vec<D0, D1, D2, 1>
// WeightShape: Vec< 1, 1, ?, EmbeddingDim> (?: # of embeddings)
// OutShape: Vec<D0, D1, D2, EmbeddingDim>

static_assert(InShape::W == 1, "");

using UnitOutDims = Vec<1, 1, 1, OutDims::W>;
using UnitOp = UnitOp<OutDims, OutShape, UnitOutDims, NumThreads, 0>;
int un = UnitOp::uop_idx_n(uop_idx);
int uc = UnitOp::uop_idx_c(uop_idx);
int uh = UnitOp::uop_idx_h(uop_idx);

// pWeight: Vec<1, 1, 1, EmbeddingDim>
int emb_idx = input[un * InDims::CH + uc * InDims::H + uh];
float *pWeight = &weight[emb_idx * WeightDims::W];

Broadcast1<Vec<1, 1, 1, WeightDims::W>, Vec<1, 1, 1, EmbeddingDim>, OutDims,
OutShape, UnitOutDims, NumThreads, 0,
Assign<float>>::run(output, pWeight, uop_idx);
}

template <typename InDims, typename InShape, typename WeightDims,
typename WeightShape, typename OutDims, typename OutShape,
int EmbeddingDim, int NumThreads>
DEVICE void embedding(half *output, int *input, half *weight,
int uop_idx, int)
{
// InShape: Vec<D0, D1, D2, 1>
// WeightShape: Vec< 1, 1, ?, EmbeddingDim> (?: # of embeddings)
// OutShape: Vec<D0, D1, D2, EmbeddingDim>

static_assert(InShape::W == 1, "");

using UnitOutDims = Vec<1, 1, 1, OutDims::W>;
using UnitOp = UnitOp<OutDims, OutShape, UnitOutDims, NumThreads, 0>;
int un = UnitOp::uop_idx_n(uop_idx);
int uc = UnitOp::uop_idx_c(uop_idx);
int uh = UnitOp::uop_idx_h(uop_idx);

// pWeight: Vec<1, 1, 1, EmbeddingDim>
int emb_idx = input[un * InDims::CH + uc * InDims::H + uh];
half *pWeight = &weight[emb_idx * WeightDims::W];

Broadcast1<Vec<1, 1, 1, WeightDims::W>, Vec<1, 1, 1, EmbeddingDim>, OutDims,
OutShape, UnitOutDims, NumThreads, 0,
Assign<half>>::run(output, pWeight, uop_idx);
}

} // namespace ark

#endif // ARK_KERNELS_EMBEDDING_H_
5 changes: 3 additions & 2 deletions ark/ops/ops_all_gather_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ void test_all_gather_4gpus_internal(size_t nelem, int iter)
ark::Tensor *data = m.scale(ones, float(gpu_id + 1));
auto outputs = m.all_gather(data, gpu_id, num_gpus);

auto ones_data = ark::utils::ones<ark::half_t>(ones->shape.size());
auto result =
ark::op_test("all_gather", m, {ones}, outputs,
baseline_all_gather<ark::half_t, num_gpus>, "ones",
true, gpu_id, num_gpus);
baseline_all_gather<ark::half_t, num_gpus>,
{ones_data.get()}, true, gpu_id, num_gpus);
ark::op_test_log(result);
return ark::unittest::SUCCESS;
});
Expand Down
5 changes: 3 additions & 2 deletions ark/ops/ops_all_reduce_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ void test_all_reduce_4gpus_internal(size_t nelem, int iter)
ark::Tensor *data = m.scale(ones, float(gpu_id + 1));
ark::Tensor *output = m.all_reduce(data, gpu_id, num_gpus);

auto ones_data = ark::utils::ones<ark::half_t>(ones->shape.size());
auto result =
ark::op_test("all_reduce", m, {ones}, {output},
baseline_all_reduce<ark::half_t, num_gpus>, "ones",
true, gpu_id, num_gpus);
baseline_all_reduce<ark::half_t, num_gpus>,
{ones_data.get()}, true, gpu_id, num_gpus);
ark::op_test_log(result);
return ark::unittest::SUCCESS;
});
Expand Down
3 changes: 3 additions & 0 deletions ark/ops/ops_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ ostream &operator<<(ostream &os, const OpType &s)
case OP_EXP: os << "OP_EXP"; break;
case OP_SQRT: os << "OP_SQRT"; break;
case OP_ROPE: os << "OP_ROPE"; break;
case OP_EMBEDDING: os << "OP_EMBEDDING"; break;
}
// clang-format on
return os;
Expand Down Expand Up @@ -505,6 +506,8 @@ std::string Op::function_name(const OpConfig &cfg) const
return static_cast<const SqrtOp *>(this)->function_name(cfg);
case OP_ROPE:
return static_cast<const RopeOp *>(this)->function_name(cfg);
case OP_EMBEDDING:
return static_cast<const EmbeddingOp *>(this)->function_name(cfg);
default:
return "";
}
Expand Down
9 changes: 9 additions & 0 deletions ark/ops/ops_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ typedef enum
OP_RECV,
OP_SEND_MM,
OP_RECV_MM,
OP_EMBEDDING,
} OpType;

/// Type of precision of @ref Op.
Expand Down Expand Up @@ -548,6 +549,14 @@ class TransposeOp : public Op
std::string function_name(const OpConfig &cfg) const;
};

class EmbeddingOp : public Op
{
public:
EmbeddingOp(OpPrecType prec_type, Tensor *input, Tensor *weight,
Tensor *output, const std::string &name);
std::string function_name(const OpConfig &cfg) const;
};

} // namespace ark

#endif // ARK_OPS_COMMON_H_
94 changes: 94 additions & 0 deletions ark/ops/ops_embedding.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include "logging.h"
#include "model.h"
#include <cassert>

namespace ark {

extern const OpConfigMap EmbeddingConfigMap;

EmbeddingOp::EmbeddingOp(OpPrecType prec_type, Tensor *input, Tensor *weight,
Tensor *output, const std::string &name)
: Op{OP_EMBEDDING, prec_type, {input, weight}, {output},
{}, name, &EmbeddingConfigMap, -1,
true}
{
}

std::string EmbeddingOp::function_name(const OpConfig &cfg) const
{
Tensor *input = this->inputs[0];
Tensor *weight = this->inputs[1];
Tensor *output = this->outputs[0];

auto in_dims = input->ldims.dims4();
auto in_shape = input->shape.dims4();

assert(in_dims[0] == 1);
assert(in_shape[0] == 1);

Dims new_in_dims{in_dims[1], in_dims[2], in_dims[3], 1};
Dims new_in_shape{in_shape[1], in_shape[2], in_shape[3], 1};

int emb_dim = weight->shape[-1];
return Op::function_name("ark::embedding",
{{
new_in_dims, // InDims
new_in_shape, // InShape
weight->ldims.dims4(), // WeightDims
weight->shape.dims4(), // WeightShape
output->ldims.dims4(), // OutDims
output->shape.dims4(), // OutShape
emb_dim, // EmbeddingDim
cfg.num_warps * 32, // NumThreads
}});
}

Tensor *Model::embedding(Tensor *input, Tensor *weight, Tensor *output,
const std::string &name)
{
assert(input != nullptr);
assert(weight != nullptr);
if (input->shape.ndims() > 3) {
LOG(ERROR, "input shape ndims > 3: ", input->shape);
}
if (weight->shape.ndims() != 2) {
LOG(ERROR, "weight shape ndims != 2: ", weight->shape);
}
OpPrecType pt = OP_PREC_NONE;
if (weight->type == FP16) {
pt = OP_PREC_FP16;
} else if (weight->type == FP32) {
pt = OP_PREC_FP32;
} else {
LOG(ERROR, "unsupported weight data type: ", weight->type);
}
auto emb_dim = weight->shape[-1];

std::vector<DimType> output_dims;
for (int i = 0; i < input->shape.ndims(); ++i) {
output_dims.push_back(input->shape[i]);
}
output_dims.push_back(emb_dim);
Dims out_shape{output_dims};
if (output == nullptr) {
output = this->tensor(out_shape, weight->type);
}
EmbeddingOp op{pt, input, weight, output, name};
return this->impl->add_op(op)[0];
}

const OpConfigMap EmbeddingConfigMap = {
{{OP_ARCH_CUDA_ANY, OP_PREC_ANY},
{
// NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost
{1, 0, {{1, 1}, {1, -1}}, {{1, -1}}, true, false},
{2, 0, {{1, 1}, {1, -1}}, {{1, -1}}, true, false},
{4, 0, {{1, 1}, {1, -1}}, {{1, -1}}, true, false},
{8, 0, {{1, 1}, {1, -1}}, {{1, -1}}, true, false},
}},
};

} // namespace ark
89 changes: 89 additions & 0 deletions ark/ops/ops_embedding_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include "include/ark.h"
#include "include/ark_utils.h"
#include "ops_test_common.h"
#include "unittest/unittest_utils.h"
#include <cassert>
#include <type_traits>

template <typename T>
void baseline_embedding(std::vector<void *> &outputs,
const std::vector<ark::Dims> &output_shapes,
const std::vector<void *> &inputs,
const std::vector<ark::Dims> &input_shapes)
{
T *out = static_cast<T *>(outputs[0]);
int *in = static_cast<int *>(inputs[0]);
T *weight = static_cast<T *>(inputs[1]);

ark::Dims osh = output_shapes[0].dims4();
ark::Dims wsh = input_shapes[1].dims4();

assert(osh[3] == wsh[3]);

int in_idx = 0;
for (ark::DimType n = 0; n < osh[0]; ++n) {
for (ark::DimType c = 0; c < osh[1]; ++c) {
for (ark::DimType h = 0; h < osh[2]; ++h) {
int weight_idx = in[in_idx++];
T *ptr = &weight[weight_idx * wsh[3]];
for (ark::DimType w = 0; w < osh[3]; ++w) {
out[n * osh[1] * osh[2] * osh[3] + c * osh[2] * osh[3] +
h * osh[3] + w] = ptr[w];
}
}
}
}
};

template <typename T> ark::unittest::State test_embedding()
{
const int num_emb = 1000;
const int emb_dim = 8192;

ark::TensorType weight_type;
if (std::is_same<T, float>::value) {
weight_type = ark::FP32;
} else {
weight_type = ark::FP16;
}

ark::Model m;
ark::Tensor *ti = m.tensor(ark::Dims(8, 3, 64), ark::INT32);
ark::Tensor *tw = m.tensor(ark::Dims(num_emb, emb_dim), weight_type);
ark::Tensor *to = m.embedding(ti, tw);

ark::srand();

std::vector<int> ti_data;
for (auto i = 0; i < ti->shape.size(); ++i) {
// Random indices in [0, num_emb)
ti_data.push_back(ark::rand() % num_emb);
}
auto tw_data = ark::utils::rand_array<T>(tw->shape.size(), 1.0);
auto result =
ark::op_test("embedding_fp32", m, {ti, tw}, {to}, baseline_embedding<T>,
{ti_data.data(), tw_data.get()});
ark::op_test_log(result);
return ark::unittest::SUCCESS;
}

ark::unittest::State test_embedding_fp32()
{
return test_embedding<float>();
}

ark::unittest::State test_embedding_fp16()
{
return test_embedding<ark::half_t>();
}

int main()
{
ark::init();
UNITTEST(test_embedding_fp32);
UNITTEST(test_embedding_fp16);
return ark::unittest::SUCCESS;
}
14 changes: 10 additions & 4 deletions ark/ops/ops_matmul_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,11 @@ ark::unittest::State test_matmul_tn()
ark::Tensor *b = m.tensor(ark::Dims(64, 256), ark::FP16);
ark::Tensor *c = m.matmul(a, b, nullptr, 1, true, false, "matmul", 0);

auto result = ark::op_test("matmul_tn", m, {a, b}, {c},
baseline_matmul_tn<half>, "ones", true);
auto ones_a = ark::utils::ones<ark::half_t>(a->shape.size());
auto ones_b = ark::utils::ones<ark::half_t>(b->shape.size());
auto result =
ark::op_test("matmul_tn", m, {a, b}, {c}, baseline_matmul_tn<half>,
{ones_a.get(), ones_b.get()}, true);
ark::op_test_log(result);
}
{
Expand Down Expand Up @@ -553,8 +556,11 @@ ark::unittest::State test_matmul_batched()
ark::Tensor *b = m.tensor(ark::Dims(2, 64, 64), ark::FP16);
ark::Tensor *c = m.matmul(a, b);

auto result = ark::op_test("matmul_batched", m, {a, b}, {c},
baseline_matmul_nn<half>, "ones", true);
auto ones_a = ark::utils::ones<ark::half_t>(a->shape.size());
auto ones_b = ark::utils::ones<ark::half_t>(b->shape.size());
auto result =
ark::op_test("matmul_batched", m, {a, b}, {c}, baseline_matmul_nn<half>,
{ones_a.get(), ones_b.get()}, true);
ark::op_test_log(result);
return ark::unittest::SUCCESS;
}
Expand Down
Loading

0 comments on commit 1cc9656

Please sign in to comment.