From 19ee13e0de06f5ff14e1ed8a0c0f73ee7473bd54 Mon Sep 17 00:00:00 2001 From: christian harrington Date: Thu, 5 Mar 2026 14:40:36 +0000 Subject: [PATCH 1/3] initial implementation --- quantammsim/core_simulator/dynamic_inputs.py | 30 +- quantammsim/core_simulator/forward_pass.py | 3 + quantammsim/hooks/dynamic_fee_base_hook.py | 1 + quantammsim/pools/reCLAMM/reclamm.py | 2 + quantammsim/pools/reCLAMM/reclamm_reserves.py | 360 +++++++++++++++++- quantammsim/runners/jax_runner_utils.py | 198 ++++++++++ quantammsim/runners/jax_runners.py | 1 + .../finance/param_financial_calculator.py | 3 + tests/pools/reCLAMM/helpers.py | 15 + tests/pools/reCLAMM/test_reclamm_e2e.py | 2 +- .../pools/reCLAMM/test_reclamm_fee_revenue.py | 1 + tests/pools/reCLAMM/test_reclamm_math.py | 4 +- .../test_reclamm_price_ratio_updates.py | 221 +++++++++++ tests/unit/test_jax_runner_utils.py | 208 ++++++++++ tests/unit/test_jax_runners_comprehensive.py | 109 ++++++ 15 files changed, 1134 insertions(+), 24 deletions(-) create mode 100644 tests/pools/reCLAMM/helpers.py create mode 100644 tests/pools/reCLAMM/test_reclamm_price_ratio_updates.py diff --git a/quantammsim/core_simulator/dynamic_inputs.py b/quantammsim/core_simulator/dynamic_inputs.py index c249088..598d3f7 100644 --- a/quantammsim/core_simulator/dynamic_inputs.py +++ b/quantammsim/core_simulator/dynamic_inputs.py @@ -13,6 +13,7 @@ class DynamicInputFrames: gas_cost: Optional[Any] = None arb_fees: Optional[Any] = None lp_supply: Optional[Any] = None + reclamm_price_ratio_updates: Optional[Any] = None class DynamicInputArrays(NamedTuple): @@ -23,6 +24,7 @@ class DynamicInputArrays(NamedTuple): gas_cost: jnp.ndarray arb_fees: jnp.ndarray lp_supply: jnp.ndarray + reclamm_price_ratio_updates: jnp.ndarray def default_dynamic_input_flags() -> dict: @@ -34,6 +36,7 @@ def default_dynamic_input_flags() -> dict: "has_dynamic_gas_cost": False, "has_dynamic_arb_fees": False, "has_lp_supply": False, + "has_reclamm_price_ratio_updates": False, } @@ -49,6 +52,9 @@ def dynamic_input_flags_from_frames(dynamic_input_frames: Optional[DynamicInputF "has_dynamic_gas_cost": dynamic_input_frames.gas_cost is not None, "has_dynamic_arb_fees": dynamic_input_frames.arb_fees is not None, "has_lp_supply": dynamic_input_frames.lp_supply is not None, + "has_reclamm_price_ratio_updates": ( + dynamic_input_frames.reclamm_price_ratio_updates is not None + ), } flags["use_dynamic_inputs"] = any(flags.values()) return flags @@ -59,11 +65,9 @@ def resolve_dynamic_input_flags( dynamic_input_flags: Optional[dict] = None, ) -> dict: """Return a safe dispatch flag set for the provided hot-path bundle.""" - flags = ( - default_dynamic_input_flags() - if dynamic_input_flags is None - else dict(dynamic_input_flags) - ) + flags = default_dynamic_input_flags() + if dynamic_input_flags is not None: + flags.update(dict(dynamic_input_flags)) if dynamic_inputs is not None: flags["use_dynamic_inputs"] = True return flags @@ -77,6 +81,10 @@ def empty_dynamic_input_arrays() -> DynamicInputArrays: gas_cost=jnp.zeros((1,), dtype=jnp.float64), arb_fees=jnp.zeros((1,), dtype=jnp.float64), lp_supply=jnp.ones((1,), dtype=jnp.float64), + # Columns: has_event, target_price_ratio, end_step, start_price_ratio_override + reclamm_price_ratio_updates=jnp.array( + [[0.0, 0.0, 0.0, jnp.nan]], dtype=jnp.float64 + ), ) @@ -109,6 +117,11 @@ def resolve_dynamic_input_components( if dynamic_input_flags["has_lp_supply"] else jnp.ones((1,), dtype=jnp.float64) ), + "reclamm_price_ratio_updates": ( + arrays.reclamm_price_ratio_updates + if dynamic_input_flags["has_reclamm_price_ratio_updates"] + else empty_dynamic_input_arrays().reclamm_price_ratio_updates + ), } @@ -148,6 +161,7 @@ def materialize_dynamic_inputs( "has_dynamic_gas_cost": True, "has_dynamic_arb_fees": True, "has_lp_supply": True, + "has_reclamm_price_ratio_updates": True, } else: flags = resolve_dynamic_input_flags(dynamic_inputs, dynamic_input_flags) @@ -174,4 +188,10 @@ def materialize_dynamic_inputs( lp_supply=_broadcast_dynamic_input_leaf( "lp_supply", resolved["lp_supply"], scan_len, dtype ), + reclamm_price_ratio_updates=_broadcast_dynamic_input_leaf( + "reclamm_price_ratio_updates", + resolved["reclamm_price_ratio_updates"], + scan_len, + dtype, + ), ) diff --git a/quantammsim/core_simulator/forward_pass.py b/quantammsim/core_simulator/forward_pass.py index d2a32cd..2125ff1 100644 --- a/quantammsim/core_simulator/forward_pass.py +++ b/quantammsim/core_simulator/forward_pass.py @@ -1113,6 +1113,9 @@ def forward_pass_nograd( gas_cost=stop_gradient(dynamic_inputs.gas_cost), arb_fees=stop_gradient(dynamic_inputs.arb_fees), lp_supply=stop_gradient(dynamic_inputs.lp_supply), + reclamm_price_ratio_updates=stop_gradient( + dynamic_inputs.reclamm_price_ratio_updates + ), ) return forward_pass( params, diff --git a/quantammsim/hooks/dynamic_fee_base_hook.py b/quantammsim/hooks/dynamic_fee_base_hook.py index ad64a5f..1ab0266 100644 --- a/quantammsim/hooks/dynamic_fee_base_hook.py +++ b/quantammsim/hooks/dynamic_fee_base_hook.py @@ -124,6 +124,7 @@ def calculate_reserves_with_fees( gas_cost=jnp.asarray(run_fingerprint["gas_cost"], dtype=jnp.float64), arb_fees=jnp.asarray(run_fingerprint["arb_fees"], dtype=jnp.float64), lp_supply=empty_inputs.lp_supply, + reclamm_price_ratio_updates=empty_inputs.reclamm_price_ratio_updates, ) return self.calculate_reserves_with_dynamic_inputs( diff --git a/quantammsim/pools/reCLAMM/reclamm.py b/quantammsim/pools/reCLAMM/reclamm.py index da36710..762301c 100644 --- a/quantammsim/pools/reCLAMM/reclamm.py +++ b/quantammsim/pools/reCLAMM/reclamm.py @@ -310,6 +310,7 @@ def calculate_reserves_and_fee_revenue_with_dynamic_inputs( fees=materialized_inputs.fees, arb_thresh=materialized_inputs.gas_cost, arb_fees=materialized_inputs.arb_fees, + price_ratio_updates=materialized_inputs.reclamm_price_ratio_updates, all_sig_variations=jnp.array( run_fingerprint["all_sig_variations"] ), @@ -387,6 +388,7 @@ def calculate_reserves_with_dynamic_inputs( fees=materialized_inputs.fees, arb_thresh=materialized_inputs.gas_cost, arb_fees=materialized_inputs.arb_fees, + price_ratio_updates=materialized_inputs.reclamm_price_ratio_updates, all_sig_variations=jnp.array( run_fingerprint["all_sig_variations"] ), diff --git a/quantammsim/pools/reCLAMM/reclamm_reserves.py b/quantammsim/pools/reCLAMM/reclamm_reserves.py index 81ad48e..b098eb1 100644 --- a/quantammsim/pools/reCLAMM/reclamm_reserves.py +++ b/quantammsim/pools/reCLAMM/reclamm_reserves.py @@ -17,7 +17,7 @@ import jax.numpy as jnp from jax import jit -from jax.lax import scan +from jax.lax import scan, cond from jax.tree_util import Partial from functools import partial @@ -556,6 +556,40 @@ def initialise_reclamm_reserves(initial_pool_value, initial_prices, price_ratio) # Scan-based reserve calculations # --------------------------------------------------------------------------- +def apply_target_price_ratio_to_virtual_balances(Ra, Rb, Va, Vb, target_price_ratio): + """Retarget virtual balances to a desired price ratio while preserving orientation. + + The overvalued-side virtual balance is preserved (subject to floor), and the + undervalued-side virtual balance is solved from the reCLAMM ratio constraint. + """ + safe_ratio = jnp.maximum(target_price_ratio, 1.0 + 1e-12) + sqrt_ratio = jnp.sqrt(safe_ratio) + fourth_root_ratio = jnp.sqrt(sqrt_ratio) + centeredness, is_above = compute_centeredness(Ra, Rb, Va, Vb) + + # Above center => B overvalued, so keep Vb and solve Va. + v_over_b_floor = Rb / jnp.maximum(fourth_root_ratio - 1.0, 1e-30) + Vb_kept = jnp.maximum(Vb, v_over_b_floor) + Va_from_b = Ra * (Vb_kept + Rb) / jnp.maximum( + (sqrt_ratio - 1.0) * Vb_kept - Rb, 1e-30 + ) + + # Below center => A overvalued, so keep Va and solve Vb. + v_over_a_floor = Ra / jnp.maximum(fourth_root_ratio - 1.0, 1e-30) + Va_kept = jnp.maximum(Va, v_over_a_floor) + Vb_from_a = Rb * (Va_kept + Ra) / jnp.maximum( + (sqrt_ratio - 1.0) * Va_kept - Ra, 1e-30 + ) + + Va_new = jnp.where(is_above, Va_from_b, Va_kept) + Vb_new = jnp.where(is_above, Vb_kept, Vb_from_a) + + # When centeredness is degenerate (e.g. both sides zero), preserve current virtuals. + invalid_centeredness = ~jnp.isfinite(centeredness) + Va_new = jnp.where(invalid_centeredness, Va, Va_new) + Vb_new = jnp.where(invalid_centeredness, Vb, Vb_new) + return Va_new, Vb_new + def _reclamm_scan_step_zero_fees( carry_list, prices, @@ -653,6 +687,13 @@ def _reclamm_scan_step_zero_fees( return [new_reserves, Va, Vb], new_reserves +# --------------------------------------------------------------------------- +# Test-only diagnostic helpers (virtual-balance history) +# --------------------------------------------------------------------------- +# These helpers mirror production kernels but additionally return Va/Vb +# trajectories for assertions in tests. Production pool paths should use the +# reserve-only kernels above. + def _reclamm_scan_step_zero_fees_full_state( carry_list, prices, @@ -662,7 +703,7 @@ def _reclamm_scan_step_zero_fees_full_state( arc_length_speed=0.0, centeredness_scaling=False, ): - """Like _reclamm_scan_step_zero_fees but outputs (reserves, Va, Vb).""" + """TEST-ONLY: scan step that outputs (reserves, Va, Vb).""" new_carry, new_reserves = _reclamm_scan_step_zero_fees( carry_list, prices, centeredness_margin, daily_price_shift_base, seconds_per_step, arc_length_speed=arc_length_speed, @@ -689,9 +730,10 @@ def _reclamm_scan_step_with_fees_and_revenue( Primary implementation — ``_reclamm_scan_step_with_fees`` wraps this. - Carry: [real_reserves (2,), Va (0-d), Vb (0-d)] + Carry: [real_reserves (2,), Va, Vb, step_idx, active_start_ratio, + active_target_ratio, active_start_step, active_end_step, active_enabled] Input: [prices, active_initial_weights, per_asset_ratios, - all_other_assets_ratios, gamma, arb_thresh, arb_fees] + all_other_assets_ratios, gamma, arb_thresh, arb_fees, price_ratio_update] Returns ------- @@ -702,6 +744,12 @@ def _reclamm_scan_step_with_fees_and_revenue( prev_reserves = carry_list[0] Va = carry_list[1] Vb = carry_list[2] + step_idx = carry_list[3] + active_start_ratio = carry_list[4] + active_target_ratio = carry_list[5] + active_start_step = carry_list[6] + active_end_step = carry_list[7] + active_enabled = carry_list[8] Ra = prev_reserves[0] Rb = prev_reserves[1] @@ -713,6 +761,86 @@ def _reclamm_scan_step_with_fees_and_revenue( gamma = input_list[4] arb_thresh = input_list[5] arb_fees = input_list[6] + price_ratio_update = input_list[7] + + event_has = price_ratio_update[0] > 0.5 + event_target_ratio = jnp.maximum( + jnp.where(jnp.isfinite(price_ratio_update[1]), price_ratio_update[1], 1.0), + 1.0 + 1e-12, + ) + event_end_step = jnp.where( + jnp.isfinite(price_ratio_update[2]), price_ratio_update[2], step_idx + ) + event_start_override = price_ratio_update[3] + + def _apply_schedule_state(_): + current_price_ratio = compute_price_ratio(Ra, Rb, Va, Vb) + start_ratio_from_event = jnp.where( + jnp.isfinite(event_start_override), + event_start_override, + current_price_ratio, + ) + next_active_start_ratio = jnp.where( + event_has, start_ratio_from_event, active_start_ratio + ) + next_active_target_ratio = jnp.where( + event_has, event_target_ratio, active_target_ratio + ) + next_active_start_step = jnp.where(event_has, step_idx, active_start_step) + next_active_end_step = jnp.where( + event_has, jnp.maximum(event_end_step, step_idx), active_end_step + ) + next_active_enabled = jnp.where(event_has, True, active_enabled) + + schedule_duration = next_active_end_step - next_active_start_step + schedule_progress = jnp.where( + schedule_duration <= 0.0, + 1.0, + jnp.clip((step_idx - next_active_start_step) / schedule_duration, 0.0, 1.0), + ) + scheduled_price_ratio = ( + next_active_start_ratio + + (next_active_target_ratio - next_active_start_ratio) * schedule_progress + ) + Va_scheduled, Vb_scheduled = apply_target_price_ratio_to_virtual_balances( + Ra, Rb, Va, Vb, scheduled_price_ratio + ) + return ( + Va_scheduled, + Vb_scheduled, + next_active_start_ratio, + next_active_target_ratio, + next_active_start_step, + next_active_end_step, + next_active_enabled, + ) + + def _skip_schedule_state(_): + return ( + Va, + Vb, + active_start_ratio, + active_target_ratio, + active_start_step, + active_end_step, + active_enabled, + ) + + schedule_active = jnp.logical_or(event_has, active_enabled) + ( + Va, + Vb, + active_start_ratio, + active_target_ratio, + active_start_step, + active_end_step, + active_enabled, + ) = cond( + schedule_active, + _apply_schedule_state, + _skip_schedule_state, + operand=None, + ) # Step 1: Update virtual balances if out of range centeredness, is_above = compute_centeredness(Ra, Rb, Va, Vb) @@ -827,7 +955,17 @@ def _reclamm_scan_step_with_fees_and_revenue( lp_fee_revenue_usd = (lp_fee_income * prices).sum() new_reserves = jnp.array([Ra_new, Rb_new]) - return [new_reserves, Va, Vb], (new_reserves, lp_fee_revenue_usd) + return [ + new_reserves, + Va, + Vb, + step_idx + 1.0, + active_start_ratio, + active_target_ratio, + active_start_step, + active_end_step, + active_enabled, + ], (new_reserves, lp_fee_revenue_usd) def _reclamm_scan_step_with_fees( @@ -865,6 +1003,37 @@ def _reclamm_scan_step_with_fees( return new_carry, new_reserves +def _reclamm_scan_step_with_fees_full_state( + carry_list, + input_list, + weights, + tokens_to_drop, + active_trade_directions, + n, + centeredness_margin, + daily_price_shift_base, + seconds_per_step, + arc_length_speed=0.0, + centeredness_scaling=False, + protocol_fee_split=0.0, +): + """TEST-ONLY: fee scan step that also outputs virtual balances.""" + new_carry, (new_reserves, _fee_rev) = _reclamm_scan_step_with_fees_and_revenue( + carry_list, input_list, + weights=weights, + tokens_to_drop=tokens_to_drop, + active_trade_directions=active_trade_directions, + n=n, + centeredness_margin=centeredness_margin, + daily_price_shift_base=daily_price_shift_base, + seconds_per_step=seconds_per_step, + arc_length_speed=arc_length_speed, + centeredness_scaling=centeredness_scaling, + protocol_fee_split=protocol_fee_split, + ) + return new_carry, (new_reserves, new_carry[1], new_carry[2]) + + @jit def _jax_calc_reclamm_reserves_zero_fees( initial_reserves, @@ -929,7 +1098,7 @@ def _jax_calc_reclamm_reserves_zero_fees_full_state( arc_length_speed=0.0, centeredness_scaling=False, ): - """Like _jax_calc_reclamm_reserves_zero_fees but also returns virtual balances. + """TEST-ONLY: Like _jax_calc_reclamm_reserves_zero_fees but returns Va/Vb. Returns ------- @@ -992,6 +1161,8 @@ def _jax_calc_reclamm_reserves_with_fees( gamma_array = jnp.full(prices.shape[0], gamma) arb_thresh_array = jnp.full(prices.shape[0], arb_thresh) arb_fees_array = jnp.full(prices.shape[0], arb_fees) + price_ratio_updates = jnp.zeros((prices.shape[0], 4), dtype=prices.dtype) + price_ratio_updates = price_ratio_updates.at[:, 3].set(jnp.nan) scan_fn = Partial( _reclamm_scan_step_with_fees, @@ -1007,17 +1178,27 @@ def _jax_calc_reclamm_reserves_with_fees( protocol_fee_split=protocol_fee_split, ) - carry_init = [initial_reserves, initial_Va, initial_Vb] + carry_init = [ + initial_reserves, + initial_Va, + initial_Vb, + jnp.float64(0.0), # step_idx + jnp.float64(0.0), # active_start_ratio + jnp.float64(0.0), # active_target_ratio + jnp.float64(0.0), # active_start_step + jnp.float64(0.0), # active_end_step + jnp.array(False), # active_enabled + ] _, reserves = scan( scan_fn, carry_init, [prices, active_initial_weights, per_asset_ratios, - all_other_assets_ratios, gamma_array, arb_thresh_array, arb_fees_array], + all_other_assets_ratios, gamma_array, arb_thresh_array, arb_fees_array, price_ratio_updates], ) return reserves -@partial(jit, static_argnums=(10,)) +@partial(jit, static_argnums=(11,)) def _jax_calc_reclamm_reserves_with_dynamic_inputs( initial_reserves, initial_Va, @@ -1029,6 +1210,7 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs( fees, arb_thresh, arb_fees, + price_ratio_updates=None, do_trades=False, trades=None, all_sig_variations=None, @@ -1048,6 +1230,18 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs( arb_fees = jnp.where( arb_fees.size == 1, jnp.full(prices.shape[0], arb_fees), arb_fees ) + if price_ratio_updates is None: + price_ratio_updates = jnp.zeros((prices.shape[0], 4), dtype=prices.dtype) + price_ratio_updates = price_ratio_updates.at[:, 3].set(jnp.nan) + else: + if price_ratio_updates.ndim == 1: + price_ratio_updates = jnp.broadcast_to( + price_ratio_updates, (prices.shape[0], price_ratio_updates.shape[0]) + ) + elif price_ratio_updates.shape[0] == 1 and prices.shape[0] != 1: + price_ratio_updates = jnp.broadcast_to( + price_ratio_updates, (prices.shape[0], price_ratio_updates.shape[1]) + ) _, active_trade_directions, tokens_to_drop, leave_one_out_idxs = ( precalc_shared_values_for_all_signatures(all_sig_variations, n_assets) @@ -1074,16 +1268,115 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs( protocol_fee_split=protocol_fee_split, ) - carry_init = [initial_reserves, initial_Va, initial_Vb] + carry_init = [ + initial_reserves, + initial_Va, + initial_Vb, + jnp.float64(0.0), # step_idx + jnp.float64(0.0), # active_start_ratio + jnp.float64(0.0), # active_target_ratio + jnp.float64(0.0), # active_start_step + jnp.float64(0.0), # active_end_step + jnp.array(False), # active_enabled + ] _, reserves = scan( scan_fn, carry_init, [prices, active_initial_weights, per_asset_ratios, - all_other_assets_ratios, gamma, arb_thresh, arb_fees], + all_other_assets_ratios, gamma, arb_thresh, arb_fees, price_ratio_updates], ) return reserves +@partial(jit, static_argnums=(11,)) +def _jax_calc_reclamm_reserves_with_dynamic_inputs_full_state( + initial_reserves, + initial_Va, + initial_Vb, + prices, + centeredness_margin, + daily_price_shift_base, + seconds_per_step, + fees, + arb_thresh, + arb_fees, + price_ratio_updates=None, + do_trades=False, + trades=None, + all_sig_variations=None, + arc_length_speed=0.0, + centeredness_scaling=False, + protocol_fee_split=0.0, +): + """TEST-ONLY: dynamic-input reserve path returning virtual-balance history.""" + n_assets = 2 + weights = jnp.array([0.5, 0.5]) + + gamma = jnp.where(fees.size == 1, jnp.full(prices.shape[0], 1.0 - fees), 1.0 - fees) + arb_thresh = jnp.where( + arb_thresh.size == 1, jnp.full(prices.shape[0], arb_thresh), arb_thresh + ) + arb_fees = jnp.where( + arb_fees.size == 1, jnp.full(prices.shape[0], arb_fees), arb_fees + ) + if price_ratio_updates is None: + price_ratio_updates = jnp.zeros((prices.shape[0], 4), dtype=prices.dtype) + price_ratio_updates = price_ratio_updates.at[:, 3].set(jnp.nan) + else: + if price_ratio_updates.ndim == 1: + price_ratio_updates = jnp.broadcast_to( + price_ratio_updates, (prices.shape[0], price_ratio_updates.shape[0]) + ) + elif price_ratio_updates.shape[0] == 1 and prices.shape[0] != 1: + price_ratio_updates = jnp.broadcast_to( + price_ratio_updates, (prices.shape[0], price_ratio_updates.shape[1]) + ) + + _, active_trade_directions, tokens_to_drop, leave_one_out_idxs = ( + precalc_shared_values_for_all_signatures(all_sig_variations, n_assets) + ) + + active_initial_weights, per_asset_ratios, all_other_assets_ratios = ( + precalc_components_of_optimal_trade_across_prices_and_dynamic_fees( + weights, prices, gamma, tokens_to_drop, + active_trade_directions, leave_one_out_idxs, + ) + ) + + scan_fn = Partial( + _reclamm_scan_step_with_fees_full_state, + weights=weights, + tokens_to_drop=tokens_to_drop, + active_trade_directions=active_trade_directions, + n=n_assets, + centeredness_margin=centeredness_margin, + daily_price_shift_base=daily_price_shift_base, + seconds_per_step=seconds_per_step, + arc_length_speed=arc_length_speed, + centeredness_scaling=centeredness_scaling, + protocol_fee_split=protocol_fee_split, + ) + + carry_init = [ + initial_reserves, + initial_Va, + initial_Vb, + jnp.float64(0.0), # step_idx + jnp.float64(0.0), # active_start_ratio + jnp.float64(0.0), # active_target_ratio + jnp.float64(0.0), # active_start_step + jnp.float64(0.0), # active_end_step + jnp.array(False), # active_enabled + ] + _, (reserves, Va_history, Vb_history) = scan( + scan_fn, + carry_init, + [prices, active_initial_weights, per_asset_ratios, + all_other_assets_ratios, gamma, arb_thresh, arb_fees, price_ratio_updates], + ) + return reserves, Va_history, Vb_history + + @jit def _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( initial_reserves, @@ -1127,6 +1420,8 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( gamma_array = jnp.full(prices.shape[0], gamma) arb_thresh_array = jnp.full(prices.shape[0], arb_thresh) arb_fees_array = jnp.full(prices.shape[0], arb_fees) + price_ratio_updates = jnp.zeros((prices.shape[0], 4), dtype=prices.dtype) + price_ratio_updates = price_ratio_updates.at[:, 3].set(jnp.nan) scan_fn = Partial( _reclamm_scan_step_with_fees_and_revenue, @@ -1142,17 +1437,27 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( protocol_fee_split=protocol_fee_split, ) - carry_init = [initial_reserves, initial_Va, initial_Vb] + carry_init = [ + initial_reserves, + initial_Va, + initial_Vb, + jnp.float64(0.0), # step_idx + jnp.float64(0.0), # active_start_ratio + jnp.float64(0.0), # active_target_ratio + jnp.float64(0.0), # active_start_step + jnp.float64(0.0), # active_end_step + jnp.array(False), # active_enabled + ] _, (reserves, fee_revenue) = scan( scan_fn, carry_init, [prices, active_initial_weights, per_asset_ratios, - all_other_assets_ratios, gamma_array, arb_thresh_array, arb_fees_array], + all_other_assets_ratios, gamma_array, arb_thresh_array, arb_fees_array, price_ratio_updates], ) return reserves, fee_revenue -@partial(jit, static_argnums=(10,)) +@partial(jit, static_argnums=(11,)) def _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( initial_reserves, initial_Va, @@ -1164,6 +1469,7 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( fees, arb_thresh, arb_fees, + price_ratio_updates=None, do_trades=False, trades=None, all_sig_variations=None, @@ -1189,6 +1495,18 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( arb_fees = jnp.where( arb_fees.size == 1, jnp.full(prices.shape[0], arb_fees), arb_fees ) + if price_ratio_updates is None: + price_ratio_updates = jnp.zeros((prices.shape[0], 4), dtype=prices.dtype) + price_ratio_updates = price_ratio_updates.at[:, 3].set(jnp.nan) + else: + if price_ratio_updates.ndim == 1: + price_ratio_updates = jnp.broadcast_to( + price_ratio_updates, (prices.shape[0], price_ratio_updates.shape[0]) + ) + elif price_ratio_updates.shape[0] == 1 and prices.shape[0] != 1: + price_ratio_updates = jnp.broadcast_to( + price_ratio_updates, (prices.shape[0], price_ratio_updates.shape[1]) + ) _, active_trade_directions, tokens_to_drop, leave_one_out_idxs = ( precalc_shared_values_for_all_signatures(all_sig_variations, n_assets) @@ -1215,11 +1533,21 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( protocol_fee_split=protocol_fee_split, ) - carry_init = [initial_reserves, initial_Va, initial_Vb] + carry_init = [ + initial_reserves, + initial_Va, + initial_Vb, + jnp.float64(0.0), # step_idx + jnp.float64(0.0), # active_start_ratio + jnp.float64(0.0), # active_target_ratio + jnp.float64(0.0), # active_start_step + jnp.float64(0.0), # active_end_step + jnp.array(False), # active_enabled + ] _, (reserves, fee_revenue) = scan( scan_fn, carry_init, [prices, active_initial_weights, per_asset_ratios, - all_other_assets_ratios, gamma, arb_thresh, arb_fees], + all_other_assets_ratios, gamma, arb_thresh, arb_fees, price_ratio_updates], ) return reserves, fee_revenue diff --git a/quantammsim/runners/jax_runner_utils.py b/quantammsim/runners/jax_runner_utils.py index e9c74f9..a7598e1 100644 --- a/quantammsim/runners/jax_runner_utils.py +++ b/quantammsim/runners/jax_runner_utils.py @@ -1228,6 +1228,7 @@ def _to_dynamic_input_arrays( gas_cost_array, arb_fees_array, lp_supply_array, + reclamm_price_ratio_updates_array, ) -> DynamicInputArrays: """Normalize optional numpy arrays into the hot-path container.""" empty = empty_dynamic_input_arrays() @@ -1237,9 +1238,156 @@ def _to_dynamic_input_arrays( gas_cost=empty.gas_cost if gas_cost_array is None else jnp.asarray(gas_cost_array, dtype=jnp.float64), arb_fees=empty.arb_fees if arb_fees_array is None else jnp.asarray(arb_fees_array, dtype=jnp.float64), lp_supply=empty.lp_supply if lp_supply_array is None else jnp.asarray(lp_supply_array, dtype=jnp.float64), + reclamm_price_ratio_updates=( + empty.reclamm_price_ratio_updates + if reclamm_price_ratio_updates_array is None + else jnp.asarray(reclamm_price_ratio_updates_array, dtype=jnp.float64) + ), ) +def _coerce_reclamm_price_ratio_updates_to_frame(raw_updates) -> pd.DataFrame: + """Accept DataFrame / CSV path / list[dict] / dict payload and return a DataFrame.""" + if raw_updates is None: + return pd.DataFrame( + columns=["unix", "end_unix", "price_ratio", "start_price_ratio"] + ) + if isinstance(raw_updates, pd.DataFrame): + return raw_updates.copy() + if isinstance(raw_updates, (str, Path)): + return pd.read_csv(raw_updates) + if isinstance(raw_updates, list): + return pd.DataFrame(raw_updates) + if isinstance(raw_updates, dict): + for key in ( + "updates", + "rows", + "reclamm_price_ratio_updates", + "price_ratio_updates", + ): + value = raw_updates.get(key) + if isinstance(value, list): + return pd.DataFrame(value) + if isinstance(value, (str, Path)): + return pd.read_csv(value) + return pd.DataFrame(raw_updates) + raise TypeError( + "reclamm_price_ratio_updates must be a DataFrame, CSV path, list of dicts, or dict payload" + ) + + +def _ceil_div_nonnegative(delta: int, denom: int) -> int: + """Ceiling division for non-negative integers.""" + if delta <= 0: + return 0 + return (delta + denom - 1) // denom + + +def _normalize_reclamm_price_ratio_updates_for_window( + raw_updates, + start_date_string: str, + end_date_string: str, + arb_frequency: int, +) -> np.ndarray: + """Normalize manual reCLAMM price-ratio updates into per-step event rows. + + Output columns per step: + 0. has_event (0/1) + 1. target_price_ratio + 2. end_step + 3. start_price_ratio_override (NaN when not supplied) + """ + start_unix = pd.to_datetime(start_date_string, format="%Y-%m-%d %H:%M:%S").value // 10**6 + end_unix = pd.to_datetime(end_date_string, format="%Y-%m-%d %H:%M:%S").value // 10**6 + step_ms = int(arb_frequency) * 60 * 1000 + if step_ms <= 0: + raise ValueError("arb_frequency must be >= 1 for reCLAMM price-ratio updates") + scan_len = int(max((end_unix - start_unix) // step_ms, 0)) + default_matrix = np.zeros((scan_len, 4), dtype=np.float64) + if scan_len > 0: + default_matrix[:, 3] = np.nan + + updates_df = _coerce_reclamm_price_ratio_updates_to_frame(raw_updates) + if updates_df.empty: + return default_matrix + + required = {"unix", "end_unix", "price_ratio"} + missing = sorted(required.difference(updates_df.columns)) + if missing: + raise ValueError( + "reclamm_price_ratio_updates missing required columns: " + + ", ".join(missing) + ) + + updates = updates_df.copy() + updates["unix"] = pd.to_numeric(updates["unix"], errors="coerce") + updates["end_unix"] = pd.to_numeric(updates["end_unix"], errors="coerce") + updates["price_ratio"] = pd.to_numeric(updates["price_ratio"], errors="coerce") + if "start_price_ratio" in updates.columns: + updates["start_price_ratio"] = pd.to_numeric( + updates["start_price_ratio"], errors="coerce" + ) + else: + updates["start_price_ratio"] = np.nan + + invalid_required = updates["unix"].isna() | updates["end_unix"].isna() | updates["price_ratio"].isna() + if invalid_required.any(): + raise ValueError( + "reclamm_price_ratio_updates contains non-numeric unix/end_unix/price_ratio values" + ) + if (updates["price_ratio"] <= 1.0).any(): + raise ValueError("reclamm price_ratio values must be > 1.0") + if (updates["end_unix"] < updates["unix"]).any(): + raise ValueError("reclamm end_unix must be >= unix for every update") + + updates = updates.sort_values("unix", kind="stable") + + for _, row in updates.iterrows(): + event_start_unix = int(row["unix"]) + event_end_unix = int(row["end_unix"]) + target_price_ratio = float(row["price_ratio"]) + start_price_ratio_override = row["start_price_ratio"] + + # Event completes before window start - no effect. + if event_end_unix <= start_unix: + continue + + in_progress_pre_window = event_start_unix < start_unix and event_end_unix > start_unix + if in_progress_pre_window and pd.isna(start_price_ratio_override): + raise ValueError( + "reclamm pre-window in-progress event requires start_price_ratio" + ) + + effective_start_unix = max(event_start_unix, start_unix) + start_step = _ceil_div_nonnegative(effective_start_unix - start_unix, step_ms) + end_step = _ceil_div_nonnegative(event_end_unix - start_unix, step_ms) + + # Starts after current window. + if start_step >= scan_len: + continue + + end_step = min(max(end_step, start_step), scan_len - 1) + default_matrix[start_step, 0] = 1.0 + default_matrix[start_step, 1] = target_price_ratio + default_matrix[start_step, 2] = float(end_step) + default_matrix[start_step, 3] = ( + float(start_price_ratio_override) + if not pd.isna(start_price_ratio_override) + else np.nan + ) + + return default_matrix + + +def _has_reclamm_schedule_events(schedule_array: Optional[np.ndarray]) -> bool: + """Return True when a normalized schedule contains at least one event row.""" + if schedule_array is None: + return False + if schedule_array.size == 0: + return False + return bool(np.any(np.asarray(schedule_array)[:, 0] > 0.5)) + + def prepare_dynamic_inputs( run_fingerprint, dynamic_input_frames: Optional[DynamicInputFrames] = None, @@ -1254,6 +1402,7 @@ def prepare_dynamic_inputs( gas_cost_df = dynamic_input_frames.gas_cost arb_fees_df = dynamic_input_frames.arb_fees lp_supply_df = dynamic_input_frames.lp_supply + reclamm_price_ratio_updates = dynamic_input_frames.reclamm_price_ratio_updates dynamic_input_flags = dynamic_input_flags_from_frames(dynamic_input_frames) if raw_trades is not None: @@ -1369,6 +1518,51 @@ def prepare_dynamic_inputs( else None ) + reclamm_price_ratio_updates_array = ( + _normalize_reclamm_price_ratio_updates_for_window( + reclamm_price_ratio_updates, + run_fingerprint["startDateString"], + run_fingerprint["endDateString"], + run_fingerprint["arb_frequency"], + ) + if reclamm_price_ratio_updates is not None + else None + ) + if do_test_period: + test_reclamm_price_ratio_updates_array = ( + _normalize_reclamm_price_ratio_updates_for_window( + reclamm_price_ratio_updates, + run_fingerprint["endDateString"], + run_fingerprint["endTestDateString"], + run_fingerprint["arb_frequency"], + ) + if reclamm_price_ratio_updates is not None + else None + ) + + train_has_reclamm_schedule = _has_reclamm_schedule_events( + reclamm_price_ratio_updates_array + ) + test_has_reclamm_schedule = False + if do_test_period: + test_has_reclamm_schedule = _has_reclamm_schedule_events( + test_reclamm_price_ratio_updates_array + ) + + if not train_has_reclamm_schedule: + reclamm_price_ratio_updates_array = None + if do_test_period and not test_has_reclamm_schedule: + test_reclamm_price_ratio_updates_array = None + if not train_has_reclamm_schedule and ( + not do_test_period or not test_has_reclamm_schedule + ): + dynamic_input_flags["has_reclamm_price_ratio_updates"] = False + dynamic_input_flags["use_dynamic_inputs"] = any( + value + for key, value in dynamic_input_flags.items() + if key != "use_dynamic_inputs" + ) + # Unit LP supply is the neutral case; keep it on the static hot path. if lp_supply_array is not None and np.allclose(lp_supply_array, 1.0): lp_supply_array = None @@ -1388,6 +1582,7 @@ def prepare_dynamic_inputs( gas_cost_array, arb_fees_array, lp_supply_array, + reclamm_price_ratio_updates_array, ), "test_dynamic_inputs": _to_dynamic_input_arrays( test_period_trades, @@ -1395,6 +1590,7 @@ def prepare_dynamic_inputs( test_gas_cost_array, test_arb_fees_array, test_lp_supply_array, + test_reclamm_price_ratio_updates_array, ), "dynamic_input_flags": dynamic_input_flags, } @@ -1405,6 +1601,7 @@ def prepare_dynamic_inputs( gas_cost_array, arb_fees_array, lp_supply_array, + reclamm_price_ratio_updates_array, ), "dynamic_input_flags": dynamic_input_flags, } @@ -1657,6 +1854,7 @@ def try_forward_pass(n_sets: int) -> bool: "has_dynamic_gas_cost": False, "has_dynamic_arb_fees": False, "has_lp_supply": False, + "has_reclamm_price_ratio_updates": False, }, }, ) diff --git a/quantammsim/runners/jax_runners.py b/quantammsim/runners/jax_runners.py index aad11bd..cbc8cdc 100644 --- a/quantammsim/runners/jax_runners.py +++ b/quantammsim/runners/jax_runners.py @@ -657,6 +657,7 @@ def train_on_historic_data( "has_dynamic_gas_cost": False, "has_dynamic_arb_fees": False, "has_lp_supply": False, + "has_reclamm_price_ratio_updates": False, }, }, ) diff --git a/quantammsim/simulator_analysis_tools/finance/param_financial_calculator.py b/quantammsim/simulator_analysis_tools/finance/param_financial_calculator.py index aa31519..cc000b0 100644 --- a/quantammsim/simulator_analysis_tools/finance/param_financial_calculator.py +++ b/quantammsim/simulator_analysis_tools/finance/param_financial_calculator.py @@ -243,6 +243,9 @@ def run_pool_simulation(simulationRunDto): trades=raw_trades, fees=fee_steps_df, gas_cost=gas_cost_df, + reclamm_price_ratio_updates=run_fingerprint.get( + "reclamm_price_ratio_updates" + ), ) print("run fingerprint-------------------", run_fingerprint) diff --git a/tests/pools/reCLAMM/helpers.py b/tests/pools/reCLAMM/helpers.py new file mode 100644 index 0000000..4d06f70 --- /dev/null +++ b/tests/pools/reCLAMM/helpers.py @@ -0,0 +1,15 @@ +"""Test-only exports for reCLAMM diagnostic kernels. + +This module centralizes access to kernels that return virtual-balance history. +Production paths should use reserve-only kernels in ``reclamm_reserves``. +""" + +from quantammsim.pools.reCLAMM.reclamm_reserves import ( + _jax_calc_reclamm_reserves_with_dynamic_inputs_full_state, + _jax_calc_reclamm_reserves_zero_fees_full_state, +) + +__all__ = [ + "_jax_calc_reclamm_reserves_zero_fees_full_state", + "_jax_calc_reclamm_reserves_with_dynamic_inputs_full_state", +] diff --git a/tests/pools/reCLAMM/test_reclamm_e2e.py b/tests/pools/reCLAMM/test_reclamm_e2e.py index 25edc98..80eeef6 100644 --- a/tests/pools/reCLAMM/test_reclamm_e2e.py +++ b/tests/pools/reCLAMM/test_reclamm_e2e.py @@ -30,8 +30,8 @@ initialise_reclamm_reserves, _jax_calc_reclamm_reserves_zero_fees, _jax_calc_reclamm_reserves_with_fees, - _jax_calc_reclamm_reserves_zero_fees_full_state, ) +from tests.pools.reCLAMM.helpers import _jax_calc_reclamm_reserves_zero_fees_full_state ALL_SIG_VARIATIONS_2 = jnp.array([[1, -1], [-1, 1]]) diff --git a/tests/pools/reCLAMM/test_reclamm_fee_revenue.py b/tests/pools/reCLAMM/test_reclamm_fee_revenue.py index 4a31f1f..bfbd21b 100644 --- a/tests/pools/reCLAMM/test_reclamm_fee_revenue.py +++ b/tests/pools/reCLAMM/test_reclamm_fee_revenue.py @@ -355,6 +355,7 @@ def test_pool_method_with_dynamic_inputs(self): gas_cost=arb_thresh_array, arb_fees=arb_fees_array, lp_supply=jnp.ones((1,)), + reclamm_price_ratio_updates=jnp.array([[0.0, 0.0, 0.0, jnp.nan]]), ) reserves, fee_revenue = pool.calculate_reserves_and_fee_revenue_with_dynamic_inputs( diff --git a/tests/pools/reCLAMM/test_reclamm_math.py b/tests/pools/reCLAMM/test_reclamm_math.py index 6f3870d..b1ce8ee 100644 --- a/tests/pools/reCLAMM/test_reclamm_math.py +++ b/tests/pools/reCLAMM/test_reclamm_math.py @@ -729,9 +729,9 @@ def test_arc_length_single_step_exact(self): def test_arc_length_constant_through_scan(self): """Through the scan, per-step Δs should be approximately constant.""" - from quantammsim.pools.reCLAMM.reclamm_reserves import ( + from quantammsim.pools.reCLAMM.reclamm_reserves import calibrate_arc_length_speed + from tests.pools.reCLAMM.helpers import ( _jax_calc_reclamm_reserves_zero_fees_full_state, - calibrate_arc_length_speed, ) Ra, Rb, Va, Vb, Q = _centered_pool(P=2.0, price_ratio=4.0) diff --git a/tests/pools/reCLAMM/test_reclamm_price_ratio_updates.py b/tests/pools/reCLAMM/test_reclamm_price_ratio_updates.py new file mode 100644 index 0000000..5ee98bc --- /dev/null +++ b/tests/pools/reCLAMM/test_reclamm_price_ratio_updates.py @@ -0,0 +1,221 @@ +"""Tests for manual reCLAMM price-ratio schedule updates.""" + +import numpy as np +import numpy.testing as npt +import jax.numpy as jnp +import pytest + +from quantammsim.pools.reCLAMM.reclamm_reserves import ( + compute_price_ratio, + initialise_reclamm_reserves, + _jax_calc_reclamm_reserves_with_dynamic_inputs, + _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs, +) +from tests.pools.reCLAMM.helpers import ( + _jax_calc_reclamm_reserves_with_dynamic_inputs_full_state, +) + + +DEFAULT_INITIAL_POOL_VALUE = 1_000_000.0 +DEFAULT_INITIAL_PRICES = jnp.array([2500.0, 1.0], dtype=jnp.float64) +DEFAULT_PRICE_RATIO = 4.0 +DEFAULT_DAILY_PRICE_SHIFT_BASE = 1.0 - 1.0 / 124000.0 +DEFAULT_SECONDS_PER_STEP = 60.0 +ALL_SIG_VARIATIONS_2 = tuple(map(tuple, [[1, -1], [-1, 1]])) + + +def _init_pool(price_ratio=DEFAULT_PRICE_RATIO): + reserves, Va, Vb = initialise_reclamm_reserves( + DEFAULT_INITIAL_POOL_VALUE, + DEFAULT_INITIAL_PRICES, + price_ratio, + ) + return reserves, Va, Vb + + +def _flat_prices(n_steps): + return jnp.stack( + [jnp.full((n_steps,), DEFAULT_INITIAL_PRICES[0]), jnp.ones((n_steps,))], + axis=1, + ) + + +def _empty_schedule(n_steps): + schedule = np.zeros((n_steps, 4), dtype=np.float64) + schedule[:, 3] = np.nan + return jnp.asarray(schedule) + + +def _single_event_schedule( + n_steps, + start_step, + end_step, + target_price_ratio, + start_price_ratio_override=np.nan, +): + schedule = np.zeros((n_steps, 4), dtype=np.float64) + schedule[:, 3] = np.nan + schedule[start_step, 0] = 1.0 + schedule[start_step, 1] = target_price_ratio + schedule[start_step, 2] = float(end_step) + schedule[start_step, 3] = start_price_ratio_override + return jnp.asarray(schedule) + + +class TestReclammPriceRatioUpdates: + def test_schedule_off_matches_baseline_dynamic_kernel(self): + reserves, Va, Vb = _init_pool() + n_steps = 8 + prices = _flat_prices(n_steps) + fees = jnp.zeros((n_steps,), dtype=jnp.float64) + arb_thresh = jnp.zeros((n_steps,), dtype=jnp.float64) + arb_fees = jnp.zeros((n_steps,), dtype=jnp.float64) + + baseline = _jax_calc_reclamm_reserves_with_dynamic_inputs( + reserves, + Va, + Vb, + prices, + centeredness_margin=0.2, + daily_price_shift_base=DEFAULT_DAILY_PRICE_SHIFT_BASE, + seconds_per_step=DEFAULT_SECONDS_PER_STEP, + fees=fees, + arb_thresh=arb_thresh, + arb_fees=arb_fees, + all_sig_variations=ALL_SIG_VARIATIONS_2, + ) + with_schedule = _jax_calc_reclamm_reserves_with_dynamic_inputs( + reserves, + Va, + Vb, + prices, + centeredness_margin=0.2, + daily_price_shift_base=DEFAULT_DAILY_PRICE_SHIFT_BASE, + seconds_per_step=DEFAULT_SECONDS_PER_STEP, + fees=fees, + arb_thresh=arb_thresh, + arb_fees=arb_fees, + price_ratio_updates=_empty_schedule(n_steps), + all_sig_variations=ALL_SIG_VARIATIONS_2, + ) + npt.assert_allclose(with_schedule, baseline, rtol=1e-10, atol=1e-10) + + def test_single_schedule_reaches_target_ratio_at_end_step(self): + reserves, Va, Vb = _init_pool() + n_steps = 8 + prices = _flat_prices(n_steps) + fees = jnp.zeros((n_steps,), dtype=jnp.float64) + arb_thresh = jnp.zeros((n_steps,), dtype=jnp.float64) + arb_fees = jnp.zeros((n_steps,), dtype=jnp.float64) + + end_step = 4 + schedule = _single_event_schedule( + n_steps, + start_step=1, + end_step=end_step, + target_price_ratio=9.0, + start_price_ratio_override=DEFAULT_PRICE_RATIO, + ) + reserves_out, Va_history, Vb_history = ( + _jax_calc_reclamm_reserves_with_dynamic_inputs_full_state( + reserves, + Va, + Vb, + prices, + centeredness_margin=0.0, + daily_price_shift_base=DEFAULT_DAILY_PRICE_SHIFT_BASE, + seconds_per_step=DEFAULT_SECONDS_PER_STEP, + fees=fees, + arb_thresh=arb_thresh, + arb_fees=arb_fees, + price_ratio_updates=schedule, + all_sig_variations=ALL_SIG_VARIATIONS_2, + ) + ) + + ratio_at_end = float( + compute_price_ratio( + reserves_out[end_step, 0], + reserves_out[end_step, 1], + Va_history[end_step], + Vb_history[end_step], + ) + ) + assert ratio_at_end == pytest.approx(9.0, rel=1e-5, abs=1e-5) + + def test_replacement_event_supersedes_active_event(self): + reserves, Va, Vb = _init_pool() + n_steps = 9 + prices = _flat_prices(n_steps) + fees = jnp.zeros((n_steps,), dtype=jnp.float64) + arb_thresh = jnp.zeros((n_steps,), dtype=jnp.float64) + arb_fees = jnp.zeros((n_steps,), dtype=jnp.float64) + + schedule = np.zeros((n_steps, 4), dtype=np.float64) + schedule[:, 3] = np.nan + # Event 1: interpolate toward 8.0 until step 6. + schedule[1] = np.array([1.0, 8.0, 6.0, DEFAULT_PRICE_RATIO], dtype=np.float64) + # Event 2 replaces at step 3 and targets 2.0 by step 4. + schedule[3] = np.array([1.0, 2.0, 4.0, np.nan], dtype=np.float64) + + reserves_out, Va_history, Vb_history = ( + _jax_calc_reclamm_reserves_with_dynamic_inputs_full_state( + reserves, + Va, + Vb, + prices, + centeredness_margin=0.0, + daily_price_shift_base=DEFAULT_DAILY_PRICE_SHIFT_BASE, + seconds_per_step=DEFAULT_SECONDS_PER_STEP, + fees=fees, + arb_thresh=arb_thresh, + arb_fees=arb_fees, + price_ratio_updates=jnp.asarray(schedule), + all_sig_variations=ALL_SIG_VARIATIONS_2, + ) + ) + + ratio_after_replacement = float( + compute_price_ratio( + reserves_out[4, 0], + reserves_out[4, 1], + Va_history[4], + Vb_history[4], + ) + ) + assert ratio_after_replacement == pytest.approx(2.0, rel=1e-4, abs=1e-4) + + def test_dynamic_fee_revenue_path_with_schedule(self): + reserves, Va, Vb = _init_pool() + n_steps = 10 + prices = _flat_prices(n_steps) + fees = jnp.full((n_steps,), 0.003, dtype=jnp.float64) + arb_thresh = jnp.zeros((n_steps,), dtype=jnp.float64) + arb_fees = jnp.zeros((n_steps,), dtype=jnp.float64) + + schedule = _single_event_schedule( + n_steps, + start_step=2, + end_step=6, + target_price_ratio=5.5, + start_price_ratio_override=DEFAULT_PRICE_RATIO, + ) + reserves_out, fee_revenue = ( + _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( + reserves, + Va, + Vb, + prices, + centeredness_margin=0.2, + daily_price_shift_base=DEFAULT_DAILY_PRICE_SHIFT_BASE, + seconds_per_step=DEFAULT_SECONDS_PER_STEP, + fees=fees, + arb_thresh=arb_thresh, + arb_fees=arb_fees, + price_ratio_updates=schedule, + all_sig_variations=ALL_SIG_VARIATIONS_2, + ) + ) + assert reserves_out.shape == (n_steps, 2) + assert fee_revenue.shape == (n_steps,) + assert jnp.all(fee_revenue >= 0.0) diff --git a/tests/unit/test_jax_runner_utils.py b/tests/unit/test_jax_runner_utils.py index 620ecb8..2e014f5 100644 --- a/tests/unit/test_jax_runner_utils.py +++ b/tests/unit/test_jax_runner_utils.py @@ -289,6 +289,7 @@ def test_dynamic_input_flags_reflect_present_frames(self): assert flags["has_dynamic_gas_cost"] is True assert flags["has_dynamic_arb_fees"] is False assert flags["has_lp_supply"] is False + assert flags["has_reclamm_price_ratio_updates"] is False def test_prepare_dynamic_inputs_preserves_fixed_hot_path_structure(self): """Normalization should return fixed bundles plus static dispatch flags.""" @@ -333,18 +334,203 @@ def test_prepare_dynamic_inputs_preserves_fixed_hot_path_structure(self): assert flags["has_dynamic_gas_cost"] is True assert flags["has_dynamic_arb_fees"] is True assert flags["has_lp_supply"] is True + assert flags["has_reclamm_price_ratio_updates"] is False assert train_inputs.trades.shape == (2, 3) assert train_inputs.fees.shape == (2,) assert train_inputs.gas_cost.shape == (2,) assert train_inputs.arb_fees.shape == (2,) assert train_inputs.lp_supply.shape == (2,) + assert train_inputs.reclamm_price_ratio_updates.shape == (1, 4) assert test_inputs.trades.shape == (2, 3) assert test_inputs.fees.shape == (2,) assert test_inputs.gas_cost.shape == (2,) assert test_inputs.arb_fees.shape == (2,) assert test_inputs.lp_supply.shape == (2,) + assert test_inputs.reclamm_price_ratio_updates.shape == (1, 4) np.testing.assert_allclose(np.asarray(train_inputs.fees), np.array([0.003, 0.003])) + def test_prepare_dynamic_inputs_normalizes_reclamm_price_ratio_updates(self): + """Manual reCLAMM update schedules should map to per-step event rows.""" + from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames + from quantammsim.runners.jax_runner_utils import prepare_dynamic_inputs + + run_fingerprint = { + "tokens": ["ETH", "USDC"], + "arb_frequency": 1, + "startDateString": "2023-01-01 00:00:00", + "endDateString": "2023-01-01 00:05:00", + "endTestDateString": "2023-01-01 00:06:00", + } + start_unix = pd.Timestamp(run_fingerprint["startDateString"]).value // 10**6 + updates = pd.DataFrame( + { + "unix": [start_unix + 30_000, start_unix + 40_000], + "end_unix": [start_unix + 150_000, start_unix + 220_000], + "price_ratio": [4.0, 6.0], + } + ) + + prepared = prepare_dynamic_inputs( + run_fingerprint, + dynamic_input_frames=DynamicInputFrames( + reclamm_price_ratio_updates=updates + ), + ) + flags = prepared["dynamic_input_flags"] + schedule = np.asarray( + prepared["train_dynamic_inputs"].reclamm_price_ratio_updates + ) + + assert flags["has_reclamm_price_ratio_updates"] is True + assert flags["use_dynamic_inputs"] is True + assert schedule.shape == (5, 4) + # Same-step collision should keep the later event. + assert schedule[1, 0] == pytest.approx(1.0) + assert schedule[1, 1] == pytest.approx(6.0) + assert schedule[1, 2] == pytest.approx(4.0) + assert np.isnan(schedule[1, 3]) + assert np.all(schedule[[0, 2, 3, 4], 0] == 0.0) + + def test_prepare_dynamic_inputs_accepts_reclamm_updates_as_list(self): + """List payloads should be accepted for reCLAMM update schedules.""" + from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames + from quantammsim.runners.jax_runner_utils import prepare_dynamic_inputs + + run_fingerprint = { + "tokens": ["ETH", "USDC"], + "arb_frequency": 1, + "startDateString": "2023-01-01 00:00:00", + "endDateString": "2023-01-01 00:03:00", + "endTestDateString": "2023-01-01 00:04:00", + } + start_unix = pd.Timestamp(run_fingerprint["startDateString"]).value // 10**6 + + prepared = prepare_dynamic_inputs( + run_fingerprint, + dynamic_input_frames=DynamicInputFrames( + reclamm_price_ratio_updates=[ + { + "unix": start_unix, + "end_unix": start_unix + 120_000, + "price_ratio": 5.0, + "start_price_ratio": 4.25, + } + ] + ), + ) + schedule = np.asarray( + prepared["train_dynamic_inputs"].reclamm_price_ratio_updates + ) + assert schedule.shape == (3, 4) + assert schedule[0, 0] == pytest.approx(1.0) + assert schedule[0, 1] == pytest.approx(5.0) + assert schedule[0, 2] == pytest.approx(2.0) + assert schedule[0, 3] == pytest.approx(4.25) + + def test_prepare_dynamic_inputs_accepts_reclamm_updates_as_csv_path(self, tmp_path): + """CSV payloads should be accepted for reCLAMM update schedules.""" + from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames + from quantammsim.runners.jax_runner_utils import prepare_dynamic_inputs + + run_fingerprint = { + "tokens": ["ETH", "USDC"], + "arb_frequency": 1, + "startDateString": "2023-01-01 00:00:00", + "endDateString": "2023-01-01 00:03:00", + "endTestDateString": "2023-01-01 00:04:00", + } + start_unix = pd.Timestamp(run_fingerprint["startDateString"]).value // 10**6 + csv_path = tmp_path / "reclamm_updates.csv" + pd.DataFrame( + [ + { + "unix": start_unix + 60_000, + "end_unix": start_unix + 180_000, + "price_ratio": 6.0, + "start_price_ratio": 4.1, + } + ] + ).to_csv(csv_path, index=False) + + prepared = prepare_dynamic_inputs( + run_fingerprint, + dynamic_input_frames=DynamicInputFrames( + reclamm_price_ratio_updates=str(csv_path) + ), + ) + schedule = np.asarray( + prepared["train_dynamic_inputs"].reclamm_price_ratio_updates + ) + assert schedule.shape == (3, 4) + assert schedule[1, 0] == pytest.approx(1.0) + assert schedule[1, 1] == pytest.approx(6.0) + assert schedule[1, 2] == pytest.approx(2.0) + assert schedule[1, 3] == pytest.approx(4.1) + + def test_prepare_dynamic_inputs_rejects_prewindow_reclamm_events_without_start_ratio(self): + """In-progress pre-window events must provide start_price_ratio.""" + from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames + from quantammsim.runners.jax_runner_utils import prepare_dynamic_inputs + + run_fingerprint = { + "tokens": ["ETH", "USDC"], + "arb_frequency": 1, + "startDateString": "2023-01-01 00:00:00", + "endDateString": "2023-01-01 00:05:00", + "endTestDateString": "2023-01-01 00:06:00", + } + start_unix = pd.Timestamp(run_fingerprint["startDateString"]).value // 10**6 + + with pytest.raises(ValueError, match="start_price_ratio"): + prepare_dynamic_inputs( + run_fingerprint, + dynamic_input_frames=DynamicInputFrames( + reclamm_price_ratio_updates=pd.DataFrame( + { + "unix": [start_unix - 120_000], + "end_unix": [start_unix + 120_000], + "price_ratio": [4.5], + } + ) + ), + ) + + def test_prepare_dynamic_inputs_disables_reclamm_schedule_flag_when_window_has_no_events(self): + """Schedule flag should be disabled when no event lands in train/test windows.""" + from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames + from quantammsim.runners.jax_runner_utils import prepare_dynamic_inputs + + run_fingerprint = { + "tokens": ["ETH", "USDC"], + "arb_frequency": 1, + "startDateString": "2023-01-01 00:00:00", + "endDateString": "2023-01-01 00:05:00", + "endTestDateString": "2023-01-01 00:10:00", + } + start_unix = pd.Timestamp(run_fingerprint["startDateString"]).value // 10**6 + updates = pd.DataFrame( + { + "unix": [start_unix - 300_000], + "end_unix": [start_unix - 60_000], + "price_ratio": [5.0], + "start_price_ratio": [4.0], + } + ) + + prepared = prepare_dynamic_inputs( + run_fingerprint, + dynamic_input_frames=DynamicInputFrames( + reclamm_price_ratio_updates=updates + ), + do_test_period=True, + ) + + flags = prepared["dynamic_input_flags"] + assert flags["has_reclamm_price_ratio_updates"] is False + assert flags["use_dynamic_inputs"] is False + assert prepared["train_dynamic_inputs"].reclamm_price_ratio_updates.shape == (1, 4) + assert prepared["test_dynamic_inputs"].reclamm_price_ratio_updates.shape == (1, 4) + def test_prepare_dynamic_inputs_uses_correct_test_period_values(self): """Test-period arrays should use values effective from the test window onward.""" from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames @@ -413,6 +599,7 @@ def test_resolve_dynamic_input_flags_promotes_explicit_bundle(self): "has_dynamic_gas_cost": False, "has_dynamic_arb_fees": False, "has_lp_supply": False, + "has_reclamm_price_ratio_updates": False, }, ) @@ -450,6 +637,7 @@ def test_resolve_dynamic_input_components_prefers_dynamic_values(self): gas_cost=jnp.array([3.0]), arb_fees=jnp.array([0.0003]), lp_supply=jnp.array([1500.0]), + reclamm_price_ratio_updates=jnp.array([[1.0, 4.0, 3.0, jnp.nan]]), ) flags = { "use_dynamic_inputs": True, @@ -458,6 +646,7 @@ def test_resolve_dynamic_input_components_prefers_dynamic_values(self): "has_dynamic_gas_cost": True, "has_dynamic_arb_fees": True, "has_lp_supply": True, + "has_reclamm_price_ratio_updates": True, } resolved = resolve_dynamic_input_components( @@ -471,6 +660,11 @@ def test_resolve_dynamic_input_components_prefers_dynamic_values(self): np.testing.assert_allclose(np.asarray(resolved["gas_cost"]), np.array([3.0])) np.testing.assert_allclose(np.asarray(resolved["arb_fees"]), np.array([0.0003])) np.testing.assert_allclose(np.asarray(resolved["lp_supply"]), np.array([1500.0])) + np.testing.assert_allclose( + np.asarray(resolved["reclamm_price_ratio_updates"]), + np.array([[1.0, 4.0, 3.0, np.nan]]), + equal_nan=True, + ) def test_materialize_dynamic_inputs_leaves_trades_optional(self): """No-trade paths should not expand placeholder trades into the scan inputs.""" @@ -488,6 +682,7 @@ def test_materialize_dynamic_inputs_leaves_trades_optional(self): "has_dynamic_gas_cost": False, "has_dynamic_arb_fees": False, "has_lp_supply": False, + "has_reclamm_price_ratio_updates": False, }, static_dict={"fees": 0.003, "gas_cost": 2.5, "arb_fees": 0.0001}, scan_len=4, @@ -499,6 +694,18 @@ def test_materialize_dynamic_inputs_leaves_trades_optional(self): np.testing.assert_allclose(np.asarray(materialized.gas_cost), np.full(4, 2.5)) np.testing.assert_allclose(np.asarray(materialized.arb_fees), np.full(4, 0.0001)) np.testing.assert_allclose(np.asarray(materialized.lp_supply), np.ones(4)) + np.testing.assert_allclose( + np.asarray(materialized.reclamm_price_ratio_updates), + np.array( + [ + [0.0, 0.0, 0.0, np.nan], + [0.0, 0.0, 0.0, np.nan], + [0.0, 0.0, 0.0, np.nan], + [0.0, 0.0, 0.0, np.nan], + ] + ), + equal_nan=True, + ) def test_materialize_dynamic_inputs_requires_trades_when_enabled(self): """Trade-enabled scans should fail fast if no trade path is available.""" @@ -517,6 +724,7 @@ def test_materialize_dynamic_inputs_requires_trades_when_enabled(self): "has_dynamic_gas_cost": False, "has_dynamic_arb_fees": False, "has_lp_supply": False, + "has_reclamm_price_ratio_updates": False, }, static_dict={"fees": 0.003, "gas_cost": 0.0, "arb_fees": 0.0}, scan_len=2, diff --git a/tests/unit/test_jax_runners_comprehensive.py b/tests/unit/test_jax_runners_comprehensive.py index 995e6d8..5269022 100644 --- a/tests/unit/test_jax_runners_comprehensive.py +++ b/tests/unit/test_jax_runners_comprehensive.py @@ -606,6 +606,115 @@ def test_provided_coarse_weights_respect_scalar_and_dynamic_gas(self, defaulted_ atol=1e-6, ) + def test_reclamm_schedule_only_dynamic_input_changes_path(self, defaulted_run_fingerprint): + """Schedule-only reCLAMM dynamic inputs should alter reserve trajectory.""" + fp = deepcopy(defaulted_run_fingerprint) + fp["rule"] = "reclamm" + fp["do_arb"] = True + fp["fees"] = 0.0 + fp["gas_cost"] = 0.0 + fp["arb_fees"] = 0.0 + fp["reclamm_interpolation_method"] = "geometric" + params = { + "price_ratio": jnp.array(4.0), + "centeredness_margin": jnp.array(0.2), + "daily_price_shift_base": jnp.array(1.0 - 1.0 / 124000.0), + } + + start_unix = pd.Timestamp(fp["startDateString"]).value // 10**6 + schedule_df = pd.DataFrame( + { + "unix": [start_unix + 2 * 60_000], + "end_unix": [start_unix + 8 * 60_000], + "price_ratio": [7.5], + } + ) + + baseline = do_run_on_historic_data( + fp, + params=params, + root=TEST_DATA_DIR, + verbose=False, + ) + with_schedule = do_run_on_historic_data( + fp, + params=params, + root=TEST_DATA_DIR, + verbose=False, + dynamic_input_frames=DynamicInputFrames( + reclamm_price_ratio_updates=schedule_df + ), + ) + + assert not np.allclose( + np.asarray(baseline["reserves"]), + np.asarray(with_schedule["reserves"]), + ) + + def test_reclamm_schedule_test_period_does_not_leak_into_train(self, defaulted_run_fingerprint): + """Test-only schedule updates should affect test path only.""" + fp = deepcopy(defaulted_run_fingerprint) + fp["rule"] = "reclamm" + fp["do_arb"] = True + fp["fees"] = 0.0 + fp["gas_cost"] = 0.0 + fp["arb_fees"] = 0.0 + params = { + "price_ratio": jnp.array(4.0), + "centeredness_margin": jnp.array(0.2), + "daily_price_shift_base": jnp.array(1.0 - 1.0 / 124000.0), + } + + test_start_unix = pd.Timestamp(fp["endDateString"]).value // 10**6 + baseline_updates = pd.DataFrame( + { + "unix": [test_start_unix + 60_000], + "end_unix": [test_start_unix + 6 * 60_000], + # Keep baseline on the initial ratio so it is a no-op schedule. + "price_ratio": [4.0], + "start_price_ratio": [4.0], + } + ) + test_only_updates = pd.DataFrame( + { + "unix": [test_start_unix + 60_000], + "end_unix": [test_start_unix + 6 * 60_000], + "price_ratio": [8.0], + } + ) + + train_base, test_base = do_run_on_historic_data( + fp, + params=params, + root=TEST_DATA_DIR, + verbose=False, + do_test_period=True, + dynamic_input_frames=DynamicInputFrames( + reclamm_price_ratio_updates=baseline_updates + ), + ) + train_sched, test_sched = do_run_on_historic_data( + fp, + params=params, + root=TEST_DATA_DIR, + verbose=False, + do_test_period=True, + dynamic_input_frames=DynamicInputFrames( + reclamm_price_ratio_updates=test_only_updates + ), + ) + + np.testing.assert_allclose( + np.asarray(train_sched["value"]), + np.asarray(train_base["value"]), + rtol=1e-6, + atol=1e-6, + ) + assert not np.allclose( + np.asarray(test_sched["value"]), + np.asarray(test_base["value"]), + ) + # ============================================================================ # Validation and Early Stopping Tests From 955dca3055868178d04f3d47caf2bac1002921c0 Mon Sep 17 00:00:00 2001 From: christian harrington Date: Thu, 5 Mar 2026 15:51:50 +0000 Subject: [PATCH 2/3] add fixes for linear interpolation bug and tests --- quantammsim/pools/reCLAMM/reclamm_reserves.py | 27 +++-- .../test_reclamm_price_ratio_updates.py | 107 +++++++++++++++++- 2 files changed, 126 insertions(+), 8 deletions(-) diff --git a/quantammsim/pools/reCLAMM/reclamm_reserves.py b/quantammsim/pools/reCLAMM/reclamm_reserves.py index b098eb1..5f1ac85 100644 --- a/quantammsim/pools/reCLAMM/reclamm_reserves.py +++ b/quantammsim/pools/reCLAMM/reclamm_reserves.py @@ -791,6 +791,9 @@ def _apply_schedule_state(_): event_has, jnp.maximum(event_end_step, step_idx), active_end_step ) next_active_enabled = jnp.where(event_has, True, active_enabled) + next_active_enabled = jnp.logical_and( + next_active_enabled, step_idx <= next_active_end_step + ) schedule_duration = next_active_end_step - next_active_start_step schedule_progress = jnp.where( @@ -798,16 +801,22 @@ def _apply_schedule_state(_): 1.0, jnp.clip((step_idx - next_active_start_step) / schedule_duration, 0.0, 1.0), ) - scheduled_price_ratio = ( - next_active_start_ratio - + (next_active_target_ratio - next_active_start_ratio) * schedule_progress + safe_start_ratio = jnp.maximum(next_active_start_ratio, 1.0 + 1e-12) + safe_target_ratio = jnp.maximum(next_active_target_ratio, 1.0 + 1e-12) + scheduled_price_ratio = safe_start_ratio * ( + safe_target_ratio / safe_start_ratio + ) ** schedule_progress + scheduled_price_ratio = jnp.where( + next_active_enabled, scheduled_price_ratio, current_price_ratio ) Va_scheduled, Vb_scheduled = apply_target_price_ratio_to_virtual_balances( Ra, Rb, Va, Vb, scheduled_price_ratio ) + next_Va = jnp.where(next_active_enabled, Va_scheduled, Va) + next_Vb = jnp.where(next_active_enabled, Vb_scheduled, Vb) return ( - Va_scheduled, - Vb_scheduled, + next_Va, + next_Vb, next_active_start_ratio, next_active_target_ratio, next_active_start_step, @@ -816,6 +825,9 @@ def _apply_schedule_state(_): ) def _skip_schedule_state(_): + retained_active_enabled = jnp.logical_and( + active_enabled, step_idx <= active_end_step + ) return ( Va, Vb, @@ -823,10 +835,11 @@ def _skip_schedule_state(_): active_target_ratio, active_start_step, active_end_step, - active_enabled, + retained_active_enabled, ) - schedule_active = jnp.logical_or(event_has, active_enabled) + active_not_expired = jnp.logical_and(active_enabled, step_idx <= active_end_step) + schedule_active = jnp.logical_or(event_has, active_not_expired) ( Va, Vb, diff --git a/tests/pools/reCLAMM/test_reclamm_price_ratio_updates.py b/tests/pools/reCLAMM/test_reclamm_price_ratio_updates.py index 5ee98bc..db24aca 100644 --- a/tests/pools/reCLAMM/test_reclamm_price_ratio_updates.py +++ b/tests/pools/reCLAMM/test_reclamm_price_ratio_updates.py @@ -21,7 +21,7 @@ DEFAULT_PRICE_RATIO = 4.0 DEFAULT_DAILY_PRICE_SHIFT_BASE = 1.0 - 1.0 / 124000.0 DEFAULT_SECONDS_PER_STEP = 60.0 -ALL_SIG_VARIATIONS_2 = tuple(map(tuple, [[1, -1], [-1, 1]])) +ALL_SIG_VARIATIONS_2 = jnp.array([[1, -1], [-1, 1]]) def _init_pool(price_ratio=DEFAULT_PRICE_RATIO): @@ -143,6 +143,111 @@ def test_single_schedule_reaches_target_ratio_at_end_step(self): ) assert ratio_at_end == pytest.approx(9.0, rel=1e-5, abs=1e-5) + def test_schedule_interpolates_geometrically_in_ratio_space(self): + reserves, Va, Vb = _init_pool() + n_steps = 8 + prices = _flat_prices(n_steps) + fees = jnp.zeros((n_steps,), dtype=jnp.float64) + arb_thresh = jnp.zeros((n_steps,), dtype=jnp.float64) + arb_fees = jnp.zeros((n_steps,), dtype=jnp.float64) + + start_step = 1 + end_step = 5 + start_ratio = 4.0 + target_ratio = 16.0 + schedule = _single_event_schedule( + n_steps, + start_step=start_step, + end_step=end_step, + target_price_ratio=target_ratio, + start_price_ratio_override=start_ratio, + ) + + reserves_out, Va_history, Vb_history = ( + _jax_calc_reclamm_reserves_with_dynamic_inputs_full_state( + reserves, + Va, + Vb, + prices, + centeredness_margin=0.0, + daily_price_shift_base=DEFAULT_DAILY_PRICE_SHIFT_BASE, + seconds_per_step=DEFAULT_SECONDS_PER_STEP, + fees=fees, + arb_thresh=arb_thresh, + arb_fees=arb_fees, + price_ratio_updates=schedule, + all_sig_variations=ALL_SIG_VARIATIONS_2, + ) + ) + + mid_step = 3 + ratio_at_mid = float( + compute_price_ratio( + reserves_out[mid_step, 0], + reserves_out[mid_step, 1], + Va_history[mid_step], + Vb_history[mid_step], + ) + ) + progress = (mid_step - start_step) / (end_step - start_step) + expected_geometric = start_ratio * (target_ratio / start_ratio) ** progress + assert ratio_at_mid == pytest.approx(expected_geometric, rel=1e-5, abs=1e-5) + + def test_schedule_stops_applying_after_end_step(self): + reserves, Va, Vb = _init_pool() + n_steps = 10 + prices = jnp.stack( + [jnp.linspace(DEFAULT_INITIAL_PRICES[0], 5000.0, n_steps), jnp.ones((n_steps,))], + axis=1, + ) + fees = jnp.zeros((n_steps,), dtype=jnp.float64) + arb_thresh = jnp.zeros((n_steps,), dtype=jnp.float64) + arb_fees = jnp.zeros((n_steps,), dtype=jnp.float64) + + end_step = 3 + schedule = _single_event_schedule( + n_steps, + start_step=1, + end_step=end_step, + target_price_ratio=9.0, + start_price_ratio_override=DEFAULT_PRICE_RATIO, + ) + reserves_out, Va_history, Vb_history = ( + _jax_calc_reclamm_reserves_with_dynamic_inputs_full_state( + reserves, + Va, + Vb, + prices, + centeredness_margin=0.0, # disable thermostat to isolate schedule behavior + daily_price_shift_base=DEFAULT_DAILY_PRICE_SHIFT_BASE, + seconds_per_step=DEFAULT_SECONDS_PER_STEP, + fees=fees, + arb_thresh=arb_thresh, + arb_fees=arb_fees, + price_ratio_updates=schedule, + all_sig_variations=ALL_SIG_VARIATIONS_2, + ) + ) + + # Reserves continue evolving under changing market prices... + assert not np.allclose( + np.asarray(reserves_out[end_step + 1]), + np.asarray(reserves_out[end_step + 2]), + ) + # ...but virtual balances should be frozen once the schedule has ended. + npt.assert_allclose( + np.asarray(Va_history[end_step + 1 :]), + np.full((n_steps - (end_step + 1),), float(Va_history[end_step])), + rtol=1e-9, + atol=1e-9, + ) + npt.assert_allclose( + np.asarray(Vb_history[end_step + 1 :]), + np.full((n_steps - (end_step + 1),), float(Vb_history[end_step])), + rtol=1e-9, + atol=1e-9, + ) + def test_replacement_event_supersedes_active_event(self): reserves, Va, Vb = _init_pool() n_steps = 9 From 8371fc43ef9dbc9c6882b31802e5408b95d0b6a3 Mon Sep 17 00:00:00 2001 From: MatthewWilletts <16085750+MatthewWilletts@users.noreply.github.com> Date: Mon, 9 Mar 2026 13:10:30 +0000 Subject: [PATCH 3/3] fix: use contract's centeredness-preserving formula for price ratio updates Replace the ad-hoc "keep overvalued, solve undervalued" virtual balance recalculation with the closed-form quadratic from ReClammMath.sol computeVirtualBalancesUpdatingPriceRatio. The old code silently drove centeredness to 1.0 for off-center pools. Add parametrized test mirroring the Foundry fuzz test testCalculateVirtualBalancesUpdatingPriceRatio__Fuzz, asserting that centeredness is preserved and the target price ratio is achieved. --- quantammsim/pools/reCLAMM/reclamm_reserves.py | 44 +++++++++++-------- .../test_reclamm_price_ratio_updates.py | 44 +++++++++++++++++++ 2 files changed, 69 insertions(+), 19 deletions(-) diff --git a/quantammsim/pools/reCLAMM/reclamm_reserves.py b/quantammsim/pools/reCLAMM/reclamm_reserves.py index 5f1ac85..3825f11 100644 --- a/quantammsim/pools/reCLAMM/reclamm_reserves.py +++ b/quantammsim/pools/reCLAMM/reclamm_reserves.py @@ -557,32 +557,38 @@ def initialise_reclamm_reserves(initial_pool_value, initial_prices, price_ratio) # --------------------------------------------------------------------------- def apply_target_price_ratio_to_virtual_balances(Ra, Rb, Va, Vb, target_price_ratio): - """Retarget virtual balances to a desired price ratio while preserving orientation. + """Retarget virtual balances to a desired price ratio while preserving centeredness. - The overvalued-side virtual balance is preserved (subject to floor), and the - undervalued-side virtual balance is solved from the reCLAMM ratio constraint. + Uses the closed-form quadratic solution from ReClammMath.sol + ``computeVirtualBalancesUpdatingPriceRatio``: + + Vu = Ru * (1 + C + sqrt(1 + C*(C + 4*Q0 - 2))) / (2*(Q0 - 1)) + Vo = Vu * lastVo / lastVu + + where Q0 = sqrt(price_ratio), C = centeredness, Ru is the real balance of + the undervalued token. The overvalued virtual balance is then scaled + proportionally so that Va/Vb is preserved, which keeps centeredness constant. """ safe_ratio = jnp.maximum(target_price_ratio, 1.0 + 1e-12) - sqrt_ratio = jnp.sqrt(safe_ratio) - fourth_root_ratio = jnp.sqrt(sqrt_ratio) + Q0 = jnp.sqrt(safe_ratio) # sqrt(price_ratio) centeredness, is_above = compute_centeredness(Ra, Rb, Va, Vb) + C = centeredness - # Above center => B overvalued, so keep Vb and solve Va. - v_over_b_floor = Rb / jnp.maximum(fourth_root_ratio - 1.0, 1e-30) - Vb_kept = jnp.maximum(Vb, v_over_b_floor) - Va_from_b = Ra * (Vb_kept + Rb) / jnp.maximum( - (sqrt_ratio - 1.0) * Vb_kept - Rb, 1e-30 - ) + # Closed-form quadratic solution for the undervalued virtual balance. + discriminant = jnp.maximum(1.0 + C * (C + 4.0 * Q0 - 2.0), 0.0) + numerator_factor = 1.0 + C + jnp.sqrt(discriminant) + denominator = 2.0 * jnp.maximum(Q0 - 1.0, 1e-30) - # Below center => A overvalued, so keep Va and solve Vb. - v_over_a_floor = Ra / jnp.maximum(fourth_root_ratio - 1.0, 1e-30) - Va_kept = jnp.maximum(Va, v_over_a_floor) - Vb_from_a = Rb * (Va_kept + Ra) / jnp.maximum( - (sqrt_ratio - 1.0) * Va_kept - Ra, 1e-30 - ) + # Above center: A is undervalued (Ra abundant), B is overvalued. + Vu_above = Ra * numerator_factor / denominator # new Va + Vo_above = Vu_above * Vb / jnp.maximum(Va, 1e-30) # new Vb, scaled + + # Below center: B is undervalued (Rb abundant), A is overvalued. + Vu_below = Rb * numerator_factor / denominator # new Vb + Vo_below = Vu_below * Va / jnp.maximum(Vb, 1e-30) # new Va, scaled - Va_new = jnp.where(is_above, Va_from_b, Va_kept) - Vb_new = jnp.where(is_above, Vb_kept, Vb_from_a) + Va_new = jnp.where(is_above, Vu_above, Vo_below) + Vb_new = jnp.where(is_above, Vo_above, Vu_below) # When centeredness is degenerate (e.g. both sides zero), preserve current virtuals. invalid_centeredness = ~jnp.isfinite(centeredness) diff --git a/tests/pools/reCLAMM/test_reclamm_price_ratio_updates.py b/tests/pools/reCLAMM/test_reclamm_price_ratio_updates.py index db24aca..955d16d 100644 --- a/tests/pools/reCLAMM/test_reclamm_price_ratio_updates.py +++ b/tests/pools/reCLAMM/test_reclamm_price_ratio_updates.py @@ -6,6 +6,8 @@ import pytest from quantammsim.pools.reCLAMM.reclamm_reserves import ( + apply_target_price_ratio_to_virtual_balances, + compute_centeredness, compute_price_ratio, initialise_reclamm_reserves, _jax_calc_reclamm_reserves_with_dynamic_inputs, @@ -290,6 +292,48 @@ def test_replacement_event_supersedes_active_event(self): ) assert ratio_after_replacement == pytest.approx(2.0, rel=1e-4, abs=1e-4) + @pytest.mark.parametrize( + "Ra, Rb, Va, Vb, target_price_ratio", + [ + # Centered pool, widen ratio + (500_000.0, 500_000.0, 100_000.0, 100_000.0, 9.0), + # Centered pool, narrow ratio + (500_000.0, 500_000.0, 100_000.0, 100_000.0, 2.0), + # Above center (Ra abundant) + (800_000.0, 200_000.0, 100_000.0, 100_000.0, 16.0), + # Below center (Rb abundant) + (200_000.0, 800_000.0, 100_000.0, 100_000.0, 16.0), + # Asymmetric virtuals + (300_000.0, 600_000.0, 50_000.0, 200_000.0, 5.0), + # Large ratio change + (500_000.0, 500_000.0, 100_000.0, 100_000.0, 100.0), + ], + ) + def test_price_ratio_update_preserves_centeredness( + self, Ra, Rb, Va, Vb, target_price_ratio + ): + """Mirrors Foundry fuzz test testCalculateVirtualBalancesUpdatingPriceRatio__Fuzz. + + Asserts that apply_target_price_ratio_to_virtual_balances preserves + centeredness and achieves the target price ratio. + """ + Ra, Rb = jnp.float64(Ra), jnp.float64(Rb) + Va, Vb = jnp.float64(Va), jnp.float64(Vb) + + old_centeredness, _ = compute_centeredness(Ra, Rb, Va, Vb) + Va_new, Vb_new = apply_target_price_ratio_to_virtual_balances( + Ra, Rb, Va, Vb, target_price_ratio + ) + new_centeredness, _ = compute_centeredness(Ra, Rb, Va_new, Vb_new) + new_price_ratio = float(compute_price_ratio(Ra, Rb, Va_new, Vb_new)) + + assert float(new_centeredness) == pytest.approx( + float(old_centeredness), rel=1e-6, abs=1e-10 + ), "Centeredness should be preserved" + assert new_price_ratio == pytest.approx( + target_price_ratio, rel=1e-4 + ), "Price ratio should match target" + def test_dynamic_fee_revenue_path_with_schedule(self): reserves, Va, Vb = _init_pool() n_steps = 10