Skip to content

Commit

Permalink
Sync queue and async copy on client finalizer (#950)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrmadsen authored Jun 25, 2024
1 parent 27fa455 commit 62ec95e
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 28 deletions.
14 changes: 11 additions & 3 deletions source/lib/rocprofiler-sdk/hsa/async_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ active_signals::create()
if(m_signal.handle != 0) return;

// function pointer may be null during unit testing
if(get_core_table()->hsa_signal_create_fn)
if(hsa::get_hsa_ref_count() > 0 && get_core_table()->hsa_signal_create_fn)
{
ROCP_HSA_TABLE_CALL(ERROR,
get_core_table()->hsa_signal_create_fn(0, 0, nullptr, &m_signal));
Expand All @@ -252,7 +252,7 @@ active_signals::destroy()
if(m_signal.handle == 0) return;

// function pointer may be null during unit testing
if(get_core_table()->hsa_signal_destroy_fn)
if(hsa::get_hsa_ref_count() > 0 && get_core_table()->hsa_signal_destroy_fn)
{
ROCP_HSA_TABLE_CALL(ERROR, get_core_table()->hsa_signal_destroy_fn(m_signal));
m_signal.handle = 0;
Expand Down Expand Up @@ -853,11 +853,19 @@ async_copy_init(hsa_api_table_t* _orig, uint64_t _tbl_instance)
}

void
async_copy_fini()
async_copy_sync()
{
if(!async_copy::get_active_signals()) return;

async_copy::get_active_signals()->sync();
}

void
async_copy_fini()
{
if(!async_copy::get_active_signals()) return;

async_copy_sync();
async_copy::get_active_signals()->destroy();
}
} // namespace hsa
Expand Down
3 changes: 3 additions & 0 deletions source/lib/rocprofiler-sdk/hsa/async_copy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ get_ids();
void
async_copy_init(hsa_api_table_t* _orig, uint64_t _tbl_instance);

void
async_copy_sync();

void
async_copy_fini();
} // namespace hsa
Expand Down
47 changes: 47 additions & 0 deletions source/lib/rocprofiler-sdk/hsa/hsa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,31 @@ should_wrap_functor(const context::context_array_t& _contexts,
return false;
}

auto hsa_reference_count_value = std::atomic<int>{0};

hsa_status_t
hsa_init_refcnt_impl()
{
struct scoped_dtor
{
scoped_dtor() = default;
~scoped_dtor() { ++hsa_reference_count_value; }
};
auto _dtor = scoped_dtor{};
return get_core_table()->hsa_init_fn();
}

hsa_status_t
hsa_shut_down_refcnt_impl()
{
if(hsa_reference_count_value > 0)
{
--hsa_reference_count_value;
return get_core_table()->hsa_shut_down_fn();
}
return HSA_STATUS_SUCCESS;
}

template <size_t TableIdx, typename LookupT = internal_table, typename Tp, size_t OpIdx>
void
copy_table(Tp* _orig, uint64_t _tbl_instance, std::integral_constant<size_t, OpIdx>)
Expand Down Expand Up @@ -573,6 +598,20 @@ copy_table(Tp* _orig, uint64_t _tbl_instance, std::integral_constant<size_t, OpI
ROCP_TRACE << "skipping copying table entry for " << _info.name
<< " from table instance " << _tbl_instance;
}

if constexpr(TableIdx == ROCPROFILER_HSA_TABLE_ID_Core &&
OpIdx == ROCPROFILER_HSA_CORE_API_ID_hsa_init)
{
auto& _func = _info.get_table_func(_info.get_table(_orig));
_func = hsa_init_refcnt_impl;
if(get_hsa_ref_count() == 0) ++hsa_reference_count_value;
}
else if constexpr(TableIdx == ROCPROFILER_HSA_TABLE_ID_Core &&
OpIdx == ROCPROFILER_HSA_CORE_API_ID_hsa_shut_down)
{
auto& _func = _info.get_table_func(_info.get_table(_orig));
_func = hsa_shut_down_refcnt_impl;
}
}
}

Expand Down Expand Up @@ -769,5 +808,13 @@ copy_table<hsa_pc_sampling_ext_table_t>(hsa_pc_sampling_ext_table_t* _tbl, uint6

#endif
#undef INSTANTIATE_HSA_TABLE_FUNC

int
get_hsa_ref_count()
{
auto _val = hsa_reference_count_value.load();
ROCP_TRACE << "hsa reference count: " << _val;
return _val;
}
} // namespace hsa
} // namespace rocprofiler
3 changes: 3 additions & 0 deletions source/lib/rocprofiler-sdk/hsa/hsa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,5 +176,8 @@ copy_table(TableT* _orig, uint64_t _tbl_instance);
template <typename TableT>
void
update_table(TableT* _orig, uint64_t _tbl_instance);

int
get_hsa_ref_count();
} // namespace hsa
} // namespace rocprofiler
7 changes: 7 additions & 0 deletions source/lib/rocprofiler-sdk/hsa/queue_controller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,13 @@ queue_controller_init(HsaApiTable* table)
CHECK_NOTNULL(get_queue_controller())->init(*table->core_, *table->amd_ext_);
}

void
queue_controller_sync()
{
if(get_queue_controller())
get_queue_controller()->iterate_queues([](const Queue* _queue) { _queue->sync(); });
}

