Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/infinicore_infer.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@

#include "infinicore_infer/models/deepseek.h"
#include "infinicore_infer/models/jiuge.h"
#include "infinicore_infer/models/Qwen3MoE.h"

#endif /* INFINICORE_INFER_H */
88 changes: 88 additions & 0 deletions include/infinicore_infer/models/Qwen3MoE.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#ifndef QWEN3MOE
#define QWEN3MOE

#include <infiniccl.h>
#include <infiniop.h>
#include <infinirt.h>

struct Qwen3MoEWeights;

/// @brief 函数指针
typedef void (*load_global)(Qwen3MoEWeights *, void *cpu_ptr);
typedef void (*load_layer)(Qwen3MoEWeights *, void *cpu_ptr, size_t layer_id);
typedef void (*load_layer_linear)(Qwen3MoEWeights *, void *weight_ptr, size_t layer_id);
/// @brief 权重加载器
typedef struct {
// Pre-Norm
load_layer load_attn_norm;

// Attention
load_layer_linear load_attn_q_proj;
load_layer_linear load_attn_k_proj;
load_layer_linear load_attn_v_proj;

// QKNorm(RMSNorm)
load_layer load_attn_q_norm;
load_layer load_attn_k_norm;

// output linear
load_layer_linear load_attn_o_proj;

}Qwen3MoEWeightLoader;

struct Qwen3MoEAttention;

/// @brief 模型参数
typedef struct {
//数据种类 BF16 / FP16
infiniDtype_t dtype;

// Linear args
size_t hidden_size;
size_t num_heads;
size_t num_kv_head; // k_v head GQA广播倍数
size_t head_dim;

// RoPE args
float rope_theta;
size_t max_seq_len;

float rms_norm_eps; //防止除零
}Qwen3MoEAttentionMeta;

/// ==================== API ====================

/// @brief 创建注意力模块
__C __export struct Qwen3MoEAttention *
createQwen3MoEAttention(const Qwen3MoEAttentionMeta *,
const Qwen3MoEWeights *);
/// @brief 创建权重矩阵
__C Qwen3MoEWeights *
createQwen3MoEWeights(const Qwen3MoEAttentionMeta *meta,
infiniDevice_t device,
int ndev,
const int *dev_ids);
/// @brief 创建weight加载器
__C __export Qwen3MoEWeightLoader *
createQwen3MoEWeightLoader();
/// @brief 创建KVCache
__C __export struct Qwen3Cache *
createQwen3Cache(const Qwen3MoEAttentionMeta *meta,
size_t batch_size, size_t seq_len);
/// @brief 前向计算
__C __export void forwardQwen3MoEAttention(
struct Qwen3MoEAttention* context,
struct Qwen3Cache* kv_cache,
const void* input_tensor,
void* output_tensor,
int batch_size, // [新增]
const int* seq_lens_ptr, // [新增]
const int* past_lens_ptr, // [新增]
const int* pos_ids_ptr // [新增]
);

/// @brief 销毁模型
__C __export void destroyQwen3MoEAttention(struct Qwen3MoEAttention* ctx);


#endif
115 changes: 115 additions & 0 deletions scripts/libinfinicore_infer/qwen3_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from .base import BaseModel, DataType, DeviceType, KVCacheCStruct, register_model
from ctypes import (
c_size_t,
c_uint,
c_int,
c_float,
c_void_p,
POINTER,
Structure,
CFUNCTYPE,
)


class Qwen3MoEAttentionMetaCStruct(Structure):
_fields_ = [
("dtype", DataType),
("hidden_size", c_size_t),
("num_heads", c_size_t),
("num_kv_head", c_size_t),
("head_dim", c_size_t),
("rope_theta", c_float),
("max_seq_len", c_size_t),
("rms_norm_eps", c_float),
]


class Qwen3MoEWeightsCStruct(Structure):
pass


class Qwen3MoEAttentionCStruct(Structure):
pass


class Qwen3CacheCStruct(Structure):
pass


load_layer_fn = CFUNCTYPE(None, POINTER(Qwen3MoEWeightsCStruct), c_void_p, c_size_t)
load_layer_linear_fn = CFUNCTYPE(
None, POINTER(Qwen3MoEWeightsCStruct), c_void_p, c_void_p, c_void_p, c_size_t
)


class Qwen3MoEWeightLoaderCStruct(Structure):
_fields_ = [
("load_attn_norm", load_layer_fn),
("load_attn_q_proj", load_layer_linear_fn),
("load_attn_k_proj", load_layer_linear_fn),
("load_attn_v_proj", load_layer_linear_fn),
("load_attn_q_norm", load_layer_fn),
("load_attn_k_norm", load_layer_fn),
("load_attn_o_proj", load_layer_linear_fn),
]


@register_model
class Qwen3MoEModel(BaseModel):
@classmethod
def register_lib(cls, lib):
"""Register Qwen3MoE model functions with the library"""
lib.createQwen3MoEWeightLoader.argtypes = []
lib.createQwen3MoEWeightLoader.restype = POINTER(
Qwen3MoEWeightLoaderCStruct
)

lib.createQwen3MoEWeights.argtypes = [
POINTER(Qwen3MoEAttentionMetaCStruct),
DeviceType,
c_int,
POINTER(c_int),
]
lib.createQwen3MoEWeights.restype = POINTER(Qwen3MoEWeightsCStruct)

lib.createQwen3MoEAttention.argtypes = [
POINTER(Qwen3MoEAttentionMetaCStruct),
POINTER(Qwen3MoEWeightsCStruct),
]
lib.createQwen3MoEAttention.restype = POINTER(Qwen3MoEAttentionCStruct)

lib.destroyQwen3MoEAttention.argtypes = [POINTER(Qwen3MoEAttentionCStruct)]

lib.createQwen3Cache.argtypes = [
POINTER(Qwen3MoEAttentionMetaCStruct),
c_size_t,
c_size_t,
]
lib.createQwen3Cache.restype = POINTER(Qwen3CacheCStruct)

lib.forwardQwen3MoEAttention.argtypes = [
POINTER(Qwen3MoEAttentionCStruct),
POINTER(Qwen3CacheCStruct),
c_void_p,
c_void_p,
]

def create_weight_loader(self):
return self.lib.createQwen3MoEWeightLoader()

def create_weights(self, meta, device_type, ndev, dev_ids):
return self.lib.createQwen3MoEWeights(meta, device_type, ndev, dev_ids)

def create_model(self, meta, weights):
return self.lib.createQwen3MoEAttention(meta, weights)

def destroy_model(self, model):
self.lib.destroyQwen3MoEAttention(model)

def create_cache(self, meta, batch_size, seq_len):
return self.lib.createQwen3Cache(meta, batch_size, seq_len)

def forward_attention(self, model, kv_cache, input_tensor, output_tensor):
self.lib.forwardQwen3MoEAttention(model, kv_cache, input_tensor, output_tensor)


Loading