Skip to content

Commit

Permalink
Fix unnecessarily large paddings
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Oct 20, 2023
1 parent 7ff517d commit bb5bfa9
Show file tree
Hide file tree
Showing 26 changed files with 81 additions and 60 deletions.
4 changes: 3 additions & 1 deletion ark/ops/ops_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ std::string AddOp::function_name(const OpConfig &cfg) const {
Tensor *output = this->outputs[0];

int ndims = output->shape.ndims();
const OpTile &tile_out = cfg.output_tiles[0];
OpTile tile_out = cfg.output_tiles[0];
if (tile_out.x < 0) tile_out.x = output->ldims.dims4()[2];
if (tile_out.y < 0) tile_out.y = output->ldims.dims4()[3];
CHECK(output->ldims[ndims - 1] % tile_out.y == 0);
if (ndims > 1) {
CHECK(output->ldims[ndims - 2] % tile_out.x == 0);
Expand Down
7 changes: 0 additions & 7 deletions ark/ops/ops_all_reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,11 @@ Tensor *Model::all_reduce(Tensor *input, int gpu_id, int gpu_num,
if (output != nullptr) {
LOG(ERROR, "all_reduce output is not supported");
}
if (input->ndims() > 1) {
LOG(ERROR, "supports only 1D input");
}
if (!input->is_sequential()) {
LOG(WARN,
"all_reduce may not work correctly if the input tensor is "
"not contiguous");
}
if (math::pad(input->shape[0], input->pads[0]) < (size_t)input->ldims[0]) {
LOG(ERROR, "all_reduce of a split tensor is not supported");
}

int base = this->impl->next_eid;
Tensor *prev_recv = nullptr;
Tensor *cumulate = input;
Expand Down
4 changes: 3 additions & 1 deletion ark/ops/ops_cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ std::string CastOp::function_name(const OpConfig &cfg) const {
Tensor *output = this->outputs[0];

int ndims = output->shape.ndims();
const OpTile &tile_out = cfg.output_tiles[0];
OpTile tile_out = cfg.output_tiles[0];
if (tile_out.x < 0) tile_out.x = output->ldims.dims4()[2];
if (tile_out.y < 0) tile_out.y = output->ldims.dims4()[3];
CHECK(output->ldims[ndims - 1] % tile_out.y == 0);
if (ndims > 1) {
CHECK(output->ldims[ndims - 2] % tile_out.x == 0);
Expand Down
4 changes: 3 additions & 1 deletion ark/ops/ops_div.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ std::string DivOp::function_name(const OpConfig &cfg) const {
Tensor *output = this->outputs[0];

int ndims = output->shape.ndims();
const OpTile &tile_out = cfg.output_tiles[0];
OpTile tile_out = cfg.output_tiles[0];
if (tile_out.x < 0) tile_out.x = output->ldims.dims4()[2];
if (tile_out.y < 0) tile_out.y = output->ldims.dims4()[3];
CHECK(output->ldims[ndims - 1] % tile_out.y == 0);
if (ndims > 1) {
CHECK(output->ldims[ndims - 2] % tile_out.x == 0);
Expand Down
4 changes: 3 additions & 1 deletion ark/ops/ops_exp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ std::string ExpOp::function_name(const OpConfig &cfg) const {
Tensor *output = this->outputs[0];

int ndims = output->shape.ndims();
const OpTile &tile_out = cfg.output_tiles[0];
OpTile tile_out = cfg.output_tiles[0];
if (tile_out.x < 0) tile_out.x = output->ldims.dims4()[2];
if (tile_out.y < 0) tile_out.y = output->ldims.dims4()[3];
CHECK(output->ldims[ndims - 1] % tile_out.y == 0);
if (ndims > 1) {
CHECK(output->ldims[ndims - 2] % tile_out.x == 0);
Expand Down
4 changes: 3 additions & 1 deletion ark/ops/ops_gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ std::string GeluOp::function_name(const OpConfig &cfg) const {
Tensor *output = this->outputs[0];

int ndims = output->shape.ndims();
const OpTile &tile_out = cfg.output_tiles[0];
OpTile tile_out = cfg.output_tiles[0];
if (tile_out.x < 0) tile_out.x = output->ldims.dims4()[2];
if (tile_out.y < 0) tile_out.y = output->ldims.dims4()[3];
CHECK(output->ldims[ndims - 1] % tile_out.y == 0);
if (ndims > 1) {
CHECK(output->ldims[ndims - 2] % tile_out.x == 0);
Expand Down
4 changes: 3 additions & 1 deletion ark/ops/ops_im2col.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ std::string Im2colOp::function_name(const OpConfig &cfg) const {
Tensor *output = this->outputs[0];

int ndims = output->shape.ndims();
const OpTile &tile_out = cfg.output_tiles[0];
OpTile tile_out = cfg.output_tiles[0];
if (tile_out.x < 0) tile_out.x = output->ldims.dims4()[2];
if (tile_out.y < 0) tile_out.y = output->ldims.dims4()[3];
CHECK(output->ldims[ndims - 1] % tile_out.y == 0);
if (ndims > 1) {
CHECK(output->ldims[ndims - 2] % tile_out.x == 0);
Expand Down
4 changes: 3 additions & 1 deletion ark/ops/ops_layernorm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ std::string LayernormOp::function_name(const OpConfig &cfg) const {
Tensor *output = this->outputs[0];

int ndims = output->shape.ndims();
const OpTile &tile_out = cfg.output_tiles[0];
OpTile tile_out = cfg.output_tiles[0];
if (tile_out.x < 0) tile_out.x = output->ldims.dims4()[2];
if (tile_out.y < 0) tile_out.y = output->ldims.dims4()[3];
CHECK(output->ldims[ndims - 1] % tile_out.y == 0);
if (ndims > 1) {
CHECK(output->ldims[ndims - 2] % tile_out.x == 0);
Expand Down
4 changes: 3 additions & 1 deletion ark/ops/ops_matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ std::string MatmulOp::function_name(const OpConfig &cfg) const {
Tensor *mat_y = this->outputs[0];

int ndims_y = mat_y->shape.ndims();
const OpTile &tile_out = cfg.output_tiles[0];
OpTile tile_out = cfg.output_tiles[0];
if (tile_out.x < 0) tile_out.x = mat_y->ldims.dims4()[2];
if (tile_out.y < 0) tile_out.y = mat_y->ldims.dims4()[3];
CHECK(mat_y->ldims[ndims_y - 1] % tile_out.y == 0);
if (ndims_y > 1) {
CHECK(mat_y->ldims[ndims_y - 2] % tile_out.x == 0);
Expand Down
4 changes: 3 additions & 1 deletion ark/ops/ops_mul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ std::string MulOp::function_name(const OpConfig &cfg) const {
Tensor *output = this->outputs[0];

int ndims = output->shape.ndims();
const OpTile &tile_out = cfg.output_tiles[0];
OpTile tile_out = cfg.output_tiles[0];
if (tile_out.x < 0) tile_out.x = output->ldims.dims4()[2];
if (tile_out.y < 0) tile_out.y = output->ldims.dims4()[3];
CHECK(output->ldims[ndims - 1] % tile_out.y == 0);
if (ndims > 1) {
CHECK(output->ldims[ndims - 2] % tile_out.x == 0);
Expand Down
4 changes: 3 additions & 1 deletion ark/ops/ops_reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ std::string ReduceOp::function_name(const OpConfig &cfg,
Tensor *output = this->outputs[0];

int ndims = output->shape.ndims();
const OpTile &tile_out = cfg.output_tiles[0];
OpTile tile_out = cfg.output_tiles[0];
if (tile_out.x < 0) tile_out.x = output->ldims.dims4()[2];
if (tile_out.y < 0) tile_out.y = output->ldims.dims4()[3];
CHECK(output->ldims[ndims - 1] % tile_out.y == 0);
if (ndims > 1) {
CHECK(output->ldims[ndims - 2] % tile_out.x == 0);
Expand Down
4 changes: 3 additions & 1 deletion ark/ops/ops_relu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ std::string ReluOp::function_name(const OpConfig &cfg) const {
Tensor *output = this->outputs[0];

int ndims = output->shape.ndims();
const OpTile &tile_out = cfg.output_tiles[0];
OpTile tile_out = cfg.output_tiles[0];
if (tile_out.x < 0) tile_out.x = output->ldims.dims4()[2];
if (tile_out.y < 0) tile_out.y = output->ldims.dims4()[3];
CHECK(output->ldims[ndims - 1] % tile_out.y == 0);
if (ndims > 1) {
CHECK(output->ldims[ndims - 2] % tile_out.x == 0);
Expand Down
4 changes: 3 additions & 1 deletion ark/ops/ops_rmsnorm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ std::string RMSnormOp::function_name(const OpConfig &cfg) const {
Tensor *output = this->outputs[0];

int ndims = output->shape.ndims();
const OpTile &tile_out = cfg.output_tiles[0];
OpTile tile_out = cfg.output_tiles[0];
if (tile_out.x < 0) tile_out.x = output->ldims.dims4()[2];
if (tile_out.y < 0) tile_out.y = output->ldims.dims4()[3];
CHECK(output->ldims[ndims - 1] % tile_out.y == 0);
if (ndims > 1) {
CHECK(output->ldims[ndims - 2] % tile_out.x == 0);
Expand Down
4 changes: 3 additions & 1 deletion ark/ops/ops_rope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ std::string RopeOp::function_name(const OpConfig &cfg) const {
Tensor *output = this->outputs[0];

int ndims = output->shape.ndims();
const OpTile &tile_out = cfg.output_tiles[0];
OpTile tile_out = cfg.output_tiles[0];
if (tile_out.x < 0) tile_out.x = output->ldims.dims4()[2];
if (tile_out.y < 0) tile_out.y = output->ldims.dims4()[3];
CHECK(output->ldims[ndims - 1] % tile_out.y == 0);
if (ndims > 1) {
CHECK(output->ldims[ndims - 2] % tile_out.x == 0);
Expand Down
4 changes: 3 additions & 1 deletion ark/ops/ops_scale.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ std::string ScaleOp::function_name(const OpConfig &cfg) const {
Tensor *output = this->outputs[0];

int ndims = output->shape.ndims();
const OpTile &tile_out = cfg.output_tiles[0];
OpTile tile_out = cfg.output_tiles[0];
if (tile_out.x < 0) tile_out.x = output->ldims.dims4()[2];
if (tile_out.y < 0) tile_out.y = output->ldims.dims4()[3];
CHECK(output->ldims[ndims - 1] % tile_out.y == 0);
if (ndims > 1) {
CHECK(output->ldims[ndims - 2] % tile_out.x == 0);
Expand Down
12 changes: 12 additions & 0 deletions ark/ops/ops_sendrecv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ RecvOp::RecvOp(const std::string &prec_type, Tensor *output, int sid, int rank,

std::string RecvOp::function_name(const OpConfig &) const {
Tensor *output = this->outputs[0];
if (!output->is_sequential()) {
LOG(INFO, "output shape: ", output->shape);
LOG(INFO, "output ldims: ", output->ldims);
LOG(INFO, "output offs: ", output->offs);
LOG(INFO, "output pads: ", output->pads);
}
CHECK(output->is_sequential());

int sid;
Expand Down Expand Up @@ -110,6 +116,9 @@ Tensor *Model::send(Tensor *input, int id, int dst_rank, size_t bytes,
bytes = max_bytes;
}
input->exported = true;
if (!input->is_sequential()) {
LOG(ERROR, "input tensor must be sequential");
}
SendOp op{"none", input, id, this->impl->rank, dst_rank, bytes, name};
return this->impl->add_op(op)[0];
}
Expand Down Expand Up @@ -138,6 +147,9 @@ Tensor *Model::recv(int id, int src_rank, size_t bytes, Tensor *output,
if (bytes == 0) {
bytes = max_bytes;
}
if (!output->is_sequential()) {
LOG(ERROR, "output tensor must be sequential");
}
RecvOp op{"none", output, id, this->impl->rank, src_rank, bytes, name};
return this->impl->add_op(op)[0];
}
Expand Down
8 changes: 6 additions & 2 deletions ark/ops/ops_sendrecv_mm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ std::string SendMMOp::function_name(const OpConfig &cfg) const {
Tensor *output = this->outputs[0];

int ndims = output->shape.ndims();
const OpTile &tile_out = cfg.output_tiles[0];
OpTile tile_out = cfg.output_tiles[0];
if (tile_out.x < 0) tile_out.x = output->ldims.dims4()[2];
if (tile_out.y < 0) tile_out.y = output->ldims.dims4()[3];
CHECK(output->ldims[ndims - 1] % tile_out.y == 0);
if (ndims > 1) {
CHECK(output->ldims[ndims - 2] % tile_out.x == 0);
Expand Down Expand Up @@ -88,7 +90,9 @@ std::string RecvMMOp::function_name(const OpConfig &cfg) const {
Tensor *output = this->outputs[0];

int ndims = output->shape.ndims();
const OpTile &tile_out = cfg.output_tiles[0];
OpTile tile_out = cfg.output_tiles[0];
if (tile_out.x < 0) tile_out.x = output->ldims.dims4()[2];
if (tile_out.y < 0) tile_out.y = output->ldims.dims4()[3];
CHECK(output->ldims[ndims - 1] % tile_out.y == 0);
if (ndims > 1) {
CHECK(output->ldims[ndims - 2] % tile_out.x == 0);
Expand Down
4 changes: 3 additions & 1 deletion ark/ops/ops_sigmoid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ std::string SigmoidOp::function_name(const OpConfig &cfg) const {
Tensor *output = this->outputs[0];

int ndims = output->shape.ndims();
const OpTile &tile_out = cfg.output_tiles[0];
OpTile tile_out = cfg.output_tiles[0];
if (tile_out.x < 0) tile_out.x = output->ldims.dims4()[2];
if (tile_out.y < 0) tile_out.y = output->ldims.dims4()[3];
CHECK(output->ldims[ndims - 1] % tile_out.y == 0);
if (ndims > 1) {
CHECK(output->ldims[ndims - 2] % tile_out.x == 0);
Expand Down
5 changes: 4 additions & 1 deletion ark/ops/ops_softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ std::string SoftmaxOp::function_name(const OpConfig &cfg) const {
Tensor *input = this->inputs[0];
Tensor *output = this->outputs[0];

const OpTile &tile_out = cfg.output_tiles[0];
OpTile tile_out = cfg.output_tiles[0];
if (tile_out.x < 0) tile_out.x = output->ldims.dims4()[2];
if (tile_out.y < 0) tile_out.y = output->ldims.dims4()[3];

Dims unit_out_dims{1, 1, tile_out.x, tile_out.y};

return Op::function_name("ark::softmax",
Expand Down
4 changes: 3 additions & 1 deletion ark/ops/ops_sqrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ std::string SqrtOp::function_name(const OpConfig &cfg) const {
Tensor *output = this->outputs[0];

int ndims = output->shape.ndims();
const OpTile &tile_out = cfg.output_tiles[0];
OpTile tile_out = cfg.output_tiles[0];
if (tile_out.x < 0) tile_out.x = output->ldims.dims4()[2];
if (tile_out.y < 0) tile_out.y = output->ldims.dims4()[3];
CHECK(output->ldims[ndims - 1] % tile_out.y == 0);
if (ndims > 1) {
CHECK(output->ldims[ndims - 2] % tile_out.x == 0);
Expand Down
4 changes: 3 additions & 1 deletion ark/ops/ops_sub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ std::string SubOp::function_name(const OpConfig &cfg) const {
Tensor *output = this->outputs[0];

int ndims = output->shape.ndims();
const OpTile &tile_out = cfg.output_tiles[0];
OpTile tile_out = cfg.output_tiles[0];
if (tile_out.x < 0) tile_out.x = output->ldims.dims4()[2];
if (tile_out.y < 0) tile_out.y = output->ldims.dims4()[3];
CHECK(output->ldims[ndims - 1] % tile_out.y == 0);
if (ndims > 1) {
CHECK(output->ldims[ndims - 2] % tile_out.x == 0);
Expand Down
4 changes: 3 additions & 1 deletion ark/ops/ops_transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ std::string TransposeOp::function_name(const OpConfig &cfg) const {

Tensor *input = this->inputs[0];
Tensor *output = this->outputs[0];
const OpTile &tile_out = cfg.output_tiles[0];
OpTile tile_out = cfg.output_tiles[0];
if (tile_out.x < 0) tile_out.x = output->ldims.dims4()[2];
if (tile_out.y < 0) tile_out.y = output->ldims.dims4()[3];
Dims unit_out_dims{1, 1, tile_out.x, tile_out.y};

return Op::function_name("ark::transpose" + tp_type_str,
Expand Down
10 changes: 1 addition & 9 deletions ark/sched/sched.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,7 @@ const OpConfig *BaseScheduler::sched_op_config(const Op *op) {
} else {
cfg = high_priority_candidates[0];
}
OpConfig *cfg_new = new OpConfig(*cfg);
OpTile &op_tile = cfg_new->output_tiles[0];
if (op_tile.x == -1) {
op_tile.x = ldims4[2];
}
if (op_tile.y == -1) {
op_tile.y = ldims4[3];
}
return cfg_new;
return cfg;
}

} // namespace ark
20 changes: 1 addition & 19 deletions ark/sched/sched_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using namespace std;
namespace ark {

SchedOp::SchedOp(const Op *op_, const OpConfig *cfg_, const string name)
: op{op_}, cfg{cfg_}, name{name}, tnums{} {
: op{op_}, cfg{cfg_}, name{name} {
if (op_ == nullptr) {
return;
}
Expand Down Expand Up @@ -82,24 +82,6 @@ SchedOp::SchedOp(const Op *op_, const OpConfig *cfg_, const string name)
}
this->op->outputs[i]->update_pads(pads);
}
// claculate the tile size for the SchedOp
if ((this->op->outputs.size() == 1) && (this->cfg != nullptr)) {
const OpTile &tile = this->cfg->output_tiles[0];
const Dims &s = this->op->outputs[0]->shape;
int ndims = s.ndims();
vector<DimType> vec;
if (ndims == 1) {
vec.emplace_back((DimType)math::div_up(s[0], tile.y));
} else {
int i = 0;
for (; i < ndims - 2; ++i) {
vec.emplace_back(s[i]);
}
vec.emplace_back((DimType)math::div_up(s[i], tile.x));
vec.emplace_back((DimType)math::div_up(s[i + 1], tile.y));
}
this->tnums = Dims{vec};
}
}

const string SchedOp::function_name() const {
Expand Down
3 changes: 0 additions & 3 deletions ark/sched/sched_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class SchedOp {
SchedOp(const Op *op_, const OpConfig *cfg_, const std::string name);
const Op *get_op() const { return op; }
int get_num_warps() const { return const_cast<OpConfig *>(cfg)->num_warps; }
const Dims &get_tnums() const { return tnums; }
const std::string &get_name() const { return name; }
const OpConfig *get_cfg() const { return cfg; }
const std::string function_name() const;
Expand All @@ -27,8 +26,6 @@ class SchedOp {
const Op *op;
const OpConfig *cfg;
std::string name;
// The number of tiles along each axis of the operator.
Dims tnums;
};

} // namespace ark
Expand Down
4 changes: 3 additions & 1 deletion ark/sched/sched_opseq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ bool SchedOpSeq::append(const Op *op, const OpConfig *cfg) {
}
if ((op->outputs.size() > 0) && (wn > 0)) {
const Dims &s = op->outputs[0]->shape;
const OpTile &tile = cfg->output_tiles[0];
OpTile tile = cfg->output_tiles[0];
if (tile.x < 0) tile.x = s.dims4()[2];
if (tile.y < 0) tile.y = s.dims4()[3];

int ndims = s.ndims();
assert(ndims != 0);
Expand Down

0 comments on commit bb5bfa9

Please sign in to comment.