From 59020907f8b0ef83ae6a123fe9f253935a51fd1d Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Fri, 13 Oct 2023 11:40:18 +0000 Subject: [PATCH] Enable OpGraph traversal from public --- ark/sched/sched.cc | 11 +- ark/sched/sched.h | 5 +- ark/sched/sched/sched_default.cc | 2 +- ark/sched/sched_opgraph.cc | 220 ++++++++++++++++--------------- ark/sched/sched_opgraph.h | 18 +++ 5 files changed, 140 insertions(+), 116 deletions(-) diff --git a/ark/sched/sched.cc b/ark/sched/sched.cc index e64934e93..05b0b3495 100644 --- a/ark/sched/sched.cc +++ b/ark/sched/sched.cc @@ -22,18 +22,17 @@ BaseScheduler::BaseScheduler(Model &model, int gpu_id, int rank_, (int)(gpu_info.max_threads_per_block / gpu_info.threads_per_warp); this->num_warps_per_sm = std::min(num_warps_per_sm_, max_warps_per_sm); this->codegen = - std::make_unique(gpu_info, num_warps_per_sm_); + std::make_unique(gpu_info, this->num_warps_per_sm); +} + +void BaseScheduler::init_op_graph() { + this->op_graph = make_unique(*(this->model)); } // create context on gpu for the model GpuMgrCtx *BaseScheduler::create_context(const std::string &name) { GpuMgrCtx *ctx = this->gpu_mgr->create_context(name, this->rank, this->world_size); - // sort by buf ID. - std::sort(this->buf_infos.begin(), this->buf_infos.end(), - [](const BufInfo &a, const BufInfo &b) { - return a.tbuf->id < b.tbuf->id; - }); for (BufInfo &bi : this->buf_infos) { GpuBuf *buf; if (bi.gpu_id == this->gpu_mgr->gpu_id) { diff --git a/ark/sched/sched.h b/ark/sched/sched.h index c19832d4c..659503824 100644 --- a/ark/sched/sched.h +++ b/ark/sched/sched.h @@ -52,6 +52,8 @@ class BaseScheduler { virtual std::vector gen_code() = 0; protected: + void init_op_graph(); + Model *model; GpuMgr *gpu_mgr; int rank; @@ -67,6 +69,8 @@ class BaseScheduler { std::vector send_recv_ops; GpuMgrCtx *ctx; + + std::unique_ptr op_graph; }; class DefaultScheduler : public BaseScheduler { @@ -93,7 +97,6 @@ class DefaultScheduler : public BaseScheduler { void recursive_schedule(std::list &nodes, std::set &seen_nodes); - std::unique_ptr op_graph; std::vector> comp_stream; std::vector> comm_stream; }; diff --git a/ark/sched/sched/sched_default.cc b/ark/sched/sched/sched_default.cc index 78cb07868..80a87f93f 100644 --- a/ark/sched/sched/sched_default.cc +++ b/ark/sched/sched/sched_default.cc @@ -175,7 +175,7 @@ DefaultScheduler::DefaultScheduler(Model &model, int gpu_id, int rank_, heuristic_optimize_model(model, model.impl.get(), gpu_info, num_sm_calc); - this->op_graph = make_unique(model); + this->init_op_graph(); } void DefaultScheduler::schedule() { diff --git a/ark/sched/sched_opgraph.cc b/ark/sched/sched_opgraph.cc index bff012a32..4ae576022 100644 --- a/ark/sched/sched_opgraph.cc +++ b/ark/sched/sched_opgraph.cc @@ -120,24 +120,37 @@ void OpGraph::create_nodes(const Model &model) { recursive_merge(this->nodes_storage, seen_nodes, leaf_nodes); } -/// Helper of @ref create_nodes(). -/// Traverse the model graph and remove virtual Ops that perform no computation. -/// -/// @param nodes The list of @ref OpNode. -/// @param boundary_nodes The list of boundary @ref OpNode. -/// -void OpGraph::recursive_rm_virt(std::list> &nodes, - std::set &seen_nodes, - const std::list &boundary_nodes) { +std::vector OpGraph::get_nodes_in_order() { + std::set seen_nodes; + std::list leaf_nodes; + for (auto &node : this->nodes_storage) { + if (node->users.empty()) { + leaf_nodes.emplace_back(node.get()); + } + } + std::vector nodes; + OpGraph::recursive_traverse_internal( + this->nodes_storage, seen_nodes, leaf_nodes, []() {}, + [&nodes](OpNode *boundary_node) { + nodes.emplace_back(boundary_node); + return true; + }); + // Reverse the order. + std::reverse(nodes.begin(), nodes.end()); + return nodes; +} + +void OpGraph::recursive_traverse_internal( + std::list> &nodes, std::set &seen_nodes, + const std::list &boundary_nodes, + const std::function &hook_boundary, + const std::function &hook_boundary_node) { if (boundary_nodes.size() == 0) { return; } - OPGRAPH_DEBUG("remove virtual ops"); + hook_boundary(); std::list new_boundary_nodes; for (auto &boundary_node : boundary_nodes) { - if (boundary_node->ops.size() != 1) { - LOG(ERROR, "unexpected error"); - } OPGRAPH_DEBUG(" boundary node"); OPGRAPH_DEBUG(" op: ", boundary_node->get_name()); for (auto &producer : boundary_node->producers) { @@ -164,26 +177,52 @@ void OpGraph::recursive_rm_virt(std::list> &nodes, " to next boundary"); new_boundary_nodes.emplace_back(producer); } - if (boundary_node->ops[0]->is_virtual()) { - OPGRAPH_DEBUG(" remove op: ", boundary_node->get_name()); - // Remove this node from the graph. - boundary_node->remove_self(); - // Remove this node from the list of nodes. - auto it = std::find_if( - nodes.begin(), nodes.end(), - [boundary_node](const std::unique_ptr &node) { - return node.get() == boundary_node; - }); - if (it == nodes.end()) { - LOG(ERROR, "unexpected error"); - } - nodes.erase(it); - OPGRAPH_DEBUG(" nodes.size() ", nodes.size()); - } else { + bool seen = hook_boundary_node(boundary_node); + if (seen) { seen_nodes.insert(boundary_node); } } - recursive_rm_virt(nodes, seen_nodes, new_boundary_nodes); + recursive_traverse_internal(nodes, seen_nodes, new_boundary_nodes, + hook_boundary, hook_boundary_node); +} + +/// Helper of @ref create_nodes(). +/// Traverse the model graph and remove virtual Ops that perform no computation. +/// +/// @param nodes The list of @ref OpNode. +/// @param boundary_nodes The list of boundary @ref OpNode. +/// +void OpGraph::recursive_rm_virt(std::list> &nodes, + std::set &seen_nodes, + const std::list &boundary_nodes) { + OpGraph::recursive_traverse_internal( + nodes, seen_nodes, boundary_nodes, + []() { OPGRAPH_DEBUG("remove virtual ops"); }, + [&nodes](OpNode *boundary_node) { + bool seen = false; + if (boundary_node->ops.size() != 1) { + LOG(ERROR, "unexpected error"); + } + if (boundary_node->ops[0]->is_virtual()) { + OPGRAPH_DEBUG(" remove op: ", boundary_node->get_name()); + // Remove this node from the graph. + boundary_node->remove_self(); + // Remove this node from the list of nodes. + auto it = std::find_if( + nodes.begin(), nodes.end(), + [boundary_node](const std::unique_ptr &node) { + return node.get() == boundary_node; + }); + if (it == nodes.end()) { + LOG(ERROR, "unexpected error"); + } + nodes.erase(it); + OPGRAPH_DEBUG(" nodes.size() ", nodes.size()); + } else { + seen = true; + } + return seen; + }); } /// Helper of @ref create_nodes(). @@ -197,89 +236,54 @@ void OpGraph::recursive_rm_virt(std::list> &nodes, void OpGraph::recursive_merge(std::list> &nodes, std::set &seen_nodes, const std::list &boundary_nodes) { - if (boundary_nodes.size() == 0) { - return; - } - OPGRAPH_DEBUG("merge ops"); - std::list new_boundary_nodes; - for (auto &boundary_node : boundary_nodes) { - OPGRAPH_DEBUG(" boundary node"); - OPGRAPH_DEBUG(" op: ", boundary_node->get_name()); - if (boundary_node->producers.size() == 0) { - // This node is a root. - seen_nodes.insert(boundary_node); - OPGRAPH_DEBUG(" root"); - continue; - } - // Add all producers of this node to the next boundary. - for (auto &producer : boundary_node->producers) { - // Exception: if any user of the producer (rather than the current - // boundary_node) is unseen, we should not add the producer to the - // next boundary. - bool should_add = true; - for (auto &user : producer->users) { - if (user == boundary_node) { - continue; - } - if (seen_nodes.find(user) == seen_nodes.end()) { - should_add = false; - break; - } + OpGraph::recursive_traverse_internal( + nodes, seen_nodes, boundary_nodes, []() { OPGRAPH_DEBUG("merge ops"); }, + [&nodes](OpNode *boundary_node) { + if (boundary_node->producers.size() == 0) { + // This node is a root. + OPGRAPH_DEBUG(" root"); + return true; } - if (!should_add) { - continue; + if (boundary_node->producers.size() > 1) { + // This node has multiple producers. It cannot be merged. + OPGRAPH_DEBUG(" multiple producers"); + return true; } - if (seen_nodes.find(producer) != seen_nodes.end()) { - LOG(ERROR, "unexpected error: circular dependency detected"); + // This node has only one producer. + OpNode *producer = *(boundary_node->producers.begin()); + if (producer->users.size() == 0) { + LOG(ERROR, "unexpected error: graph is incomplete"); } - new_boundary_nodes.emplace_back(producer); - } - if (boundary_node->producers.size() > 1) { - // This node has multiple producers. It cannot be merged. - seen_nodes.insert(boundary_node); - OPGRAPH_DEBUG(" multiple producers"); - continue; - } - // This node has only one producer. - OpNode *producer = *(boundary_node->producers.begin()); - if (producer->users.size() == 0) { - LOG(ERROR, "unexpected error: graph is incomplete"); - } - if (producer->users.size() > 1) { - // The producer has multiple users. It cannot be merged. - seen_nodes.insert(boundary_node); - OPGRAPH_DEBUG(" multiple users"); - continue; - } - // The producer has only one user. Merge the two nodes. - - // Merge `boundary_node` into `producer`. - OPGRAPH_DEBUG(" merge ops: ", producer->get_name(), " -> ", - boundary_node->get_name()); - auto &ops = boundary_node->ops; - producer->ops.insert(producer->ops.end(), ops.begin(), ops.end()); - producer->users = boundary_node->users; - for (auto &user : producer->users) { - user->producers.erase(boundary_node); - user->producers.insert(producer); - } + if (producer->users.size() > 1) { + // The producer has multiple users. It cannot be merged. + OPGRAPH_DEBUG(" multiple users"); + return true; + } + // The producer has only one user. Merge the two nodes. - // Remove `boundary_node` from `nodes`. - auto it = - std::find_if(nodes.begin(), nodes.end(), - [boundary_node](const std::unique_ptr &node) { - return node.get() == boundary_node; - }); - if (it == nodes.end()) { - LOG(ERROR, "unexpected error"); - } - nodes.erase(it); + // Merge `boundary_node` into `producer`. + OPGRAPH_DEBUG(" merge ops: ", producer->get_name(), " -> ", + boundary_node->get_name()); + auto &ops = boundary_node->ops; + producer->ops.insert(producer->ops.end(), ops.begin(), ops.end()); + producer->users = boundary_node->users; + for (auto &user : producer->users) { + user->producers.erase(boundary_node); + user->producers.insert(producer); + } - // Since producer is already in the next boundary and boundary_node is - // merged into producer, we don't need to add anything to - // seen_nodes here. - } - recursive_merge(nodes, seen_nodes, new_boundary_nodes); + // Remove `boundary_node` from `nodes`. + auto it = std::find_if( + nodes.begin(), nodes.end(), + [boundary_node](const std::unique_ptr &node) { + return node.get() == boundary_node; + }); + if (it == nodes.end()) { + LOG(ERROR, "unexpected error"); + } + nodes.erase(it); + return false; + }); } OpNode *OpGraph::break_node(OpNode *node, int op_idx) { diff --git a/ark/sched/sched_opgraph.h b/ark/sched/sched_opgraph.h index e828f8c31..c2bbbc305 100644 --- a/ark/sched/sched_opgraph.h +++ b/ark/sched/sched_opgraph.h @@ -4,6 +4,7 @@ #ifndef _ARK_SCHED_OPGRAPH_H_ #define _ARK_SCHED_OPGRAPH_H_ +#include #include #include #include @@ -84,13 +85,30 @@ class OpGraph { /// @return The new @ref OpNode. OpNode *break_node(OpNode *node, int op_idx); + /// Traverse OpGraph and return the nodes in the order of execution. + /// + /// If there are multiple nodes that can be executed at the same time, they + /// will appear in a random order in the returned vector. + /// + /// @return The nodes in the order of execution. + std::vector get_nodes_in_order(); + private: std::list> nodes_storage; void create_nodes(const Model &model); + + static void recursive_traverse_internal( + std::list> &nodes, + std::set &seen_nodes, + const std::list &boundary_nodes, + const std::function &hook_boundary, + const std::function &hook_boundary_node); + static void recursive_rm_virt(std::list> &nodes, std::set &seen_nodes, const std::list &boundary_nodes); + static void recursive_merge(std::list> &nodes, std::set &seen_nodes, const std::list &boundary_nodes);