Skip to content

Commit

Permalink
[LAYOUTS] Enable ReduceOp lowering with LinearEncodingAttr (#5477)
Browse files Browse the repository at this point in the history
ReduceOp's lowering does not support linear layouts yet, so propagation
of linear layouts across `tt.reduce` ops will cause codegen to crash.
The codegen routine is generic enough to support linear layouts, so just
enable it and add a few tests.
  • Loading branch information
Mogball authored Dec 20, 2024
1 parent 75fb922 commit 755d416
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 4 deletions.
7 changes: 3 additions & 4 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,14 @@ bool ReduceOpHelper::isSupportedLayout() {
}

auto srcLayout = getSrcLayout();
if (isa<BlockedEncodingAttr>(srcLayout)) {
if (isa<BlockedEncodingAttr, LinearEncodingAttr, SliceEncodingAttr>(
srcLayout)) {
return true;
}

if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(srcLayout)) {
return mmaLayout.supportReduction();
}
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(srcLayout)) {
return true;
}
return false;
}

Expand Down
70 changes: 70 additions & 0 deletions test/Conversion/reduce_to_llvm.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm --convert-nv-gpu-to-llvm | mlir-translate -mlir-to-llvmir | opt -S -O1 | FileCheck %s

#linear = #ttg.linear<{register = [[0, 2], [2, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {

// CHECK-LABEL: @reduce_linear_layout
tt.func private @reduce_linear_layout(%arg0: tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> {
// CHECK-NEXT: [[SRC0:%.*]] = extractvalue {{.*}} %0, 0
// CHECK-NEXT: [[SRC1:%.*]] = extractvalue {{.*}} %0, 1
// CHECK-NEXT: [[SRC2:%.*]] = extractvalue {{.*}} %0, 2
// CHECK-NEXT: [[SRC3:%.*]] = extractvalue {{.*}} %0, 3

// The layout looks lke
// [[ T0:0, T32:0, T0:1, T32:1, ...
// [ T4:0, T36:0, T4:1, T36:1, ...
// [ T0:2, T32:2, T0:3, T32:3, ...
// [ T4:2, T36:2, T4:3, T36:3,
// ...
//
// A reduction along axis=0 consists of adding registers (0, 2) and (1, 3)
// before shuffling.
//
// Columns along axis=0 are contained within a warp, so reduction arcoss warps
// is not needed.

// Reduce within threads
// CHECK-NEXT: [[SUM0:%.*]] = add i32 [[SRC0]], [[SRC2]]
// CHECK-NEXT: [[SUM1:%.*]] = add i32 [[SRC1]], [[SRC3]]

// Reduce within warp.
// CHECK-NEXT: [[W0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[SUM0]], i32 16, i32 31)
// CHECK-NEXT: [[WSUM0:%.*]] = add i32 [[W0]], [[SUM0]]
// CHECK-NEXT: [[W1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM0]], i32 8, i32 31)
// CHECK-NEXT: [[WSUM1:%.*]] = add i32 [[WSUM0]], [[W1]]
// CHECK-NEXT: [[W2:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM1]], i32 4, i32 31)
// CHECK-NEXT: [[WSUM2:%.*]] = add i32 [[WSUM1]], [[W2]]
// CHECK-NEXT: [[W3:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM2]], i32 2, i32 31)
// CHECK-NEXT: [[WSUM3:%.*]] = add i32 [[WSUM2]], [[W3]]

// CHECK-NEXT: [[W4:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[SUM1]], i32 16, i32 31)
// CHECK-NEXT: [[WSUM4:%.*]] = add i32 [[W4]], [[SUM1]]
// CHECK-NEXT: [[W5:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM4]], i32 8, i32 31)
// CHECK-NEXT: [[WSUM5:%.*]] = add i32 [[WSUM4]], [[W5]]
// CHECK-NEXT: [[W6:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM5]], i32 4, i32 31)
// CHECK-NEXT: [[WSUM6:%.*]] = add i32 [[WSUM5]], [[W6]]
// CHECK-NEXT: [[W7:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM6]], i32 2, i32 31)
// CHECK-NEXT: [[WSUM7:%.*]] = add i32 [[WSUM6]], [[W7]]

// CHECK-NEXT: [[DST0:%.*]] = insertvalue { i32, i32 } undef, i32 [[WSUM3]], 0
// CHECK-NEXT: [[DST1:%.*]] = insertvalue { i32, i32 } [[DST0]], i32 [[WSUM7]], 1

%0 = "tt.reduce"(%arg0) ({
^bb0(%arg1: i32, %arg2: i32):
%1 = arith.addi %arg1, %arg2 : i32
tt.reduce.return %1 : i32
}) {axis = 0 : i32} : (tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>>

// CHECK-NEXT: ret { i32, i32 } [[DST1]]
tt.return %0 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>>
}

tt.func @anchor(%ptr: !llvm.ptr, %arg0: tensor<32x16xi32, #linear>) {
%0 = tt.call @reduce_linear_layout(%arg0) : (tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>>
%1 = builtin.unrealized_conversion_cast %0 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> to !llvm.struct<(i32, i32)>
llvm.store volatile %1, %ptr : !llvm.struct<(i32, i32)>, !llvm.ptr
tt.return
}

}
23 changes: 23 additions & 0 deletions test/TritonGPU/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2829,3 +2829,26 @@ tt.func @remat_across_regions(%arg0: i1, %arg1: tensor<8x8xf32, #blocked>) {
}

}

// -----

#linear = #ttg.linear<{register = [[1, 0], [0, 8], [0, 16]], lane = [[2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [[0, 2], [0, 4]], block = []}>
#blocked = #ttg.blocked<{sizePerThread = [2, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [1, 0]}>

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {

// CHECK-LABEL: reduce_linear_layouts
tt.func @reduce_linear_layouts(%arg0: tensor<32x32xi32, #linear>) -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>> {
// CHECK-NOT: convert_layout
%0 = ttg.convert_layout %arg0 : tensor<32x32xi32, #linear> -> tensor<32x32xi32, #blocked>
// CHECK-NEXT: tt.reduce
%1 = "tt.reduce" (%0) ({
^bb0(%arg1: i32, %arg2: i32):
tt.reduce.return %arg1 : i32
// CHECK: (tensor<32x32xi32, #linear>) -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>
}) {axis = 1 : i32} : (tensor<32x32xi32, #blocked>) -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%2 = ttg.convert_layout %1 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>>
tt.return %2 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>>
}

}

0 comments on commit 755d416

Please sign in to comment.