Dev#73
Merged
Merged
Conversation
…metric extraction) - Expand flat_hessian and calculate_period_metrics docstrings - Fix stale for_cycle_duration docstring (removed params) - Add Sphinx doc comments for FIXED_TRAINING_DEFAULTS and CONSERVATIVE_INITIAL_PARAMS - Add metric_extraction module to walk_forward.rst API reference - Update walk_forward.rst exclude list for new CycleEvaluation/EvaluationResult fields - Add user guide sections: Deflated Sharpe Ratio, block bootstrap CIs, return decomposition, regime-tagged evaluation - Update analysis.rst with usage examples for new post_train_analysis functions
Python 3.12+ warns on unrecognised escape sequences like \_ and future versions will make them SyntaxErrors. Use RST inline-code markup instead.
Add a third optimization method ("bfgs") to train_on_historic_data,
targeting the small-param-count regime where quasi-Newton converges
faster than first-order methods.
- New elif branch in jax_runners.py: builds deterministic objective
over fixed evaluation points, flattens params via ravel_pytree,
vmaps jax.scipy.optimize.minimize over n_parameter_sets for
multi-start optimization
- bfgs_settings defaults in default_run_fingerprint.py
(maxiter=100, tol=1e-6, n_evaluation_points=20)
- 6 tests covering end-to-end, multi-start, objective improvement,
metadata structure, validation fraction, and config defaults
Add bfgs_maxiter, bfgs_n_evaluation_points, bfgs_tol mappings to HyperparamTuner so outer Optuna can search over BFGS settings. Also add n_parameter_sets to opt_settings_keys (was silently dropped). New experiment scripts: - train_bfgs_example.py: single-run BFGS training with GPU auto-sizing - tune_training_hyperparams_innerbfgs.py: outer Optuna over 13D space (BFGS settings, multi-start init, training window, initial params) using power_channel rule
- Probe GPU forward-pass budget once at hyperopt startup, constrain search space and pass memory_budget through bfgs_settings - Per-trial cap in BFGS branch: n_parameter_sets ≤ budget // n_eval_points - Save initial (step 0) and optimized (step 1) params via save_multi_params - Fix probe_max_n_parameter_sets double-binding of prices arg - Slim BFGS tests from 13min to 3min (shorter windows, fewer iters)
- BFGS save now emits 2 entries (step 0, step 1) with batched params matching the SGD format, instead of interleaved per-set entries - Add dynamic_high support to HyperparamSpace for per-trial bounds: n_parameter_sets upper bound adapts to sampled n_eval_points - Remove worst-case static range cap in favour of dynamic constraint
- Derive bout_offset_days upper bound from max val_fraction to prevent effective_train_length < bout_offset crashes during hyperopt - Fix train_objective in BFGS save to emit metric dicts (matching SGD) - Fix objective in BFGS save to use actual BFGS objective values - Add dynamic_high support to HyperparamSpace for conditional bounds
Wrap partial_training_step with jax.checkpoint before vmap/jit batching. Trades ~2x forward recompute for dramatically reduced VRAM by discarding intermediates during the backward pass. Adds gradient_checkpointing flag to bfgs_settings (default True) so it can be toggled for A/B comparison. Includes profile_bfgs_memory.py script that polls nvidia-smi during BFGS runs to measure peak VRAM, power draw, and utilisation with/without checkpointing.
nvidia-smi shows JAX's pre-allocated memory pool (~75-90% of VRAM), not actual peak allocation. Now uses device.memory_stats()["peak_bytes_in_use"] as primary metric. Adds --no-pool flag to disable JAX's pool allocator (XLA_PYTHON_CLIENT_ALLOCATOR=platform) for cases where real nvidia-smi readings are needed.
Each trial now runs in a separate subprocess so JAX's peak_bytes_in_use counter resets between trials (was cumulative, making ON/OFF comparison useless). Defaults updated to match production scale: 12 months window, 20 eval points.
Replace the runtime-based profiler (subprocess isolation + peak_bytes_in_use) with XLA's compiled.memory_analysis(). This gives deterministic, accurate temp memory numbers directly from XLA's allocation plan — no runtime noise, no nvidia-smi polling, no process isolation needed. The profiler builds the actual quantammsim BFGS computation pipeline (data loading, forward pass, batched objective, vmapped solve), compiles with and without jax.checkpoint, and reports: - temp_size_in_bytes (XLA's planned scratch allocation) - cost_analysis flops (compute cost) - inner value_and_grad stats (the BFGS hot loop) Initial CPU results show checkpoint increases temp memory ~15-23% for the quantammsim forward pass, contrary to expectations. GPU results pending.
Replace O(n*k) jnp.convolve with O(n log n) FFT convolution in all 9 estimator call sites. Add float32 compute_dtype option for BFGS so the forward pass runs in reduced precision (BFGS itself stays float64 for Hessian stability). Fix float64 promotion throughout fine_weights.py scan bodies (Python float literals, int64 intermediates, jnp.ones defaults).
… vs float64 Gradient checkpointing was proven counterproductive (XLA already does its own remat; explicit checkpoint added 15-23% overhead). Remove the code path from jax_runners.py and the setting from defaults. Rewrite profile_bfgs_memory.py to compare float32 vs float64 XLA memory/FLOP profiles instead of checkpoint on/off.
Raw lax ops (lax.add, lax.mul, etc.) require exact dtype match between operands — no type promotion. When float32 inputs met float64 Python literals (4.0, 1.0, 0.5), MLIR verification failed. jnp ops handle promotion correctly. Fixed in both estimator_primitives.py and param_utils.py. Removed unused lax imports.
…calls - Remove all 36 jax_enable_x64=True calls from library code - Add centralized x64 toggle at top of train_on_historic_data with try/finally save/restore so callers aren't affected by state leaks - Fix profiler: move x64 toggle before data loading and param init - Fix conftest: function-scoped fixture resets x64 between tests - Revert per-site dtype casting (preserved in tag dtype-casting-approach) - Fix squareplus/inverse_squareplus: use jnp ops not raw lax - Add float32 forward pass integration tests (x64 toggle approach) - Update test_weight_calculations.py: remove redundant x64 call
Adds execution timing alongside the existing compile-time memory analysis. With --execute, the profiler warms up and runs the compiled inner value_and_grad (N reps, median) and full vmapped solve (1 run), reporting wall-clock ms, effective GFLOP/s, and speedup ratios. Also suppresses noisy data-loading prints (start_date/end_date/unix_values).
The vectorized weight path explicitly transferred data CPU↔GPU via device_put, splitting the forward pass into 3 separate XLA programs with 4 cross-device copies. Removing these lets the entire pipeline (estimators → coarse weights → fine interpolation → reserves) compile as a single XLA program, enabling cross-stage kernel fusion.
Add CMA-ES (Covariance Matrix Adaptation Evolution Strategy) as an alternative to BFGS for derivative-free optimization of strategy parameters. Purpose-built for the 5-50 parameter, expensive-evaluation regime with essentially zero hyperparameters. New files: - quantammsim/training/cma_es.py: Pure JAX CMA-ES following Hansen's tutorial — CMAESState, ask/tell interface, default_params, should_stop - tests/unit/test_cma_es.py: 7 unit tests (sphere, Rosenbrock, shapes) + 5 integration tests (end-to-end, restarts, metadata, config, val) - experiments/tune_training_hyperparams_innercmaes.py: Outer Optuna tuning script with CMA-ES inner loop (13D search space) Modified files: - jax_runners.py: x64 toggle case + elif method=="cma_es" branch mirroring BFGS structure (eval points, ravel/unravel, Python loop over restarts, BestParamsTracker, save_multi_params, metadata) - default_run_fingerprint.py: cma_es_settings defaults - hyperparam_tuner.py: cma_es_* param mappings for outer Optuna Also fixes 5 pre-existing test failures: - estimator_primitives.py: replace 10 hardcoded jnp.float64 in scan carry inits with arr_in.dtype (fixes float32 forward pass) - fine_weights.py: cast scan carry output to input dtype to prevent float64 promotion from Python literals; replace jnp.float64 in jnp.ones with initial_weights.dtype - test_variance_calc.py: widen tolerance from -1e-15 to -1e-10 for first-row EWMA warm-up artifact
Mirrors scripts/profile_bfgs_memory.py but profiles the CMA-ES population evaluation (forward-only, no grad). Compiles jit(vmap(eval_single)) and measures XLA temp memory, FLOPs, and optionally wall-clock time per generation (ask + eval + tell). Supports --sweep over population sizes and --execute for timing.
Fuse the inner CMA-ES generation loop into a single XLA program via lax.while_loop, eliminating ~6ms/gen Python dispatch overhead. - Add run_cmaes() with lax.while_loop, _should_stop_jax() returning JAX bool - Harden init_cmaes/ask/tell dtypes for while_loop carry compatibility: explicit dtype on all state fields, weights cast to state dtype, math.sqrt for constant coefficients, bool.astype for h_sigma - Replace Python generation loop in jax_runners with JIT-compiled _run_one_restart calling run_cmaes - Add fused-loop timing to CMA-ES profiler - Fix Optuna JSON serialization for float32 metrics (_json_safe helper) - Add tests: sphere convergence, python-loop match, early stop, float32-under-x64 dtype preservation
Add auto-sizing of CMA-ES λ based on GPU memory capacity. A single startup probe determines max concurrent forward passes; the runner then computes λ = budget // n_eval_points per trial, filling available VRAM. Fixes pre-existing bug where population_size_override left stale weights/learning rates in cma_params by adding optional lam parameter to default_params that recomputes all dependent quantities.
Default max_sets=64 was too low — GPU saturated the search range without hitting OOM, yielding a useless budget for high n_eval trials.
1024 still saturated without OOM on A100-class GPUs.
Replace the flat vmap forward-pass probe with one that binary-searches on λ by running real CMA-ES (1 generation via train_on_historic_data). This captures the true memory footprint of the fused lax.while_loop — nested vmap, carry state, XLA constant-folding — instead of guessing a safety margin. Price data is loaded once and passed via price_data= to avoid repeated disk I/O across binary-search steps.
Reclamm phase 1
Feature/bfgs optimizer
JAX 0.6 changed the default for jax_threefry_partitionable from False to True. This switches jax.random.split to a different algorithm that produces entirely different subkeys from the same seed. Since training uses random batch sampling (random.choice in get_indices), different subkeys produce different training trajectories, causing pinned objective values to drift by up to 10%. Pin the original algorithm in conftest so tests are reproducible across JAX versions. Production code is unaffected and uses whatever JAX defaults to for the target hardware.
Increase base_lr from 0.05 to 0.5 and change seed from 42 to 123 so that training produces meaningful val-metric separation across iterations. With the old config, val metrics varied by only ~1e-13 between steps, making best_iteration selection depend on FP noise at the precision limit — the root cause of the flaky test_partial_last_chunk_matches CI failure. The new config gives ~0.59 val-metric separation, making best_iteration deterministic regardless of scan chunking strategy. Both momentum and mean_reversion_channel pinned objectives are re-pinned accordingly.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.