Skip to content

Commit

Permalink
[warpspec] Add experimental support for warp specialization with user…
Browse files Browse the repository at this point in the history
… annotations

This commit is a squash generated by:
```
git diff --stat b62b221a...06ccdadb -- . ':(exclude)python/gemmbench' ':(exclude)python/hstuBench' ':(exclude)third_party/proton'
```
  • Loading branch information
bertmaher committed Nov 14, 2024
1 parent f4c48a9 commit fdf1c9e
Show file tree
Hide file tree
Showing 78 changed files with 7,511 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,15 @@ void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit);

void populateRegReallocOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
PatternBenefit benefit);

void populateProtonOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
const TargetInfoBase &targetInfo,
PatternBenefit benefit);

} // namespace triton
} // namespace mlir

Expand Down
5 changes: 5 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ class TargetInfoBase {
virtual Value programId(RewriterBase &rewriter, Location loc,
ModuleOp moduleOp, int axis) const = 0;

virtual Value smId(RewriterBase &rewriter, Location loc) const = 0;

virtual Value clock(RewriterBase &rewriter, Location loc,
bool isClock64) const = 0;

virtual bool warpReduce(RewriterBase &rewriter, Location loc,
SmallVector<Value> &acc, triton::ReduceOp op,
unsigned numLaneToReduce,
Expand Down
42 changes: 42 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
#include "triton/Tools/LinearLayout.h"
#include "triton/Tools/StrUtil.h"
#include "triton/Tools/Sys/GetEnv.hpp"
Expand Down Expand Up @@ -144,6 +145,20 @@ using namespace mlir::triton;
namespace mlir {
namespace triton {

static inline void insertBarrier(PatternRewriter &rewriter, Operation *op) {
auto barrierOp = rewriter.create<mlir::gpu::BarrierOp>(op->getLoc());
auto asyncTaskIds = getAsyncTaskIds(op);
if (asyncTaskIds.size() == 1) {
int asyncTaskId = asyncTaskIds[0];
int barId = asyncTaskId + nameBarrierIdBegin;
assert(barId < nameBarrierIdEnd);
// TODO: Change hard code style of numThreads.
const int numThreads = 128;
barrierOp->setAttr("bar_id", rewriter.getI64IntegerAttr(barId));
barrierOp->setAttr("num_threads", rewriter.getI64IntegerAttr(numThreads));
}
}

// Delinearize supposing order is [0, 1, .. , n]
template <typename T>
llvm::SmallVector<T> getMultiDimIndexImpl(T linearIndex,
Expand Down Expand Up @@ -371,6 +386,20 @@ inline Value getStackPointer(RewriterBase &rewriter,
return funcOp.getArgument(funcOp.getNumArguments() - 1);
}

static Operation *getWarpGroupId(Operation *op) {
auto funcOp = op->getParentOfType<FunctionOpInterface>();
Operation *getWarpId = nullptr;
funcOp.walk([&](Operation *op) -> void {
if (isa<mlir::triton::nvidia_gpu::GetCanonicalWarpIdOp>(op)) {
assert(getWarpId == nullptr);
getWarpId = op;
}
});
assert(getWarpId);
getWarpId->dump();
return getWarpId;
}

inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
Operation *op) {
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
Expand All @@ -381,6 +410,19 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
.getValue()
.getZExtValue();
Value offVal = i32_val(offset);
if (op->hasAttr("allocation.copy")) {
auto copy = cast<IntegerAttr>(op->getAttr("allocation.copy")).getValue().getZExtValue();
if (copy != 1) {
Operation *getWarpId = getWarpGroupId(op);
Value warpsPerWG = i32_val(4);
Value wgId = udiv(getWarpId->getResult(0), warpsPerWG);
// (wgId - 1) * allocation.size + offset
auto singleSize = cast<IntegerAttr>(op->getAttr("allocation.size")).getValue().getZExtValue();
Value sub1 = sub(wgId, i32_val(1));
Value temp = mul(sub1, i32_val(singleSize));
offVal = add(temp, offVal);
}
}
Value base = gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal);
return base;
}
Expand Down
4 changes: 3 additions & 1 deletion include/triton/Conversion/TritonToTritonGPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
Option<"numWarps", "num-warps",
"int32_t", /*default*/"4",
"number of warps">,

Option<"protonSlots", "proton-slots",
"int32_t", /*default*/"0",
"number of proton profiler slots">,
Option<"threadsPerWarp", "threads-per-warp",
"int32_t", /*default*/"32",
"number of threads per warp">,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace triton {
constexpr static char AttrNumWarpsName[] = "triton_gpu.num-warps";
constexpr static char AttrNumCTAsName[] = "triton_gpu.num-ctas";
constexpr static char AttrTargetName[] = "triton_gpu.target";

constexpr static char AttrProtonSlotsName[] = "triton_gpu.proton-slots";
constexpr static char AttrNumThreadsPerWarp[] = "triton_gpu.threads-per-warp";

// Create the pass with numWarps passed from cl::opt.
Expand All @@ -23,7 +23,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonToTritonGPUPass();
// Create the pass with numWarps set explicitly.
std::unique_ptr<OperationPass<ModuleOp>>
createConvertTritonToTritonGPUPass(const std::string &target, int numWarps,
int threadsPerWarp = 32, int numCTAs = 1);
int threadsPerWarp = 32, int numCTAs = 1,
int protonSlots = 0);

} // namespace triton
} // namespace mlir
Expand Down
20 changes: 20 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,24 @@ def TT_InputPrecisionAttr : I32EnumAttr<
let cppNamespace = "::mlir::triton";
}

