Skip to content

Commit

Permalink
SWDEV-489158: Adding consumer+producer model to AST evaluation (#13)
Browse files Browse the repository at this point in the history
* Rebased optizations for rocprofv3 tool

* Fixing merge conflicts

* Formatting

* Open from within mutex

* Small name changes

* Added operator

* removed some parameters

* Optimizing counter collection

* Re-arrange code

* Adding back dimension query

* Formatting

* Update source/lib/rocprofiler-sdk/thread_trace/att_core.cpp

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Formatting 2

* Fix for test compilation

* Fix for yield

* Adding back check for zero

* Improved thread handling

* Formatting

* Remove automatic start

* Adding test

* Small fixes

* Adding lock for buffer callbacks

* Fix for race condition in AST

* Adding check for ptr

---------

Co-authored-by: Giovanni Baraldi <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 5, 2024
1 parent c42bdc3 commit b7661bc
Show file tree
Hide file tree
Showing 15 changed files with 577 additions and 153 deletions.
23 changes: 19 additions & 4 deletions source/lib/rocprofiler-sdk/counters/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,24 @@
set(ROCPROFILER_LIB_COUNTERS_SOURCES
metrics.cpp dimensions.cpp evaluate_ast.cpp core.cpp id_decode.cpp
dispatch_handlers.cpp controller.cpp device_counting.cpp)
metrics.cpp
dimensions.cpp
evaluate_ast.cpp
core.cpp
id_decode.cpp
dispatch_handlers.cpp
sample_processing.cpp
controller.cpp
device_counting.cpp)
set(ROCPROFILER_LIB_COUNTERS_HEADERS
metrics.hpp dimensions.hpp evaluate_ast.hpp core.hpp id_decode.hpp
dispatch_handlers.hpp controller.hpp device_counting.hpp)
metrics.hpp
dimensions.hpp
evaluate_ast.hpp
core.hpp
id_decode.hpp
dispatch_handlers.hpp
sample_processing.hpp
controller.hpp
device_counting.hpp
sample_consumer.hpp)
target_sources(rocprofiler-sdk-object-library PRIVATE ${ROCPROFILER_LIB_COUNTERS_SOURCES}
${ROCPROFILER_LIB_COUNTERS_HEADERS})

Expand Down
17 changes: 11 additions & 6 deletions source/lib/rocprofiler-sdk/counters/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "lib/rocprofiler-sdk/aql/packet_construct.hpp"
#include "lib/rocprofiler-sdk/context/context.hpp"
#include "lib/rocprofiler-sdk/counters/dispatch_handlers.hpp"
#include "lib/rocprofiler-sdk/counters/sample_processing.hpp"
#include "lib/rocprofiler-sdk/hsa/queue_controller.hpp"
#include "lib/rocprofiler-sdk/kernel_dispatch/profiling_time.hpp"

Expand Down Expand Up @@ -157,6 +158,8 @@ start_context(const context::context* ctx)

if(!already_enabled)
{
callback_thread_start();

for(auto& cb : ctx->counter_collection->callbacks)
{
// Insert our callbacks into HSA Interceptor. This
Expand All @@ -182,12 +185,12 @@ start_context(const context::context* ctx)
correlation_id);
},
// Completion CB
[=](const hsa::Queue& q,
hsa::rocprofiler_packet kern_pkt,
const hsa::Queue::queue_info_session_t& session,
inst_pkt_t& aql,
kernel_dispatch::profiling_time dispatch_time) {
completed_cb(ctx, cb, q, kern_pkt, session, aql, dispatch_time);
[=](const hsa::Queue& /* q */,
hsa::rocprofiler_packet /* kern_pkt */,
std::shared_ptr<hsa::Queue::queue_info_session_t>& session,
inst_pkt_t& aql,
kernel_dispatch::profiling_time dispatch_time) {
completed_cb(ctx, cb, session, aql, dispatch_time);
});
}
}
Expand All @@ -206,6 +209,8 @@ stop_context(const context::context* ctx)
});

if(controller) controller->disable_serialization();

callback_thread_stop();
}

rocprofiler_status_t
Expand Down
109 changes: 10 additions & 99 deletions source/lib/rocprofiler-sdk/counters/dispatch_handlers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "lib/rocprofiler-sdk/buffer.hpp"
#include "lib/rocprofiler-sdk/context/context.hpp"
#include "lib/rocprofiler-sdk/counters/core.hpp"
#include "lib/rocprofiler-sdk/counters/sample_processing.hpp"
#include "lib/rocprofiler-sdk/hsa/queue_controller.hpp"
#include "lib/rocprofiler-sdk/kernel_dispatch/profiling_time.hpp"

