Skip to content

Commit

Permalink
Adding wrappers on HSA for executable load/unload and allowing multip…
Browse files Browse the repository at this point in the history
…le agents per context on ATT (#951)

* Codeobj wrappers around HSA calls for ATT

* Formatting

* Bookeeping

* Tidy

* Tidy

* Update source/lib/rocprofiler-sdk/thread_trace/code_object.hpp

Co-authored-by: Vladimir Indic <[email protected]>

* Update source/lib/rocprofiler-sdk/thread_trace/att_core.hpp

Co-authored-by: Vladimir Indic <[email protected]>

* Variable naming

---------

Co-authored-by: Vladimir Indic <[email protected]>
  • Loading branch information
ApoKalipse-V and vlaindic authored Jun 25, 2024
1 parent b62ba5f commit 8da0c35
Show file tree
Hide file tree
Showing 9 changed files with 316 additions and 157 deletions.
2 changes: 1 addition & 1 deletion source/lib/rocprofiler-sdk/hsa/aql_packet.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ class TraceControlAQLPacket : public AQLPacket
loaded_codeobj[id] =
std::make_shared<CodeobjMarkerAQLPacket>(*tracepool, id, addr, size, true, false);
}
void remove_codeobj(code_object_id_t id) { loaded_codeobj.erase(id); }
bool remove_codeobj(code_object_id_t id) { return loaded_codeobj.erase(id) != 0; }

protected:
std::shared_ptr<TraceMemoryPool> tracepool;
Expand Down
6 changes: 1 addition & 5 deletions source/lib/rocprofiler-sdk/hsa/queue_controller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,15 +277,11 @@ QueueController::init(CoreApiTable& core_table, AmdExtTable& ext_table)
itr->buffered_tracer->domains(ROCPROFILER_BUFFER_TRACING_KERNEL_DISPATCH));

if(itr->counter_collection || itr->pc_sampler || has_kernel_tracing ||
itr->agent_counter_collection)
itr->agent_counter_collection || itr->thread_trace)
{
enable_intercepter = true;
break;
}
else if(itr->thread_trace)
{
enable_intercepter = true;
}
}

