Skip to content

Commit

Permalink
Enable OpGraph traversal from public
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Oct 13, 2023
1 parent 4281953 commit 5902090
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 116 deletions.
11 changes: 5 additions & 6 deletions ark/sched/sched.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,17 @@ BaseScheduler::BaseScheduler(Model &model, int gpu_id, int rank_,
(int)(gpu_info.max_threads_per_block / gpu_info.threads_per_warp);
this->num_warps_per_sm = std::min(num_warps_per_sm_, max_warps_per_sm);
this->codegen =
std::make_unique<CodeGenerator>(gpu_info, num_warps_per_sm_);
std::make_unique<CodeGenerator>(gpu_info, this->num_warps_per_sm);
}

void BaseScheduler::init_op_graph() {
this->op_graph = make_unique<OpGraph>(*(this->model));
}

// create context on gpu for the model
GpuMgrCtx *BaseScheduler::create_context(const std::string &name) {
GpuMgrCtx *ctx =
this->gpu_mgr->create_context(name, this->rank, this->world_size);
// sort by buf ID.
std::sort(this->buf_infos.begin(), this->buf_infos.end(),
[](const BufInfo &a, const BufInfo &b) {
return a.tbuf->id < b.tbuf->id;
});
for (BufInfo &bi : this->buf_infos) {
GpuBuf *buf;
if (bi.gpu_id == this->gpu_mgr->gpu_id) {
Expand Down
5 changes: 4 additions & 1 deletion ark/sched/sched.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class BaseScheduler {
virtual std::vector<std::string> gen_code() = 0;

protected:
void init_op_graph();

Model *model;
GpuMgr *gpu_mgr;
int rank;
Expand All @@ -67,6 +69,8 @@ class BaseScheduler {
std::vector<const Op *> send_recv_ops;

GpuMgrCtx *ctx;

std::unique_ptr<OpGraph> op_graph;
};

class DefaultScheduler : public BaseScheduler {
Expand All @@ -93,7 +97,6 @@ class DefaultScheduler : public BaseScheduler {
void recursive_schedule(std::list<OpNode *> &nodes,
std::set<OpNode *> &seen_nodes);

std::unique_ptr<OpGraph> op_graph;
std::vector<std::unique_ptr<SchedStream>> comp_stream;
std::vector<std::unique_ptr<SchedStream>> comm_stream;
};
Expand Down
2 changes: 1 addition & 1 deletion ark/sched/sched/sched_default.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ DefaultScheduler::DefaultScheduler(Model &model, int gpu_id, int rank_,

heuristic_optimize_model(model, model.impl.get(), gpu_info, num_sm_calc);

this->op_graph = make_unique<OpGraph>(model);
this->init_op_graph();
}

void DefaultScheduler::schedule() {
Expand Down
220 changes: 112 additions & 108 deletions ark/sched/sched_opgraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,24 +120,37 @@ void OpGraph::create_nodes(const Model &model) {
recursive_merge(this->nodes_storage, seen_nodes, leaf_nodes);
}

/// Helper of @ref create_nodes().
/// Traverse the model graph and remove virtual Ops that perform no computation.
///
/// @param nodes The list of @ref OpNode.
/// @param boundary_nodes The list of boundary @ref OpNode.
///
void OpGraph::recursive_rm_virt(std::list<std::unique_ptr<OpNode>> &nodes,
std::set<OpNode *> &seen_nodes,
const std::list<OpNode *> &boundary_nodes) {
std::vector<OpNode *> OpGraph::get_nodes_in_order() {
std::set<OpNode *> seen_nodes;
std::list<OpNode *> leaf_nodes;
for (auto &node : this->nodes_storage) {
if (node->users.empty()) {
leaf_nodes.emplace_back(node.get());
}
}
std::vector<OpNode *> nodes;
OpGraph::recursive_traverse_internal(
this->nodes_storage, seen_nodes, leaf_nodes, []() {},
[&nodes](OpNode *boundary_node) {
nodes.emplace_back(boundary_node);
return true;
});
// Reverse the order.
std::reverse(nodes.begin(), nodes.end());
return nodes;
}

void OpGraph::recursive_traverse_internal(
std::list<std::unique_ptr<OpNode>> &nodes, std::set<OpNode *> &seen_nodes,
const std::list<OpNode *> &boundary_nodes,
const std::function<void()> &hook_boundary,
const std::function<bool(OpNode *)> &hook_boundary_node) {
if (boundary_nodes.size() == 0) {
return;
}
OPGRAPH_DEBUG("remove virtual ops");
hook_boundary();
std::list<OpNode *> new_boundary_nodes;
for (auto &boundary_node : boundary_nodes) {
if (boundary_node->ops.size() != 1) {
LOG(ERROR, "unexpected error");
}
OPGRAPH_DEBUG(" boundary node");
OPGRAPH_DEBUG(" op: ", boundary_node->get_name());
for (auto &producer : boundary_node->producers) {
Expand All @@ -164,26 +177,52 @@ void OpGraph::recursive_rm_virt(std::list<std::unique_ptr<OpNode>> &nodes,
" to next boundary");
new_boundary_nodes.emplace_back(producer);
}
if (boundary_node->ops[0]->is_virtual()) {
OPGRAPH_DEBUG(" remove op: ", boundary_node->get_name());
// Remove this node from the graph.
boundary_node->remove_self();
// Remove this node from the list of nodes.
auto it = std::find_if(
nodes.begin(), nodes.end(),
[boundary_node](const std::unique_ptr<OpNode> &node) {
return node.get() == boundary_node;
});
if (it == nodes.end()) {
LOG(ERROR, "unexpected error");
}
nodes.erase(it);
OPGRAPH_DEBUG(" nodes.size() ", nodes.size());
} else {
bool seen = hook_boundary_node(boundary_node);
if (seen) {
seen_nodes.insert(boundary_node);
}
}
recursive_rm_virt(nodes, seen_nodes, new_boundary_nodes);
recursive_traverse_internal(nodes, seen_nodes, new_boundary_nodes,
hook_boundary, hook_boundary_node);
}

/// Helper of @ref create_nodes().
/// Traverse the model graph and remove virtual Ops that perform no computation.
///
/// @param nodes The list of @ref OpNode.
/// @param boundary_nodes The list of boundary @ref OpNode.
///
void OpGraph::recursive_rm_virt(std::list<std::unique_ptr<OpNode>> &nodes,
std::set<OpNode *> &seen_nodes,
const std::list<OpNode *> &boundary_nodes) {
OpGraph::recursive_traverse_internal(
nodes, seen_nodes, boundary_nodes,
[]() { OPGRAPH_DEBUG("remove virtual ops"); },
[&nodes](OpNode *boundary_node) {
bool seen = false;
if (boundary_node->ops.size() != 1) {
LOG(ERROR, "unexpected error");
}
if (boundary_node->ops[0]->is_virtual()) {
OPGRAPH_DEBUG(" remove op: ", boundary_node->get_name());
// Remove this node from the graph.
boundary_node->remove_self();
// Remove this node from the list of nodes.
auto it = std::find_if(
nodes.begin(), nodes.end(),
[boundary_node](const std::unique_ptr<OpNode> &node) {
return node.get() == boundary_node;
});
if (it == nodes.end()) {
LOG(ERROR, "unexpected error");
}
nodes.erase(it);
OPGRAPH_DEBUG(" nodes.size() ", nodes.size());
} else {
seen = true;
}
return seen;
});
}

/// Helper of @ref create_nodes().
Expand All @@ -197,89 +236,54 @@ void OpGraph::recursive_rm_virt(std::list<std::unique_ptr<OpNode>> &nodes,
void OpGraph::recursive_merge(std::list<std::unique_ptr<OpNode>> &nodes,
std::set<OpNode *> &seen_nodes,
const std::list<OpNode *> &boundary_nodes) {
if (boundary_nodes.size() == 0) {
return;
}
OPGRAPH_DEBUG("merge ops");
std::list<OpNode *> new_boundary_nodes;
for (auto &boundary_node : boundary_nodes) {
OPGRAPH_DEBUG(" boundary node");
OPGRAPH_DEBUG(" op: ", boundary_node->get_name());
if (boundary_node->producers.size() == 0) {
// This node is a root.
seen_nodes.insert(boundary_node);
OPGRAPH_DEBUG(" root");
continue;
}
// Add all producers of this node to the next boundary.
for (auto &producer : boundary_node->producers) {
// Exception: if any user of the producer (rather than the current
// boundary_node) is unseen, we should not add the producer to the
// next boundary.
bool should_add = true;
for (auto &user : producer->users) {
if (user == boundary_node) {
continue;
}
if (seen_nodes.find(user) == seen_nodes.end()) {
should_add = false;
break;
}
OpGraph::recursive_traverse_internal(
nodes, seen_nodes, boundary_nodes, []() { OPGRAPH_DEBUG("merge ops"); },
[&nodes](OpNode *boundary_node) {
if (boundary_node->producers.size() == 0) {
// This node is a root.
OPGRAPH_DEBUG(" root");
return true;
}
if (!should_add) {
continue;
if (boundary_node->producers.size() > 1) {
// This node has multiple producers. It cannot be merged.
OPGRAPH_DEBUG(" multiple producers");
return true;
}
if (seen_nodes.find(producer) != seen_nodes.end()) {
LOG(ERROR, "unexpected error: circular dependency detected");
// This node has only one producer.
OpNode *producer = *(boundary_node->producers.begin());
if (producer->users.size() == 0) {
LOG(ERROR, "unexpected error: graph is incomplete");
}
new_boundary_nodes.emplace_back(producer);
}
if (boundary_node->producers.size() > 1) {
// This node has multiple producers. It cannot be merged.
seen_nodes.insert(boundary_node);
OPGRAPH_DEBUG(" multiple producers");
continue;
}
// This node has only one producer.
OpNode *producer = *(boundary_node->producers.begin());
if (producer->users.size() == 0) {
LOG(ERROR, "unexpected error: graph is incomplete");
}
if (producer->users.size() > 1) {
// The producer has multiple users. It cannot be merged.
seen_nodes.insert(boundary_node);
OPGRAPH_DEBUG(" multiple users");
continue;
}
// The producer has only one user. Merge the two nodes.

// Merge `boundary_node` into `producer`.
OPGRAPH_DEBUG(" merge ops: ", producer->get_name(), " -> ",
boundary_node->get_name());
auto &ops = boundary_node->ops;
producer->ops.insert(producer->ops.end(), ops.begin(), ops.end());
producer->users = boundary_node->users;
for (auto &user : producer->users) {
user->producers.erase(boundary_node);
user->producers.insert(producer);
}
if (producer->users.size() > 1) {
// The producer has multiple users. It cannot be merged.
OPGRAPH_DEBUG(" multiple users");
return true;
}
// The producer has only one user. Merge the two nodes.

// Remove `boundary_node` from `nodes`.
auto it =
std::find_if(nodes.begin(), nodes.end(),
[boundary_node](const std::unique_ptr<OpNode> &node) {
return node.get() == boundary_node;
});
if (it == nodes.end()) {
LOG(ERROR, "unexpected error");
}
nodes.erase(it);
// Merge `boundary_node` into `producer`.
OPGRAPH_DEBUG(" merge ops: ", producer->get_name(), " -> ",
boundary_node->get_name());
auto &ops = boundary_node->ops;
producer->ops.insert(producer->ops.end(), ops.begin(), ops.end());
producer->users = boundary_node->users;
for (auto &user : producer->users) {
user->producers.erase(boundary_node);
user->producers.insert(producer);
}

// Since producer is already in the next boundary and boundary_node is
// merged into producer, we don't need to add anything to
// seen_nodes here.
}
recursive_merge(nodes, seen_nodes, new_boundary_nodes);
// Remove `boundary_node` from `nodes`.
auto it = std::find_if(
nodes.begin(), nodes.end(),
[boundary_node](const std::unique_ptr<OpNode> &node) {
return node.get() == boundary_node;
});
if (it == nodes.end()) {
LOG(ERROR, "unexpected error");
}
nodes.erase(it);
return false;
});
}

OpNode *OpGraph::break_node(OpNode *node, int op_idx) {
Expand Down
18 changes: 18 additions & 0 deletions ark/sched/sched_opgraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#ifndef _ARK_SCHED_OPGRAPH_H_
#define _ARK_SCHED_OPGRAPH_H_

#include <functional>
#include <list>
#include <memory>
#include <set>
Expand Down Expand Up @@ -84,13 +85,30 @@ class OpGraph {
/// @return The new @ref OpNode.
OpNode *break_node(OpNode *node, int op_idx);

/// Traverse OpGraph and return the nodes in the order of execution.
///
/// If there are multiple nodes that can be executed at the same time, they
/// will appear in a random order in the returned vector.
///
/// @return The nodes in the order of execution.
std::vector<OpNode *> get_nodes_in_order();

private:
std::list<std::unique_ptr<OpNode>> nodes_storage;

void create_nodes(const Model &model);

static void recursive_traverse_internal(
std::list<std::unique_ptr<OpNode>> &nodes,
std::set<OpNode *> &seen_nodes,
const std::list<OpNode *> &boundary_nodes,
const std::function<void()> &hook_boundary,
const std::function<bool(OpNode *)> &hook_boundary_node);

static void recursive_rm_virt(std::list<std::unique_ptr<OpNode>> &nodes,
std::set<OpNode *> &seen_nodes,
const std::list<OpNode *> &boundary_nodes);

static void recursive_merge(std::list<std::unique_ptr<OpNode>> &nodes,
std::set<OpNode *> &seen_nodes,
const std::list<OpNode *> &boundary_nodes);
Expand Down

0 comments on commit 5902090

Please sign in to comment.