diff --git a/source/adapters/level_zero/v2/usm.cpp b/source/adapters/level_zero/v2/usm.cpp index 7559fd9815..d5306494ed 100644 --- a/source/adapters/level_zero/v2/usm.cpp +++ b/source/adapters/level_zero/v2/usm.cpp @@ -127,16 +127,16 @@ makePool(usm::umf_disjoint_pool_config_t *poolParams, if (!poolParams) { auto [ret, poolHandle] = umf::poolMakeUniqueFromOps( - umfProxyPoolOps(), std::move(provider), nullptr); + umfProxyPoolOps(), std::move(provider), nullptr, poolDescriptor); if (ret != UMF_RESULT_SUCCESS) throw umf::umf2urResult(ret); return std::move(poolHandle); } else { auto umfParams = getUmfParamsHandle(*poolParams); - auto [ret, poolHandle] = - umf::poolMakeUniqueFromOps(umfDisjointPoolOps(), std::move(provider), - static_cast(umfParams.get())); + auto [ret, poolHandle] = umf::poolMakeUniqueFromOps( + umfDisjointPoolOps(), std::move(provider), + static_cast(umfParams.get()), poolDescriptor); if (ret != UMF_RESULT_SUCCESS) throw umf::umf2urResult(ret); return std::move(poolHandle); @@ -356,6 +356,19 @@ urUSMFree(ur_context_handle_t hContext, ///< [in] handle of the context object return exceptionToResult(std::current_exception()); } +static usm::pool_descriptor *getPoolDescriptor(const void *ptr) { + auto umfPool = umfPoolByPtr(ptr); + if (!umfPool) { + logger::error("urUSMGetMemAllocInfo: no memory associated with given ptr"); + throw UR_RESULT_ERROR_INVALID_VALUE; + } + + usm::pool_descriptor *poolDesc; + UMF_CALL_THROWS(umfPoolGetTag(umfPool, reinterpret_cast(&poolDesc))); + + return poolDesc; +} + ur_result_t urUSMGetMemAllocInfo( ur_context_handle_t hContext, ///< [in] handle of the context object const void *ptr, ///< [in] pointer to USM memory object @@ -367,48 +380,24 @@ ur_result_t urUSMGetMemAllocInfo( size_t *pPropValueSizeRet ///< [out][optional] bytes returned in USM ///< allocation property ) try { - ze_device_handle_t zeDeviceHandle; - ZeStruct zeMemoryAllocationProperties; - - // TODO: implement this using UMF once - // https://github.com/oneapi-src/unified-memory-framework/issues/686 - // https://github.com/oneapi-src/unified-memory-framework/issues/687 - // are implemented - ZE2UR_CALL(zeMemGetAllocProperties, - (hContext->getZeHandle(), ptr, &zeMemoryAllocationProperties, - &zeDeviceHandle)); UrReturnHelper ReturnValue(propValueSize, pPropValue, pPropValueSizeRet); switch (propName) { case UR_USM_ALLOC_INFO_TYPE: { - ur_usm_type_t memAllocType; - switch (zeMemoryAllocationProperties.type) { - case ZE_MEMORY_TYPE_UNKNOWN: - memAllocType = UR_USM_TYPE_UNKNOWN; - break; - case ZE_MEMORY_TYPE_HOST: - memAllocType = UR_USM_TYPE_HOST; - break; - case ZE_MEMORY_TYPE_DEVICE: - memAllocType = UR_USM_TYPE_DEVICE; - break; - case ZE_MEMORY_TYPE_SHARED: - memAllocType = UR_USM_TYPE_SHARED; - break; - default: - logger::error("urUSMGetMemAllocInfo: unexpected usm memory type"); - return UR_RESULT_ERROR_INVALID_VALUE; + try { + auto poolDesc = getPoolDescriptor(ptr); + return ReturnValue(poolDesc->type); + } catch (...) { + return ReturnValue(UR_USM_TYPE_UNKNOWN); } - return ReturnValue(memAllocType); } - case UR_USM_ALLOC_INFO_DEVICE: - if (zeDeviceHandle) { - auto Platform = hContext->getPlatform(); - auto Device = Platform->getDeviceFromNativeHandle(zeDeviceHandle); - return Device ? ReturnValue(Device) : UR_RESULT_ERROR_INVALID_VALUE; - } else { - return UR_RESULT_ERROR_INVALID_VALUE; - } + case UR_USM_ALLOC_INFO_DEVICE: { + auto poolDesc = getPoolDescriptor(ptr); + return ReturnValue(poolDesc->hDevice); + } + // TODO: implement this using UMF once + // https://github.com/oneapi-src/unified-memory-framework/issues/686 + // is implemented case UR_USM_ALLOC_INFO_BASE_PTR: { void *base; ZE2UR_CALL(zeMemGetAddressRange, @@ -422,9 +411,10 @@ ur_result_t urUSMGetMemAllocInfo( return ReturnValue(size); } case UR_USM_ALLOC_INFO_POOL: { - // TODO - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - default: + auto poolDesc = getPoolDescriptor(ptr); + return ReturnValue(poolDesc->poolHandle); + } + default: { logger::error("urUSMGetMemAllocInfo: unsupported ParamName"); return UR_RESULT_ERROR_INVALID_VALUE; } diff --git a/source/common/umf_helpers.hpp b/source/common/umf_helpers.hpp index 4b7a4a7b6f..352e07836b 100644 --- a/source/common/umf_helpers.hpp +++ b/source/common/umf_helpers.hpp @@ -222,6 +222,45 @@ static inline auto poolMakeUniqueFromOps(umf_memory_pool_ops_t *ops, UMF_RESULT_SUCCESS, pool_unique_handle_t(hPool, umfPoolDestroy)}; } +template +static inline auto poolMakeUniqueFromOps(umf_memory_pool_ops_t *ops, + provider_unique_handle_t provider, + void *params, const Tag &tag) { + auto poolTag = new Tag(tag); + + umf_memory_pool_handle_t hPool; + auto ret = umfPoolCreate(ops, provider.get(), params, + UMF_POOL_CREATE_FLAG_OWN_PROVIDER, &hPool); + if (ret != UMF_RESULT_SUCCESS) { + return std::pair{ + ret, pool_unique_handle_t(nullptr, nullptr)}; + } + + ret = umfPoolSetTag(hPool, poolTag, nullptr); + if (ret != UMF_RESULT_SUCCESS) { + umfPoolDestroy(hPool); + return std::pair{ + ret, pool_unique_handle_t(nullptr, nullptr)}; + } + + provider.release(); // pool now owns the provider + + return std::pair{ + UMF_RESULT_SUCCESS, + pool_unique_handle_t(hPool, [](umf_memory_pool_handle_t hPool) { + Tag *tag = nullptr; + umfPoolGetTag(hPool, reinterpret_cast(&tag)); + + if (tag) { + delete tag; + } else { + logger::error("Failed to get tag from pool"); + } + + umfPoolDestroy(hPool); + })}; +} + static inline auto providerMakeUniqueFromOps(umf_memory_provider_ops_t *ops, void *params) { umf_memory_provider_handle_t hProvider; diff --git a/test/conformance/usm/usm_adapter_level_zero_v2.match b/test/conformance/usm/usm_adapter_level_zero_v2.match deleted file mode 100644 index ad8e1888d4..0000000000 --- a/test/conformance/usm/usm_adapter_level_zero_v2.match +++ /dev/null @@ -1 +0,0 @@ -urUSMGetMemAllocInfoTest.Success/*___UR_USM_ALLOC_INFO_POOL