diff --git a/records/track_10min_16mb/2026-03-22_AtrisLabs_v8/README.md b/records/track_10min_16mb/2026-03-22_AtrisLabs_v8/README.md new file mode 100644 index 000000000..88706a4ce --- /dev/null +++ b/records/track_10min_16mb/2026-03-22_AtrisLabs_v8/README.md @@ -0,0 +1,54 @@ +# Atris Labs — 10L MLP3x + Int5/Int6 + BigramHash + SmearGate + SWA + +## Approach + +Stacked 8 independently validated techniques matching the current leaderboard winners: + +### Architecture (25.5M params) +- **10 transformer layers** with U-Net skip connections +- **MLP 3x** expansion (1536 hidden, relu-squared) +- **BigramHash(10240)**: Hash consecutive token pairs into 10240-bucket embedding table (dim=128), zero-init with learnable scale (0.05) +- **SmearGate**: Per-dimension learned gate blending each token with previous token embedding + +### Training +- **Muon optimizer**: matrix_lr=0.02, momentum=0.99 (warmup 0.92→0.99 over 1500 steps), weight decay=0.04 +- **AdamW**: tied_embed_lr=0.03, scalar_lr=0.02, weight decay=0.01 +- **Sequence length**: 2048 tokens, batch 786,432 tokens/step +- **Gradient clipping**: norm=0.3 +- **SWA**: Average 24 checkpoints during warmdown (when LR scale < 0.4) +- **Warmdown**: 3000 iterations + +### Quantization & Compression +- **Int5 MLP weights** (32 levels, per-row scale) — compresses ~1.88x under zstd +- **Int6 attention weights** (64 levels, per-row scale) — compresses ~1.51x +- **FP16 passthrough** for tied embeddings +- **3% magnitude pruning** before quantization +- **zstd-22** compression (or zlib fallback) + +### Evaluation +- **Reported score path:** standard final eval with `EVAL_SEQ_LEN=2048` +- **Sliding-window code path:** included in `train_gpt.py`, but not used for the reported metrics in this folder + +## Key Metrics (audited seed=42 run) + +- **val_bpb (int8+zlib roundtrip exact):** 1.18069496 +- **val_loss:** 1.99355398 +- **Artifact size:** 14,461,499 bytes (under 16MB) +- **Training steps:** 6428 in 600.039s on 8xH100 (93.35ms/step) +- **Peak memory:** 18,974 MiB +- **SWA:** 24 checkpoints averaged during warmdown +- **Train log:** included as `train.log` + +## Command + +```bash +NCCL_IB_DISABLE=1 \ +RUN_ID=atris_v8_submission \ +VAL_LOSS_EVERY=0 \ +TRAIN_LOG_EVERY=50 \ +WARMUP_STEPS=5 \ +MAX_WALLCLOCK_SECONDS=600 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +All other hyperparameters use defaults from `train_gpt.py`. diff --git a/records/track_10min_16mb/2026-03-22_AtrisLabs_v8/submission.json b/records/track_10min_16mb/2026-03-22_AtrisLabs_v8/submission.json new file mode 100644 index 000000000..c44518b61 --- /dev/null +++ b/records/track_10min_16mb/2026-03-22_AtrisLabs_v8/submission.json @@ -0,0 +1,11 @@ +{ + "author": "Atris Labs", + "github_id": "keshav55", + "name": "10L MLP3x Int5/Int6 + BigramHash + SmearGate + SWA", + "blurb": "Audited seed42 run. 25.5M param model with Int5 MLP + Int6 attn, BigramHash(10240), SmearGate, SWA(24 ckpts), WD=0.04, grad_clip=0.3, 3% pruning, seq_len=2048, 8xH100.", + "date": "2026-03-24T08:24:52Z", + "val_loss": 1.99355398, + "val_bpb": 1.18069496, + "bytes_total": 14461499, + "bytes_code": 65264 +} diff --git a/records/track_10min_16mb/2026-03-22_AtrisLabs_v8/train.log b/records/track_10min_16mb/2026-03-22_AtrisLabs_v8/train.log new file mode 100644 index 000000000..3d87e09b4 --- /dev/null +++ b/records/track_10min_16mb/2026-03-22_AtrisLabs_v8/train.log @@ -0,0 +1,168 @@ +logs/atris_v8_audit_seed42.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:10 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:25517137 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:5 max_wallclock_seconds:600.000 +seed:42 +v1:num_layers:10 int6_layers:[3,7) +warmup_step:1/5 +warmup_step:2/5 +warmup_step:3/5 +warmup_step:4/5 +warmup_step:5/5 +step:1/20000 train_loss:6.9334 train_time:273ms step_avg:273.50ms +step:2/20000 train_loss:8.1909 train_time:345ms step_avg:172.69ms +step:3/20000 train_loss:7.6563 train_time:442ms step_avg:147.20ms +step:4/20000 train_loss:6.8750 train_time:533ms step_avg:133.24ms +step:5/20000 train_loss:6.9017 train_time:625ms step_avg:124.97ms +step:6/20000 train_loss:6.8540 train_time:718ms step_avg:119.65ms +step:7/20000 train_loss:6.6549 train_time:808ms step_avg:115.49ms +step:8/20000 train_loss:6.6290 train_time:901ms step_avg:112.58ms +step:9/20000 train_loss:6.3737 train_time:992ms step_avg:110.28ms +step:10/20000 train_loss:6.0944 train_time:1084ms step_avg:108.44ms +step:50/20000 train_loss:3.8339 train_time:4748ms step_avg:94.96ms +step:100/20000 train_loss:3.1935 train_time:9327ms step_avg:93.27ms +step:150/20000 train_loss:2.9061 train_time:15008ms step_avg:100.05ms +step:200/20000 train_loss:2.4086 train_time:19575ms step_avg:97.87ms +step:250/20000 train_loss:2.5012 train_time:24165ms step_avg:96.66ms +step:300/20000 train_loss:2.5914 train_time:29300ms step_avg:97.67ms +step:350/20000 train_loss:2.5724 train_time:33885ms step_avg:96.81ms +step:400/20000 train_loss:2.4459 train_time:40431ms step_avg:101.08ms +step:450/20000 train_loss:2.3974 train_time:44999ms step_avg:100.00ms +step:500/20000 train_loss:2.4264 train_time:49587ms step_avg:99.17ms +step:550/20000 train_loss:2.3683 train_time:54704ms step_avg:99.46ms +step:600/20000 train_loss:2.3523 train_time:59291ms step_avg:98.82ms +step:650/20000 train_loss:2.3468 train_time:64385ms step_avg:99.05ms +step:700/20000 train_loss:2.3598 train_time:68973ms step_avg:98.53ms +step:750/20000 train_loss:2.3373 train_time:73724ms step_avg:98.30ms +step:800/20000 train_loss:2.2443 train_time:78958ms step_avg:98.70ms +step:850/20000 train_loss:2.2372 train_time:83526ms step_avg:98.27ms +step:900/20000 train_loss:2.1322 train_time:88501ms step_avg:98.33ms +step:950/20000 train_loss:2.2191 train_time:93087ms step_avg:97.99ms +step:1000/20000 train_loss:2.2727 train_time:97669ms step_avg:97.67ms +step:1050/20000 train_loss:2.2240 train_time:102687ms step_avg:97.80ms +step:1100/20000 train_loss:2.3209 train_time:107252ms step_avg:97.50ms +step:1150/20000 train_loss:2.2443 train_time:113722ms step_avg:98.89ms +step:1200/20000 train_loss:2.3511 train_time:118305ms step_avg:98.59ms +step:1250/20000 train_loss:2.2412 train_time:122871ms step_avg:98.30ms +step:1300/20000 train_loss:2.3554 train_time:127539ms step_avg:98.11ms +step:1350/20000 train_loss:2.1615 train_time:132116ms step_avg:97.86ms +step:1400/20000 train_loss:2.2085 train_time:136766ms step_avg:97.69ms +step:1450/20000 train_loss:2.1886 train_time:141353ms step_avg:97.48ms +step:1500/20000 train_loss:2.1808 train_time:145936ms step_avg:97.29ms +step:1550/20000 train_loss:2.1780 train_time:150615ms step_avg:97.17ms +step:1600/20000 train_loss:2.1859 train_time:155200ms step_avg:97.00ms +step:1650/20000 train_loss:1.9866 train_time:159760ms step_avg:96.82ms +step:1700/20000 train_loss:2.1853 train_time:164431ms step_avg:96.72ms +step:1750/20000 train_loss:2.1218 train_time:169010ms step_avg:96.58ms +step:1800/20000 train_loss:2.1355 train_time:173667ms step_avg:96.48ms +step:1850/20000 train_loss:2.1585 train_time:178239ms step_avg:96.35ms +step:1900/20000 train_loss:2.1987 train_time:182810ms step_avg:96.22ms +step:1950/20000 train_loss:2.1352 train_time:187462ms step_avg:96.13ms +step:2000/20000 train_loss:2.1734 train_time:192035ms step_avg:96.02ms +step:2050/20000 train_loss:2.0977 train_time:196683ms step_avg:95.94ms +step:2100/20000 train_loss:2.0760 train_time:201244ms step_avg:95.83ms +step:2150/20000 train_loss:2.0499 train_time:205812ms step_avg:95.73ms +step:2200/20000 train_loss:2.1803 train_time:210455ms step_avg:95.66ms +step:2250/20000 train_loss:2.1192 train_time:215016ms step_avg:95.56ms +step:2300/20000 train_loss:2.1180 train_time:219670ms step_avg:95.51ms +step:2350/20000 train_loss:2.1528 train_time:224246ms step_avg:95.42ms +step:2400/20000 train_loss:2.1788 train_time:228807ms step_avg:95.34ms +step:2450/20000 train_loss:2.1676 train_time:233459ms step_avg:95.29ms +step:2500/20000 train_loss:2.0742 train_time:238029ms step_avg:95.21ms +step:2550/20000 train_loss:2.1656 train_time:242752ms step_avg:95.20ms +step:2600/20000 train_loss:2.1594 train_time:247319ms step_avg:95.12ms +step:2650/20000 train_loss:2.0781 train_time:251878ms step_avg:95.05ms +step:2700/20000 train_loss:2.1168 train_time:256526ms step_avg:95.01ms +step:2750/20000 train_loss:2.1581 train_time:261090ms step_avg:94.94ms +step:2800/20000 train_loss:2.1330 train_time:265739ms step_avg:94.91ms +step:2850/20000 train_loss:2.1351 train_time:270313ms step_avg:94.85ms +step:2900/20000 train_loss:2.0638 train_time:274874ms step_avg:94.78ms +step:2950/20000 train_loss:2.1463 train_time:279522ms step_avg:94.75ms +step:3000/20000 train_loss:2.1591 train_time:284082ms step_avg:94.69ms +step:3050/20000 train_loss:2.0864 train_time:288643ms step_avg:94.64ms +step:3100/20000 train_loss:2.0954 train_time:293297ms step_avg:94.61ms +step:3150/20000 train_loss:2.1326 train_time:297856ms step_avg:94.56ms +step:3200/20000 train_loss:1.9112 train_time:302494ms step_avg:94.53ms +step:3250/20000 train_loss:2.1164 train_time:307060ms step_avg:94.48ms +step:3300/20000 train_loss:2.0477 train_time:311646ms step_avg:94.44ms +step:3350/20000 train_loss:2.0481 train_time:316293ms step_avg:94.42ms +step:3400/20000 train_loss:2.1343 train_time:320866ms step_avg:94.37ms +step:3450/20000 train_loss:2.1241 train_time:325512ms step_avg:94.35ms +step:3500/20000 train_loss:2.0931 train_time:330076ms step_avg:94.31ms +step:3550/20000 train_loss:2.0974 train_time:334633ms step_avg:94.26ms +step:3600/20000 train_loss:2.0907 train_time:339275ms step_avg:94.24ms +step:3650/20000 train_loss:2.0268 train_time:343845ms step_avg:94.20ms +step:3700/20000 train_loss:2.0418 train_time:348476ms step_avg:94.18ms +step:3750/20000 train_loss:2.1506 train_time:353036ms step_avg:94.14ms +step:3800/20000 train_loss:2.0752 train_time:357603ms step_avg:94.11ms +step:3850/20000 train_loss:2.1308 train_time:362252ms step_avg:94.09ms +step:3900/20000 train_loss:2.0688 train_time:366813ms step_avg:94.05ms +step:3950/20000 train_loss:2.0555 train_time:371462ms step_avg:94.04ms +step:4000/20000 train_loss:2.0031 train_time:376036ms step_avg:94.01ms +step:4050/20000 train_loss:2.1147 train_time:380599ms step_avg:93.98ms +step:4100/20000 train_loss:1.9607 train_time:385242ms step_avg:93.96ms +step:4150/20000 train_loss:2.1218 train_time:389800ms step_avg:93.93ms +step:4200/20000 train_loss:2.0984 train_time:394441ms step_avg:93.91ms +step:4250/20000 train_loss:2.0865 train_time:399010ms step_avg:93.88ms +step:4300/20000 train_loss:2.0624 train_time:403580ms step_avg:93.86ms +step:4350/20000 train_loss:1.9776 train_time:408229ms step_avg:93.85ms +step:4400/20000 train_loss:2.1091 train_time:412785ms step_avg:93.81ms +step:4450/20000 train_loss:1.9889 train_time:417353ms step_avg:93.79ms +step:4500/20000 train_loss:2.0520 train_time:422000ms step_avg:93.78ms +step:4550/20000 train_loss:1.9593 train_time:426573ms step_avg:93.75ms +step:4600/20000 train_loss:1.9923 train_time:431210ms step_avg:93.74ms +step:4650/20000 train_loss:2.0985 train_time:435784ms step_avg:93.72ms +step:4700/20000 train_loss:2.0307 train_time:440353ms step_avg:93.69ms +step:4750/20000 train_loss:2.0463 train_time:445001ms step_avg:93.68ms +step:4800/20000 train_loss:2.0601 train_time:449569ms step_avg:93.66ms +step:4850/20000 train_loss:2.0818 train_time:454220ms step_avg:93.65ms +step:4900/20000 train_loss:2.0382 train_time:458795ms step_avg:93.63ms +step:4950/20000 train_loss:1.9963 train_time:463356ms step_avg:93.61ms +step:5000/20000 train_loss:2.0558 train_time:467976ms step_avg:93.60ms +step:5050/20000 train_loss:1.9657 train_time:472542ms step_avg:93.57ms +step:5100/20000 train_loss:2.0261 train_time:477207ms step_avg:93.57ms +step:5150/20000 train_loss:2.0299 train_time:481782ms step_avg:93.55ms +step:5200/20000 train_loss:2.0474 train_time:486356ms step_avg:93.53ms +step:5250/20000 train_loss:1.9632 train_time:490998ms step_avg:93.52ms +step:5300/20000 train_loss:1.9431 train_time:495674ms step_avg:93.52ms +step:5350/20000 train_loss:2.1841 train_time:500354ms step_avg:93.52ms +step:5400/20000 train_loss:2.0361 train_time:504934ms step_avg:93.51ms +step:5450/20000 train_loss:2.1996 train_time:509528ms step_avg:93.49ms +step:5500/20000 train_loss:2.0501 train_time:514205ms step_avg:93.49ms +step:5550/20000 train_loss:2.0461 train_time:518806ms step_avg:93.48ms +step:5600/20000 train_loss:1.9633 train_time:523484ms step_avg:93.48ms +step:5650/20000 train_loss:2.1382 train_time:528067ms step_avg:93.46ms +step:5700/20000 train_loss:1.8887 train_time:532671ms step_avg:93.45ms +step:5750/20000 train_loss:1.9986 train_time:537340ms step_avg:93.45ms +step:5800/20000 train_loss:1.8035 train_time:541927ms step_avg:93.44ms +step:5850/20000 train_loss:1.9441 train_time:546606ms step_avg:93.44ms +step:5900/20000 train_loss:1.9655 train_time:551215ms step_avg:93.43ms +step:5950/20000 train_loss:1.8819 train_time:555793ms step_avg:93.41ms +step:6000/20000 train_loss:1.9213 train_time:560476ms step_avg:93.41ms +step:6050/20000 train_loss:1.9118 train_time:565058ms step_avg:93.40ms +step:6100/20000 train_loss:1.9396 train_time:569648ms step_avg:93.38ms +step:6150/20000 train_loss:1.9769 train_time:574320ms step_avg:93.39ms +step:6200/20000 train_loss:2.1091 train_time:578908ms step_avg:93.37ms +step:6250/20000 train_loss:1.9760 train_time:583589ms step_avg:93.37ms +step:6300/20000 train_loss:1.8647 train_time:588168ms step_avg:93.36ms +step:6350/20000 train_loss:1.9435 train_time:592757ms step_avg:93.35ms +step:6400/20000 train_loss:1.9207 train_time:597435ms step_avg:93.35ms +step:6428/20000 val_loss:1.9661 val_bpb:1.1644 train_time:600039ms step_avg:93.35ms +stopping_early: wallclock_cap train_time:600039ms step:6428/20000 +peak memory allocated: 18974 MiB reserved: 19176 MiB +swa:applying averaged 24 checkpoints +Serialized model: 98437483 bytes +Code size: 65264 bytes +Total submission size: 98502747 bytes +pruning:zeroed smallest 3.0% of large matrix weights +Serialized model int8+zlib: 14396235 bytes (payload:26268994 raw_torch:26320625 payload_ratio:3.75x) +v1:int6_tensors:41 +Total submission size int8+zlib: 14461499 bytes +final_int8_zlib_roundtrip val_loss:1.9936 val_bpb:1.1807 eval_time:1903ms +final_int8_zlib_roundtrip_exact val_loss:1.99355398 val_bpb:1.18069496 diff --git a/records/track_10min_16mb/2026-03-22_AtrisLabs_v8/train_gpt.py b/records/track_10min_16mb/2026-03-22_AtrisLabs_v8/train_gpt.py new file mode 100644 index 000000000..836e1890f --- /dev/null +++ b/records/track_10min_16mb/2026-03-22_AtrisLabs_v8/train_gpt.py @@ -0,0 +1,1494 @@ +"""Atris Labs — Parameter Golf submission. Int5 MLP, Int6 attn, BigramHash, SmearGate, SWA, sliding window eval.""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import re +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# v1 changes from baseline: +# - 10 transformer blocks (was 9) +# - matrix_lr=0.02, scalar_lr=0.02, tied_embed_lr=0.03 (was 0.04/0.04/0.05) +# - eval_seq_len supports longer eval sequences (default: train_seq_len) + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) # v8: 3000 (was 1200) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) # v8: 786K (was 524K) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) # v8: 2048 (was 1024) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 10)) # v1: 10 layers (was 9) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) # v8: 3x (Int5+zstd makes this fit in 16MB) + # v3: Weight sharing. num_unique_blocks unique blocks repeated to fill num_layers. + # Set to 0 to disable (each layer gets its own block, original behavior). + # E.g., num_unique_blocks=4, num_layers=12 → 4 unique blocks × 3 repeats = 12 effective layers. + num_unique_blocks = int(os.environ.get("NUM_UNIQUE_BLOCKS", 0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) # v1: 0.03 (was 0.05) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) # v1: 0.02 (was 0.04) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) # v1: 0.02 (was 0.04) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) # v6: 0.3 (was 0.0) + # v6: Weight decay (decoupled for Muon, standard for Adam) + muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.04)) + adam_weight_decay = float(os.environ.get("ADAM_WEIGHT_DECAY", 0.01)) + # v6: Stochastic Weight Averaging — collect checkpoints during warmdown. + # Starts when LR has decayed below swa_start_frac of peak (i.e., deep in warmdown). + # Set swa_every to 0 to disable. + swa_every = int(os.environ.get("SWA_EVERY", 50)) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) # start when LR < 40% of peak + + # v1: Eval sequence length (can be longer than train for free BPB improvement) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", os.environ.get("TRAIN_SEQ_LEN", 1024))) + # v2: Sliding window eval stride. stride < eval_seq_len means overlapping windows. + # Each token gets scored with ~(eval_seq_len - stride) context tokens. + # stride=64 with seq_len=1024 → every token has 960+ context → ~0.03 BPB free. + # Set to 0 to disable (uses standard non-overlapping eval). + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + wd = group.get("weight_decay", 0.0) + for p in params: + # v6: Decoupled weight decay (applied before gradient update) + if wd > 0: + p.data.mul_(1 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for seq_len={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + # + # v2: Sliding window eval. When eval_stride < eval_seq_len, we use overlapping + # windows so every token is scored with near-maximum context. This gives ~0.03 BPB + # improvement for free (no training changes, no artifact cost). + eval_seq_len = args.eval_seq_len + stride = args.eval_stride if args.eval_stride > 0 else eval_seq_len + + # Unwrap DDP to access forward_per_token_loss + raw_model = model.module if hasattr(model, "module") else model + # Handle torch.compile wrapper + if hasattr(raw_model, "_orig_mod"): + raw_model = raw_model._orig_mod + + use_sliding = stride < eval_seq_len and hasattr(raw_model, "forward_per_token_loss") + + if not use_sliding: + # Standard non-overlapping eval (original behavior) + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < eval_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, eval_seq_len={eval_seq_len}" + ) + local_batch_seqs = local_batch_tokens // eval_seq_len + total_seqs = (val_tokens.numel() - 1) // eval_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * eval_seq_len + raw_end = batch_seq_end * eval_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, eval_seq_len) + y = local[1:].reshape(-1, eval_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + # --- v2: Sliding window eval --- + # Process the validation set with overlapping windows of size eval_seq_len, + # advancing by `stride` tokens each step. Only score the last `stride` tokens + # per window (they all have near-full context). + total_tokens = val_tokens.numel() - 1 # -1 because we need (x, y) pairs + # Distribute windows across ranks + all_starts = list(range(0, total_tokens - eval_seq_len + 1, stride)) + rank_starts = all_starts[rank::world_size] + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for win_start in rank_starts: + win_end = win_start + eval_seq_len + # x = tokens[win_start:win_end], y = tokens[win_start+1:win_end+1] + chunk = val_tokens[win_start : win_end + 1].to(device=device, dtype=torch.int64, non_blocking=True) + x = chunk[:-1].unsqueeze(0) # [1, eval_seq_len] + y = chunk[1:].unsqueeze(0) # [1, eval_seq_len] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + per_token_loss = raw_model.forward_per_token_loss(x, y).detach() + # per_token_loss shape: [eval_seq_len] + + # Only count the last `stride` positions (they have full context) + score_start = eval_seq_len - stride + scored_losses = per_token_loss[score_start:] + scored_x = x[0, score_start:] # prev tokens for byte counting + scored_y = y[0, score_start:] # target tokens + + val_loss_sum += scored_losses.to(torch.float64).sum() + val_token_count += float(stride) + + token_bytes = base_bytes_lut[scored_y].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[scored_y] & ~is_boundary_token_lut[scored_x]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# v1: Mixed-precision quantization — INT8 for edge layers (0-2, 7-9), INT6 for middle layers (3-6). +# INT6 uses only 64 levels (stored as int8 dtype) which compresses much better under zlib. +# This is the key insight from nanlliu's competitive submission. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +# v8: Mixed-precision quantization config (matching winner's strategy) +# MLP weights → Int5 (32 levels, compresses 1.88x under zstd) +# Attention weights → Int6 (64 levels, compresses 1.51x under zstd) +# Embeddings → FP16 passthrough +# Control tensors → FP32 passthrough +QUANT_MLP_BITS = int(os.environ.get("QUANT_MLP_BITS", 5)) +QUANT_ATTN_BITS = int(os.environ.get("QUANT_ATTN_BITS", 6)) +QUANT_DEFAULT_BITS = int(os.environ.get("QUANT_DEFAULT_BITS", 6)) +# Magnitude pruning: zero out smallest N% of weights before quantization +PRUNE_PERCENT = float(os.environ.get("PRUNE_PERCENT", 3.0)) +# Compression: zstd (better ratio) or zlib (fallback) +USE_ZSTD = bool(int(os.environ.get("USE_ZSTD", 1))) +ZSTD_LEVEL = int(os.environ.get("ZSTD_LEVEL", 22)) +# Legacy compat +INT6_LAYER_START = int(os.environ.get("INT6_LAYER_START", 3)) +INT6_LAYER_END = int(os.environ.get("INT6_LAYER_END", 7)) + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _extract_layer_index(name: str) -> int | None: + """Extract transformer block layer index from tensor name, e.g. 'blocks.3.attn.c_q.weight' -> 3.""" + m = re.match(r"blocks\.(\d+)\.", name) + return int(m.group(1)) if m else None + +def _classify_param(name: str) -> str: + """Classify parameter for mixed-precision quantization (matching winner's strategy).""" + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name: + return "attn" + return "other" + +def quantize_float_tensor(t: Tensor, bits: int = 8) -> tuple[Tensor, Tensor]: + """Quantize a float tensor to int8 storage with configurable bit-width. + + bits=8: standard INT8 (256 levels, range [-127, 127]) + bits=6: INT6 (64 levels, range [-32, 31]), stored as int8 but with step=4 rounding + for better zlib compression due to fewer unique byte values. + """ + if bits == 6: + qmin, qmax = -32, 31 + else: + qmin, qmax = -127, 127 + + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(qmax)).clamp_min(1.0 / float(qmax)) + q = torch.clamp(torch.round(clipped / scale[:, None]), qmin, qmax).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / float(qmax) if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), qmin, qmax).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # v8: Mixed-precision quantization by parameter TYPE (matching winner): + # - MLP weights → Int5 (32 levels, best compression under zstd) + # - Attention weights → Int6 (64 levels) + # - BigramHash weights → Int6 + # - Embeddings (tok_emb) → FP16 passthrough (preserves quality) + # - Control tensors → FP32 passthrough + # - Small tensors → FP16 passthrough + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", + "baseline_tensor_bytes", "int8_payload_bytes", "int5_tensors", "int6_tensors"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Embeddings → FP16 passthrough (winner keeps tok_emb in FP16) + ptype = _classify_param(name) + if ptype == "embed": + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + # Small float tensors → FP16 passthrough + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + + # v8: Determine bits by parameter type + if ptype == "mlp": + bits = QUANT_MLP_BITS # default 5 + stats["int5_tensors"] += 1 + elif ptype == "attn": + bits = QUANT_ATTN_BITS # default 6 + stats["int6_tensors"] += 1 + elif ptype == "bigram": + bits = QUANT_ATTN_BITS # same as attention + stats["int6_tensors"] += 1 + else: + bits = QUANT_DEFAULT_BITS # default 6 + + q, s = quantize_float_tensor(t, bits=bits) + meta: dict[str, object] = {} + if s.ndim > 0: + meta["scheme"] = "per_row" + meta["axis"] = 0 + if bits == 6: + meta["bits"] = 6 + if meta: + qmeta[name] = meta + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class _FakeQuantSTE(torch.autograd.Function): + """Fake quantization with straight-through estimator for QAT.""" + @staticmethod + def forward(ctx, w: Tensor, bits: int) -> Tensor: + qmax = (1 << (bits - 1)) - 1 + # Per-row scale for 2D, per-tensor for 1D + if w.ndim == 2: + amax = w.abs().amax(dim=1, keepdim=True).clamp_min(1e-8) + else: + amax = w.abs().amax().clamp_min(1e-8) + scale = amax / qmax + return (w / scale).round().clamp(-qmax, qmax) * scale + + @staticmethod + def backward(ctx, grad_output: Tensor) -> tuple[Tensor, None]: + return grad_output, None # STE: pass gradient through + + +# v5: QAT bits. Set QAT_BITS=8 for INT8 QAT, QAT_BITS=6 for INT6, 0 to disable. +_QAT_BITS = int(os.environ.get("QAT_BITS", 0)) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + # v5: Optional fake quantization during forward pass (QAT) controlled by QAT_BITS env var. + def forward(self, x: Tensor) -> Tensor: + w = self.weight + if _QAT_BITS > 0 and self.training: + w = _FakeQuantSTE.apply(w, _QAT_BITS) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +# v7: BigramHash — captures local bigram context via hash embedding +# Used by #1 (thwu1) and #2 (Raahil Shah) on the leaderboard +_BIGRAM_BUCKETS = int(os.environ.get("BIGRAM_BUCKETS", 10240)) # v8: 10240 (matching winner) +_BIGRAM_DIM = int(os.environ.get("BIGRAM_DIM", 128)) + +class BigramHash(nn.Module): + def __init__(self, num_buckets: int, bigram_dim: int, model_dim: int): + super().__init__() + self.num_buckets = num_buckets + self.embed = nn.Embedding(num_buckets, bigram_dim) + nn.init.zeros_(self.embed.weight) # v8: zero-init (starts as no-op, learns gradually) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) + nn.init.zeros_(self.proj.weight) # v8: zero-init + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) # v8: learnable scale + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.num_buckets - 1 + out = torch.empty_like(t) + out[..., 0] = mod # first position → last bucket (no previous token) + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + h = self.proj(h) + return h * self.scale + + +# v7: SmearGate — learned gate blending current token with previous token +# Used by #2 and #4 on the leaderboard +_SMEAR_GATE = bool(int(os.environ.get("SMEAR_GATE", 1))) # v8: enabled by default + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + # x: [batch, seq_len, dim] + gate = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + prev = torch.zeros_like(x) + prev[:, 1:] = x[:, :-1] + return gate * x + (1 - gate) * prev + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_unique_blocks: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.num_layers = num_layers + self.tok_emb = nn.Embedding(vocab_size, model_dim) + # v7: BigramHash and SmearGate + self.bigram_hash = BigramHash(_BIGRAM_BUCKETS, _BIGRAM_DIM, model_dim) if _BIGRAM_BUCKETS > 0 else None + self.smear_gate = SmearGate(model_dim) if _SMEAR_GATE else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + # v3: Weight sharing — create fewer unique blocks, reuse them + self.weight_sharing = num_unique_blocks > 0 and num_unique_blocks < num_layers + if self.weight_sharing: + self.num_unique = num_unique_blocks + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for _ in range(num_unique_blocks) + ] + ) + # Per-layer adapters: lightweight scale + gate per virtual layer + # These differentiate repeated uses of the same block (tiny param cost) + self.layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(model_dim, dtype=torch.float32)) for _ in range(num_layers)] + ) + else: + self.num_unique = num_layers + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for _ in range(num_layers) + ] + ) + self.layer_scales = None + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _get_block(self, layer_idx: int) -> Block: + """Get the block for a given virtual layer index.""" + if self.weight_sharing: + return self.blocks[layer_idx % self.num_unique] + return self.blocks[layer_idx] + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.smear_gate is not None: + x = self.smear_gate(x) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self._get_block(i)(x, x0) + if self.layer_scales is not None: + x = x * self.layer_scales[i].to(dtype=x.dtype)[None, None, :] + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self._get_block(self.num_encoder_layers + i)(x, x0) + if self.layer_scales is not None: + x = x * self.layer_scales[self.num_encoder_layers + i].to(dtype=x.dtype)[None, None, :] + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_per_token_loss(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + """Return per-token cross-entropy losses (no reduction) for sliding window eval.""" + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.smear_gate is not None: + x = self.smear_gate(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self._get_block(i)(x, x0) + if self.layer_scales is not None: + x = x * self.layer_scales[i].to(dtype=x.dtype)[None, None, :] + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self._get_block(self.num_encoder_layers + i)(x, x0) + if self.layer_scales is not None: + x = x * self.layer_scales[self.num_encoder_layers + i].to(dtype=x.dtype)[None, None, :] + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="none") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + # v1: use eval_seq_len for validation tokens (supports longer eval sequences) + val_tokens = load_validation_tokens(args.val_files, args.eval_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if args.eval_seq_len != args.train_seq_len: + log0(f"v1:eval_seq_len:{args.eval_seq_len} (train_seq_len:{args.train_seq_len})") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_unique_blocks=args.num_unique_blocks, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + # v4: Include layer_scales from weight sharing in optimizer + if base_model.layer_scales is not None: + for ls in base_model.layer_scales: + scalar_params.append(ls) + # v7: BigramHash and SmearGate params + if base_model.bigram_hash is not None: + scalar_params.append(base_model.bigram_hash.embed.weight) + matrix_params.append(base_model.bigram_hash.proj.weight) + scalar_params.append(base_model.bigram_hash.scale) + if base_model.smear_gate is not None: + scalar_params.append(base_model.smear_gate.gate) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.AdamW( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_weight_decay, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + log0(f"v1:num_layers:{args.num_layers} int6_layers:[{INT6_LAYER_START},{INT6_LAYER_END})") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + # v6: SWA state — running sum (memory efficient, like winner's implementation) + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + swa_active = args.swa_every > 0 + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # v6: SWA — collect when LR scale drops below swa_start_frac (warmdown region) + if swa_active and step % args.swa_every == 0: + if scale < args.swa_start_frac: + if swa_state is None: + swa_state = {k: v.detach().cpu().clone() for k, v in base_model.state_dict().items()} + swa_count = 1 + else: + for k, v in base_model.state_dict().items(): + swa_state[k] += v.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + # v6: Apply SWA — average collected checkpoints (running sum / count) + if swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + k: (v / swa_count).to(dtype=current_state[k].dtype) + for k, v in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + del swa_state + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # v8: Magnitude pruning — zero out smallest N% of weights before quantization + if PRUNE_PERCENT > 0: + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), PRUNE_PERCENT / 100.0) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + log0(f"pruning:zeroed smallest {PRUNE_PERCENT}% of large matrix weights") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + # v8: Use zstd-22 for better compression (saves ~1-2MB vs zlib) + if USE_ZSTD: + try: + import zstandard as zstd + quant_blob = zstd.ZstdCompressor(level=ZSTD_LEVEL).compress(quant_raw) + except ImportError: + log0("WARNING: zstandard not installed, falling back to zlib") + quant_blob = zlib.compress(quant_raw, level=9) + else: + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"v1:int6_tensors:{quant_stats['int6_tensors']}") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + # Decompress (try zstd first, fall back to zlib) + try: + import zstandard as zstd + decompressed = zstd.ZstdDecompressor().decompress(quant_blob_disk) + except Exception: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main()