Skip to content

athirai-s/Continual-Learning

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

205 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Contradiction-Aware Sparse Memory Finetuning

Python PyTorch HuggingFace Models

Contradiction-Aware Sparse Memory Finetuning (CASM) is a continual-learning research codebase for updating large language models with new time-scoped facts without erasing older facts that were also correct in their historical context. The project compares full fine-tuning, LoRA, Sparse Memory Finetuning (SMF), and CASM on a controlled Synthetic TemporalWiki benchmark, using a shared PyTorch/HuggingFace training stack, checkpoint contract, offline smoke tests, and CARC/Colab experiment paths for Llama-scale models.

Why This Matters

Large language models are trained on static snapshots of the world, but the world keeps changing. A model may need to learn that an office holder, treaty signatory, system maintainer, or institutional role changed in 2024 while still answering correctly about 2018 or 2020. Naive sequential fine-tuning often causes catastrophic forgetting: new updates overwrite old knowledge rather than adding a time-aware version of that knowledge. For factual assistants, scientific tools, policy systems, and enterprise knowledge models, this is not just lower benchmark accuracy. It is a failure to distinguish "formerly true" from "currently true."

Methods

Method What Trains Core Idea Version-Aware?
Full FT (full_ft) All backbone parameters Highest-capacity baseline; every temporal period updates the full model. No
LoRA (lora) Low-rank attention adapters Parameter-efficient baseline over q_proj, k_proj, v_proj, and o_proj. No
SMF (smf) One sparse memory block Freezes the backbone and writes updates into a single sparse trainable memory. No
CASM (casm) Routed sparse memory slots Adds contradiction detection, versioned memory slots, and inference-time routing. Yes

CASM extends SMF by replacing a single shared memory with a bank of slots. When a new fact contradicts a prior fact, the registry can close the older slot, create a new versioned slot, and preserve metadata such as subject, relation, validity window, parent slot, and contradiction links. A router then selects the relevant slot for a query, allowing old and new facts to coexist.

CASM architecture

Results

Average results on Synthetic TemporalWiki across periods 2020, 2022, and 2024. Period 2018 is excluded because there is no previous period from which to measure retention.

Method 1B Plasticity 1B Stability 1B Token F1 3B Plasticity 3B Stability 3B Token F1
Full FT 0.089 0.139 0.048 0.082 0.143 0.027
LoRA 0.138 0.221 0.097 0.096 0.114 0.000
SMF 0.591 0.482 0.112 0.118 0.215 0.044
CASM 0.329 0.258 0.056 0.072 0.083 0.024

Plasticity measures learning on changed probes after each update. Stability measures retention on unchanged probes from earlier periods. Token F1 measures surface-form generation quality and is averaged across changed and unchanged probes. In the reported experiments, SMF was strongest under the available data and compute budget; CASM remains the version-aware framework intended for larger contradiction-rich settings where explicit slot branching and routing can be exercised more fully.

Synthetic TemporalWiki

The original TemporalWiki corpus is large enough to be expensive for repeated Llama-scale continual-learning experiments. Synthetic TemporalWiki was built as a compute-tractable diagnostic benchmark with complete control over factual change. It contains fictional entities, unusual relation types, four temporal periods (2018, 2020, 2022, 2024), changed and unchanged probes, and passages that state each target fact exactly once. Because the entities are fictional, the benchmark reduces leakage from pretraining and tests whether the model actually absorbs and preserves new temporal facts.

Repository Layout

