Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aiter/ops/causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
MD_NAME = "module_causal_conv1d_update"


@compile_ops("module_causal_conv1d_update")
@compile_ops("module_causal_conv1d_update", develop=True)
def causal_conv1d_update(
x: Tensor,
conv_state: Tensor,
Expand Down
16 changes: 8 additions & 8 deletions csrc/include/causal_conv1d.h
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
#pragma once
// SPDX-License-Identifier: MIT
// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
#include <torch/extension.h>
#include "aiter_tensor.h"

namespace aiter {

void causal_conv1d_update(
torch::Tensor& x,
torch::Tensor& conv_state,
const torch::Tensor& weight,
const torch::Tensor& bias,
torch::Tensor& out,
aiter_tensor_t& x,
aiter_tensor_t& conv_state,
aiter_tensor_t& weight,
aiter_tensor_t& bias,
aiter_tensor_t& out,
bool use_silu,
const torch::Tensor& cache_seqlens,
const torch::Tensor& conv_state_indices,
aiter_tensor_t& cache_seqlens,
aiter_tensor_t& conv_state_indices,
int pad_slot_id);

} // namespace aiter
4 changes: 2 additions & 2 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2111,8 +2111,8 @@ namespace py = pybind11;
py::arg("bias"), \
py::arg("out"), \
py::arg("use_silu"), \
py::arg("cache_seqlens") = torch::Tensor(), \
py::arg("conv_state_indices") = torch::Tensor(), \
py::arg("cache_seqlens"), \
py::arg("conv_state_indices"), \
py::arg("pad_slot_id") = -1);
Comment on lines 2113 to 2116

#define CHUNK_GDR_FWD_H_PYBIND \
Expand Down
191 changes: 94 additions & 97 deletions csrc/kernels/causal_conv1d_update.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// SPDX-License-Identifier: MIT
// Copyright (C) 2023-2026, Tri Dao.
// Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
//
//
// Causal 1D Convolution Update Kernel for AIter Framework
// Adapted for AMD MI308 GPU (ROCm/HIP)
//
Expand All @@ -17,17 +17,11 @@
// - Supports fp16, bf16, and fp32 data types
// - Convolution widths: 2, 3, 4

#include <ATen/hip/HIPContext.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <torch/extension.h>

#include "aiter_hip_common.h"
#include "aiter_tensor.h"
#include "aiter_stream.h"
#include "causal_conv1d.h"
#include "ck_tile/core.hpp"
Comment on lines 20 to 24
#include "dispatch_utils.h"
#include "py_itfs_common.h"

// Helper macros
#define CHECK_INPUT(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA/HIP tensor")

#define HIP_CHECK(err) \
do { \
Expand All @@ -39,34 +33,6 @@
} \
} while (0)

#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
if (ITYPE == at::ScalarType::Half) { \
using input_t = at::Half; \
__VA_ARGS__(); \
} else if (ITYPE == at::ScalarType::BFloat16) { \
using input_t = at::BFloat16; \
__VA_ARGS__(); \
} else if (ITYPE == at::ScalarType::Float) { \
using input_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
}

#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
if (WTYPE == at::ScalarType::Half) { \
using weight_t = at::Half; \
__VA_ARGS__(); \
} else if (WTYPE == at::ScalarType::BFloat16) { \
using weight_t = at::BFloat16; \
__VA_ARGS__(); \
} else if (WTYPE == at::ScalarType::Float) { \
using weight_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
}

