A pre-training pipeline for RoBERTa masked language models on custom datasets. Built with PyTorch and HuggingFace Transformers. Also supports Phi-3 causal language model training.
- Tokenize HuggingFace datasets with sliding-window stride support
- Pre-train RoBERTa (MLM) or Phi-3 (causal LM) from scratch
- Custom data collator that excludes special tokens from masking
- Weights & Biases integration (optional)
- Checkpoint resume support
- Ready-to-use SLURM scripts for HPC clusters
- Python 3.11
- UV package manager
git clone https://github.com/scilons/roberta-pretrain.git
cd roberta-pretrain
uv venv --python 3.11
source .venv/bin/activate
uv syncTokenize a HuggingFace dataset and save it to disk:
python -m roberta_pretrain.commands.dataset \
--dataset-with-range "dataset_name:0-100" \
--tokenizer "FacebookAI/roberta-base" \
--output-dir ./dataset \
--max-length 510Arguments:
--dataset-with-range— HuggingFace dataset name with percentage range (e.g."my-org/my-dataset:0-100"for 100%,"my-org/my-dataset:0-50"for the first half)--tokenizer— HuggingFace tokenizer name or local path--output-dir— Directory to write the tokenized dataset--max-length— Maximum token sequence length. Must be at least 2 less thanmax_position_embeddingsin your training config--stride— (optional, default 0) Sliding window stride for long documents
python -m roberta_pretrain.commands.train \
--seed 42 \
--config configs/training.toml \
--tokenizer "FacebookAI/roberta-base" \
--datasets ./dataset \
--checkpoint ./checkpoints \
--output ./modelArguments:
--seed— Random seed for reproducibility--config— Path to a TOML training config (see Configuration)--tokenizer— HuggingFace tokenizer name or local path--datasets— One or more paths to tokenized datasets (accepts multiple paths)--checkpoint— Directory for training checkpoints--output— Directory to save the final model--masking-prob— (optional, default 0.15) MLM masking probability--resume— (optional) Resume training from the latest checkpoint--is-phi3— (optional) Train a Phi-3 causal LM instead of RoBERTa--wandb-project— (optional) Weights & Biases project name (enables W&B logging)--wandb-entity— (optional) Weights & Biases entity/team name
Use the 3-layer debug config to verify everything works before launching a full run:
python -m roberta_pretrain.commands.train \
--seed 42 \
--config configs/3-layer-roberta.toml \
--tokenizer "FacebookAI/roberta-base" \
--datasets ./dataset \
--checkpoint ./checkpoints \
--output ./modelTraining is configured via TOML files with two sections:
[model]
max_position_embeddings = 514 # must be >= max_length + 2
[training]
per_device_train_batch_size = 256
num_train_epochs = 1
max_steps = 240000
warmup_steps = 2400
learning_rate = 1e-4
weight_decay = 0.1
# ... see configs/training.toml for all optionsAvailable configs:
configs/training.toml— Production RoBERTa config (based on the original paper)configs/3-layer-roberta.toml— Small 3-layer model for debuggingconfigs/phi3/03B.toml— Phi-3 3B configconfigs/phi3/debug.toml— Small Phi-3 for debugging
The slurms/ directory contains example job scripts for running on HPC clusters. These scripts were written for the DFKI Pegasus cluster using enroot containers, but can be adapted to any SLURM environment.
Dataset tokenization:
sbatch slurms/202507/datasets/512_roberta_base.slurmTraining:
sbatch slurms/202507/train/512/train.slurmTo adapt the SLURM scripts to your cluster:
- Update
--container-imageto your PyTorch container image (or remove container options if not using enroot) - Update
--partitionto match your cluster's available partitions - Adjust
--container-workdirand data paths to your environment - Set your
HF_TOKENandWANDB_API_KEYenvironment variables as needed
ruff check .
black .
mypy .pytest
pytest tests/test_foo.py -k "test_name" # run a single testsrc/roberta_pretrain/
├── commands/
│ ├── dataset.py # CLI: tokenize datasets
│ └── train.py # CLI: run pre-training
├── custom_data_collator.py # MLM collator excluding special tokens
├── dataset.py # Dataset loading and tokenization logic
├── train.py # RobertaTrainer and Phi3Trainer classes
└── utils.py # Logging utilities
configs/ # TOML training configurations
slurms/ # Example SLURM job scripts
Apache License 2.0 — see LICENSE for details.