-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_llm.py
More file actions
32 lines (25 loc) · 1005 Bytes
/
train_llm.py
File metadata and controls
32 lines (25 loc) · 1005 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import argparse
import torch
from configs.dit_config import DiTConfig
from models.dit import DiT
from data.video_loader import get_dataloaders
from training.video_trainer import train_video_model
from utils.logger import setup_logging
def main():
parser = argparse.ArgumentParser(description="Train DiT Video Model")
parser.add_argument("--batch_size", type=int, help="Batch size")
parser.add_argument("--steps", type=int, help="Training steps")
args = parser.parse_args()
config = DiTConfig()
if args.batch_size:
config.batch_size = args.batch_size
if args.steps:
config.train_steps = args.steps
logger = setup_logging()
logger.info("Initializing Video DiT...")
model = DiT(config)
print(f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
train_loader, val_loader = get_dataloaders(config)
train_video_model(model, train_loader, val_loader, config)
if __name__ == "__main__":
main()