namespace aiter {

// ============================================================================
Expand All @@ -86,11 +52,11 @@ struct ConvParamsBaseUpdate {
index_t x_batch_stride; // Stride between batches in x
index_t x_c_stride; // Stride between channels in x
index_t x_l_stride; // Stride between sequence positions in x

// Weight tensor strides
index_t weight_c_stride; // Stride between channels in weight
index_t weight_width_stride; // Stride within convolution width

// Output tensor strides
index_t out_batch_stride; // Stride between batches in output
index_t out_c_stride; // Stride between channels in output
Expand Down Expand Up @@ -149,7 +115,7 @@ void causal_conv1d_update_kernel(ConvParamsBaseUpdate params) {
const int tidx = threadIdx.x;
const int batch_id = blockIdx.x;
const int channel_id = blockIdx.y * kNThreads + tidx;

// Early exit for out-of-bounds channels
if (channel_id >= params.dim) return;

Expand All @@ -162,17 +128,17 @@ void causal_conv1d_update_kernel(ConvParamsBaseUpdate params) {
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
? batch_id
: params.conv_state_indices_ptr[batch_id];

// Skip processing if this is a padding slot
if (conv_state_batch_coord == params.pad_slot_id){
return;
}

// Conv state pointer for this channel
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
+ conv_state_batch_coord * params.conv_state_batch_stride
+ channel_id * params.conv_state_c_stride;

// Weight and output pointers
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
Expand All @@ -193,15 +159,15 @@ void causal_conv1d_update_kernel(ConvParamsBaseUpdate params) {

// Sliding window buffer for input values
float x_vals[kWidth] = {0};

// Initialize x_vals with historical state values
if constexpr (!kIsCircularBuffer) {
// Non-circular mode: Shift old data to make room for new data
#pragma unroll 2
for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) {
conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride];
}

// Load the most recent (kWidth-1) historical states into x_vals
#pragma unroll
for (int i = 0; i < kWidth - 1; ++i) {
Expand All @@ -219,13 +185,13 @@ void causal_conv1d_update_kernel(ConvParamsBaseUpdate params) {
x_vals[i] = float(state_val);
}
}

// Main convolution loop: Process each new input token
#pragma unroll 2
for (int i = 0; i < params.seqlen; ++i) {
// Read new input
input_t x_val = x[i * params.x_l_stride];

// Update conv_state with new input
if constexpr (!kIsCircularBuffer) {
// Non-circular: Write to the end of the buffer
Expand All @@ -238,21 +204,21 @@ void causal_conv1d_update_kernel(ConvParamsBaseUpdate params) {
++update_idx;
update_idx = update_idx >= state_len ? update_idx - state_len : update_idx;
}

// Add new input to the sliding window
x_vals[kWidth - 1] = float(x_val);

// Compute convolution output
float out_val = bias_val;
#pragma unroll
for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; }

// Apply SiLU activation: x * sigmoid(x) = x / (1 + exp(-x))
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }

// Write output
out[i * params.out_l_stride] = input_t(out_val);

// Shift the sliding window left by 1 position
#pragma unroll
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; }
Expand Down Expand Up @@ -284,7 +250,7 @@ void causal_conv1d_update_launch(ConvParamsBaseUpdate &params, hipStream_t strea
template<typename input_t, typename weight_t>
void causal_conv1d_update_dispatch(ConvParamsBaseUpdate &params, hipStream_t stream) {
constexpr int kNThreads = 64; // Optimized for AMD wavefront size

if (params.width == 2) {
causal_conv1d_update_launch<kNThreads, 2, input_t, weight_t>(params, stream);
} else if (params.width == 3) {
Expand All @@ -295,27 +261,22 @@ void causal_conv1d_update_dispatch(ConvParamsBaseUpdate &params, hipStream_t str
}

// ============================================================================
// PyTorch Interface
// Host Interface
// ============================================================================
// Main entry point for Python/PyTorch
// Main entry point called from Python via pybind11
// Handles tensor validation, parameter setup, and kernel dispatch

void causal_conv1d_update(
torch::Tensor& x, // [batch, dim, seqlen] - new input (typically seqlen=1 for decoding)
torch::Tensor& conv_state, // [batch, dim, state_len] - state buffer (updated in-place)
const torch::Tensor& weight, // [dim, width] - convolution weights
const torch::Tensor& bias, // [dim] - bias (or empty)
torch::Tensor& out, // [batch, dim, seqlen] - output
bool use_silu, // Whether to apply SiLU activation
const torch::Tensor& cache_seqlens, // [batch] - for circular buffer mode (or empty)
const torch::Tensor& conv_state_indices, // [batch] - for continuous batching (or empty)
int pad_slot_id) // Padding slot ID (-1 = no padding)
aiter_tensor_t& x, // [batch, dim, seqlen] - new input (typically seqlen=1 for decoding)
aiter_tensor_t& conv_state, // [batch, dim, state_len] - state buffer (updated in-place)
aiter_tensor_t& weight, // [dim, width] - convolution weights
aiter_tensor_t& bias, // [dim] - bias (or empty)
aiter_tensor_t& out, // [batch, dim, seqlen] - output
bool use_silu, // Whether to apply SiLU activation
aiter_tensor_t& cache_seqlens, // [batch] - for circular buffer mode (or empty)
aiter_tensor_t& conv_state_indices, // [batch] - for continuous batching (or empty)
int pad_slot_id) // Padding slot ID (-1 = no padding)
{
CHECK_INPUT(x);
CHECK_INPUT(conv_state);
CHECK_INPUT(weight);
CHECK_INPUT(out);

// Extract dimensions
const int32_t batch = x.size(0);
const int32_t dim = x.size(1);
Expand All @@ -324,14 +285,13 @@ void causal_conv1d_update(
const int32_t conv_state_len = conv_state.size(2);

// Validate tensor shapes
TORCH_CHECK(conv_state.size(0) == batch || conv_state_indices.defined(), "conv_state batch mismatch");
TORCH_CHECK(conv_state.size(1) == dim, "conv_state dim mismatch");
TORCH_CHECK(conv_state_len >= width - 1, "conv_state_len must be >= width - 1");
TORCH_CHECK(out.size(0) == batch && out.size(1) == dim && out.size(2) == seqlen, "Output shape mismatch");
TORCH_CHECK(weight.size(0) == dim, "Weight shape mismatch");
TORCH_CHECK(width >= 2 && width <= 4, "Width must be 2, 3, or 4");
AITER_CHECK(conv_state.size(0) == batch || conv_state_indices.numel() > 0, "conv_state batch mismatch");
AITER_CHECK(conv_state.size(1) == dim, "conv_state dim mismatch");
AITER_CHECK(conv_state_len >= width - 1, "conv_state_len must be >= width - 1");
Comment on lines +288 to +290
AITER_CHECK(out.size(0) == batch && out.size(1) == dim && out.size(2) == seqlen, "Output shape mismatch");
AITER_CHECK(weight.size(0) == dim, "Weight shape mismatch");
AITER_CHECK(width >= 2 && width <= 4, "Width must be 2, 3, or 4");

// Setup kernel parameters
// Setup kernel parameters
ConvParamsBaseUpdate params;
params.batch = batch;
Expand Down Expand Up @@ -367,10 +327,9 @@ void causal_conv1d_update(
params.conv_state_ptr = conv_state.data_ptr();

// Optional bias
if(bias.defined() && bias.numel() > 0)
if(bias.numel() > 0)
{
CHECK_INPUT(bias);
TORCH_CHECK(bias.size(0) == dim, "Bias shape mismatch");
AITER_CHECK(bias.size(0) == dim, "Bias shape mismatch");
params.bias_ptr = bias.data_ptr();
Comment on lines +330 to 333
} else {
params.bias_ptr = nullptr;
Expand All @@ -380,38 +339,76 @@ void causal_conv1d_update(
params.pad_slot_id = pad_slot_id;

// Optional: cache_seqlens for circular buffer mode
if (cache_seqlens.defined() && cache_seqlens.numel() > 0) {
CHECK_INPUT(cache_seqlens);
TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32, "cache_seqlens must be int32");
TORCH_CHECK(cache_seqlens.size(0) == batch, "cache_seqlens batch mismatch");
params.cache_seqlens = cache_seqlens.data_ptr<int32_t>();
if (cache_seqlens.numel() > 0) {
AITER_CHECK(cache_seqlens.dtype() == AITER_DTYPE_i32, "cache_seqlens must be int32");
AITER_CHECK(cache_seqlens.size(0) == batch, "cache_seqlens batch mismatch");
params.cache_seqlens = reinterpret_cast<int32_t*>(cache_seqlens.data_ptr());
Comment on lines +342 to +345
} else {
params.cache_seqlens = nullptr;
}

// Optional: conv_state_indices for continuous batching
if (conv_state_indices.defined() && conv_state_indices.numel() > 0) {
CHECK_INPUT(conv_state_indices);
TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32, "conv_state_indices must be int32");
TORCH_CHECK(conv_state_indices.size(0) == batch, "conv_state_indices batch mismatch");
params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>();
if (conv_state_indices.numel() > 0) {
AITER_CHECK(conv_state_indices.dtype() == AITER_DTYPE_i32, "conv_state_indices must be int32");
AITER_CHECK(conv_state_indices.size(0) == batch, "conv_state_indices batch mismatch");
params.conv_state_indices_ptr = reinterpret_cast<int32_t*>(conv_state_indices.data_ptr());
Comment on lines +351 to +354
} else {
params.conv_state_indices_ptr = nullptr;
}

// Get HIP device and stream
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(x));
const hipStream_t stream = at::hip::getCurrentHIPStream();
HipDeviceGuard device_guard(x.device_id);
const hipStream_t stream = aiter::getCurrentHIPStream();

// Dispatch to appropriate kernel based on data types
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] {
if (x.dtype() == AITER_DTYPE_fp16) {
using input_t = _Float16;
if (weight.dtype() == AITER_DTYPE_fp16) {
using weight_t = _Float16;
causal_conv1d_update_dispatch<input_t, weight_t>(params, stream);
} else if (weight.dtype() == AITER_DTYPE_bf16) {
using weight_t = hip_bfloat16;
causal_conv1d_update_dispatch<input_t, weight_t>(params, stream);
} else if (weight.dtype() == AITER_DTYPE_fp32) {
using weight_t = float;
causal_conv1d_update_dispatch<input_t, weight_t>(params, stream);
} else {
AITER_CHECK(false, "causal_conv1d_update not implemented for weight type");
}
} else if (x.dtype() == AITER_DTYPE_bf16) {
using input_t = hip_bfloat16;
if (weight.dtype() == AITER_DTYPE_fp16) {
using weight_t = _Float16;
causal_conv1d_update_dispatch<input_t, weight_t>(params, stream);
} else if (weight.dtype() == AITER_DTYPE_bf16) {
using weight_t = hip_bfloat16;
causal_conv1d_update_dispatch<input_t, weight_t>(params, stream);
});
});
} else if (weight.dtype() == AITER_DTYPE_fp32) {
using weight_t = float;
causal_conv1d_update_dispatch<input_t, weight_t>(params, stream);
} else {
AITER_CHECK(false, "causal_conv1d_update not implemented for weight type");
}
} else if (x.dtype() == AITER_DTYPE_fp32) {
using input_t = float;
if (weight.dtype() == AITER_DTYPE_fp16) {
using weight_t = _Float16;
causal_conv1d_update_dispatch<input_t, weight_t>(params, stream);
} else if (weight.dtype() == AITER_DTYPE_bf16) {
using weight_t = hip_bfloat16;
causal_conv1d_update_dispatch<input_t, weight_t>(params, stream);
} else if (weight.dtype() == AITER_DTYPE_fp32) {
using weight_t = float;
causal_conv1d_update_dispatch<input_t, weight_t>(params, stream);
} else {
AITER_CHECK(false, "causal_conv1d_update not implemented for weight type");
}
} else {
AITER_CHECK(false, "causal_conv1d_update not implemented for input type");
}

// Check for kernel launch errors
HIP_CHECK(hipGetLastError());
}

} // namespace aiter
} // namespace aiter
Loading
Loading