Skip to content

Commit

Permalink
[warpspec] Add experimental support for warp specialization with user…
Browse files Browse the repository at this point in the history
… annotations

This commit is a squash generated by:
```
git diff --stat b62b221a...06ccdadb -- . ':(exclude)python/gemmbench' ':(exclude)python/hstuBench' ':(exclude)third_party/proton'
```
  • Loading branch information
bertmaher committed Nov 14, 2024
1 parent f4c48a9 commit a2b8f27
Show file tree
Hide file tree
Showing 55 changed files with 6,274 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit);

void populateRegReallocOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
PatternBenefit benefit);

} // namespace triton
} // namespace mlir

Expand Down
42 changes: 42 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
#include "triton/Tools/LinearLayout.h"
#include "triton/Tools/StrUtil.h"
#include "triton/Tools/Sys/GetEnv.hpp"
Expand Down Expand Up @@ -144,6 +145,20 @@ using namespace mlir::triton;
namespace mlir {
namespace triton {

static inline void insertBarrier(PatternRewriter &rewriter, Operation *op) {
auto barrierOp = rewriter.create<mlir::gpu::BarrierOp>(op->getLoc());
auto asyncTaskIds = getAsyncTaskIds(op);
if (asyncTaskIds.size() == 1) {
int asyncTaskId = asyncTaskIds[0];
int barId = asyncTaskId + nameBarrierIdBegin;
assert(barId < nameBarrierIdEnd);
// TODO: Change hard code style of numThreads.
const int numThreads = 128;
barrierOp->setAttr("bar_id", rewriter.getI64IntegerAttr(barId));
barrierOp->setAttr("num_threads", rewriter.getI64IntegerAttr(numThreads));
}
}

// Delinearize supposing order is [0, 1, .. , n]
template <typename T>
llvm::SmallVector<T> getMultiDimIndexImpl(T linearIndex,
Expand Down Expand Up @@ -371,6 +386,20 @@ inline Value getStackPointer(RewriterBase &rewriter,
return funcOp.getArgument(funcOp.getNumArguments() - 1);
}

static Operation *getWarpGroupId(Operation *op) {
auto funcOp = op->getParentOfType<FunctionOpInterface>();
Operation *getWarpId = nullptr;
funcOp.walk([&](Operation *op) -> void {
if (isa<mlir::triton::nvidia_gpu::GetCanonicalWarpIdOp>(op)) {
assert(getWarpId == nullptr);
getWarpId = op;
}
});
assert(getWarpId);
getWarpId->dump();
return getWarpId;
}

inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
Operation *op) {
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
Expand All @@ -381,6 +410,19 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
.getValue()
.getZExtValue();
Value offVal = i32_val(offset);
if (op->hasAttr("allocation.copy")) {
auto copy = cast<IntegerAttr>(op->getAttr("allocation.copy")).getValue().getZExtValue();
if (copy != 1) {
Operation *getWarpId = getWarpGroupId(op);
Value warpsPerWG = i32_val(4);
Value wgId = udiv(getWarpId->getResult(0), warpsPerWG);
// (wgId - 1) * allocation.size + offset
auto singleSize = cast<IntegerAttr>(op->getAttr("allocation.size")).getValue().getZExtValue();
Value sub1 = sub(wgId, i32_val(1));
Value temp = mul(sub1, i32_val(singleSize));
offVal = add(temp, offVal);
}
}
Value base = gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal);
return base;
}
Expand Down
105 changes: 105 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -179,4 +179,109 @@ def TritonGPUOptimizeAccumulatorInit: Pass<"tritongpu-optimize-accumulator-init"
"mlir::triton::TritonDialect"];
}

def TritonGPUTaskIdPropagate : Pass<"triton-gpu-taskid-propagate", "mlir::ModuleOp"> {
let summary = "Propagate async_task_id annotations based on dependencies";

let description = [{
This pass propagates the `async_task_id` annotation to the dependencies
of any op that has it set. This has the functional effect of partitioning
the graph into multiple async tasks, based on the initial annotation.
}];

let dependentDialects = [
"mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
];

let options = [
Option<"numConsumerGroups", "num-consumer-groups",
"int32_t", /*default*/"0",
"number of consumer warp groups for warp specialization">
];
}

