[Feat]: Add 2D-tiled causal_conv1d prefill kernel for gated delta net #1104
Open
yiijin wants to merge 9 commits into
Open
[Feat]: Add 2D-tiled causal_conv1d prefill kernel for gated delta net #1104yiijin wants to merge 9 commits into
yiijin wants to merge 9 commits into
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This PR introduces a new 2D-tiled Triton prefill kernel for causal_conv1d (used by Gated Delta Net / Mamba-style ops) to improve prefill performance by processing multiple tokens per program instance, with an env-var switch to fall back to the original 1D per-token kernel.
Changes:
- Added
_causal_conv1d_fwd_kernel_tileand_causal_conv1d_fn_tileimplementing a 2D-tiled[BLOCK_N, BLOCK_M]prefill path with fused SiLU. - Added a dispatcher in
causal_conv1d_fnto select tiled vs original kernel viaATOM_CAUSAL_CONV1D_KERNEL=nontile. - Updated
compute_causal_conv1d_metadatato generate metadata for bothBLOCK_M=8andBLOCK_M=64.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
atom/model_ops/mamba_ops/causal_conv1d.py |
Adds the 2D-tiled prefill kernel + Python wrapper and env-var dispatch. |
atom/model_ops/attentions/gdn_attn.py |
Extends causal-conv metadata generation to support the new BLOCK_M=64 tiled path. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+1828
to
+1831
| dim, cu_seqlen = x.shape | ||
| _, width = weight.shape | ||
| state_len = width - 1 | ||
| np2_statelen = triton.next_power_of_2(state_len) |
- Add `idx_feats < dim` guard to `is_v_block` to prevent out-of-bounds stores when dim is not a multiple of BLOCK_N. - Remove unused `original_x_dtype` assignment in `_causal_conv1d_fn_tile`. Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
…ride check
- Assert KERNEL_WIDTH in {2, 3, 4} to fail fast on unsupported widths.
- Validate metadata contains the selected BLOCK_M key.
- Tighten stride check to require stride(0)==1 (channel-last), consistent
with the original _causal_conv1d_fn.
Co-authored-by: Cursor <cursoragent@cursor.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Add a 2D tiled prefill kernel variant for causal conv1d that achieves ~3.4x speedup over the original atom kernel by processing multiple tokens per thread block simultaneously.
Technical Details
_causal_conv1d_fwd_kernel_tile: 2D tiled kernel with [BLOCK_N, BLOCK_M] coalesced loadsv_rcp_f32_causal_conv1d_fn_tilewrapper with configurable BLOCK_M/BLOCK_N/num_warpsATOM_CAUSAL_CONV1D_KERNEL=nontileto fallback to original kernelTest Plan
Test Result
Submission Checklist