Expand Down Expand Up @@ -162,14 +163,13 @@ queue_cb(const context::context* ctx,
* Callback called by HSA interceptor when the kernel has completed processing.
*/
void
completed_cb(const context::context* ctx,
const std::shared_ptr<counter_callback_info>& info,
const hsa::Queue& /*queue*/,
hsa::rocprofiler_packet /*packet*/,
const hsa::Queue::queue_info_session_t& session,
inst_pkt_t& pkts,
kernel_dispatch::profiling_time dispatch_time)
completed_cb(const context::context* ctx,
const std::shared_ptr<counter_callback_info>& info,
std::shared_ptr<hsa::Queue::queue_info_session_t>& ptr_session,
inst_pkt_t& pkts,
kernel_dispatch::profiling_time dispatch_time)
{
auto& session = *ptr_session;
CHECK(info && ctx);

std::shared_ptr<profile_config> prof_config;
Expand Down Expand Up @@ -198,98 +198,9 @@ completed_cb(const context::context* ctx,
// We have no profile config, nothing to output.
if(!prof_config) return;

auto decoded_pkt = EvaluateAST::read_pkt(prof_config->pkt_generator.get(), *pkt);
EvaluateAST::read_special_counters(
*prof_config->agent, prof_config->required_special_counters, decoded_pkt);

prof_config->packets.wlock([&](auto& pkt_vector) {
if(pkt)
{
pkt_vector.emplace_back(std::move(pkt));
}
});

common::container::small_vector<rocprofiler_record_counter_t, 128> out;
rocprofiler::buffer::instance* buf = nullptr;

if(info->buffer)
{
buf = CHECK_NOTNULL(buffer::get_buffer(info->buffer->handle));
}

auto _corr_id_v =
rocprofiler_correlation_id_t{.internal = 0, .external = context::null_user_data};
if(const auto* _corr_id = session.correlation_id)
{
_corr_id_v.internal = _corr_id->internal;
if(const auto* external = rocprofiler::common::get_val(
session.tracing_data.external_correlation_ids, info->internal_context))
{
_corr_id_v.external = *external;
}
}

auto _dispatch_id = session.callback_record.dispatch_info.dispatch_id;
for(auto& ast : prof_config->asts)
{
std::vector<std::unique_ptr<std::vector<rocprofiler_record_counter_t>>> cache;
auto* ret = ast.evaluate(decoded_pkt, cache);
CHECK(ret);
ast.set_out_id(*ret);

out.reserve(out.size() + ret->size());
for(auto& val : *ret)
{
val.agent_id = prof_config->agent->id;
val.dispatch_id = _dispatch_id;
out.emplace_back(val);
}
}

if(!out.empty())
{
if(buf)
{
auto _header =
common::init_public_api_struct(rocprofiler_dispatch_counting_service_record_t{});
_header.num_records = out.size();
_header.correlation_id = _corr_id_v;
if(dispatch_time.status == HSA_STATUS_SUCCESS)
{
_header.start_timestamp = dispatch_time.start;
_header.end_timestamp = dispatch_time.end;
}
_header.dispatch_info = session.callback_record.dispatch_info;
buf->emplace(ROCPROFILER_BUFFER_CATEGORY_COUNTERS,
ROCPROFILER_COUNTER_RECORD_PROFILE_COUNTING_DISPATCH_HEADER,
_header);

for(auto itr : out)
buf->emplace(
ROCPROFILER_BUFFER_CATEGORY_COUNTERS, ROCPROFILER_COUNTER_RECORD_VALUE, itr);
}
else
{
CHECK(info->record_callback);

auto dispatch_data =
common::init_public_api_struct(rocprofiler_dispatch_counting_service_data_t{});

dispatch_data.dispatch_info = session.callback_record.dispatch_info;
dispatch_data.correlation_id = _corr_id_v;
if(dispatch_time.status == HSA_STATUS_SUCCESS)
{
dispatch_data.start_timestamp = dispatch_time.start;
dispatch_data.end_timestamp = dispatch_time.end;
}

info->record_callback(dispatch_data,
out.data(),
out.size(),
session.user_data,
info->record_callback_args);
}
}
completed_cb_params_t params{info, ptr_session, dispatch_time, prof_config, std::move(pkt)};
process_callback_data(std::move(params));
}

} // namespace counters
} // namespace rocprofiler
13 changes: 6 additions & 7 deletions source/lib/rocprofiler-sdk/counters/dispatch_handlers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,11 @@ queue_cb(const context::context* ctx,
const context::correlation_id* correlation_id);

