diff --git a/samples/CMakeLists.txt b/samples/CMakeLists.txt index 32ec96b3..9a2f18ce 100644 --- a/samples/CMakeLists.txt +++ b/samples/CMakeLists.txt @@ -25,6 +25,7 @@ add_subdirectory(common) # actual samples add_subdirectory(api_callback_tracing) +add_subdirectory(api_callback_tracing_memcpy_bug) add_subdirectory(api_buffered_tracing) add_subdirectory(code_object_tracing) add_subdirectory(counter_collection) diff --git a/samples/api_callback_tracing_memcpy_bug/BUG_DESCRIPTION b/samples/api_callback_tracing_memcpy_bug/BUG_DESCRIPTION new file mode 100644 index 00000000..ecb4fe1c --- /dev/null +++ b/samples/api_callback_tracing_memcpy_bug/BUG_DESCRIPTION @@ -0,0 +1,2 @@ +for an internal correlation id, the external correlation id popped is supposed to be the same as the external correlation id pushed. +these are not the same for when using callback api tracing for ROCPROFILER_CALLBACK_TRACING_MEMORY_COPY. the value popped is always 0. diff --git a/samples/api_callback_tracing_memcpy_bug/CMakeLists.txt b/samples/api_callback_tracing_memcpy_bug/CMakeLists.txt new file mode 100644 index 00000000..14be8a04 --- /dev/null +++ b/samples/api_callback_tracing_memcpy_bug/CMakeLists.txt @@ -0,0 +1,61 @@ +# +# +# +cmake_minimum_required(VERSION 3.21.0 FATAL_ERROR) + +if(NOT CMAKE_HIP_COMPILER) + find_program( + amdclangpp_EXECUTABLE + NAMES amdclang++ + HINTS ${ROCM_PATH} ENV ROCM_PATH /opt/rocm + PATHS ${ROCM_PATH} ENV ROCM_PATH /opt/rocm + PATH_SUFFIXES bin llvm/bin NO_CACHE) + mark_as_advanced(amdclangpp_EXECUTABLE) + + if(amdclangpp_EXECUTABLE) + set(CMAKE_HIP_COMPILER "${amdclangpp_EXECUTABLE}") + endif() +endif() + +project(rocprofiler-sdk-samples-callback-api-tracing-memcpy-bug LANGUAGES CXX HIP) + +foreach(_TYPE DEBUG MINSIZEREL RELEASE RELWITHDEBINFO) + if("${CMAKE_HIP_FLAGS_${_TYPE}}" STREQUAL "") + set(CMAKE_HIP_FLAGS_${_TYPE} "${CMAKE_CXX_FLAGS_${_TYPE}}") + endif() +endforeach() + +find_package(rocprofiler-sdk REQUIRED) + +add_library(callback-api-tracing-client-memcpy-bug SHARED) +target_sources(callback-api-tracing-client-memcpy-bug PRIVATE client.cpp client.hpp) +target_link_libraries( + callback-api-tracing-client-memcpy-bug + PRIVATE rocprofiler-sdk::rocprofiler-sdk rocprofiler-sdk::samples-build-flags + rocprofiler-sdk::samples-common-library) + +set_source_files_properties(main.cpp PROPERTIES LANGUAGE HIP) + +find_package(Threads REQUIRED) +find_package(rocprofiler-sdk-roctx REQUIRED) + +add_executable(callback-api-tracing-memcpy-bug) +target_sources(callback-api-tracing-memcpy-bug PRIVATE main.cpp) +target_link_libraries( + callback-api-tracing-memcpy-bug + PRIVATE callback-api-tracing-client-memcpy-bug Threads::Threads + rocprofiler-sdk-roctx::rocprofiler-sdk-roctx + rocprofiler-sdk::samples-build-flags) + +rocprofiler_samples_get_preload_env(PRELOAD_ENV callback-api-tracing-client-memcpy-bug) +rocprofiler_samples_get_ld_library_path_env( + LIBRARY_PATH_ENV rocprofiler-sdk-roctx::rocprofiler-sdk-roctx-shared-library) + +set(callback-api-tracing-memcpy-bug-env ${PRELOAD_ENV} ${LIBRARY_PATH_ENV}) + +add_test(NAME callback-api-tracing-memcpy-bug COMMAND $) + +set_tests_properties( + callback-api-tracing-memcpy-bug + PROPERTIES TIMEOUT 45 LABELS "samples" ENVIRONMENT "${callback-api-tracing-memcpy-bug-env}" + FAIL_REGULAR_EXPRESSION "${ROCPROFILER_DEFAULT_FAIL_REGEX}") diff --git a/samples/api_callback_tracing_memcpy_bug/client.cpp b/samples/api_callback_tracing_memcpy_bug/client.cpp new file mode 100644 index 00000000..22e9d236 --- /dev/null +++ b/samples/api_callback_tracing_memcpy_bug/client.cpp @@ -0,0 +1,391 @@ +// MIT License +// +// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// undefine NDEBUG so asserts are implemented +#ifdef NDEBUG +# undef NDEBUG +#endif + +/** + * @file samples/api_callback_tracing/client.cpp + * + * @brief Example rocprofiler client (tool) + */ + +#include "client.hpp" + +#include +#include +#include +#include +#include + +#include "common/call_stack.hpp" +#include "common/defines.hpp" +#include "common/filesystem.hpp" +#include "common/name_info.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace client +{ +namespace +{ +using common::call_stack_t; +using common::callback_name_info; +using common::source_location; + +rocprofiler_client_id_t* client_id = nullptr; +rocprofiler_client_finalize_t client_fini_func = nullptr; +rocprofiler_context_id_t client_ctx = {}; + +void +print_call_stack(const call_stack_t& _call_stack) +{ + common::print_call_stack("api_callback_trace.log", _call_stack); +} + +void +tool_tracing_ctrl_callback(rocprofiler_callback_tracing_record_t record, + rocprofiler_user_data_t*, + void* client_data) +{ + auto* ctx = static_cast(client_data); + + if(record.phase == ROCPROFILER_CALLBACK_PHASE_ENTER && + record.kind == ROCPROFILER_CALLBACK_TRACING_MARKER_CONTROL_API && + record.operation == ROCPROFILER_MARKER_CONTROL_API_ID_roctxProfilerPause) + { + ROCPROFILER_CALL(rocprofiler_stop_context(*ctx), "pausing client context"); + } + else if(record.phase == ROCPROFILER_CALLBACK_PHASE_EXIT && + record.kind == ROCPROFILER_CALLBACK_TRACING_MARKER_CONTROL_API && + record.operation == ROCPROFILER_MARKER_CONTROL_API_ID_roctxProfilerResume) + { + ROCPROFILER_CALL(rocprofiler_start_context(*ctx), "resuming client context"); + } +} + + +#define ROCPROFILER_CALL2(fn, args, msg) /* return status using GNU ext */ \ +({ \ + rocprofiler_status_t __status = fn args; \ + if (__status != ROCPROFILER_STATUS_SUCCESS) { \ + const char *__status_msg = rocprofiler_get_status_string(__status); \ + fprintf(stderr, "hpcrun: rocprofiler failure '%s' " \ + " status = %d, status_msg = '%s'\n", msg, __status, \ + __status_msg); \ + exit(-1); \ + } \ + __status; \ +}) + + +rocprofiler_thread_id_t +rocm_threads_self +( + void +) +{ + rocprofiler_thread_id_t my_thread_id; + + ROCPROFILER_CALL2 + ( + rocprofiler_get_thread_id, + (&my_thread_id), + "get thread id" + ); + + return my_thread_id; +} + + +uint64_t +rocm_cid_push +( + rocprofiler_context_id_t context_id +) +{ + static uint64_t correlation_id = 17; + + rocprofiler_user_data_t rud; + rud.value = correlation_id++; + + ROCPROFILER_CALL2 + ( + rocprofiler_push_external_correlation_id, + (context_id, rocm_threads_self(), rud), + "correlation id push" + ); + + return rud.value; +} + + +uint64_t +rocm_cid_pop +( + rocprofiler_context_id_t context_id +) +{ + rocprofiler_user_data_t rud; + ROCPROFILER_CALL2 + ( + rocprofiler_pop_external_correlation_id, + (context_id, rocm_threads_self(), &rud), + "correlation id pop" + ); + + return rud.value; +} + +void +tool_tracing_callback(rocprofiler_callback_tracing_record_t record, + rocprofiler_user_data_t* user_data, + void* callback_data) +{ + uint64_t ext_correlation_id; + + switch(record.phase) { + case ROCPROFILER_CALLBACK_PHASE_ENTER: + ext_correlation_id = rocm_cid_push(client_ctx); + + printf("push correlation id: external 0x%lx internal = 0x%lx\n", + ext_correlation_id, + record.correlation_id.internal); + break; + + case ROCPROFILER_CALLBACK_PHASE_EXIT: + ext_correlation_id = rocm_cid_pop(client_ctx); + + printf("pop correlation id: external 0x%lx internal = 0x%lx\n", + ext_correlation_id, + record.correlation_id.internal); + break; + + default: + break; + } +} + + +void +tool_control_init(rocprofiler_context_id_t& primary_ctx) +{ + // Create a specialized (throw-away) context for handling ROCTx profiler pause and resume. + // A separate context is used because if the context that is associated with roctxProfilerPause + // disabled that same context, a call to roctxProfilerResume would be ignored because the + // context that enables the callback for that API call is disabled. + auto cntrl_ctx = rocprofiler_context_id_t{}; + ROCPROFILER_CALL(rocprofiler_create_context(&cntrl_ctx), "control context creation failed"); + + // enable callback marker tracing with only the pause/resume operations + ROCPROFILER_CALL(rocprofiler_configure_callback_tracing_service( + cntrl_ctx, + ROCPROFILER_CALLBACK_TRACING_MARKER_CONTROL_API, + nullptr, + 0, + tool_tracing_ctrl_callback, + &primary_ctx), + "callback tracing service failed to configure"); + + // start the context so that it is always active + ROCPROFILER_CALL(rocprofiler_start_context(cntrl_ctx), "start of control context"); +} + +int +tool_init(rocprofiler_client_finalize_t fini_func, void* tool_data) +{ + assert(tool_data != nullptr); + + auto* call_stack_v = static_cast(tool_data); + + call_stack_v->emplace_back(source_location{__FUNCTION__, __FILE__, __LINE__, ""}); + + callback_name_info name_info = common::get_callback_id_names(); + + for(const auto& itr : name_info) + { + auto name_idx = std::stringstream{}; + name_idx << " [" << std::setw(3) << itr.value << "]"; + call_stack_v->emplace_back( + source_location{"rocprofiler_callback_tracing_kind_names " + name_idx.str(), + __FILE__, + __LINE__, + std::string{itr.name}}); + + for(auto [didx, ditr] : itr.items()) + { + auto operation_idx = std::stringstream{}; + operation_idx << " [" << std::setw(3) << didx << "]"; + call_stack_v->emplace_back(source_location{ + "rocprofiler_callback_tracing_kind_operation_names" + operation_idx.str(), + __FILE__, + __LINE__, + std::string{"- "} + std::string{*ditr}}); + } + } + + client_fini_func = fini_func; + + ROCPROFILER_CALL(rocprofiler_create_context(&client_ctx), "context creation failed"); + + // enable the control + tool_control_init(client_ctx); + + ROCPROFILER_CALL( + rocprofiler_configure_callback_tracing_service(client_ctx, + ROCPROFILER_CALLBACK_TRACING_MEMORY_COPY, + nullptr, + 0, + tool_tracing_callback, + tool_data), + "callback tracing service failed to configure"); + + int valid_ctx = 0; + ROCPROFILER_CALL(rocprofiler_context_is_valid(client_ctx, &valid_ctx), + "failure checking context validity"); + if(valid_ctx == 0) + { + // notify rocprofiler that initialization failed + // and all the contexts, buffers, etc. created + // should be ignored + return -1; + } + + ROCPROFILER_CALL(rocprofiler_start_context(client_ctx), "rocprofiler context start failed"); + + // no errors + return 0; +} + +void +tool_fini(void* tool_data) +{ + assert(tool_data != nullptr); + + auto* _call_stack = static_cast(tool_data); + _call_stack->emplace_back(source_location{__FUNCTION__, __FILE__, __LINE__, ""}); + + print_call_stack(*_call_stack); + + delete _call_stack; +} +} // namespace + +void +setup() +{} + +void +shutdown() +{ + if(client_id) client_fini_func(*client_id); +} + +void +start() +{ + ROCPROFILER_CALL(rocprofiler_start_context(client_ctx), "rocprofiler context start failed"); +} + +void +stop() +{ + int status = 0; + ROCPROFILER_CALL(rocprofiler_is_initialized(&status), "failed to retrieve init status"); + if(status != 0) + { + ROCPROFILER_CALL(rocprofiler_stop_context(client_ctx), "rocprofiler context stop failed"); + } +} +} // namespace client + +extern "C" rocprofiler_tool_configure_result_t* +rocprofiler_configure(uint32_t version, + const char* runtime_version, + uint32_t priority, + rocprofiler_client_id_t* id) +{ + // set the client name + id->name = "ExampleTool"; + + // store client info + client::client_id = id; + + // compute major/minor/patch version info + uint32_t major = version / 10000; + uint32_t minor = (version % 10000) / 100; + uint32_t patch = version % 100; + + // generate info string + auto info = std::stringstream{}; + info << id->name << " (priority=" << priority << ") is using rocprofiler-sdk v" << major << "." + << minor << "." << patch << " (" << runtime_version << ")"; + + std::clog << info.str() << std::endl; + + // demonstration of alternative way to get the version info + { + auto version_info = std::array{}; + ROCPROFILER_CALL( + rocprofiler_get_version(&version_info.at(0), &version_info.at(1), &version_info.at(2)), + "failed to get version info"); + + if(std::array{major, minor, patch} != version_info) + { + throw std::runtime_error{"version info mismatch"}; + } + } + + // data passed around all the callbacks + auto* client_tool_data = new std::vector{}; + + // add first entry + client_tool_data->emplace_back( + client::source_location{__FUNCTION__, __FILE__, __LINE__, info.str()}); + + // create configure data + static auto cfg = + rocprofiler_tool_configure_result_t{sizeof(rocprofiler_tool_configure_result_t), + &client::tool_init, + &client::tool_fini, + static_cast(client_tool_data)}; + + // return pointer to configure data + return &cfg; +} diff --git a/samples/api_callback_tracing_memcpy_bug/client.hpp b/samples/api_callback_tracing_memcpy_bug/client.hpp new file mode 100644 index 00000000..3d5b7ac3 --- /dev/null +++ b/samples/api_callback_tracing_memcpy_bug/client.hpp @@ -0,0 +1,44 @@ +// MIT License +// +// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#pragma once + +#ifdef callback_api_tracing_client_EXPORTS +# define CLIENT_API __attribute__((visibility("default"))) +#else +# define CLIENT_API +#endif + +namespace client +{ +void +setup() CLIENT_API; + +void +shutdown() CLIENT_API; + +void +start() CLIENT_API; + +void +stop() CLIENT_API; +} // namespace client diff --git a/samples/api_callback_tracing_memcpy_bug/main.cpp b/samples/api_callback_tracing_memcpy_bug/main.cpp new file mode 100644 index 00000000..97514632 --- /dev/null +++ b/samples/api_callback_tracing_memcpy_bug/main.cpp @@ -0,0 +1,272 @@ +// MIT License +// +// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "client.hpp" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#define HIP_API_CALL(CALL) \ + { \ + hipError_t error_ = (CALL); \ + if(error_ != hipSuccess) \ + { \ + auto _hip_api_print_lk = auto_lock_t{print_lock}; \ + fprintf(stderr, \ + "%s:%d :: HIP error %i: %s\n", \ + __FILE__, \ + __LINE__, \ + (int) error_, \ + hipGetErrorString(error_)); \ + throw std::runtime_error("hip_api_call"); \ + } \ + } + +namespace +{ +using auto_lock_t = std::unique_lock; +auto print_lock = std::mutex{}; +size_t nthreads = 2; +size_t nitr = 500; +size_t nsync = 10; +constexpr unsigned shared_mem_tile_dim = 32; + +void +check_hip_error(void); + +void +verify(int* in, int* out, int M, int N); +} // namespace + +__global__ void +transpose_a(const int* in, int* out, int M, int N); + +void +run(int rank, int tid, hipStream_t stream, int argc, char** argv); + +int +main(int argc, char** argv) +{ + client::setup(); // currently does nothing + // client::start(); // currently will fail + + auto range_id = roctxRangeStart("main"); + + int rank = 0; + for(int i = 1; i < argc; ++i) + { + auto _arg = std::string{argv[i]}; + if(_arg == "?" || _arg == "-h" || _arg == "--help") + { + fprintf(stderr, + "usage: transpose [NUM_THREADS (%zu)] [NUM_ITERATION (%zu)] " + "[SYNC_EVERY_N_ITERATIONS (%zu)]\n", + nthreads, + nitr, + nsync); + exit(EXIT_SUCCESS); + } + } + if(argc > 1) nthreads = atoll(argv[1]); + if(argc > 2) nitr = atoll(argv[2]); + if(argc > 3) nsync = atoll(argv[3]); + + printf("[transpose] Number of threads: %zu\n", nthreads); + printf("[transpose] Number of iterations: %zu\n", nitr); + printf("[transpose] Syncing every %zu iterations\n", nsync); + + // this is a temporary workaround in omnitrace when HIP + MPI is enabled + int ndevice = 0; + int devid = rank; + HIP_API_CALL(hipGetDeviceCount(&ndevice)); + printf("[transpose] Number of devices found: %i\n", ndevice); + if(ndevice > 0) + { + devid = rank % ndevice; + HIP_API_CALL(hipSetDevice(devid)); + printf("[transpose] Rank %i assigned to device %i\n", rank, devid); + } + if(rank == devid && rank < ndevice) + { + std::vector _threads{}; + std::vector _streams(nthreads); + roctxMark("stream creation"); + for(size_t i = 0; i < nthreads; ++i) + HIP_API_CALL(hipStreamCreate(&_streams.at(i))); + roctxMark("thread creation"); + for(size_t i = 1; i < nthreads; ++i) + _threads.emplace_back(run, rank, i, _streams.at(i), argc, argv); + run(rank, 0, _streams.at(0), argc, argv); + roctxMark("thread sync"); + for(auto& itr : _threads) + itr.join(); + roctxMark("stream destroy"); + for(size_t i = 0; i < nthreads; ++i) + HIP_API_CALL(hipStreamDestroy(_streams.at(i))); + } + + HIP_API_CALL(hipDeviceSynchronize()); + + auto tid = roctx_thread_id_t{}; + // get the thread id recognized by rocprofiler-sdk from roctx + roctxGetThreadId(&tid); + // pause API tracing + roctxProfilerPause(tid); + // would not expect below to show up in profiler (depends on tool) + HIP_API_CALL(hipDeviceReset()); + // resume API tracing + roctxProfilerResume(tid); + + roctxRangeStop(range_id); + + client::stop(); + client::shutdown(); + + return 0; +} + +__global__ void +transpose_a(const int* in, int* out, int M, int N) +{ + __shared__ int tile[shared_mem_tile_dim][shared_mem_tile_dim]; + + int idx = (blockIdx.y * blockDim.y + threadIdx.y) * M + blockIdx.x * blockDim.x + threadIdx.x; + tile[threadIdx.y][threadIdx.x] = in[idx]; + __syncthreads(); + idx = (blockIdx.x * blockDim.x + threadIdx.y) * N + blockIdx.y * blockDim.y + threadIdx.x; + out[idx] = tile[threadIdx.x][threadIdx.y]; +} + +void +run(int rank, int tid, hipStream_t stream, int argc, char** argv) +{ + auto run_name = std::stringstream{}; + run_name << __FUNCTION__ << "(" << rank << ", " << tid << ")"; + roctxRangePush(run_name.str().c_str()); + + unsigned int M = 4960 * 2; + unsigned int N = 4960 * 2; + if(argc > 2) nitr = atoll(argv[2]); + if(argc > 3) nsync = atoll(argv[3]); + + auto_lock_t _lk{print_lock}; + std::cout << "[transpose][" << rank << "][" << tid << "] M: " << M << " N: " << N << std::endl; + _lk.unlock(); + + auto _seed = std::random_device{}() * (rank + 1) * (tid + 1); + auto _engine = std::default_random_engine{_seed}; + auto _dist = std::uniform_int_distribution{0, 1000}; + + size_t size = sizeof(int) * M * N; + int* inp_matrix = new int[size]; + int* out_matrix = new int[size]; + for(size_t i = 0; i < M * N; i++) + { + inp_matrix[i] = _dist(_engine); + out_matrix[i] = 0; + } + int* in = nullptr; + int* out = nullptr; + + HIP_API_CALL(hipMalloc(&in, size)); + HIP_API_CALL(hipMalloc(&out, size)); + HIP_API_CALL(hipMemsetAsync(in, 0, size, stream)); + HIP_API_CALL(hipMemsetAsync(out, 0, size, stream)); + HIP_API_CALL(hipMemcpyAsync(in, inp_matrix, size, hipMemcpyHostToDevice, stream)); + HIP_API_CALL(hipStreamSynchronize(stream)); + + dim3 grid(M / 32, N / 32, 1); + dim3 block(32, 32, 1); // transpose_a + + auto t1 = std::chrono::high_resolution_clock::now(); + for(size_t i = 0; i < nitr; ++i) + { + transpose_a<<>>(in, out, M, N); + check_hip_error(); + if(i % nsync == (nsync - 1)) HIP_API_CALL(hipStreamSynchronize(stream)); + } + auto t2 = std::chrono::high_resolution_clock::now(); + HIP_API_CALL(hipStreamSynchronize(stream)); + HIP_API_CALL(hipMemcpyAsync(out_matrix, out, size, hipMemcpyDeviceToHost, stream)); + double time = std::chrono::duration_cast>(t2 - t1).count(); + float GB = (float) size * nitr * 2 / (1 << 30); + + print_lock.lock(); + std::cout << "[transpose][" << rank << "][" << tid << "] Runtime of transpose is " << time + << " sec\n"; + std::cout << "[transpose][" << rank << "][" << tid + << "] The average performance of transpose is " << GB / time << " GBytes/sec" + << std::endl; + print_lock.unlock(); + + HIP_API_CALL(hipStreamSynchronize(stream)); + + // cpu_transpose(matrix, out_matrix, M, N); + verify(inp_matrix, out_matrix, M, N); + + HIP_API_CALL(hipFree(in)); + HIP_API_CALL(hipFree(out)); + + delete[] inp_matrix; + delete[] out_matrix; + + roctxRangePop(); +} + +namespace +{ +void +check_hip_error(void) +{ + hipError_t err = hipGetLastError(); + if(err != hipSuccess) + { + auto_lock_t _lk{print_lock}; + std::cerr << "Error: " << hipGetErrorString(err) << std::endl; + throw std::runtime_error("hip_api_call"); + } +} + +void +verify(int* in, int* out, int M, int N) +{ + for(int i = 0; i < 10; i++) + { + int row = rand() % M; + int col = rand() % N; + if(in[row * N + col] != out[col * M + row]) + { + auto_lock_t _lk{print_lock}; + std::cout << "mismatch: " << row << ", " << col << " : " << in[row * N + col] << " | " + << out[col * M + row] << "\n"; + } + } +} +} // namespace