Skip to content

Commit

Permalink
Update README.md with Warp Specialization Support
Browse files Browse the repository at this point in the history
  • Loading branch information
htyu authored and bertmaher committed Nov 18, 2024
1 parent 4c399ae commit 55b6d80
Showing 1 changed file with 150 additions and 9 deletions.
159 changes: 150 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,159 @@
<div align="center">
<img src="https://lh5.googleusercontent.com/wzQKEsTFkrgNQO9JjhGH5wFvslJr1saLtLaJ_a6Fp_gNENpvt3VG7BmztwngU9hFJaU4CPwGiw1opQtDvTkLrxWRbO_a12Q-pdESWHgtmheIHcPbOL5ZMC4TSiJVe5ty1w=w3517" alt="Triton logo">
</div>

The Triton Conference is happening again on September 17th, 2024 in Fremont (CA)!
# Warp Specialization Support


Warp specialization enhances kernel performance by utilizing an asynchronous execution model, where different parts of the kernel are handled by separate hardware units. The data communication between these units, via shared memory on the H100, operates with high efficiency. With this in mind, we’ve developed a Triton DSL extension that allows users to partition their kernel into asynchronous tasks (which map to warp groups on NVIDIA GPU), which naturally execute concurrently, leveraging the hardware’s multitasking warp scheduler. The following sections provide a breakdown of the compiler features developed to enable warp specialization.


## Asynchronous Tasks

Warp specialization is built on top of the concept of partitioning the user’s program into asynchronous tasks (referred to as "async tasks" or “tasks” in the following sections). Each async task will be executed by a standalone warp group on the supported hardware, to achieve instruction level parallelism. Optimally and automatically partitioning async tasks is quite a challenge for the compiler. As a result, the Triton DSL has been extended to allow users to perform manual partitioning.

The language extension is built around the Python context manager, designed to be simple and intuitive. Such extension is platform-agnostic, i.e., on platforms where warp specialization is not supported, the user annotation will be ignored, with no impact on correctness and performance.

For instance, a warp-specialized GEMM implementation might look like this:

```python
@triton.jit
def matmul_persistent_ws_kernel(
a_ptr, b_ptr, c_ptr, M, N, K,
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
pid_m = pid // num_pid_m
pid_n = pid % num_pid_n
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
# Use tl.async_task to specify warp-specialized code
with tl.async_task([0]):
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
acc += tl.dot(a, b)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
c = acc.to(tl.float16)
c_ptrs = c_ptr + stride_cm * offs_m[:, None] + stride_cn * offs_n[None, :]
# Use tl.async_task to specify warp-specialized code
with tl.async_task([1]):
tl.store(c_ptrs, c)
```

By wrapping a code block within the **tl.async_task** statement, the user specifies that the block will be executed by a certain number of warp groups, as defined by the statement. In the example above, the load operations are assigned to task 0, while the store operations are handled by task 1. Operations that are explicitly specified with a task id are known as anchor operations, and they affect the task assignment for the remaining operations.


The non-anchor operations are assigned to a task by the compiler in the following way:

- Control dependencies exclusive to an anchor operation are included in the same task as the anchor operation.
- Data dependencies exclusive to an anchor operation are included in the same task as the anchor operation, unless they are another anchor operation.
- Control or data dependencies shared between tasks are included in all those tasks.

For the GEMM example above, the compiler computes a task scheme and annotates it in the IR using MLIR attributes. To illustrate this more clearly, let's use source code annotations. After task propagation:

```python
@triton.jit
def matmul_persistent_ws_kernel(
a_ptr, b_ptr, c_ptr, M, N, K,
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0) # async_task 0, 1
num_pid_m = tl.cdiv(M, BLOCK_M) # async_task 0, 1
num_pid_n = tl.cdiv(N, BLOCK_N) # async_task 0, 1
pid_m = pid // num_pid_m # async_task 0, 1
pid_n = pid % num_pid_n # async_task 0, 1
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # async_task 0, 1
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # async_task 0, 1
offs_k = tl.arange(0, BLOCK_K) # async_task 0
a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) # async_task 0
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) # async_task 0
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) # async_task 1
for k in range(0, tl.cdiv(K, BLOCK_K)): # async_task 0, 1
a = tl.load(a_ptrs) # async_task 0
b = tl.load(b_ptrs) # async_task 0
acc += tl.dot(a, b) # async_task 1
a_ptrs += BLOCK_K * stride_ak # async_task 0
b_ptrs += BLOCK_K * stride_bk # async_task 0
c = acc.to(tl.float16) # async_task 1
c_ptrs = c_ptr + stride_cm * offs_m[:, None] + stride_cn * offs_n[None, :] # async_task 1
tl.store(c_ptrs, c) # async_task 1
```

## Data Partitioning

To further improve performance, the user may choose to split the same workload across two async tasks This way, when one task is blocked on a heavy computation (e.g., the dot operation), the other group can execute other operations in parallel. This can be easily achieved by annotating the store operation with two tasks:

```python
with tl.async_task([1,2]):
tl.store(c_ptr)
```

The compiler determines how to divide the work between the two tasks to maximize performance. On the H100 GPU, the compiler will, by default, attempt to split the input tensor A along the M dimension so that each consumer computes half of the output tensor independently. This approach is known as cooperative partitioning. If this split is not advantageous—for instance, if it results in a smaller-than-native `wgmma` instruction—the compiler will instead attempt to split along the N dimension.

The transformed code for the above GEMM kernel with a configured tile size [128, 256, 64] will look like below (using source annotations instead of IR for illustration)

