Fig. 1 - Train and test accuracy over epochs; each step is 100 epochs. Training reaches 100% quickly, while test accuracy stays near 0 until ~6000 epochs before rising to 100%. Train sample size: 2553; Test sample size: 10216; Model parameters: 64,689.
This project aims to reproduce "grokking" phenomenon in modular addition task, viz.
Note that we do not use causal attention; it is not a seq2seq model.
- Install
uv. - Clone this repo and
cdinto it. Setup environment usinguv sync. - (Optional) Setup
wandbusinguv run wandb login. - Check the available config options using
uv run run.py -h. - Run the following command:
uv run run.py --p=113 --train_frac=0.20 --lr=1e-3 --epochs=30_000 --weight_decay=1.0
-
weight_decayis a key parameter to induce grokking: after the model reaches a "memorization basin" where the gradient of the loss term is negligible, the only signal for descent is from the gradient of the regularization term, say$\lambda|\cdot|^2$ which pushes the learned parameter to be in a closed ball centred at zero with radius inversely proportional to$\lambda$ ; larger the$\lambda$ , stronger is this signal, and ideally it moves the parameter to a "generalization basin".