.
|-- main.py                         # supported training entrypoint
|-- run_job.sh                      # supported SLURM wrapper for main.py
|-- training/                       # runner, trainer, configs, CASM/SMF wrappers, routers
|-- casf_dataset_api/               # dataset abstraction, probes, memory registry, contradiction logic
|-- artifacts/                      # checkpoint manifests, run metadata, atomic checkpoint helpers
|-- eval_and_metrics/               # evaluation and reporting utilities
|-- dataset_utils/                  # augmented-data generation and export helpers
|-- data/                           # synthetic facts, probes, passages, and benchmark build scripts
|-- scripts/carc/                   # CARC training runner and SLURM wrapper
|-- scripts/experiments/            # historical/debug experiment runners and job scripts
|-- notebooks/                      # Colab notebooks and historical notebook runs
|-- experiments/legacy/             # legacy 3B training script kept outside the supported path
|-- docs/                           # paper text, architecture notes, testing docs, assets, archives
|-- tests/                          # unit, contract, integration, and smoke tests
|-- pyproject.toml                  # uv-managed Python project metadata
`-- uv.lock                         # locked dependency graph

The root intentionally keeps main.py and run_job.sh visible because they are the tested supported launch path. Older one-off experiment runners remain at root when moving them would change their relative execution contract.

Setup

This project uses uv for dependency management.

curl -LsSf https://astral.sh/uv/install.sh | sh
uv sync --group dev

For real Llama runs, authenticate with HuggingFace and request access to the gated Meta Llama 3.2 checkpoints:

huggingface-cli login

For Synthetic TemporalWiki generation with Gemini, create a .env file with the Google API key expected by google-genai:

GOOGLE_API_KEY=...

Local Smoke Test

The fast, supported offline path uses a synthetic tokenizer/model/dataset and does not require HuggingFace downloads.

uv run pytest -q
uv run python main.py --mode synthetic --run-id local-smoke --checkpoint-dir /tmp/casm-smoke

The same supported path is exercised through the SLURM wrapper in CI-style smoke tests:

CONTINUAL_LEARNING_MODE=synthetic \
CONTINUAL_LEARNING_SKIP_MODULES=1 \
CONTINUAL_LEARNING_SKIP_VENV=1 \
CONTINUAL_LEARNING_CHECKPOINT_DIR=/tmp/casm-smoke \
bash run_job.sh

Running Experiments

CARC / SLURM

Use the shared CARC runner for method comparisons:

PYTHONPATH=. python scripts/carc/train_carc.py --method full_ft --model meta-llama/Llama-3.2-3B-Instruct
PYTHONPATH=. python scripts/carc/train_carc.py --method lora    --model meta-llama/Llama-3.2-3B-Instruct
PYTHONPATH=. python scripts/carc/train_carc.py --method smf     --model meta-llama/Llama-3.2-3B-Instruct
PYTHONPATH=. python scripts/carc/train_carc.py --method casm    --model meta-llama/Llama-3.2-3B-Instruct

On USC CARC, submit the wrapper:

sbatch scripts/carc/train_carc_job.sh full_ft
sbatch scripts/carc/train_carc_job.sh lora
sbatch scripts/carc/train_carc_job.sh smf
sbatch scripts/carc/train_carc_job.sh casm

Colab

Use the notebooks in notebooks/ for interactive 1B-scale runs:

notebooks/train_colab.ipynb
notebooks/train_colab_synthetic.ipynb

Colab is suitable for shorter 1B experiments on A100 runtimes. CARC is the intended path for longer 3B runs, larger checkpoint directories, and repeated method sweeps.

Historical Step Runners

Several scripts under scripts/experiments/ preserve the exact experiment sequence used during development, including pretrain_period1_1b.py, run_step2_full_ft_1b.py, run_step2_lora_1b.py, run_step2_smf_1b.py, run_step2_casm_1b.py, and run_step3_eval_1b.py. These scripts contain hard-coded scratch paths from the original runs and are best treated as reproducibility records unless adapted to your environment.

Reproducing Synthetic Data

Generate the controlled fictional facts, build training passages, and build evaluation probes:

uv run python data/generate_synthetic.py --output data/synthetic_facts_raw.json
uv run python data/build_passages.py --facts data/synthetic_facts_raw.json --output data/passages.json
uv run python data/build_probes.py --facts data/synthetic_facts_raw.json --output data/probes.json

Optional helpers:

uv run python data/redistribute_changes.py
uv run python data/build_augmentation_prompts.py
uv run python dataset_utils/generate_dataset.py --prompts-dir dataset_utils/prompts --outdir data/augmented/TWiki_Diffsets
uv run python dataset_utils/export_eval_prompts.py

The Gemini-based generation scripts require API access and may incur cost. Use --dry-run on data/generate_synthetic.py to inspect prompts before making API calls.

Testing

Required local checks:

uv run pytest tests/unit -q
uv run pytest tests/contracts -q
uv run pytest tests/integration -q
uv run pytest tests/smoke -q

See docs/testing.md and ROADMAP.md for the merge contract. Behavioral changes require tests at the right layer; organization-only and documentation changes should still preserve the supported main.py plus run_job.sh path.

Compute And Acknowledgments

Experiments were implemented in PyTorch with HuggingFace Transformers and PEFT. The reported 1B experiments used Google Colab A100 resources, consuming roughly 200 A100 GPU-hours. The 3B experiments used the University of Southern California Center for Advanced Research Computing (USC CARC) cluster, consuming roughly 300 A100 GPU-hours. We acknowledge USC CARC for providing the compute resources used for the 3B continual-learning runs.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors