Standalone repo to compare position→token and token→position factorizations on the same DiT backbone and masking distribution. Data loading, tokenizer handling, and model configs are adapted from the mdlm codebase but live entirely in this directory.
- Samples the same masked partial state for both factorizations and trains either
pos_to_tok,tok_to_pos, orjoint(both) objectives. - Both branches include a position head so you can compare joint log p(pos, tok | s) fairly.
- Logs every metric and loss term to Weights & Biases (
wandb). - Uses a DiT backbone with rotary embeddings; keep
max_length≤ DiT’s supported length.
cd /home/yunseok/Workspace/token_ordering/factorization
bash scripts/train.sh # runs pos→tok then tok→pos on wikitext-2 with default settingsTo customize a run:
python trainer.py \
--dataset wikitext2 \
--tokenizer gpt2 \
--max_length 256 \
--batch_size 8 \
--max_steps 5000 \
--factorization joint \
--wandb_project token-ordering-factorizationKey flags:
--factorization {pos_to_tok,tok_to_pos,joint}: choose which factorization to optimize.--disable_position_priors: turn off the p(i | s) head on the pos→tok side (enabled by default).--mask_ratio: fraction of visible tokens replaced by[MASK]per example.--max_length: keep at or below the DiT config length (default 512 inModelConfig).