Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions quantammsim/core_simulator/dynamic_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class DynamicInputFrames:
gas_cost: Optional[Any] = None
arb_fees: Optional[Any] = None
lp_supply: Optional[Any] = None
reclamm_price_ratio_updates: Optional[Any] = None


class DynamicInputArrays(NamedTuple):
Expand All @@ -23,6 +24,7 @@ class DynamicInputArrays(NamedTuple):
gas_cost: jnp.ndarray
arb_fees: jnp.ndarray
lp_supply: jnp.ndarray
reclamm_price_ratio_updates: jnp.ndarray


def default_dynamic_input_flags() -> dict:
Expand All @@ -34,6 +36,7 @@ def default_dynamic_input_flags() -> dict:
"has_dynamic_gas_cost": False,
"has_dynamic_arb_fees": False,
"has_lp_supply": False,
"has_reclamm_price_ratio_updates": False,
}


Expand All @@ -49,6 +52,9 @@ def dynamic_input_flags_from_frames(dynamic_input_frames: Optional[DynamicInputF
"has_dynamic_gas_cost": dynamic_input_frames.gas_cost is not None,
"has_dynamic_arb_fees": dynamic_input_frames.arb_fees is not None,
"has_lp_supply": dynamic_input_frames.lp_supply is not None,
"has_reclamm_price_ratio_updates": (
dynamic_input_frames.reclamm_price_ratio_updates is not None
),
}
flags["use_dynamic_inputs"] = any(flags.values())
return flags
Expand All @@ -59,11 +65,9 @@ def resolve_dynamic_input_flags(
dynamic_input_flags: Optional[dict] = None,
) -> dict:
"""Return a safe dispatch flag set for the provided hot-path bundle."""
flags = (
default_dynamic_input_flags()
if dynamic_input_flags is None
else dict(dynamic_input_flags)
)
flags = default_dynamic_input_flags()
if dynamic_input_flags is not None:
flags.update(dict(dynamic_input_flags))
if dynamic_inputs is not None:
flags["use_dynamic_inputs"] = True
return flags
Expand All @@ -77,6 +81,10 @@ def empty_dynamic_input_arrays() -> DynamicInputArrays:
gas_cost=jnp.zeros((1,), dtype=jnp.float64),
arb_fees=jnp.zeros((1,), dtype=jnp.float64),
lp_supply=jnp.ones((1,), dtype=jnp.float64),
# Columns: has_event, target_price_ratio, end_step, start_price_ratio_override
reclamm_price_ratio_updates=jnp.array(
[[0.0, 0.0, 0.0, jnp.nan]], dtype=jnp.float64
),
)


Expand Down Expand Up @@ -109,6 +117,11 @@ def resolve_dynamic_input_components(
if dynamic_input_flags["has_lp_supply"]
else jnp.ones((1,), dtype=jnp.float64)
),
"reclamm_price_ratio_updates": (
arrays.reclamm_price_ratio_updates
if dynamic_input_flags["has_reclamm_price_ratio_updates"]
else empty_dynamic_input_arrays().reclamm_price_ratio_updates
),
}


Expand Down Expand Up @@ -148,6 +161,7 @@ def materialize_dynamic_inputs(
"has_dynamic_gas_cost": True,
"has_dynamic_arb_fees": True,
"has_lp_supply": True,
"has_reclamm_price_ratio_updates": True,
}
else:
flags = resolve_dynamic_input_flags(dynamic_inputs, dynamic_input_flags)
Expand All @@ -174,4 +188,10 @@ def materialize_dynamic_inputs(
lp_supply=_broadcast_dynamic_input_leaf(
"lp_supply", resolved["lp_supply"], scan_len, dtype
),
reclamm_price_ratio_updates=_broadcast_dynamic_input_leaf(
"reclamm_price_ratio_updates",
resolved["reclamm_price_ratio_updates"],
scan_len,
dtype,
),
)
3 changes: 3 additions & 0 deletions quantammsim/core_simulator/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,6 +1113,9 @@ def forward_pass_nograd(
gas_cost=stop_gradient(dynamic_inputs.gas_cost),
arb_fees=stop_gradient(dynamic_inputs.arb_fees),
lp_supply=stop_gradient(dynamic_inputs.lp_supply),
reclamm_price_ratio_updates=stop_gradient(
dynamic_inputs.reclamm_price_ratio_updates
),
)
return forward_pass(
params,
Expand Down
1 change: 1 addition & 0 deletions quantammsim/hooks/dynamic_fee_base_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def calculate_reserves_with_fees(
gas_cost=jnp.asarray(run_fingerprint["gas_cost"], dtype=jnp.float64),
arb_fees=jnp.asarray(run_fingerprint["arb_fees"], dtype=jnp.float64),
lp_supply=empty_inputs.lp_supply,
reclamm_price_ratio_updates=empty_inputs.reclamm_price_ratio_updates,
)

return self.calculate_reserves_with_dynamic_inputs(
Expand Down
2 changes: 2 additions & 0 deletions quantammsim/pools/reCLAMM/reclamm.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def calculate_reserves_and_fee_revenue_with_dynamic_inputs(
fees=materialized_inputs.fees,
arb_thresh=materialized_inputs.gas_cost,
arb_fees=materialized_inputs.arb_fees,
price_ratio_updates=materialized_inputs.reclamm_price_ratio_updates,
all_sig_variations=jnp.array(
run_fingerprint["all_sig_variations"]
),
Expand Down Expand Up @@ -387,6 +388,7 @@ def calculate_reserves_with_dynamic_inputs(
fees=materialized_inputs.fees,
arb_thresh=materialized_inputs.gas_cost,
arb_fees=materialized_inputs.arb_fees,
price_ratio_updates=materialized_inputs.reclamm_price_ratio_updates,
all_sig_variations=jnp.array(
run_fingerprint["all_sig_variations"]
),
Expand Down
Loading