Skip to content

Commit

Permalink
Update Logging and Version check with more details and fix for static…
Browse files Browse the repository at this point in the history
… loader build

Signed-off-by: Neil R. Spruit <[email protected]>
  • Loading branch information
nrspruit committed Dec 10, 2024
1 parent a9842e3 commit acf49cf
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 47 deletions.
19 changes: 14 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,22 @@ endif()
include(FetchContent)

if(BUILD_L0_LOADER_TESTS)
FetchContent_Declare(
googletest
URL https://github.com/google/googletest/archive/refs/tags/v1.14.0.zip
)
FetchContent_Declare(
googletest
GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG v1.14.0
)
add_library(GTest::GTest INTERFACE IMPORTED)
target_link_libraries(GTest::GTest INTERFACE gtest_main)

# For Windows: Prevent overriding the parent project's compiler/linker settings
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
if(MSVC)
if (BUILD_STATIC)
set(gtest_force_shared_crt OFF CACHE BOOL "" FORCE)
else()
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
endif()
endif()

FetchContent_MakeAvailable(googletest)

Expand Down
68 changes: 33 additions & 35 deletions source/lib/ze_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ namespace ze_lib

///////////////////////////////////////////////////////////////////////////////
context_t::context_t()
{};
{
debugTraceEnabled = getenv_tobool( "ZE_ENABLE_LOADER_DEBUG_TRACE" );
};

///////////////////////////////////////////////////////////////////////////////
context_t::~context_t()
Expand All @@ -35,12 +37,6 @@ namespace ze_lib
ze_lib::destruction = true;
};

//////////////////////////////////////////////////////////////////////////
void debug_trace_message(const std::string &message, const std::string &extra)
{
std::cerr << message << " " << extra << std::endl;
}

//////////////////////////////////////////////////////////////////////////
__zedlllocal ze_result_t context_t::Init(ze_init_flags_t flags, bool sysmanOnly, ze_init_driver_type_desc_t* desc)
{
Expand All @@ -55,8 +51,8 @@ namespace ze_lib
loader = LOAD_DRIVER_LIBRARY(loaderFullLibraryPath.c_str());

if( NULL == loader ) {
std::string message = "ze_lib Context Init() Loader Library Load Failed";
debug_trace_message(message, "");
std::string message = "ze_lib Context Init() Loader Library Load Failed with ";
debug_trace_message(message, to_string(ZE_RESULT_ERROR_UNINITIALIZED));
return ZE_RESULT_ERROR_UNINITIALIZED;
}

Expand All @@ -65,25 +61,25 @@ namespace ze_lib
GET_FUNCTION_PTR(loader, "zeLoaderInit") );
result = loaderInit();
if( ZE_RESULT_SUCCESS != result ) {
std::string message = "ze_lib Context Init() Loader Init Failed";
debug_trace_message(message, "");
std::string message = "ze_lib Context Init() Loader Init Failed with ";
debug_trace_message(message, to_string(result));
return result;
}
typedef HMODULE (ZE_APICALL *getTracing_t)();
auto getTracing = reinterpret_cast<getTracing_t>(
GET_FUNCTION_PTR(loader, "zeLoaderGetTracingHandle") );
if (getTracing == nullptr) {
std::string message = "ze_lib Context Init() zeLoaderGetTracingHandle missing";
debug_trace_message(message, "");
std::string message = "ze_lib Context Init() zeLoaderGetTracingHandle missing, returning ";
debug_trace_message(message, to_string(ZE_RESULT_ERROR_UNINITIALIZED));
return ZE_RESULT_ERROR_UNINITIALIZED;
}
tracing_lib = getTracing();
typedef ze_result_t (ZE_APICALL *zelLoaderTracingLayerInit_t)(std::atomic<ze_dditable_t *> &zeDdiTable);
auto loaderTracingLayerInit = reinterpret_cast<zelLoaderTracingLayerInit_t>(
GET_FUNCTION_PTR(loader, "zelLoaderTracingLayerInit") );
if (loaderTracingLayerInit == nullptr) {
std::string message = "ze_lib Context Init() zelLoaderTracingLayerInit missing";
debug_trace_message(message, "");
std::string message = "ze_lib Context Init() zelLoaderTracingLayerInit missing, returning ";
debug_trace_message(message, to_string(ZE_RESULT_ERROR_UNINITIALIZED));
return ZE_RESULT_ERROR_UNINITIALIZED;
}
typedef loader::context_t * (ZE_APICALL *zelLoaderGetContext_t)();
Expand All @@ -97,16 +93,16 @@ namespace ze_lib
size_t size = 0;
result = zelLoaderGetVersions(&size, nullptr);
if (ZE_RESULT_SUCCESS != result) {
std::string message = "ze_lib Context Init() zelLoaderGetVersions Failed";
debug_trace_message(message, "");
std::string message = "ze_lib Context Init() zelLoaderGetVersions Failed with";
debug_trace_message(message, to_string(result));
return result;
}