void
queue_controller_fini()
{
Expand Down
3 changes: 3 additions & 0 deletions source/lib/rocprofiler-sdk/hsa/queue_controller.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ queue_controller_init(HsaApiTable* table);
void
queue_controller_fini();

void
queue_controller_sync();

void
profiler_serializer_kernel_completion_signal(hsa_signal_t queue_block_signal);

Expand Down
3 changes: 3 additions & 0 deletions source/lib/rocprofiler-sdk/registration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,9 @@ invoke_client_finalizer(rocprofiler_client_id_t client_id)
rocprofiler_tool_finalize_t _finalize_func = nullptr;
std::swap(_finalize_func, itr->configure_result->finalize);

hsa::async_copy_sync();
hsa::queue_controller_sync();

auto _fini_status = get_fini_status();
if(_fini_status == 0) set_fini_status(-1);
_finalize_func(itr->configure_result->tool_data);
Expand Down
37 changes: 12 additions & 25 deletions source/lib/rocprofiler-sdk/tests/intercept_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,8 @@ TEST(rocprofiler_lib, intercept_table_and_callback_tracing)

static auto& cb_data = get_client_callback_data();

static auto cfg_result =
rocprofiler_tool_configure_result_t{sizeof(rocprofiler_tool_configure_result_t),
tool_init,
tool_fini,
static_cast<void*>(&cb_data)};
static auto cfg_result = rocprofiler_tool_configure_result_t{
sizeof(rocprofiler_tool_configure_result_t), tool_init, tool_fini, &cb_data};

static rocprofiler_configure_func_t rocp_init =
[](uint32_t version,
Expand All @@ -251,7 +248,7 @@ TEST(rocprofiler_lib, intercept_table_and_callback_tracing)
{
ROCPROFILER_CALL_EXPECT(
rocprofiler_at_intercept_table_registration(
api_registration_callback, itr, static_cast<void*>(&cb_data)),
api_registration_callback, itr, &cb_data),
"test should be updated if new (non-HSA, non-HIP) intercept table is supported",
ROCPROFILER_STATUS_SUCCESS);
}
Expand All @@ -273,9 +270,11 @@ TEST(rocprofiler_lib, intercept_table_and_callback_tracing)
return status;
};

hsa_init();
hsa_init();
auto _agent_data = agent_data{};
hsa_status_t itr_status = hsa_iterate_agents(agent_cb, static_cast<void*>(&_agent_data));
hsa_status_t itr_status = hsa_iterate_agents(agent_cb, &_agent_data);
hsa_shut_down();
hsa_shut_down();

EXPECT_EQ(itr_status, HSA_STATUS_SUCCESS);
Expand All @@ -300,27 +299,18 @@ TEST(rocprofiler_lib, intercept_table_and_callback_tracing)
EXPECT_EQ(itr.second.first, itr.second.second)
<< "mismatched wrap counts for " << itr.first
<< " (lhs=tool_wrapper, rhs=rocprofiler_wrapper)";
if(itr.first != "hsa_init")
{
EXPECT_GT(itr.second.first, 0) << itr.first << " not wrapped";
}
else
{
EXPECT_EQ(itr.second.first, 0) << itr.first
<< " was wrapped. If hsa runtime has been updated to "
"include first call to hsa_init, update this test";
}
EXPECT_GT(itr.second.first, 0) << itr.first << " not wrapped";
}

auto get_count = [](std::string_view func_name) {
// we already checked that first == second so we can just check first here
return cb_data.client_callback_count.at(func_name).first;
};

EXPECT_EQ(get_count("hsa_init"), 0);
EXPECT_EQ(get_count("hsa_init"), 1);
EXPECT_EQ(get_count("hsa_iterate_agents"), 1);
EXPECT_EQ(get_count("hsa_agent_get_info"), _agent_data.agent_count);
EXPECT_EQ(get_count("hsa_shut_down"), 1);
EXPECT_EQ(get_count("hsa_shut_down"), 2);
}

TEST(rocprofiler_lib, intercept_table_and_callback_tracing_disable_context)
Expand Down Expand Up @@ -392,11 +382,8 @@ TEST(rocprofiler_lib, intercept_table_and_callback_tracing_disable_context)
static auto& cb_data = get_client_callback_data();
cb_data = callback_data_ext{};

static auto cfg_result =
rocprofiler_tool_configure_result_t{sizeof(rocprofiler_tool_configure_result_t),
tool_init,
tool_fini,
static_cast<void*>(&cb_data)};
static auto cfg_result = rocprofiler_tool_configure_result_t{
sizeof(rocprofiler_tool_configure_result_t), tool_init, tool_fini, &cb_data};

static rocprofiler_configure_func_t rocp_init =
[](uint32_t version,
Expand All @@ -415,7 +402,7 @@ TEST(rocprofiler_lib, intercept_table_and_callback_tracing_disable_context)
{
ROCPROFILER_CALL_EXPECT(
rocprofiler_at_intercept_table_registration(
api_registration_callback, itr, static_cast<void*>(&cb_data)),
api_registration_callback, itr, &cb_data),
"test should be updated if new (non-HSA, non-HIP) intercept table is supported",
ROCPROFILER_STATUS_SUCCESS);
}
Expand Down

0 comments on commit 62ec95e

Please sign in to comment.