@@ -89,10 +89,10 @@ namespace chatllm::qwen::v3_5
8989 bool vit_loaded = false ;
9090 };
9191
92- class QwenGatedAttention : public QKNormedRoPEAttention <RMSNorm , BaseAttention>
92+ class QwenGatedAttention : public QKNormedRoPEAttention <RMSNormWeightPlus1 , BaseAttention>
9393 {
9494 public:
95- typedef QKNormedRoPEAttention<RMSNorm , BaseAttention> Base;
95+ typedef QKNormedRoPEAttention<RMSNormWeightPlus1 , BaseAttention> Base;
9696 QwenGatedAttention (InitContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int head_dim, int max_length);
9797 int64_t get_param_num (bool effective_only) const override ;
9898 ggml::tensor *forward (ComputeContext *ctx, ggml::tensor *input, int n_past) override ;
@@ -485,7 +485,7 @@ namespace chatllm::qwen::v3_5
485485
486486 transformer = new ModelClass (&w_ctx_, config.num_hidden_layers , config.hidden_size ,
487487 vocab_size <= 0 ? create_embedding<Embedding>(&w_ctx_, config) : create_embedding<Embedding>(&w_ctx_, vocab_size, config.hidden_size ),
488- create_final_norm<RMSNorm >(&w_ctx_, config),
488+ create_final_norm<RMSNormWeightPlus1 >(&w_ctx_, config),
489489 config.tie_word_embeddings ? (Block *)nullptr : create_lm_head (&w_ctx_, config, false ),
490490 [&](InitContext *ctx, int layer_index) {
491491 return create_layer (ctx, layer_index);
@@ -509,26 +509,79 @@ namespace chatllm::qwen::v3_5
509509 return r;
510510 }
511511
512- Block * ConditionalGeneration::create_layer (InitContext *ctx , int layer_index)
512+ template < int NUM_EXPERTS , int EXPERTS_PER_TOK> class QWenSparseMoE : public BaseSparseMLP
513513 {
514- CHATLLM_CHECK (config.num_experts < 0 ) << " TODO: MoE" ;
514+ public:
515+ QWenSparseMoE (InitContext *ctx, int hidden_size, int intermediate_size)
516+ : BaseSparseMLP(ctx, hidden_size, intermediate_size, NUM_EXPERTS, EXPERTS_PER_TOK, ActFunc::SILU, false )
517+ {
518+ }
519+ };
515520
516- if (config.layer_is_la [layer_index])
521+ Block *ConditionalGeneration::create_layer (InitContext *ctx, int layer_index)
522+ {
523+ if (config.num_experts <= 0 )
517524 {
518- typedef LMBlock1<RMSNorm, QwenGatedDeltaNet, RMSNorm, SiLUMLP> Layer;
519- auto layer = new Layer (ctx, TypeLinearAttention (), config.hidden_size , config.intermediate_size ,
520- config.linear_conv_kernel_dim , config.linear_num_key_heads , config.linear_num_value_heads , config.linear_key_head_dim , config.linear_value_head_dim );
521- return layer;
525+ if (config.layer_is_la [layer_index])
526+ {
527+ typedef LMBlock1<RMSNormWeightPlus1, QwenGatedDeltaNet, RMSNormWeightPlus1, SiLUMLP> Layer;
528+ auto layer = new Layer (ctx, TypeLinearAttention (), config.hidden_size , config.intermediate_size ,
529+ config.linear_conv_kernel_dim , config.linear_num_key_heads , config.linear_num_value_heads , config.linear_key_head_dim , config.linear_value_head_dim );
530+ return layer;
531+ }
532+ else
533+ {
534+ typedef LMBlock1<RMSNormWeightPlus1, QwenGatedAttention, RMSNormWeightPlus1, SiLUMLP> Layer;
535+ auto layer = new Layer (ctx, config.hidden_size , config.num_attention_heads , config.intermediate_size , config.num_key_value_heads , config.head_dim , config.max_length );
536+ layer->attention .mrope_sections = config.mrope_sections ;
537+ layer->attention .rope_mode = RoPEMode::IMROPE;
538+ layer->attention .rope_dim = config.rope_dim ;
539+ layer->attention .freq_base = config.rope_theta ;
540+ return layer;
541+ }
522542 }
523543 else
524544 {
525- typedef LMBlock1<RMSNorm, QwenGatedAttention, RMSNorm, SiLUMLP> Layer;
526- auto layer = new Layer (ctx, config.hidden_size , config.num_attention_heads , config.intermediate_size , config.num_key_value_heads , config.head_dim , config.max_length );
527- layer->attention .mrope_sections = config.mrope_sections ;
528- layer->attention .rope_mode = RoPEMode::IMROPE;
529- layer->attention .rope_dim = config.rope_dim ;
530- layer->attention .freq_base = config.rope_theta ;
531- return layer;
545+ typedef GatedMLP<SiLUMLP> QWenGatedMLP;
546+
547+ if (config.layer_is_la [layer_index])
548+ {
549+ if ((config.num_experts == 256 ) && (config.num_experts_per_tok == 8 ))
550+ {
551+ typedef CombinedMLP<QWenSparseMoE<256 , 8 >, QWenGatedMLP> QWenMoEMLP;
552+ typedef LMBlock1<RMSNormWeightPlus1, QwenGatedDeltaNet, RMSNormWeightPlus1, QWenMoEMLP> Layer;
553+ auto layer = new Layer (ctx, TypeLinearAttention (), config.hidden_size , config.intermediate_size ,
554+ config.moe_intermediate_size , config.shared_expert_intermediate_size ,
555+ config.linear_conv_kernel_dim , config.linear_num_key_heads , config.linear_num_value_heads , config.linear_key_head_dim , config.linear_value_head_dim );
556+ return layer;
557+ }
558+ else
559+ {
560+ CHATLLM_CHECK (false ) << " unsupported MoE param: " << config.num_experts << " , " << config.num_experts_per_tok ;
561+ return nullptr ;
562+ }
563+ }
564+ else
565+ {
566+ if ((config.num_experts == 256 ) && (config.num_experts_per_tok == 8 ))
567+ {
568+ typedef CombinedMLP<QWenSparseMoE<256 , 8 >, QWenGatedMLP> QWenMoEMLP;
569+ typedef LMBlock1<RMSNormWeightPlus1, QwenGatedAttention, RMSNormWeightPlus1, QWenMoEMLP> Layer;
570+ auto layer = new Layer (ctx, config.hidden_size , config.num_attention_heads , config.intermediate_size ,
571+ config.moe_intermediate_size , config.shared_expert_intermediate_size ,
572+ config.num_key_value_heads , config.head_dim , config.max_length );
573+ layer->attention .mrope_sections = config.mrope_sections ;
574+ layer->attention .rope_mode = RoPEMode::IMROPE;
575+ layer->attention .rope_dim = config.rope_dim ;
576+ layer->attention .freq_base = config.rope_theta ;
577+ return layer;
578+ }
579+ else
580+ {
581+ CHATLLM_CHECK (false ) << " unsupported MoE param: " << config.num_experts << " , " << config.num_experts_per_tok ;
582+ return nullptr ;
583+ }
584+ }
532585 }
533586 }
534587
@@ -544,7 +597,11 @@ namespace chatllm::qwen::v3_5
544597 {" .self_attn.in_proj_qkv." , " .linear_attn.in_proj_qkv." },
545598 {" .self_attn.norm." , " .linear_attn.norm." },
546599 {" .self_attn.out_proj." , " .linear_attn.out_proj." },
600+ {" .mlp.mlp1." , " .mlp." },
601+ {" .mlp.mlp2.gate." , " .mlp.shared_expert_gate." },
602+ {" .mlp.mlp2." , " .mlp.shared_expert." },
547603 });
604+
548605 BaseModelForConditionalGeneration::load (loader);
549606
550607 loader.clear_tensor_name_translations ();
0 commit comments