diff --git a/aiter/ops/causal_conv1d.py b/aiter/ops/causal_conv1d.py index 05bdf50191..afb8f6a8cd 100644 --- a/aiter/ops/causal_conv1d.py +++ b/aiter/ops/causal_conv1d.py @@ -7,7 +7,7 @@ MD_NAME = "module_causal_conv1d_update" -@compile_ops("module_causal_conv1d_update") +@compile_ops("module_causal_conv1d_update", develop=True) def causal_conv1d_update( x: Tensor, conv_state: Tensor, diff --git a/csrc/include/causal_conv1d.h b/csrc/include/causal_conv1d.h index 1fb2932c36..04ac58d6ee 100644 --- a/csrc/include/causal_conv1d.h +++ b/csrc/include/causal_conv1d.h @@ -1,19 +1,19 @@ #pragma once // SPDX-License-Identifier: MIT // Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. -#include +#include "aiter_tensor.h" namespace aiter { void causal_conv1d_update( - torch::Tensor& x, - torch::Tensor& conv_state, - const torch::Tensor& weight, - const torch::Tensor& bias, - torch::Tensor& out, + aiter_tensor_t& x, + aiter_tensor_t& conv_state, + aiter_tensor_t& weight, + aiter_tensor_t& bias, + aiter_tensor_t& out, bool use_silu, - const torch::Tensor& cache_seqlens, - const torch::Tensor& conv_state_indices, + aiter_tensor_t& cache_seqlens, + aiter_tensor_t& conv_state_indices, int pad_slot_id); } // namespace aiter diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index fe53ffe11a..839c92c192 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -2111,8 +2111,8 @@ namespace py = pybind11; py::arg("bias"), \ py::arg("out"), \ py::arg("use_silu"), \ - py::arg("cache_seqlens") = torch::Tensor(), \ - py::arg("conv_state_indices") = torch::Tensor(), \ + py::arg("cache_seqlens"), \ + py::arg("conv_state_indices"), \ py::arg("pad_slot_id") = -1); #define CHUNK_GDR_FWD_H_PYBIND \ diff --git a/csrc/kernels/causal_conv1d_update.cu b/csrc/kernels/causal_conv1d_update.cu index 6ed420319e..9744f023df 100644 --- a/csrc/kernels/causal_conv1d_update.cu +++ b/csrc/kernels/causal_conv1d_update.cu @@ -1,7 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2023-2026, Tri Dao. // Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. -// +// // Causal 1D Convolution Update Kernel for AIter Framework // Adapted for AMD MI308 GPU (ROCm/HIP) // @@ -17,17 +17,11 @@ // - Supports fp16, bf16, and fp32 data types // - Convolution widths: 2, 3, 4 -#include -#include -#include - #include "aiter_hip_common.h" +#include "aiter_tensor.h" +#include "aiter_stream.h" +#include "causal_conv1d.h" #include "ck_tile/core.hpp" -#include "dispatch_utils.h" -#include "py_itfs_common.h" - -// Helper macros -#define CHECK_INPUT(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA/HIP tensor") #define HIP_CHECK(err) \ do { \ @@ -39,34 +33,6 @@ } \ } while (0) -#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ - if (ITYPE == at::ScalarType::Half) { \ - using input_t = at::Half; \ - __VA_ARGS__(); \ - } else if (ITYPE == at::ScalarType::BFloat16) { \ - using input_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (ITYPE == at::ScalarType::Float) { \ - using input_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ - } - -#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ - if (WTYPE == at::ScalarType::Half) { \ - using weight_t = at::Half; \ - __VA_ARGS__(); \ - } else if (WTYPE == at::ScalarType::BFloat16) { \ - using weight_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (WTYPE == at::ScalarType::Float) { \ - using weight_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ - } - namespace aiter { // ============================================================================ @@ -86,11 +52,11 @@ struct ConvParamsBaseUpdate { index_t x_batch_stride; // Stride between batches in x index_t x_c_stride; // Stride between channels in x index_t x_l_stride; // Stride between sequence positions in x - + // Weight tensor strides index_t weight_c_stride; // Stride between channels in weight index_t weight_width_stride; // Stride within convolution width - + // Output tensor strides index_t out_batch_stride; // Stride between batches in output index_t out_c_stride; // Stride between channels in output @@ -149,7 +115,7 @@ void causal_conv1d_update_kernel(ConvParamsBaseUpdate params) { const int tidx = threadIdx.x; const int batch_id = blockIdx.x; const int channel_id = blockIdx.y * kNThreads + tidx; - + // Early exit for out-of-bounds channels if (channel_id >= params.dim) return; @@ -162,17 +128,17 @@ void causal_conv1d_update_kernel(ConvParamsBaseUpdate params) { const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr ? batch_id : params.conv_state_indices_ptr[batch_id]; - + // Skip processing if this is a padding slot if (conv_state_batch_coord == params.pad_slot_id){ return; } - + // Conv state pointer for this channel input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + conv_state_batch_coord * params.conv_state_batch_stride + channel_id * params.conv_state_c_stride; - + // Weight and output pointers weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride @@ -193,7 +159,7 @@ void causal_conv1d_update_kernel(ConvParamsBaseUpdate params) { // Sliding window buffer for input values float x_vals[kWidth] = {0}; - + // Initialize x_vals with historical state values if constexpr (!kIsCircularBuffer) { // Non-circular mode: Shift old data to make room for new data @@ -201,7 +167,7 @@ void causal_conv1d_update_kernel(ConvParamsBaseUpdate params) { for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) { conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride]; } - + // Load the most recent (kWidth-1) historical states into x_vals #pragma unroll for (int i = 0; i < kWidth - 1; ++i) { @@ -219,13 +185,13 @@ void causal_conv1d_update_kernel(ConvParamsBaseUpdate params) { x_vals[i] = float(state_val); } } - + // Main convolution loop: Process each new input token #pragma unroll 2 for (int i = 0; i < params.seqlen; ++i) { // Read new input input_t x_val = x[i * params.x_l_stride]; - + // Update conv_state with new input if constexpr (!kIsCircularBuffer) { // Non-circular: Write to the end of the buffer @@ -238,21 +204,21 @@ void causal_conv1d_update_kernel(ConvParamsBaseUpdate params) { ++update_idx; update_idx = update_idx >= state_len ? update_idx - state_len : update_idx; } - + // Add new input to the sliding window x_vals[kWidth - 1] = float(x_val); - + // Compute convolution output float out_val = bias_val; #pragma unroll for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; } - + // Apply SiLU activation: x * sigmoid(x) = x / (1 + exp(-x)) if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } - + // Write output out[i * params.out_l_stride] = input_t(out_val); - + // Shift the sliding window left by 1 position #pragma unroll for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; } @@ -284,7 +250,7 @@ void causal_conv1d_update_launch(ConvParamsBaseUpdate ¶ms, hipStream_t strea template void causal_conv1d_update_dispatch(ConvParamsBaseUpdate ¶ms, hipStream_t stream) { constexpr int kNThreads = 64; // Optimized for AMD wavefront size - + if (params.width == 2) { causal_conv1d_update_launch(params, stream); } else if (params.width == 3) { @@ -295,27 +261,22 @@ void causal_conv1d_update_dispatch(ConvParamsBaseUpdate ¶ms, hipStream_t str } // ============================================================================ -// PyTorch Interface +// Host Interface // ============================================================================ -// Main entry point for Python/PyTorch +// Main entry point called from Python via pybind11 // Handles tensor validation, parameter setup, and kernel dispatch void causal_conv1d_update( - torch::Tensor& x, // [batch, dim, seqlen] - new input (typically seqlen=1 for decoding) - torch::Tensor& conv_state, // [batch, dim, state_len] - state buffer (updated in-place) - const torch::Tensor& weight, // [dim, width] - convolution weights - const torch::Tensor& bias, // [dim] - bias (or empty) - torch::Tensor& out, // [batch, dim, seqlen] - output - bool use_silu, // Whether to apply SiLU activation - const torch::Tensor& cache_seqlens, // [batch] - for circular buffer mode (or empty) - const torch::Tensor& conv_state_indices, // [batch] - for continuous batching (or empty) - int pad_slot_id) // Padding slot ID (-1 = no padding) + aiter_tensor_t& x, // [batch, dim, seqlen] - new input (typically seqlen=1 for decoding) + aiter_tensor_t& conv_state, // [batch, dim, state_len] - state buffer (updated in-place) + aiter_tensor_t& weight, // [dim, width] - convolution weights + aiter_tensor_t& bias, // [dim] - bias (or empty) + aiter_tensor_t& out, // [batch, dim, seqlen] - output + bool use_silu, // Whether to apply SiLU activation + aiter_tensor_t& cache_seqlens, // [batch] - for circular buffer mode (or empty) + aiter_tensor_t& conv_state_indices, // [batch] - for continuous batching (or empty) + int pad_slot_id) // Padding slot ID (-1 = no padding) { - CHECK_INPUT(x); - CHECK_INPUT(conv_state); - CHECK_INPUT(weight); - CHECK_INPUT(out); - // Extract dimensions const int32_t batch = x.size(0); const int32_t dim = x.size(1); @@ -324,14 +285,13 @@ void causal_conv1d_update( const int32_t conv_state_len = conv_state.size(2); // Validate tensor shapes - TORCH_CHECK(conv_state.size(0) == batch || conv_state_indices.defined(), "conv_state batch mismatch"); - TORCH_CHECK(conv_state.size(1) == dim, "conv_state dim mismatch"); - TORCH_CHECK(conv_state_len >= width - 1, "conv_state_len must be >= width - 1"); - TORCH_CHECK(out.size(0) == batch && out.size(1) == dim && out.size(2) == seqlen, "Output shape mismatch"); - TORCH_CHECK(weight.size(0) == dim, "Weight shape mismatch"); - TORCH_CHECK(width >= 2 && width <= 4, "Width must be 2, 3, or 4"); + AITER_CHECK(conv_state.size(0) == batch || conv_state_indices.numel() > 0, "conv_state batch mismatch"); + AITER_CHECK(conv_state.size(1) == dim, "conv_state dim mismatch"); + AITER_CHECK(conv_state_len >= width - 1, "conv_state_len must be >= width - 1"); + AITER_CHECK(out.size(0) == batch && out.size(1) == dim && out.size(2) == seqlen, "Output shape mismatch"); + AITER_CHECK(weight.size(0) == dim, "Weight shape mismatch"); + AITER_CHECK(width >= 2 && width <= 4, "Width must be 2, 3, or 4"); - // Setup kernel parameters // Setup kernel parameters ConvParamsBaseUpdate params; params.batch = batch; @@ -367,10 +327,9 @@ void causal_conv1d_update( params.conv_state_ptr = conv_state.data_ptr(); // Optional bias - if(bias.defined() && bias.numel() > 0) + if(bias.numel() > 0) { - CHECK_INPUT(bias); - TORCH_CHECK(bias.size(0) == dim, "Bias shape mismatch"); + AITER_CHECK(bias.size(0) == dim, "Bias shape mismatch"); params.bias_ptr = bias.data_ptr(); } else { params.bias_ptr = nullptr; @@ -380,38 +339,76 @@ void causal_conv1d_update( params.pad_slot_id = pad_slot_id; // Optional: cache_seqlens for circular buffer mode - if (cache_seqlens.defined() && cache_seqlens.numel() > 0) { - CHECK_INPUT(cache_seqlens); - TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32, "cache_seqlens must be int32"); - TORCH_CHECK(cache_seqlens.size(0) == batch, "cache_seqlens batch mismatch"); - params.cache_seqlens = cache_seqlens.data_ptr(); + if (cache_seqlens.numel() > 0) { + AITER_CHECK(cache_seqlens.dtype() == AITER_DTYPE_i32, "cache_seqlens must be int32"); + AITER_CHECK(cache_seqlens.size(0) == batch, "cache_seqlens batch mismatch"); + params.cache_seqlens = reinterpret_cast(cache_seqlens.data_ptr()); } else { params.cache_seqlens = nullptr; } // Optional: conv_state_indices for continuous batching - if (conv_state_indices.defined() && conv_state_indices.numel() > 0) { - CHECK_INPUT(conv_state_indices); - TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32, "conv_state_indices must be int32"); - TORCH_CHECK(conv_state_indices.size(0) == batch, "conv_state_indices batch mismatch"); - params.conv_state_indices_ptr = conv_state_indices.data_ptr(); + if (conv_state_indices.numel() > 0) { + AITER_CHECK(conv_state_indices.dtype() == AITER_DTYPE_i32, "conv_state_indices must be int32"); + AITER_CHECK(conv_state_indices.size(0) == batch, "conv_state_indices batch mismatch"); + params.conv_state_indices_ptr = reinterpret_cast(conv_state_indices.data_ptr()); } else { params.conv_state_indices_ptr = nullptr; } // Get HIP device and stream - const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(x)); - const hipStream_t stream = at::hip::getCurrentHIPStream(); + HipDeviceGuard device_guard(x.device_id); + const hipStream_t stream = aiter::getCurrentHIPStream(); // Dispatch to appropriate kernel based on data types - DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { - DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] { + if (x.dtype() == AITER_DTYPE_fp16) { + using input_t = _Float16; + if (weight.dtype() == AITER_DTYPE_fp16) { + using weight_t = _Float16; + causal_conv1d_update_dispatch(params, stream); + } else if (weight.dtype() == AITER_DTYPE_bf16) { + using weight_t = hip_bfloat16; + causal_conv1d_update_dispatch(params, stream); + } else if (weight.dtype() == AITER_DTYPE_fp32) { + using weight_t = float; + causal_conv1d_update_dispatch(params, stream); + } else { + AITER_CHECK(false, "causal_conv1d_update not implemented for weight type"); + } + } else if (x.dtype() == AITER_DTYPE_bf16) { + using input_t = hip_bfloat16; + if (weight.dtype() == AITER_DTYPE_fp16) { + using weight_t = _Float16; + causal_conv1d_update_dispatch(params, stream); + } else if (weight.dtype() == AITER_DTYPE_bf16) { + using weight_t = hip_bfloat16; causal_conv1d_update_dispatch(params, stream); - }); - }); + } else if (weight.dtype() == AITER_DTYPE_fp32) { + using weight_t = float; + causal_conv1d_update_dispatch(params, stream); + } else { + AITER_CHECK(false, "causal_conv1d_update not implemented for weight type"); + } + } else if (x.dtype() == AITER_DTYPE_fp32) { + using input_t = float; + if (weight.dtype() == AITER_DTYPE_fp16) { + using weight_t = _Float16; + causal_conv1d_update_dispatch(params, stream); + } else if (weight.dtype() == AITER_DTYPE_bf16) { + using weight_t = hip_bfloat16; + causal_conv1d_update_dispatch(params, stream); + } else if (weight.dtype() == AITER_DTYPE_fp32) { + using weight_t = float; + causal_conv1d_update_dispatch(params, stream); + } else { + AITER_CHECK(false, "causal_conv1d_update not implemented for weight type"); + } + } else { + AITER_CHECK(false, "causal_conv1d_update not implemented for input type"); + } // Check for kernel launch errors HIP_CHECK(hipGetLastError()); } -} // namespace aiter \ No newline at end of file +} // namespace aiter diff --git a/csrc/pybind/causal_conv1d_update_pybind.cu b/csrc/pybind/causal_conv1d_update_pybind.cu index eb53b0e79c..7889e91fb1 100644 --- a/csrc/pybind/causal_conv1d_update_pybind.cu +++ b/csrc/pybind/causal_conv1d_update_pybind.cu @@ -1,9 +1,11 @@ // SPDX-License-Identifier: MIT // Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. #include "rocm_ops.hpp" +#include "aiter_stream.h" #include "causal_conv1d.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + AITER_SET_STREAM_PYBIND CAUSAL_CONV1D_UPDATE_PYBIND; }