Skip to content

Commit

Permalink
Merge branch 'main' into chhwang/parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang authored Sep 19, 2023
2 parents 76081b8 + e508ef5 commit e38cf00
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 11 deletions.
5 changes: 3 additions & 2 deletions ark/gpu/gpu_mem.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ static int mem_expose(ExposalInfo *info, GpuPtr addr, uint64_t bytes)
LOG(ERROR, "gpumem driver is not loaded");
}

int flag = 1;
CULOG(cuPointerSetAttribute(&flag, CU_POINTER_ATTRIBUTE_SYNC_MEMOPS, addr));
// Convert virtual into physical address.
int fd = open(GPUMEM_DRIVER_PATH, O_RDWR, 0);
if (fd < 0) {
Expand Down Expand Up @@ -163,6 +161,9 @@ void GpuMem::init(size_t bytes, bool expose)
addr_ =
(CUdeviceptr)(((uint64_t)raw_addr_ + GPU_PAGE_OFFSET) & GPU_PAGE_MASK);

int one = 1;
CULOG(cuPointerSetAttribute(&one, CU_POINTER_ATTRIBUTE_SYNC_MEMOPS, addr_));

ExposalInfo exp_info;
if (expose) {
int err = mem_expose(&exp_info, addr_, bytes + GPU_PAGE_SIZE);
Expand Down
4 changes: 2 additions & 2 deletions ark/gpu/gpu_mgr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -364,15 +364,15 @@ void GpuMgrCtx::reg_sendrecv(int sid, int remote_gpu_id, size_t bytes,
}

//
void GpuMgrCtx::freeze()
void GpuMgrCtx::freeze(bool expose)
{
//
this->gpu_mgr->validate_total_bytes();

//
if (total_bytes > 0) {
LOG(INFO, "Allocating ", total_bytes, " bytes of GPU memory");
this->data_mem.init(total_bytes, false);
this->data_mem.init(total_bytes, expose);
// init the data mem
CULOG(cuMemsetD32(this->data_mem.ref(), 0, total_bytes >> 2));
}
Expand Down
2 changes: 1 addition & 1 deletion ark/gpu/gpu_mgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class GpuMgrCtx
void mem_export(GpuBuf *buf, size_t offset, int sid);
GpuBuf *mem_import(size_t bytes, int sid, int gpu_id);
void reg_sendrecv(int sid, int gpu_dst, std::size_t bytes, bool is_recv);
void freeze();
void freeze(bool expose = false);
// void send(int sid, int rank, size_t bytes);
GpuState set_current();
int get_world_size() const
Expand Down
4 changes: 2 additions & 2 deletions ark/gpu/gpu_mgr_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ unittest::State test_gpu_mgr_remote()
GpuBuf *gpu1_eid5 = ctx->mem_import(sizeof(int), 5, 1);
GpuBuf *gpu1_eid6 = ctx->mem_import(sizeof(int), 6, 1);

ctx->freeze();
ctx->freeze(true);

volatile int *ptr = (volatile int *)gpu0_eid3->href();
while (*ptr != 7890) {
Expand Down Expand Up @@ -176,7 +176,7 @@ unittest::State test_gpu_mgr_remote()
GpuBuf *gpu0_eid3 = ctx->mem_import(sizeof(int), 3, 0);
GpuBuf *gpu0_eid4 = ctx->mem_import(sizeof(int), 4, 0);

ctx->freeze();
ctx->freeze(true);

gpu_memset(gpu0_eid3, 7890, 1);

Expand Down
3 changes: 2 additions & 1 deletion ark/ops/ops_all_reduce_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ void test_all_reduce_4gpus_internal(size_t nelem, int iter)
auto result =
ark::op_test("all_reduce", m, {ones}, {output},
baseline_all_reduce<ark::half_t, num_gpus>,
{ones_data.get()}, true, gpu_id, num_gpus);
{ones_data.get()}, false, gpu_id, num_gpus);
ark::op_test_log(result);
UNITTEST_EQ(result.max_diff[0], 0.0f);
return ark::unittest::SUCCESS;
});
}
Expand Down
20 changes: 17 additions & 3 deletions ark/ops/ops_sendrecv_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
#include "logging.h"
#include "unittest/unittest_utils.h"

using namespace std;

void test_sendrecv_internal()
{
for (int gpu_id = 0; gpu_id < 2; ++gpu_id) {
Expand All @@ -18,7 +16,7 @@ void test_sendrecv_internal()
ark::Model model{gpu_id};
ark::Tensor *tns_x = model.tensor({1024}, ark::FP16);
if (gpu_id == 0) {
model.send(tns_x, 0, 1, 1024);
model.send(tns_x, 0, 1, tns_x->shape_bytes());
model.send_done(tns_x, 0, 1);
}
if (gpu_id == 1) {
Expand All @@ -28,6 +26,13 @@ void test_sendrecv_internal()
ark::Executor exe{gpu_id, 2, model, "test_sendrecv"};
exe.compile();

if (gpu_id == 0) {
std::vector<ark::half_t> data(1024);
for (int i = 0; i < 1024; ++i) {
data[i] = ark::half_t(i + 1);
}
tns_x->write(data.data());
}
exe.launch();
exe.run(1);
exe.stop();
Expand All @@ -36,6 +41,15 @@ void test_sendrecv_internal()
ark::IpcAllGather barrier{"test_sendrecv_barrier", gpu_id, 2, tmp,
sizeof(int)};
barrier.sync();

if (gpu_id == 1) {
std::vector<ark::half_t> data(1024);
tns_x->read(data.data());
for (int i = 0; i < 1024; ++i) {
UNITTEST_EQ(data[i], ark::half_t(i + 1));
}
}

return ark::unittest::SUCCESS;
});
}
Expand Down

0 comments on commit e38cf00

Please sign in to comment.