diff --git a/conditioner.hpp b/conditioner.hpp index b6d5646a7..2c3f7522d 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -1656,46 +1656,95 @@ struct LLMEmbedder : public Conditioner { } } - std::tuple, std::vector> tokenize(std::string text, - std::pair attn_range, - size_t max_length = 0, - bool padding = false) { + size_t get_utf8_char_len(char c) { + unsigned char uc = static_cast(c); + if ((uc & 0x80) == 0) + return 1; // ASCII (1 byte) + if ((uc & 0xE0) == 0xC0) + return 2; // 2-byte char + if ((uc & 0xF0) == 0xE0) + return 3; // 3-byte char (Common for Chinese/Japanese) + if ((uc & 0xF8) == 0xF0) + return 4; // 4-byte char (Emojis, etc.) + return 1; // Fallback (should not happen in valid UTF-8) + } + + std::tuple, std::vector> tokenize( + std::string text, + std::pair attn_range, + size_t max_length = 0, + bool padding = false, + bool spell_quotes = false) { std::vector> parsed_attention; parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f); + if (attn_range.second - attn_range.first > 0) { - auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first)); - parsed_attention.insert(parsed_attention.end(), - new_parsed_attention.begin(), - new_parsed_attention.end()); + auto new_parsed_attention = parse_prompt_attention( + text.substr(attn_range.first, attn_range.second - attn_range.first)); + parsed_attention.insert( + parsed_attention.end(), + new_parsed_attention.begin(), + new_parsed_attention.end()); } parsed_attention.emplace_back(text.substr(attn_range.second), 1.f); - { - std::stringstream ss; - ss << "["; - for (const auto& item : parsed_attention) { - ss << "['" << item.first << "', " << item.second << "], "; - } - ss << "]"; - LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str()); - } std::vector tokens; std::vector weights; + for (const auto& item : parsed_attention) { const std::string& curr_text = item.first; float curr_weight = item.second; - std::vector curr_tokens = tokenizer->tokenize(curr_text, nullptr); - tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); - weights.insert(weights.end(), curr_tokens.size(), curr_weight); - } - tokenizer->pad_tokens(tokens, weights, max_length, padding); + if (spell_quotes) { + std::string buffer; + bool in_quote = false; - // for (int i = 0; i < tokens.size(); i++) { - // std::cout << tokens[i] << ":" << weights[i] << ", " << i << std::endl; - // } - // std::cout << std::endl; + size_t i = 0; + while (i < curr_text.size()) { + // utf8 character can be 1-4 char + size_t char_len = get_utf8_char_len(curr_text[i]); + + // Safety check to prevent reading past end of string + if (i + char_len > curr_text.size()) { + char_len = curr_text.size() - i; + } + std::string uchar = curr_text.substr(i, char_len); + i += char_len; + + if (uchar == "\"") { + buffer += uchar; + // If we were accumulating normal text, flush it now + if (!in_quote) { + std::vector part_tokens = tokenizer->tokenize(buffer, nullptr); + tokens.insert(tokens.end(), part_tokens.begin(), part_tokens.end()); + weights.insert(weights.end(), part_tokens.size(), curr_weight); + buffer.clear(); + } + in_quote = !in_quote; + } else { + if (in_quote) { + std::vector char_tokens = tokenizer->tokenize(uchar, nullptr); + tokens.insert(tokens.end(), char_tokens.begin(), char_tokens.end()); + weights.insert(weights.end(), char_tokens.size(), curr_weight); + } else { + buffer += uchar; + } + } + } + + if (!buffer.empty()) { + std::vector part_tokens = tokenizer->tokenize(buffer, nullptr); + tokens.insert(tokens.end(), part_tokens.begin(), part_tokens.end()); + weights.insert(weights.end(), part_tokens.size(), curr_weight); + } + } else { + std::vector curr_tokens = tokenizer->tokenize(curr_text, nullptr); + tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); + weights.insert(weights.end(), curr_tokens.size(), curr_weight); + } + } + tokenizer->pad_tokens(tokens, weights, max_length, padding); return {tokens, weights}; } @@ -1706,71 +1755,163 @@ struct LLMEmbedder : public Conditioner { std::vector> image_embeds; std::pair prompt_attn_range; int prompt_template_encode_start_idx = 34; + int prompt_template_encode_end_idx = 0; int max_length = 0; + bool spell_quotes = false; std::set out_layers; + std::vector tokens; + std::vector weights; + std::vector mask; if (llm->enable_vision && conditioner_params.ref_images.size() > 0) { - LOG_INFO("QwenImageEditPlusPipeline"); - prompt_template_encode_start_idx = 64; - int image_embed_idx = 64 + 6; - - int min_pixels = 384 * 384; - int max_pixels = 560 * 560; - std::string placeholder = "<|image_pad|>"; - std::string img_prompt; - - for (int i = 0; i < conditioner_params.ref_images.size(); i++) { - sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]); - double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size; - int height = image.height; - int width = image.width; - int h_bar = static_cast(std::round(height / factor) * factor); - int w_bar = static_cast(std::round(width / factor) * factor); - - if (static_cast(h_bar) * w_bar > max_pixels) { - double beta = std::sqrt((height * width) / static_cast(max_pixels)); - h_bar = std::max(static_cast(factor), - static_cast(std::floor(height / beta / factor)) * static_cast(factor)); - w_bar = std::max(static_cast(factor), - static_cast(std::floor(width / beta / factor)) * static_cast(factor)); - } else if (static_cast(h_bar) * w_bar < min_pixels) { - double beta = std::sqrt(static_cast(min_pixels) / (height * width)); - h_bar = static_cast(std::ceil(height * beta / factor)) * static_cast(factor); - w_bar = static_cast(std::ceil(width * beta / factor)) * static_cast(factor); + if (sd_version_is_longcat(version)) { + LOG_INFO("LongCatEditPipeline"); + prompt_template_encode_start_idx = 67; + prompt_template_encode_end_idx = 5; + int image_embed_idx = 36 + 6; + + int min_pixels = 384 * 384; + int max_pixels = 560 * 560; + std::string placeholder = "<|image_pad|>"; + std::string img_prompt; + + // Only one image is officicially supported by the model, not sure how it handles multiple images + for (int i = 0; i < conditioner_params.ref_images.size(); i++) { + sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]); + double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size; + int height = image.height; + int width = image.width; + int h_bar = static_cast(std::round(height / factor)) * factor; + int w_bar = static_cast(std::round(width / factor)) * factor; + + if (static_cast(h_bar) * w_bar > max_pixels) { + double beta = std::sqrt((height * width) / static_cast(max_pixels)); + h_bar = std::max(static_cast(factor), + static_cast(std::floor(height / beta / factor)) * static_cast(factor)); + w_bar = std::max(static_cast(factor), + static_cast(std::floor(width / beta / factor)) * static_cast(factor)); + } else if (static_cast(h_bar) * w_bar < min_pixels) { + double beta = std::sqrt(static_cast(min_pixels) / (height * width)); + h_bar = static_cast(std::ceil(height * beta / factor)) * static_cast(factor); + w_bar = static_cast(std::ceil(width * beta / factor)) * static_cast(factor); + } + + LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar); + + sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar); + free(image.data); + image.data = nullptr; + + ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); + sd_image_f32_to_ggml_tensor(resized_image, image_tensor, false); + free(resized_image.data); + resized_image.data = nullptr; + + ggml_tensor* image_embed = nullptr; + llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx); + image_embeds.emplace_back(image_embed_idx, image_embed); + image_embed_idx += 1 + image_embed->ne[1] + 6; + + img_prompt += "<|vision_start|>"; + int64_t num_image_tokens = image_embed->ne[1]; + img_prompt.reserve(num_image_tokens * placeholder.size()); + for (int j = 0; j < num_image_tokens; j++) { + img_prompt += placeholder; + } + img_prompt += "<|vision_end|>"; + } + + max_length = 512 + prompt_template_encode_start_idx; + spell_quotes = true; + prompt = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n"; + prompt += img_prompt; + + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); + + auto tokens_and_weights = tokenize(prompt, prompt_attn_range, 0, false, spell_quotes); + tokens = std::get<0>(tokens_and_weights); + weights = std::get<1>(tokens_and_weights); + + mask.insert(mask.end(), tokens.size(), 1.f); + if (tokens.size() < max_length) { + mask.insert(mask.end(), max_length - tokens.size(), 0.f); + tokenizer->pad_tokens(tokens, weights, max_length, true); } - LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar); + std::string prompt_template_suffix = "<|im_end|>\n<|im_start|>assistant\n"; + auto suffix_tokens = tokenizer->tokenize(prompt_template_suffix, nullptr); + + LOG_DEBUG("%zd", tokens.size()); + + tokens.insert(tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); + weights.insert(weights.end(), suffix_tokens.size(), 1.f); + mask.insert(mask.end(), suffix_tokens.size(), 1.f); - sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar); - free(image.data); - image.data = nullptr; + } else { + LOG_INFO("QwenImageEditPlusPipeline"); + prompt_template_encode_start_idx = 64; + int image_embed_idx = 64 + 6; + + int min_pixels = 384 * 384; + int max_pixels = 560 * 560; + std::string placeholder = "<|image_pad|>"; + std::string img_prompt; + + for (int i = 0; i < conditioner_params.ref_images.size(); i++) { + sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]); + double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size; + int height = image.height; + int width = image.width; + int h_bar = static_cast(std::round(height / factor) * factor); + int w_bar = static_cast(std::round(width / factor) * factor); + + if (static_cast(h_bar) * w_bar > max_pixels) { + double beta = std::sqrt((height * width) / static_cast(max_pixels)); + h_bar = std::max(static_cast(factor), + static_cast(std::floor(height / beta / factor)) * static_cast(factor)); + w_bar = std::max(static_cast(factor), + static_cast(std::floor(width / beta / factor)) * static_cast(factor)); + } else if (static_cast(h_bar) * w_bar < min_pixels) { + double beta = std::sqrt(static_cast(min_pixels) / (height * width)); + h_bar = static_cast(std::ceil(height * beta / factor)) * static_cast(factor); + w_bar = static_cast(std::ceil(width * beta / factor)) * static_cast(factor); + } - ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); - sd_image_f32_to_ggml_tensor(resized_image, image_tensor, false); - free(resized_image.data); - resized_image.data = nullptr; + LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar); - ggml_tensor* image_embed = nullptr; - llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx); - image_embeds.emplace_back(image_embed_idx, image_embed); - image_embed_idx += 1 + static_cast(image_embed->ne[1]) + 6; + sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar); + free(image.data); + image.data = nullptr; - img_prompt += "Picture " + std::to_string(i + 1) + ": <|vision_start|>"; // [24669, 220, index, 25, 220, 151652] - int64_t num_image_tokens = image_embed->ne[1]; - img_prompt.reserve(num_image_tokens * placeholder.size()); - for (int j = 0; j < num_image_tokens; j++) { - img_prompt += placeholder; + ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); + sd_image_f32_to_ggml_tensor(resized_image, image_tensor, false); + free(resized_image.data); + resized_image.data = nullptr; + + ggml_tensor* image_embed = nullptr; + llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx); + image_embeds.emplace_back(image_embed_idx, image_embed); + image_embed_idx += 1 + static_cast(image_embed->ne[1]) + 6; + + img_prompt += "Picture " + std::to_string(i + 1) + ": <|vision_start|>"; // [24669, 220, index, 25, 220, 151652] + int64_t num_image_tokens = image_embed->ne[1]; + img_prompt.reserve(num_image_tokens * placeholder.size()); + for (int j = 0; j < num_image_tokens; j++) { + img_prompt += placeholder; + } + img_prompt += "<|vision_end|>"; } - img_prompt += "<|vision_end|>"; - } - prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n"; - prompt += img_prompt; + prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n"; + prompt += img_prompt; - prompt_attn_range.first = static_cast(prompt.size()); - prompt += conditioner_params.text; - prompt_attn_range.second = static_cast(prompt.size()); + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); - prompt += "<|im_end|>\n<|im_start|>assistant\n"; + prompt += "<|im_end|>\n<|im_start|>assistant\n"; + } } else if (sd_version_is_flux2(version)) { prompt_template_encode_start_idx = 0; out_layers = {10, 20, 30}; @@ -1815,6 +1956,36 @@ struct LLMEmbedder : public Conditioner { prompt_attn_range.second = static_cast(prompt.size()); prompt += "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; + } else if (sd_version_is_longcat(version)) { + prompt_template_encode_start_idx = 36; + max_length = 512 + prompt_template_encode_start_idx; + prompt_template_encode_end_idx = 5; + spell_quotes = true; + + prompt = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"; + + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); + + auto tokens_and_weights = tokenize(prompt, prompt_attn_range, 0, false, spell_quotes); + tokens = std::get<0>(tokens_and_weights); + weights = std::get<1>(tokens_and_weights); + + mask.insert(mask.end(), tokens.size(), 1.f); + if (tokens.size() < max_length) { + mask.insert(mask.end(), max_length - tokens.size(), 0.f); + tokenizer->pad_tokens(tokens, weights, max_length, true); + } + + std::string prompt_template_suffix = "<|im_end|>\n<|im_start|>assistant\n"; + auto suffix_tokens = tokenizer->tokenize(prompt_template_suffix, nullptr); + + LOG_DEBUG("%zd", tokens.size()); + + tokens.insert(tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); + weights.insert(weights.end(), suffix_tokens.size(), 1.f); + mask.insert(mask.end(), suffix_tokens.size(), 1.f); } else { prompt_template_encode_start_idx = 34; @@ -1827,17 +1998,33 @@ struct LLMEmbedder : public Conditioner { prompt += "<|im_end|>\n<|im_start|>assistant\n"; } - auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0); - auto& tokens = std::get<0>(tokens_and_weights); - auto& weights = std::get<1>(tokens_and_weights); + if (tokens.empty()) { + auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0, spell_quotes); + tokens = std::get<0>(tokens_and_weights); + weights = std::get<1>(tokens_and_weights); + } + int64_t t0 = ggml_time_ms(); struct ggml_tensor* hidden_states = nullptr; // [N, n_token, 3584] auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens); + ggml_tensor* attention_mask = nullptr; + if (!mask.empty()) { + attention_mask = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, mask.size(), mask.size()); + ggml_ext_tensor_iter(attention_mask, [&](ggml_tensor* attention_mask, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { + float value = 0.f; + if (mask[i0] == 0.f || mask[i1] == 0.f) { + value = -INFINITY; + } + ggml_ext_tensor_set_f32(attention_mask, value, i0, i1, i2, i3); + }); + print_ggml_tensor(attention_mask); + } llm->compute(n_threads, input_ids, + attention_mask, image_embeds, out_layers, &hidden_states, @@ -1875,18 +2062,18 @@ struct LLMEmbedder : public Conditioner { ggml_tensor* new_hidden_states = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, hidden_states->ne[0], - hidden_states->ne[1] - prompt_template_encode_start_idx + zero_pad_len, + hidden_states->ne[1] - prompt_template_encode_start_idx + zero_pad_len - prompt_template_encode_end_idx, hidden_states->ne[2]); ggml_ext_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { float value = 0.f; - if (i1 + prompt_template_encode_start_idx < hidden_states->ne[1]) { + if (i1 + prompt_template_encode_start_idx < hidden_states->ne[1] - prompt_template_encode_end_idx) { value = ggml_ext_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3); } ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3); }); - // print_ggml_tensor(new_hidden_states); + print_ggml_tensor(new_hidden_states, true); int64_t t1 = ggml_time_ms(); LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); diff --git a/flux.hpp b/flux.hpp index 5d94fc85d..a1d97cfed 100644 --- a/flux.hpp +++ b/flux.hpp @@ -88,14 +88,19 @@ namespace Flux { public: SelfAttention(int64_t dim, - int64_t num_heads = 8, - bool qkv_bias = false, - bool proj_bias = true) + int64_t num_heads = 8, + bool qkv_bias = false, + bool proj_bias = true, + bool diffusers_style = false) : num_heads(num_heads) { int64_t head_dim = dim / num_heads; - blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); - blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); - blocks["proj"] = std::shared_ptr(new Linear(dim, dim, proj_bias)); + if (diffusers_style) { + blocks["qkv"] = std::shared_ptr(new SplitLinear(dim, {dim, dim, dim}, qkv_bias)); + } else { + blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); + } + blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); + blocks["proj"] = std::shared_ptr(new Linear(dim, dim, proj_bias)); } std::vector pre_attention(GGMLRunnerContext* ctx, struct ggml_tensor* x) { @@ -261,7 +266,8 @@ namespace Flux { bool share_modulation = false, bool mlp_proj_bias = true, bool use_yak_mlp = false, - bool use_mlp_silu_act = false) + bool use_mlp_silu_act = false, + bool diffusers_style = false) : idx(idx), prune_mod(prune_mod) { int64_t mlp_hidden_dim = static_cast(hidden_size * mlp_ratio); @@ -269,7 +275,7 @@ namespace Flux { blocks["img_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); } blocks["img_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias)); + blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias, diffusers_style)); blocks["img_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); if (use_yak_mlp) { @@ -282,7 +288,7 @@ namespace Flux { blocks["txt_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); } blocks["txt_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias)); + blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, mlp_proj_bias, diffusers_style)); blocks["txt_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); if (use_yak_mlp) { @@ -424,6 +430,7 @@ namespace Flux { bool use_yak_mlp; bool use_mlp_silu_act; int64_t mlp_mult_factor; + bool diffusers_style = false; public: SingleStreamBlock(int64_t hidden_size, @@ -435,7 +442,8 @@ namespace Flux { bool share_modulation = false, bool mlp_proj_bias = true, bool use_yak_mlp = false, - bool use_mlp_silu_act = false) + bool use_mlp_silu_act = false, + bool diffusers_style = false) : hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod), use_yak_mlp(use_yak_mlp), use_mlp_silu_act(use_mlp_silu_act) { int64_t head_dim = hidden_size / num_heads; float scale = qk_scale; @@ -447,8 +455,11 @@ namespace Flux { if (use_yak_mlp || use_mlp_silu_act) { mlp_mult_factor = 2; } - - blocks["linear1"] = std::shared_ptr(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias)); + if (diffusers_style) { + blocks["linear1"] = std::shared_ptr(new SplitLinear(hidden_size, {hidden_size, hidden_size, hidden_size, mlp_hidden_dim * mlp_mult_factor}, mlp_proj_bias)); + } else { + blocks["linear1"] = std::shared_ptr(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias)); + } blocks["linear2"] = std::shared_ptr(new Linear(hidden_size + mlp_hidden_dim, hidden_size, mlp_proj_bias)); blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); blocks["pre_norm"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); @@ -776,6 +787,7 @@ namespace Flux { bool use_yak_mlp = false; bool use_mlp_silu_act = false; float ref_index_scale = 1.f; + bool diffusers_style = false; ChromaRadianceParams chroma_radiance_params; }; @@ -822,7 +834,8 @@ namespace Flux { params.share_modulation, !params.disable_bias, params.use_yak_mlp, - params.use_mlp_silu_act); + params.use_mlp_silu_act, + params.diffusers_style); } for (int i = 0; i < params.depth_single_blocks; i++) { @@ -835,7 +848,8 @@ namespace Flux { params.share_modulation, !params.disable_bias, params.use_yak_mlp, - params.use_mlp_silu_act); + params.use_mlp_silu_act, + params.diffusers_style); } if (params.version == VERSION_CHROMA_RADIANCE) { @@ -882,6 +896,11 @@ namespace Flux { int64_t C = x->ne[2]; int64_t H = x->ne[1]; int64_t W = x->ne[0]; + if (params.patch_size == 1) { + x = ggml_reshape_3d(ctx, x, H * W, C, N); // [N, C, H*W] + x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, H*W, C] + return x; + } int64_t p = params.patch_size; int64_t h = H / params.patch_size; int64_t w = W / params.patch_size; @@ -916,6 +935,12 @@ namespace Flux { int64_t W = w * params.patch_size; int64_t p = params.patch_size; + if (params.patch_size == 1) { + x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, C, H*W] + x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, H, W] + return x; + } + GGML_ASSERT(C * p * p == x->ne[0]); x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p] @@ -1306,6 +1331,9 @@ namespace Flux { flux_params.share_modulation = true; flux_params.ref_index_scale = 10.f; flux_params.use_mlp_silu_act = true; + } else if (sd_version_is_longcat(version)) { + flux_params.context_in_dim = 3584; + flux_params.vec_in_dim = 0; } for (auto pair : tensor_storage_map) { std::string tensor_name = pair.first; @@ -1315,6 +1343,9 @@ namespace Flux { // not schnell flux_params.guidance_embed = true; } + if (tensor_name.find("model.diffusion_model.single_blocks.0.linear1.weight.1") != std::string::npos) { + flux_params.diffusers_style = true; + } if (tensor_name.find("__x0__") != std::string::npos) { LOG_DEBUG("using x0 prediction"); flux_params.chroma_radiance_params.use_x0 = true; @@ -1353,6 +1384,10 @@ namespace Flux { LOG_INFO("Flux guidance is disabled (Schnell mode)"); } + if (flux_params.diffusers_style) { + LOG_INFO("Using diffusers-style attention blocks"); + } + flux = Flux(flux_params); flux.init(params_ctx, tensor_storage_map, prefix); } @@ -1464,7 +1499,6 @@ namespace Flux { } else if (version == VERSION_OVIS_IMAGE) { txt_arange_dims = {1, 2}; } - pe_vec = Rope::gen_flux_pe(static_cast(x->ne[1]), static_cast(x->ne[0]), flux_params.patch_size, @@ -1477,10 +1511,10 @@ namespace Flux { flux_params.theta, circular_y_enabled, circular_x_enabled, - flux_params.axes_dim); + flux_params.axes_dim, + sd_version_is_longcat(version)); int pos_len = static_cast(pe_vec.size() / flux_params.axes_dim_sum / 2); - // LOG_DEBUG("pos_len %d", pos_len); - auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len); + auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len); // pe->data = pe_vec.data(); // print_ggml_tensor(pe); // pe->data = nullptr; diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 1ff450116..c3c5bc76f 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -2207,7 +2207,7 @@ class Linear : public UnaryBlock { bool bias = true, bool force_f32 = false, bool force_prec_f32 = false, - float scale = 1.f) + float scale = 1.f / 128.f) : in_features(in_features), out_features(out_features), bias(bias), @@ -2232,6 +2232,83 @@ class Linear : public UnaryBlock { } }; +class SplitLinear : public Linear { +protected: + int64_t in_features; + std::vector out_features_vec; + bool bias; + bool force_f32; + bool force_prec_f32; + float scale; + std::string prefix; + + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { + this->prefix = prefix; + enum ggml_type wtype = get_type(prefix + "weight", tensor_storage_map, GGML_TYPE_F32); + if (in_features % ggml_blck_size(wtype) != 0 || force_f32) { + wtype = GGML_TYPE_F32; + } + params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features_vec[0]); + for (int i = 1; i < out_features_vec.size(); i++) { + // most likely same type as the first weight + params["weight." + std::to_string(i)] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features_vec[i]); + } + if (bias) { + enum ggml_type wtype = GGML_TYPE_F32; + params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_features_vec[0]); + for (int i = 1; i < out_features_vec.size(); i++) { + params["bias." + std::to_string(i)] = ggml_new_tensor_1d(ctx, wtype, out_features_vec[i]); + } + } + } + +public: + SplitLinear(int64_t in_features, + std::vector out_features_vec, + bool bias = true, + bool force_f32 = false, + bool force_prec_f32 = false, + float scale = 1.f) + : Linear(in_features, out_features_vec[0], bias, force_f32, force_prec_f32, scale), + in_features(in_features), + out_features_vec(out_features_vec), + bias(bias), + force_f32(force_f32), + force_prec_f32(force_prec_f32), + scale(scale) {} + + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { + struct ggml_tensor* w = params["weight"]; + struct ggml_tensor* b = nullptr; + if (bias) { + b = params["bias"]; + } + if (ctx->weight_adapter) { + // concat all weights and biases together so it runs in one linear layer + for (int i = 1; i < out_features_vec.size(); i++) { + w = ggml_concat(ctx->ggml_ctx, w, params["weight." + std::to_string(i)], 1); + if (bias) { + b = ggml_concat(ctx->ggml_ctx, b, params["bias." + std::to_string(i)], 0); + } + } + WeightAdapter::ForwardParams forward_params; + forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_LINEAR; + forward_params.linear.force_prec_f32 = force_prec_f32; + forward_params.linear.scale = scale; + return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params); + } + auto out = ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale); + for (int i = 1; i < out_features_vec.size(); i++) { + auto wi = params["weight." + std::to_string(i)]; + auto bi = bias ? params["bias." + std::to_string(i)] : nullptr; + auto curr_out = ggml_ext_linear(ctx->ggml_ctx, x, wi, bi, force_prec_f32, scale); + out = ggml_concat(ctx->ggml_ctx, out, curr_out, 0); + } + + return out; + } +}; + __STATIC_INLINE__ bool support_get_rows(ggml_type wtype) { std::set allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0}; if (allow_types.find(wtype) != allow_types.end()) { diff --git a/llm.hpp b/llm.hpp index 67b1ea165..77dc0e5ca 100644 --- a/llm.hpp +++ b/llm.hpp @@ -837,7 +837,8 @@ namespace LLM { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, - struct ggml_tensor* input_pos) { + struct ggml_tensor* input_pos, + struct ggml_tensor* attention_mask = nullptr) { // x: [N, n_token, hidden_size] int64_t n_token = x->ne[1]; int64_t N = x->ne[2]; @@ -880,7 +881,7 @@ namespace LLM { k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim] k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim] - x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, true, true, false); // [N, n_token, hidden_size] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, true, true, false); // [N, n_token, hidden_size] x = out_proj->forward(ctx, x); // [N, n_token, hidden_size] return x; @@ -898,7 +899,8 @@ namespace LLM { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, - struct ggml_tensor* input_pos) { + struct ggml_tensor* input_pos, + struct ggml_tensor* attention_mask = nullptr) { // x: [N, n_token, hidden_size] auto self_attn = std::dynamic_pointer_cast(blocks["self_attn"]); auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); @@ -907,7 +909,7 @@ namespace LLM { auto residual = x; x = input_layernorm->forward(ctx, x); - x = self_attn->forward(ctx, x, input_pos); + x = self_attn->forward(ctx, x, input_pos, attention_mask); x = ggml_add_inplace(ctx->ggml_ctx, x, residual); residual = x; @@ -936,6 +938,7 @@ namespace LLM { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* input_ids, struct ggml_tensor* input_pos, + struct ggml_tensor* attention_mask, std::vector> image_embeds, std::set out_layers) { // input_ids: [N, n_token] @@ -990,7 +993,7 @@ namespace LLM { for (int i = 0; i < num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["layers." + std::to_string(i)]); - x = block->forward(ctx, x, input_pos); + x = block->forward(ctx, x, input_pos, attention_mask); if (out_layers.find(i + 1) != out_layers.end()) { intermediate_outputs.push_back(x); } @@ -1036,12 +1039,13 @@ namespace LLM { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* input_ids, struct ggml_tensor* input_pos, + struct ggml_tensor* attention_mask, std::vector> image_embeds, std::set out_layers) { // input_ids: [N, n_token] auto model = std::dynamic_pointer_cast(blocks["model"]); - auto x = model->forward(ctx, input_ids, input_pos, image_embeds, out_layers); + auto x = model->forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); return x; } @@ -1157,9 +1161,10 @@ namespace LLM { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* input_ids, struct ggml_tensor* input_pos, + struct ggml_tensor* attention_mask, std::vector> image_embeds, std::set out_layers) { - auto hidden_states = model.forward(ctx, input_ids, input_pos, image_embeds, out_layers); // [N, n_token, hidden_size] + auto hidden_states = model.forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); // [N, n_token, hidden_size] return hidden_states; } @@ -1174,11 +1179,13 @@ namespace LLM { } struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids, + struct ggml_tensor* attention_mask, std::vector> image_embeds, std::set out_layers) { struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); input_ids = to_backend(input_ids); + attention_mask = to_backend(attention_mask); for (auto& image_embed : image_embeds) { image_embed.second = to_backend(image_embed.second); @@ -1207,7 +1214,7 @@ namespace LLM { auto runner_ctx = get_context(); - struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, image_embeds, out_layers); + struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); ggml_build_forward_expand(gf, hidden_states); @@ -1216,12 +1223,13 @@ namespace LLM { bool compute(const int n_threads, struct ggml_tensor* input_ids, + struct ggml_tensor* attention_mask, std::vector> image_embeds, std::set out_layers, ggml_tensor** output, ggml_context* output_ctx = nullptr) { auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(input_ids, image_embeds, out_layers); + return build_graph(input_ids, attention_mask, image_embeds, out_layers); }; return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); } @@ -1525,7 +1533,7 @@ namespace LLM { struct ggml_tensor* out = nullptr; int64_t t0 = ggml_time_ms(); - model.compute(8, input_ids, image_embeds, {}, &out, work_ctx); + model.compute(8, input_ids, nullptr, image_embeds, {}, &out, work_ctx); int64_t t1 = ggml_time_ms(); print_ggml_tensor(out); @@ -1565,7 +1573,7 @@ namespace LLM { struct ggml_tensor* out = nullptr; int64_t t0 = ggml_time_ms(); - model.compute(8, input_ids, {}, {10, 20, 30}, &out, work_ctx); + model.compute(8, input_ids, nullptr, {}, {10, 20, 30}, &out, work_ctx); int64_t t1 = ggml_time_ms(); print_ggml_tensor(out); @@ -1588,7 +1596,7 @@ namespace LLM { struct ggml_tensor* out = nullptr; int64_t t0 = ggml_time_ms(); - model.compute(8, input_ids, {}, {35}, &out, work_ctx); + model.compute(8, input_ids, nullptr, {}, {35}, &out, work_ctx); int64_t t1 = ggml_time_ms(); print_ggml_tensor(out); @@ -1611,7 +1619,7 @@ namespace LLM { struct ggml_tensor* out = nullptr; int64_t t0 = ggml_time_ms(); - model.compute(8, input_ids, {}, {}, &out, work_ctx); + model.compute(8, input_ids, nullptr, {}, {}, &out, work_ctx); int64_t t1 = ggml_time_ms(); print_ggml_tensor(out); diff --git a/model.cpp b/model.cpp index a19f180da..3c8adcd04 100644 --- a/model.cpp +++ b/model.cpp @@ -1027,7 +1027,7 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s } SDVersion ModelLoader::get_sd_version() { - TensorStorage token_embedding_weight, input_block_weight; + TensorStorage token_embedding_weight, input_block_weight, context_ebedding_weight; bool has_multiple_encoders = false; bool is_unet = false; @@ -1041,7 +1041,7 @@ SDVersion ModelLoader::get_sd_version() { for (auto& [name, tensor_storage] : tensor_storage_map) { if (!(is_xl)) { - if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { + if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos || tensor_storage.name.find("model.diffusion_model.single_transformer_blocks.") != std::string::npos) { is_flux = true; } if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) { @@ -1108,6 +1108,9 @@ SDVersion ModelLoader::get_sd_version() { tensor_storage.name == "unet.conv_in.weight") { input_block_weight = tensor_storage; } + if (tensor_storage.name == "model.diffusion_model.txt_in.weight" || tensor_storage.name == "model.diffusion_model.context_embedder.weight") { + context_ebedding_weight = tensor_storage; + } } if (is_wan) { LOG_DEBUG("patch_embedding_channels %d", patch_embedding_channels); @@ -1135,16 +1138,20 @@ SDVersion ModelLoader::get_sd_version() { } if (is_flux) { - if (input_block_weight.ne[0] == 384) { - return VERSION_FLUX_FILL; - } - if (input_block_weight.ne[0] == 128) { - return VERSION_FLUX_CONTROLS; - } - if (input_block_weight.ne[0] == 196) { - return VERSION_FLEX_2; + if (context_ebedding_weight.ne[0] == 3584) { + return VERSION_LONGCAT; + } else { + if (input_block_weight.ne[0] == 384) { + return VERSION_FLUX_FILL; + } + if (input_block_weight.ne[0] == 128) { + return VERSION_FLUX_CONTROLS; + } + if (input_block_weight.ne[0] == 196) { + return VERSION_FLEX_2; + } + return VERSION_FLUX; } - return VERSION_FLUX; } if (token_embedding_weight.ne[0] == 768) { diff --git a/model.h b/model.h index b9e50ad63..1d0c6bca0 100644 --- a/model.h +++ b/model.h @@ -46,6 +46,7 @@ enum SDVersion { VERSION_FLUX2, VERSION_Z_IMAGE, VERSION_OVIS_IMAGE, + VERSION_LONGCAT, VERSION_COUNT, }; @@ -126,6 +127,13 @@ static inline bool sd_version_is_z_image(SDVersion version) { return false; } +static inline bool sd_version_is_longcat(SDVersion version) { + if (version == VERSION_LONGCAT) { + return true; + } + return false; +} + static inline bool sd_version_is_inpaint(SDVersion version) { if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || @@ -143,7 +151,8 @@ static inline bool sd_version_is_dit(SDVersion version) { sd_version_is_sd3(version) || sd_version_is_wan(version) || sd_version_is_qwen_image(version) || - sd_version_is_z_image(version)) { + sd_version_is_z_image(version) || + sd_version_is_longcat(version)) { return true; } return false; diff --git a/name_conversion.cpp b/name_conversion.cpp index 3ae229b63..dad35f4d3 100644 --- a/name_conversion.cpp +++ b/name_conversion.cpp @@ -508,6 +508,12 @@ std::string convert_diffusers_dit_to_original_flux(std::string name) { static std::unordered_map flux_name_map; if (flux_name_map.empty()) { + // --- time_embed (longcat) --- + flux_name_map["time_embed.timestep_embedder.linear_1.weight"] = "time_in.in_layer.weight"; + flux_name_map["time_embed.timestep_embedder.linear_1.bias"] = "time_in.in_layer.bias"; + flux_name_map["time_embed.timestep_embedder.linear_2.weight"] = "time_in.out_layer.weight"; + flux_name_map["time_embed.timestep_embedder.linear_2.bias"] = "time_in.out_layer.bias"; + // --- time_text_embed --- flux_name_map["time_text_embed.timestep_embedder.linear_1.weight"] = "time_in.in_layer.weight"; flux_name_map["time_text_embed.timestep_embedder.linear_1.bias"] = "time_in.in_layer.bias"; @@ -660,7 +666,7 @@ std::string convert_diffusion_model_name(std::string name, std::string prefix, S name = convert_diffusers_unet_to_original_sdxl(name); } else if (sd_version_is_sd3(version)) { name = convert_diffusers_dit_to_original_sd3(name); - } else if (sd_version_is_flux(version) || sd_version_is_flux2(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_flux2(version) || sd_version_is_longcat(version)) { name = convert_diffusers_dit_to_original_flux(name); } else if (sd_version_is_z_image(version)) { name = convert_diffusers_dit_to_original_lumina2(name); diff --git a/rope.hpp b/rope.hpp index 2d123b3cc..ab15e458c 100644 --- a/rope.hpp +++ b/rope.hpp @@ -106,6 +106,15 @@ namespace Rope { return txt_ids; } + __STATIC_INLINE__ std::vector> gen_longcat_txt_ids(int bs, int context_len, int axes_dim_num) { + auto txt_ids = std::vector>(bs * context_len, std::vector(axes_dim_num, 0.0f)); + for (int i = 0; i < bs * context_len; i++) { + txt_ids[i][1] = (i % context_len); + txt_ids[i][2] = (i % context_len); + } + return txt_ids; + } + __STATIC_INLINE__ std::vector> gen_flux_img_ids(int h, int w, int patch_size, @@ -117,7 +126,6 @@ namespace Rope { bool scale_rope = false) { int h_len = (h + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size; - std::vector> img_ids(h_len * w_len, std::vector(axes_dim_num, 0.0)); int h_start = h_offset; @@ -130,7 +138,6 @@ namespace Rope { std::vector row_ids = linspace(1.f * h_start, 1.f * h_start + h_len - 1, h_len); std::vector col_ids = linspace(1.f * w_start, 1.f * w_start + w_len - 1, w_len); - for (int i = 0; i < h_len; ++i) { for (int j = 0; j < w_len; ++j) { img_ids[i * w_len + j][0] = 1.f * index; @@ -206,14 +213,16 @@ namespace Rope { __STATIC_INLINE__ std::vector> gen_refs_ids(int patch_size, int bs, int axes_dim_num, + int start_index, const std::vector& ref_latents, bool increase_ref_index, float ref_index_scale, - bool scale_rope) { + bool scale_rope, + int base_offset = 0) { std::vector> ids; int curr_h_offset = 0; int curr_w_offset = 0; - int index = 1; + int index = start_index; for (ggml_tensor* ref : ref_latents) { int h_offset = 0; int w_offset = 0; @@ -232,8 +241,8 @@ namespace Rope { bs, axes_dim_num, static_cast(index * ref_index_scale), - h_offset, - w_offset, + h_offset + base_offset, + w_offset + base_offset, scale_rope); ids = concat_ids(ids, ref_ids, bs); @@ -256,13 +265,17 @@ namespace Rope { std::set txt_arange_dims, const std::vector& ref_latents, bool increase_ref_index, - float ref_index_scale) { - auto txt_ids = gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims); - auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num); + float ref_index_scale, + bool is_longcat) { + int x_index = is_longcat ? 1 : 0; + + auto txt_ids = is_longcat ? gen_longcat_txt_ids(bs, context_len, axes_dim_num) : gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims); + int offset = is_longcat ? context_len : 0; + auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, x_index, offset, offset); auto ids = concat_ids(txt_ids, img_ids, bs); if (ref_latents.size() > 0) { - auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale, false); + auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, x_index + 1, ref_latents, increase_ref_index, ref_index_scale, false, offset); ids = concat_ids(ids, refs_ids, bs); } return ids; @@ -281,7 +294,8 @@ namespace Rope { int theta, bool circular_h, bool circular_w, - const std::vector& axes_dim) { + const std::vector& axes_dim, + bool is_longcat) { std::vector> ids = gen_flux_ids(h, w, patch_size, @@ -291,7 +305,8 @@ namespace Rope { txt_arange_dims, ref_latents, increase_ref_index, - ref_index_scale); + ref_index_scale, + is_longcat); std::vector> wrap_dims; if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) { int h_len = (h + (patch_size / 2)) / patch_size; @@ -356,7 +371,7 @@ namespace Rope { auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, 0, 0, 0, true); auto ids = concat_ids(txt_ids_repeated, img_ids, bs); if (ref_latents.size() > 0) { - auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1.f, true); + auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, 1, ref_latents, increase_ref_index, 1.f, true); ids = concat_ids(ids, refs_ids, bs); } return ids; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 75689ff94..c7ee62d32 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -49,6 +49,7 @@ const char* model_version_to_str[] = { "Flux.2", "Z-Image", "Ovis Image", + "Longcat-Image", }; const char* sampling_methods_str[] = { @@ -392,7 +393,7 @@ class StableDiffusionGGML { } else if (sd_version_is_sd3(version)) { scale_factor = 1.5305f; shift_factor = 0.0609f; - } else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_longcat(version)) { scale_factor = 0.3611f; shift_factor = 0.1159f; } else if (sd_version_is_wan(version) || @@ -424,8 +425,8 @@ class StableDiffusionGGML { offload_params_to_cpu, tensor_storage_map); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map); + offload_params_to_cpu, + tensor_storage_map); } else if (sd_version_is_flux(version)) { bool is_chroma = false; for (auto pair : tensor_storage_map) { @@ -473,10 +474,26 @@ class StableDiffusionGGML { tensor_storage_map, version); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - version, - sd_ctx_params->chroma_use_dit_mask); + offload_params_to_cpu, + tensor_storage_map, + version, + sd_ctx_params->chroma_use_dit_mask); + } else if (sd_version_is_longcat(version)) { + bool enable_vision = false; + if (!vae_decode_only) { + enable_vision = true; + } + cond_stage_model = std::make_shared(clip_backend, + offload_params_to_cpu, + tensor_storage_map, + version, + "", + enable_vision); + diffusion_model = std::make_shared(backend, + offload_params_to_cpu, + tensor_storage_map, + version, + sd_ctx_params->chroma_use_dit_mask); } else if (sd_version_is_wan(version)) { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, @@ -485,10 +502,10 @@ class StableDiffusionGGML { 1, true); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - "model.diffusion_model", - version); + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model", + version); if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) { high_noise_diffusion_model = std::make_shared(backend, offload_params_to_cpu, @@ -528,10 +545,10 @@ class StableDiffusionGGML { tensor_storage_map, version); diffusion_model = std::make_shared(backend, - offload_params_to_cpu, - tensor_storage_map, - "model.diffusion_model", - version); + offload_params_to_cpu, + tensor_storage_map, + "model.diffusion_model", + version); } else { // SD1.x SD2.x SDXL std::map embbeding_map; for (uint32_t i = 0; i < sd_ctx_params->embedding_count; i++) { @@ -883,7 +900,7 @@ class StableDiffusionGGML { flow_shift = 3.f; } } - } else if (sd_version_is_flux(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_longcat(version)) { pred_type = FLUX_FLOW_PRED; if (flow_shift == INFINITY) { @@ -893,6 +910,9 @@ class StableDiffusionGGML { flow_shift = 1.15f; } } + if (sd_version_is_longcat(version)) { + flow_shift = 3.0f; + } } } else if (sd_version_is_flux2(version)) { pred_type = FLUX2_FLOW_PRED; @@ -1456,7 +1476,7 @@ class StableDiffusionGGML { if (sd_version_is_flux2(version)) { latent_rgb_proj = flux2_latent_rgb_proj; latent_rgb_bias = flux2_latent_rgb_bias; - patch_sz = 2; + patch_sz = 2; } } else if (dim == 48) { if (sd_version_is_wan(version)) { @@ -1473,7 +1493,7 @@ class StableDiffusionGGML { if (sd_version_is_sd3(version)) { latent_rgb_proj = sd3_latent_rgb_proj; latent_rgb_bias = sd3_latent_rgb_bias; - } else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_longcat(version)) { latent_rgb_proj = flux_latent_rgb_proj; latent_rgb_bias = flux_latent_rgb_bias; } else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {