Skip to content

Commit 13a4744

Browse files
committed
support glm-ocr
1 parent f03a896 commit 13a4744

10 files changed

Lines changed: 513 additions & 248 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ LittleAcademia[<a href="https://github.com/foldl/little-academia" style="text-
3535

3636
**What's New:**
3737

38+
* 2026-03-03: GLM-OCR
3839
* 2026-02-22: Youtu-VL
3940
* 2026-02-18: Youtu-LLM
4041
* 2026-02-16: Voice Clone with Qwen3-TTS

convert.py

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ class ModelType(Enum):
263263
DotsOCR = ModelTypeTagChatImageIn + 0x0000020
264264
Mistral3 = ModelTypeTagChatImageIn + 0x0000030
265265
StepVL = ModelTypeTagChatImageIn + 0x0000040
266+
GLM_OCR = ModelTypeTagChatImageIn + 0x0000050
266267

267268
Qwen2Audio = ModelTypeTagChatAudioIn + 0x0000001
268269
Qwen3ForcedAligner = ModelTypeTagChatAudioIn + 0x0000002
@@ -3972,6 +3973,7 @@ def get_weight_names(config):
39723973

39733974
class GLM4VConverter(BaseConverter):
39743975
MODEL_TYPE = ModelType.GLM4V
3976+
ASSERT_HEAD_DIM = True
39753977

39763978
@classmethod
39773979
def state_dict_pp(cls, config, state_dict):
@@ -3987,7 +3989,7 @@ def state_dict_pp(cls, config, state_dict):
39873989
r[name.replace('gate_up_proj.weight', 'up_proj.weight')] = part(tensor, 1, 2).contiguous()
39883990
elif ('.k_proj.' in name) or ('.q_proj.' in name):
39893991
rope_dim = GLM4VConverter.rope_dim
3990-
head_dim = GLM4VConverter.txt_config.hidden_size // GLM4VConverter.txt_config.num_attention_heads
3992+
head_dim = GLM4VConverter.txt_config.head_dim
39913993
r[name] = permute_pair_rope_nope(tensor, tensor.shape[0] // head_dim, rope_dim)
39923994
else:
39933995
r[name] = tensor
@@ -4020,11 +4022,16 @@ def state_dict_pp(cls, config, state_dict):
40204022
def dump_config(f, config, ggml_type):
40214023
GLM4VConverter.txt_config = AttributeDict(config.text_config)
40224024
txt_config = GLM4VConverter.txt_config
4023-
assert txt_config.attention_bias
4025+
40244026
if isinstance(txt_config.eos_token_id, list):
40254027
txt_config.eos_token_id = txt_config.eos_token_id[0]
40264028

4027-
head_dim = txt_config.hidden_size // txt_config.num_attention_heads
4029+
if 'head_dim' not in txt_config:
4030+
txt_config.head_dim = txt_config.hidden_size // txt_config.num_attention_heads
4031+
head_dim = txt_config.head_dim
4032+
4033+
if GLM4VConverter.ASSERT_HEAD_DIM:
4034+
assert head_dim == txt_config.hidden_size // txt_config.num_attention_heads
40284035

40294036
rope_dim = int(txt_config.rope_parameters["partial_rotary_factor"] * head_dim)
40304037
GLM4VConverter.rope_dim = rope_dim
@@ -4076,6 +4083,68 @@ def get_weight_names(config):
40764083
weights += GLM4VConverter.get_vit_weight_names(config.vision_config['depth'])
40774084
return weights
40784085

4086+
class GLMOCRConverter(BaseConverter):
4087+
MODEL_TYPE = ModelType.GLM_OCR
4088+
4089+
@classmethod
4090+
def state_dict_pp(cls, config, state_dict):
4091+
return GLM4VConverter.state_dict_pp(config, state_dict)
4092+
4093+
@staticmethod
4094+
def dump_config(f, config, ggml_type):
4095+
print("WARNING: MTP not supported!")
4096+
GLM4VConverter.ASSERT_HEAD_DIM = False
4097+
4098+
GLM4VConverter.dump_config(f, config, ggml_type)
4099+
4100+
config_values = [
4101+
GLM4VConverter.txt_config.head_dim
4102+
]
4103+
f.write(struct.pack("<i", *config_values))
4104+
4105+
@staticmethod
4106+
def get_vit_weight_names(num_layer):
4107+
weight_names = ["visual.downsample.weight",
4108+
"visual.downsample.bias",
4109+
"visual.merger.gate_proj.weight",
4110+
"visual.merger.up_proj.weight",
4111+
"visual.merger.down_proj.weight",
4112+
"visual.merger.proj.weight",
4113+
"visual.merger.post_projection_norm.weight",
4114+
"visual.merger.post_projection_norm.bias",
4115+
"visual.patch_embed.proj.0.weight",
4116+
"visual.patch_embed.proj.bias",
4117+
"visual.patch_embed.proj.1.weight",
4118+
"visual.post_layernorm.weight"]
4119+
for i in range(num_layer):
4120+
weight_names += [
4121+
f"visual.layers.{i}.norm1.weight",
4122+
f"visual.layers.{i}.norm2.weight",
4123+
f"visual.layers.{i}.attn.q_proj.weight",
4124+
f"visual.layers.{i}.attn.k_proj.weight",
4125+
f"visual.layers.{i}.attn.v_proj.weight",
4126+
f"visual.layers.{i}.attn.o_proj.weight",
4127+
f"visual.layers.{i}.attn.q_norm.weight",
4128+
f"visual.layers.{i}.attn.k_norm.weight",
4129+
f"visual.layers.{i}.mlp.gate_proj.weight",
4130+
f"visual.layers.{i}.mlp.up_proj.weight",
4131+
f"visual.layers.{i}.mlp.down_proj.weight",
4132+
f"visual.layers.{i}.attn.q_proj.bias",
4133+
f"visual.layers.{i}.attn.k_proj.bias",
4134+
f"visual.layers.{i}.attn.v_proj.bias",
4135+
f"visual.layers.{i}.attn.o_proj.bias",
4136+
f"visual.layers.{i}.mlp.gate_proj.bias",
4137+
f"visual.layers.{i}.mlp.up_proj.bias",
4138+
f"visual.layers.{i}.mlp.down_proj.bias",
4139+
]
4140+
return weight_names
4141+
4142+
@staticmethod
4143+
def get_weight_names(config):
4144+
weights = GLM4Converter.get_weight_names(GLM4VConverter.txt_config)
4145+
weights += GLMOCRConverter.get_vit_weight_names(config.vision_config['depth'])
4146+
return weights
4147+
40794148
class Phi2Converter(BaseConverter):
40804149
MODEL_TYPE = ModelType.Phi2
40814150

@@ -9694,6 +9763,8 @@ def main():
96949763
DotsOCRConverter.convert(config, model_files, vocab, ggml_type, args.save_path)
96959764
elif arch.endswith('Glm4vForConditionalGeneration'):
96969765
GLM4VConverter.convert(config, model_files, vocab, ggml_type, args.save_path)
9766+
elif arch.endswith('GlmOcrForConditionalGeneration'):
9767+
GLMOCRConverter.convert(config, model_files, vocab, ggml_type, args.save_path)
96979768
elif arch == 'MegrezMoeForCausalLM':
96989769
MegrezMoEConverter.convert(config, model_files, vocab, ggml_type, args.save_path)
96999770
elif arch == 'OuroForCausalLM':

docs/models.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,9 @@ Please use `--format completion` for these models.
460460
* Nanonets-OCR2 (`Qwen2VLForConditionalGeneration`, `Qwen2_5_VLForConditionalGeneration`)
461461
* [x] OCR2: [3B](https://huggingface.co/nanonets/Nanonets-OCR2-3B/tree/d0368059ad151ce9e38f526890cfd4f27b28be65), [1.5B](https://huggingface.co/nanonets/Nanonets-OCR2-1.5B-exp/tree/306a9b2a65672a3dbebd9bce9a9373a9a18674a2)
462462

463+
* GLM-OCR (`GlmOcrForConditionalGeneration`)
464+
* [x] [0.7B](https://huggingface.co/zai-org/GLM-OCR/tree/677c6baa60442a451f8a8c7eabdfab32d9801a0b)
465+
463466
## ASR Models
464467

465468
* GLM-ASR (`GlmAsrForConditionalGeneration`)

models/chatglm.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ namespace chatllm::glm::glm4_0414
680680
return r;
681681
}
682682

683-
ConditionalGeneration::ConditionalGeneration(const Config &config, const RuntimeConfig &runtime_config, ModelType type)
683+
ConditionalGeneration::ConditionalGeneration(const Config &config, const RuntimeConfig &runtime_config, ModelType type, int head_dim)
684684
: BaseModelForConditionalGeneration(type, config, runtime_config), config(config)
685685
{
686686
const size_t tensor_ovhd = ggml_tensor_overhead();
@@ -689,10 +689,12 @@ namespace chatllm::glm::glm4_0414
689689
w_ctx_.gctx = GGMLContext({.mem_size = ctx_size, .mem_buffer = nullptr, .no_alloc = true});
690690
w_ctx_.dtype = config.dtype;
691691

692+
if (head_dim < 0) head_dim = config.hidden_size / config.num_attention_heads;
693+
692694
transformer = new ModelClass(
693695
&w_ctx_, config, false,
694696
config.hidden_size, config.num_attention_heads, config.num_key_value_heads,
695-
config.intermediate_size, config.max_length, config.use_attention_bias != 0, false);
697+
config.intermediate_size, head_dim, config.max_length, config.use_attention_bias != 0, false);
696698

697699
for (int i = 0; i < config.num_hidden_layers; i++)
698700
{

models/chatglm.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,8 @@ namespace chatllm::glm::glm4_0414
187187
class GLM4SelfAttention : public RoPESelfAttention<BaseAttention>
188188
{
189189
public:
190-
GLM4SelfAttention(InitContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int max_length, bool qkv_bias, bool o_bias)
191-
: RoPESelfAttention<BaseAttention>(ctx, hidden_size, num_attention_heads, num_kv_heads, max_length, qkv_bias, o_bias)
190+
GLM4SelfAttention(InitContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int head_dim, int max_length, bool qkv_bias, bool o_bias)
191+
: RoPESelfAttention<BaseAttention>(ctx, hidden_size, num_attention_heads, num_kv_heads, head_dim, max_length, qkv_bias, o_bias)
192192
{
193193
}
194194
};
@@ -197,20 +197,20 @@ namespace chatllm::glm::glm4_0414
197197
{
198198
public:
199199
GLM4Block(InitContext *ctx, int hidden_size, int num_attention_heads, int num_kv_heads, int intermediate_size,
200-
int max_length, bool qkv_bias, bool o_bias)
201-
: LMBlock4(ctx, hidden_size, num_attention_heads, intermediate_size, num_kv_heads, max_length, qkv_bias, o_bias)
200+
int head_dim, int max_length, bool qkv_bias, bool o_bias)
201+
: LMBlock4(ctx, hidden_size, num_attention_heads, intermediate_size, num_kv_heads, head_dim, max_length, qkv_bias, o_bias)
202202
{
203203
mlp.set_prec(ggml::prec::GGML_PREC_F32);
204204
}
205205
};
206206

207-
typedef Model<Config, Embedding, RMSNorm, GLM4Block, int, int, int, int, int, bool, bool> ModelClass;
207+
typedef Model<Config, Embedding, RMSNorm, GLM4Block, int, int, int, int, int, int, bool, bool> ModelClass;
208208

209209
class ConditionalGeneration : public BaseModelForConditionalGeneration
210210
{
211211
public:
212212
ConditionalGeneration() = default;
213-
ConditionalGeneration(const Config &config, const RuntimeConfig &runtime_config, ModelType type = MODEL_TYPE_GLM4);
213+
ConditionalGeneration(const Config &config, const RuntimeConfig &runtime_config, ModelType type = MODEL_TYPE_GLM4, int head_dim = -1);
214214

215215
void load(ModelLoader &loader) override;
216216

0 commit comments

Comments
 (0)