-
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_vqvae.py
More file actions
57 lines (43 loc) · 1.84 KB
/
train_vqvae.py
File metadata and controls
57 lines (43 loc) · 1.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import torch.optim as optim
from models.vqvae import VQVAE
from data.vqvae_dataset import get_vqvae_dataloader
import os
def train():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Hyperparameters
num_hiddens = 128
num_residual_layers = 2
num_residual_hiddens = 32
num_embeddings = 1024 # Small codebook for efficiency
embedding_dim = 64
commitment_cost = 0.25
learning_rate = 1e-3
batch_size = 64
num_epochs = 10
dataloader = get_vqvae_dataloader(batch_size=batch_size)
model = VQVAE(num_hiddens, num_residual_layers, num_residual_hiddens,
num_embeddings, embedding_dim, commitment_cost).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, amsgrad=False)
model.train()
for epoch in range(num_epochs):
total_loss = 0
total_recon_error = 0
total_perplexity = 0
for i, data in enumerate(dataloader):
data = data.to(device)
optimizer.zero_grad()
vq_loss, data_recon, perplexity = model(data)
recon_error = torch.mean((data_recon - data)**2)
loss = vq_loss + recon_error
loss.backward()
optimizer.step()
total_loss += loss.item()
total_recon_error += recon_error.item()
total_perplexity += perplexity.item()
if (i+1) % 10 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}, Recon: {recon_error.item():.4f}, Perplexity: {perplexity.item():.4f}")
os.makedirs("checkpoints", exist_ok=True)
torch.save(model.state_dict(), f"checkpoints/vqvae_epoch_{epoch+1}.pt")
if __name__ == "__main__":
train()