diff --git a/.vscode/launch.json b/.vscode/launch.json index 993a0d8dd..519417bfa 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -1,10 +1,10 @@ { "configurations": [ { - "name": "ops_matmul_test", + "name": "ops_cast_test", "type": "cppdbg", "request": "launch", - "program": "${workspaceFolder}/build/ark/ops_matmul_test.cu", + "program": "${workspaceFolder}/build/ark/ops_cast_test", "args": [], "stopAtEntry": false, "cwd": "${fileDirname}", @@ -12,6 +12,10 @@ { "name": "ARK_ROOT", "value": "${workspaceFolder}/build" + }, + { + "name": "ARK_LOG_LEVEL", + "value": "DEBUG" } ], "externalConsole": false, diff --git a/ark/gpu/gpu_mem.cc b/ark/gpu/gpu_mem.cc index 991741339..7bac8777e 100644 --- a/ark/gpu/gpu_mem.cc +++ b/ark/gpu/gpu_mem.cc @@ -51,8 +51,6 @@ static int mem_expose(ExposalInfo *info, GpuPtr addr, uint64_t bytes) LOG(ERROR, "gpumem driver is not loaded"); } - int flag = 1; - CULOG(cuPointerSetAttribute(&flag, CU_POINTER_ATTRIBUTE_SYNC_MEMOPS, addr)); // Convert virtual into physical address. int fd = open(GPUMEM_DRIVER_PATH, O_RDWR, 0); if (fd < 0) { @@ -163,6 +161,9 @@ void GpuMem::init(size_t bytes, bool expose) addr_ = (CUdeviceptr)(((uint64_t)raw_addr_ + GPU_PAGE_OFFSET) & GPU_PAGE_MASK); + int one = 1; + CULOG(cuPointerSetAttribute(&one, CU_POINTER_ATTRIBUTE_SYNC_MEMOPS, addr_)); + ExposalInfo exp_info; if (expose) { int err = mem_expose(&exp_info, addr_, bytes + GPU_PAGE_SIZE); diff --git a/ark/gpu/gpu_mgr.cc b/ark/gpu/gpu_mgr.cc index 179aafae2..fa28a6eaa 100644 --- a/ark/gpu/gpu_mgr.cc +++ b/ark/gpu/gpu_mgr.cc @@ -364,7 +364,7 @@ void GpuMgrCtx::reg_sendrecv(int sid, int remote_gpu_id, size_t bytes, } // -void GpuMgrCtx::freeze() +void GpuMgrCtx::freeze(bool expose) { // this->gpu_mgr->validate_total_bytes(); @@ -372,7 +372,7 @@ void GpuMgrCtx::freeze() // if (total_bytes > 0) { LOG(INFO, "Allocating ", total_bytes, " bytes of GPU memory"); - this->data_mem.init(total_bytes, false); + this->data_mem.init(total_bytes, expose); // init the data mem CULOG(cuMemsetD32(this->data_mem.ref(), 0, total_bytes >> 2)); } diff --git a/ark/gpu/gpu_mgr.h b/ark/gpu/gpu_mgr.h index f1ac8b6f5..40dbcae0a 100644 --- a/ark/gpu/gpu_mgr.h +++ b/ark/gpu/gpu_mgr.h @@ -113,7 +113,7 @@ class GpuMgrCtx void mem_export(GpuBuf *buf, size_t offset, int sid); GpuBuf *mem_import(size_t bytes, int sid, int gpu_id); void reg_sendrecv(int sid, int gpu_dst, std::size_t bytes, bool is_recv); - void freeze(); + void freeze(bool expose = false); // void send(int sid, int rank, size_t bytes); GpuState set_current(); int get_world_size() const diff --git a/ark/gpu/gpu_mgr_test.cc b/ark/gpu/gpu_mgr_test.cc index cb673b516..f79d124ee 100644 --- a/ark/gpu/gpu_mgr_test.cc +++ b/ark/gpu/gpu_mgr_test.cc @@ -144,7 +144,7 @@ unittest::State test_gpu_mgr_remote() GpuBuf *gpu1_eid5 = ctx->mem_import(sizeof(int), 5, 1); GpuBuf *gpu1_eid6 = ctx->mem_import(sizeof(int), 6, 1); - ctx->freeze(); + ctx->freeze(true); volatile int *ptr = (volatile int *)gpu0_eid3->href(); while (*ptr != 7890) { @@ -176,7 +176,7 @@ unittest::State test_gpu_mgr_remote() GpuBuf *gpu0_eid3 = ctx->mem_import(sizeof(int), 3, 0); GpuBuf *gpu0_eid4 = ctx->mem_import(sizeof(int), 4, 0); - ctx->freeze(); + ctx->freeze(true); gpu_memset(gpu0_eid3, 7890, 1); diff --git a/ark/include/ark.h b/ark/include/ark.h index 98d4c70cd..8ad653207 100644 --- a/ark/include/ark.h +++ b/ark/include/ark.h @@ -107,16 +107,59 @@ class TensorBuf friend class BaseScheduler; }; -// Type of tensor data. -typedef enum +/// Type of tensor data. +class TensorType { - FP16, - FP32, - INT32, - BYTE, -} TensorType; + private: + const int id_; + const int bytes_; + const std::string name_; + const std::string pointer_name_; + + public: + TensorType(int id = -1, int bytes = 0, const std::string &name = "none", + const std::string &pointer_name = "void *"); + + bool operator==(const TensorType &other) const; + bool operator!=(const TensorType &other) const; + + int id() const; + int bytes() const; + const std::string &name() const; + const std::string &pointer_name() const; +}; + +class Fp16 : public TensorType +{ + public: + Fp16(); +}; + +class Fp32 : public TensorType +{ + public: + Fp32(); +}; + +class Int32 : public TensorType +{ + public: + Int32(); +}; + +class Byte : public TensorType +{ + public: + Byte(); +}; + +const TensorType NONE; +const Fp16 FP16; +const Fp32 FP32; +const Int32 INT32; +const Byte BYTE; -std::ostream &operator<<(std::ostream &os, TensorType type); +std::ostream &operator<<(std::ostream &os, const TensorType &type); /// Tensor is a view of a TensorBuf. /// @@ -134,7 +177,7 @@ class Tensor { public: /// Tensor constructor. - Tensor(const Dims &shape, TensorType type, TensorBuf *buf, + Tensor(const Dims &shape, const TensorType &type, TensorBuf *buf, const Dims &ldims, const Dims &offs, const Dims &pads, bool exported, int imported_rank, int id, const std::string &name); Tensor(const Tensor &) = default; @@ -273,7 +316,7 @@ class Model /// Returns a tensor object. /// /// @param shape Shape of the tensor, where the data of interest is. - /// @param type Type of the tensor data. + /// @param ttype Type of the tensor data. /// @param buf The @ref TensorBuf that holds the entire data including the /// padding. /// @param ldims Leading dimensions (ldim) of the tensor, which may be @@ -300,7 +343,7 @@ class Model /// @param name Name of the tensor. /// @return Pointer to a tensor object. /// - Tensor *tensor(const Dims &shape, TensorType dtype, + Tensor *tensor(const Dims &shape, const TensorType &ttype, TensorBuf *buf = nullptr, const Dims &ldims = {}, const Dims &offs = {}, const Dims &pads = {}, const std::vector &deps = {}, @@ -471,6 +514,9 @@ class Model /// Embedding layer. Tensor *embedding(Tensor *input, Tensor *weight, Tensor *output = nullptr, const std::string &name = "embedding"); + /// Tensor type casting. + Tensor *cast(Tensor *input, const TensorType &ttype, + Tensor *output = nullptr, const std::string &name = "cast"); /// Verify if this model is valid. /// @return true if the model is valid, false otherwise. diff --git a/ark/include/kernels/activation.h b/ark/include/kernels/activation.h index 579504dbb..aa70987b3 100644 --- a/ark/include/kernels/activation.h +++ b/ark/include/kernels/activation.h @@ -77,7 +77,8 @@ struct Activation; template struct Activation<_ActivationType, _InShape, half, 2> { - using DataType = half; + using InputType = half; + using OutputType = half; static const int NelemPerThread = 2; static DEVICE void compute(half *output, const half *input) @@ -96,7 +97,8 @@ struct Activation<_ActivationType, _InShape, half, 2> template struct Activation<_ActivationType, _InShape, float, 1> { - using DataType = float; + using InputType = float; + using OutputType = float; static const int NelemPerThread = 1; static DEVICE void compute(float *output, const float *input) diff --git a/ark/include/kernels/arithmetic.h b/ark/include/kernels/arithmetic.h index 65426c807..6b35e8002 100644 --- a/ark/include/kernels/arithmetic.h +++ b/ark/include/kernels/arithmetic.h @@ -92,11 +92,12 @@ template struct Arithmetic { - using DataType = _DataType; + using InputType = _DataType; + using OutputType = _DataType; static const int NelemPerThread = _NelemPerThread; - static DEVICE void compute(DataType *c, const DataType *a, - const DataType *b) + static DEVICE void compute(_DataType *c, const _DataType *a, + const _DataType *b) { *c = *a + *b; if (_In0Shape::W == 1) { @@ -121,7 +122,8 @@ struct Arithmetic template struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, float, 2> { - using DataType = float; + using InputType = float; + using OutputType = float; static const int NelemPerThread = 2; static DEVICE void compute(float *c, const float *a, const float *b) @@ -147,7 +149,8 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, float, 2> template struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, float, 4> { - using DataType = float; + using InputType = float; + using OutputType = float; static const int NelemPerThread = 4; static DEVICE void compute(float *c, const float *a, const float *b) @@ -218,7 +221,8 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, float, 4> template struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, half, 2> { - using DataType = half; + using InputType = half; + using OutputType = half; static const int NelemPerThread = 2; static DEVICE void compute(half *c, const half *a, const half *b) @@ -243,7 +247,8 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, half, 2> template struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, half, 4> { - using DataType = half; + using InputType = half; + using OutputType = half; static const int NelemPerThread = 4; static DEVICE void compute(half *c, const half *a, const half *b) @@ -283,7 +288,8 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, half, 4> template struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, half, 8> { - using DataType = half; + using InputType = half; + using OutputType = half; static const int NelemPerThread = 8; static DEVICE void compute(half *c, const half *a, const half *b) diff --git a/ark/include/kernels/ark_kernels.h b/ark/include/kernels/ark_kernels.h index 4264c3522..56ff8babe 100644 --- a/ark/include/kernels/ark_kernels.h +++ b/ark/include/kernels/ark_kernels.h @@ -11,6 +11,7 @@ #include "activation.h" #include "arithmetic.h" +#include "cast.h" #include "comm.h" #include "comm_mm.h" #include "embedding.h" diff --git a/ark/include/kernels/broadcast.h b/ark/include/kernels/broadcast.h index 830c17cf0..daf9d06c0 100644 --- a/ark/include/kernels/broadcast.h +++ b/ark/include/kernels/broadcast.h @@ -73,7 +73,8 @@ struct Broadcast1 { using UnitOp = UnitOp; - using DataType = typename CompType::DataType; + using InputType = typename CompType::InputType; + using OutputType = typename CompType::OutputType; static const int NelemPerThread = CompType::NelemPerThread; static_assert(NelemPerThread > 0, "NelemPerThread must be positive"); @@ -84,7 +85,7 @@ struct Broadcast1 /// @param out Output data. /// @param in1 Input data. /// @param uop_idx Index of the unit operator. - static DEVICE void run(DataType *out, const DataType *in, int uop_idx) + static DEVICE void run(OutputType *out, const InputType *in, int uop_idx) { using InOutChk = BroadcastShapeChecker1; @@ -141,7 +142,8 @@ struct Broadcast2 { using UnitOp = UnitOp; - using DataType = typename CompType::DataType; + using InputType = typename CompType::InputType; + using OutputType = typename CompType::OutputType; static const int NelemPerThread = CompType::NelemPerThread; static_assert(NelemPerThread > 0, "NelemPerThread must be positive"); @@ -153,8 +155,8 @@ struct Broadcast2 /// @param in0 Input data 0. /// @param in1 Input data 1. /// @param uop_idx Index of the unit operator. - static DEVICE void run(DataType *out, const DataType *in0, - const DataType *in1, int uop_idx) + static DEVICE void run(OutputType *out, const InputType *in0, + const InputType *in1, int uop_idx) { using InOutChk = BroadcastShapeChecker2; diff --git a/ark/include/kernels/cast.h b/ark/include/kernels/cast.h new file mode 100644 index 000000000..e8a61674e --- /dev/null +++ b/ark/include/kernels/cast.h @@ -0,0 +1,177 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef ARK_KERNELS_CAST_H_ +#define ARK_KERNELS_CAST_H_ + +#include "broadcast.h" + +namespace ark { + +template +struct Cast; + +template struct Cast<_InShape, half, float, 2> +{ + using InputType = half; + using OutputType = float; + static const int NelemPerThread = 2; + + static DEVICE void compute(float *output, const half *input) + { + if constexpr (_InShape::W == 1) { + *output = __half2float(*(const __half *)input); + } else { + float2 *pout = (float2 *)output; + __half2 *pin = (__half2 *)input; + *pout = __half22float2(*pin); + } + } +}; + +template struct Cast<_InShape, int, float, 2> +{ + using InputType = int; + using OutputType = float; + static const int NelemPerThread = 2; + + static DEVICE void compute(float *output, const int *input) + { + if constexpr (_InShape::W == 1) { + *output = float(*input); + } else { + float2 *pout = (float2 *)output; + int2 *pin = (int2 *)input; + pout->x = float(pin->x); + pout->y = float(pin->y); + } + } +}; + +template struct Cast<_InShape, float, half, 2> +{ + using InputType = float; + using OutputType = half; + static const int NelemPerThread = 2; + + static DEVICE void compute(half *output, const float *input) + { + if constexpr (_InShape::W == 1) { + *output = __float2half_rn(*input); + } else { + __half2 *pout = (__half2 *)output; + float2 *pin = (float2 *)input; + *pout = __float22half2_rn(*pin); + } + } +}; + +template struct Cast<_InShape, int, half, 2> +{ + using InputType = int; + using OutputType = half; + static const int NelemPerThread = 2; + + static DEVICE void compute(half *output, const int *input) + { + if constexpr (_InShape::W == 1) { + *output = __int2half_rn(*input); + } else { + __half2 *pout = (__half2 *)output; + int2 *pin = (int2 *)input; + *pout = + __halves2half2(__int2half_rn(pin->x), __int2half_rn(pin->y)); + } + } +}; + +template struct Cast<_InShape, float, int, 2> +{ + using InputType = float; + using OutputType = int; + static const int NelemPerThread = 2; + + static DEVICE void compute(int *output, const float *input) + { + if constexpr (_InShape::W == 1) { + *output = int(*input); + } else { + int2 *pout = (int2 *)output; + float2 *pin = (float2 *)input; + pout->x = int(pin->x); + pout->y = int(pin->y); + } + } +}; + +template struct Cast<_InShape, half, int, 2> +{ + using InputType = half; + using OutputType = int; + static const int NelemPerThread = 2; + + static DEVICE void compute(int *output, const half *input) + { + if constexpr (_InShape::W == 1) { + *output = __half2int_rn(*(const __half *)input); + } else { + int2 *pout = (int2 *)output; + __half2 *pin = (__half2 *)input; + pout->x = __half2int_rn(__low2half(*pin)); + pout->y = __half2int_rn(__high2half(*pin)); + } + } +}; + +template +DEVICE void cast(float *out, half *in, int uop_idx, int) +{ + Broadcast1>::run(out, in, uop_idx); +} + +template +DEVICE void cast(float *out, int *in, int uop_idx, int) +{ + Broadcast1>::run(out, in, uop_idx); +} + +template +DEVICE void cast(half *out, float *in, int uop_idx, int) +{ + Broadcast1>::run(out, in, uop_idx); +} + +template +DEVICE void cast(half *out, int *in, int uop_idx, int) +{ + Broadcast1>::run(out, in, uop_idx); +} + +template +DEVICE void cast(int *out, float *in, int uop_idx, int) +{ + Broadcast1>::run(out, in, uop_idx); +} + +template +DEVICE void cast(int *out, half *in, int uop_idx, int) +{ + Broadcast1>::run(out, in, uop_idx); +} + +} // namespace ark + +#endif // ARK_KERNELS_CAST_H_ diff --git a/ark/include/kernels/embedding.h b/ark/include/kernels/embedding.h index d5cd29ecc..8b876b7f4 100644 --- a/ark/include/kernels/embedding.h +++ b/ark/include/kernels/embedding.h @@ -14,7 +14,8 @@ template struct RoPE; template <> struct RoPE { - using DataType = float; + using InputType = float; + using OutputType = float; static const int NelemPerThread = 2; static DEVICE void compute(float *c, const float *a, const float *b) { @@ -28,7 +29,8 @@ template <> struct RoPE template <> struct RoPE { - using DataType = half; + using InputType = half; + using OutputType = half; static const int NelemPerThread = 2; static DEVICE void compute(half *c, const half *a, const half *b) { @@ -64,9 +66,10 @@ DEVICE void rope(half *c, half *a, half *b, int uop_idx, int) template struct Assign { - using DataType = _DataType; + using InputType = _DataType; + using OutputType = _DataType; static const int NelemPerThread = 1; - static DEVICE void compute(DataType *c, const DataType *a) + static DEVICE void compute(_DataType *c, const _DataType *a) { *c = *a; } diff --git a/ark/include/kernels/math_functions.h b/ark/include/kernels/math_functions.h index 62872fcfc..f39243d3c 100644 --- a/ark/include/kernels/math_functions.h +++ b/ark/include/kernels/math_functions.h @@ -39,7 +39,8 @@ struct Math; template struct Math<_MathType, _InShape, half, 2> { - using DataType = half; + using InputType = half; + using OutputType = half; static const int NelemPerThread = 2; static DEVICE void compute(half *output, const half *input) @@ -57,7 +58,8 @@ struct Math<_MathType, _InShape, half, 2> template struct Math<_MathType, _InShape, float, 1> { - using DataType = float; + using InputType = float; + using OutputType = float; static const int NelemPerThread = 1; static DEVICE void compute(float *output, const float *input) diff --git a/ark/ops/ops_all_reduce_test.cc b/ark/ops/ops_all_reduce_test.cc index 3b30e5ad8..7af9b8e91 100644 --- a/ark/ops/ops_all_reduce_test.cc +++ b/ark/ops/ops_all_reduce_test.cc @@ -38,8 +38,9 @@ void test_all_reduce_4gpus_internal(size_t nelem, int iter) auto result = ark::op_test("all_reduce", m, {ones}, {output}, baseline_all_reduce, - {ones_data.get()}, true, gpu_id, num_gpus); + {ones_data.get()}, false, gpu_id, num_gpus); ark::op_test_log(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); return ark::unittest::SUCCESS; }); } diff --git a/ark/ops/ops_cast.cc b/ark/ops/ops_cast.cc new file mode 100644 index 000000000..94e7b5fda --- /dev/null +++ b/ark/ops/ops_cast.cc @@ -0,0 +1,177 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "logging.h" +#include "model.h" +#include + +namespace ark { + +extern const OpConfigMap CastConfigMap; + +CastOp::CastOp(Tensor *input, Tensor *output, const std::string &name) + : Op{OP_CAST, OP_PREC_NONE, {input}, {output}, {}, + name, &CastConfigMap, -1, true} +{ +} + +std::string CastOp::function_name(const OpConfig &cfg) const +{ + Tensor *input = this->inputs[0]; + Tensor *output = this->outputs[0]; + + int ndims = output->shape.ndims(); + const OpTile &tile_out = cfg.output_tiles[0]; + CHECK(output->ldims[ndims - 1] % tile_out.y == 0); + if (ndims > 1) { + CHECK(output->ldims[ndims - 2] % tile_out.x == 0); + } else { + CHECK(tile_out.x == 1); + } + + Dims unit_out_dims{1, 1, tile_out.x, tile_out.y}; + return Op::function_name("ark::cast", + {{ + input->ldims.dims4(), // InDims + input->shape.dims4(), // InShape + output->ldims.dims4(), // OutDims + output->shape.dims4(), // OutShape + unit_out_dims, // UnitOutDims + cfg.num_warps * 32, // NumThreads + }}); +} + +Tensor *Model::cast(Tensor *input, const TensorType &ttype, Tensor *output, + const std::string &name) +{ + assert(input != nullptr); + if (output == nullptr) { + if (input->type == ttype) { + // Casting to the same type is considered as an identity, + // only when the output tensor is not specified. + return this->identity(input, {}, name); + } + if (input->type == BYTE) { + // Casting BYTE to other types is considered as a reshape. + if (input->shape_bytes() < ttype.bytes()) { + LOG(ERROR, "input tensor is too small to be casted to ", ttype); + } + // The last greater-than-1 dimension of the input tensor should be + // divisible by the size of the output type. + int last_dim = input->shape.ndims() - 1; + for (; last_dim >= 0; --last_dim) { + if (last_dim == 0 || input->ldims[last_dim] > 1) { + break; + } + } + if ((input->shape[last_dim] % ttype.bytes()) != 0) { + LOG(ERROR, + "the last greater-than-1 dimension of the " + "input tensor shape ", + input->shape[last_dim], + " is not divisible by the size of the output " + "tensor type (", + ttype.bytes(), ")"); + } + if ((input->ldims[last_dim] % ttype.bytes()) != 0) { + LOG(ERROR, + "the last greater-than-1 dimension of the " + "input tensor ldims ", + input->ldims[last_dim], + " is not divisible by the size of the output " + "tensor type (", + ttype.bytes(), ")"); + } + if ((input->offs[last_dim] % ttype.bytes()) != 0) { + LOG(ERROR, + "the last greater-than-1 dimension of the " + "input tensor offs ", + input->offs[last_dim], + " is not divisible by the size of the output " + "tensor type (", + ttype.bytes(), ")"); + } + if (input->pads[last_dim] > 1) { + // we can ignore pads if it is 1 + if ((input->pads[last_dim] % ttype.bytes()) != 0) { + LOG(ERROR, + "the last greater-than-1 dimension of the " + "input tensor pads ", + input->pads[last_dim], + " is not divisible by the size of the output " + "tensor type (", + ttype.bytes(), ")"); + } + } + + auto out_shape = input->shape; + auto out_ldims = input->ldims; + auto out_offs = input->offs; + auto out_pads = input->pads; + out_shape[last_dim] /= ttype.bytes(); + out_ldims[last_dim] /= ttype.bytes(); + out_offs[last_dim] /= ttype.bytes(); + if (out_pads[last_dim] > 1) { + out_pads[last_dim] /= ttype.bytes(); + } + return this->tensor(out_shape, ttype, input->buf, out_ldims, + out_offs, out_pads, {input}, input->exported, + input->imported_rank, name + "/cast"); + } + if (ttype == BYTE) { + // Casting other types to BYTE is considered as a reshape. + auto out_shape = input->shape; + auto out_ldims = input->ldims; + auto out_offs = input->offs; + auto out_pads = input->pads; + out_shape[-1] *= input->type.bytes(); + out_ldims[-1] *= input->type.bytes(); + out_offs[-1] *= input->type.bytes(); + if (out_pads[-1] > 1) { + out_pads[-1] *= input->type.bytes(); + } + return this->tensor(out_shape, ttype, input->buf, out_ldims, + out_offs, out_pads, {input}, input->exported, + input->imported_rank, name + "/cast"); + } + output = this->tensor(input->shape, ttype); + } else { + if (output->type != ttype) { + LOG(ERROR, "invalid output data type: ", output->type); + } + if (output->shape != input->shape) { + LOG(ERROR, "invalid output shape: ", output->shape); + } + if (input->type == ttype) { + LOG(ERROR, "casting to the same type: ", ttype); + } + if (ttype == BYTE) { + LOG(ERROR, "casting to BYTE with a specified output tensor is not " + "supported as it implies a memory copy."); + } + } + CastOp op{input, output, name}; + return this->impl->add_op(op)[0]; +} + +const OpConfigMap CastConfigMap = { + {{OP_ARCH_CUDA_ANY, OP_PREC_NONE}, + { + // NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost + {8, 0, {{128, 256}}, {{128, 256}}, false, false}, + {8, 0, {{256, 128}}, {{256, 128}}, false, false}, + {8, 0, {{128, 128}}, {{128, 128}}, false, false}, + {4, 0, {{64, 64}}, {{64, 64}}, false, false}, + {2, 0, {{32, 64}}, {{32, 64}}, false, false}, + {1, 0, {{16, 64}}, {{16, 64}}, false, false}, + {1, 0, {{8, 64}}, {{8, 64}}, false, false}, + {1, 0, {{2, 128}}, {{2, 128}}, false, false}, + {1, 0, {{4, 64}}, {{4, 64}}, false, false}, + {1, 0, {{2, 64}}, {{2, 64}}, false, false}, + {1, 0, {{1, 128}}, {{1, 128}}, false, false}, + {1, 0, {{1, 64}}, {{1, 64}}, false, false}, + {1, 0, {{1, 32}}, {{1, 32}}, false, false}, + }}, +}; + +} // namespace ark diff --git a/ark/ops/ops_cast_test.cc b/ark/ops/ops_cast_test.cc new file mode 100644 index 000000000..042ece3b8 --- /dev/null +++ b/ark/ops/ops_cast_test.cc @@ -0,0 +1,265 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "include/ark.h" +#include "include/ark_utils.h" +#include "ops_test_common.h" +#include "unittest/unittest_utils.h" + +template +void baseline_cast(std::vector &outputs, + const std::vector &output_shapes, + const std::vector &inputs, + const std::vector &) +{ + ToType *out = static_cast(outputs[0]); + FromType *input = static_cast(inputs[0]); + ark::Dims osh = output_shapes[0]; + for (ark::DimType i = 0; i < osh.size(); ++i) { + out[i] = ToType(input[i]); + } +}; + +template +void baseline_cast_from_byte(std::vector &outputs, + const std::vector &output_shapes, + const std::vector &inputs, + const std::vector &) +{ + ToType *out = static_cast(outputs[0]); + // input is a byte array, but force read it as ToType. + ToType *input = reinterpret_cast(inputs[0]); + ark::Dims osh = output_shapes[0]; + for (ark::DimType i = 0; i < osh.size(); ++i) { + out[i] = input[i]; + } +}; + +template +void baseline_cast_to_byte(std::vector &outputs, + const std::vector &, + const std::vector &inputs, + const std::vector &input_shapes) +{ + // output is a byte array, but force write it as FromType. + FromType *out = reinterpret_cast(outputs[0]); + FromType *input = static_cast(inputs[0]); + ark::Dims ish = input_shapes[0]; + for (ark::DimType i = 0; i < ish.size(); ++i) { + out[i] = input[i]; + } +}; + +ark::unittest::State test_cast_fp16_to_fp32() +{ + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::FP16); + ark::Tensor *out = m.cast(t, ark::FP32); + + auto result = ark::op_test("cast_fp16_to_fp32", m, {t}, {out}, + baseline_cast); + ark::op_test_log(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_cast_fp16_to_int32() +{ + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::FP16); + ark::Tensor *out = m.cast(t, ark::INT32); + + std::vector input_data(t->shape.size()); + for (size_t i = 0; i < input_data.size(); ++i) { + input_data[i] = ark::half_t((i + 1) % 1000); + } + + auto result = + ark::op_test("cast_fp16_to_int32", m, {t}, {out}, + baseline_cast, {input_data.data()}); + ark::op_test_log(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_cast_fp32_to_fp16() +{ + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::FP32); + ark::Tensor *out = m.cast(t, ark::FP16); + + auto result = ark::op_test("cast_fp32_to_fp16", m, {t}, {out}, + baseline_cast); + ark::op_test_log(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_cast_fp32_to_int32() +{ + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::FP32); + ark::Tensor *out = m.cast(t, ark::INT32); + + std::vector input_data(t->shape.size()); + for (size_t i = 0; i < input_data.size(); ++i) { + input_data[i] = float((i + 1) % 1000); + } + + auto result = ark::op_test("cast_fp32_to_int32", m, {t}, {out}, + baseline_cast, {input_data.data()}); + ark::op_test_log(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_cast_int32_to_fp32() +{ + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::INT32); + ark::Tensor *out = m.cast(t, ark::FP32); + + std::vector input_data(t->shape.size()); + for (size_t i = 0; i < input_data.size(); ++i) { + input_data[i] = (i + 1) % 1000; + } + + auto result = ark::op_test("cast_int32_to_fp32", m, {t}, {out}, + baseline_cast, {input_data.data()}); + ark::op_test_log(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_cast_int32_to_fp16() +{ + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::INT32); + ark::Tensor *out = m.cast(t, ark::FP16); + + std::vector input_data(t->shape.size()); + for (size_t i = 0; i < input_data.size(); ++i) { + input_data[i] = (i + 1) % 1000; + } + + auto result = + ark::op_test("cast_int32_to_fp16", m, {t}, {out}, + baseline_cast, {input_data.data()}); + ark::op_test_log(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_cast_byte_to_fp32() +{ + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::BYTE); + ark::Tensor *out = m.cast(t, ark::FP32); + + std::vector input_data(t->shape.size()); + for (size_t i = 0; i < input_data.size(); ++i) { + input_data[i] = (i + 1) % 256; + } + + auto result = + ark::op_test("cast_byte_to_fp32", m, {t}, {out}, + baseline_cast_from_byte, {input_data.data()}); + ark::op_test_log(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_cast_byte_to_fp16() +{ + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::BYTE); + ark::Tensor *out = m.cast(t, ark::FP16); + + std::vector input_data(t->shape.size()); + for (size_t i = 0; i < input_data.size(); ++i) { + input_data[i] = (i + 1) % 256; + } + + auto result = + ark::op_test("cast_byte_to_fp16", m, {t}, {out}, + baseline_cast_from_byte, {input_data.data()}); + ark::op_test_log(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_cast_byte_to_int32() +{ + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::BYTE); + ark::Tensor *out = m.cast(t, ark::INT32); + + std::vector input_data(t->shape.size()); + for (size_t i = 0; i < input_data.size(); ++i) { + input_data[i] = (i + 1) % 256; + } + + auto result = + ark::op_test("cast_byte_to_int32", m, {t}, {out}, + baseline_cast_from_byte, {input_data.data()}); + ark::op_test_log(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_cast_fp32_to_byte() +{ + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::FP32); + ark::Tensor *out = m.cast(t, ark::BYTE); + + auto result = ark::op_test("cast_fp32_to_byte", m, {t}, {out}, + baseline_cast_to_byte); + ark::op_test_log(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_cast_fp16_to_byte() +{ + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::FP16); + ark::Tensor *out = m.cast(t, ark::BYTE); + + auto result = ark::op_test("cast_fp16_to_byte", m, {t}, {out}, + baseline_cast_to_byte); + ark::op_test_log(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_cast_int32_to_byte() +{ + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(4, 2, 1024), ark::INT32); + ark::Tensor *out = m.cast(t, ark::BYTE); + + auto result = ark::op_test("cast_int32_to_byte", m, {t}, {out}, + baseline_cast_to_byte); + ark::op_test_log(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); + return ark::unittest::SUCCESS; +} + +int main() +{ + ark::init(); + UNITTEST(test_cast_fp16_to_fp32); + UNITTEST(test_cast_fp16_to_int32); + UNITTEST(test_cast_fp32_to_fp16); + UNITTEST(test_cast_fp32_to_int32); + UNITTEST(test_cast_int32_to_fp32); + UNITTEST(test_cast_int32_to_fp16); + UNITTEST(test_cast_byte_to_fp32); + UNITTEST(test_cast_byte_to_fp16); + UNITTEST(test_cast_byte_to_int32); + UNITTEST(test_cast_fp32_to_byte); + UNITTEST(test_cast_fp16_to_byte); + UNITTEST(test_cast_int32_to_byte); + return ark::unittest::SUCCESS; +} diff --git a/ark/ops/ops_common.cc b/ark/ops/ops_common.cc index 261e50356..804da911d 100644 --- a/ark/ops/ops_common.cc +++ b/ark/ops/ops_common.cc @@ -119,6 +119,7 @@ ostream &operator<<(ostream &os, const OpType &s) case OP_SQRT: os << "OP_SQRT"; break; case OP_ROPE: os << "OP_ROPE"; break; case OP_EMBEDDING: os << "OP_EMBEDDING"; break; + case OP_CAST: os << "OP_CAST"; break; } // clang-format on return os; @@ -508,6 +509,8 @@ std::string Op::function_name(const OpConfig &cfg) const return static_cast(this)->function_name(cfg); case OP_EMBEDDING: return static_cast(this)->function_name(cfg); + case OP_CAST: + return static_cast(this)->function_name(cfg); default: return ""; } diff --git a/ark/ops/ops_common.h b/ark/ops/ops_common.h index ef1a406d4..57aec6561 100644 --- a/ark/ops/ops_common.h +++ b/ark/ops/ops_common.h @@ -129,6 +129,7 @@ typedef enum OP_SEND_MM, OP_RECV_MM, OP_EMBEDDING, + OP_CAST, } OpType; /// Type of precision of @ref Op. @@ -557,6 +558,13 @@ class EmbeddingOp : public Op std::string function_name(const OpConfig &cfg) const; }; +class CastOp : public Op +{ + public: + CastOp(Tensor *input, Tensor *output, const std::string &name); + std::string function_name(const OpConfig &cfg) const; +}; + } // namespace ark #endif // ARK_OPS_COMMON_H_ diff --git a/ark/ops/ops_embedding_test.cc b/ark/ops/ops_embedding_test.cc index d594559dd..f5dc60e60 100644 --- a/ark/ops/ops_embedding_test.cc +++ b/ark/ops/ops_embedding_test.cc @@ -43,16 +43,16 @@ template ark::unittest::State test_embedding() const int num_emb = 1000; const int emb_dim = 8192; - ark::TensorType weight_type; + const ark::TensorType *weight_type; if (std::is_same::value) { - weight_type = ark::FP32; + weight_type = &ark::FP32; } else { - weight_type = ark::FP16; + weight_type = &ark::FP16; } ark::Model m; ark::Tensor *ti = m.tensor(ark::Dims(8, 3, 64), ark::INT32); - ark::Tensor *tw = m.tensor(ark::Dims(num_emb, emb_dim), weight_type); + ark::Tensor *tw = m.tensor(ark::Dims(num_emb, emb_dim), *weight_type); ark::Tensor *to = m.embedding(ti, tw); ark::srand(); diff --git a/ark/ops/ops_sendrecv_test.cc b/ark/ops/ops_sendrecv_test.cc index 4031ddf1a..64e492f58 100644 --- a/ark/ops/ops_sendrecv_test.cc +++ b/ark/ops/ops_sendrecv_test.cc @@ -8,8 +8,6 @@ #include "logging.h" #include "unittest/unittest_utils.h" -using namespace std; - void test_sendrecv_internal() { for (int gpu_id = 0; gpu_id < 2; ++gpu_id) { @@ -18,7 +16,7 @@ void test_sendrecv_internal() ark::Model model{gpu_id}; ark::Tensor *tns_x = model.tensor({1024}, ark::FP16); if (gpu_id == 0) { - model.send(tns_x, 0, 1, 1024); + model.send(tns_x, 0, 1, tns_x->shape_bytes()); model.send_done(tns_x, 0, 1); } if (gpu_id == 1) { @@ -28,6 +26,13 @@ void test_sendrecv_internal() ark::Executor exe{gpu_id, 2, model, "test_sendrecv"}; exe.compile(); + if (gpu_id == 0) { + std::vector data(1024); + for (int i = 0; i < 1024; ++i) { + data[i] = ark::half_t(i + 1); + } + tns_x->write(data.data()); + } exe.launch(); exe.run(1); exe.stop(); @@ -36,6 +41,15 @@ void test_sendrecv_internal() ark::IpcAllGather barrier{"test_sendrecv_barrier", gpu_id, 2, tmp, sizeof(int)}; barrier.sync(); + + if (gpu_id == 1) { + std::vector data(1024); + tns_x->read(data.data()); + for (int i = 0; i < 1024; ++i) { + UNITTEST_EQ(data[i], ark::half_t(i + 1)); + } + } + return ark::unittest::SUCCESS; }); } diff --git a/ark/ops/ops_tensor.cc b/ark/ops/ops_tensor.cc index 6c2f9bc14..30448bab5 100644 --- a/ark/ops/ops_tensor.cc +++ b/ark/ops/ops_tensor.cc @@ -13,16 +13,16 @@ TensorOp::TensorOp(const std::vector &deps, Tensor *output, { } -Tensor *Model::tensor(const Dims &shape, TensorType type, TensorBuf *buf, - const Dims &ldims, const Dims &offs, const Dims &pads, - const std::vector &deps, bool exported, - int imported_rank, const std::string &name) +Tensor *Model::tensor(const Dims &shape, const TensorType &ttype, + TensorBuf *buf, const Dims &ldims, const Dims &offs, + const Dims &pads, const std::vector &deps, + bool exported, int imported_rank, const std::string &name) { if (buf == nullptr) { buf = this->impl->create_tensor_buf(); } Tensor *ret = - new Tensor{shape, type, buf, + new Tensor{shape, ttype, buf, ldims, offs, pads, exported, imported_rank, (int)this->impl->tns_storage.size(), name}; diff --git a/ark/sched/sched_codegen.cc b/ark/sched/sched_codegen.cc index 97b907be6..fb2e4f630 100644 --- a/ark/sched/sched_codegen.cc +++ b/ark/sched/sched_codegen.cc @@ -75,17 +75,7 @@ std::ostream &CodeGenerator::sync_stream(std::ostream &os, int stream_id, ostream &CodeGenerator::tensor(ostream &os, const Tensor *tensor) const { size_t off = this->get_tensor_offset(tensor); - if (tensor->type == FP16) { - os << "(ark::half *)"; - } else if (tensor->type == FP32) { - os << "(float *)"; - } else if (tensor->type == INT32) { - os << "(int *)"; - } else if (tensor->type == BYTE) { - os << "(void *)"; - } else { - LOG(ERROR, "unknown tensor type"); - } + os << "(" << tensor->type.pointer_name() << ")"; std::string buf_name = ARK_BUF_NAME; if (tensor->imported_rank >= 0) { buf_name += std::to_string(tensor->imported_rank); @@ -100,23 +90,7 @@ std::ostream &CodeGenerator::def_oparg(std::ostream &os, const OpArg &arg, if (arg.type == OP_ARG_TENSOR) { Tensor *tns; arg.get(&tns); - switch (tns->type) { - case FP16: - os << "ark::half *" << name; - break; - case FP32: - os << "float *" << name; - break; - case INT32: - os << "int *" << name; - break; - case BYTE: - os << "void *" << name; - break; - default: - LOG(ERROR, "Not implemented"); - break; - } + os << tns->type.pointer_name() << name; } else if (arg.type == OP_ARG_FLOAT) { os << "float " << name; } else if (arg.type == OP_ARG_INT) { diff --git a/ark/tensor.cc b/ark/tensor.cc index ce4b97527..830bdf1e1 100644 --- a/ark/tensor.cc +++ b/ark/tensor.cc @@ -22,8 +22,66 @@ size_t TensorBuf::get_buf_offset() const return static_cast(this->buf)->get_offset(); } +TensorType::TensorType(int id, int bytes, const std::string &name, + const std::string &pointer_name) + : id_{id}, bytes_{bytes}, name_{name}, pointer_name_{pointer_name} +{ +} + +bool TensorType::operator==(const TensorType &other) const +{ + return id_ == other.id(); +} + +bool TensorType::operator!=(const TensorType &other) const +{ + return id_ != other.id(); +} + +int TensorType::id() const +{ + return id_; +} + +int TensorType::bytes() const +{ + return bytes_; +} + +const std::string &TensorType::name() const +{ + return name_; +} + +const std::string &TensorType::pointer_name() const +{ + return pointer_name_; +} + +std::ostream &operator<<(std::ostream &os, const TensorType &type) +{ + os << type.name(); + return os; +} + +Fp16::Fp16() : TensorType{0, 2, "fp16", "ark::half *"} +{ +} + +Fp32::Fp32() : TensorType{1, 4, "fp32", "float *"} +{ +} + +Int32::Int32() : TensorType{2, 4, "int32", "int *"} +{ +} + +Byte::Byte() : TensorType{3, 1, "byte", "char *"} +{ +} + // Tensor constructor -Tensor::Tensor(const Dims &shape_, TensorType type_, TensorBuf *buf_, +Tensor::Tensor(const Dims &shape_, const TensorType &type_, TensorBuf *buf_, const Dims &ldims_, const Dims &offs_, const Dims &pads_, bool exported_, int imported_rank_, int id_, const std::string &name_) @@ -148,16 +206,7 @@ int Tensor::ndims() const // Number of bytes of each element in the tensor. int Tensor::type_bytes() const { - if (this->type == FP16) { - return 2; - } else if (this->type == FP32) { - return 4; - } else if (this->type == INT32) { - return 4; - } else if (this->type == BYTE) { - return 1; - } - return 0; + return this->type.bytes(); } // Number of bytes of the tensor. @@ -392,26 +441,4 @@ void Tensor::clear() assert(done == num); } -std::ostream &operator<<(std::ostream &os, TensorType type) -{ - switch (type) { - case BYTE: - os << "byte"; - break; - case INT32: - os << "int32"; - break; - case FP16: - os << "fp16"; - break; - case FP32: - os << "fp32"; - break; - default: - os << "none"; - break; - } - return os; -} - } // namespace ark diff --git a/docs/install.md b/docs/install.md index d77e1b205..0ac2ff5ab 100644 --- a/docs/install.md +++ b/docs/install.md @@ -16,11 +16,6 @@ - Compute capability 9.0 support will be added in the future. -* To run ARK in a Docker container, we need to mount `/dev` and `/lib/modules` into the container so that the container can use `gpumem` driver. Add the following options in the `docker run` command: - ``` - -v /dev:/dev -v /lib/modules:/lib/modules - ``` - * Mellanox OFED ## Docker Images @@ -34,9 +29,25 @@ docker pull ghcr.io/microsoft/ark/ark:base-cuda12.1 Check [ARK containers](https://github.com/microsoft/ark/pkgs/container/ark%2Fark) for all available Docker images. +To run ARK in a Docker container, we need to mount `/dev` and `/lib/modules` into the container so that the container can use `gpumem` driver. Specifically, add `--privileged -v /dev:/dev -v /lib/modules:/lib/modules` in the `docker run` command. The following is an example. +``` +docker run \ + --privileged \ + --cap-add=ALL \ + --shm-size=1g \ + --ulimit memlock=-1 \ + --ulimit stack=67108864 \ + --net=host \ + --ipc=host \ + --gpus all \ + -v /dev:/dev \ + -v /lib/modules:/lib/modules \ + -it --name [Container Name] [Image Name] bash +``` + ## Install `gpudma` -*NOTE: if you are using a Docker container, the following steps should be done on the host.* +**NOTE: if you are using a Docker container, the steps in this section should be done on the host.** 1. Pull submodules. @@ -55,7 +66,7 @@ Check [ARK containers](https://github.com/microsoft/ark/pkgs/container/ark%2Fark 3. Load `gpumem` driver. ```bash - sudo insmod third_party/gpudma/module/gpumem.ko + sudo insmod gpudma/module/gpumem.ko sudo chmod 666 /dev/gpumem ``` @@ -83,7 +94,7 @@ Check [ARK containers](https://github.com/microsoft/ark/pkgs/container/ark%2Fark ```bash cd examples/tutorial - python3 tutorial.py + python3 quickstart_tutorial.py ``` ## (Optional) Install ARK C++ and Run Unit Tests @@ -97,7 +108,7 @@ If you want to use only the core C++ interfaces, follow the instructions below. ```bash mkdir build cd build - cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/usr/local .. + cmake -DCMAKE_BUILD_TYPE=Debug -DCMAKE_INSTALL_PREFIX=/usr/local .. ``` 2. Build ARK. diff --git a/docs/quickstart.md b/docs/quickstart.md index 7b99835fb..dfb3ecc93 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -8,45 +8,48 @@ Please refer to the [ARK Install Instructions](./install.md) to install ARK for ## Quick Start Tutorial -You can run a tutorial example at [tutorial.py](../examples/tutorial/quickstart_tutorial.py) to see how ARK works. +You can run a tutorial example at [quickstart_tutorial.py](../examples/tutorial/quickstart_tutorial.py) to see how ARK works. ```bash python examples/tutorial/quickstart_tutorial.py ``` -Before diving in, let's import the required modules and initialize ARK runtime: +Before diving in, let's initialize ARK runtime: ```python import ark import numpy as np # Initialize the ARK runtime -runtime = ark.Runtime() - +ark.init() ``` First, we need to create the operational graph for our model. In this example, we define a simple model with two input tensors. The output tensor is the sum of these input tensors. ```python M, N = 64, 64 # Create an input tensor -input_tensor = ark.tensor([M, N]) +input_tensor = ark.tensor([M, N], ark.fp32) # Create another tensor -other_tensor = ark.tensor([M, N]) +other_tensor = ark.tensor([M, N], ark.fp32) # Add the two tensors output_tensor = ark.add(input_tensor, other_tensor) ``` -Next, we need to launch the ARK runtime by calling `runtime.launch()`. This call will freeze the model, schedule GPU tasks, and allocate GPU memory. Then it will generate and compile the GPU kernel for the model. Finally, it will launch the GPU kernel that will be waiting for a `runtime.run()` call. Modifying the model after launching the runtime will take no effect. +Next, we need to construct and launch the ARK runtime. +```python +# Construct the ARK runtime +runtime = ark.Runtime() +# Launch the ARK runtime +runtime.launch() +``` +`runtime.launch()` will freeze the model, schedule GPU tasks, and allocate GPU memory. Then it will generate and compile the GPU kernel for the model. Finally, it will launch the GPU kernel that will be waiting for a `runtime.run()` call. Modifying the model after launching the runtime will take no effect. > **NOTE:** Note the difference from other GPU frameworks such as PyTorch. In PyTorch, each GPU kernel represents a single GPU task and a kernel launch will immediately start computation. In ARK, the GPU kernel represents the entire GPU tasks needed to run the model, throughout the entire lifetime of the model. Therefore, ARK launches only a single kernel and the kernel will be always running until the runtime stops. Instead of immediately starting computation after launch, the ARK kernel will run computation upon a `runtime.run()` call to ensure that the host side is ready to provide input data & read results. This design allows ARK to achieve better and stable performance by removing the overhead from the host side. Next, we need to initialize the input and output tensors. You can copy a numpy array into a tensor on GPU using `tensor.from_numpy(ndarray)`. Since `runtime.launch()` allocates GPU memory, it is necessary to call `runtime.launch()` before copying the tensor between the host and device. ```python -# Launch the ARK runtime -runtime.launch() - # Initialize the input and other tensor with random values input_tensor_host = np.random.rand(M, N).astype(np.float32) input_tensor.from_numpy(input_tensor_host) diff --git a/docs/tutorial/module_tutorial.md b/docs/tutorial/module_tutorial.md index 30440834a..6250e6b9d 100644 --- a/docs/tutorial/module_tutorial.md +++ b/docs/tutorial/module_tutorial.md @@ -45,8 +45,6 @@ class TestModelARK(ark.Module): Here, we can create this model and then launch it. ```python -# Initialize the ARK runtime -runtime = ark.Runtime() # Create an input tensor input_tensor = ark.tensor([batch_size, seq_len, d_model], ark.fp16) @@ -56,11 +54,14 @@ ark_model = TestModelARK() # Perform the forward pass output_tensor = ark_model(input_tensor) +# Construct the ARK runtime +runtime = ark.Runtime() + # Launch the ARK runtime runtime.launch() ``` -The initialization of the model can be done using a state_dict. Note that the parameters of this model in the state_dict must have the same name as the parameters defined in the module. Then, we can use `load_state_dict` to import the parameters of this model. +The initialization of the model can be done using a `state_dict`. Note that the parameters of this model in the `state_dict` must have the same name as the parameters defined in the module. Then, we can use `load_state_dict` to import the parameters of this model. ```python # Initialize the input tensor @@ -69,7 +70,7 @@ input_tensor_host = ( ).astype(np.float16) input_tensor.from_numpy(input_tensor_host) -# Initialize the parameters of the ARK module using numpy state_dict +# Initialize the parameters of the ARK module using numpy `state_dict` weight_1_host = ((np.random.rand(d_model, d_ff) - 0.5) * 0.1).astype( np.float16 ) @@ -85,7 +86,7 @@ state_dict = { ark_model.load_state_dict(state_dict) ``` -If needed, we can save this state_dict using `save`. We provide a set of modules for saving and loading this model's parameters using Python's `pickle` library. +If needed, we can save this `state_dict` using `save`. We provide a set of modules for saving and loading this model's parameters using Python's `pickle` library. ```python ark.save(ark_model.state_dict(), "test_model.pt") @@ -144,9 +145,9 @@ torch_input = torch.from_numpy(input_tensor_host_float32) torch_model = TestModelPytorch() ``` -We can also convert ARK's state_dict into a PyTorch state_dict. This way, we can directly import the parameters of this model into the corresponding PyTorch model. +We can also convert ARK's `state_dict` into a PyTorch `state_dict`. This way, we can directly import the parameters of this model into the corresponding PyTorch model. -ARK state_dict's format is +ARK `state_dict`'s format is ``` { "weight_1": weight_1_numpy, diff --git a/docs/tutorial/multi_gpu_tutorial.md b/docs/tutorial/multi_gpu_tutorial.md index 80af63d99..7d10aac6f 100644 --- a/docs/tutorial/multi_gpu_tutorial.md +++ b/docs/tutorial/multi_gpu_tutorial.md @@ -43,15 +43,15 @@ for process in processes: process.join() ``` -The following is the main function for the two processes. We first use `ark.Runtime(rank, world_size)` to create the ARK runtime. The `rank` parameter is the rank of the process, and the `world_size` parameter is the number of processes. In ARK, we assume that one process corresponds to one GPU. +The following is the main function for the two processes. We first set the `rank` and `world_size` of the current process. In ARK, we assume that one process corresponds to one GPU. ```python def sendrecv_test_ping_pong_function(rank, np_inputs): print("rank:", rank) - # Initialize the ARK runtime - runtime = ark.Runtime(rank, world_size) + ark.set_rank(rank) + ark.set_world_size(world_size) ``` @@ -62,8 +62,8 @@ For more information about the `send` and `recv` operator, please refer to the [ ```python # Define the behavior for rank 0 if rank == 0: - send_tensor = ark.tensor(ark.Dims(tensor_len), ark.FP16) - recv_tensor = ark.tensor(ark.Dims(tensor_len), ark.FP16) + send_tensor = ark.tensor([tensor_len], ark.fp16) + recv_tensor = ark.tensor([tensor_len], ark.fp16) # send the tensor to rank 1 send_id, dst_rank = 0, 1 @@ -84,7 +84,7 @@ The following is the model definition for GPU1. Here, GPU1 receives the tensor f # Define the behavior for rank 1 if rank == 1: # recv the tensor from rank 0 - recv_tensor = ark.tensor(ark.Dims(tensor_len), ark.FP16) + recv_tensor = ark.tensor([tensor_len], ark.fp16) recv_id, recv_rank = 0, 0 recv_dep = ark.recv(recv_tensor, recv_id, recv_rank) @@ -110,11 +110,14 @@ send_tensor = model.identity(recv_tensor, [recv_dep]) This is because the send operation must be executed after the recv operation. In the current scheduler, if this dependency is not specified, the send operation may be executed before the recv operation, causing an error. We will improve the scheduler in the future to automatically handle this situation. -Finally, we can use `runtime.launch()` to compile the kernel code and create contexts for each GPU. The connection between the two GPUs will be established automatically. +Finally, we can launch the runtime to compile the kernel code and create contexts for each GPU. The connection between the two GPUs will be established automatically. After we lauch the ARK model, we need to copy the send tensor to GPU0 to initialize the send tensor. Then we can run the ARK program. ```python + # Construct the ARK runtime + runtime = ark.Runtime() + # Launch the ARK runtime runtime.launch() @@ -139,4 +142,4 @@ Finally, we can copy the recv_tensor to the host to check the result. The recv_t mean_error = np.mean(np.abs(host_output - np_inputs)) print("max error:", max_error, "mean error:", mean_error) print("rank:", rank, "done") -``` \ No newline at end of file +``` diff --git a/examples/tutorial/model_tutorial.py b/examples/tutorial/model_tutorial.py deleted file mode 100644 index be40bcb97..000000000 --- a/examples/tutorial/model_tutorial.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import ark -import numpy as np - - -def model_tutorial(): - # Create a Model instance - model = ark.Model() - - # Create two tensors - input = model.tensor(ark.Dims(32), ark.TensorType.FP16) - other = model.tensor(ark.Dims(32), ark.TensorType.FP16) - - # Add input and other to get output tensor - output = model.add(input, other) - - # Create the executor instance, the scheduler will be created and - # start scheduling the model when the executor is created - exe = ark.Executor(0, 0, 1, model, "tutorial_model") - - # Compile the generated code from the code generator - exe.compile() - - # Initialize the input tensors - input_np = np.random.rand(1, 32).astype(np.float16) - other_np = np.random.rand(1, 32).astype(np.float16) - - input.from_numpy(input_np) - other.from_numpy(other_np) - - print("input: ", input_np) - print("other: ", other_np) - - # Launch the kernel and run for 1 iteration - exe.launch() - exe.run(1) - - # Wait for the kernel to finish - exe.stop() - - # Copy the output tensor back to host - output_np = np.zeros((1, 32), dtype=np.float16) - output.to_numpy(output_np) - - print("output: ", output_np) - - # test if the result is correct - assert np.allclose(output_np, input_np + other_np) - - max_error = np.max(np.abs(output_np - (input_np + other_np))) - mean_error = np.mean(np.abs(output_np - (input_np + other_np))) - - print("max error: ", max_error, "mean error: ", mean_error) - print("test_add passed") - - -if __name__ == "__main__": - model_tutorial() diff --git a/examples/tutorial/module_tutorial.py b/examples/tutorial/module_tutorial.py index 8b55c7313..4ce5c5ec9 100644 --- a/examples/tutorial/module_tutorial.py +++ b/examples/tutorial/module_tutorial.py @@ -30,7 +30,7 @@ class SubModuleARK(ark.Module): def __init__(self): super(SubModuleARK, self).__init__() # Define the parameters of the submodule - self.weight_2 = ark.parameter([d_ff, d_model], ark.FP16) + self.weight_2 = ark.parameter([d_ff, d_model], ark.fp16) def forward(self, inputs): # Perform the forward pass of the submodule @@ -42,7 +42,7 @@ class TestModelARK(ark.Module): def __init__(self): super(TestModelARK, self).__init__() # Define the parameters of the module - self.weight_1 = ark.parameter([d_model, d_ff], ark.FP16) + self.weight_1 = ark.parameter([d_model, d_ff], ark.fp16) # Create a submodule of the module self.submodule = SubModuleARK() @@ -86,10 +86,8 @@ def forward(self, inputs): # An example of using the ARK module def module_test(): - # Initialize the ARK runtime - runtime = ark.Runtime() # Create an input tensor - input_tensor = ark.tensor([batch_size, seq_len, d_model], ark.FP16) + input_tensor = ark.tensor([batch_size, seq_len, d_model], ark.fp16) # Create an ARK module ark_model = TestModelARK() @@ -97,6 +95,9 @@ def module_test(): # Perform the forward pass output_tensor = ark_model(input_tensor) + # Initialize the ARK runtime + runtime = ark.Runtime() + # Launch the ARK runtime runtime.launch() @@ -168,4 +169,5 @@ def module_test(): if __name__ == "__main__": + ark.init() module_test() diff --git a/examples/tutorial/multi_gpu_tutorial.py b/examples/tutorial/multi_gpu_tutorial.py index fc3e08b9b..69b12baf5 100644 --- a/examples/tutorial/multi_gpu_tutorial.py +++ b/examples/tutorial/multi_gpu_tutorial.py @@ -13,13 +13,13 @@ def sendrecv_test_ping_pong_function(rank, np_inputs): print("rank:", rank) - # Initialize the ARK runtime - runtime = ark.Runtime(rank, world_size) + ark.set_rank(rank) + ark.set_world_size(world_size) # Define the behavior for rank 0 if rank == 0: - send_tensor = ark.tensor(ark.Dims(tensor_len), ark.FP16) - recv_tensor = ark.tensor(ark.Dims(tensor_len), ark.FP16) + send_tensor = ark.tensor([tensor_len], ark.fp16) + recv_tensor = ark.tensor([tensor_len], ark.fp16) # send the tensor to rank 1 send_id, dst_rank = 0, 1 @@ -36,7 +36,7 @@ def sendrecv_test_ping_pong_function(rank, np_inputs): # Define the behavior for rank 1 if rank == 1: # recv the tensor from rank 0 - recv_tensor = ark.tensor(ark.Dims(tensor_len), ark.FP16) + recv_tensor = ark.tensor([tensor_len], ark.fp16) recv_id, recv_rank = 0, 0 recv_dep = ark.recv(recv_tensor, recv_id, recv_rank) @@ -53,6 +53,9 @@ def sendrecv_test_ping_pong_function(rank, np_inputs): ark.identity(send_tensor, [send_dep_tensor]), send_id, dst_rank ) + # Initialize the ARK runtime + runtime = ark.Runtime() + # Launch the ARK runtime runtime.launch() diff --git a/examples/tutorial/quickstart_tutorial.py b/examples/tutorial/quickstart_tutorial.py index e394b5cfc..da1894702 100644 --- a/examples/tutorial/quickstart_tutorial.py +++ b/examples/tutorial/quickstart_tutorial.py @@ -7,17 +7,20 @@ def quickstart_tutorial(): # Initialize the ARK runtime - runtime = ark.Runtime() + ark.init() M, N = 64, 64 # Create an input tensor - input_tensor = ark.tensor([M, N]) + input_tensor = ark.tensor([M, N], ark.fp32) # Create another tensor - other_tensor = ark.tensor([M, N]) + other_tensor = ark.tensor([M, N], ark.fp32) # Add the two tensors output_tensor = ark.add(input_tensor, other_tensor) + # Initialize the ARK runtime + runtime = ark.Runtime() + # Launch the ARK runtime runtime.launch() @@ -38,6 +41,8 @@ def quickstart_tutorial(): output_tensor_host, input_tensor_host + other_tensor_host ) + print("Quickstart tutorial is successful!") + if __name__ == "__main__": quickstart_tutorial() diff --git a/python/ark/__init__.py b/python/ark/__init__.py index beae034b4..4d0687980 100644 --- a/python/ark/__init__.py +++ b/python/ark/__init__.py @@ -56,6 +56,7 @@ def version(): all_gather, all_reduce, embedding, + cast, ) diff --git a/python/ark/data_type.py b/python/ark/data_type.py index 49ef312f5..21fe382a4 100644 --- a/python/ark/data_type.py +++ b/python/ark/data_type.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import numpy as np -from ._ark_core import _TensorType +from ._ark_core import _TensorType, _FP32, _FP16, _INT32, _BYTE class DataType: @@ -21,13 +21,13 @@ def from_numpy(np_type: np.dtype) -> "DataType": @staticmethod def from_ttype(ttype: _TensorType) -> "DataType": - if ttype == _TensorType.FP32: + if ttype == _FP32: return fp32 - elif ttype == _TensorType.FP16: + elif ttype == _FP16: return fp16 - elif ttype == _TensorType.INT32: + elif ttype == _INT32: return int32 - elif ttype == _TensorType.BYTE: + elif ttype == _BYTE: return byte else: raise NotImplementedError @@ -61,7 +61,7 @@ def to_numpy() -> np.float32: @staticmethod def ttype() -> _TensorType: - return _TensorType.FP32 + return _FP32 class fp16(DataType): @@ -71,7 +71,7 @@ def to_numpy() -> np.float16: @staticmethod def ttype() -> _TensorType: - return _TensorType.FP16 + return _FP16 class int32(DataType): @@ -81,7 +81,7 @@ def to_numpy() -> np.int32: @staticmethod def ttype() -> _TensorType: - return _TensorType.INT32 + return _INT32 class byte(DataType): @@ -91,4 +91,4 @@ def to_numpy() -> np.uint8: @staticmethod def ttype() -> _TensorType: - return _TensorType.BYTE + return _BYTE diff --git a/python/ark/model.py b/python/ark/model.py index ab7cd4460..c0e10ff64 100644 --- a/python/ark/model.py +++ b/python/ark/model.py @@ -848,3 +848,21 @@ def embedding( name, ) return Tensor(_tensor) + + +def cast( + input: Tensor, + dtype: DataType, + output: Tensor = None, + name: str = "cast", +) -> Tensor: + """Type casting.""" + if output is not None: + output = output._tensor + _tensor = Model.get_model().cast( + input._tensor, + dtype.ttype(), + output, + name, + ) + return Tensor(_tensor) diff --git a/python/bindings.cpp b/python/bindings.cpp index 88a4b6340..b4d3a57b5 100644 --- a/python/bindings.cpp +++ b/python/bindings.cpp @@ -74,13 +74,29 @@ PYBIND11_MODULE(_ark_core, m) return os.str(); }); - py::enum_( - m, "_TensorType", "Type of tensor data. FP16, FP32, INT32, or BYTE") - .value("FP16", ark::TensorType::FP16) - .value("FP32", ark::TensorType::FP32) - .value("INT32", ark::TensorType::INT32) - .value("BYTE", ark::TensorType::BYTE) - .export_values(); + py::class_(m, "_TensorType", "Type of tensor data.") + .def(pybind11::self == pybind11::self) + .def(pybind11::self != pybind11::self) + .def("bytes", &ark::TensorType::bytes, "Number of bytes of this type.") + .def("name", &ark::TensorType::name, "Name of this type."); + + py::class_(m, "_Fp16", "16-bit floating point.") + .def(py::init<>()); + + py::class_(m, "_Fp32", "32-bit floating point.") + .def(py::init<>()); + + py::class_(m, "_Int32", "32-bit integer.") + .def(py::init<>()); + + py::class_(m, "_Byte", "8-bit integer.") + .def(py::init<>()); + + m.attr("_FP16") = py::cast(&ark::FP16, py::return_value_policy::reference); + m.attr("_FP32") = py::cast(&ark::FP32, py::return_value_policy::reference); + m.attr("_INT32") = + py::cast(&ark::INT32, py::return_value_policy::reference); + m.attr("_BYTE") = py::cast(&ark::BYTE, py::return_value_policy::reference); py::class_(m, "_TensorBuf", "TensorBuf refers to a data array that can be " @@ -92,9 +108,9 @@ PYBIND11_MODULE(_ark_core, m) .def_readwrite("immutable", &ark::TensorBuf::immutable); py::class_(m, "_Tensor") - .def(py::init(), + .def(py::init(), py::arg("shape"), py::arg("type"), py::arg("buf"), py::arg("ldims"), py::arg("offs"), py::arg("pads"), py::arg("exported"), py::arg("imported_rank"), py::arg("id"), @@ -146,7 +162,7 @@ PYBIND11_MODULE(_ark_core, m) .def("tensor", &ark::Model::tensor, "construct a tensor with given shape and data type.", py::return_value_policy::reference_internal, py::arg("shape"), - py::arg("dtype"), py::arg("buf") = nullptr, + py::arg("ttype"), py::arg("buf") = nullptr, py::arg("ldims") = ark::Dims(), py::arg("offs") = ark::Dims(), py::arg("pads") = ark::Dims(), py::arg("deps") = std::vector(), @@ -358,7 +374,11 @@ PYBIND11_MODULE(_ark_core, m) .def("embedding", &ark::Model::embedding, "Embedding layer.", py::return_value_policy::reference_internal, py::arg("input"), py::arg("weight"), py::arg("output") = nullptr, - py::arg("name") = "embedding"); + py::arg("name") = "embedding") + .def("cast", &ark::Model::cast, "Tensor type casting.", + py::return_value_policy::reference_internal, py::arg("input"), + py::arg("ttype"), py::arg("output") = nullptr, + py::arg("name") = "cast"); py::class_(m, "_Executor", "Convenience class for executing a model.") diff --git a/python/unittest/test_api.py b/python/unittest/test_api.py index fffdb0f94..bfadc8b37 100644 --- a/python/unittest/test_api.py +++ b/python/unittest/test_api.py @@ -32,8 +32,8 @@ def convert_state_dict(state_dict: dict, type="numpy"): class TestModelARK(ark.Module): def __init__(self): super(TestModelARK, self).__init__() - self.weight_1 = ark.parameter(ark.Dims(d_model, d_ff), ark.TensorType.FP16) - self.weight_2 = ark.parameter(ark.Dims(d_ff, d_model), ark.TensorType.FP16) + self.weight_1 = ark.parameter([d_model, d_ff], ark.fp16) + self.weight_2 = ark.parameter([d_ff, d_model], ark.fp16) def forward(self, inputs): output = ark.matmul(inputs, self.weight_1)