Skip to content

Commit

Permalink
Merge branch 'main' into chhwang/send-recv-interface
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Oct 3, 2023
2 parents 2a04c31 + be92315 commit 13d6aef
Show file tree
Hide file tree
Showing 82 changed files with 2,866 additions and 2,209 deletions.
31 changes: 30 additions & 1 deletion ark/gpu/gpu_compile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
#include <cstring>
#include <fstream>
#include <functional>
#include <mutex>

#include "cpu_timer.h"
#include "env.h"
#include "gpu/gpu_logging.h"
#include "include/ark.h"
#include "random.h"
#include "threading.h"

#define ARK_USE_NVRTC 0
#define ARK_DEBUG_KERNEL 0
Expand Down Expand Up @@ -140,6 +140,35 @@ const string link(const vector<string> &ptxs) {

#endif // (ARK_USE_NVRTC)

template <typename ItemType>
static void para_exec(std::vector<ItemType> &items, int max_num_threads,
const std::function<void(ItemType &)> &func) {
size_t nthread = (size_t)max_num_threads;
if (nthread > items.size()) {
nthread = items.size();
}
std::vector<std::thread> threads;
threads.reserve(nthread);
std::mutex mtx;
size_t idx = 0;
for (size_t i = 0; i < nthread; ++i) {
threads.emplace_back([&items, &mtx, &idx, &func] {
size_t local_idx = -1;
for (;;) {
{
const std::lock_guard<std::mutex> lock(mtx);
local_idx = idx++;
}
if (local_idx >= items.size()) break;
func(items[local_idx]);
}
});
}
for (auto &t : threads) {
t.join();
}
}

const string gpu_compile(const vector<string> &codes,
const GpuArchType &arch_type,
unsigned int max_reg_cnt) {
Expand Down
85 changes: 38 additions & 47 deletions ark/include/ark.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,73 +82,64 @@ class CodeGenerator;
class BaseScheduler;
class SchedOp;

// TensorBuf refers to a data array that can be shared by multiple tensors.
class TensorBuf {
public:
TensorBuf(const DimType &bytes = 0, int id = -1);
TensorBuf(const TensorBuf &) = default;

size_t get_buf_offset() const;

DimType bytes;
int id;
bool immutable = false;

protected:
void *buf = nullptr;

friend class Tensor;
friend class BaseScheduler;
};

/// Type of tensor data.
class TensorType {
private:
const int id_;
const int bytes_;
const std::string name_;
const std::string pointer_name_;
const int bytes_;
const std::string type_str_;

public:
TensorType(int id = -1, int bytes = 0, const std::string &name = "none",
const std::string &pointer_name = "void *");
TensorType(const std::string &name = "none", int bytes = 0,
const std::string &type_str = "void *");

bool operator==(const TensorType &other) const;
bool operator!=(const TensorType &other) const;

int id() const;
int bytes() const;
const std::string &name() const;
const std::string &pointer_name() const;
const std::string &type_str() const;
};

class Fp16 : public TensorType {
public:
Fp16();
};
const TensorType NONE;

class Fp32 : public TensorType {
public:
Fp32();
};
std::ostream &operator<<(std::ostream &os, const TensorType &type);

class Int32 : public TensorType {
public:
Int32();
};
#define REGISTER_TENSOR_TYPE(_type_name, _bytes, _type_str) \
class TensorType_##_type_name : public TensorType { \
public: \
TensorType_##_type_name() \
: TensorType{#_type_name, _bytes, _type_str} {} \
}; \
const TensorType_##_type_name _type_name;

REGISTER_TENSOR_TYPE(FP32, 4, "float")
REGISTER_TENSOR_TYPE(FP16, 2, "ark::half")
REGISTER_TENSOR_TYPE(BF16, 2, "ark::bfloat16")
REGISTER_TENSOR_TYPE(INT32, 4, "int32_t")
REGISTER_TENSOR_TYPE(UINT32, 4, "uint32_t")
REGISTER_TENSOR_TYPE(INT8, 1, "int8_t")
REGISTER_TENSOR_TYPE(UINT8, 1, "uint8_t")
REGISTER_TENSOR_TYPE(BYTE, 1, "unsigned char")

class Byte : public TensorType {
// TensorBuf refers to a data array that can be shared by multiple tensors.
class TensorBuf {
public:
Byte();
};
TensorBuf(const DimType &bytes = 0, int id = -1);
TensorBuf(const TensorBuf &) = default;

const TensorType NONE;
const Fp16 FP16;
const Fp32 FP32;
const Int32 INT32;
const Byte BYTE;
size_t get_buf_offset() const;

std::ostream &operator<<(std::ostream &os, const TensorType &type);
DimType bytes;
int id;
bool immutable = false;

protected:
void *buf = nullptr;

friend class Tensor;
friend class BaseScheduler;
};

/// Tensor is a view of a TensorBuf.
///
Expand Down
19 changes: 19 additions & 0 deletions ark/include/ark_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ struct alignas(2) half_t {
operator float() const;
};

// bfloat16 type.
struct alignas(2) bfloat16_t {
uint16_t storage;
bfloat16_t() = default;
// Constructor with float parameter
bfloat16_t(float f);
// Conversion operator from bfloat16 to float
operator float() const;
};

} // namespace ark

ark::half_t operator+(ark::half_t const &lhs, ark::half_t const &rhs);
Expand All @@ -32,6 +42,15 @@ ark::half_t &operator-=(ark::half_t &lhs, ark::half_t const &rhs);

ark::half_t abs(ark::half_t const &val);

ark::bfloat16_t operator+(ark::bfloat16_t const &lhs,
ark::bfloat16_t const &rhs);
ark::bfloat16_t operator-(ark::bfloat16_t const &lhs,
ark::bfloat16_t const &rhs);
ark::bfloat16_t operator*(ark::bfloat16_t const &lhs,
ark::bfloat16_t const &rhs);
ark::bfloat16_t &operator+=(ark::bfloat16_t &lhs, ark::bfloat16_t const &rhs);
ark::bfloat16_t &operator-=(ark::bfloat16_t &lhs, ark::bfloat16_t const &rhs);

// A set of utility functions
namespace ark {
namespace utils {
Expand Down
119 changes: 42 additions & 77 deletions ark/include/kernels/activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
namespace ark {

struct Relu {
static DEVICE float compute(float input) { return max(input, 0.0f); }
static DEVICE __half2 compute(__half2 input) {
return __hmax2(input, (__half2_raw){0, 0});
template <typename DataType>
static DEVICE DataType compute(DataType input) {
return type::Max::compute(input, type::Constant<DataType>::zero());
}
};

Expand All @@ -22,6 +22,10 @@ struct Gelu {
(input + 0.044715f * input * input * input)));
}

static DEVICE bfloat16 compute(bfloat16 input) {
return bfloat16(Gelu::compute(float(input)));
}

static DEVICE __half2 compute(__half2 input) {
__half2 half_pi =
__float2half2_rn(0.7978845608f); // sqrt(2 / pi) = 0.7978845608
Expand All @@ -48,8 +52,11 @@ struct Gelu {
};

struct Sigmoid {
static DEVICE float compute(float input) {
return 1.0f / (1.0f + expf(-input));
template <typename DataType>
static DEVICE DataType compute(DataType input) {
return type::Div::compute(
DataType(1.0f),
(type::Add::compute(DataType(1.0f), type::Exp::compute(-input))));
}
static DEVICE __half2 compute(__half2 input) {
__half2 one = __float2half2_rn(1.0f);
Expand All @@ -59,91 +66,49 @@ struct Sigmoid {
}
};

template <typename _ActivationType, typename _InShape, typename _DataType,
int _NelemPerThread>
struct Activation;

template <typename _ActivationType, typename _InShape>
struct Activation<_ActivationType, _InShape, half, 2> {
using InputType = half;
using OutputType = half;
static const int NelemPerThread = 2;

static DEVICE void compute(half *output, const half *input) {
__half2 *pout = (__half2 *)output;
if (_InShape::W == 1) {
*pout =
_ActivationType::compute(__half2half2(*(const __half *)input));
} else {
__half2 *pin = (__half2 *)input;
*pout = _ActivationType::compute(*pin);
}
}
};

template <typename _ActivationType, typename _InShape>
struct Activation<_ActivationType, _InShape, float, 1> {
using InputType = float;
using OutputType = float;
static const int NelemPerThread = 1;

static DEVICE void compute(float *output, const float *input) {
*output = _ActivationType::compute(*input);
}
};

template <typename InDims, typename InShape, typename OutDims,
typename OutShape, typename UnitOutDims, int NumThreads,
int SmemBytes>
DEVICE void relu(float *out, float *in, int uop_idx, int) {
Broadcast1<InDims, InShape, OutDims, OutShape, UnitOutDims, NumThreads,
SmemBytes, Activation<Relu, InShape, float, 1>>::run(out, in,
uop_idx);
}

template <typename InDims, typename InShape, typename OutDims,
typename OutShape, typename UnitOutDims, int NumThreads,
int SmemBytes>
DEVICE void relu(half *out, half *in, int uop_idx, int) {
Broadcast1<InDims, InShape, OutDims, OutShape, UnitOutDims, NumThreads,
SmemBytes, Activation<Relu, InShape, half, 2>>::run(out, in,
uop_idx);
}

template <typename InDims, typename InShape, typename OutDims,
typename OutShape, typename UnitOutDims, int NumThreads,
int SmemBytes>
DEVICE void gelu(float *out, float *in, int uop_idx, int) {
Broadcast1<InDims, InShape, OutDims, OutShape, UnitOutDims, NumThreads,
SmemBytes, Activation<Gelu, InShape, float, 1>>::run(out, in,
uop_idx);
}

template <typename InDims, typename InShape, typename OutDims,
typename OutShape, typename UnitOutDims, int NumThreads,
int SmemBytes>
DEVICE void gelu(half *out, half *in, int uop_idx, int) {
int SmemBytes, typename InDataType, typename OutDataType>
DEVICE void relu(OutDataType *out, InDataType *in, int uop_idx,
int smem_per_warp) {
constexpr int NelemPerThread =
(sizeof(OutDataType) <= 2 && UnitOutDims::W % 8 == 0)
? 8
: (UnitOutDims::W % 4 == 0) ? 4 : (UnitOutDims::W % 2 == 0) ? 2 : 1;
Broadcast1<InDims, InShape, OutDims, OutShape, UnitOutDims, NumThreads,
SmemBytes, Activation<Gelu, InShape, half, 2>>::run(out, in,
uop_idx);
SmemBytes,
Broadcast1Intrinsic<Relu, InShape, InDataType, OutDataType,
NelemPerThread>>::run(out, in, uop_idx);
}

template <typename InDims, typename InShape, typename OutDims,
typename OutShape, typename UnitOutDims, int NumThreads,
int SmemBytes>
DEVICE void sigmoid(float *out, float *in, int uop_idx, int) {
int SmemBytes, typename InDataType, typename OutDataType>
DEVICE void gelu(OutDataType *out, InDataType *in, int uop_idx,
int smem_per_warp) {
constexpr int NelemPerThread =
(sizeof(OutDataType) <= 2 && UnitOutDims::W % 8 == 0)
? 8
: (UnitOutDims::W % 4 == 0) ? 4 : (UnitOutDims::W % 2 == 0) ? 2 : 1;
Broadcast1<InDims, InShape, OutDims, OutShape, UnitOutDims, NumThreads,
SmemBytes, Activation<Sigmoid, InShape, float, 1>>::run(out, in,
uop_idx);
SmemBytes,
Broadcast1Intrinsic<Gelu, InShape, InDataType, OutDataType,
NelemPerThread>>::run(out, in, uop_idx);
}

template <typename InDims, typename InShape, typename OutDims,
typename OutShape, typename UnitOutDims, int NumThreads,
int SmemBytes>
DEVICE void sigmoid(half *out, half *in, int uop_idx, int) {
int SmemBytes, typename InDataType, typename OutDataType>
DEVICE void sigmoid(OutDataType *out, InDataType *in, int uop_idx,
int smem_per_warp) {
constexpr int NelemPerThread =
(sizeof(OutDataType) <= 2 && UnitOutDims::W % 8 == 0)
? 8
: (UnitOutDims::W % 4 == 0) ? 4 : (UnitOutDims::W % 2 == 0) ? 2 : 1;
Broadcast1<InDims, InShape, OutDims, OutShape, UnitOutDims, NumThreads,
SmemBytes, Activation<Sigmoid, InShape, half, 2>>::run(out, in,
uop_idx);
SmemBytes,
Broadcast1Intrinsic<Sigmoid, InShape, InDataType, OutDataType,
NelemPerThread>>::run(out, in, uop_idx);
}

} // namespace ark
Expand Down
Loading

0 comments on commit 13d6aef

Please sign in to comment.