Flash Attention is a simple algorithm: tiled back-to-back matmuls with an online softmax algorithm in between. The algorithm fits in a few dozen lines of pseudocode. Yet Flash Attention 4's production kernel is 2,875 lines, and the hardest part to get right isn't the math. It's the async execution and pipelining synchronization, all hand-derived from a schedule that no standard debugging tool can verify.
In this multi-part series, we explore software pipelining for GPU kernels from first principles. We formalize dependencies as a graph, solve for the optimal schedule with a constraint solver, and show how it all integrates into MAX via pure Mojo. We've explored what this looks like in practice for matrix multiplication in Parts 2-4 of our matmul on Blackwell series.
Before diving into the details, this post gives an overview of how advanced kernels like Flash Attention 4 are expressed as dataflow pipelines. Future posts will discuss how this formalism allows us to reduce complexity and improves composability of modern designs.
Flash Attention 4 schedule complexity
Figure 1: Flash Attention 4 pipeline schedule on Blackwell (SM100) — 14 operations across 5 hardware units, T=4, 2 pipeline stages.
This is a legal pipeline schedule for Flash Attention 4 on Blackwell (SM100). It has 14 operations across 5 hardware units, with initiation interval and 2 pipeline stages. We will explain what this means later. Each row is a scheduling unit, and each colored block is an operation from a specific loop iteration. The prologue fills the pipeline, the kernel runs at steady state, and the epilogue drains.
A constraint solver produced this schedule from the dependency graph below, in milliseconds:
Figure 2: FA4 dependency graph — 28 edges across same-iteration, loop-carried, and anti-dependency classes.
The graph has 28 dependency edges: 18 same-iteration data flows, 6 loop-carried recurrences, and 4 write-after-read anti-dependencies for quad-buffered shared memory. Each edge is a synchronization constraint that the Mojo production kernel (mha_sm100, 1,900+ lines) enforces by hand.
Pretty complicated, but let's rewind and explain what attention really is.
The attention mechanism
We'll use Attention as the motivating example in this post. Specifically, we'll examine the Flash Attention (Dao et al. 2022, Dao 2023, Shah et al. 2024, Zadouri et al. 2026) algorithm to showcase the GPU pipelining problem. Attention is a good case study because the algorithm is simple, but getting peak hardware utilization requires a complex pipelining scheme. We will refer to the algorithm as FA going forward, with FA3 being the algorithm applicable to Hopper and FA4 being applicable to Blackwell.
What is attention
The multi-head attention operation computes:
where , , and output . Here is the query length (full sequence during prefill, 1 during decoding), is the KV-cache length, and is the head dimension.
As a concrete example, Qwen3-8B has , 32 Q heads sharing 8 KV heads via Grouped-Query Attention (GQA where the Q-to-KV ratio is 4:1), with context up to . GQA means 4 Q heads share each KV head, reducing KV-cache reads by for the same Q compute.
The computation has three steps: (matmul 1), (row-wise normalization), (matmul 2). The full score matrix is . For the maximum context length, that is a matrix per head. It’s a very large matrix, but notice that this is just an intermediate matrix in the computation. The flash attention algorithm solves this problem by loop fusion and loop tiling.
Back-to-back matmul: why tiling is critical
To solve the large intermediate problem, we can use tiling. To understand tiling strategy, let's simplify attention to just two back-to-back matmuls by setting softmax aside for a moment: . Expanding element-wise:
The naive approach (no intermediate , recompute the inner sum from scratch for each output ) costs FLOPs. Materializing first and then computing costs , which is a factor of fewer FLOPs, but is and far too large for registers or shared memory.
The solution is to observe that is small. A Q block is , which fits comfortably in the 228 KB of shared memory on a single SM. Instead of tiling , we tile only the sequence dimensions and :
This achieves both goals: FLOPs, and is only , produced and consumed entirely in registers. The full score matrix is never materialized. Q is loaded once into shared memory and reused every inner-loop iteration, while K and V tiles stream through.
Figure 3: Tiled attention — Q stays in shared memory while K and V tiles stream through.
How softmax breaks tiled attention
The tiling strategy above handles the two matmuls. But the attention computation has an interleaved softmax between the two matmuls. So, let's look at that. Here's the problem: softmax normalizes over the full row of :
With tiling, we only see one -wide slice of at a time. We don't have the full row's max (needed for numerical stability) or the full denominator. A naive approach requires two passes over the KV tiles: one to compute the global max and sum, one to normalize and multiply by .
Online softmax for single-pass attention
The solution is online softmax (Milakov and Gimelshein, "Online normalizer calculation for softmax" (2018), arXiv:1805.02867). Instead of two passes, Flash Attention maintains a running max , running sum , and corrects the output accumulator on the fly. Let . After processing the -th KV tile:
The correction factor rescales in all prior terms, maintaining the invariant that and after each tile. After the last tile, exactly.
When (the common case after the first few tiles), and the rescale is a no-op. More generally, when the change in is small enough to not cause numerical issues, the rescale can be skipped entirely. FA4 uses this optimization, and it has the update rule of the form m = new_max if new_max > m + 8 else m. Then, the maximum value after softmax is bounded by which is still safe for later operations to be overflow free.
Figure 4: Online softmax correction across tiles.
In pseudocode, the algorithm looks like the following(we've added hardware annotations to convey which parts of the hardware is used):
This is a single pass over K and V with work. The cost is that each iteration has a correction step (rescale) that depends on the previous iteration's result, creating a loop-carried dependency.
The unpipelined FA3 inner loop
Putting it all together, the FA3 kernel (SM90, Hopper) distills the online-softmax attention algorithm into 8 operations on 3 functional units. Two warp groups: @tma (producer, runs TMA loads) and @consumer (runs everything else). Three hardware units for scheduling: tma, tc (tensor cores), cuda (CUDA cores).
Two TMA ops (async, tracked by shared memory barriers), two tensor core ops (async, tracked via commit/wait groups), four CUDA core ops (sync, execute on the issuing warp).
The dependency structure has 13 edges: 9 same-iteration read-after-write dependencies, 2 loop-carried recurrences (), and 2 write-after-read anti-dependencies () for double-buffered shared memory.
Figure 5: FA3 dependency graph — 13 edges: 9 same-iteration, 2 loop-carried, 2 write-after-read.
Written sequentially, this is clean and correct. But this is also slow, since only one unit is active at a time.
The GPU pipelining problem
We have a correct, sequential inner loop. But the loop is slow since we are not fully utilizing the available hardware units. A modern GPU SM has at least three independently schedulable hardware units:
| Unit | What it does | Example ops |
|---|---|---|
| TMA (Tensor Memory Accelerator) | Async bulk memory transfers (GMEM SMEM) | tma_load, tma_store |
| Tensor Cores | Matrix multiply-accumulate | wgmma (SM90), umma (SM100) |
| CUDA Cores | Scalar/vector arithmetic, reductions, type conversions | softmax, rescale, cast |
These units can operate simultaneously, but the sequential loop above leaves most of them idle:
Pipelining overlaps iterations so all units run at once:
Figure 6: Unpipelined vs. pipelined execution — overlapping iterations keeps all hardware units active.
Figure 7: GPU SM block diagram — TMA, tensor cores, and CUDA cores operate as independent state machines.
The key architectural fact is that TMA, the tensor core MMA unit, and CUDA cores are separate hardware state machines with no implicit ordering between them.
TMA (available in SM90, SM100) is an autonomous DMA engine. You issue a tma_load and it runs independently, consuming no warp cycles. Completion is signaled through mbarrier objects: hardware barriers with a phase bit and an expected-bytes counter. The consumer calls mbar_wait() on the matching phase.
WGMMA (available in SM90) and UMMA (available in SM100) are async MMA units. You dispatch an asynchronous matmul with completion tracked by commit_group() / wait_group(N). On Blackwell, UMMA writes results to TMEM, a register-like space private to the tensor core unit that CUDA cores cannot access directly. On Hopper, WGMMA writes to registers. The scheduling problem is the same between the two generations (async dispatch, explicit completion tracking), but SM100 adds a third memory space to coordinate.
Every dependency between units must be manually synchronized, otherwise a silent data race is introduced. As you can imagine, this coordination is difficult.
Two levels of pipelining: inter-warp and intra-warp
To improve the performance of the inner loop, one has to overlap iterations. But the FA3 kernel actually has two levels of overlap unified in a single loop body:
Figure 8: Two levels of pipelining — inter-warp-group (producer/consumer) and intra-warp-group (tensor core/CUDA core overlap).
Level 1: Inter-warp-group. A dedicated producer warp group runs the TMA loads. A separate consumer warp group runs everything else. The two groups communicate through a shared memory ring buffer with mbarrier synchronization. K and V tiles are -buffered ( physical slots), so the producer can load ahead while the consumer reads from a different slot.
In the dependency graph, this level is encoded by write-after-read edges with (distance = buffer depth). The edge wgmma_qk tma_load_K means the reader in iteration must finish before the writer in iteration overwrites the same slot (). The manual equivalent is mbar_arrive(empty[slot]) after reading and mbar_wait(empty[slot]) before writing. That pair of barrier calls is one graph edge. Changing the buffer depth just changes .
Level 2: Intra-warp-group. Within the consumer warp group, tensor core ops and CUDA core ops overlap. wgmma_qk fires asynchronously, and while it runs, the CUDA cores work on rowmax/softmax/rescale from the previous iteration. This is managed through commit_group() / wait_group(N).
In the dependency graph, this level is encoded by 11 read-after-write edges with and on S, P, m, and O. The stage assignments the solver produces directly determine where to place each commit_group() and wait_group(N).
Why pipelined GPU kernels are hard to debug
Sounds easy to program 🙂? While pipelining solves the utilization problem, it creates a new one: the resulting code is extremely complex, difficult to debug, and difficult to adapt and reuse.
Figure 9: The pipelining boilerplate iceberg — algorithm logic is a fraction of total kernel code.
FA3 (flash_attn/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp + hopper/flash_fwd_kernel_sm90.h) is 1717+458=2175 lines. FA4 (flash_attn/cute/flash_fwd_sm100.py) is 2,875 lines. The algorithm itself (online softmax, masking, output accumulation) is a fraction of this. The rest is dispatch specialization, hardware-specific memory layouts, and pipelining synchronization woven throughout. And there is a tight coupling between the algorithm and the synchronization:
- Ring buffer indexing. K and V tiles share a circular buffer with physical slots in shared memory ().
- Barrier management.
mbar_expect_bytes(),mbar_wait(),mbar_arrive()with phase bit tracking. - Async group tracking.
commit_group()andwait_group(N)manage in-flight WGMMA/UMMA operations. The argument depends on the schedule. - Prologue and epilogue. The steady-state loop body assumes the pipeline is full with the first and last few iterations need hand-written special cases.
- Cross-memory-space movement. In FA4, TMEM is private to the tensor core unit. And, recall that moving data between TMEM and registers requires explicit synchronized transfers.
Every one of these is derived from the schedule, and every one is a place where a simple typo can create a race condition. As a result, these pipelines are difficult to program.
But aren't there tools?
NVIDIA's compute-sanitizer does capture a large array of errors, but it does not track TMA or async WGMMA/UMMA instructions. For example, running compute-sanitizer --tool racecheck would report nothing for TMA races. As a result, you're out of luck using sanitizers here.
Use a `printf` you say? Well, since the bugs are non-deterministic, depend on timing, input size, hardware configuration, and observing a `printf` does not work. Adding a printf can change the timing enough to mask a race. And even an unrelated refactor (changing a tile size, reordering a loop) can cause hidden races to surface.
The maintenance burden: adding an op means re-deriving the schedule, updating every wait_group, and adjusting prologue and epilogue. Porting from SM90 to SM100 means rewriting from scratch. FA4 shares almost no synchronization code with FA3 despite implementing the same algorithm.
Flash Attention 4: the problem gets worse
Everything above is for FA3 with 8 ops and 3 units. The problem scales up significantly with FA4, which processes two Q tiles simultaneously against the same KV stream on Datacenter Blackwell. The inner loop expands to 14 operations across 5 functional units:
- 5 inner-loop roles (512 threads total, 16 warps):
@load(warp 13, TMA),@mma(warp 12, UMMA),@s0(warps 0-3, CUDA),@s1(warps 4-7, CUDA),@corr(warps 8-11, CUDA). The remaining warps only deallocate registers. - 3 memory spaces: shared memory (K, V tiles), TMEM (S, P, O, private to tensor cores), and registers (m, old_m, row_sum).
- 28 dependency edges: 18 same-iteration, 6 loop-carried (), 4 write-after-read () for quad-buffered shared memory.
- Parallel Q0/Q1 channels sharing K and V loads with independent statistics and accumulators.
- Consecutive group fusion: rowmax softmax rowsum runs as an indivisible sequence per tile within each statistics warp group.
Figure 10: FA4 full dependency graph — the schedule timeline in Figure 1 is one valid solution satisfying these constraints.
This is the dependency graph from the opening. The schedule timeline at the top of this post is one valid schedule that satisfies the dependence relations. For FA3, the solver's schedule matches the hand-written production kernel exactly.
GPU pipelining beyond NVIDIA: AMD and other platforms
And this is not an NVIDIA-specific problem. The pattern is universal across hardware:
async producers + async consumers + shared buffers = pipelining problem.
| Platform | Async copy | Matrix unit | Synchronization |
|---|---|---|---|
| NVIDIA Hopper / Blackwell | TMA | WGMMA / UMMA | mbarrier, commit/wait |
| AMD CDNA / RDNA | Async copy engine | MFMA | LDS barriers |
Any kernel whose inner loop decomposes into operations on independently schedulable units faces the same scheduling and synchronization challenge.
What’s next: solving the schedule with constraint programming
So now we understand attention in detail and the complexities of programming the pipeline, in part 2 we present the solution. We will formulate this problem as a software pipelining problem where we represent all the dependencies in the loop body (8 operations, 13 edges for FA3). Next, we use modulo scheduling to restrict the search: since every iteration uses the same schedule, shifted by cycles. This turns the problem into an integer linear program that can be solved in milliseconds.
With a schedule in hand, we derive buffer lifetimes and copy counts, identifying how many physical slots each shared memory tile needs. We then compress the register usage by exploiting the liveness of variables.
Finally, we validate the framework on FA3, FA3 with FlexAttention, and FA4 before implementing the scheduler all in Mojo (Part 3).
Figure 11: FA3 solver-produced schedule — matches the hand-written production kernel exactly.
The dependency graph is the input; the pipelined code is the output.

