Skip to content

Commit

Permalink
support padding
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed May 19, 2024
1 parent 2daa05c commit 169c127
Show file tree
Hide file tree
Showing 30 changed files with 473 additions and 372 deletions.
4 changes: 2 additions & 2 deletions ark/api/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ Dims Tensor::offsets() const {
return Dims();

Check warning on line 36 in ark/api/tensor.cpp

View check run for this annotation

Codecov / codecov/patch

ark/api/tensor.cpp#L36

Added line #L36 was not covered by tests
}

Dims Tensor::pads() const {
Dims Tensor::padded_shape() const {
if (ref_) {
return ref_->pads();
return ref_->padded_shape();

Check warning on line 41 in ark/api/tensor.cpp

View check run for this annotation

Codecov / codecov/patch

ark/api/tensor.cpp#L39-L41

Added lines #L39 - L41 were not covered by tests
}
return Dims();

Check warning on line 43 in ark/api/tensor.cpp

View check run for this annotation

Codecov / codecov/patch

ark/api/tensor.cpp#L43

Added line #L43 was not covered by tests
}
Expand Down
27 changes: 14 additions & 13 deletions ark/include/ark/model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,29 @@ class Model : public ModelGraph {
///
/// @param shape Shape of the tensor, where the data of interest is.
/// @param dtype Type of the tensor data.
/// @param strides Leading dimensions (ldim) of the tensor, which may be
/// @param strides Strides of each dimensions of the tensor, which may be
/// different from the shape. @p strides can be considered as the actual
/// shape of the underlying data buffer (@ref TensorBuf).
/// shape of the underlying data buffer (@ref ModelBuffer).
/// @param offsets Offsets of the tensor. The data of interest starts at
/// @p offsets and ends at @p offsets + @p shape.
/// @param pads If a dimension of @p pads is set to larger than 1, the
/// corresponding ldim will be set to the minimum multiple of @p pads that
/// is larger than or equal to the previous ldim. Padding is accumulated
/// across all tensors that share the same @ref TensorBuf. For example, if
/// one tensor sets the last dimension of @p pads to 2, and another tensor
/// sets the last dimension of @p pads to 3, then the corresponding ldim
/// will be the minimum multiple of 2x3=6 that is larger than or equal to
/// the corresponding dimension of @p offsets + @p shape.
/// @p offsets and ends at @p offsets + @p padded_shape.
/// @param padded_shape Padded shape of the tensor. Padding is used to
/// reserve extra space for the tensor when computation requires it.
/// Data on the padded region is allowed to be accessed by computation,
/// but it is not considered as the data of interest. The padded region is
/// initialized to zero only once when the Executor is launched. The padded
/// shape should be greater than or equal to the @p shape, and the
/// @p strides should be greater than or equal to the padded shape. If the
/// @p strides are not provided, they are set to the padded shape. If the
/// padded shape is not provided, it is set to the @p shape.
/// @param name Name of the tensor.
/// @return Pointer to a tensor object.
///
Tensor tensor(const Dims &shape, const DataType &data_type,
const Dims &strides = {}, const Dims &offsets = {},
const Dims &pads = {}, const std::string &name = "");
const Dims &padded_shape = {}, const std::string &name = "");

Tensor refer(Tensor input, const Dims &shape = {}, const Dims &strides = {},
const Dims &offsets = {}, const Dims &pads = {},
const Dims &offsets = {}, const Dims &padded_shape = {},
const std::string &name = "");

// Reshape `input` to `shape`. If one dimension of `shape` is -1, it will be
Expand Down
2 changes: 1 addition & 1 deletion ark/include/ark/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class Tensor {

Dims offsets() const;

Dims pads() const;
Dims padded_shape() const;

const DataType &data_type() const;
};
Expand Down
59 changes: 28 additions & 31 deletions ark/model/model_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace ark {

ModelTensor::ModelTensor(ModelDataType data_type, ModelBufferRef buffer,
const Dims &shape, const Dims &strides,
const Dims &offsets, const Dims &pads)
const Dims &offsets, const Dims &padded_shape)
: data_type_(data_type), buffer_(buffer) {
if (shape.nelems() == 0) {
ERR(InvalidUsageError,

Check warning on line 18 in ark/model/model_tensor.cpp

View check run for this annotation

Codecov / codecov/patch

ark/model/model_tensor.cpp#L18

Added line #L18 was not covered by tests
Expand All @@ -24,8 +24,26 @@ ModelTensor::ModelTensor(ModelDataType data_type, ModelBufferRef buffer,
shape_ = shape;
}
int ndims = shape_.ndims();
if (padded_shape.is_no_dim()) {
padded_shape_ = shape_;
} else {
if (ndims != padded_shape.ndims()) {
ERR(InvalidUsageError,

Check warning on line 31 in ark/model/model_tensor.cpp

View check run for this annotation

Codecov / codecov/patch

ark/model/model_tensor.cpp#L31

Added line #L31 was not covered by tests
"Tensor shape and padded shape should have the same number of "
"dimensions. Given: shape ",
shape_, " padded_shape ", padded_shape);
}
padded_shape_ = padded_shape;
}
for (int i = 0; i < ndims; ++i) {
if (shape_[i] > padded_shape_[i]) {
ERR(InvalidUsageError,

Check warning on line 40 in ark/model/model_tensor.cpp

View check run for this annotation

Codecov / codecov/patch

ark/model/model_tensor.cpp#L40

Added line #L40 was not covered by tests
"Tensor shape exceeds the padded shape. shape ", shape_,
" padded_shape ", padded_shape_);
}
}
if (strides.is_no_dim()) {
strides_ = shape_;
strides_ = padded_shape_;
} else {
if (ndims != strides.ndims()) {
ERR(InvalidUsageError,

Check warning on line 49 in ark/model/model_tensor.cpp

View check run for this annotation

Codecov / codecov/patch

ark/model/model_tensor.cpp#L49

Added line #L49 was not covered by tests
Expand All @@ -50,32 +68,11 @@ ModelTensor::ModelTensor(ModelDataType data_type, ModelBufferRef buffer,
}
offsets_ = offsets;
}
if (pads.is_no_dim()) {
std::vector<DimType> dims_vec;
for (int i = 0; i < ndims; ++i) {
dims_vec.push_back(1);
}
pads_ = Dims{dims_vec};
} else {
if (ndims != pads.ndims()) {
ERR(InvalidUsageError,
"Tensor shape and pads should have the same number of "
"dimensions. Given: shape ",
shape_, " pads ", pads);
}
pads_ = pads;
}
for (int i = 0; i < ndims; ++i) {
if (strides_[i] % pads_[i] != 0) {
ERR(InvalidUsageError,
"Tensor strides should be a multiple of pads. strides ",
strides_, " pads ", pads_);
}
}
for (int i = 0; i < ndims; ++i) {
if (offsets_[i] + shape_[i] > strides_[i]) {
if (offsets_[i] + padded_shape_[i] > strides_[i]) {
ERR(InvalidUsageError, "Tensor exceeds the memory boundary. offs ",

Check warning on line 73 in ark/model/model_tensor.cpp

View check run for this annotation

Codecov / codecov/patch

ark/model/model_tensor.cpp#L73

Added line #L73 was not covered by tests
offsets_, " shape ", shape_, " strides ", strides_);
offsets_, " padded_shape ", padded_shape_, " strides ",
strides_);
}
}
id_ = next_id();
Expand All @@ -88,7 +85,7 @@ ModelTensor::ModelTensor(const ModelTensor &other) {
shape_ = other.shape_;
strides_ = other.strides_;
offsets_ = other.offsets_;
pads_ = other.pads_;
padded_shape_ = other.padded_shape_;
}

size_t ModelTensor::shape_bytes() const {
Expand All @@ -103,7 +100,7 @@ Json ModelTensor::serialize() const {
j["Shape"] = shape_.vector();
j["Strides"] = strides_.vector();
j["Offsets"] = offsets_.vector();
j["Pads"] = pads_.vector();
j["PaddedShape"] = padded_shape_.vector();
return j;
}

Expand All @@ -123,9 +120,9 @@ std::shared_ptr<ModelTensor> ModelTensor::deserialize(const Json &serialized) {
} else if (!serialized.contains("Offsets")) {
ERR(InvalidUsageError,

Check warning on line 121 in ark/model/model_tensor.cpp

View check run for this annotation

Codecov / codecov/patch

ark/model/model_tensor.cpp#L121

Added line #L121 was not covered by tests
"ModelTensor deserialization failed: missing Offsets");
} else if (!serialized.contains("Pads")) {
} else if (!serialized.contains("PaddedShape")) {
ERR(InvalidUsageError,

Check warning on line 124 in ark/model/model_tensor.cpp

View check run for this annotation

Codecov / codecov/patch

ark/model/model_tensor.cpp#L124

Added line #L124 was not covered by tests
"ModelTensor deserialization failed: missing Pads");
"ModelTensor deserialization failed: missing PaddedShape");
} else if (!serialized.contains("Id")) {
ERR(InvalidUsageError,

Check warning on line 127 in ark/model/model_tensor.cpp

View check run for this annotation

Codecov / codecov/patch

ark/model/model_tensor.cpp#L127

Added line #L127 was not covered by tests
"ModelTensor deserialization failed: missing Id");
Expand All @@ -136,7 +133,7 @@ std::shared_ptr<ModelTensor> ModelTensor::deserialize(const Json &serialized) {
serialized["Shape"].get<std::vector<DimType>>(),
serialized["Strides"].get<std::vector<DimType>>(),
serialized["Offsets"].get<std::vector<DimType>>(),
serialized["Pads"].get<std::vector<DimType>>());
serialized["PaddedShape"].get<std::vector<DimType>>());
ret->id_ = serialized["Id"];
return ret;
}
Expand Down
6 changes: 3 additions & 3 deletions ark/model/model_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class ModelTensor {
public:
ModelTensor(ModelDataType data_type, ModelBufferRef buffer,
const Dims &shape, const Dims &strides = {},
const Dims &offsets = {}, const Dims &pads = {});
const Dims &offsets = {}, const Dims &padded_shape = {});

ModelTensor(const ModelTensor &other);

Expand All @@ -33,7 +33,7 @@ class ModelTensor {

const Dims &offsets() const { return offsets_; }

const Dims &pads() const { return pads_; }
const Dims &padded_shape() const { return padded_shape_; }

size_t shape_bytes() const;

Expand All @@ -50,7 +50,7 @@ class ModelTensor {
Dims shape_;
Dims strides_;
Dims offsets_;
Dims pads_;
Dims padded_shape_;
};

} // namespace ark
Expand Down
44 changes: 19 additions & 25 deletions ark/ops/ops_cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ ModelOpCast::ModelOpCast(ModelTensorRef input, ModelDataType data_type,

static void byte_cast_helper(ModelTensorRef input, ModelDataType data_type,
Dims &new_shape, Dims &new_strides,
Dims &new_offsets, Dims &new_pads) {
Dims &new_offsets, Dims &new_padded_shape) {
if (input->data_type() == BYTE.ref()) {
if (input->shape_bytes() < data_type->bytes()) {
ERR(InvalidUsageError, "input tensor is too small to be casted to ",
Expand Down Expand Up @@ -64,48 +64,42 @@ static void byte_cast_helper(ModelTensorRef input, ModelDataType data_type,
"tensor type (",
data_type->bytes(), ")");
}
if (input->pads()[last_dim] > 1) {
// we can ignore pads if it is 1
if ((input->pads()[last_dim] % data_type->bytes()) != 0) {
ERR(InvalidUsageError,
"the last greater-than-1 dimension of the "
"input tensor pads ",
input->pads()[last_dim],
" is not divisible by the size of the output "
"tensor type (",
data_type->bytes(), ")");
}
if ((input->padded_shape()[last_dim] % data_type->bytes()) != 0) {
ERR(InvalidUsageError,

Check warning on line 68 in ark/ops/ops_cast.cpp

View check run for this annotation

Codecov / codecov/patch

ark/ops/ops_cast.cpp#L68

Added line #L68 was not covered by tests
"the last greater-than-1 dimension of the "
"input tensor padded_shape ",
input->padded_shape()[last_dim],
" is not divisible by the size of the output "
"tensor type (",
data_type->bytes(), ")");
}
new_shape = input->shape();
new_strides = input->strides();
new_offsets = input->offsets();
new_pads = input->pads();
new_padded_shape = input->padded_shape();
new_shape[last_dim] /= data_type->bytes();
new_strides[last_dim] /= data_type->bytes();
new_offsets[last_dim] /= data_type->bytes();
if (new_pads[last_dim] > 1) {
new_pads[last_dim] /= data_type->bytes();
}
new_padded_shape[last_dim] /= data_type->bytes();
} else if (data_type == BYTE.ref()) {
new_shape = input->shape();
new_strides = input->strides();
new_offsets = input->offsets();
new_pads = input->pads();
new_padded_shape = input->padded_shape();
new_shape[-1] *= input->data_type()->bytes();
new_strides[-1] *= input->data_type()->bytes();
new_offsets[-1] *= input->data_type()->bytes();
if (new_pads[-1] > 1) {
new_pads[-1] *= input->data_type()->bytes();
}
new_padded_shape[-1] *= input->data_type()->bytes();
} else {
ERR(ModelError, "unexpected error");

Check warning on line 94 in ark/ops/ops_cast.cpp

View check run for this annotation

Codecov / codecov/patch

ark/ops/ops_cast.cpp#L94

Added line #L94 was not covered by tests
}
}

ModelOpByteCast::ModelOpByteCast(ModelTensorRef input, ModelDataType data_type,
const Dims &shape, const Dims &strides,
const Dims &offsets, const Dims &pads)
: ModelOpTensor(input->buffer(), shape, data_type, strides, offsets, pads) {
const Dims &offsets, const Dims &padded_shape)
: ModelOpTensor(input->buffer(), shape, data_type, strides, offsets,
padded_shape) {
read_tensors_ = {input};
verify();
}
Expand All @@ -121,13 +115,13 @@ Tensor Model::cast(Tensor input, const DataType &data_type, Tensor output,
} else if (data_type == BYTE || input.data_type() == BYTE) {
// Casting to/from BYTE without the output tensor specified is
// handled by `ModelOpByteCast`.
Dims new_shape, new_strides, new_offsets, new_pads;
Dims new_shape, new_strides, new_offsets, new_padded_shape;
byte_cast_helper(input.ref(), data_type.ref(), new_shape,
new_strides, new_offsets, new_pads);
new_strides, new_offsets, new_padded_shape);
return impl_
->create_op<ModelOpByteCast>(name, input.ref(), data_type.ref(),
new_shape, new_strides,
new_offsets, new_pads)
new_offsets, new_padded_shape)
->result_tensors()[0];
}
}
Expand Down
2 changes: 1 addition & 1 deletion ark/ops/ops_cast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ModelOpByteCast : public ModelOpTensor {
ModelOpByteCast() = default;
ModelOpByteCast(ModelTensorRef input, ModelDataType data_type,
const Dims &shape, const Dims &strides, const Dims &offsets,
const Dims &pads);
const Dims &padded_shape);
};

} // namespace ark
Expand Down
2 changes: 1 addition & 1 deletion ark/ops/ops_cast_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ ark::unittest::State test_cast_invalid() {
ark::Model m;
ark::Tensor t0 = m.tensor({8, 1}, ark::BYTE);
m.cast(t0, ark::FP32); // ok
ark::Tensor t1 = m.tensor({8, 1}, ark::BYTE, {9, 1}, {0, 0}, {3, 1});
ark::Tensor t1 = m.tensor({8, 1}, ark::BYTE, {13, 1}, {0, 0}, {9, 1});
UNITTEST_THROW(m.cast(t1, ark::FP32), ark::InvalidUsageError);
}
{
Expand Down
15 changes: 15 additions & 0 deletions ark/ops/ops_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,21 @@ void check_match_shape(ModelTensorRef tensor, const Dims &shape) {
}
}

void check_match_padded_shape(ModelTensorRef a, ModelTensorRef b) {
if (a->padded_shape() != b->padded_shape()) {
ERR(InvalidUsageError, "padded shapes mismatch: ", a->padded_shape(),

Check warning on line 53 in ark/ops/ops_common.cpp

View check run for this annotation

Codecov / codecov/patch

ark/ops/ops_common.cpp#L51-L53

Added lines #L51 - L53 were not covered by tests
" != ", b->padded_shape());
}
}

void check_match_padded_shape(ModelTensorRef tensor, const Dims &padded_shape) {
if (tensor->padded_shape() != padded_shape) {
ERR(InvalidUsageError,

Check warning on line 60 in ark/ops/ops_common.cpp

View check run for this annotation

Codecov / codecov/patch

ark/ops/ops_common.cpp#L60

Added line #L60 was not covered by tests
"padded shape mismatch: ", tensor->padded_shape(),
" != ", padded_shape);
}
}

Dims broadcast_shape(const Dims &dims1, const Dims &dims2) {
std::vector<DimType> output_dims_reversed;
int ndims = std::max(dims1.ndims(), dims2.ndims());
Expand Down
4 changes: 4 additions & 0 deletions ark/ops/ops_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ void check_match_shape(ModelTensorRef a, ModelTensorRef b);

void check_match_shape(ModelTensorRef tensor, const Dims &shape);

void check_match_padded_shape(ModelTensorRef a, ModelTensorRef b);

void check_match_padded_shape(ModelTensorRef tensor, const Dims &padded_shape);

/// Return the output shape of broadcasting between two shapes.
/// Follow NumPy rules.
/// https://numpy.org/doc/stable/user/basics.broadcasting.html
Expand Down
6 changes: 4 additions & 2 deletions ark/ops/ops_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ ModelOpSend::ModelOpSend(ModelTensorRef input, int remote_rank, int tag,
} else {
output = std::make_shared<ModelTensor>(
input->data_type(), std::make_shared<ModelBuffer>(remote_rank),
input->shape(), input->strides(), input->offsets(), input->pads());
input->shape(), input->strides(), input->offsets(),
input->padded_shape());
}
input->buffer()->tag_send(remote_rank, tag);
output->buffer()->tag_recv(-1, tag);
Expand Down Expand Up @@ -87,7 +88,8 @@ ModelOpRecv::ModelOpRecv(ModelTensorRef output, int remote_rank, int tag)
ModelTensorRef result = std::make_shared<ModelTensor>(*output);
ModelTensorRef input = std::make_shared<ModelTensor>(
output->data_type(), std::make_shared<ModelBuffer>(remote_rank),
output->shape(), output->strides(), output->offsets(), output->pads());
output->shape(), output->strides(), output->offsets(),
output->padded_shape());
input->buffer()->tag_send(-1, tag);
output->buffer()->tag_recv(remote_rank, tag);

Expand Down
2 changes: 1 addition & 1 deletion ark/ops/ops_identity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace ark {
ModelOpIdentity::ModelOpIdentity(ModelTensorRef input,
const std::vector<ModelTensorRef> &deps)
: ModelOpTensor(input->buffer(), input->shape(), input->data_type(),
input->strides(), input->offsets(), input->pads()) {
input->strides(), input->offsets(), input->padded_shape()) {
std::set<ModelTensorRef> dep_set;
dep_set.emplace(input);
read_tensors_.emplace_back(input);
Expand Down
Loading

0 comments on commit 169c127

Please sign in to comment.