From 8da0c35079c6c9a449beef991a9aa31e228c7989 Mon Sep 17 00:00:00 2001 From: Giovanni Lenzi Baraldi Date: Tue, 25 Jun 2024 13:46:13 -0300 Subject: [PATCH] Adding wrappers on HSA for executable load/unload and allowing multiple agents per context on ATT (#951) * Codeobj wrappers around HSA calls for ATT * Formatting * Bookeeping * Tidy * Tidy * Update source/lib/rocprofiler-sdk/thread_trace/code_object.hpp Co-authored-by: Vladimir Indic <139573562+vlaindic@users.noreply.github.com> * Update source/lib/rocprofiler-sdk/thread_trace/att_core.hpp Co-authored-by: Vladimir Indic <139573562+vlaindic@users.noreply.github.com> * Variable naming --------- Co-authored-by: Vladimir Indic <139573562+vlaindic@users.noreply.github.com> --- source/lib/rocprofiler-sdk/hsa/aql_packet.hpp | 2 +- .../rocprofiler-sdk/hsa/queue_controller.cpp | 6 +- source/lib/rocprofiler-sdk/registration.cpp | 1 + .../thread_trace/CMakeLists.txt | 5 +- .../rocprofiler-sdk/thread_trace/att_core.cpp | 124 ++++++--------- .../rocprofiler-sdk/thread_trace/att_core.hpp | 43 ++--- .../thread_trace/att_service.cpp | 82 ++++------ .../thread_trace/code_object.cpp | 147 ++++++++++++++++++ .../thread_trace/code_object.hpp | 63 ++++++++ 9 files changed, 316 insertions(+), 157 deletions(-) create mode 100644 source/lib/rocprofiler-sdk/thread_trace/code_object.cpp create mode 100644 source/lib/rocprofiler-sdk/thread_trace/code_object.hpp diff --git a/source/lib/rocprofiler-sdk/hsa/aql_packet.hpp b/source/lib/rocprofiler-sdk/hsa/aql_packet.hpp index 10004f12..e0216e2e 100644 --- a/source/lib/rocprofiler-sdk/hsa/aql_packet.hpp +++ b/source/lib/rocprofiler-sdk/hsa/aql_packet.hpp @@ -212,7 +212,7 @@ class TraceControlAQLPacket : public AQLPacket loaded_codeobj[id] = std::make_shared(*tracepool, id, addr, size, true, false); } - void remove_codeobj(code_object_id_t id) { loaded_codeobj.erase(id); } + bool remove_codeobj(code_object_id_t id) { return loaded_codeobj.erase(id) != 0; } protected: std::shared_ptr tracepool; diff --git a/source/lib/rocprofiler-sdk/hsa/queue_controller.cpp b/source/lib/rocprofiler-sdk/hsa/queue_controller.cpp index 0a778b0f..3398b97a 100644 --- a/source/lib/rocprofiler-sdk/hsa/queue_controller.cpp +++ b/source/lib/rocprofiler-sdk/hsa/queue_controller.cpp @@ -277,15 +277,11 @@ QueueController::init(CoreApiTable& core_table, AmdExtTable& ext_table) itr->buffered_tracer->domains(ROCPROFILER_BUFFER_TRACING_KERNEL_DISPATCH)); if(itr->counter_collection || itr->pc_sampler || has_kernel_tracing || - itr->agent_counter_collection) + itr->agent_counter_collection || itr->thread_trace) { enable_intercepter = true; break; } - else if(itr->thread_trace) - { - enable_intercepter = true; - } } if(enable_intercepter) diff --git a/source/lib/rocprofiler-sdk/registration.cpp b/source/lib/rocprofiler-sdk/registration.cpp index b2fc6a1e..8561c959 100644 --- a/source/lib/rocprofiler-sdk/registration.cpp +++ b/source/lib/rocprofiler-sdk/registration.cpp @@ -762,6 +762,7 @@ rocprofiler_set_api_table(const char* name, rocprofiler::hsa::async_copy_init(hsa_api_table, lib_instance); rocprofiler::code_object::initialize(hsa_api_table); + rocprofiler::thread_trace::code_object::initialize(hsa_api_table); #if ROCPROFILER_SDK_HSA_PC_SAMPLING > 0 rocprofiler::pc_sampling::code_object::initialize(hsa_api_table); #endif diff --git a/source/lib/rocprofiler-sdk/thread_trace/CMakeLists.txt b/source/lib/rocprofiler-sdk/thread_trace/CMakeLists.txt index 802ad64f..60f552d6 100644 --- a/source/lib/rocprofiler-sdk/thread_trace/CMakeLists.txt +++ b/source/lib/rocprofiler-sdk/thread_trace/CMakeLists.txt @@ -1,5 +1,6 @@ -set(ROCPROFILER_LIB_THREAD_TRACE_SOURCES att_core.cpp att_service.cpp att_parser.cpp) -set(ROCPROFILER_LIB_THREAD_TRACE_HEADERS att_core.hpp) +set(ROCPROFILER_LIB_THREAD_TRACE_SOURCES att_core.cpp att_service.cpp att_parser.cpp + code_object.cpp) +set(ROCPROFILER_LIB_THREAD_TRACE_HEADERS att_core.hpp code_object.hpp) target_sources(rocprofiler-object-library PRIVATE ${ROCPROFILER_LIB_THREAD_TRACE_SOURCES} ${ROCPROFILER_LIB_THREAD_TRACE_HEADERS}) diff --git a/source/lib/rocprofiler-sdk/thread_trace/att_core.cpp b/source/lib/rocprofiler-sdk/thread_trace/att_core.cpp index a5d1c9e7..5be0eed9 100644 --- a/source/lib/rocprofiler-sdk/thread_trace/att_core.cpp +++ b/source/lib/rocprofiler-sdk/thread_trace/att_core.cpp @@ -100,6 +100,7 @@ ThreadTracerQueue::ThreadTracerQueue(thread_trace_parameter_pack _params, const CoreApiTable& coreapi, const AmdExtTable& ext) : params(std::move(_params)) +, agent_id(cache.get_rocp_agent()->id) { factory = std::make_unique(cache, this->params, coreapi, ext); control_packet = factory->construct_control_packet(); @@ -122,6 +123,14 @@ ThreadTracerQueue::ThreadTracerQueue(thread_trace_parameter_pack _params, signal_store_screlease_fn = coreapi.hsa_signal_store_screlease_fn; add_write_index_relaxed_fn = coreapi.hsa_queue_add_write_index_relaxed_fn; load_read_index_relaxed_fn = coreapi.hsa_queue_load_read_index_relaxed_fn; + + codeobj_reg = std::make_unique( + [this](rocprofiler_agent_id_t agent, uint64_t codeobj_id, uint64_t addr, uint64_t size) { + if(agent == this->agent_id) this->load_codeobj(codeobj_id, addr, size); + }, + [this](uint64_t codeobj_id) { this->unload_codeobj(codeobj_id); }); + + codeobj_reg->IterateLoaded(); } ThreadTracerQueue::~ThreadTracerQueue() @@ -203,8 +212,7 @@ ThreadTracerQueue::unload_codeobj(code_object_id_t id) { std::unique_lock lk(trace_resources_mut); - control_packet->remove_codeobj(id); - + if(!control_packet->remove_codeobj(id)) return; if(!queue || active_traces.load() < 1) return; auto packet = factory->construct_unload_marker_packet(id); @@ -214,33 +222,6 @@ ThreadTracerQueue::unload_codeobj(code_object_id_t id) packet.release(); } -// TODO: make this a wrapper on HSA load instead of registering -void -DispatchThreadTracer::codeobj_tracing_callback(rocprofiler_callback_tracing_record_t record, - rocprofiler_user_data_t* /* user_data */, - void* callback_data) -{ - if(!callback_data) return; - if(record.kind != ROCPROFILER_CALLBACK_TRACING_CODE_OBJECT) return; - if(record.operation != ROCPROFILER_CODE_OBJECT_LOAD) return; - - auto* rec = static_cast(record.payload); - assert(rec); - - DispatchThreadTracer& tracer = *static_cast(callback_data); - auto agent = rec->hsa_agent; - - std::shared_lock lk(tracer.agents_map_mut); - - auto tracer_it = tracer.agents.find(agent); - if(tracer_it == tracer.agents.end()) return; - - if(record.phase == ROCPROFILER_CALLBACK_PHASE_LOAD) - tracer_it->second->load_codeobj(rec->code_object_id, rec->load_delta, rec->load_size); - else if(record.phase == ROCPROFILER_CALLBACK_PHASE_UNLOAD) - tracer_it->second->unload_codeobj(rec->code_object_id); -} - void DispatchThreadTracer::resource_init(const hsa::AgentCache& cache, const CoreApiTable& coreapi, @@ -377,12 +358,6 @@ void DispatchThreadTracer::start_context() { using corr_id_map_t = hsa::Queue::queue_info_session_t::external_corr_id_map_t; - if(codeobj_client_ctx.handle != 0) - { - auto status = rocprofiler_start_context(codeobj_client_ctx); - if(status != ROCPROFILER_STATUS_SUCCESS) throw std::exception(); - } - CHECK_NOTNULL(hsa::get_queue_controller())->enable_serialization(); // Only one thread should be attempting to enable/disable this context @@ -427,82 +402,75 @@ AgentThreadTracer::resource_init(const hsa::AgentCache& cache, const CoreApiTable& coreapi, const AmdExtTable& ext) { - if(cache.get_rocp_agent()->id != this->agent_id) return; + auto id = cache.get_rocp_agent()->id; + std::unique_lock lk(agent_mut); - std::unique_lock lk(mut); + if(params.find(id) == params.end()) return; - if(tracer != nullptr) + if(tracers.find(id) != tracers.end()) { - tracer->active_queues.fetch_add(1); + tracers.at(id)->active_queues.fetch_add(1); return; } - - tracer = std::make_unique(this->params, cache, coreapi, ext); + tracers.emplace(id, std::make_unique(params.at(id), cache, coreapi, ext)); } void AgentThreadTracer::resource_deinit(const hsa::AgentCache& cache) { - if(cache.get_rocp_agent()->id != this->agent_id) return; + auto id = cache.get_rocp_agent()->id; + std::unique_lock lk(agent_mut); - std::unique_lock lk(mut); - if(tracer == nullptr) return; + if(params.find(id) == params.end()) return; + if(tracers.find(id) == tracers.end()) return; - if(tracer->active_queues.fetch_sub(1) == 1) tracer.reset(); + auto& tracer = *tracers.at(id); + if(tracer.active_queues.fetch_sub(1) == 1) tracers.erase(id); } void AgentThreadTracer::start_context() { - std::unique_lock lk(mut); + std::unique_lock lk(agent_mut); - if(tracer == nullptr) + if(tracers.empty()) { ROCP_FATAL << "Thread trace context not present for agent!"; return; } - auto packet = tracer->get_control(true); - packet->populate_before(); + for(auto& [_, tracer] : tracers) + { + auto packet = tracer->get_control(true); + packet->populate_before(); - for(auto& start : packet->before_krn_pkt) - tracer->Submit(&start); + for(auto& start : packet->before_krn_pkt) + tracer->Submit(&start); + } } void AgentThreadTracer::stop_context() { - std::unique_lock lk(mut); - - auto packet = tracer->get_control(false); - packet->populate_after(); + std::unique_lock lk(agent_mut); - for(auto& stop : packet->after_krn_pkt) - tracer->Submit(&stop); - - rocprofiler_user_data_t userdata{.ptr = params.callback_userdata}; - tracer->iterate_data(packet->GetHandle(), userdata); -} - -void -AgentThreadTracer::codeobj_tracing_callback(rocprofiler_callback_tracing_record_t record, - rocprofiler_user_data_t* /* user_data */, - void* callback_data) -{ - if(!callback_data) return; - if(record.kind != ROCPROFILER_CALLBACK_TRACING_CODE_OBJECT) return; - if(record.operation != ROCPROFILER_CODE_OBJECT_LOAD) return; + if(tracers.empty()) + { + ROCP_FATAL << "Thread trace context not present for agent!"; + return; + } - auto* rec = static_cast(record.payload); - assert(rec); + for(auto& [_, tracer] : tracers) + { + auto packet = tracer->get_control(false); + packet->populate_after(); - AgentThreadTracer& tracer = *static_cast(callback_data); - std::unique_lock lk(tracer.mut); + for(auto& stop : packet->after_krn_pkt) + tracer->Submit(&stop); - if(record.phase == ROCPROFILER_CALLBACK_PHASE_LOAD) - tracer.tracer->load_codeobj(rec->code_object_id, rec->load_delta, rec->load_size); - else if(record.phase == ROCPROFILER_CALLBACK_PHASE_UNLOAD) - tracer.tracer->unload_codeobj(rec->code_object_id); + rocprofiler_user_data_t userdata{.ptr = tracer->params.callback_userdata}; + tracer->iterate_data(packet->GetHandle(), userdata); + } } } // namespace thread_trace diff --git a/source/lib/rocprofiler-sdk/thread_trace/att_core.hpp b/source/lib/rocprofiler-sdk/thread_trace/att_core.hpp index caee1fb4..76e6a86a 100644 --- a/source/lib/rocprofiler-sdk/thread_trace/att_core.hpp +++ b/source/lib/rocprofiler-sdk/thread_trace/att_core.hpp @@ -22,9 +22,11 @@ #pragma once +#include "lib/rocprofiler-sdk/hsa/agent_cache.hpp" +#include "lib/rocprofiler-sdk/thread_trace/code_object.hpp" + #include #include -#include "lib/rocprofiler-sdk/hsa/agent_cache.hpp" #include #include @@ -100,6 +102,10 @@ class ThreadTracerQueue bool Submit(hsa_ext_amd_aql_pm4_packet_t* packet); private: + std::unique_ptr codeobj_reg{nullptr}; + + rocprofiler_agent_id_t agent_id; + decltype(hsa_queue_load_read_index_relaxed)* load_read_index_relaxed_fn{nullptr}; decltype(hsa_queue_add_write_index_relaxed)* add_write_index_relaxed_fn{nullptr}; decltype(hsa_signal_store_screlease)* signal_store_screlease_fn{nullptr}; @@ -135,10 +141,6 @@ class DispatchThreadTracer : public ThreadTracerInterface void resource_init(const hsa::AgentCache&, const CoreApiTable&, const AmdExtTable&) override; void resource_deinit(const hsa::AgentCache&) override; - static void codeobj_tracing_callback(rocprofiler_callback_tracing_record_t record, - rocprofiler_user_data_t* user_data, - void* callback_data); - std::unique_ptr pre_kernel_call(const hsa::Queue& queue, uint64_t kernel_id, rocprofiler_dispatch_id_t dispatch_id, @@ -153,16 +155,12 @@ class DispatchThreadTracer : public ThreadTracerInterface std::atomic post_move_data{0}; thread_trace_parameter_pack params; - rocprofiler_context_id_t codeobj_client_ctx{0}; }; class AgentThreadTracer : public ThreadTracerInterface { public: - AgentThreadTracer(thread_trace_parameter_pack _params, rocprofiler_agent_id_t _id) - : agent_id(_id) - , params(std::move(_params)) - {} + AgentThreadTracer() = default; ~AgentThreadTracer() override = default; void start_context() override; @@ -170,16 +168,21 @@ class AgentThreadTracer : public ThreadTracerInterface void resource_init(const hsa::AgentCache&, const CoreApiTable&, const AmdExtTable&) override; void resource_deinit(const hsa::AgentCache&) override; - static void codeobj_tracing_callback(rocprofiler_callback_tracing_record_t record, - rocprofiler_user_data_t* user_data, - void* callback_data); - - rocprofiler_agent_id_t agent_id; - std::mutex mut; - std::unique_ptr tracer{nullptr}; - - thread_trace_parameter_pack params; - rocprofiler_context_id_t codeobj_client_ctx{0}; + void add_agent(rocprofiler_agent_id_t id, thread_trace_parameter_pack _params) + { + std::unique_lock lk(agent_mut); + params[id] = std::move(_params); + } + bool has_agent(rocprofiler_agent_id_t id) + { + std::unique_lock lk(agent_mut); + return params.find(id) != params.end(); + } + + std::map> tracers{}; + std::map params; + + std::mutex agent_mut; }; }; // namespace thread_trace diff --git a/source/lib/rocprofiler-sdk/thread_trace/att_service.cpp b/source/lib/rocprofiler-sdk/thread_trace/att_service.cpp index 60a7fe9e..36efa1e3 100644 --- a/source/lib/rocprofiler-sdk/thread_trace/att_service.cpp +++ b/source/lib/rocprofiler-sdk/thread_trace/att_service.cpp @@ -62,12 +62,12 @@ rocprofiler_configure_dispatch_thread_trace_service( if(!ctx) return ROCPROFILER_STATUS_ERROR_CONTEXT_NOT_STARTED; if(ctx->thread_trace) return ROCPROFILER_STATUS_ERROR_SERVICE_ALREADY_CONFIGURED; - auto param_pack = rocprofiler::thread_trace::thread_trace_parameter_pack{}; + auto pack = rocprofiler::thread_trace::thread_trace_parameter_pack{}; - param_pack.context_id = context_id; - param_pack.dispatch_cb_fn = dispatch_callback; - param_pack.shader_cb_fn = shader_callback; - param_pack.callback_userdata = callback_userdata; + pack.context_id = context_id; + pack.dispatch_cb_fn = dispatch_callback; + pack.shader_cb_fn = shader_callback; + pack.callback_userdata = callback_userdata; const auto& id_map = *CHECK_NOTNULL(rocprofiler::counters::getPerfCountersIdMap()); for(size_t p = 0; p < num_parameters; p++) @@ -78,39 +78,26 @@ rocprofiler_configure_dispatch_thread_trace_service( switch(param.type) { - case ROCPROFILER_ATT_PARAMETER_TARGET_CU: param_pack.target_cu = param.value; break; + case ROCPROFILER_ATT_PARAMETER_TARGET_CU: pack.target_cu = param.value; break; case ROCPROFILER_ATT_PARAMETER_SHADER_ENGINE_MASK: - param_pack.shader_engine_mask = param.value; + pack.shader_engine_mask = param.value; break; - case ROCPROFILER_ATT_PARAMETER_BUFFER_SIZE: param_pack.buffer_size = param.value; break; - case ROCPROFILER_ATT_PARAMETER_SIMD_SELECT: param_pack.simd_select = param.value; break; + case ROCPROFILER_ATT_PARAMETER_BUFFER_SIZE: pack.buffer_size = param.value; break; + case ROCPROFILER_ATT_PARAMETER_SIMD_SELECT: pack.simd_select = param.value; break; case ROCPROFILER_ATT_PARAMETER_PERFCOUNTER: if(const auto* metric_ptr = rocprofiler::common::get_val(id_map, param.counter_id.handle)) - param_pack.perfcounters.push_back(get_mask(metric_ptr, param.simd_mask)); + pack.perfcounters.push_back(get_mask(metric_ptr, param.simd_mask)); break; case ROCPROFILER_ATT_PARAMETER_PERFCOUNTERS_CTRL: - param_pack.perfcounter_ctrl = param.value; + pack.perfcounter_ctrl = param.value; break; case ROCPROFILER_ATT_PARAMETER_LAST: return ROCPROFILER_STATUS_ERROR_INVALID_ARGUMENT; } } - auto tracer = std::make_unique(param_pack); - - rocprofiler_status_t status = rocprofiler_create_context(&tracer->codeobj_client_ctx); - if(status != ROCPROFILER_STATUS_SUCCESS) return status; - - status = rocprofiler_configure_callback_tracing_service( - tracer->codeobj_client_ctx, - ROCPROFILER_CALLBACK_TRACING_CODE_OBJECT, - nullptr, - 0, - rocprofiler::thread_trace::DispatchThreadTracer::codeobj_tracing_callback, - tracer.get()); - - ctx->thread_trace = std::move(tracer); - return status; + ctx->thread_trace = std::make_unique(pack); + return ROCPROFILER_STATUS_SUCCESS; } rocprofiler_status_t ROCPROFILER_API @@ -122,18 +109,20 @@ rocprofiler_configure_agent_thread_trace_service( rocprofiler_att_shader_data_callback_t shader_callback, void* callback_userdata) { + using AgentThreadTracer = rocprofiler::thread_trace::AgentThreadTracer; if(rocprofiler::registration::get_init_status() > -1) return ROCPROFILER_STATUS_ERROR_CONFIGURATION_LOCKED; auto* ctx = rocprofiler::context::get_mutable_registered_context(context_id); if(!ctx) return ROCPROFILER_STATUS_ERROR_CONTEXT_NOT_STARTED; - if(ctx->thread_trace) return ROCPROFILER_STATUS_ERROR_SERVICE_ALREADY_CONFIGURED; - auto param_pack = rocprofiler::thread_trace::thread_trace_parameter_pack{}; + if(!ctx->thread_trace) ctx->thread_trace = std::make_unique(); - param_pack.context_id = context_id; - param_pack.shader_cb_fn = shader_callback; - param_pack.callback_userdata = callback_userdata; + auto pack = rocprofiler::thread_trace::thread_trace_parameter_pack{}; + + pack.context_id = context_id; + pack.shader_cb_fn = shader_callback; + pack.callback_userdata = callback_userdata; const auto& id_map = *CHECK_NOTNULL(rocprofiler::counters::getPerfCountersIdMap()); for(size_t p = 0; p < num_parameters; p++) @@ -144,38 +133,29 @@ rocprofiler_configure_agent_thread_trace_service( switch(param.type) { - case ROCPROFILER_ATT_PARAMETER_TARGET_CU: param_pack.target_cu = param.value; break; + case ROCPROFILER_ATT_PARAMETER_TARGET_CU: pack.target_cu = param.value; break; case ROCPROFILER_ATT_PARAMETER_SHADER_ENGINE_MASK: - param_pack.shader_engine_mask = param.value; + pack.shader_engine_mask = param.value; break; - case ROCPROFILER_ATT_PARAMETER_BUFFER_SIZE: param_pack.buffer_size = param.value; break; - case ROCPROFILER_ATT_PARAMETER_SIMD_SELECT: param_pack.simd_select = param.value; break; + case ROCPROFILER_ATT_PARAMETER_BUFFER_SIZE: pack.buffer_size = param.value; break; + case ROCPROFILER_ATT_PARAMETER_SIMD_SELECT: pack.simd_select = param.value; break; case ROCPROFILER_ATT_PARAMETER_PERFCOUNTER: if(const auto* metric_ptr = rocprofiler::common::get_val(id_map, param.counter_id.handle)) - param_pack.perfcounters.push_back(get_mask(metric_ptr, param.simd_mask)); + pack.perfcounters.push_back(get_mask(metric_ptr, param.simd_mask)); break; case ROCPROFILER_ATT_PARAMETER_PERFCOUNTERS_CTRL: - param_pack.perfcounter_ctrl = param.value; + pack.perfcounter_ctrl = param.value; break; case ROCPROFILER_ATT_PARAMETER_LAST: return ROCPROFILER_STATUS_ERROR_INVALID_ARGUMENT; } } - auto tracer = std::make_unique(param_pack, agent); - - rocprofiler_status_t status = rocprofiler_create_context(&tracer->codeobj_client_ctx); - if(status != ROCPROFILER_STATUS_SUCCESS) return status; - - status = rocprofiler_configure_callback_tracing_service( - tracer->codeobj_client_ctx, - ROCPROFILER_CALLBACK_TRACING_CODE_OBJECT, - nullptr, - 0, - rocprofiler::thread_trace::AgentThreadTracer::codeobj_tracing_callback, - tracer.get()); + auto* agent_tracer = dynamic_cast(ctx->thread_trace.get()); + if(agent_tracer == nullptr || agent_tracer->has_agent(agent)) + return ROCPROFILER_STATUS_ERROR_SERVICE_ALREADY_CONFIGURED; - ctx->thread_trace = std::move(tracer); - return status; + agent_tracer->add_agent(agent, pack); + return ROCPROFILER_STATUS_SUCCESS; } } diff --git a/source/lib/rocprofiler-sdk/thread_trace/code_object.cpp b/source/lib/rocprofiler-sdk/thread_trace/code_object.cpp new file mode 100644 index 00000000..10498fd3 --- /dev/null +++ b/source/lib/rocprofiler-sdk/thread_trace/code_object.cpp @@ -0,0 +1,147 @@ +// MIT License +// +// Copyright (c) 2024 ROCm Developer Tools +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "lib/rocprofiler-sdk/thread_trace/code_object.hpp" +#include "lib/rocprofiler-sdk/code_object/code_object.hpp" + +namespace rocprofiler +{ +namespace thread_trace +{ +namespace code_object +{ +std::mutex CodeobjCallbackRegistry::mut; +std::set CodeobjCallbackRegistry::all_registries{}; + +CodeobjCallbackRegistry::CodeobjCallbackRegistry(LoadCallback _ld, UnloadCallback _unld) +: ld_fn(std::move(_ld)) +, unld_fn(std::move(_unld)) +{ + std::unique_lock lg(mut); + all_registries.insert(this); +} + +CodeobjCallbackRegistry::~CodeobjCallbackRegistry() +{ + std::unique_lock lg(mut); + all_registries.erase(this); +} + +void +CodeobjCallbackRegistry::Load(rocprofiler_agent_id_t agent, + uint64_t id, + uint64_t addr, + uint64_t size) +{ + std::unique_lock lg(mut); + for(auto* reg : all_registries) + reg->ld_fn(agent, id, addr, size); +} + +void +CodeobjCallbackRegistry::Unload(uint64_t id) +{ + std::unique_lock lg(mut); + for(auto* reg : all_registries) + reg->unld_fn(id); +} + +void +CodeobjCallbackRegistry::IterateLoaded() const +{ + std::unique_lock lg(mut); + + rocprofiler::code_object::iterate_loaded_code_objects( + [&](const rocprofiler::code_object::hsa::code_object& code_object) { + const auto& data = code_object.rocp_data; + ld_fn(data.rocp_agent, data.code_object_id, data.load_delta, data.load_size); + }); +} + +namespace +{ +auto& +get_freeze_function() +{ + static decltype(::hsa_executable_freeze)* _v = nullptr; + return _v; +} + +auto& +get_destroy_function() +{ + static decltype(::hsa_executable_destroy)* _v = nullptr; + return _v; +} + +hsa_status_t +executable_freeze(hsa_executable_t executable, const char* options) +{ + // Call underlying function + hsa_status_t status = CHECK_NOTNULL(get_freeze_function())(executable, options); + if(status != HSA_STATUS_SUCCESS) return status; + + rocprofiler::code_object::iterate_loaded_code_objects( + [&](const rocprofiler::code_object::hsa::code_object& code_object) { + if(code_object.hsa_executable != executable) return; + + const auto& data = code_object.rocp_data; + CodeobjCallbackRegistry::Load( + data.rocp_agent, data.code_object_id, data.load_delta, data.load_size); + }); + + return HSA_STATUS_SUCCESS; +} + +hsa_status_t +executable_destroy(hsa_executable_t executable) +{ + rocprofiler::code_object::iterate_loaded_code_objects( + [&](const rocprofiler::code_object::hsa::code_object& code_object) { + if(code_object.hsa_executable == executable) + CodeobjCallbackRegistry::Unload(code_object.rocp_data.code_object_id); + }); + + // Call underlying function + return CHECK_NOTNULL(get_destroy_function())(executable); +} +} // namespace + +void +initialize(HsaApiTable* table) +{ + (void) table; + auto& core_table = *table->core_; + + get_freeze_function() = CHECK_NOTNULL(core_table.hsa_executable_freeze_fn); + get_destroy_function() = CHECK_NOTNULL(core_table.hsa_executable_destroy_fn); + core_table.hsa_executable_freeze_fn = executable_freeze; + core_table.hsa_executable_destroy_fn = executable_destroy; + LOG_IF(FATAL, get_freeze_function() == core_table.hsa_executable_freeze_fn) + << "infinite recursion"; + LOG_IF(FATAL, get_destroy_function() == core_table.hsa_executable_destroy_fn) + << "infinite recursion"; +} + +} // namespace code_object +} // namespace thread_trace +} // namespace rocprofiler diff --git a/source/lib/rocprofiler-sdk/thread_trace/code_object.hpp b/source/lib/rocprofiler-sdk/thread_trace/code_object.hpp new file mode 100644 index 00000000..2ecebf9e --- /dev/null +++ b/source/lib/rocprofiler-sdk/thread_trace/code_object.hpp @@ -0,0 +1,63 @@ +// MIT License +// +// Copyright (c) 2024 ROCm Developer Tools +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#pragma once + +#include + +#include + +#include +#include +#include + +namespace rocprofiler +{ +namespace thread_trace +{ +namespace code_object +{ +struct CodeobjCallbackRegistry +{ + using LoadCallback = std::function; + using UnloadCallback = std::function; + + CodeobjCallbackRegistry(LoadCallback ld, UnloadCallback unld); + virtual ~CodeobjCallbackRegistry(); + + void IterateLoaded() const; + static void Load(rocprofiler_agent_id_t agent, uint64_t id, uint64_t addr, uint64_t size); + static void Unload(uint64_t id); + +private: + LoadCallback ld_fn; + UnloadCallback unld_fn; + + static std::mutex mut; + static std::set all_registries; +}; + +void +initialize(HsaApiTable* table); +} // namespace code_object +} // namespace thread_trace +} // namespace rocprofiler