```python
@triton.jit
def matmul_persistent_ws_kernel(
a_ptr, b_ptr, c_ptr, M, N, K,
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0) # async_task 0, 1, 2
num_pid_m = tl.cdiv(M, BLOCK_M) # async_task 0, 1, 2
num_pid_n = tl.cdiv(N, BLOCK_N) # async_task 0, 1, 2
pid_m = pid // num_pid_m # async_task 0, 1, 2
pid_n = pid % num_pid_n # async_task 0, 1, 2
offs_m_1 = pid_m * BLOCK_M + tl.arange(0, BLOCK_M // 2) # async_task 0, 1, 2
offs_m_2 = pid_m * BLOCK_M + tl.arange(BLOCK_M // 2, BLOCK_M) # async_task 0, 1, 2
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_N) # async_task 0, 1, 2
offs_k = tl.arange(0, BLOCK_K) # async_task 0
a_ptrs_1 = a_ptr + (offs_m_1[:, None] * stride_am + offs_k[None, :] * stride_ak) # async_task 0
a_ptrs_2 = a_ptr + (offs_m_2[:, None] * stride_am + offs_k[None, :] * stride_ak) # async_task 0
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) # async_task 0
acc_1 = tl.zeros((BLOCK_M // 2, BLOCK_N), dtype=tl.float32) # async_task 1
acc_1 = tl.zeros((BLOCK_M // 2, BLOCK_N), dtype=tl.float32) # async_task 2
for k in range(0, tl.cdiv(K, BLOCK_K)): # async_task 0, 1, 2
a_1 = tl.load(a_ptrs_1) # async_task 0
a_2 = tl.load(a_ptrs_2) # async_task 0
b = tl.load(b_ptrs) # async_task 0
acc_1 += tl.dot(a_1, b) # async_task 1
acc_2 += tl.dot(a_2, b) # async_task 2
a_ptrs_1 += BLOCK_K * stride_ak # async_task 0
a_ptrs_2 += BLOCK_K * stride_ak # async_task 0
b_ptrs += BLOCK_K * stride_bk # async_task 0
c_1 = acc_1.to(tl.float16) # async_task 1
c_2 = acc_2.to(tl.float16) # async_task 2
c_ptrs_1 = c_ptr_1 + stride_cm * offs_m_1[:, None] + stride_cn * offs_n[None, :] # async_task 1
c_ptrs_2 = c_ptr_2 + stride_cm * offs_m_2[:, None] + stride_cn * offs_n[None, :] # async_task 2
tl.store(c_ptrs_1, c_1) # async_task 1
tl.store(c_ptrs_2, c_2) # async_task 2
```

## Code Partitioning

We assume all operations are already marked with a list of taskIds. We first find all communications required between warp groups. Each communication starts from a load operation with a single taskId, and ends at a direct user of the load which belongs to a different taskId. For `ForOps` containing a communication channel, we add additional arguments: `phase` and `bufferIndex`.

We introduce a tuning configuration: `num_buffers_warp_spec`. For each communication channel, if it is within a `forOp`, we use an array of buffers in SMEM to save the results, and size of the array is determined by `num_buffers_warp_spec`. We also use an array of barriers for each communication channel that is inside a `ForOp`. At this pass, four new operations are introduced to correctly synchronize between the producer and the consumer: `ProducerAcquireOp`, `ProducerCommitOp`, `ConsumerWaitOp`, and `ConsumerReleaseOp`. Each of the four new ops take a token, a buffer Index. `ProducerAcquire` and `ConsumerWait` take an additional phase operand.


For `ForOps` with multiple task Ids, we clone one copy for each taskId, each copy contains the operations with the specific taskId. In the end, we create multiple `IfOps`, one for each possible taskId. We go through the body of the function, clone the op for each attached task Id and put the cloned op in the right `IfOp`.

If you are interested in attending, please fill up [this form](https://docs.google.com/forms/d/e/1FAIpQLSecHC1lkalcm0h3JDUbspekDX5bmBvMxgVTLaK3e-61bzDDbg/viewform).
To adjust register usage, we introduce two new ops: `RegAllocOp` and `RegDeallocOp`, both taking an integer operand. For each warp group, we decide to insert either `RegAllocOp` or `RegDeallocOp`. The current heuristic is simple: if the task Id is 0, we add `RegDeallocOp`, otherwise we use `RegAllocOp`. The amount of register adjustment can be tuned via `reg_dec_producer` and `reg_inc_consumer`.

This pass also lowers `loadOp`s to `AsyncTMACopyGlobalToLocalOp` or `AsyncCopyGlobalToLocalOp`, so the communication can be expressed via SMEM. For TMA, the producer will become
`ProducerAcquire` -> `barrier_expect` -> `AsyncTMACopyGlobalToLocalOp`, and the consumer will contain `wait_barrier` -> ops -> `ConsumerRelease`. For non-TMA loads, the producer will become `ProducerAcquire` -> `AsyncCopyGlobalToLocalOp` -> `ProducerCommitOp`, and the consumer will contain `ConsumerWaitOp` -> ops -> `ConsumerRelease`.

| **`Documentation`** | **`Nightly Wheels`** |
|-------------------- | -------------------- |
| [![Documentation](https://github.com/triton-lang/triton/actions/workflows/documentation.yml/badge.svg)](https://triton-lang.org/) | [![Wheels](https://github.com/triton-lang/triton/actions/workflows/wheels.yml/badge.svg?branch=release/2.0.x)](https://github.com/triton-lang/triton/actions/workflows/wheels.yml) |


# Triton
# More about Triton

This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write fast code at higher productivity than CUDA, but also with higher flexibility than other existing DSLs.

Expand Down

0 comments on commit 55b6d80

Please sign in to comment.