Skip to content

Dev#73

Merged
MatthewWilletts merged 82 commits into
mainfrom
dev
Mar 5, 2026
Merged

Dev#73
MatthewWilletts merged 82 commits into
mainfrom
dev

Conversation

@MatthewWilletts
Copy link
Copy Markdown
Contributor

No description provided.

…metric extraction)

- Expand flat_hessian and calculate_period_metrics docstrings
- Fix stale for_cycle_duration docstring (removed params)
- Add Sphinx doc comments for FIXED_TRAINING_DEFAULTS and CONSERVATIVE_INITIAL_PARAMS
- Add metric_extraction module to walk_forward.rst API reference
- Update walk_forward.rst exclude list for new CycleEvaluation/EvaluationResult fields
- Add user guide sections: Deflated Sharpe Ratio, block bootstrap CIs,
  return decomposition, regime-tagged evaluation
- Update analysis.rst with usage examples for new post_train_analysis functions
Python 3.12+ warns on unrecognised escape sequences like \_ and future
versions will make them SyntaxErrors.  Use RST inline-code markup instead.
Add a third optimization method ("bfgs") to train_on_historic_data,
targeting the small-param-count regime where quasi-Newton converges
faster than first-order methods.

- New elif branch in jax_runners.py: builds deterministic objective
  over fixed evaluation points, flattens params via ravel_pytree,
  vmaps jax.scipy.optimize.minimize over n_parameter_sets for
  multi-start optimization
- bfgs_settings defaults in default_run_fingerprint.py
  (maxiter=100, tol=1e-6, n_evaluation_points=20)
- 6 tests covering end-to-end, multi-start, objective improvement,
  metadata structure, validation fraction, and config defaults
Add bfgs_maxiter, bfgs_n_evaluation_points, bfgs_tol mappings to
HyperparamTuner so outer Optuna can search over BFGS settings.
Also add n_parameter_sets to opt_settings_keys (was silently dropped).

New experiment scripts:
- train_bfgs_example.py: single-run BFGS training with GPU auto-sizing
- tune_training_hyperparams_innerbfgs.py: outer Optuna over 13D space
  (BFGS settings, multi-start init, training window, initial params)
  using power_channel rule
- Probe GPU forward-pass budget once at hyperopt startup, constrain
  search space and pass memory_budget through bfgs_settings
- Per-trial cap in BFGS branch: n_parameter_sets ≤ budget // n_eval_points
- Save initial (step 0) and optimized (step 1) params via save_multi_params
- Fix probe_max_n_parameter_sets double-binding of prices arg
- Slim BFGS tests from 13min to 3min (shorter windows, fewer iters)
- BFGS save now emits 2 entries (step 0, step 1) with batched params
  matching the SGD format, instead of interleaved per-set entries
- Add dynamic_high support to HyperparamSpace for per-trial bounds:
  n_parameter_sets upper bound adapts to sampled n_eval_points
- Remove worst-case static range cap in favour of dynamic constraint
- Derive bout_offset_days upper bound from max val_fraction to prevent
  effective_train_length < bout_offset crashes during hyperopt
- Fix train_objective in BFGS save to emit metric dicts (matching SGD)
- Fix objective in BFGS save to use actual BFGS objective values
- Add dynamic_high support to HyperparamSpace for conditional bounds
Wrap partial_training_step with jax.checkpoint before vmap/jit batching.
Trades ~2x forward recompute for dramatically reduced VRAM by discarding
intermediates during the backward pass.

