Skip to content

Commit

Permalink
Merge branch 'main' into chhwang/llama
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang authored Sep 25, 2023
2 parents 93a303f + 5719df8 commit a6912de
Show file tree
Hide file tree
Showing 36 changed files with 965 additions and 245 deletions.
8 changes: 6 additions & 2 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
{
"configurations": [
{
"name": "ops_matmul_test",
"name": "ops_cast_test",
"type": "cppdbg",
"request": "launch",
"program": "${workspaceFolder}/build/ark/ops_matmul_test.cu",
"program": "${workspaceFolder}/build/ark/ops_cast_test",
"args": [],
"stopAtEntry": false,
"cwd": "${fileDirname}",
"environment": [
{
"name": "ARK_ROOT",
"value": "${workspaceFolder}/build"
},
{
"name": "ARK_LOG_LEVEL",
"value": "DEBUG"
}
],
"externalConsole": false,
Expand Down
5 changes: 3 additions & 2 deletions ark/gpu/gpu_mem.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ static int mem_expose(ExposalInfo *info, GpuPtr addr, uint64_t bytes)
LOG(ERROR, "gpumem driver is not loaded");
}

int flag = 1;
CULOG(cuPointerSetAttribute(&flag, CU_POINTER_ATTRIBUTE_SYNC_MEMOPS, addr));
// Convert virtual into physical address.
int fd = open(GPUMEM_DRIVER_PATH, O_RDWR, 0);
if (fd < 0) {
Expand Down Expand Up @@ -163,6 +161,9 @@ void GpuMem::init(size_t bytes, bool expose)
addr_ =
(CUdeviceptr)(((uint64_t)raw_addr_ + GPU_PAGE_OFFSET) & GPU_PAGE_MASK);

int one = 1;
CULOG(cuPointerSetAttribute(&one, CU_POINTER_ATTRIBUTE_SYNC_MEMOPS, addr_));

ExposalInfo exp_info;
if (expose) {
int err = mem_expose(&exp_info, addr_, bytes + GPU_PAGE_SIZE);
Expand Down
4 changes: 2 additions & 2 deletions ark/gpu/gpu_mgr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -364,15 +364,15 @@ void GpuMgrCtx::reg_sendrecv(int sid, int remote_gpu_id, size_t bytes,
}

//
void GpuMgrCtx::freeze()
void GpuMgrCtx::freeze(bool expose)
{
//
this->gpu_mgr->validate_total_bytes();

//
if (total_bytes > 0) {
LOG(INFO, "Allocating ", total_bytes, " bytes of GPU memory");
this->data_mem.init(total_bytes, false);
this->data_mem.init(total_bytes, expose);
// init the data mem
CULOG(cuMemsetD32(this->data_mem.ref(), 0, total_bytes >> 2));
}
Expand Down
2 changes: 1 addition & 1 deletion ark/gpu/gpu_mgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class GpuMgrCtx
void mem_export(GpuBuf *buf, size_t offset, int sid);
GpuBuf *mem_import(size_t bytes, int sid, int gpu_id);
void reg_sendrecv(int sid, int gpu_dst, std::size_t bytes, bool is_recv);
void freeze();
void freeze(bool expose = false);
// void send(int sid, int rank, size_t bytes);
GpuState set_current();
int get_world_size() const
Expand Down
4 changes: 2 additions & 2 deletions ark/gpu/gpu_mgr_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ unittest::State test_gpu_mgr_remote()
GpuBuf *gpu1_eid5 = ctx->mem_import(sizeof(int), 5, 1);
GpuBuf *gpu1_eid6 = ctx->mem_import(sizeof(int), 6, 1);

ctx->freeze();
ctx->freeze(true);

volatile int *ptr = (volatile int *)gpu0_eid3->href();
while (*ptr != 7890) {
Expand Down Expand Up @@ -176,7 +176,7 @@ unittest::State test_gpu_mgr_remote()
GpuBuf *gpu0_eid3 = ctx->mem_import(sizeof(int), 3, 0);
GpuBuf *gpu0_eid4 = ctx->mem_import(sizeof(int), 4, 0);

ctx->freeze();
ctx->freeze(true);

gpu_memset(gpu0_eid3, 7890, 1);

Expand Down
68 changes: 57 additions & 11 deletions ark/include/ark.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,59 @@ class TensorBuf
friend class BaseScheduler;
};

// Type of tensor data.
typedef enum
/// Type of tensor data.
class TensorType
{
FP16,
FP32,
INT32,
BYTE,
} TensorType;
private:
const int id_;
const int bytes_;
const std::string name_;
const std::string pointer_name_;

public:
TensorType(int id = -1, int bytes = 0, const std::string &name = "none",
const std::string &pointer_name = "void *");

bool operator==(const TensorType &other) const;
bool operator!=(const TensorType &other) const;

int id() const;
int bytes() const;
const std::string &name() const;
const std::string &pointer_name() const;
};

class Fp16 : public TensorType
{
public:
Fp16();
};

class Fp32 : public TensorType
{
public:
Fp32();
};

class Int32 : public TensorType
{
public:
Int32();
};

class Byte : public TensorType
{
public:
Byte();
};

const TensorType NONE;
const Fp16 FP16;
const Fp32 FP32;
const Int32 INT32;
const Byte BYTE;

std::ostream &operator<<(std::ostream &os, TensorType type);
std::ostream &operator<<(std::ostream &os, const TensorType &type);

/// Tensor is a view of a TensorBuf.
///
Expand All @@ -134,7 +177,7 @@ class Tensor
{
public:
/// Tensor constructor.
Tensor(const Dims &shape, TensorType type, TensorBuf *buf,
Tensor(const Dims &shape, const TensorType &type, TensorBuf *buf,
const Dims &ldims, const Dims &offs, const Dims &pads, bool exported,
int imported_rank, int id, const std::string &name);
Tensor(const Tensor &) = default;
Expand Down Expand Up @@ -273,7 +316,7 @@ class Model
/// Returns a tensor object.
///
/// @param shape Shape of the tensor, where the data of interest is.
/// @param type Type of the tensor data.
/// @param ttype Type of the tensor data.
/// @param buf The @ref TensorBuf that holds the entire data including the
/// padding.
/// @param ldims Leading dimensions (ldim) of the tensor, which may be
Expand All @@ -300,7 +343,7 @@ class Model
/// @param name Name of the tensor.
/// @return Pointer to a tensor object.
///
Tensor *tensor(const Dims &shape, TensorType dtype,
Tensor *tensor(const Dims &shape, const TensorType &ttype,
TensorBuf *buf = nullptr, const Dims &ldims = {},
const Dims &offs = {}, const Dims &pads = {},
const std::vector<Tensor *> &deps = {},
Expand Down Expand Up @@ -471,6 +514,9 @@ class Model
/// Embedding layer.
Tensor *embedding(Tensor *input, Tensor *weight, Tensor *output = nullptr,
const std::string &name = "embedding");
/// Tensor type casting.
Tensor *cast(Tensor *input, const TensorType &ttype,
Tensor *output = nullptr, const std::string &name = "cast");

/// Verify if this model is valid.
/// @return true if the model is valid, false otherwise.
Expand Down
6 changes: 4 additions & 2 deletions ark/include/kernels/activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ struct Activation;
template <typename _ActivationType, typename _InShape>
struct Activation<_ActivationType, _InShape, half, 2>
{
using DataType = half;
using InputType = half;
using OutputType = half;
static const int NelemPerThread = 2;

static DEVICE void compute(half *output, const half *input)
Expand All @@ -96,7 +97,8 @@ struct Activation<_ActivationType, _InShape, half, 2>
template <typename _ActivationType, typename _InShape>
struct Activation<_ActivationType, _InShape, float, 1>
{
using DataType = float;
using InputType = float;
using OutputType = float;
static const int NelemPerThread = 1;

static DEVICE void compute(float *output, const float *input)
Expand Down
22 changes: 14 additions & 8 deletions ark/include/kernels/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,12 @@ template <typename _ArithmeticType, typename _In0Shape, typename _In1Shape,
typename _DataType, int _NelemPerThread>
struct Arithmetic
{
using DataType = _DataType;
using InputType = _DataType;
using OutputType = _DataType;
static const int NelemPerThread = _NelemPerThread;

static DEVICE void compute(DataType *c, const DataType *a,
const DataType *b)
static DEVICE void compute(_DataType *c, const _DataType *a,
const _DataType *b)
{
*c = *a + *b;
if (_In0Shape::W == 1) {
Expand All @@ -121,7 +122,8 @@ struct Arithmetic
template <typename _ArithmeticType, typename _In0Shape, typename _In1Shape>
struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, float, 2>
{
using DataType = float;
using InputType = float;
using OutputType = float;
static const int NelemPerThread = 2;

static DEVICE void compute(float *c, const float *a, const float *b)
Expand All @@ -147,7 +149,8 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, float, 2>
template <typename _ArithmeticType, typename _In0Shape, typename _In1Shape>
struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, float, 4>
{
using DataType = float;
using InputType = float;
using OutputType = float;
static const int NelemPerThread = 4;

static DEVICE void compute(float *c, const float *a, const float *b)
Expand Down Expand Up @@ -218,7 +221,8 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, float, 4>
template <typename _ArithmeticType, typename _In0Shape, typename _In1Shape>
struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, half, 2>
{
using DataType = half;
using InputType = half;
using OutputType = half;
static const int NelemPerThread = 2;

static DEVICE void compute(half *c, const half *a, const half *b)
Expand All @@ -243,7 +247,8 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, half, 2>
template <typename _ArithmeticType, typename _In0Shape, typename _In1Shape>
struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, half, 4>
{
using DataType = half;
using InputType = half;
using OutputType = half;
static const int NelemPerThread = 4;

static DEVICE void compute(half *c, const half *a, const half *b)
Expand Down Expand Up @@ -283,7 +288,8 @@ struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, half, 4>
template <typename _ArithmeticType, typename _In0Shape, typename _In1Shape>
struct Arithmetic<_ArithmeticType, _In0Shape, _In1Shape, half, 8>
{
using DataType = half;
using InputType = half;
using OutputType = half;
static const int NelemPerThread = 8;

static DEVICE void compute(half *c, const half *a, const half *b)
Expand Down
1 change: 1 addition & 0 deletions ark/include/kernels/ark_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "activation.h"
#include "arithmetic.h"
#include "cast.h"
#include "comm.h"
#include "comm_mm.h"
#include "embedding.h"
Expand Down
12 changes: 7 additions & 5 deletions ark/include/kernels/broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ struct Broadcast1
{
using UnitOp =
UnitOp<OutDims, OutShape, UnitOutDims, NumThreads, SmemBytes>;
using DataType = typename CompType::DataType;
using InputType = typename CompType::InputType;
using OutputType = typename CompType::OutputType;
static const int NelemPerThread = CompType::NelemPerThread;

static_assert(NelemPerThread > 0, "NelemPerThread must be positive");
Expand All @@ -84,7 +85,7 @@ struct Broadcast1
/// @param out Output data.
/// @param in1 Input data.
/// @param uop_idx Index of the unit operator.
static DEVICE void run(DataType *out, const DataType *in, int uop_idx)
static DEVICE void run(OutputType *out, const InputType *in, int uop_idx)
{
using InOutChk = BroadcastShapeChecker1<InShape, OutShape>;

Expand Down Expand Up @@ -141,7 +142,8 @@ struct Broadcast2
{
using UnitOp =
UnitOp<OutDims, OutShape, UnitOutDims, NumThreads, SmemBytes>;
using DataType = typename CompType::DataType;
using InputType = typename CompType::InputType;
using OutputType = typename CompType::OutputType;
static const int NelemPerThread = CompType::NelemPerThread;

static_assert(NelemPerThread > 0, "NelemPerThread must be positive");
Expand All @@ -153,8 +155,8 @@ struct Broadcast2
/// @param in0 Input data 0.
/// @param in1 Input data 1.
/// @param uop_idx Index of the unit operator.
static DEVICE void run(DataType *out, const DataType *in0,
const DataType *in1, int uop_idx)
static DEVICE void run(OutputType *out, const InputType *in0,
const InputType *in1, int uop_idx)
{
using InOutChk = BroadcastShapeChecker2<In0Shape, In1Shape, OutShape>;

Expand Down
Loading

0 comments on commit a6912de

Please sign in to comment.