-
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_multimodal.py
More file actions
99 lines (77 loc) · 3.64 KB
/
generate_multimodal.py
File metadata and controls
99 lines (77 loc) · 3.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from models.llm import MinimalLLM
from models.vqvae import VQVAE
from configs.multimodal_config import MultimodalConfig
from PIL import Image
import os
def generate(prompt, model_path, vqvae_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = MultimodalConfig()
# 1. Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
# 2. Load LLM
model = MinimalLLM(config).to(device)
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
# 3. Load VQ-VAE
vq_model = VQVAE(num_embeddings=config.image_vocab_size).to(device)
if os.path.exists(vqvae_path):
vq_model.load_state_dict(torch.load(vqvae_path, map_location=device))
vq_model.eval()
# 4. Prepare Prompt
text_ids = tokenizer.encode(prompt, add_special_tokens=True)
# Append <seg_start> to force image generation start
input_ids = torch.tensor(text_ids + [config.seg_start_id], device=device).unsqueeze(0)
# 5. Generate Autoregressively
print(f"Generating image for prompt: {prompt}")
generated_tokens = []
temperature = 1.0
top_k = 50
with torch.no_grad():
for _ in range(config.num_image_tokens):
logits = model(input_ids)
last_logits = logits[:, -1, :] / temperature
# Top-k sampling
if top_k > 0:
v, _ = torch.topk(last_logits, min(top_k, last_logits.size(-1)))
last_logits[last_logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(last_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=1)
generated_tokens.append(next_token.item())
if next_token.item() == config.seg_end_id:
break
# 6. Extract and Decode Image Tokens
# Remove offset to get back to VQ codebook range
image_tokens = [t - config.image_token_offset for t in generated_tokens if config.image_token_offset <= t < config.image_token_offset + config.image_vocab_size]
if len(image_tokens) == 0:
print("Error: No image tokens generated.")
return
print(f"Decoded {len(image_tokens)} tokens.")
# Pad if necessary to match expected 32x32 = 1024
if len(image_tokens) < 1024:
image_tokens += [0] * (1024 - len(image_tokens))
image_tokens = image_tokens[:1024]
tokens_tensor = torch.tensor(image_tokens, device=device).unsqueeze(1)
with torch.no_grad():
# Our VQ-VAE expects 1024 tokens for 128x128 image (32x32 grid)
decoded_img = vq_model.decode(tokens_tensor, 32, 32)
# 7. Save Image
# Denormalize
decoded_img = (decoded_img + 1) / 2
decoded_img = decoded_img.clamp(0, 1).cpu().squeeze(0).permute(1, 2, 0).numpy()
decoded_img = (decoded_img * 255).astype("uint8")
img = Image.fromarray(decoded_img)
img.save("generated_image.png")
print("✅ Image saved to generated_image.png")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, default="A photo of a cat")
parser.add_argument("--model_path", type=str, default="checkpoints/multimodal_llm_epoch_5.pt")
parser.add_argument("--vqvae_path", type=str, default="checkpoints/vqvae_epoch_5.pt")
args = parser.parse_args()
generate(args.prompt, args.model_path, args.vqvae_path)