Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add resource argument to PluginStrategy methods #1775

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/plugin/counter-plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ class CounterPlugin :
public RAJA::util::PluginStrategy
{
public:
void preCapture(const RAJA::util::PluginContext& p) override {
if (p.platform == RAJA::Platform::host)
void preCapture(const RAJA::util::PluginContext& p, const RAJA::resources::Resource&) override {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's double check that we don't need a non-const resource reference.

if (p.platform == RAJA::Platform::host)
{
std::cout << " [CounterPlugin]: Capturing host kernel for the " << ++host_capture_counter << " time!" << std::endl;
}
Expand All @@ -25,7 +25,7 @@ class CounterPlugin :
}
}

void preLaunch(const RAJA::util::PluginContext& p) override {
void preLaunch(const RAJA::util::PluginContext& p, const RAJA::resources::Resource&) override {
if (p.platform == RAJA::Platform::host)
{
std::cout << " [CounterPlugin]: Launching host kernel for the " << ++host_launch_counter << " time!" << std::endl;
Expand Down
9 changes: 5 additions & 4 deletions include/RAJA/pattern/WorkGroup.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,15 +268,16 @@ struct WorkPool<WorkGroupPolicy<EXEC_POLICY_T,
}

util::PluginContext context{util::make_context<exec_policy>()};
util::callPreCapturePlugins(context);
// todo(bowen) do we want default resource here?
util::callPreCapturePlugins(context, resource_type::get_default());
Comment on lines +271 to +272
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's think about this


using RAJA::util::trigger_updates_before;
auto body = trigger_updates_before(loop_body);

m_runner.enqueue(
m_storage, std::forward<segment_T>(seg), std::move(body));

util::callPostCapturePlugins(context);
util::callPostCapturePlugins(context, resource_type::get_default());
}

inline workgroup_type instantiate();
Expand Down Expand Up @@ -497,12 +498,12 @@ WorkGroup<
Args... args)
{
util::PluginContext context{util::make_context<EXEC_POLICY_T>()};
util::callPreLaunchPlugins(context);
util::callPreLaunchPlugins(context, r);

// move any per run storage into worksite
worksite_type site(r, m_runner.run(m_storage, r, std::forward<Args>(args)...));

util::callPostLaunchPlugins(context);
util::callPostLaunchPlugins(context, r);

return site;
}
Expand Down
32 changes: 16 additions & 16 deletions include/RAJA/pattern/forall.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,14 +309,14 @@ RAJA_INLINE resources::EventProxy<Res> forall_Icount(ExecutionPolicy&& p,
//expt::check_forall_optional_args(loop_body, f_params);

util::PluginContext context{util::make_context<camp::decay<ExecutionPolicy>>()};
util::callPreCapturePlugins(context);
util::callPreCapturePlugins(context, r);

using RAJA::util::trigger_updates_before;
auto body = trigger_updates_before(loop_body);

util::callPostCapturePlugins(context);
util::callPostCapturePlugins(context, r);

util::callPreLaunchPlugins(context);
util::callPreLaunchPlugins(context, r);

RAJA::resources::EventProxy<Res> e = wrap::forall_Icount(
r,
Expand All @@ -325,7 +325,7 @@ RAJA_INLINE resources::EventProxy<Res> forall_Icount(ExecutionPolicy&& p,
std::move(body),
f_params);

util::callPostLaunchPlugins(context);
util::callPostLaunchPlugins(context, r);
return e;
}
template <typename ExecutionPolicy, typename IdxSet, typename LoopBody,
Expand Down Expand Up @@ -364,14 +364,14 @@ forall(ExecutionPolicy&& p, Res r, IdxSet&& c, Params&&... params)
expt::check_forall_optional_args(loop_body, f_params);

util::PluginContext context{util::make_context<camp::decay<ExecutionPolicy>>()};
util::callPreCapturePlugins(context);
util::callPreCapturePlugins(context, r);

using RAJA::util::trigger_updates_before;
auto body = trigger_updates_before(loop_body);

util::callPostCapturePlugins(context);
util::callPostCapturePlugins(context, r);

util::callPreLaunchPlugins(context);
util::callPreLaunchPlugins(context, r);

resources::EventProxy<Res> e = wrap::forall(
r,
Expand All @@ -380,7 +380,7 @@ forall(ExecutionPolicy&& p, Res r, IdxSet&& c, Params&&... params)
std::move(body),
f_params);

util::callPostLaunchPlugins(context);
util::callPostLaunchPlugins(context, r);
return e;
}
template <typename ExecutionPolicy, typename IdxSet, typename LoopBody,
Expand Down Expand Up @@ -457,14 +457,14 @@ forall_Icount(ExecutionPolicy&& p,
//expt::check_forall_optional_args(loop_body, f_params);

util::PluginContext context{util::make_context<camp::decay<ExecutionPolicy>>()};
util::callPreCapturePlugins(context);
util::callPreCapturePlugins(context, r);

using RAJA::util::trigger_updates_before;
auto body = trigger_updates_before(loop_body);

util::callPostCapturePlugins(context);
util::callPostCapturePlugins(context, r);

util::callPreLaunchPlugins(context);
util::callPreLaunchPlugins(context, r);

resources::EventProxy<Res> e = wrap::forall_Icount(
r,
Expand All @@ -474,7 +474,7 @@ forall_Icount(ExecutionPolicy&& p,
std::move(body),
f_params);

util::callPostLaunchPlugins(context);
util::callPostLaunchPlugins(context, r);
return e;
}
template <typename ExecutionPolicy,
Expand Down Expand Up @@ -525,14 +525,14 @@ forall(ExecutionPolicy&& p, Res r, Container&& c, Params&&... params)
expt::check_forall_optional_args(loop_body, f_params);

util::PluginContext context{util::make_context<camp::decay<ExecutionPolicy>>()};
util::callPreCapturePlugins(context);
util::callPreCapturePlugins(context, r);

using RAJA::util::trigger_updates_before;
auto body = trigger_updates_before(loop_body);

util::callPostCapturePlugins(context);
util::callPostCapturePlugins(context, r);

util::callPreLaunchPlugins(context);
util::callPreLaunchPlugins(context, r);

resources::EventProxy<Res> e = wrap::forall(
r,
Expand All @@ -541,7 +541,7 @@ forall(ExecutionPolicy&& p, Res r, Container&& c, Params&&... params)
std::move(body),
f_params);

util::callPostLaunchPlugins(context);
util::callPostLaunchPlugins(context, r);
return e;
}

Expand Down
8 changes: 4 additions & 4 deletions include/RAJA/pattern/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ RAJA_INLINE resources::EventProxy<Resource> kernel_param_resource(SegmentTuple &
camp::decay<Bodies>...>;


util::callPreCapturePlugins(context);
util::callPreCapturePlugins(context, resource);

// Create the LoopData object, which contains our policy object,
// our segments, loop bodies, and the tuple of loop indices
Expand All @@ -137,17 +137,17 @@ RAJA_INLINE resources::EventProxy<Resource> kernel_param_resource(SegmentTuple &
resource,
std::forward<Bodies>(bodies)...);

util::callPostCapturePlugins(context);
util::callPostCapturePlugins(context, resource);

using loop_types_t = internal::makeInitialLoopTypes<loop_data_t>;

util::callPreLaunchPlugins(context);
util::callPreLaunchPlugins(context, resource);

// Execute!
RAJA_FORCEINLINE_RECURSIVE
internal::execute_statement_list<PolicyType, loop_types_t>(loop_data);

util::callPostLaunchPlugins(context);
util::callPostLaunchPlugins(context, resource);

return resources::EventProxy<Resource>(resource);
}
Expand Down
46 changes: 24 additions & 22 deletions include/RAJA/pattern/launch/launch_core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,24 +229,24 @@ void launch(LaunchParams const &launch_params, const char *kernel_name, ReducePa
auto&& launch_body = expt::get_lambda(std::forward<ReduceParams>(rest_of_launch_args)...);

//Take the first policy as we assume the second policy is not user defined.
//We rely on the user to pair launch and loop policies correctly.
//We rely on the user to pair launch and loop policies core_protly.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spell correctly?

util::PluginContext context{util::make_context<typename LAUNCH_POLICY::host_policy_t>()};
util::callPreCapturePlugins(context);
using Res = typename resources::get_resource<typename LAUNCH_POLICY::host_policy_t>::type;
util::callPreCapturePlugins(context, Res::get_default());

using RAJA::util::trigger_updates_before;
auto p_body = trigger_updates_before(launch_body);

util::callPostCapturePlugins(context);
util::callPostCapturePlugins(context, Res::get_default());

util::callPreLaunchPlugins(context);
util::callPreLaunchPlugins(context, Res::get_default());

using launch_t = LaunchExecute<typename LAUNCH_POLICY::host_policy_t>;

using Res = typename resources::get_resource<typename LAUNCH_POLICY::host_policy_t>::type;

launch_t::exec(Res::get_default(), launch_params, kernel_name, p_body, reducers);

util::callPostLaunchPlugins(context);
util::callPostLaunchPlugins(context, Res::get_default());
}


Expand All @@ -264,23 +264,25 @@ void launch(LaunchParams const &launch_params, ReduceParams&&... rest_of_launch_

//Take the first policy as we assume the second policy is not user defined.
//We rely on the user to pair launch and loop policies correctly.
using Res = typename resources::get_resource<typename LAUNCH_POLICY::host_policy_t>::type;
using RAJA::util::trigger_updates_before;

util::PluginContext context{util::make_context<typename LAUNCH_POLICY::host_policy_t>()};
util::callPreCapturePlugins(context);

using RAJA::util::trigger_updates_before;
util::callPreCapturePlugins(context, Res::get_default());

auto p_body = trigger_updates_before(launch_body);

util::callPostCapturePlugins(context);
util::callPostCapturePlugins(context, Res::get_default());

util::callPreLaunchPlugins(context);
util::callPreLaunchPlugins(context, Res::get_default());

using launch_t = LaunchExecute<typename LAUNCH_POLICY::host_policy_t>;

using Res = typename resources::get_resource<typename LAUNCH_POLICY::host_policy_t>::type;

launch_t::exec(Res::get_default(), launch_params, kernel_name, p_body, reducers);

util::callPostLaunchPlugins(context);
util::callPostLaunchPlugins(context, Res::get_default());
}

//=================================================
Expand Down Expand Up @@ -420,27 +422,27 @@ launch(RAJA::resources::Resource res, LaunchParams const &launch_params,
util::PluginContext context{util::make_context<typename POLICY_LIST::host_policy_t>()};
#endif

util::callPreCapturePlugins(context);
util::callPreCapturePlugins(context, res);

using RAJA::util::trigger_updates_before;
auto p_body = trigger_updates_before(launch_body);

util::callPostCapturePlugins(context);
util::callPostCapturePlugins(context, res);

util::callPreLaunchPlugins(context);
util::callPreLaunchPlugins(context, res);

switch (place) {
case ExecPlace::HOST: {
using launch_t = LaunchExecute<typename POLICY_LIST::host_policy_t>;
resources::EventProxy<resources::Resource> e_proxy = launch_t::exec(res, launch_params, kernel_name, p_body, reducers);
util::callPostLaunchPlugins(context);
util::callPostLaunchPlugins(context, res);
return e_proxy;
}
#if defined(RAJA_GPU_ACTIVE)
case ExecPlace::DEVICE: {
using launch_t = LaunchExecute<typename POLICY_LIST::device_policy_t>;
resources::EventProxy<resources::Resource> e_proxy = launch_t::exec(res, launch_params, kernel_name, p_body, reducers);
util::callPostLaunchPlugins(context);
util::callPostLaunchPlugins(context, res);
return e_proxy;
}
#endif
Expand Down Expand Up @@ -488,27 +490,27 @@ launch(RAJA::resources::Resource res, LaunchParams const &launch_params,
util::PluginContext context{util::make_context<typename POLICY_LIST::host_policy_t>()};
#endif

util::callPreCapturePlugins(context);
util::callPreCapturePlugins(context, res);

using RAJA::util::trigger_updates_before;
auto p_body = trigger_updates_before(launch_body);

util::callPostCapturePlugins(context);
util::callPostCapturePlugins(context, res);

util::callPreLaunchPlugins(context);
util::callPreLaunchPlugins(context, res);

switch (place) {
case ExecPlace::HOST: {
using launch_t = LaunchExecute<typename POLICY_LIST::host_policy_t>;
resources::EventProxy<resources::Resource> e_proxy = launch_t::exec(res, launch_params, kernel_name, p_body, reducers);
util::callPostLaunchPlugins(context);
util::callPostLaunchPlugins(context, res);
return e_proxy;
}
#if defined(RAJA_GPU_ACTIVE)
case ExecPlace::DEVICE: {
using launch_t = LaunchExecute<typename POLICY_LIST::device_policy_t>;
resources::EventProxy<resources::Resource> e_proxy = launch_t::exec(res, launch_params, kernel_name, p_body, reducers);
util::callPostLaunchPlugins(context);
util::callPostLaunchPlugins(context, res);
return e_proxy;
}
#endif
Expand Down
20 changes: 10 additions & 10 deletions include/RAJA/policy/MultiPolicy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,21 +174,21 @@ struct policy_invoker : public policy_invoker<index - 1, size, rest...> {
if (offset == size - index - 1) {

util::PluginContext context{util::make_context<Policy>()};
util::callPreCapturePlugins(context);
auto r = resources::get_resource<Policy>::type::get_default();
util::callPreCapturePlugins(context, r);

using RAJA::util::trigger_updates_before;
auto body = trigger_updates_before(loop_body);

util::callPostCapturePlugins(context);
util::callPostCapturePlugins(context, r);

util::callPreLaunchPlugins(context);
util::callPreLaunchPlugins(context, r);

using policy::multi::forall_impl;
RAJA_FORCEINLINE_RECURSIVE
auto r = resources::get_resource<Policy>::type::get_default();
forall_impl(r, _p, std::forward<Iterable>(iter), body);

util::callPostLaunchPlugins(context);
util::callPostLaunchPlugins(context, r);
} else {
NextInvoker::invoke(offset, std::forward<Iterable>(iter), std::forward<LoopBody>(loop_body));
}
Expand All @@ -205,22 +205,22 @@ struct policy_invoker<0, size, Policy, rest...> {
if (offset == size - 1) {

util::PluginContext context{util::make_context<Policy>()};
util::callPreCapturePlugins(context);
auto r = resources::get_resource<Policy>::type::get_default();
util::callPreCapturePlugins(context, r);

using RAJA::util::trigger_updates_before;
auto body = trigger_updates_before(loop_body);

util::callPostCapturePlugins(context);
util::callPostCapturePlugins(context, r);

util::callPreLaunchPlugins(context);
util::callPreLaunchPlugins(context, r);

//std::cout <<"policy_invoker: No index\n";
using policy::multi::forall_impl;
RAJA_FORCEINLINE_RECURSIVE
auto r = resources::get_resource<Policy>::type::get_default();
forall_impl(r, _p, std::forward<Iterable>(iter), body);

util::callPostLaunchPlugins(context);
util::callPostLaunchPlugins(context, r);
} else {
throw std::runtime_error("unknown offset invoked");
}
Expand Down
Loading
Loading