-
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_flow_matching.py
More file actions
250 lines (206 loc) · 8.08 KB
/
generate_flow_matching.py
File metadata and controls
250 lines (206 loc) · 8.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import torch
import torchvision.utils as vutils
import os
import argparse
from models.dit_2d import DiT2D
from configs.flow_matching_config import FlowMatchingConfig
def sample_euler(model, config, device, num_samples=16, class_labels=None, num_steps=None):
"""
Euler sampling for flow matching.
Integrates the ODE dx/dt = v_theta(x_t, t) from t=0 to t=1.
Args:
model: The trained DiT2D model
config: FlowMatchingConfig
device: torch device
num_samples: Number of samples to generate
class_labels: Optional class labels for conditional generation
num_steps: Number of integration steps (default: config.num_sampling_steps)
Returns:
Generated images tensor [num_samples, C, H, W]
"""
model.eval()
# Start from pure noise (t=0)
channels = config.in_channels
size = config.input_size
x_t = torch.randn(num_samples, channels, size, size, device=device)
# If no class labels provided, use random ones or all zeros
if class_labels is None:
class_labels = torch.randint(0, config.num_classes, (num_samples,), device=device)
elif isinstance(class_labels, int):
class_labels = torch.full((num_samples,), class_labels, device=device)
# Number of integration steps
if num_steps is None:
num_steps = config.num_sampling_steps
dt = 1.0 / num_steps
print(f"Sampling with {num_steps} steps...")
with torch.no_grad():
for i in range(num_steps):
# Current time
t = torch.ones(num_samples, device=device) * (i * dt)
# Predict velocity field
if config.use_cfg:
v_pred = model.forward_with_cfg(x_t, t, class_labels, config.cfg_scale)
else:
v_pred = model(x_t, t, class_labels)
# Euler integration step: x_{t+dt} = x_t + v(x_t, t) * dt
x_t = x_t + v_pred * dt
model.train()
return x_t
def load_model(checkpoint_path, device):
"""Load a trained model from checkpoint."""
print(f"Loading checkpoint from {checkpoint_path}...")
checkpoint = torch.load(checkpoint_path, map_location=device)
# Extract config
if 'config' in checkpoint:
config = checkpoint['config']
else:
print("Warning: No config found in checkpoint, using default config")
config = FlowMatchingConfig()
# Create model
model = DiT2D(config).to(device)
# Load weights
if 'ema_shadow' in checkpoint:
print("Loading EMA weights...")
# Load EMA shadow parameters
model_state = checkpoint['model_state_dict']
for name, param in model.named_parameters():
if name in checkpoint['ema_shadow']:
model_state[name] = checkpoint['ema_shadow'][name]
model.load_state_dict(model_state)
else:
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"Model loaded successfully! (Step: {checkpoint.get('step', 'unknown')})")
return model, config
def generate_samples(
checkpoint_path,
num_samples=16,
class_labels=None,
num_steps=50,
output_dir="generated_samples",
cfg_scale=None
):
"""
Generate samples from a trained flow matching model.
Args:
checkpoint_path: Path to model checkpoint
num_samples: Number of images to generate
class_labels: Optional class label(s) for conditional generation
num_steps: Number of sampling steps
output_dir: Directory to save generated images
cfg_scale: Optional classifier-free guidance scale override
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load model
model, config = load_model(checkpoint_path, device)
# Override sampling steps and cfg scale if provided
if num_steps is not None:
config.num_sampling_steps = num_steps
if cfg_scale is not None:
config.cfg_scale = cfg_scale
print(f"Generating {num_samples} samples...")
print(f"CFG Scale: {config.cfg_scale}")
print(f"Sampling steps: {config.num_sampling_steps}")
# Generate samples
samples = sample_euler(
model,
config,
device,
num_samples=num_samples,
class_labels=class_labels,
num_steps=config.num_sampling_steps
)
# Denormalize from [-1, 1] to [0, 1]
samples = (samples + 1) / 2
samples = torch.clamp(samples, 0, 1)
# Save samples
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "generated.png")
vutils.save_image(
samples,
output_path,
nrow=int(num_samples ** 0.5),
normalize=False
)
print(f"Saved samples to {output_path}")
# Also save individual images if requested
if num_samples <= 64:
individual_dir = os.path.join(output_dir, "individual")
os.makedirs(individual_dir, exist_ok=True)
for i, sample in enumerate(samples):
vutils.save_image(
sample,
os.path.join(individual_dir, f"sample_{i:03d}.png"),
normalize=False
)
print(f"Saved individual samples to {individual_dir}/")
return samples
def generate_class_grid(checkpoint_path, samples_per_class=8, num_steps=50, output_dir="generated_samples"):
"""
Generate a grid of samples for all classes.
Args:
checkpoint_path: Path to model checkpoint
samples_per_class: Number of samples per class
num_steps: Number of sampling steps
output_dir: Directory to save generated images
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
model, config = load_model(checkpoint_path, device)
config.num_sampling_steps = num_steps
all_samples = []
for class_idx in range(config.num_classes):
print(f"Generating samples for class {class_idx}...")
samples = sample_euler(
model,
config,
device,
num_samples=samples_per_class,
class_labels=class_idx,
num_steps=num_steps
)
all_samples.append(samples)
# Concatenate all samples
all_samples = torch.cat(all_samples, dim=0)
# Denormalize
all_samples = (all_samples + 1) / 2
all_samples = torch.clamp(all_samples, 0, 1)
# Save grid
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "class_grid.png")
vutils.save_image(
all_samples,
output_path,
nrow=samples_per_class,
normalize=False
)
print(f"Saved class grid to {output_path}")
return all_samples
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate images using flow matching")
parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint")
parser.add_argument("--num_samples", type=int, default=16, help="Number of samples to generate")
parser.add_argument("--class_label", type=int, default=None, help="Optional class label for conditional generation")
parser.add_argument("--num_steps", type=int, default=50, help="Number of sampling steps")
parser.add_argument("--cfg_scale", type=float, default=None, help="Classifier-free guidance scale")
parser.add_argument("--output_dir", type=str, default="generated_samples", help="Output directory")
parser.add_argument("--class_grid", action="store_true", help="Generate a grid with all classes")
parser.add_argument("--samples_per_class", type=int, default=8, help="Samples per class for class grid")
args = parser.parse_args()
if args.class_grid:
generate_class_grid(
checkpoint_path=args.checkpoint,
samples_per_class=args.samples_per_class,
num_steps=args.num_steps,
output_dir=args.output_dir
)
else:
generate_samples(
checkpoint_path=args.checkpoint,
num_samples=args.num_samples,
class_labels=args.class_label,
num_steps=args.num_steps,
output_dir=args.output_dir,
cfg_scale=args.cfg_scale
)