Skip to content

Commit

Permalink
Merge branch 'main' into chhwang/bin-caching
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang authored Oct 16, 2023
2 parents 290eee8 + fdc76bd commit fee2c96
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
4 changes: 4 additions & 0 deletions ark/include/kernels/embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ DEVICE void embedding(DataType *output, int *input, DataType *weight,

// pWeight: Vec<1, 1, 1, EmbeddingDim>
int emb_idx = input[un * InDims::CH + uc * InDims::H + uh];
if (emb_idx < 0) {
emb_idx += WeightShape::H;
}
// TODO: assert if emb_idx is still negative
DataType *pWeight = &weight[emb_idx * WeightDims::W];

Broadcast1<Vec<1, 1, 1, WeightDims::W>, Vec<1, 1, 1, EmbeddingDim>, OutDims,
Expand Down
10 changes: 9 additions & 1 deletion ark/ops/ops_embedding_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ void baseline_embedding(std::vector<void *> &outputs,
for (ark::DimType c = 0; c < osh[1]; ++c) {
for (ark::DimType h = 0; h < osh[2]; ++h) {
int weight_idx = in[in_idx++];
if (weight_idx < 0) {
weight_idx += wsh[2];
}
T *ptr = &weight[weight_idx * wsh[3]];
for (ark::DimType w = 0; w < osh[3]; ++w) {
out[n * osh[1] * osh[2] * osh[3] + c * osh[2] * osh[3] +
Expand Down Expand Up @@ -60,7 +63,12 @@ ark::unittest::State test_embedding() {
std::vector<int> ti_data;
for (auto i = 0; i < ti->shape.size(); ++i) {
// Random indices in [0, num_emb)
ti_data.push_back(ark::rand() % num_emb);
int rand_idx = ark::rand() % num_emb;
if (i % 9 == 0) {
// test negative tokens (padding)
rand_idx = -rand_idx;
}
ti_data.push_back(rand_idx);
}
auto tw_data = ark::utils::rand_array<T>(tw->shape.size(), 1.0);
auto result =
Expand Down

0 comments on commit fee2c96

Please sign in to comment.