diff --git a/source/lib/rocprofiler-sdk/hsa/async_copy.cpp b/source/lib/rocprofiler-sdk/hsa/async_copy.cpp index 5160c3eb..7b2d97a4 100644 --- a/source/lib/rocprofiler-sdk/hsa/async_copy.cpp +++ b/source/lib/rocprofiler-sdk/hsa/async_copy.cpp @@ -239,7 +239,7 @@ active_signals::create() if(m_signal.handle != 0) return; // function pointer may be null during unit testing - if(get_core_table()->hsa_signal_create_fn) + if(hsa::get_hsa_ref_count() > 0 && get_core_table()->hsa_signal_create_fn) { ROCP_HSA_TABLE_CALL(ERROR, get_core_table()->hsa_signal_create_fn(0, 0, nullptr, &m_signal)); @@ -252,7 +252,7 @@ active_signals::destroy() if(m_signal.handle == 0) return; // function pointer may be null during unit testing - if(get_core_table()->hsa_signal_destroy_fn) + if(hsa::get_hsa_ref_count() > 0 && get_core_table()->hsa_signal_destroy_fn) { ROCP_HSA_TABLE_CALL(ERROR, get_core_table()->hsa_signal_destroy_fn(m_signal)); m_signal.handle = 0; @@ -853,11 +853,19 @@ async_copy_init(hsa_api_table_t* _orig, uint64_t _tbl_instance) } void -async_copy_fini() +async_copy_sync() { if(!async_copy::get_active_signals()) return; async_copy::get_active_signals()->sync(); +} + +void +async_copy_fini() +{ + if(!async_copy::get_active_signals()) return; + + async_copy_sync(); async_copy::get_active_signals()->destroy(); } } // namespace hsa diff --git a/source/lib/rocprofiler-sdk/hsa/async_copy.hpp b/source/lib/rocprofiler-sdk/hsa/async_copy.hpp index 1d549971..782bd7e8 100644 --- a/source/lib/rocprofiler-sdk/hsa/async_copy.hpp +++ b/source/lib/rocprofiler-sdk/hsa/async_copy.hpp @@ -48,6 +48,9 @@ get_ids(); void async_copy_init(hsa_api_table_t* _orig, uint64_t _tbl_instance); +void +async_copy_sync(); + void async_copy_fini(); } // namespace hsa diff --git a/source/lib/rocprofiler-sdk/hsa/hsa.cpp b/source/lib/rocprofiler-sdk/hsa/hsa.cpp index 9381192f..694ff671 100644 --- a/source/lib/rocprofiler-sdk/hsa/hsa.cpp +++ b/source/lib/rocprofiler-sdk/hsa/hsa.cpp @@ -536,6 +536,31 @@ should_wrap_functor(const context::context_array_t& _contexts, return false; } +auto hsa_reference_count_value = std::atomic{0}; + +hsa_status_t +hsa_init_refcnt_impl() +{ + struct scoped_dtor + { + scoped_dtor() = default; + ~scoped_dtor() { ++hsa_reference_count_value; } + }; + auto _dtor = scoped_dtor{}; + return get_core_table()->hsa_init_fn(); +} + +hsa_status_t +hsa_shut_down_refcnt_impl() +{ + if(hsa_reference_count_value > 0) + { + --hsa_reference_count_value; + return get_core_table()->hsa_shut_down_fn(); + } + return HSA_STATUS_SUCCESS; +} + template void copy_table(Tp* _orig, uint64_t _tbl_instance, std::integral_constant) @@ -573,6 +598,20 @@ copy_table(Tp* _orig, uint64_t _tbl_instance, std::integral_constant(hsa_pc_sampling_ext_table_t* _tbl, uint6 #endif #undef INSTANTIATE_HSA_TABLE_FUNC + +int +get_hsa_ref_count() +{ + auto _val = hsa_reference_count_value.load(); + ROCP_TRACE << "hsa reference count: " << _val; + return _val; +} } // namespace hsa } // namespace rocprofiler diff --git a/source/lib/rocprofiler-sdk/hsa/hsa.hpp b/source/lib/rocprofiler-sdk/hsa/hsa.hpp index 52aa3a67..eb76a49a 100644 --- a/source/lib/rocprofiler-sdk/hsa/hsa.hpp +++ b/source/lib/rocprofiler-sdk/hsa/hsa.hpp @@ -176,5 +176,8 @@ copy_table(TableT* _orig, uint64_t _tbl_instance); template void update_table(TableT* _orig, uint64_t _tbl_instance); + +int +get_hsa_ref_count(); } // namespace hsa } // namespace rocprofiler diff --git a/source/lib/rocprofiler-sdk/hsa/queue_controller.cpp b/source/lib/rocprofiler-sdk/hsa/queue_controller.cpp index b3d284ce..83f2815c 100644 --- a/source/lib/rocprofiler-sdk/hsa/queue_controller.cpp +++ b/source/lib/rocprofiler-sdk/hsa/queue_controller.cpp @@ -392,6 +392,13 @@ queue_controller_init(HsaApiTable* table) CHECK_NOTNULL(get_queue_controller())->init(*table->core_, *table->amd_ext_); } +void +queue_controller_sync() +{ + if(get_queue_controller()) + get_queue_controller()->iterate_queues([](const Queue* _queue) { _queue->sync(); }); +} + void queue_controller_fini() { diff --git a/source/lib/rocprofiler-sdk/hsa/queue_controller.hpp b/source/lib/rocprofiler-sdk/hsa/queue_controller.hpp index 59a14f5c..eb0c7423 100644 --- a/source/lib/rocprofiler-sdk/hsa/queue_controller.hpp +++ b/source/lib/rocprofiler-sdk/hsa/queue_controller.hpp @@ -121,6 +121,9 @@ queue_controller_init(HsaApiTable* table); void queue_controller_fini(); +void +queue_controller_sync(); + void profiler_serializer_kernel_completion_signal(hsa_signal_t queue_block_signal); diff --git a/source/lib/rocprofiler-sdk/registration.cpp b/source/lib/rocprofiler-sdk/registration.cpp index 76251c97..b2fc6a1e 100644 --- a/source/lib/rocprofiler-sdk/registration.cpp +++ b/source/lib/rocprofiler-sdk/registration.cpp @@ -486,6 +486,9 @@ invoke_client_finalizer(rocprofiler_client_id_t client_id) rocprofiler_tool_finalize_t _finalize_func = nullptr; std::swap(_finalize_func, itr->configure_result->finalize); + hsa::async_copy_sync(); + hsa::queue_controller_sync(); + auto _fini_status = get_fini_status(); if(_fini_status == 0) set_fini_status(-1); _finalize_func(itr->configure_result->tool_data); diff --git a/source/lib/rocprofiler-sdk/tests/intercept_table.cpp b/source/lib/rocprofiler-sdk/tests/intercept_table.cpp index 2ab135f3..311b7eb4 100644 --- a/source/lib/rocprofiler-sdk/tests/intercept_table.cpp +++ b/source/lib/rocprofiler-sdk/tests/intercept_table.cpp @@ -228,11 +228,8 @@ TEST(rocprofiler_lib, intercept_table_and_callback_tracing) static auto& cb_data = get_client_callback_data(); - static auto cfg_result = - rocprofiler_tool_configure_result_t{sizeof(rocprofiler_tool_configure_result_t), - tool_init, - tool_fini, - static_cast(&cb_data)}; + static auto cfg_result = rocprofiler_tool_configure_result_t{ + sizeof(rocprofiler_tool_configure_result_t), tool_init, tool_fini, &cb_data}; static rocprofiler_configure_func_t rocp_init = [](uint32_t version, @@ -251,7 +248,7 @@ TEST(rocprofiler_lib, intercept_table_and_callback_tracing) { ROCPROFILER_CALL_EXPECT( rocprofiler_at_intercept_table_registration( - api_registration_callback, itr, static_cast(&cb_data)), + api_registration_callback, itr, &cb_data), "test should be updated if new (non-HSA, non-HIP) intercept table is supported", ROCPROFILER_STATUS_SUCCESS); } @@ -273,9 +270,11 @@ TEST(rocprofiler_lib, intercept_table_and_callback_tracing) return status; }; + hsa_init(); hsa_init(); auto _agent_data = agent_data{}; - hsa_status_t itr_status = hsa_iterate_agents(agent_cb, static_cast(&_agent_data)); + hsa_status_t itr_status = hsa_iterate_agents(agent_cb, &_agent_data); + hsa_shut_down(); hsa_shut_down(); EXPECT_EQ(itr_status, HSA_STATUS_SUCCESS); @@ -300,16 +299,7 @@ TEST(rocprofiler_lib, intercept_table_and_callback_tracing) EXPECT_EQ(itr.second.first, itr.second.second) << "mismatched wrap counts for " << itr.first << " (lhs=tool_wrapper, rhs=rocprofiler_wrapper)"; - if(itr.first != "hsa_init") - { - EXPECT_GT(itr.second.first, 0) << itr.first << " not wrapped"; - } - else - { - EXPECT_EQ(itr.second.first, 0) << itr.first - << " was wrapped. If hsa runtime has been updated to " - "include first call to hsa_init, update this test"; - } + EXPECT_GT(itr.second.first, 0) << itr.first << " not wrapped"; } auto get_count = [](std::string_view func_name) { @@ -317,10 +307,10 @@ TEST(rocprofiler_lib, intercept_table_and_callback_tracing) return cb_data.client_callback_count.at(func_name).first; }; - EXPECT_EQ(get_count("hsa_init"), 0); + EXPECT_EQ(get_count("hsa_init"), 1); EXPECT_EQ(get_count("hsa_iterate_agents"), 1); EXPECT_EQ(get_count("hsa_agent_get_info"), _agent_data.agent_count); - EXPECT_EQ(get_count("hsa_shut_down"), 1); + EXPECT_EQ(get_count("hsa_shut_down"), 2); } TEST(rocprofiler_lib, intercept_table_and_callback_tracing_disable_context) @@ -392,11 +382,8 @@ TEST(rocprofiler_lib, intercept_table_and_callback_tracing_disable_context) static auto& cb_data = get_client_callback_data(); cb_data = callback_data_ext{}; - static auto cfg_result = - rocprofiler_tool_configure_result_t{sizeof(rocprofiler_tool_configure_result_t), - tool_init, - tool_fini, - static_cast(&cb_data)}; + static auto cfg_result = rocprofiler_tool_configure_result_t{ + sizeof(rocprofiler_tool_configure_result_t), tool_init, tool_fini, &cb_data}; static rocprofiler_configure_func_t rocp_init = [](uint32_t version, @@ -415,7 +402,7 @@ TEST(rocprofiler_lib, intercept_table_and_callback_tracing_disable_context) { ROCPROFILER_CALL_EXPECT( rocprofiler_at_intercept_table_registration( - api_registration_callback, itr, static_cast(&cb_data)), + api_registration_callback, itr, &cb_data), "test should be updated if new (non-HSA, non-HIP) intercept table is supported", ROCPROFILER_STATUS_SUCCESS); }