Skip to content

Commit

Permalink
[CodePartition] Optimize the placement of consumer releases (#11)
Browse files Browse the repository at this point in the history
Fix the inefficient placement of consumer releases when producer and
consumer are not in the same scope. For example, given

```
Q = tl.load
for (..) 
   K = tl.load
   QK = dot(Q, K)
   ...
tl.store
```

Previously the consumer release corresponding to `Q` was placed after
the store. With the current fix the release would go right after the for
loop.


`TORCH_CUDA_ARCH_LIST=9.0a python run.py --op flash_attention --only
triton_tutorial_flash_v2_tma_ws,triton_tutorial_flash_v2_tma_ws_persistent,triton_tutorial_flash_v2
--num-inputs 1 --seq-len 10 --metrics tflops --batch 1024 --n-heads 4
--d-head 128 --cudagraph`

Before:

  ```
(Batch, Heads, SeqLen, Dhead) triton_tutorial_flash_v2_tma_ws-tflops
triton_tutorial_flash_v2_tma_ws_persistent-tflops
triton_tutorial_flash_v2-tflops
------------------------------- ----------------------------------------
---------------------------------------------------
---------------------------------
(1024, 4, 1024, 128) 393.141 400.046 366.498

```
```

After:
(Batch, Heads, SeqLen, Dhead) triton_tutorial_flash_v2_tma_ws-tflops
triton_tutorial_flash_v2_tma_ws_persistent-tflops
triton_tutorial_flash_v2-tflops
------------------------------- ----------------------------------------
---------------------------------------------------
---------------------------------
(1024, 4, 1024, 128) 396.43 422.847 363.753

```
  • Loading branch information
htyu authored Dec 17, 2024
1 parent ab5f7c2 commit 8706035
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 10 deletions.
46 changes: 37 additions & 9 deletions lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1302,6 +1302,7 @@ optimizeTMALoads(OpBuilderWithAsyncTaskIds &builder,
// "channelsGroupedByConsumers". tokenMap tracks the set of tokens for each
// channel.
void insertAsyncComm(
triton::FuncOp funcOp,
const DenseMap<Channel *, SmallVector<Channel *>>
&channelsGroupedByConsumers,
const DenseMap<Channel *, DenseMap<int, Value>> &tokenMap,
Expand All @@ -1324,14 +1325,41 @@ void insertAsyncComm(
int consumerAsyncTaskId) -> Operation * {
if (c->getBlock() != p->getBlock())
return getSameLevelOp(p, c);
for (auto it = c->getBlock()->rbegin(); it != c->getBlock()->rend(); ++it) {
if (!it->hasAttr("async_task_id"))
continue;
auto asyncAttr = it->getAttrOfType<DenseIntElementsAttr>("async_task_id")
.getValues<int>();
if (asyncAttr.size() == 1 && asyncAttr[0] == consumerAsyncTaskId)
return &(*it);

// Find a common place for all users of the consumer, which would be the
// common post dominator.
mlir::PostDominanceInfo dom(funcOp);
std::unordered_set<Operation *> mutuallyNonDominatingUsers;
for (auto user : c->getUsers()) {
auto it = mutuallyNonDominatingUsers.begin();
while (it != mutuallyNonDominatingUsers.end()) {
if (dom.properlyPostDominates(user, *it)) {
it = mutuallyNonDominatingUsers.erase(it);
} else if (dom.properlyPostDominates(*it, user)) {
break;
} else {
++it;
}
}
if (it == mutuallyNonDominatingUsers.end())
mutuallyNonDominatingUsers.insert(user);
}

if (mutuallyNonDominatingUsers.size() == 1) {
// Find the common parent of this user and c
auto user = *mutuallyNonDominatingUsers.begin();
while (user && user->getParentOp() != c->getParentOp())
user = user->getParentOp();
assert(user && "Failed to find common parent of this user and c");
return user;
}

for (auto &op : reverse(c->getBlock()->getOperations())) {
auto asyncTasks = getAsyncTaskIds(&op);
if (asyncTasks.size() == 1 && asyncTasks[0] == consumerAsyncTaskId)
return &op;
}

return nullptr;
};

Expand Down Expand Up @@ -1623,8 +1651,8 @@ class TritonGPUWSCodePartitionPass

// Step 6: add async communication ops (ProducerAcquire etc). Also lower the
// loads.
insertAsyncComm(channelsGroupedByConsumers, tokenMap, barrierAllocMap,
bufferMap, numConsumerGroups);
insertAsyncComm(funcOp, channelsGroupedByConsumers, tokenMap,
barrierAllocMap, bufferMap, numConsumerGroups);
LLVM_DEBUG({
LDBG("\n\nwith SyncOps");
funcOp.dump();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK: scf.for
// CHECK: triton_gpu.local_load
// CHECK: triton_nvidia_gpu.consumer_wait
// CHECK: tt.experimental_descriptor_store
// CHECK: triton_nvidia_gpu.consumer_release
// CHECK: tt.experimental_descriptor_store

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
Expand Down

0 comments on commit 8706035

Please sign in to comment.