Adds gradient_checkpointing flag to bfgs_settings (default True) so it
can be toggled for A/B comparison. Includes profile_bfgs_memory.py script
that polls nvidia-smi during BFGS runs to measure peak VRAM, power draw,
and utilisation with/without checkpointing.
nvidia-smi shows JAX's pre-allocated memory pool (~75-90% of VRAM),
not actual peak allocation. Now uses device.memory_stats()["peak_bytes_in_use"]
as primary metric. Adds --no-pool flag to disable JAX's pool allocator
(XLA_PYTHON_CLIENT_ALLOCATOR=platform) for cases where real nvidia-smi
readings are needed.
Each trial now runs in a separate subprocess so JAX's peak_bytes_in_use
counter resets between trials (was cumulative, making ON/OFF comparison
useless). Defaults updated to match production scale: 12 months window,
20 eval points.
Replace the runtime-based profiler (subprocess isolation + peak_bytes_in_use)
with XLA's compiled.memory_analysis(). This gives deterministic, accurate
temp memory numbers directly from XLA's allocation plan — no runtime noise,
no nvidia-smi polling, no process isolation needed.

The profiler builds the actual quantammsim BFGS computation pipeline
(data loading, forward pass, batched objective, vmapped solve), compiles
with and without jax.checkpoint, and reports:
  - temp_size_in_bytes (XLA's planned scratch allocation)
  - cost_analysis flops (compute cost)
  - inner value_and_grad stats (the BFGS hot loop)

Initial CPU results show checkpoint increases temp memory ~15-23% for the
quantammsim forward pass, contrary to expectations. GPU results pending.
Replace O(n*k) jnp.convolve with O(n log n) FFT convolution in all 9
estimator call sites. Add float32 compute_dtype option for BFGS so the
forward pass runs in reduced precision (BFGS itself stays float64 for
Hessian stability). Fix float64 promotion throughout fine_weights.py
scan bodies (Python float literals, int64 intermediates, jnp.ones
defaults).
… vs float64

Gradient checkpointing was proven counterproductive (XLA already does
its own remat; explicit checkpoint added 15-23% overhead). Remove the
code path from jax_runners.py and the setting from defaults.

Rewrite profile_bfgs_memory.py to compare float32 vs float64 XLA
memory/FLOP profiles instead of checkpoint on/off.
Raw lax ops (lax.add, lax.mul, etc.) require exact dtype match between
operands — no type promotion. When float32 inputs met float64 Python
literals (4.0, 1.0, 0.5), MLIR verification failed. jnp ops handle
promotion correctly.

Fixed in both estimator_primitives.py and param_utils.py. Removed
unused lax imports.
…calls

- Remove all 36 jax_enable_x64=True calls from library code
- Add centralized x64 toggle at top of train_on_historic_data with
  try/finally save/restore so callers aren't affected by state leaks
- Fix profiler: move x64 toggle before data loading and param init
- Fix conftest: function-scoped fixture resets x64 between tests
- Revert per-site dtype casting (preserved in tag dtype-casting-approach)
- Fix squareplus/inverse_squareplus: use jnp ops not raw lax
- Add float32 forward pass integration tests (x64 toggle approach)
- Update test_weight_calculations.py: remove redundant x64 call
Adds execution timing alongside the existing compile-time memory analysis.
With --execute, the profiler warms up and runs the compiled inner
value_and_grad (N reps, median) and full vmapped solve (1 run), reporting
wall-clock ms, effective GFLOP/s, and speedup ratios.

Also suppresses noisy data-loading prints (start_date/end_date/unix_values).
The vectorized weight path explicitly transferred data CPU↔GPU via
device_put, splitting the forward pass into 3 separate XLA programs
with 4 cross-device copies. Removing these lets the entire pipeline
(estimators → coarse weights → fine interpolation → reserves) compile
as a single XLA program, enabling cross-stage kernel fusion.
Add CMA-ES (Covariance Matrix Adaptation Evolution Strategy) as an
alternative to BFGS for derivative-free optimization of strategy
parameters. Purpose-built for the 5-50 parameter, expensive-evaluation
regime with essentially zero hyperparameters.

New files:
- quantammsim/training/cma_es.py: Pure JAX CMA-ES following Hansen's
  tutorial — CMAESState, ask/tell interface, default_params, should_stop
- tests/unit/test_cma_es.py: 7 unit tests (sphere, Rosenbrock, shapes)
  + 5 integration tests (end-to-end, restarts, metadata, config, val)
- experiments/tune_training_hyperparams_innercmaes.py: Outer Optuna
  tuning script with CMA-ES inner loop (13D search space)

Modified files:
- jax_runners.py: x64 toggle case + elif method=="cma_es" branch
  mirroring BFGS structure (eval points, ravel/unravel, Python loop
  over restarts, BestParamsTracker, save_multi_params, metadata)
- default_run_fingerprint.py: cma_es_settings defaults
- hyperparam_tuner.py: cma_es_* param mappings for outer Optuna

Also fixes 5 pre-existing test failures:
- estimator_primitives.py: replace 10 hardcoded jnp.float64 in scan
  carry inits with arr_in.dtype (fixes float32 forward pass)
- fine_weights.py: cast scan carry output to input dtype to prevent
  float64 promotion from Python literals; replace jnp.float64 in
  jnp.ones with initial_weights.dtype
- test_variance_calc.py: widen tolerance from -1e-15 to -1e-10 for
  first-row EWMA warm-up artifact
Mirrors scripts/profile_bfgs_memory.py but profiles the CMA-ES
population evaluation (forward-only, no grad). Compiles
jit(vmap(eval_single)) and measures XLA temp memory, FLOPs, and
optionally wall-clock time per generation (ask + eval + tell).

Supports --sweep over population sizes and --execute for timing.
Fuse the inner CMA-ES generation loop into a single XLA program via
lax.while_loop, eliminating ~6ms/gen Python dispatch overhead.

- Add run_cmaes() with lax.while_loop, _should_stop_jax() returning JAX bool
- Harden init_cmaes/ask/tell dtypes for while_loop carry compatibility:
  explicit dtype on all state fields, weights cast to state dtype,
  math.sqrt for constant coefficients, bool.astype for h_sigma
- Replace Python generation loop in jax_runners with JIT-compiled
  _run_one_restart calling run_cmaes
- Add fused-loop timing to CMA-ES profiler
- Fix Optuna JSON serialization for float32 metrics (_json_safe helper)
- Add tests: sphere convergence, python-loop match, early stop,
  float32-under-x64 dtype preservation
Add auto-sizing of CMA-ES λ based on GPU memory capacity. A single
startup probe determines max concurrent forward passes; the runner
then computes λ = budget // n_eval_points per trial, filling available
VRAM. Fixes pre-existing bug where population_size_override left stale
weights/learning rates in cma_params by adding optional lam parameter
to default_params that recomputes all dependent quantities.
Default max_sets=64 was too low — GPU saturated the search range
without hitting OOM, yielding a useless budget for high n_eval trials.
1024 still saturated without OOM on A100-class GPUs.
Replace the flat vmap forward-pass probe with one that binary-searches
on λ by running real CMA-ES (1 generation via train_on_historic_data).
This captures the true memory footprint of the fused lax.while_loop —
nested vmap, carry state, XLA constant-folding — instead of guessing
a safety margin.

Price data is loaded once and passed via price_data= to avoid repeated
disk I/O across binary-search steps.
bulkcade and others added 29 commits February 27, 2026 18:28
JAX 0.6 changed the default for jax_threefry_partitionable from False
to True. This switches jax.random.split to a different algorithm that
produces entirely different subkeys from the same seed. Since training
uses random batch sampling (random.choice in get_indices), different
subkeys produce different training trajectories, causing pinned
objective values to drift by up to 10%.

Pin the original algorithm in conftest so tests are reproducible across
JAX versions. Production code is unaffected and uses whatever JAX
defaults to for the target hardware.
Increase base_lr from 0.05 to 0.5 and change seed from 42 to 123 so
that training produces meaningful val-metric separation across
iterations. With the old config, val metrics varied by only ~1e-13
between steps, making best_iteration selection depend on FP noise at the
precision limit — the root cause of the flaky
test_partial_last_chunk_matches CI failure.

The new config gives ~0.59 val-metric separation, making best_iteration
deterministic regardless of scan chunking strategy. Both momentum and
mean_reversion_channel pinned objectives are re-pinned accordingly.
@MatthewWilletts MatthewWilletts merged commit 3a111c3 into main Mar 5, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants