Skip to content

Commit

Permalink
[Backend] Convert FMA dot operand to linear layout (#5469)
Browse files Browse the repository at this point in the history
This PR
- Introduces FMA dot operand converter to linear layout, related tests
- Fixes FMA generation. previous version had incompatible repetitions
with blocked layout

Fixes triton-lang/triton#5423
  • Loading branch information
binarman authored Dec 30, 2024
1 parent 2939d86 commit d9facf3
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,13 @@ void storeValuesInLinearVector(PatternRewriter &rewriter, Location loc,
}
}

void verifyCTALayout(CTALayoutAttr ctaLayout) {
bool verifyCTALayout(CTALayoutAttr ctaLayout) {
auto ctaSplit = ctaLayout.getCTASplitNum();
for (auto split : ctaSplit) {
if (split != 1)
llvm::report_fatal_error("tensors splited in CGA(thread group clusters) "
"are not supported in FMA dot yet.");
return false;
}
return true;
}

/// Get a linear offset of first element loaded by thread.
Expand Down Expand Up @@ -216,7 +216,8 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout,
Value thread, Location loc,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, const int dotOpNo) {
verifyCTALayout(dLayout.getCTALayout());
if (!verifyCTALayout(dLayout.getCTALayout()))
return Value();

DimIdx dim;
dim.batch = 0;
Expand Down Expand Up @@ -292,6 +293,15 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout,
auto numBTiles = std::max(1u, B / shapePerCTABTile);
auto numNonKTiles = std::max(1u, NonK / shapePerCTANonKTile);

// Found discrepancy in this case,
// use linear layout based converter for this case
// TODO: break batch and non-k dimension iterations in
// "repeat" and "inside-repeate" parts, pack them in llvm structure
// according repeat and register order.
// See FMA.cpp:getValueTableFromStructFMA for reference
if (numBTiles != 1 || numNonKTiles != 1)
return Value();

auto perThreadShape =
getElemsPerThreadInOp(opTensorShape, shapePerCTATile, sizePerThread);

Expand Down
102 changes: 72 additions & 30 deletions lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,51 @@ using ::mlir::triton::gpu::expandMatrixShapeWithBatch;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getSizePerThread;

using ValueTableFMA = std::map<std::tuple<int, int, int>, Value>;
/// \brief spatial position of repetition and register of a given value
struct OperandValueKey {
unsigned bRepIdx, nonKRepIdx;
unsigned bIdx, nonKIdx, kIdx;

bool operator==(const OperandValueKey &other) const {
return (bRepIdx == other.bRepIdx && nonKRepIdx == other.nonKRepIdx &&
bIdx == other.bIdx && nonKIdx == other.nonKIdx &&
kIdx == other.kIdx);
}
};

template <> struct std::hash<OperandValueKey> {
std::size_t operator()(const OperandValueKey &k) const {
return llvm::hash_combine(k.bRepIdx, k.nonKRepIdx, k.bIdx, k.nonKIdx,
k.kIdx);
}
};

using ValueTableFMA = std::unordered_map<OperandValueKey, Value>;

static ValueTableFMA
getValueTableFromStructFMA(Value val, ArrayRef<unsigned> perTileShape,
unsigned kDim, unsigned nonKDim,
ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<unsigned> order) {
static ValueTableFMA getValueTableFromStructFMA(
Value val, ArrayRef<unsigned> perRepShape, ArrayRef<unsigned> repetitions,
unsigned kDim, unsigned nonKDim, ConversionPatternRewriter &rewriter,
Location loc, ArrayRef<unsigned> inRepOrder, ArrayRef<unsigned> repOrder) {
ValueTableFMA res;
auto elems = unpackLLElements(loc, val, rewriter);
assert(perTileShape.size() == 3);
assert(elems.size() == product(perTileShape));
assert(perRepShape.size() == 3);
auto numElemsRep = product(perRepShape);
assert(elems.size() == numElemsRep * product(repetitions));
assert(kDim == 1 || kDim == 2);
assert(nonKDim == 1 || nonKDim == 2);
const unsigned bDim = 0;

for (unsigned idx = 0; idx < elems.size(); ++idx) {
auto spatialIdx = mlir::LLVM::delinearize(idx, perTileShape, order);
res[{spatialIdx[bDim], spatialIdx[nonKDim], spatialIdx[kDim]}] = elems[idx];
auto inRepLinearIdx = idx % numElemsRep;
auto repLinearIdx = idx / numElemsRep;
auto inRepSpatialIdx =
mlir::LLVM::delinearize(inRepLinearIdx, perRepShape, inRepOrder);
auto repSpatialIdx =
mlir::LLVM::delinearize(repLinearIdx, repetitions, repOrder);
OperandValueKey key{repSpatialIdx[0], repSpatialIdx[nonKDim],
inRepSpatialIdx[0], inRepSpatialIdx[nonKDim],
inRepSpatialIdx[kDim]};
res[key] = elems[idx];
}
return res;
}
Expand All @@ -54,46 +81,61 @@ LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,

BlockedEncodingAttr dLayout =
cast<BlockedEncodingAttr>(dTensorTy.getEncoding());
auto order = expandMatrixOrderWithBatch(dLayout.getOrder());
// TODO process A and B operand separately
auto inRepOrder = expandMatrixOrderWithBatch(dLayout.getOrder());
auto repOrder = expandMatrixOrderWithBatch(dLayout.getRepOrder());
auto cc = unpackLLElements(loc, adaptor.getC(), rewriter);

Value llA = adaptor.getA();
Value llB = adaptor.getB();

auto sizePerThread =
expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout)));
auto numElemsPerThread = product(sizePerThread);
auto shapePerCTATile =
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout)));

unsigned K = aShapePerCTA[2];

unsigned perThreadShape[3];
unsigned threadTileShape[3];
unsigned repetitions[3];
for (int i = 0; i < 3; ++i) {
unsigned numRep = dShapePerCTA[i] / shapePerCTATile[i];
numRep = std::max(static_cast<unsigned>(1), numRep);
perThreadShape[i] = numRep * sizePerThread[i];
repetitions[i] =
ceil(dShapePerCTA[i], static_cast<int64_t>(shapePerCTATile[i]));
}

auto has = getValueTableFromStructFMA(
llA, {perThreadShape[0], perThreadShape[1], K},
/*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, order);
llA, {sizePerThread[0], sizePerThread[1], K},
{repetitions[0], repetitions[1], 1},
/*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, inRepOrder, repOrder);
auto hbs = getValueTableFromStructFMA(
llB, {perThreadShape[0], K, perThreadShape[2]},
/*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, order);
llB, {sizePerThread[0], K, sizePerThread[2]},
{repetitions[0], 1, repetitions[2]},
/*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, inRepOrder, repOrder);

SmallVector<Value> acc = cc;

for (unsigned b = 0; b < perThreadShape[0]; ++b)
for (unsigned m = 0; m < perThreadShape[1]; ++m)
for (unsigned n = 0; n < perThreadShape[2]; ++n) {
SmallVector<unsigned> multiDimAccumIdx = {b, m, n};
unsigned linearAccumIdx =
linearize(multiDimAccumIdx, perThreadShape, order);
for (unsigned k = 0; k < K; ++k) {
acc[linearAccumIdx] = rewriter.create<LLVM::FMulAddOp>(
loc, has[{b, m, k}], hbs[{b, n, k}], acc[linearAccumIdx]);
}
}
for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep)
for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep)
for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep)
for (unsigned b = 0; b < sizePerThread[0]; ++b)
for (unsigned m = 0; m < sizePerThread[1]; ++m)
for (unsigned n = 0; n < sizePerThread[2]; ++n) {
SmallVector<unsigned> multiDimAccumIdx = {b, m, n};
unsigned linearInRepIdx =
linearize(multiDimAccumIdx, sizePerThread, inRepOrder);
SmallVector<unsigned> multiDimRepIdx = {bRep, mRep, nRep};
unsigned linearRepIdx =
linearize(multiDimRepIdx, repetitions, repOrder);
unsigned linearAccumIdx =
linearInRepIdx + linearRepIdx * numElemsPerThread;
for (unsigned k = 0; k < K; ++k) {
auto aOp = has[{bRep, mRep, b, m, k}];
auto bOp = hbs[{bRep, nRep, b, n, k}];
acc[linearAccumIdx] = rewriter.create<LLVM::FMulAddOp>(
loc, aOp, bOp, acc[linearAccumIdx]);
}
}

auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy);
rewriter.replaceOp(op, res);
Expand Down
43 changes: 1 addition & 42 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,54 +119,13 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
}

// FIXME [Dot LL]
// Do for all DotOperandEncodingAttr once we have LLs for all of them
static bool isSupportedLayout(Attribute dstLayout) {
if (isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
LinearEncodingAttr>(dstLayout))
return true;
if (auto dot = dyn_cast<DotOperandEncodingAttr>(dstLayout)) {
if (isa<MmaEncodingTrait>(dot.getParent()))
return true;
}
return false;
};

LogicalResult
matchAndRewrite(LocalLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType dstTy = op.getType();
Attribute dstLayout = dstTy.getEncoding();
if (isSupportedLayout(dstLayout)) {
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
rewriter);
}
if (isa<DotOperandEncodingAttr>(dstLayout) &&
isa<BlockedEncodingAttr>(
cast<DotOperandEncodingAttr>(dstLayout).getParent())) {
return lowerSharedToDotOpFMA(op, adaptor, getTypeConverter(), rewriter);
}
return failure();
return lowerSharedToDistributed(op, adaptor, getTypeConverter(), rewriter);
}

private:
LogicalResult
lowerSharedToDotOpFMA(LocalLoadOp op, LocalLoadOpAdaptor adaptor,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
RankedTensorType dstTy = op.getType();
Attribute dstLayout = dstTy.getEncoding();
auto dotLayout = cast<DotOperandEncodingAttr>(dstLayout);
auto blockedLayout = cast<BlockedEncodingAttr>(
cast<DotOperandEncodingAttr>(dstLayout).getParent());
auto thread = getThreadId(rewriter, loc);
Value res = SharedToDotOperandFMA::convertLayout(
dotLayout.getOpIdx(), op.getSrc(), adaptor.getSrc(), blockedLayout,
thread, loc, getTypeConverter(), rewriter);
rewriter.replaceOp(op, res);
return success();
}
LogicalResult
lowerSharedToDistributed(LocalLoadOp op, LocalLoadOpAdaptor adaptor,
const LLVMTypeConverter *typeConverter,
Expand Down
79 changes: 63 additions & 16 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,11 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef<int64_t> shape,
return combineCtaCgaWithShape(tileLayout, shared.getCTALayout(), shape);
}

LinearLayout warpsDotOperand(MLIRContext *ctx, ArrayRef<unsigned> warpShape,
ArrayRef<unsigned> warpOrder, unsigned inner) {
/// Function to generate lane and warp layout for dot operands.
LinearLayout broadcastedDotOperandLayout(MLIRContext *ctx,
ArrayRef<unsigned> shape,
ArrayRef<unsigned> order,
unsigned kDim, StringAttr inDimName) {
// Let warpsPerCTAMma = {2, 2}, then
// warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB
// assume warpOrder = {1, 0}
Expand All @@ -255,24 +258,23 @@ LinearLayout warpsDotOperand(MLIRContext *ctx, ArrayRef<unsigned> warpShape,
// - - | - - - - | - -
// 2 3 | 2 3 0 2 | 1 3
// In other words, we need to broadcast along K
auto rank = warpShape.size();
auto rank = shape.size();
auto dimNames = standardOutDimNames(ctx, rank);
LinearLayout warpLayout = LinearLayout::empty();
LinearLayout layout = LinearLayout::empty();

// We have to broadcast along the inner dimension
// For A, when moving along M we go from 0 to 2.
// For B, when moving along N we go from 0 to 1.
// As such, choosing the order of A {1, 0}, gives us the correct broadcasting
// Same happens if the warpOrder is {0, 1}, like in Hopper
for (auto d : warpOrder) {
if (d == inner) {
warpLayout *= LinearLayout::zeros1D(warpShape[d], S("warp"), dimNames[d]);
for (auto d : order) {
if (d == kDim) {
layout *= LinearLayout::zeros1D(shape[d], inDimName, dimNames[d]);
} else {
warpLayout *=
LinearLayout::identity1D(warpShape[d], S("warp"), dimNames[d]);
layout *= LinearLayout::identity1D(shape[d], inDimName, dimNames[d]);
}
}
return warpLayout;
return layout;
}

} // anonymous namespace
Expand Down Expand Up @@ -620,7 +622,8 @@ wmmaDotOperandToLinearLayout(DotOperandEncodingAttr dotWmmaLayout,
// Generate warp layout
auto warpsPerCTA = wmmaLayout.getWarpsPerCTA();
auto warpOrder = triton::gpu::getWarpOrder(dotWmmaLayout);
LinearLayout warpLayout = warpsDotOperand(ctx, warpsPerCTA, warpOrder, kDim);
LinearLayout warpLayout =
broadcastedDotOperandLayout(ctx, warpsPerCTA, warpOrder, kDim, S("warp"));

// reorder dim names in rep order, so combineCtaCgaWithShape generate proper
// extension of layout
Expand Down Expand Up @@ -650,6 +653,48 @@ BlockedEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
}

std::optional<LinearLayout>
fmaDotToLinearLayout(DotOperandEncodingAttr operandLayout,
ArrayRef<int64_t> shape) {
int rank = shape.size();
auto blocked = cast<BlockedEncodingAttr>(operandLayout.getParent());
MLIRContext *ctx = operandLayout.getContext();

// TODO: introduce registerOrder or use getOrder(operandLayout)
// Currently this order is used in legacy converter, because we do not
// have access to full dot operand layout, only parent part.
auto regOrder = blocked.getOrder();
// TODO: use operandLayout.getThreadOrder()
auto threadOrder = blocked.getThreadOrder();
auto warpOrder = blocked.getWarpOrder();
auto repOrder = blocked.getRepOrder();

StringAttr kReg = S("register");
StringAttr kLane = S("lane");
StringAttr kWarp = S("warp");

SmallVector<unsigned> threadSize = blocked.getSizePerThread();
auto kDimIdx = operandLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
threadSize[kDimIdx] = shape[kDimIdx];
auto threadShape = blocked.getThreadsPerWarp();
auto warpShape = blocked.getWarpsPerCTA();

SmallVector<StringAttr> repDimNames =
permuteDimNames(standardOutDimNames(ctx, rank), repOrder);

auto registersLayout = identityStandardND(kReg, threadSize, regOrder);
auto lanesLayout = broadcastedDotOperandLayout(ctx, threadShape, threadOrder,
kDimIdx, kLane);
auto warpsLayout =
broadcastedDotOperandLayout(ctx, warpShape, warpOrder, kDimIdx, kWarp);

LinearLayout ctaLayout = registersLayout.transposeOuts(repDimNames) *
lanesLayout.transposeOuts(repDimNames) *
warpsLayout.transposeOuts(repDimNames);

return combineCtaCgaWithShape(ctaLayout, getCTALayout(operandLayout), shape);
}

LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
unsigned kWidth, ArrayRef<unsigned> order,
ArrayRef<unsigned> repOrder) {
Expand Down Expand Up @@ -740,19 +785,21 @@ LinearLayout nvidiaDotToLinearLayout(ArrayRef<int64_t> shape,
auto ctaLayout =
nvidiaMmaTile(ctx, tileShape, kWidth, getOrder(dot), dot.getRepOrder());
auto kDim = isA ? rank - 1 : rank - 2;
ctaLayout *=
warpsDotOperand(ctx, mma.getWarpsPerCTA(), mma.getWarpOrder(), kDim)
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));
ctaLayout *= broadcastedDotOperandLayout(ctx, mma.getWarpsPerCTA(),
mma.getWarpOrder(), kDim, S("warp"))
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));

return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape);
}

std::optional<LinearLayout>
DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
auto parent = getParent();
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(parent)) {
return fmaDotToLinearLayout(*this, shape);
} else if (auto mfmaLayout = mlir::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
return mfmaDotToLinearLayout(*this, shape);
} else if (auto wmmaLayout = llvm::dyn_cast<AMDWmmaEncodingAttr>(parent)) {
} else if (auto wmmaLayout = mlir::dyn_cast<AMDWmmaEncodingAttr>(parent)) {
return wmmaDotOperandToLinearLayout(*this, shape);
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
return nvidiaDotToLinearLayout(shape, *this);
Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/amd/decompose-unsupported-conversions.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx940", "ttg.threads-per-warp" = 64 : i32} {
tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<32x32xf16, #blocked>) {
tt.func @neg_blocked_to_dot_op_incompatible_warp_gfx940(%arg0: tensor<128x128xf16, #blocked>) {
// CHECK-NOT: ttg.convert_layout
// CHECK: ttg.local_alloc
// CHECK: ttg.local_load
%0 = ttg.convert_layout %arg0 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>>
%0 = ttg.convert_layout %arg0 : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked1}>>
tt.return
}
}
Loading

0 comments on commit d9facf3

Please sign in to comment.