Add SoftmaxMultiClass objective and multiclass APIs#117
Conversation
Port XGBoost C++'s multi-class softmax (multiclass_obj.cu, gbtree.cc) into forust-ml. Adds SoftmaxMultiClass objective type, K trees per boosting round, N×K strided predictions, predict_proba(), and MultiClassLogLoss metric. Key additions: - SoftmaxMultiClass struct: softmax, calc_grad_hess (2x hessian), calc_init (centered log-priors), validate_labels - fit_trees_multiclass: K-tree-per-round training loop with N×K grad/hess, GOSS row sampling via max-abs-grad reduction - predict() multi-class branch: N×K raw logits - predict_proba(): softmax over logits → probabilities - MultiClassLogLoss metric with direct dispatch (bypasses fn-ptr) - num_classes + base_scores fields with serde defaults for backward compatibility - Guards: num_classes>1 requires SoftmaxMultiClass objective - 19 new tests (unit + integration), 0 existing tests modified
Verify our SoftmaxMultiClass implementation matches XGBoost C++: - Softmax math: exact match (< 1e-12 tolerance) - Gradient/Hessian: exact match (< 1e-6 tolerance, f32 precision) - Base scores (InitEstimation): exact match (< 1e-10 tolerance) - End-to-end: both achieve 100% on separable 3-class dataset Reference values generated by tests/generate_xgb_reference.py.
Covers edge cases, numerical limits, and feature interactions: - K=2 binary, K=10 many-class, K=50 softmax math - Single sample per class, all same class, empty class - Very imbalanced (95/3/2), extreme sample weights - GOSS sampling, random subsampling, column subsampling - Missing values (branch splitter + imputer splitter) - Large dataset (5000 samples), 1-iteration minimum - Tiny/large learning rates, refit resets state - JSON roundtrip K=5, deterministic seeds - Early stopping metric tracking, parallel vs sequential - Label validation boundary conditions - Grad/hess at uniform distribution - CalcInit log-prior magnitudes for imbalanced data
|
Thanks! Will try and review in the next day or so. |
Adds ObjectiveType::QuantileLoss { alpha } for quantile regression, used
by downstream xgboost-v3 multi-horizon MFE/MAE quantile heads.
- pinball loss ρ_α(r) = max(α·r, (α−1)·r), r = y − ŷ
- gradient: g = I[y > ŷ] − α (matches XGBoost reg:quantileerror)
- constant pseudo-hessian = weight (same pattern as XGBoost/LightGBM)
- weighted α-quantile base score
- default metric = RMSE
QuantileLoss uses direct-call dispatch (like SoftmaxMultiClass) because
it carries an alpha parameter. fit_trees boxes a closure to capture alpha
without disturbing the existing function-pointer dispatch path.
Tests: 5 new unit tests covering init, gradient signs for α=0.30/0.70,
pinball loss values, and weighted quantile base score.
|
Thanks, I am currently making some changes to the package for optimization that will likely make this PR out of date. Will need to wait to merge/review this till after that. |
|
Also, you should be able to run the python test locally. It’s even setup with vu now, so you can just run |
The gradient was `(indicator - alpha)` which is the negation of the standard pinball loss gradient. This caused gradient boosting to diverge: predictions drifted ±20 after 150+ iterations instead of converging to the target quantile. Correct gradient: `(alpha - indicator)` - y > ŷ (underprediction): grad = α - 1 < 0 → pushes ŷ up ✓ - y ≤ ŷ (overprediction): grad = α > 0 → pushes ŷ down ✓
Summary
SoftmaxMultiClasssupport for K-class classification in the Rust corenum_classesandpredict_probaDetails
This PR adds a softmax multiclass objective to forust, validates labels and eval sets for multiclass training, and wires the public APIs through both Rust and Python.
It also makes multiclass prediction semantics explicit:
predict()returns raw logits with shape(n_samples, K)in Pythonpredict_proba()returns class probabilities with shape(n_samples, K)prediction_iterationnow means boosting rounds, not an arbitrary tree count, so multiclass predictions always use complete class roundsValidation
cargo testcd py-forust && cargo checkPython
pytestcould not be run in this environment becausesklearnwas not installed during test collection.