if(enable_intercepter)
Expand Down
1 change: 1 addition & 0 deletions source/lib/rocprofiler-sdk/registration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,7 @@ rocprofiler_set_api_table(const char* name,

rocprofiler::hsa::async_copy_init(hsa_api_table, lib_instance);
rocprofiler::code_object::initialize(hsa_api_table);
rocprofiler::thread_trace::code_object::initialize(hsa_api_table);
#if ROCPROFILER_SDK_HSA_PC_SAMPLING > 0
rocprofiler::pc_sampling::code_object::initialize(hsa_api_table);
#endif
Expand Down
5 changes: 3 additions & 2 deletions source/lib/rocprofiler-sdk/thread_trace/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
set(ROCPROFILER_LIB_THREAD_TRACE_SOURCES att_core.cpp att_service.cpp att_parser.cpp)
set(ROCPROFILER_LIB_THREAD_TRACE_HEADERS att_core.hpp)
set(ROCPROFILER_LIB_THREAD_TRACE_SOURCES att_core.cpp att_service.cpp att_parser.cpp
code_object.cpp)
set(ROCPROFILER_LIB_THREAD_TRACE_HEADERS att_core.hpp code_object.hpp)
target_sources(rocprofiler-object-library PRIVATE ${ROCPROFILER_LIB_THREAD_TRACE_SOURCES}
${ROCPROFILER_LIB_THREAD_TRACE_HEADERS})

Expand Down
124 changes: 46 additions & 78 deletions source/lib/rocprofiler-sdk/thread_trace/att_core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ ThreadTracerQueue::ThreadTracerQueue(thread_trace_parameter_pack _params,
const CoreApiTable& coreapi,
const AmdExtTable& ext)
: params(std::move(_params))
, agent_id(cache.get_rocp_agent()->id)
{
factory = std::make_unique<aql::ThreadTraceAQLPacketFactory>(cache, this->params, coreapi, ext);
control_packet = factory->construct_control_packet();
Expand All @@ -122,6 +123,14 @@ ThreadTracerQueue::ThreadTracerQueue(thread_trace_parameter_pack _params,
signal_store_screlease_fn = coreapi.hsa_signal_store_screlease_fn;
add_write_index_relaxed_fn = coreapi.hsa_queue_add_write_index_relaxed_fn;
load_read_index_relaxed_fn = coreapi.hsa_queue_load_read_index_relaxed_fn;

codeobj_reg = std::make_unique<code_object::CodeobjCallbackRegistry>(
[this](rocprofiler_agent_id_t agent, uint64_t codeobj_id, uint64_t addr, uint64_t size) {
if(agent == this->agent_id) this->load_codeobj(codeobj_id, addr, size);
},
[this](uint64_t codeobj_id) { this->unload_codeobj(codeobj_id); });

codeobj_reg->IterateLoaded();
}

ThreadTracerQueue::~ThreadTracerQueue()
Expand Down Expand Up @@ -203,8 +212,7 @@ ThreadTracerQueue::unload_codeobj(code_object_id_t id)
{
std::unique_lock<std::mutex> lk(trace_resources_mut);

control_packet->remove_codeobj(id);

if(!control_packet->remove_codeobj(id)) return;
if(!queue || active_traces.load() < 1) return;

auto packet = factory->construct_unload_marker_packet(id);
Expand All @@ -214,33 +222,6 @@ ThreadTracerQueue::unload_codeobj(code_object_id_t id)
packet.release();
}

// TODO: make this a wrapper on HSA load instead of registering
void
DispatchThreadTracer::codeobj_tracing_callback(rocprofiler_callback_tracing_record_t record,
rocprofiler_user_data_t* /* user_data */,
void* callback_data)
{
if(!callback_data) return;
if(record.kind != ROCPROFILER_CALLBACK_TRACING_CODE_OBJECT) return;
if(record.operation != ROCPROFILER_CODE_OBJECT_LOAD) return;

auto* rec = static_cast<rocprofiler_callback_tracing_code_object_load_data_t*>(record.payload);
assert(rec);

DispatchThreadTracer& tracer = *static_cast<DispatchThreadTracer*>(callback_data);
auto agent = rec->hsa_agent;

std::shared_lock<std::shared_mutex> lk(tracer.agents_map_mut);

auto tracer_it = tracer.agents.find(agent);
if(tracer_it == tracer.agents.end()) return;

if(record.phase == ROCPROFILER_CALLBACK_PHASE_LOAD)
tracer_it->second->load_codeobj(rec->code_object_id, rec->load_delta, rec->load_size);
else if(record.phase == ROCPROFILER_CALLBACK_PHASE_UNLOAD)
tracer_it->second->unload_codeobj(rec->code_object_id);
}

void
DispatchThreadTracer::resource_init(const hsa::AgentCache& cache,
const CoreApiTable& coreapi,
Expand Down Expand Up @@ -377,12 +358,6 @@ void
DispatchThreadTracer::start_context()
{
using corr_id_map_t = hsa::Queue::queue_info_session_t::external_corr_id_map_t;
if(codeobj_client_ctx.handle != 0)
{
auto status = rocprofiler_start_context(codeobj_client_ctx);
if(status != ROCPROFILER_STATUS_SUCCESS) throw std::exception();
}

CHECK_NOTNULL(hsa::get_queue_controller())->enable_serialization();

// Only one thread should be attempting to enable/disable this context
Expand Down Expand Up @@ -427,82 +402,75 @@ AgentThreadTracer::resource_init(const hsa::AgentCache& cache,
const CoreApiTable& coreapi,
const AmdExtTable& ext)
{
if(cache.get_rocp_agent()->id != this->agent_id) return;
auto id = cache.get_rocp_agent()->id;
std::unique_lock<std::mutex> lk(agent_mut);

std::unique_lock<std::mutex> lk(mut);
if(params.find(id) == params.end()) return;

if(tracer != nullptr)
if(tracers.find(id) != tracers.end())
{
tracer->active_queues.fetch_add(1);
tracers.at(id)->active_queues.fetch_add(1);
return;
}

tracer = std::make_unique<ThreadTracerQueue>(this->params, cache, coreapi, ext);
tracers.emplace(id, std::make_unique<ThreadTracerQueue>(params.at(id), cache, coreapi, ext));
}

void
AgentThreadTracer::resource_deinit(const hsa::AgentCache& cache)
{
if(cache.get_rocp_agent()->id != this->agent_id) return;
auto id = cache.get_rocp_agent()->id;
std::unique_lock<std::mutex> lk(agent_mut);

std::unique_lock<std::mutex> lk(mut);
if(tracer == nullptr) return;
if(params.find(id) == params.end()) return;
if(tracers.find(id) == tracers.end()) return;

if(tracer->active_queues.fetch_sub(1) == 1) tracer.reset();
auto& tracer = *tracers.at(id);
if(tracer.active_queues.fetch_sub(1) == 1) tracers.erase(id);
}

void
AgentThreadTracer::start_context()
{
std::unique_lock<std::mutex> lk(mut);
std::unique_lock<std::mutex> lk(agent_mut);

if(tracer == nullptr)
if(tracers.empty())
{
ROCP_FATAL << "Thread trace context not present for agent!";
return;
}

auto packet = tracer->get_control(true);
packet->populate_before();
for(auto& [_, tracer] : tracers)
{
auto packet = tracer->get_control(true);
packet->populate_before();

for(auto& start : packet->before_krn_pkt)
tracer->Submit(&start);
for(auto& start : packet->before_krn_pkt)
tracer->Submit(&start);
}
}

void
AgentThreadTracer::stop_context()
{
std::unique_lock<std::mutex> lk(mut);

auto packet = tracer->get_control(false);
packet->populate_after();
std::unique_lock<std::mutex> lk(agent_mut);

for(auto& stop : packet->after_krn_pkt)
tracer->Submit(&stop);

rocprofiler_user_data_t userdata{.ptr = params.callback_userdata};
tracer->iterate_data(packet->GetHandle(), userdata);
}

void
AgentThreadTracer::codeobj_tracing_callback(rocprofiler_callback_tracing_record_t record,
rocprofiler_user_data_t* /* user_data */,
void* callback_data)
{
if(!callback_data) return;
if(record.kind != ROCPROFILER_CALLBACK_TRACING_CODE_OBJECT) return;
if(record.operation != ROCPROFILER_CODE_OBJECT_LOAD) return;
if(tracers.empty())
{
ROCP_FATAL << "Thread trace context not present for agent!";
return;
}

auto* rec = static_cast<rocprofiler_callback_tracing_code_object_load_data_t*>(record.payload);
assert(rec);
for(auto& [_, tracer] : tracers)
{
auto packet = tracer->get_control(false);
packet->populate_after();

AgentThreadTracer& tracer = *static_cast<AgentThreadTracer*>(callback_data);
std::unique_lock<std::mutex> lk(tracer.mut);
for(auto& stop : packet->after_krn_pkt)
tracer->Submit(&stop);

if(record.phase == ROCPROFILER_CALLBACK_PHASE_LOAD)
tracer.tracer->load_codeobj(rec->code_object_id, rec->load_delta, rec->load_size);
else if(record.phase == ROCPROFILER_CALLBACK_PHASE_UNLOAD)
tracer.tracer->unload_codeobj(rec->code_object_id);
rocprofiler_user_data_t userdata{.ptr = tracer->params.callback_userdata};
tracer->iterate_data(packet->GetHandle(), userdata);
}
}

} // namespace thread_trace
Expand Down
43 changes: 23 additions & 20 deletions source/lib/rocprofiler-sdk/thread_trace/att_core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@

#pragma once

#include "lib/rocprofiler-sdk/hsa/agent_cache.hpp"
#include "lib/rocprofiler-sdk/thread_trace/code_object.hpp"

#include <rocprofiler-sdk/cxx/hash.hpp>
#include <rocprofiler-sdk/cxx/operators.hpp>
#include "lib/rocprofiler-sdk/hsa/agent_cache.hpp"

#include <rocprofiler-sdk/amd_detail/thread_trace.h>
#include <rocprofiler-sdk/intercept_table.h>
Expand Down Expand Up @@ -100,6 +102,10 @@ class ThreadTracerQueue
bool Submit(hsa_ext_amd_aql_pm4_packet_t* packet);

private:
std::unique_ptr<code_object::CodeobjCallbackRegistry> codeobj_reg{nullptr};

rocprofiler_agent_id_t agent_id;

decltype(hsa_queue_load_read_index_relaxed)* load_read_index_relaxed_fn{nullptr};
decltype(hsa_queue_add_write_index_relaxed)* add_write_index_relaxed_fn{nullptr};
decltype(hsa_signal_store_screlease)* signal_store_screlease_fn{nullptr};
Expand Down Expand Up @@ -135,10 +141,6 @@ class DispatchThreadTracer : public ThreadTracerInterface
void resource_init(const hsa::AgentCache&, const CoreApiTable&, const AmdExtTable&) override;
void resource_deinit(const hsa::AgentCache&) override;

static void codeobj_tracing_callback(rocprofiler_callback_tracing_record_t record,
rocprofiler_user_data_t* user_data,
void* callback_data);

std::unique_ptr<hsa::AQLPacket> pre_kernel_call(const hsa::Queue& queue,
uint64_t kernel_id,
rocprofiler_dispatch_id_t dispatch_id,
Expand All @@ -153,33 +155,34 @@ class DispatchThreadTracer : public ThreadTracerInterface
std::atomic<int> post_move_data{0};

thread_trace_parameter_pack params;
rocprofiler_context_id_t codeobj_client_ctx{0};
};

class AgentThreadTracer : public ThreadTracerInterface
{
public:
AgentThreadTracer(thread_trace_parameter_pack _params, rocprofiler_agent_id_t _id)
: agent_id(_id)
, params(std::move(_params))
{}
AgentThreadTracer() = default;
~AgentThreadTracer() override = default;

void start_context() override;
void stop_context() override;
void resource_init(const hsa::AgentCache&, const CoreApiTable&, const AmdExtTable&) override;
void resource_deinit(const hsa::AgentCache&) override;

static void codeobj_tracing_callback(rocprofiler_callback_tracing_record_t record,
rocprofiler_user_data_t* user_data,
void* callback_data);

rocprofiler_agent_id_t agent_id;
std::mutex mut;
std::unique_ptr<ThreadTracerQueue> tracer{nullptr};

thread_trace_parameter_pack params;
rocprofiler_context_id_t codeobj_client_ctx{0};
void add_agent(rocprofiler_agent_id_t id, thread_trace_parameter_pack _params)
{
std::unique_lock<std::mutex> lk(agent_mut);
params[id] = std::move(_params);
}
bool has_agent(rocprofiler_agent_id_t id)
{
std::unique_lock<std::mutex> lk(agent_mut);
return params.find(id) != params.end();
}

std::map<rocprofiler_agent_id_t, std::unique_ptr<ThreadTracerQueue>> tracers{};
std::map<rocprofiler_agent_id_t, thread_trace_parameter_pack> params;

std::mutex agent_mut;
};

}; // namespace thread_trace
Expand Down
Loading

0 comments on commit 8da0c35

Please sign in to comment.