std::vector<zel_component_version_t> versions(size);
result = zelLoaderGetVersions(&size, versions.data());
if (ZE_RESULT_SUCCESS != result) {
std::string message = "ze_lib Context Init() zelLoaderGetVersions Failed to read component versions";
debug_trace_message(message, "");
std::string message = "ze_lib Context Init() zelLoaderGetVersions Failed to read component versions with ";
debug_trace_message(message, to_string(result));
return result;
}
bool zeInitDriversSupport = true;
Expand All @@ -121,12 +117,14 @@ namespace ze_lib
zeInitDriversSupport = false;
}
} else {
std::string message = "ze_lib Context Init() Loader version is too new";
debug_trace_message(message, "");
std::string message = "ze_lib Context Init() Loader version is too new, returning ";
debug_trace_message(message, to_string(ZE_RESULT_ERROR_UNSUPPORTED_VERSION));
return ZE_RESULT_ERROR_UNSUPPORTED_VERSION;
}
}
}
std::string version_message = "Loader API Version to be requested is v" + std::to_string(ZE_MAJOR_VERSION(version)) + "." + std::to_string(ZE_MINOR_VERSION(version));
debug_trace_message(version_message, "");
#else
result = zeLoaderInit();
if( ZE_RESULT_SUCCESS == result ) {
Expand Down Expand Up @@ -162,35 +160,35 @@ namespace ze_lib
{
result = zeDdiTableInit(version);
if (result != ZE_RESULT_SUCCESS) {
std::string message = "ze_lib Context Init() zeDdiTableInit failed";
debug_trace_message(message, "");
std::string message = "ze_lib Context Init() zeDdiTableInit failed with ";
debug_trace_message(message, to_string(result));
}
}
// Init the ZET DDI Tables
if( ZE_RESULT_SUCCESS == result )
{
result = zetDdiTableInit(version);
if( ZE_RESULT_SUCCESS != result ) {
std::string message = "ze_lib Context Init() zetDdiTableInit failed";
debug_trace_message(message, "");
std::string message = "ze_lib Context Init() zetDdiTableInit failed with ";
debug_trace_message(message, to_string(result));
}
}
// Init the ZES DDI Tables
if( ZE_RESULT_SUCCESS == result )
{
result = zesDdiTableInit(version);
if (result != ZE_RESULT_SUCCESS) {
std::string message = "ze_lib Context Init() zesDdiTableInit failed";
debug_trace_message(message, "");
std::string message = "ze_lib Context Init() zesDdiTableInit failed with ";
debug_trace_message(message, to_string(result));
}
}
// Init the Tracing API DDI Tables
if( ZE_RESULT_SUCCESS == result )
{
result = zelTracingDdiTableInit(version);
if (result != ZE_RESULT_SUCCESS) {
std::string message = "ze_lib Context Init() zelTracingDdiTableInit failed";
debug_trace_message(message, "");
std::string message = "ze_lib Context Init() zelTracingDdiTableInit failed with ";
debug_trace_message(message, to_string(result));
}
}
// Init the stored ddi tables for the tracing layer
Expand All @@ -216,8 +214,8 @@ namespace ze_lib
auto loaderDriverCheck = reinterpret_cast<zelLoaderDriverCheck_t>(
GET_FUNCTION_PTR(loader, "zelLoaderDriverCheck") );
if (loaderDriverCheck == nullptr) {
std::string message = "ze_lib Context Init() zelLoaderDriverCheck missing";
debug_trace_message(message, "");
std::string message = "ze_lib Context Init() zelLoaderDriverCheck missing, returning ";
debug_trace_message(message, to_string(ZE_RESULT_ERROR_UNINITIALIZED));
return ZE_RESULT_ERROR_UNINITIALIZED;
}
result = loaderDriverCheck(flags, desc, &ze_lib::context->initialzeDdiTable.Global, &ze_lib::context->initialzesDdiTable.Global, &requireDdiReinit, sysmanOnly);
Expand All @@ -226,8 +224,8 @@ namespace ze_lib
auto loaderDriverCheck = reinterpret_cast<zelLoaderDriverCheck_t>(
GET_FUNCTION_PTR(loader, "zelLoaderDriverCheck") );
if (loaderDriverCheck == nullptr) {
std::string message = "ze_lib Context Init() zelLoaderDriverCheck missing";
debug_trace_message(message, "");
std::string message = "ze_lib Context Init() zelLoaderDriverCheck missing, returning ";
debug_trace_message(message, to_string(ZE_RESULT_ERROR_UNINITIALIZED));
return ZE_RESULT_ERROR_UNINITIALIZED;
}
result = loaderDriverCheck(flags, &ze_lib::context->initialzeDdiTable.Global, &ze_lib::context->initialzesDdiTable.Global, &requireDdiReinit, sysmanOnly);
Expand All @@ -236,8 +234,8 @@ namespace ze_lib
result = zelLoaderDriverCheck(flags, desc, &ze_lib::context->initialzeDdiTable.Global, &ze_lib::context->initialzesDdiTable.Global, &requireDdiReinit, sysmanOnly);
#endif
if (result != ZE_RESULT_SUCCESS) {
std::string message = "ze_lib Context Init() zelLoaderDriverCheck failed";
debug_trace_message(message, "");
std::string message = "ze_lib Context Init() zelLoaderDriverCheck failed with ";
debug_trace_message(message, to_string(result));
}
// If a driver was removed from the driver list, then the ddi tables need to be reinit to allow for passthru directly to the driver.
if (requireDdiReinit && loaderContextAccessAllowed) {
Expand Down
116 changes: 109 additions & 7 deletions source/lib/ze_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include <vector>
#include <mutex>
#include <atomic>
#include <typeinfo>
#include <iostream>

namespace ze_lib
{
Expand All @@ -34,14 +36,113 @@ namespace ze_lib
context_t();
~context_t();

///////////////////////////////////////////////////////////////////////////////
template <typename T, typename TableType>
ze_result_t getTableWithCheck(T getTable, ze_api_version_t version, TableType* table) {
if (getTable == nullptr) {
return ZE_RESULT_ERROR_UNINITIALIZED;
}
return getTable(version, table);
//////////////////////////////////////////////////////////////////////////
std::string to_string(const ze_result_t result) {
if (result == ZE_RESULT_SUCCESS) {
return "ZE_RESULT_SUCCESS";
} else if (result == ZE_RESULT_NOT_READY) {
return "ZE_RESULT_NOT_READY";
} else if (result == ZE_RESULT_ERROR_UNINITIALIZED) {
return "ZE_RESULT_ERROR_UNINITIALIZED";
} else if (result == ZE_RESULT_ERROR_DEVICE_LOST) {
return "ZE_RESULT_ERROR_DEVICE_LOST";
} else if (result == ZE_RESULT_ERROR_INVALID_ARGUMENT) {
return "ZE_RESULT_ERROR_INVALID_ARGUMENT";
} else if (result == ZE_RESULT_ERROR_OUT_OF_HOST_MEMORY) {
return "ZE_RESULT_ERROR_OUT_OF_HOST_MEMORY";
} else if (result == ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY) {
return "ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY";
} else if (result == ZE_RESULT_ERROR_MODULE_BUILD_FAILURE) {
return "ZE_RESULT_ERROR_MODULE_BUILD_FAILURE";
} else if (result == ZE_RESULT_ERROR_MODULE_LINK_FAILURE) {
return "ZE_RESULT_ERROR_MODULE_LINK_FAILURE";
} else if (result == ZE_RESULT_ERROR_INSUFFICIENT_PERMISSIONS) {
return "ZE_RESULT_ERROR_INSUFFICIENT_PERMISSIONS";
} else if (result == ZE_RESULT_ERROR_NOT_AVAILABLE) {
return "ZE_RESULT_ERROR_NOT_AVAILABLE";
} else if (result == ZE_RESULT_ERROR_DEPENDENCY_UNAVAILABLE) {
return "ZE_RESULT_ERROR_DEPENDENCY_UNAVAILABLE";
} else if (result == ZE_RESULT_WARNING_DROPPED_DATA) {
return "ZE_RESULT_WARNING_DROPPED_DATA";
} else if (result == ZE_RESULT_ERROR_UNSUPPORTED_VERSION) {
return "ZE_RESULT_ERROR_UNSUPPORTED_VERSION";
} else if (result == ZE_RESULT_ERROR_UNSUPPORTED_FEATURE) {
return "ZE_RESULT_ERROR_UNSUPPORTED_FEATURE";
} else if (result == ZE_RESULT_ERROR_INVALID_NULL_HANDLE) {
return "ZE_RESULT_ERROR_INVALID_NULL_HANDLE";
} else if (result == ZE_RESULT_ERROR_HANDLE_OBJECT_IN_USE) {
return "ZE_RESULT_ERROR_HANDLE_OBJECT_IN_USE";
} else if (result == ZE_RESULT_ERROR_INVALID_NULL_POINTER) {
return "ZE_RESULT_ERROR_INVALID_NULL_POINTER";
} else if (result == ZE_RESULT_ERROR_INVALID_SIZE) {
return "ZE_RESULT_ERROR_INVALID_SIZE";
} else if (result == ZE_RESULT_ERROR_UNSUPPORTED_SIZE) {
return "ZE_RESULT_ERROR_UNSUPPORTED_SIZE";
} else if (result == ZE_RESULT_ERROR_UNSUPPORTED_ALIGNMENT) {
return "ZE_RESULT_ERROR_UNSUPPORTED_ALIGNMENT";
} else if (result == ZE_RESULT_ERROR_INVALID_SYNCHRONIZATION_OBJECT) {
return "ZE_RESULT_ERROR_INVALID_SYNCHRONIZATION_OBJECT";
} else if (result == ZE_RESULT_ERROR_INVALID_ENUMERATION) {
return "ZE_RESULT_ERROR_INVALID_ENUMERATION";
} else if (result == ZE_RESULT_ERROR_UNSUPPORTED_ENUMERATION) {
return "ZE_RESULT_ERROR_UNSUPPORTED_ENUMERATION";
} else if (result == ZE_RESULT_ERROR_UNSUPPORTED_IMAGE_FORMAT) {
return "ZE_RESULT_ERROR_UNSUPPORTED_IMAGE_FORMAT";
} else if (result == ZE_RESULT_ERROR_INVALID_NATIVE_BINARY) {
return "ZE_RESULT_ERROR_INVALID_NATIVE_BINARY";
} else if (result == ZE_RESULT_ERROR_INVALID_GLOBAL_NAME) {
return "ZE_RESULT_ERROR_INVALID_GLOBAL_NAME";
} else if (result == ZE_RESULT_ERROR_INVALID_KERNEL_NAME) {
return "ZE_RESULT_ERROR_INVALID_KERNEL_NAME";
} else if (result == ZE_RESULT_ERROR_INVALID_FUNCTION_NAME) {
return "ZE_RESULT_ERROR_INVALID_FUNCTION_NAME";
} else if (result == ZE_RESULT_ERROR_INVALID_GROUP_SIZE_DIMENSION) {
return "ZE_RESULT_ERROR_INVALID_GROUP_SIZE_DIMENSION";
} else if (result == ZE_RESULT_ERROR_INVALID_GLOBAL_WIDTH_DIMENSION) {
return "ZE_RESULT_ERROR_INVALID_GLOBAL_WIDTH_DIMENSION";
} else if (result == ZE_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX) {
return "ZE_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX";
} else if (result == ZE_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE) {
return "ZE_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE";
} else if (result == ZE_RESULT_ERROR_INVALID_KERNEL_ATTRIBUTE_VALUE) {
return "ZE_RESULT_ERROR_INVALID_KERNEL_ATTRIBUTE_VALUE";
} else if (result == ZE_RESULT_ERROR_INVALID_MODULE_UNLINKED) {
return "ZE_RESULT_ERROR_INVALID_MODULE_UNLINKED";
} else if (result == ZE_RESULT_ERROR_INVALID_COMMAND_LIST_TYPE) {
return "ZE_RESULT_ERROR_INVALID_COMMAND_LIST_TYPE";
} else if (result == ZE_RESULT_ERROR_OVERLAPPING_REGIONS) {
return "ZE_RESULT_ERROR_OVERLAPPING_REGIONS";
} else if (result == ZE_RESULT_ERROR_UNKNOWN) {
return "ZE_RESULT_ERROR_UNKNOWN";
} else {
return std::to_string(static_cast<int>(result));
}
}

//////////////////////////////////////////////////////////////////////////
void debug_trace_message(std::string message, std::string result) {
if (debugTraceEnabled){
std::string debugTracePrefix = "ZE_LOADER_DEBUG_TRACE:";
std::cerr << debugTracePrefix << message << result << std::endl;
}
}

///////////////////////////////////////////////////////////////////////////////
template <typename T, typename TableType>
ze_result_t getTableWithCheck(T getTable, ze_api_version_t version, TableType* table) {
ze_result_t result = ZE_RESULT_ERROR_UNINITIALIZED;
if (getTable == nullptr) {
std::string message = "getTableWithCheck Failed for " + std::string(typeid(TableType).name()) + " with ";
debug_trace_message(message, to_string(result));
return result;
}
result = getTable(version, table);
if (result != ZE_RESULT_SUCCESS) {
std::string message = "getTableWithCheck Failed for " + std::string(typeid(TableType).name()) + " with ";
debug_trace_message(message, to_string(result));
}
return result;
}

std::once_flag initOnce;
std::once_flag initOnceDrivers;
Expand Down Expand Up @@ -71,6 +172,7 @@ namespace ze_lib
bool inTeardown = false;
bool zesInuse = false;
bool zeInuse = false;
bool debugTraceEnabled = false;
};

extern context_t *context;
Expand Down

0 comments on commit acf49cf

Please sign in to comment.