// Proton Profiling Metric.
def TT_ProtonMetricAttr : I32EnumAttr<
"ProtonMetric", "",
[
I32EnumAttrCase<"CYCLE", 0, "cycle">,
I32EnumAttrCase<"INVALID", 1, "invalid">,
]> {
let cppNamespace = "::mlir::triton";
}

// Proton Profiling Granularity.
def TT_ProtonGranularityAttr : I32EnumAttr<
"ProtonGranularity", "",
[
I32EnumAttrCase<"WARPGROUP", 0, "warpgroup">,
I32EnumAttrCase<"WARP", 1, "warp">,
]> {
let cppNamespace = "::mlir::triton";
}

#endif
18 changes: 18 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1205,5 +1205,23 @@ def TT_ExperimentalTensormapFenceproxyAcquireOp: TT_Op<
}];
}

def TT_ProtonRecordOp : TT_Op<"proton_record", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Record a hardware event";

let description = [{
The intra kernel profiler records an event from a special register or a hardware performance counter.
Currently only cycle counter is supported.
}];
let arguments = (
ins BoolAttr: $isStart,
I32Attr: $regionId,
DefaultValuedAttr<TT_ProtonMetricAttr, "triton::ProtonMetric::CYCLE">:$metric,
DefaultValuedAttr<TT_ProtonGranularityAttr, "triton::ProtonGranularity::WARPGROUP">:$granularity
);

let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}


#endif // Triton_OPS
8 changes: 8 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,14 @@ void dumpHWLayout(RankedTensorType tensorType);
// Return a string representation of the layout of the tensor.
std::string getLayoutStr(RankedTensorType tensorType, bool useHWPointOfView);

// Return the number of warps in a warp group. Currently, hard-coded to 4.
// TODO(fywkevin): Put this as an attribute of the module
// so user can flexibly choose the granularity to profile.
const int getWarpGroupSize();

// Return the number of words (4bytes) each proton entry has.
const int getWordsPerProtonEntry();

} // namespace gpu
} // namespace triton
} // namespace mlir
Expand Down
41 changes: 41 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -256,4 +256,45 @@ def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods<MemoryEf
}];
}

def TTG_LocalRecordOp : TTG_Op<"local_record", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Record a hardware event";

let description = [{
The intra kernel profiler records an event from a special register or a hardware performance counter,
and save it in the shared memory. The bookkeeping is automatically maintained.
}];
let arguments = (
ins TT_MemDescType:$data,
TT_PtrLike :$indexPtr,
BoolAttr: $isStart,
I32Attr: $regionId,
DefaultValuedAttr<TT_ProtonMetricAttr, "triton::ProtonMetric::CYCLE">:$metric,
DefaultValuedAttr<TT_ProtonGranularityAttr, "triton::ProtonGranularity::WARPGROUP">:$granularity
);
let hasVerifier = 1;
}

def TTG_ProtonFinalizeOp : TTG_Op<"proton_finalize", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Finalize the intra kernel profiler";

let description = [{
Finalize the intra kernel profiler, including dumping the metadata and measurements to the global memory.
}];
let arguments = (
ins TT_MemDescType:$data,
TT_PtrLike :$indexPtr,
TT_PtrLike :$ptr
);
}

def TTG_ProtonInitOp : TTG_Op<"proton_init", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Initialize the intra kernel profiler";

let description = [{
Stack allocation and initialization for the intra kernel profiler.
`indexPtr` stores the number of entires proton measured, initialized as 0.
}];
let arguments = (ins);
let results = (outs TT_PtrLike :$indexPtr);
}
#endif
117 changes: 117 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -179,4 +179,121 @@ def TritonGPUOptimizeAccumulatorInit: Pass<"tritongpu-optimize-accumulator-init"
"mlir::triton::TritonDialect"];
}

