[module_causal_conv1d_update] refactor hip kernel#3595
Open
amd-ruitang3 wants to merge 2 commits into
Open
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Contributor
There was a problem hiding this comment.
Pull request overview
This PR refactors the module_causal_conv1d_update HIP kernel and its bindings to reduce build time by removing Torch/ATen dependencies from the kernel TU and switching the host interface to aiter_tensor_t + a thread-local HIP stream.
Changes:
- Refactor
causal_conv1d_updatehost interface fromtorch::Tensortoaiter_tensor_t, and switch stream handling toaiter::getCurrentHIPStream(). - Update the pybind module to expose
_set_current_hip_streamand adjust Python to compile indevelop=Truemode (so tensors are converted toaiter_tensor_tand the stream is forwarded). - Adjust pybind argument definitions for
causal_conv1d_update(cache/index arguments no longer have Torch-tensor defaults).
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| csrc/pybind/causal_conv1d_update_pybind.cu | Adds stream setter binding and wires CAUSAL_CONV1D_UPDATE_PYBIND for the module. |
| csrc/kernels/causal_conv1d_update.cu | Refactors kernel TU to aiter_tensor_t + thread-local stream; removes Torch-specific includes. |
| csrc/include/rocm_ops.hpp | Updates CAUSAL_CONV1D_UPDATE_PYBIND arg list for the refactored interface. |
| csrc/include/causal_conv1d.h | Updates the public header signature to aiter_tensor_t. |
| aiter/ops/causal_conv1d.py | Switches compile_ops(..., develop=True) to match the new pybind ABI expectations. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
223
to
227
| { | ||
| CHECK_INPUT(x); | ||
| CHECK_INPUT(conv_state); | ||
| CHECK_INPUT(weight); | ||
| CHECK_INPUT(out); | ||
|
|
||
| // Extract dimensions | ||
| const int32_t batch = x.size(0); | ||
| const int32_t dim = x.size(1); | ||
| const int32_t seqlen = x.size(2); | ||
| const int32_t width = weight.size(1); |
Comment on lines
+230
to
+232
| AITER_CHECK(conv_state.size(0) == batch || conv_state_indices.numel() > 0, "conv_state batch mismatch"); | ||
| AITER_CHECK(conv_state.size(1) == dim, "conv_state dim mismatch"); | ||
| AITER_CHECK(conv_state_len >= width - 1, "conv_state_len must be >= width - 1"); |
Comment on lines
+265
to
268
| if(bias.numel() > 0) | ||
| { | ||
| CHECK_INPUT(bias); | ||
| TORCH_CHECK(bias.size(0) == dim, "Bias shape mismatch"); | ||
| AITER_CHECK(bias.size(0) == dim, "Bias shape mismatch"); | ||
| params.bias_ptr = bias.data_ptr(); |
Comment on lines
+275
to
+278
| if (cache_seqlens.numel() > 0) { | ||
| AITER_CHECK(cache_seqlens.dtype() == AITER_DTYPE_i32, "cache_seqlens must be int32"); | ||
| AITER_CHECK(cache_seqlens.size(0) == batch, "cache_seqlens batch mismatch"); | ||
| params.cache_seqlens = reinterpret_cast<int32_t*>(cache_seqlens.data_ptr()); |
Comment on lines
+283
to
+286
| if (conv_state_indices.numel() > 0) { | ||
| AITER_CHECK(conv_state_indices.dtype() == AITER_DTYPE_i32, "conv_state_indices must be int32"); | ||
| AITER_CHECK(conv_state_indices.size(0) == batch, "conv_state_indices batch mismatch"); | ||
| params.conv_state_indices_ptr = reinterpret_cast<int32_t*>(conv_state_indices.data_ptr()); |
Comment on lines
8
to
12
| #include "aiter_hip_common.h" | ||
| #include "aiter_tensor.h" | ||
| #include "aiter_stream.h" | ||
| #include "causal_conv1d.h" | ||
| #include "ck_tile/core.hpp" |
Comment on lines
1
to
11
| // SPDX-License-Identifier: MIT | ||
| // Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. | ||
| #include "rocm_ops.hpp" | ||
| #include "aiter_stream.h" | ||
| #include "causal_conv1d.h" | ||
|
|
||
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) | ||
| { | ||
| AITER_SET_STREAM_PYBIND | ||
| CAUSAL_CONV1D_UPDATE_PYBIND; | ||
| } |
Comment on lines
2113
to
2116
| py::arg("use_silu"), \ | ||
| py::arg("cache_seqlens") = torch::Tensor(), \ | ||
| py::arg("conv_state_indices") = torch::Tensor(), \ | ||
| py::arg("cache_seqlens"), \ | ||
| py::arg("conv_state_indices"), \ | ||
| py::arg("pad_slot_id") = -1); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
[module_causal_conv1d_update] Build time reduced from 37.7s to 10.7s (-73.2%).
Technical Details
Test Plan
Test Result
Submission Checklist