Skip to content

Commit

Permalink
[LAYOUTS] Fix LinearEncodingAttr interpretation of elemsPerThread (…
Browse files Browse the repository at this point in the history
…#5476)

LinearEncodingAttr's interpretation of `elemsPerThread` appears to be
different from that of the TypeConverter. For
`register=[[0,0],[0,0],[0,0],[0,0]]` for example, LinearEncodingAttr
thinks there is 1 element per thread, but the lowering thinks there are
16 (and some for `tt.reshape`).

This fixes that by not skipping broadcasted bases when calling
`::getElemsPerThread` for registers.
  • Loading branch information
Mogball authored Dec 20, 2024
1 parent f12eae0 commit 75fb922
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
6 changes: 3 additions & 3 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1441,7 +1441,7 @@ SmallVector<unsigned> basesPerDim(const LinearLayout::BasesT &namedBases,

SmallVector<unsigned> ret(rank, 1);
auto nonZero = [](auto val) { return val != 0; };
int nonZeroIdx = -1;
int nonZeroIdx = 0;
for (const auto &basis : bases) {
auto it = std::find_if(basis.begin(), basis.end(), nonZero);
// Bases can have one or zero non-zero elements
Expand All @@ -1453,7 +1453,6 @@ SmallVector<unsigned> basesPerDim(const LinearLayout::BasesT &namedBases,
} else if (!skipBroadcast) {
// If we've seen a non-zero basis, we double the size of the previous dim
// This is just needed to count the CTAsPerCGA
assert(nonZeroIdx != -1);
ret[nonZeroIdx] *= 2;
}
}
Expand Down Expand Up @@ -1604,7 +1603,8 @@ LinearEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type) const {
// the invariant that the shape of the LL is that of the tensor
// We choose the former for BC
auto ll = *toLinearLayout(shape);
return basesPerDim(ll, StringAttr::get(getContext(), "register"));
return basesPerDim(ll, StringAttr::get(getContext(), "register"),
/*skipBroadcast=*/false);
}

// Start of Selection
Expand Down
25 changes: 25 additions & 0 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2091,3 +2091,28 @@ tt.func @upcast_mxfp(%arg0: tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #m
}

}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1, 1, 16], threadsPerWarp = [4, 4, 2], warpsPerCTA = [8, 1, 1], order = [2, 1, 0]}>
#linear = #ttg.linear<{register = [[0, 0], [0, 0], [0, 0], [0, 0]], lane = [[0, 0], [0, 1], [0, 2], [1, 0], [2, 0]], warp = [[4, 0], [8, 0], [16, 0]], block = []}>

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {

// CHECK-LABEL: expand_dims_linear_layout
tt.func private @expand_dims_linear_layout() -> tensor<1x4xi32, #linear> {
%0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #linear}>>
%1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #linear}>> -> tensor<1x4xi32, #linear>
// CHECK: return %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)>
tt.return %1 : tensor<1x4xi32, #linear>
}

// CHECK-LABEL: reshape_linear_layout_broadcasting
tt.func private @reshape_linear_layout_broadcasting(%arg0: tensor<32x4xbf16, #linear>) -> tensor<32x4x1xbf16, #blocked> {
// CHECK-COUNT-16: extractvalue
// CHECK-COUNT-16: insertvalue
%0 = tt.reshape %arg0 : tensor<32x4xbf16, #linear> -> tensor<32x4x1xbf16, #blocked>
tt.return %0 : tensor<32x4x1xbf16, #blocked>
}

}

0 comments on commit 75fb922

Please sign in to comment.