Skip to content

Commit

Permalink
byte casting
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Sep 25, 2023
1 parent 2d03ec4 commit a97a562
Show file tree
Hide file tree
Showing 10 changed files with 404 additions and 111 deletions.
8 changes: 6 additions & 2 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
{
"configurations": [
{
"name": "ops_matmul_test",
"name": "ops_cast_test",
"type": "cppdbg",
"request": "launch",
"program": "${workspaceFolder}/build/ark/ops_matmul_test.cu",
"program": "${workspaceFolder}/build/ark/ops_cast_test",
"args": [],
"stopAtEntry": false,
"cwd": "${fileDirname}",
"environment": [
{
"name": "ARK_ROOT",
"value": "${workspaceFolder}/build"
},
{
"name": "ARK_LOG_LEVEL",
"value": "DEBUG"
}
],
"externalConsole": false,
Expand Down
69 changes: 56 additions & 13 deletions ark/include/ark.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,59 @@ class TensorBuf
friend class BaseScheduler;
};

// Type of tensor data.
typedef enum
/// Type of tensor data.
class TensorType
{
FP16,
FP32,
INT32,
BYTE,
} TensorType;
private:
const int id_;
const int bytes_;
const std::string name_;
const std::string pointer_name_;

public:
TensorType(int id = -1, int bytes = 0, const std::string &name = "none",
const std::string &pointer_name = "void *");

bool operator==(const TensorType &other) const;
bool operator!=(const TensorType &other) const;

int id() const;
int bytes() const;
const std::string &name() const;
const std::string &pointer_name() const;
};

class Fp16 : public TensorType
{
public:
Fp16();
};

class Fp32 : public TensorType
{
public:
Fp32();
};

class Int32 : public TensorType
{
public:
Int32();
};

class Byte : public TensorType
{
public:
Byte();
};

const TensorType NONE;
const Fp16 FP16;
const Fp32 FP32;
const Int32 INT32;
const Byte BYTE;

std::ostream &operator<<(std::ostream &os, TensorType type);
std::ostream &operator<<(std::ostream &os, const TensorType &type);

/// Tensor is a view of a TensorBuf.
///
Expand All @@ -134,7 +177,7 @@ class Tensor
{
public:
/// Tensor constructor.
Tensor(const Dims &shape, TensorType type, TensorBuf *buf,
Tensor(const Dims &shape, const TensorType &type, TensorBuf *buf,
const Dims &ldims, const Dims &offs, const Dims &pads, bool exported,
int imported_rank, int id, const std::string &name);
Tensor(const Tensor &) = default;
Expand Down Expand Up @@ -273,7 +316,7 @@ class Model
/// Returns a tensor object.
///
/// @param shape Shape of the tensor, where the data of interest is.
/// @param type Type of the tensor data.
/// @param ttype Type of the tensor data.
/// @param buf The @ref TensorBuf that holds the entire data including the
/// padding.
/// @param ldims Leading dimensions (ldim) of the tensor, which may be
Expand All @@ -300,7 +343,7 @@ class Model
/// @param name Name of the tensor.
/// @return Pointer to a tensor object.
///
Tensor *tensor(const Dims &shape, TensorType dtype,
Tensor *tensor(const Dims &shape, const TensorType &ttype,
TensorBuf *buf = nullptr, const Dims &ldims = {},
const Dims &offs = {}, const Dims &pads = {},
const std::vector<Tensor *> &deps = {},
Expand Down Expand Up @@ -472,8 +515,8 @@ class Model
Tensor *embedding(Tensor *input, Tensor *weight, Tensor *output = nullptr,
const std::string &name = "embedding");
/// Tensor type casting.
Tensor *cast(Tensor *input, TensorType ttype, Tensor *output = nullptr,
const std::string &name = "cast");
Tensor *cast(Tensor *input, const TensorType &ttype,
Tensor *output = nullptr, const std::string &name = "cast");

/// Verify if this model is valid.
/// @return true if the model is valid, false otherwise.
Expand Down
109 changes: 103 additions & 6 deletions ark/ops/ops_cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,114 @@ std::string CastOp::function_name(const OpConfig &cfg) const
}});
}

Tensor *Model::cast(Tensor *input, TensorType ttype, Tensor *output,
Tensor *Model::cast(Tensor *input, const TensorType &ttype, Tensor *output,
const std::string &name)
{
assert(input != nullptr);
if (output != nullptr && ttype != output->type) {
LOG(ERROR, "invalid output data type: ", output->type);
}
if (output == nullptr) {
if (input->type == ttype) {
// Casting to the same type is considered as an identity,
// only when the output tensor is not specified.
return this->identity(input, {}, name);
}
if (input->type == BYTE) {
// Casting BYTE to other types is considered as a reshape.
if (input->shape_bytes() < ttype.bytes()) {
LOG(ERROR, "input tensor is too small to be casted to ", ttype);
}
// The last greater-than-1 dimension of the input tensor should be
// divisible by the size of the output type.
int last_dim = input->shape.ndims() - 1;
for (; last_dim >= 0; --last_dim) {
if (last_dim == 0 || input->ldims[last_dim] > 1) {
break;
}
}
if ((input->shape[last_dim] % ttype.bytes()) != 0) {
LOG(ERROR,
"the last greater-than-1 dimension of the "
"input tensor shape ",
input->shape[last_dim],
" is not divisible by the size of the output "
"tensor type (",
ttype.bytes(), ")");
}
if ((input->ldims[last_dim] % ttype.bytes()) != 0) {
LOG(ERROR,
"the last greater-than-1 dimension of the "
"input tensor ldims ",
input->ldims[last_dim],
" is not divisible by the size of the output "
"tensor type (",
ttype.bytes(), ")");
}
if ((input->offs[last_dim] % ttype.bytes()) != 0) {
LOG(ERROR,
"the last greater-than-1 dimension of the "
"input tensor offs ",
input->offs[last_dim],
" is not divisible by the size of the output "
"tensor type (",
ttype.bytes(), ")");
}
if (input->pads[last_dim] > 1) {
// we can ignore pads if it is 1
if ((input->pads[last_dim] % ttype.bytes()) != 0) {
LOG(ERROR,
"the last greater-than-1 dimension of the "
"input tensor pads ",
input->pads[last_dim],
" is not divisible by the size of the output "
"tensor type (",
ttype.bytes(), ")");
}
}

auto out_shape = input->shape;
auto out_ldims = input->ldims;
auto out_offs = input->offs;
auto out_pads = input->pads;
out_shape[last_dim] /= ttype.bytes();
out_ldims[last_dim] /= ttype.bytes();
out_offs[last_dim] /= ttype.bytes();
if (out_pads[last_dim] > 1) {
out_pads[last_dim] /= ttype.bytes();
}
return this->tensor(out_shape, ttype, input->buf, out_ldims,
out_offs, out_pads, {input}, input->exported,
input->imported_rank, name + "/cast");
}
if (ttype == BYTE) {
// Casting other types to BYTE is considered as a reshape.
auto out_shape = input->shape;
auto out_ldims = input->ldims;
auto out_offs = input->offs;
auto out_pads = input->pads;
out_shape[-1] *= input->type.bytes();
out_ldims[-1] *= input->type.bytes();
out_offs[-1] *= input->type.bytes();
if (out_pads[-1] > 1) {
out_pads[-1] *= input->type.bytes();
}
return this->tensor(out_shape, ttype, input->buf, out_ldims,
out_offs, out_pads, {input}, input->exported,
input->imported_rank, name + "/cast");
}
output = this->tensor(input->shape, ttype);
} else if (output->shape != input->shape) {
LOG(ERROR, "invalid output shape: ", output->shape);
} else {
if (output->type != ttype) {
LOG(ERROR, "invalid output data type: ", output->type);
}
if (output->shape != input->shape) {
LOG(ERROR, "invalid output shape: ", output->shape);
}
if (input->type == ttype) {
LOG(ERROR, "casting to the same type: ", ttype);
}
if (ttype == BYTE) {
LOG(ERROR, "casting to BYTE with a specified output tensor is not "
"supported as it implies a memory copy.");
}
}
CastOp op{input, output, name};
return this->impl->add_op(op)[0];
Expand Down
Loading

0 comments on commit a97a562

Please sign in to comment.