Skip to content

Commit

Permalink
[L0 v2] use UMF for memory info queries
Browse files Browse the repository at this point in the history
This avoid going to the driver and speeds up
UR_USM_ALLOC_INFO_DEVICE query ~2 times in a microbenchmark.
  • Loading branch information
igchor committed Dec 17, 2024
1 parent 230d19e commit 527f2f3
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 45 deletions.
76 changes: 32 additions & 44 deletions source/adapters/level_zero/v2/usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,16 @@ makePool(usm::umf_disjoint_pool_config_t *poolParams,

if (!poolParams) {
auto [ret, poolHandle] = umf::poolMakeUniqueFromOps(
umfProxyPoolOps(), std::move(provider), nullptr);
umfProxyPoolOps(), std::move(provider), nullptr, poolDescriptor);
if (ret != UMF_RESULT_SUCCESS)
throw umf::umf2urResult(ret);
return std::move(poolHandle);
} else {
auto umfParams = getUmfParamsHandle(*poolParams);

auto [ret, poolHandle] =
umf::poolMakeUniqueFromOps(umfDisjointPoolOps(), std::move(provider),
static_cast<void *>(umfParams.get()));
auto [ret, poolHandle] = umf::poolMakeUniqueFromOps(
umfDisjointPoolOps(), std::move(provider),
static_cast<void *>(umfParams.get()), poolDescriptor);
if (ret != UMF_RESULT_SUCCESS)
throw umf::umf2urResult(ret);
return std::move(poolHandle);
Expand Down Expand Up @@ -356,6 +356,19 @@ urUSMFree(ur_context_handle_t hContext, ///< [in] handle of the context object
return exceptionToResult(std::current_exception());
}

static usm::pool_descriptor *getPoolDescriptor(const void *ptr) {
auto umfPool = umfPoolByPtr(ptr);
if (!umfPool) {
logger::error("urUSMGetMemAllocInfo: no memory associated with given ptr");
throw UR_RESULT_ERROR_INVALID_VALUE;
}

usm::pool_descriptor *poolDesc;
UMF_CALL_THROWS(umfPoolGetTag(umfPool, reinterpret_cast<void **>(&poolDesc)));

return poolDesc;
}

ur_result_t urUSMGetMemAllocInfo(
ur_context_handle_t hContext, ///< [in] handle of the context object
const void *ptr, ///< [in] pointer to USM memory object
Expand All @@ -367,48 +380,22 @@ ur_result_t urUSMGetMemAllocInfo(
size_t *pPropValueSizeRet ///< [out][optional] bytes returned in USM
///< allocation property
) try {
ze_device_handle_t zeDeviceHandle;
ZeStruct<ze_memory_allocation_properties_t> zeMemoryAllocationProperties;

// TODO: implement this using UMF once
// https://github.com/oneapi-src/unified-memory-framework/issues/686
// https://github.com/oneapi-src/unified-memory-framework/issues/687
// are implemented
ZE2UR_CALL(zeMemGetAllocProperties,
(hContext->getZeHandle(), ptr, &zeMemoryAllocationProperties,
&zeDeviceHandle));

UrReturnHelper ReturnValue(propValueSize, pPropValue, pPropValueSizeRet);
switch (propName) {
case UR_USM_ALLOC_INFO_TYPE: {
ur_usm_type_t memAllocType;
switch (zeMemoryAllocationProperties.type) {
case ZE_MEMORY_TYPE_UNKNOWN:
memAllocType = UR_USM_TYPE_UNKNOWN;
break;
case ZE_MEMORY_TYPE_HOST:
memAllocType = UR_USM_TYPE_HOST;
break;
case ZE_MEMORY_TYPE_DEVICE:
memAllocType = UR_USM_TYPE_DEVICE;
break;
case ZE_MEMORY_TYPE_SHARED:
memAllocType = UR_USM_TYPE_SHARED;
break;
default:
logger::error("urUSMGetMemAllocInfo: unexpected usm memory type");
return UR_RESULT_ERROR_INVALID_VALUE;
}
return ReturnValue(memAllocType);
auto poolDesc = getPoolDescriptor(ptr);

assert(poolDesc->type != UR_USM_TYPE_UNKNOWN);
return ReturnValue(poolDesc->type);
}
case UR_USM_ALLOC_INFO_DEVICE:
if (zeDeviceHandle) {
auto Platform = hContext->getPlatform();
auto Device = Platform->getDeviceFromNativeHandle(zeDeviceHandle);
return Device ? ReturnValue(Device) : UR_RESULT_ERROR_INVALID_VALUE;
} else {
return UR_RESULT_ERROR_INVALID_VALUE;
}
case UR_USM_ALLOC_INFO_DEVICE: {
auto poolDesc = getPoolDescriptor(ptr);
return ReturnValue(poolDesc->hDevice);
}
// TODO: implement this using UMF once
// https://github.com/oneapi-src/unified-memory-framework/issues/686
// is implemented
case UR_USM_ALLOC_INFO_BASE_PTR: {
void *base;
ZE2UR_CALL(zeMemGetAddressRange,
Expand All @@ -422,9 +409,10 @@ ur_result_t urUSMGetMemAllocInfo(
return ReturnValue(size);
}
case UR_USM_ALLOC_INFO_POOL: {
// TODO
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
default:
auto poolDesc = getPoolDescriptor(ptr);
return ReturnValue(poolDesc->poolHandle);
}
default: {
logger::error("urUSMGetMemAllocInfo: unsupported ParamName");
return UR_RESULT_ERROR_INVALID_VALUE;
}
Expand Down
39 changes: 39 additions & 0 deletions source/common/umf_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,45 @@ static inline auto poolMakeUniqueFromOps(umf_memory_pool_ops_t *ops,
UMF_RESULT_SUCCESS, pool_unique_handle_t(hPool, umfPoolDestroy)};
}

template <typename Tag>
static inline auto poolMakeUniqueFromOps(umf_memory_pool_ops_t *ops,
provider_unique_handle_t provider,
void *params, const Tag &tag) {
auto poolTag = new Tag(tag);

umf_memory_pool_handle_t hPool;
auto ret = umfPoolCreate(ops, provider.get(), params,
UMF_POOL_CREATE_FLAG_OWN_PROVIDER, &hPool);
if (ret != UMF_RESULT_SUCCESS) {
return std::pair<umf_result_t, pool_unique_handle_t>{
ret, pool_unique_handle_t(nullptr, nullptr)};
}

ret = umfPoolSetTag(hPool, poolTag, nullptr);
if (ret != UMF_RESULT_SUCCESS) {
umfPoolDestroy(hPool);
return std::pair<umf_result_t, pool_unique_handle_t>{
ret, pool_unique_handle_t(nullptr, nullptr)};
}

provider.release(); // pool now owns the provider

return std::pair<umf_result_t, pool_unique_handle_t>{
UMF_RESULT_SUCCESS,
pool_unique_handle_t(hPool, [](umf_memory_pool_handle_t hPool) {
Tag *tag = nullptr;
umfPoolGetTag(hPool, reinterpret_cast<void **>(&tag));

if (tag) {
delete tag;
} else {
logger::error("Failed to get tag from pool");
}

umfPoolDestroy(hPool);
})};
}

static inline auto providerMakeUniqueFromOps(umf_memory_provider_ops_t *ops,
void *params) {
umf_memory_provider_handle_t hProvider;
Expand Down
1 change: 0 additions & 1 deletion test/conformance/usm/usm_adapter_level_zero_v2.match

This file was deleted.

0 comments on commit 527f2f3

Please sign in to comment.