From b0cae5ff6dced1df99952c5cbef6b91ddc8d37bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Fri, 5 Dec 2025 20:47:23 +0100 Subject: [PATCH 01/14] Support LongCat Image model --- flux.hpp | 44 +++++++++++++++++++++------- ggml_extend.hpp | 69 ++++++++++++++++++++++++++++++++++++++++++++ model.cpp | 29 ++++++++++++------- model.h | 11 ++++++- name_conversion.cpp | 8 ++++- stable-diffusion.cpp | 48 +++++++++++++++++++----------- 6 files changed, 169 insertions(+), 40 deletions(-) diff --git a/flux.hpp b/flux.hpp index 5d94fc85d..751787bc9 100644 --- a/flux.hpp +++ b/flux.hpp @@ -90,10 +90,15 @@ namespace Flux { SelfAttention(int64_t dim, int64_t num_heads = 8, bool qkv_bias = false, - bool proj_bias = true) + bool proj_bias = true, + bool diffusers_style = false) : num_heads(num_heads) { int64_t head_dim = dim / num_heads; - blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); + if(diffusers_style) { + blocks["qkv"] = std::shared_ptr(new SplitLinear(dim, {dim, dim, dim}, qkv_bias)); + } else { + blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); + } blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); blocks["proj"] = std::shared_ptr(new Linear(dim, dim, proj_bias)); } @@ -261,7 +266,8 @@ namespace Flux { bool share_modulation = false, bool mlp_proj_bias = true, bool use_yak_mlp = false, - bool use_mlp_silu_act = false) + bool use_mlp_silu_act = false, + bool diffusers_style = false) : idx(idx), prune_mod(prune_mod) { int64_t mlp_hidden_dim = static_cast(hidden_size * mlp_ratio); @@ -269,7 +275,7 @@ namespace Flux { blocks["img_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); } blocks["img_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias)); + blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias, diffusers_style)); blocks["img_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); if (use_yak_mlp) { @@ -282,7 +288,7 @@ namespace Flux { blocks["txt_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); } blocks["txt_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias)); + blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias, diffusers_style)); blocks["txt_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); if (use_yak_mlp) { @@ -424,6 +430,7 @@ namespace Flux { bool use_yak_mlp; bool use_mlp_silu_act; int64_t mlp_mult_factor; + bool diffusers_style = false; public: SingleStreamBlock(int64_t hidden_size, @@ -435,7 +442,8 @@ namespace Flux { bool share_modulation = false, bool mlp_proj_bias = true, bool use_yak_mlp = false, - bool use_mlp_silu_act = false) + bool use_mlp_silu_act = false, + bool diffusers_style = false) : hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod), use_yak_mlp(use_yak_mlp), use_mlp_silu_act(use_mlp_silu_act) { int64_t head_dim = hidden_size / num_heads; float scale = qk_scale; @@ -447,8 +455,11 @@ namespace Flux { if (use_yak_mlp || use_mlp_silu_act) { mlp_mult_factor = 2; } - - blocks["linear1"] = std::shared_ptr(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias)); + if (diffusers_style) { + blocks["linear1"] = std::shared_ptr(new SplitLinear(hidden_size, {hidden_size, hidden_size, hidden_size, mlp_hidden_dim * mlp_mult_factor}, mlp_proj_bias)); + } else { + blocks["linear1"] = std::shared_ptr(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias)); + } blocks["linear2"] = std::shared_ptr(new Linear(hidden_size + mlp_hidden_dim, hidden_size, mlp_proj_bias)); blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); blocks["pre_norm"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); @@ -777,6 +788,7 @@ namespace Flux { bool use_mlp_silu_act = false; float ref_index_scale = 1.f; ChromaRadianceParams chroma_radiance_params; + bool diffusers_style = false; }; struct Flux : public GGMLBlock { @@ -822,7 +834,8 @@ namespace Flux { params.share_modulation, !params.disable_bias, params.use_yak_mlp, - params.use_mlp_silu_act); + params.use_mlp_silu_act, + params.diffusers_style); } for (int i = 0; i < params.depth_single_blocks; i++) { @@ -835,7 +848,8 @@ namespace Flux { params.share_modulation, !params.disable_bias, params.use_yak_mlp, - params.use_mlp_silu_act); + params.use_mlp_silu_act, + params.diffusers_style); } if (params.version == VERSION_CHROMA_RADIANCE) { @@ -1306,6 +1320,9 @@ namespace Flux { flux_params.share_modulation = true; flux_params.ref_index_scale = 10.f; flux_params.use_mlp_silu_act = true; + } else if (sd_version_is_longcat(version)) { + flux_params.context_in_dim = 3584; + flux_params.vec_in_dim = 0; } for (auto pair : tensor_storage_map) { std::string tensor_name = pair.first; @@ -1315,6 +1332,9 @@ namespace Flux { // not schnell flux_params.guidance_embed = true; } + if (tensor_name.find("model.diffusion_model.single_blocks.0.linear1.weight.1") == std::string::npos) { + flux_params.diffusers_style = true; + } if (tensor_name.find("__x0__") != std::string::npos) { LOG_DEBUG("using x0 prediction"); flux_params.chroma_radiance_params.use_x0 = true; @@ -1353,6 +1373,10 @@ namespace Flux { LOG_INFO("Flux guidance is disabled (Schnell mode)"); } + if (flux_params.diffusers_style) { + LOG_INFO("Using diffusers-style naming"); + } + flux = Flux(flux_params); flux.init(params_ctx, tensor_storage_map, prefix); } diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 1ff450116..834e48a5b 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -2232,6 +2232,75 @@ class Linear : public UnaryBlock { } }; +class SplitLinear : public Linear { +protected: + int64_t in_features; + std::vector out_features_vec; + bool bias; + bool force_f32; + bool force_prec_f32; + float scale; + std::string prefix; + + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { + this->prefix = prefix; + enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32); + if (in_features % ggml_blck_size(wtype) != 0 || force_f32) { + wtype = GGML_TYPE_F32; + } + params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features_vec[0]); + for (int i = 1; i < out_features_vec.size(); i++) { + // most likely same type as the first weight + params["weight." + std::to_string(i)] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features_vec[i]); + } + if (bias) { + enum ggml_type wtype = GGML_TYPE_F32; + params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_features_vec[0]); + for (int i = 1; i < out_features_vec.size(); i++) { + params["bias." + std::to_string(i)] = ggml_new_tensor_1d(ctx, wtype, out_features_vec[i]); + } + } + } + +public: + SplitLinear(int64_t in_features, + std::vector out_features_vec, + bool bias = true, + bool force_f32 = false, + bool force_prec_f32 = false, + float scale = 1.f) + : Linear(in_features, out_features_vec[0], bias, force_f32, force_prec_f32, scale), + in_features(in_features), + out_features_vec(out_features_vec), + bias(bias), + force_f32(force_f32), + force_prec_f32(force_prec_f32), + scale(scale) {} + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { + struct ggml_tensor* w = params["weight"]; + struct ggml_tensor* b = nullptr; + if (bias) { + b = params["bias"]; + } + // concat all weights and biases together + for (int i = 1; i < out_features_vec.size(); i++) { + w = ggml_concat(ctx->ggml_ctx, w, params["weight." + std::to_string(i)], 1); + if (bias) { + b = ggml_concat(ctx->ggml_ctx, b, params["bias." + std::to_string(i)], 0); + } + } + if (ctx->weight_adapter) { + WeightAdapter::ForwardParams forward_params; + forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_LINEAR; + forward_params.linear.force_prec_f32 = force_prec_f32; + forward_params.linear.scale = scale; + return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params); + } + return ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale); + } +}; + __STATIC_INLINE__ bool support_get_rows(ggml_type wtype) { std::set allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0}; if (allow_types.find(wtype) != allow_types.end()) { diff --git a/model.cpp b/model.cpp index a19f180da..3c8adcd04 100644 --- a/model.cpp +++ b/model.cpp @@ -1027,7 +1027,7 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s } SDVersion ModelLoader::get_sd_version() { - TensorStorage token_embedding_weight, input_block_weight; + TensorStorage token_embedding_weight, input_block_weight, context_ebedding_weight; bool has_multiple_encoders = false; bool is_unet = false; @@ -1041,7 +1041,7 @@ SDVersion ModelLoader::get_sd_version() { for (auto& [name, tensor_storage] : tensor_storage_map) { if (!(is_xl)) { - if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { + if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos || tensor_storage.name.find("model.diffusion_model.single_transformer_blocks.") != std::string::npos) { is_flux = true; } if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) { @@ -1108,6 +1108,9 @@ SDVersion ModelLoader::get_sd_version() { tensor_storage.name == "unet.conv_in.weight") { input_block_weight = tensor_storage; } + if (tensor_storage.name == "model.diffusion_model.txt_in.weight" || tensor_storage.name == "model.diffusion_model.context_embedder.weight") { + context_ebedding_weight = tensor_storage; + } } if (is_wan) { LOG_DEBUG("patch_embedding_channels %d", patch_embedding_channels); @@ -1135,16 +1138,20 @@ SDVersion ModelLoader::get_sd_version() { } if (is_flux) { - if (input_block_weight.ne[0] == 384) { - return VERSION_FLUX_FILL; - } - if (input_block_weight.ne[0] == 128) { - return VERSION_FLUX_CONTROLS; - } - if (input_block_weight.ne[0] == 196) { - return VERSION_FLEX_2; + if (context_ebedding_weight.ne[0] == 3584) { + return VERSION_LONGCAT; + } else { + if (input_block_weight.ne[0] == 384) { + return VERSION_FLUX_FILL; + } + if (input_block_weight.ne[0] == 128) { + return VERSION_FLUX_CONTROLS; + } + if (input_block_weight.ne[0] == 196) { + return VERSION_FLEX_2; + } + return VERSION_FLUX; } - return VERSION_FLUX; } if (token_embedding_weight.ne[0] == 768) { diff --git a/model.h b/model.h index b9e50ad63..1d0c6bca0 100644 --- a/model.h +++ b/model.h @@ -46,6 +46,7 @@ enum SDVersion { VERSION_FLUX2, VERSION_Z_IMAGE, VERSION_OVIS_IMAGE, + VERSION_LONGCAT, VERSION_COUNT, }; @@ -126,6 +127,13 @@ static inline bool sd_version_is_z_image(SDVersion version) { return false; } +static inline bool sd_version_is_longcat(SDVersion version) { + if (version == VERSION_LONGCAT) { + return true; + } + return false; +} + static inline bool sd_version_is_inpaint(SDVersion version) { if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || @@ -143,7 +151,8 @@ static inline bool sd_version_is_dit(SDVersion version) { sd_version_is_sd3(version) || sd_version_is_wan(version) || sd_version_is_qwen_image(version) || - sd_version_is_z_image(version)) { + sd_version_is_z_image(version) || + sd_version_is_longcat(version)) { return true; } return false; diff --git a/name_conversion.cpp b/name_conversion.cpp index 3ae229b63..dad35f4d3 100644 --- a/name_conversion.cpp +++ b/name_conversion.cpp @@ -508,6 +508,12 @@ std::string convert_diffusers_dit_to_original_flux(std::string name) { static std::unordered_map flux_name_map; if (flux_name_map.empty()) { + // --- time_embed (longcat) --- + flux_name_map["time_embed.timestep_embedder.linear_1.weight"] = "time_in.in_layer.weight"; + flux_name_map["time_embed.timestep_embedder.linear_1.bias"] = "time_in.in_layer.bias"; + flux_name_map["time_embed.timestep_embedder.linear_2.weight"] = "time_in.out_layer.weight"; + flux_name_map["time_embed.timestep_embedder.linear_2.bias"] = "time_in.out_layer.bias"; + // --- time_text_embed --- flux_name_map["time_text_embed.timestep_embedder.linear_1.weight"] = "time_in.in_layer.weight"; flux_name_map["time_text_embed.timestep_embedder.linear_1.bias"] = "time_in.in_layer.bias"; @@ -660,7 +666,7 @@ std::string convert_diffusion_model_name(std::string name, std::string prefix, S name = convert_diffusers_unet_to_original_sdxl(name); } else if (sd_version_is_sd3(version)) { name = convert_diffusers_dit_to_original_sd3(name); - } else if (sd_version_is_flux(version) || sd_version_is_flux2(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_flux2(version) || sd_version_is_longcat(version)) { name = convert_diffusers_dit_to_original_flux(name); } else if (sd_version_is_z_image(version)) { name = convert_diffusers_dit_to_original_lumina2(name); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 75689ff94..02855d4b7 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -49,6 +49,7 @@ const char* model_version_to_str[] = { "Flux.2", "Z-Image", "Ovis Image", + "Longcat-Image", }; const char* sampling_methods_str[] = { @@ -392,7 +393,7 @@ class StableDiffusionGGML { } else if (sd_version_is_sd3(version)) { scale_factor = 1.5305f; shift_factor = 0.0609f; - } else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_longcat(version)) { scale_factor = 0.3611f; shift_factor = 0.1159f; } else if (sd_version_is_wan(version) || @@ -424,8 +425,8 @@ class StableDiffusionGGML { offload_params_to_cpu, tensor_storage_map); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map); + offload_params_to_cpu, + tensor_storage_map); } else if (sd_version_is_flux(version)) { bool is_chroma = false; for (auto pair : tensor_storage_map) { @@ -473,10 +474,23 @@ class StableDiffusionGGML { tensor_storage_map, version); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - version, - sd_ctx_params->chroma_use_dit_mask); + offload_params_to_cpu, + tensor_storage_map, + version, + sd_ctx_params->chroma_use_dit_mask); + } else if (sd_version_is_longcat(version)) { + bool enable_vision = false; + cond_stage_model = std::make_shared(clip_backend, + offload_params_to_cpu, + tensor_storage_map, + version, + "", + enable_vision); + diffusion_model = std::make_shared(backend, + offload_params_to_cpu, + tensor_storage_map, + version, + sd_ctx_params->chroma_use_dit_mask); } else if (sd_version_is_wan(version)) { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, @@ -485,10 +499,10 @@ class StableDiffusionGGML { 1, true); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - "model.diffusion_model", - version); + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model", + version); if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) { high_noise_diffusion_model = std::make_shared(backend, offload_params_to_cpu, @@ -528,10 +542,10 @@ class StableDiffusionGGML { tensor_storage_map, version); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - "model.diffusion_model", - version); + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model", + version); } else { // SD1.x SD2.x SDXL std::map embbeding_map; for (uint32_t i = 0; i < sd_ctx_params->embedding_count; i++) { @@ -883,7 +897,7 @@ class StableDiffusionGGML { flow_shift = 3.f; } } - } else if (sd_version_is_flux(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_longcat(version)) { pred_type = FLUX_FLOW_PRED; if (flow_shift == INFINITY) { @@ -1473,7 +1487,7 @@ class StableDiffusionGGML { if (sd_version_is_sd3(version)) { latent_rgb_proj = sd3_latent_rgb_proj; latent_rgb_bias = sd3_latent_rgb_bias; - } else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_longcat(version)) { latent_rgb_proj = flux_latent_rgb_proj; latent_rgb_bias = flux_latent_rgb_bias; } else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) { From 2247731cfebec52924bb41920d3bcdd7d79ad93b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Fri, 5 Dec 2025 20:58:53 +0100 Subject: [PATCH 02/14] temp fix cuda error on quant concat for splitlinear --- ggml_extend.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 834e48a5b..1ff2af055 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -2273,7 +2273,7 @@ class SplitLinear : public Linear { in_features(in_features), out_features_vec(out_features_vec), bias(bias), - force_f32(force_f32), + force_f32(true), force_prec_f32(force_prec_f32), scale(scale) {} From f1d0c955ed8648d7b73e7b29f7ad890cad4f1ff6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Sat, 6 Dec 2025 02:43:46 +0100 Subject: [PATCH 03/14] pre-patchify --- flux.hpp | 1 + stable-diffusion.cpp | 21 ++++++++++++++++----- vae.hpp | 4 ++-- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/flux.hpp b/flux.hpp index 751787bc9..a90e60e81 100644 --- a/flux.hpp +++ b/flux.hpp @@ -1323,6 +1323,7 @@ namespace Flux { } else if (sd_version_is_longcat(version)) { flux_params.context_in_dim = 3584; flux_params.vec_in_dim = 0; + flux_params.patch_size = 1; } for (auto pair : tensor_storage_map) { std::string tensor_name = pair.first; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 02855d4b7..232e24bb0 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -474,10 +474,10 @@ class StableDiffusionGGML { tensor_storage_map, version); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - version, - sd_ctx_params->chroma_use_dit_mask); + offload_params_to_cpu, + tensor_storage_map, + version, + sd_ctx_params->chroma_use_dit_mask); } else if (sd_version_is_longcat(version)) { bool enable_vision = false; cond_stage_model = std::make_shared(clip_backend, @@ -907,6 +907,9 @@ class StableDiffusionGGML { flow_shift = 1.15f; } } + if(sd_version_is_longcat(version)) { + flow_shift = 3.0f; + } } } else if (sd_version_is_flux2(version)) { pred_type = FLUX2_FLOW_PRED; @@ -1470,6 +1473,12 @@ class StableDiffusionGGML { if (sd_version_is_flux2(version)) { latent_rgb_proj = flux2_latent_rgb_proj; latent_rgb_bias = flux2_latent_rgb_bias; + patch_sz = 2; + } + } else if (dim == 64) { + if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_longcat(version)) { + latent_rgb_proj = flux_latent_rgb_proj; + latent_rgb_bias = flux_latent_rgb_bias; patch_sz = 2; } } else if (dim == 48) { @@ -2258,7 +2267,7 @@ class StableDiffusionGGML { int vae_scale_factor = 8; if (version == VERSION_WAN2_2_TI2V) { vae_scale_factor = 16; - } else if (sd_version_is_flux2(version)) { + } else if (sd_version_is_flux2(version) || sd_version_is_longcat(version)) { vae_scale_factor = 16; } else if (version == VERSION_CHROMA_RADIANCE) { vae_scale_factor = 1; @@ -2287,6 +2296,8 @@ class StableDiffusionGGML { latent_channel = 3; } else if (sd_version_is_flux2(version)) { latent_channel = 128; + } else if (sd_version_is_longcat(version)) { + latent_channel = 64; } else { latent_channel = 16; } diff --git a/vae.hpp b/vae.hpp index cd055aa86..a007592c1 100644 --- a/vae.hpp +++ b/vae.hpp @@ -553,7 +553,7 @@ class AutoencodingEngine : public GGMLBlock { struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) { // z: [N, z_channels, h, w] - if (sd_version_is_flux2(version)) { + if (sd_version_is_flux2(version) || sd_version_is_longcat(version)) { // [N, C*p*p, h, w] -> [N, C, h*p, w*p] int64_t p = 2; @@ -592,7 +592,7 @@ class AutoencodingEngine : public GGMLBlock { auto quant_conv = std::dynamic_pointer_cast(blocks["quant_conv"]); z = quant_conv->forward(ctx, z); // [N, 2*embed_dim, h/8, w/8] } - if (sd_version_is_flux2(version)) { + if (sd_version_is_flux2(version) || sd_version_is_longcat(version)) { z = ggml_ext_chunk(ctx->ggml_ctx, z, 2, 2)[0]; // [N, C, H, W] -> [N, C*p*p, H/p, W/p] From 3c6c05b463fc3f55cd5c4db598aff7f576d01f05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Sat, 6 Dec 2025 02:44:20 +0100 Subject: [PATCH 04/14] longcat rope ids --- conditioner.hpp | 11 +++++++++++ flux.hpp | 7 +++---- ggml_extend.hpp | 26 +++++++++++++++++--------- rope.hpp | 35 +++++++++++++++++++++++++---------- 4 files changed, 56 insertions(+), 23 deletions(-) diff --git a/conditioner.hpp b/conditioner.hpp index b6d5646a7..b1813259e 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -1815,6 +1815,17 @@ struct LLMEmbedder : public Conditioner { prompt_attn_range.second = static_cast(prompt.size()); prompt += "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; + } else if (sd_version_is_longcat(version)) { + prompt_template_encode_start_idx = 36; + // prompt_template_encode_end_idx = 5; + + prompt = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"; + + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); + + prompt += "<|im_end|>\n<|im_start|>assistant\n"; } else { prompt_template_encode_start_idx = 34; diff --git a/flux.hpp b/flux.hpp index a90e60e81..31cb0825c 100644 --- a/flux.hpp +++ b/flux.hpp @@ -1375,7 +1375,7 @@ namespace Flux { } if (flux_params.diffusers_style) { - LOG_INFO("Using diffusers-style naming"); + LOG_INFO("Using diffusers-style attention blocks"); } flux = Flux(flux_params); @@ -1489,7 +1489,6 @@ namespace Flux { } else if (version == VERSION_OVIS_IMAGE) { txt_arange_dims = {1, 2}; } - pe_vec = Rope::gen_flux_pe(static_cast(x->ne[1]), static_cast(x->ne[0]), flux_params.patch_size, @@ -1502,9 +1501,9 @@ namespace Flux { flux_params.theta, circular_y_enabled, circular_x_enabled, - flux_params.axes_dim); + flux_params.axes_dim, + sd_version_is_longcat(version)); int pos_len = static_cast(pe_vec.size() / flux_params.axes_dim_sum / 2); - // LOG_DEBUG("pos_len %d", pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len); // pe->data = pe_vec.data(); // print_ggml_tensor(pe); diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 1ff2af055..8f1bf5a23 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -2273,7 +2273,7 @@ class SplitLinear : public Linear { in_features(in_features), out_features_vec(out_features_vec), bias(bias), - force_f32(true), + force_f32(force_f32), force_prec_f32(force_prec_f32), scale(scale) {} @@ -2283,21 +2283,29 @@ class SplitLinear : public Linear { if (bias) { b = params["bias"]; } - // concat all weights and biases together - for (int i = 1; i < out_features_vec.size(); i++) { - w = ggml_concat(ctx->ggml_ctx, w, params["weight." + std::to_string(i)], 1); - if (bias) { - b = ggml_concat(ctx->ggml_ctx, b, params["bias." + std::to_string(i)], 0); - } - } if (ctx->weight_adapter) { + // concat all weights and biases together so it runs in one linear layer + for (int i = 1; i < out_features_vec.size(); i++) { + w = ggml_concat(ctx->ggml_ctx, w, params["weight." + std::to_string(i)], 1); + if (bias) { + b = ggml_concat(ctx->ggml_ctx, b, params["bias." + std::to_string(i)], 0); + } + } WeightAdapter::ForwardParams forward_params; forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_LINEAR; forward_params.linear.force_prec_f32 = force_prec_f32; forward_params.linear.scale = scale; return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params); } - return ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale); + auto x0 = ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale); + for (int i = 1; i < out_features_vec.size(); i++) { + auto wi = params["weight." + std::to_string(i)]; + auto bi = bias ? params["bias." + std::to_string(i)] : nullptr; + auto xi = ggml_ext_linear(ctx->ggml_ctx, x, wi, bi, force_prec_f32, scale); + x0 = ggml_concat(ctx->ggml_ctx, x0, xi, 0); + } + + return x0; } }; diff --git a/rope.hpp b/rope.hpp index 2d123b3cc..599a2ab0a 100644 --- a/rope.hpp +++ b/rope.hpp @@ -106,7 +106,16 @@ namespace Rope { return txt_ids; } - __STATIC_INLINE__ std::vector> gen_flux_img_ids(int h, + __STATIC_INLINE__ std::vector> gen_longcat_txt_ids(int bs, int context_len, int axes_dim_num) { + auto txt_ids = std::vector>(bs * context_len, std::vector(axes_dim_num, 0.0f)); + for (int i = 0; i < bs * context_len; i++) { + txt_ids[i][1] = (i % context_len); + txt_ids[i][2] = (i % context_len); + } + return txt_ids; + } + + __STATIC_INLINE__ std::vector> gen_flux_img_ids(int h, int w, int patch_size, int bs, @@ -117,7 +126,6 @@ namespace Rope { bool scale_rope = false) { int h_len = (h + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size; - std::vector> img_ids(h_len * w_len, std::vector(axes_dim_num, 0.0)); int h_start = h_offset; @@ -206,6 +214,7 @@ namespace Rope { __STATIC_INLINE__ std::vector> gen_refs_ids(int patch_size, int bs, int axes_dim_num, + int start_index, const std::vector& ref_latents, bool increase_ref_index, float ref_index_scale, @@ -213,7 +222,7 @@ namespace Rope { std::vector> ids; int curr_h_offset = 0; int curr_w_offset = 0; - int index = 1; + int index = start_index; for (ggml_tensor* ref : ref_latents) { int h_offset = 0; int w_offset = 0; @@ -256,13 +265,17 @@ namespace Rope { std::set txt_arange_dims, const std::vector& ref_latents, bool increase_ref_index, - float ref_index_scale) { - auto txt_ids = gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims); - auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num); + float ref_index_scale, + bool is_longcat) { + int start_index = is_longcat ? 1 : 0; + + auto txt_ids = is_longcat ? gen_longcat_txt_ids(bs, context_len, axes_dim_num) : gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims); + int offset = is_longcat ? context_len : 0; + auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, start_index, offset, offset); auto ids = concat_ids(txt_ids, img_ids, bs); if (ref_latents.size() > 0) { - auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale, false); + auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, start_index + 1, ref_latents, increase_ref_index, ref_index_scale, false); ids = concat_ids(ids, refs_ids, bs); } return ids; @@ -281,7 +294,8 @@ namespace Rope { int theta, bool circular_h, bool circular_w, - const std::vector& axes_dim) { + const std::vector& axes_dim, + bool is_longcat) { std::vector> ids = gen_flux_ids(h, w, patch_size, @@ -291,7 +305,8 @@ namespace Rope { txt_arange_dims, ref_latents, increase_ref_index, - ref_index_scale); + ref_index_scale, + is_longcat); std::vector> wrap_dims; if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) { int h_len = (h + (patch_size / 2)) / patch_size; @@ -356,7 +371,7 @@ namespace Rope { auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, 0, 0, 0, true); auto ids = concat_ids(txt_ids_repeated, img_ids, bs); if (ref_latents.size() > 0) { - auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1.f, true); + auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, 1, ref_latents, increase_ref_index, 1.f, true); ids = concat_ids(ids, refs_ids, bs); } return ids; From 0be35064b13a9264fa67c511fe3714af627de3a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Sat, 6 Dec 2025 03:47:52 +0100 Subject: [PATCH 05/14] Fix diffusers_style detection --- flux.hpp | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/flux.hpp b/flux.hpp index 31cb0825c..55876bd25 100644 --- a/flux.hpp +++ b/flux.hpp @@ -88,19 +88,19 @@ namespace Flux { public: SelfAttention(int64_t dim, - int64_t num_heads = 8, - bool qkv_bias = false, - bool proj_bias = true, - bool diffusers_style = false) + int64_t num_heads = 8, + bool qkv_bias = false, + bool proj_bias = true, + bool diffusers_style = false) : num_heads(num_heads) { int64_t head_dim = dim / num_heads; - if(diffusers_style) { - blocks["qkv"] = std::shared_ptr(new SplitLinear(dim, {dim, dim, dim}, qkv_bias)); + if (diffusers_style) { + blocks["qkv"] = std::shared_ptr(new SplitLinear(dim, {dim, dim, dim}, qkv_bias)); } else { - blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); + blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); } - blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); - blocks["proj"] = std::shared_ptr(new Linear(dim, dim, proj_bias)); + blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); + blocks["proj"] = std::shared_ptr(new Linear(dim, dim, proj_bias)); } std::vector pre_attention(GGMLRunnerContext* ctx, struct ggml_tensor* x) { @@ -787,8 +787,8 @@ namespace Flux { bool use_yak_mlp = false; bool use_mlp_silu_act = false; float ref_index_scale = 1.f; + bool diffusers_style = false; ChromaRadianceParams chroma_radiance_params; - bool diffusers_style = false; }; struct Flux : public GGMLBlock { @@ -1333,7 +1333,7 @@ namespace Flux { // not schnell flux_params.guidance_embed = true; } - if (tensor_name.find("model.diffusion_model.single_blocks.0.linear1.weight.1") == std::string::npos) { + if (tensor_name.find("model.diffusion_model.single_blocks.0.linear1.weight.1") != std::string::npos) { flux_params.diffusers_style = true; } if (tensor_name.find("__x0__") != std::string::npos) { @@ -1502,9 +1502,9 @@ namespace Flux { circular_y_enabled, circular_x_enabled, flux_params.axes_dim, - sd_version_is_longcat(version)); + sd_version_is_longcat(version)); int pos_len = static_cast(pe_vec.size() / flux_params.axes_dim_sum / 2); - auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len); + auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len); // pe->data = pe_vec.data(); // print_ggml_tensor(pe); // pe->data = nullptr; From 56421a9328a21037912f0f3f021a73dceb1e4983 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Sat, 6 Dec 2025 16:05:58 +0100 Subject: [PATCH 06/14] Flux: simplify when patch_size is 1 --- flux.hpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/flux.hpp b/flux.hpp index 55876bd25..eb8283f14 100644 --- a/flux.hpp +++ b/flux.hpp @@ -896,6 +896,11 @@ namespace Flux { int64_t C = x->ne[2]; int64_t H = x->ne[1]; int64_t W = x->ne[0]; + if (params.patch_size == 1) { + x = ggml_reshape_3d(ctx, x, H * W, C, N); // [N, C, H*W] + x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, H*W, C] + return x; + } int64_t p = params.patch_size; int64_t h = H / params.patch_size; int64_t w = W / params.patch_size; @@ -930,6 +935,12 @@ namespace Flux { int64_t W = w * params.patch_size; int64_t p = params.patch_size; + if (params.patch_size == 1) { + x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, C, H*W] + x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, H, W] + return x; + } + GGML_ASSERT(C * p * p == x->ne[0]); x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p] From 410d2697729343e5432efe5180e385a102c0cf48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Sat, 6 Dec 2025 16:06:32 +0100 Subject: [PATCH 07/14] correct rope offset for image tokens stuff --- ggml_extend.hpp | 12 ++++++------ rope.hpp | 13 +++++++------ stable-diffusion.cpp | 6 +++++- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 8f1bf5a23..7a0213724 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -2297,15 +2297,15 @@ class SplitLinear : public Linear { forward_params.linear.scale = scale; return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params); } - auto x0 = ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale); + auto out = ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale); for (int i = 1; i < out_features_vec.size(); i++) { - auto wi = params["weight." + std::to_string(i)]; - auto bi = bias ? params["bias." + std::to_string(i)] : nullptr; - auto xi = ggml_ext_linear(ctx->ggml_ctx, x, wi, bi, force_prec_f32, scale); - x0 = ggml_concat(ctx->ggml_ctx, x0, xi, 0); + auto wi = params["weight." + std::to_string(i)]; + auto bi = bias ? params["bias." + std::to_string(i)] : nullptr; + auto curr_out = ggml_ext_linear(ctx->ggml_ctx, x, wi, bi, force_prec_f32, scale); + out = ggml_concat(ctx->ggml_ctx, out, curr_out, 0); } - return x0; + return out; } }; diff --git a/rope.hpp b/rope.hpp index 599a2ab0a..2ee1d987a 100644 --- a/rope.hpp +++ b/rope.hpp @@ -218,10 +218,11 @@ namespace Rope { const std::vector& ref_latents, bool increase_ref_index, float ref_index_scale, - bool scale_rope) { + bool scale_rope, + int base_offset = 0) { std::vector> ids; - int curr_h_offset = 0; - int curr_w_offset = 0; + int curr_h_offset = base_offset; + int curr_w_offset = base_offset; int index = start_index; for (ggml_tensor* ref : ref_latents) { int h_offset = 0; @@ -267,15 +268,15 @@ namespace Rope { bool increase_ref_index, float ref_index_scale, bool is_longcat) { - int start_index = is_longcat ? 1 : 0; + int x_index = is_longcat ? 1 : 0; auto txt_ids = is_longcat ? gen_longcat_txt_ids(bs, context_len, axes_dim_num) : gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims); int offset = is_longcat ? context_len : 0; - auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, start_index, offset, offset); + auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, x_index, offset, offset); auto ids = concat_ids(txt_ids, img_ids, bs); if (ref_latents.size() > 0) { - auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, start_index + 1, ref_latents, increase_ref_index, ref_index_scale, false); + auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, x_index + 1, ref_latents, increase_ref_index, ref_index_scale, false, offset); ids = concat_ids(ids, refs_ids, bs); } return ids; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 232e24bb0..e73a2a665 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -480,6 +480,9 @@ class StableDiffusionGGML { sd_ctx_params->chroma_use_dit_mask); } else if (sd_version_is_longcat(version)) { bool enable_vision = false; + if (!vae_decode_only) { + enable_vision = true; + } cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, tensor_storage_map, @@ -907,7 +910,7 @@ class StableDiffusionGGML { flow_shift = 1.15f; } } - if(sd_version_is_longcat(version)) { + if (sd_version_is_longcat(version)) { flow_shift = 3.0f; } } @@ -2598,6 +2601,7 @@ class StableDiffusionGGML { sd_version_is_qwen_image(version) || sd_version_is_wan(version) || sd_version_is_flux2(version) || + sd_version_is_longcat(version) || version == VERSION_CHROMA_RADIANCE) { latent = vae_output; } else if (version == VERSION_SD1_PIX2PIX) { From c01da4f047c562cf7d6a59c4f6d2575050007b1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Mon, 8 Dec 2025 01:36:17 +0100 Subject: [PATCH 08/14] Fix token length --- conditioner.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/conditioner.hpp b/conditioner.hpp index b1813259e..a078a7a6a 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -1706,7 +1706,7 @@ struct LLMEmbedder : public Conditioner { std::vector> image_embeds; std::pair prompt_attn_range; int prompt_template_encode_start_idx = 34; - int max_length = 0; + int max_length = 0; std::set out_layers; if (llm->enable_vision && conditioner_params.ref_images.size() > 0) { LOG_INFO("QwenImageEditPlusPipeline"); @@ -1818,6 +1818,7 @@ struct LLMEmbedder : public Conditioner { } else if (sd_version_is_longcat(version)) { prompt_template_encode_start_idx = 36; // prompt_template_encode_end_idx = 5; + max_length = 512; prompt = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"; From 17f76f55ad1681f1af4a5f29f754dfe293bf0628 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Mon, 8 Dec 2025 02:28:29 +0100 Subject: [PATCH 09/14] Split quoted text into character-level tokens remove debug logs --- conditioner.hpp | 105 +++++++++++++++++++++++++++++++++++------------- 1 file changed, 76 insertions(+), 29 deletions(-) diff --git a/conditioner.hpp b/conditioner.hpp index a078a7a6a..36999f7c0 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -1656,46 +1656,91 @@ struct LLMEmbedder : public Conditioner { } } - std::tuple, std::vector> tokenize(std::string text, - std::pair attn_range, - size_t max_length = 0, - bool padding = false) { + std::tuple, std::vector> tokenize( + std::string text, + std::pair attn_range, + size_t max_length = 0, + bool padding = false, + bool spell_quotes = false) { std::vector> parsed_attention; parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f); + if (attn_range.second - attn_range.first > 0) { - auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first)); - parsed_attention.insert(parsed_attention.end(), - new_parsed_attention.begin(), - new_parsed_attention.end()); + auto new_parsed_attention = parse_prompt_attention( + text.substr(attn_range.first, attn_range.second - attn_range.first)); + parsed_attention.insert( + parsed_attention.end(), + new_parsed_attention.begin(), + new_parsed_attention.end()); } parsed_attention.emplace_back(text.substr(attn_range.second), 1.f); - { - std::stringstream ss; - ss << "["; - for (const auto& item : parsed_attention) { - ss << "['" << item.first << "', " << item.second << "], "; - } - ss << "]"; - LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str()); - } + + // { + // std::stringstream ss; + // ss << '['; + // for (const auto& item : parsed_attention) { + // ss << "['" << item.first << "', " << item.second << "], "; + // } + // ss << ']'; + // LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str()); + // } std::vector tokens; std::vector weights; + for (const auto& item : parsed_attention) { const std::string& curr_text = item.first; float curr_weight = item.second; - std::vector curr_tokens = tokenizer->tokenize(curr_text, nullptr); - tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); - weights.insert(weights.end(), curr_tokens.size(), curr_weight); - } - tokenizer->pad_tokens(tokens, weights, max_length, padding); + if (spell_quotes) { + std::vector parts; + bool in_quote = false; + std::string current_part; - // for (int i = 0; i < tokens.size(); i++) { - // std::cout << tokens[i] << ":" << weights[i] << ", " << i << std::endl; - // } - // std::cout << std::endl; + for (char c : curr_text) { + if (c == '\'') { + if (!current_part.empty()) { + parts.push_back(current_part); + current_part.clear(); + } + in_quote = !in_quote; + } else { + current_part += c; + if (in_quote && current_part.size() == 1) { + parts.push_back(current_part); + current_part.clear(); + } + } + } + if (!current_part.empty()) { + parts.push_back(current_part); + } + for (const auto& part : parts) { + if (part.empty()) + continue; + if (part[0] == '\'' && part.back() == '\'') { + std::string quoted_content = part.substr(1, part.size() - 2); + for (char ch : quoted_content) { + std::string char_str(1, ch); + std::vector char_tokens = tokenizer->tokenize(char_str, nullptr); + tokens.insert(tokens.end(), char_tokens.begin(), char_tokens.end()); + weights.insert(weights.end(), char_tokens.size(), curr_weight); + } + } else { + std::vector part_tokens = tokenizer->tokenize(part, nullptr); + tokens.insert(tokens.end(), part_tokens.begin(), part_tokens.end()); + weights.insert(weights.end(), part_tokens.size(), curr_weight); + } + } + } else { + std::vector curr_tokens = tokenizer->tokenize(curr_text, nullptr); + tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); + weights.insert(weights.end(), curr_tokens.size(), curr_weight); + } + } + + tokenizer->pad_tokens(tokens, weights, max_length, padding); return {tokens, weights}; } @@ -1706,7 +1751,8 @@ struct LLMEmbedder : public Conditioner { std::vector> image_embeds; std::pair prompt_attn_range; int prompt_template_encode_start_idx = 34; - int max_length = 0; + int max_length = 0; + bool spell_quotes = false; std::set out_layers; if (llm->enable_vision && conditioner_params.ref_images.size() > 0) { LOG_INFO("QwenImageEditPlusPipeline"); @@ -1818,7 +1864,8 @@ struct LLMEmbedder : public Conditioner { } else if (sd_version_is_longcat(version)) { prompt_template_encode_start_idx = 36; // prompt_template_encode_end_idx = 5; - max_length = 512; + max_length = 512; + spell_quotes = true; prompt = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"; @@ -1839,7 +1886,7 @@ struct LLMEmbedder : public Conditioner { prompt += "<|im_end|>\n<|im_start|>assistant\n"; } - auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0); + auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0, spell_quotes); auto& tokens = std::get<0>(tokens_and_weights); auto& weights = std::get<1>(tokens_and_weights); From 358bb2cce89349e3a561f78edc49710dd7193791 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Mon, 8 Dec 2025 13:43:37 +0100 Subject: [PATCH 10/14] support longcat-image-edit Fix base rope offset for ref images --- conditioner.hpp | 179 +++++++++++++++++++++++++++++++++--------------- rope.hpp | 13 ++-- 2 files changed, 131 insertions(+), 61 deletions(-) diff --git a/conditioner.hpp b/conditioner.hpp index 36999f7c0..e7c034495 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -1698,7 +1698,7 @@ struct LLMEmbedder : public Conditioner { std::string current_part; for (char c : curr_text) { - if (c == '\'') { + if (c == '"') { if (!current_part.empty()) { parts.push_back(current_part); current_part.clear(); @@ -1719,7 +1719,7 @@ struct LLMEmbedder : public Conditioner { for (const auto& part : parts) { if (part.empty()) continue; - if (part[0] == '\'' && part.back() == '\'') { + if (part[0] == '"' && part.back() == '"') { std::string quoted_content = part.substr(1, part.size() - 2); for (char ch : quoted_content) { std::string char_str(1, ch); @@ -1755,68 +1755,139 @@ struct LLMEmbedder : public Conditioner { bool spell_quotes = false; std::set out_layers; if (llm->enable_vision && conditioner_params.ref_images.size() > 0) { - LOG_INFO("QwenImageEditPlusPipeline"); - prompt_template_encode_start_idx = 64; - int image_embed_idx = 64 + 6; - - int min_pixels = 384 * 384; - int max_pixels = 560 * 560; - std::string placeholder = "<|image_pad|>"; - std::string img_prompt; - - for (int i = 0; i < conditioner_params.ref_images.size(); i++) { - sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]); - double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size; - int height = image.height; - int width = image.width; - int h_bar = static_cast(std::round(height / factor) * factor); - int w_bar = static_cast(std::round(width / factor) * factor); - - if (static_cast(h_bar) * w_bar > max_pixels) { - double beta = std::sqrt((height * width) / static_cast(max_pixels)); - h_bar = std::max(static_cast(factor), - static_cast(std::floor(height / beta / factor)) * static_cast(factor)); - w_bar = std::max(static_cast(factor), - static_cast(std::floor(width / beta / factor)) * static_cast(factor)); - } else if (static_cast(h_bar) * w_bar < min_pixels) { - double beta = std::sqrt(static_cast(min_pixels) / (height * width)); - h_bar = static_cast(std::ceil(height * beta / factor)) * static_cast(factor); - w_bar = static_cast(std::ceil(width * beta / factor)) * static_cast(factor); + if (sd_version_is_longcat(version)) { + LOG_INFO("LongCatEditPipeline"); + prompt_template_encode_start_idx = 67; + // prompt_template_encode_end_idx = 5; + int image_embed_idx = 36 + 6; + + int min_pixels = 384 * 384; + int max_pixels = 560 * 560; + std::string placeholder = "<|image_pad|>"; + std::string img_prompt; + + + // Only one image is officicially supported by the model, not sure how it handles multiple images + for (int i = 0; i < conditioner_params.ref_images.size(); i++) { + sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]); + double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size; + int height = image.height; + int width = image.width; + int h_bar = static_cast(std::round(height / factor)) * factor; + int w_bar = static_cast(std::round(width / factor)) * factor; + + if (static_cast(h_bar) * w_bar > max_pixels) { + double beta = std::sqrt((height * width) / static_cast(max_pixels)); + h_bar = std::max(static_cast(factor), + static_cast(std::floor(height / beta / factor)) * static_cast(factor)); + w_bar = std::max(static_cast(factor), + static_cast(std::floor(width / beta / factor)) * static_cast(factor)); + } else if (static_cast(h_bar) * w_bar < min_pixels) { + double beta = std::sqrt(static_cast(min_pixels) / (height * width)); + h_bar = static_cast(std::ceil(height * beta / factor)) * static_cast(factor); + w_bar = static_cast(std::ceil(width * beta / factor)) * static_cast(factor); + } + + LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar); + + sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar); + free(image.data); + image.data = nullptr; + + ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); + sd_image_f32_to_ggml_tensor(resized_image, image_tensor, false); + free(resized_image.data); + resized_image.data = nullptr; + + ggml_tensor* image_embed = nullptr; + llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx); + image_embeds.emplace_back(image_embed_idx, image_embed); + image_embed_idx += 1 + image_embed->ne[1] + 6; + + img_prompt += "<|vision_start|>"; + int64_t num_image_tokens = image_embed->ne[1]; + img_prompt.reserve(num_image_tokens * placeholder.size()); + for (int j = 0; j < num_image_tokens; j++) { + img_prompt += placeholder; + } + img_prompt += "<|vision_end|>"; } - LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar); + max_length = 512; + spell_quotes = true; + prompt = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n"; + prompt += img_prompt; + + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); + + prompt += "<|im_end|>\n<|im_start|>assistant\n"; + + } else { + LOG_INFO("QwenImageEditPlusPipeline"); + prompt_template_encode_start_idx = 64; + int image_embed_idx = 64 + 6; + + int min_pixels = 384 * 384; + int max_pixels = 560 * 560; + std::string placeholder = "<|image_pad|>"; + std::string img_prompt; + + for (int i = 0; i < conditioner_params.ref_images.size(); i++) { + sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]); + double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size; + int height = image.height; + int width = image.width; + int h_bar = static_cast(std::round(height / factor) * factor); + int w_bar = static_cast(std::round(width / factor) * factor); + + if (static_cast(h_bar) * w_bar > max_pixels) { + double beta = std::sqrt((height * width) / static_cast(max_pixels)); + h_bar = std::max(static_cast(factor), + static_cast(std::floor(height / beta / factor)) * static_cast(factor)); + w_bar = std::max(static_cast(factor), + static_cast(std::floor(width / beta / factor)) * static_cast(factor)); + } else if (static_cast(h_bar) * w_bar < min_pixels) { + double beta = std::sqrt(static_cast(min_pixels) / (height * width)); + h_bar = static_cast(std::ceil(height * beta / factor)) * static_cast(factor); + w_bar = static_cast(std::ceil(width * beta / factor)) * static_cast(factor); + } + + LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar); - sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar); - free(image.data); - image.data = nullptr; + sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar); + free(image.data); + image.data = nullptr; - ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); - sd_image_f32_to_ggml_tensor(resized_image, image_tensor, false); - free(resized_image.data); - resized_image.data = nullptr; + ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); + sd_image_f32_to_ggml_tensor(resized_image, image_tensor, false); + free(resized_image.data); + resized_image.data = nullptr; - ggml_tensor* image_embed = nullptr; - llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx); - image_embeds.emplace_back(image_embed_idx, image_embed); - image_embed_idx += 1 + static_cast(image_embed->ne[1]) + 6; + ggml_tensor* image_embed = nullptr; + llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx); + image_embeds.emplace_back(image_embed_idx, image_embed); + image_embed_idx += 1 + static_cast(image_embed->ne[1]) + 6; - img_prompt += "Picture " + std::to_string(i + 1) + ": <|vision_start|>"; // [24669, 220, index, 25, 220, 151652] - int64_t num_image_tokens = image_embed->ne[1]; - img_prompt.reserve(num_image_tokens * placeholder.size()); - for (int j = 0; j < num_image_tokens; j++) { - img_prompt += placeholder; + img_prompt += "Picture " + std::to_string(i + 1) + ": <|vision_start|>"; // [24669, 220, index, 25, 220, 151652] + int64_t num_image_tokens = image_embed->ne[1]; + img_prompt.reserve(num_image_tokens * placeholder.size()); + for (int j = 0; j < num_image_tokens; j++) { + img_prompt += placeholder; + } + img_prompt += "<|vision_end|>"; } - img_prompt += "<|vision_end|>"; - } - prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n"; - prompt += img_prompt; + prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n"; + prompt += img_prompt; - prompt_attn_range.first = static_cast(prompt.size()); - prompt += conditioner_params.text; - prompt_attn_range.second = static_cast(prompt.size()); + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); - prompt += "<|im_end|>\n<|im_start|>assistant\n"; + prompt += "<|im_end|>\n<|im_start|>assistant\n"; + } } else if (sd_version_is_flux2(version)) { prompt_template_encode_start_idx = 0; out_layers = {10, 20, 30}; diff --git a/rope.hpp b/rope.hpp index 2ee1d987a..ab15e458c 100644 --- a/rope.hpp +++ b/rope.hpp @@ -115,7 +115,7 @@ namespace Rope { return txt_ids; } - __STATIC_INLINE__ std::vector> gen_flux_img_ids(int h, + __STATIC_INLINE__ std::vector> gen_flux_img_ids(int h, int w, int patch_size, int bs, @@ -138,7 +138,6 @@ namespace Rope { std::vector row_ids = linspace(1.f * h_start, 1.f * h_start + h_len - 1, h_len); std::vector col_ids = linspace(1.f * w_start, 1.f * w_start + w_len - 1, w_len); - for (int i = 0; i < h_len; ++i) { for (int j = 0; j < w_len; ++j) { img_ids[i * w_len + j][0] = 1.f * index; @@ -219,10 +218,10 @@ namespace Rope { bool increase_ref_index, float ref_index_scale, bool scale_rope, - int base_offset = 0) { + int base_offset = 0) { std::vector> ids; - int curr_h_offset = base_offset; - int curr_w_offset = base_offset; + int curr_h_offset = 0; + int curr_w_offset = 0; int index = start_index; for (ggml_tensor* ref : ref_latents) { int h_offset = 0; @@ -242,8 +241,8 @@ namespace Rope { bs, axes_dim_num, static_cast(index * ref_index_scale), - h_offset, - w_offset, + h_offset + base_offset, + w_offset + base_offset, scale_rope); ids = concat_ids(ids, ref_ids, bs); From 5bd1335d16f8799a17581f08716e6ab1e13641ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Fri, 12 Dec 2025 02:10:46 +0100 Subject: [PATCH 11/14] Split quotes by utf8 characters rather than individual char --- conditioner.hpp | 87 +++++++++++++++++++++++++------------------------ 1 file changed, 45 insertions(+), 42 deletions(-) diff --git a/conditioner.hpp b/conditioner.hpp index e7c034495..66200ab4d 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -1656,6 +1656,19 @@ struct LLMEmbedder : public Conditioner { } } + size_t get_utf8_char_len(char c) { + unsigned char uc = static_cast(c); + if ((uc & 0x80) == 0) + return 1; // ASCII (1 byte) + if ((uc & 0xE0) == 0xC0) + return 2; // 2-byte char + if ((uc & 0xF0) == 0xE0) + return 3; // 3-byte char (Common for Chinese/Japanese) + if ((uc & 0xF8) == 0xF0) + return 4; // 4-byte char (Emojis, etc.) + return 1; // Fallback (should not happen in valid UTF-8) + } + std::tuple, std::vector> tokenize( std::string text, std::pair attn_range, @@ -1675,16 +1688,6 @@ struct LLMEmbedder : public Conditioner { } parsed_attention.emplace_back(text.substr(attn_range.second), 1.f); - // { - // std::stringstream ss; - // ss << '['; - // for (const auto& item : parsed_attention) { - // ss << "['" << item.first << "', " << item.second << "], "; - // } - // ss << ']'; - // LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str()); - // } - std::vector tokens; std::vector weights; @@ -1693,46 +1696,47 @@ struct LLMEmbedder : public Conditioner { float curr_weight = item.second; if (spell_quotes) { - std::vector parts; + std::string buffer; bool in_quote = false; - std::string current_part; - for (char c : curr_text) { - if (c == '"') { - if (!current_part.empty()) { - parts.push_back(current_part); - current_part.clear(); + size_t i = 0; + while (i < curr_text.size()) { + // utf8 character can be 1-4 char + size_t char_len = get_utf8_char_len(curr_text[i]); + + // Safety check to prevent reading past end of string + if (i + char_len > curr_text.size()) { + char_len = curr_text.size() - i; + } + std::string uchar = curr_text.substr(i, char_len); + i += char_len; + + if (uchar == "\"") { + buffer += uchar; + // If we were accumulating normal text, flush it now + if (!in_quote) { + std::vector part_tokens = tokenizer->tokenize(buffer, nullptr); + tokens.insert(tokens.end(), part_tokens.begin(), part_tokens.end()); + weights.insert(weights.end(), part_tokens.size(), curr_weight); + buffer.clear(); } in_quote = !in_quote; } else { - current_part += c; - if (in_quote && current_part.size() == 1) { - parts.push_back(current_part); - current_part.clear(); - } - } - } - if (!current_part.empty()) { - parts.push_back(current_part); - } - - for (const auto& part : parts) { - if (part.empty()) - continue; - if (part[0] == '"' && part.back() == '"') { - std::string quoted_content = part.substr(1, part.size() - 2); - for (char ch : quoted_content) { - std::string char_str(1, ch); - std::vector char_tokens = tokenizer->tokenize(char_str, nullptr); + if (in_quote) { + std::vector char_tokens = tokenizer->tokenize(uchar, nullptr); tokens.insert(tokens.end(), char_tokens.begin(), char_tokens.end()); weights.insert(weights.end(), char_tokens.size(), curr_weight); + } else { + buffer += uchar; } - } else { - std::vector part_tokens = tokenizer->tokenize(part, nullptr); - tokens.insert(tokens.end(), part_tokens.begin(), part_tokens.end()); - weights.insert(weights.end(), part_tokens.size(), curr_weight); } } + + if (!buffer.empty()) { + std::vector part_tokens = tokenizer->tokenize(buffer, nullptr); + tokens.insert(tokens.end(), part_tokens.begin(), part_tokens.end()); + weights.insert(weights.end(), part_tokens.size(), curr_weight); + } } else { std::vector curr_tokens = tokenizer->tokenize(curr_text, nullptr); tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); @@ -1759,14 +1763,13 @@ struct LLMEmbedder : public Conditioner { LOG_INFO("LongCatEditPipeline"); prompt_template_encode_start_idx = 67; // prompt_template_encode_end_idx = 5; - int image_embed_idx = 36 + 6; + int image_embed_idx = 36 + 6; int min_pixels = 384 * 384; int max_pixels = 560 * 560; std::string placeholder = "<|image_pad|>"; std::string img_prompt; - // Only one image is officicially supported by the model, not sure how it handles multiple images for (int i = 0; i < conditioner_params.ref_images.size(); i++) { sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]); From 1e855c84ee9fc3625b32df008c989a1e1c82f97c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Fri, 12 Dec 2025 02:39:09 +0100 Subject: [PATCH 12/14] patch size consistent with Flux1 --- flux.hpp | 1 - stable-diffusion.cpp | 11 +---------- vae.hpp | 4 ++-- 3 files changed, 3 insertions(+), 13 deletions(-) diff --git a/flux.hpp b/flux.hpp index eb8283f14..a1d97cfed 100644 --- a/flux.hpp +++ b/flux.hpp @@ -1334,7 +1334,6 @@ namespace Flux { } else if (sd_version_is_longcat(version)) { flux_params.context_in_dim = 3584; flux_params.vec_in_dim = 0; - flux_params.patch_size = 1; } for (auto pair : tensor_storage_map) { std::string tensor_name = pair.first; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index e73a2a665..c7ee62d32 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1478,12 +1478,6 @@ class StableDiffusionGGML { latent_rgb_bias = flux2_latent_rgb_bias; patch_sz = 2; } - } else if (dim == 64) { - if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_longcat(version)) { - latent_rgb_proj = flux_latent_rgb_proj; - latent_rgb_bias = flux_latent_rgb_bias; - patch_sz = 2; - } } else if (dim == 48) { if (sd_version_is_wan(version)) { latent_rgb_proj = wan_22_latent_rgb_proj; @@ -2270,7 +2264,7 @@ class StableDiffusionGGML { int vae_scale_factor = 8; if (version == VERSION_WAN2_2_TI2V) { vae_scale_factor = 16; - } else if (sd_version_is_flux2(version) || sd_version_is_longcat(version)) { + } else if (sd_version_is_flux2(version)) { vae_scale_factor = 16; } else if (version == VERSION_CHROMA_RADIANCE) { vae_scale_factor = 1; @@ -2299,8 +2293,6 @@ class StableDiffusionGGML { latent_channel = 3; } else if (sd_version_is_flux2(version)) { latent_channel = 128; - } else if (sd_version_is_longcat(version)) { - latent_channel = 64; } else { latent_channel = 16; } @@ -2601,7 +2593,6 @@ class StableDiffusionGGML { sd_version_is_qwen_image(version) || sd_version_is_wan(version) || sd_version_is_flux2(version) || - sd_version_is_longcat(version) || version == VERSION_CHROMA_RADIANCE) { latent = vae_output; } else if (version == VERSION_SD1_PIX2PIX) { diff --git a/vae.hpp b/vae.hpp index a007592c1..cd055aa86 100644 --- a/vae.hpp +++ b/vae.hpp @@ -553,7 +553,7 @@ class AutoencodingEngine : public GGMLBlock { struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) { // z: [N, z_channels, h, w] - if (sd_version_is_flux2(version) || sd_version_is_longcat(version)) { + if (sd_version_is_flux2(version)) { // [N, C*p*p, h, w] -> [N, C, h*p, w*p] int64_t p = 2; @@ -592,7 +592,7 @@ class AutoencodingEngine : public GGMLBlock { auto quant_conv = std::dynamic_pointer_cast(blocks["quant_conv"]); z = quant_conv->forward(ctx, z); // [N, 2*embed_dim, h/8, w/8] } - if (sd_version_is_flux2(version) || sd_version_is_longcat(version)) { + if (sd_version_is_flux2(version)) { z = ggml_ext_chunk(ctx->ggml_ctx, z, 2, 2)[0]; // [N, C, H, W] -> [N, C*p*p, H/p, W/p] From 3e7ce7f5c33ee575118f8b051e9ef0e4716bc565 Mon Sep 17 00:00:00 2001 From: leejet Date: Mon, 15 Dec 2025 22:16:21 +0800 Subject: [PATCH 13/14] fix conditionner --- conditioner.hpp | 55 +++++++++++++++++++++++++++++++++++++++++-------- ggml_extend.hpp | 2 +- llm.hpp | 34 ++++++++++++++++++------------ 3 files changed, 68 insertions(+), 23 deletions(-) diff --git a/conditioner.hpp b/conditioner.hpp index 66200ab4d..98ad0d8fb 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -1755,9 +1755,13 @@ struct LLMEmbedder : public Conditioner { std::vector> image_embeds; std::pair prompt_attn_range; int prompt_template_encode_start_idx = 34; + int prompt_template_encode_end_idx = 0; int max_length = 0; bool spell_quotes = false; std::set out_layers; + std::vector tokens; + std::vector weights; + std::vector mask; if (llm->enable_vision && conditioner_params.ref_images.size() > 0) { if (sd_version_is_longcat(version)) { LOG_INFO("LongCatEditPipeline"); @@ -1937,8 +1941,8 @@ struct LLMEmbedder : public Conditioner { prompt += "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; } else if (sd_version_is_longcat(version)) { prompt_template_encode_start_idx = 36; - // prompt_template_encode_end_idx = 5; - max_length = 512; + max_length = 512 + prompt_template_encode_start_idx; + prompt_template_encode_end_idx = 5; spell_quotes = true; prompt = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"; @@ -1947,7 +1951,24 @@ struct LLMEmbedder : public Conditioner { prompt += conditioner_params.text; prompt_attn_range.second = static_cast(prompt.size()); - prompt += "<|im_end|>\n<|im_start|>assistant\n"; + auto tokens_and_weights = tokenize(prompt, prompt_attn_range, 0, false, spell_quotes); + tokens = std::get<0>(tokens_and_weights); + weights = std::get<1>(tokens_and_weights); + + mask.insert(mask.end(), tokens.size(), 1.f); + if (tokens.size() < max_length) { + mask.insert(mask.end(), max_length - tokens.size(), 0.f); + tokenizer->pad_tokens(tokens, weights, max_length, true); + } + + std::string prompt_template_suffix = "<|im_end|>\n<|im_start|>assistant\n"; + auto suffix_tokens = tokenizer->tokenize(prompt_template_suffix, nullptr); + + LOG_DEBUG("%zd", tokens.size()); + + tokens.insert(tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); + weights.insert(weights.end(), suffix_tokens.size(), 1.f); + mask.insert(mask.end(), suffix_tokens.size(), 1.f); } else { prompt_template_encode_start_idx = 34; @@ -1960,17 +1981,33 @@ struct LLMEmbedder : public Conditioner { prompt += "<|im_end|>\n<|im_start|>assistant\n"; } - auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0, spell_quotes); - auto& tokens = std::get<0>(tokens_and_weights); - auto& weights = std::get<1>(tokens_and_weights); + if (tokens.empty()) { + auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0, spell_quotes); + tokens = std::get<0>(tokens_and_weights); + weights = std::get<1>(tokens_and_weights); + } + int64_t t0 = ggml_time_ms(); struct ggml_tensor* hidden_states = nullptr; // [N, n_token, 3584] auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens); + ggml_tensor* attention_mask = nullptr; + if (!mask.empty()) { + attention_mask = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, mask.size(), mask.size()); + ggml_ext_tensor_iter(attention_mask, [&](ggml_tensor* attention_mask, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { + float value = 0.f; + if (mask[i0] == 0.f || mask[i1] == 0.f) { + value = -INFINITY; + } + ggml_ext_tensor_set_f32(attention_mask, value, i0, i1, i2, i3); + }); + print_ggml_tensor(attention_mask); + } llm->compute(n_threads, input_ids, + attention_mask, image_embeds, out_layers, &hidden_states, @@ -2008,18 +2045,18 @@ struct LLMEmbedder : public Conditioner { ggml_tensor* new_hidden_states = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, hidden_states->ne[0], - hidden_states->ne[1] - prompt_template_encode_start_idx + zero_pad_len, + hidden_states->ne[1] - prompt_template_encode_start_idx + zero_pad_len - prompt_template_encode_end_idx, hidden_states->ne[2]); ggml_ext_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { float value = 0.f; - if (i1 + prompt_template_encode_start_idx < hidden_states->ne[1]) { + if (i1 + prompt_template_encode_start_idx < hidden_states->ne[1] - prompt_template_encode_end_idx) { value = ggml_ext_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3); } ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3); }); - // print_ggml_tensor(new_hidden_states); + print_ggml_tensor(new_hidden_states, true); int64_t t1 = ggml_time_ms(); LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 7a0213724..c3c5bc76f 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -2207,7 +2207,7 @@ class Linear : public UnaryBlock { bool bias = true, bool force_f32 = false, bool force_prec_f32 = false, - float scale = 1.f) + float scale = 1.f / 128.f) : in_features(in_features), out_features(out_features), bias(bias), diff --git a/llm.hpp b/llm.hpp index 67b1ea165..77dc0e5ca 100644 --- a/llm.hpp +++ b/llm.hpp @@ -837,7 +837,8 @@ namespace LLM { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, - struct ggml_tensor* input_pos) { + struct ggml_tensor* input_pos, + struct ggml_tensor* attention_mask = nullptr) { // x: [N, n_token, hidden_size] int64_t n_token = x->ne[1]; int64_t N = x->ne[2]; @@ -880,7 +881,7 @@ namespace LLM { k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim] k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim] - x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, true, true, false); // [N, n_token, hidden_size] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, true, true, false); // [N, n_token, hidden_size] x = out_proj->forward(ctx, x); // [N, n_token, hidden_size] return x; @@ -898,7 +899,8 @@ namespace LLM { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, - struct ggml_tensor* input_pos) { + struct ggml_tensor* input_pos, + struct ggml_tensor* attention_mask = nullptr) { // x: [N, n_token, hidden_size] auto self_attn = std::dynamic_pointer_cast(blocks["self_attn"]); auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); @@ -907,7 +909,7 @@ namespace LLM { auto residual = x; x = input_layernorm->forward(ctx, x); - x = self_attn->forward(ctx, x, input_pos); + x = self_attn->forward(ctx, x, input_pos, attention_mask); x = ggml_add_inplace(ctx->ggml_ctx, x, residual); residual = x; @@ -936,6 +938,7 @@ namespace LLM { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* input_ids, struct ggml_tensor* input_pos, + struct ggml_tensor* attention_mask, std::vector> image_embeds, std::set out_layers) { // input_ids: [N, n_token] @@ -990,7 +993,7 @@ namespace LLM { for (int i = 0; i < num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["layers." + std::to_string(i)]); - x = block->forward(ctx, x, input_pos); + x = block->forward(ctx, x, input_pos, attention_mask); if (out_layers.find(i + 1) != out_layers.end()) { intermediate_outputs.push_back(x); } @@ -1036,12 +1039,13 @@ namespace LLM { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* input_ids, struct ggml_tensor* input_pos, + struct ggml_tensor* attention_mask, std::vector> image_embeds, std::set out_layers) { // input_ids: [N, n_token] auto model = std::dynamic_pointer_cast(blocks["model"]); - auto x = model->forward(ctx, input_ids, input_pos, image_embeds, out_layers); + auto x = model->forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); return x; } @@ -1157,9 +1161,10 @@ namespace LLM { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* input_ids, struct ggml_tensor* input_pos, + struct ggml_tensor* attention_mask, std::vector> image_embeds, std::set out_layers) { - auto hidden_states = model.forward(ctx, input_ids, input_pos, image_embeds, out_layers); // [N, n_token, hidden_size] + auto hidden_states = model.forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); // [N, n_token, hidden_size] return hidden_states; } @@ -1174,11 +1179,13 @@ namespace LLM { } struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids, + struct ggml_tensor* attention_mask, std::vector> image_embeds, std::set out_layers) { struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); input_ids = to_backend(input_ids); + attention_mask = to_backend(attention_mask); for (auto& image_embed : image_embeds) { image_embed.second = to_backend(image_embed.second); @@ -1207,7 +1214,7 @@ namespace LLM { auto runner_ctx = get_context(); - struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, image_embeds, out_layers); + struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); ggml_build_forward_expand(gf, hidden_states); @@ -1216,12 +1223,13 @@ namespace LLM { bool compute(const int n_threads, struct ggml_tensor* input_ids, + struct ggml_tensor* attention_mask, std::vector> image_embeds, std::set out_layers, ggml_tensor** output, ggml_context* output_ctx = nullptr) { auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(input_ids, image_embeds, out_layers); + return build_graph(input_ids, attention_mask, image_embeds, out_layers); }; return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); } @@ -1525,7 +1533,7 @@ namespace LLM { struct ggml_tensor* out = nullptr; int64_t t0 = ggml_time_ms(); - model.compute(8, input_ids, image_embeds, {}, &out, work_ctx); + model.compute(8, input_ids, nullptr, image_embeds, {}, &out, work_ctx); int64_t t1 = ggml_time_ms(); print_ggml_tensor(out); @@ -1565,7 +1573,7 @@ namespace LLM { struct ggml_tensor* out = nullptr; int64_t t0 = ggml_time_ms(); - model.compute(8, input_ids, {}, {10, 20, 30}, &out, work_ctx); + model.compute(8, input_ids, nullptr, {}, {10, 20, 30}, &out, work_ctx); int64_t t1 = ggml_time_ms(); print_ggml_tensor(out); @@ -1588,7 +1596,7 @@ namespace LLM { struct ggml_tensor* out = nullptr; int64_t t0 = ggml_time_ms(); - model.compute(8, input_ids, {}, {35}, &out, work_ctx); + model.compute(8, input_ids, nullptr, {}, {35}, &out, work_ctx); int64_t t1 = ggml_time_ms(); print_ggml_tensor(out); @@ -1611,7 +1619,7 @@ namespace LLM { struct ggml_tensor* out = nullptr; int64_t t0 = ggml_time_ms(); - model.compute(8, input_ids, {}, {}, &out, work_ctx); + model.compute(8, input_ids, nullptr, {}, {}, &out, work_ctx); int64_t t1 = ggml_time_ms(); print_ggml_tensor(out); From fbd8417acbfc3257c526232bfa320c0fc687d349 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Tue, 6 Jan 2026 17:26:42 +0100 Subject: [PATCH 14/14] apply conditioner fix to edit mode --- conditioner.hpp | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/conditioner.hpp b/conditioner.hpp index 98ad0d8fb..2c3f7522d 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -1766,7 +1766,7 @@ struct LLMEmbedder : public Conditioner { if (sd_version_is_longcat(version)) { LOG_INFO("LongCatEditPipeline"); prompt_template_encode_start_idx = 67; - // prompt_template_encode_end_idx = 5; + prompt_template_encode_end_idx = 5; int image_embed_idx = 36 + 6; int min_pixels = 384 * 384; @@ -1820,16 +1820,33 @@ struct LLMEmbedder : public Conditioner { img_prompt += "<|vision_end|>"; } - max_length = 512; - spell_quotes = true; - prompt = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n"; + max_length = 512 + prompt_template_encode_start_idx; + spell_quotes = true; + prompt = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n"; prompt += img_prompt; prompt_attn_range.first = static_cast(prompt.size()); prompt += conditioner_params.text; prompt_attn_range.second = static_cast(prompt.size()); - prompt += "<|im_end|>\n<|im_start|>assistant\n"; + auto tokens_and_weights = tokenize(prompt, prompt_attn_range, 0, false, spell_quotes); + tokens = std::get<0>(tokens_and_weights); + weights = std::get<1>(tokens_and_weights); + + mask.insert(mask.end(), tokens.size(), 1.f); + if (tokens.size() < max_length) { + mask.insert(mask.end(), max_length - tokens.size(), 0.f); + tokenizer->pad_tokens(tokens, weights, max_length, true); + } + + std::string prompt_template_suffix = "<|im_end|>\n<|im_start|>assistant\n"; + auto suffix_tokens = tokenizer->tokenize(prompt_template_suffix, nullptr); + + LOG_DEBUG("%zd", tokens.size()); + + tokens.insert(tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); + weights.insert(weights.end(), suffix_tokens.size(), 1.f); + mask.insert(mask.end(), suffix_tokens.size(), 1.f); } else { LOG_INFO("QwenImageEditPlusPipeline");