def TritonGPUWSCodePartition: Pass<"tritongpu-warp-spec-code-partition", "mlir::ModuleOp"> {
let summary = "TritonGPU warp specialization code partition";

let description = "This pass generates warp specialized code baed on task id attributes.";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"];
let options = [
Option<"numBuffers", "num-buffers",
"int32_t", /*default*/"0",
"number of buffering for producer-consumer">,
Option<"numConsumerGroups", "num-consumer-groups",
"int32_t", /*default*/"0",
"number of consumer warp groups for warp specialization">,
Option<"regDecProducer", "producer-reg-dec",
"int32_t", /*default*/"40",
"register decrement for producer warp group">,
Option<"regIncConsumer", "consumer-reg-inc",
"int32_t", /*default*/"232",
"register indrement for consumer warp group">
];
}

def TritonGPUWSDataPartition : Pass<"tritongpu-warp-spec-data-partition", "mlir::ModuleOp"> {
let summary = "Warp specialization data partition";

let description = "This pass partitions operations into multiple suboperations which operate on smaller data shapes";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"];
let options = [
Option<"numConsumerGroups", "num-consumer-groups",
"int32_t", /*default*/"0",
"number of consumer warp groups for warp specialization">
];
}

def TritonGPUWSLowering : Pass<"tritongpu-warp-spec-lowering", "mlir::ModuleOp"> {
let summary = "Warp specialization lowering";

let description = "This pass lowers warp specializtion related operations.";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"];
let options = [
Option<"numConsumerGroups", "num-consumer-groups",
"int32_t", /*default*/"0",
"number of consumer warp groups for warp specialization">
];
}

def TritonGPUPingPongSync: Pass<"tritongpu-ping-pong-sync", "mlir::ModuleOp"> {
let summary = "TritonGPU experiemental ping pong schedule";

let description = "This pass generates warp specialized code baed on warp group id attributes.";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
let options = [
Option<"numConsumerGroups", "num-consumer-groups",
"int32_t", /*default*/"0",
"number of consumer warp groups for warp specialization">,
Option<"partitionStyle", "partition-style",
"int32_t", /*default*/"0",
"partition style for multiple consumer warp groups">
];
}

// #ifdef __FACEBOOK__
def TritonGPULoopScheduling: Pass<"tritongpu-loop-scheduling", "mlir::ModuleOp"> {
let summary = "Generate loop scheduling for SWP";

let description = "This pass sets up stages and clustering for software pipelining.";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];
let options = [
Option<"numStages", "num-stages",
"int32_t", /*default*/"3",
"number of pipeline stages">
];
}
// #endif
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ void addOps(scf::ForOp forOp, int stage,
/// mutable.
void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
Value val);

// Begin __FACEBOOK__ CompPipe
/// Create a map from load ops to their indirection level and the
/// final use of the load op (another load op, or a dot op).
/// Indirection level is "0" for the load op directly used by the dot op,
/// "1" for the load op used by the load op used by the dot op, and so on.
llvm::SmallVector<std::tuple<Operation *, int, Operation *>>
loadOpsToIndirectionLevelAndUse(scf::ForOp forOp);
// End __FACEBOOK__ CompPipe
} // namespace triton
} // namespace mlir

Expand Down
6 changes: 4 additions & 2 deletions include/triton/Dialect/TritonGPU/Transforms/Schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ class CoarseSchedule {
return true;
}

void insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster,
bool includeArg);
void
insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster,
bool includeArg,
DenseMap<Operation *, Operation *> *additionalDep = nullptr);

void erase(Operation *op) { opToStageAndCluster.erase(op); }

Expand Down
119 changes: 119 additions & 0 deletions include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,38 @@ class TTNG_Op<string mnemonic, list<Trait> traits = []> :
!listconcat(traits, [VerifyTensorLayoutsTrait])> {
}

def TTNG_MBarrierArriveOp : TTNG_Op<"mbarrier_arrive", [AttrSizedOperandSegments,
MemoryEffects<[MemWrite<SharedMemory>]>]> {
let summary = "mbarrier arrive";

let description = [{
This operation defining the arriving action for a mbarrier.
txCount:
An optional attribute that set tx-count. This Op will be lowered into
mbarrier.arrive.expect_tx if the optional attribute exist.
trackAsyncOp:
If true, this op will be lowered into cp.async.mbarrier.arrive.noinc.
pred:
Only perform arrive action when pred is true.
remoteCtaId:
if set, perform an remote arrive action.

Example:

triton_nvidia_gpu.mbarrier_arrive %0 {trackAsyncOp = false} : !tt.ptr<i64>

}];

let arguments = (ins TT_MemDescType:$mbarrier,
Optional<I1>:$pred,
Optional<I32>:$remoteCtaId,
I1Attr: $trackAsyncOp,
DefaultValuedAttr<I32Attr, "0">: $txCount
);

let assemblyFormat = "operands attr-dict `:` type(operands)";
}

def TTNG_FenceAsyncSharedOp : TTNG_Op<"fence_async_shared"> {
let arguments = (ins BoolAttr:$bCluster);

Expand All @@ -57,6 +89,31 @@ def TTNG_FenceAsyncSharedOp : TTNG_Op<"fence_async_shared"> {
}];
}

def TTNG_GetCanonicalWarpIdOp : TTNG_Op<"get_canonical_warp_id", [Pure]> {
let description = [{
Returns the one dimensional warpId when it's used for producing warp uniform values.
}];

let results = (outs I32:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}

def TTNG_NamedBarrierArriveOp : TTNG_Op<"bar_arrive", []> {
let summary = "named barrier arrive";

let arguments = (ins I32:$bar, I32: $numThreads);

let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)";
}

def TTNG_NamedBarrierWaitOp : TTNG_Op<"bar_wait", []> {
let summary = "named barrier wait";

let arguments = (ins I32:$bar, I32: $numThreads);

let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)";
}

def TTNG_ClusterArriveOp : TTNG_Op<"cluster_arrive", []> {
let arguments = (ins I1Attr:$relaxed);
let assemblyFormat = "attr-dict";
Expand Down Expand Up @@ -249,4 +306,66 @@ def TTNG_TMAStoreWait : TTNG_Op<"async_tma_store_wait"> {
let assemblyFormat = "attr-dict";
}

def TTNG_GetAsyncTaskIdOp : TTNG_Op<"get_async_task_id", [Pure]> {
let results = (outs I32:$result);

let builders = [OpBuilder<(ins)>];

let assemblyFormat = "attr-dict `:` type($result)";
}

//
// Token
//

def TTNG_CreateTokenOp : TTNG_Op<"create_token"> {
let results = (outs TensorOf<[TTNG_TokenType]>:$result);

let arguments = (ins I32Attr:$num);

let builders = [OpBuilder<(ins "uint32_t":$num)>];

let assemblyFormat = "attr-dict `:` type($result)";
}

def TTNG_ProducerAcquireOp : TTNG_Op<"producer_acquire"> {
let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx, I1:$phase);

let assemblyFormat = "$token `,` $idx `,` $phase attr-dict `:` type(operands)";
}

def TTNG_ProducerCommitOp : TTNG_Op<"producer_commit"> {
let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx);

let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)";
}

def TTNG_ConsumerWaitOp : TTNG_Op<"consumer_wait"> {
let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx, I1: $phase);

let assemblyFormat = "$token `,` $idx `,` $phase attr-dict `:` type(operands)";
}

def TTNG_ConsumerReleaseOp : TTNG_Op<"consumer_release"> {
let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx);

let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)";
}

def TTNG_RegAllocOp : TTNG_Op<"reg_alloc", []> {
let summary = "register allocation";

let arguments = (ins I32Attr: $regCount);

let assemblyFormat = "$regCount attr-dict";
}

def TTNG_RegDeallocOp : TTNG_Op<"reg_dealloc", []> {
let summary = "register deallocation";

let arguments = (ins I32Attr: $regCount);

let assemblyFormat = "$regCount attr-dict";
}

#endif
Loading

0 comments on commit a2b8f27

Please sign in to comment.