void
completed_cb(const context::context* ctx,
const std::shared_ptr<counter_callback_info>& info,
const hsa::Queue& queue,
hsa::rocprofiler_packet packet,
const hsa::Queue::queue_info_session_t& session,
inst_pkt_t& pkts,
kernel_dispatch::profiling_time dispatch_time);
completed_cb(const context::context* ctx,
const std::shared_ptr<counter_callback_info>& info,
std::shared_ptr<hsa::Queue::queue_info_session_t>& session,
inst_pkt_t& pkts,
kernel_dispatch::profiling_time dispatch_time);

} // namespace counters
} // namespace rocprofiler
20 changes: 13 additions & 7 deletions source/lib/rocprofiler-sdk/counters/evaluate_ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
// SOFTWARE.

#include "lib/rocprofiler-sdk/counters/evaluate_ast.hpp"
#include "lib/common/static_object.hpp"
#include "lib/common/synchronized.hpp"

#include <algorithm>
#include <cstdint>
Expand Down Expand Up @@ -569,7 +571,9 @@ using property_function_t = int64_t (*)(const rocprofiler_agent_t&);
int64_t
get_agent_property(std::string_view property, const rocprofiler_agent_t& agent)
{
static std::unordered_map<std::string_view, property_function_t> props = {
using map_t = std::unordered_map<std::string_view, property_function_t>;

static auto*& _props = common::static_object<common::Synchronized<map_t>>::construct(map_t{
GEN_MAP_ENTRY("cpu_cores_count", agent_info.cpu_cores_count),
GEN_MAP_ENTRY("simd_count", agent_info.simd_count),
GEN_MAP_ENTRY("mem_banks_count", agent_info.mem_banks_count),
Expand Down Expand Up @@ -599,13 +603,15 @@ get_agent_property(std::string_view property, const rocprofiler_agent_t& agent)
GEN_MAP_ENTRY("num_sdma_queues_per_engine", agent_info.num_sdma_queues_per_engine),
GEN_MAP_ENTRY("num_cp_queues", agent_info.num_cp_queues),
GEN_MAP_ENTRY("max_engine_clk_ccompute", agent_info.max_engine_clk_ccompute),
};
if(const auto* func = rocprofiler::common::get_val(props, property))
{
return (*func)(agent);
}
});

return 0.0;
return CHECK_NOTNULL(_props)->wlock([&property, &agent](map_t& props) -> int64_t {
if(const auto* func = rocprofiler::common::get_val(props, property))
{
return (*func)(agent);
}
return 0;
});
}

void
Expand Down
109 changes: 109 additions & 0 deletions source/lib/rocprofiler-sdk/counters/sample_consumer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// MIT License
//
// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.

#pragma once

#include "lib/rocprofiler-sdk/counters/sample_processing.hpp"

#include <condition_variable>
#include <mutex>
#include <thread>

namespace rocprofiler
{
namespace counters
{
template <typename DataType>
class consumer_thread_t
{
static constexpr size_t SIZE = 128;
using consume_func_t = std::function<void(DataType&&)>;

public:
consumer_thread_t(consume_func_t func) { this->consume_fn = func; }
virtual ~consumer_thread_t() { exit(); }

void start()
{
{
std::unique_lock<std::mutex> lk(mut);
if(valid.exchange(true)) return;
}
consumer = std::thread{&consumer_thread_t::consumer_loop, this};
}

void exit()
{
{
std::unique_lock<std::mutex> lk(mut);
if(!valid.exchange(false)) return;
cv.notify_one();
}
consumer.join();
}

void add(DataType&& params)
{
std::unique_lock<std::mutex> lk(mut);

if(read_ptr + buffer.size() <= write_ptr || !valid)
{
// If not possible to use consumer thread, proccess with this thread
consume_fn(std::move(params));
return;
}

buffer.at(write_ptr % buffer.size()) = std::move(params);
write_ptr.fetch_add(1);
cv.notify_one();
}

protected:
void consumer_loop()
{
while(true)
{
while(read_ptr == write_ptr)
{
std::unique_lock<std::mutex> lk(mut);
cv.wait(lk, [&] { return read_ptr != write_ptr || !valid; });
if(!valid && read_ptr == write_ptr) return;
}

auto retrieved = std::move(buffer.at(read_ptr % buffer.size()));
read_ptr.fetch_add(1);
consume_fn(std::move(retrieved));
}
}

consume_func_t consume_fn;
std::atomic<bool> valid{false};
std::mutex mut;
std::atomic<size_t> write_ptr{0};
std::atomic<size_t> read_ptr{0};
std::array<DataType, SIZE> buffer;
std::thread consumer;
std::condition_variable cv;
};

} // namespace counters
} // namespace rocprofiler
Loading

0 comments on commit b7661bc

Please sign in to comment.