Skip to content

[module_causal_conv1d_update] refactor hip kernel#3595

Open
amd-ruitang3 wants to merge 2 commits into
ROCm:mainfrom
amd-ruitang3:module_causal_conv1d_update_refactor
Open

[module_causal_conv1d_update] refactor hip kernel#3595
amd-ruitang3 wants to merge 2 commits into
ROCm:mainfrom
amd-ruitang3:module_causal_conv1d_update_refactor

Conversation

@amd-ruitang3
Copy link
Copy Markdown
Contributor

Motivation

[module_causal_conv1d_update] Build time reduced from 37.7s to 10.7s (-73.2%).

Technical Details

Test Plan

Test Result

Submission Checklist

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 8, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3595 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the module_causal_conv1d_update HIP kernel and its bindings to reduce build time by removing Torch/ATen dependencies from the kernel TU and switching the host interface to aiter_tensor_t + a thread-local HIP stream.

Changes:

  • Refactor causal_conv1d_update host interface from torch::Tensor to aiter_tensor_t, and switch stream handling to aiter::getCurrentHIPStream().
  • Update the pybind module to expose _set_current_hip_stream and adjust Python to compile in develop=True mode (so tensors are converted to aiter_tensor_t and the stream is forwarded).
  • Adjust pybind argument definitions for causal_conv1d_update (cache/index arguments no longer have Torch-tensor defaults).

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
csrc/pybind/causal_conv1d_update_pybind.cu Adds stream setter binding and wires CAUSAL_CONV1D_UPDATE_PYBIND for the module.
csrc/kernels/causal_conv1d_update.cu Refactors kernel TU to aiter_tensor_t + thread-local stream; removes Torch-specific includes.
csrc/include/rocm_ops.hpp Updates CAUSAL_CONV1D_UPDATE_PYBIND arg list for the refactored interface.
csrc/include/causal_conv1d.h Updates the public header signature to aiter_tensor_t.
aiter/ops/causal_conv1d.py Switches compile_ops(..., develop=True) to match the new pybind ABI expectations.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 223 to 227
{
CHECK_INPUT(x);
CHECK_INPUT(conv_state);
CHECK_INPUT(weight);
CHECK_INPUT(out);

// Extract dimensions
const int32_t batch = x.size(0);
const int32_t dim = x.size(1);
const int32_t seqlen = x.size(2);
const int32_t width = weight.size(1);
Comment on lines +230 to +232
AITER_CHECK(conv_state.size(0) == batch || conv_state_indices.numel() > 0, "conv_state batch mismatch");
AITER_CHECK(conv_state.size(1) == dim, "conv_state dim mismatch");
AITER_CHECK(conv_state_len >= width - 1, "conv_state_len must be >= width - 1");
Comment on lines +265 to 268
if(bias.numel() > 0)
{
CHECK_INPUT(bias);
TORCH_CHECK(bias.size(0) == dim, "Bias shape mismatch");
AITER_CHECK(bias.size(0) == dim, "Bias shape mismatch");
params.bias_ptr = bias.data_ptr();
Comment on lines +275 to +278
if (cache_seqlens.numel() > 0) {
AITER_CHECK(cache_seqlens.dtype() == AITER_DTYPE_i32, "cache_seqlens must be int32");
AITER_CHECK(cache_seqlens.size(0) == batch, "cache_seqlens batch mismatch");
params.cache_seqlens = reinterpret_cast<int32_t*>(cache_seqlens.data_ptr());
Comment on lines +283 to +286
if (conv_state_indices.numel() > 0) {
AITER_CHECK(conv_state_indices.dtype() == AITER_DTYPE_i32, "conv_state_indices must be int32");
AITER_CHECK(conv_state_indices.size(0) == batch, "conv_state_indices batch mismatch");
params.conv_state_indices_ptr = reinterpret_cast<int32_t*>(conv_state_indices.data_ptr());
Comment on lines 8 to 12
#include "aiter_hip_common.h"
#include "aiter_tensor.h"
#include "aiter_stream.h"
#include "causal_conv1d.h"
#include "ck_tile/core.hpp"
Comment on lines 1 to 11
// SPDX-License-Identifier: MIT
// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
#include "rocm_ops.hpp"
#include "aiter_stream.h"
#include "causal_conv1d.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
AITER_SET_STREAM_PYBIND
CAUSAL_CONV1D_UPDATE_PYBIND;
}
Comment thread csrc/include/rocm_ops.hpp
Comment on lines 2113 to 2116
py::arg("use_silu"), \
py::arg("cache_seqlens") = torch::Tensor(), \
py::arg("conv_state_indices") = torch::Tensor(), \
py::arg("cache_seqlens"), \
py::arg("conv_state_indices"), \
py::arg("pad_slot_id") = -1);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants