This repository contains the implementation and experiments for our paper on stabilizing and improving looped language models. It includes two experimental settings:
- Addition task (
train_addition/): explores the effects of normalization structures, random loop-count training, and Jacobian regularization on a synthetic addition task. - Math reasoning (
ouro_experiment/): fine-tunes ByteDance/Ouro-1.4B on math datasets with random loop sampling and Jacobian regularization.
LoopedModel_STARS/
├── README.md
├── ICML-looped-llm-cmr-v2.pdf # Paper
├── environment.yml # Conda environment specification
├── train_addition/ # Addition task experiments
│ ├── data/ # Training/testing data (4-digit addition)
│ ├── primary_train_formal.py # Train: norm structure & type study (Sec 4.1)
│ ├── primary_test_formal.py # Test: norm structure & type evaluation
│ ├── primary_pca_formal.py # PCA visualization of hidden state trajectories
│ ├── random_train_formal.py # Train: random loop sampling + model structure (Sec 4.2)
│ ├── random_test_formal.py # Test: random loop models evaluation
│ ├── train.py # Train: combined random loop + Jacobian regularization
│ ├── formal_logs/ # Logs for primary experiments
│ ├── random_logs/ # Logs for random loop experiments
│ ├── logs/ # Additional logs
│ └── reglogs/ # Logs for regularization experiments
└── ouro_experiment/ # Math reasoning experiments
├── Ouro-1.4B/ # Model weights & config (download from HuggingFace)
│ ├── modeling_ouro.py # Modified: reg_loss + random loop support
│ ├── configuration_ouro.py
│ ├── config.json
│ └── ...
├── train_large.py # Training script
├── run.sh # Launch training via torchrun
├── run_eval.sh # Evaluation with vLLM + lm-eval-harness
├── ouro_vllm.py # vLLM adapter (for older vLLM without Ouro support)
└── outputs/ # Training checkpoints
A reference conda environment is provided in environment.yml. To create it:
conda env create -f environment.yml
conda activate ouroFor vLLM-based evaluation, additionally install vLLM (the exact version depends on your CUDA setup; we tested with vllm>=0.8.0):
pip install vllmSynthetic 4-digit addition experiments on a small looped transformer. Refer to the paper for full experimental details and analysis.
primary_train_formal.py — Studies the effect of normalization placement and type on hidden state trajectory (Section 4.1). Configurable via Config class:
norm_structure:pre_norm,post_norm,sandwich_branch,sandwich_dualnorm_type:LayerNorm,RMSNorm,DeepNorm,SimpleNorm- Fixed loop count (
l_loops), single shared block (k_layers=1)
primary_test_formal.py — Evaluates models trained by primary_train_formal.py.
primary_pca_formal.py — PCA-based visualization of hidden state trajectories across loop iterations.
random_train_formal.py — Studies random loop-count training and model structure variants (Section 4.2). Key Config options:
random_distribution:"log_norm","poisson","uniform"— loop count distributionrandom_loop_mu/random_loop_sigma— log-normal parametersprelude/coda: toggles non-looped blocks before/after the recurrent module, enabling comparison of only-recurrent, with-prelude, and with-coda architecturesl2_limit: optionally adds an L2 constraint on the final loop step
random_test_formal.py — Evaluates models trained by random_train_formal.py.
train.py — Combines random loop sampling with Jacobian regularization (reg_loss), the proposed method in the paper. Uses power iteration to estimate the Jacobian Frobenius norm of the loop dynamics, penalizing unstable recurrent transformations. Key parameter: reg_weight.
All scripts are self-contained; modify the Config class in each file and run directly:
cd train_addition
python primary_train_formal.py
python random_train_formal.py
python train.pyEach script auto-detects available GPUs. Outputs (checkpoints, logs, loss curves) are saved to timestamped directories under formal_logs/, random_formal_logs/, or newlogs/.
Fine-tunes Ouro-1.4B on math datasets with loop regularization and random loop sampling.
1. Create environment and download the model:
conda env create -f environment.yml
conda activate ouro
pip install huggingface_hub
huggingface-cli download ByteDance/Ouro-1.4B --local-dir ouro_experiment/Ouro-1.4B2. Install vLLM (for evaluation):
pip install vllm- Jacobian regularization: During training, computes the Frobenius norm of the loop dynamics Jacobian via power iteration on the final hidden states, added as a regularization term (controlled by
reg_weight). - Dynamic
total_ut_steps: The number of universal transformer iterations can be set per forward pass, enabling random loop-count training.
Edit the CLISettings dataclass in train_large.py for revising parameters.
Random loop parameters and regularization weight are set in tightly_scoped_fwd_bwd:
total_ut_steps = sample_random_loops(random_loop_mu=1.7, random_loop_sigma=0.4,
random_loop_min=1, random_loop_max=16)
total_loss = (1 - reg_weight) * loss + reg_weight * reg_loss # reg_weight default: 0.1cd ouro_experiment
bash run.shUses torchrun on 8 GPUs. Edit CUDA_VISIBLE_DEVICES and --nproc_per_node in run.sh as needed. Checkpoints are saved to outputs/<run_name>/.
With native vLLM Ouro support:
cd ouro_experiment
bash run_eval.shEdit MODEL_PATH, total_ut_steps in hf_overrides, and --tasks in the script.
With older vLLM (no native Ouro support):
Place ouro_vllm.py in the same directory as run_eval.sh. This adapter unrolls the recurrent layers into N * T distinct layer objects for KV-cache compatibility and disables dynamic early-exit for batched inference.