def TritonGPUTaskIdPropagate : Pass<"triton-gpu-taskid-propagate", "mlir::ModuleOp"> {
let summary = "Propagate async_task_id annotations based on dependencies";

let description = [{
This pass propagates the `async_task_id` annotation to the dependencies
of any op that has it set. This has the functional effect of partitioning
the graph into multiple async tasks, based on the initial annotation.
}];

let dependentDialects = [
"mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
];

let options = [
Option<"numConsumerGroups", "num-consumer-groups",
"int32_t", /*default*/"0",
"number of consumer warp groups for warp specialization">
];
}

def TritonGPUWSCodePartition: Pass<"tritongpu-warp-spec-code-partition", "mlir::ModuleOp"> {
let summary = "TritonGPU warp specialization code partition";

let description = "This pass generates warp specialized code baed on task id attributes.";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"];
let options = [
Option<"numBuffers", "num-buffers",
"int32_t", /*default*/"0",
"number of buffering for producer-consumer">,
Option<"numConsumerGroups", "num-consumer-groups",
"int32_t", /*default*/"0",
"number of consumer warp groups for warp specialization">,
Option<"regDecProducer", "producer-reg-dec",
"int32_t", /*default*/"40",
"register decrement for producer warp group">,
Option<"regIncConsumer", "consumer-reg-inc",
"int32_t", /*default*/"232",
"register indrement for consumer warp group">
];
}

def TritonGPUWSDataPartition : Pass<"tritongpu-warp-spec-data-partition", "mlir::ModuleOp"> {
let summary = "Warp specialization data partition";

let description = "This pass partitions operations into multiple suboperations which operate on smaller data shapes";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"];
let options = [
Option<"numConsumerGroups", "num-consumer-groups",
"int32_t", /*default*/"0",
"number of consumer warp groups for warp specialization">
];
}

def TritonGPUWSLowering : Pass<"tritongpu-warp-spec-lowering", "mlir::ModuleOp"> {
let summary = "Warp specialization lowering";

let description = "This pass lowers warp specializtion related operations.";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"];
let options = [
Option<"numConsumerGroups", "num-consumer-groups",
"int32_t", /*default*/"0",
"number of consumer warp groups for warp specialization">
];
}

def TritonGPUPingPongSync: Pass<"tritongpu-ping-pong-sync", "mlir::ModuleOp"> {
let summary = "TritonGPU experiemental ping pong schedule";

let description = "This pass generates warp specialized code baed on warp group id attributes.";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
let options = [
Option<"numConsumerGroups", "num-consumer-groups",
"int32_t", /*default*/"0",
"number of consumer warp groups for warp specialization">,
Option<"partitionStyle", "partition-style",
"int32_t", /*default*/"0",
"partition style for multiple consumer warp groups">
];
}

// #ifdef __FACEBOOK__
def TritonGPULoopScheduling: Pass<"tritongpu-loop-scheduling", "mlir::ModuleOp"> {
let summary = "Generate loop scheduling for SWP";

let description = "This pass sets up stages and clustering for software pipelining.";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
let options = [
Option<"numStages", "num-stages",
"int32_t", /*default*/"3",
"number of pipeline stages">
];
}
// #endif

def TritonGPUProtonLowering: Pass<"tritongpu-proton-lowering", "mlir::ModuleOp"> {
let summary = "Lower the Triton's ProtonRecordOp with scaffolding code (e.g., resource allocation and post processing).";

let description = "Allocate memory for local profiling buffers and convert the Triton's ProtonRecordOp into "
"TritonGPU's LocalRecordOp with resource binded. Before exiting the local profiling buffers "
"(with metadata) are copied back to GPU's global memory.";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ void addOps(scf::ForOp forOp, int stage,
/// mutable.
void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
Value val);

// Begin __FACEBOOK__ CompPipe
/// Create a map from load ops to their indirection level and the
/// final use of the load op (another load op, or a dot op).
/// Indirection level is "0" for the load op directly used by the dot op,
/// "1" for the load op used by the load op used by the dot op, and so on.
llvm::SmallVector<std::tuple<Operation *, int, Operation *>>
loadOpsToIndirectionLevelAndUse(scf::ForOp forOp);
// End __FACEBOOK__ CompPipe
} // namespace triton
} // namespace mlir

Expand Down
6 changes: 4 additions & 2 deletions include/triton/Dialect/TritonGPU/Transforms/Schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ class CoarseSchedule {
return true;
}

void insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster,
bool includeArg);
void
insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster,
bool includeArg,
DenseMap<Operation *, Operation *> *additionalDep = nullptr);

void erase(Operation *op) { opToStageAndCluster.erase(op); }

Expand Down
Loading

0 comments on commit fdf1c9e

Please sign in to comment.