From a01590b28b3fd54fefe73f54ae1c9b8f54f31e03 Mon Sep 17 00:00:00 2001 From: john bowen Date: Mon, 4 Nov 2024 16:58:18 -0800 Subject: [PATCH] Add resource argument to PluginStrategy methods --- examples/plugin/counter-plugin.cpp | 6 +-- include/RAJA/pattern/WorkGroup.hpp | 9 ++-- include/RAJA/pattern/forall.hpp | 32 +++++++------- include/RAJA/pattern/kernel.hpp | 8 ++-- include/RAJA/pattern/launch/launch_core.hpp | 46 +++++++++++---------- include/RAJA/policy/MultiPolicy.hpp | 20 ++++----- include/RAJA/util/PluginStrategy.hpp | 12 +++--- include/RAJA/util/plugins.hpp | 26 ++++++------ src/PluginStrategy.cpp | 12 +++--- test/integration/plugin/plugin_to_test.cpp | 8 ++-- 10 files changed, 92 insertions(+), 87 deletions(-) diff --git a/examples/plugin/counter-plugin.cpp b/examples/plugin/counter-plugin.cpp index 8134cd9b83..d96e4f8ba8 100644 --- a/examples/plugin/counter-plugin.cpp +++ b/examples/plugin/counter-plugin.cpp @@ -14,8 +14,8 @@ class CounterPlugin : public RAJA::util::PluginStrategy { public: - void preCapture(const RAJA::util::PluginContext& p) override { - if (p.platform == RAJA::Platform::host) + void preCapture(const RAJA::util::PluginContext& p, const RAJA::resources::Resource&) override { + if (p.platform == RAJA::Platform::host) { std::cout << " [CounterPlugin]: Capturing host kernel for the " << ++host_capture_counter << " time!" << std::endl; } @@ -25,7 +25,7 @@ class CounterPlugin : } } - void preLaunch(const RAJA::util::PluginContext& p) override { + void preLaunch(const RAJA::util::PluginContext& p, const RAJA::resources::Resource&) override { if (p.platform == RAJA::Platform::host) { std::cout << " [CounterPlugin]: Launching host kernel for the " << ++host_launch_counter << " time!" << std::endl; diff --git a/include/RAJA/pattern/WorkGroup.hpp b/include/RAJA/pattern/WorkGroup.hpp index 767821b8d8..28928cc18e 100644 --- a/include/RAJA/pattern/WorkGroup.hpp +++ b/include/RAJA/pattern/WorkGroup.hpp @@ -268,7 +268,8 @@ struct WorkPool()}; - util::callPreCapturePlugins(context); + // todo(bowen) do we want default resource here? + util::callPreCapturePlugins(context, resource_type::get_default()); using RAJA::util::trigger_updates_before; auto body = trigger_updates_before(loop_body); @@ -276,7 +277,7 @@ struct WorkPool(seg), std::move(body)); - util::callPostCapturePlugins(context); + util::callPostCapturePlugins(context, resource_type::get_default()); } inline workgroup_type instantiate(); @@ -497,12 +498,12 @@ WorkGroup< Args... args) { util::PluginContext context{util::make_context()}; - util::callPreLaunchPlugins(context); + util::callPreLaunchPlugins(context, r); // move any per run storage into worksite worksite_type site(r, m_runner.run(m_storage, r, std::forward(args)...)); - util::callPostLaunchPlugins(context); + util::callPostLaunchPlugins(context, r); return site; } diff --git a/include/RAJA/pattern/forall.hpp b/include/RAJA/pattern/forall.hpp index 686f0e8c6b..19b83790d9 100644 --- a/include/RAJA/pattern/forall.hpp +++ b/include/RAJA/pattern/forall.hpp @@ -309,14 +309,14 @@ RAJA_INLINE resources::EventProxy forall_Icount(ExecutionPolicy&& p, //expt::check_forall_optional_args(loop_body, f_params); util::PluginContext context{util::make_context>()}; - util::callPreCapturePlugins(context); + util::callPreCapturePlugins(context, r); using RAJA::util::trigger_updates_before; auto body = trigger_updates_before(loop_body); - util::callPostCapturePlugins(context); + util::callPostCapturePlugins(context, r); - util::callPreLaunchPlugins(context); + util::callPreLaunchPlugins(context, r); RAJA::resources::EventProxy e = wrap::forall_Icount( r, @@ -325,7 +325,7 @@ RAJA_INLINE resources::EventProxy forall_Icount(ExecutionPolicy&& p, std::move(body), f_params); - util::callPostLaunchPlugins(context); + util::callPostLaunchPlugins(context, r); return e; } template >()}; - util::callPreCapturePlugins(context); + util::callPreCapturePlugins(context, r); using RAJA::util::trigger_updates_before; auto body = trigger_updates_before(loop_body); - util::callPostCapturePlugins(context); + util::callPostCapturePlugins(context, r); - util::callPreLaunchPlugins(context); + util::callPreLaunchPlugins(context, r); resources::EventProxy e = wrap::forall( r, @@ -380,7 +380,7 @@ forall(ExecutionPolicy&& p, Res r, IdxSet&& c, Params&&... params) std::move(body), f_params); - util::callPostLaunchPlugins(context); + util::callPostLaunchPlugins(context, r); return e; } template >()}; - util::callPreCapturePlugins(context); + util::callPreCapturePlugins(context, r); using RAJA::util::trigger_updates_before; auto body = trigger_updates_before(loop_body); - util::callPostCapturePlugins(context); + util::callPostCapturePlugins(context, r); - util::callPreLaunchPlugins(context); + util::callPreLaunchPlugins(context, r); resources::EventProxy e = wrap::forall_Icount( r, @@ -474,7 +474,7 @@ forall_Icount(ExecutionPolicy&& p, std::move(body), f_params); - util::callPostLaunchPlugins(context); + util::callPostLaunchPlugins(context, r); return e; } template >()}; - util::callPreCapturePlugins(context); + util::callPreCapturePlugins(context, r); using RAJA::util::trigger_updates_before; auto body = trigger_updates_before(loop_body); - util::callPostCapturePlugins(context); + util::callPostCapturePlugins(context, r); - util::callPreLaunchPlugins(context); + util::callPreLaunchPlugins(context, r); resources::EventProxy e = wrap::forall( r, @@ -541,7 +541,7 @@ forall(ExecutionPolicy&& p, Res r, Container&& c, Params&&... params) std::move(body), f_params); - util::callPostLaunchPlugins(context); + util::callPostLaunchPlugins(context, r); return e; } diff --git a/include/RAJA/pattern/kernel.hpp b/include/RAJA/pattern/kernel.hpp index 1875fe27d9..bd69fa777b 100644 --- a/include/RAJA/pattern/kernel.hpp +++ b/include/RAJA/pattern/kernel.hpp @@ -125,7 +125,7 @@ RAJA_INLINE resources::EventProxy kernel_param_resource(SegmentTuple & camp::decay...>; - util::callPreCapturePlugins(context); + util::callPreCapturePlugins(context, resource); // Create the LoopData object, which contains our policy object, // our segments, loop bodies, and the tuple of loop indices @@ -137,17 +137,17 @@ RAJA_INLINE resources::EventProxy kernel_param_resource(SegmentTuple & resource, std::forward(bodies)...); - util::callPostCapturePlugins(context); + util::callPostCapturePlugins(context, resource); using loop_types_t = internal::makeInitialLoopTypes; - util::callPreLaunchPlugins(context); + util::callPreLaunchPlugins(context, resource); // Execute! RAJA_FORCEINLINE_RECURSIVE internal::execute_statement_list(loop_data); - util::callPostLaunchPlugins(context); + util::callPostLaunchPlugins(context, resource); return resources::EventProxy(resource); } diff --git a/include/RAJA/pattern/launch/launch_core.hpp b/include/RAJA/pattern/launch/launch_core.hpp index f1d70aeacb..b40fac3806 100644 --- a/include/RAJA/pattern/launch/launch_core.hpp +++ b/include/RAJA/pattern/launch/launch_core.hpp @@ -229,24 +229,24 @@ void launch(LaunchParams const &launch_params, const char *kernel_name, ReducePa auto&& launch_body = expt::get_lambda(std::forward(rest_of_launch_args)...); //Take the first policy as we assume the second policy is not user defined. - //We rely on the user to pair launch and loop policies correctly. + //We rely on the user to pair launch and loop policies core_protly. util::PluginContext context{util::make_context()}; - util::callPreCapturePlugins(context); + using Res = typename resources::get_resource::type; + util::callPreCapturePlugins(context, Res::get_default()); using RAJA::util::trigger_updates_before; auto p_body = trigger_updates_before(launch_body); - util::callPostCapturePlugins(context); + util::callPostCapturePlugins(context, Res::get_default()); - util::callPreLaunchPlugins(context); + util::callPreLaunchPlugins(context, Res::get_default()); using launch_t = LaunchExecute; - using Res = typename resources::get_resource::type; launch_t::exec(Res::get_default(), launch_params, kernel_name, p_body, reducers); - util::callPostLaunchPlugins(context); + util::callPostLaunchPlugins(context, Res::get_default()); } @@ -264,23 +264,25 @@ void launch(LaunchParams const &launch_params, ReduceParams&&... rest_of_launch_ //Take the first policy as we assume the second policy is not user defined. //We rely on the user to pair launch and loop policies correctly. + using Res = typename resources::get_resource::type; + using RAJA::util::trigger_updates_before; + util::PluginContext context{util::make_context()}; - util::callPreCapturePlugins(context); - using RAJA::util::trigger_updates_before; + util::callPreCapturePlugins(context, Res::get_default()); + auto p_body = trigger_updates_before(launch_body); - util::callPostCapturePlugins(context); + util::callPostCapturePlugins(context, Res::get_default()); - util::callPreLaunchPlugins(context); + util::callPreLaunchPlugins(context, Res::get_default()); using launch_t = LaunchExecute; - using Res = typename resources::get_resource::type; launch_t::exec(Res::get_default(), launch_params, kernel_name, p_body, reducers); - util::callPostLaunchPlugins(context); + util::callPostLaunchPlugins(context, Res::get_default()); } //================================================= @@ -420,27 +422,27 @@ launch(RAJA::resources::Resource res, LaunchParams const &launch_params, util::PluginContext context{util::make_context()}; #endif - util::callPreCapturePlugins(context); + util::callPreCapturePlugins(context, res); using RAJA::util::trigger_updates_before; auto p_body = trigger_updates_before(launch_body); - util::callPostCapturePlugins(context); + util::callPostCapturePlugins(context, res); - util::callPreLaunchPlugins(context); + util::callPreLaunchPlugins(context, res); switch (place) { case ExecPlace::HOST: { using launch_t = LaunchExecute; resources::EventProxy e_proxy = launch_t::exec(res, launch_params, kernel_name, p_body, reducers); - util::callPostLaunchPlugins(context); + util::callPostLaunchPlugins(context, res); return e_proxy; } #if defined(RAJA_GPU_ACTIVE) case ExecPlace::DEVICE: { using launch_t = LaunchExecute; resources::EventProxy e_proxy = launch_t::exec(res, launch_params, kernel_name, p_body, reducers); - util::callPostLaunchPlugins(context); + util::callPostLaunchPlugins(context, res); return e_proxy; } #endif @@ -488,27 +490,27 @@ launch(RAJA::resources::Resource res, LaunchParams const &launch_params, util::PluginContext context{util::make_context()}; #endif - util::callPreCapturePlugins(context); + util::callPreCapturePlugins(context, res); using RAJA::util::trigger_updates_before; auto p_body = trigger_updates_before(launch_body); - util::callPostCapturePlugins(context); + util::callPostCapturePlugins(context, res); - util::callPreLaunchPlugins(context); + util::callPreLaunchPlugins(context, res); switch (place) { case ExecPlace::HOST: { using launch_t = LaunchExecute; resources::EventProxy e_proxy = launch_t::exec(res, launch_params, kernel_name, p_body, reducers); - util::callPostLaunchPlugins(context); + util::callPostLaunchPlugins(context, res); return e_proxy; } #if defined(RAJA_GPU_ACTIVE) case ExecPlace::DEVICE: { using launch_t = LaunchExecute; resources::EventProxy e_proxy = launch_t::exec(res, launch_params, kernel_name, p_body, reducers); - util::callPostLaunchPlugins(context); + util::callPostLaunchPlugins(context, res); return e_proxy; } #endif diff --git a/include/RAJA/policy/MultiPolicy.hpp b/include/RAJA/policy/MultiPolicy.hpp index defa08585a..4d40822598 100644 --- a/include/RAJA/policy/MultiPolicy.hpp +++ b/include/RAJA/policy/MultiPolicy.hpp @@ -174,21 +174,21 @@ struct policy_invoker : public policy_invoker { if (offset == size - index - 1) { util::PluginContext context{util::make_context()}; - util::callPreCapturePlugins(context); + auto r = resources::get_resource::type::get_default(); + util::callPreCapturePlugins(context, r); using RAJA::util::trigger_updates_before; auto body = trigger_updates_before(loop_body); - util::callPostCapturePlugins(context); + util::callPostCapturePlugins(context, r); - util::callPreLaunchPlugins(context); + util::callPreLaunchPlugins(context, r); using policy::multi::forall_impl; RAJA_FORCEINLINE_RECURSIVE - auto r = resources::get_resource::type::get_default(); forall_impl(r, _p, std::forward(iter), body); - util::callPostLaunchPlugins(context); + util::callPostLaunchPlugins(context, r); } else { NextInvoker::invoke(offset, std::forward(iter), std::forward(loop_body)); } @@ -205,22 +205,22 @@ struct policy_invoker<0, size, Policy, rest...> { if (offset == size - 1) { util::PluginContext context{util::make_context()}; - util::callPreCapturePlugins(context); + auto r = resources::get_resource::type::get_default(); + util::callPreCapturePlugins(context, r); using RAJA::util::trigger_updates_before; auto body = trigger_updates_before(loop_body); - util::callPostCapturePlugins(context); + util::callPostCapturePlugins(context, r); - util::callPreLaunchPlugins(context); + util::callPreLaunchPlugins(context, r); //std::cout <<"policy_invoker: No index\n"; using policy::multi::forall_impl; RAJA_FORCEINLINE_RECURSIVE - auto r = resources::get_resource::type::get_default(); forall_impl(r, _p, std::forward(iter), body); - util::callPostLaunchPlugins(context); + util::callPostLaunchPlugins(context, r); } else { throw std::runtime_error("unknown offset invoked"); } diff --git a/include/RAJA/util/PluginStrategy.hpp b/include/RAJA/util/PluginStrategy.hpp index 3935559bba..dcd0711e0c 100644 --- a/include/RAJA/util/PluginStrategy.hpp +++ b/include/RAJA/util/PluginStrategy.hpp @@ -11,6 +11,8 @@ #include "RAJA/util/PluginContext.hpp" #include "RAJA/util/PluginOptions.hpp" #include "RAJA/util/Registry.hpp" +#include "RAJA/util/resource.hpp" +#include "camp/resource.hpp" namespace RAJA { namespace util { @@ -22,15 +24,15 @@ class PluginStrategy virtual ~PluginStrategy() = default; - virtual RAJASHAREDDLL_API void init(const PluginOptions& p); + virtual RAJASHAREDDLL_API void init(const PluginOptions&); - virtual RAJASHAREDDLL_API void preCapture(const PluginContext& p); + virtual RAJASHAREDDLL_API void preCapture(const PluginContext&, const resources::Resource&); - virtual RAJASHAREDDLL_API void postCapture(const PluginContext& p); + virtual RAJASHAREDDLL_API void postCapture(const PluginContext&, const resources::Resource&); - virtual RAJASHAREDDLL_API void preLaunch(const PluginContext& p); + virtual RAJASHAREDDLL_API void preLaunch(const PluginContext&, const resources::Resource&); - virtual RAJASHAREDDLL_API void postLaunch(const PluginContext& p); + virtual RAJASHAREDDLL_API void postLaunch(const PluginContext&, const resources::Resource&); virtual RAJASHAREDDLL_API void finalize(); }; diff --git a/include/RAJA/util/plugins.hpp b/include/RAJA/util/plugins.hpp index d5f42efde0..08ddf9dea3 100644 --- a/include/RAJA/util/plugins.hpp +++ b/include/RAJA/util/plugins.hpp @@ -30,49 +30,49 @@ RAJA_INLINE auto trigger_updates_before(T&& item) RAJA_INLINE void -callPreCapturePlugins(const PluginContext& p) +callPreCapturePlugins(const PluginContext& p, const resources::Resource& resource) { for (auto plugin = PluginRegistry::begin(); plugin != PluginRegistry::end(); ++plugin) { - (*plugin).get()->preCapture(p); + (*plugin).get()->preCapture(p, resource); } } RAJA_INLINE void -callPostCapturePlugins(const PluginContext& p) +callPostCapturePlugins(const PluginContext& p, const resources::Resource& resource) { for (auto plugin = PluginRegistry::begin(); plugin != PluginRegistry::end(); ++plugin) { - (*plugin).get()->postCapture(p); + (*plugin).get()->postCapture(p, resource); } } RAJA_INLINE void -callPreLaunchPlugins(const PluginContext& p) +callPreLaunchPlugins(const PluginContext& p, const resources::Resource& resource) { for (auto plugin = PluginRegistry::begin(); plugin != PluginRegistry::end(); ++plugin) { - (*plugin).get()->preLaunch(p); + (*plugin).get()->preLaunch(p, resource); } } RAJA_INLINE void -callPostLaunchPlugins(const PluginContext& p) +callPostLaunchPlugins(const PluginContext& p, const resources::Resource& resource) { for (auto plugin = PluginRegistry::begin(); plugin != PluginRegistry::end(); ++plugin) { - (*plugin).get()->postLaunch(p); + (*plugin).get()->postLaunch(p, resource); } } @@ -80,7 +80,7 @@ RAJA_INLINE void callInitPlugins(const PluginOptions p) { - for (auto plugin = PluginRegistry::begin(); + for (auto plugin = PluginRegistry::begin(); plugin != PluginRegistry::end(); ++plugin) { @@ -91,22 +91,22 @@ callInitPlugins(const PluginOptions p) RAJA_INLINE void init_plugins(const std::string& path) -{ +{ callInitPlugins(make_options(path)); } RAJA_INLINE void init_plugins() -{ +{ callInitPlugins(make_options("")); } RAJA_INLINE void finalize_plugins() -{ - for (auto plugin = PluginRegistry::begin(); +{ + for (auto plugin = PluginRegistry::begin(); plugin != PluginRegistry::end(); ++plugin) { diff --git a/src/PluginStrategy.cpp b/src/PluginStrategy.cpp index e39c5718a8..df072230a2 100644 --- a/src/PluginStrategy.cpp +++ b/src/PluginStrategy.cpp @@ -14,17 +14,17 @@ namespace util { PluginStrategy::PluginStrategy() = default; -void PluginStrategy::init(const PluginOptions&) { } +RAJASHAREDDLL_API void PluginStrategy::init(const PluginOptions& p) { } -void PluginStrategy::preCapture(const PluginContext&) { } +RAJASHAREDDLL_API void PluginStrategy::preCapture(const PluginContext&, const RAJA::resources::Resource&) { } -void PluginStrategy::postCapture(const PluginContext&) { } +RAJASHAREDDLL_API void PluginStrategy::postCapture(const PluginContext&, const RAJA::resources::Resource&) { } -void PluginStrategy::preLaunch(const PluginContext&) { } +RAJASHAREDDLL_API void PluginStrategy::preLaunch(const PluginContext&, const RAJA::resources::Resource&) { } -void PluginStrategy::postLaunch(const PluginContext&) { } +RAJASHAREDDLL_API void PluginStrategy::postLaunch(const PluginContext&, const RAJA::resources::Resource&) { } -void PluginStrategy::finalize() { } +RAJASHAREDDLL_API void PluginStrategy::finalize() { } } } diff --git a/test/integration/plugin/plugin_to_test.cpp b/test/integration/plugin/plugin_to_test.cpp index 8290804191..6112fa4b09 100644 --- a/test/integration/plugin/plugin_to_test.cpp +++ b/test/integration/plugin/plugin_to_test.cpp @@ -16,7 +16,7 @@ class CounterPlugin : public RAJA::util::PluginStrategy { public: - void preCapture(const RAJA::util::PluginContext& p) override { + void preCapture(const RAJA::util::PluginContext& p, const RAJA::resources::Resource&) override { ASSERT_NE(plugin_test_data, nullptr); ASSERT_NE(plugin_test_resource, nullptr); @@ -30,7 +30,7 @@ class CounterPlugin : plugin_test_resource->memcpy(plugin_test_data, &data, sizeof(CounterData)); } - void postCapture(const RAJA::util::PluginContext& p) override { + void postCapture(const RAJA::util::PluginContext& p, const RAJA::resources::Resource&) override { ASSERT_NE(plugin_test_data, nullptr); ASSERT_NE(plugin_test_resource, nullptr); @@ -44,7 +44,7 @@ class CounterPlugin : plugin_test_resource->memcpy(plugin_test_data, &data, sizeof(CounterData)); } - void preLaunch(const RAJA::util::PluginContext& p) override { + void preLaunch(const RAJA::util::PluginContext& p, const RAJA::resources::Resource&) override { ASSERT_NE(plugin_test_data, nullptr); ASSERT_NE(plugin_test_resource, nullptr); @@ -58,7 +58,7 @@ class CounterPlugin : plugin_test_resource->memcpy(plugin_test_data, &data, sizeof(CounterData)); } - void postLaunch(const RAJA::util::PluginContext& p) override { + void postLaunch(const RAJA::util::PluginContext& p, const RAJA::resources::Resource&) override { ASSERT_NE(plugin_test_data, nullptr); ASSERT_NE(plugin_test_resource, nullptr);