[AIROCMLIR-499] Add support for sliding window masking for attention#2240
[AIROCMLIR-499] Add support for sliding window masking for attention#2240justinrosner merged 15 commits intodevelopfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR implements sliding window attention support combined with KVCache in rocMLIR. Sliding window attention masks key positions before max(0, currentSeqLen - windowSize) with -inf, effectively limiting attention to the last W positions relative to the current decoded position. This is achieved through pattern detection in TosaToRock conversion and masking logic in GridwiseGemmToBlockwise lowering.
Changes:
- Added optional
slidingWindowSizeattribute torock.attentionandrock.gridwise_attention_acceloperations with validation requiringcurrentSeqLento be set - Implemented pattern detection in TosaToRock for sliding window masks (detecting
greater(seqLen + negative_offset, col_indices)) and clip patterns on currentSeqLen - Extended GridwiseGemmToBlockwise to apply sliding window masking (independent of causal/KVCache masking) using precomputed lower bound
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| mlir/include/mlir/Dialect/Rock/IR/RockOps.td | Added slidingWindowSize optional attribute to AttentionOp and GridwiseAttentionAccelOp definitions with documentation |
| mlir/lib/Dialect/Rock/IR/RockDialect.cpp | Added validation logic for slidingWindowSize attribute (must be positive and requires currentSeqLen) |
| mlir/lib/Conversion/TosaToRock/TosaToRock.cpp | Implemented pattern detection for sliding window masks, clip pattern detection for currentSeqLen, and clip application during rewrite |
| mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp | Added sliding window masking logic with lower bound computation and masking on every iteration |
| mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp | Updated to pass slidingWindowSize attribute through the lowering pipeline |
| mlir/lib/Dialect/Rock/Transforms/DetectFlashDecoding.cpp | Propagated slidingWindowSize attribute when creating new AttentionOp |
| mlir/lib/Dialect/Rock/Transforms/SortDimensionsMemoryLayout.cpp | Propagated slidingWindowSize attribute during dimension sorting transformation |
| mlir/tools/rocmlir-gen/rocmlir-gen.cpp | Added nullptr for slidingWindowSize in generated attention kernels (feature not yet supported in codegen) |
| mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir | Added comprehensive test with FileCheck patterns validating sliding window masking behavior |
| mlir/test/Conversion/TosaToRock/tosa-to-rock-attention-kvcache.mlir | Added test for sliding window pattern detection and clip pattern handling |
| mlir/test/fusion/pr-e2e/attention/mixr-attention-sliding-window-kvcache.mlir | Added end-to-end test for sliding window attention with KVCache |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // Validate sliding window constraints | ||
| if (getSlidingWindowSize()) { | ||
| int32_t windowSize = static_cast<int32_t>(*getSlidingWindowSize()); | ||
| if (windowSize <= 0) | ||
| return emitError("slidingWindowSize must be positive"); | ||
| if (!getCurrentSeqLen()) | ||
| return emitError("slidingWindowSize requires currentSeqLen to be set"); | ||
| } |
There was a problem hiding this comment.
Missing validation: slidingWindowSize should only be allowed when enableSoftmax is true (for attention operations). Similar to the checks for currentSeqLen (line 2826), prefixOffset (line 2829), and causal (line 2832), there should be a check that rejects slidingWindowSize when enableSoftmax is false, since sliding window masking only makes sense in the context of attention operations.
| // Sliding window masking: mask when key_pos < max(0, currentSeqLen - | ||
| // windowSize). This is independent of causal masking and applies | ||
| // alongside KV-cache masking. | ||
| if (slidingWindowSize > 0) { |
There was a problem hiding this comment.
I think all the "ifs" around setGemm0OutputOutOfScope() are not necessary, because setGemm0OutputOutOfScope() internally checks if it needs to run.
There was a problem hiding this comment.
True, but do you find that it makes the code a bit more unreadable to do that? I'm fine either way, just wondering what your opinion is.
There was a problem hiding this comment.
I'm fine either way, I just think we need to remove of of the ifs, either the outer one of the inner one
There was a problem hiding this comment.
I've removed it in all of the cases except for the if/else block for prefix causal. Because prefixCausal also requires that causal be set, running the logic for both masks would not be correct in this case.
umangyadav
left a comment
There was a problem hiding this comment.
I had some similar concerns as daniels. But apart from that looks good to me.
b5382ea to
10cd27e
Compare
…2240) * Initial TosaToRock changes for sliding window attention * Add GridwiseGemmToBlockwise masking support * Small bug fixes * Small refactor * More refactoring * Clang-format * Add LIT tests * Attend to review comments * Add rocmlir-gen and PR LIT tests * Add E2E tests * Clang-format * Attend to more review comments * Clang-format
Automated weekly review of merged PRs #2234 #2240 #2248 #2249 #2251 #2254 #2257 #2258 #2259 #2270 #2271. Identifies 6 areas with weak test coverage and meaningful business risk: 1. ConcurrentQueue (no unit tests, multi-threaded, silent deadlock risk) 2. parse_tuning_db_line / read_tuning_db key schema change (no Python tests) 3. BooleanElementwiseConverter missing f16/unsigned dtype coverage 4. Attention MaxNumFOp vs MaximumFOp NaN correctness (no dedicated test) 5. firstCausalMaskIter off-by-one risk (no non-trivial shape test) 6. Sliding window attention edge cases (windowSize=0/>=seqLen/unaligned) The GitHub discussion API returned FORBIDDEN (read-only CI token); analysis committed here as a permanent record. Co-authored-by: Djordje Antic <djordje.antic@amd.com>
Motivation
This PR adds support for sliding window attention combined with KVCache in rocMLIR. Sliding window attention limits each query to attent only to the last W key positions (relative to the current decoded position), masking earlier positions with
-inf.This implements: https://amd-hub.atlassian.net/browse/AIROCMLIR-499
Technical Details
slidingWindowSizeattribute torock.attentionopscurrentSeqLenis also settrySlidingWindowPatternin TosaToRock to detect the sliding window mask patterntryClipPatternfor KVCache pattern detection in TosaToRockclip(arg, lo, hi)Test Plan
Test Result
Submission Checklist