Skip to content

Commit

Permalink
Add a loop peeling pass
Browse files Browse the repository at this point in the history
  • Loading branch information
htyu committed Dec 20, 2024
1 parent 8706035 commit 6aab9f8
Show file tree
Hide file tree
Showing 9 changed files with 390 additions and 0 deletions.
8 changes: 8 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -251,4 +251,12 @@ def TritonGPUWSLowering : Pass<"tritongpu-warp-spec-lowering", "mlir::ModuleOp">
"number of consumer warp groups for warp specialization">
];
}

def TritonMatmulLoopPeeling : Pass<"tritongpu-matmul-loop-peeling", "mlir::ModuleOp"> {
let summary = "Loop peeling for matmul loop";

let description = "Peel the first iteration of the matmul loop to avoid the initialization of the accumulator";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
}
#endif
7 changes: 7 additions & 0 deletions include/triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ class OpBuilderWithAsyncTaskIds : public OpBuilder {
return op;
}

Operation *cloneWithAsyncTaskIds(Operation &op, IRMapping &mapper) {
Operation *newOp = OpBuilder::clone(op, mapper);
if (!asyncTaskIds.empty())
setAsyncTaskIds(newOp, asyncTaskIds);
return newOp;
}

private:
SmallVector<AsyncTaskId> asyncTaskIds;
};
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TritonGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ add_triton_library(TritonGPUTransforms
WSDataPartition.cpp
WSCodePartition.cpp
WSLowering.cpp
MatmulLoopPeeling.cpp

DEPENDS
TritonGPUTransformsIncGen
Expand Down
182 changes: 182 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/MatmulLoopPeeling.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#include <memory>

#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/Transforms/Passes.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "triton-loop-peeling"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

namespace mlir {
namespace triton {
namespace gpu {

scf::ForOp peelFirstIteration(scf::ForOp forOp) {
// Extract the first iteration outside the loop.
OpBuilderWithAsyncTaskIds builder(forOp);

// Map block arguments to loop initial values.
IRMapping mapping;
mapping.map(forOp.getBody()->getArguments()[0], forOp.getLowerBound());
for (unsigned i = 1; i < forOp.getBody()->getArguments().size(); ++i) {
mapping.map(forOp.getBody()->getArguments()[i], forOp.getInitArgs()[i - 1]);
LLVM_DEBUG({
LDBG("Mapping ");
forOp.getBody()->getArguments()[i].dump();
LDBG(" to ");
forOp.getInitArgs()[i - 1].dump();
LDBG("\n");
});
}

// Clone the operations in the loop body for the first iteration.
SmallVector<Value> peeledResults;
for (Operation &op : forOp.getBody()->getOperations()) {
if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
for (auto result : yieldOp->getOperands()) {
peeledResults.push_back(mapping.lookup(result));
}
} else {
auto newOp = builder.cloneWithAsyncTaskIds(op, mapping);
for (unsigned i = 0; i < op.getNumResults(); ++i) {
mapping.map(op.getResult(i), newOp->getResult(i));
}
}
}

// Adjust the original loop to become the remainder loop.
Value lb = forOp.getLowerBound();
Value step = forOp.getStep();
Value newLb = builder.create<arith::AddIOp>(forOp->getLoc(), lb, step);
assert(peeledResults.size() == forOp.getNumResults() &&
"peeled results size mismatch");
auto newForOp = builder.createWithAsyncTaskIds<scf::ForOp>(
forOp->getLoc(), newLb, forOp.getUpperBound(), step, peeledResults);
newForOp->setAttrs(forOp->getAttrs());
newForOp.getRegion().takeBody(forOp.getRegion());
for (unsigned i = 0; i < forOp.getNumResults(); ++i)
forOp.getResult(i).replaceAllUsesWith(newForOp.getResult(i));

// Erase the original loop.
forOp.erase();
return newForOp;
}

static bool isConstantZeroTensor(Value v) {
auto constOp = v.getDefiningOp<arith::ConstantOp>();
if (!constOp)
return false;
auto splat = mlir::dyn_cast<SplatElementsAttr>(constOp.getValue());
if (!splat)
return false;
return splat.getSplatValue<FloatAttr>().getValue().convertToFloat() == 0.0f;
}

// Check if the loop must be run at least once.
bool loopMustBeRunAtLeastOnce(scf::ForOp forOp) {
auto lb = forOp.getLowerBound();
auto ub = forOp.getUpperBound();
auto step = forOp.getStep();
auto lbInt = getConstantIntValue(lb);
auto ubInt = getConstantIntValue(ub);
auto stepInt = getConstantIntValue(step);

// Peeling is not needed if there is one or less iteration.
if (lbInt && ubInt && stepInt && ceil(float(*ubInt - *lbInt) / *stepInt) <= 1)
return false;

// Check if there is an assume that says the loop is not empty.
if (!lbInt || !ubInt) {
// Get the block containing the ForOp.
Block *block = forOp->getBlock();
// Iterate over operations in the block before the ForOp.
for (auto it = Block::iterator(forOp); it != block->begin(); --it) {
if (auto assumeOp = dyn_cast<LLVM::AssumeOp>(it)) {
LLVM_DEBUG({
LDBG("Found AssumeOp prior to ForOp:\n");
assumeOp->dump();
});
auto truth = assumeOp->getOperand(0);
if (auto cmpOp = truth.getDefiningOp<arith::CmpIOp>()) {
switch (cmpOp.getPredicate()) {
case arith::CmpIPredicate::sgt:
if (cmpOp.getLhs() == ub && cmpOp.getRhs() == lb) {
return true;
}
case arith::CmpIPredicate::slt:
if (cmpOp.getLhs() == lb && cmpOp.getRhs() == lb) {
return true;
}
default:
break;
}
}
}
}
}

return false;
}

bool shouldPeel(scf::ForOp forOp) {
SmallVector<Operation *> dotOps;
for (Operation &op : forOp.getBody()->without_terminator()) {
if (op.hasTrait<OpTrait::DotLike>())
dotOps.push_back(&op);
}

bool hasZeroAccDotOp = false;
for (Operation *dotOp : dotOps) {
auto acc = dotOp->getOperand(2);
if (auto arg = dyn_cast<BlockArgument>(acc)) {
assert(arg.getOwner() == forOp.getBody());
if (isConstantZeroTensor(forOp.getInitArgs()[arg.getArgNumber() - 1])) {
hasZeroAccDotOp = true;
break;
}
}
}

if (!hasZeroAccDotOp)
return false;
return loopMustBeRunAtLeastOnce(forOp);
}

#define GEN_PASS_DEF_TRITONMATMULLOOPPEELING
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"

class TritonMatmulLoopPeelingPass
: public impl::TritonMatmulLoopPeelingBase<TritonMatmulLoopPeelingPass> {

public:
TritonMatmulLoopPeelingPass() = default;
TritonMatmulLoopPeelingPass(const TritonMatmulLoopPeelingPass &) {}
void runOnOperation() override {
LDBG("Loop peeling pass");
SmallVector<scf::ForOp, 4> loops;
getOperation()->walk([&](scf::ForOp forOp) { loops.push_back(forOp); });
for (auto loop : loops) {
if (shouldPeel(loop))
(void)peelFirstIteration(loop);
}
LLVM_DEBUG({
LDBG("After loop peeling");
getOperation()->dump();
});
}
};

} // namespace gpu
} // namespace triton
} // namespace mlir
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def matmul_persistent_tma_ws_cooperative_kernel(
offs_bn = pid_n * BLOCK_SIZE_N
offs_k = 0

tl.assume(tl.cdiv(K, BLOCK_SIZE_K)>0)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
with tl.async_task([0]):
Expand Down
2 changes: 2 additions & 0 deletions python/src/passes.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "mlir/Transforms/Passes.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "passes.h"
Expand Down Expand Up @@ -68,6 +69,7 @@ void init_triton_passes_ttgpuir(py::module &&m) {
createTritonGPUCombineTensorSelectAndIf);
ADD_PASS_WRAPPER_0("add_optimize_accumulator_init",
createTritonGPUOptimizeAccumulatorInit);
ADD_PASS_WRAPPER_0("add_loop_peeling", createTritonMatmulLoopPeeling);
ADD_PASS_OPTION_WRAPPER_1("add_ws_data_partition",
createTritonGPUWSDataPartition, int);
ADD_PASS_OPTION_WRAPPER_1("add_ws_lowering", createTritonGPUWSLowering, int);
Expand Down
1 change: 1 addition & 0 deletions python/tutorials/03-matrix-multiplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def matmul_kernel(
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
tl.assume(tl.cdiv(K, BLOCK_SIZE_K)>0)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
Expand Down
Loading

0 comments on commit 6aab9f8

Please sign in to comment.