Skip to content

[Feat]: Add 2D-tiled causal_conv1d prefill kernel for gated delta net #1104

Open
yiijin wants to merge 9 commits into
ROCm:mainfrom
yiijin:conv
Open

[Feat]: Add 2D-tiled causal_conv1d prefill kernel for gated delta net #1104
yiijin wants to merge 9 commits into
ROCm:mainfrom
yiijin:conv

Conversation

@yiijin
Copy link
Copy Markdown

@yiijin yiijin commented Jun 5, 2026

Motivation

Add a 2D tiled prefill kernel variant for causal conv1d that achieves ~3.4x speedup over the original atom kernel by processing multiple tokens per thread block simultaneously.

Technical Details

  • Add _causal_conv1d_fwd_kernel_tile: 2D tiled kernel with [BLOCK_N, BLOCK_M] coalesced loads
  • Replace per-token loop with batch tile loads + vectorized convolution + fused SiLU via v_rcp_f32
  • Add _causal_conv1d_fn_tile wrapper with configurable BLOCK_M/BLOCK_N/num_warps
  • Environment variable ATOM_CAUSAL_CONV1D_KERNEL=nontile to fallback to original kernel
  • Default config: BLOCK_M=64, BLOCK_N=32, num_warps=4

Test Plan

Test Result

image

Submission Checklist

Copilot AI review requested due to automatic review settings June 5, 2026 10:05
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a new 2D-tiled Triton prefill kernel for causal_conv1d (used by Gated Delta Net / Mamba-style ops) to improve prefill performance by processing multiple tokens per program instance, with an env-var switch to fall back to the original 1D per-token kernel.

Changes:

  • Added _causal_conv1d_fwd_kernel_tile and _causal_conv1d_fn_tile implementing a 2D-tiled [BLOCK_N, BLOCK_M] prefill path with fused SiLU.
  • Added a dispatcher in causal_conv1d_fn to select tiled vs original kernel via ATOM_CAUSAL_CONV1D_KERNEL=nontile.
  • Updated compute_causal_conv1d_metadata to generate metadata for both BLOCK_M=8 and BLOCK_M=64.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
atom/model_ops/mamba_ops/causal_conv1d.py Adds the 2D-tiled prefill kernel + Python wrapper and env-var dispatch.
atom/model_ops/attentions/gdn_attn.py Extends causal-conv metadata generation to support the new BLOCK_M=64 tiled path.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread atom/model_ops/mamba_ops/causal_conv1d.py
Comment thread atom/model_ops/mamba_ops/causal_conv1d.py Outdated
Copilot AI review requested due to automatic review settings June 5, 2026 10:27
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.

Comment on lines +1828 to +1831
dim, cu_seqlen = x.shape
_, width = weight.shape
state_len = width - 1
np2_statelen = triton.next_power_of_2(state_len)
Comment thread atom/model_ops/mamba_ops/causal_conv1d.py
Comment thread atom/model_ops/mamba_ops/causal_conv1d.py Outdated
Comment thread atom/model_ops/mamba_ops/causal_conv1d.py
Comment thread atom/model_ops/mamba_ops/causal_conv1d.py Outdated
Copilot AI review requested due to automatic review settings June 5, 2026 10:38
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

Comment thread atom/model_ops/mamba_ops/causal_conv1d.py
Comment thread atom/model_ops/mamba_ops/causal_conv1d.py
Comment thread atom/model_ops/mamba_ops/causal_conv1d.py
yiijin and others added 9 commits June 7, 2026 23:19
- Add `idx_feats < dim` guard to `is_v_block` to prevent out-of-bounds
  stores when dim is not a multiple of BLOCK_N.
- Remove unused `original_x_dtype` assignment in `_causal_conv1d_fn_tile`.

Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
…ride check

- Assert KERNEL_WIDTH in {2, 3, 4} to fail fast on unsupported widths.
- Validate metadata contains the selected BLOCK_M key.
- Tighten stride check to require stride(0)==1 (channel-last), consistent
  with the original _causal_conv1d_fn.

Co-authored-by: Cursor <cursoragent@cursor.com>
Copilot AI review requested due to automatic review settings June 8, 2026 04:19
@yiijin yiijin changed the title Add 2D-tiled causal_conv1d prefill kernel for gated delta net [Feat]: Add 2D-tiled causal_conv1d prefill kernel for gated delta net Jun 8, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated no new comments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants