Skip to content

Commit

Permalink
fix unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Aug 12, 2024
1 parent 802d84f commit 18a391f
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions ark/api/executor_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ ark::unittest::State test_executor_tensor_read_write(ark::Dims shape,
m.noop(tensor);

ark::DefaultExecutor executor(m, 0);
executor.launch();

UNITTEST_NE(executor.tensor_address(tensor), nullptr);

// Copy data from CPU array to ARK tensor
Expand All @@ -102,20 +102,28 @@ ark::unittest::State test_executor_tensor_read_write(ark::Dims shape,
dev_data[i] = -1;
}

ark::gpuStream stream;
UNITTEST_EQ(
ark::gpuMemcpy(dev_data.data(), dev_ptr, shape.nelems() * sizeof(float),
ark::gpuMemcpyDeviceToHost),
ark::gpuStreamCreateWithFlags(&stream, ark::gpuStreamNonBlocking),
ark::gpuSuccess);

UNITTEST_EQ(ark::gpuMemcpyAsync(dev_data.data(), dev_ptr,
shape.nelems() * sizeof(float),
ark::gpuMemcpyDeviceToHost, stream),
ark::gpuSuccess);
UNITTEST_EQ(ark::gpuStreamSynchronize(stream), ark::gpuSuccess);

for (size_t i = 0; i < dev_data.size(); ++i) {
UNITTEST_EQ(dev_data[i], static_cast<float>(i));
dev_data[i] = -1;
}

// Copy -1s back to GPU array
UNITTEST_EQ(
ark::gpuMemcpy(dev_ptr, dev_data.data(), shape.nelems() * sizeof(float),
ark::gpuMemcpyHostToDevice),
ark::gpuSuccess);
UNITTEST_EQ(ark::gpuMemcpyAsync(dev_ptr, dev_data.data(),
shape.nelems() * sizeof(float),
ark::gpuMemcpyHostToDevice, stream),
ark::gpuSuccess);
UNITTEST_EQ(ark::gpuStreamSynchronize(stream), ark::gpuSuccess);

// Copy data from GPU array to ARK tensor
executor.tensor_write(tensor, dev_ptr, shape.nelems() * sizeof(float),
Expand All @@ -131,10 +139,6 @@ ark::unittest::State test_executor_tensor_read_write(ark::Dims shape,
}

// Provide a stream
ark::gpuStream stream;
UNITTEST_EQ(
ark::gpuStreamCreateWithFlags(&stream, ark::gpuStreamNonBlocking),
ark::gpuSuccess);
executor.tensor_read(tensor, host_data.data(),
shape.nelems() * sizeof(float), stream);
executor.tensor_write(tensor, host_data.data(),
Expand Down

0 comments on commit 18a391f

Please sign in to comment.