This repository extends ReactEMG to study few-shot adaptation of healthy-pretrained sEMG models to stroke survivors. It provides a systematic experimental framework for comparing fine-tuning strategies when adapting a model trained on healthy subjects to stroke participants with limited calibration data.
Prerequisites: This repository builds on ReactEMG. For installation instructions, model architecture details, and background on the Any2Any transformer, see the ReactEMG README.
- s1 (data folder:
data/s1) - s2 (data folder:
data/s2) - s3 (data folder:
data/s3)
Each participant's data is organized into calibration and test sets:
data/s1/ # s2/ and s3/ share this layout
├── s1_open_1.csv, s1_close_1.csv ┐
├── s1_open_2.csv, s1_close_2.csv │ Calibration pool
├── s1_open_3.csv, s1_close_3.csv │ (4 baseline sets × 3 reps each = 12 paired reps)
├── s1_open_4.csv, s1_close_4.csv ┘
│
├── s1_open_5.csv, s1_close_5.csv # mid_session_baseline
├── s1_open_fatigue.csv, s1_close_fatigue.csv # end_session_baseline
├── s1_open_hovering.csv, s1_close_hovering.csv # unseen_posture
├── s1_open_sensor_shift.csv, s1_close_sensor_shift.csv # sensor_shift
└── s1_close_from_open.csv # orthosis_actuated
Calibration Pool: 12 paired repetitions (g_0 through g_11) extracted from 4 baseline sets, used for training/validation splits.
Test Conditions (5 types):
| Condition | Description |
|---|---|
mid_session_baseline |
Mid-session recordings (open_5, close_5) |
end_session_baseline |
Post-fatigue recordings (open_fatigue, close_fatigue) |
unseen_posture |
Arm hovering posture (open_hovering, close_hovering) |
sensor_shift |
After sensor repositioning (open_sensor_shift, close_sensor_shift) |
orthosis_actuated |
Orthosis-driven close motion (close_from_open) |
This repository compares 5 adaptation strategies:
| Strategy | Description | Command Flags |
|---|---|---|
| Zero-shot | Frozen pretrained model (baseline) | No training |
| Stroke-only | Train from scratch on stroke data | No --saved_checkpoint_pth |
| Head-only | Freeze backbone, train classification head | --freeze_backbone 1 |
| LoRA | Low-rank adaptation of linear layers | --use_lora 1 |
| Full fine-tune | Update all parameters | Default behavior |
The experiments follow a three-stage pipeline:
┌─────────────────────────────────────────────────────────────────────────┐
│ EXPERIMENTAL PIPELINE │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ Stage 1: ZERO-SHOT BASELINE │
│ └── Evaluate pretrained healthy model directly on stroke test sets │
│ │
│ Stage 2: HYPERPARAMETER SEARCH + TRAINING │
│ ├── 4-fold CV across calibration pool │
│ ├── Search: 27 configs (3 LRs × 3 epochs × 3 dropouts) │
│ ├── Select best config per variant (primary: transition accuracy) │
│ └── Train final model on full calibration pool │
│ │
│ Stage 3: EVALUATION │
│ └── Test all models on 5 test conditions with latency metrics │
│ │
└─────────────────────────────────────────────────────────────────────────┘
All commands assume you're in the reactemg/ directory.
The stroke dataset and the healthy-pretrained checkpoint do not live in this repo; set both up before running:
-
Download the stroke dataset. It is not bundled with this repository. Unpack it at the repository root so the three subject folders land at
data/s1,data/s2, anddata/s3(the layout shown above). The run scripts resolvePARTICIPANTSfrom this location automatically, so there are no data paths to edit in the code. Download and unpack it from the repository root:curl -L -o data.zip "https://www.dropbox.com/scl/fi/jp28src297h9ndddkz97i/data.zip?rlkey=usa9qunoo22ir35rh1ui677mh&dl=1" unzip data.zip && rm data.zip # the archive contains a top-level data/ -> data/s1, data/s2, data/s3
-
Set the healthy-pretrained checkpoint. The healthy-pretrained Any2Any model that every fine-tuned variant adapts from (produced by base ReactEMG, e.g. a healthy LOSO run) is not in this repo either. In
run_main_experiment.py,run_data_efficiency.py, andrun_convergence.py, replace thePRETRAINED_CHECKPOINTplaceholder (markedTODOnear the top of each file) with the path to your.pthcheckpoint. -
Disable / configure Weights & Biases. Every training run calls
wandb.init. To reproduce without a W&B account:export WANDB_MODE=disabled
Recommended order for a full from-scratch reproduction. Steps 2–3 reuse the
per-variant CV configs that step 1 writes to temp_cv_checkpoints/, so run step 1 first:
| Step | Command | Produces |
|---|---|---|
| 1 | python3 run_main_experiment.py --participant all |
Table 2 results and the per-variant CV configs in temp_cv_checkpoints/ (reused below) |
| 2 | run_data_efficiency.py for all variants — see §2 |
Data-efficiency results |
| 3 | run_convergence.py for all variants — see §3 |
Convergence results |
| 4 | analysis scripts — see §4 | The paper's tables and figures from the JSON under results/ |
The main experiment script orchestrates zero-shot evaluation, hyperparameter search, final training, and evaluation for all strategies.
python3 run_main_experiment.py --participant allThis orchestrates:
- Zero-shot evaluation on stroke data
- 4-fold CV hyperparameter search per strategy — 27 configs (LR {5e-5, 1e-4, 5e-4} × epochs {5, 10, 15} × dropout {0, 0.1, 0.2}) × 4 folds = 108 runs
- Final training with best hyperparameters (saved to
temp_cv_checkpoints/{participant}_{variant}_cv_results.json) - Evaluation on all 5 test conditions
Evaluates performance with limited calibration data (K = 1, 4, 8 paired repetitions, 12 trials per K). Each run reuses the CV config the main experiment wrote for that variant. Run all subjects and all variants:
for v in stroke_only head_only lora full_finetune; do
python3 run_data_efficiency.py --participant all --variant "$v"
doneSampling: K=1 uses one unique repetition per trial (trial i uses g_i); K>1 samples K of the 12 repetitions without replacement per trial.
Trains for a fixed 100 epochs (far beyond the CV-selected optimum), evaluating every 5 epochs — 21 checkpoints — on the stroke test sets to track learning dynamics. Run all subjects and all variants:
for v in stroke_only head_only lora full_finetune; do
python3 run_convergence.py --participant all --variant "$v"
doneWith results/ populated, these scripts produce the paper's tables and figures (subject mapping s1 = S1, s2 = S2, s3 = S3):
python3 extract_results.py # Table 2 -> results/main_experiment/table2.txt
python3 plot_main_results.py # Table 2 bars -> results/main_experiment/table2_bars.png
python3 analyze_data_efficiency.py --compare -o figure_dataeff.png # data-efficiency figure
python3 analyze_convergence.py --combined -p s2 -o figure_conv_s2.png # convergence figure (per subject; s2 = S2)- The
--compareand--combinedfigures require the corresponding experiment to have been run for every overlaid variant (data efficiency defaults tohead_only lora full_finetune; convergence needs all four). - Pass
--variant <v> --participant <p>to eitheranalyze_*script for a single numeric summary, or-o <path>to set the output file.
All stroke experiments use these evaluation settings for consistency:
| Parameter | Value |
|---|---|
buffer_range |
800 |
lookahead |
100 |
samples_between_prediction |
100 |
allow_relax |
1 |
stride |
1 |
likelihood_format |
logits |
maj_vote_range |
future |
Refer to the ReactEMG README for how these parameters shape the online smoothing behavior and the transition-accuracy metric.
For questions or support, please email Runsheng at runsheng.w@columbia.edu
This project is released under the MIT License; see the LICENSE file for details.