diff --git a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h index 29af2c5f7..ebb54a8a8 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -102,6 +102,10 @@ void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, PatternBenefit benefit); +void populateRegReallocOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + } // namespace triton } // namespace mlir diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index b209a02b4..bbbba9459 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -13,6 +13,7 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" #include "triton/Tools/LinearLayout.h" #include "triton/Tools/StrUtil.h" #include "triton/Tools/Sys/GetEnv.hpp" @@ -144,6 +145,20 @@ using namespace mlir::triton; namespace mlir { namespace triton { +static inline void insertBarrier(PatternRewriter &rewriter, Operation *op) { + auto barrierOp = rewriter.create(op->getLoc()); + auto asyncTaskIds = getAsyncTaskIds(op); + if (asyncTaskIds.size() == 1) { + int asyncTaskId = asyncTaskIds[0]; + int barId = asyncTaskId + nameBarrierIdBegin; + assert(barId < nameBarrierIdEnd); + // TODO: Change hard code style of numThreads. + const int numThreads = 128; + barrierOp->setAttr("bar_id", rewriter.getI64IntegerAttr(barId)); + barrierOp->setAttr("num_threads", rewriter.getI64IntegerAttr(numThreads)); + } +} + // Delinearize supposing order is [0, 1, .. , n] template llvm::SmallVector getMultiDimIndexImpl(T linearIndex, @@ -371,6 +386,20 @@ inline Value getStackPointer(RewriterBase &rewriter, return funcOp.getArgument(funcOp.getNumArguments() - 1); } +static Operation *getWarpGroupId(Operation *op) { + auto funcOp = op->getParentOfType(); + Operation *getWarpId = nullptr; + funcOp.walk([&](Operation *op) -> void { + if (isa(op)) { + assert(getWarpId == nullptr); + getWarpId = op; + } + }); + assert(getWarpId); + getWarpId->dump(); + return getWarpId; +} + inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter, Operation *op) { auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); @@ -381,6 +410,19 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter, .getValue() .getZExtValue(); Value offVal = i32_val(offset); + if (op->hasAttr("allocation.copy")) { + auto copy = cast(op->getAttr("allocation.copy")).getValue().getZExtValue(); + if (copy != 1) { + Operation *getWarpId = getWarpGroupId(op); + Value warpsPerWG = i32_val(4); + Value wgId = udiv(getWarpId->getResult(0), warpsPerWG); + // (wgId - 1) * allocation.size + offset + auto singleSize = cast(op->getAttr("allocation.size")).getValue().getZExtValue(); + Value sub1 = sub(wgId, i32_val(1)); + Value temp = mul(sub1, i32_val(singleSize)); + offVal = add(temp, offVal); + } + } Value base = gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal); return base; } diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index f2b79d222..d375e3801 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -179,4 +179,109 @@ def TritonGPUOptimizeAccumulatorInit: Pass<"tritongpu-optimize-accumulator-init" "mlir::triton::TritonDialect"]; } +def TritonGPUTaskIdPropagate : Pass<"triton-gpu-taskid-propagate", "mlir::ModuleOp"> { + let summary = "Propagate async_task_id annotations based on dependencies"; + + let description = [{ + This pass propagates the `async_task_id` annotation to the dependencies + of any op that has it set. This has the functional effect of partitioning + the graph into multiple async tasks, based on the initial annotation. + }]; + + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; + + let options = [ + Option<"numConsumerGroups", "num-consumer-groups", + "int32_t", /*default*/"0", + "number of consumer warp groups for warp specialization"> + ]; +} + +def TritonGPUWSCodePartition: Pass<"tritongpu-warp-spec-code-partition", "mlir::ModuleOp"> { + let summary = "TritonGPU warp specialization code partition"; + + let description = "This pass generates warp specialized code baed on task id attributes."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"]; + let options = [ + Option<"numBuffers", "num-buffers", + "int32_t", /*default*/"0", + "number of buffering for producer-consumer">, + Option<"numConsumerGroups", "num-consumer-groups", + "int32_t", /*default*/"0", + "number of consumer warp groups for warp specialization">, + Option<"regDecProducer", "producer-reg-dec", + "int32_t", /*default*/"40", + "register decrement for producer warp group">, + Option<"regIncConsumer", "consumer-reg-inc", + "int32_t", /*default*/"232", + "register indrement for consumer warp group"> + ]; +} + +def TritonGPUWSDataPartition : Pass<"tritongpu-warp-spec-data-partition", "mlir::ModuleOp"> { + let summary = "Warp specialization data partition"; + + let description = "This pass partitions operations into multiple suboperations which operate on smaller data shapes"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"]; + let options = [ + Option<"numConsumerGroups", "num-consumer-groups", + "int32_t", /*default*/"0", + "number of consumer warp groups for warp specialization"> + ]; +} + +def TritonGPUWSLowering : Pass<"tritongpu-warp-spec-lowering", "mlir::ModuleOp"> { + let summary = "Warp specialization lowering"; + + let description = "This pass lowers warp specializtion related operations."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"]; + let options = [ + Option<"numConsumerGroups", "num-consumer-groups", + "int32_t", /*default*/"0", + "number of consumer warp groups for warp specialization"> + ]; +} + +def TritonGPUPingPongSync: Pass<"tritongpu-ping-pong-sync", "mlir::ModuleOp"> { + let summary = "TritonGPU experiemental ping pong schedule"; + + let description = "This pass generates warp specialized code baed on warp group id attributes."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; + let options = [ + Option<"numConsumerGroups", "num-consumer-groups", + "int32_t", /*default*/"0", + "number of consumer warp groups for warp specialization">, + Option<"partitionStyle", "partition-style", + "int32_t", /*default*/"0", + "partition style for multiple consumer warp groups"> + ]; +} + +// #ifdef __FACEBOOK__ +def TritonGPULoopScheduling: Pass<"tritongpu-loop-scheduling", "mlir::ModuleOp"> { + let summary = "Generate loop scheduling for SWP"; + + let description = "This pass sets up stages and clustering for software pipelining."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; + let options = [ + Option<"numStages", "num-stages", + "int32_t", /*default*/"3", + "number of pipeline stages"> + ]; +} +// #endif #endif diff --git a/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h b/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h index 88f062a01..db89d0dec 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h @@ -29,6 +29,15 @@ void addOps(scf::ForOp forOp, int stage, /// mutable. void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse, Value val); + +// Begin __FACEBOOK__ CompPipe +/// Create a map from load ops to their indirection level and the +/// final use of the load op (another load op, or a dot op). +/// Indirection level is "0" for the load op directly used by the dot op, +/// "1" for the load op used by the load op used by the dot op, and so on. +llvm::SmallVector> +loadOpsToIndirectionLevelAndUse(scf::ForOp forOp); +// End __FACEBOOK__ CompPipe } // namespace triton } // namespace mlir diff --git a/include/triton/Dialect/TritonGPU/Transforms/Schedule.h b/include/triton/Dialect/TritonGPU/Transforms/Schedule.h index 1dd1fc686..4bd8ff79e 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Schedule.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Schedule.h @@ -84,8 +84,10 @@ class CoarseSchedule { return true; } - void insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster, - bool includeArg); + void + insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster, + bool includeArg, + DenseMap *additionalDep = nullptr); void erase(Operation *op) { opToStageAndCluster.erase(op); } diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index 243b93436..3ce1d80de 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -43,6 +43,38 @@ class TTNG_Op traits = []> : !listconcat(traits, [VerifyTensorLayoutsTrait])> { } +def TTNG_MBarrierArriveOp : TTNG_Op<"mbarrier_arrive", [AttrSizedOperandSegments, + MemoryEffects<[MemWrite]>]> { + let summary = "mbarrier arrive"; + + let description = [{ + This operation defining the arriving action for a mbarrier. + txCount: + An optional attribute that set tx-count. This Op will be lowered into + mbarrier.arrive.expect_tx if the optional attribute exist. + trackAsyncOp: + If true, this op will be lowered into cp.async.mbarrier.arrive.noinc. + pred: + Only perform arrive action when pred is true. + remoteCtaId: + if set, perform an remote arrive action. + + Example: + + triton_nvidia_gpu.mbarrier_arrive %0 {trackAsyncOp = false} : !tt.ptr + + }]; + + let arguments = (ins TT_MemDescType:$mbarrier, + Optional:$pred, + Optional:$remoteCtaId, + I1Attr: $trackAsyncOp, + DefaultValuedAttr: $txCount + ); + + let assemblyFormat = "operands attr-dict `:` type(operands)"; +} + def TTNG_FenceAsyncSharedOp : TTNG_Op<"fence_async_shared"> { let arguments = (ins BoolAttr:$bCluster); @@ -57,6 +89,31 @@ def TTNG_FenceAsyncSharedOp : TTNG_Op<"fence_async_shared"> { }]; } +def TTNG_GetCanonicalWarpIdOp : TTNG_Op<"get_canonical_warp_id", [Pure]> { + let description = [{ + Returns the one dimensional warpId when it's used for producing warp uniform values. + }]; + + let results = (outs I32:$result); + let assemblyFormat = "attr-dict `:` type($result)"; +} + +def TTNG_NamedBarrierArriveOp : TTNG_Op<"bar_arrive", []> { + let summary = "named barrier arrive"; + + let arguments = (ins I32:$bar, I32: $numThreads); + + let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)"; +} + +def TTNG_NamedBarrierWaitOp : TTNG_Op<"bar_wait", []> { + let summary = "named barrier wait"; + + let arguments = (ins I32:$bar, I32: $numThreads); + + let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)"; +} + def TTNG_ClusterArriveOp : TTNG_Op<"cluster_arrive", []> { let arguments = (ins I1Attr:$relaxed); let assemblyFormat = "attr-dict"; @@ -249,4 +306,66 @@ def TTNG_TMAStoreWait : TTNG_Op<"async_tma_store_wait"> { let assemblyFormat = "attr-dict"; } +def TTNG_GetAsyncTaskIdOp : TTNG_Op<"get_async_task_id", [Pure]> { + let results = (outs I32:$result); + + let builders = [OpBuilder<(ins)>]; + + let assemblyFormat = "attr-dict `:` type($result)"; +} + +// +// Token +// + +def TTNG_CreateTokenOp : TTNG_Op<"create_token"> { + let results = (outs TensorOf<[TTNG_TokenType]>:$result); + + let arguments = (ins I32Attr:$num); + + let builders = [OpBuilder<(ins "uint32_t":$num)>]; + + let assemblyFormat = "attr-dict `:` type($result)"; +} + +def TTNG_ProducerAcquireOp : TTNG_Op<"producer_acquire"> { + let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx, I1:$phase); + + let assemblyFormat = "$token `,` $idx `,` $phase attr-dict `:` type(operands)"; +} + +def TTNG_ProducerCommitOp : TTNG_Op<"producer_commit"> { + let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx); + + let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)"; +} + +def TTNG_ConsumerWaitOp : TTNG_Op<"consumer_wait"> { + let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx, I1: $phase); + + let assemblyFormat = "$token `,` $idx `,` $phase attr-dict `:` type(operands)"; +} + +def TTNG_ConsumerReleaseOp : TTNG_Op<"consumer_release"> { + let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx); + + let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)"; +} + +def TTNG_RegAllocOp : TTNG_Op<"reg_alloc", []> { + let summary = "register allocation"; + + let arguments = (ins I32Attr: $regCount); + + let assemblyFormat = "$regCount attr-dict"; +} + +def TTNG_RegDeallocOp : TTNG_Op<"reg_dealloc", []> { + let summary = "register deallocation"; + + let arguments = (ins I32Attr: $regCount); + + let assemblyFormat = "$regCount attr-dict"; +} + #endif diff --git a/include/triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h b/include/triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h new file mode 100644 index 000000000..6a200ebe9 --- /dev/null +++ b/include/triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h @@ -0,0 +1,129 @@ + +#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_ +#define TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_ + +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "nvidia/include/Dialect/NVGPU/IR/Dialect.h" +#include "llvm/ADT/MapVector.h" + +namespace mlir { + +// 0 is reserved for default sync. +// TODO: comprehensive mechanism to globally manage namedbarrier. +static int const nameBarrierIdBegin = 1; +static int nameBarrierIdEnd = 16; + +/// Helper functions for async task +typedef int AsyncTaskId; +SmallVector getAsyncTaskIds(Operation *op); +bool hasAsyncTaskId(Operation *op, AsyncTaskId asyncTaskId); +void setAsyncTaskIds(Operation *op, ArrayRef asyncTaskIds); +SmallVector getNestedAsyncTaskIds(Operation *op); +void addAsyncTaskIds(Operation *op, ArrayRef asyncTasks); +void removeAsyncTaskId(Operation *op, AsyncTaskId asyncTaskId); +void removeAsyncTaskIds(Operation *op); +SmallVector getMutexBarIds(Operation *op); +SmallVector getMutexNumThreads(Operation *op); + +static Value GetCanonicalWarpId(RewriterBase &rewriter, Location loc) { + return rewriter.create( + loc, rewriter.getI32Type()); +} + +static Value getClusterCTAId(RewriterBase &rewriter, Location loc) { + return rewriter.create(loc, + rewriter.getI32Type()); +} + +class OpBuilderWithAsyncTaskIds : public OpBuilder { +public: + OpBuilderWithAsyncTaskIds(MLIRContext *context) : OpBuilder(context) {} + + explicit OpBuilderWithAsyncTaskIds(Operation *op) + : OpBuilder(op) + { + setAsyncTaskIdsFromOp(op); + } + + void setAsynTaskIdsFromArray(ArrayRef newAsyncTaskIds) { + asyncTaskIds = SmallVector(newAsyncTaskIds.begin(), newAsyncTaskIds.end()); + } + + void setAsyncTaskIdsFromOp(Operation *op) { + setAsynTaskIdsFromArray(getAsyncTaskIds(op)); + } + + void setAsyncTaskIdsFromValueUsers(Value value) { + SetVector asyncTaskIdSet; + for (Operation *user : value.getUsers()) + for (AsyncTaskId asyncTaskId : getAsyncTaskIds(user)) + asyncTaskIdSet.insert(asyncTaskId); + setAsynTaskIdsFromArray(asyncTaskIdSet.getArrayRef()); + } + + template + OpTy createWithAsyncTaskIds(Args &&...args) { + OpTy op = create(std::forward(args)...); + if (!asyncTaskIds.empty()) + setAsyncTaskIds(op, asyncTaskIds); + return op; + } + +private: + SmallVector asyncTaskIds; +}; + +class PatternRewriterWithAsyncTaskIds { +public: + PatternRewriterWithAsyncTaskIds(PatternRewriter &rewriter, Operation *op) + : rewriter(&rewriter) { + setAsyncTaskIdsFromOp(op); + } + + void setAsynTaskIdsFromArray(ArrayRef newAsyncTaskIds) { + asyncTaskIds = SmallVector(newAsyncTaskIds.begin(), newAsyncTaskIds.end()); + } + + void setAsyncTaskIdsFromOp(Operation *op) { + setAsynTaskIdsFromArray(getAsyncTaskIds(op)); + } + + void setAsyncTaskIdsFromValueUsers(Value value) { + SetVector asyncTaskIdSet; + for (Operation *user : value.getUsers()) + for (AsyncTaskId asyncTaskId : getAsyncTaskIds(user)) + asyncTaskIdSet.insert(asyncTaskId); + setAsynTaskIdsFromArray(asyncTaskIdSet.getArrayRef()); + } + + template + OpTy create(Location location, Args &&...args) { + OpTy op = rewriter->create(location, std::forward(args)...); + if (!asyncTaskIds.empty()) + setAsyncTaskIds(op, asyncTaskIds); + return op; + } + + template + OpTy replaceOpWithNewOp(Operation *op, Args &&...args) { + auto newOp = + rewriter->replaceOpWithNewOp(op, std::forward(args)...); + return newOp; + } + +private: + PatternRewriter* rewriter; + SmallVector asyncTaskIds; +}; + +/// Constant task ids +constexpr AsyncTaskId kLoadAsyncTaskId = 0; +constexpr AsyncTaskId kDotAsyncTaskId = 1; + +bool isWSCandidateLoad(Operation *op); +bool isWSSupported(ModuleOp m, int computeCapability); + +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_ diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 43e7df135..ef77cccc6 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -30,6 +30,11 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "TRITON_LLVM_DEBUG_ONLY", "USE_IR_LOC", "NVPTX_ENABLE_DUMP", + "PEEL_LAST_ITER", + "ENABLE_PINGPONG", + "HACK_ASYNC_DOT", + "SWP_FOR_CONSUMER", + "HARDCODE_TASKID_FA", // clang-format on }; diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index b44b75601..eda8628c7 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -13,6 +13,8 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/ADT/SmallVector.h" using ::mlir::triton::gpu::AMDMfmaEncodingAttr; @@ -189,6 +191,15 @@ class AllocationAnalysis { auto shapePerCTA = triton::gpu::getShapePerCTA(allocType); auto bytes = product(shapePerCTA) * allocType.getElementTypeBitWidth() / 8; + if (op->hasAttr("allocation.copy")) { + auto copy = cast(op->getAttr("allocation.copy")) + .getValue() + .getZExtValue(); + op->setAttr( + "allocation.size", + IntegerAttr::get(IntegerType::get(op->getContext(), 32), bytes)); + bytes = bytes * copy; + } auto alignment = alloc.getAlignmentOrDefault(); allocation->addBuffer(result, bytes, @@ -251,6 +262,15 @@ class AllocationAnalysis { isa(srcTy.getElementType()) ? elems * kPtrBitWidth / 8 : elems * std::max(8, srcTy.getElementTypeBitWidth()) / 8; + if (op->hasAttr("allocation.copy")) { + auto copy = cast(op->getAttr("allocation.copy")) + .getValue() + .getZExtValue(); + op->setAttr( + "allocation.size", + IntegerAttr::get(IntegerType::get(op->getContext(), 32), bytes)); + bytes = bytes * copy; + } maybeAddScratchBuffer(op, bytes, scratchAlignment); } else if (isa(op)) { @@ -370,8 +390,18 @@ class AllocationAnalysis { // range. auto *op = opScratchIter.first; auto *buffer = opScratchIter.second; - bufferRange.insert({buffer, Interval(operationId.lookup(op), - operationId.lookup(op) + 1)}); + // Extend live range when asyncTaskId is not empty (i.e when we have + // warp spec). + if (getAsyncTaskIds(op).empty()) { + bufferRange.insert({buffer, Interval(operationId.lookup(op), + operationId.lookup(op) + 1)}); + } else { + // FIXME: This range makes scratch buffers used in warp-specialized + // regions conflict with everything else in the program, which is + // too conservative, but safe. A better approach would make them + // conflict with buffers live in other warp-specialized regions. + bufferRange.insert({buffer, Interval(0, operationId.size())}); + } } }; processScratchMemory(allocation->opScratch); @@ -408,6 +438,11 @@ class AllocationAnalysis { auto maxId = std::numeric_limits::min(); std::for_each(liveOperations.begin(), liveOperations.end(), [&](Operation *liveOp) { + if (!getAsyncTaskIds(liveOp).empty()) { + minId = 0; + maxId = operationId.size(); + return; + } if (operationId[liveOp] < minId) { minId = operationId[liveOp]; } diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index bb106238e..1bad78b5a 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -6,6 +6,7 @@ #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" #include namespace mlir { @@ -95,6 +96,16 @@ void MembarAnalysis::visitTerminator(Operation *op, void MembarAnalysis::insertBarrier(Operation *op, OpBuilder *builder) { OpBuilder::InsertionGuard g(*builder); auto barrierOp = builder->create(op->getLoc()); + auto asyncTaskIds = getAsyncTaskIds(op); + if (asyncTaskIds.size() == 1) { + int asyncTaskId = asyncTaskIds[0]; + int barId = asyncTaskId + nameBarrierIdBegin; + assert(barId < nameBarrierIdEnd); + // TODO: Change hard code style of numThreads. + const int numThreads = 128; + barrierOp->setAttr("bar_id", builder->getI64IntegerAttr(barId)); + barrierOp->setAttr("num_threads", builder->getI64IntegerAttr(numThreads)); + } } void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, diff --git a/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp b/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp index aae9faf0e..ac2e77061 100644 --- a/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp +++ b/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp @@ -40,6 +40,7 @@ struct AllocateSharedMemory } if (offset == -1) return; + if (op->hasAttr("allocation.offset")) return; op->setAttr("allocation.offset", IntegerAttr::get(IntegerType::get(ctx, 32), offset)); }); diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index cca2830b0..ecf3d34c7 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -15,6 +15,7 @@ add_triton_library(TritonGPUToLLVM ConvertLayoutOpToLLVM.cpp ControlFlowOpToLLVM.cpp FuncOpToLLVM.cpp + RegReallocOpToLLVM.cpp SPMDOpToLLVM.cpp DecomposeUnsupportedConversions.cpp PrintOpToLLVM.cpp diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 403cac9de..d9a24ddc6 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -213,7 +213,7 @@ struct ConvertLayoutOpConversion auto multiDimRepId = getMultiDimIndex(repId, numReplicates, outOrd); if (repId != 0) { - barrier(); + insertBarrier(rewriter, op); } auto successful = targetInfo.processReplicaUsingStMatrix( rewriter, loc, smemBase, vals, srcTy, @@ -224,7 +224,7 @@ struct ConvertLayoutOpConversion multiDimRepId, inVec, paddedRepShape, origRepShape, outOrd, vals, smemBase); } - barrier(); + insertBarrier(rewriter, op); processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep, multiDimRepId, outVec, paddedRepShape, origRepShape, outOrd, outVals, smemBase); @@ -581,7 +581,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion llvm::MapVector outVals; for (int i = 0; i < iterations; i++) { if (i != 0) - barrier(); + insertBarrier(rewriter, op); auto &inRegs = inRegsForIter[i]; auto &outRegs = outRegsForIter[i]; @@ -605,7 +605,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } } - barrier(); + insertBarrier(rewriter, op); for (int j = 0; j < outSize / iterations; j += scratchConfig.outVec) { auto outRegSlice = outRegs[j]; diff --git a/lib/Conversion/TritonGPUToLLVM/RegReallocOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/RegReallocOpToLLVM.cpp new file mode 100644 index 000000000..d7dca0397 --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/RegReallocOpToLLVM.cpp @@ -0,0 +1,47 @@ +#include "nvidia/include/Dialect/NVGPU/IR/Dialect.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { +struct RegAllocOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::RegAllocOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::RegAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + rewriter.replaceOpWithNewOp( + op, adaptor.getRegCount()); + return success(); + } +}; + +struct RegDeallocOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::RegDeallocOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::RegDeallocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + rewriter.replaceOpWithNewOp( + op, adaptor.getRegCount()); + return success(); + } +}; +} // namespace + +void mlir::triton::populateRegReallocOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + return; +} diff --git a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp index 4d40e0f31..f796775cb 100644 --- a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp +++ b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp @@ -314,10 +314,14 @@ class RewriteTensorPointerPass loadOp.getLoc(), newPtr, newMask, newOther, loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); op->getResult(0).replaceAllUsesWith(newResult); + if (op->getAttr("async_task_id")) + newResult->setAttr("async_task_id", op->getAttr("async_task_id")); } else if (auto storeOp = dyn_cast(op)) { - builder.create(storeOp.getLoc(), newPtr, + auto newOp = builder.create(storeOp.getLoc(), newPtr, storeOp.getValue(), newMask, storeOp.getCache(), storeOp.getEvict()); + if (op->getAttr("async_task_id")) + newOp->setAttr("async_task_id", op->getAttr("async_task_id")); } // Erase the original operation @@ -413,6 +417,7 @@ class RewriteTensorPointerPass auto newForOp = builder.create(op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep(), newIterOperands); + newForOp->setAttrs(op->getAttrs()); // Create value mapping. Note that for tensor pointers, we use identity // mapping. It may refer to a value in the old loop, but we will rewrite it diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index a454fef56..a1c2990fb 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -2717,8 +2717,9 @@ struct CanonicalizeConvertFromAlloc auto convert = op.getSrc().getDefiningOp(); if (!convert) return failure(); - rewriter.replaceOpWithNewOp( + auto newAlloc = rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), convert.getSrc()); + newAlloc->setAttrs(op->getAttrs()); return mlir::success(); } }; @@ -2734,8 +2735,9 @@ struct CanonicalizeConvertFromLocalStore auto convert = op.getSrc().getDefiningOp(); if (!convert) return failure(); - rewriter.replaceOpWithNewOp(op, convert.getSrc(), - op.getDst()); + auto store = rewriter.replaceOpWithNewOp(op, convert.getSrc(), + op.getDst()); + store->setAttrs(op->getAttrs()); return mlir::success(); } }; @@ -2854,8 +2856,10 @@ struct CanonicalizeConvertFromConvert // cvt(cvt(x, type1), type2) -> cvt(x, type2) if (auto cvt = dyn_cast(arg)) { auto srcType = op.getSrc().getType(); - rewriter.replaceOpWithNewOp( + auto origAttrs = op->getAttrs(); + auto newOp = rewriter.replaceOpWithNewOp( op, op->getResultTypes().front(), cvt.getSrc()); + newOp->setAttrs(origAttrs); return success(); } diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index d9bbd51bd..51d482aaa 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -8,6 +8,7 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" #include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/Support/Debug.h" @@ -291,6 +292,7 @@ class BlockedToMMA : public mlir::OpRewritePattern { versionMinor, warpsPerTile, CTALayout, instrShape); } + PatternRewriterWithAsyncTaskIds taskIdRewriter(rewriter, dotOp); auto newRetType = RankedTensorType::get( oldRetType.getShape(), oldRetType.getElementType(), mmaEnc); // convert accumulator @@ -305,7 +307,7 @@ class BlockedToMMA : public mlir::OpRewritePattern { bool allowTranspose = eltType.isF16() || eltType.isBF16(); a = getSharedMemoryMMAOperand(a, rewriter, 0, allowTranspose); b = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose); - newDot = rewriter.create( + newDot = taskIdRewriter.create( dotOp.getLoc(), newRetType, a, b, newAcc, nullptr, dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc(), false); } else { @@ -327,9 +329,9 @@ class BlockedToMMA : public mlir::OpRewritePattern { auto newBType = RankedTensorType::get( oldBType.getShape(), oldBType.getElementType(), newBEncoding); b = rewriter.create(b.getLoc(), newBType, b); - newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, newAcc, - dotOp.getInputPrecision(), - dotOp.getMaxNumImpreciseAcc()); + newDot = taskIdRewriter.create(dotOp.getLoc(), newRetType, a, b, + newAcc, dotOp.getInputPrecision(), + dotOp.getMaxNumImpreciseAcc()); } // convert dot instruction rewriter.replaceOpWithNewOp(dotOp, oldRetType, diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index 99e2ac3c9..d3ef9a145 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_triton_library(TritonGPUTransforms Coalesce.cpp F32DotTC.cpp CombineTensorSelectAndIf.cpp + LoopScheduling.cpp ReduceDataDuplication.cpp OptimizeAccumulatorInit.cpp OptimizeDotOperands.cpp @@ -18,6 +19,11 @@ add_triton_library(TritonGPUTransforms RemoveLayoutConversions.cpp ReorderInstructions.cpp Utility.cpp + TaskIdPropagate.cpp + WSDataPartition.cpp + WSCodePartition.cpp + WSLowering.cpp + PingPong.cpp DEPENDS TritonGPUTransformsIncGen diff --git a/lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp b/lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp new file mode 100644 index 000000000..4b8eef70c --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp @@ -0,0 +1,622 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +#define DEBUG_TYPE "triton-loop-schedule" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace tt = mlir::triton; +namespace ttg = ::mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; +namespace mlir { +namespace triton { +namespace gpu { + +// Begin __FACEBOOK__ CompPipe +static void scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule, + DenseSet &rootUsers, int numStages) { + // Get all loads that are (transitively) used by dot ops and their distance + // to the dot op. + llvm::SmallVector> + loadOpToIndLevelAndUse = + mlir::triton::loadOpsToIndirectionLevelAndUse(forOp); + LLVM_DEBUG({ + LDBG("Found " << loadOpToIndLevelAndUse.size() << " loads to pipeline:"); + for (const auto &[l, i, u] : loadOpToIndLevelAndUse) { + LDBG(" - load: " << *l); + LDBG(" at indirection level: " << i); + LDBG(" used by op: " << *u); + } + }); + if (loadOpToIndLevelAndUse.empty()) + return; + + // Calculate the stage distance between applicable loads. + int maxIndirectionLevel = -1; + for (auto [loadOp, dist, use] : loadOpToIndLevelAndUse) { + maxIndirectionLevel = std::max(maxIndirectionLevel, dist); + } + unsigned stagesBetweenLoads = + ceil(numStages - 2, maxIndirectionLevel + 1); + + tt::CoarseSchedule::Cluster rootUsersCluster = schedule.clusters.newAtFront(); + // Put the root uses of the loads in the last stage. + for (auto &[loadOp, dist, use] : loadOpToIndLevelAndUse) { + // Non-LoadOp(s) are the root uses of all LoadOp(s) and should be + // always present in the opInfo + if (!isa(use)) { + rootUsers.insert(use); + schedule.insert(use, numStages - 1, rootUsersCluster); + } + } + + SmallVector loadsClusters; + for (int i = 0; i < maxIndirectionLevel + 1; i++) { + loadsClusters.push_back(schedule.clusters.newAtBack()); + } + // Assign stages to the loads. + for (auto [loadOp, indLevel, _] : loadOpToIndLevelAndUse) { + int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads; + schedule.insert(loadOp, stage, loadsClusters[indLevel]); + } +} + +static tt::CoarseSchedule::Cluster +schedulePrologueAndEpilogue(scf::ForOp forOp, tt::CoarseSchedule &schedule, + DenseSet &rootUsers, int numStages) { + // afterPrologue : first cluster curently but we will add a cluster at front + // and a cluster at back + tt::CoarseSchedule::Cluster afterPrologue = schedule.clusters.begin(); + + // Look for the IfOp that is in the backward slice any of the currently + // scheduled ops and put it at the beginning of the loop. + DenseMap ifsToStage; + // Go stage by stage. + for (int stage = 0; stage < numStages; stage++) { + for (auto [op, stage_, cluster] : schedule.getOpsInOrder(forOp)) { + if (stage_ != stage) + continue; + SetVector backwardSlice; + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + getBackwardSlice((Operation *)op, &backwardSlice, opt); + + for (auto op : backwardSlice) { + if (auto ifOp = dyn_cast(op)) { + ifsToStage.insert({ifOp, stage}); + } + } + } + } + tt::CoarseSchedule::Cluster prologueCluster = schedule.clusters.newAtFront(); + for (auto [ifOp, stage] : ifsToStage) { + schedule.insert(ifOp, stage, prologueCluster); + } + // Look for the IfOp that is in the forward slice of the root users and put it + // at the end of the loop. + tt::CoarseSchedule::Cluster epilogueCluster = schedule.clusters.newAtBack(); + for (auto rootUser : rootUsers) { + SetVector forwardSlice; + getForwardSlice(rootUser, &forwardSlice); + + int stage = schedule[rootUser].first; + for (auto op : forwardSlice) { + scf::IfOp ifOp = dyn_cast(op); + if (ifOp == nullptr) { + // check if the op is in the body of an if op that's part of the loop + auto parentOp = op->getParentOp(); + if (parentOp != nullptr && + parentOp->getParentOp() == forOp.getOperation()) { + ifOp = dyn_cast(parentOp); + } + } + if (ifOp) { + schedule.insertIfAbsent(ifOp, stage, + epilogueCluster); // after prefetch extracts + } + } + } + return afterPrologue; +} + +static const char *kLoopScheduleAttrName = "tt.loop_schedule"; +std::string getLoopScheduleOrDefault(scf::ForOp forOp) { + if (!forOp->hasAttr(kLoopScheduleAttrName)) + return "default"; + return (cast(forOp->getAttr(kLoopScheduleAttrName))).str(); +} +// End __FACEBOOK__ CompPipe + +static bool isHeavyComputation(Operation *op) { + // include exp2, mulf, addf 1D. Somehow we don't go through reduction + // when checking dependencies + if (!isa(op) && !isa(op) && + !isa(op)) + return false; + auto tensorTy = dyn_cast(op->getOperand(0).getType()); + if (!tensorTy) + return false; + if (tensorTy.getRank() < 1) + return false; + return true; +} + +// Find all consumer_waits needed for a given dot. Assume we have this sequence +// consumer_wait -> subview -> local_load -> dot +// or +// consumer_wait -> subview -> dot +// with TMA +// wait_barrier -> subview -> trans -> dot +// We assume consumer_wait and subview are right next to each other. We start +// from consumer_wait or wait_barrier, find the subview and check if the subview +// feeds into the dot. +static DenseSet getConsumerWaits(Operation *dot, + scf::ForOp forOp) { + llvm::SmallVector deps; + DenseSet seen; + // Get dependencies of the DotOp, stop when hitting Subview or another Dot + std::function dfs = [&](Operation *op, + Operation *baseOp) { + if (!seen.insert(op).second) + return; + if (op != baseOp && + op->hasTrait()) // do not go through Dots + return; + if (isa(op)) { + deps.push_back(op); + return; + } + if (isa(op) || op->hasTrait()) + deps.push_back(op); + + for (Value operand : op->getOperands()) { + Value v = operand; + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + dfs(defOp, baseOp); + } + } + }; + dfs(dot, dot); + DenseSet depSet; + for (auto *op : deps) { + depSet.insert(op); + } + // Go through loop body, check for the sequence. + Operation *currentWait = nullptr; + unsigned seqNum = 0; + DenseSet waits; + for (Operation &op : forOp.getBody()->without_terminator()) { + if (auto wait = dyn_cast(op)) { + currentWait = &op; + seqNum = 1; + continue; + } + if (auto wait = dyn_cast(op)) { + currentWait = &op; + seqNum = 1; + continue; + } + if (currentWait && seqNum == 1) { + if (isa(op)) + continue; + // subview must be next to wait minus some constants + // we should try to associate a barrier with a buffer + if (auto view = dyn_cast(op)) { + seqNum = 2; + if (depSet.count(&op)) + waits.insert(currentWait); + } else { + currentWait = nullptr; + seqNum = 0; + } + continue; + } + } + return waits; +} + +static void +getListOfProducerAcquires(scf::ForOp forOp, + SmallVector &producerAquires) { + auto funcOp = forOp->getParentOfType(); + funcOp.walk([&](scf::ForOp forOp) { + auto taskArr = mlir::getAsyncTaskIds(forOp); + if (taskArr.size() == 1 && taskArr[0] == 0) { + // Producer warp group ForOp. + forOp.walk([&](Operation *op) { + if (isa(op)) + producerAquires.push_back(op); + }); + } + }); +} + +// FIXME: need to know the corresponding wait/release for a given load. +static Operation * +getConsumerReleaseForWait(Operation *wait, scf::ForOp forOp, + SmallVector &producerAquires, + bool firstLoad) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (auto release = dyn_cast(op)) { + if (isa(wait)) { + // TMA case, only match with producerAquires (1st operand). + // For data partitioning, 4 tokens inside the loop. First 2 + // producerAcquires correspond to firstLoad (loadK). Last 2 correspond + // to secondLoad (loadV). + assert(producerAquires.size() == 4); + if (release->getOperand(0) == + producerAquires[firstLoad ? 0 : 2]->getOperand(0)) + return release; + if (release->getOperand(0) == + producerAquires[firstLoad ? 1 : 3]->getOperand(0)) + return release; + continue; + } + bool isMatch = true; + unsigned i = 0; + for (Value operand : wait->getOperands()) { + if (i >= release->getNumOperands()) + break; + if (operand != release->getOperand(i)) { + isMatch = false; + break; + } + ++i; + } + if (isMatch) + return release; + } + } + return nullptr; +} + +#define GEN_PASS_DEF_TRITONGPULOOPSCHEDULING +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPULoopSchedulingPass + : public impl::TritonGPULoopSchedulingBase { +public: + using impl::TritonGPULoopSchedulingBase< + TritonGPULoopSchedulingPass>::TritonGPULoopSchedulingBase; + + // Begin __FACEBOOK__ CompPipe + int getNumStagesOrDefault(scf::ForOp forOp) { + // Use the attribute attached to the loop if it exists otherwise use the + // global control. + if (!forOp->hasAttr(mlir::triton::kNumStagesAttrName)) + return numStages; + return mlir::cast( + forOp->getAttr(mlir::triton::kNumStagesAttrName)) + .getInt(); + } + + tt::CoarseSchedule::Cluster + getDefaultLoopSchedule(scf::ForOp forOp, tt::CoarseSchedule &schedule, + int numStages) { + DenseSet rootUsers; + scheduleLoads(forOp, schedule, rootUsers, numStages); + return schedulePrologueAndEpilogue(forOp, schedule, rootUsers, numStages); + } + + // Check for warp spec consumer group. Assume two dots. + bool + isFlashAttention(scf::ForOp forOp, + llvm::SmallVector> + &loadOpToIndLevelAndUse, + SmallVector &keyOps, + DenseSet &heavyCompOps) { + SmallVector loads; + SmallVector dots; + for (Operation &op : forOp.getBody()->without_terminator()) { + // Check for loop-carried dependencies. + // We have two loadOps, one feeding the first dot, and the other feeding + // the second dot. + if (isa(op)) { + loads.push_back(&op); + } + if (op.hasTrait()) { + dots.push_back(&op); + } + } + // Check for async_task_id. + auto taskArr = mlir::getAsyncTaskIds(forOp); + bool isConsumerWG = taskArr.size() != 1 ? false : taskArr[0] != 0; + if (dots.size() != 2 || (loads.size() != 2 && !isConsumerWG)) + return false; + + Operation *secondDot = dots[1]; + DenseSet seen; + DenseSet tracedDots; + // Make sure there is a dependency path from firstDot to secondDot. + // This means we need to do computation pipelining to break the dependency. + std::function dfs = [&](Operation *op) { + if (!seen.insert(op).second) + return; + for (Value operand : op->getOperands()) { + Value v = operand; + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + if (defOp->hasTrait()) { + // Stop tracing when hitting a dot. + tracedDots.insert(defOp); + } else { + if (isHeavyComputation(defOp)) + heavyCompOps.insert(defOp); + dfs(defOp); + } + } + } + }; + dfs(secondDot); + if (tracedDots.size() != 1) + return false; + + for (auto [loadOp, dist, use] : loadOpToIndLevelAndUse) { + if (dist != 0) + return false; + } + + keyOps.push_back(loads.size() == 0 ? nullptr : loads[0]); // FIXME + keyOps.push_back(loads.size() == 0 ? nullptr : loads[1]); + keyOps.push_back(dots[0]); + keyOps.push_back(secondDot); + return true; + } + + tt::CoarseSchedule::Cluster + getFAFirstDotSchedule(scf::ForOp forOp, tt::CoarseSchedule &schedule, + int numStages) { + llvm::SmallVector> + loadOpToIndLevelAndUse = + mlir::triton::loadOpsToIndirectionLevelAndUse(forOp); + LLVM_DEBUG({ + LDBG("Found " << loadOpToIndLevelAndUse.size() << " loads to pipeline:"); + for (const auto &[l, i, u] : loadOpToIndLevelAndUse) { + LDBG(" - load: " << *l); + LDBG(" at indirection level: " << i); + LDBG(" used by op: " << *u); + } + }); + // if (loadOpToIndLevelAndUse.empty()) + // return schedule.clusters.begin(); + + // Check to see if the for loop matches the pattern for flash attention. + // If yes, move the first dot to its own stage (numStages - 2), the + // rest of the computation will be in stage (numStages - 1). The two loads + // will be in stage 0 and 1. + SmallVector keyOps; + DenseSet heavyCompOps; + if (!isFlashAttention(forOp, loadOpToIndLevelAndUse, keyOps, + heavyCompOps)) { + LDBG("isFlashAttention returns false"); + return schedule.clusters.begin(); + } + + // firstLoad: keyOps[0] + tt::CoarseSchedule::Cluster rootUsersCluster = + schedule.clusters.newAtFront(); + tt::CoarseSchedule::Cluster loadCluster = schedule.clusters.newAtBack(); + bool isConsumerWG = keyOps[0] == nullptr; + if (!isConsumerWG) { + schedule.insert(keyOps[0], 0, loadCluster); + schedule.insert(keyOps[1], 1, loadCluster); + } else { + // Check producer warp group to get the list of ProducerAcquires (assume + // they are in order matching firstLoad and secondLoad). Then match + // ConsumerReleases with them. With TMA, align consumerRleases with + // consumerWaits, assuming consumerWaits happen in order matching + // firstLoad and secondLoad. + SmallVector producerAquires; + getListOfProducerAcquires(forOp, producerAquires); + // dependency from consumer_wait to subview, then to consumer_release + // Assume this group of ops: consumer_wait, subview, local_load. Find the + // corresponding consumer_release for the consumer_wait by checking the + // operands. The local_load needed by firstDot will be in the same stage + // cluseter as firstDot. + DenseSet ConsumerWaitsForDot1 = + getConsumerWaits(keyOps[2], forOp); + for (auto *op : ConsumerWaitsForDot1) { + schedule.insert(op, isConsumerWG ? 0 : numStages - 2, rootUsersCluster); + Operation *consumerRelease = + getConsumerReleaseForWait(op, forOp, producerAquires, true); + schedule.insert(consumerRelease, isConsumerWG ? 0 : numStages - 2, + rootUsersCluster); + LLVM_DEBUG({ + LDBG("firstDot wait "); + op->dump(); + LDBG("firstDot release "); + consumerRelease->dump(); + }); + } + DenseSet ConsumerWaitsForDot2 = + getConsumerWaits(keyOps[3], forOp); + for (auto *op : ConsumerWaitsForDot2) { + schedule.insert(op, numStages - 1, rootUsersCluster); + Operation *consumerRelease = + getConsumerReleaseForWait(op, forOp, producerAquires, false); + schedule.insert(consumerRelease, numStages - 1, rootUsersCluster); + LLVM_DEBUG({ + LDBG("secondDot wait "); + op->dump(); + LDBG("secondDot release "); + consumerRelease->dump(); + }); + } + } + schedule.insert(keyOps[2], isConsumerWG ? 0 : numStages - 2, + rootUsersCluster); + schedule.insert(keyOps[3], numStages - 1, rootUsersCluster); + return schedule.clusters.begin(); + } + + tt::CoarseSchedule::Cluster + getFASecondDotSchedule(scf::ForOp forOp, tt::CoarseSchedule &schedule, + int numStages) { + llvm::SmallVector> + loadOpToIndLevelAndUse = + mlir::triton::loadOpsToIndirectionLevelAndUse(forOp); + LLVM_DEBUG({ + LDBG("Found " << loadOpToIndLevelAndUse.size() << " loads to pipeline:"); + for (const auto &[l, i, u] : loadOpToIndLevelAndUse) { + LDBG(" - load: " << *l); + LDBG(" at indirection level: " << i); + LDBG(" used by op: " << *u); + } + }); + // if (loadOpToIndLevelAndUse.empty()) + // return schedule.clusters.begin(); + + // Check to see if the for loop matches the pattern for flash attention. + // If yes, move the second dot to its own stage (numStages - 1), the + // rest of the computation will be in stage (numStages - 2). The two loads + // will be in stage 0 and 1. + SmallVector keyOps; + DenseSet heavyCompOps; + if (!isFlashAttention(forOp, loadOpToIndLevelAndUse, keyOps, + heavyCompOps)) { + LDBG("isFlashAttention returns false"); + return schedule.clusters.begin(); + } + // Go through loop body + for (Operation &op : forOp.getBody()->without_terminator()) { + if (isHeavyComputation(&op)) + heavyCompOps.insert(&op); + } + // keyOps: load0, load1, dot0, dot1 + // Dot0(i+1) + // Dot1(i) + // Softmax(i+1): includes MUL0(i+1) + // MUL1(i+1) + tt::CoarseSchedule::Cluster rootUsersCluster = + schedule.clusters.newAtFront(); + tt::CoarseSchedule::Cluster nextCluster = schedule.clusters.newAtBack(); + tt::CoarseSchedule::Cluster nextNextCluster = schedule.clusters.newAtBack(); + tt::CoarseSchedule::Cluster loadCluster = schedule.clusters.newAtBack(); + bool isConsumerWG = keyOps[0] == nullptr; + if (!isConsumerWG) { + schedule.insert(keyOps[0], 0, loadCluster); + schedule.insert(keyOps[1], 1, loadCluster); + } else { + SmallVector producerAquires; + getListOfProducerAcquires(forOp, producerAquires); + + DenseSet ConsumerWaitsForDot1 = + getConsumerWaits(keyOps[2], forOp); + for (auto *op : ConsumerWaitsForDot1) { + schedule.insert(op, isConsumerWG ? 0 : numStages - 2, rootUsersCluster); + Operation *consumerRelease = + getConsumerReleaseForWait(op, forOp, producerAquires, true); + assert(consumerRelease); + schedule.insert(consumerRelease, isConsumerWG ? 0 : numStages - 2, + rootUsersCluster); + LLVM_DEBUG({ + LDBG("firstDot wait "); + op->dump(); + LDBG("firstDot release "); + consumerRelease->dump(); + }); + } + DenseSet ConsumerWaitsForDot2 = + getConsumerWaits(keyOps[3], forOp); + for (auto *op : ConsumerWaitsForDot2) { + schedule.insert(op, numStages - 1, nextCluster); + Operation *consumerRelease = + getConsumerReleaseForWait(op, forOp, producerAquires, false); + schedule.insert(consumerRelease, numStages - 1, nextCluster); + LLVM_DEBUG({ + LDBG("secondDot wait "); + op->dump(); + LDBG("secondDot release "); + consumerRelease->dump(); + }); + } + } + schedule.insert(keyOps[2], isConsumerWG ? 0 : numStages - 2, + rootUsersCluster); + schedule.insert(keyOps[3], numStages - 1, nextCluster); + // Softmax(i+1), MUL1(i+1) in nextNextCluster + for (auto *heavyOp : heavyCompOps) + schedule.insert(heavyOp, isConsumerWG ? 0 : numStages - 2, + nextNextCluster); + return schedule.clusters.begin(); + } + // End __FACEBOOK__ CompPipe + + void runOnOperation() override { + // Begin __FACEBOOK__ CompPipe + SmallVector loops; + getOperation()->walk([&](scf::ForOp forOp) { + // Bail out for loops with num_stage <= 1 or loop without loop_schedule + if (getNumStagesOrDefault(forOp) > 1 && + forOp->hasAttr(kLoopScheduleAttrName)) + loops.push_back(forOp); + }); + + if (loops.empty()) + return; + for (scf::ForOp forOp : loops) { + int loopNumStages = getNumStagesOrDefault(forOp); + tt::CoarseSchedule coarseSchedule(loopNumStages); + tt::CoarseSchedule::Cluster afterPrologue; + + std::string loopSchedule = getLoopScheduleOrDefault(forOp); + if (loopSchedule == "default") { + afterPrologue = + getDefaultLoopSchedule(forOp, coarseSchedule, loopNumStages); + } else if (loopSchedule == "FA_firstDot") { + afterPrologue = + getFAFirstDotSchedule(forOp, coarseSchedule, loopNumStages); + } else if (loopSchedule == "FA_secondDot") { + afterPrologue = + getFASecondDotSchedule(forOp, coarseSchedule, loopNumStages); + } else { + assert(false && "unrecognized loop schedule"); + } + // Go through schedule and assign (stage, cluster). + // shift so afterPrologue will be at clusterId 0 + auto ctx = forOp.getContext(); + for (auto [op, stage_, cluster] : coarseSchedule.getOpsInOrder(forOp)) { + op->setAttr("loop.stage", + IntegerAttr::get(IntegerType::get(ctx, 32), stage_)); + op->setAttr("loop.cluster", + IntegerAttr::get(IntegerType::get(ctx, 32), + *cluster - *afterPrologue)); + LLVM_DEBUG({ + LDBG("set stage " << stage_ << " cluster " << (*cluster)); + op->dump(); + }); + } + } + // End __FACEBOOK__ CompPipe + return; + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/PingPong.cpp b/lib/Dialect/TritonGPU/Transforms/PingPong.cpp new file mode 100644 index 000000000..757eba67e --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/PingPong.cpp @@ -0,0 +1,186 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include + +#define DEBUG_TYPE "triton-ping-pong-sync" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace tt = mlir::triton; +namespace ttng = ::mlir::triton::nvidia_gpu; +namespace mlir { +namespace triton { +namespace gpu { + +// Returns the taskId if op has a single taskId, otherwise, returns -1. +static int getSingleTaskId(Operation *op) { + if (!op->hasAttr("async_task_id")) + return -1; + auto taskArray = op->getAttrOfType("async_task_id"); + if (taskArray.getValues().size() > 1) + return -1; + return taskArray.getValues()[0]; +} + +// Treat exp2, mulf, addf, reduce as expensive computation when data type is +// a tensor type of 1D or higher. +static bool isExpensiveComp(Operation *op) { + if (!isa(op) && !isa(op) && + !isa(op) && !isa(op)) + return false; + auto tensorTy = dyn_cast(op->getOperand(0).getType()); + return tensorTy && tensorTy.getRank() >= 1; +} + +static Value createGetAsyncTaskId(OpBuilder &builder, Operation *op) { + auto loc = op->getLoc(); + return builder.create(loc); +} + +#define GEN_PASS_DEF_TRITONGPUPINGPONGSYNC +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUPingPongSyncPass + : public impl::TritonGPUPingPongSyncBase { +public: + using impl::TritonGPUPingPongSyncBase< + TritonGPUPingPongSyncPass>::TritonGPUPingPongSyncBase; + + enum class ResourceType { + Gemm, + OtherComp, + }; + + void getNestedFor(scf::IfOp ifOp, SmallVector &loops) { + ifOp->walk([&](scf::ForOp forOp) { loops.push_back(forOp); }); + } + void runOnFuncOp(triton::FuncOp funcOp) { + // Insert sync points in ForOp for consumer warp groups. Enable this pass + // when number of consumer warp groups == 2. + if (numConsumerGroups != 2) + return; + + SmallVector loops; + // Identify ForOps for consumer warp groups. Here we assume taskId 0 is for + // producer. This pass handles the case of a single forOp for two consumer + // warp groups. + getOperation()->walk([&](scf::IfOp ifOp) { + int wgId = getSingleTaskId(ifOp); + // Assume taskId 0 is for producer. + if (wgId == 1 || wgId == 2) { + getNestedFor(ifOp, loops); + } + }); + + if (!mlir::triton::tools::getBoolEnv("ENABLE_PINGPONG")) + return; + if (loops.size() != 1) + return; + + Operation *startOfGemm = nullptr; + Operation *endOfGemm = nullptr; + // FIXME: only handle the first loop. + auto forOp = loops[0]; + OpBuilder builder(forOp); + // A simple heuristic for now: + // Mark the start of a gemm section when hitting a DotLike op. + // Mark the end of a gemm section once hitting a expensive cuda op. + for (auto &op : forOp.getBody()->without_terminator()) { + if (startOfGemm && endOfGemm) + break; + bool isCudaCore = isExpensiveComp(&op); + if (op.hasTrait() && !isCudaCore && + startOfGemm == nullptr) { + startOfGemm = &op; + continue; + } + if (!op.hasTrait() && isCudaCore && startOfGemm) { + endOfGemm = &op; + break; + } + } + if (!startOfGemm || !endOfGemm) + return; + + LLVM_DEBUG({ + LDBG("found start of tensor core ops"); + startOfGemm->dump(); + }); + LLVM_DEBUG({ + LDBG("found end of tensor core ops"); + endOfGemm->dump(); + }); + + // FIXME: hard-code using named barrier 9 and 10 in this pass. + // Prior to the forOp, add "bar.arrive 9, 256" only when task Id is 2. + // At startOfGemm, insert "bar.sync 8+taskId, 256" + // At endOfGemm, insert "bar.arrive 11-taskId, 256" + builder.setInsertionPoint(forOp); + auto forLoc = forOp->getLoc(); + + // FIXME: hard-code total number of threads to be 256 when numConsumerGroups + // is 2. + Value numThreads = builder.create(forLoc, 256, 32); + Value c_9 = builder.create(forLoc, 9, 32); + + // "bar.arrive 9, 256" only when task Id is 2. + Value c_2 = builder.create(forLoc, 2, 32); + Value curTaskId = createGetAsyncTaskId(builder, forOp); + auto pred = builder.create(forLoc, arith::CmpIPredicate::eq, + curTaskId, c_2); + auto ifOp = builder.create(forLoc, pred, /*else=*/false); + builder.setInsertionPoint(ifOp.thenYield()); + builder.create(forLoc, c_9, numThreads); + + // At startOfGemm, insert "bar.sync 8+taskId, 256" + // 8 + taskId: 9 for taskId 1 and 10 for taskId 2. + builder.setInsertionPoint(startOfGemm); + auto loc = startOfGemm->getLoc(); + Value c_8 = builder.create(loc, 8, 32); + Value syncBarrier = builder.create(loc, c_8, curTaskId); + builder.create(loc, syncBarrier, numThreads); + + // At endOfGemm, insert "bar.arrive 11-taskId, 256" + // 11 - taskId: 10 for taskId 1 and 9 for taskId2. + builder.setInsertionPoint(endOfGemm); + auto loc2 = endOfGemm->getLoc(); + Value c_11 = builder.create(loc2, 11, 32); + Value arriveBarrier = builder.create(loc2, c_11, curTaskId); + builder.create(loc2, arriveBarrier, numThreads); + } + + void runOnOperation() override { + getOperation()->walk([&](triton::FuncOp funcOp) { runOnFuncOp(funcOp); }); + LLVM_DEBUG({ + LDBG("post pass"); + getOperation()->dump(); + }); + return; + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 5cc537d5f..54cdf22a1 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -16,6 +16,8 @@ #include "triton/Dialect/TritonGPU/Transforms/Schedule.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" @@ -56,7 +58,8 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, tt::CoarseSchedule &schedule, tt::CoarseSchedule::Cluster prefetchCluster, llvm::MapVector &loadToInfo, - int numStages) { + int numStages, + DenseMap &TMAUserToWait) { OpBuilder builder(forOp); Value zero = builder.create(forOp.getLoc(), 0, 32); // Replace the load with insert/extract slice. @@ -113,6 +116,7 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, loadOffsets[0] = extractIdx; auto viewLoad = builder.create(loc, subviewTy, alloc, loadOffsets); + TMAUserToWait[viewLoad] = wait; // viewLoad will depend on barrierWait if (isMMV3Load) { auto alloc = cast((*loadOp->getUsers().begin())); replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); @@ -157,7 +161,8 @@ static void createTMAAsyncCopy( scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp, Value alloc, Value insertIdx, Value extractIdx, Value barrier, Operation *waitOp, Value phase, tt::CoarseSchedule &schedule, - llvm::MapVector &loadToInfo, int numStages) { + llvm::MapVector &loadToInfo, int numStages, + DenseMap &TMAUserToWait) { assert(phase && "Phase value is required for TMA async copy."); OpBuilder builder(forOp); Attribute sharedMemorySpace = @@ -189,6 +194,7 @@ static void createTMAAsyncCopy( loadOffsets[0] = extractIdx; auto viewLoad = builder.create(loc, subviewTy, alloc, loadOffsets); + TMAUserToWait[viewLoad] = waitOp; // viewLoad will depend on barrierWait if (isMMV3Load) { auto alloc = cast((*loadOp->getUsers().begin())); replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); @@ -590,6 +596,156 @@ scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule, return loadToInfo; } +// Begin __FACEBOOK__ CompPipe +static bool loopHasSchedule(scf::ForOp forOp) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (op.hasAttr("loop.stage") && op.hasAttr("loop.cluster")) { + return true; + } + } + return false; +} + +static tt::CoarseSchedule::Cluster +getLoopSchedule(scf::ForOp forOp, tt::CoarseSchedule &schedule, + /*DenseSet &rootUsers,*/ int numStages, + llvm::MapVector &loadToInfo) { + ModuleOp moduleOp = forOp->getParentOfType(); + tt::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + tt::CoarseSchedule::Cluster afterPrologue = schedule.clusters.begin(); + auto taskArr = mlir::getAsyncTaskIds(forOp); + // We either have a single task Id with a merged IfOp for all consumers + // or we have one task Id for each IfOp per consumer. + // We should not see a list of task Ids here. + bool isConsumerWG = taskArr.size() != 1 ? false : taskArr[0] != 0; + + // Get all loads that are (transitively) used by dot ops and their distance + // to the dot op. + llvm::SmallVector> + loadOpToIndLevelAndUse = + mlir::triton::loadOpsToIndirectionLevelAndUse(forOp); + LLVM_DEBUG({ + LDBG("Found " << loadOpToIndLevelAndUse.size() << " loads to pipeline:"); + for (const auto &[l, i, u] : loadOpToIndLevelAndUse) { + LDBG(" - load: " << *l); + LDBG(" at indirection level: " << i); + LDBG(" used by op: " << *u); + } + }); + bool dataPartition = mlir::triton::tools::getBoolEnv("SWP_FOR_CONSUMER"); + // When there are no load operations, continue computation pipelining if + // dataPartition is true and isConsumerWG is true. Early exit otherwise. + if (!(isConsumerWG && dataPartition) && loadOpToIndLevelAndUse.empty()) + return afterPrologue; + + // Check which loads are good for pipelining, and assign them + // memory layouts. + llvm::MapVector loadToInfoT = + assignMemoryLayouts(loadOpToIndLevelAndUse, axisInfoAnalysis); + loadToInfo = loadToInfoT; + + if (!(isConsumerWG && dataPartition) && loadToInfo.empty()) + return afterPrologue; + + // reconstrcut schedule from annotations of (stage, cluster) + int maxClusterId = 0, minClusterId = 0; + bool hasSchedule = false; + for (Operation &op : forOp.getBody()->without_terminator()) { + if (op.hasAttr("loop.stage") && op.hasAttr("loop.cluster")) { + auto clusterId = cast(op.getAttr("loop.cluster")) + .getValue() + .getSExtValue(); + LLVM_DEBUG({ + LDBG("saw cluster " << clusterId); + op.dump(); + }); + if (!hasSchedule) { + minClusterId = clusterId; + maxClusterId = clusterId; + hasSchedule = true; + continue; + } + minClusterId = (clusterId < minClusterId) ? clusterId : minClusterId; + maxClusterId = (clusterId > maxClusterId) ? clusterId : maxClusterId; + } + } + assert(hasSchedule); + LDBG("minCluster " << minClusterId << " max " << maxClusterId); + DenseMap clusters; + for (int i = minClusterId; i < maxClusterId + 1; i++) { + clusters.insert({i, schedule.clusters.newAtBack()}); + } + // afterPrologue should be the first cluster after ifOps? + for (Operation &op : forOp.getBody()->without_terminator()) { + if (op.hasAttr("loop.stage") && op.hasAttr("loop.cluster")) { + auto stage = + cast(op.getAttr("loop.stage")).getValue().getZExtValue(); + auto clusterId = cast(op.getAttr("loop.cluster")) + .getValue() + .getSExtValue(); + schedule.insert(&op, stage, clusters[clusterId]); + LLVM_DEBUG({ + LDBG("insert stage " << stage << " cluster " << clusterId << " " + << *clusters[clusterId]); + op.dump(); + }); + } + } + + // Distance from the load to the use. This needs to be re-worked. + if (forOp->hasAttr(tt::kNumStagesAttrName)) { + for (auto [loadOp, _, use] : loadOpToIndLevelAndUse) { + if (loadToInfo.count(loadOp) == 0) + continue; + loadToInfo[loadOp].distToUse = + schedule[use].first - schedule[loadOp].first; + } + return clusters[0]; + } + + // If there is a use chain of load -> dot -> dot, we can ignore the second dot + // here. + // Start from loadOp, check uses and stop the recursion when hitting a dot. + DenseSet seen; + llvm::SmallVector> loadOpToDirectUses; + std::function dfsUse = + [&](Operation *op, Operation *use) { + if (!seen.insert(use).second) + return; + if (use->hasTrait()) { + loadOpToDirectUses.push_back(std::make_tuple(op, use)); + return; + } + for (auto &tUse : use->getUses()) { + Operation *useOp = tUse.getOwner(); + if (useOp && useOp->getBlock() == op->getBlock()) { + dfsUse(op, useOp); + } + } + }; + DenseSet loadOps; + for (auto [loadOp, _, use] : loadOpToIndLevelAndUse) { + if (loadToInfo.count(loadOp) == 0) + continue; + if (!loadOps.insert(loadOp).second) + continue; + seen.clear(); + dfsUse(loadOp, loadOp); + } + for (auto [loadOp, use] : loadOpToDirectUses) { + LLVM_DEBUG({ + LDBG("loadOpToDirectUses " << schedule[use].first << " " + << schedule[loadOp].first); + loadOp->dump(); + use->dump(); + }); + loadToInfo[loadOp].distToUse = schedule[use].first - schedule[loadOp].first; + } + + return clusters[0]; +} +// End __FACEBOOK__ CompPipe + // Schedule the prologue and epilogue `if` ops in the loop, pushing them as // close to the loop boundaries as possible. Return the cluster after the // prologue (or the beginning of the loop if there is no prologue). @@ -652,8 +808,10 @@ schedulePrologueAndEpilogue(scf::ForOp forOp, tt::CoarseSchedule &schedule, // Add dependencies of anchor ops to the coarse schedule. Schedule them to // the same stage and ordering cluster as the anchor op. -static void scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule, - int numStages) { +static void +scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule, + int numStages, + DenseMap &TMAUserToWait) { SmallVector> opsInOrder = schedule.getOpsInOrder(forOp); // Schedule dependencies stage by stage. @@ -661,7 +819,7 @@ static void scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule, for (auto [op, stage_, cluster] : opsInOrder) { if (stage_ != stage) continue; - schedule.insertDepsOfOp(op, stage, cluster, false); + schedule.insertDepsOfOp(op, stage, cluster, false, &TMAUserToWait); } } } @@ -818,7 +976,7 @@ struct AsyncLoad { }; // Create barriers and wait ops for the async loads. Barriers may be shared by -// multiple loads is the schedule allows it. +// multiple loads if the schedule allows it. static void createTMABarrierAndWait( scf::ForOp &forOp, SmallVector &asyncLoads, Value insertIdx, Value extractIdx, Value phase, int numBuffers, tt::CoarseSchedule &schedule, @@ -905,7 +1063,7 @@ static void createTMABarrierAndWait( Value pred = builder.create(loc, 1, 1); Operation *expect = builder.create( forOp.getLoc(), barrier, sizeInBytes, pred); - auto [stage, cluster] = schedule[asyncLoads[0].loadOp]; + auto [stage, cluster] = schedule[group[0]->loadOp]; schedule.insert(expect, stage, cluster); builder.setInsertionPointAfter(group.back()->loadOp); @@ -926,7 +1084,8 @@ static void createTMABarrierAndWait( static SmallVector createAsyncOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule, llvm::MapVector &loadToInfo, - SmallVector &barriers, int numStages) { + SmallVector &barriers, int numStages, + DenseMap &TMAUserToWait) { // Calculate the number of buffers needed for each load. // TODO pawel: we could do more fine-grained allocation here and // allocate only the number of buffers that specific loads need. @@ -1017,12 +1176,13 @@ createAsyncOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule, for (AsyncLoad &asyncLoad : asyncLoads) { if (auto loadOp = dyn_cast(asyncLoad.loadOp)) { createAsyncCopy(forOp, loadOp, asyncLoad.alloc, insertIdx, extractIdx, - schedule, prefetchCluster, loadToInfo, numStages); + schedule, prefetchCluster, loadToInfo, numStages, + TMAUserToWait); } else { auto descLoad = cast(asyncLoad.loadOp); createTMAAsyncCopy(forOp, descLoad, asyncLoad.alloc, insertIdx, extractIdx, asyncLoad.barrier, asyncLoad.waitOp, phase, - schedule, loadToInfo, numStages); + schedule, loadToInfo, numStages, TMAUserToWait); } } SmallVector newYieldOperands = {insertIdx, extractIdx}; @@ -1060,9 +1220,24 @@ bool mlir::triton::preProcessLoopAndGetSchedule( // a scaffold for the final schedule. DenseSet rootUsers; tt::CoarseSchedule coarseSchedule(numStages); - llvm::MapVector loadToInfo = - scheduleLoads(forOp, coarseSchedule, rootUsers, numStages); - if (loadToInfo.empty()) + // Begin __FACEBOOK__ CompPipe + bool hasSchedule = loopHasSchedule(forOp); + llvm::MapVector loadToInfo; + tt::CoarseSchedule::Cluster afterPrologue; + if (!hasSchedule) { + loadToInfo = scheduleLoads(forOp, coarseSchedule, rootUsers, numStages); + } else { + afterPrologue = getLoopSchedule(forOp, coarseSchedule, + /*rootUsers,*/ numStages, loadToInfo); + } + // vanilla + // llvm::MapVector loadToInfo = + // scheduleLoads(forOp, coarseSchedule, rootUsers, numStages); + // End __FACEBOOK__ CompPipe + auto taskArr = mlir::getAsyncTaskIds(forOp); + bool isConsumerWG = taskArr.size() != 1 ? false : taskArr[0] != 0; + bool dataPartition = mlir::triton::tools::getBoolEnv("SWP_FOR_CONSUMER"); + if (!(isConsumerWG && dataPartition) && loadToInfo.empty()) return false; LLVM_DEBUG({ @@ -1070,24 +1245,33 @@ bool mlir::triton::preProcessLoopAndGetSchedule( coarseSchedule.dump(); }); - tt::CoarseSchedule::Cluster afterPrologue = - schedulePrologueAndEpilogue(forOp, coarseSchedule, rootUsers, numStages); + // Begin __FACEBOOK__ CompPipe + if (!hasSchedule) { + afterPrologue = schedulePrologueAndEpilogue(forOp, coarseSchedule, + rootUsers, numStages); + } + // vanilla + // tt::CoarseSchedule::Cluster afterPrologue = + // schedulePrologueAndEpilogue(forOp, coarseSchedule, rootUsers, + // numStages); + // End __FACEBOOK__ CompPipe LLVM_DEBUG({ LDBG("Coarse schedule with prologue and epilogue:"); coarseSchedule.dump(); }); SmallVector barriers; + DenseMap TMAUserToWait; // Convert the loads into async loads and create the allocs. - SmallVector allocs = - createAsyncOps(forOp, coarseSchedule, loadToInfo, barriers, numStages); + SmallVector allocs = createAsyncOps( + forOp, coarseSchedule, loadToInfo, barriers, numStages, TMAUserToWait); LLVM_DEBUG({ LDBG("Coarse schedule with async loads:"); coarseSchedule.dump(); }); - scheduleDependencies(forOp, coarseSchedule, numStages); + scheduleDependencies(forOp, coarseSchedule, numStages, TMAUserToWait); LLVM_DEBUG({ LDBG("Coarse schedule with dependencies:"); coarseSchedule.dump(); @@ -1116,7 +1300,10 @@ bool mlir::triton::preProcessLoopAndGetSchedule( std::vector> &s) { s = std::move(schedule); }; - options.peelEpilogue = false; + bool hasLoopSchedule = forOp->hasAttr("tt.loop_schedule"); + bool PeelLastIter = + ::triton::tools::getBoolEnv("PEEL_LAST_ITER") && hasLoopSchedule; + options.peelEpilogue = PeelLastIter ? true : false; options.predicateFn = tt::predicateOp; options.supportDynamicLoops = true; options.annotateFn = [](Operation *op, @@ -1467,6 +1654,10 @@ static std::optional dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp, return iterArgIdx; } + if (::triton::tools::getBoolEnv("HACK_ASYNC_DOT")) { + return iterArgIdx; + } + // Rule 3b: Are all users of the dot's result from iteration i-1 after the // first `warp_group_dot_wait {pendings=0}` op? If so, the dot can be // properly async, but we have to thread its result from iteration i-1 through diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp index 2f186e3c5..6be8745e5 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp @@ -32,6 +32,7 @@ #include "llvm/Support/MathExtras.h" #include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "triton/Tools/Sys/GetEnv.hpp" #define DEBUG_TYPE "triton-loop-pipelining" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") @@ -452,8 +453,10 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop( Type t = ub.getType(); Location loc = forOp.getLoc(); // newUb = ub - maxStage * step + // peel last iteration of ops, newUb = ub - step + bool PeelLastIter = ::triton::tools::getBoolEnv("PEEL_LAST_ITER") && peelEpilogue; Value maxStageValue = rewriter.create( - loc, rewriter.getIntegerAttr(t, maxStage)); + loc, rewriter.getIntegerAttr(t, PeelLastIter ? 1 : maxStage)); Value maxStageByStep = rewriter.create(loc, step, maxStageValue); newUb = rewriter.create(loc, ub, maxStageByStep); @@ -461,6 +464,7 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop( auto newForOp = rewriter.create(forOp.getLoc(), forOp.getLowerBound(), newUb, forOp.getStep(), newLoopArg); + newForOp->setAttrs(forOp->getAttrs()); // When there are no iter args, the loop body terminator will be created. // Since we always create it below, remove the terminator if it was created. if (!newForOp.getBody()->empty()) @@ -485,11 +489,16 @@ LogicalResult LoopPipelinerInternal::createKernel( mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); } SmallVector predicates(maxStage + 1, nullptr); - if (!peelEpilogue) { + bool PeelLastIter = ::triton::tools::getBoolEnv("PEEL_LAST_ITER") && peelEpilogue; + if (!peelEpilogue || PeelLastIter) { // Create a predicate for each stage except the last stage. Location loc = newForOp.getLoc(); Type t = ub.getType(); - for (unsigned i = 0; i < maxStage; i++) { + // predicates[i] = indVar < c = indVar < ub - (maxStage - i) * step + // if peeling last iteration only, S2 should always be executed. + // only create predicates for S0 to S1 + int iEnd = PeelLastIter ? maxStage - 1 : maxStage; + for (unsigned i = 0; i < iEnd; i++) { // c = ub - (maxStage - i) * step Value c = rewriter.create( loc, ub, @@ -619,12 +628,29 @@ LogicalResult LoopPipelinerInternal::createKernel( // If there is a live range spanning across more than 2 stages we need to // add extra arg. for (unsigned i = 1; i < numVersionReturned; i++) { + LLVM_DEBUG({ + llvm::dbgs() << "set valueMapping3: version " << version + << " lastUseStage " << it.second.lastUseStage + << " defStage " << it.second.defStage << " "; + it.first.dump(); + }); setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), version++); yieldOperands.push_back( newForOp.getBody()->getArguments()[yieldOperands.size() + 1 + newForOp.getNumInductionVars()]); } + // Map [key, version] to result of newForOp. + if (PeelLastIter && it.second.lastUseStage == maxStage) { + // we only need version maxStage for ops in stage maxStage + version += maxStage - 1; // loop body contains the first epilogue + } + LLVM_DEBUG({ + llvm::dbgs() << "set valueMapping: version " << version + << " lastUseStage " << it.second.lastUseStage << " defStage " + << it.second.defStage << " "; + it.first.dump(); + }); setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), version++); yieldOperands.push_back(mapping.lookupOrDefault(it.first)); @@ -640,10 +666,14 @@ LogicalResult LoopPipelinerInternal::createKernel( for (unsigned int stage = 1; stage <= maxStage; stage++) setValueMapping(forOp.getRegionIterArgs()[retVal.index()], retVal.value(), stage); - } else if (defStage->second > 0) { + } else if (defStage->second > 0 && + (!PeelLastIter || defStage->second > maxStage - 1)) { + // If PeelLastIter is false, no change. If it is true, only enter when + // defStage->second is bigger than 1. setValueMapping(forOp.getRegionIterArgs()[retVal.index()], newForOp->getResult(retVal.index()), - maxStage - defStage->second + 1); + maxStage - defStage->second + 1 + + (PeelLastIter ? maxStage - 1 : 0)); } } rewriter.create(forOp.getLoc(), yieldOperands); @@ -693,17 +723,33 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter, // Emit `maxStage - 1` epilogue part that includes operations from stages // [i; maxStage]. - for (int64_t i = 1; i <= maxStage; i++) { + bool PeelLastIter = ::triton::tools::getBoolEnv("PEEL_LAST_ITER") && peelEpilogue; + for (int64_t i = PeelLastIter ? maxStage : 1; i <= maxStage; i++) { SmallVector> returnMap(returnValues.size()); for (Operation *op : opOrder) { if (stages[op] < i) continue; + LLVM_DEBUG({ + llvm::errs() << "clone "; + op->dump(); + }); unsigned currentVersion = maxStage - stages[op] + i; unsigned nextVersion = currentVersion + 1; Operation *newOp = cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { auto it = valueMapping.find(newOperand->get()); if (it != valueMapping.end()) { + LLVM_DEBUG({ + llvm::errs() << "find valueMapping: version " + << (maxStage - stages[op] + i) << " "; + newOperand->get().dump(); + unsigned tmp = 0; + for (auto v : it->second) { + llvm::errs() << "idx " << tmp << ": "; + v.dump(); + ++tmp; + } + }); Value replacement = it->second[currentVersion]; newOperand->set(replacement); } diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp index 1a3162f17..eabb5fe7c 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -81,6 +81,37 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op, return op; } + if (isa(op)) + return op; + if (auto wait = dyn_cast(op)) { + rewriter.setInsertionPoint(wait); + auto ifOp = + rewriter.create(wait->getLoc(), pred, /*else=*/false); + rewriter.moveOpBefore(wait, ifOp.thenBlock(), ifOp.thenBlock()->begin()); + return ifOp; + } + if (auto wait = dyn_cast(op)) { + rewriter.setInsertionPoint(wait); + auto ifOp = + rewriter.create(wait->getLoc(), pred, /*else=*/false); + rewriter.moveOpBefore(wait, ifOp.thenBlock(), ifOp.thenBlock()->begin()); + return ifOp; + } + if (auto release = dyn_cast(op)) { + rewriter.setInsertionPoint(release); + auto ifOp = + rewriter.create(release->getLoc(), pred, /*else=*/false); + rewriter.moveOpBefore(release, ifOp.thenBlock(), ifOp.thenBlock()->begin()); + return ifOp; + } + if (auto arrive = dyn_cast(op)) { + rewriter.setInsertionPoint(arrive); + auto ifOp = + rewriter.create(arrive->getLoc(), pred, /*else=*/false); + rewriter.moveOpBefore(arrive, ifOp.thenBlock(), ifOp.thenBlock()->begin()); + return ifOp; + } + assert("don't know how to predicate this op" && false); return op; } @@ -159,6 +190,7 @@ void mlir::triton::replaceUsesAndPropagateType(OpBuilder &builder, trans.getOrderAttr()); } assert(newVal); + newVal.getDefiningOp()->setAttrs(user->getAttrs()); replaceUsesAndPropagateType(builder, user, newVal); opsToDelete.push_back(use.getOwner()); } @@ -173,3 +205,49 @@ void mlir::triton::replaceUsesAndPropagateType(OpBuilder &builder, for (Operation *op : opsToDelete) op->erase(); } + +// Begin __FACEBOOK__ CompPipe +llvm::SmallVector> +mlir::triton::loadOpsToIndirectionLevelAndUse(scf::ForOp forOp) { + llvm::SmallVector> + loadOpToIndLevelAndUse; + DenseSet seen; + + std::function dfs = + [&](Operation *op, int distance, Operation *use) { + if (!seen.insert(op).second) + return; + if (isa(op)) { + // TODO: What if there are multiple uses at different distances? + loadOpToIndLevelAndUse.push_back(std::make_tuple(op, distance, use)); + use = op; + distance++; + } + for (Value operand : op->getOperands()) { + Value v = operand; + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + dfs(defOp, distance, use); + } + } + }; + + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!op.hasTrait()) + continue; + seen.clear(); + dfs(&op, 0, &op); + } + + // If the loop has numStages attribute, also consider pipelining other loads + // that are not directly used by dot ops. + if (forOp->hasAttr(tt::kNumStagesAttrName)) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!isa(op)) + dfs(&op, 0, &op); + } + } + + return loadOpToIndLevelAndUse; +} +// End __FACEBOOK__ CompPipe diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp index 1116b70a0..1d10595c3 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp @@ -15,9 +15,15 @@ namespace tt = mlir::triton; namespace ttg = mlir::triton::gpu; namespace ttng = mlir::triton::nvidia_gpu; -void tt::CoarseSchedule::insertDepsOfOp(Operation *op, int stage, - tt::CoarseSchedule::Cluster cluster, - bool includeArg) { +void tt::CoarseSchedule::insertDepsOfOp( + Operation *op, int stage, tt::CoarseSchedule::Cluster cluster, + bool includeArg, DenseMap *additionalDep) { + // Look in additionalDep. + if (additionalDep && additionalDep->find(op) != additionalDep->end()) { + Operation *wait = (*additionalDep)[op]; + if (insertIfAbsent(wait, stage, cluster)) + insertDepsOfOp(wait, stage, cluster, includeArg, additionalDep); + } for (Value operand : op->getOperands()) { Value v = operand; llvm::SmallDenseSet seen; @@ -36,7 +42,7 @@ void tt::CoarseSchedule::insertDepsOfOp(Operation *op, int stage, Operation *defOp = v.getDefiningOp(); if (defOp && defOp->getBlock() == op->getBlock()) { if (insertIfAbsent(defOp, stage, cluster)) { - insertDepsOfOp(defOp, stage, cluster, includeArg); + insertDepsOfOp(defOp, stage, cluster, includeArg, additionalDep); } } } diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp index 7985d25b9..eee8219ba 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp @@ -1,6 +1,8 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Schedule.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" using namespace mlir; namespace tt = mlir::triton; @@ -29,7 +31,7 @@ getTMAStores(scf::ForOp forOp) { static Value createAlloc(scf::ForOp &forOp, tt::ExperimentalDescriptorStoreOp storeOp) { - OpBuilder builder(forOp); + OpBuilderWithAsyncTaskIds builder(forOp); auto ty = cast(storeOp.getSrc().getType()); auto order = ttg::getOrder(ty.getEncoding()); auto ctaLayout = ttg::getCTALayout(ty.getEncoding()); @@ -44,7 +46,7 @@ static Value createAlloc(scf::ForOp &forOp, Type memdescType = tt::MemDescType::get(ty.getShape(), ty.getElementType(), encoding, sharedMemorySpace, /*mutableMemory*/ true); - Value alloc = builder.create(storeOp->getLoc(), + Value alloc = builder.createWithAsyncTaskIds(storeOp->getLoc(), memdescType, Value()); return alloc; } @@ -52,7 +54,7 @@ static Value createAlloc(scf::ForOp &forOp, static void createTMAAsyncCopy(scf::ForOp &forOp, tt::ExperimentalDescriptorStoreOp storeOp, Value alloc) { - OpBuilder builder(storeOp); + OpBuilderWithAsyncTaskIds builder(storeOp); auto loc = storeOp.getLoc(); auto ty = cast(storeOp.getSrc().getType()); auto order = ttg::getOrder(ty.getEncoding()); @@ -60,10 +62,10 @@ static void createTMAAsyncCopy(scf::ForOp &forOp, // Put wait before the local_store make the store truly async. We know // that we are the only user of the CopyLocalToGlobal. - builder.create(loc, 0); - builder.create(loc, storeOp.getSrc(), alloc); - builder.create(loc, false); - builder.create( + builder.createWithAsyncTaskIds(loc, 0); + builder.createWithAsyncTaskIds(loc, storeOp.getSrc(), alloc); + builder.createWithAsyncTaskIds(loc, false); + builder.createWithAsyncTaskIds( loc, storeOp.getDescPtr(), storeOp.getIndices(), alloc); storeOp->erase(); diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index b6d855a05..42549521f 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -528,6 +528,8 @@ Value LayoutPropagation::getValueAs(Value value, Attribute encoding) { tensorType.getElementType(), encoding); Value converted = rewriter.create(value.getLoc(), tmpType, rewrittenValue); + if (value.getDefiningOp()) + converted.getDefiningOp()->setAttrs(value.getDefiningOp()->getAttrs()); // TODO: we could cache the conversion. return converted; } @@ -770,6 +772,7 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) { auto newType = RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(), encoding); auto cvt = rewriter.create(op->getLoc(), newType, src); + cvt->setAttrs(op->getAttrs()); map(op->getResult(0), cvt.getResult()); return cvt.getOperation(); } @@ -1171,6 +1174,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( tensorType.getShape(), tensorType.getElementType(), *srcEncoding); auto newConvertOp = builder.create( convertOp.getLoc(), newType, extOrBroadcatOp->getOperand(0)); + newConvertOp->setAttrs(convertOp->getAttrs()); Operation *newExtOrBroadcast = builder.clone(*extOrBroadcatOp); newExtOrBroadcast->setOperand(0, newConvertOp.getResult()); auto oldExtOrBroadcastType = diff --git a/lib/Dialect/TritonGPU/Transforms/TaskIdPropagate.cpp b/lib/Dialect/TritonGPU/Transforms/TaskIdPropagate.cpp new file mode 100644 index 000000000..ddd85dee5 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/TaskIdPropagate.cpp @@ -0,0 +1,407 @@ +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +#define DEBUG_TYPE "triton-gpu-taskid-propagate" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace tt = ::mlir::triton; +namespace ttg = ::mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUTASKIDPROPAGATE +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +// Return all Ops that are marked with target task +void getAsyncTaskOps(triton::FuncOp funcOp, DenseSet &asyncTaskOps, + int asyncTaskId) { + funcOp.walk([&](Operation *op) -> void { + if (auto attr = + op->getAttrOfType("async_task_id")) { + for (auto val : attr.getValues()) { + if (val == asyncTaskId) { + asyncTaskOps.insert(op); + break; + } + } + } + }); +} + +void getAllParentOps(DenseSet &parentOps, Operation *targetOp) { + auto op = targetOp; + while (auto parent = op->getParentOp()) { + if (!isa(parent) && !isa(parent)) { + parentOps.insert(parent); + op = parent; + } else { + break; + } + } +} + +void getAllParentOps(triton::FuncOp funcOp, DenseSet &parentOps, + int asyncTaskId) { + DenseSet targetOps; + getAsyncTaskOps(funcOp, targetOps, asyncTaskId); + for (auto op : targetOps) { + getAllParentOps(parentOps, op); + } +} + +void labelByUsers(Operation *op, ArrayRef allAsyncTasks) { + for (Value result : op->getResults()) { + for (Operation *userOp : result.getUsers()) { + if (!userOp->hasAttr("async_task_id")) { + labelByUsers(userOp, allAsyncTasks); + } + addAsyncTaskIds(op, getAsyncTaskIds(userOp)); + } + } + if (!op->hasAttr("async_task_id")) { + addAsyncTaskIds(op, allAsyncTasks); + } +} + +/// Because we set some special filter rules in populateAsyncTaskRegion, +/// there may be unlabeled Ops, e.g. YieldOps, some definingOps of ForOps. +/// or Ops without relations to asyncTaskOps +void populateUnlabledOpsAtLast(triton::FuncOp funcOp, + ArrayRef allAsyncTasks) { + // Label asyncTasks' parentOps + for (int i : allAsyncTasks) { + DenseSet asyncTaskParentOps; + getAllParentOps(funcOp, asyncTaskParentOps, i); + for (auto op : asyncTaskParentOps) { + addAsyncTaskIds(op, {i}); + } + } + + // Get unlabeled Ops + DenseSet unlabeledOps; + funcOp.walk([&](Operation *op) -> void { + if (isa(op) || isa(op) || + isa(op)) { + return; + } + if (!op->hasAttr("async_task_id")) { + unlabeledOps.insert(op); + } + }); + + // Label Ops using its parentOp + for (auto op : unlabeledOps) { + if (auto parent = op->getParentOp()) { + if (!isa(parent)) { + if (!parent->hasAttr("async_task_id")) { + LLVM_DEBUG({ + LDBG("op and parent: "); + op->dump(); + parent->dump(); + }); + continue; + } + assert(parent->hasAttr("async_task_id")); + auto asyncTasks = getAsyncTaskIds(parent); + setAsyncTaskIds(op, asyncTasks); + unlabeledOps.erase(op); + } + } + } + + // Label Ops using dependency + for (auto op : unlabeledOps) { + labelByUsers(op, allAsyncTasks); + unlabeledOps.erase(op); + } + assert(unlabeledOps.size() == 0); +} + +#ifndef NDEBUG +static bool oneVecCoversTheOther(SmallVector &one, + SmallVector &other) { + // Every element of other appears in one. + for (AsyncTaskId t : other) { + // If t doesn't appear in one, return false. + bool found = false; + for (AsyncTaskId t2 : one) { + if (t2 == t) { + found = true; + break; + } + } + if (!found) + return false; + } + return true; +} + +struct AsyncTaskIdsCompare { + static SmallVector getEmptyKey() { + SmallVector V; + V.push_back(reinterpret_cast(-1)); + return V; + } + + static SmallVector getTombstoneKey() { + SmallVector V; + V.push_back(reinterpret_cast(-2)); + return V; + } + + static unsigned getHashValue(const SmallVector &V) { + return static_cast(llvm::hash_combine_range(V.begin(), V.end())); + } + + static bool isEqual(const SmallVector &LHS, + const SmallVector &RHS) { + return LHS == RHS; + } +}; + +// Make sure the def chain contains the right taskId. +bool verifyTaskId(triton::FuncOp &funcOp, + const llvm::DenseSet& anchorOps) { + bool retCode = true; + DenseSet, AsyncTaskIdsCompare> anchorAsyncTasks; + for (auto anchorOp : anchorOps) { + anchorAsyncTasks.insert(getAsyncTaskIds(anchorOp)); + } + + funcOp.walk([&](Operation *op) { + // Skip control ops + if (llvm::isa(op)) + return; + + auto asyncTaskIds = getAsyncTaskIds(op); + if (asyncTaskIds.empty()) { + LLVM_DEBUG({ + LDBG("Op does not have task id"); + op->dump(); + }); + llvm_unreachable("Op does not have task id"); + } + + auto partitionShouldBeUsedSpecified = [](Operation *op) { + if (isa(op)) + return true; + if (isa(op)) + return true; + if (op->hasTrait()) + return true; + return false; + }; + + if (!anchorAsyncTasks.contains(asyncTaskIds)) { + if (partitionShouldBeUsedSpecified(op)) { + LLVM_DEBUG({ + LDBG("async tasks not specified by user"); + op->dump(); + }); + llvm_unreachable("async tasks not specified by user"); + } + } + + assert(!asyncTaskIds.empty() && "Op does not have task id"); + + for (Value operand : op->getOperands()) { + Operation *defOp = operand.getDefiningOp(); + if (!defOp) + continue; + if (llvm::isa(defOp)) + continue; + auto defTaskIds = getAsyncTaskIds(defOp); + // Make sure defTaskIds cover asyncTaskIds. Call addAsyncTaskIds if + // necessary. + LLVM_DEBUG({ + if (!oneVecCoversTheOther(defTaskIds, asyncTaskIds)) { + // print defOp and op + LDBG("Def op does not cover op"); + LDBG("Def op"); + defOp->dump(); + LDBG("op"); + op->dump(); + } + }); + assert(oneVecCoversTheOther(defTaskIds, asyncTaskIds) && + "defTaskIds should cover asyncTaskIds"); + } + }); + return retCode; +} +#endif + +void backwardPropagateTaskIds(Operation *op, + const llvm::DenseSet &anchors) { + SmallVector queue; + auto asyncTasks = getAsyncTaskIds(op); + for (Value operand : op->getOperands()) { + queue.push_back(operand); + } + + DenseSet seen; + for (auto anchor : anchors) { + if (anchor != op) + for (auto result : anchor->getResults()) + seen.insert(result); + } + + while (!queue.empty()) { + auto value = queue.pop_back_val(); + if (!seen.insert(value).second) { + continue; + } + + // Handle BlockArguments of for loops (i.e. loop carried dependences). + if (auto blockArg = dyn_cast(value)) { + auto parent = blockArg.getOwner()->getParentOp(); + if (auto forOp = dyn_cast(parent)) { + // Propagate to the control operands. + auto control = + forOp.getOperands().take_front(forOp.getNumControlOperands()); + queue.insert(queue.end(), control.begin(), control.end()); + // Propagate to the initializer. + if (blockArg.getArgNumber() >= forOp.getNumInductionVars()) { + queue.push_back(forOp.getTiedLoopInit(blockArg)->get()); + // Propagate to the yield. + auto idx = blockArg.getArgNumber() - forOp.getNumInductionVars(); + queue.push_back(forOp.getBody()->getTerminator()->getOperand(idx)); + addAsyncTaskIds(forOp, asyncTasks); + } + } + continue; + } + + auto op = value.getDefiningOp(); + addAsyncTaskIds(op, asyncTasks); + + // Handle for loops. + if (auto forOp = dyn_cast(op)) { + // Propagate to control operands. + auto control = + forOp.getOperands().take_front(forOp.getNumControlOperands()); + queue.insert(queue.end(), control.begin(), control.end()); + // Propagate to arguments. + unsigned idx = cast(value).getResultNumber(); + queue.push_back(forOp.getOperand(idx + forOp.getNumControlOperands())); + // Propagate to yield. + queue.push_back(forOp.getBody()->getTerminator()->getOperand(idx)); + continue; + } + + // Handle conditionals. + if (auto ifOp = dyn_cast(op)) { + queue.push_back(ifOp.getCondition()); + unsigned idx = cast(value).getResultNumber(); + if (ifOp.elseBlock()) { + queue.push_back(ifOp.elseYield()->getOperand(idx)); + } + queue.push_back(ifOp.thenYield()->getOperand(idx)); + continue; + } + + // Handle normal ops. + for (Value operand : op->getOperands()) { + queue.push_back(operand); + } + } +} + +void backwardPropagateTaskIds(llvm::DenseSet &anchorOps) { + for (Operation *op : anchorOps) { + backwardPropagateTaskIds(op, anchorOps); + } +} + +void populateTaskIdsForControlDependencies( + llvm::DenseSet &anchorOps) { + for (auto op : anchorOps) { + auto asyncTaskIds = getAsyncTaskIds(op); + if (!asyncTaskIds.empty()) { + while (auto parent = op->getParentOp()) { + if (!isa(parent) && !isa(parent)) { + setAsyncTaskIds(parent, asyncTaskIds); + backwardPropagateTaskIds(parent, anchorOps); + op = parent; + } else { + break; + } + } + } + } +} + +class TritonGPUTaskIdPropagatePass + : public impl::TritonGPUTaskIdPropagateBase { +public: + using impl::TritonGPUTaskIdPropagateBase< + TritonGPUTaskIdPropagatePass>::TritonGPUTaskIdPropagateBase; + + void runOnFuncOp(triton::FuncOp funcOp) { + llvm::DenseSet anchorOps; + funcOp.walk([&](mlir::Operation *op) { + auto asyncTasks = getAsyncTaskIds(op); + if (!asyncTasks.empty() && + !isa(op)) + anchorOps.insert(op); + }); + + populateTaskIdsForControlDependencies(anchorOps); + + LLVM_DEBUG({ + LDBG("after populateTaskIdsForControlDependencies "); + funcOp->dump(); + }); + + backwardPropagateTaskIds(anchorOps); + + LLVM_DEBUG({ + LDBG("after backwardPropagateTaskIds "); + funcOp->dump(); + }); + + DenseSet allAsyncTasks; + funcOp->walk([&](Operation *op) { + auto asyncTasks = getAsyncTaskIds(op); + allAsyncTasks.insert(asyncTasks.begin(), asyncTasks.end()); + }); + SmallVector allAsyncTasksVec(allAsyncTasks.begin(), + allAsyncTasks.end()); + populateUnlabledOpsAtLast(funcOp, allAsyncTasksVec); + + LLVM_DEBUG({ + LDBG("after populateUnlabledOpsAtLast "); + funcOp->dump(); + }); + +#ifndef NDEBUG + verifyTaskId(funcOp, anchorOps); +#endif + } + + void runOnOperation() override { + if (numConsumerGroups == 0) { + getOperation()->walk([&](triton::FuncOp funcOp) { + funcOp.walk([&](mlir::Operation *op) { + auto asyncTasks = getAsyncTaskIds(op); + if (!asyncTasks.empty()) + op->removeAttr("async_task_id"); + }); + }); + return; + } + getOperation()->walk([&](triton::FuncOp funcOp) { runOnFuncOp(funcOp); }); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp b/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp new file mode 100644 index 000000000..2ae27f467 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp @@ -0,0 +1,1424 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include +#include + +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUWSCODEPARTITION +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +#define DEBUG_TYPE "tritongpu-warp-spec-code-partition" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +std::pair scanRegUsage(ArrayRef opList, + AsyncTaskId asyncTaskId, int regDecProducer, + int regIncConsumer) { + // TODO: scan ops to estimate register usage + if (asyncTaskId == 0) { + // deallocate registers + return {regDecProducer == 0 ? 40 : regDecProducer, false}; + } else { + // allocate registers + return {regIncConsumer == 0 ? 232 : regIncConsumer, true}; + } +} + +// Create IfOp for each ayncTaskId. +DenseMap SpecializeRegion(triton::FuncOp funcOp, + int regDecProducer, + int regIncConsumer) { + MLIRContext *context = funcOp.getContext(); + OpBuilder builder(context); + auto loc = funcOp.getLoc(); + + // Collect original operations + SmallVector opList; + for (auto &block : funcOp.getBody().getBlocks()) { + for (Operation &op : block.getOperations()) + opList.push_back(&op); + } + + // Create GetAsyncTaskIdOp. + Block *lastBlock = &funcOp.getBody().back(); + auto returnOp = llvm::cast(lastBlock->getTerminator()); + builder.setInsertionPoint(returnOp); + Value curAsyncTaskId = builder.create(loc); + + // Resources for each asyncTaskId: builder, IfOp, and IRMapping. + DenseMap> + tasksToBuilders; + DenseMap tasksToIfOp; + DenseMap tasksToIRMappings; + + for (AsyncTaskId asyncTaskId : getNestedAsyncTaskIds(funcOp)) { + // Create IfOp for each asyncTaskId. + Value cond = builder.create( + loc, arith::CmpIPredicate::eq, curAsyncTaskId, + builder.create(loc, asyncTaskId, 32)); + + auto ifOp = builder.create(loc, cond); + tasksToIfOp[asyncTaskId] = ifOp; + setAsyncTaskIds(ifOp, {asyncTaskId}); + + // Create OpBuilderWithAsyncTaskIds for each taskId. + auto taskBuilder = std::make_shared(context); + tasksToBuilders[asyncTaskId] = taskBuilder; + taskBuilder->setAsynTaskIdsFromArray({asyncTaskId}); + + // Decide if this taskId is a producer or a consumer, and create either + // RegAllocOp or RegDeallocOp accordingly. + auto regAlloc = + scanRegUsage(opList, asyncTaskId, regDecProducer, regIncConsumer); + taskBuilder->setInsertionPointToStart(&(ifOp.getThenRegion().front())); + if (regAlloc.second) + taskBuilder->create( + loc, taskBuilder->getI32IntegerAttr(regAlloc.first)); + else + taskBuilder->create( + loc, taskBuilder->getI32IntegerAttr(regAlloc.first)); + + // Set insertion point before yieldOp. + auto yieldOp = ifOp.thenYield(); + setAsyncTaskIds(yieldOp, {asyncTaskId}); + taskBuilder->setInsertionPoint(yieldOp); + } + + // Clone all operations into the corresponding if blocks. If the operation has + // multiple taskIds, it will be cloned for multiple if blocks. + // If the original code has an IfOp, we should only clone its + // body with the right asyncTaskId, instead of cloning the IfOp. + SmallVector cloned; + for (Operation *op : opList) { + auto asyncTaskIds = getAsyncTaskIds(op); + if (asyncTaskIds.size() == 0) + continue; + cloned.push_back(op); + if (auto ifOp = dyn_cast(op)) { + DenseMap tasksToThisIfOp; + // TODO: handle outputs of this IfOp. + for (AsyncTaskId asyncTaskId : getAsyncTaskIds(op)) { + IRMapping &mapping = tasksToIRMappings[asyncTaskId]; + auto ifOpForTask = tasksToBuilders[asyncTaskId]->create( + loc, mapping.lookup(ifOp.getCondition())); + tasksToThisIfOp[asyncTaskId] = ifOpForTask; + auto newYieldOp = ifOpForTask.thenYield(); + tasksToBuilders[asyncTaskId]->setInsertionPoint(newYieldOp); + } + // Handle thenRegion of this IfOp. + for (Operation &thenOp : ifOp.thenBlock()->without_terminator()) { + LLVM_DEBUG({ + LDBG("specialize thenBlock inside ifOp "); + thenOp.dump(); + }); + for (AsyncTaskId asyncTaskId : getAsyncTaskIds(&thenOp)) { + IRMapping &mapping = tasksToIRMappings[asyncTaskId]; + Operation *newOp = + tasksToBuilders[asyncTaskId]->clone(thenOp, mapping); + for (unsigned i = 0; i < thenOp.getNumResults(); ++i) + mapping.map(thenOp.getResult(i), newOp->getResult(i)); + } + } + if (!ifOp.elseBlock()) + continue; // Done with this IfOp, continue to the next op. + // Handle elseRegion of the IfOp. + for (AsyncTaskId asyncTaskId : getAsyncTaskIds(op)) { + auto newYieldOp = tasksToThisIfOp[asyncTaskId].elseYield(); + tasksToBuilders[asyncTaskId]->setInsertionPoint(newYieldOp); + } + for (Operation &thenOp : ifOp.elseBlock()->without_terminator()) { + LLVM_DEBUG({ + LDBG("specialize elseBlock inside ifOp "); + thenOp.dump(); + }); + for (AsyncTaskId asyncTaskId : getAsyncTaskIds(&thenOp)) { + IRMapping &mapping = tasksToIRMappings[asyncTaskId]; + Operation *newOp = + tasksToBuilders[asyncTaskId]->clone(thenOp, mapping); + for (unsigned i = 0; i < thenOp.getNumResults(); ++i) + mapping.map(thenOp.getResult(i), newOp->getResult(i)); + } + } + } else { + for (AsyncTaskId asyncTaskId : getAsyncTaskIds(op)) { + IRMapping &mapping = tasksToIRMappings[asyncTaskId]; + Operation *newOp = tasksToBuilders[asyncTaskId]->clone(*op, mapping); + for (unsigned i = 0; i < op->getNumResults(); ++i) + mapping.map(op->getResult(i), newOp->getResult(i)); + } + } + } + + LLVM_DEBUG({ + LDBG("\n\nWith task Id checks"); + funcOp.dump(); + }); + + // Remove original operations that have been cloned in reverse order. + for (auto it = cloned.rbegin(); it != cloned.rend(); ++it) { + Operation *op = *it; + LLVM_DEBUG({ + LDBG("erasing op "); + op->dump(); + }); + // For debugging purposes, check to see if the original op is still in use. + bool hasUse = false; + for (unsigned i = 0; i < op->getNumResults(); ++i) { + for (Operation *user : op->getResult(i).getUsers()) { + hasUse = true; + LLVM_DEBUG({ + LDBG("op has use "); + user->dump(); + }); + } + } + op->erase(); + } + return tasksToIfOp; +} + +struct Channel { +public: + using Relation = std::pair>; + + Channel(int producer, SmallVector &consumers, Operation *src, + Operation *dst, Value srcOperand) + : relation(producer, consumers), srcOp(src), dstOp(dst), + srcOperand(srcOperand) {} + + bool operator==(const Channel &c) { + return relation == c.relation && srcOp == c.srcOp && dstOp == c.dstOp; + } + + Relation relation; // producer task Id, a list of consumer task Ids + Operation *srcOp; + Operation *dstOp; + Value srcOperand; +}; + +// Loads will be in producer warp groups. For now, we only allow a single +// warp group/task for a producer. For each LoadOp, create a channel from it +// to any direct user which belongs to a different taskId. +void collectAsyncChannels(SmallVector> &channels, + triton::FuncOp &funcOp) { + funcOp.walk([&](Operation *op) { + if (isa(op)) { + auto producerTaskIds = getAsyncTaskIds(op); + if (producerTaskIds.empty() || producerTaskIds.size() > 1) { + LLVM_DEBUG({ + LDBG(" ignoring load ops without async task id or with multiple task " + "ids: "); + op->dump(); + }); + return; + } + auto producerTaskId = producerTaskIds.front(); + + for (auto result : op->getResults()) { + if (result.use_empty()) { + continue; + } + for (Operation *userOp : result.getUsers()) { + auto consumerTaskIds = getAsyncTaskIds(userOp); + if (consumerTaskIds.empty()) + continue; + // Remove producer task id from consumerTaskIds. + auto iter = std::remove(consumerTaskIds.begin(), + consumerTaskIds.end(), producerTaskId); + consumerTaskIds.erase(iter, consumerTaskIds.end()); + // Add a channel from the single producer task to consumerTaskIds. + if (consumerTaskIds.size() > 0) { + channels.push_back(std::make_unique( + producerTaskId, consumerTaskIds, op, userOp, result)); + } + } + } + } + }); + + LLVM_DEBUG({ + LDBG("Async channels:"); + for (auto &channel : channels) { + LDBG("producer op: " << channel->relation.first); + channel->srcOp->dump(); + for (auto &asyncTaskId : channel->relation.second) + LDBG("consumer: " << asyncTaskId); + channel->dstOp->dump(); + } + }); +} + +// Update map, which will be keyed by dstOp of the channel. Use mapKeyVec to +// enforce deterministic order for map. +void groupChannels(SmallVector &channels, + DenseMap> &map, + SmallVector &mapKeyVec) { + // Two channels can be combined if + // src1 and src2 are in the same block and + // (dst1 == dst2 or + // (dst1 and dst2 are in the same block, both have a single user, and + // dst1User == dst2User and dst1User is in the same block as dst1)) + auto channelCanBeMerged = [](Channel *c1, Channel *c2) -> bool { + if (c1->srcOp->getBlock() != c2->srcOp->getBlock()) + return false; + Operation *dst1 = c1->dstOp, *dst2 = c2->dstOp; + if (dst1 == dst2) + return true; + if (dst1->getBlock() != dst2->getBlock() || !dst1->hasOneUse() || + !dst2->hasOneUse()) + return false; + Operation *dst1User = *(dst1->getUsers().begin()); + Operation *dst2User = *(dst2->getUsers().begin()); + return dst1User == dst2User && dst1User->getBlock() == dst1->getBlock(); + }; + assert(channels.size() > 0 && "channel size is zero"); + // Compare with existing channels in the map to see if it can be combined. + for (auto *c0 : channels) { + bool merged = false; + for (auto &kv : map) { + if (kv.second.size() > 0 && channelCanBeMerged(c0, kv.second.front())) { + kv.second.push_back(c0); + merged = true; + break; + } + } + if (!merged) { // Create a new entry. + auto *keyOp = c0->dstOp; + if (!map.count(keyOp)) + mapKeyVec.push_back(keyOp); + map[keyOp].push_back(c0); + } + } + + // Reorder channels associated with one entry based on program order of the + // producers. + for (auto &kv : map) { + if (kv.second.size() > 1) { + auto &allOps = kv.second.front()->srcOp->getBlock()->getOperations(); + std::sort( + kv.second.begin(), kv.second.end(), [&](Channel *a, Channel *b) { + auto itrA = + std::find_if(allOps.begin(), allOps.end(), [&](Operation &op) { + Operation *opPointer = &op; + return opPointer == a->srcOp; + }); + auto itrB = + std::find_if(allOps.begin(), allOps.end(), [&](Operation &op) { + Operation *opPointer = &op; + return opPointer == b->srcOp; + }); + assert(itrA != allOps.end() && itrB != allOps.end()); + return std::distance(itrA, itrB) < 0; + }); + } + } +} + +// Reorder producer ops to unblock consumers interleavingly. +void reorderProducerOps(SmallVector &channels) { + if (channels.size() <= 1) + return; + + // Bail out if channels are not in the same block + auto block = channels.front()->srcOp->getBlock(); + for (auto &channel : channels) { + if (channel->srcOp->getBlock() != block) { + return; + } + } + + // Group channels by the first consumer taskId of each channel. Smaller taskId + // has higher priority. + // TODO: consider consumer priority + std::map> groupedProducerOps; + for (auto &channel : channels) { + auto asyncTaskId = channel->relation.second.front(); + groupedProducerOps[asyncTaskId].push_back(channel); + } + + // No need to reorder if all channels are in the same group. + if (groupedProducerOps.size() <= 1) + return; + + // Sort each group by number of consumers. + for (auto &group : groupedProducerOps) { + std::sort(group.second.begin(), group.second.end(), + [&](Channel *a, Channel *b) { + return a->relation.second.size() < b->relation.second.size(); + }); + } + + // Start from the first producer in channels. Iterate through the groups + // which are ordered by the first consumer taskId. Within each group, channels + // are ordered by number of consumers. + Operation *currOp = channels.front()->srcOp; + for (auto &group : groupedProducerOps) { + for (auto &channel : group.second) { + channel->srcOp->moveAfter(currOp); + currOp = channel->srcOp; + } + } + + // Move backward dependency slice close to producer ops. + // Start from the last producer op backwards and move backward slice to + // before each op. This guarantees that the backward slice of each op is + // scheduled as late as possible. + for (auto &group : reverse(groupedProducerOps)) { + for (auto &channel : reverse(group.second)) { + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + SetVector backwardSlice; + getBackwardSlice(channel->srcOp, &backwardSlice, opt); + for (auto &op : backwardSlice) { + if (op->getBlock() == block) + op->moveBefore(channel->srcOp); + } + } + } + + LLVM_DEBUG({ + LDBG("\n"); + LDBG("after reordering producer ops"); + currOp->getParentOfType().dump(); + LDBG("\n"); + }); +} + +bool isInnermostLoop(scf::ForOp forOp) { + for (Operation &nestedOp : forOp.getBody()->getOperations()) { + if (isa(nestedOp)) { + return false; + } + } + return true; +} + +// Add phase and bufferIndex to be used when lowering the producer. +scf::ForOp createNewLoop(scf::ForOp forOp, int numBuffers, + scf::ForOp &parentForOp) { + auto loc = forOp.getLoc(); + Block *body = forOp.getBody(); + + OpBuilderWithAsyncTaskIds builder(forOp.getContext()); + builder.setAsynTaskIdsFromArray(getNestedAsyncTaskIds(forOp)); + builder.setInsertionPoint(forOp); + + Value numBuffersVal = + builder.createWithAsyncTaskIds(loc, numBuffers, 32); + + // Step 1: Append bufferIdx and phase as forOp arguments. + Value phase = + body->insertArgument(body->getNumArguments(), builder.getI1Type(), loc); + Value bufferIdx = + body->insertArgument(body->getNumArguments(), builder.getI32Type(), loc); + + // Step 2: Generate bufferIdx and phase for next iteration: + // nextBufferIdx = bufferIdx + 1 + // nextPhase = ((nextBufferIdx < numBuffers && curPhase) || + // (nextBufferIdx >= numBuffers && curPhase^1)) + // nextBufferIdx = nextBufferIdx >= numBuffers ? 0 : nextBufferIdx + auto yieldOp = llvm::cast(body->getTerminator()); + builder.setInsertionPoint(yieldOp); + Value one = builder.createWithAsyncTaskIds(loc, 1, 32); + Value zero = builder.createWithAsyncTaskIds(loc, 0, 32); + Value _1_1b = builder.createWithAsyncTaskIds(loc, 1, 1); + // nextBufferIdx = bufferIdx + 1 + Value nextBufferIdx = + builder.createWithAsyncTaskIds(loc, bufferIdx, one); + Value bufferGECond = builder.createWithAsyncTaskIds( + loc, arith::CmpIPredicate::uge, nextBufferIdx, numBuffersVal); + Value bufferLTCond = builder.createWithAsyncTaskIds( + loc, arith::CmpIPredicate::ult, nextBufferIdx, numBuffersVal); + if (isInnermostLoop(forOp)) { + // nextBufferIdx >= numBuffers ? nextBufferIdx - numBuffers : nextBufferIdx + Value moduloBufferIdx = builder.createWithAsyncTaskIds( + loc, nextBufferIdx, numBuffersVal); + nextBufferIdx = builder.createWithAsyncTaskIds( + loc, bufferGECond, moduloBufferIdx, nextBufferIdx); + } + + // nextPhase = ((nextBufferIdx < numBuffers && curPhase) || + // (nextBufferIdx >= numBuffers && curPhase^1)) + Value flipPhase = + builder.createWithAsyncTaskIds(loc, phase, _1_1b); + Value cond0 = builder.createWithAsyncTaskIds( + loc, bufferGECond, flipPhase); + Value cond1 = builder.createWithAsyncTaskIds( + loc, bufferLTCond, phase); + Value nextPhase = + builder.createWithAsyncTaskIds(loc, cond0, cond1); + + // Step 3: Add nextBufferIdx and nextPhase to yieldOp. + yieldOp->insertOperands(yieldOp.getNumOperands(), {nextPhase, nextBufferIdx}); + + // Step 4: Create loop arguments for the new ForOp. + SmallVector newLoopArgs; + for (auto operand : forOp.getInitArgs()) + newLoopArgs.push_back(operand); + + builder.setInsertionPoint(forOp); + Value initBufferIdx, initPhase; + zero = builder.createWithAsyncTaskIds(loc, 0, 32); + // Set initial values for bufferIdx and phase. + if (parentForOp) { + // Assume parent ForOp has bufferIdx as the last argument. + initBufferIdx = parentForOp.getBody()->getArguments().back(); + + // numSteps = ((upperBound - lowerBound) + forOpStep - 1) / forOpStep + Value numSteps = builder.createWithAsyncTaskIds( + loc, forOp.getUpperBound(), forOp.getLowerBound()); + numSteps = builder.createWithAsyncTaskIds(loc, numSteps, + forOp.getStep()); + Value one = + builder.createWithAsyncTaskIds(loc, 1, 32); + Value two = + builder.createWithAsyncTaskIds(loc, 2, 32); + numSteps = + builder.createWithAsyncTaskIds(loc, numSteps, one); + numSteps = builder.createWithAsyncTaskIds(loc, numSteps, + forOp.getStep()); + + // initBufferIdx = (parentForOp.bufferIdx * numSteps) % numBuffers + // tmpIdx = parentForOp.bufferIdx * numSteps + // initBufferIdx = tmpIdx - tmpIdx / numBuffers * numBuffers + // initPhase = (tmpIdx / numBuffers) & 1 + initBufferIdx = builder.createWithAsyncTaskIds( + loc, initBufferIdx, numSteps); + Value bufferIdx = builder.createWithAsyncTaskIds( + loc, initBufferIdx, numBuffersVal); + initBufferIdx = builder.createWithAsyncTaskIds( + loc, initBufferIdx, + builder.createWithAsyncTaskIds(loc, bufferIdx, + numBuffersVal)); + bufferIdx = + builder.createWithAsyncTaskIds(loc, bufferIdx, one); + initPhase = builder.createWithAsyncTaskIds( + loc, builder.getI1Type(), bufferIdx); + } else { + // Set initial phase to false, and initial bufferIdx to 0. + initBufferIdx = zero; + initPhase = builder.createWithAsyncTaskIds(loc, 0, 1); + } + newLoopArgs.append({initPhase, initBufferIdx}); + + // Step 5: Create newForOp and take the region of the original forOp. + auto newForOp = builder.createWithAsyncTaskIds( + loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), + newLoopArgs); + if (forOp->getAttr("tt.loop_schedule")) + newForOp->setAttr("tt.loop_schedule", forOp->getAttr("tt.loop_schedule")); + newForOp.getRegion().takeBody(forOp.getRegion()); + + // Step 6: Replace forOp with newForOp. + for (unsigned i = 0; i < forOp.getNumResults(); ++i) + forOp.getResult(i).replaceAllUsesWith(newForOp.getResult(i)); + forOp.erase(); + + return newForOp; +} + +// Find top-level ops which contain at least one channel. If a channel's srcOp +// and dstOp belong to the inner loop, the outer loop will be part of +// asyncTaskOps. +SmallVector +getTaskTopRegion(triton::FuncOp funcOp, + const SmallVector &channels) { + SmallVector asyncTaskOps; + auto isAsyncTaskTopOp = [&](Operation *taskTopOp) -> bool { + for (auto c : channels) { + Operation *producer = c->srcOp, *consumer = c->dstOp; + while (producer && !isa(producer->getParentOp())) { + producer = producer->getParentOp(); + } + while (consumer && !isa(consumer->getParentOp())) { + consumer = consumer->getParentOp(); + } + if (producer == taskTopOp && consumer == taskTopOp) + return true; + } + return false; + }; + for (auto &block : funcOp.getBody().getBlocks()) { + for (Operation &bodyOp : block.getOperations()) { + Operation *op = &bodyOp; + if (op->getNumRegions() <= 0) + continue; + // If this op does not contain both a producer taskId and a consumer + // taskId, continue. + if (getAsyncTaskIds(op).size() == 1) + continue; + if (isAsyncTaskTopOp(op)) + asyncTaskOps.push_back(op); + } + } + return asyncTaskOps; +} + +// For ForOps in taskTopOps, create new ForOp for each by adding phase, +// bufferIdx to the arguments. +void appendBufferIdxArgs(SmallVector &taskTopOps, int numBuffers) { + SmallVector orderedForOps; + for (auto &op : taskTopOps) { + op->walk([&](Operation *subOp) { + if (auto forOp = dyn_cast(subOp)) { + orderedForOps.push_back(forOp); + } + }); + } + + for (auto &origForOp : orderedForOps) { + scf::ForOp parentForOp = origForOp->getParentOfType(); + scf::ForOp newForOp; + // for(...) -> for(..., phase, bufferIdx) + newForOp = createNewLoop(origForOp, numBuffers, parentForOp); + // origForOp is erased in createNewLoop. If origForOp is a top operation + // (i.e in taskTopOps), make sure taskTopOps is updated with the newForOp. + auto asyncTaskLoopForItr = std::find(taskTopOps.begin(), taskTopOps.end(), + origForOp.getOperation()); + if (asyncTaskLoopForItr != taskTopOps.end()) { + // Update taskTopOps. + *asyncTaskLoopForItr = newForOp.getOperation(); + } + } +} + +// Create an allocation to hold the mbarriers. +static Value createBarrierAlloc(triton::FuncOp funcOp, unsigned distance) { + OpBuilder builder(funcOp); + builder.setInsertionPointToStart(&(funcOp.getBody().front())); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(funcOp.getContext()); + Location loc = funcOp.getLoc(); + auto context = funcOp.getContext(); + auto barrierCTALayout = + ttg::CTALayoutAttr::get(context, /*CTAsPerCGA=*/{1}, + /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); + auto barrierEncoding = + ttg::SharedEncodingAttr::get(context, 1, 1, 1, {0}, barrierCTALayout); + Type barrierMemDescType = tt::MemDescType::get( + {distance}, builder.getI64Type(), barrierEncoding, sharedMemorySpace, + /*mutableMemory=*/true); + Type singleBarrierMemDescType = + tt::MemDescType::get({1}, builder.getI64Type(), barrierEncoding, + sharedMemorySpace, /*mutableMemory=*/true); + Value barrierAlloc = builder.create( + loc, barrierMemDescType, Value()); + for (unsigned i = 0; i < distance; i++) { + Value idx = builder.create(loc, i, 32); + Value barrierView = builder.create( + loc, singleBarrierMemDescType, barrierAlloc, idx); + builder.create(funcOp->getLoc(), barrierView, 1); + } + return barrierAlloc; +} + +// map: channels are grouped together. +// Go through each group, check the first channel in the group, create a token +// for each consumer taskId. Return a map that maps each channel + consumer +// taskId to a token. Also update barrierAllocMap that maps each channel + +// consumer taskId to a BarrierAlloc. +DenseMap> +createToken(const DenseMap> &map, + const SmallVector &mapKeyVec, triton::FuncOp funcOp, + int numBuffers, int numConsumerGroups, + DenseMap> &barrierAllocMap) { + DenseMap> ret; + OpBuilder builder(funcOp); + builder.setInsertionPointToStart(&(funcOp.getBody().front())); + for (auto *key : mapKeyVec) { + auto it = map.find(key); + for (auto consumerAsyncTaskId : it->second.front()->relation.second) { + Value v; + if (it->second.front()->srcOp->getParentOfType()) { + v = builder.create(funcOp.getLoc(), numBuffers); + } else { + v = builder.create(funcOp.getLoc(), 1); + } + // Channels in the group share the same set of tokens. + for (auto &c : it->second) + ret[c][consumerAsyncTaskId] = v; + + auto producerOp = it->second.front()->srcOp; + if (isa(producerOp)) { + Value bAlloc = createBarrierAlloc(funcOp, numBuffers); + // Channels in the group share the same set of tokens. + for (auto &c : it->second) { + ret[c][consumerAsyncTaskId] = v; + barrierAllocMap[c][consumerAsyncTaskId] = bAlloc; + } + } + } + } + return ret; +} + +// Create a buffer array for each channel, if the producer is in a ForOp, +// the buffer array will contain numBuffers. +DenseMap createBuffer(const SmallVector &channels, + triton::FuncOp funcOp, int numBuffers, + int numConsumerGroups) { + DenseMap bufferMap; + MLIRContext *context = funcOp.getContext(); + OpBuilder builder(funcOp); + builder.setInsertionPointToStart(&(funcOp.getBody().front())); + for (const auto &c : channels) { + if (auto tensorType = dyn_cast(c->srcOperand.getType())) { + // Get basic information from tensorType + auto order = ttg::getOrder(tensorType.getEncoding()); + auto CTALayout = ttg::getCTALayout(tensorType.getEncoding()); + auto elemType = tensorType.getElementType(); + + // Get shape, layout and type of a slice + auto sliceShape = tensorType.getShape(); + auto sharedLayout = ttg::SharedEncodingAttr::get( + context, sliceShape, order, CTALayout, elemType); + auto sliceType = + RankedTensorType::get(sliceShape, elemType, sharedLayout); + + // Get shape, layout and type of the complete buffer + SmallVector bufferShape(sliceShape.begin(), sliceShape.end()); + if (c->srcOp->getParentOfType()) + bufferShape.insert(bufferShape.begin(), numBuffers); + else + bufferShape.insert(bufferShape.begin(), 1); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(context); + auto bufferType = + RankedTensorType::get(bufferShape, elemType, sharedLayout); + Type memdescType = + tt::MemDescType::get(bufferShape, elemType, sharedLayout, + sharedMemorySpace, /*mutableMemory*/ true); + Value buffer; + if (isa(c->srcOp)) { + buffer = + builder.create(funcOp.getLoc(), memdescType); + } else { + buffer = builder.create(funcOp.getLoc(), memdescType, + c->srcOperand); + } + bufferMap[c] = buffer; + } else { + llvm_unreachable("Unexpected result type"); + } + } + return bufferMap; +} + +static Operation *createAsyncCopy(const DenseMap &bufferMap, + Channel *c, Operation *op, + SmallVector &asyncTasksPC, + Value bufferIdx, Value bufferIdxExtract) { + auto loadOp = cast(op); + auto buffer = bufferMap.find(c)->second; + MLIRContext *context = loadOp->getContext(); + OpBuilderWithAsyncTaskIds builder(context); + builder.setInsertionPoint(loadOp->getParentOp()); + builder.setAsynTaskIdsFromArray(asyncTasksPC); + + builder.setInsertionPoint(loadOp); + Value loadResult = loadOp.getResult(); + auto tensorType = dyn_cast(loadResult.getType()); + if (!tensorType) + return nullptr; + // Get basic information from tensorType + auto order = ttg::getOrder(tensorType.getEncoding()); + auto CTALayout = ttg::getCTALayout(tensorType.getEncoding()); + auto elemType = tensorType.getElementType(); + + // Get shape, layout and type of a slice + auto sliceShape = tensorType.getShape(); + auto sharedLayout = ttg::SharedEncodingAttr::get(context, sliceShape, order, + CTALayout, elemType); + auto sliceType = RankedTensorType::get(sliceShape, elemType, sharedLayout); + + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(context); + tt::MemDescType subviewTy = + tt::MemDescType::get(sliceType.getShape(), sliceType.getElementType(), + sliceType.getEncoding(), sharedMemorySpace, + /*mutableMemory=*/true); + Value zero = builder.createWithAsyncTaskIds( + loadOp.getLoc(), 0, 32); + SmallVector copyOffsets(sliceType.getRank() + 1, zero); + copyOffsets[0] = bufferIdx; + builder.setAsyncTaskIdsFromOp(loadOp); + builder.setInsertionPointAfter(loadOp); + auto view = builder.createWithAsyncTaskIds( + loadOp.getLoc(), subviewTy, buffer, copyOffsets); + // Create cp.async + Operation *copy = + builder.createWithAsyncTaskIds( + loadOp.getLoc(), loadOp.getPtr(), view, loadOp.getMask(), + loadOp.getOther(), loadOp.getCache(), loadOp.getEvict(), + loadOp.getIsVolatile()); + + // Extract part. + builder.setAsyncTaskIdsFromValueUsers(loadResult); + builder.setInsertionPoint(c->dstOp); + SmallVector loadOffsets(sliceType.getRank() + 1, zero); + loadOffsets[0] = bufferIdxExtract; + auto viewLoad = builder.createWithAsyncTaskIds( + loadOp.getLoc(), subviewTy, buffer, loadOffsets); + auto sharedLoad = builder.createWithAsyncTaskIds( + loadOp.getLoc(), loadOp.getType(), viewLoad /*,wait->getResult(0)*/); + // Replace all uses of loadResult + loadResult.replaceAllUsesWith(sharedLoad.getResult()); + loadOp.erase(); + return copy; +} + +static int getTMALoadSize(tt::ExperimentalDescriptorLoadOp &tmaLoad) { + auto tensorTy = cast(tmaLoad->getResult(0).getType()); + int loadSize = product(tensorTy.getShape()); + return loadSize * tensorTy.getElementType().getIntOrFloatBitWidth() / 8; +} + +Value getBarrierForPipelineStage(OpBuilderWithAsyncTaskIds &builder, + Value barrierAlloc, Value bufferIdx) { + auto context = barrierAlloc.getContext(); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(context); + tt::MemDescType barrierTy = tt::MemDescType::get( + {1}, builder.getI64Type(), + cast(barrierAlloc.getType()).getEncoding(), + sharedMemorySpace, + /*mutableMemory=*/true); + + // Create barrierForTMA from barrierAlloc. + return builder.createWithAsyncTaskIds( + barrierAlloc.getLoc(), barrierTy, barrierAlloc, + ArrayRef({bufferIdx})); +} + +Value getBufferForPipelineStage(OpBuilderWithAsyncTaskIds &builder, + Type loadType, Value buffer, Value bufferIdx, + bool mutableMem) { + auto context = buffer.getContext(); + auto tensorType = dyn_cast(loadType); + assert(tensorType); + + auto order = ttg::getOrder(tensorType.getEncoding()); + auto CTALayout = ttg::getCTALayout(tensorType.getEncoding()); + auto elemType = tensorType.getElementType(); + + // Get shape, layout and type of a slice + auto sliceShape = tensorType.getShape(); + auto sharedLayout = ttg::SharedEncodingAttr::get(context, sliceShape, order, + CTALayout, elemType); + auto sliceType = RankedTensorType::get(sliceShape, elemType, sharedLayout); + + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(context); + tt::MemDescType subviewTy = + tt::MemDescType::get(sliceType.getShape(), sliceType.getElementType(), + sliceType.getEncoding(), sharedMemorySpace, + /*mutableMemOry=*/mutableMem); + + Value zero = builder.createWithAsyncTaskIds( + buffer.getLoc(), 0, 32); + SmallVector copyOffsets(sliceType.getRank() + 1, zero); + copyOffsets[0] = bufferIdx; + + return builder.createWithAsyncTaskIds( + buffer.getLoc(), subviewTy, buffer, copyOffsets); +} + +Operation * +optimizeTMALoads(OpBuilderWithAsyncTaskIds &builder, + SmallVector &tmaLoads, + SmallVector &buffers, Value barrierAlloc, + Value bufferIdx, Value bufferIdxExtract, Value phase, + Operation *headProducer, Operation *headConsumer) { + auto loc = barrierAlloc.getLoc(); + + // Compute the total size of the loads. + int sizeInBytes = 0; + for (auto &tmaLoad : tmaLoads) { + sizeInBytes += getTMALoadSize(tmaLoad); + } + + // For each of the following ops, we will operate on a subview of each value + // according to the pipeline stage. + + // Create a barrier_expect with the appropriate size and insert it before the + // first load. + builder.setInsertionPoint(headProducer); + builder.setAsyncTaskIdsFromOp(headProducer); + auto prodBarrier = + getBarrierForPipelineStage(builder, barrierAlloc, bufferIdx); + auto pred = builder.createWithAsyncTaskIds(loc, 1, 1); + auto expect = builder.createWithAsyncTaskIds( + loc, prodBarrier, sizeInBytes, pred); + + // Convert all the producers to async_tma_copy_global_to_local + Operation *copy = nullptr; + for (auto [tmaLoad, buffer] : zip(tmaLoads, buffers)) { + auto pipelineBuffer = getBufferForPipelineStage(builder, tmaLoad.getType(), + buffer, bufferIdx, true); + copy = builder.createWithAsyncTaskIds( + loc, tmaLoad.getDescPtr(), tmaLoad.getIndices(), prodBarrier, + pipelineBuffer, pred); + } + + // Create a wait_barrier before the first consumer. + builder.setInsertionPoint(headConsumer); + builder.setAsyncTaskIdsFromOp(headConsumer); + auto consBarrier = + getBarrierForPipelineStage(builder, barrierAlloc, bufferIdxExtract); + phase = builder.createWithAsyncTaskIds( + loc, builder.getI32Type(), phase); + auto wait = builder.createWithAsyncTaskIds( + loc, consBarrier, phase); + + // Convert all the consumers to local_load + for (auto [tmaLoad, buffer] : zip(tmaLoads, buffers)) { + auto pipelineBuffer = getBufferForPipelineStage( + builder, tmaLoad.getType(), buffer, bufferIdxExtract, false); + auto sharedLoad = builder.createWithAsyncTaskIds( + loc, tmaLoad.getType(), pipelineBuffer); + + Value loadResult = tmaLoad.getResult(); + tmaLoad.getResult().replaceAllUsesWith(sharedLoad.getResult()); + tmaLoad.erase(); + } + return copy; +} + +// Lower producers for channels. Here channels are grouped in "map". tokenMap +// tracks the set of tokens for each channel. +void buildAsyncComm( + const DenseMap> &map, + const DenseMap> &tokenMap, + const DenseMap> &barrierAllocMap, + const DenseMap &bufferMap, int numBuffers, + int numConsumerGroups) { + + // Find the operation that is along producer's parent chain, and its parent + // is the same op as producer's parent. Here p is producer, and c is consumer. + auto getSameLevelOp = [](Operation *p, Operation *c) -> Operation * { + while (!isa(c)) { + if (c->getParentOp() == p->getParentOp()) { + return c; + } + c = c->getParentOp(); + } + llvm_unreachable("Failed to find consumer's same level Op with producer"); + }; + + auto consumerReleaseHeutistic = [&](Operation *p, Operation *c, + int consumerAsyncTaskId) -> Operation * { + if (c->getBlock() != p->getBlock()) + return getSameLevelOp(p, c); + for (auto it = c->getBlock()->rbegin(); it != c->getBlock()->rend(); ++it) { + if (!it->hasAttr("async_task_id")) + continue; + auto asyncAttr = it->getAttrOfType("async_task_id") + .getValues(); + if (asyncAttr.size() == 1 && asyncAttr[0] == consumerAsyncTaskId) + return &(*it); + } + return nullptr; + }; + + auto getAsyncTasks = [&](Operation *p, Operation *c, + SmallVector &asyncTaskP, + SmallVector &asyncTaskC, + SmallVector &asyncTasksPC) -> void { + asyncTaskP = getNestedAsyncTaskIds(p); + asyncTaskC = getNestedAsyncTaskIds(c); + asyncTasksPC.reserve(asyncTaskP.size() + asyncTaskC.size()); + asyncTasksPC.insert(asyncTasksPC.end(), asyncTaskP.begin(), + asyncTaskP.end()); + asyncTasksPC.insert(asyncTasksPC.end(), asyncTaskC.begin(), + asyncTaskC.end()); + }; + + // Go through each channel group. + for (auto kv : map) { + auto headProducer = kv.second.front()->srcOp; + auto tailProducer = kv.second.back()->srcOp; + auto headConsumer = kv.second.front()->dstOp; + auto tailConsumer = kv.second.back()->dstOp; + // We have one set of tokens for each channel group. + auto tokens = tokenMap.find(kv.second.front())->second; + + SmallVector asyncTaskP, asyncTaskC, asyncTasksPC; + getAsyncTasks(headProducer, headConsumer, asyncTaskP, asyncTaskC, + asyncTasksPC); + OpBuilderWithAsyncTaskIds builder(headProducer->getContext()); + if (auto funcOp = dyn_cast(headProducer->getParentOp())) { + builder.setInsertionPointToStart(&(funcOp.getBody().front())); + } else { + builder.setInsertionPoint(headProducer->getParentOp()); + } + builder.setAsynTaskIdsFromArray(asyncTasksPC); + + Value bufferIdx; + Value phase = Value(); + if (auto forOp = headProducer->getParentOfType()) { + // We already added phase, bufferIdx to the ForOp. + auto tSize = forOp.getBody()->getArguments().size(); + assert(tSize >= 2); + bufferIdx = forOp.getBody()->getArguments().back(); + phase = forOp.getBody()->getArgument(tSize - 2); // next to last argument + } else { + // Producer is not in a ForOp, create phase and bufferIdx here. + bufferIdx = builder.createWithAsyncTaskIds( + headProducer->getLoc(), 0, 32); + phase = builder.createWithAsyncTaskIds( + headProducer->getLoc(), 0, 1); + } + + assert((isa(headProducer)) && + "producer must be a LoadOp or tma LoadOp"); + builder.setAsynTaskIdsFromArray(asyncTaskP); + for (auto token : tokens) { + // Insert ProducerAcquireOp before the producer. + builder.setInsertionPoint(headProducer); + builder.createWithAsyncTaskIds( + headProducer->getLoc(), token.second, bufferIdx, phase); + + // Insert ProducerCommitOp if producer is LoadOp. For TMA, TMA lowering + // will handle the ProducerCommit. + if (isa(headProducer)) { + builder.setInsertionPointAfter(tailProducer); + builder.createWithAsyncTaskIds( + tailProducer->getLoc(), token.second, bufferIdx); + } + } + + for (auto token : tokens) { + builder.setAsynTaskIdsFromArray(token.first); + // Insert ConsumerWaitOp + if (!isa(headProducer)) { + auto consumerWaitPoint = getSameLevelOp(headProducer, headConsumer); + builder.setInsertionPoint(consumerWaitPoint); + builder.createWithAsyncTaskIds( + headConsumer->getLoc(), token.second, bufferIdx, phase); + } + + // Insert ConsumerReleaseOp. + auto consumerReleasePoint = + consumerReleaseHeutistic(tailProducer, tailConsumer, token.first); + builder.setInsertionPointAfter(consumerReleasePoint); + builder.createWithAsyncTaskIds( + consumerReleasePoint->getLoc(), token.second, bufferIdx); + } + + SmallVector tmaLoads; + SmallVector buffers; + // Go through all channels in this channel group. + for (auto &c : kv.second) { + assert( + (isa(c->srcOp)) && + "producer must be a LoadOp or tma LoadOp"); + bool insideLoop = c->srcOp->getParentOfType() != nullptr; + if (isa(c->srcOp)) { + // After createAsyncCopy, c->srcOp/headProducer are no longer valid. + createAsyncCopy(bufferMap, c, c->srcOp, asyncTasksPC, bufferIdx, + bufferIdx); + } else if (auto tmaLoad = + dyn_cast(c->srcOp)) { + tmaLoads.push_back(tmaLoad); + buffers.push_back(bufferMap.find(c)->second); + } + } + + // Optimize TMA loads. + if (tmaLoads.size() > 0) { + auto barrierAllocs = barrierAllocMap.find(kv.second.front())->second; + // TODO: we created one Alloc for each consumer taskId, but here, we + // only use the first Alloc. + auto barrierAlloc = barrierAllocs.begin()->second; + optimizeTMALoads(builder, tmaLoads, buffers, barrierAlloc, bufferIdx, + bufferIdx, phase, headProducer, headConsumer); + } + } +} + +// Collect argument indices that are used by the specific taskId. +static SmallVector collectBlockArgsForTask( + scf::ForOp forOp, int asyncTaskId, + DenseMap &blockArgToYieldOperand) { + DenseSet seen; + // Collect argument indices that can be reached along the definition chain. + // If reaching a BlockArgument, visit the corresponding yield operand. + SetVector argIndices; + std::function dfs = [&](Operation *op) { + if (!seen.insert(op).second) + return; + for (Value operand : op->getOperands()) { + if (auto blockArg = dyn_cast(operand)) { + if (!blockArgToYieldOperand[blockArg]) + continue; + argIndices.insert(blockArg.getArgNumber() - + forOp.getNumInductionVars()); + operand = blockArgToYieldOperand[blockArg]; + } + Operation *depOp = operand.getDefiningOp(); + assert(depOp && "Unexpected Value with no defining op"); + if (depOp->getBlock() != forOp.getBody()) + continue; + assert(hasAsyncTaskId(depOp, asyncTaskId) && "Dependency error"); + dfs(depOp); + } + }; + + // Start from operations that are marked with this asyncTaskId explicitly and + // check dependency with DFS traversal. + forOp.walk([&](Operation *op) { + if (hasAsyncTaskId(op, asyncTaskId) && !isa(op)) + dfs(op); + }); + + SmallVector args(argIndices.begin(), argIndices.end()); + llvm::sort(args); + return args; +} + +DenseMap +createForOpsForEachAsyncTaskId(scf::ForOp forOp) { + // Collect operation list for each asyncTaskId. + DenseMap> opList; + for (Operation &op : forOp.getBody()->without_terminator()) { + auto ids = getAsyncTaskIds(&op); + for (AsyncTaskId asyncTaskId : ids) + opList[asyncTaskId].push_back(&op); + } + + // Prepare blockArgToYieldOperand mapping. + DenseMap blockArgToYieldOperand; + auto yieldOp = llvm::cast(forOp.getBody()->getTerminator()); + assert(yieldOp.getNumOperands() == forOp.getNumRegionIterArgs()); + for (unsigned i = 0; i < forOp.getNumRegionIterArgs(); ++i) + blockArgToYieldOperand[forOp.getRegionIterArg(i)] = yieldOp.getOperand(i); + + auto loc = forOp.getLoc(); + OpBuilderWithAsyncTaskIds builder(forOp.getContext()); + DenseMap asyncTasksToForOp; + + // Create newForOp for each task Id. + for (AsyncTaskId asyncTaskId : getNestedAsyncTaskIds(forOp)) { + auto usedArgs = + collectBlockArgsForTask(forOp, asyncTaskId, blockArgToYieldOperand); + + // Prepare newLoopArgs. + SmallVector newLoopArgs; + for (unsigned argNumber : usedArgs) + newLoopArgs.push_back(forOp.getInitArgs()[argNumber]); + + // Create newForOp. + builder.setAsynTaskIdsFromArray({asyncTaskId}); + builder.setInsertionPoint(forOp); + auto newForOp = builder.createWithAsyncTaskIds( + loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), + newLoopArgs); + if (forOp->getAttr("tt.loop_schedule")) + newForOp->setAttr("tt.loop_schedule", forOp->getAttr("tt.loop_schedule")); + + // Initialize Value mapping from forOp to newForOp + IRMapping mapping; + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + for (unsigned i = 0; i < usedArgs.size(); ++i) { + auto oldArg = forOp.getRegionIterArgs()[usedArgs[i]]; + auto newArg = newForOp.getRegionIterArgs()[i]; + mapping.map(oldArg, newArg); + } + + // Clone all operations with this asyncTaskId to newForOp. + builder.setInsertionPointToStart(newForOp.getBody()); + for (Operation *op : opList[asyncTaskId]) { + Operation *newOp = builder.clone(*op, mapping); + setAsyncTaskIds(newOp, {asyncTaskId}); + for (unsigned i = 0; i < op->getNumResults(); ++i) + mapping.map(op->getResult(i), newOp->getResult(i)); + } + + // Create YieldOp for newForOp. + SmallVector newYieldOperands; + for (unsigned i : usedArgs) { + LDBG("lookup operand " << i); + newYieldOperands.push_back(mapping.lookup(yieldOp.getOperand(i))); + } + bool createNewYield = true; + if (newForOp.getBody()->mightHaveTerminator()) { + auto initialYield = + llvm::cast(newForOp.getBody()->getTerminator()); + if (newYieldOperands.size() == 0) { + setAsyncTaskIds(initialYield, {asyncTaskId}); + createNewYield = false; + } + } + if (createNewYield) { + auto newYieldOp = + builder.create(yieldOp.getLoc(), newYieldOperands); + setAsyncTaskIds(newYieldOp, {asyncTaskId}); + } + + // Replace results of forOp with results of newForOp. + for (unsigned i = 0; i < usedArgs.size(); ++i) { + auto oldResult = forOp.getResult(usedArgs[i]); + auto newResult = newForOp.getResult(i); + oldResult.replaceUsesWithIf(newResult, [&](OpOperand &operand) -> bool { + return hasAsyncTaskId(operand.getOwner(), asyncTaskId); + }); + } + + asyncTasksToForOp[asyncTaskId] = newForOp; + } + + return asyncTasksToForOp; +} + +// Input asyncTaskTopOp can be an IfOp that contains a ForOp. We clone +// the ForOp for each asyncTaskId. +DenseMap +asyncTaskDivision(Operation *asyncTaskTopOp) { + DenseMap asyncTaskTopOpMap; + Operation *mainForOp = asyncTaskTopOp; + if (auto ifOp = dyn_cast(asyncTaskTopOp)) { + // Find the outmost ForOp inside. Assume only a single ForOp. + Operation *nestedFor = nullptr; + asyncTaskTopOp->walk([&](Operation *op) { + if (auto forOp = dyn_cast(op)) { + assert(nestedFor == nullptr); + nestedFor = op; + } + }); + assert(nestedFor && "can't find ForOp in a top-level IfOp"); + mainForOp = nestedFor; + } + asyncTaskTopOp->walk([&](Operation *op) { + auto ids = getAsyncTaskIds(op); + if (op->getNumRegions() > 0 && ids.size() > 1) { + if (auto forOp = dyn_cast(op)) { + // Create a cloned ForOp for each taskId and return the map. + auto forOps = createForOpsForEachAsyncTaskId(forOp); + if (op == mainForOp) { + for (auto kv : forOps) { + auto f = kv.second; + auto id = getAsyncTaskIds(f.getOperation()); + assert(id.size() == 1 && + "generated ForOp doesn't have one and only one asyncTaskId"); + asyncTaskTopOpMap[id.front()] = f.getOperation(); + } + } + // For debugging purposes, check to see if it is safe to erase the + // original ForOp. + bool hasIssue = false; + for (Operation &opT : forOp.getBody()->without_terminator()) { + // Check to see if opT is used in another block. + for (unsigned i = 0; i < opT.getNumResults(); ++i) + for (Operation *user : opT.getResult(i).getUsers()) { + if (user->getBlock() != opT.getBlock()) { + hasIssue = true; + LLVM_DEBUG({ + LDBG("-- op has user in another block"); + opT.dump(); + user->dump(); + }); + } + } + } + if (hasIssue) { + for (Operation &opT : forOp.getBody()->without_terminator()) { + LLVM_DEBUG({ + LDBG("addr " << (&opT) << ": "); + opT.dump(); + }); + } + } + bool hasUse = false; + for (unsigned i = 0; i < op->getNumResults(); ++i) { + for (Operation *user : op->getResult(i).getUsers()) { + hasUse = true; + LLVM_DEBUG({ + LDBG("op has use "); + user->dump(); + }); + } + } + ModuleOp moduleOp = forOp->getParentOfType(); + LLVM_DEBUG({ + LDBG("erase ForOp"); + forOp.dump(); + }); + forOp.erase(); + LDBG("done erasing ForOp"); + } else if (auto ifOp = dyn_cast(op)) { + // The ForOp inside this ifOp will be cloned. + LDBG("IfOp in asyncTaskDivision"); + } else if (auto whileOp = dyn_cast(op)) { + LDBG("WhileOp in asyncTaskDivision"); + } else { + llvm_unreachable("Unexpected Op with regions"); + } + } + }); + assert(asyncTaskTopOpMap.size() > 0 && "AsyncTask division failed"); + return asyncTaskTopOpMap; +} + +void cloneAsyncTaskLoopForEachAsyncTaskId( + SmallVector &asyncTaskTopOps) { + SmallVector newBackBone; + + for (Operation *op : asyncTaskTopOps) { + auto loc = op->getLoc(); + OpBuilderWithAsyncTaskIds builder(op->getContext()); + builder.setInsertionPoint(op); + // Step 1: create a cloned forOp for each taskId based on the original + // ForOp that is in this top-level operation. + DenseMap newAsyncTaskLoops = + asyncTaskDivision(op); + + // Step 2: remove irrelevant Ops from the cloned ForOps. + for (auto kv : newAsyncTaskLoops) { + SmallVector deleteOps; + AsyncTaskId targetId = kv.first; + Operation *newAsyncTaskLoop = kv.second; + newAsyncTaskLoop->walk([&](Operation *subOp) { + auto ids = getAsyncTaskIds(subOp); + if (std::find(ids.begin(), ids.end(), targetId) == ids.end()) { + deleteOps.push_back(subOp); + } + }); + for (auto it = deleteOps.rbegin(); it != deleteOps.rend(); ++it) { + (*it)->erase(); + } + } + } +} + +class TritonGPUWSCodePartitionPass + : public impl::TritonGPUWSCodePartitionBase { +public: + using impl::TritonGPUWSCodePartitionBase< + TritonGPUWSCodePartitionPass>::TritonGPUWSCodePartitionBase; + + void runOnFuncOp(triton::FuncOp funcOp) { + // Disable code partitioning when numBuffers is 0. + if (numBuffers == 0) + return; + + // Step 1: collect all communications between producers and consumers. + SmallVector> channelsOrigin; + collectAsyncChannels(channelsOrigin, funcOp); + SmallVector channels; + for (const auto &c : channelsOrigin) { + channels.push_back(c.get()); + } + if (channels.empty()) { + return; + } + + // Step 2: group channels where each entry of the map is keyed by the dstOp. + DenseMap> map; + SmallVector mapKeyVec; + groupChannels(channels, map, mapKeyVec); + + // Step 3: reorder producer ops and the backward slices of the producer ops. + reorderProducerOps(channels); + + // Step 4: find top-level ops that contain a channel, also create new ForOps + // by adding phase and bufferIdx to the original ForOps, erase the original + // ForOps. + SmallVector asyncTaskTopOps = + getTaskTopRegion(funcOp, channels); + appendBufferIdxArgs(asyncTaskTopOps, numBuffers); + + // Step 5: Create tokens, and buffers. A set of tokens for each group of + // channels and an array of buffers for each channel. + DenseMap> barrierAllocMap; + DenseMap> tokenMap = createToken( + map, mapKeyVec, funcOp, numBuffers, numConsumerGroups, barrierAllocMap); + DenseMap bufferMap = + createBuffer(channels, funcOp, numBuffers, numConsumerGroups); + LLVM_DEBUG({ + LDBG("\n\nafter createBuffer"); + funcOp.dump(); + }); + + // Step 6: add async communication ops (ProducerAcquire etc). Also lower the + // loads. + buildAsyncComm(map, tokenMap, barrierAllocMap, bufferMap, numBuffers, + numConsumerGroups); + LLVM_DEBUG({ + LDBG("\n\nwith SyncOps"); + funcOp.dump(); + }); + + // If loadResult has a single use which is LocalAlloc, we can get rid of + // sharedLoad and replace all uses of LocalAlloc with viewLoad. + DenseMap opsToReplace; + funcOp.walk([&](ttg::LocalAllocOp localAlloc) { + if (auto src = localAlloc.getSrc()) { + if (auto localLoad = dyn_cast(src.getDefiningOp())) { + opsToReplace[localAlloc] = localLoad.getSrc(); + } + } + }); + OpBuilderWithAsyncTaskIds builder(funcOp.getContext()); + for (auto kv : opsToReplace) + replaceUsesAndPropagateType(builder, kv.getFirst(), kv.getSecond()); + LLVM_DEBUG({ + LDBG("\n\nsimplify localLoad + localAlloc"); + funcOp.dump(); + }); + + // Clone taskTopOp, remove irrelevant blockArgument for {forOp, ifOp} + cloneAsyncTaskLoopForEachAsyncTaskId(asyncTaskTopOps); + LLVM_DEBUG({ + LDBG("\n\nwith Loop Split"); + funcOp.dump(); + }); + + auto ret = SpecializeRegion(funcOp, regDecProducer, regIncConsumer); + LLVM_DEBUG({ + LDBG("\n\nwith IfOps"); + funcOp.dump(); + }); + } + + void runOnOperation() override { + getOperation()->walk([&](triton::FuncOp funcOp) { runOnFuncOp(funcOp); }); + LLVM_DEBUG({ + LDBG("post pass"); + getOperation()->dump(); + }); + return; + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/WSDataPartition.cpp b/lib/Dialect/TritonGPU/Transforms/WSDataPartition.cpp new file mode 100644 index 000000000..9e9f0a6a3 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/WSDataPartition.cpp @@ -0,0 +1,680 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" + +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; +namespace mlir { +namespace triton { +namespace gpu { + +#define DEBUG_TYPE "tritongpu-warp-spec-data-partition" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +static bool oneVecCoversTheOther(SmallVector &one, + SmallVector &other) { + // Every element of other appears in one. + for (AsyncTaskId t : other) { + // If t doesn't appear in one, return false. + bool found = false; + for (AsyncTaskId t2 : one) { + if (t2 == t) { + found = true; + break; + } + } + if (!found) + return false; + } + return true; +} + +// Make sure the def chain contains the right taskId. +bool fixTaskId(triton::FuncOp &funcOp) { + bool retCode = true; + funcOp.walk([&](Operation *op) { + auto asyncTaskIds = getAsyncTaskIds(op); + for (Value operand : op->getOperands()) { + Operation *defOp = operand.getDefiningOp(); + if (!defOp) + continue; + // Do not update loads. + if (isa(defOp)) + continue; + auto defTaskIds = getAsyncTaskIds(defOp); + // Make sure defTaskIds cover asyncTaskIds. Call addAsyncTaskIds if + // necessary. + if (!oneVecCoversTheOther(defTaskIds, asyncTaskIds)) { + retCode = false; + // Const ops with same value but different task ids can be folded. + if (isa(defOp)) { + LLVM_DEBUG({ + LDBG("fixing taskId for"); + defOp->dump(); + }); + addAsyncTaskIds(defOp, asyncTaskIds); + LLVM_DEBUG({ + LDBG("resulting"); + defOp->dump(); + }); + } + } + } + }); + return retCode; +} + +static SmallVector getShape(Value v) { + auto type = v.getType(); + if (auto type = dyn_cast(v.getType())) { + return {type.getShape().begin(), type.getShape().end()}; + } else if (auto type = dyn_cast(v.getType())) { + return {type.getShape().begin(), type.getShape().end()}; + } + return {}; +} + +bool needToSlice(Value v, int dim, int size) { + auto shape = getShape(v); + return shape.size() > dim && shape[dim] > size; +} + +bool getBackwardSliceToPartition(Value root, unsigned dim, int sliceSize, + SetVector &backwardSlice) { + auto newOpInserted = false; + SmallVector queue = {root}; + while (!queue.empty()) { + auto v = queue.back(); + queue.pop_back(); + if (!needToSlice(v, dim, sliceSize)) + continue; + if (auto op = v.getDefiningOp()) { + auto inserted = backwardSlice.insert(op); + newOpInserted |= inserted; + if (inserted) { + if (op->hasTrait() || + isa(op)) { + for (Value operand : op->getOperands()) + queue.push_back(operand); + } else if (auto dotOp = dyn_cast(op)) { + queue.push_back(dim == 0 ? dotOp.getA() : dotOp.getB()); + queue.push_back(dotOp.getC()); + } else { + llvm_unreachable("Unexpected op"); + } + } + } else { + assert(isa(v) && "value is not an operation or block "); + auto bbArg = cast(v); + Operation *bbAargOwner = bbArg.getOwner()->getParentOp(); + if (auto forOp = dyn_cast(bbAargOwner)) { + // track initial value + auto initArg = forOp.getInitArgs()[bbArg.getArgNumber() - 1]; + queue.push_back(initArg); + // track yield value + auto yieldArg = forOp.getYieldedValues()[bbArg.getArgNumber() - 1]; + queue.push_back(yieldArg); + } + } + } + return newOpInserted; +}; + +bool getForwardSliceToPartition(Value root, unsigned dim, int sliceSize, + SetVector &forwardSlice) { + auto newOpInserted = false; + SmallVector queue = {root}; + llvm::SmallDenseSet seen; + while (!queue.empty()) { + auto v = queue.back(); + queue.pop_back(); + if (!seen.insert(v).second) + continue; + if (!needToSlice(v, dim, sliceSize)) + continue; + getForwardSlice(v, &forwardSlice); + for (Operation *op : forwardSlice) { + if (op->getNumResults() > 0) + seen.insert(op->getResult(0)); + if (auto yieldOp = dyn_cast(op)) { + if (auto forOp = dyn_cast(yieldOp->getParentOp())) { + for (OpOperand &operand : yieldOp->getOpOperands()) { + if (seen.count(operand.get())) { + queue.push_back(forOp->getResult(operand.getOperandNumber())); + forwardSlice.insert(forOp); + newOpInserted = true; + } + } + } + } + } + } + return newOpInserted; +}; + +// Compute a closure of all ops originated from or being dependent on by the +// root op. +void getSliceToPartition(Value root, unsigned dim, int sliceSize, + SetVector &slice) { + auto newOpInserted = false; + while (!newOpInserted) { + newOpInserted |= getBackwardSliceToPartition(root, dim, sliceSize, slice); + SetVector forwardSlice; + newOpInserted |= + getForwardSliceToPartition(root, dim, sliceSize, forwardSlice); + slice.insert(forwardSlice.begin(), forwardSlice.end()); + for (auto op : forwardSlice) { + if (op->hasTrait() || + isa(op)) { + for (OpOperand &operand : op->getOpOperands()) { + newOpInserted |= + getBackwardSliceToPartition(operand.get(), dim, sliceSize, slice); + } + } else if (auto dotOp = dyn_cast(op)) { + newOpInserted |= getBackwardSliceToPartition( + dim == 0 ? dotOp.getA() : dotOp.getB(), dim, sliceSize, slice); + newOpInserted |= + getBackwardSliceToPartition(dotOp.getC(), dim, sliceSize, slice); + } + } + } +} + +struct DataPartitionScheme { + // Which dimension to partition. For dot, dim 0 means along M dimension, 1 + // means along N dimensiont. + unsigned partitionDim = 0; + unsigned numPartitions = 0; + SetVector ops; +}; + +bool computePartitionScheme(triton::FuncOp &funcOp, + DataPartitionScheme &partitionScheme) { + // Do not partition producer tasks + + // Use dot to drive the partition + SetVector dots; + + // check all dot ops that have more than one async task id + funcOp.walk([&](Operation *op) { + auto asyncTaskIds = getAsyncTaskIds(op); + if (asyncTaskIds.size() > 1) { + if (auto dotWaitOp = dyn_cast(op)) { + dots.insert(dotWaitOp); + } + } + }); + + // Checking if all dots can be partitioned in the same way + int numWarps = + TritonGPUDialect::getNumWarps(funcOp->getParentOfType()); + for (auto dotOp : dots) { + // partition along M first, otherwise along N + RankedTensorType dotType = dotOp.getType(); + LLVM_DEBUG({ + LDBG("Computing partition scheme for"); + dotOp.dump(); + LDBG("\n"); + }); + auto shapePerCTA = getShapePerCTA(dotType); + if (shapePerCTA.size() != 2) { + LDBG("partition not possible: shapePerCTA " << shapePerCTA.size()); + return false; + } + auto CTALayout = getCTALayout(dotType.getEncoding()); + auto asyncTaskIds = getAsyncTaskIds(dotOp); + int sliceSizeM = shapePerCTA[0] / asyncTaskIds.size(); + int sliceSizeN = shapePerCTA[1] / asyncTaskIds.size(); + int partitionDim, partitionSize; + Value partitionOperand; + + if (sliceSizeM >= 64) { + LLVM_DEBUG({ LDBG("partition along M\n"); }); + partitionDim = 0; + partitionSize = sliceSizeM; + partitionOperand = dotOp.getA(); + } else if (sliceSizeN >= 256) { + LLVM_DEBUG({ LDBG("partition along N\n"); }); + partitionDim = 1; + partitionSize = sliceSizeN; + partitionOperand = dotOp.getB(); + } else { + LDBG("partition not possible: " << sliceSizeM << " " << sliceSizeN); + return false; + } + + if (partitionScheme.numPartitions == 0) { + partitionScheme.partitionDim = partitionDim; + partitionScheme.numPartitions = asyncTaskIds.size(); + } else { + if (partitionScheme.partitionDim != partitionDim || + partitionScheme.numPartitions != asyncTaskIds.size()) { + LDBG("partition not possible, in conflict with previous partition\n"); + return false; + } + } + + // Partition the slice closure + SetVector &slice = partitionScheme.ops; + getSliceToPartition(dotOp.getD(), partitionDim, partitionSize, slice); + + LLVM_DEBUG({ + partitionOperand.dump(); + LDBG("\n"); + LDBG(" slice:"); + for (auto &op : slice) { + op->dump(); + } + LDBG("\n"); + }); + + for (auto op : partitionScheme.ops) { + auto opTaskIds = getAsyncTaskIds(op); + // skip check for control flow ops + if (isa(op)) + continue; +#if 0 + if (opTaskIds.size() > partitionScheme.numPartitions) { + LLVM_DEBUG({ + LDBG("partition not possible: numPartitions" << opTaskIds.size() << " " << partitionScheme.numPartitions); + op->dump(); + }); + return false; + } +#endif + } + } + + return !partitionScheme.ops.empty(); +} + +Operation *sliceOp(Value v, int offset, OpBuilderWithAsyncTaskIds &builder, + IRMapping &mappings, IRMapping &reverseMappings, + DataPartitionScheme &partitionScheme); + +Operation *sliceOp(Operation *op, int offset, + OpBuilderWithAsyncTaskIds &builder, IRMapping &mappings, + IRMapping &reverseMappings, + DataPartitionScheme &partitionScheme) { + if (!partitionScheme.ops.contains(op)) + return op; + if (mappings.contains(op)) + return mappings.lookupOrNull(op); + if (reverseMappings.contains(op)) + return op; + + LLVM_DEBUG({ + LDBG("slicing:"); + op->dump(); + LDBG("\n"); + }); + + int dim = partitionScheme.partitionDim; + int numOfPartitions = partitionScheme.numPartitions; + + auto asyncTaskIds = getAsyncTaskIds(op); + SmallVector sliceTaskIds; + if (asyncTaskIds.size() == numOfPartitions) { + // We are slicing the op for consumer only + sliceTaskIds.push_back(asyncTaskIds[offset]); + } else if (asyncTaskIds.size() == 1) { + // We are slicing the op for producer only + sliceTaskIds.push_back(asyncTaskIds.front()); + } else if (asyncTaskIds.size() > numOfPartitions) { + // We are slicing the op for both producer and consumer + sliceTaskIds.push_back(asyncTaskIds.front()); + sliceTaskIds.push_back(asyncTaskIds[offset + 1]); + } else { + llvm_unreachable("Unexpected asyncTaskIds.size()"); + } + + builder.setAsynTaskIdsFromArray(sliceTaskIds); + auto cloneAndSetResultType = [&](Operation *op) { + builder.setInsertionPoint(op); + auto newOp = builder.clone(*op, mappings); + setAsyncTaskIds(newOp, sliceTaskIds); + mappings.map(op, newOp); + reverseMappings.map(newOp, op); + // set result shape + if (!op->getResults().empty()) { + auto v = op->getResult(0); + auto newV = newOp->getResult(0); + if (auto type = dyn_cast(v.getType())) { + SmallVector shape{type.getShape().begin(), + type.getShape().end()}; + int sliceSize = shape[dim] / numOfPartitions; + shape[dim] = sliceSize; + auto newType = + MemDescType::get(shape, type.getElementType(), type.getEncoding(), + type.getMemorySpace(), type.getMutableMemory()); + newV.setType(newType); + } else if (auto type = dyn_cast(v.getType())) { + SmallVector shape{type.getShape().begin(), + type.getShape().end()}; + int sliceSize = shape[dim] / numOfPartitions; + shape[dim] = sliceSize; + auto newType = RankedTensorType::get(shape, type.getElementType(), + type.getEncoding()); + newV.setType(newType); + } + + mappings.map(v, newV); + reverseMappings.map(newV, v); + } + return newOp; + }; + + // slice operands first + Operation *newOp; + if (op->hasTrait() || + isa( + op)) { + for (Value operand : op->getOperands()) + sliceOp(operand, offset, builder, mappings, reverseMappings, + partitionScheme); + newOp = cloneAndSetResultType(op); + } else if (auto constOp = dyn_cast(op)) { + builder.setInsertionPoint(op); + auto valAttr = cast(constOp.getValueAttr()); + auto valType = cast(valAttr.getType()); + SmallVector shape{valType.getShape().begin(), + valType.getShape().end()}; + int sliceSize = shape[dim] / numOfPartitions; + shape[dim] = sliceSize; + auto newValType = valType.clone(shape); + auto newValAttr = valAttr.resizeSplat(newValType); + newOp = builder.createWithAsyncTaskIds(op->getLoc(), + newValAttr); + // Do not drop original task id as constant folding may lose one constant. + setAsyncTaskIds(newOp, getAsyncTaskIds(op)); + auto v = op->getResult(0); + auto newV = newOp->getResult(0); + mappings.map(v, newV); + reverseMappings.map(newV, v); + } else if (auto makeRangeOp = dyn_cast(op)) { + builder.setInsertionPoint(op); + int newRangeStart = makeRangeOp.getStart(); + int newRangeEnd = makeRangeOp.getEnd(); + int sliceSize = (newRangeEnd - newRangeStart) / numOfPartitions; + newRangeStart += offset * sliceSize; + newRangeEnd = newRangeStart + sliceSize; + auto v = op->getResult(0); + auto type = cast(v.getType()); + auto newType = RankedTensorType::get({sliceSize}, builder.getI32Type(), + type.getEncoding()); + newOp = builder.createWithAsyncTaskIds( + op->getLoc(), newType, newRangeStart, newRangeEnd); + auto newV = newOp->getResult(0); + mappings.map(v, newV); + reverseMappings.map(newV, v); + } else if (isa(op)) { + for (Value operand : op->getOperands()) + sliceOp(operand, offset, builder, mappings, reverseMappings, + partitionScheme); + // TODO: slice store base ptr + newOp = cloneAndSetResultType(op); + } else if (isa( + op)) { + SmallVector shape; + Value coordVal; + if (auto loadOp = dyn_cast(op)) { + coordVal = loadOp.getIndices()[dim]; + shape = getShape(loadOp.getResult()); + } else if (auto storeOp = dyn_cast(op)) { + coordVal = storeOp.getIndices()[dim]; + shape = getShape(storeOp.getSrc()); + } + auto newCoordVal = coordVal; + if (offset) { + builder.setInsertionPointAfter(coordVal.getDefiningOp()); + Value offsetVal = builder.createWithAsyncTaskIds( + op->getLoc(), offset * shape[dim] / numOfPartitions, 32); + newCoordVal = builder.createWithAsyncTaskIds( + op->getLoc(), coordVal, offsetVal); + mappings.map(coordVal, newCoordVal); + reverseMappings.map(newCoordVal, coordVal); + } + + newOp = cloneAndSetResultType(op); + if (isa(op)) { + // map load result + auto v = op->getResult(0); + auto newV = newOp->getResult(0); + mappings.map(v, newV); + reverseMappings.map(newV, v); + } + } else if (auto dotOp = dyn_cast(op)) { + // Only hanlde A and accumulator + sliceOp(dim == 0 ? dotOp.getA() : dotOp.getB(), offset, builder, mappings, + reverseMappings, partitionScheme); + sliceOp(dotOp.getC(), offset, builder, mappings, reverseMappings, + partitionScheme); + newOp = cloneAndSetResultType(op); + } else if (auto forOp = dyn_cast(op)) { + // Add new loop arguments + SmallVector newLoopArgs; + for (auto initArg : forOp.getInitArgs()) + newLoopArgs.push_back(initArg); + DenseMap newArgIdices; + for (unsigned i = 0; i < forOp.getInitArgs().size(); i++) { + auto initArg = forOp.getInitArgs()[i]; + auto newInitArgOp = sliceOp(initArg.getDefiningOp(), offset, builder, + mappings, reverseMappings, partitionScheme); + auto newInitArg = newInitArgOp->getResult(0); + if (newInitArg != initArg) { + newLoopArgs.append({newInitArg}); + forOp.getBody()->insertArgument(forOp.getBody()->getNumArguments(), + newInitArg.getType(), forOp.getLoc()); + newArgIdices[i] = newLoopArgs.size() - 1; + } + } + + // Create newForOp and take the region of forOp + builder.setInsertionPoint(op); + auto newForOp = builder.createWithAsyncTaskIds( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), newLoopArgs); + assert(newForOp.getRegionIterArgs().size() == + newForOp.getInitArgs().size()); + newForOp->setAttrs(forOp->getAttrs()); + partitionScheme.ops.insert(newForOp); + newOp = newForOp; + + // Replace forOp with newForOp + newForOp.getRegion().takeBody(forOp.getRegion()); + for (unsigned i = 0; i < forOp.getNumResults(); ++i) + forOp.getResult(i).replaceAllUsesWith(newForOp.getResult(i)); + op->setAttr("to_be_removed", builder.getUnitAttr()); + + // Map new loop arguments + for (auto argIndex : newArgIdices) { + Value v = newForOp.getResult(argIndex.first); + Value newV = newForOp.getResult(argIndex.second); + mappings.map(v, newV); + reverseMappings.map(newV, v); + + auto regionArg = newForOp.getRegionIterArg(argIndex.first); + auto newRegionArg = newForOp.getRegionIterArg(argIndex.second); + mappings.map(regionArg, newRegionArg); + reverseMappings.map(newRegionArg, regionArg); + } + + } else if (auto yieldOp = dyn_cast(op)) { + int num = yieldOp.getNumOperands(); + for (int i = 0; i < num; i++) { + auto operand = yieldOp.getOperand(i); + sliceOp(operand, offset, builder, mappings, reverseMappings, + partitionScheme); + if (auto newV = mappings.lookupOrNull(operand)) + yieldOp->insertOperands(op->getNumOperands(), newV); + } + newOp = op; + } else if (auto reduceOp = dyn_cast(op)) { + assert(reduceOp.getAxis() != partitionScheme.partitionDim && + "reduce should not happen on the partitioned dimension"); + for (Value operand : op->getOperands()) + sliceOp(operand, offset, builder, mappings, reverseMappings, + partitionScheme); + newOp = cloneAndSetResultType(op); + } else { + llvm_unreachable("unsupported value type"); + } + + LLVM_DEBUG({ + LDBG("resulting"); + newOp->dump(); + LDBG("\n"); + }); + mappings.map(op, newOp); + reverseMappings.map(newOp, op); + return newOp; +} + +Operation *sliceOp(Value v, int offset, OpBuilderWithAsyncTaskIds &builder, + IRMapping &mappings, IRMapping &reverseMappings, + DataPartitionScheme &partitionScheme) { + if (auto op = v.getDefiningOp()) { + return sliceOp(op, offset, builder, mappings, reverseMappings, + partitionScheme); + } else { + assert(isa(v) && "value is not an operation or block "); + auto bbArg = cast(v); + Operation *bbAargOwner = bbArg.getOwner()->getParentOp(); + return sliceOp(bbAargOwner, offset, builder, mappings, reverseMappings, + partitionScheme); + } +} + +void partitionTasks(triton::FuncOp &funcOp) { + + // op -> (partition dim, num of partitions) + DataPartitionScheme partitionScheme; + if (!computePartitionScheme(funcOp, partitionScheme)) + return; + + for (int i = 0; i < partitionScheme.numPartitions; i++) { + OpBuilderWithAsyncTaskIds builder(funcOp.getContext()); + IRMapping mappings, reverseMappings; + + LLVM_DEBUG({ LDBG("partitioning op for task " << i << ":\n"); }); + + // TODO: compute a topological order for partitionScheme.ops and + // slice in that order. + int numOps = partitionScheme.ops.size(); + for (int j = 0; j < numOps; j++) { + auto op = partitionScheme.ops[j]; + sliceOp(op, i, builder, mappings, reverseMappings, partitionScheme); + } + + // clean up + SmallVector opsToDelete; + for (auto op : partitionScheme.ops) { + if (op->hasAttr("to_be_removed")) + opsToDelete.push_back(op); + } + for (auto op : opsToDelete) { + partitionScheme.ops.remove(op); + op->erase(); + } + } + + // clean up + + SmallVector opsToDelete; + for (auto op : partitionScheme.ops) { + if (isa(op)) + continue; + bool notUsed = true; + for (auto result : op->getResults()) { + if (!result.getUsers().empty()) { + notUsed = false; + break; + } + } + if (notUsed) + opsToDelete.push_back(op); + } + + LLVM_DEBUG({ + LDBG("opsToDelete:\n"); + for (auto op : opsToDelete) { + LDBG("op: "); + op->dump(); + } + LDBG("\n"); + }); + for (auto op : opsToDelete) { + partitionScheme.ops.remove(op); + op->erase(); + } + LLVM_DEBUG({ + LDBG("prior to clean up:"); + funcOp.dump(); + }); + + // delete block arguments + RewritePatternSet cleanUpPatterns(funcOp.getContext()); + populateForOpDeadArgumentElimination(cleanUpPatterns); + scf::ForOp::getCanonicalizationPatterns(cleanUpPatterns, funcOp.getContext()); + if (applyPatternsAndFoldGreedily(funcOp, std::move(cleanUpPatterns)) + .failed()) { + llvm_unreachable("failed to clean up"); + // signalPassFailure(); + } + + // Make sure original ops are not used + LLVM_DEBUG({ + LDBG("after partition"); + funcOp.dump(); + LDBG("\n"); + }); + fixTaskId(funcOp); +} + +#define GEN_PASS_DEF_TRITONGPUWSDATAPARTITION +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUWSDataPartitionPass + : public impl::TritonGPUWSDataPartitionBase { +public: + using impl::TritonGPUWSDataPartitionBase< + TritonGPUWSDataPartitionPass>::TritonGPUWSDataPartitionBase; + + void runOnFuncOp(triton::FuncOp funcOp) { + if (numConsumerGroups == 0) + return; + partitionTasks(funcOp); + } + + void runOnOperation() override { + getOperation()->walk([&](triton::FuncOp funcOp) { runOnFuncOp(funcOp); }); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/WSLowering.cpp b/lib/Dialect/TritonGPU/Transforms/WSLowering.cpp new file mode 100644 index 000000000..c2bf31fc5 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/WSLowering.cpp @@ -0,0 +1,349 @@ +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" + +#include + +#include "mlir/IR/OperationSupport.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; +namespace mlir { +namespace triton { +namespace gpu { + +#define DEBUG_TYPE "tritongpu-warp-spec-lowering" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +enum class LoadType { + LoadAsyncOp, + LoadTMAOp, +}; + +static Value createThreadIdOp(OpBuilder &builder, Location loc) { + Value threadId = builder.create<::mlir::gpu::ThreadIdOp>( + loc, builder.getIndexType(), ::mlir::gpu::Dimension::x); + auto cast = builder.create( + loc, TypeRange{builder.getIntegerType(32)}, ValueRange{threadId}); + return cast.getResult(0); +} + +// Lower to use GetCanonicalWarpIdOp. +// In Hopper, each task is a warpgroup consisting of 4 warps. +static const int WARPS_PER_TASK = 4; +static const int THREADS_PER_TASK = 128; +void lowerGetAsyncTaskIdOp(Operation *parentOp, int numConsumerGroups) { + DenseSet eraseOps; + parentOp->walk([&](ttng::GetAsyncTaskIdOp op) { + auto loc = op.getLoc(); + OpBuilder builder(op); + Value _4 = builder.create(loc, WARPS_PER_TASK, 32); + Value warpId = builder.create(loc); + Value asyncTaskId = builder.create(loc, warpId, _4); + op.getResult().replaceAllUsesWith(asyncTaskId); + + LLVM_DEBUG({ + LDBG("erasing GetAsyncTask"); + op->dump(); + }); + eraseOps.insert(op); + }); + for (Operation *op : eraseOps) + op->erase(); +} + +//===----------------------------------------------------------------------===// +// Lower token operations +//===----------------------------------------------------------------------===// + +LoadType scanLoadTypes(ttng::CreateTokenOp createTokenOp) { + std::set loadTypes; + createTokenOp->getBlock()->walk([&](Operation *op) { + if (auto asyncCopy = dyn_cast(op)) { + loadTypes.insert(LoadType::LoadAsyncOp); + } else if (auto asyncCopy = + dyn_cast(op)) { + loadTypes.insert(LoadType::LoadTMAOp); + } + }); + assert(loadTypes.size() > 0 && "no async copy in the block"); + assert(loadTypes.size() == 1 && "block contains both async copy and tma"); + return *loadTypes.begin(); +} + +Value getMBarrierPhaseBit(OpBuilder &builder, Operation *op, + bool emptyBarrier) { + auto loc = op->getLoc(); + assert(isa(op) || isa(op)); + Value curPhase; + if (auto acq = dyn_cast(op)) + curPhase = acq.getPhase(); + else if (auto wait = dyn_cast(op)) + curPhase = wait.getPhase(); + if (emptyBarrier) { + // curPhase = curPhase xor True for emptyBarrier. + Value _1_1b = builder.create(loc, 1, 1); + curPhase = builder.create(loc, curPhase, _1_1b); + } + LLVM_DEBUG(curPhase.dump()); + return curPhase; +} + +void processProducerAcquireOp(OpBuilder &builder, ttng::ProducerAcquireOp op, + Value bufferEmpty) { + auto loc = op.getLoc(); + Value phase = getMBarrierPhaseBit(builder, op, true); + auto i32Ty = builder.getIntegerType(32); + phase = builder.create(loc, i32Ty, phase); + auto waitOp = builder.create(loc, bufferEmpty, phase); + assert(op.getOperation()->hasAttr("async_task_id")); + setAsyncTaskIds(waitOp, getAsyncTaskIds(op.getOperation())); +} + +void processProducerCommitOp(OpBuilder &builder, ttng::ProducerCommitOp op, + Value bufferFull, LoadType loadType) { + auto loc = op.getLoc(); + int txCnt = 0; + ttng::MBarrierArriveOp arriveOp; + + if (loadType == LoadType::LoadAsyncOp) { + // Each thread arrives. + Value pred = builder.create(loc, 1, 1); + arriveOp = builder.create( + loc, bufferFull, pred, /*remoteCTAId*/ nullptr, /*trackAsyncOp*/ true, + txCnt); + } else { + // Only thread 0 arrives for TMA load. + Value _0 = builder.create(loc, 0, 32); + Value threadId = createThreadIdOp(builder, loc); + Value pred = builder.create(loc, arith::CmpIPredicate::eq, + threadId, _0); + arriveOp = builder.create( + loc, bufferFull, pred, /*remoteCTAId*/ nullptr, /*trackAsyncOp*/ false, + txCnt); + } + + assert(op.getOperation()->hasAttr("async_task_id")); + setAsyncTaskIds(arriveOp, getAsyncTaskIds(op.getOperation())); +} + +void processConsumerWaitOp(OpBuilder &builder, ttng::ConsumerWaitOp op, + Value bufferFull) { + auto loc = op.getLoc(); + Value phase = getMBarrierPhaseBit(builder, op, false); + auto i32Ty = builder.getIntegerType(32); + phase = builder.create(loc, i32Ty, phase); + auto waitOp = builder.create(loc, bufferFull, phase); + assert(op.getOperation()->hasAttr("async_task_id")); + setAsyncTaskIds(waitOp, getAsyncTaskIds(op.getOperation())); +} + +void processConsumerReleaseOp(OpBuilder &builder, ttng::ConsumerReleaseOp op, + Value bufferEmpty, int numCTAs) { + auto loc = op.getLoc(); + Value _0 = builder.create(loc, 0, 32); + Value _4 = builder.create(loc, 4, 32); + Value _8 = builder.create(loc, 8, 32); + Value _32 = builder.create(loc, 32, 32); + Value _threadPerTask = + builder.create(loc, THREADS_PER_TASK, 32); + + // threadId = threadId % THREADS_PER_TASK + Value threadId = builder.create( + loc, createThreadIdOp(builder, loc), _threadPerTask); + // k = threadId / 8 + Value k = builder.create(loc, threadId, _8); + // row = k / 4 + Value row = builder.create(loc, k, _4); + // col = k % 4 + Value col = builder.create(loc, k, _4); + // remoteCTAId = (col ^ row) * 4 + col + Value remoteCTAId = builder.create( + loc, + Value{builder.create( + loc, Value{builder.create(loc, col, row)}, _4)}, + col); + + // pred0 = threadId % 8 == 0 + Value pred0 = builder.create( + loc, arith::CmpIPredicate::eq, + builder.create(loc, threadId, _8), _0); + // pred1 = remoteCTAId < numCTAs + Value pred1 = builder.create( + loc, arith::CmpIPredicate::ult, remoteCTAId, + builder.create(loc, numCTAs, 32)); + + // pred = pred0 & pred1 + Value pred = builder.create(loc, pred0, pred1); + // bufferEmpty arrive + auto arriveOp = builder.create(loc, bufferEmpty, pred, + remoteCTAId, false, 0); + + assert(op.getOperation()->hasAttr("async_task_id")); + setAsyncTaskIds(arriveOp, getAsyncTaskIds(op.getOperation())); +} + +void lowerTokenOperations(Operation *parentOp, int numCTAs, + int numConsumerGroups) { + SmallVector deprecatedOps; + parentOp->walk([&](ttng::CreateTokenOp createTokenOp) { + LoadType loadType = scanLoadTypes(createTokenOp); + MLIRContext *context = createTokenOp.getContext(); + OpBuilder builder(createTokenOp); + Location loc = createTokenOp.getLoc(); + + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(context); + auto barrierCTALayout = + ttg::CTALayoutAttr::get(context, /*CTAsPerCGA=*/{1}, + /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); + auto barrierEncoding = + ttg::SharedEncodingAttr::get(context, 1, 1, 1, {0}, barrierCTALayout); + Type barrierMemDescType = + tt::MemDescType::get({createTokenOp.getNum()}, builder.getI64Type(), + barrierEncoding, sharedMemorySpace, + /*mutableMemory=*/true); + Type singleBarrierMemDescType = + tt::MemDescType::get({1}, builder.getI64Type(), barrierEncoding, + sharedMemorySpace, /*mutableMemory=*/true); + Value bufferFullArray = builder.create( + loc, barrierMemDescType, Value()); + Value bufferEmptyArray = builder.create( + loc, barrierMemDescType, Value()); + + for (unsigned i = 0; i < createTokenOp.getNum(); i++) { + Value idx = builder.create(loc, i, 32); + Value barrierFullView = builder.create( + loc, singleBarrierMemDescType, bufferFullArray, idx); + unsigned bufferFullCount = + loadType == LoadType::LoadTMAOp ? 1 : THREADS_PER_TASK; + builder.create(loc, barrierFullView, + bufferFullCount); + + Value barrierEmptyView = builder.create( + loc, singleBarrierMemDescType, bufferEmptyArray, idx); + unsigned bufferEmptyCount = numCTAs; + builder.create(loc, barrierEmptyView, numCTAs); + } + + if (numCTAs == 1) { + builder.create(loc); + } else { + // Make sure that MBarriers are initialized in all CTAs. + builder.create(loc, false); + builder.create(loc); + } + + // Helper function for extracting one index from bufferFullArray. + auto extractBufferFull = [&](Location loc, Value idx) -> Value { + return builder.create( + loc, singleBarrierMemDescType, bufferFullArray, idx); + }; + + // Helper function for extracting one index from bufferEmptyArray. + auto extractBufferEmpty = [&](Location loc, Value idx) -> Value { + return builder.create( + loc, singleBarrierMemDescType, bufferEmptyArray, idx); + }; + + // Process token users: ProducerAcquireOp, ProducerCommitOp, ConsumerWaitOp, + // and ConsumerReleaseOp. + for (Operation *user : createTokenOp.getResult().getUsers()) { + auto loc = user->getLoc(); + builder.setInsertionPoint(user); + if (auto op = dyn_cast(user)) { + Value bufferEmpty = extractBufferEmpty(loc, op.getIdx()); + assert(user->hasAttr("async_task_id")); + setAsyncTaskIds(bufferEmpty.getDefiningOp(), getAsyncTaskIds(user)); + processProducerAcquireOp(builder, op, bufferEmpty); + } else if (auto op = dyn_cast(user)) { + Value bufferFull = extractBufferFull(loc, op.getIdx()); + assert(user->hasAttr("async_task_id")); + setAsyncTaskIds(bufferFull.getDefiningOp(), getAsyncTaskIds(user)); + processProducerCommitOp(builder, op, bufferFull, loadType); + } else if (auto op = dyn_cast(user)) { + Value bufferFull = extractBufferFull(loc, op.getIdx()); + assert(user->hasAttr("async_task_id")); + setAsyncTaskIds(bufferFull.getDefiningOp(), getAsyncTaskIds(user)); + processConsumerWaitOp(builder, op, bufferFull); + } else if (auto op = dyn_cast(user)) { + Value bufferEmpty = extractBufferEmpty(loc, op.getIdx()); + assert(user->hasAttr("async_task_id")); + setAsyncTaskIds(bufferEmpty.getDefiningOp(), getAsyncTaskIds(user)); + processConsumerReleaseOp(builder, op, bufferEmpty, numCTAs); + } else { + llvm_unreachable("Unexpected user of token"); + } + deprecatedOps.push_back(user); + } + + deprecatedOps.push_back(createTokenOp); + }); + for (auto op : deprecatedOps) { + op->erase(); + } + + // Insert a cluster barrier before the kernel exits. Without this barrier, + // mbarrier_remote_arrive will fail if the remote CTA already exits. + if (numCTAs > 1) { + parentOp->walk([&](triton::FuncOp funcOp) { + Block *block = &funcOp.getBody().front(); + auto returnOp = llvm::cast(block->getTerminator()); + OpBuilder builder(returnOp); + auto loc = returnOp.getLoc(); + builder.create(loc, false); + builder.create(loc); + }); + } +} + +#define GEN_PASS_DEF_TRITONGPUWSLOWERING +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +// This pass lowers WS-specific operations. +class TritonGPUWSLowering + : public impl::TritonGPUWSLoweringBase { +public: + using impl::TritonGPUWSLoweringBase< + TritonGPUWSLowering>::TritonGPUWSLoweringBase; + + void runOnOperation() override { + // Disable WarpSpec if numConsumerGroups is zero. + if (numConsumerGroups == 0) + return; + ModuleOp mod = getOperation(); + int numCTAs = ttg::TritonGPUDialect::getNumCTAs(mod); + + lowerGetAsyncTaskIdOp(mod, numConsumerGroups); + lowerTokenOperations(mod, numCTAs, numConsumerGroups); + + // We assume number of warps per warp group is 4. + // With Warp Spec, the effective warps per CTA is + // number of warp groups * 4, but within each warp group, layout will use + // num_warps of 4, since tensors are not distributed between the groups. + // + // Loads usually happen in one producer warp groups. num_warps of 4 makes + // sense because only the 4 warps from the producer warp group are + // participating in the load. + // + // But at some point (at least when we launch the kernel!) we really do need + // to know that the CTA has 8 or 12 warps in it. Attribute + // "num-warp-groups-per-cta" can be used to calculate the total number of + // warps. + auto builder = OpBuilder::atBlockBegin(mod.getBody()); + mod->setAttr("triton_gpu.num-warp-groups-per-cta", + builder.getI32IntegerAttr(1 + numConsumerGroups)); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index 37c69eef8..888e93bb0 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -93,6 +93,19 @@ LogicalResult WarpGroupDotWaitOp::inferReturnTypes( return mlir::success(); } +///--- Async related ops --- +void GetAsyncTaskIdOp::build(::mlir::OpBuilder &builder, + ::mlir::OperationState &state) { + build(builder, state, builder.getI32Type()); +} + +void CreateTokenOp::build(::mlir::OpBuilder &builder, + ::mlir::OperationState &state, uint32_t num) { + auto tokenType = TokenType::get(builder.getContext()); + auto resultType = RankedTensorType::get({num}, tokenType); + build(builder, state, resultType, num); +} + static LogicalResult verifyBarrierType(Operation *op, MemDescType barrierType) { if (!barrierType.getElementType().isInteger(64) || barrierType.getShape() != ArrayRef({1})) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt index 5adebc352..001d96214 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_triton_library(TritonNvidiaGPUTransforms FenceInsertion.cpp PlanCTA.cpp TMALowering.cpp + Utility.cpp DEPENDS TritonNvidiaGPUTransformsIncGen diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp index 0938432c7..c1bf9ca8c 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -7,6 +7,7 @@ #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" #include @@ -25,7 +26,7 @@ class TMALoadLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExperimentalDescriptorLoadOp op, - PatternRewriter &rewriter) const override { + PatternRewriter &baseRewriter) const override { MLIRContext *ctx = op.getContext(); Attribute sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(ctx); auto loc = op.getLoc(); @@ -42,6 +43,7 @@ class TMALoadLowering : public OpRewritePattern { MemDescType memDescType = MemDescType::get(tensorType.getShape(), tensorType.getElementType(), encoding, sharedMemorySpace, /*mutableMemory=*/true); + PatternRewriterWithAsyncTaskIds rewriter(baseRewriter, op); Value alloc = rewriter.create(loc, memDescType); auto barrierCTALayout = CTALayoutAttr::get( /*context=*/tensorType.getContext(), /*CTAsPerCGA=*/{1}, @@ -49,7 +51,7 @@ class TMALoadLowering : public OpRewritePattern { auto barrierEncoding = SharedEncodingAttr::get(tensorType.getContext(), 1, 1, 1, {0}, barrierCTALayout); MemDescType barrierMemDescType = - MemDescType::get({1}, rewriter.getI64Type(), barrierEncoding, + MemDescType::get({1}, baseRewriter.getI64Type(), barrierEncoding, sharedMemorySpace, /*mutableMemory=*/true); Value barrierAlloc = rewriter.create(loc, barrierMemDescType); rewriter.create(loc, barrierAlloc, 1); @@ -91,11 +93,17 @@ class TMAStoreLowering MemDescType memDescType = MemDescType::get(tensorType.getShape(), tensorType.getElementType(), encoding, sharedMemorySpace, /*mutableMemory=*/true); - Value alloc = rewriter.create(loc, memDescType, op.getSrc()); - rewriter.create(loc, false); - rewriter.create( + // If op has allocation.copy, the created LocalAlloc will have it. + auto alloc = rewriter.create(loc, memDescType, op.getSrc()); + auto attrs = op->getAttrs(); + alloc->setAttrs(attrs); + auto fence = rewriter.create(loc, false); + fence->setAttrs(attrs); + auto asyncCopy = rewriter.create( loc, op.getDescPtr(), op.getIndices(), alloc); - rewriter.create(loc, 0); + asyncCopy->setAttrs(attrs); + auto tma_wait = rewriter.create(loc, 0); + tma_wait->setAttrs(attrs); rewriter.eraseOp(op); return success(); } diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp new file mode 100644 index 000000000..83f21019f --- /dev/null +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp @@ -0,0 +1,162 @@ + +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/Support/Debug.h" +#include + +namespace mlir { + +namespace ttg = triton::gpu; + +namespace { + +bool knownSafeToIgnoreRegion(Operation *op) { + return isa(op); +} + +// Assigns `dependentSet` and returns ok if the analysis is successful. +// We do not support dependency analysis across load/store, thus a failure will +// be returned if encountering such cases. +LogicalResult getDependentPointers(Value ptr, DenseSet &dependentSet, + DenseSet &processedSet) { + // early return if processed + if (!processedSet.insert(ptr).second) + return success(); + + if (auto blockArg = dyn_cast(ptr)) { + if (!blockArg.getOwner()->isEntryBlock()) + return failure(); + auto parentOp = blockArg.getOwner()->getParentOp(); + if (auto forOp = dyn_cast(parentOp)) { + if (blockArg.getArgNumber() >= forOp.getNumInductionVars()) { + if (failed(getDependentPointers(forOp.getTiedLoopInit(blockArg)->get(), + dependentSet, processedSet))) + return failure(); + + unsigned operandIdx = + blockArg.getArgNumber() - forOp.getNumInductionVars(); + return getDependentPointers( + forOp.getBody()->getTerminator()->getOperand(operandIdx), + dependentSet, processedSet); + } + } else if (auto funcOp = dyn_cast(parentOp)) { + dependentSet.insert(ptr); + return success(); + } + // unknown ops, return failure for correctness. + return failure(); + } + + auto definingOp = ptr.getDefiningOp(); + assert(definingOp); + if (auto makeTensorPtrOp = ptr.getDefiningOp()) { + return getDependentPointers(makeTensorPtrOp.getBase(), dependentSet, + processedSet); + } else if (auto advanceOp = ptr.getDefiningOp()) { + return getDependentPointers(advanceOp.getPtr(), dependentSet, processedSet); + } else if (auto addPtrOp = ptr.getDefiningOp()) { + return getDependentPointers(addPtrOp.getPtr(), dependentSet, processedSet); + } else if (auto forOp = ptr.getDefiningOp()) { + unsigned idx = cast(ptr).getResultNumber(); + return getDependentPointers( + forOp.getBody()->getTerminator()->getOperand(idx), dependentSet, + processedSet); + } else if (auto ifOp = ptr.getDefiningOp()) { + unsigned idx = cast(ptr).getResultNumber(); + if (ifOp.elseBlock() && + failed(getDependentPointers(ifOp.elseYield()->getOperand(idx), + dependentSet, processedSet))) + return failure(); + return getDependentPointers(ifOp.thenYield()->getOperand(idx), dependentSet, + processedSet); + } else if (!definingOp->getNumRegions() || + knownSafeToIgnoreRegion(definingOp)) { + for (Value operand : definingOp->getOperands()) + if (failed(getDependentPointers(operand, dependentSet, processedSet))) + return failure(); + return success(); + } + // unknown ops, return failure for correctness. + return failure(); +} + +} // namespace + +//===----------------------------------------------------------------------===// +// Helper functions for async task +//===----------------------------------------------------------------------===// + +SmallVector getAsyncTaskIds(Operation *op) { + SmallVector asyncTaskIds; + if (auto attr = op->getAttrOfType("async_task_id")) + for (AsyncTaskId asyncTaskId : attr.getValues()) + asyncTaskIds.push_back(asyncTaskId); + return asyncTaskIds; +} + +bool hasAsyncTaskId(Operation *op, AsyncTaskId asyncTaskId) { + for (AsyncTaskId candidate : getAsyncTaskIds(op)) + if (candidate == asyncTaskId) + return true; + return false; +} + +void setAsyncTaskIds(Operation *op, ArrayRef asyncTaskIds) { + SmallVector sortedAsyncTaskIds(asyncTaskIds.begin(), asyncTaskIds.end()); + sort(sortedAsyncTaskIds); + auto i32Ty = IntegerType::get(op->getContext(), 32); + auto size = static_cast(sortedAsyncTaskIds.size()); + auto vecTy = VectorType::get(size, i32Ty); + op->setAttr("async_task_id", DenseIntElementsAttr::get(vecTy, sortedAsyncTaskIds)); +} + +SmallVector getNestedAsyncTaskIds(Operation *op) { + SetVector asyncTaskIds; + op->walk([&](Operation *curOp) { + for (AsyncTaskId asyncTaskId : getAsyncTaskIds(curOp)) + asyncTaskIds.insert(asyncTaskId); + }); + SmallVector res(asyncTaskIds.begin(), asyncTaskIds.end()); + llvm::sort(res); + return res; +} + +void addAsyncTaskIds(Operation *op, ArrayRef asyncTasks) { + auto asyncTasksVec = getAsyncTaskIds(op); + DenseSet asyncTasksSet(asyncTasksVec.begin(), asyncTasksVec.end()); + for (int a : asyncTasks) { + if (!asyncTasksSet.contains(a)) { + asyncTasksVec.push_back(a); + } + } + if (asyncTasksVec.size() > 0) { + setAsyncTaskIds(op, asyncTasksVec); + } +} + +void removeAsyncTaskId(Operation *op, AsyncTaskId asyncTaskId) { + auto origAsyncTaskIds = getAsyncTaskIds(op); + auto end = std::remove(origAsyncTaskIds.begin(), origAsyncTaskIds.end(), asyncTaskId); + origAsyncTaskIds.erase(end, origAsyncTaskIds.end()); + if (origAsyncTaskIds.empty()) + op->removeAttr("async_task_id"); + else + setAsyncTaskIds(op, origAsyncTaskIds); +} + +void removeAsyncTaskIds(Operation *op) { + op->removeAttr("async_task_id"); +} +//===----------------------------------------------------------------------===// +// Implementations for general auto WS +//===----------------------------------------------------------------------===// + + +} // namespace mlir diff --git a/python/src/ir.cc b/python/src/ir.cc index 95e48a692..fcd1b623b 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -37,6 +37,15 @@ namespace py = pybind11; using namespace mlir; using namespace triton; +void setAsyncTaskIds(Operation *op, ArrayRef asyncTaskIds) { + SmallVector sortedAsyncTaskIds(asyncTaskIds.begin(), asyncTaskIds.end()); + sort(sortedAsyncTaskIds); + auto i32Ty = IntegerType::get(op->getContext(), 32); + auto size = static_cast(sortedAsyncTaskIds.size()); + auto vecTy = VectorType::get(size, i32Ty); + op->setAttr("async_task_id", DenseIntElementsAttr::get(vecTy, sortedAsyncTaskIds)); +} + // A custom op builder that keeps track of the last location class TritonOpBuilder { public: @@ -95,7 +104,10 @@ class TritonOpBuilder { template OpTy create(Args &&...args) { auto loc = getLastLoc(); - return builder->create(loc, std::forward(args)...); + auto ret = builder->create(loc, std::forward(args)...); + if (asyncTaskIds) + ::setAsyncTaskIds(ret, *asyncTaskIds); + return ret; } // Overload to create or fold a single result operation. @@ -114,9 +126,16 @@ class TritonOpBuilder { return builder->createOrFold(loc, std::forward(args)...); } + void setAsyncTaskIds(std::vector taskIds) { this->asyncTaskIds = taskIds; } + + void unsetAsyncTaskIds() { + this->asyncTaskIds = std::nullopt; + } + private: std::unique_ptr builder; std::unique_ptr lastLoc; + std::optional> asyncTaskIds; bool lineInfoEnabled = !triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO"); }; @@ -368,6 +387,7 @@ void init_triton_ir(py::module &&m) { py::class_(m, "attribute", py::module_local()); py::class_(m, "integer_attr", py::module_local()); py::class_(m, "bool_attr", py::module_local()); + py::class_(m, "string_attr", py::module_local()); // Ops py::class_(m, "OpState", py::module_local()) @@ -631,6 +651,12 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self, OpBuilder::InsertPoint pt) { self.restoreInsertionPoint(pt); }) + .def("set_async_task_ids", + [](TritonOpBuilder &self, std::vector v) { + self.setAsyncTaskIds(v); + }) + .def("unset_async_task_ids", + [](TritonOpBuilder &self) { self.unsetAsyncTaskIds(); }) // Attr .def("get_bool_attr", [](TritonOpBuilder &self, bool value) { @@ -640,6 +666,10 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self, int32_t value) { return self.getBuilder().getI32IntegerAttr(value); }) + .def("get_string_attr", + [](TritonOpBuilder &self, const std::string &value) { + return self.getBuilder().getStringAttr(value); + }) // Use arith.ConstantOp to create constants // Constants .def("get_int1", diff --git a/python/src/passes.cc b/python/src/passes.cc index 98d8369d4..026d2f7c7 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -68,6 +68,17 @@ void init_triton_passes_ttgpuir(py::module &&m) { createTritonGPUCombineTensorSelectAndIf); ADD_PASS_WRAPPER_0("add_optimize_accumulator_init", createTritonGPUOptimizeAccumulatorInit); + ADD_PASS_OPTION_WRAPPER_1("add_ws_data_partition", + createTritonGPUWSDataPartition, int); + ADD_PASS_OPTION_WRAPPER_1("add_ws_lowering", createTritonGPUWSLowering, int); + ADD_PASS_OPTION_WRAPPER_1("add_taskid_propagate", + createTritonGPUTaskIdPropagate, int); + ADD_PASS_OPTION_WRAPPER_4("add_ws_code_partition", + createTritonGPUWSCodePartition, int, int, int, int); + ADD_PASS_OPTION_WRAPPER_2("add_ping_pong_sync", createTritonGPUPingPongSync, + int, int); + ADD_PASS_OPTION_WRAPPER_1("add_loop_scheduling", + createTritonGPULoopScheduling, int); } void init_triton_passes_convert(py::module &&m) { diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 19d09de85..517a2fc4d 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -802,6 +802,20 @@ def visit_UnaryOp(self, node): ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__' } + def visit_withitem(self, node): + return self.visit(node.context_expr) + + def visit_With(self, node): + assert len(node.items) == 1 + context = node.items[0].context_expr + withitemClass = self.visit(context.func) + if withitemClass == language.async_task: + args = [self.visit(arg) for arg in context.args] + with withitemClass(*args, _builder=self.builder): + self.visit_compound_statement(node.body) + else: + self.visit_compound_statement(node.body) + def visit_While(self, node): with enter_sub_region(self) as sr: liveins, insert_block = sr @@ -904,6 +918,7 @@ def visit_For(self, node): ast.NodeVisitor.generic_visit(self, stmt) return num_stages = None + loop_schedule = None if IteratorClass is language.range: iterator = IteratorClass(*iter_args, **iter_kwargs) # visit iterator arguments @@ -913,6 +928,7 @@ def visit_For(self, node): ub = iterator.end step = iterator.step num_stages = iterator.num_stages + loop_schedule = iterator.loop_schedule elif IteratorClass is range: # visit iterator arguments # note: only `range` iterator is supported now @@ -986,6 +1002,8 @@ def visit_For(self, node): for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args]) if num_stages is not None: for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages)) + if loop_schedule is not None: + for_op.set_attr("tt.loop_schedule", self.builder.get_string_attr(loop_schedule.value)) self.scf_stack.append(node) self.builder.set_insertion_point_to_start(for_op.get_body(0)) diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 0a84bd86a..d18701be2 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -32,6 +32,7 @@ arange, associative_scan, assume, + async_task, atomic_add, atomic_and, atomic_cas, diff --git a/python/triton/language/core.py b/python/triton/language/core.py index e16ca2dee..38cd9d899 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -2505,6 +2505,21 @@ def __next__(self): raise RuntimeError("static_range can only be used in @triton.jit'd functions") +class async_task: + """ + Context manager to run code fragments asynchronously. + """ + def __init__(self, task_ids, _builder=None): + self.task_ids = task_ids + self.builder = _builder + + def __enter__(self): + self.builder.set_async_task_ids(self.task_ids) + + def __exit__(self, exc_type, exc_value, traceback): + self.builder.unset_async_task_ids() + + class range: """ Iterator that counts upward forever. @@ -2514,7 +2529,7 @@ class range: @triton.jit def kernel(...): - for i in tl.range(10, num_stages=3): + for i in tl.range(10, num_stages=3, loop_schedule="Default"): ... :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler. @@ -2528,9 +2543,10 @@ def kernel(...): kernel argument. The kernel argument only pipelines loads that feed into :code:`dot` operations, while this attribute tries to pipeline most (though not all) loads in this loop. + :param loop_schedule: specify a scheduling policy for the loop. """ - def __init__(self, arg1, arg2=None, step=None, num_stages=None): + def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_schedule=None): if step is None: self.step = constexpr(1) else: @@ -2542,6 +2558,7 @@ def __init__(self, arg1, arg2=None, step=None, num_stages=None): self.start = arg1 self.end = arg2 self.num_stages = num_stages + self.loop_schedule = loop_schedule def __iter__(self): raise RuntimeError("tl.range can only be used in @triton.jit'd functions") diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 59191a31b..c2d16b820 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -35,7 +35,7 @@ def __init__( 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. """ if not configs: - self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)] + self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0, reg_dec_producer=0, reg_inc_consumer=0)] else: self.configs = configs self.key_idx = [arg_names.index(k) for k in key] @@ -153,6 +153,17 @@ def run(self, *args, **kwargs): timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} bench_end = time.time() self.bench_time = bench_end - bench_start + + # __FACEBOOK__ (facebook) begin T203283446 + if os.getenv("TRITON_PRINT_AUTOTUNING_ALL", None) == "1": + print( + f'\nPrinting ALL Multiple Triton autotuning Configs with timings in sorted order for kernel {self.fn}:' + ) + sorted_configs = builtins.sorted(timings, key=timings.get) + for config in sorted_configs: + print(f'Triton autotune config: [{config}]; Triton autotune timing: {timings[config]}') + # __FACEBOOK__ (facebook) end T203283446 + self.cache[key] = builtins.min(timings, key=timings.get) self.pre_hook(args, reset_only=True) self.configs_timings = timings @@ -227,11 +238,15 @@ class Config: function are args. """ - def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, maxnreg=None, pre_hook=None): + def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0, reg_dec_producer=0, reg_inc_consumer=0, maxnreg=None, pre_hook=None): self.kwargs = kwargs self.num_warps = num_warps self.num_ctas = num_ctas self.num_stages = num_stages + self.num_buffers_warp_spec = num_buffers_warp_spec + self.num_consumer_groups = num_consumer_groups + self.reg_dec_producer = reg_dec_producer + self.reg_inc_consumer = reg_inc_consumer self.maxnreg = maxnreg self.pre_hook = pre_hook @@ -243,6 +258,10 @@ def all_kwargs(self): ("num_warps", self.num_warps), ("num_ctas", self.num_ctas), ("num_stages", self.num_stages), + ("num_buffers_warp_spec", self.num_buffers_warp_spec), + ("num_consumer_groups", self.num_consumer_groups), + ("reg_dec_producer", self.reg_dec_producer), + ("reg_inc_consumer", self.reg_inc_consumer), ("maxnreg", self.maxnreg), ) if v is not None } @@ -255,6 +274,10 @@ def __str__(self): res.append(f"num_warps: {self.num_warps}") res.append(f"num_ctas: {self.num_ctas}") res.append(f"num_stages: {self.num_stages}") + res.append(f"num_buffers_warp_spec: {self.num_buffers_warp_spec}") + res.append(f"num_consumer_groups: {self.num_consumer_groups}") + res.append(f"reg_dec_producer: {self.reg_dec_producer}") + res.append(f"reg_inc_consumer: {self.reg_inc_consumer}") res.append(f"maxnreg: {self.maxnreg}") return ", ".join(res) diff --git a/python/tutorials/10-warp-specialized-matmul.py b/python/tutorials/10-warp-specialized-matmul.py new file mode 100644 index 000000000..ed51de580 --- /dev/null +++ b/python/tutorials/10-warp-specialized-matmul.py @@ -0,0 +1,319 @@ +import os +import sys + +import torch +import triton +import triton.language as tl + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl) + +if HAS_TMA_DESC: + print( + "TMA benchmarks will be running with experimental grid constant TMA descriptor.", + ) +else: + print( + "TMA benchmarks will be running without grid constant TMA descriptor.", + ) + + +class TmaAutoTuneHelper: + + # duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498 + class KernelParamWrapper: + def __init__(self, desc): + self.desc = desc + + def tma_desc_cpu_ptr(self): + return self.desc.data_ptr() + + TMA_SIZE = 128 + + def __init__(self): + self.fill_1d_tma_descriptor_inner = ( + triton.runtime.driver.active.utils.fill_1d_tma_descriptor + ) + self.fill_2d_tma_descriptor_inner = ( + triton.runtime.driver.active.utils.fill_2d_tma_descriptor + ) + if HAS_TMA_DESC: + self.descriptors = {} + else: + self.cuda_descriptors = {} + + # Call this method outside of the lambda function for grid size + def init_tma_descriptor(self, name): + if HAS_TMA_DESC: + self.descriptors[name] = torch.empty( + TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8 + ) + else: + self.cuda_descriptors[name] = torch.empty( + TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8 + ) + + # Call this method inside the lambda function for grid size + def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size): + if HAS_TMA_DESC: + desc_x = self.descriptors[name] + assert desc_x.data_ptr() % 64 == 0 + self.fill_1d_tma_descriptor_inner( + ptr, dim, block_dim, element_size, desc_x.data_ptr() + ) + else: + desc_x = self.cuda_descriptors[name] + buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) + self.fill_1d_tma_descriptor_inner( + ptr, dim, block_dim, element_size, buf_x.data_ptr() + ) + desc_x.copy_(buf_x, non_blocking=True) + + # Call this method inside the lambda function for grid size + def fill_2d_tma_descriptor( + self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size + ): + if HAS_TMA_DESC: + desc_x = self.descriptors[name] + assert desc_x.data_ptr() % 64 == 0 + self.fill_2d_tma_descriptor_inner( + ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr() + ) + else: + desc_x = self.cuda_descriptors[name] + buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) + self.fill_2d_tma_descriptor_inner( + ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr() + ) + desc_x.copy_(buf_x, non_blocking=True) + + def get_tma_descriptor_kernel_param(self, name): + if HAS_TMA_DESC: + assert self.descriptors[name] is not None + return self.KernelParamWrapper(self.descriptors[name]) + else: + assert self.cuda_descriptors[name] is not None + return self.cuda_descriptors[name] + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=4, + num_consumer_groups=2, + num_buffers_warp_spec=3, + ), + ], + key=["M", "N", "K"], +) +@triton.jit +def matmul_persistent_tma_ws_cooperative_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + + num_tiles = tl.cdiv(M, BLOCK_SIZE_M) * tl.cdiv(N, BLOCK_SIZE_N) + for pid in range(tl.program_id(0), num_tiles, tl.num_programs(0)): + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + offs_k = 0 + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + with tl.async_task([0]): + a = tl._experimental_descriptor_load( + a_ptr, + [offs_am, offs_k], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + tl.float16, + ) + b = tl._experimental_descriptor_load( + b_ptr, [offs_k, offs_bn], [BLOCK_SIZE_K, BLOCK_SIZE_N], tl.float16 + ) + + accumulator += tl.dot(a, b) + offs_k += BLOCK_SIZE_K + + c = accumulator.to(tl.float16) + + with tl.async_task([1, 2]): + tl._experimental_descriptor_store(c_ptr, c, [offs_am, offs_bn]) + + +# %% +# We can now create a convenience wrapper function that only takes two input tensors, +# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. + + +def matmul_persistent_tma_ws_cooperative(a, b): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + M, K = a.shape + K, N = b.shape + dtype = a.dtype + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=dtype) + + desc_helper = TmaAutoTuneHelper() + desc_helper.init_tma_descriptor("a") + desc_helper.init_tma_descriptor("b") + desc_helper.init_tma_descriptor("c") + + def grid(META): + nonlocal desc_helper + desc_helper.fill_2d_tma_descriptor( + "a", + a.data_ptr(), + M, + K, + META["BLOCK_SIZE_M"] // 2, + META["BLOCK_SIZE_K"], + a.element_size(), + ) + + desc_helper.fill_2d_tma_descriptor( + "b", + b.data_ptr(), + K, + N, + META["BLOCK_SIZE_K"], + META["BLOCK_SIZE_N"], + b.element_size(), + ) + desc_helper.fill_2d_tma_descriptor( + "c", + c.data_ptr(), + M, + N, + META["BLOCK_SIZE_M"] // 2, + META["BLOCK_SIZE_N"], + c.element_size(), + ) + return ( + min( + NUM_SMS, + triton.cdiv(M, META["BLOCK_SIZE_M"]) + * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ), + ) + + desc_a = desc_helper.get_tma_descriptor_kernel_param("a") + desc_b = desc_helper.get_tma_descriptor_kernel_param("b") + desc_c = desc_helper.get_tma_descriptor_kernel_param("c") + matmul_persistent_tma_ws_cooperative_kernel[grid]( + desc_a, + desc_b, + desc_c, # + M, + N, + K, # + ) + return c + + +def aten_matmul(a, b): + return a.mm(b) + + +test_impls = [ + aten_matmul, + matmul_persistent_tma_ws_cooperative, +] + + +impl_map = {fn.__name__: fn for fn in test_impls} + + +def test(): + torch.manual_seed(0) + m = 4 * 11 * 64 + n = 12 * 256 + k = 64 * 4 + a = torch.randn((m, k), device="cuda", dtype=torch.float16) + b = torch.randn((k, n), device="cuda", dtype=torch.float16) + torch_output = torch.matmul(a, b) + rtol = 0 + for fn in test_impls: + triton_output = fn(a, b) + torch.cuda.synchronize() + if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): + print(f" Torch matches {fn.__name__}") + else: + print(f" Torch DOES NOT match {fn.__name__}") + print("torch output:") + print(torch_output) + print("triton output:") + print(triton_output) + + +x_vals = [(8192, 8192, i) for i in range(256, 8192 + 1, 256)] +configs = [] +configs.append( + triton.testing.Benchmark( + x_names=["M", "N", "K"], # Argument names to use as an x-axis for the plot + x_vals=x_vals, + line_arg="provider", # Argument name whose value corresponds to a different line in the plot + # Possible values for `line_arg` + # Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment. + line_vals=[fn.__name__ for fn in test_impls], + line_names=[fn.__name__ for fn in test_impls], + # styles=[("red", "-"), ("green", "-"), ("blue", "-")], + ylabel="TFLOPS", # Label name for the y-axis + plot_name="matmul-performance-" + + ( + "fp16" + ), # Name for the plot, used also as a file name for saving the plot. + args={}, + ) +) + +@triton.testing.perf_report(configs) +def benchmark(M, N, K, provider): + a = torch.randn((M, K), device="cuda", dtype=torch.float16) + b = torch.randn((K, N), device="cuda", dtype=torch.float16) + quantiles = [0.5, 0.2, 0.8] + fn = impl_map[provider] + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(lambda: fn(a, b), quantiles=quantiles) + perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) + + +test() +benchmark.run(show_plots=True, print_data=True) diff --git a/python/tutorials/mm.py b/python/tutorials/mm.py new file mode 100644 index 000000000..2931fdf0b --- /dev/null +++ b/python/tutorials/mm.py @@ -0,0 +1,201 @@ +import torch + +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + # fmt: off + triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=4, num_buffers_warp_spec=3, num_consumer_groups=1), + # fmt: on + ], + key=["M", "N", "K"], +) +@triton.jit +def matmul_kernel_ws( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + with tl.async_task([0]): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + with tl.async_task([1]): + tl.store(c_ptrs, c, mask=c_mask) + + +@triton.autotune( + configs=[ + # fmt: off + triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=4), + # fmt: on + ], + key=["M", "N", "K"], +) +@triton.jit +def matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul(a, b, ws=True): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + kernel = matmul_kernel_ws if ws else matmul_kernel + kernel[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + ) + return c + + +def test(): + m, n, k = 8192, 8192, 8192 + a = torch.randn((m, k), device="cuda", dtype=torch.float16) + b = torch.randn((k, n), device="cuda", dtype=torch.float16) + triton_output = matmul(a, b, ws=True) + torch_output = torch.matmul(a, b) + + print("triton:", triton_output) + print(" torch:", torch_output) + + torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0.0) + + +@triton.testing.perf_report( + [ + triton.testing.Benchmark( + x_names=["M", "N", "K"], + x_vals=[128 * i for i in range(28, 33)], + line_arg="provider", + line_vals=["cublas", "triton-warpspec", "triton-multistage"], + line_names=["cuBLAS", "Triton:WarpSpec", "Triton:MultiStage"], + ylabel="TFLOPS", + plot_name="matmul-performance-fp16", + args={}, + ) + ] +) +def benchmark(M, N, K, provider): + a = torch.randn((M, K), device="cuda", dtype=torch.float16) + b = torch.randn((K, N), device="cuda", dtype=torch.float16) + quantiles = [0.5, 0.2, 0.8] + if provider == "cublas": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: torch.matmul(a, b), quantiles=quantiles + ) + if "triton" in provider: + ws = "warpspec" in provider + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: matmul(a, b, ws=ws), quantiles=quantiles + ) + perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) + + +test() +benchmark.run(show_plots=True, print_data=True) diff --git a/test/TritonGPU/comp-pipeline.mlir b/test/TritonGPU/comp-pipeline.mlir new file mode 100644 index 000000000..492b1d508 --- /dev/null +++ b/test/TritonGPU/comp-pipeline.mlir @@ -0,0 +1,102 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=4 -debug-only=triton-matmul-loop-pipeline 2>&1 | FileCheck %s --check-prefix=DEBUG +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=4 | FileCheck %s + +// DEBUG: Final coarse schedule: +// DEBUG: Ops in stage 2 +// DEBUG-DAG: triton_nvidia_gpu.wait_barrier +// DEBUG-DAG: triton_nvidia_gpu.warp_group_dot +// DEBUG: Ops in stage 3 +// DEBUG: triton_nvidia_gpu.wait_barrier +// DEBUG: Original loop: + +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 4], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @_attn_fwd_tma(%arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: f32, %arg8: !tt.ptr {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32, %arg11: i64, %arg14: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c128_i32 = arith.constant 128 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %25 = tt.experimental_descriptor_load %arg3[%arg9, %c0_i32] : !tt.ptr -> tensor<128x128xf16, #blocked1> + %26 = triton_gpu.local_alloc %25 : (tensor<128x128xf16, #blocked1>) -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> + %27 = arith.extsi %arg14 : i32 to i64 + %28 = tt.splat %arg6 : f32 -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %29 = tt.splat %arg6 : f32 -> tensor<128x128xf32, #mma> + %30 = arith.extsi %arg17 : i32 to i64 + // CHECK: tt.experimental_descriptor_load + // CHECK: %[[QLOC:.+]] = triton_gpu.local_alloc {{.*}}tt.memdesc<128x128xf16 + // CHECK: %[[KLOC:.+]] = triton_gpu.local_alloc {{.*}}tt.memdesc<3x128x128xf16 + // CHECK: %[[VLOC:.+]] = triton_gpu.local_alloc {{.*}}tt.memdesc<3x128x128xf16 + // CHECK: %[[KBAR:.+]] = triton_gpu.local_alloc {{.*}}tt.memdesc<3xi64 + // CHECK: %[[VBAR:.+]] = triton_gpu.local_alloc {{.*}}tt.memdesc<3xi64 + // stage 0 iteration 0 + // CHECK: %[[K0:.+]] = triton_gpu.memdesc_subview %[[KLOC]][%c0_i32, %c0_i32, %c0_i32] + // CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local{{.*}} %[[K0]] + // stage 0 iteration 1 + // CHECK: %[[K1:.+]] = triton_gpu.memdesc_subview %[[KLOC]][%c1_i32, %c0_i32, %c0_i32] + // CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local{{.*}} %[[K1]] + // stage 1 iteration 0 + // CHECK: %[[V0:.+]] = triton_gpu.memdesc_subview %[[VLOC]][%c0_i32, %c0_i32, %c0_i32] + // CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local{{.*}} %[[V0]] + // stage 2 iteration 0 + // CHECK: %[[FIRSTDOT:.+]] = triton_nvidia_gpu.warp_group_dot + // stage 0 iteration 2 + // CHECK: %[[K2:.+]] = triton_gpu.memdesc_subview %[[KLOC]][%c2_i32, %c0_i32, %c0_i32] + // CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local{{.*}} %[[K2]] + // stage 1 iteration 1 + // CHECK: %[[V1:.+]] = triton_gpu.memdesc_subview %[[VLOC]][%c1_i32, %c0_i32, %c0_i32] + // CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local{{.*}} %[[V1]] + // CHECK: scf.for {{.*}} %[[ARG:.+]] = %[[FIRSTDOT]] + // CHECK: %[[KBARSUB:.+]] = triton_gpu.memdesc_subview %[[KBAR]][%[[KBARIDX:.+]]] + // CHECK: scf.if + // CHECK: triton_nvidia_gpu.wait_barrier %[[KBARSUB]] + // CHECK: %[[KLOOP:.+]] = triton_gpu.memdesc_subview %[[KLOC]] + // CHECK: tt.trans %[[KLOOP]] + // CHECK: %[[FIRSTDOTLOOP:.+]] = triton_nvidia_gpu.warp_group_dot + // CHECK: %[[WAIT:.+]]:{{[0-9]+}} = triton_nvidia_gpu.warp_group_dot_wait + // CHECK: "tt.reduce"(%[[ARG]]) + // CHECK: %[[VBARSUB:.+]] = triton_gpu.memdesc_subview %[[VBAR]] + // CHECK: triton_nvidia_gpu.wait_barrier %[[VBARSUB]] + // CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local + // CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local + // CHECK: scf.yield {{.*}}%[[WAIT]]#0 + // arg26 is acc + %31:1 = scf.for %arg24 = %c0_i32 to %arg23 step %c128_i32 iter_args(%arg26 = %cst_2) -> (tensor<128x128xf32, #mma>) : i32 { + %48 = arith.divsi %arg11, %27 : i64 + %49 = arith.trunci %48 : i64 to i32 + %50 = arith.addi %arg24, %49 : i32 + // loads in different stages + %51 = tt.experimental_descriptor_load %arg4[%50, %c0_i32] {loop.stage = 0 : i32, loop.cluster = 1 : i32} : !tt.ptr -> tensor<128x128xf16, #blocked1> + %52 = triton_gpu.local_alloc %51 : (tensor<128x128xf16, #blocked1>) -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> + %53 = tt.trans %52 {order = array} : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory> + %54 = triton_nvidia_gpu.warp_group_dot %26, %53, %cst_2 {inputPrecision = 0 : i32, loop.stage = 2 : i32, loop.cluster = 0 : i32} : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x128xf32, #mma> + %55 = "tt.reduce"(%54) <{axis = 1 : i32}> ({ + ^bb0(%arg28: f32 loc(unknown), %arg29: f32 loc(unknown)): + %80 = arith.maxnumf %arg28, %arg29 : f32 + tt.reduce.return %80 : f32 + }) : (tensor<128x128xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %56 = arith.mulf %55, %28 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %58 = arith.mulf %54, %29 : tensor<128x128xf32, #mma> + %59 = tt.expand_dims %56 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> + %60 = tt.broadcast %59 : tensor<128x1xf32, #mma> -> tensor<128x128xf32, #mma> + %61 = arith.subf %58, %60 : tensor<128x128xf32, #mma> + %62 = math.exp2 %61 : tensor<128x128xf32, #mma> + %71 = arith.divsi %arg11, %30 : i64 + %72 = arith.extsi %arg24 : i32 to i64 + %73 = arith.addi %71, %72 : i64 + %74 = arith.trunci %73 : i64 to i32 + %75 = tt.experimental_descriptor_load %arg5[%74, %c0_i32] {loop.stage = 1 : i32, loop.cluster = 1 : i32} : !tt.ptr -> tensor<128x128xf16, #blocked1> + %76 = triton_gpu.local_alloc %75 : (tensor<128x128xf16, #blocked1>) -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> + %77 = arith.truncf %62 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> + %78 = triton_gpu.convert_layout %77 : tensor<128x128xf16, #mma> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + %79 = triton_nvidia_gpu.warp_group_dot %78, %76, %arg26 {inputPrecision = 0 : i32, loop.stage = 3 : i32, loop.cluster = 0 : i32} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x128xf32, #mma> + scf.yield %79 : tensor<128x128xf32, #mma> + } {tt.divisibility_arg1 = dense<128> : tensor<1xi32>} + %42 = arith.truncf %31#0 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> + %43 = triton_gpu.convert_layout %42 : tensor<128x128xf16, #mma> -> tensor<128x128xf16, #blocked1> + tt.experimental_descriptor_store %arg8[%arg10, %c0_i32], %43 : !tt.ptr, tensor<128x128xf16, #blocked1> + tt.return + } +} diff --git a/test/TritonNvidiaGPU/WarpSpecialization/async_propagate.mlir b/test/TritonNvidiaGPU/WarpSpecialization/async_propagate.mlir new file mode 100644 index 000000000..1cca80d21 --- /dev/null +++ b/test/TritonNvidiaGPU/WarpSpecialization/async_propagate.mlir @@ -0,0 +1,63 @@ +// RUN: triton-opt %s -split-input-file --triton-gpu-taskid-propagate=num-consumer-groups=1 | FileCheck %s + +// CHECK-LABEL: @async_kernel +// CHECK: %0 = tt.get_program_id x {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 +// CHECK: %5 = tt.splat %arg2 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<1024xi32> +// CHECK: %9 = tt.load %8, %6 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr> +// CHECK: %10 = tt.splat %arg1 {async_task_id = dense<1> : vector<1xi32>} : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: tt.store %11, %9 {async_task_id = dense<1> : vector<1xi32>} : tensor<1024x!tt.ptr> + +module { + tt.func public @async_kernel(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = tt.splat %arg2 : i32 -> tensor<1024xi32> + %6 = arith.cmpi slt, %4, %5 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %8 = tt.addptr %7, %4 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr>, tensor<1024xi32> + %9 = tt.load %8, %6 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr> + %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %11 = tt.addptr %10, %4 {async_task_id = dense<1> : vector<1xi32>} : tensor<1024x!tt.ptr>, tensor<1024xi32> + tt.store %11, %9 {async_task_id = dense<1> : vector<1xi32>} : tensor<1024x!tt.ptr> + tt.return + } +} + +// ----- + +// CHECK-LABEL: @two_consumers +// CHECK: tt.get_program_id x {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 +// CHECK: tt.splat %arg0 {async_task_id = dense<0> : vector<1xi32>} +// CHECK: tt.load {{.*}} {async_task_id = dense<0> : vector<1xi32>} +// CHECK: tt.load {{.*}} {async_task_id = dense<0> : vector<1xi32>} +// CHECK: tt.splat %arg1 {async_task_id = dense<[1, 2]> : vector<2xi32>} +// CHECK: tt.store {{.*}} {async_task_id = dense<1> : vector<1xi32>} +// CHECK: tt.store {{.*}} {async_task_id = dense<2> : vector<1xi32>} + +module { + tt.func public @two_consumers(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.make_range {end = 2048 : i32, start = 1024 : i32} : tensor<1024xi32> + %4 = tt.splat %1 : i32 -> tensor<1024xi32> + %5 = arith.addi %4, %2 : tensor<1024xi32> + %6 = arith.addi %4, %3 : tensor<1024xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %8 = tt.addptr %7, %5 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr>, tensor<1024xi32> + %9 = tt.addptr %7, %6 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr>, tensor<1024xi32> + %10 = tt.load %8 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr> + %11 = tt.load %9 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr> + %12 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %13 = tt.addptr %12, %5 {async_task_id = dense<1> : vector<1xi32>} : tensor<1024x!tt.ptr>, tensor<1024xi32> + %14 = tt.addptr %12, %6 {async_task_id = dense<2> : vector<1xi32>} : tensor<1024x!tt.ptr>, tensor<1024xi32> + tt.store %13, %10 {async_task_id = dense<1> : vector<1xi32>} : tensor<1024x!tt.ptr> + tt.store %14, %11 {async_task_id = dense<2> : vector<1xi32>} : tensor<1024x!tt.ptr> + tt.return + } +} diff --git a/test/TritonNvidiaGPU/WarpSpecialization/ws_code_partition.mlir b/test/TritonNvidiaGPU/WarpSpecialization/ws_code_partition.mlir new file mode 100644 index 000000000..0461ce39b --- /dev/null +++ b/test/TritonNvidiaGPU/WarpSpecialization/ws_code_partition.mlir @@ -0,0 +1,306 @@ +// RUN: triton-opt %s -split-input-file --tritongpu-warp-spec-code-partition=num-buffers=1 | FileCheck %s + +// CHECK-LABEL: @matmul_kernel_one_consumer +// CHECK: %[[#TASKID:]] = triton_nvidia_gpu.get_async_task_id : i32 +// CHECK: %c0_i32 = arith.constant 0 : i32 +// CHECK: %[[#WG0:]] = arith.cmpi eq, %[[#TASKID]], %c0_i32 : i32 +// CHECK: scf.if %[[#WG0]] +// CHECK: triton_nvidia_gpu.reg_dealloc 40 +// CHECK: scf.for +// CHECK: triton_nvidia_gpu.producer_acquire +// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: triton_nvidia_gpu.producer_commit +// CHECK: %c1_i32 = arith.constant 1 : i32 +// CHECK: %[[#WG1:]] = arith.cmpi eq, %[[#TASKID]], %c1_i32 : i32 +// CHECK: scf.if %[[#WG1]] +// CHECK: triton_nvidia_gpu.reg_alloc 232 +// CHECK: triton_nvidia_gpu.consumer_wait +// CHECK: triton_gpu.local_load +// CHECK: triton_gpu.local_load +// CHECK: tt.dot +// CHECK: triton_nvidia_gpu.consumer_release + + +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_one_consumer(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant {async_task_id = dense<1> : vector<1xi32>} dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %c255_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 255 : i32 + %c127_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 127 : i32 + %c1_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 1 : i32 + %c0_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %cst_0 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<0.000000e+00> : tensor<256x128xf16, #blocked1> + %cst_1 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<0.000000e+00> : tensor<128x256xf16, #blocked2> + %c8_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 8 : i32 + %c128_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 128 : i32 + %c256_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 256 : i32 + %cst_2 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<256> : tensor<128x256xi32, #blocked2> + %0 = tt.get_program_id x {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %1 = arith.addi %arg3, %c127_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %2 = arith.divsi %1, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %3 = arith.addi %arg4, %c127_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %4 = arith.divsi %3, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %5 = arith.muli %4, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %6 = arith.divsi %0, %5 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %7 = arith.muli %6, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %8 = arith.subi %2, %7 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %9 = arith.minsi %8, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %10 = arith.remsi %0, %5 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %11 = arith.remsi %10, %9 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %12 = arith.addi %7, %11 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %13 = arith.divsi %10, %9 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %14 = arith.muli %12, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %15 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %16 = tt.make_range {async_task_id = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %17 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %18 = tt.splat %14 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %19 = tt.splat %14 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %20 = arith.addi %18, %15 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %21 = arith.addi %19, %16 {async_task_id = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %22 = tt.splat %arg3 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %23 = arith.remsi %20, %22 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %24 = arith.muli %13, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %25 = tt.splat %24 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %26 = arith.addi %25, %17 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %27 = tt.splat %arg4 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %28 = arith.remsi %26, %27 {async_task_id = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %29 = tt.expand_dims %23 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> + %30 = tt.splat %arg6 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked2> + %31 = arith.muli %29, %30 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked2> + %32 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %33 = tt.expand_dims %32 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> + %34 = tt.broadcast %31 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked2> -> tensor<128x256xi32, #blocked2> + %35 = tt.broadcast %33 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> + %36 = arith.addi %34, %35 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256xi32, #blocked2> + %37 = tt.splat %arg0 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<128x256x!tt.ptr, #blocked2> + %38 = tt.addptr %37, %36 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> + %39 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %40 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %41 = tt.expand_dims %39 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1> + %42 = tt.expand_dims %40 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1> + %43 = tt.splat %arg7 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256x1xi32, #blocked1> + %44 = arith.muli %41, %43 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi32, #blocked1> + %45 = tt.expand_dims %28 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> + %46 = tt.broadcast %44 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi32, #blocked1> -> tensor<256x128xi32, #blocked1> + %47 = tt.broadcast %45 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x128xi32, #blocked1> -> tensor<256x128xi32, #blocked1> + %48 = arith.addi %46, %47 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128xi32, #blocked1> + %49 = tt.splat %arg1 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<256x128x!tt.ptr, #blocked1> + %50 = tt.addptr %49, %48 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128x!tt.ptr, #blocked1>, tensor<256x128xi32, #blocked1> + %51 = arith.addi %arg5, %c255_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %52 = arith.divsi %51, %c256_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %53 = arith.muli %arg7, %c256_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %54 = tt.splat %53 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256x128xi32, #blocked1> + %55:3 = scf.for %arg9 = %c0_i32 to %52 step %c1_i32 iter_args(%arg10 = %cst, %arg11 = %38, %arg12 = %50) -> (tensor<128x128xf32, #blocked>, tensor<128x256x!tt.ptr, #blocked2>, tensor<256x128x!tt.ptr, #blocked1>) : i32 { + %74 = arith.muli %arg9, %c256_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %75 = arith.subi %arg5, %74 {async_task_id = dense<0> : vector<1xi32>} : i32 + %76 = tt.splat %75 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<1x256xi32, #blocked2> + %77 = arith.cmpi slt, %33, %76 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi32, #blocked2> + %78 = tt.broadcast %77 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> + %79 = tt.load %arg11, %78, %cst_1 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256x!tt.ptr, #blocked2> + %80 = tt.splat %75 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256x1xi32, #blocked1> + %81 = arith.cmpi slt, %42, %80 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi32, #blocked1> + %82 = tt.broadcast %81 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi1, #blocked1> -> tensor<256x128xi1, #blocked1> + %83 = tt.load %arg12, %82, %cst_0 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128x!tt.ptr, #blocked1> + %84 = triton_gpu.convert_layout %79 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x256xf16, #blocked2> -> tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + %85 = triton_gpu.convert_layout %83 {async_task_id = dense<1> : vector<1xi32>} : tensor<256x128xf16, #blocked1> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %86 = tt.dot %84, %85, %arg10, inputPrecision = tf32 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked> + %87 = tt.addptr %arg11, %cst_2 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> + %88 = tt.addptr %arg12, %54 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128x!tt.ptr, #blocked1>, tensor<256x128xi32, #blocked1> + scf.yield {async_task_id = dense<[0, 1]> : vector<2xi32>} %86, %87, %88 : tensor<128x128xf32, #blocked>, tensor<128x256x!tt.ptr, #blocked2>, tensor<256x128x!tt.ptr, #blocked1> + } {async_task_id = dense<[0, 1]> : vector<2xi32>} + %56 = arith.truncf %55#0 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked> + %57 = tt.expand_dims %21 {async_task_id = dense<1> : vector<1xi32>, axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %58 = tt.splat %arg8 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked1> + %59 = arith.muli %58, %57 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1xi32, #blocked1> + %60 = tt.splat %arg2 {async_task_id = dense<1> : vector<1xi32>} : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %61 = tt.addptr %60, %59 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %62 = tt.expand_dims %26 {async_task_id = dense<1> : vector<1xi32>, axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> + %63 = tt.broadcast %61 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x128x!tt.ptr, #blocked1> + %64 = tt.broadcast %62 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi32, #blocked1> -> tensor<128x128xi32, #blocked1> + %65 = tt.addptr %63, %64 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr, #blocked1>, tensor<128x128xi32, #blocked1> + %66 = tt.splat %arg3 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked1> + %67 = arith.cmpi slt, %57, %66 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1xi32, #blocked1> + %68 = tt.splat %arg4 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<1x128xi32, #blocked1> + %69 = arith.cmpi slt, %62, %68 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi32, #blocked1> + %70 = tt.broadcast %67 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1xi1, #blocked1> -> tensor<128x128xi1, #blocked1> + %71 = tt.broadcast %69 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi1, #blocked1> -> tensor<128x128xi1, #blocked1> + %72 = arith.andi %70, %71 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128xi1, #blocked1> + %73 = triton_gpu.convert_layout %56 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #blocked1> + tt.store %65, %73, %72 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- + + +// CHECK-LABEL: @matmul_kernel_two_consumers +// CHECK: scf.if +// CHECK: triton_nvidia_gpu.reg_dealloc 40 +// CHECK: scf.for +// CHECK: triton_nvidia_gpu.producer_acquire +// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: triton_nvidia_gpu.producer_commit +// CHECK: triton_nvidia_gpu.producer_acquire +// CHECK: triton_nvidia_gpu.producer_acquire +// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: triton_nvidia_gpu.producer_commit +// CHECK: triton_nvidia_gpu.producer_commit +// CHECK: scf.if +// CHECK: triton_nvidia_gpu.reg_alloc 232 +// CHECK: triton_nvidia_gpu.consumer_wait +// CHECK: triton_nvidia_gpu.consumer_wait +// CHECK: triton_nvidia_gpu.warp_group_dot +// CHECK: triton_nvidia_gpu.consumer_release +// CHECK: triton_nvidia_gpu.consumer_release +// CHECK: scf.if +// CHECK: triton_nvidia_gpu.reg_alloc 232 +// CHECK: triton_nvidia_gpu.consumer_wait +// CHECK: triton_nvidia_gpu.consumer_wait +// CHECK: triton_nvidia_gpu.warp_group_dot +// CHECK: triton_nvidia_gpu.consumer_release +// CHECK: triton_nvidia_gpu.consumer_release + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_two_consumers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant {async_task_id = dense<0> : vector<1xi32>} dense<64> : tensor<64x64xi32, #blocked> + %c64_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 64 : i32 + %c128_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 128 : i32 + %c8_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 8 : i32 + %cst_0 = arith.constant {async_task_id = dense<0> : vector<1xi32>} dense<0.000000e+00> : tensor<64x64xf16, #blocked> + %cst_1 = arith.constant {async_task_id = dense<0> : vector<1xi32>} dense<0.000000e+00> : tensor<64x128xf16, #blocked1> + %c0_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 0 : i32 + %c1_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 1 : i32 + %c127_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 127 : i32 + %c63_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 63 : i32 + %cst_2 = arith.constant {async_task_id = dense<[1, 2]> : vector<2xi32>} dense<0.000000e+00> : tensor<64x128xf32, #mma> + %0 = tt.get_program_id x {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %1 = arith.addi %arg3, %c127_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %2 = arith.divsi %1, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %3 = arith.addi %arg4, %c127_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %4 = arith.divsi %3, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %5 = arith.muli %4, %c8_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %6 = arith.divsi %0, %5 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %7 = arith.muli %6, %c8_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %8 = arith.subi %2, %7 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %9 = arith.minsi %8, %c8_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %10 = arith.remsi %0, %5 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %11 = arith.remsi %10, %9 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %12 = arith.addi %7, %11 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %13 = arith.divsi %10, %9 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %14 = arith.muli %12, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %15 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %16 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %17 = tt.splat %14 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %18 = tt.splat %14 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %19 = arith.addi %17, %15 {async_task_id = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %20 = arith.addi %18, %16 {async_task_id = dense<1> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %21 = tt.splat %arg3 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %22 = arith.remsi %19, %21 {async_task_id = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %23 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 128 : i32, start = 64 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %24 = tt.make_range {async_task_id = dense<2> : vector<1xi32>, end = 128 : i32, start = 64 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %25 = arith.addi %17, %23 {async_task_id = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %26 = arith.addi %18, %24 {async_task_id = dense<2> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %27 = arith.remsi %25, %21 {async_task_id = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %28 = arith.muli %13, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %29 = tt.make_range {async_task_id = dense<[0, 1, 2]> : vector<3xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %30 = tt.splat %28 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %31 = arith.addi %30, %29 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %32 = tt.splat %arg4 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %33 = arith.remsi %31, %32 {async_task_id = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %34 = tt.expand_dims %22 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %35 = tt.splat %arg6 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked> + %36 = arith.muli %34, %35 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked> + %37 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %38 = tt.expand_dims %37 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %39 = tt.broadcast %36 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked> -> tensor<64x64xi32, #blocked> + %40 = tt.broadcast %38 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x64xi32, #blocked> -> tensor<64x64xi32, #blocked> + %41 = arith.addi %39, %40 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64xi32, #blocked> + %42 = tt.splat %arg0 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked> + %43 = tt.addptr %42, %41 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi32, #blocked> + %44 = tt.expand_dims %27 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %45 = arith.muli %44, %35 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked> + %46 = tt.broadcast %45 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked> -> tensor<64x64xi32, #blocked> + %47 = arith.addi %46, %40 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64xi32, #blocked> + %48 = tt.addptr %42, %47 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi32, #blocked> + %49 = tt.expand_dims %16 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %50 = tt.splat %arg7 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked1> + %51 = arith.muli %49, %50 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %52 = tt.expand_dims %33 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> + %53 = tt.broadcast %51 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> -> tensor<64x128xi32, #blocked1> + %54 = tt.broadcast %52 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x128xi32, #blocked1> -> tensor<64x128xi32, #blocked1> + %55 = arith.addi %53, %54 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x128xi32, #blocked1> + %56 = tt.splat %arg1 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x128x!tt.ptr, #blocked1> + %57 = tt.addptr %56, %55 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1>, tensor<64x128xi32, #blocked1> + %58 = arith.addi %arg5, %c63_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %59 = arith.divsi %58, %c64_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %60 = tt.expand_dims %37 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %61 = tt.expand_dims %16 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %62 = arith.muli %arg7, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %63 = tt.splat %62 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x128xi32, #blocked1> + %true = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} true + %false = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} false + %true_3 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} true + %false_4 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} false + %64:5 = scf.for %arg9 = %c0_i32 to %59 step %c1_i32 iter_args(%arg10 = %cst_2, %arg11 = %cst_2, %arg12 = %43, %arg13 = %57, %arg14 = %48) -> (tensor<64x128xf32, #mma>, tensor<64x128xf32, #mma>, tensor<64x64x!tt.ptr, #blocked>, tensor<64x128x!tt.ptr, #blocked1>, tensor<64x64x!tt.ptr, #blocked>) : i32 { + %93 = arith.muli %arg9, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %94 = arith.subi %arg5, %93 {async_task_id = dense<0> : vector<1xi32>} : i32 + %95 = tt.splat %94 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<1x64xi32, #blocked> + %96 = arith.cmpi slt, %60, %95 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x64xi32, #blocked> + %97 = tt.broadcast %96 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x64xi1, #blocked> -> tensor<64x64xi1, #blocked> + %98 = tt.load %arg12, %97, %cst_0 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked> + %99 = triton_gpu.local_alloc %98 {async_task_id = dense<1> : vector<1xi32>} : (tensor<64x64xf16, #blocked>) -> !tt.memdesc<64x64xf16, #shared, #triton_gpu.shared_memory> + %100 = tt.splat %94 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked1> + %101 = arith.cmpi slt, %61, %100 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %102 = tt.broadcast %101 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi1, #blocked1> -> tensor<64x128xi1, #blocked1> + %103 = tt.load %arg13, %102, %cst_1 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1> + %104 = triton_gpu.local_alloc %103 {async_task_id = dense<[1, 2]> : vector<2xi32>} : (tensor<64x128xf16, #blocked1>) -> !tt.memdesc<64x128xf16, #shared, #triton_gpu.shared_memory> + %105 = tt.load %arg14, %97, %cst_0 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked> + %106 = triton_gpu.local_alloc %105 {async_task_id = dense<2> : vector<1xi32>} : (tensor<64x64xf16, #blocked>) -> !tt.memdesc<64x64xf16, #shared, #triton_gpu.shared_memory> + %107 = triton_nvidia_gpu.warp_group_dot %99, %104, %arg10 {async_task_id = dense<1> : vector<1xi32>, inputPrecision = 0 : i32} : !tt.memdesc<64x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<64x128xf32, #mma> + %108 = triton_nvidia_gpu.warp_group_dot %106, %104, %arg11 {async_task_id = dense<2> : vector<1xi32>, inputPrecision = 0 : i32} : !tt.memdesc<64x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<64x128xf32, #mma> + %109 = tt.addptr %arg12, %cst {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi32, #blocked> + %110 = tt.addptr %arg14, %cst {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi32, #blocked> + %111 = tt.addptr %arg13, %63 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1>, tensor<64x128xi32, #blocked1> + scf.yield {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} %107, %108, %109, %111, %110 : tensor<64x128xf32, #mma>, tensor<64x128xf32, #mma>, tensor<64x64x!tt.ptr, #blocked>, tensor<64x128x!tt.ptr, #blocked1>, tensor<64x64x!tt.ptr, #blocked> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} + %65 = arith.truncf %64#0 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xf32, #mma> to tensor<64x128xf16, #mma> + %66 = arith.truncf %64#1 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xf32, #mma> to tensor<64x128xf16, #mma> + %67 = tt.expand_dims %20 {async_task_id = dense<1> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %68 = tt.splat %arg8 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<64x1xi32, #blocked1> + %69 = arith.muli %68, %67 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %70 = tt.splat %arg2 {async_task_id = dense<[1, 2]> : vector<2xi32>} : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1> + %71 = tt.addptr %70, %69 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %72 = tt.expand_dims %31 {async_task_id = dense<[1, 2]> : vector<2xi32>, axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> + %73 = tt.broadcast %71 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x128x!tt.ptr, #blocked1> + %74 = tt.broadcast %72 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x128xi32, #blocked1> -> tensor<64x128xi32, #blocked1> + %75 = tt.addptr %73, %74 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1>, tensor<64x128xi32, #blocked1> + %76 = tt.expand_dims %26 {async_task_id = dense<2> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %77 = arith.muli %68, %76 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %78 = tt.addptr %70, %77 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %79 = tt.broadcast %78 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x128x!tt.ptr, #blocked1> + %80 = tt.addptr %79, %74 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1>, tensor<64x128xi32, #blocked1> + %81 = tt.splat %arg3 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<64x1xi32, #blocked1> + %82 = arith.cmpi slt, %67, %81 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %83 = tt.splat %arg4 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<1x128xi32, #blocked1> + %84 = arith.cmpi slt, %72, %83 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x128xi32, #blocked1> + %85 = tt.broadcast %82 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi1, #blocked1> -> tensor<64x128xi1, #blocked1> + %86 = tt.broadcast %84 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x128xi1, #blocked1> -> tensor<64x128xi1, #blocked1> + %87 = arith.andi %85, %86 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xi1, #blocked1> + %88 = arith.cmpi slt, %76, %81 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %89 = tt.broadcast %88 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi1, #blocked1> -> tensor<64x128xi1, #blocked1> + %90 = arith.andi %89, %86 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xi1, #blocked1> + %91 = triton_gpu.convert_layout %65 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xf16, #mma> -> tensor<64x128xf16, #blocked1> + tt.store %75, %91, %87 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1> + %92 = triton_gpu.convert_layout %66 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xf16, #mma> -> tensor<64x128xf16, #blocked1> + tt.store %80, %92, %90 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1> + tt.return + } +} diff --git a/test/TritonNvidiaGPU/WarpSpecialization/ws_data_partition.mlir b/test/TritonNvidiaGPU/WarpSpecialization/ws_data_partition.mlir new file mode 100644 index 000000000..3816f5bc4 --- /dev/null +++ b/test/TritonNvidiaGPU/WarpSpecialization/ws_data_partition.mlir @@ -0,0 +1,136 @@ +// RUN: triton-opt %s -split-input-file --tritongpu-warp-spec-data-partition=num-consumer-groups=2 | FileCheck %s + +// CHECK-LABEL: @matmul_persistent_ws_cooperative_kernel +// CHECK: %[[#GA1:]] = tt.load {{.*}} : tensor<64x64x!tt.ptr +// CHECK: %[[#GA2:]] = tt.load {{.*}} : tensor<64x64x!tt.ptr +// CHECK: %[[#LA1:]] = triton_gpu.local_alloc %[[#GA1]] +// CHECK: %[[#LA2:]] = triton_gpu.local_alloc %[[#GA2]] +// CHECK: %[[#GB:]] = tt.load {{.*}} : tensor<64x256x!tt.ptr +// CHECK: %[[#LB:]] = triton_gpu.local_alloc %[[#GB]] +// CHECK: %[[#C1:]] = triton_nvidia_gpu.warp_group_dot %[[#LA1]], %[[#LB]], {{.*}} : !tt.memdesc<64x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<64x256xf32, #mma> +// CHECK: %[[#C2:]] = triton_nvidia_gpu.warp_group_dot %[[#LA2]], %[[#LB]], {{.*}} : !tt.memdesc<64x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<64x256xf32, #mma> +// CHECK: tt.store {{.*}} : tensor<64x256x!tt.ptr, #blocked1> +// CHECK: tt.store {{.*}} : tensor<64x256x!tt.ptr, #blocked1> + + + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @matmul_persistent_ws_cooperative_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant {async_task_id = dense<0> : vector<1xi32>} dense<64> : tensor<128x64xi32, #blocked> + %c0_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 0 : i32 + %c1_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 1 : i32 + %c255_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 255 : i32 + %c63_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 63 : i32 + %c64_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 64 : i32 + %c256_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 256 : i32 + %c128_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 128 : i32 + %c8_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 8 : i32 + %c127_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 127 : i32 + %cst_0 = arith.constant {async_task_id = dense<0> : vector<1xi32>} dense<0.000000e+00> : tensor<128x64xf16, #blocked> + %cst_1 = arith.constant {async_task_id = dense<0> : vector<1xi32>} dense<0.000000e+00> : tensor<64x256xf16, #blocked1> + %cst_2 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} dense<0.000000e+00> : tensor<128x256xf32, #mma> + %0 = arith.addi %arg3, %c127_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %1 = arith.divsi %0, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %2 = arith.addi %arg4, %c255_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %3 = arith.divsi %2, %c256_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %4 = arith.muli %1, %3 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %5 = tt.get_program_id x {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %6 = tt.get_num_programs x {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %7 = arith.muli %3, %c8_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %8 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %9 = tt.make_range {async_task_id = dense<[1, 2]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %10 = tt.splat %arg3 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %11 = tt.make_range {async_task_id = dense<[0, 1, 2]> : vector<3xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %12 = tt.splat %arg4 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %13 = tt.splat %arg6 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked> + %14 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %15 = tt.expand_dims %14 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %16 = tt.broadcast %15 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x64xi32, #blocked> -> tensor<128x64xi32, #blocked> + %17 = tt.splat %arg0 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> + %18 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %19 = tt.expand_dims %18 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %20 = tt.splat %arg7 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked1> + %21 = arith.muli %19, %20 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %22 = tt.broadcast %21 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> -> tensor<64x256xi32, #blocked1> + %23 = tt.splat %arg1 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked1> + %24 = arith.addi %arg5, %c63_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %25 = arith.divsi %24, %c64_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %26 = tt.expand_dims %14 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %27 = tt.expand_dims %18 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %28 = arith.muli %arg7, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %29 = tt.splat %28 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x256xi32, #blocked1> + %30 = tt.splat %arg8 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<128x1xi32, #blocked1> + %31 = tt.splat %arg2 {async_task_id = dense<[1, 2]> : vector<2xi32>} : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %32 = tt.splat %arg3 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<128x1xi32, #blocked1> + %33 = tt.splat %arg4 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<1x256xi32, #blocked1> + scf.for %arg9 = %5 to %4 step %6 : i32 { + %34 = arith.divsi %arg9, %7 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %35 = arith.muli %34, %c8_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %36 = arith.subi %1, %35 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %37 = arith.minsi %36, %c8_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %38 = arith.remsi %arg9, %7 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %39 = arith.remsi %38, %37 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %40 = arith.addi %35, %39 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %41 = arith.divsi %38, %37 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %42 = arith.muli %40, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %43 = tt.splat %42 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %44 = tt.splat %42 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %45 = arith.addi %43, %8 {async_task_id = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %46 = arith.addi %44, %9 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %47 = arith.remsi %45, %10 {async_task_id = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %48 = arith.muli %41, %c256_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %49 = tt.splat %48 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %50 = arith.addi %49, %11 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %51 = arith.remsi %50, %12 {async_task_id = dense<0> : vector<1xi32>} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %52 = tt.expand_dims %47 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %53 = arith.muli %52, %13 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked> + %54 = tt.broadcast %53 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked> -> tensor<128x64xi32, #blocked> + %55 = arith.addi %54, %16 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x64xi32, #blocked> + %56 = tt.addptr %17, %55 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + %57 = tt.expand_dims %51 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %58 = tt.broadcast %57 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi32, #blocked1> -> tensor<64x256xi32, #blocked1> + %59 = arith.addi %22, %58 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x256xi32, #blocked1> + %60 = tt.addptr %23, %59 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> + %true = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} true + %false = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} false + %61:3 = scf.for %arg10 = %c0_i32 to %25 step %c1_i32 iter_args(%arg11 = %cst_2, %arg12 = %56, %arg13 = %60) -> (tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked>, tensor<64x256x!tt.ptr, #blocked1>) : i32 { + %76 = arith.muli %arg10, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %77 = arith.subi %arg5, %76 {async_task_id = dense<0> : vector<1xi32>} : i32 + %78 = tt.splat %77 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<1x64xi32, #blocked> + %79 = arith.cmpi slt, %26, %78 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x64xi32, #blocked> + %80 = tt.broadcast %79 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x64xi1, #blocked> -> tensor<128x64xi1, #blocked> + %81 = tt.load %arg12, %80, %cst_0 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x64x!tt.ptr, #blocked> + %82 = triton_gpu.local_alloc %81 {async_task_id = dense<[1, 2]> : vector<2xi32>} : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %83 = tt.splat %77 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked1> + %84 = arith.cmpi slt, %27, %83 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %85 = tt.broadcast %84 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi1, #blocked1> -> tensor<64x256xi1, #blocked1> + %86 = tt.load %arg13, %85, %cst_1 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x256x!tt.ptr, #blocked1> + %87 = triton_gpu.local_alloc %86 {async_task_id = dense<[1, 2]> : vector<2xi32>} : (tensor<64x256xf16, #blocked1>) -> !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> + %88 = triton_nvidia_gpu.warp_group_dot %82, %87, %arg11 {async_task_id = dense<[1, 2]> : vector<2xi32>, inputPrecision = 0 : i32} : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x256xf32, #mma> + %89 = tt.addptr %arg12, %cst {async_task_id = dense<0> : vector<1xi32>} : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + %90 = tt.addptr %arg13, %29 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> + scf.yield {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} %88, %89, %90 : tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked>, tensor<64x256x!tt.ptr, #blocked1> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} + %62 = arith.truncf %61#0 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> + %63 = tt.expand_dims %46 {async_task_id = dense<[1, 2]> : vector<2xi32>, axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %64 = arith.muli %30, %63 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x1xi32, #blocked1> + %65 = tt.addptr %31, %64 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %66 = tt.expand_dims %50 {async_task_id = dense<[1, 2]> : vector<2xi32>, axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %67 = tt.broadcast %65 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x256x!tt.ptr, #blocked1> + %68 = tt.broadcast %66 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x256xi32, #blocked1> -> tensor<128x256xi32, #blocked1> + %69 = tt.addptr %67, %68 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x256x!tt.ptr, #blocked1>, tensor<128x256xi32, #blocked1> + %70 = arith.cmpi slt, %63, %32 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x1xi32, #blocked1> + %71 = arith.cmpi slt, %66, %33 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x256xi32, #blocked1> + %72 = tt.broadcast %70 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x1xi1, #blocked1> -> tensor<128x256xi1, #blocked1> + %73 = tt.broadcast %71 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x256xi1, #blocked1> -> tensor<128x256xi1, #blocked1> + %74 = arith.andi %72, %73 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x256xi1, #blocked1> + %75 = triton_gpu.convert_layout %62 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked1> + tt.store %69, %75, %74 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x256x!tt.ptr, #blocked1> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} + tt.return + } +} diff --git a/test/TritonNvidiaGPU/WarpSpecialization/ws_lowering.mlir b/test/TritonNvidiaGPU/WarpSpecialization/ws_lowering.mlir new file mode 100644 index 000000000..de69a59b8 --- /dev/null +++ b/test/TritonNvidiaGPU/WarpSpecialization/ws_lowering.mlir @@ -0,0 +1,237 @@ +// RUN: triton-opt %s -split-input-file --tritongpu-warp-spec-lowering=num-consumer-groups=1 | FileCheck %s + +// CHECK: %[[#PBARRIER:]] = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64 +// CHECK: %[[#CBARRIER:]] = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64 +// CHECK: %[[#]] = triton_gpu.memdesc_subview %[[#PBARRIER]][%c0_i32] +// CHECK: triton_nvidia_gpu.init_barrier %[[#]], 128 +// CHECK: %[[#]] = triton_gpu.memdesc_subview %[[#CBARRIER]][%c0_i32] +// CHECK: triton_nvidia_gpu.init_barrier %[[#]], 1 +// CHECK: scf.for +// CHECK: %[[#]] = triton_gpu.memdesc_subview %[[#CBARRIER]] +// CHECK: triton_nvidia_gpu.wait_barrier %[[#]] +// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: %[[#]] = triton_gpu.memdesc_subview %[[#PBARRIER]] +// CHECK: triton_nvidia_gpu.mbarrier_arrive %[[#]] +// CHECK: scf.for +// CHECK: %[[#]] = triton_gpu.memdesc_subview %[[#PBARRIER]] +// CHECK: triton_nvidia_gpu.wait_barrier %[[#]] +// CHECK: triton_gpu.local_load +// CHECK: triton_gpu.local_load +// CHECK: tt.dot +// CHECK: %[[#]] = triton_gpu.memdesc_subview %[[#CBARRIER]] +// CHECK: triton_nvidia_gpu.mbarrier_arrive %[[#]] + + + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x128x256xf16, #shared, #triton_gpu.shared_memory, mutable> + %1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x256x128xf16, #shared, #triton_gpu.shared_memory, mutable> + %2 = triton_nvidia_gpu.create_token {num = 1 : i32} : tensor<1x!triton_nvidia_gpu.token> + %3 = triton_nvidia_gpu.get_async_task_id : i32 + %c0_i32 = arith.constant 0 : i32 + %4 = arith.cmpi eq, %3, %c0_i32 : i32 + scf.if %4 { + %c255_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 255 : i32 + %c127_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 127 : i32 + %c1_i32_0 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 1 : i32 + %c0_i32_1 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %cst = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<0.000000e+00> : tensor<256x128xf16, #blocked> + %cst_2 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<0.000000e+00> : tensor<128x256xf16, #blocked1> + %c8_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 8 : i32 + %c128_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 128 : i32 + %c256_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 256 : i32 + %cst_3 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<256> : tensor<128x256xi32, #blocked1> + %6 = tt.get_program_id x {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %7 = arith.addi %arg3, %c127_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %8 = arith.divsi %7, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %9 = arith.addi %arg4, %c127_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %10 = arith.divsi %9, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %11 = arith.muli %10, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %12 = arith.divsi %6, %11 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %13 = arith.muli %12, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %14 = arith.subi %8, %13 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %15 = arith.minsi %14, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %16 = arith.remsi %6, %11 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %17 = arith.remsi %16, %15 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %18 = arith.addi %13, %17 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %19 = arith.divsi %16, %15 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %20 = arith.muli %18, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %21 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %22 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %23 = tt.splat %20 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %24 = arith.addi %23, %21 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %25 = tt.splat %arg3 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %26 = arith.remsi %24, %25 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %27 = arith.muli %19, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %28 = tt.splat %27 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %22 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %30 = tt.splat %arg4 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %31 = arith.remsi %29, %30 {async_task_id = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %32 = tt.expand_dims %26 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %33 = tt.splat %arg6 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked1> + %34 = arith.muli %32, %33 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked1> + %35 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %36 = tt.expand_dims %35 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %37 = tt.broadcast %34 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked1> -> tensor<128x256xi32, #blocked1> + %38 = tt.broadcast %36 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi32, #blocked1> -> tensor<128x256xi32, #blocked1> + %39 = arith.addi %37, %38 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256xi32, #blocked1> + %40 = tt.splat %arg0 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<128x256x!tt.ptr, #blocked1> + %41 = tt.addptr %40, %39 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256x!tt.ptr, #blocked1>, tensor<128x256xi32, #blocked1> + %42 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %43 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %44 = tt.expand_dims %42 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked> + %45 = tt.expand_dims %43 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked> + %46 = tt.splat %arg7 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256x1xi32, #blocked> + %47 = arith.muli %44, %46 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi32, #blocked> + %48 = tt.expand_dims %31 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %49 = tt.broadcast %47 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi32, #blocked> -> tensor<256x128xi32, #blocked> + %50 = tt.broadcast %48 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x128xi32, #blocked> -> tensor<256x128xi32, #blocked> + %51 = arith.addi %49, %50 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128xi32, #blocked> + %52 = tt.splat %arg1 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<256x128x!tt.ptr, #blocked> + %53 = tt.addptr %52, %51 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128x!tt.ptr, #blocked>, tensor<256x128xi32, #blocked> + %54 = arith.addi %arg5, %c255_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %55 = arith.divsi %54, %c256_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %56 = arith.muli %arg7, %c256_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %57 = tt.splat %56 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256x128xi32, #blocked> + %c1_i32_4 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 1 : i32 + %c0_i32_5 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %false = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} false + %58:4 = scf.for %arg9 = %c0_i32_1 to %55 step %c1_i32_0 iter_args(%arg10 = %41, %arg11 = %53, %arg12 = %false, %arg13 = %c0_i32_5) -> (tensor<128x256x!tt.ptr, #blocked1>, tensor<256x128x!tt.ptr, #blocked>, i1, i32) : i32 { + %59 = arith.muli %arg9, %c256_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %60 = arith.subi %arg5, %59 {async_task_id = dense<0> : vector<1xi32>} : i32 + %61 = tt.splat %60 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<1x256xi32, #blocked1> + %62 = arith.cmpi slt, %36, %61 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi32, #blocked1> + %63 = tt.broadcast %62 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi1, #blocked1> -> tensor<128x256xi1, #blocked1> + triton_nvidia_gpu.producer_acquire %2, %arg13, %false {async_task_id = dense<0> : vector<1xi32>} : tensor<1x!triton_nvidia_gpu.token>, i32, i1 + %c0_i32_6 = arith.constant {async_task_id = dense<0> : vector<1xi32>} 0 : i32 + %c1_i32_7 = arith.constant {async_task_id = dense<0> : vector<1xi32>} 1 : i32 + %64 = triton_gpu.memdesc_subview %0[%arg13, %c0_i32_6, %c0_i32_6] {async_task_id = dense<0> : vector<1xi32>} : !tt.memdesc<1x128x256xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x256xf16, #shared, #triton_gpu.shared_memory, mutable> + %65 = triton_gpu.async_copy_global_to_local %arg10, %64 mask %63 other %cst_2 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256x!tt.ptr, #blocked1> -> <128x256xf16, #shared, #triton_gpu.shared_memory, mutable> + %66 = tt.splat %60 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256x1xi32, #blocked> + %67 = arith.cmpi slt, %45, %66 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi32, #blocked> + %68 = tt.broadcast %67 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi1, #blocked> -> tensor<256x128xi1, #blocked> + %c0_i32_8 = arith.constant {async_task_id = dense<0> : vector<1xi32>} 0 : i32 + %c1_i32_9 = arith.constant {async_task_id = dense<0> : vector<1xi32>} 1 : i32 + %69 = triton_gpu.memdesc_subview %1[%arg13, %c0_i32_8, %c0_i32_8] {async_task_id = dense<0> : vector<1xi32>} : !tt.memdesc<1x256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> + %70 = triton_gpu.async_copy_global_to_local %arg11, %69 mask %68 other %cst {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128x!tt.ptr, #blocked> -> <256x128xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.producer_commit %2, %arg13 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x!triton_nvidia_gpu.token>, i32 + %71 = tt.addptr %arg10, %cst_3 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256x!tt.ptr, #blocked1>, tensor<128x256xi32, #blocked1> + %72 = tt.addptr %arg11, %57 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128x!tt.ptr, #blocked>, tensor<256x128xi32, #blocked> + %c1_i32_10 = arith.constant {async_task_id = dense<0> : vector<1xi32>} 1 : i32 + %c0_i32_11 = arith.constant {async_task_id = dense<0> : vector<1xi32>} 0 : i32 + %true = arith.constant {async_task_id = dense<0> : vector<1xi32>} true + %73 = arith.addi %arg13, %c1_i32_10 {async_task_id = dense<0> : vector<1xi32>} : i32 + %74 = arith.cmpi uge, %73, %c1_i32_4 {async_task_id = dense<0> : vector<1xi32>} : i32 + %75 = arith.cmpi ult, %73, %c1_i32_4 {async_task_id = dense<0> : vector<1xi32>} : i32 + %76 = arith.subi %73, %c1_i32_4 {async_task_id = dense<0> : vector<1xi32>} : i32 + %77 = arith.select %74, %76, %73 {async_task_id = dense<0> : vector<1xi32>} : i32 + %78 = arith.xori %arg12, %true {async_task_id = dense<0> : vector<1xi32>} : i1 + %79 = arith.andi %74, %78 {async_task_id = dense<0> : vector<1xi32>} : i1 + %80 = arith.andi %75, %arg12 {async_task_id = dense<0> : vector<1xi32>} : i1 + %81 = arith.ori %79, %80 {async_task_id = dense<0> : vector<1xi32>} : i1 + scf.yield {async_task_id = dense<0> : vector<1xi32>} %71, %72, %81, %77 : tensor<128x256x!tt.ptr, #blocked1>, tensor<256x128x!tt.ptr, #blocked>, i1, i32 + } {async_task_id = dense<0> : vector<1xi32>} + } {async_task_id = dense<0> : vector<1xi32>} + %c1_i32 = arith.constant 1 : i32 + %5 = arith.cmpi eq, %3, %c1_i32 : i32 + scf.if %5 { + %cst = arith.constant {async_task_id = dense<1> : vector<1xi32>} dense<0.000000e+00> : tensor<128x128xf32, #blocked2> + %c255_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 255 : i32 + %c127_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 127 : i32 + %c1_i32_0 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 1 : i32 + %c0_i32_1 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %cst_2 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<0.000000e+00> : tensor<256x128xf16, #blocked> + %cst_3 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<0.000000e+00> : tensor<128x256xf16, #blocked1> + %c8_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 8 : i32 + %c128_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 128 : i32 + %c256_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 256 : i32 + %cst_4 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<256> : tensor<128x256xi32, #blocked1> + %6 = tt.get_program_id x {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %7 = arith.addi %arg3, %c127_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %8 = arith.divsi %7, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %9 = arith.addi %arg4, %c127_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %10 = arith.divsi %9, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %11 = arith.muli %10, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %12 = arith.divsi %6, %11 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %13 = arith.muli %12, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %14 = arith.subi %8, %13 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %15 = arith.minsi %14, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %16 = arith.remsi %6, %11 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %17 = arith.remsi %16, %15 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %18 = arith.addi %13, %17 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %19 = arith.divsi %16, %15 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %20 = arith.muli %18, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %21 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %22 = tt.make_range {async_task_id = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %23 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %24 = tt.splat %20 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %25 = tt.splat %20 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %26 = arith.addi %24, %21 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %27 = arith.addi %25, %22 {async_task_id = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %28 = tt.splat %arg3 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %29 = arith.remsi %26, %28 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %30 = arith.muli %19, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %31 = tt.splat %30 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %32 = arith.addi %31, %23 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %33 = arith.addi %arg5, %c255_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %34 = arith.divsi %33, %c256_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %c1_i32_5 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 1 : i32 + %c0_i32_6 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %false = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} false + %35:3 = scf.for %arg9 = %c0_i32_1 to %34 step %c1_i32_0 iter_args(%arg10 = %cst, %arg11 = %false, %arg12 = %c0_i32_6) -> (tensor<128x128xf32, #blocked2>, i1, i32) : i32 { + %c0_i32_7 = arith.constant {async_task_id = dense<1> : vector<1xi32>} 0 : i32 + %c1_i32_8 = arith.constant {async_task_id = dense<1> : vector<1xi32>} 1 : i32 + %c0_i32_9 = arith.constant {async_task_id = dense<1> : vector<1xi32>} 0 : i32 + %c1_i32_10 = arith.constant {async_task_id = dense<1> : vector<1xi32>} 1 : i32 + triton_nvidia_gpu.consumer_wait %2, %arg12, %false {async_task_id = dense<1> : vector<1xi32>} : tensor<1x!triton_nvidia_gpu.token>, i32, i1 + %54 = triton_gpu.memdesc_subview %0[%arg12, %c0_i32_7, %c0_i32_7] {async_task_id = dense<1> : vector<1xi32>} : !tt.memdesc<1x128x256xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x256xf16, #shared, #triton_gpu.shared_memory, mutable> + %55 = triton_gpu.local_load %54 {async_task_id = dense<1> : vector<1xi32>} : !tt.memdesc<128x256xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x256xf16, #blocked1> + %56 = triton_gpu.convert_layout %55 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x256xf16, #blocked1> -> tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> + %57 = triton_gpu.memdesc_subview %1[%arg12, %c0_i32_9, %c0_i32_9] {async_task_id = dense<1> : vector<1xi32>} : !tt.memdesc<1x256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> + %58 = triton_gpu.local_load %57 {async_task_id = dense<1> : vector<1xi32>} : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #blocked> + %59 = triton_gpu.convert_layout %58 {async_task_id = dense<1> : vector<1xi32>} : tensor<256x128xf16, #blocked> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> + %60 = tt.dot %56, %59, %arg10, inputPrecision = tf32 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x128xf32, #blocked2> + triton_nvidia_gpu.consumer_release %2, %arg12 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x!triton_nvidia_gpu.token>, i32 + %c1_i32_11 = arith.constant {async_task_id = dense<1> : vector<1xi32>} 1 : i32 + %c0_i32_12 = arith.constant {async_task_id = dense<1> : vector<1xi32>} 0 : i32 + %true = arith.constant {async_task_id = dense<1> : vector<1xi32>} true + %61 = arith.addi %arg12, %c1_i32_11 {async_task_id = dense<1> : vector<1xi32>} : i32 + %62 = arith.cmpi uge, %61, %c1_i32_5 {async_task_id = dense<1> : vector<1xi32>} : i32 + %63 = arith.cmpi ult, %61, %c1_i32_5 {async_task_id = dense<1> : vector<1xi32>} : i32 + %64 = arith.subi %61, %c1_i32_5 {async_task_id = dense<1> : vector<1xi32>} : i32 + %65 = arith.select %62, %64, %61 {async_task_id = dense<1> : vector<1xi32>} : i32 + %66 = arith.xori %arg11, %true {async_task_id = dense<1> : vector<1xi32>} : i1 + %67 = arith.andi %62, %66 {async_task_id = dense<1> : vector<1xi32>} : i1 + %68 = arith.andi %63, %arg11 {async_task_id = dense<1> : vector<1xi32>} : i1 + %69 = arith.ori %67, %68 {async_task_id = dense<1> : vector<1xi32>} : i1 + scf.yield {async_task_id = dense<1> : vector<1xi32>} %60, %69, %65 : tensor<128x128xf32, #blocked2>, i1, i32 + } {async_task_id = dense<1> : vector<1xi32>} + %36 = arith.truncf %35#0 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128xf32, #blocked2> to tensor<128x128xf16, #blocked2> + %37 = tt.expand_dims %27 {async_task_id = dense<1> : vector<1xi32>, axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %38 = tt.splat %arg8 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked> + %39 = arith.muli %38, %37 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1xi32, #blocked> + %40 = tt.splat %arg2 {async_task_id = dense<1> : vector<1xi32>} : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> + %41 = tt.addptr %40, %39 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> + %42 = tt.expand_dims %32 {async_task_id = dense<1> : vector<1xi32>, axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %43 = tt.broadcast %41 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> + %44 = tt.broadcast %42 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi32, #blocked> -> tensor<128x128xi32, #blocked> + %45 = tt.addptr %43, %44 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> + %46 = tt.splat %arg3 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked> + %47 = arith.cmpi slt, %37, %46 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1xi32, #blocked> + %48 = tt.splat %arg4 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<1x128xi32, #blocked> + %49 = arith.cmpi slt, %42, %48 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi32, #blocked> + %50 = tt.broadcast %47 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1xi1, #blocked> -> tensor<128x128xi1, #blocked> + %51 = tt.broadcast %49 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi1, #blocked> -> tensor<128x128xi1, #blocked> + %52 = arith.andi %50, %51 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128xi1, #blocked> + %53 = triton_gpu.convert_layout %36 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128xf16, #blocked2> -> tensor<128x128xf16, #blocked> + tt.store %45, %53, %52 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr, #blocked> + } {async_task_id = dense<1> : vector<1xi32>} + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index 8649911a7..626f41a0e 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -93,7 +93,7 @@ struct ConvertTritonAMDGPUToLLVM int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); - // Hack: WSMaterialization may have changed the effective number of warps, + // Hack: WSLowering may have changed the effective number of warps, // in a way that isn't reflected in triton_gpu.num-warps. If so, we have to // respect that here. if (Attribute attr = mod->getAttr("triton_gpu.num-warp-groups-per-cta")) { diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index adfde57b0..6d1122fb0 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -88,6 +88,11 @@ class CUDAOptions: num_warps: int = 4 num_ctas: int = 1 num_stages: int = 3 + num_buffers_warp_spec: int = 0 + num_consumer_groups: int = 0 + reg_dec_producer: int = 0 + reg_inc_consumer: int = 0 + partition_style: int = 0 # maxnreg corresponds to the ptx parameter .maxnreg, which controls the # maximum number of 32-bit registers used by one thread. maxnreg: Optional[int] = None @@ -221,7 +226,14 @@ def make_ttgir(mod, metadata, opt, capability): if capability // 10 >= 8: passes.ttgpuir.add_optimize_accumulator_init(pm) passes.ttgpuir.add_combine_tensor_select_and_if(pm) + passes.ttgpuir.add_taskid_propagate(pm, opt.num_consumer_groups) + passes.ttgpuir.add_ws_data_partition(pm, opt.num_consumer_groups) + passes.ttgpuir.add_ws_code_partition(pm, opt.num_buffers_warp_spec, opt.num_consumer_groups, + opt.reg_dec_producer, opt.reg_inc_consumer) + passes.ttgpuir.add_loop_scheduling(pm, opt.num_stages) passes.ttgpuir.add_pipeline(pm, opt.num_stages) + passes.ttgpuir.add_ping_pong_sync(pm, opt.num_consumer_groups, opt.partition_style) + passes.ttgpuir.add_ws_lowering(pm, opt.num_consumer_groups) passes.ttgpuir.add_prefetch(pm) passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) passes.ttgpuir.add_remove_layout_conversions(pm) diff --git a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td index 31b2646db..840e0714c 100644 --- a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td +++ b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td @@ -39,6 +39,7 @@ def NVGPU_WGMMAFenceOp : NVGPU_Op<"wgmma_fence", []> { let assemblyFormat = "attr-dict"; } + def NVGPU_WGMMACommitGroupOp : NVGPU_Op<"wgmma_commit_group", []> { let assemblyFormat = "attr-dict"; } @@ -52,6 +53,32 @@ def NVGPU_WGMMAWaitGroupOp : NVGPU_Op<"wgmma_wait_group", let assemblyFormat = "$input attr-dict `:` type($input)"; } +def MBarrier_ArriveTypeAttr : I32EnumAttr<"MBarriveType", + "mbarrier arrive type, either 'normal', 'expect_tx', 'cp_async'", + [ + I32EnumAttrCase<"normal", 0>, + I32EnumAttrCase<"cp_async", 1>, + I32EnumAttrCase<"expect_tx", 2>, + I32EnumAttrCase<"remote", 3>, + ]>{ + let cppNamespace = "::mlir::triton::nvgpu"; +} + +def NVGPU_MBarrierArriveOp : NVGPU_Op<"mbarrier_arrive", []> { + let arguments = (ins LLVM_PointerShared:$mbarrier, I1:$pred, Optional:$ctaId, MBarrier_ArriveTypeAttr:$arriveType, DefaultValuedAttr:$txCount); + let assemblyFormat = "$mbarrier `,` $pred (`,` $ctaId^)? attr-dict `:` type($mbarrier)"; +} + +def NVGPU_NamedBarrierArriveOp : NVGPU_Op<"bar_arrive", []> { + let arguments = (ins I32:$bar, I32:$numThreads); + let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)"; +} + +def NVGPU_NamedBarrierWaitOp : NVGPU_Op<"bar_wait", []> { + let arguments = (ins I32:$bar, I32:$numThreads); + let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)"; +} + def WGMMA_LayoutAttr : I32EnumAttr<"WGMMALayout", "wgmma layout, either 'row' or 'col'", [ @@ -112,4 +139,19 @@ def NVGPU_ClusterCTAIdOp : NVGPU_Op<"cluster_id", [Pure]> { let assemblyFormat = "attr-dict"; } +def NVGPU_CanonicalWarpIdOp : NVGPU_Op<"canonical_warp_id", [Pure]> { + let results = (outs I32:$result); + let assemblyFormat = "attr-dict"; +} + +def NVGPU_RegAllocOp : NVGPU_Op<"reg_alloc", []> { + let arguments = (ins I32Attr: $regCount); + let assemblyFormat = "attr-dict"; +} + +def NVGPU_RegDeallocOp : NVGPU_Op<"reg_dealloc", []> { + let arguments = (ins I32Attr: $regCount); + let assemblyFormat = "attr-dict"; +} + #endif diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp index 8de0efefc..5a461fb72 100644 --- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -38,6 +38,28 @@ const std::string Cluster_Cta_Id_Op = "{\n" "mad.lo.u32 a1, a2, a4, a1; \n" "mad.lo.u32 $0, a1, a3, a0; \n" "}"; +const std::string Reg_Alloc_Op = "setmaxnreg.inc.sync.aligned.u32 #regCount;"; +const std::string Reg_Dealloc_Op = "setmaxnreg.dec.sync.aligned.u32 #regCount;"; + +const std::string Named_Barrier_Arrive_Op = "bar.arrive $0, $1;"; +const std::string Named_Barrier_Wait_Op = "bar.sync $0, $1;"; +const std::string Canonical_Warp_Id_Op = + "{\n" + ".reg .u32 a<5>; \n" + "mov.u32 a0, %tid.x; \n" // x + "mov.u32 a1, %tid.y; \n" // y + "mov.u32 a2, %tid.z; \n" // z + "mov.u32 a3, %ntid.x; \n" // nx + "mov.u32 a4, %ntid.y; \n" // ny + "mad.lo.u32 a1, a2, a4, a1; \n" + "mad.lo.u32 a0, a1, a3, a0; \n" + "shr.u32 a0, a0, 5; \n" + ".reg .b32 %tmp<3>; \n" + "mov.u32 %tmp0, -1; \n" + "mov.u32 %tmp1, 31; \n" + "mov.u32 %tmp2, 0; \n" + "shfl.sync.idx.b32 $0, a0, %tmp2, %tmp1, %tmp0; \n" + "}"; bool isNumber(const std::string &s) { return !s.empty() && std::find_if(s.begin(), s.end(), [](unsigned char c) { @@ -278,6 +300,77 @@ class StoreMatrixOpPattern : public OpRewritePattern { } }; +class MBarrierArriveOpPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttn::MBarrierArriveOp op, + PatternRewriter &rewriter) const override { + return rewriteAsPtxAsm(op, rewriter, getPtxAsm(op), + getOperandsAndConstraints(op)); + } + + OperandsAndConstraints + getOperandsAndConstraints(ttn::MBarrierArriveOp op) const { + OperandsAndConstraints operandsAndTypes; + Value mbarrier = op.getMbarrier(); + Value pred = op.getPred(); + Value ctaId = op.getCtaId(); + auto arriveType = op.getArriveType(); + + switch (arriveType) { + case ttn::MBarriveType::normal: + case ttn::MBarriveType::cp_async: + case ttn::MBarriveType::expect_tx: + operandsAndTypes.push_back({mbarrier, "r"}); + operandsAndTypes.push_back({pred, "b"}); + break; + case ttn::MBarriveType::remote: + operandsAndTypes.push_back({mbarrier, "r"}); + operandsAndTypes.push_back({ctaId, "r"}); + operandsAndTypes.push_back({pred, "b"}); + break; + default: + llvm::errs() << "Unsupported mbarrier arrive type " << arriveType << "\n"; + llvm_unreachable(""); + break; + } + return operandsAndTypes; + } + + std::string getPtxAsm(ttn::MBarrierArriveOp op) const { + Value ctaId = op.getCtaId(); + auto arriveType = op.getArriveType(); + uint32_t txCount = op.getTxCount(); + std::string ptxAsm; + switch (arriveType) { + case ttn::MBarriveType::normal: + ptxAsm = "@$1 mbarrier.arrive.shared.b64 _, [$0];"; + break; + case ttn::MBarriveType::cp_async: + ptxAsm = "@$1 cp.async.mbarrier.arrive.noinc.shared.b64 [$0];"; + break; + case ttn::MBarriveType::expect_tx: + assert(txCount > 0 && "txCount should be valid"); + ptxAsm = "@$1 mbarrier.arrive.expect_tx.shared.b64 _, [$0], " + + std::to_string(txCount) + ";"; + break; + case ttn::MBarriveType::remote: + assert(ctaId && "ctaId should have a valid value"); + ptxAsm = + " { .reg .b32 remAddr32; \n" + " @$2 mapa.shared::cluster.u32 remAddr32, $0, $1; \n" + " @$2 mbarrier.arrive.shared::cluster.b64 _, [remAddr32]; } \n"; + break; + default: + llvm::errs() << "Unsupported mbarrier arrive type " << arriveType << "\n"; + llvm_unreachable(""); + break; + } + return ptxAsm; + } +}; + class WGMMAWaitGroupOpPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -507,17 +600,25 @@ class ConvertNVGPUToLLVM : public ConvertNVGPUToLLVMBase { #define POPULATE_NVGPU_OP(SRC_OP, ASM) \ patterns.add>(context, ASM, Constraints(), \ Constraints()); + POPULATE_NVGPU_OP(ttn::RegAllocOp, Reg_Alloc_Op) POPULATE_NVGPU_OP(ttn::WGMMAFenceOp, Wgmma_Fence_Op) POPULATE_NVGPU_OP(ttn::WGMMACommitGroupOp, Wgmma_Commit_Group_Op) POPULATE_NVGPU_OP(ttn::ClusterWaitOp, Cluster_Wait_Op) + POPULATE_NVGPU_OP(ttn::RegDeallocOp, Reg_Dealloc_Op) #undef POPULATE_NVGPU_OP + patterns.add>( + context, Named_Barrier_Arrive_Op, Constraints(), + Constraints({"r", "r"})); + patterns.add>( + context, Named_Barrier_Wait_Op, Constraints(), Constraints({"r", "r"})); patterns.add>( context, Cluster_Cta_Id_Op, Constraints({"=r"}), Constraints()); + patterns.add>( + context, Canonical_Warp_Id_Op, Constraints({"=r"}), Constraints()); - patterns - .add( - context); + patterns.add(context); if (applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed()) signalPassFailure(); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp index 746b910e1..268d1dbf6 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp @@ -55,6 +55,77 @@ struct BarrierOpConversion } }; +// -------------------------------------------------------------------------- +// -- MBarrier related Ops lowering, to be moved to a separate file --------- +// -------------------------------------------------------------------------- +struct MBarrierArriveOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::MBarrierArriveOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::MBarrierArriveOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto mbarrier = LLVM::getSharedMemoryObjectFromStruct( + op.getLoc(), adaptor.getMbarrier(), + typeConverter->convertType(op.getMbarrier().getType().getElementType()), + rewriter); + + bool trackAsyncOp = op.getTrackAsyncOp(); + triton::nvgpu::MBarriveType type = triton::nvgpu::MBarriveType::normal; + uint32_t txCount = op.getTxCount(); + auto remoteCtaId = adaptor.getRemoteCtaId(); + if (trackAsyncOp) { + type = triton::nvgpu::MBarriveType::cp_async; + } else if (remoteCtaId) { + assert(txCount == 0 && + "remote arrive of transaction mbarrier is not implemented yet"); + type = triton::nvgpu::MBarriveType::remote; + } else if (txCount > 0) { + type = triton::nvgpu::MBarriveType::expect_tx; + } + Value pred = adaptor.getPred(); + if (pred == nullptr) { + pred = int_val(/*width*/ 1, 1); + } + rewriter.replaceOpWithNewOp( + op, mbarrier.getBase(), pred, remoteCtaId, type, txCount); + return success(); + } +}; + +struct NamedBarrierArriveOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::NamedBarrierArriveOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::NamedBarrierArriveOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + rewriter.replaceOpWithNewOp( + op, adaptor.getBar(), adaptor.getNumThreads()); + return success(); + } +}; + +struct NamedBarrierWaitOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::NamedBarrierWaitOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::NamedBarrierWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + rewriter.replaceOpWithNewOp( + op, adaptor.getBar(), adaptor.getNumThreads()); + return success(); + } +}; + struct FenceAsyncSharedOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< @@ -83,8 +154,18 @@ struct InitBarrierOpConversion typeConverter->convertType(op.getAlloc().getType().getElementType()), rewriter); + auto asyncTaskIds = getAsyncTaskIds(op); + int executingThreadId = 0; + if (!asyncTaskIds.empty()) { + assert(asyncTaskIds.size() == 1 && "only support single async task"); + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + int warpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + executingThreadId = asyncTaskIds[0] * numWarps * warpSize; + } + auto id = getThreadId(rewriter, loc); - auto pred = icmp_eq(id, i32_val(0)); + auto pred = icmp_eq(id, i32_val(executingThreadId)); ::mlir::triton::PTXBuilder ptxBuilder; const std::string ptx = "@$0 mbarrier.init.shared::cta.b64 [$1], " + std::to_string(op.getCount()) + ";"; @@ -112,8 +193,17 @@ struct InvalBarrierOpConversion typeConverter->convertType(op.getAlloc().getType().getElementType()), rewriter); + auto asyncTaskIds = getAsyncTaskIds(op); + int executingThreadId = 0; + if (!asyncTaskIds.empty()) { + assert(asyncTaskIds.size() == 1 && "only support single async task"); + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + int warpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + executingThreadId = asyncTaskIds[0] * numWarps * warpSize; + } auto id = getThreadId(rewriter, loc); - Value pred = icmp_eq(id, i32_val(0)); + Value pred = icmp_eq(id, i32_val(executingThreadId)); ::mlir::triton::PTXBuilder ptxBuilder; const std::string ptx = "@$0 mbarrier.inval.shared::cta.b64 [$1];"; auto &barSyncOp = *ptxBuilder.create<>(ptx); @@ -140,8 +230,17 @@ struct BarrierExpectConversion typeConverter->convertType(op.getAlloc().getType().getElementType()), rewriter); + auto asyncTaskIds = getAsyncTaskIds(op); + int executingThreadId = 0; + if (!asyncTaskIds.empty()) { + assert(asyncTaskIds.size() == 1 && "only support single async task"); + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + int warpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + executingThreadId = asyncTaskIds[0] * numWarps * warpSize; + } auto id = getThreadId(rewriter, loc); - Value pred = icmp_eq(id, i32_val(0)); + Value pred = icmp_eq(id, i32_val(executingThreadId)); pred = and_(pred, adaptor.getPred()); ::mlir::triton::PTXBuilder ptxBuilder; const std::string ptx = @@ -194,6 +293,9 @@ void mlir::triton::NVIDIA::populateBarrierOpToLLVMPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 8fb44ce64..37702f1d6 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -505,7 +505,7 @@ struct ConvertLayoutOpConversion auto multiDimRepId = getMultiDimIndex(repId, numReplicates, outOrd); if (repId != 0) { - barrier(); + insertBarrier(rewriter, op); } if (isLayoutMmaV1(srcLayout)) @@ -517,7 +517,7 @@ struct ConvertLayoutOpConversion multiDimRepId, inVec, paddedRepShape, origRepShape, outOrd, vals, smemBase); - barrier(); + insertBarrier(rewriter, op); if (isLayoutMmaV1(dstLayout)) processReplicaForMMAV1(loc, rewriter, /*stNotRd*/ false, dstTy, diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index d950e0157..73e31104c 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -9,6 +9,8 @@ #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" using namespace mlir; using namespace mlir::triton; @@ -490,10 +492,21 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, } }; -void createBarrier(ConversionPatternRewriter &rewriter, Location loc, +void createBarrier(ConversionPatternRewriter &rewriter, Operation *op, int numCTAs) { + auto loc = op->getLoc(); if (numCTAs == 1) { - barrier(); + auto barrierOp = barrier(); + auto asyncTaskIds = getAsyncTaskIds(op); + if (asyncTaskIds.size() == 1) { + int asyncTaskId = asyncTaskIds[0]; + int barId = asyncTaskId + nameBarrierIdBegin; + assert(barId < nameBarrierIdEnd); + // TODO: Change hard code style of numThreads. + const int numThreads = 128; + barrierOp->setAttr("bar_id", rewriter.getI64IntegerAttr(barId)); + barrierOp->setAttr("num_threads", rewriter.getI64IntegerAttr(numThreads)); + } } else { rewriter.create(loc, false); rewriter.create(loc); @@ -606,7 +619,7 @@ struct AtomicCASOpConversion st(dstOprStore, valOprStore).predicate(mask); auto ASMReturnTy = void_ty(ctx); ptxBuilderStore.launch(rewriter, loc, ASMReturnTy); - createBarrier(rewriter, loc, numCTAs); + createBarrier(rewriter, op, numCTAs); Value ret = load(valueElemTy, atomPtr); rewriter.replaceOp(op, {ret}); } @@ -778,7 +791,7 @@ struct AtomicRMWOpConversion auto *valOpr = ptxBuilderStore.newOperand(old, tyId); storeShared(ptrOpr, valOpr).predicate(rmwMask); ptxBuilderStore.launch(rewriter, loc, void_ty(ctx)); - createBarrier(rewriter, loc, numCTAs); + createBarrier(rewriter, op, numCTAs); Value ret = load(valueElemTy, atomPtr); rewriter.replaceOp(op, {ret}); } @@ -988,6 +1001,13 @@ struct AsyncTMACopyGlobalToLocalOpConversion if (rank > 1) numCopies = ceil(contigDimSizeInByte, 128); + auto asyncTaskIds = getAsyncTaskIds(op); + int firstThreadId = 0; + if (!asyncTaskIds.empty()) { + assert(asyncTaskIds.size() == 1 && "only support single async task"); + firstThreadId = asyncTaskIds[0] * numWarps * warpSize; + } + // The bounding box inner dimension must be less than or equal to the // swizzle size. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 @@ -997,8 +1017,9 @@ struct AsyncTMACopyGlobalToLocalOpConversion int numWarpsToCopy = std::min(numCopies - copyIdx, numWarps); if (numWarpsToCopy == 1) warpID = i32_val(0); - Value boxPred = - and_(pred, icmp_ult(id, i32_val(numWarpsToCopy * warpSize))); + Value boxPred = and_( + pred, + icmp_ult(id, i32_val(numWarpsToCopy * warpSize + firstThreadId))); ::mlir::triton::PTXBuilder ptxBuilderTMA; Type elemPtrTy = ptr_ty(rewriter.getContext(), 3); Value copyIdxVal = add(warpID, i32_val(copyIdx)); @@ -1037,6 +1058,14 @@ struct AsyncTMACopyGlobalToLocalOpConversion } }; +int getWarpOffset(Operation *op) { + auto asyncTaskIds = getAsyncTaskIds(op); + if (asyncTaskIds.size() > 0) { + return 4 * *std::min_element(asyncTaskIds.begin(), asyncTaskIds.end()); + } + return 0; +} + struct AsyncTMACopyLocalToGlobalOpConversion : public ConvertOpToLLVMPattern< triton::nvidia_gpu::AsyncTMACopyLocalToGlobalOp> { @@ -1082,6 +1111,9 @@ struct AsyncTMACopyLocalToGlobalOpConversion int numWarpsToCopy = std::min(numCopies - copyIdx, numWarps); if (numWarpsToCopy == 1) warpID = i32_val(0); + auto warpOffset = getWarpOffset(op); + warpID = sub(warpID, i32_val(warpOffset)); + id = sub(id, i32_val(warpOffset * warpSize)); Value boxPred = and_(pred, icmp_ult(id, i32_val(numWarpsToCopy * warpSize))); ::mlir::triton::PTXBuilder ptxBuilderTMA; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/SPMDOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/SPMDOpToLLVM.cpp index 93ad46971..8bc55e187 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/SPMDOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/SPMDOpToLLVM.cpp @@ -1,5 +1,6 @@ #include "PatternTritonGPUOpToLLVM.h" #include "Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" namespace { @@ -33,10 +34,23 @@ struct GetNumProgramsOpConversion } }; +struct GetCanonicalWarpIdConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::GetCanonicalWarpIdOp>::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(triton::nvidia_gpu::GetCanonicalWarpIdOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(op, GetCanonicalWarpId(rewriter, op->getLoc())); + return success(); + } +}; } // namespace void mlir::triton::NVIDIA::populateSPMDOpToLLVMPattern( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp index 0a35176ec..1e1e7c488 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp @@ -96,6 +96,13 @@ struct ConvertTritonGPUToLLVM int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + // Hack: WSLowering may have changed the effective number of warps, + // in a way that isn't reflected in triton_gpu.num-warps. If so, we have to + // respect that here. + if (Attribute attr = mod->getAttr("triton_gpu.num-warp-groups-per-cta")) { + numWarps *= cast(attr).getInt(); + } + // Allocate shared memory and set barrier ModuleAllocation allocation(mod); ModuleMembarAnalysis membarPass(&allocation, NVIDIA::canSkipBarSync); @@ -175,6 +182,8 @@ struct ConvertTritonGPUToLLVM patterns, benefit); mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo, patterns, benefit); + mlir::triton::populateRegReallocOpToLLVMPatterns(typeConverter, patterns, + benefit); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure();