Skip to content

Add SoftmaxMultiClass objective and multiclass APIs#117

Open
saulhs12 wants to merge 11 commits into
jinlow:mainfrom
saulhs12:softmax-multiclass
Open

Add SoftmaxMultiClass objective and multiclass APIs#117
saulhs12 wants to merge 11 commits into
jinlow:mainfrom
saulhs12:softmax-multiclass

Conversation

@saulhs12
Copy link
Copy Markdown

@saulhs12 saulhs12 commented Apr 1, 2026

Summary

  • add native SoftmaxMultiClass support for K-class classification in the Rust core
  • expose multiclass support end to end in the Python wrapper, including num_classes and predict_proba
  • add multiclass validation, XGBoost cross-validation fixtures, stress coverage, and API semantics fixes for prediction iteration

Details

This PR adds a softmax multiclass objective to forust, validates labels and eval sets for multiclass training, and wires the public APIs through both Rust and Python.

It also makes multiclass prediction semantics explicit:

  • predict() returns raw logits with shape (n_samples, K) in Python
  • predict_proba() returns class probabilities with shape (n_samples, K)
  • prediction_iteration now means boosting rounds, not an arbitrary tree count, so multiclass predictions always use complete class rounds
  • unsupported scalar-only APIs such as contributions and partial dependence now fail explicitly for multiclass instead of returning misleading results

Validation

  • cargo test
  • cd py-forust && cargo check

Python pytest could not be run in this environment because sklearn was not installed during test collection.

saulhs12 added 7 commits April 1, 2026 13:41
Port XGBoost C++'s multi-class softmax (multiclass_obj.cu, gbtree.cc)
into forust-ml. Adds SoftmaxMultiClass objective type, K trees per
boosting round, N×K strided predictions, predict_proba(), and
MultiClassLogLoss metric.

Key additions:
- SoftmaxMultiClass struct: softmax, calc_grad_hess (2x hessian),
  calc_init (centered log-priors), validate_labels
- fit_trees_multiclass: K-tree-per-round training loop with N×K
  grad/hess, GOSS row sampling via max-abs-grad reduction
- predict() multi-class branch: N×K raw logits
- predict_proba(): softmax over logits → probabilities
- MultiClassLogLoss metric with direct dispatch (bypasses fn-ptr)
- num_classes + base_scores fields with serde defaults for
  backward compatibility
- Guards: num_classes>1 requires SoftmaxMultiClass objective
- 19 new tests (unit + integration), 0 existing tests modified
Verify our SoftmaxMultiClass implementation matches XGBoost C++:
- Softmax math: exact match (< 1e-12 tolerance)
- Gradient/Hessian: exact match (< 1e-6 tolerance, f32 precision)
- Base scores (InitEstimation): exact match (< 1e-10 tolerance)
- End-to-end: both achieve 100% on separable 3-class dataset

Reference values generated by tests/generate_xgb_reference.py.
Covers edge cases, numerical limits, and feature interactions:
- K=2 binary, K=10 many-class, K=50 softmax math
- Single sample per class, all same class, empty class
- Very imbalanced (95/3/2), extreme sample weights
- GOSS sampling, random subsampling, column subsampling
- Missing values (branch splitter + imputer splitter)
- Large dataset (5000 samples), 1-iteration minimum
- Tiny/large learning rates, refit resets state
- JSON roundtrip K=5, deterministic seeds
- Early stopping metric tracking, parallel vs sequential
- Label validation boundary conditions
- Grad/hess at uniform distribution
- CalcInit log-prior magnitudes for imbalanced data
@jinlow
Copy link
Copy Markdown
Owner

jinlow commented Apr 2, 2026

Thanks! Will try and review in the next day or so.
For any ai generated output it would be great to see some artifacts around your interactions to understand the models reasoning. For instance have it create a summary of each of your conversations interactions in conversation log. Any SDD methods you used, those artifacts could be included as well, even if just in a separate branch on your fork, if we cont want to muddy this commit.
Please feel free to add an agent/ directory with this info.

Adds ObjectiveType::QuantileLoss { alpha } for quantile regression, used
by downstream xgboost-v3 multi-horizon MFE/MAE quantile heads.

- pinball loss ρ_α(r) = max(α·r, (α−1)·r), r = y − ŷ
- gradient: g = I[y > ŷ] − α (matches XGBoost reg:quantileerror)
- constant pseudo-hessian = weight (same pattern as XGBoost/LightGBM)
- weighted α-quantile base score
- default metric = RMSE

QuantileLoss uses direct-call dispatch (like SoftmaxMultiClass) because
it carries an alpha parameter. fit_trees boxes a closure to capture alpha
without disturbing the existing function-pointer dispatch path.

Tests: 5 new unit tests covering init, gradient signs for α=0.30/0.70,
pinball loss values, and weighted quantile base score.
@jinlow
Copy link
Copy Markdown
Owner

jinlow commented Apr 11, 2026

Thanks, I am currently making some changes to the package for optimization that will likely make this PR out of date. Will need to wait to merge/review this till after that.

@jinlow
Copy link
Copy Markdown
Owner

jinlow commented Apr 11, 2026

Also, you should be able to run the python test locally. It’s even setup with vu now, so you can just run up sync followed by uv run pytest in the python package folder.

The gradient was `(indicator - alpha)` which is the negation of the
standard pinball loss gradient. This caused gradient boosting to
diverge: predictions drifted ±20 after 150+ iterations instead of
converging to the target quantile.

Correct gradient: `(alpha - indicator)`
- y > ŷ (underprediction): grad = α - 1 < 0 → pushes ŷ up ✓
- y ≤ ŷ (overprediction): grad = α > 0 → pushes ŷ down ✓
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants