Skip to content

Commit

Permalink
[SYCL] Remove unnecessary template parameter (#5127)
Browse files Browse the repository at this point in the history
Replacing unnecessary KernelName parameter with a bool value that is actually used in `HostKernel` class reduces the number of instantiated templates and may improve host-side frontend time by ~9%.
  • Loading branch information
alexbatashev authored Dec 13, 2021
1 parent d3649d8 commit cabb43f
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 29 deletions.
14 changes: 1 addition & 13 deletions sycl/include/CL/sycl/detail/cg_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ class HostTask {
};

// Class which stores specific lambda object.
template <class KernelType, class KernelArgType, int Dims, typename KernelName>
template <class KernelType, class KernelArgType, int Dims, bool StoreLocation>
class HostKernel : public HostKernelBase {
using IDBuilder = sycl::detail::Builder;
KernelType MKernel;
Expand Down Expand Up @@ -290,9 +290,6 @@ class HostKernel : public HostKernelBase {
template <class ArgT = KernelArgType>
typename detail::enable_if_t<std::is_same<ArgT, sycl::id<Dims>>::value>
runOnHost(const NDRDescT &NDRDesc) {
using KI = detail::KernelInfo<KernelName>;
constexpr bool StoreLocation = KI::callsAnyThisFreeFunction();

sycl::range<Dims> Range(InitializedVal<Dims, range>::template get<0>());
sycl::id<Dims> Offset;
sycl::range<Dims> Stride(
Expand Down Expand Up @@ -323,9 +320,6 @@ class HostKernel : public HostKernelBase {
typename detail::enable_if_t<
std::is_same<ArgT, item<Dims, /*Offset=*/false>>::value>
runOnHost(const NDRDescT &NDRDesc) {
using KI = detail::KernelInfo<KernelName>;
constexpr bool StoreLocation = KI::callsAnyThisFreeFunction();

sycl::id<Dims> ID;
sycl::range<Dims> Range(InitializedVal<Dims, range>::template get<0>());
for (int I = 0; I < Dims; ++I)
Expand All @@ -348,9 +342,6 @@ class HostKernel : public HostKernelBase {
typename detail::enable_if_t<
std::is_same<ArgT, item<Dims, /*Offset=*/true>>::value>
runOnHost(const NDRDescT &NDRDesc) {
using KI = detail::KernelInfo<KernelName>;
constexpr bool StoreLocation = KI::callsAnyThisFreeFunction();

sycl::range<Dims> Range(InitializedVal<Dims, range>::template get<0>());
sycl::id<Dims> Offset;
sycl::range<Dims> Stride(
Expand Down Expand Up @@ -380,9 +371,6 @@ class HostKernel : public HostKernelBase {
template <class ArgT = KernelArgType>
typename detail::enable_if_t<std::is_same<ArgT, nd_item<Dims>>::value>
runOnHost(const NDRDescT &NDRDesc) {
using KI = detail::KernelInfo<KernelName>;
constexpr bool StoreLocation = KI::callsAnyThisFreeFunction();

sycl::range<Dims> GroupSize(InitializedVal<Dims, range>::template get<0>());
for (int I = 0; I < Dims; ++I) {
if (NDRDesc.LocalSize[I] == 0 ||
Expand Down
32 changes: 18 additions & 14 deletions sycl/include/CL/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -576,22 +576,22 @@ class __SYCL_EXPORT handler {

// For 'id, item w/wo offset, nd_item' kernel arguments
template <class KernelType, class NormalizedKernelType, int Dims,
typename KernelName>
bool StoreLocation>
KernelType *ResetHostKernelHelper(const KernelType &KernelFunc) {
NormalizedKernelType NormalizedKernel(KernelFunc);
auto NormalizedKernelFunc =
std::function<void(const sycl::nd_item<Dims> &)>(NormalizedKernel);
auto HostKernelPtr =
new detail::HostKernel<decltype(NormalizedKernelFunc),
sycl::nd_item<Dims>, Dims, KernelName>(
sycl::nd_item<Dims>, Dims, StoreLocation>(
NormalizedKernelFunc);
MHostKernel.reset(HostKernelPtr);
return &HostKernelPtr->MKernel.template target<NormalizedKernelType>()
->MKernelFunc;
}

// For 'sycl::id<Dims>' kernel argument
template <class KernelType, typename ArgT, int Dims, typename KernelName>
template <class KernelType, typename ArgT, int Dims, bool StoreLocation>
typename std::enable_if<std::is_same<ArgT, sycl::id<Dims>>::value,
KernelType *>::type
ResetHostKernel(const KernelType &KernelFunc) {
Expand All @@ -604,11 +604,11 @@ class __SYCL_EXPORT handler {
}
};
return ResetHostKernelHelper<KernelType, struct NormalizedKernelType, Dims,
KernelName>(KernelFunc);
StoreLocation>(KernelFunc);
}

// For 'sycl::nd_item<Dims>' kernel argument
template <class KernelType, typename ArgT, int Dims, typename KernelName>
template <class KernelType, typename ArgT, int Dims, bool StoreLocation>
typename std::enable_if<std::is_same<ArgT, sycl::nd_item<Dims>>::value,
KernelType *>::type
ResetHostKernel(const KernelType &KernelFunc) {
Expand All @@ -621,11 +621,11 @@ class __SYCL_EXPORT handler {
}
};
return ResetHostKernelHelper<KernelType, struct NormalizedKernelType, Dims,
KernelName>(KernelFunc);
StoreLocation>(KernelFunc);
}

// For 'sycl::item<Dims, without_offset>' kernel argument
template <class KernelType, typename ArgT, int Dims, typename KernelName>
template <class KernelType, typename ArgT, int Dims, bool StoreLocation>
typename std::enable_if<std::is_same<ArgT, sycl::item<Dims, false>>::value,
KernelType *>::type
ResetHostKernel(const KernelType &KernelFunc) {
Expand All @@ -640,11 +640,11 @@ class __SYCL_EXPORT handler {
}
};
return ResetHostKernelHelper<KernelType, struct NormalizedKernelType, Dims,
KernelName>(KernelFunc);
StoreLocation>(KernelFunc);
}

// For 'sycl::item<Dims, with_offset>' kernel argument
template <class KernelType, typename ArgT, int Dims, typename KernelName>
template <class KernelType, typename ArgT, int Dims, bool StoreLocation>
typename std::enable_if<std::is_same<ArgT, sycl::item<Dims, true>>::value,
KernelType *>::type
ResetHostKernel(const KernelType &KernelFunc) {
Expand All @@ -659,7 +659,7 @@ class __SYCL_EXPORT handler {
}
};
return ResetHostKernelHelper<KernelType, struct NormalizedKernelType, Dims,
KernelName>(KernelFunc);
StoreLocation>(KernelFunc);
}

/* 'wrapper'-based approach using 'NormalizedKernelType' struct is
Expand All @@ -669,13 +669,14 @@ class __SYCL_EXPORT handler {
* not supported in ESIMD.
*/
// For 'void' and 'sycl::group<Dims>' kernel argument
template <class KernelType, typename ArgT, int Dims, typename KernelName>
template <class KernelType, typename ArgT, int Dims, bool StoreLocation>
typename std::enable_if<std::is_same<ArgT, void>::value ||
std::is_same<ArgT, sycl::group<Dims>>::value,
KernelType *>::type
ResetHostKernel(const KernelType &KernelFunc) {
MHostKernel.reset(
new detail::HostKernel<KernelType, ArgT, Dims, KernelName>(KernelFunc));
new detail::HostKernel<KernelType, ArgT, Dims, StoreLocation>(
KernelFunc));
return (KernelType *)(MHostKernel->getPtr());
}

Expand All @@ -697,6 +698,9 @@ class __SYCL_EXPORT handler {
template <typename KernelName, typename KernelType, int Dims,
typename LambdaArgType>
void StoreLambda(KernelType KernelFunc) {
using KI = detail::KernelInfo<KernelName>;
constexpr bool StoreLocation = KI::callsAnyThisFreeFunction();

constexpr bool IsCallableWithKernelHandler =
detail::KernelLambdaHasKernelHandlerArgT<KernelType,
LambdaArgType>::value;
Expand All @@ -707,7 +711,7 @@ class __SYCL_EXPORT handler {
PI_INVALID_OPERATION);
}
KernelType *KernelPtr =
ResetHostKernel<KernelType, LambdaArgType, Dims, KernelName>(
ResetHostKernel<KernelType, LambdaArgType, Dims, StoreLocation>(
KernelFunc);

using KI = sycl::detail::KernelInfo<KernelName>;
Expand Down Expand Up @@ -1481,7 +1485,7 @@ class __SYCL_EXPORT handler {

MArgs = std::move(MAssociatedAccesors);
MHostKernel.reset(
new detail::HostKernel<FuncT, void, 1, void>(std::move(Func)));
new detail::HostKernel<FuncT, void, 1, false>(std::move(Func)));
setType(detail::CG::RunOnHostIntel);
}

Expand Down
3 changes: 1 addition & 2 deletions sycl/unittests/scheduler/StreamInitDependencyOnHost.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ class MockHandler : public sycl::handler {
typename KernelName>
void setHostKernel(KernelType Kernel) {
static_cast<sycl::handler *>(this)->MHostKernel.reset(
new sycl::detail::HostKernel<KernelType, ArgType, Dims, KernelName>(
Kernel));
new sycl::detail::HostKernel<KernelType, ArgType, Dims, false>(Kernel));
}

template <int Dims> void setNDRangeDesc(sycl::nd_range<Dims> Range) {
Expand Down

0 comments on commit cabb43f

Please sign in to comment.