diff --git a/.codex b/.codex new file mode 100644 index 0000000..e69de29 diff --git a/docs/joint_calibration_analysis.md b/docs/joint_calibration_analysis.md new file mode 100644 index 0000000..dee2760 --- /dev/null +++ b/docs/joint_calibration_analysis.md @@ -0,0 +1,230 @@ +# Joint Calibration Analysis: MLP Capacity vs Identification + +## Training results (2026-03-10) + +### R2 progression across model architectures + +| Attempt | Architecture | Noise params | Total params | Median R2 | Joint loss | +|---|---|---|---|---|---| +| Structural MoE (numpyro) | 3 archetypes x 8 coeffs | 24 | ~40 | -0.70 | - | +| Linear joint | SharedLinearNoiseHead | 63 | 63 | -0.15 | 9.62 | +| MLP noise joint | MLPNoiseHead(hidden=16) | 255 | 255 | 0.01 | 9.39 | +| Full MLP | MLPHead(cad,16) + MLPNoiseHead(16) | 255 | 377 | -0.02 | 8.59 | +| Option C (per-pool) | PerPoolNoiseHead | 37x8=296 | 37x9=333 | 0.61 | 1.25 (median) | + +The direction is clear: more capacity on the noise side helps substantially +(-0.70 -> -0.15 -> 0.01). Adding cadence capacity (MLP noise -> full MLP) +improved joint loss (9.39 -> 8.59) but not per-pool R2 (-0.02), and didn't +converge within 500 iterations. + +### Convergence concern + +The full MLP explicitly failed to converge (scipy `success=False`). The MLP +noise model converged but only reduced loss from 12.70 to 9.39 — a 26% +reduction vs the linear baseline's 99.5% reduction (2011.77 -> 9.62). This +suggests the MLPs are undertraining. + +Current optimizer settings: +- L-BFGS-B with maxiter=500 +- ftol=1e-10, gtol=1e-8 +- maxcor=10 (L-BFGS memory, scipy default) +- alpha=0.01 for all heads (L2 regularization on weights) +- hidden=16 for all MLPs +- He init for W1, W2=0, b2=pooled OLS / mean of Option C + +### Why the MLPs may not be converging + +1. **maxiter=500 is low for 255-377 params.** L-BFGS-B typically needs + O(1000-5000) iterations for MLP-scale problems. The linear model with + 63 params converges easily in 500; the MLP with 377 params does not. + +2. **maxcor=10 may be too small.** The default L-BFGS memory of 10 past + gradients may not provide a good enough Hessian approximation for 377 + parameters. Increasing to 20-50 can help. + +3. **Regularization alpha=0.01 may be wrong.** With 37 pools and 255 noise + params, the model is overparameterized (255/37 ≈ 7 params per pool). + alpha=0.01 might be too weak (overfitting some pools, underfitting + others) or too strong (preventing the MLP from expressing the necessary + nonlinearity). This is the most important hyperparameter to sweep. + +4. **W2=0 initialization creates a flat starting surface.** Since the MLP + starts as a constant function (output = b2 everywhere), L-BFGS-B must + first learn to differentiate between pools. The initial gradients + through W1 are informative (He init + backprop through ReLU), but the + first few iterations may be slow compared to the linear model which + starts from an OLS warm-start. + +5. **Dead ReLU units.** With He init and k_attr=6 features, some hidden + units may have all-negative pre-activations across the 37 pool + attribute vectors, making them permanently dead with zero gradient. + +6. **Per-pool loss weighting.** All observations contribute equally. + USDC/WETH (1757 obs) dominates RDNT/WETH (89 obs) by 20x. The + optimizer may be fitting a few high-obs pools at the expense of many + low-obs ones. + +## Diagnosis: identification vs convergence + +Two distinct problems: + +1. **Convergence problem** (addressable via hyperparameters): + The MLP isn't reaching its minimum. Fix: more iterations, better + hyperparameters, multiple restarts. + +2. **Identification problem** (addressable via architecture): + Even at the minimum, the shared mapping can't match per-pool R2. + 37 pools is tiny for a nonlinear model. Cadence is idiosyncratic. + Fix: DeltaHead (per-pool residuals with shrinkage), better features. + +These are **independent** problems that compound. We should fix convergence +first (hyperparameter sweep) to understand the true capacity of the current +architecture before adding structural complexity. + +## Hyperparameter sweep design + +### Parameters to sweep + +| Parameter | Current | Sweep values | Rationale | +|---|---|---|---| +| maxiter | 500 | 500, 2000, 5000 | Primary convergence bottleneck | +| alpha (noise) | 0.01 | 0.0001, 0.001, 0.01, 0.1 | Controls overfitting vs underfitting | +| alpha (cadence) | 0.01 | 0.001, 0.01, 0.1 | Separate from noise reg | +| hidden | 16 | 8, 16, 32 | Capacity vs overfitting | +| maxcor | 10 | 10, 30 | L-BFGS Hessian quality | +| loss_type | l2 | l2, huber | Outlier robustness | + +### Sweep strategy + +Full grid is 3 x 4 x 3 x 3 x 2 x 2 = 432 runs. Too many. + +**Phase 1: Fix convergence (1D sweeps)** +- Sweep maxiter = [500, 2000, 5000] with defaults. Cheapest diagnostic. +- If 5000 converges, use that going forward. + +**Phase 2: Regularization (most important)** +- alpha_noise x alpha_cad grid: 4 x 3 = 12 runs at converged maxiter. +- Evaluate both joint loss AND per-pool median R2. + +**Phase 3: Architecture** +- hidden = [8, 16, 32] at best alpha settings: 3 runs. +- loss_type = [l2, huber] at best settings: 2 runs. +- maxcor = [10, 30] at best settings: 2 runs. + +Total: ~22 runs, each ~2-5 min = ~1-2 hours. + +### Metrics to track per run + +- Joint loss (final) +- Joint loss (init) — sanity check +- Converged (bool) +- Number of L-BFGS iterations used +- Per-pool median R2 +- Per-pool mean R2 +- Per-pool R2 distribution (10th, 25th, 50th, 75th, 90th percentiles) +- Wall time + +### What success looks like + +- Converged = True for the full MLP +- Joint loss < 8.0 (below current 8.59) +- Per-pool median R2 > 0.3 (closing the gap toward Option C's 0.61) +- The R2 improvement should be spread across pools, not concentrated + +## Features / data that would help + +### Missing pool attributes (from docs) + +Current features (k_attr=6 after chain dummy removal): +log_fee, mean_log_tvl, log_mcap_product, has_stable, same_asset_type, +weight_imbalance. + +These describe what the pool IS but not the market around it. Cadence is +driven by arbitrage frequency, which depends on: + +| Missing feature | Why it matters | Source | Effort | +|---|---|---|---| +| Block time | Directly limits minimum cadence. Arb=0.25s vs Main=12s | Static per chain | Trivial | +| Mean pair volatility | Pool-level (not obs-level) vol predicts arb intensity | Binance minute data (loaded) | Small | +| CEX daily volume | More CEX vol = more arb opportunities | Binance API | Medium | +| Competing DEX pools | More pools for same pair = faster arb | Balancer subgraph | Medium | +| Pool routing share | Dominant pool gets arbitraged first | DEX aggregator data | Hard | +| Mean daily swap count | Direct proxy for pool activity | Panel data | Small | + +The pair-intrinsic formula bias (1.26-2.22x) documented in +noise_calibration_review.md is the largest unexplained variance source. +It varies with pair liquidity characteristics in ways that the current +token classification doesn't capture. CEX volume/depth would help. + +### Observation-level features (x_obs, K_OBS=8) + +Current: [1, log_tvl_lag1, log_sigma, tvl*sigma, tvl*fee, sigma*fee, +dow_sin, dow_cos] + +Missing: +- Rolling CEX volume (daily) — high volume days have more noise/organic flow +- Gas price that day (mainnet) — affects whether arbs execute +- Market regime (rolling momentum) — trending vs mean-reverting +- Number of swaps that day — direct activity measure + +### Time-varying dynamics + +Panel spans 2021-2026. MEV dynamics changed dramatically: +- Flashbots launched mid-2021 +- L2s matured 2023-2024 +- EIP-4844 (March 2024) dropped L2 gas costs +The current model assumes constant cadence per pool over this period. + +## Structural improvements (post-sweep) + +### DeltaHead (per-pool residuals with shrinkage) + +Most important structural change. For cadence: +``` +log_cadence_i = f(x_attr_i) + delta_i +regularization: alpha_shared * ||W||^2 + alpha_delta * sum(delta_i^2) +``` + +At alpha_delta=0: pure per-pool (Option C) +At alpha_delta=inf: pure shared (current joint) +Cross-validate alpha_delta. + +For new pools: predict f(x_attr_new) with delta=0. + +This is essentially a mixed-effects model fitted end-to-end through the +grid interpolation loss. + +### Per-pool loss weighting + +Weight each pool's contribution by 1/sqrt(n_obs_i) to equalize pool-level +influence. Currently USDC/WETH (1757 obs) has 20x the influence of any +Sonic pool (89 obs). + +### Hybrid: per-pool cadence + shared noise + +Cadence is idiosyncratic (LOO R2 = 0.24 at best). Noise structure is +more regular (hierarchical model R2 = 0.71 on total volume). Natural split: +- Cadence: per-pool (Option C) +- Noise: shared MLP (generalizable) +- Gas: fixed to chain values + +### Sensitivity analysis (the decision point) + +Before investing more in mapping improvement: does reCLAMM optimal +concentration change materially when cadence varies +/-50%? This is +recommendation #1 in calibration_results.md, noise_calibration_review.md, +and joint_calibration_design.md. Still not done. + +If the optimum is robust, the current pipeline (Option C + Ridge LOO) is +already sufficient and further mapping improvement is nice-to-have. + +## Priority order + +1. **Hyperparameter sweep** — fix convergence before changing architecture +2. **DeltaHead** — if R2 gap persists post-sweep, this is the minimal + structural change +3. **Per-pool loss weighting** — simple fix, helps all joint models +4. **Add block_time and mean_pair_volatility** — high-signal, low-effort + features +5. **Sensitivity analysis** — the real decision point for whether any of + this matters for the downstream task diff --git a/experiments/fetch_competitor_tvl.py b/experiments/fetch_competitor_tvl.py new file mode 100644 index 0000000..7ad3cb3 --- /dev/null +++ b/experiments/fetch_competitor_tvl.py @@ -0,0 +1,513 @@ +"""Fetch competitor TVL for each token pair from DeFi Llama. + +For each of our 36 calibration pools, finds all other DEX pools trading +the same token pair, sums their daily TVL, and saves as a time series. + +K_i(t) = sum_{j != i} TVL_j(t) for all pools trading pool i's pair + +Output: results/competitor_tvl/competitor_tvl.npz + - pool_ids: list of our pool IDs + - dates: array of dates (days since epoch or ISO strings) + - competitor_tvl: (n_dates, n_pools) array of daily competitor TVL in USD + +Usage: + python experiments/fetch_competitor_tvl.py + python experiments/fetch_competitor_tvl.py --cache-dir results/competitor_tvl +""" + +import argparse +import json +import os +import pickle +import sys +import time +from collections import defaultdict + +import numpy as np +import pandas as pd + +CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "token_factored_calibration", "_cache", +) + +# Map Balancer token names to DeFi Llama symbol conventions +# DeFi Llama symbols are typically uppercase, no dots, no "W" prefix inconsistencies +SYMBOL_MAP = { + "WETH": "WETH", + "WBTC": "WBTC", + "wstETH": "WSTETH", + "waEthLidowstETH": "WSTETH", + "waEthLidoWETH": "WETH", + "waGnowstETH": "WSTETH", + "waGnoGNO": "GNO", + "waBasUSDC": "USDC", + "waBasWETH": "WETH", + "sDAI": "DAI", + "scUSD": "USDC", + "stS": "S", + "JitoSOL": "JITOSOL", + # Common DeFi Llama variants + "USDC.e": "USDC", + "USDT.e": "USDT", + "WETH.e": "WETH", + "WBTC.e": "WBTC", +} + + +def _normalize_symbol(token): + """Normalize Balancer token name to DeFi Llama symbol.""" + return SYMBOL_MAP.get(token, token.upper()) + + +def _fetch_json(url, retries=5, delay=3.0): + """Fetch JSON from URL with exponential backoff.""" + import urllib.request + for attempt in range(retries): + try: + req = urllib.request.Request(url) + req.add_header("User-Agent", "quantammsim/1.0") + with urllib.request.urlopen(req, timeout=30) as resp: + return json.loads(resp.read().decode()) + except Exception as e: + wait = delay * (2 ** attempt) # 3, 6, 12, 24, 48s + if attempt < retries - 1: + print(f" Retry {attempt+1} (wait {wait:.0f}s): {e}") + time.sleep(wait) + else: + raise + + +def fetch_all_pools(local_path=None): + """Load DeFi Llama yield pools from local file or API.""" + if local_path and os.path.exists(local_path): + print(f"Loading DeFi Llama pools from {local_path}...") + with open(local_path) as f: + data = json.load(f) + else: + print("Fetching DeFi Llama pool list from API...") + data = _fetch_json("https://yields.llama.fi/pools") + pools = data.get("data", []) if isinstance(data, dict) else data + print(f" {len(pools)} pools") + return pools + + +# Map Balancer chain names to DeFi Llama chain names +CHAIN_MAP = { + "mainnet": "Ethereum", + "ethereum": "Ethereum", + "arbitrum": "Arbitrum", + "polygon": "Polygon", + "gnosis": "Gnosis", + "base": "Base", + "optimism": "Optimism", + "avalanche": "Avalanche", + "sonic": "Sonic", +} + + +def match_pools(our_pools, llama_pools): + """Match our token pairs to DeFi Llama pools. + + Returns dict: pool_id -> {pair_key, chain, llama_pools_same_chain, + llama_pools_all_chains, tokens} + """ + # Normalize DeFi Llama token symbols to match our convention + LLAMA_NORMALIZE = { + "ETH": "WETH", + "BTC": "WBTC", + "STETH": "WSTETH", + } + + # Index llama pools by (pair, chain) + pair_chain_to_llama = defaultdict(list) + pair_to_llama = defaultdict(list) + for p in llama_pools: + symbol = p.get("symbol", "") + if not symbol or "-" not in symbol: + continue + tokens = symbol.split("-") + if len(tokens) != 2: + continue + normed = [LLAMA_NORMALIZE.get(t.upper(), t.upper()) for t in tokens] + pair_key = tuple(sorted(normed)) + chain = p.get("chain", "") + pair_to_llama[pair_key].append(p) + pair_chain_to_llama[(pair_key, chain)].append(p) + + from quantammsim.calibration.pool_data import _parse_tokens + + matches = {} + for pid, entry in our_pools.items(): + toks = _parse_tokens(entry["tokens"]) + tok_a = _normalize_symbol(toks[0]) + tok_b = _normalize_symbol(toks[1]) if len(toks) > 1 else tok_a + pair_key = tuple(sorted([tok_a, tok_b])) + + our_chain = entry.get("chain", "mainnet") + llama_chain = CHAIN_MAP.get(our_chain.lower(), our_chain) + + matches[pid] = { + "pair_key": pair_key, + "chain": llama_chain, + "llama_pools_same_chain": pair_chain_to_llama.get( + (pair_key, llama_chain), []), + "llama_pools_all_chains": pair_to_llama.get(pair_key, []), + "tokens": (tok_a, tok_b), + } + + return matches, pair_chain_to_llama + + +def fetch_pool_history(pool_id): + """Fetch daily TVL history for a DeFi Llama pool.""" + url = f"https://yields.llama.fi/chart/{pool_id}" + data = _fetch_json(url) + points = data.get("data", []) + if not points: + return None + + dates = [] + tvls = [] + for p in points: + ts = p.get("timestamp", "")[:10] + tvl = p.get("tvlUsd", 0) + if ts and tvl is not None: + dates.append(pd.Timestamp(ts)) + tvls.append(float(tvl)) + + return pd.Series(tvls, index=pd.DatetimeIndex(dates), name="tvl") + + +def main(): + parser = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--cache-dir", default="results/competitor_tvl") + parser.add_argument("--max-pools-per-pair", type=int, default=30, + help="Max DeFi Llama pools to fetch per pair") + parser.add_argument("--min-tvl", type=float, default=10000, + help="Skip pools with current TVL below this") + parser.add_argument("--pools-json", default=None, + help="Local DeFi Llama pools.json (skip API fetch)") + args = parser.parse_args() + + os.makedirs(args.cache_dir, exist_ok=True) + + # Load our pools + with open(os.path.join(CACHE_DIR, "stage1.pkl"), "rb") as f: + stage1 = pickle.load(f) + matched_clean = stage1["matched_clean"] + pool_ids = sorted(matched_clean.keys()) + print(f"Our pools: {len(pool_ids)}") + + # Fetch DeFi Llama pools + llama_pools = fetch_all_pools(args.pools_json) + + # Match + matches, pair_chain_to_llama = match_pools(matched_clean, llama_pools) + + # Summary + print(f"\nPair matching (same-chain / all-chains):") + seen = set() + for pid in pool_ids: + m = matches[pid] + pair = m["pair_key"] + chain = m["chain"] + key = (pair, chain) + if key in seen: + continue + seen.add(key) + toks = matched_clean[pid].get("tokens", "?") + n_same = len(m["llama_pools_same_chain"]) + n_all = len(m["llama_pools_all_chains"]) + tvl_same = sum(p.get("tvlUsd", 0) for p in m["llama_pools_same_chain"]) + tvl_all = sum(p.get("tvlUsd", 0) for p in m["llama_pools_all_chains"]) + print(f" {'/'.join(pair):>20s} {chain:>10s}" + f" same={n_same:>3d} (${tvl_same/1e6:>7.1f}M)" + f" all={n_all:>3d} (${tvl_all/1e6:>7.1f}M)" + f" [{toks}]") + + # Flag zero-match pairs — likely symbol mapping issues + zero_matches = [] + for pid in pool_ids: + m = matches[pid] + if len(m["llama_pools_same_chain"]) == 0 and len(m["llama_pools_all_chains"]) == 0: + toks = matched_clean[pid].get("tokens", "?") + zero_matches.append((pid[:16], toks, m["pair_key"], m["chain"])) + if zero_matches: + print(f"\n WARNING: {len(zero_matches)} pools with zero DeFi Llama matches" + f" (check SYMBOL_MAP):") + for pid, toks, pair, chain in zero_matches: + print(f" {pid} {toks:>20s} → {'/'.join(pair)} ({chain})") + + # Fetch historical TVL for each (pair, chain) combination + print(f"\nFetching historical TVL (same-chain)...") + # Key: (pair, chain) -> pd.Series of daily total TVL + pair_chain_histories = {} + + fetched = set() + for pid in pool_ids: + m = matches[pid] + pair = m["pair_key"] + chain = m["chain"] + key = (pair, chain) + if key in fetched: + continue + fetched.add(key) + + # Same-chain pools, sorted by TVL + llama = sorted(m["llama_pools_same_chain"], + key=lambda p: p.get("tvlUsd", 0), reverse=True) + llama = [p for p in llama if p.get("tvlUsd", 0) >= args.min_tvl] + llama = llama[:args.max_pools_per_pair] + + cache_name = f"{'_'.join(pair)}_{chain}_history.pkl" + pair_cache = os.path.join(args.cache_dir, cache_name) + + if not llama: + print(f" {'/'.join(pair)} ({chain}): no qualifying pools") + continue + + if os.path.exists(pair_cache): + with open(pair_cache, "rb") as f: + pair_chain_histories[key] = pickle.load(f) + print(f" {'/'.join(pair)} ({chain}): loaded from cache" + f" ({len(pair_chain_histories[key])} days)") + continue + + print(f" {'/'.join(pair)} ({chain}): fetching {len(llama)} pools...", + end="", flush=True) + pool_series = [] + for lp in llama: + lid = lp["pool"] + try: + hist = fetch_pool_history(lid) + if hist is not None and len(hist) > 10: + pool_series.append(hist) + except Exception as e: + print(f"\n Skip {lid}: {e}", end="") + time.sleep(3.0) # rate limit — DeFi Llama allows ~1 req/3s + print(f" got {len(pool_series)} histories") + + if pool_series: + df = pd.concat(pool_series, axis=1).sort_index() + pair_chain_histories[key] = df.sum(axis=1) # skipna=True: pre-launch = 0 + + with open(pair_cache, "wb") as f: + pickle.dump(pair_chain_histories[key], f) + + # --- Network conductance: fetch hub-pair TVL for multi-hop K --- + HUB_TOKENS = ["WETH", "WSTETH", "USDC", "USDT", "DAI", "WBTC"] + + # Identify all (token, hub) pairs we need across all pools + hub_pairs_needed = set() + for pid in pool_ids: + m = matches[pid] + tok_a, tok_b = m["tokens"] + chain = m["chain"] + for hub in HUB_TOKENS: + if hub in (tok_a, tok_b): + continue + # Need L(tok_a, hub) and L(hub, tok_b) on same chain + pair_ah = tuple(sorted([tok_a, hub])) + pair_hb = tuple(sorted([hub, tok_b])) + hub_pairs_needed.add((pair_ah, chain)) + hub_pairs_needed.add((pair_hb, chain)) + + # Remove pairs we already have + hub_pairs_to_fetch = hub_pairs_needed - fetched + print(f"\nFetching hub-pair TVL for network conductance...") + print(f" {len(hub_pairs_needed)} hub pairs needed," + f" {len(hub_pairs_to_fetch)} to fetch") + + for pair, chain in sorted(hub_pairs_to_fetch): + key = (pair, chain) + cache_name = f"{'_'.join(pair)}_{chain}_history.pkl" + pair_cache = os.path.join(args.cache_dir, cache_name) + + if os.path.exists(pair_cache): + with open(pair_cache, "rb") as f: + pair_chain_histories[key] = pickle.load(f) + continue + + # Find matching DeFi Llama pools + llama = pair_chain_to_llama.get((pair, chain), []) + llama = sorted(llama, key=lambda p: p.get("tvlUsd", 0), reverse=True) + llama = [p for p in llama if p.get("tvlUsd", 0) >= args.min_tvl] + llama = llama[:args.max_pools_per_pair] + + if not llama: + continue + + print(f" {'/'.join(pair)} ({chain}): fetching {len(llama)} pools...", + end="", flush=True) + pool_series = [] + for lp in llama: + lid = lp["pool"] + try: + hist = fetch_pool_history(lid) + if hist is not None and len(hist) > 10: + pool_series.append(hist) + except Exception as e: + print(f"\n Skip {lid}: {e}", end="") + time.sleep(3.0) + print(f" got {len(pool_series)} histories") + + if pool_series: + df = pd.concat(pool_series, axis=1).sort_index() + pair_chain_histories[key] = df.sum(axis=1) # skipna=True: pre-launch = 0 + with open(pair_cache, "wb") as f: + pickle.dump(pair_chain_histories[key], f) + + # Build per-pool competitor TVL arrays aligned to our panel dates + print(f"\nBuilding competitor TVL arrays...") + + # Common date grid + all_dates = set() + for pid in pool_ids: + all_dates.update(matched_clean[pid]["panel"]["date"].values) + date_list = sorted(all_dates) + n_dates = len(date_list) + date_to_idx = {d: i for i, d in enumerate(date_list)} + n_pools = len(pool_ids) + + competitor_tvl = np.full((n_dates, n_pools), np.nan) + + for j, pid in enumerate(pool_ids): + m = matches[pid] + pair = m["pair_key"] + chain = m["chain"] + key = (pair, chain) + if key not in pair_chain_histories: + continue + + hist = pair_chain_histories[key] + panel = matched_clean[pid]["panel"] + panel_dates = panel["date"].values + # Own TVL for self-exclusion: K_i = total_pair_tvl - own_tvl + own_tvl = np.exp(panel["log_tvl_lag1"].values.astype(float)) + + for k, date in enumerate(panel_dates): + t = date_to_idx[date] + day = pd.Timestamp(date).normalize() + if day in hist.index: + total = hist.loc[day] + own = own_tvl[k] if k < len(own_tvl) else 0 + # Competitor TVL = total pair TVL - own TVL (floor at 0) + competitor_tvl[t, j] = max(total - own, 0) + + # Forward-fill then back-fill gaps + for j in range(n_pools): + col = competitor_tvl[:, j] + mask = np.isfinite(col) + if mask.any() and not mask.all(): + s = pd.Series(col, index=date_list).ffill().bfill() + competitor_tvl[:, j] = s.values + + valid = np.isfinite(competitor_tvl) + n_valid = valid.sum() + n_total = n_dates * n_pools + print(f" Coverage: {n_valid}/{n_total} ({100*n_valid/n_total:.0f}%)") + + # Warn about pools with surprisingly low coverage despite having pair data + for j, pid in enumerate(pool_ids): + m = matches[pid] + key = (m["pair_key"], m["chain"]) + if key in pair_chain_histories and len(pair_chain_histories[key]) > 100: + col = competitor_tvl[:, j] + cov = np.isfinite(col).sum() / n_dates + if cov < 0.5: + print(f" WARNING: {pid[:16]} has pair data but only" + f" {cov*100:.0f}% coverage — possible date mismatch") + + # --- Compute K_eff = K_direct + multi-hop contributions --- + print(f"\nComputing network K_eff (direct + multi-hop)...") + # Compute multi-hop contribution over ALL dates in the grid + # (not just panel dates — so forward-fill works correctly) + k_eff = np.full((n_dates, n_pools), np.nan) + + def _get_pair_tvl_on_date(pair_key, chain, day): + """Get total TVL for a pair on a given date.""" + key = (pair_key, chain) + if key not in pair_chain_histories: + return 0.0 + hist = pair_chain_histories[key] + if day in hist.index: + return float(hist.loc[day]) + return 0.0 + + for j, pid in enumerate(pool_ids): + m = matches[pid] + tok_a, tok_b = m["tokens"] + chain = m["chain"] + + for t, date in enumerate(date_list): + day = pd.Timestamp(date).normalize() + + # Direct (from already-computed competitor_tvl) + direct = competitor_tvl[t, j] if np.isfinite(competitor_tvl[t, j]) else 0.0 + + # Multi-hop through hub tokens + multihop = 0.0 + for hub in HUB_TOKENS: + if hub in (tok_a, tok_b): + continue + pair_ah = tuple(sorted([tok_a, hub])) + pair_hb = tuple(sorted([hub, tok_b])) + L_ah = _get_pair_tvl_on_date(pair_ah, chain, day) + L_hb = _get_pair_tvl_on_date(pair_hb, chain, day) + if L_ah > 0 and L_hb > 0: + multihop += L_ah * L_hb / (L_ah + L_hb) + + total = direct + multihop + if total > 0: + k_eff[t, j] = total + + # Forward-fill / back-fill K_eff gaps + for j in range(n_pools): + col = k_eff[:, j] + mask = np.isfinite(col) & (col > 0) + if mask.any() and not mask.all(): + s = pd.Series(col, index=date_list).ffill().bfill() + k_eff[:, j] = s.values + + # Per-pool stats + print(f"\n {'Pool':>16s} {'Tokens':>20s} {'Pair':>20s}" + f" {'K_direct med':>14s} {'K_eff med':>14s} {'Multi/Dir':>10s}") + for j, pid in enumerate(pool_ids): + toks = matched_clean[pid].get("tokens", "?") + pair = matches[pid]["pair_key"] + # Use only panel dates for display (not forward-filled grid dates) + panel_dates = matched_clean[pid]["panel"]["date"].values + panel_t = [date_to_idx[d] for d in panel_dates if d in date_to_idx] + if panel_t: + d_vals = competitor_tvl[panel_t, j] + e_vals = k_eff[panel_t, j] + valid_d = d_vals[np.isfinite(d_vals)] + valid_e = e_vals[np.isfinite(e_vals)] + else: + valid_d = valid_e = np.array([]) + med_d = np.median(valid_d) if len(valid_d) > 0 else 0 + med_e = np.median(valid_e) if len(valid_e) > 0 else 0 + ratio = med_e / med_d if med_d > 0 else float("inf") + print(f" {pid[:16]} {toks:>20s} {'/'.join(pair):>20s}" + f" ${med_d:>13,.0f} ${med_e:>13,.0f} {ratio:>9.1f}x") + + # Save + out_path = os.path.join(args.cache_dir, "competitor_tvl.npz") + np.savez(out_path, + pool_ids=pool_ids, + date_list=np.array([str(d) for d in date_list]), + competitor_tvl=competitor_tvl, + k_eff=k_eff) + print(f"\nSaved: {out_path}") + + # Also save raw pair-chain histories for inspection + pair_path = os.path.join(args.cache_dir, "pair_chain_histories.pkl") + with open(pair_path, "wb") as f: + pickle.dump(pair_chain_histories, f) + print(f"Saved: {pair_path}") + + +if __name__ == "__main__": + main() diff --git a/experiments/run_cross_pool_diagnostics.py b/experiments/run_cross_pool_diagnostics.py new file mode 100644 index 0000000..ae9e2ed --- /dev/null +++ b/experiments/run_cross_pool_diagnostics.py @@ -0,0 +1,423 @@ +"""Diagnostic experiments for cross-pool noise calibration. + +Runs cheap experiments to bound the value of learned cross-pool aggregation: +1. Lambda_token sweep — is the LOO failure due to overfitting token effects? +2. Leave-one-in — how much pool-specific data closes the gap? +3. Naive AR baseline — is the model barely beating lag-1? +4. Pool connectivity — which pools are predictable at all? + +Uses cached stage1 data from run_token_factored_calibration.py. +""" + +import os +import pickle +import sys + +import numpy as np +import pandas as pd + +CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "token_factored_calibration", "_cache", +) +JOINT_MAXITER = 3000 # reduced for sweep speed + + +def load_stage1(): + """Load cached stage 1 (matched_clean + option_c_clean).""" + path = os.path.join(CACHE_DIR, "stage1.pkl") + if not os.path.exists(path): + print("ERROR: no stage1 cache. Run run_token_factored_calibration.py first.") + sys.exit(1) + with open(path, "rb") as f: + data = pickle.load(f) + print(f"Loaded {len(data['matched_clean'])} pools from cache") + return data["matched_clean"], data["option_c_clean"] + + +# ---- Diagnostic 1: Lambda_token sweep ---- + + +def run_lambda_token_sweep(matched_clean, option_c_clean): + """LOO with varying lambda_token to test whether overfitting is the problem.""" + from quantammsim.calibration.calibration_model import CalibrationModel + from quantammsim.calibration.heads import ( + FixedHead, PerPoolHead, TokenFactoredNoiseHead, + ) + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + from quantammsim.calibration.joint_fit import prepare_token_factored_data + from quantammsim.calibration.loss import CHAIN_GAS_USD + from quantammsim.calibration.pool_data import K_OBS_REDUCED, build_x_obs, _parse_tokens + import jax.numpy as jnp + + pool_ids = sorted(matched_clean.keys()) + lambda_tokens = [0.1, 0.5, 1.0, 5.0, 10.0] + + print("\n" + "=" * 70) + print("Diagnostic 1: Lambda_token sweep (LOO)") + print("=" * 70) + print(f" lambda_delta=1.0 fixed, sweeping lambda_token") + print(f" maxiter={JOINT_MAXITER}") + + all_results = {} + + for lt in lambda_tokens: + print(f"\n--- lambda_token={lt} ---") + loo_r2s = [] + + for hold_out_pid in pool_ids: + train_matched = {p: matched_clean[p] for p in pool_ids if p != hold_out_pid} + train_oc = {p: option_c_clean[p] for p in pool_ids if p != hold_out_pid} + + if len(train_matched) < 3: + continue + + jdata, enc = prepare_token_factored_data(train_matched) + + gas_values = [] + for pid in jdata.pool_ids: + chain = train_matched[pid]["chain"] + gas_values.append(np.log(max(CHAIN_GAS_USD.get(chain, 1.0), 1e-6))) + + noise_head = TokenFactoredNoiseHead( + k_obs=K_OBS_REDUCED, + lambda_delta=1.0, + lambda_token=lt, + **enc, + ) + model = CalibrationModel( + PerPoolHead("log_cadence", default=np.log(12.0)), + FixedHead("log_gas", np.array(gas_values)), + noise_head, + ) + result = model.fit(jdata, maxiter=JOINT_MAXITER, warm_start=train_oc) + + # Predict for held-out pool + n_train = len(jdata.pool_data) + k_attr = jdata.x_attr.shape[1] + (_, _), (_, _), (ns, ne) = model._head_slices(n_train, k_attr) + noise_params = result["params_flat"][ns:ne] + + ho_entry = matched_clean[hold_out_pid] + toks = _parse_tokens(ho_entry["tokens"]) + ho_pred = noise_head.predict_new_pool( + noise_params, toks[0], toks[1], + ho_entry["chain"], ho_entry["fee"], + n_pools=n_train, + ) + + # Evaluate + ho_panel = ho_entry["panel"] + x_obs_ho = build_x_obs(ho_panel, reduced=True) + y_obs_ho = ho_panel["log_volume"].values.astype(float) + + oc_ho = option_c_clean[hold_out_pid] + v_arb_all = np.array(interpolate_pool_daily( + ho_entry["coeffs"], + jnp.float64(oc_ho["log_cadence"]), + jnp.float64(np.exp(oc_ho["log_gas"])), + )) + v_arb = v_arb_all[ho_entry["day_indices"]] + v_noise = np.exp(x_obs_ho @ ho_pred["noise_coeffs"][:K_OBS_REDUCED]) + log_pred = np.log(np.maximum(v_arb + v_noise, 1e-6)) + ss_res = np.sum((log_pred - y_obs_ho) ** 2) + ss_tot = np.sum((y_obs_ho - y_obs_ho.mean()) ** 2) + r2 = 1 - ss_res / max(ss_tot, 1e-10) + loo_r2s.append(r2) + + tag = "OK" if r2 > 0 else "NEG" + print(f" {hold_out_pid[:16]} R²={r2:.3f} [{tag}]") + + median_r2 = np.median(loo_r2s) + wins = sum(1 for r2, pid in zip(loo_r2s, pool_ids) + if r2 > option_c_clean[pid].get("r2", 0)) + all_results[lt] = { + "median_r2": median_r2, + "mean_r2": np.mean(loo_r2s), + "r2s": loo_r2s, + "n_negative": sum(1 for r in loo_r2s if r < 0), + } + print(f" lambda_token={lt}: median R²={median_r2:.4f}, " + f"mean={np.mean(loo_r2s):.4f}, " + f"n_negative={sum(1 for r in loo_r2s if r < 0)}") + + # Summary table + print(f"\n{'='*60}") + print(f"{'lambda_token':>12} {'median_R²':>10} {'mean_R²':>10} {'n_neg':>6}") + print("-" * 42) + for lt in lambda_tokens: + r = all_results[lt] + print(f"{lt:>12.1f} {r['median_r2']:>10.4f} {r['mean_r2']:>10.4f} " + f"{r['n_negative']:>6}") + + return all_results + + +# ---- Diagnostic 2: Leave-one-in ---- + + +def run_leave_one_in(matched_clean, option_c_clean, n_days_in=30): + """LOO but give held-out pool n_days_in days of data for adaptation.""" + from quantammsim.calibration.calibration_model import CalibrationModel + from quantammsim.calibration.heads import ( + FixedHead, PerPoolHead, TokenFactoredNoiseHead, + ) + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + from quantammsim.calibration.joint_fit import prepare_token_factored_data + from quantammsim.calibration.loss import CHAIN_GAS_USD + from quantammsim.calibration.pool_data import K_OBS_REDUCED, build_x_obs, _parse_tokens + import jax.numpy as jnp + + pool_ids = sorted(matched_clean.keys()) + + print("\n" + "=" * 70) + print(f"Diagnostic 2: Leave-one-in ({n_days_in} days of held-out data)") + print("=" * 70) + + results = [] + + for hold_out_pid in pool_ids: + ho_entry = matched_clean[hold_out_pid] + ho_panel = ho_entry["panel"] + n_obs = len(ho_panel) + + if n_obs <= n_days_in + 10: + print(f" {hold_out_pid[:16]} — too few obs ({n_obs}), skipping") + continue + + # Split: first n_days_in for training, rest for evaluation + train_panel = ho_panel.iloc[:n_days_in].copy() + eval_panel = ho_panel.iloc[n_days_in:].copy() + train_day_indices = ho_entry["day_indices"][:n_days_in] + eval_day_indices = ho_entry["day_indices"][n_days_in:] + + # Build training matched: all other pools + truncated held-out pool + train_matched = {} + for p in pool_ids: + if p != hold_out_pid: + train_matched[p] = matched_clean[p] + + # Add truncated held-out pool + ho_train_entry = dict(ho_entry) + ho_train_entry["panel"] = train_panel.reset_index(drop=True) + ho_train_entry["day_indices"] = train_day_indices + train_matched[hold_out_pid] = ho_train_entry + + train_oc = dict(option_c_clean) # all pools including held-out + + # Fit with held-out pool included (gets its own delta from 30 days) + jdata, enc = prepare_token_factored_data(train_matched) + + gas_values = [] + for pid in jdata.pool_ids: + chain = train_matched[pid]["chain"] + gas_values.append(np.log(max(CHAIN_GAS_USD.get(chain, 1.0), 1e-6))) + + noise_head = TokenFactoredNoiseHead( + k_obs=K_OBS_REDUCED, + lambda_delta=1.0, + lambda_token=0.1, + **enc, + ) + model = CalibrationModel( + PerPoolHead("log_cadence", default=np.log(12.0)), + FixedHead("log_gas", np.array(gas_values)), + noise_head, + ) + result = model.fit(jdata, maxiter=JOINT_MAXITER, warm_start=train_oc) + + # Find held-out pool's index in training set and extract noise_coeffs + ho_idx = jdata.pool_ids.index(hold_out_pid) + noise_coeffs = result["noise_coeffs"][ho_idx] + + # Evaluate on held-out days + x_obs_eval = build_x_obs(eval_panel, reduced=True) + y_obs_eval = eval_panel["log_volume"].values.astype(float) + + oc_ho = option_c_clean[hold_out_pid] + v_arb_all = np.array(interpolate_pool_daily( + ho_entry["coeffs"], + jnp.float64(oc_ho["log_cadence"]), + jnp.float64(np.exp(oc_ho["log_gas"])), + )) + v_arb = v_arb_all[eval_day_indices] + v_noise = np.exp(x_obs_eval @ noise_coeffs[:K_OBS_REDUCED]) + log_pred = np.log(np.maximum(v_arb + v_noise, 1e-6)) + ss_res = np.sum((log_pred - y_obs_eval) ** 2) + ss_tot = np.sum((y_obs_eval - y_obs_eval.mean()) ** 2) + r2_in = 1 - ss_res / max(ss_tot, 1e-10) + + # Also compute Option C R² on eval days for comparison + v_noise_c = np.exp(x_obs_eval @ oc_ho["noise_coeffs"][:K_OBS_REDUCED]) + log_pred_c = np.log(np.maximum(v_arb + v_noise_c, 1e-6)) + ss_res_c = np.sum((log_pred_c - y_obs_eval) ** 2) + r2_c_eval = 1 - ss_res_c / max(ss_tot, 1e-10) + + results.append({ + "pool_id": hold_out_pid, + "r2_leave_one_in": r2_in, + "r2_option_c_eval": r2_c_eval, + "n_train_days": n_days_in, + "n_eval_days": len(eval_panel), + "tokens": ho_entry["tokens"], + }) + + print(f" {hold_out_pid[:16]} ({ho_entry['tokens']:<14}) " + f"R²_in={r2_in:.3f} R²_C_eval={r2_c_eval:.3f} " + f"n_eval={len(eval_panel)}") + + if results: + r2s_in = [r["r2_leave_one_in"] for r in results] + r2s_c = [r["r2_option_c_eval"] for r in results] + print(f"\n Leave-one-in ({n_days_in}d): median R²={np.median(r2s_in):.4f}") + print(f" Option C (eval days): median R²={np.median(r2s_c):.4f}") + print(f" Recall: zero-shot LOO: median R²=0.362") + + return results + + +# ---- Diagnostic 3: Naive AR baseline ---- + + +def run_naive_ar_baseline(matched_clean): + """Compute R² of vol_tomorrow = vol_today (no model, no cross-pool).""" + print("\n" + "=" * 70) + print("Diagnostic 3: Naive autoregressive baseline (lag-1 copy)") + print("=" * 70) + + pool_r2s = [] + for pid in sorted(matched_clean.keys()): + panel = matched_clean[pid]["panel"] + y = panel["log_volume"].values.astype(float) + + if len(y) < 3: + continue + + # Predict day t from day t-1 + y_true = y[1:] + y_pred = y[:-1] + + ss_res = np.sum((y_pred - y_true) ** 2) + ss_tot = np.sum((y_true - y_true.mean()) ** 2) + r2 = 1 - ss_res / max(ss_tot, 1e-10) + pool_r2s.append(r2) + + print(f" {pid[:16]} ({matched_clean[pid]['tokens']:<14}) " + f"R²_AR1={r2:.3f} n_obs={len(y)}") + + print(f"\n Naive AR1: median R²={np.median(pool_r2s):.4f}, " + f"mean={np.mean(pool_r2s):.4f}") + print(f" Recall: zero-shot LOO = 0.362, Option C in-sample = 0.589") + + return pool_r2s + + +# ---- Diagnostic 4: Pool connectivity analysis ---- + + +def run_connectivity_analysis(matched_clean, option_c_clean): + """Analyze token overlap and partition LOO R² by connectivity.""" + from quantammsim.calibration.pool_data import _parse_tokens, _canonicalize_token + + print("\n" + "=" * 70) + print("Diagnostic 4: Pool connectivity analysis") + print("=" * 70) + + pool_ids = sorted(matched_clean.keys()) + + # Build canonical token sets per pool + pool_tokens = {} + for pid in pool_ids: + toks = _parse_tokens(matched_clean[pid]["tokens"]) + canon = {_canonicalize_token(t) for t in toks[:2]} + pool_tokens[pid] = canon + + # Count: for each pool, how many other pools share at least 1 token? + # And how many share both tokens? + print(f"\n{'Pool':<18} {'Tokens':<16} {'1+ shared':>10} {'2 shared':>10} " + f"{'R²_C':>8}") + print("-" * 66) + + connectivity = [] + for pid in pool_ids: + my_toks = pool_tokens[pid] + n_one_shared = 0 + n_both_shared = 0 + for other in pool_ids: + if other == pid: + continue + overlap = len(my_toks & pool_tokens[other]) + if overlap >= 1: + n_one_shared += 1 + if overlap >= 2: + n_both_shared += 1 + + oc = option_c_clean[pid] + r2_c = 1 - oc["loss"] / max( + np.var(matched_clean[pid]["panel"]["log_volume"].values) * + len(matched_clean[pid]["panel"]) / + max(len(matched_clean[pid]["panel"]) - 1, 1), + 1e-10, + ) + + connectivity.append({ + "pool_id": pid, + "tokens": matched_clean[pid]["tokens"], + "n_one_shared": n_one_shared, + "n_both_shared": n_both_shared, + }) + + print(f" {pid[:16]} {matched_clean[pid]['tokens']:<16} " + f"{n_one_shared:>10} {n_both_shared:>10}") + + # Partition: well-connected (1+ shared ≥ 3) vs isolated + well_connected = [c for c in connectivity if c["n_one_shared"] >= 3] + isolated = [c for c in connectivity if c["n_one_shared"] < 3] + + print(f"\n Well-connected (≥3 pools share a token): {len(well_connected)}") + print(f" Isolated (<3 pools share a token): {len(isolated)}") + + if isolated: + print(f"\n Isolated pools:") + for c in isolated: + print(f" {c['pool_id'][:16]} {c['tokens']}") + + return connectivity + + +# ---- Main ---- + + +def main(): + os.environ.setdefault("JAX_PLATFORMS", "cpu") + + print("=" * 70) + print("Cross-Pool Calibration Diagnostics") + print("=" * 70) + + matched_clean, option_c_clean = load_stage1() + + # Run all diagnostics + ar_results = run_naive_ar_baseline(matched_clean) + connectivity = run_connectivity_analysis(matched_clean, option_c_clean) + leave_one_in = run_leave_one_in(matched_clean, option_c_clean, n_days_in=30) + lambda_sweep = run_lambda_token_sweep(matched_clean, option_c_clean) + + # Final summary + print("\n" + "=" * 70) + print("SUMMARY") + print("=" * 70) + print(f" Naive AR1 baseline: median R² = {np.median(ar_results):.4f}") + if leave_one_in: + r2s_in = [r["r2_leave_one_in"] for r in leave_one_in] + print(f" Leave-one-in (30 days): median R² = {np.median(r2s_in):.4f}") + print(f" Zero-shot LOO (current): median R² = 0.362") + print(f" Option C in-sample: median R² = 0.589") + print(f"\n Lambda_token sweep:") + for lt, r in sorted(lambda_sweep.items()): + print(f" lambda_token={lt:>5.1f}: median R² = {r['median_r2']:.4f} " + f"(n_neg={r['n_negative']})") + + +if __name__ == "__main__": + main() diff --git a/experiments/run_cross_pool_linear.py b/experiments/run_cross_pool_linear.py new file mode 100644 index 0000000..acd5c48 --- /dev/null +++ b/experiments/run_cross_pool_linear.py @@ -0,0 +1,457 @@ +"""Cross-pool linear volume prediction baselines. + +1. Ridge cross-pool regression: log_vol_i_t = W_i @ log_vol_{-i, t-1} + - In-sample: fit full 36x36 W, evaluate on training data + - LOO: hold out pool i, fit W on 35 pools, predict pool i using + token-overlap-weighted average of learned rows (transfer via similarity) + - LOO with burn-in: use 30 days of pool i to learn its row of W directly + +2. Zero-parameter peer-mean: predicted_vol_i_t = mean(log_vol_{j,t-1}) + for peers sharing a canonical token with pool i +""" + +import os +import pickle +import sys + +import numpy as np +from sklearn.linear_model import RidgeCV + +CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "token_factored_calibration", "_cache", +) + + +def load_stage1(): + path = os.path.join(CACHE_DIR, "stage1.pkl") + if not os.path.exists(path): + print("ERROR: no stage1 cache. Run run_token_factored_calibration.py first.") + sys.exit(1) + with open(path, "rb") as f: + data = pickle.load(f) + print(f"Loaded {len(data['matched_clean'])} pools from cache") + return data["matched_clean"], data["option_c_clean"] + + +def build_volume_matrix(matched_clean): + """Build (n_dates, n_pools) aligned volume matrix. + + Returns vol_matrix, date_list, pool_ids. + Dates are the intersection of all pools' date ranges. + Missing values filled with NaN. + """ + pool_ids = sorted(matched_clean.keys()) + + # Collect all (pool, date) -> log_volume + pool_date_vol = {} + all_dates = set() + for pid in pool_ids: + panel = matched_clean[pid]["panel"] + dates = panel["date"].values + vols = panel["log_volume"].values.astype(float) + pool_date_vol[pid] = dict(zip(dates, vols)) + all_dates.update(dates) + + date_list = sorted(all_dates) + n_dates = len(date_list) + n_pools = len(pool_ids) + + vol_matrix = np.full((n_dates, n_pools), np.nan) + for j, pid in enumerate(pool_ids): + dv = pool_date_vol[pid] + for t, date in enumerate(date_list): + if date in dv: + vol_matrix[t, j] = dv[date] + + return vol_matrix, date_list, pool_ids + + +def build_token_overlap(matched_clean, pool_ids): + """Build (n_pools, n_pools) token overlap matrix (0, 1, or 2).""" + from quantammsim.calibration.pool_data import _parse_tokens, _canonicalize_token + + n = len(pool_ids) + overlap = np.zeros((n, n), dtype=np.int32) + pool_tokens = {} + for i, pid in enumerate(pool_ids): + toks = _parse_tokens(matched_clean[pid]["tokens"]) + pool_tokens[i] = {_canonicalize_token(t) for t in toks[:2]} + + for i in range(n): + for j in range(n): + overlap[i, j] = len(pool_tokens[i] & pool_tokens[j]) + + return overlap + + +def r2_score(y_true, y_pred): + ss_res = np.sum((y_true - y_pred) ** 2) + ss_tot = np.sum((y_true - y_true.mean()) ** 2) + return 1 - ss_res / max(ss_tot, 1e-10) + + +# ---- 1. Ridge cross-pool regression ---- + + +def run_ridge_cross_pool(matched_clean): + """Full cross-pool ridge: predict each pool from all others' lag-1.""" + vol_matrix, date_list, pool_ids = build_volume_matrix(matched_clean) + n_dates, n_pools = vol_matrix.shape + + # Check if fully-observed rows exist + X_lag = vol_matrix[:-1, :] + Y_cur = vol_matrix[1:, :] + valid = ~np.any(np.isnan(X_lag), axis=1) & ~np.any(np.isnan(Y_cur), axis=1) + n_valid = int(valid.sum()) + + # Peers only + print("\n" + "=" * 70) + print("1a. Ridge cross-pool regression (in-sample, peers only)") + print("=" * 70) + print(f" {n_pools} pools, {n_valid} fully-observed day pairs " + f"(of {n_dates-1} total)") + print(" Using per-pool valid rows with NaN imputation.") + r2_peers, _, _, _ = _run_ridge_per_pool_valid( + vol_matrix, pool_ids, matched_clean, include_own_lag=False) + + # Peers + own lag + print("\n" + "=" * 70) + print("1a+. Ridge cross-pool regression (in-sample, peers + own lag)") + print("=" * 70) + r2_both, _, _, _ = _run_ridge_per_pool_valid( + vol_matrix, pool_ids, matched_clean, include_own_lag=True) + + return r2_peers, r2_both, vol_matrix, date_list, pool_ids + + +def _run_ridge_per_pool_valid(vol_matrix, pool_ids, matched_clean, + include_own_lag=False): + """Fallback: per-pool ridge using only rows where pool i AND predictors have data.""" + n_dates, n_pools = vol_matrix.shape + tag = " + own_lag" if include_own_lag else "" + + pool_r2s = [] + for i, pid in enumerate(pool_ids): + X_lag = vol_matrix[:-1, :] + y_cur = vol_matrix[1:, i] + own_lag = X_lag[:, i] # pool i's own lag + + # Valid: pool i has data today AND own lag exists (if used) + valid_y = ~np.isnan(y_cur) + if include_own_lag: + valid_y = valid_y & ~np.isnan(own_lag) + + X_others = np.delete(X_lag, i, axis=1) + + # For each predictor, fill NaN with that predictor's mean (simple imputation) + X_filled = X_others.copy() + for j in range(X_filled.shape[1]): + col = X_filled[:, j] + col_mean = np.nanmean(col) + col[np.isnan(col)] = col_mean + X_filled[:, j] = col + + if include_own_lag: + X_full = np.column_stack([X_filled, own_lag[:, None]]) + else: + X_full = X_filled + + X_i = X_full[valid_y] + y_i = y_cur[valid_y] + + if len(y_i) < 10: + pool_r2s.append(np.nan) + continue + + model = RidgeCV(alphas=np.logspace(-2, 4, 50)) + model.fit(X_i, y_i) + y_pred = model.predict(X_i) + r2 = r2_score(y_i, y_pred) + pool_r2s.append(r2) + + print(f" {pid[:16]} ({matched_clean[pid]['tokens']:<14}) " + f"R²={r2:.3f} n_obs={len(y_i)} alpha={model.alpha_:.1f}") + + valid_r2s = [r for r in pool_r2s if not np.isnan(r)] + print(f"\n In-sample ridge{tag}: median R²={np.median(valid_r2s):.4f}, " + f"mean={np.mean(valid_r2s):.4f}") + + return pool_r2s, vol_matrix, None, pool_ids + + +def run_ridge_loo(matched_clean): + """LOO cross-pool ridge with token-overlap transfer.""" + print("\n" + "=" * 70) + print("1b. Ridge cross-pool LOO (transfer via token overlap)") + print("=" * 70) + + vol_matrix, date_list, pool_ids = build_volume_matrix(matched_clean) + n_dates, n_pools = vol_matrix.shape + overlap = build_token_overlap(matched_clean, pool_ids) + + pool_r2s = [] + for i, pid in enumerate(pool_ids): + # Training pools: all except i + train_idx = [j for j in range(n_pools) if j != i] + n_train = len(train_idx) + + # Build training data: for each training pool k, predict from others' lag + # Use per-pool valid rows with NaN imputation + X_lag_all = vol_matrix[:-1, :] + Y_cur_all = vol_matrix[1:, :] + + # Fit a ridge model for each training pool + train_models = {} + train_weights = {} # weight vectors (excluding self) + for k_pos, k in enumerate(train_idx): + # Predictors: all pools except k (including pool i's historical data!) + pred_idx = [j for j in range(n_pools) if j != k] + X_k = X_lag_all[:, pred_idx].copy() + y_k = Y_cur_all[:, k] + + valid = ~np.isnan(y_k) + for c in range(X_k.shape[1]): + col = X_k[:, c] + m = np.nanmean(col) + col[np.isnan(col)] = m if np.isfinite(m) else 0.0 + X_k[:, c] = col + + X_k = X_k[valid] + y_k = y_k[valid] + + if len(y_k) < 10: + continue + + model = RidgeCV(alphas=np.logspace(-2, 4, 50)) + model.fit(X_k, y_k) + train_models[k] = model + + # Store full weight vector (n_pools-1,) with mapping to pool indices + w = np.zeros(n_pools) + for widx, pidx in enumerate(pred_idx): + w[pidx] = model.coef_[widx] + w_intercept = model.intercept_ + train_weights[k] = (w, w_intercept) + + if not train_weights: + pool_r2s.append(np.nan) + continue + + # Transfer to held-out pool i: weighted average of training pools' weight vectors + # Weight by token overlap with pool i + w_transfer = np.zeros(n_pools) + intercept_transfer = 0.0 + total_sim = 0.0 + for k in train_weights: + sim = overlap[i, k] + if sim == 0: + sim = 0.1 # small weight for unrelated pools + w_k, b_k = train_weights[k] + w_transfer += sim * w_k + intercept_transfer += sim * b_k + total_sim += sim + + w_transfer /= total_sim + intercept_transfer /= total_sim + + # Zero out pool i's own weight (shouldn't predict from self) + w_transfer[i] = 0.0 + + # Predict pool i + X_lag_i = vol_matrix[:-1, :].copy() + y_true_i = vol_matrix[1:, i] + valid = ~np.isnan(y_true_i) + + # Impute NaN predictors + for c in range(X_lag_i.shape[1]): + col = X_lag_i[:, c] + m = np.nanmean(col) + col[np.isnan(col)] = m if np.isfinite(m) else 0.0 + X_lag_i[:, c] = col + + y_pred_i = X_lag_i[valid] @ w_transfer + intercept_transfer + y_true_i = y_true_i[valid] + + r2 = r2_score(y_true_i, y_pred_i) + pool_r2s.append(r2) + + print(f" {pid[:16]} ({matched_clean[pid]['tokens']:<14}) " + f"R²={r2:.3f} n_eval={len(y_true_i)}") + + valid_r2s = [r for r in pool_r2s if not np.isnan(r)] + print(f"\n LOO ridge (overlap transfer): median R²={np.median(valid_r2s):.4f}, " + f"mean={np.mean(valid_r2s):.4f}") + print(f" Recall: AR1={0.397:.3f}, zero-shot token-factored={0.362:.3f}") + + return pool_r2s + + +def run_ridge_loo_burnin(matched_clean, n_burnin=30): + """LOO with burn-in: learn pool i's weight row from n_burnin days.""" + print("\n" + "=" * 70) + print(f"1c. Ridge cross-pool LOO with {n_burnin}-day burn-in") + print("=" * 70) + + vol_matrix, date_list, pool_ids = build_volume_matrix(matched_clean) + n_dates, n_pools = vol_matrix.shape + + pool_r2s = [] + for i, pid in enumerate(pool_ids): + # Pool i's data + y_all = vol_matrix[:, i] + valid_days = ~np.isnan(y_all) + valid_indices = np.where(valid_days)[0] + + if len(valid_indices) < n_burnin + 10: + print(f" {pid[:16]} — too few obs, skipping") + pool_r2s.append(np.nan) + continue + + # Split: first n_burnin valid days for training, rest for eval + burn_indices = valid_indices[:n_burnin] + eval_indices = valid_indices[n_burnin:] + + # Training: predict pool i from all others' lag using burn-in days + # Need (day, day-1) pairs where day is in burn_indices and day >= 1 + burn_pairs = burn_indices[burn_indices >= 1] + + pred_idx = [j for j in range(n_pools) if j != i] + X_burn = vol_matrix[burn_pairs - 1][:, pred_idx].copy() + y_burn = vol_matrix[burn_pairs, i] + + # Impute NaN + for c in range(X_burn.shape[1]): + col = X_burn[:, c] + m = np.nanmean(col) + col[np.isnan(col)] = m if np.isfinite(m) else 0.0 + X_burn[:, c] = col + + model = RidgeCV(alphas=np.logspace(-2, 4, 50)) + model.fit(X_burn, y_burn) + + # Evaluate on remaining days + eval_pairs = eval_indices[eval_indices >= 1] + X_eval = vol_matrix[eval_pairs - 1][:, pred_idx].copy() + y_eval = vol_matrix[eval_pairs, i] + + for c in range(X_eval.shape[1]): + col = X_eval[:, c] + m = np.nanmean(col) + col[np.isnan(col)] = m if np.isfinite(m) else 0.0 + X_eval[:, c] = col + + y_pred = model.predict(X_eval) + r2 = r2_score(y_eval, y_pred) + pool_r2s.append(r2) + + print(f" {pid[:16]} ({matched_clean[pid]['tokens']:<14}) " + f"R²={r2:.3f} n_burn={len(burn_pairs)} n_eval={len(eval_pairs)} " + f"alpha={model.alpha_:.1f}") + + valid_r2s = [r for r in pool_r2s if not np.isnan(r)] + print(f"\n LOO ridge ({n_burnin}d burn-in): " + f"median R²={np.median(valid_r2s):.4f}, " + f"mean={np.mean(valid_r2s):.4f}") + + return pool_r2s + + +# ---- 2. Zero-parameter peer-mean ---- + + +def run_peer_mean_baseline(matched_clean): + """Predict pool i's volume as mean of token-peer lagged volumes.""" + from quantammsim.calibration.pool_data import _parse_tokens, _canonicalize_token + + print("\n" + "=" * 70) + print("2. Zero-parameter peer-mean baseline") + print("=" * 70) + + vol_matrix, date_list, pool_ids = build_volume_matrix(matched_clean) + n_dates, n_pools = vol_matrix.shape + overlap = build_token_overlap(matched_clean, pool_ids) + + pool_r2s = [] + for i, pid in enumerate(pool_ids): + # Peers: pools sharing at least 1 token + peers = [j for j in range(n_pools) if j != i and overlap[i, j] >= 1] + + if not peers: + pool_r2s.append(np.nan) + continue + + y_true = vol_matrix[1:, i] + valid = ~np.isnan(y_true) + + # Peer mean at t-1 + peer_lag = vol_matrix[:-1, :][:, peers] + peer_mean = np.nanmean(peer_lag, axis=1) + + y_pred = peer_mean[valid] + y_true = y_true[valid] + + # Remove any remaining NaN + both_valid = ~np.isnan(y_pred) + y_pred = y_pred[both_valid] + y_true = y_true[both_valid] + + r2 = r2_score(y_true, y_pred) + pool_r2s.append(r2) + + print(f" {pid[:16]} ({matched_clean[pid]['tokens']:<14}) " + f"R²={r2:.3f} n_peers={len(peers)} n_obs={len(y_true)}") + + valid_r2s = [r for r in pool_r2s if not np.isnan(r)] + print(f"\n Peer-mean baseline: median R²={np.median(valid_r2s):.4f}, " + f"mean={np.mean(valid_r2s):.4f}") + print(f" Recall: AR1={0.397:.3f}, zero-shot token-factored={0.362:.3f}") + + return pool_r2s + + +# ---- Main ---- + + +def main(): + os.environ.setdefault("JAX_PLATFORMS", "cpu") + + print("=" * 70) + print("Cross-Pool Linear Volume Prediction Baselines") + print("=" * 70) + + matched_clean, option_c_clean = load_stage1() + + # 1a. In-sample ridge (peers only + peers + own lag) + r2_peers, r2_both, vol_matrix, date_list, pool_ids = run_ridge_cross_pool(matched_clean) + + # 1b. LOO ridge with token-overlap transfer + loo_r2s = run_ridge_loo(matched_clean) + + # 1c. LOO ridge with 30-day burn-in + burnin_r2s = run_ridge_loo_burnin(matched_clean, n_burnin=30) + + # 2. Zero-parameter peer mean + peer_r2s = run_peer_mean_baseline(matched_clean) + + # Summary + def safe_median(xs): + v = [x for x in xs if not np.isnan(x)] + return np.median(v) if v else float("nan") + + print("\n" + "=" * 70) + print("SUMMARY") + print("=" * 70) + print(f" Ridge in-sample (peers): median R² = {safe_median(r2_peers):.4f}") + print(f" Ridge in-sample (+own): median R² = {safe_median(r2_both):.4f}") + print(f" Ridge LOO (overlap xfer): median R² = {safe_median(loo_r2s):.4f}") + print(f" Ridge LOO (30d burn-in): median R² = {safe_median(burnin_r2s):.4f}") + print(f" Peer-mean (0 params): median R² = {safe_median(peer_r2s):.4f}") + print(f" ---") + print(f" Naive AR1: median R² = 0.397") + print(f" Token-factored LOO: median R² = 0.362") + print(f" Option C in-sample: median R² = 0.589") + + +if __name__ == "__main__": + main() diff --git a/experiments/run_cross_pool_noise_linear.py b/experiments/run_cross_pool_noise_linear.py new file mode 100644 index 0000000..08ead10 --- /dev/null +++ b/experiments/run_cross_pool_noise_linear.py @@ -0,0 +1,392 @@ +"""Cross-pool linear prediction of NOISE residuals. + +Decomposes total volume into V_arb (from grid + Option C cadence/gas) +and noise residual, then tests whether peer pools' lagged noise +residuals predict this pool's noise residual. + +1. Ridge in-sample: noise_resid_i_t = W_i @ noise_resid_{-i, t-1} [+ own_lag] +2. Ridge LOO with overlap transfer +3. Ridge LOO with 30-day burn-in +""" + +import os +import pickle +import sys + +import numpy as np +from sklearn.linear_model import RidgeCV + +CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "token_factored_calibration", "_cache", +) + + +def load_stage1(): + path = os.path.join(CACHE_DIR, "stage1.pkl") + if not os.path.exists(path): + print("ERROR: no stage1 cache.") + sys.exit(1) + with open(path, "rb") as f: + data = pickle.load(f) + print(f"Loaded {len(data['matched_clean'])} pools from cache") + return data["matched_clean"], data["option_c_clean"] + + +def build_noise_residual_matrix(matched_clean, option_c_clean): + """Build (n_dates, n_pools) noise residual matrix. + + noise_resid_i_t = log_volume_i_t - log(V_arb_i_t) + + V_arb computed from grid interpolation at Option C cadence/gas. + Returns residual matrix (NaN where missing), date list, pool ids. + """ + import jax.numpy as jnp + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + + pool_ids = sorted(matched_clean.keys()) + n_pools = len(pool_ids) + + # Collect all dates + all_dates = set() + for pid in pool_ids: + panel = matched_clean[pid]["panel"] + all_dates.update(panel["date"].values) + date_list = sorted(all_dates) + n_dates = len(date_list) + date_to_idx = {d: i for i, d in enumerate(date_list)} + + # Build matrices + vol_matrix = np.full((n_dates, n_pools), np.nan) + resid_matrix = np.full((n_dates, n_pools), np.nan) + + for j, pid in enumerate(pool_ids): + entry = matched_clean[pid] + oc = option_c_clean[pid] + panel = entry["panel"] + coeffs = entry["coeffs"] + day_indices = entry["day_indices"] + + # Compute V_arb from grid + v_arb_all = np.array(interpolate_pool_daily( + coeffs, + jnp.float64(oc["log_cadence"]), + jnp.float64(np.exp(oc["log_gas"])), + )) + v_arb = v_arb_all[day_indices] + log_v_arb = np.log(np.maximum(v_arb, 1e-6)) + + # Fill matrices + dates = panel["date"].values + log_vols = panel["log_volume"].values.astype(float) + + for k, date in enumerate(dates): + t = date_to_idx[date] + vol_matrix[t, j] = log_vols[k] + resid_matrix[t, j] = log_vols[k] - log_v_arb[k] + + print(f" Built noise residual matrix: {n_dates} dates x {n_pools} pools") + print(f" Residual stats: mean={np.nanmean(resid_matrix):.3f}, " + f"std={np.nanstd(resid_matrix):.3f}") + + return vol_matrix, resid_matrix, date_list, pool_ids + + +def build_token_overlap(matched_clean, pool_ids): + from quantammsim.calibration.pool_data import _parse_tokens, _canonicalize_token + n = len(pool_ids) + overlap = np.zeros((n, n), dtype=np.int32) + pool_tokens = {} + for i, pid in enumerate(pool_ids): + toks = _parse_tokens(matched_clean[pid]["tokens"]) + pool_tokens[i] = {_canonicalize_token(t) for t in toks[:2]} + for i in range(n): + for j in range(n): + overlap[i, j] = len(pool_tokens[i] & pool_tokens[j]) + return overlap + + +def r2_score(y_true, y_pred): + ss_res = np.sum((y_true - y_pred) ** 2) + ss_tot = np.sum((y_true - y_true.mean()) ** 2) + return 1 - ss_res / max(ss_tot, 1e-10) + + +# ---- In-sample ridge on noise residuals ---- + + +def run_ridge_insample(resid_matrix, pool_ids, matched_clean): + """Per-pool ridge: predict noise_resid from peers' lagged residuals.""" + n_dates, n_pools = resid_matrix.shape + + for include_own in [False, True]: + tag = "peers + own_lag" if include_own else "peers only" + print(f"\n{'='*70}") + print(f"Ridge in-sample on noise residuals ({tag})") + print(f"{'='*70}") + + pool_r2s = [] + for i, pid in enumerate(pool_ids): + X_lag = resid_matrix[:-1, :] + y_cur = resid_matrix[1:, i] + own_lag = X_lag[:, i] + + valid = ~np.isnan(y_cur) + if include_own: + valid = valid & ~np.isnan(own_lag) + + X_others = np.delete(X_lag, i, axis=1) + X_filled = X_others.copy() + for c in range(X_filled.shape[1]): + col = X_filled[:, c] + m = np.nanmean(col) + col[np.isnan(col)] = m if np.isfinite(m) else 0.0 + X_filled[:, c] = col + + if include_own: + X_full = np.column_stack([X_filled, own_lag[:, None]]) + else: + X_full = X_filled + + X_i = X_full[valid] + y_i = y_cur[valid] + + if len(y_i) < 10: + pool_r2s.append(np.nan) + continue + + model = RidgeCV(alphas=np.logspace(-2, 4, 50)) + model.fit(X_i, y_i) + y_pred = model.predict(X_i) + r2 = r2_score(y_i, y_pred) + pool_r2s.append(r2) + + print(f" {pid[:16]} ({matched_clean[pid]['tokens']:<14}) " + f"R²={r2:.3f} n={len(y_i)} alpha={model.alpha_:.1f}") + + valid_r2s = [r for r in pool_r2s if not np.isnan(r)] + print(f"\n In-sample ridge ({tag}): median R²={np.median(valid_r2s):.4f}, " + f"mean={np.mean(valid_r2s):.4f}") + + return pool_r2s + + +def run_ar1_noise_baseline(resid_matrix, pool_ids, matched_clean): + """Naive AR1 on noise residuals: resid_tomorrow = resid_today.""" + print(f"\n{'='*70}") + print("AR1 baseline on noise residuals") + print(f"{'='*70}") + + pool_r2s = [] + for i, pid in enumerate(pool_ids): + y = resid_matrix[:, i] + valid = ~np.isnan(y[:-1]) & ~np.isnan(y[1:]) + y_true = y[1:][valid] + y_pred = y[:-1][valid] + + if len(y_true) < 3: + pool_r2s.append(np.nan) + continue + + r2 = r2_score(y_true, y_pred) + pool_r2s.append(r2) + print(f" {pid[:16]} ({matched_clean[pid]['tokens']:<14}) " + f"R²={r2:.3f} n={len(y_true)}") + + valid_r2s = [r for r in pool_r2s if not np.isnan(r)] + print(f"\n AR1 noise residual: median R²={np.median(valid_r2s):.4f}, " + f"mean={np.mean(valid_r2s):.4f}") + return pool_r2s + + +# ---- LOO with overlap transfer ---- + + +def run_ridge_loo(resid_matrix, pool_ids, matched_clean): + """LOO on noise residuals with token-overlap weight transfer.""" + print(f"\n{'='*70}") + print("Ridge LOO on noise residuals (overlap transfer, peers + own_lag)") + print(f"{'='*70}") + + n_dates, n_pools = resid_matrix.shape + overlap = build_token_overlap(matched_clean, pool_ids) + + pool_r2s = [] + for i, pid in enumerate(pool_ids): + train_idx = [j for j in range(n_pools) if j != i] + + # Fit ridge for each training pool + train_weights = {} + for k in train_idx: + pred_idx = [j for j in range(n_pools) if j != k] + X_lag = resid_matrix[:-1, pred_idx].copy() + own_lag_k = resid_matrix[:-1, k] + y_k = resid_matrix[1:, k] + + valid = ~np.isnan(y_k) & ~np.isnan(own_lag_k) + for c in range(X_lag.shape[1]): + col = X_lag[:, c] + m = np.nanmean(col) + col[np.isnan(col)] = m if np.isfinite(m) else 0.0 + X_lag[:, c] = col + + X_k = np.column_stack([X_lag[valid], own_lag_k[valid, None]]) + y_k = y_k[valid] + + if len(y_k) < 10: + continue + + model = RidgeCV(alphas=np.logspace(-2, 4, 50)) + model.fit(X_k, y_k) + + # Store weights mapped to pool indices + w = np.zeros(n_pools + 1) # +1 for own_lag + for widx, pidx in enumerate(pred_idx): + w[pidx] = model.coef_[widx] + w[-1] = model.coef_[-1] # own_lag weight + train_weights[k] = (w, model.intercept_) + + if not train_weights: + pool_r2s.append(np.nan) + continue + + # Transfer: overlap-weighted average of training pool weight vectors + w_transfer = np.zeros(n_pools + 1) + b_transfer = 0.0 + total_sim = 0.0 + for k in train_weights: + sim = max(overlap[i, k], 0.1) + w_k, b_k = train_weights[k] + w_transfer += sim * w_k + b_transfer += sim * b_k + total_sim += sim + w_transfer /= total_sim + b_transfer /= total_sim + w_transfer[i] = 0.0 # no self-prediction from peers + + # Predict held-out pool + X_lag_all = resid_matrix[:-1, :].copy() + own_lag_i = resid_matrix[:-1, i] + y_true = resid_matrix[1:, i] + valid = ~np.isnan(y_true) & ~np.isnan(own_lag_i) + + for c in range(X_lag_all.shape[1]): + col = X_lag_all[:, c] + m = np.nanmean(col) + col[np.isnan(col)] = m if np.isfinite(m) else 0.0 + X_lag_all[:, c] = col + + X_i = np.column_stack([X_lag_all[valid], own_lag_i[valid, None]]) + y_pred = X_i @ w_transfer + b_transfer + y_true = y_true[valid] + + r2 = r2_score(y_true, y_pred) + pool_r2s.append(r2) + print(f" {pid[:16]} ({matched_clean[pid]['tokens']:<14}) " + f"R²={r2:.3f} n={len(y_true)}") + + valid_r2s = [r for r in pool_r2s if not np.isnan(r)] + print(f"\n LOO ridge (overlap transfer): median R²={np.median(valid_r2s):.4f}, " + f"mean={np.mean(valid_r2s):.4f}") + return pool_r2s + + +# ---- LOO with burn-in ---- + + +def run_ridge_burnin(resid_matrix, pool_ids, matched_clean, n_burnin=30): + """LOO with burn-in: learn pool i's weights from first n_burnin days.""" + print(f"\n{'='*70}") + print(f"Ridge LOO on noise residuals ({n_burnin}d burn-in, peers + own_lag)") + print(f"{'='*70}") + + n_dates, n_pools = resid_matrix.shape + + pool_r2s = [] + for i, pid in enumerate(pool_ids): + y_all = resid_matrix[:, i] + own_lag_all = np.full(n_dates, np.nan) + own_lag_all[1:] = y_all[:-1] + + valid_days = ~np.isnan(y_all) & ~np.isnan(own_lag_all) + valid_indices = np.where(valid_days)[0] + + if len(valid_indices) < n_burnin + 10: + pool_r2s.append(np.nan) + continue + + burn_idx = valid_indices[:n_burnin] + eval_idx = valid_indices[n_burnin:] + + pred_idx = [j for j in range(n_pools) if j != i] + + def build_X(indices): + X_peers = resid_matrix[indices - 1][:, pred_idx].copy() + for c in range(X_peers.shape[1]): + col = X_peers[:, c] + m = np.nanmean(col) + col[np.isnan(col)] = m if np.isfinite(m) else 0.0 + X_peers[:, c] = col + own = y_all[indices - 1] + return np.column_stack([X_peers, own[:, None]]) + + X_burn = build_X(burn_idx) + y_burn = y_all[burn_idx] + + model = RidgeCV(alphas=np.logspace(-2, 4, 50)) + model.fit(X_burn, y_burn) + + X_eval = build_X(eval_idx) + y_eval = y_all[eval_idx] + y_pred = model.predict(X_eval) + + r2 = r2_score(y_eval, y_pred) + pool_r2s.append(r2) + print(f" {pid[:16]} ({matched_clean[pid]['tokens']:<14}) " + f"R²={r2:.3f} n_eval={len(eval_idx)} alpha={model.alpha_:.1f}") + + valid_r2s = [r for r in pool_r2s if not np.isnan(r)] + print(f"\n LOO ridge ({n_burnin}d burn-in): " + f"median R²={np.median(valid_r2s):.4f}, mean={np.mean(valid_r2s):.4f}") + return pool_r2s + + +# ---- Main ---- + + +def main(): + os.environ.setdefault("JAX_PLATFORMS", "cpu") + + print("=" * 70) + print("Cross-Pool Linear Prediction of Noise Residuals") + print(" (total volume - grid arb volume)") + print("=" * 70) + + matched_clean, option_c_clean = load_stage1() + vol_matrix, resid_matrix, date_list, pool_ids = build_noise_residual_matrix( + matched_clean, option_c_clean) + + ar1_r2s = run_ar1_noise_baseline(resid_matrix, pool_ids, matched_clean) + insample_r2s = run_ridge_insample(resid_matrix, pool_ids, matched_clean) + loo_r2s = run_ridge_loo(resid_matrix, pool_ids, matched_clean) + burnin_r2s = run_ridge_burnin(resid_matrix, pool_ids, matched_clean, n_burnin=30) + + def safe_median(xs): + v = [x for x in xs if x is not None and not np.isnan(x)] + return np.median(v) if v else float("nan") + + print("\n" + "=" * 70) + print("SUMMARY (noise residuals)") + print("=" * 70) + print(f" AR1 noise residual: median R² = {safe_median(ar1_r2s):.4f}") + print(f" Ridge in-sample (+own): median R² = {safe_median(insample_r2s):.4f}") + print(f" Ridge LOO (overlap xfer): median R² = {safe_median(loo_r2s):.4f}") + print(f" Ridge LOO (30d burn-in): median R² = {safe_median(burnin_r2s):.4f}") + print(f" ---") + print(f" (total vol) AR1: median R² = 0.397") + print(f" (total vol) Ridge +own: median R² = 0.599") + print(f" Option C in-sample: median R² = 0.589") + + +if __name__ == "__main__": + main() diff --git a/experiments/run_deconfounder_noise.py b/experiments/run_deconfounder_noise.py new file mode 100644 index 0000000..cd761b9 --- /dev/null +++ b/experiments/run_deconfounder_noise.py @@ -0,0 +1,555 @@ +"""Causal noise volume estimation: TVL decomposition + deconfounder sensitivity. + +Two identification strategies for the causal effect of TVL on noise volume: + +**Primary: TVL decomposition (IV-style)** + Decomposes Δlog(TVL) into: + - Price-driven: Δlog(TVL) - Δlog(shares) — market price moves, more + exogenous to pool-specific trading activity (conditional on BTC/token + market features already in the model) + - Flow-driven: Δlog(shares) — LP deposits/withdrawals, endogenous + (LPs deposit when they expect fees → correlated with noise) + If b_tvl estimated from price-driven variation ≈ observational b_tvl, + the coefficient is likely causal. + +**Secondary: Deconfounder sensitivity analysis (Wang & Blei 2019)** + Fit a factor model (PPCA) on covariates only (no outcome), extract + latent factors Z_hat, include them in the outcome model. + This is a sensitivity analysis: if b_tvl shifts substantially when + conditioning on Z_hat, there's evidence of unobserved confounding. + If it's stable, confounding through the covariate structure is small. + + NB: The deconfounder has known theoretical limitations (D'Amour 2019). + Wang & Blei (2020, arXiv:2003.04948) respond that D'Amour's + counterexamples violate the required assumptions (pinpointability). + The theory holds under its assumptions, but the key assumption (no + unobserved single-cause confounders) is domain-specific and + uncheckable. Results should be interpreted as sensitivity bounds. + +**Diagnostics:** + - Variance decomposition of log_tvl: between-pool vs within-pool + - Within-pool simple regression: Δlog(V_obs) on Δlog(TVL) per pool + - These test whether the observational b_tvl reflects cross-sectional + or temporal variation + +Usage: + python experiments/run_deconfounder_noise.py + python experiments/run_deconfounder_noise.py --n-factors 1 2 3 5 +""" + +import argparse +import os +import time + +import jax.numpy as jnp +import numpy as np + + +# ---- Factor model ---- + + +def fit_ppca(X, n_components): + """Probabilistic PCA. Returns Z_hat and the model.""" + from sklearn.decomposition import PCA + pca = PCA(n_components=n_components) + Z_hat = pca.fit_transform(X) + print(f" PPCA({n_components}): explained var = " + f"{pca.explained_variance_ratio_.sum():.3f} " + f"per-component: {np.round(pca.explained_variance_ratio_, 3)}") + return Z_hat, pca + + +def build_augmented_data(data, Z_hat): + """Augment covariate matrix with standardized substitute confounders.""" + x_orig = data["x"] + n_z = Z_hat.shape[1] + + z_mean = Z_hat.mean(axis=0) + z_std = Z_hat.std(axis=0) + z_std[z_std < 1e-6] = 1.0 + Z_std = ((Z_hat - z_mean) / z_std).astype(np.float32) + + x_aug = np.concatenate([x_orig, Z_std], axis=1) + data_aug = dict(data) + data_aug["x"] = x_aug + data_aug["n_feat"] = x_aug.shape[1] + data_aug["feat_names"] = data["feat_names"] + [f"Z_{k}" for k in range(n_z)] + data_aug["x_mean"] = np.concatenate([ + data["x_mean"], z_mean.astype(np.float32)]) + data_aug["x_std"] = np.concatenate([ + data["x_std"], z_std.astype(np.float32)]) + return data_aug + + +def _tvl_col_index(feat_names): + """Find TVL column index from feature names (robust to reordering).""" + return feat_names.index("xobs_1") + + +def _intercept_col_index(feat_names): + """Find intercept column index.""" + return feat_names.index("xobs_0") + + +# ---- TVL decomposition ---- + + +def decompose_tvl(matched_clean, pool_ids, sample_pools, sample_days, + date_to_idx, n_dates, n_pools): + """Decompose Δlog(TVL) into price-driven and flow-driven components. + + Uses log_tvl (not log_tvl_lag1) for the decomposition to avoid + mixing lags. total_shares is assumed to be contemporaneous with TVL. + + flow = Δlog(shares) — LP deposits/withdrawals + price = Δlog(tvl) - Δlog(shares) — price changes + + Returns per-sample arrays and a validity mask. + """ + log_shares = np.full((n_dates, n_pools), np.nan) + log_tvl = np.full((n_dates, n_pools), np.nan) + + for j, pid in enumerate(pool_ids): + panel = matched_clean[pid]["panel"] + dates = panel["date"].values + + has_shares = ("total_shares" in panel.columns and + panel["total_shares"].notna().any()) + if not has_shares: + continue + + shares = panel["total_shares"].values.astype(float) + shares = np.maximum(shares, 1e-10) + + # Use log_tvl (not lag) for contemporaneous decomposition + if "log_tvl" in panel.columns: + tvl_vals = panel["log_tvl"].values.astype(float) + else: + tvl_vals = panel["log_tvl_lag1"].values.astype(float) + + for k, date in enumerate(dates): + t = date_to_idx.get(date) + if t is not None: + log_shares[t, j] = np.log(shares[k]) + log_tvl[t, j] = tvl_vals[k] + + n_samples = len(sample_pools) + tvl_flow = np.full(n_samples, np.nan, dtype=np.float32) + tvl_price = np.full(n_samples, np.nan, dtype=np.float32) + + for s in range(n_samples): + i = sample_pools[s] + t = sample_days[s] + if t >= 1: + d_log_shares = log_shares[t, i] - log_shares[t - 1, i] + d_log_tvl = log_tvl[t, i] - log_tvl[t - 1, i] + if np.isfinite(d_log_shares) and np.isfinite(d_log_tvl): + tvl_flow[s] = d_log_shares + tvl_price[s] = d_log_tvl - d_log_shares + + valid = np.isfinite(tvl_flow) & np.isfinite(tvl_price) + return tvl_flow, tvl_price, valid + + +def run_tvl_decomposition_analysis(data, matched_clean, tvl_flow, tvl_price, valid): + """Primary identification: compare b_tvl from price-driven vs all TVL.""" + from sklearn.linear_model import RidgeCV + + y = data["y_total"] + x = data["x"] + tvl_idx = _tvl_col_index(data["feat_names"]) + + print(f"\n Valid samples (have LP shares data): {valid.sum()}/{len(valid)}") + if valid.sum() < 100: + print(" Insufficient LP shares data for TVL decomposition.") + return None + + x_valid = x[valid] + y_valid = y[valid] + tvl_price_valid = tvl_price[valid] + tvl_flow_valid = tvl_flow[valid] + + print(f" Price component: mean={tvl_price_valid.mean():.4f}," + f" std={tvl_price_valid.std():.4f}") + print(f" Flow component: mean={tvl_flow_valid.mean():.4f}," + f" std={tvl_flow_valid.std():.4f}") + + # Observational b_tvl (Ridge on all features) + ridge_obs = RidgeCV(alphas=np.logspace(-2, 4, 50)) + ridge_obs.fit(x_valid, y_valid) + b_tvl_obs = ridge_obs.coef_[tvl_idx] + + # Replace TVL column with price-driven component only + x_price = x_valid.copy() + ps = tvl_price_valid.std() + x_price[:, tvl_idx] = (tvl_price_valid - tvl_price_valid.mean()) / max(ps, 1e-6) + ridge_price = RidgeCV(alphas=np.logspace(-2, 4, 50)) + ridge_price.fit(x_price, y_valid) + b_tvl_price = ridge_price.coef_[tvl_idx] + + # Replace TVL column with flow-driven component only + x_flow = x_valid.copy() + fs = tvl_flow_valid.std() + x_flow[:, tvl_idx] = (tvl_flow_valid - tvl_flow_valid.mean()) / max(fs, 1e-6) + ridge_flow = RidgeCV(alphas=np.logspace(-2, 4, 50)) + ridge_flow.fit(x_flow, y_valid) + b_tvl_flow = ridge_flow.coef_[tvl_idx] + + print(f"\n b_tvl estimates (Ridge, all 22 features):") + print(f" All TVL variation: {b_tvl_obs:+.4f}") + print(f" Price-driven only: {b_tvl_price:+.4f}" + f" (more exogenous)") + print(f" Flow-driven only: {b_tvl_flow:+.4f}" + f" (endogenous)") + + if abs(b_tvl_price - b_tvl_obs) < 0.3 * abs(b_tvl_obs): + print(f"\n → Price-driven ≈ observational: confounding small.") + else: + print(f"\n → Price-driven ≠ observational: potential confounding.") + + return {"obs": b_tvl_obs, "price": b_tvl_price, "flow": b_tvl_flow} + + +# ---- Variance decomposition ---- + + +def run_variance_decomposition(matched_clean, pool_ids): + """Decompose log_tvl variance into between-pool and within-pool.""" + all_tvls = [] + pool_labels = [] + + for j, pid in enumerate(pool_ids): + panel = matched_clean[pid]["panel"] + tvls = panel["log_tvl_lag1"].values.astype(float) + valid = np.isfinite(tvls) + all_tvls.extend(tvls[valid]) + pool_labels.extend([j] * valid.sum()) + + all_tvls = np.array(all_tvls) + pool_labels = np.array(pool_labels) + + total_var = np.var(all_tvls) + pool_means = np.array([all_tvls[pool_labels == j].mean() + for j in range(len(pool_ids))]) + between_var = np.var(pool_means) + within_vars = [np.var(all_tvls[pool_labels == j]) + for j in range(len(pool_ids))] + within_var = np.mean(within_vars) + + print(f" Total variance: {total_var:.4f}") + print(f" Between-pool: {between_var:.4f} ({between_var/total_var*100:.1f}%)") + print(f" Within-pool (avg): {within_var:.4f} ({within_var/total_var*100:.1f}%)") + print(f" Pool mean range: {pool_means.min():.1f} to {pool_means.max():.1f}") + + return {"total": total_var, "between": between_var, "within": within_var} + + +# ---- Within-pool simple regression ---- + + +def run_within_pool_regressions(matched_clean, pool_ids): + """Per-pool: Δlog(V_obs) on Δlog(TVL), no other covariates.""" + print(f"\n {'Pool':16s} {'Tokens':16s} {'b_tvl':>8s} {'R²':>6s}" + f" {'n':>5s} {'ΔTVL_std':>8s}") + print(f" {'-'*65}") + + b_tvls = [] + for pid in pool_ids: + panel = matched_clean[pid]["panel"] + log_vol = panel["log_volume"].values.astype(float) + log_tvl = panel["log_tvl_lag1"].values.astype(float) + + d_vol = np.diff(log_vol) + d_tvl = np.diff(log_tvl) + + valid = np.isfinite(d_vol) & np.isfinite(d_tvl) + if valid.sum() < 10: + continue + + dv = d_vol[valid] + dt = d_tvl[valid] + + # OLS: Δlog_vol = a + b * Δlog_tvl + X = np.column_stack([np.ones(len(dt)), dt]) + sol, _, _, _ = np.linalg.lstsq(X, dv, rcond=None) + b = sol[1] + pred = X @ sol + ss_res = np.sum((dv - pred) ** 2) + ss_tot = np.sum((dv - dv.mean()) ** 2) + r2 = 1 - ss_res / max(ss_tot, 1e-10) + + b_tvls.append(b) + tokens = matched_clean[pid]["tokens"] + print(f" {pid[:16]:16s} {tokens[:16]:16s} {b:+8.3f} {r2:6.3f}" + f" {valid.sum():5d} {dt.std():8.4f}") + + if b_tvls: + print(f"\n Median within-pool b_tvl: {np.median(b_tvls):+.4f}") + print(f" Mean: {np.mean(b_tvls):+.4f}") + print(f" Std across pools: {np.std(b_tvls):.4f}") + return b_tvls + + +# ---- Lagged-average TVL analysis ---- + + +def run_lagged_average_analysis(matched_clean, pool_ids): + """Test TVL→noise at different timescales. + + If the daily Δ elasticity is ~0 but the level elasticity is ~2.5, + the effect may operate on longer timescales. Test by regressing + noise on rolling-average TVL at windows of 7, 14, 30, 60, 90 days. + If b_tvl grows with window size, the relationship is real but slow. + """ + windows = [1, 7, 14, 30, 60, 90] + + print(f"\n Window Median b_tvl Mean b_tvl Pools w/ data") + print(f" {'-'*55}") + + for w in windows: + b_tvls = [] + n_pools_used = 0 + for pid in pool_ids: + panel = matched_clean[pid]["panel"] + log_vol = panel["log_volume"].values.astype(float) + log_tvl = panel["log_tvl_lag1"].values.astype(float) + + if len(log_vol) < w + 10: + continue + + # Rolling mean TVL over window w + if w == 1: + tvl_avg = log_tvl + else: + # Simple trailing average + tvl_avg = np.full_like(log_tvl, np.nan) + for t in range(w, len(log_tvl)): + vals = log_tvl[t - w:t] + if np.all(np.isfinite(vals)): + tvl_avg[t] = np.mean(vals) + + # Within-pool: demean both series + valid = np.isfinite(log_vol) & np.isfinite(tvl_avg) + if valid.sum() < 15: + continue + + vol = log_vol[valid] + tvl = tvl_avg[valid] + vol_dm = vol - vol.mean() + tvl_dm = tvl - tvl.mean() + + # OLS: demeaned_vol = b * demeaned_tvl + if np.var(tvl_dm) < 1e-10: + continue + b = np.sum(vol_dm * tvl_dm) / np.sum(tvl_dm ** 2) + b_tvls.append(b) + n_pools_used += 1 + + if b_tvls: + print(f" {w:5d}d {np.median(b_tvls):+11.4f} {np.mean(b_tvls):+10.4f}" + f" {n_pools_used:>13d}") + + return windows + + +# ---- Deconfounder sensitivity ---- + + +def run_deconfounder(data, n_factors_list, args): + """Secondary: deconfounder sensitivity analysis across n_factors.""" + from experiments.run_linear_market_noise import make_loss_fn, train + from sklearn.linear_model import RidgeCV + + X = data["x"] + tvl_idx = _tvl_col_index(data["feat_names"]) + intercept_idx = _intercept_col_index(data["feat_names"]) + results = {} + + for n_f in n_factors_list: + print(f"\n --- n_factors={n_f} ---") + Z_hat, _ = fit_ppca(X, n_f) + data_aug = build_augmented_data(data, Z_hat) + + n_feat = data_aug["n_feat"] + n_pools = data_aug["n_pools"] + + # Ridge warm-start + ridge = RidgeCV(alphas=np.logspace(-2, 4, 50)) + ridge.fit(data_aug["x"], data_aug["y_total"]) + sol = ridge.coef_.copy() + sol[intercept_idx] += ridge.intercept_ + + params = { + "log_cadence": jnp.array(data_aug["init_log_cadences"]), + "noise_coeffs": jnp.array(sol.astype(np.float32)), + } + + grad_fn = make_loss_fn(data_aug["pool_coeffs"], data_aug["pool_gas"], n_pools) + params = train(params, data_aug, grad_fn, args.epochs, args.lr, + args.l2_alpha, args.huber_delta, verbose=False) + + nc = np.array(params["noise_coeffs"]) + b_tvl = nc[tvl_idx] + n_orig = data["n_feat"] + z_coeffs = nc[n_orig:n_orig + n_f] + + print(f" b_tvl = {b_tvl:+.4f} " + f" Z coeffs: {np.round(z_coeffs, 3)}") + + results[n_f] = { + "b_tvl": float(b_tvl), + "z_coeffs": z_coeffs.tolist(), + "explained_var": float(Z_hat.var(axis=0).sum() / X.var(axis=0).sum()), + } + + return results + + +def main(): + parser = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--n-factors", type=int, nargs="+", default=[1, 2, 3, 5], + help="Number of latent factors to sweep") + parser.add_argument("--epochs", type=int, default=2000) + parser.add_argument("--lr", type=float, default=3e-4) + parser.add_argument("--l2-alpha", type=float, default=1e-3) + parser.add_argument("--huber-delta", type=float, default=1.0) + parser.add_argument("--trend-windows", type=int, nargs="+", default=[7]) + args = parser.parse_args() + + os.environ.setdefault("JAX_PLATFORMS", "cpu") + + print("=" * 70) + print("Causal Noise Volume Estimation") + print(" 1. Variance decomposition (between vs within pool)") + print(" 2. Within-pool simple regressions") + print(" 3. TVL decomposition (price vs flow)") + print(" 4. Deconfounder sensitivity (PPCA factors)") + print(f" n_factors sweep: {args.n_factors}") + print("=" * 70) + + from experiments.run_linear_market_noise import load_stage1, build_data + + matched_clean, option_c_clean = load_stage1() + + print("\nBuilding data...") + t0 = time.time() + data = build_data( + matched_clean, option_c_clean, + trend_windows=tuple(args.trend_windows), + include_market=True, include_cross_pool=True, + ) + pool_ids = data["pool_ids"] + n_pools = data["n_pools"] + print(f" {len(data['pool_idx'])} samples, {n_pools} pools," + f" {data['n_feat']} features, {time.time() - t0:.1f}s") + + # ---- 1. Variance decomposition ---- + print(f"\n{'='*70}") + print("1. Variance Decomposition of log_tvl_lag1") + print(f"{'='*70}") + var_results = run_variance_decomposition(matched_clean, pool_ids) + + # ---- 2. Within-pool simple regressions ---- + print(f"\n{'='*70}") + print("2. Within-pool: Δlog(V_obs) ~ Δlog(TVL) (no other covariates)") + print(f"{'='*70}") + within_b_tvls = run_within_pool_regressions(matched_clean, pool_ids) + + # ---- 2b. Lagged-average TVL (timescale test) ---- + print(f"\n{'='*70}") + print("2b. Lagged-Average TVL: Does b_tvl grow with averaging window?") + print(f" (Tests whether the TVL→noise effect is slow-moving)") + print(f"{'='*70}") + run_lagged_average_analysis(matched_clean, pool_ids) + + # ---- 3. TVL decomposition ---- + print(f"\n{'='*70}") + print("3. TVL Decomposition: Price-driven vs Flow-driven") + print(f"{'='*70}") + + all_dates = set() + for pid in pool_ids: + all_dates.update(matched_clean[pid]["panel"]["date"].values) + date_list = sorted(all_dates) + date_to_idx = {d: i for i, d in enumerate(date_list)} + + tvl_flow, tvl_price, valid = decompose_tvl( + matched_clean, pool_ids, data["pool_idx"], data["day_idx"], + date_to_idx, len(date_list), n_pools, + ) + tvl_results = run_tvl_decomposition_analysis( + data, matched_clean, tvl_flow, tvl_price, valid) + + # ---- 4. Deconfounder sensitivity ---- + print(f"\n{'='*70}") + print("4. Deconfounder Sensitivity Analysis") + print(f" (Wang & Blei 2019; D'Amour 2019; Wang & Blei 2020)") + print(f"{'='*70}") + deconf_results = run_deconfounder(data, args.n_factors, args) + + # ---- Summary ---- + print(f"\n{'='*70}") + print("SUMMARY: b_tvl across identification strategies") + print(f"{'='*70}") + + # Observational baseline from artifact + obs_path = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "linear_market_noise", "model.npz", + ) + if os.path.exists(obs_path): + obs_nc = np.load(obs_path)["noise_coeffs"] + tvl_idx = _tvl_col_index(data["feat_names"]) + if obs_nc.ndim == 2: + b_obs = float(np.median(obs_nc[:, tvl_idx])) + print(f" Observational (per-pool median): {b_obs:+.4f}") + else: + b_obs = float(obs_nc[tvl_idx]) + print(f" Observational (shared): {b_obs:+.4f}") + + between_pct = var_results["between"] / var_results["total"] * 100 + print(f"\n Variance decomposition: {between_pct:.0f}% between-pool," + f" {100-between_pct:.0f}% within-pool") + + if within_b_tvls: + print(f" Within-pool Δ regressions: " + f"median={np.median(within_b_tvls):+.4f}" + f" (mean={np.mean(within_b_tvls):+.4f})") + + if tvl_results: + print(f"\n TVL decomposition (Ridge, 22 features):") + print(f" All variation: {tvl_results['obs']:+.4f}") + print(f" Price-driven (exogenous): {tvl_results['price']:+.4f}") + print(f" Flow-driven (endogenous): {tvl_results['flow']:+.4f}") + + print(f"\n Deconfounder sensitivity (shared, learnable cadence):") + print(f" {'n_factors':>10s} {'b_tvl':>8s}") + for n_f, r in sorted(deconf_results.items()): + print(f" {n_f:>10d} {r['b_tvl']:+8.4f}") + + # Stability + b_tvls_d = [r["b_tvl"] for r in deconf_results.values()] + rng = max(b_tvls_d) - min(b_tvls_d) + mn = np.mean(b_tvls_d) + stable = rng < 0.3 * abs(mn) + print(f"\n Deconfounder: {'STABLE' if stable else 'VARIES'}" + f" (range {rng:.3f}, mean {mn:+.3f})") + + print(f"\n Interpretation:") + if tvl_results and abs(tvl_results['price']) < 0.5: + print(f" Daily b_tvl (Δ regression, price-driven) is near zero.") + print(f" This does NOT mean the long-run effect is zero:") + print(f" - Noise may respond slowly to TVL (routing updates,") + print(f" aggregator discovery, ecosystem integration)") + print(f" - The lagged-average analysis above tests this") + print(f" - The per-pool b_tvl of ~1.0 captures medium-frequency") + print(f" within-pool variation and is the best working estimate") + print(f" - Changing reClAMM concentration is a structural change") + print(f" (like being a different pool), not a daily TVL shock") + print(f" → Use per-pool b_tvl (~1.0) for counterfactuals, with") + print(f" sensitivity analysis across [0.5, 1.0, 2.0]") + + +if __name__ == "__main__": + main() diff --git a/experiments/run_deepsets_noise.py b/experiments/run_deepsets_noise.py new file mode 100644 index 0000000..a639256 --- /dev/null +++ b/experiments/run_deepsets_noise.py @@ -0,0 +1,595 @@ +"""DeepSets noise volume prediction with V_arb decomposition. + +Predicts V_noise via a shared encoder-decoder over peer pools. +V_arb is precomputed from grids at Option C cadence/gas. +Loss: mean((log(V_arb + V_noise_predicted) - log_volume)^2) + +Usage: + python experiments/run_deepsets_noise.py # default hparams + python experiments/run_deepsets_noise.py --tune 50 # Optuna, 50 trials +""" + +import argparse +import os +import pickle +import sys +import time + +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd + +CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "token_factored_calibration", "_cache", +) + +# Default hyperparameters +DEFAULTS = dict( + hidden=16, + d_embed=8, + lr=3e-4, + l2_alpha=1e-3, + n_epochs=1000, + include_own_lag=True, +) + + +def load_stage1(): + path = os.path.join(CACHE_DIR, "stage1.pkl") + if not os.path.exists(path): + print("ERROR: no stage1 cache.") + sys.exit(1) + with open(path, "rb") as f: + data = pickle.load(f) + return data["matched_clean"], data["option_c_clean"] + + +# ---- Data construction ---- + + +def build_data(matched_clean, option_c_clean, exclude_pool_idx=None): + """Build training arrays with V_arb decomposition. + + For each (pool i, day t) sample: + - peer_vols: other pools' log_volume at t-1 + - v_arb: precomputed arb volume for pool i at day t + - local_features: [log_tvl_lag1, dow_sin, dow_cos] + - own_lag: pool i's log_volume at t-1 + - y: log_volume at day t + """ + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + from quantammsim.calibration.pool_data import ( + build_pool_attributes, _parse_tokens, _canonicalize_token, + ) + + pool_ids = sorted(matched_clean.keys()) + n_pools = len(pool_ids) + + # ---- Collect all dates, build volume + V_arb matrices ---- + all_dates = set() + for pid in pool_ids: + all_dates.update(matched_clean[pid]["panel"]["date"].values) + date_list = sorted(all_dates) + n_dates = len(date_list) + date_to_idx = {d: i for i, d in enumerate(date_list)} + + vol_matrix = np.full((n_dates, n_pools), np.nan) + v_arb_matrix = np.full((n_dates, n_pools), np.nan) + tvl_matrix = np.full((n_dates, n_pools), np.nan) + weekday_matrix = np.full(n_dates, np.nan) + + for j, pid in enumerate(pool_ids): + entry = matched_clean[pid] + oc = option_c_clean[pid] + panel = entry["panel"] + + # V_arb from grid + v_arb_all = np.array(interpolate_pool_daily( + entry["coeffs"], + jnp.float64(oc["log_cadence"]), + jnp.float64(np.exp(oc["log_gas"])), + )) + v_arb_day = v_arb_all[entry["day_indices"]] + + dates = panel["date"].values + log_vols = panel["log_volume"].values.astype(float) + tvl_vals = panel["log_tvl_lag1"].values.astype(float) + + for k, date in enumerate(dates): + t = date_to_idx[date] + vol_matrix[t, j] = log_vols[k] + v_arb_matrix[t, j] = v_arb_day[k] + tvl_matrix[t, j] = tvl_vals[k] + + # Weekdays + for t, date in enumerate(date_list): + dt = pd.Timestamp(date) + weekday_matrix[t] = dt.weekday() + + # ---- Pool attributes ---- + X_attr, attr_names, _ = build_pool_attributes(matched_clean) + attr_mean = np.mean(X_attr, axis=0) + attr_std = np.std(X_attr, axis=0) + attr_std[attr_std < 1e-6] = 1.0 + X_attr_norm = ((X_attr - attr_mean) / attr_std).astype(np.float32) + k_attr = X_attr_norm.shape[1] + + # ---- Token overlap ---- + pool_tokens = {} + for i, pid in enumerate(pool_ids): + toks = _parse_tokens(matched_clean[pid]["tokens"]) + pool_tokens[i] = {_canonicalize_token(t) for t in toks[:2]} + + n_peers = n_pools - 1 + peer_attrs = np.zeros((n_pools, n_peers, k_attr), dtype=np.float32) + peer_overlap = np.zeros((n_pools, n_peers), dtype=np.float32) + peer_col_idx = np.zeros((n_pools, n_peers), dtype=np.int32) + + for i in range(n_pools): + peers = [j for j in range(n_pools) if j != i] + for p, j in enumerate(peers): + peer_attrs[i, p] = X_attr_norm[j] + peer_overlap[i, p] = len(pool_tokens[i] & pool_tokens[j]) + peer_col_idx[i, p] = j + + target_attrs = X_attr_norm + + # ---- Standardize volumes for encoder input ---- + vol_mean = float(np.nanmean(vol_matrix)) + vol_std = float(np.nanstd(vol_matrix)) + + # ---- Build samples ---- + sample_pools, sample_days = [], [] + for i in range(n_pools): + if i == exclude_pool_idx: + continue + for t in range(1, n_dates): + if (np.isnan(vol_matrix[t, i]) or np.isnan(vol_matrix[t - 1, i]) + or np.isnan(v_arb_matrix[t, i]) or np.isnan(tvl_matrix[t, i])): + continue + sample_pools.append(i) + sample_days.append(t) + + sample_pools = np.array(sample_pools, dtype=np.int32) + sample_days = np.array(sample_days, dtype=np.int32) + n_samples = len(sample_pools) + + peer_vols_arr = np.zeros((n_samples, n_peers), dtype=np.float32) + peer_mask_arr = np.zeros((n_samples, n_peers), dtype=np.float32) + own_lag_arr = np.zeros(n_samples, dtype=np.float32) + v_arb_arr = np.zeros(n_samples, dtype=np.float32) + local_arr = np.zeros((n_samples, 3), dtype=np.float32) # tvl, dow_sin, dow_cos + y_arr = np.zeros(n_samples, dtype=np.float32) + + for s in range(n_samples): + i = sample_pools[s] + t = sample_days[s] + cols = peer_col_idx[i] + + pvols_raw = vol_matrix[t - 1, cols] + valid = ~np.isnan(pvols_raw) + pvols_norm = (pvols_raw - vol_mean) / vol_std + peer_vols_arr[s] = np.where(valid, pvols_norm, 0.0) + peer_mask_arr[s] = valid.astype(np.float32) + + own_lag_arr[s] = (vol_matrix[t - 1, i] - vol_mean) / vol_std + v_arb_arr[s] = v_arb_matrix[t, i] + y_arr[s] = vol_matrix[t, i] # raw log_volume (not standardized) + + wd = weekday_matrix[t] + local_arr[s, 0] = tvl_matrix[t, i] + local_arr[s, 1] = np.sin(2 * np.pi * wd / 7) + local_arr[s, 2] = np.cos(2 * np.pi * wd / 7) + + # Standardize local features + local_mean = np.mean(local_arr, axis=0) + local_std = np.std(local_arr, axis=0) + local_std[local_std < 1e-6] = 1.0 + local_arr = ((local_arr - local_mean) / local_std).astype(np.float32) + + return { + "peer_attrs": jnp.array(peer_attrs), + "target_attrs": jnp.array(target_attrs), + "peer_overlap": jnp.array(peer_overlap), + "peer_vols": jnp.array(peer_vols_arr), + "peer_mask": jnp.array(peer_mask_arr), + "own_lag": jnp.array(own_lag_arr), + "v_arb": jnp.array(v_arb_arr), + "local": jnp.array(local_arr), + "y": jnp.array(y_arr), + "pool_idx": jnp.array(sample_pools), + "day_idx": sample_days, + "n_pools": n_pools, + "n_peers": n_peers, + "k_attr": k_attr, + "k_local": 3, + "pool_ids": pool_ids, + } + + +# ---- Model ---- + + +def init_params(key, k_attr, k_local, hidden, d_embed, include_own_lag): + k1, k2, k3, k4 = jax.random.split(key, 4) + enc_in = 2 * k_attr + 2 # peer_attr + target_attr + peer_vol + overlap + dec_in = d_embed + k_attr + k_local + (1 if include_own_lag else 0) + + return { + "enc_W1": jax.random.normal(k1, (enc_in, hidden)) * np.sqrt(2.0 / enc_in), + "enc_b1": jnp.zeros(hidden), + "enc_W2": jax.random.normal(k2, (hidden, d_embed)) * np.sqrt(2.0 / hidden), + "enc_b2": jnp.zeros(d_embed), + "dec_W1": jax.random.normal(k3, (dec_in, hidden)) * np.sqrt(2.0 / dec_in), + "dec_b1": jnp.zeros(hidden), + "dec_W2": jax.random.normal(k4, (hidden, 1)) * 0.01, + "dec_b2": jnp.zeros(1), + } + + +def forward(params, peer_attrs_all, target_attrs_all, peer_overlap_all, + peer_vols, peer_mask, own_lag, local_feat, pool_idx, + include_own_lag=True): + """Returns log_v_noise per sample.""" + batch = peer_vols.shape[0] + + pa = peer_attrs_all[pool_idx] + ta = target_attrs_all[pool_idx] + ov = peer_overlap_all[pool_idx] + ta_broad = jnp.broadcast_to(ta[:, None, :], pa.shape) + + enc_in = jnp.concatenate([ + pa, ta_broad, + peer_vols[:, :, None], + ov[:, :, None], + ], axis=-1) + + flat = enc_in.reshape(-1, enc_in.shape[-1]) + h = jnp.maximum(flat @ params["enc_W1"] + params["enc_b1"], 0.0) + h = h @ params["enc_W2"] + params["enc_b2"] + h = h.reshape(batch, peer_vols.shape[1], -1) + + h_masked = h * peer_mask[:, :, None] + n_valid = jnp.maximum(jnp.sum(peer_mask, axis=1, keepdims=True), 1.0) + summary = jnp.sum(h_masked, axis=1) / n_valid + + dec_parts = [summary, ta, local_feat] + if include_own_lag: + dec_parts.append(own_lag[:, None]) + dec_in = jnp.concatenate(dec_parts, axis=-1) + + h_dec = jnp.maximum(dec_in @ params["dec_W1"] + params["dec_b1"], 0.0) + log_v_noise = (h_dec @ params["dec_W2"] + params["dec_b2"])[:, 0] + return log_v_noise + + +def loss_fn(params, static, peer_vols, peer_mask, own_lag, local_feat, + pool_idx, v_arb, y, l2_alpha, include_own_lag): + """Log-space V_arb + V_noise loss matching the calibration pipeline.""" + log_v_noise = forward( + params, static["peer_attrs"], static["target_attrs"], + static["peer_overlap"], peer_vols, peer_mask, own_lag, local_feat, + pool_idx, include_own_lag, + ) + v_noise = jnp.exp(log_v_noise) + log_v_pred = jnp.log(jnp.maximum(v_arb + v_noise, 1e-6)) + mse = jnp.mean((log_v_pred - y) ** 2) + reg = sum(jnp.sum(v ** 2) for k, v in params.items() if "W" in k) + return mse + alpha * reg if (alpha := l2_alpha) else mse + l2_alpha * reg + + +@jax.jit +def _loss_and_grad(params, static, peer_vols, peer_mask, own_lag, local_feat, + pool_idx, v_arb, y, l2_alpha, include_own_lag): + return jax.value_and_grad(loss_fn)( + params, static, peer_vols, peer_mask, own_lag, local_feat, + pool_idx, v_arb, y, l2_alpha, include_own_lag, + ) + + +# ---- Training ---- + + +def train(params, data, hparams, verbose=True): + """Full-batch Adam.""" + static = {k: data[k] for k in ["peer_attrs", "target_attrs", "peer_overlap"]} + include_own_lag = hparams["include_own_lag"] + lr = hparams["lr"] + l2_alpha = hparams["l2_alpha"] + n_epochs = hparams["n_epochs"] + + m = {k: jnp.zeros_like(v) for k, v in params.items()} + v = {k: jnp.zeros_like(v) for k, v in params.items()} + + for epoch in range(n_epochs): + loss_val, grads = _loss_and_grad( + params, static, data["peer_vols"], data["peer_mask"], + data["own_lag"], data["local"], data["pool_idx"], + data["v_arb"], data["y"], l2_alpha, include_own_lag, + ) + + for k in params: + m[k] = 0.9 * m[k] + 0.1 * grads[k] + v[k] = 0.999 * v[k] + 0.001 * grads[k] ** 2 + m_hat = m[k] / (1.0 - 0.9 ** (epoch + 1)) + v_hat = v[k] / (1.0 - 0.999 ** (epoch + 1)) + params[k] = params[k] - lr * m_hat / (jnp.sqrt(v_hat) + 1e-8) + + if verbose and (epoch % 200 == 0 or epoch == n_epochs - 1): + print(f" epoch {epoch:4d} loss={float(loss_val):.6f}") + + return params, float(loss_val) + + +# ---- Evaluation ---- + + +def per_pool_r2(params, data, hparams): + """Per-pool R² on the log(V_arb + V_noise) prediction.""" + static = {k: data[k] for k in ["peer_attrs", "target_attrs", "peer_overlap"]} + log_v_noise = np.array(forward( + params, static["peer_attrs"], static["target_attrs"], + static["peer_overlap"], data["peer_vols"], data["peer_mask"], + data["own_lag"], data["local"], data["pool_idx"], + hparams["include_own_lag"], + )) + v_noise = np.exp(log_v_noise) + v_arb = np.array(data["v_arb"]) + log_v_pred = np.log(np.maximum(v_arb + v_noise, 1e-6)) + y = np.array(data["y"]) + pool_idx = np.array(data["pool_idx"]) + + r2s = {} + for i in range(data["n_pools"]): + mask = pool_idx == i + if mask.sum() < 2: + continue + yi = y[mask] + pi = log_v_pred[mask] + ss_res = np.sum((yi - pi) ** 2) + ss_tot = np.sum((yi - yi.mean()) ** 2) + r2s[i] = 1 - ss_res / max(ss_tot, 1e-10) + return r2s + + +def subset_data(data, mask): + """Subset data arrays by boolean mask.""" + jmask = jnp.array(mask) + static_keys = ["peer_attrs", "target_attrs", "peer_overlap", + "n_pools", "n_peers", "k_attr", "k_local", "pool_ids"] + out = {k: data[k] for k in static_keys} + for k in ["peer_vols", "peer_mask", "own_lag", "v_arb", "local", "y", "pool_idx"]: + out[k] = data[k][jmask] + out["day_idx"] = np.array(data["day_idx"])[mask] + return out + + +# ---- Experiments ---- + + +def run_single(matched_clean, option_c_clean, hparams, split_frac=0.7): + """Train with temporal split, report in-sample and eval R².""" + data = build_data(matched_clean, option_c_clean) + n_params = sum(v.size for v in init_params( + jax.random.PRNGKey(0), data["k_attr"], data["k_local"], + hparams["hidden"], hparams["d_embed"], hparams["include_own_lag"], + ).values()) + + print(f" {data['peer_vols'].shape[0]} samples, {data['n_pools']} pools, " + f"{n_params} params") + + day_idx = np.array(data["day_idx"]) + split_day = int(day_idx.max() * split_frac) + train_mask = day_idx <= split_day + eval_mask = day_idx > split_day + train_data = subset_data(data, train_mask) + eval_data = subset_data(data, eval_mask) + + print(f" Train: {int(train_mask.sum())} samples, " + f"Eval: {int(eval_mask.sum())} samples") + + params = init_params( + jax.random.PRNGKey(42), data["k_attr"], data["k_local"], + hparams["hidden"], hparams["d_embed"], hparams["include_own_lag"], + ) + t0 = time.time() + params, final_loss = train(params, train_data, hparams) + print(f" Training: {time.time() - t0:.1f}s, final loss={final_loss:.6f}") + + r2_train = per_pool_r2(params, train_data, hparams) + r2_eval = per_pool_r2(params, eval_data, hparams) + + pool_ids = data["pool_ids"] + for i, pid in enumerate(pool_ids): + r_tr = r2_train.get(i, float("nan")) + r_ev = r2_eval.get(i, float("nan")) + print(f" {pid[:16]} ({matched_clean[pid]['tokens']:<14}) " + f"train={r_tr:.3f} eval={r_ev:.3f}") + + vals_train = [v for v in r2_train.values() if np.isfinite(v)] + vals_eval = [v for v in r2_eval.values() if np.isfinite(v)] + med_train = np.median(vals_train) if vals_train else float("nan") + med_eval = np.median(vals_eval) if vals_eval else float("nan") + + print(f"\n Train: median R²={med_train:.4f}") + print(f" Eval: median R²={med_eval:.4f}") + print(f" (Option C in-sample: 0.589)") + + return med_eval, params, data + + +def run_loo(matched_clean, option_c_clean, hparams): + """Full LOO.""" + pool_ids = sorted(matched_clean.keys()) + n_pools = len(pool_ids) + loo_r2s = [] + + print(f"\n{'='*70}") + print("LOO DeepSets Noise") + print(f"{'='*70}") + + # Use fewer epochs for LOO + loo_hparams = dict(hparams, n_epochs=min(hparams["n_epochs"], 500)) + + for hold_out_idx in range(n_pools): + hold_out_pid = pool_ids[hold_out_idx] + train_data = build_data(matched_clean, option_c_clean, + exclude_pool_idx=hold_out_idx) + + params = init_params( + jax.random.PRNGKey(42), train_data["k_attr"], train_data["k_local"], + loo_hparams["hidden"], loo_hparams["d_embed"], + loo_hparams["include_own_lag"], + ) + params, _ = train(params, train_data, loo_hparams, verbose=False) + + # Eval on held-out pool + full_data = build_data(matched_clean, option_c_clean) + ho_mask = np.array(full_data["pool_idx"]) == hold_out_idx + if ho_mask.sum() < 2: + loo_r2s.append(float("nan")) + continue + + eval_data = subset_data(full_data, ho_mask) + r2s = per_pool_r2(params, eval_data, loo_hparams) + r2 = r2s.get(hold_out_idx, float("nan")) + loo_r2s.append(r2) + + tag = "OK" if r2 > 0 else "NEG" + print(f" {hold_out_pid[:16]} ({matched_clean[hold_out_pid]['tokens']:<14}) " + f"R²={r2:.3f} [{tag}]") + + valid = [r for r in loo_r2s if np.isfinite(r)] + med = np.median(valid) if valid else float("nan") + print(f"\n LOO: median R²={med:.4f}, " + f"mean={np.mean(valid):.4f}, " + f"n_neg={sum(1 for r in valid if r < 0)}") + return loo_r2s + + +# ---- Optuna ---- + + +def run_optuna(matched_clean, option_c_clean, n_trials): + """Hyperparameter optimization with Optuna.""" + import optuna + + # Precompute data once (shared across trials) + data = build_data(matched_clean, option_c_clean) + day_idx = np.array(data["day_idx"]) + split_day = int(day_idx.max() * 0.7) + train_mask = day_idx <= split_day + eval_mask = day_idx > split_day + train_data = subset_data(data, train_mask) + eval_data = subset_data(data, eval_mask) + + def objective(trial): + hp = { + "hidden": trial.suggest_categorical("hidden", [8, 16, 32]), + "d_embed": trial.suggest_categorical("d_embed", [4, 8, 16]), + "lr": trial.suggest_float("lr", 1e-4, 1e-2, log=True), + "l2_alpha": trial.suggest_float("l2_alpha", 1e-5, 1e-1, log=True), + "n_epochs": trial.suggest_categorical("n_epochs", [500, 1000, 2000]), + "include_own_lag": trial.suggest_categorical("include_own_lag", [True, False]), + } + + params = init_params( + jax.random.PRNGKey(42), data["k_attr"], data["k_local"], + hp["hidden"], hp["d_embed"], hp["include_own_lag"], + ) + params, final_loss = train(params, train_data, hp, verbose=False) + + r2s = per_pool_r2(params, eval_data, hp) + vals = [v for v in r2s.values() if np.isfinite(v)] + med_r2 = float(np.median(vals)) if vals else -10.0 + + # Report train R² too for diagnostics + r2s_tr = per_pool_r2(params, train_data, hp) + vals_tr = [v for v in r2s_tr.values() if np.isfinite(v)] + med_tr = float(np.median(vals_tr)) if vals_tr else -10.0 + + trial.set_user_attr("train_median_r2", med_tr) + trial.set_user_attr("final_loss", final_loss) + + print(f" Trial {trial.number}: eval={med_r2:.4f} train={med_tr:.4f} " + f"h={hp['hidden']} d={hp['d_embed']} lr={hp['lr']:.1e} " + f"alpha={hp['l2_alpha']:.1e} epochs={hp['n_epochs']} " + f"own_lag={hp['include_own_lag']}") + + return med_r2 + + study = optuna.create_study(direction="maximize") + study.optimize(objective, n_trials=n_trials) + + print(f"\n{'='*70}") + print("Optuna Results") + print(f"{'='*70}") + print(f" Best trial: {study.best_trial.number}") + print(f" Best eval median R²: {study.best_value:.4f}") + print(f" Best params: {study.best_params}") + print(f" Train median R²: {study.best_trial.user_attrs['train_median_r2']:.4f}") + + # Show top 5 + print(f"\n Top 5 trials:") + trials = sorted(study.trials, key=lambda t: t.value if t.value else -999, + reverse=True) + for t in trials[:5]: + if t.value is not None: + print(f" #{t.number}: eval={t.value:.4f} " + f"train={t.user_attrs.get('train_median_r2', '?'):.4f} " + f"{t.params}") + + return study + + +# ---- Main ---- + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--tune", type=int, default=0, + help="Run Optuna with N trials") + parser.add_argument("--loo", action="store_true", + help="Run LOO evaluation") + parser.add_argument("--hidden", type=int, default=DEFAULTS["hidden"]) + parser.add_argument("--d-embed", type=int, default=DEFAULTS["d_embed"]) + parser.add_argument("--lr", type=float, default=DEFAULTS["lr"]) + parser.add_argument("--l2-alpha", type=float, default=DEFAULTS["l2_alpha"]) + parser.add_argument("--epochs", type=int, default=DEFAULTS["n_epochs"]) + parser.add_argument("--no-own-lag", action="store_true") + args = parser.parse_args() + + os.environ.setdefault("JAX_PLATFORMS", "cpu") + + hparams = { + "hidden": args.hidden, + "d_embed": args.d_embed, + "lr": args.lr, + "l2_alpha": args.l2_alpha, + "n_epochs": args.epochs, + "include_own_lag": not args.no_own_lag, + } + + print("=" * 70) + print("DeepSets Noise Volume Prediction (V_arb decomposition)") + print(f" {hparams}") + print("=" * 70) + + matched_clean, option_c_clean = load_stage1() + + if args.tune > 0: + run_optuna(matched_clean, option_c_clean, args.tune) + else: + print(f"\n{'='*70}") + print("Temporal split (70/30)") + print(f"{'='*70}") + med_eval, params, data = run_single(matched_clean, option_c_clean, hparams) + + if args.loo: + run_loo(matched_clean, option_c_clean, hparams) + + +if __name__ == "__main__": + main() diff --git a/experiments/run_deepsets_v2.py b/experiments/run_deepsets_v2.py new file mode 100644 index 0000000..bc3661b --- /dev/null +++ b/experiments/run_deepsets_v2.py @@ -0,0 +1,1656 @@ +"""DeepSets v2: full feature menu with Optuna feature selection. + +Trains on total log_volume, evaluates on both total volume and noise +residual (log_vol - log_V_arb). V_arb precomputed from Option C fits. + +Feature menu: + Peer (encoder) — always: peer_attr, target_attr, vol_lag1, overlap + optional: vol_lag2, vol_change, tvl, volatility + relational: same_chain, log_tvl_ratio, log_fee_ratio + Local (decoder) — always: target_attr, own_vol_lag1, dow_sin, dow_cos + optional: own_vol_lag2, own_vol_change, own_tvl, own_volatility + +Model variants: + encoder_type: "mlp" (2-layer ReLU) or "linear" (single affine) + no_peers: decoder-only ablation (zero peer summary) + huber_delta: Huber loss transition point (default 1.0) + Per-pool loss weighting (equal weight per pool regardless of sample count) + +Usage: + python experiments/run_deepsets_v2.py # defaults + python experiments/run_deepsets_v2.py --tune 50 # Optuna + python experiments/run_deepsets_v2.py --no-peers # decoder-only + python experiments/run_deepsets_v2.py --encoder-type linear + python experiments/run_deepsets_v2.py --loo # LOO eval +""" + +import argparse +import os +import pickle +import sys +import time + +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd + +CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "token_factored_calibration", "_cache", +) + + +def load_stage1(): + path = os.path.join(CACHE_DIR, "stage1.pkl") + if not os.path.exists(path): + print("ERROR: no stage1 cache.") + sys.exit(1) + with open(path, "rb") as f: + data = pickle.load(f) + return data["matched_clean"], data["option_c_clean"] + + +# ---- Data construction ---- + + +def build_all_features(matched_clean, option_c_clean): + """Build all possible feature matrices. Called once.""" + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + from quantammsim.calibration.pool_data import ( + build_pool_attributes, _parse_tokens, _canonicalize_token, + build_x_obs, build_cross_pool_x_obs, + ) + + pool_ids = sorted(matched_clean.keys()) + n_pools = len(pool_ids) + + # Collect dates + all_dates = set() + for pid in pool_ids: + all_dates.update(matched_clean[pid]["panel"]["date"].values) + date_list = sorted(all_dates) + n_dates = len(date_list) + date_to_idx = {d: i for i, d in enumerate(date_list)} + + # Daily matrices: (n_dates, n_pools) + vol_matrix = np.full((n_dates, n_pools), np.nan) + tvl_matrix = np.full((n_dates, n_pools), np.nan) + volatility_matrix = np.full((n_dates, n_pools), np.nan) + v_arb_matrix = np.full((n_dates, n_pools), np.nan) + weekday_arr = np.zeros(n_dates) + + for j, pid in enumerate(pool_ids): + entry = matched_clean[pid] + oc = option_c_clean[pid] + panel = entry["panel"] + + v_arb_all = np.array(interpolate_pool_daily( + entry["coeffs"], + jnp.float64(oc["log_cadence"]), + jnp.float64(np.exp(oc["log_gas"])), + )) + v_arb = v_arb_all[entry["day_indices"]] + + dates = panel["date"].values + for k, date in enumerate(dates): + t = date_to_idx[date] + vol_matrix[t, j] = panel["log_volume"].values[k] + tvl_matrix[t, j] = panel["log_tvl_lag1"].values[k] + volatility_matrix[t, j] = panel["volatility"].values[k] + v_arb_matrix[t, j] = v_arb[k] + + # Per-pool coeffs, gas, and day mapping for learnable cadence + pool_coeffs = [] + pool_gas = [] + init_log_cadences = np.zeros(n_pools, dtype=np.float32) + common_to_grid = np.full((n_pools, n_dates), 0, dtype=np.int32) + + for j, pid in enumerate(pool_ids): + entry = matched_clean[pid] + oc = option_c_clean[pid] + pool_coeffs.append(entry["coeffs"]) + pool_gas.append(jnp.float64(np.exp(oc["log_gas"]))) + init_log_cadences[j] = oc["log_cadence"] + dates_j = entry["panel"]["date"].values + for k, date in enumerate(dates_j): + common_to_grid[j, date_to_idx[date]] = entry["day_indices"][k] + + for t, date in enumerate(date_list): + weekday_arr[t] = pd.Timestamp(date).weekday() + + # Pool attributes (static) + X_attr, attr_names, _ = build_pool_attributes(matched_clean) + attr_mean = np.mean(X_attr, axis=0) + attr_std = np.std(X_attr, axis=0) + attr_std[attr_std < 1e-6] = 1.0 + X_attr_norm = ((X_attr - attr_mean) / attr_std).astype(np.float32) + k_attr = X_attr_norm.shape[1] + + # Raw per-pool values for relational features + fee_idx = attr_names.index("log_fee") + tvl_idx = attr_names.index("mean_log_tvl") + raw_log_fee = X_attr[:, fee_idx] + raw_mean_log_tvl = X_attr[:, tvl_idx] + pool_chains = [matched_clean[pid]["chain"] for pid in pool_ids] + + # Token overlap + pool_tokens = {} + for i, pid in enumerate(pool_ids): + toks = _parse_tokens(matched_clean[pid]["tokens"]) + pool_tokens[i] = {_canonicalize_token(t) for t in toks[:2]} + + n_peers = n_pools - 1 + peer_attrs = np.zeros((n_pools, n_peers, k_attr), dtype=np.float32) + peer_overlap = np.zeros((n_pools, n_peers), dtype=np.float32) + peer_col_idx = np.zeros((n_pools, n_peers), dtype=np.int32) + rel_same_chain = np.zeros((n_pools, n_peers), dtype=np.float32) + rel_log_tvl_ratio = np.zeros((n_pools, n_peers), dtype=np.float32) + rel_log_fee_ratio = np.zeros((n_pools, n_peers), dtype=np.float32) + + for i in range(n_pools): + peers = [j for j in range(n_pools) if j != i] + for p, j in enumerate(peers): + peer_attrs[i, p] = X_attr_norm[j] + peer_overlap[i, p] = len(pool_tokens[i] & pool_tokens[j]) + peer_col_idx[i, p] = j + rel_same_chain[i, p] = float(pool_chains[i] == pool_chains[j]) + rel_log_tvl_ratio[i, p] = abs(raw_mean_log_tvl[i] - raw_mean_log_tvl[j]) + rel_log_fee_ratio[i, p] = abs(raw_log_fee[i] - raw_log_fee[j]) + + # Standardize ratio features (new arrays, no in-place mutation) + def _standardize(arr): + mu = np.mean(arr) + sigma = max(np.std(arr), 1e-6) + return ((arr - mu) / sigma).astype(np.float32) + + rel_log_tvl_ratio = _standardize(rel_log_tvl_ratio) + rel_log_fee_ratio = _standardize(rel_log_fee_ratio) + + # Cross-pool peer maps: which pools share tokens / chain + from collections import defaultdict + pool_tokens_ordered = {} + token_to_pools = defaultdict(set) + for i, pid in enumerate(pool_ids): + toks = _parse_tokens(matched_clean[pid]["tokens"]) + ordered = [_canonicalize_token(t) for t in toks[:2]] + pool_tokens_ordered[i] = ordered + for tok in ordered: + token_to_pools[tok].add(i) + + token_a_peers = {} + token_b_peers = {} + chain_peer_map = {} + for i in range(n_pools): + toks = pool_tokens_ordered[i] + token_a_peers[i] = sorted(token_to_pools[toks[0]] - {i}) + token_b_peers[i] = sorted(token_to_pools[toks[1]] - {i}) if len(toks) > 1 else [] + chain_peer_map[i] = [j for j in range(n_pools) if j != i and pool_chains[j] == pool_chains[i]] + + # Per-pool log_fee for interaction features + pool_log_fee = raw_log_fee.copy() + fee_mean = float(np.mean(pool_log_fee)) + fee_std = max(float(np.std(pool_log_fee)), 1e-6) + + # Standardization stats for volumes + vol_mean = float(np.nanmean(vol_matrix)) + vol_std = float(np.nanstd(vol_matrix)) + tvl_mean = float(np.nanmean(tvl_matrix)) + tvl_std = float(np.nanstd(tvl_matrix)) + vola_mean = float(np.nanmean(volatility_matrix)) + vola_std = float(np.nanstd(volatility_matrix)) + + # Build x_obs per pool, mapped to common date grid + # x_obs_reduced: (n_dates, n_pools, 4), x_obs_cross: (n_dates, n_pools, 7) + from quantammsim.calibration.pool_data import K_OBS_REDUCED, K_OBS_CROSS + x_obs_reduced_grid = np.full((n_dates, n_pools, K_OBS_REDUCED), np.nan) + x_obs_cross_grid = np.full((n_dates, n_pools, K_OBS_CROSS), np.nan) + + for j, pid in enumerate(pool_ids): + entry = matched_clean[pid] + panel = entry["panel"] + dates_j = panel["date"].values + + # Reduced x_obs (4 features) + xr = build_x_obs(panel, reduced=True) # (n_obs, 4) + for k, date in enumerate(dates_j): + x_obs_reduced_grid[date_to_idx[date], j] = xr[k] + + # Cross-pool x_obs (7 features) — drops first day + xc = build_cross_pool_x_obs(panel, matched_clean, pid) # (n_obs-1, 7) + for k, date in enumerate(dates_j[1:]): + x_obs_cross_grid[date_to_idx[date], j] = xc[k] + + # Build samples: require t >= 2 (for lag-2), valid vol at t, t-1, t-2 + sample_pools, sample_days = [], [] + for i in range(n_pools): + for t in range(2, n_dates): + if (np.isnan(vol_matrix[t, i]) or np.isnan(vol_matrix[t - 1, i]) + or np.isnan(vol_matrix[t - 2, i])): + continue + sample_pools.append(i) + sample_days.append(t) + + sample_pools = np.array(sample_pools, dtype=np.int32) + sample_days = np.array(sample_days, dtype=np.int32) + n_samples = len(sample_pools) + + def _norm_vol(x): + return (x - vol_mean) / vol_std + + def _norm_tvl(x): + return (x - tvl_mean) / tvl_std + + def _norm_vola(x): + return (x - vola_mean) / vola_std + + # Per-sample arrays + # Peer features (per peer) + pf_vol_lag1 = np.zeros((n_samples, n_peers), dtype=np.float32) + pf_vol_lag2 = np.zeros((n_samples, n_peers), dtype=np.float32) + pf_vol_change = np.zeros((n_samples, n_peers), dtype=np.float32) + pf_tvl = np.zeros((n_samples, n_peers), dtype=np.float32) + pf_volatility = np.zeros((n_samples, n_peers), dtype=np.float32) + peer_mask = np.zeros((n_samples, n_peers), dtype=np.float32) + + # Local features + lf_own_vol_lag1 = np.zeros(n_samples, dtype=np.float32) + lf_own_vol_lag2 = np.zeros(n_samples, dtype=np.float32) + lf_own_vol_change = np.zeros(n_samples, dtype=np.float32) + lf_own_tvl = np.zeros(n_samples, dtype=np.float32) + lf_own_volatility = np.zeros(n_samples, dtype=np.float32) + lf_dow_sin = np.zeros(n_samples, dtype=np.float32) + lf_dow_cos = np.zeros(n_samples, dtype=np.float32) + # Interaction features (from calibration pipeline's x_obs) + lf_tvl_x_vola = np.zeros(n_samples, dtype=np.float32) + lf_tvl_x_fee = np.zeros(n_samples, dtype=np.float32) + lf_vola_x_fee = np.zeros(n_samples, dtype=np.float32) + # Cross-pool volume aggregates + lf_cross_vol_tok_a = np.zeros(n_samples, dtype=np.float32) + lf_cross_vol_tok_b = np.zeros(n_samples, dtype=np.float32) + lf_cross_vol_chain = np.zeros(n_samples, dtype=np.float32) + lf_market_vol = np.zeros(n_samples, dtype=np.float32) + # Cross-pool momentum (peer volume changes) + lf_cross_mom_tok_a = np.zeros(n_samples, dtype=np.float32) + lf_cross_mom_tok_b = np.zeros(n_samples, dtype=np.float32) + lf_cross_mom_chain = np.zeros(n_samples, dtype=np.float32) + + # Targets + y_total = np.zeros(n_samples, dtype=np.float32) + v_arb_samples = np.zeros(n_samples, dtype=np.float32) + + for s in range(n_samples): + i = sample_pools[s] + t = sample_days[s] + cols = peer_col_idx[i] + + # Peer features at t-1 + pvols1 = vol_matrix[t - 1, cols] + pvols2 = vol_matrix[t - 2, cols] + valid = ~np.isnan(pvols1) + peer_mask[s] = valid.astype(np.float32) + + pf_vol_lag1[s] = np.where(valid, _norm_vol(pvols1), 0.0) + pf_vol_lag2[s] = np.where(valid & ~np.isnan(pvols2), _norm_vol(pvols2), 0.0) + pf_vol_change[s] = np.where( + valid & ~np.isnan(pvols2), + _norm_vol(pvols1) - _norm_vol(pvols2), 0.0) + + ptvl = tvl_matrix[t - 1, cols] + pf_tvl[s] = np.where(valid & ~np.isnan(ptvl), _norm_tvl(ptvl), 0.0) + + pvola = volatility_matrix[t - 1, cols] + pf_volatility[s] = np.where(valid & ~np.isnan(pvola), _norm_vola(pvola), 0.0) + + # Local features + lf_own_vol_lag1[s] = _norm_vol(vol_matrix[t - 1, i]) + lf_own_vol_lag2[s] = _norm_vol(vol_matrix[t - 2, i]) + lf_own_vol_change[s] = lf_own_vol_lag1[s] - lf_own_vol_lag2[s] + + tvl_val = tvl_matrix[t, i] + lf_own_tvl[s] = _norm_tvl(tvl_val) if np.isfinite(tvl_val) else 0.0 + + vola_val = volatility_matrix[t, i] + lf_own_volatility[s] = _norm_vola(vola_val) if np.isfinite(vola_val) else 0.0 + + wd = weekday_arr[t] + lf_dow_sin[s] = np.sin(2 * np.pi * wd / 7) + lf_dow_cos[s] = np.cos(2 * np.pi * wd / 7) + + # Interaction features (raw products, standardized after loop) + norm_fee_i = (raw_log_fee[i] - fee_mean) / fee_std + lf_tvl_x_vola[s] = lf_own_tvl[s] * lf_own_volatility[s] + lf_tvl_x_fee[s] = lf_own_tvl[s] * norm_fee_i + lf_vola_x_fee[s] = lf_own_volatility[s] * norm_fee_i + + # Cross-pool volume aggregates at t-1 + def _peer_vol_mean(peer_list, t_lag): + if not peer_list: + return vol_mean # global fallback + vals = vol_matrix[t_lag, peer_list] + valid = vals[~np.isnan(vals)] + return float(np.mean(valid)) if len(valid) > 0 else vol_mean + + def _peer_vol_change_mean(peer_list, t_lag): + if not peer_list: + return 0.0 + v1 = vol_matrix[t_lag, peer_list] + v2 = vol_matrix[t_lag - 1, peer_list] + valid = ~np.isnan(v1) & ~np.isnan(v2) + if valid.sum() == 0: + return 0.0 + return float(np.mean(v1[valid] - v2[valid])) + + lf_cross_vol_tok_a[s] = _norm_vol(_peer_vol_mean(token_a_peers[i], t - 1)) + lf_cross_vol_tok_b[s] = _norm_vol(_peer_vol_mean(token_b_peers[i], t - 1)) + lf_cross_vol_chain[s] = _norm_vol(_peer_vol_mean(chain_peer_map[i], t - 1)) + lf_market_vol[s] = _norm_vol(float(np.nanmean(vol_matrix[t - 1, :]))) + + # Cross-pool momentum: mean volume change of peers (t-1 vs t-2) + lf_cross_mom_tok_a[s] = _peer_vol_change_mean(token_a_peers[i], t - 1) + lf_cross_mom_tok_b[s] = _peer_vol_change_mean(token_b_peers[i], t - 1) + lf_cross_mom_chain[s] = _peer_vol_change_mean(chain_peer_map[i], t - 1) + + y_total[s] = vol_matrix[t, i] + v_arb_val = v_arb_matrix[t, i] + v_arb_samples[s] = v_arb_val if np.isfinite(v_arb_val) else 1e-6 + + # Per-sample grid day indices for learnable cadence + sample_grid_days = common_to_grid[sample_pools, sample_days] + + # Per-sample x_obs arrays + x_obs_reduced = np.zeros((n_samples, K_OBS_REDUCED), dtype=np.float32) + x_obs_cross = np.zeros((n_samples, K_OBS_CROSS), dtype=np.float32) + for s in range(n_samples): + xr = x_obs_reduced_grid[sample_days[s], sample_pools[s]] + if np.all(np.isfinite(xr)): + x_obs_reduced[s] = xr + xc = x_obs_cross_grid[sample_days[s], sample_pools[s]] + if np.all(np.isfinite(xc)): + x_obs_cross[s] = xc + + # Standardize momentum features (raw volume differences) + for arr in [lf_cross_mom_tok_a, lf_cross_mom_tok_b, lf_cross_mom_chain]: + mu = np.mean(arr) + sigma = max(np.std(arr), 1e-6) + arr[:] = ((arr - mu) / sigma).astype(np.float32) + + return { + # Static per-pool + "peer_attrs": peer_attrs, # (n_pools, n_peers, k_attr) + "target_attrs": X_attr_norm, # (n_pools, k_attr) + "peer_overlap": peer_overlap, # (n_pools, n_peers) + "rel_same_chain": rel_same_chain, # (n_pools, n_peers) + "rel_log_tvl_ratio": rel_log_tvl_ratio, # (n_pools, n_peers) + "rel_log_fee_ratio": rel_log_fee_ratio, # (n_pools, n_peers) + # Per-sample peer features + "pf_vol_lag1": pf_vol_lag1, + "pf_vol_lag2": pf_vol_lag2, + "pf_vol_change": pf_vol_change, + "pf_tvl": pf_tvl, + "pf_volatility": pf_volatility, + "peer_mask": peer_mask, + # Per-sample local features + "lf_own_vol_lag1": lf_own_vol_lag1, + "lf_own_vol_lag2": lf_own_vol_lag2, + "lf_own_vol_change": lf_own_vol_change, + "lf_own_tvl": lf_own_tvl, + "lf_own_volatility": lf_own_volatility, + "lf_dow_sin": lf_dow_sin, + "lf_dow_cos": lf_dow_cos, + # Interaction features + "lf_tvl_x_vola": lf_tvl_x_vola, + "lf_tvl_x_fee": lf_tvl_x_fee, + "lf_vola_x_fee": lf_vola_x_fee, + # Cross-pool volume aggregates + "lf_cross_vol_tok_a": lf_cross_vol_tok_a, + "lf_cross_vol_tok_b": lf_cross_vol_tok_b, + "lf_cross_vol_chain": lf_cross_vol_chain, + "lf_market_vol": lf_market_vol, + # Cross-pool momentum + "lf_cross_mom_tok_a": lf_cross_mom_tok_a, + "lf_cross_mom_tok_b": lf_cross_mom_tok_b, + "lf_cross_mom_chain": lf_cross_mom_chain, + # Targets + "y_total": y_total, + "y_residual": (y_total - np.log(np.maximum(v_arb_samples, 1e-6))).astype(np.float32), + "v_arb": v_arb_samples, + # Cadence learning (per-pool, not subject to _subset) + "pool_coeffs": pool_coeffs, # list of PoolCoeffsDaily + "pool_gas": pool_gas, # list of jnp scalars + "init_log_cadences": init_log_cadences, # (n_pools,) + "sample_grid_days": sample_grid_days, # (n_samples,) + "x_obs_reduced": x_obs_reduced, # (n_samples, 4) + "x_obs_cross": x_obs_cross, # (n_samples, 7) + # Indices + "pool_idx": sample_pools, + "day_idx": sample_days, + # Meta + "n_pools": n_pools, + "n_peers": n_peers, + "k_attr": k_attr, + "pool_ids": pool_ids, + "vol_mean": vol_mean, + "vol_std": vol_std, + "fee_attr_idx": fee_idx, + "tvl_attr_idx": tvl_idx, + } + + +def assemble_inputs(data, feat_cfg): + """Assemble encoder/decoder inputs based on feature config. + + Returns dict with JAX arrays ready for training. + """ + # ---- Peer encoder input: (n_samples, n_peers, n_feat) ---- + pool_idx = data["pool_idx"] + pa = data["peer_attrs"][pool_idx] # (n_samples, n_peers, k_attr) + ta = data["target_attrs"][pool_idx] # (n_samples, k_attr) + + if feat_cfg.get("minimal_encoder"): + # 7-feature encoder: peer_fee, peer_tvl, target_fee, target_tvl, + # vol_lag1, overlap, same_chain — prevents pool identification + fi = data["fee_attr_idx"] + ti = data["tvl_attr_idx"] + pa_min = np.stack([pa[:, :, fi], pa[:, :, ti]], axis=-1) + ta_min = np.stack([ta[:, fi], ta[:, ti]], axis=-1) + ta_min_broad = np.broadcast_to( + ta_min[:, None, :], (pa_min.shape[0], pa_min.shape[1], 2)) + peer_parts = [ + pa_min, ta_min_broad, + data["pf_vol_lag1"][:, :, None], + data["peer_overlap"][pool_idx][:, :, None], + data["rel_same_chain"][pool_idx][:, :, None], + ] + else: + ta_broad = np.broadcast_to(ta[:, None, :], pa.shape) + peer_parts = [ + pa, ta_broad, + data["pf_vol_lag1"][:, :, None], + data["peer_overlap"][pool_idx][:, :, None], + ] + # Relational features (optional via feat_cfg, default on) + if feat_cfg.get("rel_same_chain", True): + peer_parts.append(data["rel_same_chain"][pool_idx][:, :, None]) + + # Optional temporal peer features (both modes) + if feat_cfg.get("peer_vol_lag2"): + peer_parts.append(data["pf_vol_lag2"][:, :, None]) + if feat_cfg.get("peer_vol_change"): + peer_parts.append(data["pf_vol_change"][:, :, None]) + if feat_cfg.get("peer_tvl"): + peer_parts.append(data["pf_tvl"][:, :, None]) + if feat_cfg.get("peer_volatility"): + peer_parts.append(data["pf_volatility"][:, :, None]) + + # Relational ratio features (both modes) + if feat_cfg.get("rel_tvl_ratio", True): + peer_parts.append(data["rel_log_tvl_ratio"][pool_idx][:, :, None]) + if feat_cfg.get("rel_fee_ratio", True): + peer_parts.append(data["rel_log_fee_ratio"][pool_idx][:, :, None]) + + peer_input = np.concatenate(peer_parts, axis=-1).astype(np.float32) + + # ---- Local decoder input: (n_samples, n_feat) ---- + # Always: target_attr, own_vol_lag1, dow_sin, dow_cos + local_parts = [ + ta, + data["lf_own_vol_lag1"][:, None], + data["lf_dow_sin"][:, None], + data["lf_dow_cos"][:, None], + ] + + if feat_cfg.get("own_vol_lag2"): + local_parts.append(data["lf_own_vol_lag2"][:, None]) + if feat_cfg.get("own_vol_change"): + local_parts.append(data["lf_own_vol_change"][:, None]) + if feat_cfg.get("own_tvl"): + local_parts.append(data["lf_own_tvl"][:, None]) + if feat_cfg.get("own_volatility"): + local_parts.append(data["lf_own_volatility"][:, None]) + + # Interaction features (tvl×vola, tvl×fee, vola×fee) + if feat_cfg.get("interactions"): + local_parts.append(data["lf_tvl_x_vola"][:, None]) + local_parts.append(data["lf_tvl_x_fee"][:, None]) + local_parts.append(data["lf_vola_x_fee"][:, None]) + + # Cross-pool volume aggregates (token-peer, chain-peer, market) + if feat_cfg.get("cross_pool_vol"): + local_parts.append(data["lf_cross_vol_tok_a"][:, None]) + local_parts.append(data["lf_cross_vol_tok_b"][:, None]) + local_parts.append(data["lf_cross_vol_chain"][:, None]) + local_parts.append(data["lf_market_vol"][:, None]) + + # Cross-pool momentum (peer volume changes) + if feat_cfg.get("cross_pool_momentum"): + local_parts.append(data["lf_cross_mom_tok_a"][:, None]) + local_parts.append(data["lf_cross_mom_tok_b"][:, None]) + local_parts.append(data["lf_cross_mom_chain"][:, None]) + + # Option C x_obs covariates (none / reduced=4 / cross=7) + x_obs_mode = feat_cfg.get("x_obs_mode", "none") + if x_obs_mode == "reduced" and "x_obs_reduced" in data: + local_parts.append(data["x_obs_reduced"]) + elif x_obs_mode == "cross" and "x_obs_cross" in data: + local_parts.append(data["x_obs_cross"]) + + local_input = np.concatenate(local_parts, axis=-1).astype(np.float32) + + result = { + "peer_input": jnp.array(peer_input), + "local_input": jnp.array(local_input), + "peer_mask": jnp.array(data["peer_mask"]), + "y": jnp.array(data["y_residual"] if feat_cfg.get("target_residual") else data["y_total"]), + "y_total": jnp.array(data["y_total"]), + "v_arb": jnp.array(data["v_arb"]), + "pool_idx": jnp.array(pool_idx), + "n_pools": data["n_pools"], + "n_peer_feat": peer_input.shape[-1], + "n_local_feat": local_input.shape[-1], + } + # Cadence learning arrays + if "sample_grid_days" in data: + result["sample_grid_days"] = jnp.array(data["sample_grid_days"]) + return result + + +# ---- Model ---- + + +def init_params(key, n_peer_feat, n_local_feat, hidden, d_embed, + encoder_type="mlp"): + """Initialize model parameters. + + encoder_type: "mlp" (2-layer ReLU) or "linear" (single affine). + Presence of "enc_W2" in params dict distinguishes the two at forward time. + """ + k1, k2, k3, k4 = jax.random.split(key, 4) + dec_in = d_embed + n_local_feat + params = {} + + if encoder_type == "mlp": + params["enc_W1"] = jax.random.normal(k1, (n_peer_feat, hidden)) * np.sqrt(2.0 / n_peer_feat) + params["enc_b1"] = jnp.zeros(hidden) + params["enc_W2"] = jax.random.normal(k2, (hidden, d_embed)) * np.sqrt(2.0 / hidden) + params["enc_b2"] = jnp.zeros(d_embed) + else: # linear + params["enc_W1"] = jax.random.normal(k1, (n_peer_feat, d_embed)) * np.sqrt(2.0 / n_peer_feat) + params["enc_b1"] = jnp.zeros(d_embed) + + params["dec_W1"] = jax.random.normal(k3, (dec_in, hidden)) * np.sqrt(2.0 / dec_in) + params["dec_b1"] = jnp.zeros(hidden) + params["dec_W2"] = jax.random.normal(k4, (hidden, 1)) * 0.01 + params["dec_b2"] = jnp.zeros(1) + return params + + +def warm_start_decoder(params, inputs, d_embed): + """Set decoder output layer via OLS through hidden activations. + + Fits y ~ h(local_input) with zero peer summary, so the decoder + starts predicting in the right volume range (~10-17 log scale). + """ + local = np.array(inputs["local_input"]) + y = np.array(inputs["y"]) + n = local.shape[0] + + # Simulate decoder input with zero peer summary + dec_in = np.concatenate( + [np.zeros((n, d_embed), dtype=np.float32), local], axis=1) + + # Forward through first decoder layer with current (random) weights + h = np.maximum( + dec_in @ np.array(params["dec_W1"]) + np.array(params["dec_b1"]), 0.0) + + # OLS: y ≈ h @ W2 + b2 + h_bias = np.concatenate([h, np.ones((n, 1), dtype=np.float32)], axis=1) + sol, _, _, _ = np.linalg.lstsq(h_bias, y[:, None], rcond=None) + + params["dec_W2"] = jnp.array(sol[:-1].astype(np.float32)) + params["dec_b2"] = jnp.array(sol[-1:].astype(np.float32)) + return params + + +def forward(params, peer_input, peer_mask, local_input, no_peers=False): + """Returns predicted log_volume (total) per sample.""" + batch, n_peers, _ = peer_input.shape + + if no_peers: + d_embed = params["dec_W1"].shape[0] - local_input.shape[-1] + summary = jnp.zeros((batch, d_embed)) + else: + flat = peer_input.reshape(-1, peer_input.shape[-1]) + if "enc_W2" in params: + # MLP encoder: 2-layer with ReLU + h = jnp.maximum(flat @ params["enc_W1"] + params["enc_b1"], 0.0) + h = h @ params["enc_W2"] + params["enc_b2"] + else: + # Linear encoder: single affine + h = flat @ params["enc_W1"] + params["enc_b1"] + h = h.reshape(batch, n_peers, -1) + + h_masked = h * peer_mask[:, :, None] + n_valid = jnp.maximum(jnp.sum(peer_mask, axis=1, keepdims=True), 1.0) + summary = jnp.sum(h_masked, axis=1) / n_valid + + dec_in = jnp.concatenate([summary, local_input], axis=-1) + h_dec = jnp.maximum(dec_in @ params["dec_W1"] + params["dec_b1"], 0.0) + return (h_dec @ params["dec_W2"] + params["dec_b2"])[:, 0] + + +def loss_fn(params, peer_input, peer_mask, local_input, y, l2_alpha, + pool_idx, n_pools, huber_delta, no_peers): + """Huber loss with per-pool weighting + L2 reg.""" + pred = forward(params, peer_input, peer_mask, local_input, no_peers) + residuals = pred - y + abs_r = jnp.abs(residuals) + huber_vals = jnp.where(abs_r <= huber_delta, 0.5 * residuals ** 2, + huber_delta * (abs_r - 0.5 * huber_delta)) + + # Per-pool mean loss, then average across active pools (handles LOO gaps) + pool_counts = jnp.zeros(n_pools).at[pool_idx].add(jnp.ones_like(pool_idx, dtype=jnp.float32)) + active = (pool_counts > 0).astype(jnp.float32) + n_active = jnp.maximum(jnp.sum(active), 1.0) + pool_counts = jnp.maximum(pool_counts, 1.0) + pool_sums = jnp.zeros(n_pools).at[pool_idx].add(huber_vals) + data_loss = jnp.sum((pool_sums / pool_counts) * active) / n_active + + reg = sum(jnp.sum(v ** 2) for k, v in params.items() if "W" in k) + return data_loss + l2_alpha * reg + + +# n_pools (arg 7): static for jnp.zeros shape; no_peers (arg 9): static for if/else in forward +_grad_fn = jax.jit(jax.value_and_grad(loss_fn), static_argnums=(7, 9)) + + +# ---- Learnable cadence ---- + + +def make_cadence_loss_fn(pool_coeffs, pool_gas, n_pools, no_peers): + """Build a loss function with per-pool PCHIP coefficients closed over. + + The returned function is JIT-compiled. The Python loop over pools is + unrolled at trace time, so each pool's coefficients are constants. + + The neural net predicts log(V_noise). V_arb comes from PCHIP at the + current learnable log_cadence. Loss is Huber on log(V_arb + V_noise) + vs log(V_obs). + """ + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + + def loss_fn_cadence(params, peer_input, peer_mask, local_input, y_total, + sample_grid_days, pool_idx, l2_alpha, huber_delta): + # Neural net predicts log(V_noise) + log_v_noise = forward(params, peer_input, peer_mask, local_input, no_peers) + + # Compute V_arb per sample via PCHIP (loop unrolled at trace time) + log_cadence = params["log_cadence"] + n_samples = y_total.shape[0] + v_arb = jnp.zeros(n_samples) + + for i in range(n_pools): + v_arb_all = interpolate_pool_daily( + pool_coeffs[i], log_cadence[i], pool_gas[i]) + # Index into this pool's daily V_arb; clip for safety on other pools' samples + safe_days = jnp.clip(sample_grid_days, 0, v_arb_all.shape[0] - 1) + v_arb = jnp.where(pool_idx == i, v_arb_all[safe_days], v_arb) + + # Combine: log(V_arb + V_noise) via numerically stable logaddexp + log_v_arb = jnp.log(jnp.maximum(v_arb, 1e-10)) + log_v_total = jnp.logaddexp(log_v_arb, log_v_noise) + + # Huber loss with per-pool weighting + residuals = log_v_total - y_total + abs_r = jnp.abs(residuals) + huber_vals = jnp.where(abs_r <= huber_delta, 0.5 * residuals ** 2, + huber_delta * (abs_r - 0.5 * huber_delta)) + + pool_counts = jnp.zeros(n_pools).at[pool_idx].add( + jnp.ones_like(pool_idx, dtype=jnp.float32)) + active = (pool_counts > 0).astype(jnp.float32) + n_active = jnp.maximum(jnp.sum(active), 1.0) + pool_counts = jnp.maximum(pool_counts, 1.0) + pool_sums = jnp.zeros(n_pools).at[pool_idx].add(huber_vals) + data_loss = jnp.sum((pool_sums / pool_counts) * active) / n_active + + reg = sum(jnp.sum(v ** 2) for k, v in params.items() if "W" in k) + return data_loss + l2_alpha * reg + + grad_fn = jax.jit(jax.value_and_grad(loss_fn_cadence)) + return grad_fn + + +# ---- Training ---- + + +def train(params, inputs, n_epochs, lr, l2_alpha, huber_delta=1.0, + no_peers=False, verbose=True, grad_fn_override=None): + m = {k: jnp.zeros_like(v) for k, v in params.items()} + v = {k: jnp.zeros_like(v) for k, v in params.items()} + final_loss = float("inf") + + n_pools = int(inputs["n_pools"]) + pool_idx = inputs["pool_idx"] + use_cadence = grad_fn_override is not None + + for epoch in range(n_epochs): + if use_cadence: + loss_val, grads = grad_fn_override( + params, inputs["peer_input"], inputs["peer_mask"], + inputs["local_input"], inputs["y_total"], + inputs["sample_grid_days"], pool_idx, l2_alpha, huber_delta, + ) + else: + loss_val, grads = _grad_fn( + params, inputs["peer_input"], inputs["peer_mask"], + inputs["local_input"], inputs["y"], l2_alpha, + pool_idx, n_pools, huber_delta, no_peers, + ) + final_loss = float(loss_val) + + for k in params: + m[k] = 0.9 * m[k] + 0.1 * grads[k] + v[k] = 0.999 * v[k] + 0.001 * grads[k] ** 2 + m_hat = m[k] / (1.0 - 0.9 ** (epoch + 1)) + v_hat = v[k] / (1.0 - 0.999 ** (epoch + 1)) + params[k] = params[k] - lr * m_hat / (jnp.sqrt(v_hat) + 1e-8) + + if verbose and (epoch % 200 == 0 or epoch == n_epochs - 1): + if use_cadence: + cads = np.exp(np.array(params["log_cadence"])) + # Quick decomposition check: forward pass + V_arb at current cadences + _lvn = np.array(forward( + params, inputs["peer_input"], inputs["peer_mask"], + inputs["local_input"], no_peers)) + _vn = np.exp(_lvn) + _vo = np.exp(np.array(inputs["y_total"])) + # Approximate arb fraction (use V_obs - V_noise as proxy to avoid PCHIP call) + _arb_proxy = np.clip(1.0 - _vn / _vo, 0, None) + _n_pathological = np.sum(_arb_proxy < -0.5) # noise > 1.5x observed + _n_bound = np.sum((cads <= 1.01) | (cads >= 59.9)) + print(f" epoch {epoch:4d} loss={final_loss:.6f}" + f" cad=[{cads.min():.1f}-{np.median(cads):.1f}-{cads.max():.1f}]" + f" |logVn|={np.mean(np.abs(_lvn)):.1f}" + f" bound={_n_bound}") + else: + print(f" epoch {epoch:4d} loss={final_loss:.6f}") + + return params, final_loss + + +# ---- Evaluation ---- + + +def evaluate(params, inputs, data, label="", no_peers=False, + target_residual=False): + """Per-pool R² on total volume and noise residual.""" + pred = np.array(forward( + params, inputs["peer_input"], inputs["peer_mask"], + inputs["local_input"], no_peers=no_peers, + )) + y = np.array(inputs["y"]) + v_arb = np.array(inputs["v_arb"]) + pool_idx = np.array(inputs["pool_idx"]) + + log_v_arb = np.log(np.maximum(v_arb, 1e-6)) + + if target_residual: + # Model predicts noise residual directly + resid_true = y + resid_pred = pred + y_total = y + log_v_arb # reconstruct total for total R² + pred_total = pred + log_v_arb + else: + # Model predicts total log_volume + y_total = y + pred_total = pred + resid_true = y - log_v_arb + resid_pred = pred - log_v_arb + + r2_total = {} + r2_resid = {} + pool_ids = data.get("pool_ids", []) + + for i in range(data["n_pools"]): + mask = pool_idx == i + if mask.sum() < 2: + continue + yt = y_total[mask] + pt = pred_total[mask] + ss_res_t = np.sum((yt - pt) ** 2) + ss_tot_t = np.sum((yt - yt.mean()) ** 2) + r2_total[i] = 1 - ss_res_t / max(ss_tot_t, 1e-10) + + rt = resid_true[mask] + rp = resid_pred[mask] + ss_res_r = np.sum((rt - rp) ** 2) + ss_tot_r = np.sum((rt - rt.mean()) ** 2) + r2_resid[i] = 1 - ss_res_r / max(ss_tot_r, 1e-10) + + def _med(d): + v = [x for x in d.values() if np.isfinite(x)] + return np.median(v) if v else float("nan") + + if label: + print(f"\n {label}:") + for i in range(data["n_pools"]): + if i in r2_total and i < len(pool_ids): + pid = pool_ids[i] + print(f" {pid[:16]} total={r2_total[i]:.3f} resid={r2_resid[i]:.3f}") + + med_total = _med(r2_total) + med_resid = _med(r2_resid) + print(f" Median R² total={med_total:.4f} resid={med_resid:.4f}") + return med_total, med_resid, r2_total, r2_resid + + +def _compute_cadence_decomposition(params, inputs, data, no_peers=False): + """Compute V_arb, V_noise, and predictions for cadence mode. Returns numpy arrays.""" + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + + log_v_noise = np.array(forward( + params, inputs["peer_input"], inputs["peer_mask"], + inputs["local_input"], no_peers=no_peers, + )) + y_total = np.array(inputs["y_total"]) + pool_idx = np.array(inputs["pool_idx"]) + sample_grid_days = np.array(inputs["sample_grid_days"]) + + pool_coeffs = data["pool_coeffs"] + pool_gas = data["pool_gas"] + log_cadence = np.array(params["log_cadence"]) + n_pools = data["n_pools"] + + v_arb = np.zeros(len(y_total)) + for i in range(n_pools): + mask = pool_idx == i + if not mask.any(): + continue + v_arb_all = np.array(interpolate_pool_daily( + pool_coeffs[i], jnp.float64(log_cadence[i]), pool_gas[i])) + v_arb[mask] = v_arb_all[sample_grid_days[mask]] + + v_obs = np.exp(y_total) + v_noise = np.exp(log_v_noise) + log_v_arb = np.log(np.maximum(v_arb, 1e-10)) + pred_total = np.logaddexp(log_v_arb, log_v_noise) + + return { + "v_arb": v_arb, "v_noise": v_noise, "v_obs": v_obs, + "log_v_noise": log_v_noise, "log_v_arb": log_v_arb, + "pred_total": pred_total, "y_total": y_total, + "pool_idx": pool_idx, "log_cadence": log_cadence, + } + + +def evaluate_cadence(params, inputs, data, label="", no_peers=False): + """Evaluate with learned cadence: per-pool R², decomposition diagnostics.""" + dec = _compute_cadence_decomposition(params, inputs, data, no_peers) + pool_ids = data.get("pool_ids", []) + init_cads = data["init_log_cadences"] + n_pools = data["n_pools"] + + if label: + print(f"\n {label}:") + print(f" {'Pool'[:16]:16s} {'R²':>6s} {'Cad init':>8s} {'→learn':>7s}" + f" {'Arb%':>6s} {'Noise%':>7s} {'logVn μ':>7s} {'logVn σ':>7s} {'Flag':>5s}") + print(f" {'-'*80}") + + r2_total = {} + pool_diag = [] + for i in range(n_pools): + mask = dec["pool_idx"] == i + if mask.sum() < 2: + continue + yt = dec["y_total"][mask] + pt = dec["pred_total"][mask] + ss_res = np.sum((yt - pt) ** 2) + ss_tot = np.sum((yt - yt.mean()) ** 2) + r2_total[i] = 1 - ss_res / max(ss_tot, 1e-10) + + pid = pool_ids[i] if i < len(pool_ids) else f"pool_{i}" + cad_init = np.exp(init_cads[i]) + cad_learned = np.exp(dec["log_cadence"][i]) + + va = dec["v_arb"][mask] + vo = dec["v_obs"][mask] + vn = dec["v_noise"][mask] + lvn = dec["log_v_noise"][mask] + + arb_pct = np.median(va / vo) * 100 + noise_pct = np.median(vn / vo) * 100 + lvn_mu = np.mean(lvn) + lvn_std = np.std(lvn) + + # Flags + flags = [] + if arb_pct > 150: + flags.append("A") # arb dominates + if cad_learned <= 1.01 or cad_learned >= 59.9: + flags.append("B") # cadence at bound + if r2_total[i] < 0: + flags.append("X") # negative R² + + flag_str = "".join(flags) if flags else "" + pool_diag.append({ + "idx": i, "pid": pid, "r2": r2_total[i], + "cad_init": cad_init, "cad_learned": cad_learned, + "arb_pct": arb_pct, "noise_pct": noise_pct, + "lvn_mu": lvn_mu, "lvn_std": lvn_std, "flags": flag_str, + }) + + print(f" {pid[:16]:16s} {r2_total[i]:6.3f} {cad_init:7.1f}m {cad_learned:6.1f}m" + f" {arb_pct:6.0f}% {noise_pct:6.0f}% {lvn_mu:7.1f} {lvn_std:7.2f}" + f" {flag_str:>5s}") + + # ── Summary statistics ── + vals = [x for x in r2_total.values() if np.isfinite(x)] + med_r2 = np.median(vals) if vals else float("nan") + cads = np.exp(dec["log_cadence"]) + + n_pathological = sum(1 for d in pool_diag if d["arb_pct"] > 150) + n_at_bound = sum(1 for d in pool_diag + if d["cad_learned"] <= 1.01 or d["cad_learned"] >= 59.9) + n_negative_r2 = sum(1 for d in pool_diag if d["r2"] < 0) + healthy = [d for d in pool_diag if d["arb_pct"] <= 150 and d["r2"] > 0] + med_r2_healthy = (np.median([d["r2"] for d in healthy]) + if healthy else float("nan")) + + print(f"\n ── Summary ──") + print(f" Median R² total: {med_r2:.4f} (healthy only: {med_r2_healthy:.4f})") + print(f" Cadence range: {cads.min():.1f} - {np.median(cads):.1f}" + f" - {cads.max():.1f} min") + print(f" Decomposition: {len(pool_diag) - n_pathological}/{len(pool_diag)}" + f" healthy (arb≤150%), {n_pathological} pathological") + print(f" Cadence at bounds: {n_at_bound}/{len(pool_diag)}" + f" (≤1min or ≥60min)") + print(f" Negative R²: {n_negative_r2}/{len(pool_diag)}") + print(f" Flags: A=arb>150%, B=cadence at bound, X=negative R²") + + return med_r2, r2_total, pool_diag + + +def print_cadence_comparison(train_diag, eval_diag): + """Print train vs eval diagnostic comparison.""" + train_map = {d["pid"]: d for d in train_diag} + eval_map = {d["pid"]: d for d in eval_diag} + all_pids = sorted(set(train_map) | set(eval_map)) + + print(f"\n ── Train vs Eval Gap ──") + print(f" {'Pool'[:16]:16s} {'R² trn':>7s} {'R² eval':>7s} {'Gap':>6s}" + f" {'ArbTrn%':>7s} {'ArbEval%':>8s}") + print(f" {'-'*55}") + + gaps = [] + for pid in all_pids: + td = train_map.get(pid) + ed = eval_map.get(pid) + if td is None or ed is None: + continue + gap = td["r2"] - ed["r2"] + gaps.append(gap) + flag = " ***" if abs(gap) > 0.5 else "" + print(f" {pid[:16]:16s} {td['r2']:7.3f} {ed['r2']:7.3f} {gap:+6.3f}" + f" {td['arb_pct']:6.0f}% {ed['arb_pct']:7.0f}%{flag}") + + if gaps: + print(f" Median gap: {np.median(gaps):+.3f} " + f"Mean gap: {np.mean(gaps):+.3f} " + f"Max gap: {max(gaps):+.3f}") + + +# Keys indexed by sample (shape[0] == n_samples) +_SAMPLE_KEYS = { + "pf_vol_lag1", "pf_vol_lag2", "pf_vol_change", "pf_tvl", "pf_volatility", + "peer_mask", "lf_own_vol_lag1", "lf_own_vol_lag2", "lf_own_vol_change", + "lf_own_tvl", "lf_own_volatility", "lf_dow_sin", "lf_dow_cos", + "lf_tvl_x_vola", "lf_tvl_x_fee", "lf_vola_x_fee", + "lf_cross_vol_tok_a", "lf_cross_vol_tok_b", "lf_cross_vol_chain", + "lf_market_vol", "lf_cross_mom_tok_a", "lf_cross_mom_tok_b", + "lf_cross_mom_chain", + "y_total", "y_residual", "v_arb", "sample_grid_days", + "x_obs_reduced", "x_obs_cross", + "pool_idx", "day_idx", +} + + +def _subset(d, mask): + """Subset sample-indexed arrays by boolean mask.""" + out = {} + for k, v in d.items(): + if k in _SAMPLE_KEYS and isinstance(v, np.ndarray): + out[k] = v[mask] + else: + out[k] = v + return out + + +# ---- Temporal split ---- + + +def run_temporal(data, feat_cfg, hparams, split_frac=0.7): + """Train on first split_frac of days, eval on rest.""" + day_idx = data["day_idx"] + split_day = int(day_idx.max() * split_frac) + train_mask = day_idx <= split_day + eval_mask = day_idx > split_day + + train_data = _subset(data, train_mask) + eval_data = _subset(data, eval_mask) + + train_inputs = assemble_inputs(train_data, feat_cfg) + eval_inputs = assemble_inputs(eval_data, feat_cfg) + + encoder_type = hparams.get("encoder_type", "mlp") + no_peers = hparams.get("no_peers", False) + huber_delta = hparams.get("huber_delta", 1.0) + target_residual = feat_cfg.get("target_residual", False) + learn_cadence = hparams.get("learn_cadence", False) + + n_pf = train_inputs["n_peer_feat"] + n_lf = train_inputs["n_local_feat"] + n_params = sum(v.size for v in init_params( + jax.random.PRNGKey(0), n_pf, n_lf, hparams["hidden"], hparams["d_embed"], + encoder_type=encoder_type, + ).values()) + + print(f" Train: {int(train_mask.sum())}, Eval: {int(eval_mask.sum())}, " + f"peer_feat={n_pf}, local_feat={n_lf}, params={n_params}") + if learn_cadence: + print(f" learn_cadence=True (joint cadence+noise optimization)") + if target_residual: + print(f" target=residual (log_vol - log_V_arb)") + if encoder_type != "mlp": + print(f" encoder_type={encoder_type}") + if no_peers: + print(f" no_peers=True (decoder-only ablation)") + if huber_delta != 1.0: + print(f" huber_delta={huber_delta}") + + params = init_params( + jax.random.PRNGKey(42), n_pf, n_lf, hparams["hidden"], hparams["d_embed"], + encoder_type=encoder_type, + ) + + if learn_cadence: + # Add learnable cadence, initialized from Option C + params["log_cadence"] = jnp.array(data["init_log_cadences"]) + init_cads = np.exp(data["init_log_cadences"]) + print(f" Init cadence: {init_cads.min():.1f}-{np.median(init_cads):.1f}" + f"-{init_cads.max():.1f} min") + + # Warm-start decoder to predict noise residual (log_vol - log_V_arb) + # using the Option C V_arb as the initial target + ws_inputs = dict(train_inputs) + ws_inputs["y"] = train_inputs["y_total"] - jnp.log( + jnp.maximum(train_inputs["v_arb"], 1e-6)) + params = warm_start_decoder(params, ws_inputs, hparams["d_embed"]) + + grad_fn = make_cadence_loss_fn( + data["pool_coeffs"], data["pool_gas"], + data["n_pools"], no_peers) + + print(" Compiling cadence loss (may take a moment)...") + t0 = time.time() + params, _ = train( + params, train_inputs, hparams["n_epochs"], hparams["lr"], + hparams["l2_alpha"], huber_delta=huber_delta, no_peers=no_peers, + grad_fn_override=grad_fn, + ) + print(f" Training: {time.time() - t0:.1f}s") + + print("\n --- Train ---") + _, _, train_diag = evaluate_cadence( + params, train_inputs, data, no_peers=no_peers) + print("\n --- Eval ---") + _, _, eval_diag = evaluate_cadence( + params, eval_inputs, data, no_peers=no_peers) + print_cadence_comparison(train_diag, eval_diag) + else: + params = warm_start_decoder(params, train_inputs, hparams["d_embed"]) + t0 = time.time() + params, _ = train( + params, train_inputs, hparams["n_epochs"], hparams["lr"], + hparams["l2_alpha"], huber_delta=huber_delta, no_peers=no_peers, + ) + print(f" Training: {time.time() - t0:.1f}s") + + eval_kw = dict(no_peers=no_peers, target_residual=target_residual) + print("\n --- Train ---") + evaluate(params, train_inputs, data, **eval_kw) + print("\n --- Eval ---") + _, med_resid_eval, _, _ = evaluate(params, eval_inputs, data, **eval_kw) + + return params + + +# ---- LOO cross-validation ---- + + +def run_loo(data, feat_cfg, hparams): + """Leave-one-pool-out: train on N-1 pools, evaluate on held-out pool. + + Tests cross-pool generalization — can the shared encoder+decoder predict + volume for a pool it has never optimized on? The held-out pool's volume + is still observable as peer features for training pools. + + Note: normalization stats are computed on the full dataset. With N=36 + the leakage from including one held-out pool is ~3% on mean/std. + """ + n_pools = data["n_pools"] + pool_idx = data["pool_idx"] + pool_ids = data.get("pool_ids", []) + + encoder_type = hparams.get("encoder_type", "mlp") + no_peers = hparams.get("no_peers", False) + huber_delta = hparams.get("huber_delta", 1.0) + target_residual = feat_cfg.get("target_residual", False) + d_embed = hparams["d_embed"] + + r2_total_all = {} + r2_resid_all = {} + + for held_out in range(n_pools): + pid = pool_ids[held_out] if held_out < len(pool_ids) else f"pool_{held_out}" + + train_mask = pool_idx != held_out + eval_mask = pool_idx == held_out + n_eval = int(eval_mask.sum()) + + if n_eval < 2: + print(f" [{held_out:2d}] {pid[:16]}: skipped ({n_eval} samples)") + continue + + train_data = _subset(data, train_mask) + eval_data = _subset(data, eval_mask) + + train_inputs = assemble_inputs(train_data, feat_cfg) + eval_inputs = assemble_inputs(eval_data, feat_cfg) + + params = init_params( + jax.random.PRNGKey(42), + train_inputs["n_peer_feat"], train_inputs["n_local_feat"], + hparams["hidden"], d_embed, + encoder_type=encoder_type, + ) + params = warm_start_decoder(params, train_inputs, d_embed) + + params, _ = train( + params, train_inputs, hparams["n_epochs"], + hparams["lr"], hparams["l2_alpha"], + huber_delta=huber_delta, no_peers=no_peers, + verbose=False, + ) + + # Evaluate on held-out pool + pred = np.array(forward( + params, eval_inputs["peer_input"], eval_inputs["peer_mask"], + eval_inputs["local_input"], no_peers=no_peers, + )) + y = np.array(eval_inputs["y"]) + v_arb = np.array(eval_inputs["v_arb"]) + log_v_arb = np.log(np.maximum(v_arb, 1e-6)) + + if target_residual: + resid_true, resid_pred = y, pred + y_total = y + log_v_arb + pred_total = pred + log_v_arb + else: + y_total, pred_total = y, pred + resid_true = y - log_v_arb + resid_pred = pred - log_v_arb + + ss_res = np.sum((y_total - pred_total) ** 2) + ss_tot = np.sum((y_total - y_total.mean()) ** 2) + r2_t = 1 - ss_res / max(ss_tot, 1e-10) + + ss_res_r = np.sum((resid_true - resid_pred) ** 2) + ss_tot_r = np.sum((resid_true - resid_true.mean()) ** 2) + r2_r = 1 - ss_res_r / max(ss_tot_r, 1e-10) + + r2_total_all[held_out] = r2_t + r2_resid_all[held_out] = r2_r + + print(f" [{held_out:2d}] {pid[:16]}: total={r2_t:.3f} resid={r2_r:.3f} (n={n_eval})") + + def _med(d): + v = [x for x in d.values() if np.isfinite(x)] + return np.median(v) if v else float("nan") + + med_total = _med(r2_total_all) + med_resid = _med(r2_resid_all) + print(f"\n LOO Median R² total={med_total:.4f} resid={med_resid:.4f}") + print(f" ({len(r2_total_all)} pools evaluated)") + + return med_total, med_resid + + +# ---- Optuna ---- + + +_FEAT_KEYS = [ + "peer_vol_lag2", "peer_vol_change", "peer_tvl", "peer_volatility", + "own_vol_lag2", "own_vol_change", "own_tvl", "own_volatility", + "rel_same_chain", "rel_tvl_ratio", "rel_fee_ratio", "minimal_encoder", + "interactions", "cross_pool_vol", "cross_pool_momentum", +] + + +def run_optuna(data, n_trials, target_residual=False): + import optuna + + day_idx = data["day_idx"] + split_day = int(day_idx.max() * 0.7) + train_mask = day_idx <= split_day + eval_mask = day_idx > split_day + + train_data = _subset(data, train_mask) + eval_data = _subset(data, eval_mask) + + def objective(trial): + feat_cfg = { + "peer_vol_lag2": trial.suggest_categorical("peer_vol_lag2", [True, False]), + "peer_vol_change": trial.suggest_categorical("peer_vol_change", [True, False]), + "peer_tvl": trial.suggest_categorical("peer_tvl", [True, False]), + "peer_volatility": trial.suggest_categorical("peer_volatility", [True, False]), + "own_vol_lag2": trial.suggest_categorical("own_vol_lag2", [True, False]), + "own_vol_change": trial.suggest_categorical("own_vol_change", [True, False]), + "own_tvl": trial.suggest_categorical("own_tvl", [True, False]), + "own_volatility": trial.suggest_categorical("own_volatility", [True, False]), + "rel_same_chain": trial.suggest_categorical("rel_same_chain", [True, False]), + "rel_tvl_ratio": trial.suggest_categorical("rel_tvl_ratio", [True, False]), + "rel_fee_ratio": trial.suggest_categorical("rel_fee_ratio", [True, False]), + "minimal_encoder": trial.suggest_categorical("minimal_encoder", [True, False]), + "interactions": trial.suggest_categorical("interactions", [True, False]), + "cross_pool_vol": trial.suggest_categorical("cross_pool_vol", [True, False]), + "cross_pool_momentum": trial.suggest_categorical("cross_pool_momentum", [True, False]), + "target_residual": target_residual, + } + hparams = { + "hidden": trial.suggest_categorical("hidden", [16, 32, 64, 128]), + "d_embed": trial.suggest_categorical("d_embed", [4, 8, 16, 32]), + "lr": trial.suggest_float("lr", 1e-4, 1e-2, log=True), + "l2_alpha": trial.suggest_float("l2_alpha", 1e-5, 1e-1, log=True), + "n_epochs": trial.suggest_categorical("n_epochs", [500, 1000, 2000]), + "encoder_type": trial.suggest_categorical("encoder_type", ["mlp", "linear"]), + "huber_delta": trial.suggest_categorical("huber_delta", [0.5, 1.0, 1.5, 2.0]), + "no_peers": trial.suggest_categorical("no_peers", [True, False]), + } + + train_inputs = assemble_inputs(train_data, feat_cfg) + eval_inputs = assemble_inputs(eval_data, feat_cfg) + + params = init_params( + jax.random.PRNGKey(42), + train_inputs["n_peer_feat"], train_inputs["n_local_feat"], + hparams["hidden"], hparams["d_embed"], + encoder_type=hparams["encoder_type"], + ) + params = warm_start_decoder(params, train_inputs, hparams["d_embed"]) + params, _ = train( + params, train_inputs, hparams["n_epochs"], + hparams["lr"], hparams["l2_alpha"], + huber_delta=hparams["huber_delta"], + no_peers=hparams["no_peers"], + verbose=False, + ) + + # Eval R² on noise residual + _tgt_resid = feat_cfg.get("target_residual", False) + pred = np.array(forward( + params, eval_inputs["peer_input"], eval_inputs["peer_mask"], + eval_inputs["local_input"], no_peers=hparams["no_peers"], + )) + y = np.array(eval_inputs["y"]) + v_arb = np.array(eval_inputs["v_arb"]) + log_v_arb = np.log(np.maximum(v_arb, 1e-6)) + pool_idx = np.array(eval_data["pool_idx"]) + + r2_resids = [] + r2_totals = [] + for i in range(data["n_pools"]): + mask = pool_idx == i + if mask.sum() < 2: + continue + yi = y[mask] + pi = pred[mask] + lva = log_v_arb[mask] + + if _tgt_resid: + resid_true, resid_pred = yi, pi + yt, pt = yi + lva, pi + lva + else: + yt, pt = yi, pi + resid_true, resid_pred = yi - lva, pi - lva + + ss_res = np.sum((yt - pt) ** 2) + ss_tot = np.sum((yt - yt.mean()) ** 2) + r2_totals.append(1 - ss_res / max(ss_tot, 1e-10)) + + ss_res_r = np.sum((resid_true - resid_pred) ** 2) + ss_tot_r = np.sum((resid_true - resid_true.mean()) ** 2) + r2_resids.append(1 - ss_res_r / max(ss_tot_r, 1e-10)) + + med_resid = float(np.median(r2_resids)) if r2_resids else -10.0 + med_total = float(np.median(r2_totals)) if r2_totals else -10.0 + + trial.set_user_attr("med_total_r2", med_total) + n_feat = sum(1 for k in _FEAT_KEYS if feat_cfg.get(k)) + print(f" Trial {trial.number}: resid={med_resid:.4f} total={med_total:.4f} " + f"enc={hparams['encoder_type']} h={hparams['hidden']} d={hparams['d_embed']} " + f"hub={hparams['huber_delta']} " + f"{'no_peers ' if hparams['no_peers'] else ''}" + f"lr={hparams['lr']:.1e} a={hparams['l2_alpha']:.1e} " + f"ep={hparams['n_epochs']} feat={n_feat}/15" + f"{' minimal' if feat_cfg.get('minimal_encoder') else ''}") + + return med_resid + + study = optuna.create_study(direction="maximize") + study.optimize(objective, n_trials=n_trials) + + print(f"\n{'='*70}") + print("Optuna Results") + print(f"{'='*70}") + print(f" Best eval noise resid R²: {study.best_value:.4f}") + print(f" Best total R²: {study.best_trial.user_attrs['med_total_r2']:.4f}") + print(f" Best params:") + for k, v in sorted(study.best_params.items()): + print(f" {k}: {v}") + + print(f"\n Top 10:") + trials = sorted(study.trials, key=lambda t: t.value if t.value else -999, + reverse=True) + for t in trials[:10]: + if t.value is not None: + feats = sum(1 for k in _FEAT_KEYS if t.params.get(k)) + print(f" #{t.number}: resid={t.value:.4f} " + f"total={t.user_attrs.get('med_total_r2', '?'):.4f} " + f"enc={t.params['encoder_type']} " + f"h={t.params['hidden']} d={t.params['d_embed']} " + f"feat={feats}/15") + + return study + + +def run_optuna_cadence(data, n_trials): + """Optuna sweep with learnable cadence. Optimizes median eval total R².""" + import optuna + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + + day_idx = data["day_idx"] + split_day = int(day_idx.max() * 0.7) + train_mask = day_idx <= split_day + eval_mask = day_idx > split_day + + train_data = _subset(data, train_mask) + eval_data = _subset(data, eval_mask) + + pool_coeffs = data["pool_coeffs"] + pool_gas = data["pool_gas"] + n_pools = data["n_pools"] + + # Pre-build grad_fn closures for (no_peers=True, no_peers=False) + # to avoid recompiling on every trial with the same no_peers setting + _grad_fn_cache = {} + + def _get_grad_fn(no_peers): + if no_peers not in _grad_fn_cache: + _grad_fn_cache[no_peers] = make_cadence_loss_fn( + pool_coeffs, pool_gas, n_pools, no_peers) + return _grad_fn_cache[no_peers] + + def objective(trial): + feat_cfg = { + "peer_vol_lag2": trial.suggest_categorical("peer_vol_lag2", [True, False]), + "peer_vol_change": trial.suggest_categorical("peer_vol_change", [True, False]), + "peer_tvl": trial.suggest_categorical("peer_tvl", [True, False]), + "peer_volatility": trial.suggest_categorical("peer_volatility", [True, False]), + "own_vol_lag2": trial.suggest_categorical("own_vol_lag2", [True, False]), + "own_vol_change": trial.suggest_categorical("own_vol_change", [True, False]), + "own_tvl": trial.suggest_categorical("own_tvl", [True, False]), + "own_volatility": trial.suggest_categorical("own_volatility", [True, False]), + "rel_same_chain": trial.suggest_categorical("rel_same_chain", [True, False]), + "rel_tvl_ratio": trial.suggest_categorical("rel_tvl_ratio", [True, False]), + "rel_fee_ratio": trial.suggest_categorical("rel_fee_ratio", [True, False]), + "minimal_encoder": trial.suggest_categorical("minimal_encoder", [True, False]), + "interactions": trial.suggest_categorical("interactions", [True, False]), + "cross_pool_vol": trial.suggest_categorical("cross_pool_vol", [True, False]), + "cross_pool_momentum": trial.suggest_categorical("cross_pool_momentum", [True, False]), + "x_obs_mode": trial.suggest_categorical("x_obs_mode", ["none", "reduced", "cross"]), + } + hparams = { + "hidden": trial.suggest_categorical("hidden", [16, 32, 64]), + "d_embed": trial.suggest_categorical("d_embed", [4, 8, 16]), + "lr": trial.suggest_float("lr", 3e-4, 3e-3, log=True), + "l2_alpha": trial.suggest_float("l2_alpha", 1e-5, 1e-2, log=True), + "n_epochs": trial.suggest_categorical("n_epochs", [500, 1000, 2000]), + "encoder_type": trial.suggest_categorical("encoder_type", ["mlp", "linear"]), + "huber_delta": trial.suggest_categorical("huber_delta", [0.5, 1.0, 1.5]), + "no_peers": trial.suggest_categorical("no_peers", [True, False]), + } + + no_peers = hparams["no_peers"] + train_inputs = assemble_inputs(train_data, feat_cfg) + eval_inputs = assemble_inputs(eval_data, feat_cfg) + + params = init_params( + jax.random.PRNGKey(42), + train_inputs["n_peer_feat"], train_inputs["n_local_feat"], + hparams["hidden"], hparams["d_embed"], + encoder_type=hparams["encoder_type"], + ) + # Learnable cadence from Option C init + params["log_cadence"] = jnp.array(data["init_log_cadences"]) + + # Warm-start decoder on noise residual + ws_inputs = dict(train_inputs) + ws_inputs["y"] = train_inputs["y_total"] - jnp.log( + jnp.maximum(train_inputs["v_arb"], 1e-6)) + params = warm_start_decoder(params, ws_inputs, hparams["d_embed"]) + + grad_fn = _get_grad_fn(no_peers) + params, _ = train( + params, train_inputs, hparams["n_epochs"], + hparams["lr"], hparams["l2_alpha"], + huber_delta=hparams["huber_delta"], no_peers=no_peers, + verbose=False, grad_fn_override=grad_fn, + ) + + # Eval: compute V_arb at learned cadences, combine with net + log_v_noise = np.array(forward( + params, eval_inputs["peer_input"], eval_inputs["peer_mask"], + eval_inputs["local_input"], no_peers=no_peers, + )) + y_total = np.array(eval_inputs["y_total"]) + pool_idx = np.array(eval_data["pool_idx"]) + sample_grid_days = np.array(eval_inputs["sample_grid_days"]) + log_cadence = np.array(params["log_cadence"]) + + v_arb = np.zeros(len(y_total)) + for i in range(n_pools): + mask = pool_idx == i + if not mask.any(): + continue + v_arb_all = np.array(interpolate_pool_daily( + pool_coeffs[i], jnp.float64(log_cadence[i]), pool_gas[i])) + v_arb[mask] = v_arb_all[sample_grid_days[mask]] + + log_v_arb = np.log(np.maximum(v_arb, 1e-10)) + pred_total = np.logaddexp(log_v_arb, log_v_noise) + + r2_totals = [] + for i in range(n_pools): + mask = pool_idx == i + if mask.sum() < 2: + continue + yt = y_total[mask] + pt = pred_total[mask] + ss_res = np.sum((yt - pt) ** 2) + ss_tot = np.sum((yt - yt.mean()) ** 2) + r2_totals.append(1 - ss_res / max(ss_tot, 1e-10)) + + med_total = float(np.median(r2_totals)) if r2_totals else -10.0 + + cads = np.exp(log_cadence) + trial.set_user_attr("med_total_r2", med_total) + trial.set_user_attr("cad_median", float(np.median(cads))) + n_feat = sum(1 for k in _FEAT_KEYS if feat_cfg.get(k)) + print(f" Trial {trial.number}: total={med_total:.4f} " + f"enc={hparams['encoder_type']} h={hparams['hidden']} d={hparams['d_embed']} " + f"hub={hparams['huber_delta']} " + f"{'no_peers ' if no_peers else ''}" + f"lr={hparams['lr']:.1e} a={hparams['l2_alpha']:.1e} " + f"ep={hparams['n_epochs']} feat={n_feat}/15" + f" cad=[{cads.min():.0f}-{np.median(cads):.0f}-{cads.max():.0f}]") + + return med_total + + study = optuna.create_study(direction="maximize") + study.optimize(objective, n_trials=n_trials) + + print(f"\n{'='*70}") + print("Optuna Results (learn_cadence)") + print(f"{'='*70}") + print(f" Best eval total R²: {study.best_value:.4f}") + print(f" Best params:") + for k, v in sorted(study.best_params.items()): + print(f" {k}: {v}") + + print(f"\n Top 10:") + trials = sorted(study.trials, key=lambda t: t.value if t.value else -999, + reverse=True) + for t in trials[:10]: + if t.value is not None: + feats = sum(1 for k in _FEAT_KEYS if t.params.get(k)) + cad_med = t.user_attrs.get("cad_median", "?") + print(f" #{t.number}: total={t.value:.4f} " + f"enc={t.params['encoder_type']} " + f"h={t.params['hidden']} d={t.params['d_embed']} " + f"feat={feats}/15 cad_med={cad_med:.0f}") + + return study + + +# ---- Main ---- + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--tune", type=int, default=0) + parser.add_argument("--loo", action="store_true") + # Architecture + parser.add_argument("--hidden", type=int, default=16) + parser.add_argument("--d-embed", type=int, default=8) + parser.add_argument("--lr", type=float, default=3e-4) + parser.add_argument("--l2-alpha", type=float, default=1e-3) + parser.add_argument("--epochs", type=int, default=1000) + parser.add_argument("--encoder-type", choices=["mlp", "linear"], default="mlp") + parser.add_argument("--huber-delta", type=float, default=1.0) + parser.add_argument("--no-peers", action="store_true", + help="Decoder-only ablation (zero peer summary)") + parser.add_argument("--learn-cadence", action="store_true", + help="Jointly optimize per-pool arb cadence via PCHIP") + parser.add_argument("--x-obs", choices=["none", "reduced", "cross"], + default="none", + help="Append Option C x_obs covariates to decoder: " + "none, reduced (4: intercept,tvl,dow), " + "cross (7: +peer volumes)") + parser.add_argument("--minimal-encoder", action="store_true", + help="7-feature encoder (fee, tvl, overlap, same_chain) " + "instead of full attributes") + parser.add_argument("--target-residual", action="store_true", + help="Train on noise residual (log_vol - log_V_arb) " + "instead of total log_volume") + # Feature flags + parser.add_argument("--peer-vol-lag2", action="store_true") + parser.add_argument("--peer-vol-change", action="store_true") + parser.add_argument("--peer-tvl", action="store_true") + parser.add_argument("--peer-volatility", action="store_true") + parser.add_argument("--own-vol-lag2", action="store_true") + parser.add_argument("--own-vol-change", action="store_true") + parser.add_argument("--own-tvl", action="store_true") + parser.add_argument("--own-volatility", action="store_true") + parser.add_argument("--interactions", action="store_true", + help="tvl×vola, tvl×fee, vola×fee interaction terms") + parser.add_argument("--cross-pool-vol", action="store_true", + help="Token-peer, chain-peer, market volume aggregates") + parser.add_argument("--cross-pool-momentum", action="store_true", + help="Peer volume change momentum features") + parser.add_argument("--all-features", action="store_true", + help="Enable all optional features") + args = parser.parse_args() + + os.environ.setdefault("JAX_PLATFORMS", "cpu") + + feat_cfg = { + "peer_vol_lag2": args.peer_vol_lag2 or args.all_features, + "peer_vol_change": args.peer_vol_change or args.all_features, + "peer_tvl": args.peer_tvl or args.all_features, + "peer_volatility": args.peer_volatility or args.all_features, + "own_vol_lag2": args.own_vol_lag2 or args.all_features, + "own_vol_change": args.own_vol_change or args.all_features, + "own_tvl": args.own_tvl or args.all_features, + "own_volatility": args.own_volatility or args.all_features, + # Relational features: always on for CLI, searchable in Optuna + "rel_same_chain": True, + "rel_tvl_ratio": True, + "rel_fee_ratio": True, + "interactions": args.interactions or args.all_features, + "cross_pool_vol": args.cross_pool_vol or args.all_features, + "cross_pool_momentum": args.cross_pool_momentum or args.all_features, + "minimal_encoder": args.minimal_encoder, + "target_residual": args.target_residual, + "x_obs_mode": args.x_obs, + } + hparams = { + "hidden": args.hidden, + "d_embed": args.d_embed, + "lr": args.lr, + "l2_alpha": args.l2_alpha, + "n_epochs": args.epochs, + "encoder_type": args.encoder_type, + "huber_delta": args.huber_delta, + "no_peers": args.no_peers, + "learn_cadence": args.learn_cadence, + } + + print("=" * 70) + print("DeepSets v2: Total Volume Target + Noise Residual Eval") + feat_on = [k for k, v in feat_cfg.items() if v] + print(f" Optional features: {feat_on or 'none'}") + print(f" Architecture: {hparams}") + print("=" * 70) + + matched_clean, option_c_clean = load_stage1() + + print("\nBuilding features...") + t0 = time.time() + data = build_all_features(matched_clean, option_c_clean) + print(f" {len(data['pool_idx'])} samples, {data['n_pools']} pools, " + f"{time.time() - t0:.1f}s") + + if args.loo: + print(f"\n{'='*70}") + print("Leave-One-Pool-Out Cross-Validation") + print(f"{'='*70}") + run_loo(data, feat_cfg, hparams) + elif args.tune > 0 and args.learn_cadence: + run_optuna_cadence(data, args.tune) + elif args.tune > 0: + run_optuna(data, args.tune, target_residual=args.target_residual) + else: + print(f"\n{'='*70}") + print("Temporal split (70/30)") + print(f"{'='*70}") + run_temporal(data, feat_cfg, hparams) + + print(f"\n Baselines for comparison:") + print(f" Option C on residual: median R² = 0.060") + print(f" Ridge+own on residual: median R² = 0.098") + print(f" Constant zero: median R² = -0.083") + + +if __name__ == "__main__": + main() diff --git a/experiments/run_deepsets_volume.py b/experiments/run_deepsets_volume.py new file mode 100644 index 0000000..dc174d2 --- /dev/null +++ b/experiments/run_deepsets_volume.py @@ -0,0 +1,467 @@ +"""DeepSets cross-pool volume prediction. + +Architecture: + For pool i at day t: + For each peer j != i with valid data at t-1: + h_j = Encoder(attr_j, attr_i, vol_j_{t-1}, overlap_ij) + peer_summary = masked_mean(h_j) + pred_i_t = Decoder(peer_summary, attr_i, own_vol_{t-1}) + +Evaluation: + 1. In-sample R² (all data) + 2. Temporal split (70/30) + 3. LOO (hold out one pool, retrain, evaluate) +""" + +import os +import pickle +import sys +import time + +import jax +import jax.numpy as jnp +import numpy as np + +CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "token_factored_calibration", "_cache", +) +HIDDEN = 8 +D_EMBED = 4 +LR = 1e-3 +N_EPOCHS = 500 +N_EPOCHS_LOO = 200 +L2_ALPHA = 0.001 + + +def load_stage1(): + path = os.path.join(CACHE_DIR, "stage1.pkl") + if not os.path.exists(path): + print("ERROR: no stage1 cache.") + sys.exit(1) + with open(path, "rb") as f: + data = pickle.load(f) + return data["matched_clean"], data["option_c_clean"] + + +def build_volume_matrix(matched_clean): + """Build (n_dates, n_pools) volume matrix. NaN where missing.""" + pool_ids = sorted(matched_clean.keys()) + pool_date_vol = {} + all_dates = set() + for pid in pool_ids: + panel = matched_clean[pid]["panel"] + dates = panel["date"].values + vols = panel["log_volume"].values.astype(float) + pool_date_vol[pid] = dict(zip(dates, vols)) + all_dates.update(dates) + + date_list = sorted(all_dates) + n_dates = len(date_list) + n_pools = len(pool_ids) + vol_matrix = np.full((n_dates, n_pools), np.nan) + for j, pid in enumerate(pool_ids): + dv = pool_date_vol[pid] + for t, date in enumerate(date_list): + if date in dv: + vol_matrix[t, j] = dv[date] + return vol_matrix, date_list, pool_ids + + +def build_data(matched_clean, exclude_pool_idx=None): + """Build all arrays for DeepSets training. + + If exclude_pool_idx is set, that pool is excluded from training + samples but kept as a peer (its volume data is still available). + """ + from quantammsim.calibration.pool_data import ( + build_pool_attributes, _parse_tokens, _canonicalize_token, + ) + + pool_ids = sorted(matched_clean.keys()) + n_pools = len(pool_ids) + + vol_matrix, date_list, _ = build_volume_matrix(matched_clean) + X_attr, attr_names, _ = build_pool_attributes(matched_clean) + + # Standardize attributes + attr_mean = np.mean(X_attr, axis=0) + attr_std = np.std(X_attr, axis=0) + attr_std[attr_std < 1e-6] = 1.0 + X_attr_norm = ((X_attr - attr_mean) / attr_std).astype(np.float32) + + # Standardize volumes + vol_mean = float(np.nanmean(vol_matrix)) + vol_std = float(np.nanstd(vol_matrix)) + vol_norm = ((vol_matrix - vol_mean) / vol_std).astype(np.float32) + + # Token overlap + k_attr = X_attr_norm.shape[1] + pool_tokens = {} + for i, pid in enumerate(pool_ids): + toks = _parse_tokens(matched_clean[pid]["tokens"]) + pool_tokens[i] = {_canonicalize_token(t) for t in toks[:2]} + + # Per-pool peer structures + n_peers = n_pools - 1 + peer_attrs = np.zeros((n_pools, n_peers, k_attr), dtype=np.float32) + peer_overlap = np.zeros((n_pools, n_peers), dtype=np.float32) + peer_col_idx = np.zeros((n_pools, n_peers), dtype=np.int32) + + for i in range(n_pools): + peers = [j for j in range(n_pools) if j != i] + for p, j in enumerate(peers): + peer_attrs[i, p] = X_attr_norm[j] + peer_overlap[i, p] = len(pool_tokens[i] & pool_tokens[j]) + peer_col_idx[i, p] = j + + target_attrs = X_attr_norm + + # Build samples + sample_pools, sample_days = [], [] + for i in range(n_pools): + if i == exclude_pool_idx: + continue + for t in range(1, len(date_list)): + if np.isnan(vol_matrix[t, i]) or np.isnan(vol_matrix[t - 1, i]): + continue + sample_pools.append(i) + sample_days.append(t) + + sample_pools = np.array(sample_pools, dtype=np.int32) + sample_days = np.array(sample_days, dtype=np.int32) + n_samples = len(sample_pools) + + # Vectorized: gather peer volumes and masks + peer_vols = np.zeros((n_samples, n_peers), dtype=np.float32) + peer_mask = np.zeros((n_samples, n_peers), dtype=np.float32) + own_lag = np.zeros(n_samples, dtype=np.float32) + y = np.zeros(n_samples, dtype=np.float32) + + for s in range(n_samples): + i = sample_pools[s] + t = sample_days[s] + cols = peer_col_idx[i] + pvols = vol_norm[t - 1, cols] + valid = ~np.isnan(pvols) + peer_vols[s] = np.where(valid, pvols, 0.0) + peer_mask[s] = valid.astype(np.float32) + own_lag[s] = vol_norm[t - 1, i] + y[s] = vol_norm[t, i] + + return { + "peer_attrs": jnp.array(peer_attrs), + "target_attrs": jnp.array(target_attrs), + "peer_overlap": jnp.array(peer_overlap), + "peer_vols": jnp.array(peer_vols), + "peer_mask": jnp.array(peer_mask), + "own_lag": jnp.array(own_lag), + "y": jnp.array(y), + "pool_idx": jnp.array(sample_pools), + "day_idx": sample_days, + "n_pools": n_pools, + "n_peers": n_peers, + "k_attr": k_attr, + "pool_ids": pool_ids, + "vol_mean": vol_mean, + "vol_std": vol_std, + } + + +# ---- Model ---- + + +def init_params(key, k_attr, hidden=HIDDEN, d=D_EMBED): + k1, k2, k3, k4 = jax.random.split(key, 4) + enc_in = 2 * k_attr + 2 # peer_attr + target_attr + peer_vol + overlap + dec_in = d + k_attr + 1 # summary + target_attr + own_lag + return { + "enc_W1": jax.random.normal(k1, (enc_in, hidden)) * np.sqrt(2.0 / enc_in), + "enc_b1": jnp.zeros(hidden), + "enc_W2": jax.random.normal(k2, (hidden, d)) * np.sqrt(2.0 / hidden), + "enc_b2": jnp.zeros(d), + "dec_W1": jax.random.normal(k3, (dec_in, hidden)) * np.sqrt(2.0 / dec_in), + "dec_b1": jnp.zeros(hidden), + "dec_W2": jax.random.normal(k4, (hidden, 1)) * 0.01, + "dec_b2": jnp.zeros(1), + } + + +def forward(params, peer_attrs_all, target_attrs_all, peer_overlap_all, + peer_vols, peer_mask, own_lag, pool_idx): + """Batched DeepSets forward pass.""" + batch = peer_vols.shape[0] + n_peers = peer_vols.shape[1] + + pa = peer_attrs_all[pool_idx] # (batch, n_peers, k_attr) + ta = target_attrs_all[pool_idx] # (batch, k_attr) + ov = peer_overlap_all[pool_idx] # (batch, n_peers) + + ta_broad = jnp.broadcast_to(ta[:, None, :], pa.shape) + + enc_in = jnp.concatenate([ + pa, ta_broad, + peer_vols[:, :, None], + ov[:, :, None], + ], axis=-1) + + # Encoder MLP + flat = enc_in.reshape(-1, enc_in.shape[-1]) + h = jnp.maximum(flat @ params["enc_W1"] + params["enc_b1"], 0.0) + h = h @ params["enc_W2"] + params["enc_b2"] + h = h.reshape(batch, n_peers, -1) + + # Masked mean + h_masked = h * peer_mask[:, :, None] + n_valid = jnp.maximum(jnp.sum(peer_mask, axis=1, keepdims=True), 1.0) + summary = jnp.sum(h_masked, axis=1) / n_valid + + # Decoder MLP + dec_in = jnp.concatenate([summary, ta, own_lag[:, None]], axis=-1) + h_dec = jnp.maximum(dec_in @ params["dec_W1"] + params["dec_b1"], 0.0) + return (h_dec @ params["dec_W2"] + params["dec_b2"])[:, 0] + + +def loss_fn(params, static, peer_vols, peer_mask, own_lag, pool_idx, y, alpha): + pred = forward(params, static["peer_attrs"], static["target_attrs"], + static["peer_overlap"], peer_vols, peer_mask, own_lag, pool_idx) + mse = jnp.mean((pred - y) ** 2) + reg = sum(jnp.sum(v ** 2) for k, v in params.items() if "W" in k) + return mse + alpha * reg + + +grad_fn = jax.jit(jax.value_and_grad(loss_fn)) + + +# ---- Training ---- + + +def train(params, data, n_epochs=N_EPOCHS, lr=LR, alpha=L2_ALPHA, verbose=True): + """Full-batch Adam training.""" + static = { + "peer_attrs": data["peer_attrs"], + "target_attrs": data["target_attrs"], + "peer_overlap": data["peer_overlap"], + } + + # Adam state + m = {k: jnp.zeros_like(v) for k, v in params.items()} + v = {k: jnp.zeros_like(v) for k, v in params.items()} + + for epoch in range(n_epochs): + loss_val, grads = grad_fn( + params, static, data["peer_vols"], data["peer_mask"], + data["own_lag"], data["pool_idx"], data["y"], alpha, + ) + + # Adam update + for k in params: + m[k] = 0.9 * m[k] + 0.1 * grads[k] + v[k] = 0.999 * v[k] + 0.001 * grads[k] ** 2 + m_hat = m[k] / (1.0 - 0.9 ** (epoch + 1)) + v_hat = v[k] / (1.0 - 0.999 ** (epoch + 1)) + params[k] = params[k] - lr * m_hat / (jnp.sqrt(v_hat) + 1e-8) + + if verbose and (epoch % 100 == 0 or epoch == n_epochs - 1): + print(f" epoch {epoch:4d} loss={float(loss_val):.6f}") + + return params + + +# ---- Evaluation ---- + + +def per_pool_r2(params, data): + """Compute per-pool R² from trained model.""" + static = { + "peer_attrs": data["peer_attrs"], + "target_attrs": data["target_attrs"], + "peer_overlap": data["peer_overlap"], + } + pred = np.array(forward( + params, static["peer_attrs"], static["target_attrs"], + static["peer_overlap"], data["peer_vols"], data["peer_mask"], + data["own_lag"], data["pool_idx"], + )) + y = np.array(data["y"]) + pool_idx = np.array(data["pool_idx"]) + + r2s = {} + for i in range(data["n_pools"]): + mask = pool_idx == i + if mask.sum() < 2: + continue + yi = y[mask] + pi = pred[mask] + ss_res = np.sum((yi - pi) ** 2) + ss_tot = np.sum((yi - yi.mean()) ** 2) + r2s[i] = 1 - ss_res / max(ss_tot, 1e-10) + return r2s + + +# ---- Main experiments ---- + + +def run_insample(matched_clean): + print("\n" + "=" * 70) + print("1. In-sample DeepSets") + print("=" * 70) + + data = build_data(matched_clean) + n_params = sum(v.size for v in init_params(jax.random.PRNGKey(0), data["k_attr"]).values()) + print(f" {data['peer_vols'].shape[0]} samples, {data['n_pools']} pools, " + f"{data['k_attr']} attrs, {n_params} params") + + params = init_params(jax.random.PRNGKey(42), data["k_attr"]) + t0 = time.time() + params = train(params, data) + print(f" Training: {time.time() - t0:.1f}s") + + r2s = per_pool_r2(params, data) + pool_ids = data["pool_ids"] + for i, pid in enumerate(pool_ids): + if i in r2s: + print(f" {pid[:16]} ({matched_clean[pid]['tokens']:<14}) R²={r2s[i]:.3f}") + + vals = list(r2s.values()) + print(f"\n In-sample: median R²={np.median(vals):.4f}, mean={np.mean(vals):.4f}") + return params, data, r2s + + +def run_temporal_split(matched_clean, split_frac=0.7): + print("\n" + "=" * 70) + print(f"2. Temporal split ({int(split_frac*100)}/{int((1-split_frac)*100)})") + print("=" * 70) + + data_all = build_data(matched_clean) + day_idx = np.array(data_all["day_idx"]) + max_day = day_idx.max() + split_day = int(max_day * split_frac) + + train_mask = day_idx <= split_day + eval_mask = day_idx > split_day + + def subset(data, mask): + jmask = jnp.array(mask) + return { + **{k: data[k] for k in ["peer_attrs", "target_attrs", "peer_overlap", + "n_pools", "n_peers", "k_attr", "pool_ids", + "vol_mean", "vol_std"]}, + "peer_vols": data["peer_vols"][jmask], + "peer_mask": data["peer_mask"][jmask], + "own_lag": data["own_lag"][jmask], + "y": data["y"][jmask], + "pool_idx": data["pool_idx"][jmask], + "day_idx": data_all["day_idx"][mask], + } + + train_data = subset(data_all, train_mask) + eval_data = subset(data_all, eval_mask) + + print(f" Train: {int(train_mask.sum())} samples, Eval: {int(eval_mask.sum())} samples") + + params = init_params(jax.random.PRNGKey(42), data_all["k_attr"]) + params = train(params, train_data) + + r2s_train = per_pool_r2(params, train_data) + r2s_eval = per_pool_r2(params, eval_data) + + pool_ids = data_all["pool_ids"] + for i, pid in enumerate(pool_ids): + r_tr = r2s_train.get(i, float("nan")) + r_ev = r2s_eval.get(i, float("nan")) + print(f" {pid[:16]} ({matched_clean[pid]['tokens']:<14}) " + f"train={r_tr:.3f} eval={r_ev:.3f}") + + vals_eval = [v for v in r2s_eval.values() if np.isfinite(v)] + print(f"\n Temporal eval: median R²={np.median(vals_eval):.4f}, " + f"mean={np.mean(vals_eval):.4f}") + return r2s_eval + + +def run_loo(matched_clean): + print("\n" + "=" * 70) + print("3. LOO DeepSets") + print("=" * 70) + + pool_ids = sorted(matched_clean.keys()) + n_pools = len(pool_ids) + loo_r2s = [] + + for hold_out_idx in range(n_pools): + hold_out_pid = pool_ids[hold_out_idx] + + # Build training data excluding held-out pool's samples + # (but keeping its volume data for peers) + train_data = build_data(matched_clean, exclude_pool_idx=hold_out_idx) + + params = init_params(jax.random.PRNGKey(42), train_data["k_attr"]) + params = train(params, train_data, n_epochs=N_EPOCHS_LOO, verbose=False) + + # Build eval data: only held-out pool's samples + eval_data = build_data(matched_clean) + ho_mask = np.array(eval_data["pool_idx"]) == hold_out_idx + if ho_mask.sum() < 2: + loo_r2s.append(float("nan")) + continue + + jmask = jnp.array(ho_mask) + eval_sub = { + **{k: eval_data[k] for k in ["peer_attrs", "target_attrs", "peer_overlap", + "n_pools", "n_peers", "k_attr", "pool_ids", + "vol_mean", "vol_std"]}, + "peer_vols": eval_data["peer_vols"][jmask], + "peer_mask": eval_data["peer_mask"][jmask], + "own_lag": eval_data["own_lag"][jmask], + "y": eval_data["y"][jmask], + "pool_idx": eval_data["pool_idx"][jmask], + "day_idx": np.array(eval_data["day_idx"])[ho_mask], + } + + r2s = per_pool_r2(params, eval_sub) + r2 = r2s.get(hold_out_idx, float("nan")) + loo_r2s.append(r2) + + tag = "OK" if r2 > 0 else "NEG" + print(f" {hold_out_pid[:16]} ({matched_clean[hold_out_pid]['tokens']:<14}) " + f"R²={r2:.3f} [{tag}]") + + valid = [r for r in loo_r2s if np.isfinite(r)] + print(f"\n LOO DeepSets: median R²={np.median(valid):.4f}, " + f"mean={np.mean(valid):.4f}, " + f"n_neg={sum(1 for r in valid if r < 0)}") + return loo_r2s + + +def main(): + os.environ.setdefault("JAX_PLATFORMS", "cpu") + + print("=" * 70) + print("DeepSets Cross-Pool Volume Prediction") + print(f" hidden={HIDDEN}, d={D_EMBED}, lr={LR}, " + f"alpha={L2_ALPHA}, epochs={N_EPOCHS}") + print("=" * 70) + + matched_clean, _ = load_stage1() + + params, data, r2_insample = run_insample(matched_clean) + r2_temporal = run_temporal_split(matched_clean) + r2_loo = run_loo(matched_clean) + + print("\n" + "=" * 70) + print("SUMMARY") + print("=" * 70) + vals_in = list(r2_insample.values()) + vals_temp = [v for v in r2_temporal.values() if np.isfinite(v)] + vals_loo = [r for r in r2_loo if np.isfinite(r)] + print(f" DeepSets in-sample: median R² = {np.median(vals_in):.4f}") + print(f" DeepSets temporal (30%): median R² = {np.median(vals_temp):.4f}") + print(f" DeepSets LOO: median R² = {np.median(vals_loo):.4f}") + print(f" ---") + print(f" Ridge in-sample: median R² = 0.441") + print(f" Naive AR1: median R² = 0.397") + print(f" Token-factored LOO: median R² = 0.362") + + +if __name__ == "__main__": + main() diff --git a/experiments/run_hybrid_noise.py b/experiments/run_hybrid_noise.py new file mode 100644 index 0000000..d815a5c --- /dev/null +++ b/experiments/run_hybrid_noise.py @@ -0,0 +1,814 @@ +"""Hybrid noise model: DeepSets peer encoder + linear noise model. + +Architecture: + peer_effect = DeepSets_encoder(peer_data, current_pool_attrs) → scalar + log(V_noise) = [x_obs, market, peer_effect, peer_effect×tvl, ...] @ coeffs + V_total = V_arb(cadence) + exp(log_v_noise) + +The encoder learns how to aggregate peer information. The linear model +learns how that aggregate (plus market/pool features) drives noise volume. +Cadence is learnable per-pool via PCHIP. + +Usage: + python experiments/run_hybrid_noise.py + python experiments/run_hybrid_noise.py --encoder-hidden 16 --epochs 2000 + python experiments/run_hybrid_noise.py --n-peer-outputs 3 # multi-dim peer effect +""" + +import argparse +import os +import pickle +import sys +import time + +import jax +import jax.numpy as jnp +import numpy as np + +CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "token_factored_calibration", "_cache", +) + + +def load_stage1(): + path = os.path.join(CACHE_DIR, "stage1.pkl") + if not os.path.exists(path): + print("ERROR: no stage1 cache.") + sys.exit(1) + with open(path, "rb") as f: + data = pickle.load(f) + return data["matched_clean"], data["option_c_clean"] + + +def build_data(matched_clean, option_c_clean, trend_windows=(7, 14, 30)): + """Build all features: x_obs, market, peer encoder inputs.""" + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + from quantammsim.calibration.pool_data import ( + build_x_obs, build_cross_pool_x_obs, build_pool_attributes, + _parse_tokens, _canonicalize_token, K_OBS_CROSS, + ) + from quantammsim.calibration.market_features import ( + build_pool_market_features, pool_market_features_to_matrix, + ) + + pool_ids = sorted(matched_clean.keys()) + n_pools = len(pool_ids) + + # Common date grid + all_dates = set() + for pid in pool_ids: + all_dates.update(matched_clean[pid]["panel"]["date"].values) + date_list = sorted(all_dates) + n_dates = len(date_list) + date_to_idx = {d: i for i, d in enumerate(date_list)} + + # Volume matrix and per-pool metadata + vol_matrix = np.full((n_dates, n_pools), np.nan) + pool_coeffs = [] + pool_gas = [] + init_log_cadences = np.zeros(n_pools, dtype=np.float32) + common_to_grid = np.full((n_pools, n_dates), 0, dtype=np.int32) + + for j, pid in enumerate(pool_ids): + entry = matched_clean[pid] + oc = option_c_clean[pid] + pool_coeffs.append(entry["coeffs"]) + pool_gas.append(jnp.float64(np.exp(oc["log_gas"]))) + init_log_cadences[j] = oc["log_cadence"] + dates = entry["panel"]["date"].values + log_vols = entry["panel"]["log_volume"].values.astype(float) + for k, date in enumerate(dates): + t = date_to_idx[date] + vol_matrix[t, j] = log_vols[k] + common_to_grid[j, t] = entry["day_indices"][k] + + # Pool attributes (static, normalized) + X_attr, attr_names, _ = build_pool_attributes(matched_clean) + attr_mean = np.mean(X_attr, axis=0) + attr_std = np.std(X_attr, axis=0) + attr_std[attr_std < 1e-6] = 1.0 + X_attr_norm = ((X_attr - attr_mean) / attr_std).astype(np.float32) + k_attr = X_attr_norm.shape[1] + + # Token overlap matrix + pool_tokens = {} + for i, pid in enumerate(pool_ids): + toks = _parse_tokens(matched_clean[pid]["tokens"]) + pool_tokens[i] = {_canonicalize_token(t) for t in toks[:2]} + + overlap = np.zeros((n_pools, n_pools), dtype=np.float32) + for i in range(n_pools): + for j in range(n_pools): + if i != j: + overlap[i, j] = len(pool_tokens[i] & pool_tokens[j]) + + # Peer index mapping: for pool i, peers are all j != i + n_peers = n_pools - 1 + peer_idx = np.zeros((n_pools, n_peers), dtype=np.int32) + peer_overlap = np.zeros((n_pools, n_peers), dtype=np.float32) + for i in range(n_pools): + peers = [j for j in range(n_pools) if j != i] + peer_idx[i] = peers + peer_overlap[i] = overlap[i, peers] + + # Build samples + sample_pools, sample_days = [], [] + for i in range(n_pools): + for t in range(1, n_dates): + if np.isnan(vol_matrix[t, i]) or np.isnan(vol_matrix[t - 1, i]): + continue + sample_pools.append(i) + sample_days.append(t) + sample_pools = np.array(sample_pools, dtype=np.int32) + sample_days = np.array(sample_days, dtype=np.int32) + n_samples = len(sample_pools) + + # ---- x_obs (cross-pool, 7 features) ---- + x_obs_grid = np.full((n_dates, n_pools, K_OBS_CROSS), np.nan) + for j, pid in enumerate(pool_ids): + panel = matched_clean[pid]["panel"] + xc = build_cross_pool_x_obs(panel, matched_clean, pid) + dates_j = panel["date"].values + for k, date in enumerate(dates_j[1:]): + x_obs_grid[date_to_idx[date], j] = xc[k] + + x_obs = np.zeros((n_samples, K_OBS_CROSS), dtype=np.float32) + for s in range(n_samples): + xval = x_obs_grid[sample_days[s], sample_pools[s]] + if np.all(np.isfinite(xval)): + x_obs[s] = xval + + # ---- Market features ---- + print(" Building market features...") + pool_feat = build_pool_market_features( + matched_clean, trend_windows=list(trend_windows)) + x_market, market_names = pool_market_features_to_matrix( + pool_feat, matched_clean, date_to_idx, pool_ids, + sample_pools, sample_days) + print(f" Market features: {len(market_names)} columns") + + # ---- Peer encoder inputs: (n_samples, n_peers, n_peer_feat) ---- + # Per peer: [peer_attrs, target_attrs, peer_vol_lag1, overlap] + # peer_vol_lag1 is the peer's volume at t-1 + vol_mean = float(np.nanmean(vol_matrix)) + vol_std = max(float(np.nanstd(vol_matrix)), 1e-6) + + # Static peer features (per pool) + peer_attrs = np.zeros((n_pools, n_peers, k_attr), dtype=np.float32) + for i in range(n_pools): + peer_attrs[i] = X_attr_norm[peer_idx[i]] + + # Per-sample peer features + peer_vol_lag1 = np.zeros((n_samples, n_peers), dtype=np.float32) + peer_mask = np.zeros((n_samples, n_peers), dtype=np.float32) + + for s in range(n_samples): + i = sample_pools[s] + t = sample_days[s] + cols = peer_idx[i] + pvols = vol_matrix[t - 1, cols] + valid = ~np.isnan(pvols) + peer_mask[s] = valid.astype(np.float32) + peer_vol_lag1[s] = np.where(valid, (pvols - vol_mean) / vol_std, 0.0) + + # Assemble peer encoder input: (n_samples, n_peers, n_peer_feat) + # [peer_attrs(k_attr), target_attrs(k_attr), vol_lag1(1), overlap(1)] + target_attrs_broad = np.broadcast_to( + X_attr_norm[sample_pools][:, None, :], + (n_samples, n_peers, k_attr)) + peer_input = np.concatenate([ + peer_attrs[sample_pools], # (n_samples, n_peers, k_attr) + target_attrs_broad, # (n_samples, n_peers, k_attr) + peer_vol_lag1[:, :, None], # (n_samples, n_peers, 1) + peer_overlap[sample_pools][:, :, None], # (n_samples, n_peers, 1) + ], axis=-1).astype(np.float32) + n_peer_feat = peer_input.shape[-1] + + # Combine linear features (x_obs + market) + x_base = np.concatenate([x_obs, x_market], axis=1).astype(np.float32) + base_names = [f"xobs_{i}" for i in range(K_OBS_CROSS)] + market_names + + # Standardize base features (except intercept) + x_mean = np.mean(x_base, axis=0) + x_std_arr = np.std(x_base, axis=0) + x_std_arr[x_std_arr < 1e-6] = 1.0 + x_mean[0] = 0.0 + x_std_arr[0] = 1.0 + x_base = ((x_base - x_mean) / x_std_arr).astype(np.float32) + + # Build named column index for interaction construction + col_idx = {name: i for i, name in enumerate(base_names)} + + # Interaction terms (products of standardized features) + interactions = [] + interaction_names = [] + + def _add_interaction(name_a, name_b): + if name_a in col_idx and name_b in col_idx: + interactions.append( + x_base[:, col_idx[name_a]] * x_base[:, col_idx[name_b]]) + interaction_names.append(f"{name_a}×{name_b}") + + # TVL × volatility: deep pools respond differently to market stress + _add_interaction("xobs_1", "btc_realized_vol_7d") # tvl × btc vol + _add_interaction("xobs_1", "tok_a_realized_vol_7d") # tvl × tok_a vol + _add_interaction("xobs_1", "pair_realized_vol_7d") # tvl × pair vol + + # Cross-token volatility interaction: both tokens moving = pair activity + _add_interaction("tok_a_realized_vol_7d", "tok_b_realized_vol_7d") + + if interactions: + x_interactions = np.column_stack(interactions).astype(np.float32) + x_linear = np.concatenate([x_base, x_interactions], axis=1) + linear_names = base_names + interaction_names + else: + x_linear = x_base + linear_names = base_names + + # Track which columns are tvl and btc_vol for peer_effect interactions in loss + tvl_col = col_idx.get("xobs_1", 1) + btc_vol_col = col_idx.get("btc_realized_vol_7d") + + # Targets and indices + y_total = np.array([vol_matrix[sample_days[s], sample_pools[s]] + for s in range(n_samples)], dtype=np.float32) + sample_grid_days = common_to_grid[sample_pools, sample_days] + + return { + "x_linear": x_linear, # (n_samples, n_linear_feat) + "peer_input": peer_input, # (n_samples, n_peers, n_peer_feat) + "peer_mask": peer_mask, # (n_samples, n_peers) + "y_total": y_total, + "pool_idx": sample_pools, + "day_idx": sample_days, + "sample_grid_days": sample_grid_days, + "pool_coeffs": pool_coeffs, + "pool_gas": pool_gas, + "init_log_cadences": init_log_cadences, + "n_pools": n_pools, + "n_peers": n_peers, + "n_linear_feat": x_linear.shape[1], + "n_peer_feat": n_peer_feat, + "pool_ids": pool_ids, + "linear_names": linear_names, + "tvl_col": tvl_col, + "btc_vol_col": btc_vol_col, + } + + +# ---- Model ---- + +_SAMPLE_KEYS = { + "x_linear", "peer_input", "peer_mask", "y_total", + "pool_idx", "day_idx", "sample_grid_days", +} + + +def _subset(d, mask): + out = {} + for k, v in d.items(): + if k in _SAMPLE_KEYS and isinstance(v, np.ndarray): + out[k] = v[mask] + else: + out[k] = v + return out + + +def init_params(key, n_peer_feat, n_linear_feat, encoder_hidden, + n_peer_outputs, n_pools, init_log_cadences, + encoder_depth=1): + """Initialize all parameters. + + Encoder: peer_input → hidden (× depth) → n_peer_outputs (per peer, mean-pooled) + Linear: [x_linear, peer_outputs, peer×tvl, peer×btc_vol] @ coeffs + + encoder_depth: number of hidden layers (1-4). + params["enc_depth"] stores the depth as a scalar for forward_encoder. + """ + keys = jax.random.split(key, encoder_depth + 2) + + n_peer_linear = n_peer_outputs * 3 + n_total_linear = n_linear_feat + n_peer_linear + + params = {} + + # First layer: input → hidden + params["enc_W1"] = jax.random.normal(keys[0], (n_peer_feat, encoder_hidden)) * np.sqrt(2.0 / n_peer_feat) + params["enc_b1"] = jnp.zeros(encoder_hidden) + + # Hidden layers 2..depth: hidden → hidden + for d in range(2, encoder_depth + 1): + params[f"enc_W{d}"] = jax.random.normal(keys[d - 1], (encoder_hidden, encoder_hidden)) * np.sqrt(2.0 / encoder_hidden) + params[f"enc_b{d}"] = jnp.zeros(encoder_hidden) + + # Output layer: hidden → n_peer_outputs + out_idx = encoder_depth + 1 + params[f"enc_W{out_idx}"] = jax.random.normal(keys[-1], (encoder_hidden, n_peer_outputs)) * 0.01 + params[f"enc_b{out_idx}"] = jnp.zeros(n_peer_outputs) + + params["noise_coeffs"] = jnp.zeros(n_total_linear) + params["log_cadence"] = jnp.array(init_log_cadences) + + return params, n_total_linear + + +def forward_encoder(params, peer_input, peer_mask): + """DeepSets encoder: per-peer MLP → masked mean → scalar(s). + + Depth determined by counting enc_W* keys. + Returns (n_samples, n_peer_outputs). + """ + batch, n_peers, _ = peer_input.shape + flat = peer_input.reshape(-1, peer_input.shape[-1]) + + # Count layers: enc_W1, enc_W2, ..., enc_W{depth+1} + n_layers = sum(1 for k in params if k.startswith("enc_W")) + + # Hidden layers with ReLU + h = flat + for i in range(1, n_layers): + h = jnp.maximum(h @ params[f"enc_W{i}"] + params[f"enc_b{i}"], 0.0) + + # Output layer (no activation) + h = h @ params[f"enc_W{n_layers}"] + params[f"enc_b{n_layers}"] + + h = h.reshape(batch, n_peers, -1) + h_masked = h * peer_mask[:, :, None] + n_valid = jnp.maximum(jnp.sum(peer_mask, axis=1, keepdims=True), 1.0) + return jnp.sum(h_masked, axis=1) / n_valid + + +def make_loss_fn(pool_coeffs, pool_gas, n_pools, tvl_col, btc_vol_col): + """Loss with learnable cadence + encoder + linear noise model.""" + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + + def loss_fn(params, x_linear, peer_input, peer_mask, y_total, + sample_grid_days, pool_idx, l2_alpha, huber_delta): + + # Encoder → peer effect scalar(s) + peer_effect = forward_encoder(params, peer_input, peer_mask) + + # Build full linear input: [x_linear, peer, peer×tvl, peer×btc_vol] + tvl = x_linear[:, tvl_col:tvl_col + 1] + btc_vol = x_linear[:, btc_vol_col:btc_vol_col + 1] + peer_x_tvl = peer_effect * tvl + peer_x_btcvol = peer_effect * btc_vol + + x_full = jnp.concatenate( + [x_linear, peer_effect, peer_x_tvl, peer_x_btcvol], axis=1) + log_v_noise = x_full @ params["noise_coeffs"] + + # V_arb from PCHIP + log_cadence = params["log_cadence"] + n_samples = y_total.shape[0] + v_arb = jnp.zeros(n_samples) + for i in range(n_pools): + v_arb_all = interpolate_pool_daily( + pool_coeffs[i], log_cadence[i], pool_gas[i]) + safe_days = jnp.clip(sample_grid_days, 0, v_arb_all.shape[0] - 1) + v_arb = jnp.where(pool_idx == i, v_arb_all[safe_days], v_arb) + + log_v_arb = jnp.log(jnp.maximum(v_arb, 1e-10)) + log_v_total = jnp.logaddexp(log_v_arb, log_v_noise) + + # Huber loss with per-pool weighting + residuals = log_v_total - y_total + abs_r = jnp.abs(residuals) + huber_vals = jnp.where(abs_r <= huber_delta, 0.5 * residuals ** 2, + huber_delta * (abs_r - 0.5 * huber_delta)) + + pool_counts = jnp.zeros(n_pools).at[pool_idx].add( + jnp.ones_like(pool_idx, dtype=jnp.float32)) + active = (pool_counts > 0).astype(jnp.float32) + n_active = jnp.maximum(jnp.sum(active), 1.0) + pool_counts = jnp.maximum(pool_counts, 1.0) + pool_sums = jnp.zeros(n_pools).at[pool_idx].add(huber_vals) + data_loss = jnp.sum((pool_sums / pool_counts) * active) / n_active + + # L2 on all encoder weights + noise coeffs + reg = l2_alpha * ( + sum(jnp.sum(v ** 2) for k, v in params.items() if k.startswith("enc_W")) + + jnp.sum(params["noise_coeffs"] ** 2) + ) + return data_loss + reg + + return jax.jit(jax.value_and_grad(loss_fn)) + + +def train(params, data, grad_fn, n_epochs, lr, l2_alpha, huber_delta, + verbose=True): + m = {k: jnp.zeros_like(v) for k, v in params.items()} + v = {k: jnp.zeros_like(v) for k, v in params.items()} + + xl = jnp.array(data["x_linear"]) + pi = jnp.array(data["peer_input"]) + pm = jnp.array(data["peer_mask"]) + yt = jnp.array(data["y_total"]) + sgd = jnp.array(data["sample_grid_days"]) + pidx = jnp.array(data["pool_idx"]) + + for epoch in range(n_epochs): + loss_val, grads = grad_fn( + params, xl, pi, pm, yt, sgd, pidx, l2_alpha, huber_delta) + loss_f = float(loss_val) + + for k in params: + m[k] = 0.9 * m[k] + 0.1 * grads[k] + v[k] = 0.999 * v[k] + 0.001 * grads[k] ** 2 + m_hat = m[k] / (1.0 - 0.9 ** (epoch + 1)) + v_hat = v[k] / (1.0 - 0.999 ** (epoch + 1)) + params[k] = params[k] - lr * m_hat / (jnp.sqrt(v_hat) + 1e-8) + + if verbose and (epoch % 200 == 0 or epoch == n_epochs - 1): + cads = np.exp(np.array(params["log_cadence"])) + pe = np.array(forward_encoder( + params, jnp.array(data["peer_input"][:100]), + jnp.array(data["peer_mask"][:100]))) + print(f" epoch {epoch:4d} loss={loss_f:.6f}" + f" cad=[{cads.min():.1f}-{np.median(cads):.1f}-{cads.max():.1f}]" + f" peer_eff=[{pe.min():.2f},{pe.mean():.2f},{pe.max():.2f}]") + + return params + + +def evaluate(params, data, label=""): + """Evaluate decomposition.""" + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + + x_linear = np.array(data["x_linear"]) + peer_input = data["peer_input"] + peer_mask = data["peer_mask"] + y_total = np.array(data["y_total"]) + pool_idx = np.array(data["pool_idx"]) + sgd = np.array(data["sample_grid_days"]) + log_cadence = np.array(params["log_cadence"]) + init_cads = data["init_log_cadences"] + pool_ids = data["pool_ids"] + n_pools = data["n_pools"] + + # Encoder + peer_effect = np.array(forward_encoder( + params, jnp.array(peer_input), jnp.array(peer_mask))) + + # Build full linear input (must match loss_fn construction) + tvl_col = data["tvl_col"] + btc_vol_col = data["btc_vol_col"] + tvl = x_linear[:, tvl_col:tvl_col + 1] + btc_vol = x_linear[:, btc_vol_col:btc_vol_col + 1] + peer_x_tvl = peer_effect * tvl + peer_x_btcvol = peer_effect * btc_vol + x_full = np.concatenate( + [x_linear, peer_effect, peer_x_tvl, peer_x_btcvol], axis=1) + + noise_coeffs = np.array(params["noise_coeffs"]) + log_v_noise = x_full @ noise_coeffs + v_noise = np.exp(log_v_noise) + + # V_arb + v_arb = np.zeros(len(y_total)) + for i in range(n_pools): + mask = pool_idx == i + if not mask.any(): + continue + v_arb_all = np.array(interpolate_pool_daily( + data["pool_coeffs"][i], jnp.float64(log_cadence[i]), + data["pool_gas"][i])) + v_arb[mask] = v_arb_all[sgd[mask]] + + v_obs = np.exp(y_total) + log_v_arb = np.log(np.maximum(v_arb, 1e-10)) + pred_total = np.logaddexp(log_v_arb, log_v_noise) + + if label: + print(f"\n {label}:") + print(f" {'Pool'[:16]:16s} {'R²':>6s} {'Cad':>5s} → {'learn':>5s}" + f" {'Arb%':>6s} {'Noise%':>7s} {'PeerEff':>8s} {'Flag':>5s}") + print(f" {'-'*65}") + + r2s = {} + pool_diag = [] + for i in range(n_pools): + mask = pool_idx == i + if mask.sum() < 2: + continue + yt = y_total[mask] + pt = pred_total[mask] + ss_res = np.sum((yt - pt) ** 2) + ss_tot = np.sum((yt - yt.mean()) ** 2) + r2s[i] = 1 - ss_res / max(ss_tot, 1e-10) + + pid = pool_ids[i] + ci = np.exp(init_cads[i]) + cl = np.exp(log_cadence[i]) + arb_pct = np.median(v_arb[mask] / v_obs[mask]) * 100 + noise_pct = np.median(v_noise[mask] / v_obs[mask]) * 100 + pe_mean = np.mean(peer_effect[mask]) + + flags = [] + if arb_pct > 150: + flags.append("A") + if cl <= 1.01 or cl >= 59.9: + flags.append("B") + if r2s[i] < 0: + flags.append("X") + flag_str = "".join(flags) + + pool_diag.append({ + "pid": pid, "r2": r2s[i], "cad_init": ci, "cad_learned": cl, + "arb_pct": arb_pct, "noise_pct": noise_pct, + "peer_effect": pe_mean, "flags": flag_str, + }) + + print(f" {pid[:16]:16s} {r2s[i]:6.3f} {ci:5.1f} → {cl:5.1f}" + f" {arb_pct:6.0f}% {noise_pct:6.0f}% {pe_mean:+8.3f} {flag_str:>5s}") + + vals = [x for x in r2s.values() if np.isfinite(x)] + med = np.median(vals) if vals else float("nan") + healthy = [d for d in pool_diag if d["arb_pct"] <= 150 and d["r2"] > 0] + med_h = np.median([d["r2"] for d in healthy]) if healthy else float("nan") + n_path = sum(1 for d in pool_diag if d["arb_pct"] > 150) + n_bound = sum(1 for d in pool_diag + if d["cad_learned"] <= 1.01 or d["cad_learned"] >= 59.9) + + # Print coefficient analysis + nc = np.array(params["noise_coeffs"]) + n_linear = data["n_linear_feat"] + n_po = params["enc_W2"].shape[1] # n_peer_outputs + + print(f"\n Median R²: {med:.4f} (healthy: {med_h:.4f})") + print(f" Healthy: {len(pool_diag) - n_path}/{len(pool_diag)}," + f" at bounds: {n_bound}") + + print(f"\n Linear coefficients:") + for j, name in enumerate(data["linear_names"]): + print(f" {name:30s} {nc[j]:+8.4f}") + for j in range(n_po): + print(f" {'peer_effect_' + str(j):30s} {nc[n_linear + j]:+8.4f}") + for j in range(n_po): + print(f" {'peer_eff_' + str(j) + '×tvl':30s} {nc[n_linear + n_po + j]:+8.4f}") + for j in range(n_po): + print(f" {'peer_eff_' + str(j) + '×btc_vol':30s} {nc[n_linear + 2*n_po + j]:+8.4f}") + + return med, r2s, pool_diag + + +def run_optuna(matched_clean, option_c_clean, n_trials): + """Optuna sweep over encoder + linear model hyperparameters.""" + import optuna + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + + # Build data for each trend_windows config (cache to avoid rebuilding) + data_cache = {} + + def _get_data(trend_key): + if trend_key not in data_cache: + data_cache[trend_key] = build_data( + matched_clean, option_c_clean, + trend_windows=trend_key) + return data_cache[trend_key] + + # Cache grad_fn per (n_pools, tvl_col, btc_vol_col) — these are stable + grad_fn_cache = {} + + def objective(trial): + trend_w = trial.suggest_categorical("trend_window", [7, 14, 30]) + data = _get_data((trend_w,)) + n_pools = data["n_pools"] + + encoder_hidden = trial.suggest_categorical("encoder_hidden", [16, 32, 64, 128]) + encoder_depth = trial.suggest_categorical("encoder_depth", [1, 2, 3, 4]) + n_peer_outputs = trial.suggest_categorical("n_peer_outputs", [1, 2, 4]) + lr = trial.suggest_float("lr", 3e-4, 3e-3, log=True) + l2_alpha = trial.suggest_float("l2_alpha", 1e-4, 1e-2, log=True) + huber_delta = trial.suggest_categorical("huber_delta", [0.5, 1.0, 1.5]) + n_epochs = trial.suggest_categorical("n_epochs", [1000, 2000, 3000]) + + # Temporal split + day_idx = data["day_idx"] + split_day = int(day_idx.max() * 0.7) + train_mask = day_idx <= split_day + eval_mask = day_idx > split_day + train_data = _subset(data, train_mask) + eval_data = _subset(data, eval_mask) + + # Init + params, n_total_linear = init_params( + jax.random.PRNGKey(42), + data["n_peer_feat"], data["n_linear_feat"], + encoder_hidden, n_peer_outputs, + n_pools, data["init_log_cadences"], + encoder_depth=encoder_depth, + ) + + # OLS warm-start + x_trn = data["x_linear"][train_mask] + y_trn = data["y_total"][train_mask] + n_peer_cols = n_peer_outputs * 3 + x_trn_padded = np.concatenate([ + x_trn, np.zeros((x_trn.shape[0], n_peer_cols), dtype=np.float32) + ], axis=1) + sol, _, _, _ = np.linalg.lstsq(x_trn_padded, y_trn, rcond=None) + params["noise_coeffs"] = jnp.array(sol.astype(np.float32)) + + # Build grad_fn (cache by config) + cache_key = (n_pools, data["tvl_col"], data["btc_vol_col"]) + if cache_key not in grad_fn_cache: + grad_fn_cache[cache_key] = make_loss_fn( + data["pool_coeffs"], data["pool_gas"], n_pools, + data["tvl_col"], data["btc_vol_col"]) + grad_fn = grad_fn_cache[cache_key] + + # Train + params = train(params, train_data, grad_fn, n_epochs, lr, + l2_alpha, huber_delta, verbose=False) + + # Eval: compute total R² + x_linear = np.array(eval_data["x_linear"]) + peer_input = eval_data["peer_input"] + peer_mask = eval_data["peer_mask"] + y_total = np.array(eval_data["y_total"]) + pool_idx = np.array(eval_data["pool_idx"]) + sgd = np.array(eval_data["sample_grid_days"]) + log_cadence = np.array(params["log_cadence"]) + + peer_effect = np.array(forward_encoder( + params, jnp.array(peer_input), jnp.array(peer_mask))) + + tvl_col = data["tvl_col"] + btc_vol_col = data["btc_vol_col"] + tvl = x_linear[:, tvl_col:tvl_col + 1] + btc_vol = x_linear[:, btc_vol_col:btc_vol_col + 1] + x_full = np.concatenate([ + x_linear, peer_effect, peer_effect * tvl, peer_effect * btc_vol + ], axis=1) + + noise_coeffs = np.array(params["noise_coeffs"]) + log_v_noise = x_full @ noise_coeffs + + v_arb = np.zeros(len(y_total)) + for i in range(n_pools): + mask = pool_idx == i + if not mask.any(): + continue + v_arb_all = np.array(interpolate_pool_daily( + data["pool_coeffs"][i], jnp.float64(log_cadence[i]), + data["pool_gas"][i])) + v_arb[mask] = v_arb_all[sgd[mask]] + + log_v_arb = np.log(np.maximum(v_arb, 1e-10)) + pred_total = np.logaddexp(log_v_arb, log_v_noise) + + r2s = [] + for i in range(n_pools): + mask = pool_idx == i + if mask.sum() < 2: + continue + yt = y_total[mask] + pt = pred_total[mask] + ss_res = np.sum((yt - pt) ** 2) + ss_tot = np.sum((yt - yt.mean()) ** 2) + r2s.append(1 - ss_res / max(ss_tot, 1e-10)) + + med_total = float(np.median(r2s)) if r2s else -10.0 + + cads = np.exp(log_cadence) + print(f" Trial {trial.number}: total={med_total:.4f}" + f" enc_h={encoder_hidden} d={encoder_depth} n_po={n_peer_outputs}" + f" hub={huber_delta} lr={lr:.1e} l2={l2_alpha:.1e}" + f" ep={n_epochs} tw={trend_w}" + f" cad=[{cads.min():.0f}-{np.median(cads):.0f}-{cads.max():.0f}]") + + return med_total + + study = optuna.create_study(direction="maximize") + study.optimize(objective, n_trials=n_trials) + + print(f"\n{'='*70}") + print("Optuna Results (hybrid)") + print(f"{'='*70}") + print(f" Best eval total R²: {study.best_value:.4f}") + print(f" Best params:") + for k, v in sorted(study.best_params.items()): + print(f" {k}: {v}") + + print(f"\n Top 10:") + trials = sorted(study.trials, key=lambda t: t.value if t.value else -999, + reverse=True) + for t in trials[:10]: + if t.value is not None: + print(f" #{t.number}: total={t.value:.4f}" + f" enc_h={t.params['encoder_hidden']}" + f" d={t.params['encoder_depth']}" + f" n_po={t.params['n_peer_outputs']}" + f" ep={t.params['n_epochs']}" + f" tw={t.params['trend_window']}") + + return study + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--tune", type=int, default=0, + help="Number of Optuna trials (0 = single run)") + parser.add_argument("--epochs", type=int, default=2000) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--l2-alpha", type=float, default=1e-3) + parser.add_argument("--huber-delta", type=float, default=1.0) + parser.add_argument("--encoder-hidden", type=int, default=16) + parser.add_argument("--encoder-depth", type=int, default=1) + parser.add_argument("--n-peer-outputs", type=int, default=1) + parser.add_argument("--trend-windows", type=int, nargs="+", default=[7]) + args = parser.parse_args() + + os.environ.setdefault("JAX_PLATFORMS", "cpu") + + print("=" * 70) + print("Hybrid: DeepSets Peer Encoder + Linear Noise Model") + print(f" encoder_hidden={args.encoder_hidden}," + f" n_peer_outputs={args.n_peer_outputs}") + print(f" epochs={args.epochs}, lr={args.lr}, l2={args.l2_alpha}") + print("=" * 70) + + matched_clean, option_c_clean = load_stage1() + + if args.tune > 0: + run_optuna(matched_clean, option_c_clean, args.tune) + return + + print("\nBuilding data...") + t0 = time.time() + data = build_data(matched_clean, option_c_clean, + trend_windows=tuple(args.trend_windows)) + n_pools = data["n_pools"] + print(f" {len(data['pool_idx'])} samples, {n_pools} pools") + print(f" Linear features: {data['n_linear_feat']}") + print(f" Peer encoder input: {data['n_peer_feat']} per peer," + f" {data['n_peers']} peers") + print(f" Build time: {time.time() - t0:.1f}s") + + # Temporal split + day_idx = data["day_idx"] + split_day = int(day_idx.max() * 0.7) + train_mask = day_idx <= split_day + eval_mask = day_idx > split_day + + train_data = _subset(data, train_mask) + eval_data = _subset(data, eval_mask) + + # Init + params, n_total_linear = init_params( + jax.random.PRNGKey(42), + data["n_peer_feat"], data["n_linear_feat"], + args.encoder_hidden, args.n_peer_outputs, + n_pools, data["init_log_cadences"], + encoder_depth=args.encoder_depth, + ) + + # Warm-start linear coeffs via OLS (peer_effect = 0 initially) + x_trn = data["x_linear"][train_mask] + y_trn = data["y_total"][train_mask] + # Pad with zeros for peer_effect columns (raw + ×tvl + ×btc_vol) + n_peer_cols = args.n_peer_outputs * 3 + x_trn_padded = np.concatenate([ + x_trn, + np.zeros((x_trn.shape[0], n_peer_cols), dtype=np.float32) + ], axis=1) + sol, _, _, _ = np.linalg.lstsq(x_trn_padded, y_trn, rcond=None) + params["noise_coeffs"] = jnp.array(sol.astype(np.float32)) + + n_enc_params = (args.encoder_hidden * data["n_peer_feat"] + + args.encoder_hidden + + args.encoder_hidden * args.n_peer_outputs + + args.n_peer_outputs) + print(f"\n Params: {n_total_linear} linear + {n_enc_params} encoder" + f" + {n_pools} cadences = {n_total_linear + n_enc_params + n_pools}") + print(f" Init cadence: {np.exp(data['init_log_cadences']).min():.1f}" + f"-{np.median(np.exp(data['init_log_cadences'])):.1f}" + f"-{np.exp(data['init_log_cadences']).max():.1f} min") + + # Train + grad_fn = make_loss_fn(data["pool_coeffs"], data["pool_gas"], n_pools, + data["tvl_col"], data["btc_vol_col"]) + + print("\n Compiling...") + t0 = time.time() + params = train(params, train_data, grad_fn, args.epochs, args.lr, + args.l2_alpha, args.huber_delta) + print(f" Training: {time.time() - t0:.1f}s") + + # Evaluate + print("\n --- Train ---") + evaluate(params, train_data) + print("\n --- Eval ---") + evaluate(params, eval_data) + + print(f"\n Baselines (eval, total volume R²):") + print(f" V_arb only: median R² = -0.33") + print(f" Linear shared: median R² = 0.39") + print(f" Linear+intercept: median R² = 0.39") + print(f" DeepSets: median R² = 0.43") + + +if __name__ == "__main__": + main() diff --git a/experiments/run_linear_market_noise.py b/experiments/run_linear_market_noise.py new file mode 100644 index 0000000..56524ce --- /dev/null +++ b/experiments/run_linear_market_noise.py @@ -0,0 +1,627 @@ +"""Linear noise model with market features and learnable cadence. + +V_total = V_arb(cadence) + exp(x @ coeffs) + +where x includes: + - Option C x_obs (intercept, log_tvl_lag1, dow_sin, dow_cos) + - Cross-pool lagged volumes (token-A, token-B, chain peers) + - Market features (BTC price/vol/trend, token prices/vol/trend) + +Cadence is per-pool, optimized jointly with noise coefficients via Adam +through the differentiable PCHIP grid. + +Usage: + python experiments/run_linear_market_noise.py + python experiments/run_linear_market_noise.py --trend-windows 7 14 30 + python experiments/run_linear_market_noise.py --no-market # x_obs only +""" + +import argparse +import os +import pickle +import sys +import time + +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd + +CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "token_factored_calibration", "_cache", +) + + +def load_stage1(): + path = os.path.join(CACHE_DIR, "stage1.pkl") + if not os.path.exists(path): + print("ERROR: no stage1 cache.") + sys.exit(1) + with open(path, "rb") as f: + data = pickle.load(f) + return data["matched_clean"], data["option_c_clean"] + + +def build_data(matched_clean, option_c_clean, trend_windows=(7, 14, 30), + include_market=True, include_cross_pool=True): + """Build feature matrix and targets.""" + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + from quantammsim.calibration.pool_data import ( + build_x_obs, build_cross_pool_x_obs, K_OBS_REDUCED, K_OBS_CROSS, + ) + from quantammsim.calibration.market_features import ( + build_pool_market_features, pool_market_features_to_matrix, + ) + + pool_ids = sorted(matched_clean.keys()) + n_pools = len(pool_ids) + + # Common date grid + all_dates = set() + for pid in pool_ids: + all_dates.update(matched_clean[pid]["panel"]["date"].values) + date_list = sorted(all_dates) + n_dates = len(date_list) + date_to_idx = {d: i for i, d in enumerate(date_list)} + + # Per-pool: V_arb, volumes, coeffs, gas, grid day mapping + vol_matrix = np.full((n_dates, n_pools), np.nan) + pool_coeffs = [] + pool_gas = [] + init_log_cadences = np.zeros(n_pools, dtype=np.float32) + common_to_grid = np.full((n_pools, n_dates), 0, dtype=np.int32) + + for j, pid in enumerate(pool_ids): + entry = matched_clean[pid] + oc = option_c_clean[pid] + panel = entry["panel"] + + pool_coeffs.append(entry["coeffs"]) + pool_gas.append(jnp.float64(np.exp(oc["log_gas"]))) + init_log_cadences[j] = oc["log_cadence"] + + dates = panel["date"].values + log_vols = panel["log_volume"].values.astype(float) + for k, date in enumerate(dates): + t = date_to_idx[date] + vol_matrix[t, j] = log_vols[k] + common_to_grid[j, t] = entry["day_indices"][k] + + # Build samples: require t >= 1 (for lag) + sample_pools, sample_days = [], [] + for i in range(n_pools): + for t in range(1, n_dates): + if np.isnan(vol_matrix[t, i]) or np.isnan(vol_matrix[t - 1, i]): + continue + sample_pools.append(i) + sample_days.append(t) + sample_pools = np.array(sample_pools, dtype=np.int32) + sample_days = np.array(sample_days, dtype=np.int32) + n_samples = len(sample_pools) + + # x_obs: reduced (4) or cross-pool (7) + if include_cross_pool: + k_obs = K_OBS_CROSS + x_obs_grid = np.full((n_dates, n_pools, k_obs), np.nan) + for j, pid in enumerate(pool_ids): + panel = matched_clean[pid]["panel"] + xc = build_cross_pool_x_obs(panel, matched_clean, pid) # (n_obs-1, 7) + dates = panel["date"].values + for k, date in enumerate(dates[1:]): + x_obs_grid[date_to_idx[date], j] = xc[k] + else: + k_obs = K_OBS_REDUCED + x_obs_grid = np.full((n_dates, n_pools, k_obs), np.nan) + for j, pid in enumerate(pool_ids): + panel = matched_clean[pid]["panel"] + xr = build_x_obs(panel, reduced=True) + dates = panel["date"].values + for k, date in enumerate(dates): + x_obs_grid[date_to_idx[date], j] = xr[k] + + # Per-sample x_obs + x_obs = np.zeros((n_samples, k_obs), dtype=np.float32) + for s in range(n_samples): + xval = x_obs_grid[sample_days[s], sample_pools[s]] + if np.all(np.isfinite(xval)): + x_obs[s] = xval + + # Market features + if include_market: + print(" Building market features...") + pool_feat = build_pool_market_features( + matched_clean, trend_windows=list(trend_windows)) + x_market, market_names = pool_market_features_to_matrix( + pool_feat, matched_clean, date_to_idx, pool_ids, + sample_pools, sample_days) + print(f" Market features: {len(market_names)} columns") + else: + x_market = np.zeros((n_samples, 0), dtype=np.float32) + market_names = [] + + # Combine base features + x_base = np.concatenate([x_obs, x_market], axis=1).astype(np.float32) + base_names = [f"xobs_{i}" for i in range(k_obs)] + market_names + + # Feature-appropriate scaling: + # - intercept: untouched + # - log_tvl, btc_log_price: raw log scale (absolute level carries info) + # - dow_sin/cos: already [-1,1], no scaling + # - returns, trends, vol_zscore: already comparable, no scaling + # - realized_vol, pair_vol: small positive, light centering + # - cross-pool volumes: z-score (different scales across pool groups) + x_mean = np.zeros(x_base.shape[1], dtype=np.float32) + x_std = np.ones(x_base.shape[1], dtype=np.float32) + + for i, name in enumerate(base_names): + if name == "xobs_0": + # Intercept: leave as-is + pass + elif name in ("xobs_1", "btc_log_price"): + # Log levels: leave in raw log scale (range ~10-20) + pass + elif name in ("xobs_2", "xobs_3"): + # dow_sin, dow_cos: already [-1,1] + pass + elif "volume_zscore" in name: + # Already z-scored by construction + pass + elif "log_return" in name or "trend_" in name: + # Returns and trends: small, centered around 0, comparable + pass + elif "realized_vol" in name or "pair_realized" in name: + # Volatilities: small positive, center but don't squeeze + x_mean[i] = float(np.mean(x_base[:, i])) + # Use std but don't over-compress — floor at 0.01 + x_std[i] = max(float(np.std(x_base[:, i])), 0.01) + elif name.startswith("xobs_") and int(name.split("_")[1]) >= 4: + # Cross-pool volumes (xobs_4,5,6): z-score (different scales) + x_mean[i] = float(np.mean(x_base[:, i])) + x_std[i] = max(float(np.std(x_base[:, i])), 1e-6) + # else: leave untouched + + x_base = ((x_base - x_mean) / x_std).astype(np.float32) + + # Interaction terms (products of standardized features) + col_idx = {name: i for i, name in enumerate(base_names)} + interactions = [] + interaction_names = [] + + def _add_interaction(name_a, name_b): + if name_a in col_idx and name_b in col_idx: + interactions.append( + x_base[:, col_idx[name_a]] * x_base[:, col_idx[name_b]]) + interaction_names.append(f"{name_a}×{name_b}") + + _add_interaction("xobs_1", "btc_realized_vol_7d") # tvl × btc vol + _add_interaction("xobs_1", "tok_a_realized_vol_7d") # tvl × tok_a vol + _add_interaction("xobs_1", "pair_realized_vol_7d") # tvl × pair vol + _add_interaction("tok_a_realized_vol_7d", "tok_b_realized_vol_7d") # cross-token vol + + if interactions: + x_interactions = np.column_stack(interactions).astype(np.float32) + x_all = np.concatenate([x_base, x_interactions], axis=1) + feat_names = base_names + interaction_names + # Extend x_mean/x_std for interaction columns (already standardized → 0/1) + x_mean = np.concatenate([x_mean, np.zeros(len(interactions))]) + x_std = np.concatenate([x_std, np.ones(len(interactions))]) + else: + x_all = x_base + feat_names = base_names + + # Targets + y_total = np.array([vol_matrix[sample_days[s], sample_pools[s]] + for s in range(n_samples)], dtype=np.float32) + sample_grid_days = common_to_grid[sample_pools, sample_days] + + return { + "x": x_all, # (n_samples, n_feat) + "y_total": y_total, # (n_samples,) + "pool_idx": sample_pools, # (n_samples,) + "day_idx": sample_days, # (n_samples,) + "sample_grid_days": sample_grid_days, # (n_samples,) + "pool_coeffs": pool_coeffs, + "pool_gas": pool_gas, + "init_log_cadences": init_log_cadences, + "n_pools": n_pools, + "n_feat": x_all.shape[1], + "pool_ids": pool_ids, + "feat_names": feat_names, + "x_mean": x_mean, + "x_std": x_std, + } + + +def make_loss_fn(pool_coeffs, pool_gas, n_pools): + """Loss function with learnable cadence + linear noise model. + + Supports both shared coefficients (noise_coeffs shape: (n_feat,)) and + per-pool coefficients (noise_coeffs shape: (n_pools, n_feat)). + Detected at trace time from the array shape. + """ + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + + def loss_fn(params, x, y_total, sample_grid_days, pool_idx, + l2_alpha, huber_delta): + log_cadence = params["log_cadence"] + noise_coeffs = params["noise_coeffs"] + + # Per-pool or shared coefficients + if noise_coeffs.ndim == 2: + # Per-pool: (n_pools, n_feat) — gather each sample's pool coeffs + per_sample_coeffs = noise_coeffs[pool_idx] # (n_samples, n_feat) + log_v_noise = jnp.sum(x * per_sample_coeffs, axis=1) + else: + # Shared: (n_feat,) + log_v_noise = x @ noise_coeffs + + if "pool_intercepts" in params: + log_v_noise = log_v_noise + params["pool_intercepts"][pool_idx] + + # V_arb from PCHIP at learned cadence + n_samples = y_total.shape[0] + v_arb = jnp.zeros(n_samples) + for i in range(n_pools): + v_arb_all = interpolate_pool_daily( + pool_coeffs[i], log_cadence[i], pool_gas[i]) + safe_days = jnp.clip(sample_grid_days, 0, v_arb_all.shape[0] - 1) + v_arb = jnp.where(pool_idx == i, v_arb_all[safe_days], v_arb) + + log_v_arb = jnp.log(jnp.maximum(v_arb, 1e-10)) + log_v_total = jnp.logaddexp(log_v_arb, log_v_noise) + + # Huber loss with per-pool weighting + residuals = log_v_total - y_total + abs_r = jnp.abs(residuals) + huber_vals = jnp.where(abs_r <= huber_delta, 0.5 * residuals ** 2, + huber_delta * (abs_r - 0.5 * huber_delta)) + + pool_counts = jnp.zeros(n_pools).at[pool_idx].add( + jnp.ones_like(pool_idx, dtype=jnp.float32)) + active = (pool_counts > 0).astype(jnp.float32) + n_active = jnp.maximum(jnp.sum(active), 1.0) + pool_counts = jnp.maximum(pool_counts, 1.0) + pool_sums = jnp.zeros(n_pools).at[pool_idx].add(huber_vals) + data_loss = jnp.sum((pool_sums / pool_counts) * active) / n_active + + reg = l2_alpha * jnp.sum(noise_coeffs ** 2) + return data_loss + reg + + return jax.jit(jax.value_and_grad(loss_fn)) + + +def train(params, data, grad_fn, n_epochs, lr, l2_alpha, huber_delta, + verbose=True): + m = {k: jnp.zeros_like(v) for k, v in params.items()} + v = {k: jnp.zeros_like(v) for k, v in params.items()} + + x = jnp.array(data["x"]) + y = jnp.array(data["y_total"]) + sgd = jnp.array(data["sample_grid_days"]) + pidx = jnp.array(data["pool_idx"]) + + for epoch in range(n_epochs): + loss_val, grads = grad_fn( + params, x, y, sgd, pidx, l2_alpha, huber_delta) + loss_f = float(loss_val) + + for k in params: + m[k] = 0.9 * m[k] + 0.1 * grads[k] + v[k] = 0.999 * v[k] + 0.001 * grads[k] ** 2 + m_hat = m[k] / (1.0 - 0.9 ** (epoch + 1)) + v_hat = v[k] / (1.0 - 0.999 ** (epoch + 1)) + params[k] = params[k] - lr * m_hat / (jnp.sqrt(v_hat) + 1e-8) + + if verbose and (epoch % 200 == 0 or epoch == n_epochs - 1): + cads = np.exp(np.array(params["log_cadence"])) + nc = np.array(params["noise_coeffs"]) + print(f" epoch {epoch:4d} loss={loss_f:.6f}" + f" cad=[{cads.min():.1f}-{np.median(cads):.1f}-{cads.max():.1f}]" + f" |coeffs|={np.mean(np.abs(nc)):.3f}") + + return params + + +def evaluate(params, data, label=""): + """Evaluate decomposition quality.""" + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + + x = np.array(data["x"]) + y_total = np.array(data["y_total"]) + pool_idx = np.array(data["pool_idx"]) + sgd = np.array(data["sample_grid_days"]) + log_cadence = np.array(params["log_cadence"]) + noise_coeffs = np.array(params["noise_coeffs"]) + init_cads = data["init_log_cadences"] + pool_ids = data["pool_ids"] + n_pools = data["n_pools"] + + if noise_coeffs.ndim == 2: + # Per-pool: (n_pools, n_feat) + per_sample_coeffs = noise_coeffs[pool_idx] + log_v_noise = np.sum(x * per_sample_coeffs, axis=1) + else: + log_v_noise = x @ noise_coeffs + if "pool_intercepts" in params: + pool_intercepts = np.array(params["pool_intercepts"]) + log_v_noise = log_v_noise + pool_intercepts[pool_idx] + v_noise = np.exp(log_v_noise) + + v_arb = np.zeros(len(y_total)) + for i in range(n_pools): + mask = pool_idx == i + if not mask.any(): + continue + v_arb_all = np.array(interpolate_pool_daily( + data["pool_coeffs"][i], jnp.float64(log_cadence[i]), + data["pool_gas"][i])) + v_arb[mask] = v_arb_all[sgd[mask]] + + v_obs = np.exp(y_total) + log_v_arb = np.log(np.maximum(v_arb, 1e-10)) + pred_total = np.logaddexp(log_v_arb, log_v_noise) + + if label: + print(f"\n {label}:") + print(f" {'Pool'[:16]:16s} {'R²':>6s} {'Cad':>5s} {'→':>2s} {'learn':>5s}" + f" {'Arb%':>6s} {'Noise%':>7s} {'Flag':>5s}") + print(f" {'-'*60}") + + r2s = {} + pool_diag = [] + for i in range(n_pools): + mask = pool_idx == i + if mask.sum() < 2: + continue + yt = y_total[mask] + pt = pred_total[mask] + ss_res = np.sum((yt - pt) ** 2) + ss_tot = np.sum((yt - yt.mean()) ** 2) + r2s[i] = 1 - ss_res / max(ss_tot, 1e-10) + + pid = pool_ids[i] + ci = np.exp(init_cads[i]) + cl = np.exp(log_cadence[i]) + arb_pct = np.median(v_arb[mask] / v_obs[mask]) * 100 + noise_pct = np.median(v_noise[mask] / v_obs[mask]) * 100 + + flags = [] + if arb_pct > 150: + flags.append("A") + if cl <= 1.01 or cl >= 59.9: + flags.append("B") + if r2s[i] < 0: + flags.append("X") + flag_str = "".join(flags) + + pool_diag.append({ + "pid": pid, "r2": r2s[i], "cad_init": ci, "cad_learned": cl, + "arb_pct": arb_pct, "noise_pct": noise_pct, "flags": flag_str, + }) + + print(f" {pid[:16]:16s} {r2s[i]:6.3f} {ci:5.1f} → {cl:5.1f}" + f" {arb_pct:6.0f}% {noise_pct:6.0f}% {flag_str:>5s}") + + vals = [x for x in r2s.values() if np.isfinite(x)] + med = np.median(vals) if vals else float("nan") + healthy = [d for d in pool_diag if d["arb_pct"] <= 150 and d["r2"] > 0] + med_h = np.median([d["r2"] for d in healthy]) if healthy else float("nan") + n_path = sum(1 for d in pool_diag if d["arb_pct"] > 150) + n_bound = sum(1 for d in pool_diag + if d["cad_learned"] <= 1.01 or d["cad_learned"] >= 59.9) + + print(f"\n Median R²: {med:.4f} (healthy: {med_h:.4f})") + print(f" Healthy: {len(pool_diag) - n_path}/{len(pool_diag)}," + f" at bounds: {n_bound}") + + return med, r2s, pool_diag + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--epochs", type=int, default=2000) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--l2-alpha", type=float, default=1e-3) + parser.add_argument("--huber-delta", type=float, default=1.0) + parser.add_argument("--trend-windows", type=int, nargs="+", default=[7]) + parser.add_argument("--no-market", action="store_true", + help="x_obs only, no market features") + parser.add_argument("--no-cross-pool", action="store_true", + help="Reduced x_obs (4) instead of cross-pool (7)") + parser.add_argument("--pool-intercepts", action="store_true", + help="Per-pool intercept (shared slopes + per-pool bias)") + parser.add_argument("--per-pool", action="store_true", + help="Per-pool noise coefficients (Option A)") + parser.add_argument("--no-split", action="store_true", + help="Train on all data (no temporal holdout)") + args = parser.parse_args() + + os.environ.setdefault("JAX_PLATFORMS", "cpu") + + print("=" * 70) + print("Linear Noise Model + Learnable Cadence") + mode = "per-pool" if args.per_pool else ( + "shared+intercepts" if args.pool_intercepts else "shared") + print(f" mode={mode}, market={not args.no_market}," + f" cross_pool={not args.no_cross_pool}") + print(f" trend_windows={args.trend_windows}") + print(f" epochs={args.epochs}, lr={args.lr}, l2={args.l2_alpha}") + print("=" * 70) + + matched_clean, option_c_clean = load_stage1() + + print("\nBuilding data...") + t0 = time.time() + data = build_data( + matched_clean, option_c_clean, + trend_windows=tuple(args.trend_windows), + include_market=not args.no_market, + include_cross_pool=not args.no_cross_pool, + ) + print(f" {len(data['pool_idx'])} samples, {data['n_pools']} pools," + f" {data['n_feat']} features, {time.time() - t0:.1f}s") + print(f" Features: {data['feat_names']}") + + # Split + day_idx = data["day_idx"] + n_samples = len(day_idx) + if args.no_split: + train_mask = np.ones(n_samples, dtype=bool) + eval_mask = None + else: + split_day = int(day_idx.max() * 0.7) + train_mask = day_idx <= split_day + eval_mask = day_idx > split_day + + train_data = {k: v[train_mask] if isinstance(v, np.ndarray) + and v.shape[0] == n_samples else v + for k, v in data.items()} + if eval_mask is not None: + eval_data = {k: v[eval_mask] if isinstance(v, np.ndarray) + and v.shape[0] == n_samples else v + for k, v in data.items()} + else: + eval_data = None + + # Init params + n_feat = data["n_feat"] + n_pools = data["n_pools"] + x_trn = data["x"][train_mask] + y_trn = data["y_total"][train_mask] + pool_idx_trn = data["pool_idx"][train_mask] + + if args.per_pool: + # Per-pool coefficients: (n_pools, n_feat) + # Warm-start each pool via per-pool Ridge (not OLS — avoids blowup + # on pools with few samples or near-singular features) + from sklearn.linear_model import RidgeCV + coeffs_init = np.zeros((n_pools, n_feat), dtype=np.float32) + # Shared Ridge as fallback + ridge_shared = RidgeCV(alphas=np.logspace(-2, 4, 50)) + ridge_shared.fit(x_trn, y_trn) + for i in range(n_pools): + mask_i = pool_idx_trn == i + if mask_i.sum() >= 20: + ridge_i = RidgeCV(alphas=np.logspace(-2, 4, 50)) + ridge_i.fit(x_trn[mask_i], y_trn[mask_i]) + coeffs_init[i] = ridge_i.coef_ + coeffs_init[i, 0] += ridge_i.intercept_ # fold intercept into xobs_0 + else: + coeffs_init[i] = ridge_shared.coef_ + coeffs_init[i, 0] += ridge_shared.intercept_ + params = { + "log_cadence": jnp.array(data["init_log_cadences"]), + "noise_coeffs": jnp.array(coeffs_init), + } + print(f"\n Per-pool coefficients: {n_pools} × {n_feat} = {n_pools * n_feat} params") + print(f" Ridge warm-start |coeffs|={np.mean(np.abs(coeffs_init)):.3f}") + else: + params = { + "log_cadence": jnp.array(data["init_log_cadences"]), + "noise_coeffs": jnp.zeros(n_feat), + } + # Warm-start noise_coeffs via OLS on train + sol, _, _, _ = np.linalg.lstsq(x_trn, y_trn, rcond=None) + params["noise_coeffs"] = jnp.array(sol.astype(np.float32)) + + if args.pool_intercepts and not args.per_pool: + # Init per-pool intercepts from OLS residuals + ols_pred = x_trn @ sol + ols_resid = y_trn - ols_pred + pool_idx_trn = data["pool_idx"][train_mask] + intercepts = np.zeros(n_pools, dtype=np.float32) + for i in range(n_pools): + mask_i = pool_idx_trn == i + if mask_i.sum() > 0: + intercepts[i] = np.mean(ols_resid[mask_i]) + params["pool_intercepts"] = jnp.array(intercepts) + print(f" Per-pool intercepts: {n_pools} pools" + f" (range {intercepts.min():.2f} to {intercepts.max():.2f})") + + print(f"\n Init cadence: {np.exp(data['init_log_cadences']).min():.1f}" + f"-{np.median(np.exp(data['init_log_cadences'])):.1f}" + f"-{np.exp(data['init_log_cadences']).max():.1f} min") + total_params = sum(v.size for v in params.values()) + print(f" Total params: {total_params}") + + # Build loss and train + grad_fn = make_loss_fn(data["pool_coeffs"], data["pool_gas"], data["n_pools"]) + + print("\n Compiling...") + t0 = time.time() + params = train(params, train_data, grad_fn, args.epochs, args.lr, + args.l2_alpha, args.huber_delta) + print(f" Training: {time.time() - t0:.1f}s") + + # Print learned coefficients + nc = np.array(params["noise_coeffs"]) + if nc.ndim == 2: + # Per-pool: print median coefficient across pools + print(f"\n Per-pool noise coefficients — median across {n_pools} pools:") + for i, name in enumerate(data["feat_names"]): + vals = nc[:, i] + print(f" {name:30s} med={np.median(vals):+7.3f}" + f" [{vals.min():+7.3f}, {vals.max():+7.3f}]") + else: + print(f"\n Noise coefficients ({len(nc)}):") + for i, name in enumerate(data["feat_names"]): + print(f" {name:30s} {nc[i]:+8.4f}") + + # Evaluate + if eval_data is not None: + print("\n --- Train ---") + evaluate(params, train_data) + print("\n --- Eval ---") + evaluate(params, eval_data) + else: + print("\n --- All data ---") + evaluate(params, train_data) + + # Save artifact + artifact_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "linear_market_noise", + ) + os.makedirs(artifact_dir, exist_ok=True) + artifact = { + "noise_coeffs": np.array(params["noise_coeffs"]), + "log_cadence": np.array(params["log_cadence"]), + "init_log_cadences": data["init_log_cadences"], + "feat_names": data["feat_names"], + "pool_ids": data["pool_ids"], + "n_pools": data["n_pools"], + "n_feat": data["n_feat"], + "x_mean": data["x_mean"], + "x_std": data["x_std"], + "hparams": { + "epochs": args.epochs, "lr": args.lr, + "l2_alpha": args.l2_alpha, "huber_delta": args.huber_delta, + "trend_windows": args.trend_windows, + "per_pool": args.per_pool, + "pool_intercepts": args.pool_intercepts, + }, + } + if "pool_intercepts" in params: + artifact["pool_intercepts"] = np.array(params["pool_intercepts"]) + artifact_path = os.path.join(artifact_dir, "model.npz") + np.savez(artifact_path, **{k: v for k, v in artifact.items() + if isinstance(v, np.ndarray)}) + # Save non-array metadata separately + import json + meta_path = os.path.join(artifact_dir, "meta.json") + meta = {k: v for k, v in artifact.items() if not isinstance(v, np.ndarray)} + with open(meta_path, "w") as f: + json.dump(meta, f, indent=2, default=str) + print(f"\n Saved artifact: {artifact_path}") + print(f" Saved metadata: {meta_path}") + + # Baselines + print(f"\n Baselines (eval, total volume R²):") + print(f" V_arb only: median R² = -0.33") + print(f" Naive lag: median R² = 0.01") + print(f" DeepSets best: median R² = 0.43") + + +if __name__ == "__main__": + main() diff --git a/experiments/run_mlp_noise.py b/experiments/run_mlp_noise.py new file mode 100644 index 0000000..35d5418 --- /dev/null +++ b/experiments/run_mlp_noise.py @@ -0,0 +1,623 @@ +"""MLP noise model with Binance market features and learnable cadence. + +No cross-pool DEX dependency — only uses this pool's TVL + public market +data (Binance prices/volumes for BTC and the pool's tokens). + +Architecture: + log(V_noise) = MLP(x_market) + V_total = V_arb(cadence) + exp(log_v_noise) + +where x_market = [log_tvl, dow_sin, dow_cos, btc_features, tok_a_features, +tok_b_features, pair_vol, interactions]. + +Cadence is per-pool, learned jointly via Adam through PCHIP. + +Usage: + python experiments/run_mlp_noise.py + python experiments/run_mlp_noise.py --hidden 64 32 --epochs 3000 + python experiments/run_mlp_noise.py --per-pool --hidden 32 +""" + +import argparse +import os +import time + +import jax +import jax.numpy as jnp +import numpy as np + + +# ---- Model ---- + + +def init_mlp_params(key, n_input, hidden_sizes, n_pools, init_log_cadences, + per_pool=False): + """Initialize MLP parameters. + + MLP: input → hidden1 → ... → hiddenN → 1 (with ReLU activations). + If per_pool: separate output bias per pool. + """ + params = {} + keys = jax.random.split(key, len(hidden_sizes) + 2) + + # Hidden layers + in_dim = n_input + for i, h in enumerate(hidden_sizes): + params[f"W{i}"] = jax.random.normal(keys[i], (in_dim, h)) * np.sqrt(2.0 / in_dim) + params[f"b{i}"] = jnp.zeros(h) + in_dim = h + + # Output layer → scalar + params["W_out"] = jax.random.normal(keys[-2], (in_dim, 1)) * 0.01 + params["b_out"] = jnp.zeros(1) + + if per_pool: + params["pool_bias"] = jnp.zeros(n_pools) + + params["log_cadence"] = jnp.array(init_log_cadences) + return params + + +def forward_mlp(params, x, pool_idx=None): + """MLP forward pass. Returns (n_samples,) log_v_noise.""" + h = x + i = 0 + while f"W{i}" in params: + h = jnp.maximum(h @ params[f"W{i}"] + params[f"b{i}"], 0.0) + i += 1 + out = (h @ params["W_out"] + params["b_out"])[:, 0] + + if "pool_bias" in params and pool_idx is not None: + out = out + params["pool_bias"][pool_idx] + + return out + + +def make_loss_fn(pool_coeffs, pool_gas, n_pools): + """Loss with learnable cadence + MLP noise model.""" + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + + def loss_fn(params, x, y_total, sample_grid_days, pool_idx, + l2_alpha, huber_delta): + log_v_noise = forward_mlp(params, x, pool_idx) + + log_cadence = params["log_cadence"] + n_samples = y_total.shape[0] + v_arb = jnp.zeros(n_samples) + for i in range(n_pools): + v_arb_all = interpolate_pool_daily( + pool_coeffs[i], log_cadence[i], pool_gas[i]) + safe_days = jnp.clip(sample_grid_days, 0, v_arb_all.shape[0] - 1) + v_arb = jnp.where(pool_idx == i, v_arb_all[safe_days], v_arb) + + log_v_arb = jnp.log(jnp.maximum(v_arb, 1e-10)) + log_v_total = jnp.logaddexp(log_v_arb, log_v_noise) + + residuals = log_v_total - y_total + abs_r = jnp.abs(residuals) + huber_vals = jnp.where(abs_r <= huber_delta, 0.5 * residuals ** 2, + huber_delta * (abs_r - 0.5 * huber_delta)) + + pool_counts = jnp.zeros(n_pools).at[pool_idx].add( + jnp.ones_like(pool_idx, dtype=jnp.float32)) + active = (pool_counts > 0).astype(jnp.float32) + n_active = jnp.maximum(jnp.sum(active), 1.0) + pool_counts = jnp.maximum(pool_counts, 1.0) + pool_sums = jnp.zeros(n_pools).at[pool_idx].add(huber_vals) + data_loss = jnp.sum((pool_sums / pool_counts) * active) / n_active + + reg = l2_alpha * sum(jnp.sum(v ** 2) for k, v in params.items() + if k.startswith("W")) + return data_loss + reg + + return jax.jit(jax.value_and_grad(loss_fn)) + + +def train(params, data, grad_fn, n_epochs, lr, l2_alpha, huber_delta, + verbose=True, use_cosine=False, warmup_steps=100): + x = jnp.array(data["x"]) + y = jnp.array(data["y_total"]) + sgd = jnp.array(data["sample_grid_days"]) + pidx = jnp.array(data["pool_idx"]) + + if use_cosine: + import optax + schedule = optax.warmup_cosine_decay_schedule( + init_value=lr * 0.01, + peak_value=lr, + warmup_steps=warmup_steps, + decay_steps=n_epochs, + end_value=lr * 0.01, + ) + optimizer = optax.adam(learning_rate=schedule) + opt_state = optimizer.init(params) + + for epoch in range(n_epochs): + loss_val, grads = grad_fn( + params, x, y, sgd, pidx, l2_alpha, huber_delta) + updates, opt_state = optimizer.update(grads, opt_state, params) + params = optax.apply_updates(params, updates) + + if verbose and (epoch % 500 == 0 or epoch == n_epochs - 1): + cads = np.exp(np.array(params["log_cadence"])) + cur_lr = float(schedule(epoch)) + print(f" epoch {epoch:5d} loss={float(loss_val):.6f}" + f" lr={cur_lr:.2e}" + f" cad=[{cads.min():.1f}-{np.median(cads):.1f}-{cads.max():.1f}]") + else: + m = {k: jnp.zeros_like(v) for k, v in params.items()} + v = {k: jnp.zeros_like(v) for k, v in params.items()} + + for epoch in range(n_epochs): + loss_val, grads = grad_fn( + params, x, y, sgd, pidx, l2_alpha, huber_delta) + + for k in params: + m[k] = 0.9 * m[k] + 0.1 * grads[k] + v[k] = 0.999 * v[k] + 0.001 * grads[k] ** 2 + m_hat = m[k] / (1.0 - 0.9 ** (epoch + 1)) + v_hat = v[k] / (1.0 - 0.999 ** (epoch + 1)) + params[k] = params[k] - lr * m_hat / (jnp.sqrt(v_hat) + 1e-8) + + if verbose and (epoch % 500 == 0 or epoch == n_epochs - 1): + cads = np.exp(np.array(params["log_cadence"])) + print(f" epoch {epoch:5d} loss={float(loss_val):.6f}" + f" cad=[{cads.min():.1f}-{np.median(cads):.1f}-{cads.max():.1f}]") + + return params + + +def evaluate(params, data, label=""): + """Evaluate decomposition.""" + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + + x = np.array(data["x"]) + y_total = np.array(data["y_total"]) + pool_idx = np.array(data["pool_idx"]) + sgd = np.array(data["sample_grid_days"]) + log_cadence = np.array(params["log_cadence"]) + init_cads = data["init_log_cadences"] + pool_ids = data["pool_ids"] + n_pools = data["n_pools"] + + log_v_noise = np.array(forward_mlp( + params, jnp.array(x), + jnp.array(pool_idx) if "pool_bias" in params else None)) + v_noise = np.exp(log_v_noise) + + v_arb = np.zeros(len(y_total)) + for i in range(n_pools): + mask = pool_idx == i + if not mask.any(): + continue + v_arb_all = np.array(interpolate_pool_daily( + data["pool_coeffs"][i], jnp.float64(log_cadence[i]), + data["pool_gas"][i])) + v_arb[mask] = v_arb_all[sgd[mask]] + + v_obs = np.exp(y_total) + log_v_arb = np.log(np.maximum(v_arb, 1e-10)) + pred_total = np.logaddexp(log_v_arb, log_v_noise) + + if label: + print(f"\n {label}:") + print(f" {'Pool'[:16]:16s} {'R²':>6s} {'Cad':>5s} → {'learn':>5s}" + f" {'Arb%':>6s} {'Noise%':>7s} {'Flag':>5s}") + print(f" {'-'*55}") + + r2s = {} + for i in range(n_pools): + mask = pool_idx == i + if mask.sum() < 2: + continue + yt = y_total[mask] + pt = pred_total[mask] + ss_res = np.sum((yt - pt) ** 2) + ss_tot = np.sum((yt - yt.mean()) ** 2) + r2s[i] = 1 - ss_res / max(ss_tot, 1e-10) + + pid = pool_ids[i] + ci = np.exp(init_cads[i]) + cl = np.exp(log_cadence[i]) + arb_pct = np.median(v_arb[mask] / v_obs[mask]) * 100 + noise_pct = np.median(v_noise[mask] / v_obs[mask]) * 100 + flags = [] + if arb_pct > 150: flags.append("A") + if cl <= 1.01 or cl >= 59.9: flags.append("B") + if r2s[i] < 0: flags.append("X") + print(f" {pid[:16]:16s} {r2s[i]:6.3f} {ci:5.1f} → {cl:5.1f}" + f" {arb_pct:6.0f}% {noise_pct:6.0f}% {''.join(flags):>5s}") + + vals = [x for x in r2s.values() if np.isfinite(x)] + med = np.median(vals) if vals else float("nan") + healthy = [r for r in r2s.values() if r > 0 and np.isfinite(r)] + med_h = np.median(healthy) if healthy else float("nan") + print(f"\n Median R²: {med:.4f} (healthy: {med_h:.4f})") + return med, r2s + + +def run_optuna(data, n_trials): + """Optuna sweep over MLP architecture and training hyperparameters.""" + import optuna + + day_idx = data["day_idx"] + n_samples = len(day_idx) + split_day = int(day_idx.max() * 0.7) + train_mask = day_idx <= split_day + eval_mask = day_idx > split_day + train_data = {k: v[train_mask] if isinstance(v, np.ndarray) + and v.shape[0] == n_samples else v + for k, v in data.items()} + eval_data = {k: v[eval_mask] if isinstance(v, np.ndarray) + and v.shape[0] == n_samples else v + for k, v in data.items()} + + n_pools = data["n_pools"] + n_feat = data["n_feat"] + + def objective(trial): + # Architecture + n_layers = trial.suggest_int("n_layers", 1, 7) + first_hidden = trial.suggest_categorical("first_hidden", [8, 16, 32, 64, 128, 256]) + # Bottleneck: each layer is half the previous (min 2) + bottleneck_ratio = trial.suggest_categorical("bottleneck_ratio", [0.5, 0.75, 1.0]) + hidden = [] + h = first_hidden + for _ in range(n_layers): + hidden.append(h) + h = max(int(h * bottleneck_ratio), 2) + + # Training + lr = trial.suggest_float("lr", 1e-4, 5e-2, log=True) + l2_alpha = trial.suggest_float("l2_alpha", 1e-5, 5e-1, log=True) + huber_delta = trial.suggest_categorical("huber_delta", [0.5, 1.0, 1.5, 2.0]) + n_epochs = trial.suggest_categorical("n_epochs", [2000, 5000, 10000, 20000]) + use_cosine = trial.suggest_categorical("use_cosine", [True, False]) + per_pool = trial.suggest_categorical("per_pool", [True, False]) + + params = init_mlp_params( + jax.random.PRNGKey(42), n_feat, hidden, n_pools, + data["init_log_cadences"], per_pool=per_pool) + + # OLS warm-start + x_trn = jnp.array(train_data["x"]) + y_trn = np.array(train_data["y_total"]) + h_act = np.array(x_trn) + i = 0 + while f"W{i}" in params: + h_act = np.maximum( + h_act @ np.array(params[f"W{i}"]) + np.array(params[f"b{i}"]), 0.0) + i += 1 + h_bias = np.concatenate([h_act, np.ones((h_act.shape[0], 1))], axis=1) + sol, _, _, _ = np.linalg.lstsq(h_bias, y_trn[:, None], rcond=None) + params["W_out"] = jnp.array(sol[:-1].astype(np.float32)) + params["b_out"] = jnp.array(sol[-1:].astype(np.float32)) + + grad_fn = make_loss_fn(data["pool_coeffs"], data["pool_gas"], n_pools) + params = train(params, train_data, grad_fn, n_epochs, lr, + l2_alpha, huber_delta, verbose=False, + use_cosine=use_cosine) + + # Eval + x_eval = np.array(eval_data["x"]) + y_eval = np.array(eval_data["y_total"]) + pool_idx_eval = np.array(eval_data["pool_idx"]) + sgd_eval = np.array(eval_data["sample_grid_days"]) + log_cadence = np.array(params["log_cadence"]) + + log_v_noise = np.array(forward_mlp( + params, jnp.array(x_eval), + jnp.array(pool_idx_eval) if per_pool else None)) + + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + v_arb = np.zeros(len(y_eval)) + for i in range(n_pools): + mask = pool_idx_eval == i + if not mask.any(): + continue + v_arb_all = np.array(interpolate_pool_daily( + data["pool_coeffs"][i], jnp.float64(log_cadence[i]), + data["pool_gas"][i])) + v_arb[mask] = v_arb_all[sgd_eval[mask]] + + log_v_arb = np.log(np.maximum(v_arb, 1e-10)) + pred_total = np.logaddexp(log_v_arb, log_v_noise) + + r2s = [] + for i in range(n_pools): + mask = pool_idx_eval == i + if mask.sum() < 2: + continue + yt = y_eval[mask] + pt = pred_total[mask] + ss_res = np.sum((yt - pt) ** 2) + ss_tot = np.sum((yt - yt.mean()) ** 2) + r2s.append(1 - ss_res / max(ss_tot, 1e-10)) + + med_r2 = float(np.median(r2s)) if r2s else -10.0 + arch_str = "×".join(str(h) for h in hidden) + print(f" Trial {trial.number}: eval={med_r2:.4f}" + f" arch=[{arch_str}]" + f" {'cosine' if use_cosine else 'const'}" + f" {'per_pool' if per_pool else 'shared'}" + f" lr={lr:.1e} l2={l2_alpha:.1e}" + f" hub={huber_delta} ep={n_epochs}") + + # Save every trial's model + trial_dir = os.path.join("results", "mlp_noise", "trials", f"trial_{trial.number:04d}") + os.makedirs(trial_dir, exist_ok=True) + save_dict = {k: np.array(v) for k, v in params.items()} + save_dict["x_mean"] = data.get("x_mean", np.zeros(n_feat)) + save_dict["x_std"] = data.get("x_std", np.ones(n_feat)) + np.savez(os.path.join(trial_dir, "model.npz"), **save_dict) + import json as _json + _meta = { + "pool_ids": data["pool_ids"], + "feat_names": data["feat_names"], + "n_feat": n_feat, + "hidden": hidden, + "per_pool": per_pool, + "eval_r2": med_r2, + "hparams": { + "hidden": hidden, "lr": lr, "l2_alpha": l2_alpha, + "huber_delta": huber_delta, "n_epochs": n_epochs, + "use_cosine": use_cosine, "per_pool": per_pool, + "bottleneck_ratio": bottleneck_ratio, + "first_hidden": first_hidden, "n_layers": n_layers, + }, + } + with open(os.path.join(trial_dir, "meta.json"), "w") as _f: + _json.dump(_meta, _f, indent=2) + + return med_r2 + + study = optuna.create_study(direction="maximize") + study.optimize(objective, n_trials=n_trials) + + print(f"\n{'='*70}") + print(f"Optuna Results (MLP noise)") + print(f"{'='*70}") + print(f" Best eval R²: {study.best_value:.4f}") + print(f" Best params:") + for k, v in sorted(study.best_params.items()): + print(f" {k}: {v}") + + trials = sorted(study.trials, key=lambda t: t.value if t.value else -999, + reverse=True) + print(f"\n Top 10:") + for t in trials[:10]: + if t.value is not None: + n_l = t.params["n_layers"] + fh = t.params["first_hidden"] + h = fh + arch = [] + for _ in range(n_l): + arch.append(h) + h = max(int(h * t.params.get("bottleneck_ratio", 0.5)), 2) + print(f" #{t.number}: eval={t.value:.4f}" + f" arch={arch}" + f" ep={t.params['n_epochs']}" + f" {'cos' if t.params['use_cosine'] else 'cst'}" + f" {'pp' if t.params['per_pool'] else 'sh'}") + + # Copy best trial to top-level artifact + best_trial = study.best_trial + best_trial_dir = os.path.join("results", "mlp_noise", "trials", + f"trial_{best_trial.number:04d}") + save_dir = "results/mlp_noise" + if os.path.exists(os.path.join(best_trial_dir, "model.npz")): + import shutil + shutil.copy2(os.path.join(best_trial_dir, "model.npz"), + os.path.join(save_dir, "model.npz")) + shutil.copy2(os.path.join(best_trial_dir, "meta.json"), + os.path.join(save_dir, "meta.json")) + print(f"\n Copied best trial ({best_trial.number}) to: {save_dir}") + + # TVL response check on best model + import json as _json + art = dict(np.load(os.path.join(save_dir, "model.npz"), allow_pickle=True)) + with open(os.path.join(save_dir, "meta.json")) as _f: + meta = _json.load(_f) + best_params = {k: jnp.array(art[k]) for k in art + if k.startswith("W") or k.startswith("b") + or k == "log_cadence" or k == "pool_bias"} + per_pool = meta.get("per_pool", False) + pool_idx_probe = jnp.array([0]) if per_pool else None + + print(f"\n TVL Response Check (best model, trial {best_trial.number}):") + x_probe = np.zeros((1, n_feat), dtype=np.float32) + x_probe[0, 0] = 1.0 + x_probe[0, 4] = 10.5 # typical btc_log_price + prev_noise = None + for tvl in [1e4, 1e5, 5e5, 1e6, 5e6, 1e7, 5e7, 1e8, 5e8]: + x_probe[0, 1] = np.log(tvl) + out = np.array(forward_mlp(best_params, jnp.array(x_probe), + pool_idx_probe)) + noise = np.exp(out[0]) + if prev_noise is not None and prev_noise > 0: + ratio = noise / prev_noise + print(f" TVL=${tvl:>11,.0f} noise=${noise:>12,.0f}/day" + f" ({ratio:.2f}x prev)") + else: + print(f" TVL=${tvl:>11,.0f} noise=${noise:>12,.0f}/day") + prev_noise = noise + + return study + + +def main(): + parser = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--hidden", type=int, nargs="+", default=[32], + help="Hidden layer sizes (e.g. --hidden 64 32)") + parser.add_argument("--epochs", type=int, default=2000) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--l2-alpha", type=float, default=1e-3) + parser.add_argument("--huber-delta", type=float, default=1.0) + parser.add_argument("--cosine", action="store_true", + help="Use optax Adam with cosine LR decay") + parser.add_argument("--tune", type=int, default=0, + help="Optuna sweep (0 = single run)") + parser.add_argument("--trend-windows", type=int, nargs="+", default=[7]) + parser.add_argument("--per-pool", action="store_true", + help="Per-pool output bias") + parser.add_argument("--pool-attrs", action="store_true", + help="Append static pool attributes to input") + parser.add_argument("--no-split", action="store_true", + help="Train on all data") + parser.add_argument("--save-artifact", default=None, + help="Save model to this directory") + args = parser.parse_args() + + os.environ.setdefault("JAX_PLATFORMS", "cpu") + + print("=" * 70) + print("MLP Noise Model (Binance features only, no cross-pool DEX)") + print(f" hidden={args.hidden}, per_pool={args.per_pool}") + print(f" epochs={args.epochs}, lr={args.lr}, l2={args.l2_alpha}") + print("=" * 70) + + # Build data WITHOUT cross-pool features + from experiments.run_linear_market_noise import load_stage1, build_data + + matched_clean, option_c_clean = load_stage1() + + print("\nBuilding data...") + t0 = time.time() + data = build_data( + matched_clean, option_c_clean, + trend_windows=tuple(args.trend_windows), + include_market=True, + include_cross_pool=False, # No DEX peer features + ) + n_pools = data["n_pools"] + n_feat = data["n_feat"] + print(f" {len(data['pool_idx'])} samples, {n_pools} pools," + f" {n_feat} features, {time.time() - t0:.1f}s") + + # Append pool attributes if requested + if args.pool_attrs: + from quantammsim.calibration.pool_data import build_pool_attributes + X_attr, attr_names, _ = build_pool_attributes(matched_clean) + # Standardize + attr_mean = X_attr.mean(axis=0) + attr_std = X_attr.std(axis=0) + attr_std[attr_std < 1e-6] = 1.0 + X_attr_norm = ((X_attr - attr_mean) / attr_std).astype(np.float32) + + # Broadcast to per-sample: each sample gets its pool's attributes + pool_idx = data["pool_idx"] + x_attr_samples = X_attr_norm[pool_idx] + data["x"] = np.concatenate([data["x"], x_attr_samples], axis=1) + data["n_feat"] = data["x"].shape[1] + data["feat_names"] = data["feat_names"] + attr_names + n_feat = data["n_feat"] + print(f" + {len(attr_names)} pool attributes → {n_feat} total features") + + print(f" Features: {data['feat_names']}") + + if args.tune > 0: + run_optuna(data, args.tune) + return + + # Split + if args.no_split: + train_data = data + eval_data = None + else: + day_idx = data["day_idx"] + n_samples = len(day_idx) + split_day = int(day_idx.max() * 0.7) + train_mask = day_idx <= split_day + eval_mask = day_idx > split_day + train_data = {k: v[train_mask] if isinstance(v, np.ndarray) + and v.shape[0] == n_samples else v + for k, v in data.items()} + eval_data = {k: v[eval_mask] if isinstance(v, np.ndarray) + and v.shape[0] == n_samples else v + for k, v in data.items()} + + # Init + params = init_mlp_params( + jax.random.PRNGKey(42), n_feat, args.hidden, n_pools, + data["init_log_cadences"], per_pool=args.per_pool) + + n_params = sum(v.size for v in params.values()) + print(f"\n Total params: {n_params}" + f" (MLP: {n_params - n_pools - (n_pools if args.per_pool else 0)}," + f" cadence: {n_pools}" + f"{',' + str(n_pools) + ' pool biases' if args.per_pool else ''})") + + # Warm-start output layer via OLS through hidden activations + x_trn = jnp.array(train_data["x"]) + y_trn = np.array(train_data["y_total"]) + h = np.array(x_trn) + i = 0 + while f"W{i}" in params: + h = np.maximum(h @ np.array(params[f"W{i}"]) + np.array(params[f"b{i}"]), 0.0) + i += 1 + h_bias = np.concatenate([h, np.ones((h.shape[0], 1))], axis=1) + sol, _, _, _ = np.linalg.lstsq(h_bias, y_trn[:, None], rcond=None) + params["W_out"] = jnp.array(sol[:-1].astype(np.float32)) + params["b_out"] = jnp.array(sol[-1:].astype(np.float32)) + print(f" OLS warm-start on hidden activations") + + # Train + grad_fn = make_loss_fn(data["pool_coeffs"], data["pool_gas"], n_pools) + + print(f"\n Compiling + training...") + t0 = time.time() + params = train(params, train_data, grad_fn, args.epochs, args.lr, + args.l2_alpha, args.huber_delta, + use_cosine=args.cosine) + print(f" Training: {time.time() - t0:.1f}s") + + # Evaluate + if eval_data is not None: + print("\n --- Train ---") + evaluate(params, train_data) + print("\n --- Eval ---") + evaluate(params, eval_data) + else: + print("\n --- All data ---") + evaluate(params, train_data) + + print(f"\n Baselines:") + print(f" Linear (no cross-pool): median R² ≈ 0.48") + print(f" Linear (with cross-pool): median R² ≈ 0.53") + print(f" Per-pool linear: median R² ≈ 0.61") + + # Save artifact + if args.save_artifact: + import json + os.makedirs(args.save_artifact, exist_ok=True) + # Save params as npz + save_dict = {k: np.array(v) for k, v in params.items()} + save_dict["x_mean"] = data["x_mean"] if "x_mean" in data else np.zeros(n_feat) + save_dict["x_std"] = data["x_std"] if "x_std" in data else np.ones(n_feat) + np.savez(os.path.join(args.save_artifact, "model.npz"), **save_dict) + # Save meta + meta = { + "pool_ids": data["pool_ids"], + "feat_names": data["feat_names"], + "n_feat": n_feat, + "hidden": args.hidden, + "per_pool": args.per_pool, + "hparams": { + "hidden": args.hidden, + "lr": args.lr, + "l2_alpha": args.l2_alpha, + "huber_delta": args.huber_delta, + "epochs": args.epochs, + "trend_windows": args.trend_windows, + "use_cosine": args.cosine, + "per_pool": args.per_pool, + }, + } + with open(os.path.join(args.save_artifact, "meta.json"), "w") as f: + json.dump(meta, f, indent=2) + print(f"\n Saved artifact to: {args.save_artifact}") + + +if __name__ == "__main__": + main() diff --git a/experiments/run_mm_noise.py b/experiments/run_mm_noise.py new file mode 100644 index 0000000..45323a6 --- /dev/null +++ b/experiments/run_mm_noise.py @@ -0,0 +1,865 @@ +"""Michaelis-Menten noise model with market features. + +Replaces the linear TVL term with a Michaelis-Menten saturation curve +while keeping all market features for temporal fit: + + log(V_noise) = log_alpha_i + x_market @ gamma + + log(TVL) - log(K_i + TVL) + + V_total = V_arb(cadence_i) + exp(log_V_noise) + Loss = Huber(log(V_total) - log(V_obs)) + +The TVL feature (xobs_1) is removed from x_market and handled +structurally via the MM saturation term. All other features (dow, +BTC, token, pair vol, interactions) remain as shared linear covariates. + +Parameters: + log_alpha_i : per-pool intercept + log_K_i : per-pool half-saturation TVL + gamma : shared coefficients on non-TVL features + log_cadence_i: per-pool arb frequency (via PCHIP) + +Usage: + python experiments/run_mm_noise.py + python experiments/run_mm_noise.py --epochs 5000 --lr 3e-4 + python experiments/run_mm_noise.py --per-pool-gamma # per-pool market coeffs +""" + +import argparse +import json +import os +import pickle +import sys +import time + +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd + + +CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "token_factored_calibration", "_cache", +) + + +def load_stage1(): + path = os.path.join(CACHE_DIR, "stage1.pkl") + with open(path, "rb") as f: + data = pickle.load(f) + return data["matched_clean"], data["option_c_clean"] + + +COMPETITOR_TVL_PATH = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "competitor_tvl", "competitor_tvl.npz", +) + + +def build_mm_data(matched_clean, option_c_clean, trend_windows=(7,), + include_cross_pool=False, competitor_tvl_path=None): + """Build data with MM structure: separate TVL from market features. + + Loads observed competitor TVL from DeFi Llama for K. + """ + from experiments.run_linear_market_noise import build_data + from quantammsim.calibration.pool_data import _parse_tokens + + # Get full feature matrix from linear model's pipeline + data = build_data( + matched_clean, option_c_clean, + trend_windows=trend_windows, + include_market=True, + include_cross_pool=include_cross_pool, + ) + + # Remove TVL column and TVL interaction terms — TVL handled by MM + feat_names = data["feat_names"] + x_full = data["x"] + tvl_col = feat_names.index("xobs_1") + tvl_interaction_cols = [i for i, name in enumerate(feat_names) + if name.startswith("xobs_1\u00d7")] + remove_cols = {tvl_col} | set(tvl_interaction_cols) + keep_cols = [i for i in range(len(feat_names)) if i not in remove_cols] + x_market = x_full[:, keep_cols].astype(np.float32) + market_names = [feat_names[i] for i in keep_cols] + + pool_ids = data["pool_ids"] + n_pools = data["n_pools"] + + # Common date grid + all_dates = set() + for pid in pool_ids: + all_dates.update(matched_clean[pid]["panel"]["date"].values) + date_list = sorted(all_dates) + date_to_idx = {d: i for i, d in enumerate(date_list)} + n_dates = len(date_list) + + # Rebuild raw log_tvl from panel + tvl_grid = np.full((n_dates, n_pools), np.nan) + for j, pid in enumerate(pool_ids): + panel = matched_clean[pid]["panel"] + dates = panel["date"].values + log_tvls = panel["log_tvl_lag1"].values.astype(float) + for k, date in enumerate(dates): + tvl_grid[date_to_idx[date], j] = log_tvls[k] + + pool_idx = data["pool_idx"] + day_idx = data["day_idx"] + n_samples = len(pool_idx) + log_tvl = np.array([tvl_grid[day_idx[s], pool_idx[s]] + for s in range(n_samples)], dtype=np.float32) + + # Load observed competitor TVL (K) + comp_path = competitor_tvl_path or COMPETITOR_TVL_PATH + if os.path.exists(comp_path): + comp_data = np.load(comp_path, allow_pickle=True) + comp_pool_ids = list(comp_data["pool_ids"]) + comp_dates = list(comp_data["date_list"]) + # Use K_eff (network conductance) if available, else direct competitor TVL + if "k_eff" in comp_data: + comp_tvl_matrix = comp_data["k_eff"] + print(f" Using network K_eff (direct + multi-hop)") + else: + comp_tvl_matrix = comp_data["competitor_tvl"] + print(f" Using direct competitor TVL only") + + # Build date index for competitor data (normalize to YYYY-MM-DD) + comp_date_to_idx = {} + for ci, d in enumerate(comp_dates): + comp_date_to_idx[str(d)[:10]] = ci + + # Map competitor TVL to our (n_dates, n_pools) grid + comp_tvl_grid = np.full((n_dates, n_pools), np.nan) + for j, pid in enumerate(pool_ids): + if pid not in comp_pool_ids: + continue + cj = comp_pool_ids.index(pid) + for t, date in enumerate(date_list): + date_str = str(pd.Timestamp(date))[:10] + if date_str in comp_date_to_idx: + ci = comp_date_to_idx[date_str] + val = comp_tvl_matrix[ci, cj] + if np.isfinite(val) and val > 0: + comp_tvl_grid[t, j] = val + + # Forward-fill / back-fill gaps per pool + for j in range(n_pools): + col = comp_tvl_grid[:, j] + mask = np.isfinite(col) + if mask.any() and not mask.all(): + s = pd.Series(col, index=date_list).ffill().bfill() + comp_tvl_grid[:, j] = s.values + + # Flag pools with no competitor data + has_comp = np.zeros(n_pools, dtype=bool) + for j in range(n_pools): + has_comp[j] = np.isfinite(comp_tvl_grid[:, j]).any() + + n_with = has_comp.sum() + print(f" Competitor TVL: {n_with}/{n_pools} pools with data") + + # Per-sample log(competitor_tvl), floor at $1 + raw_comp = np.array([ + comp_tvl_grid[day_idx[s], pool_idx[s]] + for s in range(n_samples)], dtype=np.float64) + + # For pools without data, impute with median of pools that have data + valid_comp = raw_comp[np.isfinite(raw_comp) & (raw_comp > 0)] + fallback_val = float(np.median(valid_comp)) if len(valid_comp) > 0 else 1e6 + raw_comp = np.where(np.isfinite(raw_comp) & (raw_comp > 0), + raw_comp, fallback_val) + log_comp_tvl = np.log(np.maximum(raw_comp, 1.0)).astype(np.float32) + print(f" Fallback comp TVL for missing pools: ${fallback_val:,.0f}") + for j in range(n_pools): + if not has_comp[j]: + print(f" No competitor data: {pool_ids[j][:16]}" + f" ({matched_clean[pool_ids[j]].get('tokens', '?')})") + else: + print(f" WARNING: no competitor TVL file at {comp_path}") + log_comp_tvl = np.full(n_samples, np.log(1e6), dtype=np.float32) + has_comp = np.zeros(n_pools, dtype=bool) + + # Token info + pool_tokens = [] + for pid in pool_ids: + toks = _parse_tokens(matched_clean[pid]["tokens"]) + tok_a = toks[0] + tok_b = toks[1] if len(toks) > 1 else toks[0] + pool_tokens.append((tok_a, tok_b)) + + removed_names = [feat_names[i] for i in sorted(remove_cols)] + print(f" Removed: {removed_names}") + print(f" Market features ({len(market_names)}): {market_names}") + + return { + "x_market": x_market, + "log_tvl": log_tvl, + "log_comp_tvl": log_comp_tvl, + "has_comp": has_comp, + "y_total": data["y_total"], + "pool_idx": pool_idx, + "day_idx": day_idx, + "sample_grid_days": data["sample_grid_days"], + "pool_coeffs": data["pool_coeffs"], + "pool_gas": data["pool_gas"], + "init_log_cadences": data["init_log_cadences"], + "n_pools": n_pools, + "n_market_feat": x_market.shape[1], + "pool_ids": pool_ids, + "pool_tokens": pool_tokens, + "market_names": market_names, + "x_mean": data["x_mean"], + "x_std": data["x_std"], + } + + +# ---- Model ---- + +def forward_mm(params, x_market, log_tvl, pool_idx, log_comp_tvl=None): + """MM forward pass → log(V_noise) per sample. + + K modes (checked in order): + - Observed: log_comp_tvl provided + params has "k_scale" (2,) + K = exp(k_scale[0] + k_scale[1] * log_comp_tvl) + - Per-pool: params contains "log_K" (n_pools,) + - Shared k_params: params contains "k_params" (3,) [legacy] + """ + log_alpha = params["log_alpha"] + gamma = params["gamma"] + + alpha_i = log_alpha[pool_idx] + tvl = jnp.exp(log_tvl) + + # K + if log_comp_tvl is not None and "k_scale" in params: + # Observed competitor TVL with learned scale/offset + k_s = params["k_scale"] + log_K = k_s[0] + k_s[1] * log_comp_tvl + K = jnp.exp(log_K) + elif log_comp_tvl is not None and "k_scale" not in params and "log_K" not in params: + # Observed competitor TVL, used directly as K + K = jnp.exp(log_comp_tvl) + elif "log_K" in params: + K = jnp.exp(params["log_K"][pool_idx]) + elif "k_params" in params: + # Legacy Binance-volume mode (kept for loading old models) + K = jnp.exp(params["k_params"][0]) + else: + K = jnp.exp(jnp.array(14.5)) # fallback + + # Market features: shared or per-pool gamma + if gamma.ndim == 2: + per_sample_gamma = gamma[pool_idx] + market_term = jnp.sum(x_market * per_sample_gamma, axis=1) + else: + market_term = x_market @ gamma + + log_saturation = log_tvl - jnp.log(K + tvl) + return alpha_i + market_term + log_saturation + + +def make_loss_fn(pool_coeffs, pool_gas, n_pools): + """Loss with PCHIP arb + MM noise.""" + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + + def loss_fn(params, x_market, log_tvl, log_comp_tvl, y_total, + sample_grid_days, pool_idx, l2_alpha, huber_delta): + log_cadence = params["log_cadence"] + + # V_arb from PCHIP + n_samples = x_market.shape[0] + log_v_arb = jnp.zeros(n_samples) + for i in range(n_pools): + v_arb_all = interpolate_pool_daily( + pool_coeffs[i], jnp.float64(log_cadence[i]), pool_gas[i]) + safe_days = jnp.clip(sample_grid_days, 0, v_arb_all.shape[0] - 1) + log_v_arb = jnp.where( + pool_idx == i, + jnp.log(jnp.maximum(v_arb_all[safe_days], 1e-10)), + log_v_arb) + + # V_noise from MM + log_v_noise = forward_mm( + params, x_market, log_tvl, pool_idx, + log_comp_tvl=log_comp_tvl) + + # V_total + log_v_total = jnp.logaddexp(log_v_arb, log_v_noise) + + # Huber + residual = log_v_total - y_total + abs_r = jnp.abs(residual) + huber = jnp.where( + abs_r <= huber_delta, + 0.5 * residual ** 2, + huber_delta * (abs_r - 0.5 * huber_delta)) + + # Per-pool equal weighting + pool_counts = jnp.zeros(n_pools).at[pool_idx].add( + jnp.ones_like(pool_idx, dtype=jnp.float32)) + active = (pool_counts > 0).astype(jnp.float32) + n_active = jnp.maximum(jnp.sum(active), 1.0) + pool_counts = jnp.maximum(pool_counts, 1.0) + pool_sums = jnp.zeros(n_pools).at[pool_idx].add(huber) + mean_loss = jnp.sum((pool_sums / pool_counts) * active) / n_active + + # L2 on gamma and log_alpha + reg = l2_alpha * ( + jnp.mean(params["gamma"] ** 2) + + jnp.mean(params["log_alpha"] ** 2) + ) + + return mean_loss + reg + + return jax.jit(jax.value_and_grad(loss_fn)) + + +# ---- Training ---- + +def train(params, data, grad_fn, n_epochs, lr, l2_alpha, huber_delta, + verbose=True): + """Adam training loop.""" + m = {k: jnp.zeros_like(v) for k, v in params.items()} + v = {k: jnp.zeros_like(v) for k, v in params.items()} + b1, b2, eps = 0.9, 0.999, 1e-8 + + x_market = jnp.array(data["x_market"]) + log_tvl = jnp.array(data["log_tvl"]) + log_comp_tvl = jnp.array(data["log_comp_tvl"]) + y_total = jnp.array(data["y_total"]) + sgd = jnp.array(data["sample_grid_days"]) + pidx = jnp.array(data["pool_idx"]) + + for epoch in range(n_epochs): + loss, grads = grad_fn( + params, x_market, log_tvl, log_comp_tvl, + y_total, sgd, pidx, l2_alpha, huber_delta) + + for k in params: + g = grads[k] + m[k] = b1 * m[k] + (1 - b1) * g + v[k] = b2 * v[k] + (1 - b2) * g ** 2 + m_hat = m[k] / (1 - b1 ** (epoch + 1)) + v_hat = v[k] / (1 - b2 ** (epoch + 1)) + params[k] = params[k] - lr * m_hat / (jnp.sqrt(v_hat) + eps) + + if verbose and (epoch % 200 == 0 or epoch == n_epochs - 1): + ev = evaluate(params, data) + if "k_scale" in params: + ks = np.array(params["k_scale"]) + k_str = f" k_s=[{ks[0]:.2f},{ks[1]:.3f}]" + elif "k_params" in params: + k_p = np.array(params["k_params"]) + k_str = f" k=[{k_p[0]:.2f},{k_p[1]:.3f},{k_p[2]:.3f}]" + else: + k_str = "" + K_med = float(np.median(list(ev["K_values"].values()))) + cad = np.exp(np.array(params["log_cadence"])) + print(f" epoch {epoch:5d} loss={float(loss):.4f}" + f" R²={ev['median_r2']:.3f}" + f" K_med=${K_med/1e6:.1f}M{k_str}" + f" cad=[{cad.min():.0f},{np.median(cad):.0f},{cad.max():.0f}]") + + return params + + +# ---- Evaluation ---- + +def evaluate(params, data): + """Per-pool R² and diagnostics.""" + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + + n_pools = data["n_pools"] + pool_idx = np.array(data["pool_idx"]) + sgd = np.array(data["sample_grid_days"]) + y = np.array(data["y_total"]) + log_cadence = np.array(params["log_cadence"]) + + # V_arb + v_arb = np.zeros(len(y)) + for i in range(n_pools): + mask = pool_idx == i + if not mask.any(): + continue + v_arb_all = np.array(interpolate_pool_daily( + data["pool_coeffs"][i], jnp.float64(log_cadence[i]), + data["pool_gas"][i])) + safe = np.clip(sgd[mask], 0, len(v_arb_all) - 1) + v_arb[mask] = v_arb_all[safe] + log_v_arb = np.log(np.maximum(v_arb, 1e-10)) + + log_v_noise = np.array(forward_mm( + params, jnp.array(data["x_market"]), + jnp.array(data["log_tvl"]), + jnp.array(data["pool_idx"]), + log_comp_tvl=jnp.array(data["log_comp_tvl"]))) + + log_v_total = np.logaddexp(log_v_arb, log_v_noise) + v_noise = np.exp(log_v_noise) + v_total = np.exp(log_v_total) + + r2s = {} + noise_shares = {} + for i in range(n_pools): + mask = pool_idx == i + if mask.sum() < 2: + continue + yt = y[mask] + pt = log_v_total[mask] + ss_res = np.sum((yt - pt) ** 2) + ss_tot = np.sum((yt - yt.mean()) ** 2) + r2s[data["pool_ids"][i]] = 1 - ss_res / max(ss_tot, 1e-10) + noise_shares[data["pool_ids"][i]] = float(np.median( + v_noise[mask] / v_total[mask])) + + # Per-pool median K + K_values = {} + if "k_scale" in params: + ks = np.array(params["k_scale"]) + for i in range(n_pools): + mask = pool_idx == i + if not mask.any(): + K_values[data["pool_ids"][i]] = 0 + continue + lc = data["log_comp_tvl"][mask] + log_K_i = ks[0] + ks[1] * lc + K_values[data["pool_ids"][i]] = float(np.exp(np.median(log_K_i))) + elif "log_K" in params: + for i in range(n_pools): + K_values[data["pool_ids"][i]] = float(np.exp(params["log_K"][i])) + elif "k_params" in params: + k_p = np.array(params["k_params"]) + for i in range(n_pools): + K_values[data["pool_ids"][i]] = float(np.exp(k_p[0])) + else: + # Observed K: compute from log_comp_tvl directly + for i in range(n_pools): + mask = pool_idx == i + if mask.any(): + K_values[data["pool_ids"][i]] = float( + np.exp(np.median(data["log_comp_tvl"][mask]))) + else: + K_values[data["pool_ids"][i]] = 1e6 + + return { + "r2s": r2s, + "noise_shares": noise_shares, + "K_values": K_values, + "median_r2": float(np.median(list(r2s.values()))), + } + + +def tvl_response_check(params, data): + """Print predicted noise at various TVL levels.""" + n_pools = data["n_pools"] + pool_idx = np.array(data["pool_idx"]) + + # Median market features per pool + print(f"\n TVL Response Check (per-pool median market features):") + print(f" {'Pool':>20s} {'K ($M)':>10s} {'TVL=100K':>10s}" + f" {'TVL=1M':>10s} {'TVL=10M':>10s} {'TVL=100M':>10s}" + f" {'TVL=1B':>10s} {'ε@1M':>6s} {'ε@100M':>6s}") + + tvl_test = [1e5, 1e6, 1e7, 1e8, 1e9] + + for i in range(n_pools): + pid = data["pool_ids"][i] + toks = data["pool_tokens"][i] + label = f"{toks[0]}/{toks[1]}" + mask = pool_idx == i + if mask.sum() == 0: + continue + + # Per-pool K (median) + if "k_scale" in params: + ks = np.array(params["k_scale"]) + lc = data["log_comp_tvl"][mask] + K_i = float(np.exp(np.median(ks[0] + ks[1] * lc))) + elif "log_K" in params: + K_i = float(np.exp(params["log_K"][i])) + elif "k_params" in params: + K_i = float(np.exp(np.array(params["k_params"])[0])) + else: + # Observed K directly from competitor TVL + K_i = float(np.exp(np.median(data["log_comp_tvl"][mask]))) + x_med = np.median(data["x_market"][mask], axis=0) + + gamma = np.array(params["gamma"]) + if gamma.ndim == 2: + market_term = float(x_med @ gamma[i]) + else: + market_term = float(x_med @ gamma) + log_alpha_i = float(params["log_alpha"][i]) + + vols = [] + for tvl in tvl_test: + log_sat = np.log(tvl) - np.log(K_i + tvl) + log_v = log_alpha_i + market_term + log_sat + vols.append(np.exp(log_v)) + + # Elasticity at 1M and 100M + eps_1m = K_i / (K_i + 1e6) + eps_100m = K_i / (K_i + 1e8) + + print(f" {label:>20s} ${K_i/1e6:>9.1f}" + + "".join(f" ${v:>9,.0f}" for v in vols) + + f" {eps_1m:>6.3f} {eps_100m:>6.3f}") + + +# ---- Main ---- + +def main(): + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--epochs", type=int, default=3000) + parser.add_argument("--lr", type=float, default=3e-4) + parser.add_argument("--l2-alpha", type=float, default=1e-3) + parser.add_argument("--huber-delta", type=float, default=1.0) + parser.add_argument("--init-log-K", type=float, default=17.0, + help="Initial log(K) ~ log($24M)") + parser.add_argument("--shared-K", action="store_true", + help="Predict K from Binance volumes (3 shared params)") + parser.add_argument("--observed-K", action="store_true", + help="Use observed competitor TVL from DeFi Llama as K") + parser.add_argument("--per-pool-gamma", action="store_true", + help="Per-pool market feature coefficients") + parser.add_argument("--no-split", action="store_true") + parser.add_argument("--trend-windows", type=int, nargs="+", default=[7]) + parser.add_argument("--include-cross-pool", action="store_true") + parser.add_argument("--tune", type=int, default=0, + help="Optuna sweep (0 = single run)") + parser.add_argument("--save-artifact", default="results/mm_noise") + args = parser.parse_args() + + os.environ.setdefault("JAX_PLATFORMS", "cpu") + + print("=" * 70) + print("Michaelis-Menten Noise Model + Market Features") + print(f" epochs={args.epochs}, lr={args.lr}, l2={args.l2_alpha}") + print(f" init log(K)={args.init_log_K} (K=${np.exp(args.init_log_K):,.0f})") + print(f" per_pool_gamma={args.per_pool_gamma}") + print("=" * 70) + + matched_clean, option_c_clean = load_stage1() + + print("\nBuilding data...") + t0 = time.time() + data = build_mm_data(matched_clean, option_c_clean, + trend_windows=tuple(args.trend_windows), + include_cross_pool=args.include_cross_pool) + n_pools = data["n_pools"] + n_market = data["n_market_feat"] + n_samples = len(data["pool_idx"]) + print(f" {n_samples} samples, {n_pools} pools," + f" {n_market} market features, {time.time() - t0:.1f}s") + + if args.tune > 0: + run_optuna(data, args.tune) + return + + # Pool summary + pool_idx = data["pool_idx"] + for i, (pid, toks) in enumerate( + zip(data["pool_ids"], data["pool_tokens"])): + mask = pool_idx == i + n = mask.sum() + if n > 0: + med_tvl = np.exp(np.median(data["log_tvl"][mask])) + print(f" {pid[:16]} {toks[0]:>8s}/{toks[1]:<8s}" + f" {n:>4d} days TVL=${med_tvl:>12,.0f}") + + # Split + if args.no_split: + train_data = data + eval_data = None + else: + day_idx = data["day_idx"] + split_day = int(day_idx.max() * 0.7) + train_mask = day_idx <= split_day + eval_mask = day_idx > split_day + train_data = {k: v[train_mask] if isinstance(v, np.ndarray) + and v.shape[0] == n_samples else v + for k, v in data.items()} + eval_data = {k: v[eval_mask] if isinstance(v, np.ndarray) + and v.shape[0] == n_samples else v + for k, v in data.items()} + print(f"\n Split: {train_mask.sum()} train, {eval_mask.sum()} eval") + + # Init + if args.per_pool_gamma: + gamma_init = jnp.zeros((n_pools, n_market)) + else: + gamma_init = jnp.zeros(n_market) + + params = { + "log_alpha": jnp.zeros(n_pools), + "gamma": gamma_init, + "log_cadence": jnp.array(data["init_log_cadences"]), + } + if args.observed_K: + # K = competitor_tvl directly. No learned params for K. + # log_comp_tvl is passed as data, not as a parameter. + pass + elif args.shared_K: + params["k_params"] = jnp.array([args.init_log_K, 0.0, 0.0]) + else: + params["log_K"] = jnp.full(n_pools, args.init_log_K) + n_params = sum(v.size for v in params.values()) + print(f"\n Parameters: {n_params}" + f" (α: {n_pools}, K: {n_pools}," + f" γ: {gamma_init.size}, cadence: {n_pools})") + + # Warm-start gamma via Ridge (numpy, no sklearn) + print(" Warm-starting γ via Ridge on residuals...") + + def _ridge(X, y, alpha=1.0): + """Ridge regression: (X'X + αI)^-1 X'y.""" + XtX = X.T @ X + alpha * np.eye(X.shape[1]) + Xty = X.T @ y + return np.linalg.solve(XtX, Xty) + + x_trn = data["x_market"] if args.no_split else train_data["x_market"] + y_trn = data["y_total"] if args.no_split else train_data["y_total"] + if args.per_pool_gamma: + pidx = data["pool_idx"] if args.no_split else train_data["pool_idx"] + for i in range(n_pools): + mask = pidx == i + if mask.sum() < 5: + continue + # Add intercept column for warm-start + X_i = np.concatenate([x_trn[mask], np.ones((mask.sum(), 1))], 1) + w = _ridge(X_i, y_trn[mask]) + params["gamma"] = params["gamma"].at[i].set( + jnp.array(w[:-1].astype(np.float32))) + params["log_alpha"] = params["log_alpha"].at[i].set(float(w[-1])) + else: + X_all = np.concatenate([x_trn, np.ones((len(y_trn), 1))], 1) + w = _ridge(X_all, y_trn) + params["gamma"] = jnp.array(w[:-1].astype(np.float32)) + + # Loss + grad_fn = make_loss_fn(data["pool_coeffs"], data["pool_gas"], n_pools) + + print(f"\nTraining ({args.epochs} epochs)...") + t0 = time.time() + params = train(params, train_data, grad_fn, args.epochs, args.lr, + args.l2_alpha, args.huber_delta) + print(f" Training time: {time.time() - t0:.1f}s") + + # Evaluate + print("\n" + "=" * 70) + print("Results (train)") + print("=" * 70) + train_eval = evaluate(params, train_data) + print(f" Median R²: {train_eval['median_r2']:.4f}") + + if "k_scale" in params: + ks = np.array(params["k_scale"]) + print(f" Observed K: offset={ks[0]:.3f}, slope={ks[1]:.3f}") + elif "k_params" in params: + k_p = np.array(params["k_params"]) + print(f" k_params: k_0={k_p[0]:.2f}, k_min={k_p[1]:.4f}, k_max={k_p[2]:.4f}") + elif "log_K" in params: + K_med = float(np.exp(np.median(np.array(params["log_K"])))) + print(f" Per-pool K: median=${K_med/1e6:.1f}M") + else: + K_med = float(np.median(list(train_eval["K_values"].values()))) + print(f" Observed K (fixed): median=${K_med/1e6:.1f}M") + + print(f"\n {'Pool':>16s} {'Tokens':>16s} {'R²':>6s}" + f" {'Noise%':>7s} {'K ($M)':>10s}") + for pid in data["pool_ids"]: + i = data["pool_ids"].index(pid) + toks = data["pool_tokens"][i] + r2 = train_eval["r2s"].get(pid, float("nan")) + ns = train_eval["noise_shares"].get(pid, float("nan")) + K = train_eval["K_values"][pid] + print(f" {pid[:16]} {toks[0]:>8s}/{toks[1]:<6s}" + f" {r2:>6.3f} {ns*100:>6.1f}% ${K/1e6:>9.1f}") + + if eval_data is not None: + print("\n" + "=" * 70) + print("Results (eval)") + print("=" * 70) + eval_result = evaluate(params, eval_data) + print(f" Median R²: {eval_result['median_r2']:.4f}") + + # TVL response + tvl_response_check(params, data) + + # Gamma coefficients + gamma = np.array(params["gamma"]) + if gamma.ndim == 1: + print(f"\n Shared γ coefficients:") + for j, name in enumerate(data["market_names"]): + print(f" {name:>30s}: {gamma[j]:>8.4f}") + + # Save + if args.save_artifact: + os.makedirs(args.save_artifact, exist_ok=True) + save_dict = {k: np.array(v) for k, v in params.items()} + np.savez(os.path.join(args.save_artifact, "model.npz"), **save_dict) + meta = { + "model": "michaelis_menten", + "pool_ids": data["pool_ids"], + "pool_tokens": data["pool_tokens"], + "market_names": data["market_names"], + "n_pools": n_pools, + "n_market_feat": n_market, + "per_pool_gamma": args.per_pool_gamma, + "hparams": { + "epochs": args.epochs, "lr": args.lr, + "l2_alpha": args.l2_alpha, "huber_delta": args.huber_delta, + "init_log_K": args.init_log_K, + }, + } + with open(os.path.join(args.save_artifact, "meta.json"), "w") as f: + json.dump(meta, f, indent=2) + print(f"\n Saved: {args.save_artifact}/") + + +def run_optuna(data, n_trials): + """Optuna hyperparameter sweep for MM noise model.""" + import optuna + optuna.logging.set_verbosity(optuna.logging.WARNING) + + n_pools = data["n_pools"] + n_market = data["n_market_feat"] + n_samples = len(data["pool_idx"]) + + # 70/30 temporal split + day_idx = data["day_idx"] + split_day = int(day_idx.max() * 0.7) + train_mask = day_idx <= split_day + eval_mask = day_idx > split_day + train_data = {k: v[train_mask] if isinstance(v, np.ndarray) + and v.shape[0] == n_samples else v + for k, v in data.items()} + eval_data = {k: v[eval_mask] if isinstance(v, np.ndarray) + and v.shape[0] == n_samples else v + for k, v in data.items()} + print(f" Optuna split: {train_mask.sum()} train, {eval_mask.sum()} eval") + + def _ridge(X, y, alpha=1.0): + XtX = X.T @ X + alpha * np.eye(X.shape[1]) + return np.linalg.solve(XtX, X.T @ y) + + def objective(trial): + lr = trial.suggest_float("lr", 1e-4, 3e-2, log=True) + l2_alpha = trial.suggest_float("l2_alpha", 1e-5, 1e-1, log=True) + huber_delta = trial.suggest_categorical("huber_delta", [0.5, 1.0, 1.5]) + init_log_K = trial.suggest_float("init_log_K", 14.0, 20.0) + n_epochs = trial.suggest_categorical("n_epochs", [2000, 3000, 5000]) + per_pool_gamma = trial.suggest_categorical("per_pool_gamma", [True, False]) + if per_pool_gamma: + gamma_init = jnp.zeros((n_pools, n_market)) + else: + gamma_init = jnp.zeros(n_market) + + params = { + "log_alpha": jnp.zeros(n_pools), + "k_params": jnp.array([init_log_K, 0.0, 0.0]), + "gamma": gamma_init, + "log_cadence": jnp.array(data["init_log_cadences"]), + } + + # Warm-start gamma + x_trn = train_data["x_market"] + y_trn = train_data["y_total"] + if per_pool_gamma: + pidx = train_data["pool_idx"] + for i in range(n_pools): + mask_i = pidx == i + if mask_i.sum() < 5: + continue + X_i = np.concatenate([x_trn[mask_i], + np.ones((mask_i.sum(), 1))], 1) + w = _ridge(X_i, y_trn[mask_i]) + params["gamma"] = params["gamma"].at[i].set( + jnp.array(w[:-1].astype(np.float32))) + params["log_alpha"] = params["log_alpha"].at[i].set( + float(w[-1])) + else: + X_all = np.concatenate([x_trn, np.ones((len(y_trn), 1))], 1) + w = _ridge(X_all, y_trn) + params["gamma"] = jnp.array(w[:-1].astype(np.float32)) + + grad_fn = make_loss_fn(data["pool_coeffs"], data["pool_gas"], n_pools) + params = train(params, train_data, grad_fn, n_epochs, lr, + l2_alpha, huber_delta, verbose=False) + + # Eval + eval_result = evaluate(params, eval_data) + med_r2 = eval_result["median_r2"] + + K_med = float(np.median([v for v in eval_result["K_values"].values()])) + k_p = np.array(params["k_params"]) + pp_str = "pp" if per_pool_gamma else "sh" + print(f" Trial {trial.number}: eval={med_r2:.4f}" + f" K_med=${K_med/1e6:.1f}M" + f" k=[{k_p[0]:.1f},{k_p[1]:.3f},{k_p[2]:.3f}]" + f" {pp_str} ep={n_epochs} lr={lr:.1e} l2={l2_alpha:.1e}" + f" hub={huber_delta}") + + # Save every trial + trial_dir = os.path.join("results", "mm_noise", "trials", + f"trial_{trial.number:04d}") + os.makedirs(trial_dir, exist_ok=True) + save_dict = {k: np.array(v) for k, v in params.items()} + np.savez(os.path.join(trial_dir, "model.npz"), **save_dict) + meta = { + "pool_ids": data["pool_ids"], + "pool_tokens": data["pool_tokens"], + "market_names": data["market_names"], + "n_pools": n_pools, + "n_market_feat": n_market, + "per_pool_gamma": per_pool_gamma, + "eval_r2": med_r2, + "hparams": { + "lr": lr, "l2_alpha": l2_alpha, "huber_delta": huber_delta, + "init_log_K": init_log_K, "n_epochs": n_epochs, + "per_pool_gamma": per_pool_gamma, + }, + } + with open(os.path.join(trial_dir, "meta.json"), "w") as f: + json.dump(meta, f, indent=2) + + return med_r2 + + study = optuna.create_study(direction="maximize") + study.optimize(objective, n_trials=n_trials) + + print(f"\n{'='*70}") + print(f"Optuna Results (MM noise)") + print(f"{'='*70}") + print(f" Best eval R²: {study.best_value:.4f}") + print(f" Best params:") + for k, v in sorted(study.best_params.items()): + print(f" {k}: {v}") + + trials = sorted(study.trials, key=lambda t: t.value if t.value else -999, + reverse=True) + print(f"\n Top 10:") + for t in trials[:10]: + if t.value is not None: + print(f" #{t.number}: eval={t.value:.4f} {t.params}") + + # Copy best to top-level + best_dir = os.path.join("results", "mm_noise", "trials", + f"trial_{study.best_trial.number:04d}") + if os.path.exists(os.path.join(best_dir, "model.npz")): + import shutil + for fn in ("model.npz", "meta.json"): + shutil.copy2(os.path.join(best_dir, fn), + os.path.join("results", "mm_noise", fn)) + print(f"\n Copied best trial ({study.best_trial.number})" + f" to results/mm_noise/") + + return study + + +if __name__ == "__main__": + main() diff --git a/experiments/run_model_comparison.py b/experiments/run_model_comparison.py new file mode 100644 index 0000000..398ec48 --- /dev/null +++ b/experiments/run_model_comparison.py @@ -0,0 +1,338 @@ +"""Compare linear vs MLP noise models across TVL levels. + +Evaluates both noise models for a given pool over the same date range, +sweeping initial TVL. Uses real price data, the PCHIP arb grid, and +both noise models to predict daily volume decomposition. + +Produces a plot: predicted daily noise volume vs TVL for each model, +with the real observed volume overlaid where available. + +Usage: + python experiments/run_model_comparison.py + python experiments/run_model_comparison.py --tvl-range 1e5 1e6 5e6 7e6 20e6 50e6 +""" + +import argparse +import json +import os +import pickle +import time + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +import jax.numpy as jnp + + +CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "token_factored_calibration", "_cache", +) +LINEAR_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "linear_market_noise", +) +MLP_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "mlp_noise", +) + + +def load_linear_model(artifact_dir, pool_id): + """Load linear noise model for a pool.""" + art = np.load(os.path.join(artifact_dir, "model.npz")) + with open(os.path.join(artifact_dir, "meta.json")) as f: + meta = json.load(f) + pool_ids = meta["pool_ids"] + idx = next((i for i, p in enumerate(pool_ids) + if p.startswith(pool_id) or pool_id.startswith(p)), -1) + nc = art["noise_coeffs"] + coeffs = nc[idx] if nc.ndim == 2 and idx >= 0 else (nc if nc.ndim == 1 else np.median(nc, axis=0)) + return { + "coeffs": coeffs, + "log_cadence": art["log_cadence"][idx] if idx >= 0 else np.log(10.0), + "x_mean": art["x_mean"], + "x_std": art["x_std"], + "feat_names": meta["feat_names"], + "type": "linear", + } + + +def load_mlp_model(artifact_dir, pool_id): + """Load MLP noise model.""" + art = dict(np.load(os.path.join(artifact_dir, "model.npz"), allow_pickle=True)) + with open(os.path.join(artifact_dir, "meta.json")) as f: + meta = json.load(f) + pool_ids = meta["pool_ids"] + idx = next((i for i, p in enumerate(pool_ids) + if p.startswith(pool_id) or pool_id.startswith(p)), -1) + + # Extract MLP params + params = {} + for k in art: + if k.startswith("W") or k.startswith("b") or k == "log_cadence" or k == "pool_bias": + params[k] = art[k] + + return { + "params": params, + "log_cadence": art["log_cadence"][idx] if idx >= 0 else np.log(10.0), + "pool_idx": idx, + "x_mean": art["x_mean"], + "x_std": art["x_std"], + "feat_names": meta["feat_names"], + "hidden": meta["hidden"], + "per_pool": meta.get("per_pool", False), + "type": "mlp", + } + + +def predict_noise_linear(model, x_daily, tvl_values): + """Predict noise volume at multiple TVL levels using linear model.""" + tvl_col = 1 # xobs_1 + results = {} + for tvl in tvl_values: + x = x_daily.copy() + x[:, tvl_col] = (np.log(tvl) - model["x_mean"][tvl_col]) / model["x_std"][tvl_col] + # Update TVL interaction terms + for i, name in enumerate(model["feat_names"]): + if name.startswith("xobs_1\u00d7"): + paired = name.split("\u00d7")[1] + if paired in model["feat_names"]: + j = model["feat_names"].index(paired) + x[:, i] = x[:, tvl_col] * x_daily[:, j] + log_noise = x @ model["coeffs"] + results[tvl] = np.exp(log_noise) + return results + + +def predict_noise_mlp(model, x_daily, tvl_values): + """Predict noise volume at multiple TVL levels using MLP model.""" + from experiments.run_mlp_noise import forward_mlp + tvl_col = 1 + params = model["params"] + pool_idx_arr = (jnp.full(x_daily.shape[0], model["pool_idx"]) + if model["per_pool"] and model["pool_idx"] >= 0 else None) + results = {} + for tvl in tvl_values: + x = x_daily.copy() + x[:, tvl_col] = (np.log(tvl) - model["x_mean"][tvl_col]) / model["x_std"][tvl_col] + for i, name in enumerate(model["feat_names"]): + if name.startswith("xobs_1\u00d7"): + paired = name.split("\u00d7")[1] + if paired in model["feat_names"]: + j = model["feat_names"].index(paired) + x[:, i] = x[:, tvl_col] * x_daily[:, j] + log_noise = np.array(forward_mlp(params, jnp.array(x), pool_idx_arr)) + results[tvl] = np.exp(log_noise) + return results + + +def main(): + parser = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--pool", default="0x9d1fcf346ea1b0") + parser.add_argument("--tvl-range", type=float, nargs="+", + default=[100_000, 500_000, 1_000_000, 5_000_000, + 7_000_000, 20_000_000, 50_000_000]) + parser.add_argument("--linear-dir", default=LINEAR_DIR) + parser.add_argument("--mlp-dir", default=MLP_DIR) + parser.add_argument("--output-dir", default="results/model_comparison") + args = parser.parse_args() + + os.environ.setdefault("JAX_PLATFORMS", "cpu") + os.makedirs(args.output_dir, exist_ok=True) + + # Load data + with open(os.path.join(CACHE_DIR, "stage1.pkl"), "rb") as f: + data = pickle.load(f) + mc = data["matched_clean"] + oc = data["option_c_clean"] + + pid = args.pool + entry = mc[pid] + panel = entry["panel"] + dates = pd.to_datetime(panel["date"]) + vol_obs = np.exp(panel["log_volume"].values.astype(float)) + tvl_obs = np.exp(panel["log_tvl_lag1"].values.astype(float)) + + print(f"Pool: {pid} ({entry['tokens']}, {entry['chain']})") + print(f"{len(dates)} days: {dates.min().date()} → {dates.max().date()}") + + # Build feature matrix (same for both models) + from experiments.run_linear_market_noise import build_data + data_full = build_data(mc, oc, trend_windows=(7,), + include_market=True, include_cross_pool=False) + pool_ids = data_full["pool_ids"] + pool_i = pool_ids.index(pid) + pool_mask = data_full["pool_idx"] == pool_i + x_pool = data_full["x"][pool_mask] + day_idx = data_full["day_idx"][pool_mask] + sgd = data_full["sample_grid_days"][pool_mask] + + all_dates = set() + for p in pool_ids: + all_dates.update(mc[p]["panel"]["date"].values) + date_list = sorted(all_dates) + sample_dates = np.array([pd.Timestamp(date_list[d]) for d in day_idx]) + + n_days = len(sample_dates) + print(f"Feature samples: {n_days}") + + # Load models + print("\nLoading models...") + linear_model = load_linear_model(args.linear_dir, pid) + print(f" Linear: {len(linear_model['coeffs'])} coefficients," + f" cadence={np.exp(linear_model['log_cadence']):.1f}min") + + has_mlp = os.path.exists(os.path.join(args.mlp_dir, "model.npz")) + if has_mlp: + mlp_model = load_mlp_model(args.mlp_dir, pid) + print(f" MLP: hidden={mlp_model['hidden']}," + f" cadence={np.exp(mlp_model['log_cadence']):.1f}min") + else: + print(f" MLP: no artifact at {args.mlp_dir}") + mlp_model = None + + # V_arb from PCHIP (same for both — uses linear model's cadence) + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + cadence = float(np.exp(linear_model["log_cadence"])) + gas = float(np.exp(oc[pid]["log_gas"])) + v_arb_all = np.array(interpolate_pool_daily( + entry["coeffs"], jnp.float64(np.log(cadence)), jnp.float64(gas))) + v_arb = v_arb_all[sgd] + + # Predict at each TVL + print(f"\nPredicting noise at {len(args.tvl_range)} TVL levels...") + linear_noise = predict_noise_linear(linear_model, x_pool, args.tvl_range) + mlp_noise = predict_noise_mlp(mlp_model, x_pool, args.tvl_range) if mlp_model else {} + + # Real observed volume for comparison + tvl_for_samples = np.zeros(n_days) + vol_for_samples = np.zeros(n_days) + for i, sd in enumerate(sample_dates): + matches = np.where(dates == sd)[0] + if len(matches) > 0: + tvl_for_samples[i] = tvl_obs[matches[0]] + vol_for_samples[i] = vol_obs[matches[0]] + + # ---- Plot 1: Median noise volume vs TVL ---- + fig, axes = plt.subplots(1, 2, figsize=(14, 6)) + + tvls = np.array(args.tvl_range) + lin_medians = np.array([np.median(linear_noise[t]) for t in tvls]) + lin_q25 = np.array([np.percentile(linear_noise[t], 25) for t in tvls]) + lin_q75 = np.array([np.percentile(linear_noise[t], 75) for t in tvls]) + + ax = axes[0] + ax.fill_between(tvls / 1e6, lin_q25 / 1e6, lin_q75 / 1e6, + alpha=0.2, color="steelblue") + ax.plot(tvls / 1e6, lin_medians / 1e6, "o-", color="steelblue", + linewidth=2, label="Linear noise (median)") + + if mlp_noise: + mlp_medians = np.array([np.median(mlp_noise[t]) for t in tvls]) + mlp_q25 = np.array([np.percentile(mlp_noise[t], 25) for t in tvls]) + mlp_q75 = np.array([np.percentile(mlp_noise[t], 75) for t in tvls]) + ax.fill_between(tvls / 1e6, mlp_q25 / 1e6, mlp_q75 / 1e6, + alpha=0.2, color="coral") + ax.plot(tvls / 1e6, mlp_medians / 1e6, "s-", color="coral", + linewidth=2, label="MLP noise (median)") + + # Add real observed volume at real TVL + valid = tvl_for_samples > 100 + ax.scatter(tvl_for_samples[valid] / 1e6, vol_for_samples[valid] / 1e6, + c="black", s=3, alpha=0.2, label="Observed total vol", zorder=1) + + ax.set_xlabel("Effective TVL ($M)") + ax.set_ylabel("Daily volume ($M)") + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_title(f"{entry['tokens']} — Noise Volume vs TVL") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # ---- Plot 2: Noise/TVL ratio vs TVL ---- + ax = axes[1] + ax.plot(tvls / 1e6, lin_medians / tvls * 100, "o-", color="steelblue", + linewidth=2, label="Linear noise/TVL") + if mlp_noise: + ax.plot(tvls / 1e6, mlp_medians / tvls * 100, "s-", color="coral", + linewidth=2, label="MLP noise/TVL") + + # Real vol/TVL + ax.scatter(tvl_for_samples[valid] / 1e6, + vol_for_samples[valid] / tvl_for_samples[valid] * 100, + c="black", s=3, alpha=0.2, label="Observed vol/TVL") + + ax.set_xlabel("Effective TVL ($M)") + ax.set_ylabel("Noise / TVL (%)") + ax.set_xscale("log") + ax.set_title("Noise as Fraction of TVL") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + fig.suptitle(f"Linear vs MLP Noise Model — {entry['tokens']}", fontsize=12) + fig.tight_layout() + out = os.path.join(args.output_dir, f"{pid[:16]}_model_comparison.png") + fig.savefig(out, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"\nSaved: {out}") + + # ---- Plot 3: Time series at selected TVLs ---- + fig, axes = plt.subplots(len(args.tvl_range), 1, + figsize=(14, 3 * len(args.tvl_range)), + sharex=True) + if len(args.tvl_range) == 1: + axes = [axes] + + for k, tvl in enumerate(args.tvl_range): + ax = axes[k] + v_total_lin = v_arb + linear_noise[tvl] + ax.plot(sample_dates, v_total_lin / 1e6, "b-", linewidth=0.6, + alpha=0.7, label="Linear (arb+noise)") + if mlp_noise: + v_total_mlp = v_arb + mlp_noise[tvl] + ax.plot(sample_dates, v_total_mlp / 1e6, "r-", linewidth=0.6, + alpha=0.7, label="MLP (arb+noise)") + ax.plot(sample_dates, vol_for_samples / 1e6, "k-", linewidth=0.5, + alpha=0.3, label="Observed (at real TVL)") + ax.set_ylabel(f"$M/day\nTVL=${tvl/1e6:.1f}M") + ax.set_yscale("log") + if k == 0: + ax.legend(fontsize=7, loc="upper right") + ax.grid(True, alpha=0.3) + + axes[-1].set_xlabel("Date") + fig.suptitle(f"Volume Time Series at Different TVLs — {entry['tokens']}", fontsize=11) + fig.tight_layout() + out2 = os.path.join(args.output_dir, f"{pid[:16]}_tvl_sweep_timeseries.png") + fig.savefig(out2, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"Saved: {out2}") + + # Summary table + print(f"\n{'='*70}") + print(f"Summary: Median daily noise volume by TVL") + print(f"{'='*70}") + print(f"{'TVL':>14s} {'Linear':>12s} {'Lin/TVL':>8s}", end="") + if mlp_noise: + print(f" {'MLP':>12s} {'MLP/TVL':>8s} {'MLP/Lin':>8s}") + else: + print() + + for tvl in args.tvl_range: + lin = np.median(linear_noise[tvl]) + print(f"${tvl:>13,.0f} ${lin:>11,.0f} {lin/tvl*100:>7.1f}%", end="") + if mlp_noise: + mlp = np.median(mlp_noise[tvl]) + ratio = mlp / lin if lin > 0 else 0 + print(f" ${mlp:>11,.0f} {mlp/tvl*100:>7.1f}% {ratio:>7.2f}x") + else: + print() + + +if __name__ == "__main__": + main() diff --git a/experiments/run_residual_comparison.py b/experiments/run_residual_comparison.py new file mode 100644 index 0000000..201c403 --- /dev/null +++ b/experiments/run_residual_comparison.py @@ -0,0 +1,235 @@ +"""Apples-to-apples R² comparison on noise residuals. + +Target for all methods: r_it = log(V_total_it) - log(V_arb_it) + +Methods: + 1. Option C: log(1 + exp(x_obs @ noise_coeffs) / V_arb) + 2. AR1 on residuals: r_{i, t-1} + 3. Ridge on residuals (peers only, in-sample) + 4. Ridge on residuals (peers + own lag, in-sample) + 5. Constant zero (predict r=0, i.e. V_total = V_arb) +""" + +import os +import pickle +import sys + +import numpy as np +from sklearn.linear_model import RidgeCV + +CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "token_factored_calibration", "_cache", +) + + +def load_stage1(): + path = os.path.join(CACHE_DIR, "stage1.pkl") + if not os.path.exists(path): + print("ERROR: no stage1 cache.") + sys.exit(1) + with open(path, "rb") as f: + data = pickle.load(f) + return data["matched_clean"], data["option_c_clean"] + + +def r2_score(y_true, y_pred): + ss_res = np.sum((y_true - y_pred) ** 2) + ss_tot = np.sum((y_true - y_true.mean()) ** 2) + return 1 - ss_res / max(ss_tot, 1e-10) + + +def main(): + os.environ.setdefault("JAX_PLATFORMS", "cpu") + + import jax.numpy as jnp + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + from quantammsim.calibration.pool_data import ( + K_OBS_REDUCED, build_x_obs, _parse_tokens, _canonicalize_token, + ) + + print("=" * 70) + print("Apples-to-Apples: All methods on noise residual target") + print(" target = log(V_total) - log(V_arb)") + print("=" * 70) + + matched_clean, option_c_clean = load_stage1() + pool_ids = sorted(matched_clean.keys()) + n_pools = len(pool_ids) + + # ---- Build aligned data per pool ---- + # For each pool: residual, Option C prediction of residual, dates + pool_data = {} + all_dates = set() + + for pid in pool_ids: + entry = matched_clean[pid] + oc = option_c_clean[pid] + panel = entry["panel"] + + # V_arb + v_arb_all = np.array(interpolate_pool_daily( + entry["coeffs"], + jnp.float64(oc["log_cadence"]), + jnp.float64(np.exp(oc["log_gas"])), + )) + v_arb = v_arb_all[entry["day_indices"]] + log_v_arb = np.log(np.maximum(v_arb, 1e-6)) + + # Observed + log_vol = panel["log_volume"].values.astype(float) + dates = panel["date"].values + + # Noise residual target + resid = log_vol - log_v_arb + + # Option C noise prediction (in residual space) + x_obs = build_x_obs(panel, reduced=True) + noise_coeffs = oc["noise_coeffs"][:K_OBS_REDUCED] + v_noise_oc = np.exp(x_obs @ noise_coeffs) + resid_pred_oc = np.log(np.maximum(1.0 + v_noise_oc / np.maximum(v_arb, 1e-6), 1e-10)) + + pool_data[pid] = { + "dates": dates, + "resid": resid, + "resid_pred_oc": resid_pred_oc, + "log_vol": log_vol, + "v_arb": v_arb, + } + all_dates.update(dates) + + # ---- Build residual matrix for cross-pool methods ---- + date_list = sorted(all_dates) + n_dates = len(date_list) + date_to_idx = {d: i for i, d in enumerate(date_list)} + + resid_matrix = np.full((n_dates, n_pools), np.nan) + for j, pid in enumerate(pool_ids): + pd = pool_data[pid] + for k, date in enumerate(pd["dates"]): + resid_matrix[date_to_idx[date], j] = pd["resid"][k] + + # ---- Token overlap for peer identification ---- + pool_tokens = {} + for i, pid in enumerate(pool_ids): + toks = _parse_tokens(matched_clean[pid]["tokens"]) + pool_tokens[i] = {_canonicalize_token(t) for t in toks[:2]} + + # ---- Compute R² for each method, per pool ---- + results = {m: [] for m in [ + "option_c", "ar1", "ridge_peers", "ridge_peers_own", + "constant_zero", "peer_mean", + ]} + + print(f"\n{'Pool':<18} {'Tokens':<14} {'OptC':>7} {'AR1':>7} " + f"{'R_peer':>7} {'R_p+own':>7} {'zero':>7} {'pmean':>7} {'n':>5}") + print("-" * 90) + + for i, pid in enumerate(pool_ids): + pd = pool_data[pid] + resid = pd["resid"] + n_obs = len(resid) + + # --- Option C --- + r2_oc = r2_score(resid, pd["resid_pred_oc"]) + + # --- Constant zero (V_total = V_arb) --- + r2_zero = r2_score(resid, np.zeros_like(resid)) + + # --- AR1 on residuals --- + if n_obs >= 3: + r2_ar1 = r2_score(resid[1:], resid[:-1]) + else: + r2_ar1 = np.nan + + # --- Ridge peers only (in-sample) --- + X_lag = resid_matrix[:-1, :] + y_cur = resid_matrix[1:, i] + own_lag = X_lag[:, i] + valid = ~np.isnan(y_cur) + + X_others = np.delete(X_lag, i, axis=1) + X_filled = X_others.copy() + for c in range(X_filled.shape[1]): + col = X_filled[:, c] + m = np.nanmean(col) + col[np.isnan(col)] = m if np.isfinite(m) else 0.0 + X_filled[:, c] = col + + X_peers = X_filled[valid] + y_i = y_cur[valid] + + if len(y_i) >= 10: + model_p = RidgeCV(alphas=np.logspace(-2, 4, 50)) + model_p.fit(X_peers, y_i) + r2_rp = r2_score(y_i, model_p.predict(X_peers)) + else: + r2_rp = np.nan + + # --- Ridge peers + own lag (in-sample) --- + valid_own = valid & ~np.isnan(own_lag) + X_both = np.column_stack([X_filled, own_lag[:, None]]) + X_both_v = X_both[valid_own] + y_both = y_cur[valid_own] + + if len(y_both) >= 10: + model_po = RidgeCV(alphas=np.logspace(-2, 4, 50)) + model_po.fit(X_both_v, y_both) + r2_rpo = r2_score(y_both, model_po.predict(X_both_v)) + else: + r2_rpo = np.nan + + # --- Peer mean (zero parameter) --- + peers = [j for j in range(n_pools) if j != i + and len(pool_tokens[i] & pool_tokens[j]) >= 1] + if peers: + peer_lag = resid_matrix[:-1, :][:, peers] + peer_mean = np.nanmean(peer_lag, axis=1) + y_pm = y_cur[valid] + pm_pred = peer_mean[valid] + pm_valid = ~np.isnan(pm_pred) + if pm_valid.sum() >= 3: + r2_pm = r2_score(y_pm[pm_valid], pm_pred[pm_valid]) + else: + r2_pm = np.nan + else: + r2_pm = np.nan + + results["option_c"].append(r2_oc) + results["ar1"].append(r2_ar1) + results["ridge_peers"].append(r2_rp) + results["ridge_peers_own"].append(r2_rpo) + results["constant_zero"].append(r2_zero) + results["peer_mean"].append(r2_pm) + + tokens = matched_clean[pid]["tokens"] + print(f" {pid[:16]} {tokens:<14} {r2_oc:>7.3f} {r2_ar1:>7.3f} " + f"{r2_rp:>7.3f} {r2_rpo:>7.3f} {r2_zero:>7.3f} " + f"{r2_pm:>7.3f} {n_obs:>5}") + + # ---- Summary ---- + def safe_median(xs): + v = [x for x in xs if np.isfinite(x)] + return np.median(v) if v else float("nan") + + print(f"\n{'='*70}") + print("SUMMARY — all on noise residual target") + print(f"{'='*70}") + for name, label in [ + ("option_c", "Option C (per-pool fitted)"), + ("ar1", "AR1 on residuals"), + ("ridge_peers", "Ridge peers only (in-sample)"), + ("ridge_peers_own", "Ridge peers + own lag (in-sample)"), + ("peer_mean", "Peer mean (0 params)"), + ("constant_zero", "Constant zero (V_total=V_arb)"), + ]: + vals = results[name] + med = safe_median(vals) + mean = np.nanmean([x for x in vals if np.isfinite(x)]) + n_neg = sum(1 for x in vals if np.isfinite(x) and x < 0) + print(f" {label:<35} median R² = {med:>7.4f} " + f"mean = {mean:>7.4f} n_neg = {n_neg}") + + +if __name__ == "__main__": + main() diff --git a/experiments/scan_lp_events.py b/experiments/scan_lp_events.py new file mode 100644 index 0000000..7194b84 --- /dev/null +++ b/experiments/scan_lp_events.py @@ -0,0 +1,522 @@ +"""Scan all pools for large LP deposit/withdrawal events and estimate TVL→noise elasticity. + +Identifies "semi-exogenous" LP flow events — large share changes that represent +genuine deposit/withdrawal decisions, not pool creation or dust. + +Filters: + - |Δlog(shares)| > threshold (default 20%) + - Pool must have been active for at least --min-age days before the event + - Pre-event TVL must be above --min-tvl (filters out pool creation events + where initial TVL is dust) + - Enough pre/post data to estimate volume change + +For each event, computes the volume response and implied elasticity. + +Usage: + python experiments/scan_lp_events.py + python experiments/scan_lp_events.py --threshold 0.1 --window 7 + python experiments/scan_lp_events.py --use-api # fetch fresh snapshots +""" + +import argparse +import os +import time + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + + +def load_panel_data(use_api=False): + """Load pool panel data from calibration cache or API.""" + import pickle + + pools = {} + + # Stage1 calibration pools + cache_path = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "token_factored_calibration", "_cache", "stage1.pkl", + ) + if os.path.exists(cache_path): + with open(cache_path, "rb") as f: + data = pickle.load(f) + for pid, entry in data["matched_clean"].items(): + panel = entry["panel"].copy() + panel["pool_id"] = pid + panel["chain"] = entry["chain"] + panel["tokens"] = entry["tokens"] + pools[pid] = panel + + # Noise calibration panel (broader set) + noise_panel_path = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "local_data", "noise_calibration", "panel.parquet", + ) + if os.path.exists(noise_panel_path): + panel_all = pd.read_parquet(noise_panel_path) + for pid in panel_all["pool_id"].unique(): + if pid[:16] not in pools: # don't duplicate + pp = panel_all[panel_all["pool_id"] == pid].copy() + if len(pp) >= 30: + pools[pid[:16]] = pp + + # Top50 snapshots (even broader) + snap_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "local_data", "noise_top50", "snapshots", + ) + if os.path.exists(snap_dir): + import glob + for f in glob.glob(os.path.join(snap_dir, "*.parquet")): + pid = os.path.basename(f).replace(".parquet", "") + if pid[:16] not in pools: + try: + df = pd.read_parquet(f) + if len(df) >= 30 and "total_shares" in df.columns: + df["pool_id"] = pid + pools[pid[:16]] = df + except Exception: + pass + + if use_api: + print(" Fetching fresh snapshots from Balancer API...") + from quantammsim.noise_calibration import ( + fetch_pool_snapshots, BALANCER_API_CHAINS, + ) + for pid_short, panel in list(pools.items()): + if "chain" in panel.columns: + chain = panel["chain"].iloc[0] + else: + chain = "MAINNET" + full_pid = panel["pool_id"].iloc[0] if "pool_id" in panel.columns else pid_short + try: + fresh = fetch_pool_snapshots(full_pid, chain) + if len(fresh) > len(panel): + fresh["pool_id"] = full_pid + fresh["chain"] = chain + if "tokens" in panel.columns: + fresh["tokens"] = panel["tokens"].iloc[0] + pools[pid_short] = fresh + time.sleep(0.3) + except Exception: + pass + + print(f" Loaded {len(pools)} pools") + return pools + + +def find_lp_events(panel, threshold=0.2, min_age_days=30, min_tvl=10_000): + """Find large LP deposit/withdrawal events in a single pool's panel. + + Returns list of event dicts. + """ + dates = pd.to_datetime(panel["date"]) + + # Need shares and TVL + if "total_shares" not in panel.columns: + return [] + shares = panel["total_shares"].values.astype(float) + if np.all(shares <= 0) or np.all(np.isnan(shares)): + return [] + + # TVL + if "total_liquidity_usd" in panel.columns: + tvl = panel["total_liquidity_usd"].values.astype(float) + elif "log_tvl" in panel.columns: + tvl = np.exp(panel["log_tvl"].values.astype(float)) + elif "log_tvl_lag1" in panel.columns: + tvl = np.exp(panel["log_tvl_lag1"].values.astype(float)) + else: + return [] + + # Volume + if "volume_usd" in panel.columns: + vol = panel["volume_usd"].values.astype(float) + elif "log_volume" in panel.columns: + vol = np.exp(panel["log_volume"].values.astype(float)) + else: + return [] + + log_shares = np.log(np.maximum(shares, 1e-10)) + d_log_shares = np.diff(log_shares) + + events = [] + for i in range(len(d_log_shares)): + if abs(d_log_shares[i]) < np.log(1 + threshold): + continue + + # Check min age: pool must have been active for min_age_days + days_active = (dates.iloc[i + 1] - dates.iloc[0]).days + if days_active < min_age_days: + continue + + # Check min TVL before event + if tvl[i] < min_tvl: + continue + + # Check shares aren't near-zero before (not pool creation) + if shares[i] < 1: + continue + + pct_change = (np.exp(d_log_shares[i]) - 1) * 100 + event_type = "deposit" if d_log_shares[i] > 0 else "withdrawal" + + events.append({ + "date": dates.iloc[i + 1], + "idx": i + 1, + "type": event_type, + "d_log_shares": float(d_log_shares[i]), + "pct_change": float(pct_change), + "shares_before": float(shares[i]), + "shares_after": float(shares[i + 1]), + "tvl_before": float(tvl[i]), + "tvl_after": float(tvl[i + 1]), + "vol_on_day": float(vol[i + 1]), + }) + + return events + + +def compute_event_elasticity(panel, event, window=7): + """Compute volume response around an LP event. + + Compares median volume in [event-window, event) vs [event+1, event+window+1). + """ + dates = pd.to_datetime(panel["date"]) + idx = event["idx"] + + if "volume_usd" in panel.columns: + vol = panel["volume_usd"].values.astype(float) + elif "log_volume" in panel.columns: + vol = np.exp(panel["log_volume"].values.astype(float)) + else: + return None + + if "total_liquidity_usd" in panel.columns: + tvl = panel["total_liquidity_usd"].values.astype(float) + elif "log_tvl" in panel.columns: + tvl = np.exp(panel["log_tvl"].values.astype(float)) + elif "log_tvl_lag1" in panel.columns: + tvl = np.exp(panel["log_tvl_lag1"].values.astype(float)) + else: + return None + + pre_start = max(0, idx - window) + post_end = min(len(vol), idx + 1 + window) + + if idx - pre_start < 3 or post_end - (idx + 1) < 3: + return None + + vol_pre = np.median(vol[pre_start:idx]) + vol_post = np.median(vol[idx + 1:post_end]) + tvl_pre = np.median(tvl[pre_start:idx]) + tvl_post = np.median(tvl[idx + 1:post_end]) + + if vol_pre <= 0 or tvl_pre <= 0 or tvl_post <= 0: + return None + + vol_ratio = vol_post / vol_pre + tvl_ratio = tvl_post / tvl_pre + + if abs(np.log(tvl_ratio)) < 0.05: # TVL didn't actually change much + return None + + elasticity = np.log(vol_ratio) / np.log(tvl_ratio) + + return { + "vol_pre": vol_pre, + "vol_post": vol_post, + "tvl_pre": tvl_pre, + "tvl_post": tvl_post, + "vol_ratio": vol_ratio, + "tvl_ratio": tvl_ratio, + "elasticity": elasticity, + } + + +def main(): + parser = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--threshold", type=float, default=0.2, + help="Min |share change| to count as event (0.2 = 20%%)") + parser.add_argument("--window", type=int, default=7, + help="Days before/after event for volume comparison") + parser.add_argument("--min-age", type=int, default=30, + help="Min days pool must be active before event") + parser.add_argument("--min-tvl", type=float, default=10_000, + help="Min TVL before event (filters pool creation)") + parser.add_argument("--use-api", action="store_true", + help="Fetch fresh snapshots from Balancer API") + parser.add_argument("--output-dir", default="results/lp_events", + help="Output directory for CSV and plots") + args = parser.parse_args() + + print("=" * 70) + print("LP Event Scanner: Semi-Exogenous TVL Shocks") + print(f" threshold={args.threshold:.0%}, window={args.window}d," + f" min_age={args.min_age}d, min_tvl=${args.min_tvl:,.0f}") + print("=" * 70) + + pools = load_panel_data(use_api=args.use_api) + + all_events = [] + print(f"\nScanning {len(pools)} pools for LP events...") + + for pid_short, panel in pools.items(): + tokens = (panel["tokens"].iloc[0] if "tokens" in panel.columns + else "?") + chain = (panel["chain"].iloc[0] if "chain" in panel.columns + else "?") + + events = find_lp_events( + panel, threshold=args.threshold, + min_age_days=args.min_age, min_tvl=args.min_tvl) + + for ev in events: + result = compute_event_elasticity(panel, ev, window=args.window) + ev["pool_id"] = pid_short + ev["tokens"] = tokens + ev["chain"] = chain + ev["result"] = result + all_events.append(ev) + + # Sort by absolute share change + all_events.sort(key=lambda e: abs(e["d_log_shares"]), reverse=True) + + print(f"\nFound {len(all_events)} LP events across {len(pools)} pools") + events_with_elasticity = [e for e in all_events if e["result"] is not None] + print(f" {len(events_with_elasticity)} with computable elasticity") + + # Print event table + print(f"\n{'Date':12s} {'Pool':16s} {'Tokens':18s} {'Type':10s}" + f" {'Δshares':>8s} {'TVL before':>12s} {'TVL after':>12s}" + f" {'VolPre':>10s} {'VolPost':>10s} {'Elast':>7s}") + print("-" * 120) + + for ev in all_events: + r = ev["result"] + if r: + elast_str = f"{r['elasticity']:+7.2f}" + vol_pre_str = f"${r['vol_pre']:>9,.0f}" + vol_post_str = f"${r['vol_post']:>9,.0f}" + else: + elast_str = " n/a" + vol_pre_str = " n/a" + vol_post_str = " n/a" + + print(f"{str(ev['date'].date()):12s} {ev['pool_id'][:16]:16s}" + f" {str(ev['tokens'])[:18]:18s} {ev['type']:10s}" + f" {ev['pct_change']:+7.0f}%" + f" ${ev['tvl_before']:>11,.0f} ${ev['tvl_after']:>11,.0f}" + f" {vol_pre_str} {vol_post_str} {elast_str}") + + # Summary statistics + if not events_with_elasticity: + print("No events with computable elasticity.") + return + + deposits = [e for e in events_with_elasticity if e["type"] == "deposit"] + withdrawals = [e for e in events_with_elasticity if e["type"] == "withdrawal"] + + all_elast = [e["result"]["elasticity"] for e in events_with_elasticity] + dep_elast = [e["result"]["elasticity"] for e in deposits] + wth_elast = [e["result"]["elasticity"] for e in withdrawals] + clean = [e for e in events_with_elasticity + if -1 < e["result"]["elasticity"] < 5] + clean_elast = [e["result"]["elasticity"] for e in clean] + + print(f"\n{'='*70}") + print("Summary: Implied TVL→Volume Elasticity") + print(f"{'='*70}") + print(f" All events ({len(all_elast)}):" + f" median={np.median(all_elast):+.2f}" + f" mean={np.mean(all_elast):+.2f}" + f" std={np.std(all_elast):.2f}") + if dep_elast: + print(f" Deposits ({len(dep_elast)}):" + f" median={np.median(dep_elast):+.2f}" + f" mean={np.mean(dep_elast):+.2f}") + if wth_elast: + print(f" Withdrawals ({len(wth_elast)}):" + f" median={np.median(wth_elast):+.2f}" + f" mean={np.mean(wth_elast):+.2f}") + if clean_elast: + print(f"\n Clean events (elasticity in [-1, 5], n={len(clean_elast)}):") + print(f" median={np.median(clean_elast):+.2f}" + f" mean={np.mean(clean_elast):+.2f}" + f" [Q25={np.percentile(clean_elast, 25):+.2f}," + f" Q75={np.percentile(clean_elast, 75):+.2f}]") + + print(f"\n For comparison:") + print(f" Per-pool observational b_tvl: ~1.0") + print(f" Shared observational b_tvl: ~2.5") + print(f" Daily Δ within-pool: ~0.1") + + # ---- Save CSV ---- + out_dir = args.output_dir + os.makedirs(out_dir, exist_ok=True) + + rows = [] + for ev in all_events: + r = ev.get("result") or {} + rows.append({ + "date": ev["date"], + "pool_id": ev["pool_id"], + "tokens": str(ev["tokens"]), + "chain": str(ev["chain"]), + "type": ev["type"], + "pct_change": ev["pct_change"], + "tvl_before": ev["tvl_before"], + "tvl_after": ev["tvl_after"], + "shares_before": ev["shares_before"], + "shares_after": ev["shares_after"], + "vol_pre": r.get("vol_pre"), + "vol_post": r.get("vol_post"), + "tvl_ratio": r.get("tvl_ratio"), + "vol_ratio": r.get("vol_ratio"), + "elasticity": r.get("elasticity"), + }) + df = pd.DataFrame(rows) + csv_path = os.path.join(out_dir, "lp_events.csv") + df.to_csv(csv_path, index=False) + print(f"\n Saved: {csv_path} ({len(df)} events)") + + # ---- Plots ---- + # 1. Elasticity histogram (clean events, deposits vs withdrawals) + fig, axes = plt.subplots(1, 3, figsize=(16, 5)) + + ax = axes[0] + ax.hist(clean_elast, bins=40, color="steelblue", alpha=0.7, edgecolor="white") + ax.axvline(np.median(clean_elast), color="red", linestyle="--", linewidth=2, + label=f"median={np.median(clean_elast):+.2f}") + ax.axvline(1.0, color="gray", linestyle=":", alpha=0.5, label="elasticity=1") + ax.set_xlabel("Elasticity (Δlog vol / Δlog TVL)") + ax.set_ylabel("Count") + ax.set_title(f"All clean events (n={len(clean_elast)})") + ax.legend(fontsize=8) + + ax = axes[1] + dep_clean = [e["result"]["elasticity"] for e in clean if e["type"] == "deposit"] + wth_clean = [e["result"]["elasticity"] for e in clean if e["type"] == "withdrawal"] + ax.hist(dep_clean, bins=30, color="green", alpha=0.6, label=f"deposits (n={len(dep_clean)})", edgecolor="white") + ax.hist(wth_clean, bins=30, color="coral", alpha=0.6, label=f"withdrawals (n={len(wth_clean)})", edgecolor="white") + ax.axvline(1.0, color="gray", linestyle=":", alpha=0.5) + ax.set_xlabel("Elasticity") + ax.set_title("Deposits vs Withdrawals") + ax.legend(fontsize=8) + + # 2. Elasticity vs event size (|Δlog shares|) + ax = axes[2] + sizes = [abs(e["d_log_shares"]) for e in clean] + elasts = [e["result"]["elasticity"] for e in clean] + colors = ["green" if e["type"] == "deposit" else "coral" for e in clean] + ax.scatter(sizes, elasts, c=colors, alpha=0.4, s=15, edgecolors="none") + ax.axhline(1.0, color="gray", linestyle=":", alpha=0.5) + ax.axhline(np.median(elasts), color="red", linestyle="--", alpha=0.7, + label=f"median={np.median(elasts):+.2f}") + ax.set_xlabel("|Δlog(shares)| (event size)") + ax.set_ylabel("Elasticity") + ax.set_title("Elasticity vs Event Size") + ax.legend(fontsize=8) + + fig.suptitle(f"LP Event Elasticity Analysis — {len(clean)} clean events" + f" from {len(pools)} pools", fontsize=11) + fig.tight_layout() + p1 = os.path.join(out_dir, "elasticity_histograms.png") + fig.savefig(p1, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" Saved: {p1}") + + # 3. Elasticity vs pre-event TVL (does pool size affect elasticity?) + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + + ax = axes[0] + tvl_pre = [e["result"]["tvl_pre"] for e in clean] + ax.scatter(tvl_pre, elasts, c=colors, alpha=0.4, s=15, edgecolors="none") + ax.set_xscale("log") + ax.axhline(1.0, color="gray", linestyle=":", alpha=0.5) + ax.set_xlabel("Pre-event TVL (USD)") + ax.set_ylabel("Elasticity") + ax.set_title("Elasticity vs Pool Size") + + # Bin by TVL decile and show median elasticity + tvl_arr = np.array(tvl_pre) + el_arr = np.array(elasts) + for q_lo, q_hi in [(0, 25), (25, 50), (50, 75), (75, 100)]: + lo = np.percentile(tvl_arr, q_lo) + hi = np.percentile(tvl_arr, q_hi) + mask = (tvl_arr >= lo) & (tvl_arr < hi + 1) + if mask.sum() > 5: + med_tvl = np.median(tvl_arr[mask]) + med_el = np.median(el_arr[mask]) + ax.plot(med_tvl, med_el, "rs", markersize=10, zorder=5) + ax.annotate(f"{med_el:.2f}", (med_tvl, med_el), + textcoords="offset points", xytext=(8, 5), fontsize=7) + + # 4. log(vol_post/vol_pre) vs log(tvl_post/tvl_pre) scatter + ax = axes[1] + log_tvl_ratio = [np.log(e["result"]["tvl_ratio"]) for e in clean] + log_vol_ratio = [np.log(e["result"]["vol_ratio"]) for e in clean] + ax.scatter(log_tvl_ratio, log_vol_ratio, c=colors, alpha=0.4, s=15, + edgecolors="none") + + # OLS fit line + x_fit = np.array(log_tvl_ratio) + y_fit = np.array(log_vol_ratio) + slope, intercept = np.polyfit(x_fit, y_fit, 1) + x_line = np.linspace(x_fit.min(), x_fit.max(), 100) + ax.plot(x_line, slope * x_line + intercept, "r-", linewidth=2, + label=f"OLS slope={slope:.2f}") + ax.plot(x_line, x_line, "k--", alpha=0.3, label="1:1 line") + ax.set_xlabel("Δlog(TVL)") + ax.set_ylabel("Δlog(Volume)") + ax.set_title("Volume Response to TVL Shocks") + ax.legend(fontsize=8) + + fig.suptitle("TVL→Volume Elasticity: Event Study", fontsize=11) + fig.tight_layout() + p2 = os.path.join(out_dir, "elasticity_scatter.png") + fig.savefig(p2, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" Saved: {p2}") + + # 5. Elasticity by chain + fig, ax = plt.subplots(figsize=(10, 5)) + chain_data = {} + for e in clean: + ch = str(e["chain"]) + if ch not in chain_data: + chain_data[ch] = [] + chain_data[ch].append(e["result"]["elasticity"]) + chains_sorted = sorted(chain_data.keys(), + key=lambda c: len(chain_data[c]), reverse=True) + chains_plot = [c for c in chains_sorted if len(chain_data[c]) >= 5] + if chains_plot: + positions = range(len(chains_plot)) + bp = ax.boxplot([chain_data[c] for c in chains_plot], + positions=positions, widths=0.6, patch_artist=True) + for patch in bp["boxes"]: + patch.set_facecolor("steelblue") + patch.set_alpha(0.6) + ax.set_xticks(positions) + ax.set_xticklabels([f"{c}\n(n={len(chain_data[c])})" for c in chains_plot], + fontsize=8) + ax.axhline(1.0, color="gray", linestyle=":", alpha=0.5) + ax.axhline(np.median(clean_elast), color="red", linestyle="--", alpha=0.5, + label=f"overall median={np.median(clean_elast):.2f}") + ax.set_ylabel("Elasticity") + ax.set_title("Elasticity by Chain") + ax.set_ylim(-2, 5) + ax.legend(fontsize=8) + fig.tight_layout() + p3 = os.path.join(out_dir, "elasticity_by_chain.png") + fig.savefig(p3, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" Saved: {p3}") + + +if __name__ == "__main__": + main() diff --git a/experiments/tune_reclamm_calibrated_noise.py b/experiments/tune_reclamm_calibrated_noise.py new file mode 100644 index 0000000..55ca99b --- /dev/null +++ b/experiments/tune_reclamm_calibrated_noise.py @@ -0,0 +1,375 @@ +"""Optuna tuning of reClAMM pool parameters with calibrated noise models. + +Supports two noise model modes: + --noise-model calibrated (legacy 8-covariate model) + --noise-model market_linear (new per-pool model with market features) + +The market_linear model uses precomputed daily arrays from the per-pool +calibrated noise model artifact (results/linear_market_noise/). It evaluates: + + log(V_noise) = base_t + tvl_coeff_t * log(effective_TVL) + +where base_t absorbs all non-TVL terms (market regime, token volatility, +pair volatility, day-of-week, cross-pool volumes) and tvl_coeff_t is the +effective TVL coefficient including interaction terms. + +Pool: 0x9d1fcf346ea1b0 = AAVE/WETH Mainnet + +Usage: + cd + source ~/miniconda3/etc/profile.d/conda.sh && conda activate qsim_reclamm_public + + # New market_linear model (default) + python experiments/tune_reclamm_calibrated_noise.py + + # Legacy 8-covariate model + python experiments/tune_reclamm_calibrated_noise.py --noise-model calibrated + + # All three objectives + python experiments/tune_reclamm_calibrated_noise.py --all-objectives + + # More trials + python experiments/tune_reclamm_calibrated_noise.py --n-trials 200 +""" + +import argparse +import json +import math +import numpy as np +from pathlib import Path +from quantammsim.runners.jax_runners import train_on_historic_data + +POOL_ID = "0x9d1fcf346ea1b0" # AAVE/WETH Mainnet + +# --- Legacy 8-covariate noise coefficients --- +NOISE_COEFFS_LEGACY = [ + -0.453, # c_0: intercept + 0.025, # c_1: log(TVL) + -0.060, # c_2: log(sigma) + 0.310, # c_3: log(TVL) * log(sigma) + -0.149, # c_4: log(TVL) * fee + 0.359, # c_5: log(sigma) * fee + 0.061, # c_6: dow_sin + 0.060, # c_7: dow_cos +] +LEGACY_LOG_CADENCE = 2.68 +LEGACY_ARB_FREQUENCY = max(1, round(math.exp(LEGACY_LOG_CADENCE))) # ~15 min + +PARAMETER_CONFIG = { + "price_ratio": {"low": 1.01, "high": 200.0, "log_scale": True, "scalar": True}, + "centeredness_margin": {"low": 0.01, "high": 0.99, "scalar": True}, + "shift_exponent": {"low": 1e-5, "high": 125.0, "log_scale": True, "scalar": True}, +} + +OBJECTIVES = ["daily_log_sharpe", "returns_over_hodl", "fee_revenue_over_value"] + + +def _build_market_linear_arrays(args): + """Precompute noise arrays from the per-pool market noise model artifact.""" + from quantammsim.calibration.noise_model_arrays import build_simulator_arrays + + # Parse dates — strip time component for the array builder + start = args.start_date.split(" ")[0] + end = args.end_test_date.split(" ")[0] + + print(f" Building market_linear noise arrays for {POOL_ID}...") + print(f" Date range: {start} → {end}") + arrays = build_simulator_arrays( + token_a="AAVE", + token_b="ETH", + start_date=start, + end_date=end, + artifact_dir=args.artifact_dir, + pool_id=POOL_ID, + ) + print(f" {arrays['n_days']} days, {arrays['n_minutes']} minutes") + print(f" noise_base range: [{arrays['noise_base'].min():.2f}," + f" {arrays['noise_base'].max():.2f}]") + print(f" noise_tvl_coeff range: [{arrays['noise_tvl_coeff'].min():.4f}," + f" {arrays['noise_tvl_coeff'].max():.4f}]") + + # Save arrays to disk (fingerprint can't hold numpy arrays — it gets JSON-serialized) + import os + cache_dir = os.path.join(args.artifact_dir, "_sim_arrays") + os.makedirs(cache_dir, exist_ok=True) + arrays_path = os.path.join(cache_dir, f"{POOL_ID}_{start}_{end}.npz") + np.savez(arrays_path, + noise_base=arrays["noise_base"], + noise_tvl_coeff=arrays["noise_tvl_coeff"], + tvl_mean=arrays["tvl_mean"], + tvl_std=arrays["tvl_std"]) + print(f" Saved arrays: {arrays_path}") + + # Get learned cadence from artifact + from quantammsim.calibration.noise_model_arrays import load_artifact, _find_pool_index + art, meta = load_artifact(args.artifact_dir) + pool_idx = _find_pool_index(POOL_ID, meta["pool_ids"]) + if pool_idx >= 0: + learned_cadence = float(np.exp(art["log_cadence"][pool_idx])) + print(f" Learned cadence: {learned_cadence:.1f} min") + else: + learned_cadence = 5.0 + print(f" Pool not in calibration set, using default cadence: {learned_cadence}") + + return arrays_path, max(1, round(learned_cadence)) + + +def _build_mm_observed_arrays(args): + """Precompute noise arrays from the MM model + DeFi Llama competitor TVL.""" + from quantammsim.calibration.noise_model_arrays import ( + build_mm_simulator_arrays, load_artifact, _find_pool_index, + ) + + start = args.start_date.split(" ")[0] + end = args.end_test_date.split(" ")[0] + + print(f" Building mm_observed noise arrays for {POOL_ID}...") + print(f" Date range: {start} → {end}") + arrays = build_mm_simulator_arrays( + token_a="AAVE", + token_b="ETH", + start_date=start, + end_date=end, + mm_artifact_dir=args.artifact_dir, + competitor_tvl_path=args.competitor_tvl_path, + pool_id=POOL_ID, + ) + print(f" {arrays['n_days']} days, {arrays['n_minutes']} minutes") + print(f" noise_base range: [{arrays['noise_base'].min():.2f}," + f" {arrays['noise_base'].max():.2f}]") + print(f" competitor_tvl range: [${np.exp(np.log(arrays['competitor_tvl'].max())):.0f}]") + + # Save arrays to disk + import os + cache_dir = os.path.join(args.artifact_dir, "_sim_arrays") + os.makedirs(cache_dir, exist_ok=True) + arrays_path = os.path.join(cache_dir, f"{POOL_ID}_{start}_{end}_mm.npz") + np.savez(arrays_path, + noise_base=arrays["noise_base"], + competitor_tvl=arrays["competitor_tvl"]) + print(f" Saved arrays: {arrays_path}") + + # Get cadence from MM model artifact + art, meta = load_artifact(args.artifact_dir) + pool_idx = _find_pool_index(POOL_ID, meta["pool_ids"]) + if pool_idx >= 0 and "log_cadence" in art: + learned_cadence = float(np.exp(art["log_cadence"][pool_idx])) + print(f" Learned cadence: {learned_cadence:.1f} min") + else: + learned_cadence = 5.0 + print(f" Using default cadence: {learned_cadence}") + + return arrays_path, max(1, round(learned_cadence)) + + +def _build_opt_settings(args): + """Build optimisation_settings for optuna, bfgs, or cma_es.""" + if args.method == "bfgs": + return { + "method": "bfgs", + "n_parameter_sets": args.n_parameter_sets, + **({"val_fraction": args.val_fraction} if args.val_fraction is not None else {}), + "bfgs_settings": { + "maxiter": args.bfgs_maxiter, + "tol": args.bfgs_tol, + "n_evaluation_points": args.bfgs_eval_points, + "compute_dtype": "float64", + }, + } + elif args.method == "cma_es": + return { + "method": "cma_es", + "n_parameter_sets": args.n_parameter_sets, + **({"val_fraction": args.val_fraction} if args.val_fraction is not None else {}), + "cma_es_settings": { + "population_size": args.cma_pop_size, + "n_generations": args.cma_generations, + "sigma0": args.cma_sigma0, + "tol": 1e-8, + "n_evaluation_points": args.cma_eval_points, + "compute_dtype": "float32", + }, + } + else: + return { + "method": "optuna", + "n_parameter_sets": 1, + **({"val_fraction": args.val_fraction} if args.val_fraction is not None else {}), + "optuna_settings": { + "make_scalar": True, + "expand_around": False, + "n_trials": args.n_trials, + "multi_objective": False, + "parameter_config": PARAMETER_CONFIG, + **({"overfitting_penalty": args.overfitting_penalty} + if args.overfitting_penalty is not None else {}), + **({"min_train_returns_over_hodl": args.min_train_ret} + if args.min_train_ret is not None else {}), + }, + } + + +def build_fingerprint(objective, args, noise_arrays_path=None, arb_freq=None): + """Build run fingerprint with calibrated noise model.""" + if args.noise_model == "mm_observed" and noise_arrays_path is not None: + noise_block = { + "noise_trader_ratio": 0.0, + "noise_model": "mm_observed", + "noise_arrays_path": noise_arrays_path, + } + freq = arb_freq or 5 + elif args.noise_model == "market_linear" and noise_arrays_path is not None: + _arr = np.load(noise_arrays_path) + noise_block = { + "noise_trader_ratio": 0.0, + "noise_model": "market_linear", + "noise_arrays_path": noise_arrays_path, + "reclamm_noise_params": { + "tvl_mean": float(_arr["tvl_mean"]), + "tvl_std": float(_arr["tvl_std"]), + }, + } + freq = arb_freq or 5 + else: + noise_block = { + "noise_trader_ratio": 0.0, + "noise_model": "calibrated", + "reclamm_noise_params": { + f"c_{i}": NOISE_COEFFS_LEGACY[i] for i in range(8) + }, + } + freq = LEGACY_ARB_FREQUENCY + + return { + "rule": "reclamm", + "tokens": ["AAVE", "ETH"], + "startDateString": args.start_date, + "endDateString": args.end_date, + "endTestDateString": args.end_test_date, + "initial_pool_value": args.initial_pool_value, + "do_arb": True, + "arb_frequency": freq, + "fees": args.fees, + "gas_cost": args.gas_cost, + "arb_fees": 0.0, + "protocol_fee_split": 0.25, + **noise_block, + "return_val": objective, + "reclamm_interpolation_method": args.interpolation, + "reclamm_centeredness_scaling": args.centeredness_scaling, + "reclamm_learn_arc_length_speed": False, + "reclamm_use_shift_exponent": True, + **({"bout_offset": args.bout_offset} if args.bout_offset is not None else {}), + "optimisation_settings": _build_opt_settings(args), + } + + +def run_single(objective, args, noise_arrays_path=None, arb_freq=None): + """Run Optuna tuning for a single objective.""" + print(f"\n{'='*60}") + print(f" Objective: {objective}") + print(f" Noise model: {args.noise_model}") + print(f" Method: {args.method}") + print(f" Pool: AAVE/WETH Mainnet ({POOL_ID})") + print(f" Train: {args.start_date} → {args.end_date}") + print(f" Test: {args.end_date} → {args.end_test_date}") + if arb_freq: + print(f" Arb frequency: {arb_freq} min (learned)") + print(f"{'='*60}\n") + + fp = build_fingerprint(objective, args, noise_arrays_path, arb_freq) + result = train_on_historic_data(fp, verbose=True) + + if result is not None: + print(f"\n=== Result ({objective}) ===") + for k, v in result.items(): + print(f" {k}: {v}") + + return result + + +def main(): + parser = argparse.ArgumentParser( + description="Tune reClAMM params with calibrated 8-covariate noise model" + ) + parser.add_argument("--method", default="optuna", + choices=["optuna", "bfgs", "cma_es"], + help="Optimisation method") + parser.add_argument("--n-trials", type=int, default=50, + help="Optuna trials (ignored for bfgs)") + parser.add_argument("--n-parameter-sets", type=int, default=1, + help="Number of parameter sets for bfgs") + parser.add_argument("--bfgs-maxiter", type=int, default=100) + parser.add_argument("--bfgs-tol", type=float, default=1e-6) + parser.add_argument("--bfgs-eval-points", type=int, default=20, + help="Number of evaluation points for bfgs") + # CMA-ES + parser.add_argument("--cma-generations", type=int, default=300) + parser.add_argument("--cma-sigma0", type=float, default=0.5, + help="Initial step size for CMA-ES") + parser.add_argument("--cma-pop-size", type=int, default=None, + help="Population size (None = auto)") + parser.add_argument("--cma-eval-points", type=int, default=20) + parser.add_argument("--min-train-ret", type=float, default=-0.5, + help="Reject trials with IS returns_over_hodl below this") + parser.add_argument("--noise-model", default="market_linear", + choices=["calibrated", "market_linear", "mm_observed"], + help="Noise model variant") + parser.add_argument("--artifact-dir", + default="results/linear_market_noise", + help="Artifact dir for market_linear or mm_observed model") + parser.add_argument("--competitor-tvl-path", + default="results/competitor_tvl/competitor_tvl.npz", + help="Path to competitor TVL data (mm_observed only)") + parser.add_argument("--initial-pool-value", type=float, default=20_000_000.0, + help="Initial pool TVL in USD (default: 20M)") + parser.add_argument("--fees", type=float, default=0.0025, + help="Pool fee rate (default: 0.0025 matching calibration)") + parser.add_argument("--gas-cost", type=float, default=1.0) + parser.add_argument("--objective", default="fee_revenue_over_value", + choices=OBJECTIVES) + parser.add_argument("--all-objectives", action="store_true", + help="Run all three objectives sequentially") + parser.add_argument("--interpolation", default="geometric", + choices=["geometric", "constant_arc_length"]) + parser.add_argument("--centeredness-scaling", action="store_true") + parser.add_argument("--start-date", default="2024-06-01 00:00:00") + parser.add_argument("--end-date", default="2025-06-01 00:00:00", + help="End of training / start of test") + parser.add_argument("--end-test-date", default="2026-03-01 00:00:00", + help="End of test (latest available data)") + parser.add_argument("--bout-offset", type=int, default=None) + parser.add_argument("--val-fraction", type=float, default=None) + parser.add_argument("--overfitting-penalty", type=float, default=None) + parser.add_argument("--output", type=str, default=None, + help="Save results to JSON file") + args = parser.parse_args() + + if args.all_objectives: + objectives = OBJECTIVES + else: + objectives = [args.objective] + + # Precompute noise arrays once + noise_arrays_path = None + arb_freq = None + if args.noise_model == "market_linear": + noise_arrays_path, arb_freq = _build_market_linear_arrays(args) + elif args.noise_model == "mm_observed": + noise_arrays_path, arb_freq = _build_mm_observed_arrays(args) + + all_results = {} + for obj in objectives: + result = run_single(obj, args, noise_arrays_path, arb_freq) + all_results[obj] = result + + if args.output: + out_path = Path(args.output) + out_path.parent.mkdir(parents=True, exist_ok=True) + with open(out_path, "w") as f: + json.dump(all_results, f, indent=2, default=str) + print(f"\nResults saved to {out_path}") + + +if __name__ == "__main__": + main() diff --git a/experiments/tune_reclamm_params.py b/experiments/tune_reclamm_params.py index 4969923..e2448bf 100644 --- a/experiments/tune_reclamm_params.py +++ b/experiments/tune_reclamm_params.py @@ -13,9 +13,16 @@ # More trials, custom fees python experiments/tune_reclamm_params.py --n-trials 200 --fees 0.005 + + # With calibrated 8-covariate noise model and arb frequency from calibration + python experiments/tune_reclamm_params.py --noise-model calibrated \ + --noise-params-json results/mlp_calibration/option_c_reduced.json \ + --noise-pool-id 0x9d1fcf346ea1b0 """ import argparse +import json +import math from quantammsim.runners.jax_runners import train_on_historic_data PARAMETER_CONFIG = { @@ -39,6 +46,15 @@ def main(): choices=["geometric", "constant_arc_length"]) parser.add_argument("--centeredness-scaling", action="store_true") parser.add_argument("--noise-trader-ratio", type=float, default=0.0) + parser.add_argument("--noise-model", default=None, + choices=["ratio", "loglinear", "calibrated", "arb_only"], + help="Noise volume model (default: ratio via noise-trader-ratio)") + parser.add_argument("--noise-params-json", default=None, + help="JSON file with per-pool calibration results") + parser.add_argument("--noise-pool-id", default=None, + help="Pool ID to load noise params for (from --noise-params-json)") + parser.add_argument("--arb-frequency", type=int, default=None, + help="Arb frequency in minutes (default: from calibrated cadence or 1)") parser.add_argument("--start-date", default="2024-06-01 00:00:00") parser.add_argument("--end-date", default="2025-01-01 00:00:00", help="End of training / start of test") @@ -49,6 +65,8 @@ def main(): help="Validation holdout fraction (default: 0.2, use 0 to disable)") parser.add_argument("--overfitting-penalty", type=float, default=None, help="Overfitting penalty weight (default: 0.2)") + parser.add_argument("--n-eval-points", type=int, default=None, + help="Number of evaluation sub-windows (default: 20, use 1 for full-window)") args = parser.parse_args() learn_speed = args.interpolation == "constant_arc_length" @@ -56,19 +74,72 @@ def main(): if learn_speed: param_config.update(ARC_LENGTH_SPEED_CONFIG) + # --- Noise model setup --- + pool_tokens = ["AAVE", "ETH"] # default + noise_fp = {"noise_trader_ratio": args.noise_trader_ratio} + if args.noise_model: + noise_fp["noise_model"] = args.noise_model + if args.noise_params_json and args.noise_pool_id: + with open(args.noise_params_json) as f: + all_results = json.load(f) + # Support both {"option_c_reduced": {pid: ...}} and {pid: ...} formats + pool_results = all_results + for key in all_results: + if isinstance(all_results[key], dict) and args.noise_pool_id in all_results[key]: + pool_results = all_results[key] + break + pool_data = pool_results[args.noise_pool_id] + coeffs = pool_data["noise_coeffs"] + if len(coeffs) == 8: + # Full 8-covariate model: [intercept, log_tvl, log_sigma, + # tvl*sigma, tvl*fee, sigma*fee, dow_sin, dow_cos] + noise_fp["reclamm_noise_params"] = { + f"c_{i}": c for i, c in enumerate(coeffs) + } + elif len(coeffs) == 4: + # Reduced 4-covariate model: [intercept, log_tvl, dow_sin, dow_cos] + # Map to c_0, c_1, c_6, c_7 (sigma/fee terms stay at 0) + noise_fp["reclamm_noise_params"] = { + "c_0": coeffs[0], "c_1": coeffs[1], + "c_6": coeffs[2], "c_7": coeffs[3], + } + else: + raise ValueError(f"Expected 4 or 8 noise_coeffs, got {len(coeffs)}") + # Derive arb_frequency from calibrated cadence if not explicitly set + if args.arb_frequency is None: + log_cad = pool_data["log_cadence"] + args.arb_frequency = max(1, round(math.exp(log_cad))) + print(f" arb_frequency={args.arb_frequency} " + f"(from log_cadence={log_cad:.2f}, " + f"cadence={math.exp(log_cad):.1f} min)") + # Use pool's fee and gas from calibration as defaults + if "fee" in pool_data: + args.fees = pool_data["fee"] + if "gas_usd" in pool_data: + args.gas_cost = pool_data["gas_usd"] + # Pick up token pair from calibration + # Map on-chain names (WETH, WBTC) to data-file names (ETH, BTC) + _TOKEN_MAP = {"WETH": "ETH", "WBTC": "BTC"} + if "tokens" in pool_data: + pool_tokens = [ + _TOKEN_MAP.get(t, t) for t in pool_data["tokens"].split(",") + ] + print(f" tokens={pool_tokens}, fee={args.fees}, gas={args.gas_cost}") + fp = { "rule": "reclamm", - "tokens": ["AAVE", "ETH"], + "tokens": pool_tokens, "startDateString": args.start_date, "endDateString": args.end_date, "endTestDateString": args.end_test_date, "initial_pool_value": 1_000_000.0, "do_arb": True, + **({"arb_frequency": args.arb_frequency} if args.arb_frequency is not None else {}), "fees": args.fees, "gas_cost": args.gas_cost, "arb_fees": 0.0, "protocol_fee_split": 0.5, - "noise_trader_ratio": args.noise_trader_ratio, + **noise_fp, "return_val": args.objective, "reclamm_interpolation_method": args.interpolation, "reclamm_centeredness_scaling": args.centeredness_scaling, @@ -86,6 +157,7 @@ def main(): "multi_objective": False, "parameter_config": param_config, **({"overfitting_penalty": args.overfitting_penalty} if args.overfitting_penalty is not None else {}), + **({"n_evaluation_points": args.n_eval_points} if args.n_eval_points is not None else {}), }, }, } diff --git a/experiments/validate_tvl_counterfactual.py b/experiments/validate_tvl_counterfactual.py new file mode 100644 index 0000000..d4982b5 --- /dev/null +++ b/experiments/validate_tvl_counterfactual.py @@ -0,0 +1,236 @@ +"""Validate the noise model's TVL counterfactual predictions. + +Uses the AAVE/WETH reClAMM pool's natural experiment (70x TVL increase +from LP deposit in Jan 2026) to check whether the model's combined +V_arb + V_noise prediction matches observed volume changes. + +Tests whether the full model (PCHIP arb grid + per-pool linear noise +with b_tvl on standardized features) produces the right total volume +response, even though the noise-specific elasticity (~0.42 raw) is +lower than the event study's total elasticity (~0.9). + +Also evaluates counterfactual noise volumes at specified TVL levels. + +Usage: + python experiments/validate_tvl_counterfactual.py + python experiments/validate_tvl_counterfactual.py --pool 0x9d1fcf346ea1b0 + python experiments/validate_tvl_counterfactual.py --counterfactual-tvl 1e6 5e6 20e6 50e6 +""" + +import argparse +import json +import os +import pickle + +import numpy as np +import pandas as pd + +import jax.numpy as jnp +from quantammsim.calibration.grid_interpolation import interpolate_pool_daily +from quantammsim.calibration.noise_model_arrays import load_artifact + + +CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "token_factored_calibration", "_cache", +) +ARTIFACT_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "linear_market_noise", +) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--pool", default="0x9d1fcf346ea1b0", + help="Pool ID prefix") + parser.add_argument("--artifact-dir", default=ARTIFACT_DIR) + parser.add_argument("--pre-cutoff", default="2026-01-10", + help="Date before which = pre-deposit") + parser.add_argument("--post-cutoff", default="2026-01-20", + help="Date after which = post-deposit") + parser.add_argument("--counterfactual-tvl", type=float, nargs="+", + default=[70_000, 500_000, 5_000_000, 20_000_000], + help="TVL values for counterfactual evaluation") + args = parser.parse_args() + + os.environ.setdefault("JAX_PLATFORMS", "cpu") + + # ---- Load model artifact ---- + art, meta = load_artifact(args.artifact_dir) + nc = art["noise_coeffs"] + log_cad = art["log_cadence"] + x_mean = art["x_mean"] + x_std = art["x_std"] + pool_ids = meta["pool_ids"] + feat_names = meta["feat_names"] + per_pool = nc.ndim == 2 + + # ---- Load pool data ---- + with open(os.path.join(CACHE_DIR, "stage1.pkl"), "rb") as f: + data = pickle.load(f) + mc = data["matched_clean"] + oc = data["option_c_clean"] + + pid = args.pool + if pid not in pool_ids: + # Try prefix match + matches = [p for p in pool_ids if p.startswith(pid) or pid.startswith(p)] + if matches: + pid = matches[0] + else: + print(f"Pool {args.pool} not found in calibration set") + return + + idx = pool_ids.index(pid) + coeffs = nc[idx] if per_pool else nc + cadence = float(np.exp(log_cad[idx])) + gas = float(np.exp(oc[pid]["log_gas"])) + tvl_col = feat_names.index("xobs_1") + + print("=" * 70) + print("TVL Counterfactual Validation") + print(f" Pool: {pid} ({mc[pid]['tokens']}, {mc[pid]['chain']})") + print(f" Learned cadence: {cadence:.1f} min") + print(f" Gas: ${gas:.2f}") + print(f" b_tvl (standardized): {coeffs[tvl_col]:.4f}") + print(f" TVL standardization: mean={x_mean[tvl_col]:.2f}," + f" std={x_std[tvl_col]:.2f}") + print(f" Raw noise elasticity: {coeffs[tvl_col]/x_std[tvl_col]:.4f}") + print("=" * 70) + + # ---- V_arb from PCHIP ---- + entry = mc[pid] + v_arb_all = np.array(interpolate_pool_daily( + entry["coeffs"], jnp.float64(np.log(cadence)), jnp.float64(gas))) + day_indices = entry["day_indices"] + v_arb = v_arb_all[day_indices] + + panel = entry["panel"] + log_vol = panel["log_volume"].values.astype(float) + log_tvl = panel["log_tvl_lag1"].values.astype(float) + vol_obs = np.exp(log_vol) + tvl = np.exp(log_tvl) + dates = pd.to_datetime(panel["date"]) + + pre_mask = dates < args.pre_cutoff + post_mask = dates >= args.post_cutoff + + # ---- Build full feature vectors ---- + from experiments.run_linear_market_noise import build_data + data_full = build_data(mc, oc, trend_windows=(7,), + include_market=True, include_cross_pool=True) + x_full = data_full["x"] + pool_idx_full = data_full["pool_idx"] + day_idx_full = data_full["day_idx"] + + pool_i = pool_ids.index(pid) + pool_mask = pool_idx_full == pool_i + + all_dates = set() + for p in pool_ids: + all_dates.update(mc[p]["panel"]["date"].values) + date_list = sorted(all_dates) + + sample_dates = np.array([pd.Timestamp(date_list[d]) + for d in day_idx_full[pool_mask]]) + sample_x = x_full[pool_mask] + sgd = data_full["sample_grid_days"][pool_mask] + v_arb_samples = v_arb_all[sgd] + + # Per-sample noise prediction + if per_pool: + log_v_noise = sample_x @ coeffs + else: + log_v_noise = sample_x @ coeffs + v_noise = np.exp(log_v_noise) + + sample_pre = sample_dates < pd.Timestamp(args.pre_cutoff) + sample_post = sample_dates >= pd.Timestamp(args.post_cutoff) + + # ---- Pre/post comparison ---- + print(f"\n=== Pre-deposit (before {args.pre_cutoff}) ===") + print(f" Median TVL: ${np.median(tvl[pre_mask]):>14,.0f}") + print(f" Median V_obs: ${np.median(vol_obs[pre_mask]):>14,.0f}") + print(f" Median V_arb: ${np.median(v_arb[pre_mask]):>14,.0f} (PCHIP)") + print(f" Median V_noise: ${np.median(v_noise[sample_pre]):>14,.0f} (model)") + v_total_pre = v_arb_samples[sample_pre] + v_noise[sample_pre] + print(f" Median V_total: ${np.median(v_total_pre):>14,.0f} (V_arb + V_noise)") + + print(f"\n=== Post-deposit (after {args.post_cutoff}) ===") + print(f" Median TVL: ${np.median(tvl[post_mask]):>14,.0f}") + print(f" Median V_obs: ${np.median(vol_obs[post_mask]):>14,.0f}") + print(f" Median V_arb: ${np.median(v_arb[post_mask]):>14,.0f} (PCHIP)") + print(f" Median V_noise: ${np.median(v_noise[sample_post]):>14,.0f} (model)") + v_total_post = v_arb_samples[sample_post] + v_noise[sample_post] + print(f" Median V_total: ${np.median(v_total_post):>14,.0f} (V_arb + V_noise)") + + # ---- Ratios ---- + tvl_ratio = np.median(tvl[post_mask]) / np.median(tvl[pre_mask]) + vol_ratio = np.median(vol_obs[post_mask]) / np.median(vol_obs[pre_mask]) + varb_ratio = np.median(v_arb[post_mask]) / np.median(v_arb[pre_mask]) + vnoise_ratio = np.median(v_noise[sample_post]) / np.median(v_noise[sample_pre]) + vtotal_ratio = np.median(v_total_post) / np.median(v_total_pre) + + print(f"\n=== Ratios (post / pre) ===") + print(f" TVL: {tvl_ratio:>8.1f}x") + print(f" V_obs: {vol_ratio:>8.1f}x (ground truth)") + print(f" V_arb: {varb_ratio:>8.1f}x (PCHIP grid)") + print(f" V_noise: {vnoise_ratio:>8.1f}x (noise model)") + print(f" V_total: {vtotal_ratio:>8.1f}x (V_arb + V_noise)") + print(f" Gap: {vtotal_ratio/vol_ratio:>8.2f}x (pred/obs)") + + # ---- Decomposition shares ---- + print(f"\n=== Decomposition shares ===") + arb_share_pre = np.median(v_arb[pre_mask]) / np.median(vol_obs[pre_mask]) * 100 + noise_share_pre = np.median(v_noise[sample_pre]) / np.median(vol_obs[pre_mask]) * 100 + arb_share_post = np.median(v_arb[post_mask]) / np.median(vol_obs[post_mask]) * 100 + noise_share_post = np.median(v_noise[sample_post]) / np.median(vol_obs[post_mask]) * 100 + + print(f" Pre: arb={arb_share_pre:.0f}% noise={noise_share_pre:.0f}%") + print(f" Post: arb={arb_share_post:.0f}% noise={noise_share_post:.0f}%") + + # ---- Counterfactual evaluation ---- + print(f"\n=== Counterfactual noise volumes ===") + print(f" (Using median pre-deposit market features, varying TVL only)") + print(f" {'TVL':>14s} {'V_noise/day':>12s} {'V_noise/min':>12s}" + f" {'Ratio vs 70K':>12s}") + print(f" {'-'*55}") + + x_base = np.median(sample_x[sample_pre], axis=0).copy() + baseline_tvl = 70_000 + x_baseline = x_base.copy() + x_baseline[tvl_col] = (np.log(baseline_tvl) - x_mean[tvl_col]) / x_std[tvl_col] + for i, name in enumerate(feat_names): + if name.startswith("xobs_1" + "\u00d7"): + paired_name = name.split("\u00d7")[1] + if paired_name in feat_names: + paired_idx = feat_names.index(paired_name) + x_baseline[i] = x_baseline[tvl_col] * x_base[paired_idx] + vn_baseline = np.exp(x_baseline @ coeffs) + + for cf_tvl in args.counterfactual_tvl: + x_cf = x_base.copy() + std_log_tvl = (np.log(cf_tvl) - x_mean[tvl_col]) / x_std[tvl_col] + x_cf[tvl_col] = std_log_tvl + for i, name in enumerate(feat_names): + if name.startswith("xobs_1" + "\u00d7"): + paired_name = name.split("\u00d7")[1] + if paired_name in feat_names: + paired_idx = feat_names.index(paired_name) + x_cf[i] = std_log_tvl * x_base[paired_idx] + + vn = np.exp(x_cf @ coeffs) + ratio = vn / vn_baseline + print(f" ${cf_tvl:>13,.0f} ${vn:>11,.0f} ${vn/1440:>11,.0f}" + f" {ratio:>11.1f}x") + + print(f"\n Key finding: model predicts {vtotal_ratio:.1f}x total volume" + f" increase vs {vol_ratio:.1f}x observed ({vtotal_ratio/vol_ratio:.0%} accuracy).") + print(f" V_arb ({varb_ratio:.0f}x) carries most of the response;" + f" V_noise ({vnoise_ratio:.1f}x) is secondary but adds up.") + + +if __name__ == "__main__": + main() diff --git a/experiments/verify_vol_volume_slope.py b/experiments/verify_vol_volume_slope.py new file mode 100644 index 0000000..1758bd7 --- /dev/null +++ b/experiments/verify_vol_volume_slope.py @@ -0,0 +1,372 @@ +"""Verify: is the volatility-volume slope identical across fee tiers? + +The claim (from noise_calibration_review.md): "the relationship between +price volatility and swap volume is identical across fee tiers (slope 0.91 +for both low-fee and high-fee pools)." + +This script tests the claim by: +1. Loading all pool panel data +2. Splitting pools by fee tier (low vs high) +3. Regressing log(volume) on log(volatility) within each group +4. Comparing slopes + +If the slopes are similar, it means volatility drives organic volume +identically regardless of fee — supporting the arb/noise decomposition +(since arb intensity differs across fee tiers but noise doesn't). + +Usage: + python experiments/verify_vol_volume_slope.py +""" + +import os +import pickle +import sys + +import numpy as np +import pandas as pd + +CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "token_factored_calibration", "_cache", +) + + +def main(): + path = os.path.join(CACHE_DIR, "stage1.pkl") + if not os.path.exists(path): + print("ERROR: no stage1 cache.") + sys.exit(1) + with open(path, "rb") as f: + data = pickle.load(f) + matched_clean = data["matched_clean"] + + # Collect per-observation data + rows = [] + for pid, entry in matched_clean.items(): + panel = entry["panel"] + fee = entry.get("fee", np.exp(panel["log_fee"].values[0])) + chain = entry.get("chain", "unknown") + tokens = entry.get("tokens", "?") + + log_vol = panel["log_volume"].values.astype(float) + vol_raw = panel["volatility"].values.astype(float) + log_tvl = panel["log_tvl_lag1"].values.astype(float) + + for i in range(len(log_vol)): + if vol_raw[i] > 1e-10 and np.isfinite(log_vol[i]): + rows.append({ + "pool_id": pid, + "tokens": tokens, + "chain": chain, + "fee": fee, + "log_fee": np.log(fee), + "log_volume": log_vol[i], + "log_sigma": np.log(max(vol_raw[i], 1e-10)), + "log_tvl": log_tvl[i] if np.isfinite(log_tvl[i]) else np.nan, + }) + + df = pd.DataFrame(rows).dropna() + print(f"Loaded {len(df)} observations from {df['pool_id'].nunique()} pools") + + # Fee tier split + fees = df.groupby("pool_id")["fee"].first() + median_fee = fees.median() + print(f"\nFee distribution:") + print(f" min={fees.min():.5f} median={median_fee:.5f} max={fees.max():.5f}") + print(f" Unique fees: {sorted(fees.unique())}") + + low_fee_pools = set(fees[fees <= median_fee].index) + high_fee_pools = set(fees[fees > median_fee].index) + print(f" Low-fee pools (≤{median_fee:.4f}): {len(low_fee_pools)}") + print(f" High-fee pools (>{median_fee:.4f}): {len(high_fee_pools)}") + + df_low = df[df["pool_id"].isin(low_fee_pools)] + df_high = df[df["pool_id"].isin(high_fee_pools)] + + # OLS: log_volume ~ intercept + log_sigma + def ols_slope(x, y): + X = np.column_stack([np.ones(len(x)), x]) + beta = np.linalg.lstsq(X, y, rcond=None)[0] + y_hat = X @ beta + ss_res = np.sum((y - y_hat) ** 2) + ss_tot = np.sum((y - y.mean()) ** 2) + r2 = 1 - ss_res / ss_tot + # Standard error of slope + n = len(x) + se = np.sqrt(ss_res / (n - 2) / np.sum((x - x.mean()) ** 2)) + return beta[1], beta[0], r2, se + + print(f"\n{'='*60}") + print("OLS: log(volume) ~ intercept + log(sigma)") + print(f"{'='*60}") + + # All pools + slope, intercept, r2, se = ols_slope(df["log_sigma"].values, df["log_volume"].values) + print(f"\n All pools ({len(df)} obs):") + print(f" slope = {slope:.4f} ± {1.96*se:.4f} (95% CI)") + print(f" R² = {r2:.4f}") + + # Low fee + slope_l, int_l, r2_l, se_l = ols_slope(df_low["log_sigma"].values, df_low["log_volume"].values) + print(f"\n Low-fee pools ({len(df_low)} obs, {len(low_fee_pools)} pools):") + print(f" slope = {slope_l:.4f} ± {1.96*se_l:.4f}") + print(f" R² = {r2_l:.4f}") + + # High fee + slope_h, int_h, r2_h, se_h = ols_slope(df_high["log_sigma"].values, df_high["log_volume"].values) + print(f"\n High-fee pools ({len(df_high)} obs, {len(high_fee_pools)} pools):") + print(f" slope = {slope_h:.4f} ± {1.96*se_h:.4f}") + print(f" R² = {r2_h:.4f}") + + print(f"\n Difference: {abs(slope_l - slope_h):.4f}") + print(f" Ratio: {slope_l/slope_h:.3f}") + + # Per-pool slopes + print(f"\n{'='*60}") + print("Per-pool OLS slopes") + print(f"{'='*60}") + pool_slopes = [] + print(f"\n {'Pool':>16s} {'Tokens':>20s} {'Fee':>8s} {'Slope':>8s}" + f" {'R²':>6s} {'N':>5s}") + for pid in sorted(matched_clean.keys()): + pool_df = df[df["pool_id"] == pid] + if len(pool_df) < 20: + continue + s, _, r, se = ols_slope(pool_df["log_sigma"].values, pool_df["log_volume"].values) + fee = pool_df["fee"].iloc[0] + tokens = pool_df["tokens"].iloc[0] + pool_slopes.append({"pool_id": pid, "tokens": tokens, "fee": fee, + "slope": s, "r2": r, "n": len(pool_df)}) + print(f" {pid[:16]} {tokens:>20s} {fee:>8.5f} {s:>8.4f}" + f" {r:>6.3f} {len(pool_df):>5d}") + + ps = pd.DataFrame(pool_slopes) + if len(ps) > 0: + low_slopes = ps[ps["fee"] <= median_fee]["slope"] + high_slopes = ps[ps["fee"] > median_fee]["slope"] + print(f"\n Per-pool slope summary:") + print(f" Low-fee: median={low_slopes.median():.4f}," + f" mean={low_slopes.mean():.4f} (n={len(low_slopes)})") + print(f" High-fee: median={high_slopes.median():.4f}," + f" mean={high_slopes.mean():.4f} (n={len(high_slopes)})") + + # Also try with TVL control: log_volume ~ log_sigma + log_tvl + print(f"\n{'='*60}") + print("OLS: log(volume) ~ intercept + log(sigma) + log(tvl)") + print(f"{'='*60}") + + def ols_multi(df_sub): + x = np.column_stack([ + np.ones(len(df_sub)), + df_sub["log_sigma"].values, + df_sub["log_tvl"].values, + ]) + y = df_sub["log_volume"].values + beta = np.linalg.lstsq(x, y, rcond=None)[0] + y_hat = x @ beta + ss_res = np.sum((y - y_hat) ** 2) + ss_tot = np.sum((y - y.mean()) ** 2) + return beta, 1 - ss_res / ss_tot + + beta_all, r2_all = ols_multi(df) + print(f"\n All: σ_slope={beta_all[1]:.4f}, tvl_slope={beta_all[2]:.4f}, R²={r2_all:.4f}") + + beta_l, r2_l = ols_multi(df_low) + print(f" Low: σ_slope={beta_l[1]:.4f}, tvl_slope={beta_l[2]:.4f}, R²={r2_l:.4f}") + + beta_h, r2_h = ols_multi(df_high) + print(f" High: σ_slope={beta_h[1]:.4f}, tvl_slope={beta_h[2]:.4f}, R²={r2_h:.4f}") + + print(f"\n σ slope difference (low-high): {beta_l[1] - beta_h[1]:.4f}") + + + # ---- Volume/TVL vs TVL (cross-pool) ---- + print(f"\n{'='*60}") + print("Volume/TVL vs TVL (cross-pool)") + print(f"{'='*60}") + + # Per-pool median volume and TVL + pool_stats = [] + for pid in sorted(matched_clean.keys()): + pool_df = df[df["pool_id"] == pid] + if len(pool_df) < 20: + continue + med_vol = np.exp(np.median(pool_df["log_volume"].values)) + med_tvl = np.exp(np.median(pool_df["log_tvl"].values)) + tokens = pool_df["tokens"].iloc[0] + fee = pool_df["fee"].iloc[0] + vol_tvl = med_vol / med_tvl + pool_stats.append({ + "pool_id": pid, "tokens": tokens, "fee": fee, + "med_vol": med_vol, "med_tvl": med_tvl, + "vol_tvl_pct": vol_tvl * 100, + "log_med_tvl": np.log(med_tvl), + }) + + ps = pd.DataFrame(pool_stats) + print(f"\n {'Pool':>16s} {'Tokens':>20s} {'TVL':>14s} {'Vol/day':>14s} {'Vol/TVL':>8s}") + for _, row in ps.sort_values("med_tvl").iterrows(): + print(f" {row['pool_id'][:16]} {row['tokens']:>20s}" + f" ${row['med_tvl']:>13,.0f} ${row['med_vol']:>13,.0f}" + f" {row['vol_tvl_pct']:>7.1f}%") + + # OLS: log(vol/tvl) ~ log(tvl) + log_vol_tvl = np.log(ps["med_vol"].values / ps["med_tvl"].values) + log_tvl_vals = ps["log_med_tvl"].values + slope_vt, int_vt, r2_vt, se_vt = ols_slope(log_tvl_vals, log_vol_tvl) + print(f"\n OLS: log(Vol/TVL) ~ log(TVL)") + print(f" slope = {slope_vt:.4f} ± {1.96*se_vt:.4f}") + print(f" R² = {r2_vt:.4f}") + print(f" (slope < 0 means Vol/TVL declines with TVL)") + + # Equivalent: log(Vol) ~ α + β*log(TVL), β < 1 means sublinear + slope_v, int_v, r2_v, se_v = ols_slope(log_tvl_vals, np.log(ps["med_vol"].values)) + print(f"\n OLS: log(Vol) ~ log(TVL)") + print(f" slope = {slope_v:.4f} ± {1.96*se_v:.4f}") + print(f" R² = {r2_v:.4f}") + print(f" (slope < 1 means sublinear = Vol/TVL declines)") + + # ---- TVL elasticity by TVL quartile (MM signature) ---- + print(f"\n{'='*60}") + print("TVL Elasticity by TVL Quartile") + print(f"{'='*60}") + + # Use observation-level data, not pool medians — more power + # Within-quartile regression: log(vol) ~ log(tvl) for pools in each bin + ps_sorted = ps.sort_values("med_tvl") + n_q = len(ps_sorted) // 4 + quartiles = [] + for q in range(4): + start = q * n_q + end = (q + 1) * n_q if q < 3 else len(ps_sorted) + q_pools = set(ps_sorted.iloc[start:end]["pool_id"]) + q_df = df[df["pool_id"].isin(q_pools)] + if len(q_df) < 20: + continue + s, intercept, r2, se = ols_slope(q_df["log_tvl"].values, + q_df["log_volume"].values) + tvl_lo = np.exp(q_df["log_tvl"].min()) + tvl_hi = np.exp(q_df["log_tvl"].max()) + tvl_med = np.exp(q_df["log_tvl"].median()) + quartiles.append({ + "q": q + 1, "n_pools": len(q_pools), "n_obs": len(q_df), + "tvl_lo": tvl_lo, "tvl_hi": tvl_hi, "tvl_med": tvl_med, + "slope": s, "se": se, "r2": r2, + }) + print(f"\n Q{q+1}: TVL ${tvl_lo:,.0f} – ${tvl_hi:,.0f}" + f" (median ${tvl_med:,.0f})") + print(f" {len(q_pools)} pools, {len(q_df)} obs") + print(f" slope = {s:.4f} ± {1.96*se:.4f} R² = {r2:.4f}") + + if len(quartiles) >= 2: + print(f"\n Summary:") + print(f" {'Quartile':>10s} {'Med TVL':>14s} {'Slope':>8s} {'95% CI':>16s}") + for q in quartiles: + ci = f"[{q['slope']-1.96*q['se']:.3f}, {q['slope']+1.96*q['se']:.3f}]" + print(f" Q{q['q']:>9d} ${q['tvl_med']:>13,.0f} {q['slope']:>8.4f} {ci:>16s}") + + slope_q1 = quartiles[0]["slope"] + slope_q4 = quartiles[-1]["slope"] + print(f"\n Q1→Q4 slope change: {slope_q4 - slope_q1:+.4f}") + print(f" (Negative = elasticity declines with TVL = MM signature)") + + # Also try: rolling window across pools sorted by TVL + print(f"\n Rolling 10-pool window:") + print(f" {'Window':>8s} {'Med TVL':>14s} {'Slope':>8s} {'R²':>6s}") + window = 10 + for start_i in range(0, len(ps_sorted) - window + 1, 3): + w_pools = set(ps_sorted.iloc[start_i:start_i + window]["pool_id"]) + w_df = df[df["pool_id"].isin(w_pools)] + if len(w_df) < 30: + continue + s, _, r2, se = ols_slope(w_df["log_tvl"].values, + w_df["log_volume"].values) + tvl_med = np.exp(w_df["log_tvl"].median()) + print(f" {start_i:>3d}-{start_i+window:>3d} ${tvl_med:>13,.0f} {s:>8.4f} {r2:>6.3f}") + + # Plot + try: + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + fig, axes = plt.subplots(1, 3, figsize=(16, 5)) + + # Panel 1: Vol/TVL vs TVL + ax = axes[0] + ax.scatter(ps["med_tvl"] / 1e6, ps["vol_tvl_pct"], + s=30, alpha=0.7, c="steelblue") + for _, row in ps.iterrows(): + ax.annotate(row["tokens"].split(",")[0], + (row["med_tvl"] / 1e6, row["vol_tvl_pct"]), + fontsize=5, alpha=0.6) + ax.set_xscale("log") + ax.set_xlabel("Median TVL ($M)") + ax.set_ylabel("Median Vol/TVL (%)") + ax.set_title(f"Volume/TVL Declines with TVL\n" + f"log(Vol/TVL) ~ {slope_vt:.2f}·log(TVL), R²={r2_vt:.2f}") + ax.grid(True, alpha=0.3) + + # Fit line + tvl_fit = np.logspace(np.log10(ps["med_tvl"].min()), + np.log10(ps["med_tvl"].max()), 100) + vol_tvl_fit = np.exp(int_vt + slope_vt * np.log(tvl_fit)) * 100 + ax.plot(tvl_fit / 1e6, vol_tvl_fit, "r--", linewidth=1, alpha=0.7) + + # Panel 2: Vol vs TVL (log-log) + ax = axes[1] + ax.scatter(ps["med_tvl"] / 1e6, ps["med_vol"] / 1e6, + s=30, alpha=0.7, c="coral") + for _, row in ps.iterrows(): + ax.annotate(row["tokens"].split(",")[0], + (row["med_tvl"] / 1e6, row["med_vol"] / 1e6), + fontsize=5, alpha=0.6) + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_xlabel("Median TVL ($M)") + ax.set_ylabel("Median Daily Volume ($M)") + ax.set_title(f"Volume vs TVL (cross-pool)\n" + f"log(Vol) ~ {slope_v:.2f}·log(TVL), R²={r2_v:.2f}") + ax.grid(True, alpha=0.3) + + # Fit line + linear reference + vol_fit = np.exp(int_v + slope_v * np.log(tvl_fit)) + ax.plot(tvl_fit / 1e6, vol_fit / 1e6, "r--", linewidth=1, + alpha=0.7, label=f"slope={slope_v:.2f}") + # Linear reference (slope=1) + vol_linear = np.exp(int_v + 1.0 * np.log(tvl_fit)) + ax.plot(tvl_fit / 1e6, vol_linear / 1e6, "k:", linewidth=0.5, + alpha=0.3, label="slope=1 (linear)") + ax.legend(fontsize=8) + + # Panel 3: by fee tier + ax = axes[2] + for _, row in ps.iterrows(): + color = "steelblue" if row["fee"] <= median_fee else "coral" + ax.scatter(row["med_tvl"] / 1e6, row["vol_tvl_pct"], + s=30, alpha=0.7, c=color) + ax.annotate(row["tokens"].split(",")[0], + (row["med_tvl"] / 1e6, row["vol_tvl_pct"]), + fontsize=5, alpha=0.6) + ax.set_xscale("log") + ax.set_xlabel("Median TVL ($M)") + ax.set_ylabel("Median Vol/TVL (%)") + ax.set_title("Vol/TVL by Fee Tier\n" + f"blue=low fee (≤{median_fee:.4f}), red=high fee") + ax.grid(True, alpha=0.3) + + fig.suptitle("Cross-Pool Evidence for Volume Saturation", fontsize=13) + fig.tight_layout() + out = os.path.join(os.path.dirname(os.path.dirname(__file__)), + "results", "mm_noise", "plots", + "cross_pool_vol_tvl.png") + os.makedirs(os.path.dirname(out), exist_ok=True) + fig.savefig(out, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"\n Saved: {out}") + except Exception as e: + print(f" Plot failed: {e}") + + +if __name__ == "__main__": + main() diff --git a/quantammsim/calibration/__init__.py b/quantammsim/calibration/__init__.py index a888758..48a556a 100644 --- a/quantammsim/calibration/__init__.py +++ b/quantammsim/calibration/__init__.py @@ -39,3 +39,15 @@ build_x_obs, match_grids_to_panel, ) +from quantammsim.calibration.calibration_model import CalibrationModel +from quantammsim.calibration.heads import ( + FixedHead, + Head, + LinearHead, + MLPHead, + MLPNoiseHead, + PerPoolHead, + PerPoolNoiseHead, + SharedLinearNoiseHead, +) +from quantammsim.calibration.loss import _compute_loss_huber diff --git a/quantammsim/calibration/calibration_model.py b/quantammsim/calibration/calibration_model.py new file mode 100644 index 0000000..86471ec --- /dev/null +++ b/quantammsim/calibration/calibration_model.py @@ -0,0 +1,358 @@ +"""Composable CalibrationModel with pluggable Head components. + +The CalibrationModel coordinates three heads (cadence, gas, noise) and +provides: + - Parameter packing/unpacking across all heads + - Per-pool JIT-compiled loss closures (same pattern as existing code) + - Joint loss aggregation with head regularization + - scipy L-BFGS-B fitting for both per-pool and joint modes + - Prediction for new pools + +All heads are concatenated in order [cadence | gas | noise] in the flat +parameter vector. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Callable, Dict, Optional + +import jax +import jax.numpy as jnp +import numpy as np +import scipy.optimize + +from quantammsim.calibration.grid_interpolation import interpolate_pool_daily +from quantammsim.calibration.heads import Head +from quantammsim.calibration.loss import K_OBS, noise_volume + + +@dataclass +class CalibrationModel: + """Composable calibration model with pluggable heads. + + Coordinates cadence_head, gas_head, and noise_head to build a single + flat parameter vector and produce per-pool JIT-compiled loss functions. + """ + + cadence_head: Head + gas_head: Head + noise_head: Head + loss_type: str = "l2" + huber_delta: float = 1.5 + + # ── Parameter geometry ───────────────────────────────────────────── + + def n_params(self, n_pools: int, k_attr: int) -> int: + """Total parameter count across all heads.""" + return ( + self.cadence_head.n_params(n_pools, k_attr) + + self.gas_head.n_params(n_pools, k_attr) + + self.noise_head.n_params(n_pools, k_attr) + ) + + def _head_slices(self, n_pools: int, k_attr: int): + """Return (start, end) index pairs for each head's param slice.""" + n_cad = self.cadence_head.n_params(n_pools, k_attr) + n_gas = self.gas_head.n_params(n_pools, k_attr) + n_noise = self.noise_head.n_params(n_pools, k_attr) + cad_end = n_cad + gas_end = cad_end + n_gas + noise_end = gas_end + n_noise + return (0, cad_end), (cad_end, gas_end), (gas_end, noise_end) + + # ── Initialization ───────────────────────────────────────────────── + + def pack_init(self, jdata, warm_start=None) -> np.ndarray: + """Concatenate head inits into a single flat NumPy vector.""" + cad_init = self.cadence_head.init(jdata, warm_start) + gas_init = self.gas_head.init(jdata, warm_start) + noise_init = self.noise_head.init(jdata, warm_start) + return np.concatenate([cad_init, gas_init, noise_init]) + + # ── Bounds ───────────────────────────────────────────────────────── + + def make_bounds(self, n_pools: int, k_attr: int) -> list: + """Concatenate per-head scipy bounds.""" + return ( + self.cadence_head.make_bounds(n_pools, k_attr) + + self.gas_head.make_bounds(n_pools, k_attr) + + self.noise_head.make_bounds(n_pools, k_attr) + ) + + # ── Loss functions ───────────────────────────────────────────────── + + def _compute_loss(self, residuals: jnp.ndarray) -> jnp.ndarray: + """Compute loss from residuals based on loss_type.""" + if self.loss_type == "huber": + delta = self.huber_delta + abs_r = jnp.abs(residuals) + huber = jnp.where( + abs_r <= delta, + 0.5 * residuals ** 2, + delta * (abs_r - 0.5 * delta), + ) + return jnp.mean(huber) + return jnp.mean(residuals ** 2) + + def make_pool_loss_fn( + self, + pool_idx: int, + pool_data_i: dict, + x_attr_i: jnp.ndarray, + n_pools: int, + k_attr: int, + ) -> Callable: + """Create a JIT-compiled loss function for a single pool. + + Closes over pool-specific data. Takes params_flat as sole argument. + Returns scalar loss (no regularization — that's added at aggregate level). + """ + coeffs = pool_data_i["coeffs"] + x_obs = pool_data_i["x_obs"] + y_obs = pool_data_i["y_obs"] + day_indices = pool_data_i["day_indices"] + + (cad_s, cad_e), (gas_s, gas_e), (noise_s, noise_e) = \ + self._head_slices(n_pools, k_attr) + + cad_head = self.cadence_head + gas_head = self.gas_head + noise_head = self.noise_head + compute_loss = self._compute_loss + i = pool_idx + + @jax.jit + def pool_loss_fn(params_flat): + cad_slice = params_flat[cad_s:cad_e] + gas_slice = params_flat[gas_s:gas_e] + noise_slice = params_flat[noise_s:noise_e] + + log_cad = cad_head.predict(cad_slice, i, x_attr_i) + log_gas = gas_head.predict(gas_slice, i, x_attr_i) + noise_c = noise_head.predict(noise_slice, i, x_attr_i) + + v_arb_all = interpolate_pool_daily( + coeffs, log_cad, jnp.exp(log_gas) + ) + v_arb = v_arb_all[day_indices] + v_noise = jnp.exp(x_obs @ noise_c) + log_v_pred = jnp.log(jnp.maximum(v_arb + v_noise, 1e-6)) + + return compute_loss(log_v_pred - y_obs) + + return pool_loss_fn + + def make_joint_loss_fn(self, jdata) -> Callable: + """Create the joint loss function over all pools. + + Returns loss_fn(params_flat) -> scalar. Also attaches helper + attributes for the scipy wrapper (_pool_val_and_grad_fns, etc.). + """ + n_pools = len(jdata.pool_data) + k_attr = jdata.x_attr.shape[1] + + pool_loss_fns = [] + pool_val_and_grad_fns = [] + for i in range(n_pools): + fn = self.make_pool_loss_fn( + i, jdata.pool_data[i], jdata.x_attr[i], n_pools, k_attr + ) + pool_loss_fns.append(fn) + pool_val_and_grad_fns.append(jax.value_and_grad(fn)) + + (cad_s, cad_e), (gas_s, gas_e), (noise_s, noise_e) = \ + self._head_slices(n_pools, k_attr) + + cad_head = self.cadence_head + gas_head = self.gas_head + noise_head = self.noise_head + + def loss_fn(params_flat): + total = sum(fn(params_flat) for fn in pool_loss_fns) + data_loss = total / n_pools + + reg = cad_head.regularization(params_flat[cad_s:cad_e]) + reg = reg + gas_head.regularization(params_flat[gas_s:gas_e]) + reg = reg + noise_head.regularization(params_flat[noise_s:noise_e]) + + return data_loss + reg + + # Attach for the scipy wrapper + loss_fn._pool_loss_fns = pool_loss_fns + loss_fn._pool_val_and_grad_fns = pool_val_and_grad_fns + loss_fn._n_pools = n_pools + loss_fn._head_slices = (cad_s, cad_e), (gas_s, gas_e), (noise_s, noise_e) + loss_fn._cad_head = cad_head + loss_fn._gas_head = gas_head + loss_fn._noise_head = noise_head + + return loss_fn + + # ── Fitting ──────────────────────────────────────────────────────── + + def fit( + self, + jdata, + maxiter: int = 500, + warm_start: Optional[Dict[str, dict]] = None, + ) -> dict: + """Fit the model on joint data via L-BFGS-B. + + Returns a result dict with fitted parameters and diagnostics. + """ + n_pools = len(jdata.pool_data) + k_attr = jdata.x_attr.shape[1] + + loss_fn = self.make_joint_loss_fn(jdata) + init = self.pack_init(jdata, warm_start) + bounds = self.make_bounds(n_pools, k_attr) + + pool_vg_fns = loss_fn._pool_val_and_grad_fns + (cad_s, cad_e), (gas_s, gas_e), (noise_s, noise_e) = \ + loss_fn._head_slices + + cad_head = self.cadence_head + gas_head = self.gas_head + noise_head = self.noise_head + + def scipy_wrapper(params_np): + params_j = jnp.array(params_np) + + total_val = 0.0 + total_grad = jnp.zeros_like(params_j) + for vg_fn in pool_vg_fns: + v, g = vg_fn(params_j) + total_val += float(v) + total_grad = total_grad + g + + data_loss = total_val / n_pools + data_grad = total_grad / n_pools + + # Regularization — compute value and gradient + reg_val = 0.0 + reg_grad = jnp.zeros_like(params_j) + + # Cadence head regularization + cad_slice = params_j[cad_s:cad_e] + if cad_e > cad_s: + cad_reg_fn = lambda p: cad_head.regularization(p) + cr = float(cad_reg_fn(cad_slice)) + if cr != 0.0: + cad_rg = jax.grad(cad_reg_fn)(cad_slice) + reg_val += cr + reg_grad = reg_grad.at[cad_s:cad_e].set(cad_rg) + + # Gas head regularization + gas_slice = params_j[gas_s:gas_e] + if gas_e > gas_s: + gas_reg_fn = lambda p: gas_head.regularization(p) + gr = float(gas_reg_fn(gas_slice)) + if gr != 0.0: + gas_rg = jax.grad(gas_reg_fn)(gas_slice) + reg_val += gr + reg_grad = reg_grad.at[gas_s:gas_e].set(gas_rg) + + # Noise head regularization + noise_slice = params_j[noise_s:noise_e] + if noise_e > noise_s: + noise_reg_fn = lambda p: noise_head.regularization(p) + nr = float(noise_reg_fn(noise_slice)) + if nr != 0.0: + noise_rg = jax.grad(noise_reg_fn)(noise_slice) + reg_val += nr + reg_grad = reg_grad.at[noise_s:noise_e].set(noise_rg) + + val = data_loss + reg_val + grad = data_grad + reg_grad + return val, np.array(grad, dtype=np.float64) + + init_np = np.array(init, dtype=np.float64) + init_loss = float(loss_fn(jnp.array(init_np))) + + result = scipy.optimize.minimize( + scipy_wrapper, + init_np, + method="L-BFGS-B", + jac=True, + bounds=bounds, + options={"maxiter": maxiter, "ftol": 1e-10, "gtol": 1e-8}, + ) + + fitted = jnp.array(result.x) + (cad_s, cad_e), (gas_s, gas_e), (noise_s, noise_e) = \ + self._head_slices(n_pools, k_attr) + + # Compute data_loss and reg_loss at optimum + fitted_j = jnp.array(result.x) + data_loss_val = sum( + float(fn(fitted_j)) for fn in loss_fn._pool_loss_fns + ) / n_pools + reg_loss_val = float(result.fun) - data_loss_val + + out = { + "init_loss": init_loss, + "loss": float(result.fun), + "data_loss": data_loss_val, + "reg_loss": reg_loss_val, + "converged": result.success, + "params_flat": np.array(result.x), + } + + # Unpack each head's result + out.update(self.cadence_head.unpack_result( + np.array(fitted[cad_s:cad_e]), n_pools, k_attr)) + out.update(self.gas_head.unpack_result( + np.array(fitted[gas_s:gas_e]), n_pools, k_attr)) + out.update(self.noise_head.unpack_result( + np.array(fitted[noise_s:noise_e]), n_pools, k_attr)) + + out["pool_ids"] = jdata.pool_ids + out["attr_names"] = jdata.attr_names + out["k_attr"] = k_attr + out["n_pools"] = n_pools + + return out + + # ── Prediction ───────────────────────────────────────────────────── + + def predict_new_pool( + self, + result: dict, + x_attr: np.ndarray, + ) -> dict: + """Predict simulator settings for a new pool. + + Delegates to each head's predict_new. Heads that can't + generalize (PerPoolHead, FixedHead) will raise ValueError. + """ + n_pools = result["n_pools"] + k_attr = result["k_attr"] + params = result["params_flat"] + + (cad_s, cad_e), (gas_s, gas_e), (noise_s, noise_e) = \ + self._head_slices(n_pools, k_attr) + + log_cadence = self.cadence_head.predict_new( + params[cad_s:cad_e], x_attr + ) + log_gas = self.gas_head.predict_new( + params[gas_s:gas_e], x_attr + ) + + out = { + "log_cadence": float(log_cadence), + "log_gas": float(log_gas), + "cadence_minutes": float(np.exp(log_cadence)), + "gas_usd": float(np.exp(log_gas)), + } + + try: + noise_coeffs = self.noise_head.predict_new( + params[noise_s:noise_e], x_attr + ) + out["noise_coeffs"] = np.array(noise_coeffs) + except ValueError: + pass # PerPoolNoiseHead can't generalize + + return out diff --git a/quantammsim/calibration/heads.py b/quantammsim/calibration/heads.py new file mode 100644 index 0000000..b051315 --- /dev/null +++ b/quantammsim/calibration/heads.py @@ -0,0 +1,917 @@ +"""Pluggable Head components for the composable CalibrationModel. + +Each Head encapsulates a specific parameterization strategy (per-pool, +fixed, linear) for one of the three model components: cadence, gas, or noise. + +Heads define how many parameters they need, how to predict from a parameter +slice, and how to compute regularization. The CalibrationModel concatenates +head parameter slices into a single flat vector for scipy L-BFGS-B. +""" + +from __future__ import annotations + +from typing import Optional, Protocol, runtime_checkable + +import jax.numpy as jnp +import numpy as np + +from quantammsim.calibration.loss import K_OBS +from quantammsim.calibration.pool_data import ( + D_TOKEN, K_OBS_REDUCED, _canonicalize_token, _classify_token, + _load_token_mcaps, +) + + +# --------------------------------------------------------------------------- +# Protocol +# --------------------------------------------------------------------------- + + +@runtime_checkable +class Head(Protocol): + """Protocol that all head implementations must satisfy.""" + + name: str + + def n_params(self, n_pools: int, k_attr: int) -> int: + """Number of scalar parameters this head contributes.""" + ... + + def predict( + self, + params_slice: jnp.ndarray, + pool_idx: int, + x_attr_i: jnp.ndarray, + ) -> jnp.ndarray: + """Predict value(s) for *pool_idx* given its attribute vector. + + Called inside a JIT-compiled per-pool closure, so this must be + JAX-traceable. Returns a scalar for cadence/gas heads, or a + (K_OBS,) vector for noise heads. + """ + ... + + def regularization(self, params_slice: jnp.ndarray) -> jnp.ndarray: + """Scalar regularization penalty added to the joint loss.""" + ... + + def init( + self, + jdata, + warm_start: Optional[dict] = None, + ) -> np.ndarray: + """Return initial NumPy parameter vector (flat).""" + ... + + def predict_new( + self, + params_slice: np.ndarray, + x_attr: np.ndarray, + ) -> np.ndarray: + """Predict for a *new* pool not seen during training (NumPy).""" + ... + + def unpack_result( + self, + params_slice: np.ndarray, + n_pools: int, + k_attr: int, + ) -> dict: + """Convert the optimized parameter slice to human-readable dict.""" + ... + + def make_bounds(self, n_pools: int, k_attr: int) -> list: + """Scipy (lo, hi) bounds for each parameter.""" + ... + + +# --------------------------------------------------------------------------- +# PerPoolHead — one free scalar per pool (Option C cadence / gas) +# --------------------------------------------------------------------------- + + +class PerPoolHead: + """One free scalar parameter per pool. + + Used for Option C per-pool cadence or gas. + """ + + def __init__(self, name: str, default: float = 0.0): + self.name = name + self._default = default + + def n_params(self, n_pools: int, k_attr: int) -> int: + return n_pools + + def predict( + self, + params_slice: jnp.ndarray, + pool_idx: int, + x_attr_i: jnp.ndarray, + ) -> jnp.ndarray: + return params_slice[pool_idx] + + def regularization(self, params_slice: jnp.ndarray) -> jnp.ndarray: + return jnp.float32(0.0) + + def init(self, jdata, warm_start=None) -> np.ndarray: + n_pools = len(jdata.pool_data) + if warm_start is not None: + vals = [] + for pid in jdata.pool_ids: + if pid in warm_start and self.name in warm_start[pid]: + vals.append(warm_start[pid][self.name]) + else: + vals.append(self._default) + return np.array(vals, dtype=np.float64) + return np.full(n_pools, self._default, dtype=np.float64) + + def predict_new(self, params_slice, x_attr): + raise ValueError( + f"PerPoolHead('{self.name}') cannot predict for unseen pools" + ) + + def unpack_result(self, params_slice, n_pools, k_attr): + return {f"{self.name}_per_pool": np.array(params_slice)} + + def make_bounds(self, n_pools, k_attr): + return [(None, None)] * n_pools + + +# --------------------------------------------------------------------------- +# FixedHead — zero parameters, returns pre-set values +# --------------------------------------------------------------------------- + + +class FixedHead: + """Zero-parameter head that returns pre-set per-pool values. + + Used when gas is fixed to known chain-level costs. + """ + + def __init__(self, name: str, values: np.ndarray): + self.name = name + self._values = np.asarray(values, dtype=np.float64) + self._values_jax = jnp.array(self._values) + + def n_params(self, n_pools: int, k_attr: int) -> int: + return 0 + + def predict(self, params_slice, pool_idx, x_attr_i): + return self._values_jax[pool_idx] + + def regularization(self, params_slice): + return jnp.float32(0.0) + + def init(self, jdata, warm_start=None): + return np.array([], dtype=np.float64) + + def predict_new(self, params_slice, x_attr): + raise ValueError( + f"FixedHead('{self.name}') cannot predict for unseen pools — " + "values are pool-specific" + ) + + def unpack_result(self, params_slice, n_pools, k_attr): + return {f"{self.name}_fixed": np.array(self._values)} + + def make_bounds(self, n_pools, k_attr): + return [] + + +# --------------------------------------------------------------------------- +# LinearHead — bias + x_attr @ W (Option A cadence / gas) +# --------------------------------------------------------------------------- + + +class LinearHead: + """Linear mapping from pool attributes: bias + x_attr @ W. + + L2 regularization on W (not bias) with strength ``alpha``. + """ + + def __init__(self, name: str, alpha: float = 0.01, + output_lo: float = None, output_hi: float = None): + self.name = name + self.alpha = alpha + self.output_lo = output_lo + self.output_hi = output_hi + + def n_params(self, n_pools: int, k_attr: int) -> int: + return 1 + k_attr # bias + W + + def predict(self, params_slice, pool_idx, x_attr_i): + bias = params_slice[0] + W = params_slice[1:] + out = bias + jnp.dot(x_attr_i, W) + if self.output_lo is not None or self.output_hi is not None: + out = jnp.clip(out, self.output_lo, self.output_hi) + return out + + def regularization(self, params_slice): + W = params_slice[1:] + return self.alpha * jnp.sum(W ** 2) + + def init(self, jdata, warm_start=None): + k_attr = jdata.x_attr.shape[1] + n_pools = len(jdata.pool_data) + + if warm_start is not None: + # Fit linear regression from per-pool values + vals = [] + for pid in jdata.pool_ids: + if pid in warm_start and self.name in warm_start[pid]: + vals.append(warm_start[pid][self.name]) + else: + vals.append(self._default_bias()) + y = np.array(vals) + X_aug = np.column_stack([np.ones(n_pools), np.array(jdata.x_attr)]) + params, _, _, _ = np.linalg.lstsq(X_aug, y, rcond=None) + return params.astype(np.float64) + + init = np.zeros(1 + k_attr, dtype=np.float64) + init[0] = self._default_bias() + return init + + def _default_bias(self): + if "cad" in self.name: + return np.log(12.0) + elif "gas" in self.name: + return np.log(1.0) + return 0.0 + + def predict_new(self, params_slice, x_attr): + bias = params_slice[0] + W = params_slice[1:] + return bias + np.dot(x_attr, W) + + def unpack_result(self, params_slice, n_pools, k_attr): + return { + f"bias_{self.name}": float(params_slice[0]), + f"W_{self.name}": np.array(params_slice[1:]), + } + + def make_bounds(self, n_pools, k_attr): + return [(None, None)] * (1 + k_attr) + + +# --------------------------------------------------------------------------- +# PerPoolNoiseHead — K_OBS free coefficients per pool +# --------------------------------------------------------------------------- + + +class PerPoolNoiseHead: + """Per-pool noise coefficients: each pool has k_obs free parameters. + + Used for Option C noise or Option A with per-pool noise. + """ + + def __init__(self, alpha: float = 0.0, k_obs: int = None): + self.name = "noise" + self.alpha = alpha + self.k_obs = k_obs if k_obs is not None else K_OBS + + def n_params(self, n_pools: int, k_attr: int) -> int: + return n_pools * self.k_obs + + def predict(self, params_slice, pool_idx, x_attr_i): + start = pool_idx * self.k_obs + return params_slice[start:start + self.k_obs] + + def regularization(self, params_slice): + if self.alpha == 0.0: + return jnp.float32(0.0) + return self.alpha * jnp.sum(params_slice ** 2) + + def init(self, jdata, warm_start=None): + n_pools = len(jdata.pool_data) + + if warm_start is not None: + noise_all = np.zeros((n_pools, self.k_obs), dtype=np.float64) + for i, pid in enumerate(jdata.pool_ids): + if pid in warm_start and "noise_coeffs" in warm_start[pid]: + noise_all[i] = warm_start[pid]["noise_coeffs"] + return noise_all.ravel() + + noise_all = np.zeros((n_pools, self.k_obs), dtype=np.float64) + for i, pd in enumerate(jdata.pool_data): + x_obs_np = np.array(pd["x_obs"]) + y_obs_np = np.array(pd["y_obs"]) + c, _, _, _ = np.linalg.lstsq(x_obs_np, y_obs_np, rcond=None) + noise_all[i] = c + return noise_all.ravel() + + def predict_new(self, params_slice, x_attr): + raise ValueError( + "PerPoolNoiseHead cannot predict noise for unseen pools" + ) + + def unpack_result(self, params_slice, n_pools, k_attr): + return { + "noise_coeffs": np.array(params_slice).reshape(n_pools, self.k_obs), + } + + def make_bounds(self, n_pools, k_attr): + return [(None, None)] * (n_pools * self.k_obs) + + +# --------------------------------------------------------------------------- +# SharedLinearNoiseHead — bias_noise + x_attr @ W_noise +# --------------------------------------------------------------------------- + + +class SharedLinearNoiseHead: + """Shared linear mapping for noise: bias_noise + x_attr @ W_noise. + + Output is (k_obs,) noise coefficients, predicted from pool attributes. + L2 regularization on W_noise (not bias_noise). + """ + + def __init__(self, alpha: float = 0.01, k_obs: int = None): + self.name = "noise" + self.alpha = alpha + self.k_obs = k_obs if k_obs is not None else K_OBS + + def n_params(self, n_pools: int, k_attr: int) -> int: + return (1 + k_attr) * self.k_obs + + def predict(self, params_slice, pool_idx, x_attr_i): + k_attr = x_attr_i.shape[0] + W_full = params_slice.reshape(1 + k_attr, self.k_obs) + bias_noise = W_full[0] + W_noise = W_full[1:] + return bias_noise + jnp.dot(x_attr_i, W_noise) + + def regularization(self, params_slice): + W_full = params_slice.reshape(-1, self.k_obs) + W_noise = W_full[1:] + return self.alpha * jnp.sum(W_noise ** 2) + + def init(self, jdata, warm_start=None): + k_attr = jdata.x_attr.shape[1] + n_pools = len(jdata.pool_data) + + if warm_start is not None: + noise_all = np.zeros((n_pools, self.k_obs), dtype=np.float64) + for i, pid in enumerate(jdata.pool_ids): + if pid in warm_start and "noise_coeffs" in warm_start[pid]: + noise_all[i] = warm_start[pid]["noise_coeffs"] + X_aug = np.column_stack([np.ones(n_pools), np.array(jdata.x_attr)]) + params, _, _, _ = np.linalg.lstsq(X_aug, noise_all, rcond=None) + return params.ravel().astype(np.float64) + + all_x = np.vstack([np.array(pd["x_obs"]) for pd in jdata.pool_data]) + all_y = np.concatenate([np.array(pd["y_obs"]) for pd in jdata.pool_data]) + c, _, _, _ = np.linalg.lstsq(all_x, all_y, rcond=None) + params = np.zeros((1 + k_attr, self.k_obs), dtype=np.float64) + params[0, :] = c + return params.ravel() + + def predict_new(self, params_slice, x_attr): + k_attr = len(x_attr) + W_full = np.array(params_slice).reshape(1 + k_attr, self.k_obs) + bias_noise = W_full[0] + W_noise = W_full[1:] + return bias_noise + x_attr @ W_noise + + def unpack_result(self, params_slice, n_pools, k_attr): + W_full = np.array(params_slice).reshape(1 + k_attr, self.k_obs) + return { + "bias_noise": W_full[0], + "W_noise": W_full[1:], + } + + def make_bounds(self, n_pools, k_attr): + return [(None, None)] * ((1 + k_attr) * self.k_obs) + + +# --------------------------------------------------------------------------- +# MLPHead — x_attr → Dense(hidden, relu) → Dense(1) +# --------------------------------------------------------------------------- + + +class MLPHead: + """Two-layer MLP mapping from pool attributes to a scalar. + + Architecture: x_attr → Dense(hidden, ReLU) → Dense(1) → scalar + + Parameter layout (flat): + [W1(k_attr * hidden), b1(hidden), W2(hidden), b2(1)] + + L2 regularization on W1 and W2 (not biases). + + Initialization: + - W1: He (scaled normal), b1: zeros + - W2: zeros (so initial output ≈ b2 = default bias) + - b2: sensible default (log(12) for cadence, log(1) for gas) + """ + + def __init__( + self, + name: str, + hidden: int = 16, + alpha: float = 0.01, + seed: int = 0, + output_lo: float = None, + output_hi: float = None, + ): + self.name = name + self.hidden = hidden + self.alpha = alpha + self._seed = seed + self.output_lo = output_lo + self.output_hi = output_hi + + def n_params(self, n_pools: int, k_attr: int) -> int: + h = self.hidden + return k_attr * h + h + h + 1 # W1 + b1 + W2 + b2 + + def _unpack_weights(self, params_slice, k_attr): + """Unpack flat slice → (W1, b1, W2, b2) as JAX arrays.""" + h = self.hidden + idx = 0 + W1 = params_slice[idx:idx + k_attr * h].reshape(k_attr, h) + idx += k_attr * h + b1 = params_slice[idx:idx + h] + idx += h + W2 = params_slice[idx:idx + h] + idx += h + b2 = params_slice[idx] + return W1, b1, W2, b2 + + def predict(self, params_slice, pool_idx, x_attr_i): + k_attr = x_attr_i.shape[0] + W1, b1, W2, b2 = self._unpack_weights(params_slice, k_attr) + hidden = jnp.maximum(x_attr_i @ W1 + b1, 0.0) # ReLU + out = hidden @ W2 + b2 + if self.output_lo is not None or self.output_hi is not None: + out = jnp.clip(out, self.output_lo, self.output_hi) + return out + + def regularization(self, params_slice): + # Regularize W1 and W2, not biases + # We can't call _unpack_weights without k_attr, so compute + # the total weight norm from the full slice minus biases. + # Layout: [W1(k*h), b1(h), W2(h), b2(1)] + # But we don't know k_attr here. Use a simpler approach: + # regularize the entire slice — biases are small relative to + # weights and the approximation error is negligible. + # Actually, let's extract properly by computing h from params. + h = self.hidden + total = params_slice.shape[0] + k_attr = (total - 2 * h - 1) // h + W1 = params_slice[:k_attr * h] + # b1 = params_slice[k_attr*h : k_attr*h + h] # skip + W2 = params_slice[k_attr * h + h:k_attr * h + 2 * h] + # b2 = params_slice[-1] # skip + return self.alpha * (jnp.sum(W1 ** 2) + jnp.sum(W2 ** 2)) + + def init(self, jdata, warm_start=None): + k_attr = jdata.x_attr.shape[1] + n_pools = len(jdata.pool_data) + h = self.hidden + rng = np.random.RandomState(self._seed) + + # He initialization for W1 + std = np.sqrt(2.0 / k_attr) + W1 = rng.randn(k_attr, h).astype(np.float64) * std + b1 = np.zeros(h, dtype=np.float64) + + b2 = np.array([self._default_bias()], dtype=np.float64) + W2 = np.zeros(h, dtype=np.float64) + + if warm_start is not None: + vals = [] + for pid in jdata.pool_ids: + if pid in warm_start and self.name in warm_start[pid]: + vals.append(warm_start[pid][self.name]) + else: + vals.append(self._default_bias()) + y = np.array(vals) + b2 = np.array([np.mean(y)], dtype=np.float64) + + # Warm-start W2 by least-squares through hidden activations + # so the MLP init approximates the per-pool warm-start values + x_attr = np.array(jdata.x_attr) + H = np.maximum(x_attr @ W1 + b1, 0.0) # (n_pools, h) + residuals = y - float(b2) # what W2 needs to produce + W2, _, _, _ = np.linalg.lstsq(H, residuals, rcond=None) + + return np.concatenate([W1.ravel(), b1, W2, b2]) + + def _default_bias(self): + if "cad" in self.name: + return np.log(12.0) + elif "gas" in self.name: + return np.log(1.0) + return 0.0 + + def predict_new(self, params_slice, x_attr): + k_attr = len(x_attr) + W1, b1, W2, b2 = self._unpack_weights( + np.asarray(params_slice), k_attr + ) + hidden = np.maximum(x_attr @ W1 + b1, 0.0) + return float(hidden @ W2 + b2) + + def unpack_result(self, params_slice, n_pools, k_attr): + params_np = np.array(params_slice) + W1, b1, W2, b2 = self._unpack_weights(params_np, k_attr) + return { + f"mlp_{self.name}_W1": np.array(W1), + f"mlp_{self.name}_b1": np.array(b1), + f"mlp_{self.name}_W2": np.array(W2), + f"mlp_{self.name}_b2": float(b2), + } + + def make_bounds(self, n_pools, k_attr): + return [(None, None)] * self.n_params(n_pools, k_attr) + + +# --------------------------------------------------------------------------- +# TokenFactoredNoiseHead — additive token + chain + fee composition +# --------------------------------------------------------------------------- + + +class TokenFactoredNoiseHead: + """Noise coefficients from additive token + chain + fee composition. + + noise_coeffs_i = u[token_a_i] + u[token_b_i] + alpha[chain_i] + + beta_fee * log(fee_i) + delta_i + + Token effects u_t are regularized toward x_token_t @ Gamma (population + prediction from token covariates). Per-pool deltas are L2-regularized, + controlling the shrinkage between per-pool and population estimates. + + Parameter layout (flat): + [u (n_tokens * k_obs), + Gamma (d_token * k_obs), + alpha (n_chains * k_obs), + beta_fee (k_obs), + delta (n_pools * k_obs)] + """ + + def __init__( + self, + token_a_idx: np.ndarray, + token_b_idx: np.ndarray, + chain_idx: np.ndarray, + log_fees: np.ndarray, + x_token: np.ndarray, + n_tokens: int, + n_chains: int, + token_index: dict, + chain_index: dict, + k_obs: int = K_OBS_REDUCED, + lambda_delta: float = 1.0, + lambda_token: float = 0.1, + lambda_chain: float = 0.1, + lambda_fee: float = 0.01, + mcap_path: str = None, + ): + self.name = "noise" + self.token_a_idx = np.asarray(token_a_idx, dtype=np.int32) + self.token_b_idx = np.asarray(token_b_idx, dtype=np.int32) + self.chain_idx = np.asarray(chain_idx, dtype=np.int32) + self.log_fees = np.asarray(log_fees, dtype=np.float64) + self.x_token = np.asarray(x_token, dtype=np.float64) + self.n_tokens = n_tokens + self.n_chains = n_chains + self.d_token = x_token.shape[1] + self.k_obs = k_obs + self.token_index = dict(token_index) + self.chain_index = dict(chain_index) + self.lambda_delta = lambda_delta + self.lambda_token = lambda_token + self.lambda_chain = lambda_chain + self.lambda_fee = lambda_fee + self._mcap_path = mcap_path + # Pre-convert to JAX for predict() + self._token_a_jax = jnp.array(self.token_a_idx) + self._token_b_jax = jnp.array(self.token_b_idx) + self._chain_jax = jnp.array(self.chain_idx) + self._log_fees_jax = jnp.array(self.log_fees) + self._x_token_jax = jnp.array(self.x_token) + + def n_params(self, n_pools: int, k_attr: int) -> int: + k = self.k_obs + return (self.n_tokens * k # u + + self.d_token * k # Gamma + + self.n_chains * k # alpha + + k # beta_fee + + n_pools * k) # delta + + def _unpack(self, params_slice, n_pools): + k = self.k_obs + idx = 0 + u = params_slice[idx:idx + self.n_tokens * k].reshape(self.n_tokens, k) + idx += self.n_tokens * k + Gamma = params_slice[idx:idx + self.d_token * k].reshape(self.d_token, k) + idx += self.d_token * k + alpha = params_slice[idx:idx + self.n_chains * k].reshape(self.n_chains, k) + idx += self.n_chains * k + beta_fee = params_slice[idx:idx + k] + idx += k + delta = params_slice[idx:idx + n_pools * k].reshape(n_pools, k) + return u, Gamma, alpha, beta_fee, delta + + def _infer_n_pools(self, params_slice): + k = self.k_obs + n_shared = self.n_tokens * k + self.d_token * k + self.n_chains * k + k + return (params_slice.shape[0] - n_shared) // k + + def predict(self, params_slice, pool_idx, x_attr_i): + n_pools = self._infer_n_pools(params_slice) + u, Gamma, alpha, beta_fee, delta = self._unpack(params_slice, n_pools) + ta = self._token_a_jax[pool_idx] + tb = self._token_b_jax[pool_idx] + ch = self._chain_jax[pool_idx] + lf = self._log_fees_jax[pool_idx] + return u[ta] + u[tb] + alpha[ch] + beta_fee * lf + delta[pool_idx] + + def regularization(self, params_slice): + n_pools = self._infer_n_pools(params_slice) + u, Gamma, alpha, beta_fee, delta = self._unpack(params_slice, n_pools) + u_pred = self._x_token_jax @ Gamma + reg_token = self.lambda_token * jnp.sum((u - u_pred) ** 2) + reg_chain = self.lambda_chain * jnp.sum(alpha ** 2) + reg_fee = self.lambda_fee * jnp.sum(beta_fee ** 2) + reg_delta = self.lambda_delta * jnp.sum(delta ** 2) + return reg_token + reg_chain + reg_fee + reg_delta + + def init(self, jdata, warm_start=None): + n_pools = len(jdata.pool_data) + k = self.k_obs + + if warm_start is not None: + # Collect per-pool noise_coeffs from warm_start + noise_all = np.zeros((n_pools, k), dtype=np.float64) + for i, pid in enumerate(jdata.pool_ids): + if pid in warm_start and "noise_coeffs" in warm_start[pid]: + nc = np.asarray(warm_start[pid]["noise_coeffs"]) + n_copy = min(len(nc), k) + noise_all[i, :n_copy] = nc[:n_copy] + + # Solve: u[ta_i] + u[tb_i] + alpha[ch_i] + beta_fee * lf_i ≈ noise_all[i] + n_cols = self.n_tokens + self.n_chains + 1 + A = np.zeros((n_pools, n_cols), dtype=np.float64) + for i in range(n_pools): + A[i, self.token_a_idx[i]] = 1.0 + A[i, self.token_b_idx[i]] += 1.0 + A[i, self.n_tokens + self.chain_idx[i]] = 1.0 + A[i, -1] = self.log_fees[i] + + lam_reg = 0.1 + AtA = A.T @ A + lam_reg * np.eye(n_cols) + u_init = np.zeros((self.n_tokens, k)) + alpha_init = np.zeros((self.n_chains, k)) + beta_fee_init = np.zeros(k) + + for j in range(k): + sol = np.linalg.solve(AtA, A.T @ noise_all[:, j]) + u_init[:, j] = sol[:self.n_tokens] + alpha_init[:, j] = sol[self.n_tokens:self.n_tokens + self.n_chains] + beta_fee_init[j] = sol[-1] + + # Delta = residuals + predicted = np.zeros_like(noise_all) + for i in range(n_pools): + predicted[i] = (u_init[self.token_a_idx[i]] + + u_init[self.token_b_idx[i]] + + alpha_init[self.chain_idx[i]] + + beta_fee_init * self.log_fees[i]) + delta_init = noise_all - predicted + + # Gamma from post-hoc regression of u on x_token + Gamma_init, _, _, _ = np.linalg.lstsq( + self.x_token, u_init, rcond=None + ) + else: + # Cold start: pooled OLS for baseline, then decompose + all_x = np.vstack([np.array(pd["x_obs"]) for pd in jdata.pool_data]) + all_y = np.concatenate([np.array(pd["y_obs"]) for pd in jdata.pool_data]) + pooled_coeffs, _, _, _ = np.linalg.lstsq(all_x, all_y, rcond=None) + pooled_coeffs = pooled_coeffs[:k] + + u_init = np.tile(pooled_coeffs / 2.0, (self.n_tokens, 1)) + Gamma_init, _, _, _ = np.linalg.lstsq( + self.x_token, u_init, rcond=None + ) + alpha_init = np.zeros((self.n_chains, k)) + beta_fee_init = np.zeros(k) + delta_init = np.zeros((n_pools, k)) + + return np.concatenate([ + u_init.ravel(), + Gamma_init.ravel(), + alpha_init.ravel(), + beta_fee_init, + delta_init.ravel(), + ]).astype(np.float64) + + def predict_new(self, params_slice, x_attr): + raise ValueError( + "TokenFactoredNoiseHead.predict_new() requires token identifiers. " + "Use predict_new_pool(params, token_a, token_b, chain, fee) instead." + ) + + def predict_new_pool( + self, params_slice, token_a, token_b, chain, fee, n_pools, + ) -> dict: + """Predict noise coefficients for a new pool from token composition. + + Seen tokens use learned u_t. Unseen tokens fall back to x_t @ Gamma. + Unseen chains use alpha = zeros. No delta for new pools. + Input token names are canonicalized before lookup. + """ + params_np = np.asarray(params_slice) + u, Gamma, alpha, beta_fee, delta = self._unpack(params_np, n_pools) + u, Gamma, alpha, beta_fee = ( + np.array(u), np.array(Gamma), np.array(alpha), np.array(beta_fee) + ) + mcaps = _load_token_mcaps(self._mcap_path) + + # Canonicalize input tokens + token_a = _canonicalize_token(token_a) + token_b = _canonicalize_token(token_b) + + def _get_token_effect(token): + if token in self.token_index: + return u[self.token_index[token]] + x_t = np.zeros(self.d_token) + x_t[0] = 1.0 + cls = _classify_token(token, mcaps) + x_t[1] = cls["log_mcap"] + x_t[2] = cls["is_stable"] + x_t[3] = cls["is_eth_derivative"] + x_t[4] = cls["is_L1_native"] + return x_t @ Gamma + + u_a = _get_token_effect(token_a) + u_b = _get_token_effect(token_b) + + if chain in self.chain_index: + alpha_c = alpha[self.chain_index[chain]] + else: + alpha_c = np.zeros(self.k_obs) + + fee_effect = beta_fee * np.log(fee) + noise_coeffs = u_a + u_b + alpha_c + fee_effect + + return { + "noise_coeffs": noise_coeffs, + "components": { + "token_a": u_a, + "token_b": u_b, + "chain": alpha_c, + "fee": fee_effect, + }, + } + + def unpack_result(self, params_slice, n_pools, k_attr): + params_np = np.asarray(params_slice) + u, Gamma, alpha, beta_fee, delta = self._unpack(params_np, n_pools) + u, Gamma, alpha, beta_fee, delta = ( + np.array(u), np.array(Gamma), np.array(alpha), + np.array(beta_fee), np.array(delta), + ) + # Reconstruct per-pool noise_coeffs + noise_coeffs = np.zeros((n_pools, self.k_obs)) + for i in range(n_pools): + noise_coeffs[i] = (u[self.token_a_idx[i]] + u[self.token_b_idx[i]] + + alpha[self.chain_idx[i]] + + beta_fee * self.log_fees[i] + + delta[i]) + return { + "token_effects": u, + "Gamma": Gamma, + "chain_effects": alpha, + "beta_fee": beta_fee, + "noise_deltas": delta, + "noise_coeffs": noise_coeffs, + } + + def make_bounds(self, n_pools, k_attr): + return [(None, None)] * self.n_params(n_pools, k_attr) + + +# --------------------------------------------------------------------------- +# MLPNoiseHead — x_attr → Dense(hidden, relu) → Dense(K_OBS) +# --------------------------------------------------------------------------- + + +class MLPNoiseHead: + """Two-layer MLP mapping from pool attributes to noise coefficients. + + Architecture: x_attr → Dense(hidden, ReLU) → Dense(k_obs) + + Parameter layout (flat): + [W1(k_attr * hidden), b1(hidden), W2(hidden * k_obs), b2(k_obs)] + + L2 regularization on W1 and W2 (not biases). + + Initialization: + - W1: He (scaled normal), b1: zeros + - W2: zeros (so initial output = b2 = pooled OLS noise coefficients) + - b2: pooled OLS noise from training data + """ + + def __init__( + self, + hidden: int = 16, + alpha: float = 0.01, + seed: int = 0, + k_obs: int = None, + ): + self.name = "noise" + self.hidden = hidden + self.alpha = alpha + self._seed = seed + self.k_obs = k_obs if k_obs is not None else K_OBS + + def n_params(self, n_pools: int, k_attr: int) -> int: + h = self.hidden + return k_attr * h + h + h * self.k_obs + self.k_obs + + def _unpack_weights(self, params_slice, k_attr): + """Unpack flat slice → (W1, b1, W2, b2).""" + h = self.hidden + ko = self.k_obs + idx = 0 + W1 = params_slice[idx:idx + k_attr * h].reshape(k_attr, h) + idx += k_attr * h + b1 = params_slice[idx:idx + h] + idx += h + W2 = params_slice[idx:idx + h * ko].reshape(h, ko) + idx += h * ko + b2 = params_slice[idx:idx + ko] + return W1, b1, W2, b2 + + def predict(self, params_slice, pool_idx, x_attr_i): + k_attr = x_attr_i.shape[0] + W1, b1, W2, b2 = self._unpack_weights(params_slice, k_attr) + hidden = jnp.maximum(x_attr_i @ W1 + b1, 0.0) # ReLU + return hidden @ W2 + b2 # (k_obs,) + + def regularization(self, params_slice): + h = self.hidden + ko = self.k_obs + total = params_slice.shape[0] + # Solve for k_attr: total = k*h + h + h*ko + ko + k_attr = (total - h - h * ko - ko) // h + W1 = params_slice[:k_attr * h] + W2 = params_slice[k_attr * h + h:k_attr * h + h + h * ko] + return self.alpha * (jnp.sum(W1 ** 2) + jnp.sum(W2 ** 2)) + + def init(self, jdata, warm_start=None): + k_attr = jdata.x_attr.shape[1] + n_pools = len(jdata.pool_data) + h = self.hidden + ko = self.k_obs + rng = np.random.RandomState(self._seed) + + # He initialization for W1 + std = np.sqrt(2.0 / k_attr) + W1 = rng.randn(k_attr, h).astype(np.float64) * std + b1 = np.zeros(h, dtype=np.float64) + + W2 = np.zeros((h, ko), dtype=np.float64) + + if warm_start is not None: + noise_all = np.zeros((n_pools, ko), dtype=np.float64) + for i, pid in enumerate(jdata.pool_ids): + if pid in warm_start and "noise_coeffs" in warm_start[pid]: + noise_all[i] = warm_start[pid]["noise_coeffs"] + b2 = np.mean(noise_all, axis=0) + + # Warm-start W2 by least-squares through hidden activations + x_attr = np.array(jdata.x_attr) + H = np.maximum(x_attr @ W1 + b1, 0.0) # (n_pools, h) + residuals = noise_all - b2 # (n_pools, ko) + W2, _, _, _ = np.linalg.lstsq(H, residuals, rcond=None) + else: + # Pooled OLS noise as b2 + all_x = np.vstack([np.array(pd["x_obs"]) for pd in jdata.pool_data]) + all_y = np.concatenate([np.array(pd["y_obs"]) for pd in jdata.pool_data]) + b2, _, _, _ = np.linalg.lstsq(all_x, all_y, rcond=None) + + return np.concatenate([W1.ravel(), b1, W2.ravel(), b2]) + + def predict_new(self, params_slice, x_attr): + k_attr = len(x_attr) + W1, b1, W2, b2 = self._unpack_weights(np.asarray(params_slice), k_attr) + hidden = np.maximum(x_attr @ W1 + b1, 0.0) + return hidden @ W2 + b2 # (k_obs,) + + def unpack_result(self, params_slice, n_pools, k_attr): + params_np = np.array(params_slice) + W1, b1, W2, b2 = self._unpack_weights(params_np, k_attr) + return { + "mlp_noise_W1": np.array(W1), + "mlp_noise_b1": np.array(b1), + "mlp_noise_W2": np.array(W2), + "mlp_noise_b2": np.array(b2), + } + + def make_bounds(self, n_pools, k_attr): + return [(None, None)] * self.n_params(n_pools, k_attr) diff --git a/quantammsim/calibration/joint_fit.py b/quantammsim/calibration/joint_fit.py index 4f891c6..8eaa7b8 100644 --- a/quantammsim/calibration/joint_fit.py +++ b/quantammsim/calibration/joint_fit.py @@ -38,17 +38,23 @@ class JointData(NamedTuple): def prepare_joint_data( matched: Dict[str, dict], drop_chain_dummies: bool = False, + fix_gas_to_chain: bool = False, + reduced_x_obs: bool = False, ) -> JointData: """Build batched JAX arrays from matched pool data. Args: matched: dict from match_grids_to_panel drop_chain_dummies: if True, remove chain_* columns from attributes - (reduces feature count for small n) + fix_gas_to_chain: if True, store fixed_log_gas per pool from CHAIN_GAS_USD + reduced_x_obs: if True, use 4-column reduced x_obs + (removes sigma/fee terms to avoid identification problems) Returns: JointData with per-pool JAX arrays and shared attribute matrix. """ + from quantammsim.calibration.loss import CHAIN_GAS_USD + X_attr, attr_names, pool_ids = build_pool_attributes(matched) if drop_chain_dummies: @@ -61,15 +67,21 @@ def prepare_joint_data( for pid in pool_ids: entry = matched[pid] panel = entry["panel"] - x_obs = build_x_obs(panel) + x_obs = build_x_obs(panel, reduced=reduced_x_obs) y_obs = panel["log_volume"].values.astype(float) - pool_data.append({ + d = { "coeffs": entry["coeffs"], "x_obs": jnp.array(x_obs), "y_obs": jnp.array(y_obs), "day_indices": jnp.array(entry["day_indices"]), - }) + } + if fix_gas_to_chain: + chain = entry["chain"] + gas_usd = CHAIN_GAS_USD.get(chain, 1.0) + d["fixed_log_gas"] = jnp.float64(np.log(max(gas_usd, 1e-6))) + + pool_data.append(d) return JointData( pool_data=pool_data, @@ -79,6 +91,66 @@ def prepare_joint_data( ) +def prepare_token_factored_data( + matched: Dict[str, dict], + reduced_x_obs: bool = True, + fix_gas_to_chain: bool = True, + canonicalize: bool = True, + cross_pool: bool = False, +) -> tuple: + """Prepare JointData + token encoding for TokenFactoredNoiseHead. + + Args: + matched: dict from match_grids_to_panel + reduced_x_obs: if True, use 4-column reduced x_obs + fix_gas_to_chain: if True, fix gas to chain-level costs + canonicalize: if True, canonicalize token names before building index + cross_pool: if True, use cross-pool lag features (K_OBS_CROSS=7) + + Returns (jdata, token_encoding) where token_encoding is the dict from + encode_tokens() containing token/chain structure for constructing the head. + """ + from quantammsim.calibration.pool_data import encode_tokens + + if cross_pool: + from quantammsim.calibration.pool_data import ( + K_OBS_CROSS, build_cross_pool_x_obs, + ) + + jdata = prepare_joint_data( + matched, + fix_gas_to_chain=fix_gas_to_chain, + reduced_x_obs=reduced_x_obs, + ) + + if cross_pool: + # Replace x_obs with cross-pool version for each pool + pool_ids = jdata.pool_ids + new_pool_data = [] + for i, pid in enumerate(pool_ids): + entry = matched[pid] + x_obs_cross = build_cross_pool_x_obs( + entry["panel"], matched, pid, canonicalize=canonicalize, + ) + # x_obs_cross has n_obs-1 rows (first day dropped); + # trim y_obs and day_indices to match + d = dict(jdata.pool_data[i]) + d["x_obs"] = jnp.array(x_obs_cross) + d["y_obs"] = d["y_obs"][1:] + d["day_indices"] = d["day_indices"][1:] + new_pool_data.append(d) + jdata = JointData( + pool_data=new_pool_data, + x_attr=jdata.x_attr, + pool_ids=jdata.pool_ids, + attr_names=jdata.attr_names, + ) + + token_encoding = encode_tokens(matched, canonicalize=canonicalize) + + return jdata, token_encoding + + def pack_joint_params( bias_cad: float, bias_gas: float, @@ -102,38 +174,66 @@ def pack_joint_params( ]) +def pack_joint_params_fixed_gas( + bias_cad: float, + W_cad: jnp.ndarray, + noise_params: jnp.ndarray, +) -> jnp.ndarray: + """Pack joint params with gas excluded. + + Layout: [bias_cad, W_cad(k_attr), noise_params...] + """ + return jnp.concatenate([ + jnp.array([bias_cad]), + W_cad.ravel(), + noise_params.ravel(), + ]) + + def unpack_joint_params( flat: jnp.ndarray, config: dict ) -> dict: """Unpack flat array to structured params. config must have: k_attr, n_pools, mode + config may have: fix_gas (bool) — if True, no bias_gas/W_gas in flat array """ k_attr = config["k_attr"] mode = config["mode"] + fix_gas = config.get("fix_gas", False) - bias_cad = flat[0] - bias_gas = flat[1] - W_cad = flat[2:2 + k_attr] - W_gas = flat[2 + k_attr:2 + 2 * k_attr] - rest = flat[2 + 2 * k_attr:] + if fix_gas: + bias_cad = flat[0] + W_cad = flat[1:1 + k_attr] + rest = flat[1 + k_attr:] + else: + bias_cad = flat[0] + bias_gas = flat[1] + W_cad = flat[2:2 + k_attr] + W_gas = flat[2 + k_attr:2 + 2 * k_attr] + rest = flat[2 + 2 * k_attr:] if mode == "per_pool_noise": n_pools = config["n_pools"] noise_coeffs = rest.reshape(n_pools, K_OBS) + if fix_gas: + return {"bias_cad": bias_cad, "W_cad": W_cad, + "noise_coeffs": noise_coeffs} return { "bias_cad": bias_cad, "bias_gas": bias_gas, "W_cad": W_cad, "W_gas": W_gas, "noise_coeffs": noise_coeffs, } else: # shared_noise - # noise_params: (1 + k_attr, K_OBS) — row 0 is bias W_noise_full = rest.reshape(1 + k_attr, K_OBS) + if fix_gas: + return {"bias_cad": bias_cad, "W_cad": W_cad, + "bias_noise": W_noise_full[0], "W_noise": W_noise_full[1:]} return { "bias_cad": bias_cad, "bias_gas": bias_gas, "W_cad": W_cad, "W_gas": W_gas, - "bias_noise": W_noise_full[0], # (K_OBS,) - "W_noise": W_noise_full[1:], # (k_attr, K_OBS) + "bias_noise": W_noise_full[0], + "W_noise": W_noise_full[1:], } @@ -147,30 +247,53 @@ def _make_pool_loss_fn( Closes over pool-specific data; takes only params_flat as input. Each pool gets its own small JIT'd computation graph. + + If config["fix_gas"] is True, gas comes from pool_data_i["fixed_log_gas"] + instead of being predicted from attributes. """ coeffs = pool_data_i["coeffs"] x_obs = pool_data_i["x_obs"] y_obs = pool_data_i["y_obs"] day_indices = pool_data_i["day_indices"] mode = config["mode"] + fix_gas = config.get("fix_gas", False) i = pool_idx - @jax.jit - def pool_loss_fn(params_flat): - params = unpack_joint_params(params_flat, config) - log_cad = params["bias_cad"] + jnp.dot(x_attr_i, params["W_cad"]) - log_gas = params["bias_gas"] + jnp.dot(x_attr_i, params["W_gas"]) + if fix_gas: + fixed_log_gas = pool_data_i["fixed_log_gas"] - if mode == "per_pool_noise": - noise_c = params["noise_coeffs"][i] - else: - noise_c = params["bias_noise"] + jnp.dot(x_attr_i, params["W_noise"]) + @jax.jit + def pool_loss_fn(params_flat): + params = unpack_joint_params(params_flat, config) + log_cad = params["bias_cad"] + jnp.dot(x_attr_i, params["W_cad"]) + + if mode == "per_pool_noise": + noise_c = params["noise_coeffs"][i] + else: + noise_c = params["bias_noise"] + jnp.dot(x_attr_i, params["W_noise"]) - v_arb_all = interpolate_pool_daily(coeffs, log_cad, jnp.exp(log_gas)) - v_arb = v_arb_all[day_indices] - v_noise = jnp.exp(x_obs @ noise_c) - log_v_pred = jnp.log(jnp.maximum(v_arb + v_noise, 1e-6)) - return jnp.mean((log_v_pred - y_obs) ** 2) + v_arb_all = interpolate_pool_daily(coeffs, log_cad, jnp.exp(fixed_log_gas)) + v_arb = v_arb_all[day_indices] + v_noise = jnp.exp(x_obs @ noise_c) + log_v_pred = jnp.log(jnp.maximum(v_arb + v_noise, 1e-6)) + return jnp.mean((log_v_pred - y_obs) ** 2) + else: + @jax.jit + def pool_loss_fn(params_flat): + params = unpack_joint_params(params_flat, config) + log_cad = params["bias_cad"] + jnp.dot(x_attr_i, params["W_cad"]) + log_gas = params["bias_gas"] + jnp.dot(x_attr_i, params["W_gas"]) + + if mode == "per_pool_noise": + noise_c = params["noise_coeffs"][i] + else: + noise_c = params["bias_noise"] + jnp.dot(x_attr_i, params["W_noise"]) + + v_arb_all = interpolate_pool_daily(coeffs, log_cad, jnp.exp(log_gas)) + v_arb = v_arb_all[day_indices] + v_noise = jnp.exp(x_obs @ noise_c) + log_v_pred = jnp.log(jnp.maximum(v_arb + v_noise, 1e-6)) + return jnp.mean((log_v_pred - y_obs) ** 2) return pool_loss_fn @@ -180,6 +303,7 @@ def make_joint_loss_fn( mode: str = "per_pool_noise", alpha_cad: float = 0.01, alpha_gas: float = 0.01, + fix_gas: bool = False, ): """Create per-pool JIT'd loss functions and a Python-level aggregator. @@ -190,20 +314,22 @@ def make_joint_loss_fn( Loss averages over pools (not observations), giving equal weight to each pool regardless of observation count. - L2 regularization is applied to W_cad and W_gas only (not biases). + L2 regularization is applied to W_cad (and W_gas if not fixed). Args: jdata: JointData from prepare_joint_data mode: "per_pool_noise" or "shared_noise" alpha_cad: L2 regularization on W_cad - alpha_gas: L2 regularization on W_gas + alpha_gas: L2 regularization on W_gas (ignored if fix_gas=True) + fix_gas: if True, gas is fixed per pool (no W_gas in params) Returns: loss_fn(params_flat) -> scalar loss """ n_pools = len(jdata.pool_data) k_attr = jdata.x_attr.shape[1] - config = {"k_attr": k_attr, "n_pools": n_pools, "mode": mode} + config = {"k_attr": k_attr, "n_pools": n_pools, "mode": mode, + "fix_gas": fix_gas} # Build per-pool JIT'd loss functions pool_loss_fns = [] @@ -218,8 +344,9 @@ def loss_fn(params_flat): data_loss = total / n_pools params = unpack_joint_params(params_flat, config) - reg = alpha_cad * jnp.sum(params["W_cad"] ** 2) + \ - alpha_gas * jnp.sum(params["W_gas"] ** 2) + reg = alpha_cad * jnp.sum(params["W_cad"] ** 2) + if not fix_gas: + reg = reg + alpha_gas * jnp.sum(params["W_gas"] ** 2) return data_loss + reg # Attach per-pool functions for the value_and_grad wrapper @@ -236,15 +363,12 @@ def make_initial_joint_params( jdata: JointData, mode: str = "per_pool_noise", init_from_option_c: Optional[Dict[str, dict]] = None, + fix_gas: bool = False, ) -> jnp.ndarray: """Create initial parameter vector. - If init_from_option_c is provided, warm-start from Option C per-pool fits: - - bias_cad, W_cad from OLS on per-pool fitted log_cadence - - bias_gas, W_gas from OLS on per-pool fitted log_gas - - noise_coeffs from per-pool fits - - Otherwise, use defaults: cadence=12min, gas=$1 for all pools. + If init_from_option_c is provided, warm-start from Option C per-pool fits. + If fix_gas is True, excludes bias_gas and W_gas from the parameter vector. """ n_pools = len(jdata.pool_data) k_attr = jdata.x_attr.shape[1] @@ -252,7 +376,6 @@ def make_initial_joint_params( if init_from_option_c is not None: pool_ids = jdata.pool_ids - # Filter out pools with NaN losses from warm start valid = {p: init_from_option_c[p] for p in pool_ids if p in init_from_option_c and np.isfinite(init_from_option_c[p].get("loss", float("nan")))} @@ -269,32 +392,31 @@ def make_initial_joint_params( } log_cads = np.array([valid[p]["log_cadence"] for p in pool_ids]) - log_gases = np.array([valid[p]["log_gas"] for p in pool_ids]) noise_all = np.array([valid[p]["noise_coeffs"] for p in pool_ids]) - # OLS with intercept: X_aug = [1, x_attr]; solve for [bias, W] X_aug = np.column_stack([np.ones(n_pools), x_attr_np]) cad_params, _, _, _ = np.linalg.lstsq(X_aug, log_cads, rcond=None) - gas_params, _, _, _ = np.linalg.lstsq(X_aug, log_gases, rcond=None) bias_cad, W_cad = cad_params[0], cad_params[1:] - bias_gas, W_gas = gas_params[0], gas_params[1:] + + if not fix_gas: + log_gases = np.array([valid[p]["log_gas"] for p in pool_ids]) + gas_params, _, _, _ = np.linalg.lstsq(X_aug, log_gases, rcond=None) + bias_gas, W_gas = gas_params[0], gas_params[1:] if mode == "per_pool_noise": noise_params = noise_all else: - # OLS with intercept for noise mapping noise_aug, _, _, _ = np.linalg.lstsq(X_aug, noise_all, rcond=None) - # noise_aug: (1+k_attr, K_OBS) — row 0 is bias noise_params = noise_aug else: - # Default: all pools get cadence=12min, gas=$1 bias_cad = np.log(12.0) - bias_gas = np.log(1.0) # = 0.0 W_cad = np.zeros(k_attr) - W_gas = np.zeros(k_attr) + + if not fix_gas: + bias_gas = np.log(1.0) + W_gas = np.zeros(k_attr) if mode == "per_pool_noise": - # Initialize noise via OLS per pool noise_params = np.zeros((n_pools, K_OBS)) for i, pd in enumerate(jdata.pool_data): x_obs_np = np.array(pd["x_obs"]) @@ -302,29 +424,40 @@ def make_initial_joint_params( c, _, _, _ = np.linalg.lstsq(x_obs_np, y_obs_np, rcond=None) noise_params[i] = c else: - # Initialize shared noise from pooled OLS all_x = np.vstack([np.array(pd["x_obs"]) for pd in jdata.pool_data]) all_y = np.concatenate([np.array(pd["y_obs"]) for pd in jdata.pool_data]) c, _, _, _ = np.linalg.lstsq(all_x, all_y, rcond=None) - # (1+k_attr, K_OBS): bias row + zero weight rows noise_params = np.zeros((1 + k_attr, K_OBS)) noise_params[0, :] = c - return pack_joint_params( - float(bias_cad), - float(bias_gas), - jnp.array(W_cad), - jnp.array(W_gas), - jnp.array(noise_params), - ) + if fix_gas: + return pack_joint_params_fixed_gas( + float(bias_cad), + jnp.array(W_cad), + jnp.array(noise_params), + ) + else: + return pack_joint_params( + float(bias_cad), + float(bias_gas), + jnp.array(W_cad), + jnp.array(W_gas), + jnp.array(noise_params), + ) -def _make_bounds(k_attr, n_pools, mode): +def _make_bounds(k_attr, n_pools, mode, fix_gas=False): """Build scipy bounds for joint params.""" - # bias_cad, bias_gas: unbounded - bounds = [(None, None)] * 2 - # W_cad, W_gas: unbounded - bounds += [(None, None)] * (2 * k_attr) + if fix_gas: + # bias_cad only + bounds = [(None, None)] * 1 + # W_cad only + bounds += [(None, None)] * k_attr + else: + # bias_cad, bias_gas + bounds = [(None, None)] * 2 + # W_cad, W_gas + bounds += [(None, None)] * (2 * k_attr) if mode == "per_pool_noise": bounds += [(None, None)] * (n_pools * K_OBS) @@ -342,6 +475,7 @@ def fit_joint( alpha_cad: float = 0.01, alpha_gas: float = 0.01, drop_chain_dummies: bool = False, + fix_gas_to_chain: bool = False, ) -> dict: """Joint end-to-end optimization across all pools. @@ -349,38 +483,43 @@ def fit_joint( matched: dict from match_grids_to_panel mode: "per_pool_noise" or "shared_noise" init_from_option_c: Optional Option C results for warm start. - Pools with NaN losses are silently excluded from warm start. maxiter: max L-BFGS-B iterations alpha_cad: L2 regularization on W_cad (not bias) - alpha_gas: L2 regularization on W_gas (not bias) + alpha_gas: L2 regularization on W_gas (not bias, ignored if fix_gas) drop_chain_dummies: if True, remove chain_* columns from attributes + fix_gas_to_chain: if True, gas is fixed to known chain-level costs Returns dict with fitted params and diagnostics. """ - jdata = prepare_joint_data(matched, drop_chain_dummies=drop_chain_dummies) + jdata = prepare_joint_data(matched, drop_chain_dummies=drop_chain_dummies, + fix_gas_to_chain=fix_gas_to_chain) loss_fn = make_joint_loss_fn(jdata, mode=mode, - alpha_cad=alpha_cad, alpha_gas=alpha_gas) + alpha_cad=alpha_cad, alpha_gas=alpha_gas, + fix_gas=fix_gas_to_chain) init = make_initial_joint_params(jdata, mode=mode, - init_from_option_c=init_from_option_c) + init_from_option_c=init_from_option_c, + fix_gas=fix_gas_to_chain) n_pools = len(jdata.pool_data) k_attr = jdata.x_attr.shape[1] - config = {"k_attr": k_attr, "n_pools": n_pools, "mode": mode} - bounds = _make_bounds(k_attr, n_pools, mode) + config = {"k_attr": k_attr, "n_pools": n_pools, "mode": mode, + "fix_gas": fix_gas_to_chain} + bounds = _make_bounds(k_attr, n_pools, mode, fix_gas=fix_gas_to_chain) - # Per-pool value_and_grad — each pool has its own small JIT graph pool_vg_fns = loss_fn._pool_val_and_grad_fns - # Indices for W_cad and W_gas in the flat param vector (for reg gradient) - w_cad_start = 2 - w_cad_end = 2 + k_attr - w_gas_start = 2 + k_attr - w_gas_end = 2 + 2 * k_attr + if fix_gas_to_chain: + w_cad_start = 1 + w_cad_end = 1 + k_attr + else: + w_cad_start = 2 + w_cad_end = 2 + k_attr + w_gas_start = 2 + k_attr + w_gas_end = 2 + 2 * k_attr def scipy_wrapper(params_np): params_j = jnp.array(params_np) - # Sum per-pool losses and gradients total_val = 0.0 total_grad = jnp.zeros_like(params_j) for vg_fn in pool_vg_fns: @@ -391,15 +530,15 @@ def scipy_wrapper(params_np): data_loss = total_val / n_pools data_grad = total_grad / n_pools - # Regularization on W_cad and W_gas (not biases) - reg = (alpha_cad * float(jnp.sum(params_j[w_cad_start:w_cad_end] ** 2)) + - alpha_gas * float(jnp.sum(params_j[w_gas_start:w_gas_end] ** 2))) - + reg = alpha_cad * float(jnp.sum(params_j[w_cad_start:w_cad_end] ** 2)) reg_grad = jnp.zeros_like(params_j) reg_grad = reg_grad.at[w_cad_start:w_cad_end].set( 2 * alpha_cad * params_j[w_cad_start:w_cad_end]) - reg_grad = reg_grad.at[w_gas_start:w_gas_end].set( - 2 * alpha_gas * params_j[w_gas_start:w_gas_end]) + + if not fix_gas_to_chain: + reg += alpha_gas * float(jnp.sum(params_j[w_gas_start:w_gas_end] ** 2)) + reg_grad = reg_grad.at[w_gas_start:w_gas_end].set( + 2 * alpha_gas * params_j[w_gas_start:w_gas_end]) val = data_loss + reg grad = data_grad + reg_grad @@ -422,17 +561,30 @@ def scipy_wrapper(params_np): out = { "init_loss": init_loss, "bias_cad": float(params["bias_cad"]), - "bias_gas": float(params["bias_gas"]), "W_cad": np.array(params["W_cad"]), - "W_gas": np.array(params["W_gas"]), "loss": float(result.fun), "converged": result.success, "mode": mode, "k_attr": k_attr, "pool_ids": jdata.pool_ids, "attr_names": jdata.attr_names, + "fix_gas": fix_gas_to_chain, } + if fix_gas_to_chain: + # Store per-pool fixed gas values for downstream use + from quantammsim.calibration.loss import CHAIN_GAS_USD + gas_per_pool = [] + for pid in jdata.pool_ids: + chain = matched[pid]["chain"] + gas_per_pool.append(CHAIN_GAS_USD.get(chain, 1.0)) + out["gas_per_pool"] = np.array(gas_per_pool) + out["bias_gas"] = 0.0 + out["W_gas"] = np.zeros(k_attr) + else: + out["bias_gas"] = float(params["bias_gas"]) + out["W_gas"] = np.array(params["W_gas"]) + if mode == "per_pool_noise": out["noise_coeffs"] = np.array(params["noise_coeffs"]) else: diff --git a/quantammsim/calibration/loss.py b/quantammsim/calibration/loss.py index 13b2e65..cdc8c1e 100644 --- a/quantammsim/calibration/loss.py +++ b/quantammsim/calibration/loss.py @@ -16,6 +16,17 @@ K_OBS = 8 # observation-level covariates +# Known chain gas costs (USD) — used when fixing gas to chain-level values. +# These are effective per-transaction costs, not per-gas-unit. +CHAIN_GAS_USD = { + "MAINNET": 1.0, + "POLYGON": 0.005, + "GNOSIS": 0.001, + "ARBITRUM": 0.01, + "BASE": 0.005, + "SONIC": 0.005, +} + def noise_volume( noise_coeffs: jnp.ndarray, x_obs: jnp.ndarray @@ -41,6 +52,35 @@ def unpack_params( return flat[0], flat[1], flat[2:] +def pack_params_fixed_gas( + log_cadence: float, noise_coeffs: jnp.ndarray +) -> jnp.ndarray: + """Pack into flat array with gas excluded: [log_cadence, noise_coeffs...].""" + return jnp.concatenate([ + jnp.array([log_cadence]), + jnp.asarray(noise_coeffs), + ]) + + +def unpack_params_fixed_gas( + flat: jnp.ndarray, +) -> Tuple[float, jnp.ndarray]: + """Unpack flat array to (log_cadence, noise_coeffs). Gas not included.""" + return flat[0], flat[1:] + + +def _compute_loss_huber( + residuals: jnp.ndarray, + delta: float = 1.5, +) -> jnp.ndarray: + """Huber loss: 0.5*r^2 for |r|<=delta, delta*(|r|-0.5*delta) otherwise.""" + abs_r = jnp.abs(residuals) + return jnp.mean( + jnp.where(abs_r <= delta, 0.5 * residuals ** 2, + delta * (abs_r - 0.5 * delta)) + ) + + def pool_loss( params_flat: jnp.ndarray, coeffs: PoolCoeffsDaily, @@ -72,3 +112,28 @@ def pool_loss( # Log-space L2 loss log_v_pred = jnp.log(jnp.maximum(v_arb + v_noise, 1e-6)) return jnp.mean((log_v_pred - y_obs) ** 2) + + +def pool_loss_fixed_gas( + params_flat: jnp.ndarray, + fixed_log_gas: float, + coeffs: PoolCoeffsDaily, + x_obs: jnp.ndarray, + y_obs: jnp.ndarray, + day_indices: jnp.ndarray, +) -> jnp.ndarray: + """Per-pool loss with gas fixed to a known chain-level value. + + Args: + params_flat: [log_cadence, noise_coeffs...] — no log_gas + fixed_log_gas: log(gas_usd) held constant (not optimized) + coeffs, x_obs, y_obs, day_indices: as in pool_loss + """ + log_cadence, noise_coeffs = unpack_params_fixed_gas(params_flat) + + v_arb_all = interpolate_pool_daily(coeffs, log_cadence, jnp.exp(fixed_log_gas)) + v_arb = v_arb_all[day_indices] + v_noise = noise_volume(noise_coeffs, x_obs) + + log_v_pred = jnp.log(jnp.maximum(v_arb + v_noise, 1e-6)) + return jnp.mean((log_v_pred - y_obs) ** 2) diff --git a/quantammsim/calibration/market_features.py b/quantammsim/calibration/market_features.py new file mode 100644 index 0000000..f2dd610 --- /dev/null +++ b/quantammsim/calibration/market_features.py @@ -0,0 +1,308 @@ +"""Market-level and token-level features for the noise volume model. + +Derives daily features from Binance minute-level price data and pool metadata. +Features are grounded in market microstructure — what mechanistically drives +organic (non-arb) trading volume: + +Market regime: + - BTC log price, log return — crypto market regime proxy + - BTC trend (rolling mean log return) — bull/bear at various horizons + +Token-level (per pool token): + - Token log price, daily log return + - Token realized volatility — higher vol → more hedging/speculative flow + - Token Binance volume — proxy for overall token trading interest + - Token trend (rolling mean log return) + +All features are computed daily and aligned to the panel date grid. +""" + +import os +from typing import Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd + +DATA_DIR = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + "data", +) + +# Map wrapped/derivative tokens to their Binance underlying +TOKEN_MAP = { + "WETH": "ETH", + "WBTC": "BTC", + "wstETH": "ETH", + "waEthLidowstETH": "ETH", + "waEthLidoWETH": "ETH", + "waGnowstETH": "ETH", + "waGnoGNO": "GNO", + "waBasUSDC": "USDC", + "waBasWETH": "ETH", + "sDAI": "DAI", + "scUSD": "USDC", + "stS": "S", + "JitoSOL": "SOL", + "USDT": "USDC", # treat as $1 stablecoin +} + + +def _load_binance_daily(symbol: str) -> pd.DataFrame: + """Load Binance minute data and resample to daily OHLCV.""" + mapped = TOKEN_MAP.get(symbol, symbol) + path = os.path.join(DATA_DIR, f"{mapped}_USD.parquet") + if not os.path.exists(path): + return None + + df = pd.read_parquet(path, columns=["date", "close", "Volume USD"]) + df["date"] = pd.to_datetime(df["date"]) + df = df.set_index("date").sort_index() + + daily = df.resample("1D").agg({ + "close": "last", + "Volume USD": "sum", + }).dropna(subset=["close"]) + + daily.columns = ["close", "volume_usd"] + return daily + + +def _compute_token_features( + daily: pd.DataFrame, + trend_windows: List[int] = (7, 14, 30), + is_market: bool = False, +) -> pd.DataFrame: + """Compute daily features from a token's OHLCV. + + For market-level tokens (BTC): includes log_price as regime proxy. + For pool tokens: only returns/vol/trends (comparable across tokens). + + Volume is normalised as z-score within each token: today's log-volume + relative to a 30-day trailing mean/std. This captures "is this token + unusually active today" without the cross-token scale problem. + + Returns DataFrame indexed by date. + """ + out = pd.DataFrame(index=daily.index) + log_price = np.log(daily["close"].clip(lower=1e-10)) + out["log_return"] = log_price.diff() + + if is_market: + # BTC log_price is a market regime proxy (same for all pools) + out["log_price"] = log_price + + # Realized volatility: std of log returns over trailing 7 days + out["realized_vol_7d"] = out["log_return"].rolling(7, min_periods=3).std() + + # Volume: z-score relative to trailing 30d mean/std of log-volume + # Captures "unusually active day for this token" — comparable across tokens + log_vol = np.log(daily["volume_usd"].clip(lower=1.0)) + vol_mean_30d = log_vol.rolling(30, min_periods=10).mean() + vol_std_30d = log_vol.rolling(30, min_periods=10).std().clip(lower=0.1) + out["volume_zscore"] = (log_vol - vol_mean_30d) / vol_std_30d + + # Trend: rolling mean log return at various horizons + for w in trend_windows: + out[f"trend_{w}d"] = out["log_return"].rolling(w, min_periods=max(w // 2, 2)).mean() + + return out + + +def build_btc_daily_features( + trend_windows: List[int] = (7, 14, 30), +) -> pd.DataFrame: + """BTC daily features as market regime proxy. + + Returns DataFrame indexed by date with columns prefixed 'btc_'. + """ + daily = _load_binance_daily("BTC") + if daily is None: + raise FileNotFoundError("BTC_USD.parquet not found") + + feat = _compute_token_features(daily, trend_windows, is_market=True) + feat.columns = [f"btc_{c}" for c in feat.columns] + return feat + + +def build_token_daily_features( + symbol: str, + trend_windows: List[int] = (7, 14, 30), +) -> Optional[pd.DataFrame]: + """Daily features for a single token. Returns None if no data.""" + daily = _load_binance_daily(symbol) + if daily is None: + return None + return _compute_token_features(daily, trend_windows) + + +def _compute_pair_volatility( + symbol_a: str, + symbol_b: str, +) -> Optional[pd.DataFrame]: + """Compute realized volatility of the A/B price ratio. + + vol(log(price_A/price_B)) = vol(log(price_A) - log(price_B)) + Symmetric: A/B and B/A give identical volatility. + + Returns DataFrame indexed by date with 'pair_realized_vol_7d'. + """ + daily_a = _load_binance_daily(symbol_a) + daily_b = _load_binance_daily(symbol_b) + if daily_a is None or daily_b is None: + return None + + # Align on common dates + log_a = np.log(daily_a["close"].clip(lower=1e-10)) + log_b = np.log(daily_b["close"].clip(lower=1e-10)) + common = log_a.index.intersection(log_b.index) + if len(common) < 10: + return None + + log_ratio = log_a.loc[common] - log_b.loc[common] + log_ratio_return = log_ratio.diff() + + out = pd.DataFrame(index=common) + out["pair_realized_vol_7d"] = log_ratio_return.rolling(7, min_periods=3).std() + return out + + +def build_pool_market_features( + matched_clean: Dict[str, dict], + trend_windows: List[int] = (7, 14, 30), +) -> Dict[str, pd.DataFrame]: + """Build per-pool market feature DataFrames. + + For each pool, produces a DataFrame aligned to the pool's panel dates with: + - BTC features (market regime) + - Token A features (vs USD) + - Token B features (vs USD) + - Pair volatility (A/B ratio) + + Returns dict: pool_id -> DataFrame with all features. + """ + from quantammsim.calibration.pool_data import _parse_tokens + + # Load BTC features once + btc_feat = build_btc_daily_features(trend_windows) + + # Cache token features and pair volatilities + token_cache = {} + pair_vol_cache = {} + + pool_features = {} + for pid, entry in matched_clean.items(): + panel = entry["panel"] + dates = pd.to_datetime(panel["date"]) + + # Parse tokens + toks = _parse_tokens(entry["tokens"]) + tok_a, tok_b = toks[0], toks[1] if len(toks) > 1 else toks[0] + + # Get token features + for tok in [tok_a, tok_b]: + mapped = TOKEN_MAP.get(tok, tok) + if mapped not in token_cache: + token_cache[mapped] = build_token_daily_features(mapped, trend_windows) + + feat_a = token_cache.get(TOKEN_MAP.get(tok_a, tok_a)) + feat_b = token_cache.get(TOKEN_MAP.get(tok_b, tok_b)) + + # Pair volatility (cache by sorted token pair to avoid duplicates) + mapped_a = TOKEN_MAP.get(tok_a, tok_a) + mapped_b = TOKEN_MAP.get(tok_b, tok_b) + pair_key = tuple(sorted([mapped_a, mapped_b])) + if pair_key not in pair_vol_cache: + pair_vol_cache[pair_key] = _compute_pair_volatility(mapped_a, mapped_b) + pair_vol = pair_vol_cache[pair_key] + + # Build per-date feature vectors + rows = [] + for date in dates: + day = pd.Timestamp(date).normalize() + row = {} + + # BTC features + if day in btc_feat.index: + for col in btc_feat.columns: + row[col] = btc_feat.loc[day, col] + + # Token A features + if feat_a is not None and day in feat_a.index: + for col in feat_a.columns: + row[f"tok_a_{col}"] = feat_a.loc[day, col] + + # Token B features + if feat_b is not None and day in feat_b.index: + for col in feat_b.columns: + row[f"tok_b_{col}"] = feat_b.loc[day, col] + + # Pair volatility + if pair_vol is not None and day in pair_vol.index: + row["pair_realized_vol_7d"] = pair_vol.loc[day, "pair_realized_vol_7d"] + + rows.append(row) + + df = pd.DataFrame(rows, index=dates) + pool_features[pid] = df + + return pool_features + + +def pool_market_features_to_matrix( + pool_features: Dict[str, pd.DataFrame], + matched_clean: Dict[str, dict], + date_to_idx: Dict, + pool_ids: List[str], + sample_pools: np.ndarray, + sample_days: np.ndarray, +) -> Tuple[np.ndarray, List[str]]: + """Convert per-pool market features to a (n_samples, n_feat) matrix. + + Aligns features to the common date grid and sample indices. + NaN-fills missing values, then imputes with column mean. + + Returns (feature_matrix, feature_names). + """ + # Get feature columns from first pool + first_pid = pool_ids[0] + feat_cols = sorted(pool_features[first_pid].columns) + n_feat = len(feat_cols) + n_pools = len(pool_ids) + + # Collect all dates + n_dates = max(date_to_idx.values()) + 1 + + # Build (n_dates, n_pools, n_feat) grid + feat_grid = np.full((n_dates, n_pools, n_feat), np.nan, dtype=np.float32) + + for j, pid in enumerate(pool_ids): + if pid not in pool_features: + continue + pf = pool_features[pid] + panel_dates = matched_clean[pid]["panel"]["date"].values + for k, date in enumerate(panel_dates): + t = date_to_idx.get(date) + if t is None: + continue + for f, col in enumerate(feat_cols): + if col in pf.columns and k < len(pf): + val = pf.iloc[k][col] if col in pf.columns else np.nan + if np.isfinite(val): + feat_grid[t, j, f] = val + + # Extract per-sample + n_samples = len(sample_pools) + X = np.zeros((n_samples, n_feat), dtype=np.float32) + for s in range(n_samples): + X[s] = feat_grid[sample_days[s], sample_pools[s]] + + # Impute NaN with column mean + for f in range(n_feat): + col = X[:, f] + mask = np.isnan(col) + if mask.all(): + col[:] = 0.0 + elif mask.any(): + col[mask] = np.nanmean(col) + + return X, feat_cols diff --git a/quantammsim/calibration/noise_model_arrays.py b/quantammsim/calibration/noise_model_arrays.py new file mode 100644 index 0000000..7b73dc8 --- /dev/null +++ b/quantammsim/calibration/noise_model_arrays.py @@ -0,0 +1,460 @@ +"""Precompute noise_base and noise_tvl_coeff arrays for the simulator. + +Builds daily feature vectors from Binance price data only — no panel/API +dependency. Works for any date range covered by Binance parquets. + +Produces the two arrays needed by reclamm_market_linear_noise_volume(): + + log(V_daily_noise) = noise_base_t + noise_tvl_coeff_t * log(effective_TVL) + +Usage: + from quantammsim.calibration.noise_model_arrays import build_simulator_arrays + + arrays = build_simulator_arrays( + token_a="AAVE", token_b="ETH", + start_date="2024-06-01", + end_date="2026-03-01", + artifact_dir="results/linear_market_noise", + pool_id="0x9d1fcf346ea1b0", # for per-pool coeffs + ) +""" + +import json +import os +from typing import Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd + + +def load_artifact(artifact_dir: str) -> Tuple[dict, dict]: + """Load model.npz + meta.json from artifact directory.""" + art = np.load(os.path.join(artifact_dir, "model.npz"), allow_pickle=True) + with open(os.path.join(artifact_dir, "meta.json")) as f: + meta = json.load(f) + return dict(art), meta + + +def _find_pool_index(pool_id: str, pool_ids: list) -> int: + """Match pool_id (full or prefix) to calibration pool list.""" + for i, cid in enumerate(pool_ids): + if pool_id.startswith(cid) or cid.startswith(pool_id): + return i + return -1 + + +def _identify_tvl_columns(feat_names: list) -> Tuple[int, list]: + """Identify which feature columns involve TVL. + + Returns: + tvl_col: index of the pure log_tvl feature (xobs_1) + tvl_interaction_cols: list of (col_idx, paired_col_idx) + """ + tvl_col = None + tvl_interaction_cols = [] + + for i, name in enumerate(feat_names): + if name == "xobs_1": + tvl_col = i + elif "xobs_1\u00d7" in name: + paired_name = name.split("\u00d7")[1] + for j, n2 in enumerate(feat_names): + if n2 == paired_name: + tvl_interaction_cols.append((i, j)) + break + + if tvl_col is None: + raise ValueError("xobs_1 (log_tvl) not found in feature names") + + return tvl_col, tvl_interaction_cols + + +def build_daily_features_from_binance( + token_a: str, + token_b: str, + start_date: str, + end_date: str, + feat_names: List[str], + x_mean: np.ndarray, + x_std: np.ndarray, + trend_windows: tuple = (7,), +) -> Tuple[np.ndarray, list]: + """Build daily feature matrix from Binance data only. + + No panel or API dependency. Features: + - xobs_0 (intercept), xobs_1 (log_tvl — filled with 0, handled at runtime), + xobs_2/3 (dow_sin/cos) + - BTC: log_price, log_return, realized_vol_7d, trend, volume_zscore + - Token A/B: log_return, realized_vol_7d, trend, volume_zscore + - Pair realized_vol_7d + - Interaction terms + """ + from quantammsim.calibration.market_features import ( + build_btc_daily_features, + build_token_daily_features, + _compute_pair_volatility, + TOKEN_MAP, + ) + + start = pd.Timestamp(start_date) + end = pd.Timestamp(end_date) + + # Generate complete daily date range + date_range = pd.date_range(start, end, freq="D") + n_days = len(date_range) + + # BTC features + btc_feat = build_btc_daily_features(list(trend_windows)) + + # Token features + mapped_a = TOKEN_MAP.get(token_a, token_a) + mapped_b = TOKEN_MAP.get(token_b, token_b) + feat_a = build_token_daily_features(mapped_a, list(trend_windows)) + feat_b = build_token_daily_features(mapped_b, list(trend_windows)) + + # Pair volatility + pair_vol = _compute_pair_volatility(token_a, token_b) + + # Identify which features are x_obs vs market + # x_obs features: xobs_0 (intercept), xobs_1 (tvl), xobs_2 (dow_sin), xobs_3 (dow_cos) + # Remaining xobs_4,5,6 are cross-pool — skip if not in feat_names + n_xobs = sum(1 for f in feat_names if f.startswith("xobs_")) + + # Build market feature column list (everything after x_obs, before interactions) + market_names = [f for f in feat_names + if not f.startswith("xobs_") and "\u00d7" not in f] + + # Build per-day feature vectors + x_base_cols = n_xobs + len(market_names) + x_base = np.zeros((n_days, x_base_cols), dtype=np.float32) + + for k, day in enumerate(date_range): + day_norm = day.normalize() + + # x_obs + x_base[k, 0] = 1.0 # intercept + # x_base[k, 1] = 0.0 # log_tvl — placeholder, handled at runtime + weekday = day.weekday() + if n_xobs > 2: + x_base[k, 2] = np.sin(2 * np.pi * weekday / 7) + if n_xobs > 3: + x_base[k, 3] = np.cos(2 * np.pi * weekday / 7) + # xobs_4,5,6 (cross-pool) left as 0 if present + + # Market features + col = n_xobs + for mname in market_names: + val = 0.0 + if mname.startswith("btc_") and btc_feat is not None: + bcol = mname[4:] # strip "btc_" + if day_norm in btc_feat.index and bcol in btc_feat.columns: + v = btc_feat.loc[day_norm, bcol] + if np.isfinite(v): + val = v + elif mname.startswith("tok_a_") and feat_a is not None: + acol = mname[6:] + if day_norm in feat_a.index and acol in feat_a.columns: + v = feat_a.loc[day_norm, acol] + if np.isfinite(v): + val = v + elif mname.startswith("tok_b_") and feat_b is not None: + bcol = mname[6:] + if day_norm in feat_b.index and bcol in feat_b.columns: + v = feat_b.loc[day_norm, bcol] + if np.isfinite(v): + val = v + elif mname == "pair_realized_vol_7d" and pair_vol is not None: + if day_norm in pair_vol.index: + v = pair_vol.loc[day_norm, "pair_realized_vol_7d"] + if np.isfinite(v): + val = v + x_base[k, col] = val + col += 1 + + # Standardize base features + x_base = ((x_base - x_mean[:x_base_cols]) / x_std[:x_base_cols]).astype(np.float32) + + # Interaction terms + base_feat_names = feat_names[:x_base_cols] + col_idx = {name: i for i, name in enumerate(base_feat_names)} + + interactions = [] + for fname in feat_names[x_base_cols:]: + if "\u00d7" in fname: + parts = fname.split("\u00d7") + if parts[0] in col_idx and parts[1] in col_idx: + interactions.append( + x_base[:, col_idx[parts[0]]] * x_base[:, col_idx[parts[1]]]) + else: + interactions.append(np.zeros(n_days, dtype=np.float32)) + else: + interactions.append(np.zeros(n_days, dtype=np.float32)) + + if interactions: + x_all = np.concatenate( + [x_base, np.column_stack(interactions)], axis=1).astype(np.float32) + else: + x_all = x_base + + return x_all, date_range.tolist() + + +def build_simulator_arrays( + token_a: str, + token_b: str, + start_date: str, + end_date: str, + artifact_dir: str = "results/linear_market_noise", + pool_id: Optional[str] = None, +) -> Dict: + """Build noise_base and noise_tvl_coeff arrays for the simulator. + + No panel dependency — uses Binance data only. + + Parameters + ---------- + token_a, token_b : str + Token symbols (e.g. "AAVE", "ETH"). Mapped to Binance symbols + internally (WETH→ETH, wstETH→ETH, etc.) + start_date, end_date : str + Date range (inclusive). + artifact_dir : str + Directory containing model.npz and meta.json. + pool_id : str, optional + Pool ID for per-pool coefficients. If None or not found, + uses median coefficients. + + Returns + ------- + dict with noise_base, noise_tvl_coeff, tvl_mean, tvl_std, dates, etc. + """ + art, meta = load_artifact(artifact_dir) + noise_coeffs = art["noise_coeffs"] + feat_names = meta["feat_names"] + pool_ids = meta["pool_ids"] + x_mean = art["x_mean"] + x_std = art["x_std"] + per_pool = noise_coeffs.ndim == 2 + trend_windows = tuple(meta["hparams"]["trend_windows"]) + + # Find pool coefficients + pool_idx = -1 + if pool_id is not None: + pool_idx = _find_pool_index(pool_id, pool_ids) + + if pool_idx >= 0 and per_pool: + coeffs = noise_coeffs[pool_idx] + print(f" Using per-pool coefficients (pool idx {pool_idx})") + elif per_pool: + coeffs = np.median(noise_coeffs, axis=0) + print(f" Pool not found, using median coefficients") + else: + coeffs = noise_coeffs + + # Build daily features from Binance + print(f" Building features from Binance data: {token_a}/{token_b}," + f" {start_date} → {end_date}") + x_daily, dates = build_daily_features_from_binance( + token_a, token_b, start_date, end_date, + feat_names, x_mean, x_std, trend_windows, + ) + n_days = len(dates) + print(f" {n_days} days, {len(feat_names)} features") + + # Decompose into base (non-TVL) and tvl_coeff + tvl_col, tvl_interactions = _identify_tvl_columns(feat_names) + + tvl_coeff_daily = np.full(n_days, coeffs[tvl_col], dtype=np.float64) + for inter_col, paired_col in tvl_interactions: + tvl_coeff_daily += coeffs[inter_col] * x_daily[:, paired_col] + + tvl_related = {tvl_col} | {ic for ic, _ in tvl_interactions} + base_daily = np.zeros(n_days, dtype=np.float64) + for j in range(len(feat_names)): + if j not in tvl_related: + base_daily += coeffs[j] * x_daily[:, j] + + # Expand to minute resolution + n_minutes = n_days * 1440 + noise_base = np.repeat(base_daily, 1440) + noise_tvl_coeff = np.repeat(tvl_coeff_daily, 1440) + + return { + "noise_base": noise_base, + "noise_tvl_coeff": noise_tvl_coeff, + "tvl_mean": float(x_mean[tvl_col]), + "tvl_std": float(x_std[tvl_col]), + "dates": dates, + "pool_index": pool_idx, + "n_days": n_days, + "n_minutes": n_minutes, + "coeffs": coeffs, + "tvl_col": tvl_col, + } + + +def build_mm_simulator_arrays( + token_a: str, + token_b: str, + start_date: str, + end_date: str, + mm_artifact_dir: str = "results/mm_noise", + competitor_tvl_path: str = "results/competitor_tvl/competitor_tvl.npz", + pool_id: Optional[str] = None, +) -> Dict: + """Build noise_base and competitor_tvl arrays for the MM simulator. + + The MM noise model evaluates:: + + V_noise = exp(noise_base_t) * TVL / (K_t + TVL) + + where noise_base_t = alpha_i + gamma_i @ x_market_t absorbs all + non-TVL terms, and K_t = competitor_tvl_t is observed from DeFi + Llama (network conductance model: direct + multi-hop). + + Parameters + ---------- + token_a, token_b : str + Token symbols. + start_date, end_date : str + Date range. + mm_artifact_dir : str + Directory with MM model.npz and meta.json. + competitor_tvl_path : str + Path to competitor_tvl.npz from fetch_competitor_tvl.py. + pool_id : str, optional + Pool ID for per-pool alpha/gamma. + + Returns + ------- + dict with noise_base, competitor_tvl (minute arrays), dates, etc. + """ + # Load MM model + art, meta = load_artifact(mm_artifact_dir) + pool_ids = meta["pool_ids"] + market_names = meta["market_names"] + n_market = meta["n_market_feat"] + per_pool_gamma = meta.get("per_pool_gamma", False) + + pool_idx = -1 + if pool_id is not None: + pool_idx = _find_pool_index(pool_id, pool_ids) + + log_alpha = art["log_alpha"] + gamma = art["gamma"] + + if pool_idx >= 0: + alpha_i = float(log_alpha[pool_idx]) + gamma_i = gamma[pool_idx] if per_pool_gamma else gamma + print(f" MM model: pool idx {pool_idx}, alpha={alpha_i:.3f}") + else: + alpha_i = float(np.median(log_alpha)) + gamma_i = np.median(gamma, axis=0) if per_pool_gamma else gamma + print(f" MM model: pool not found, using median alpha={alpha_i:.3f}") + + # Build daily market features from Binance + # The MM model uses the same features as the linear model minus TVL + # We need x_mean/x_std from the linear model artifact for standardization + linear_art_dir = os.path.join( + os.path.dirname(os.path.dirname(mm_artifact_dir)), + "results", "linear_market_noise") + if os.path.exists(os.path.join(linear_art_dir, "model.npz")): + lin_art, lin_meta = load_artifact(linear_art_dir) + x_mean = lin_art["x_mean"] + x_std = lin_art["x_std"] + feat_names = lin_meta["feat_names"] + else: + # Fallback: try to get from MM artifact + x_mean = art.get("x_mean", np.zeros(n_market)) + x_std = art.get("x_std", np.ones(n_market)) + feat_names = market_names + + trend_windows = (7,) + + print(f" Building features from Binance: {token_a}/{token_b}," + f" {start_date} → {end_date}") + x_daily, dates = build_daily_features_from_binance( + token_a, token_b, start_date, end_date, + feat_names, x_mean, x_std, trend_windows, + ) + n_days = len(dates) + + # Extract market features (exclude TVL and TVL interactions) + tvl_col = None + tvl_interaction_cols = set() + for i, name in enumerate(feat_names): + if name == "xobs_1": + tvl_col = i + elif name.startswith("xobs_1\u00d7"): + tvl_interaction_cols.add(i) + + keep_cols = [i for i in range(len(feat_names)) + if i != tvl_col and i not in tvl_interaction_cols] + + # Map market_names to x_daily columns + x_market_daily = np.zeros((n_days, n_market), dtype=np.float32) + for mi, mname in enumerate(market_names): + # Find mname in feat_names + for fi, fname in enumerate(feat_names): + if fname == mname and fi in keep_cols: + col_in_daily = fi + x_market_daily[:, mi] = x_daily[:, col_in_daily] + break + + # Compute noise_base = alpha_i + gamma_i @ x_market + noise_base_daily = alpha_i + x_market_daily @ gamma_i + noise_base_daily = noise_base_daily.astype(np.float64) + + # Load competitor TVL (K) + print(f" Loading competitor TVL from {competitor_tvl_path}") + comp_data = np.load(competitor_tvl_path, allow_pickle=True) + comp_pool_ids = list(comp_data["pool_ids"]) + comp_dates = list(comp_data["date_list"]) + k_eff = comp_data["k_eff"] # (n_comp_dates, n_comp_pools) + + # Find pool in competitor data + comp_pool_idx = -1 + if pool_id is not None: + comp_pool_idx = _find_pool_index(pool_id, comp_pool_ids) + + if comp_pool_idx < 0: + print(f" WARNING: pool not in competitor TVL data, using K=$10M") + K_daily = np.full(n_days, 10e6, dtype=np.float64) + else: + # Build date index for competitor data + comp_date_to_idx = {} + for ci, d in enumerate(comp_dates): + comp_date_to_idx[str(d)[:10]] = ci + + K_daily = np.full(n_days, np.nan, dtype=np.float64) + for k, day in enumerate(dates): + ds = str(pd.Timestamp(day))[:10] + if ds in comp_date_to_idx: + ci = comp_date_to_idx[ds] + val = k_eff[ci, comp_pool_idx] + if np.isfinite(val) and val > 0: + K_daily[k] = val + + # Forward-fill / back-fill + s = pd.Series(K_daily).ffill().bfill() + K_daily = s.values.astype(np.float64) + + # Floor + K_daily = np.maximum(K_daily, 1.0) + med_K = np.median(K_daily[np.isfinite(K_daily)]) + print(f" K (competitor TVL): median=${med_K:,.0f}," + f" range=[${K_daily.min():,.0f}, ${K_daily.max():,.0f}]") + + # Expand to minute resolution + n_minutes = n_days * 1440 + noise_base = np.repeat(noise_base_daily, 1440) + competitor_tvl_array = np.repeat(K_daily, 1440) + + return { + "noise_base": noise_base, + "competitor_tvl": competitor_tvl_array, + "dates": dates, + "pool_index": pool_idx, + "n_days": n_days, + "n_minutes": n_minutes, + } diff --git a/quantammsim/calibration/per_pool_fit.py b/quantammsim/calibration/per_pool_fit.py index 037a124..be724a9 100644 --- a/quantammsim/calibration/per_pool_fit.py +++ b/quantammsim/calibration/per_pool_fit.py @@ -2,6 +2,9 @@ Fits (log_cadence, log_gas, noise_coeffs) per pool by minimizing the log-space L2 loss using scipy.optimize.minimize with JAX gradients. + +Supports fixed-gas mode where gas is set to the known chain-level cost, +leaving only (log_cadence, noise_coeffs) to be optimized. """ from typing import Dict, Optional @@ -12,7 +15,13 @@ import scipy.optimize from quantammsim.calibration.grid_interpolation import PoolCoeffsDaily -from quantammsim.calibration.loss import K_OBS, pack_params, pool_loss +from quantammsim.calibration.loss import ( + CHAIN_GAS_USD, + K_OBS, + pack_params, + pool_loss, + pool_loss_fixed_gas, +) from quantammsim.calibration.pool_data import build_x_obs @@ -22,14 +31,25 @@ def make_initial_guess(x_obs: np.ndarray, y_obs: np.ndarray) -> np.ndarray: OLS: noise_coeffs = lstsq(x_obs, y_obs) — assumes all volume is noise. This overestimates noise but gives a reasonable starting point. """ + k_obs = x_obs.shape[1] noise_coeffs, _, _, _ = np.linalg.lstsq(x_obs, y_obs, rcond=None) - init = np.zeros(2 + K_OBS) + init = np.zeros(2 + k_obs) init[0] = np.log(12.0) # log_cadence init[1] = np.log(1.0) # log_gas (= 0.0) init[2:] = noise_coeffs return init +def make_initial_guess_fixed_gas(x_obs: np.ndarray, y_obs: np.ndarray) -> np.ndarray: + """Initial params for fixed-gas mode: cadence=12min, noise_coeffs from OLS.""" + k_obs = x_obs.shape[1] + noise_coeffs, _, _, _ = np.linalg.lstsq(x_obs, y_obs, rcond=None) + init = np.zeros(1 + k_obs) + init[0] = np.log(12.0) # log_cadence + init[1:] = noise_coeffs + return init + + def fit_single_pool( coeffs: PoolCoeffsDaily, x_obs: np.ndarray, @@ -37,74 +57,128 @@ def fit_single_pool( day_indices: np.ndarray, init: Optional[np.ndarray] = None, bounds: Optional[dict] = None, + fixed_gas_usd: Optional[float] = None, ) -> dict: - """Fit (log_cadence, log_gas, noise_coeffs) for one pool via L-BFGS-B. + """Fit one pool via L-BFGS-B. + + If fixed_gas_usd is given, gas is held constant at that value and only + (log_cadence, noise_coeffs) are optimized. Otherwise fits all three. Returns dict with fitted params, loss, and convergence status. """ - if init is None: - init = make_initial_guess(x_obs, y_obs) + # Convert to JAX arrays + x_obs_j = jnp.array(x_obs) + y_obs_j = jnp.array(y_obs) + day_idx_j = jnp.array(day_indices) - # Default bounds if bounds is None: bounds = {} log_cad_bounds = bounds.get("log_cadence", (np.log(1.0), np.log(60.0))) - log_gas_bounds = bounds.get("log_gas", (np.log(0.001), np.log(50.0))) noise_bounds = bounds.get("noise_coeffs", (-20.0, 20.0)) - scipy_bounds = [ - log_cad_bounds, - log_gas_bounds, - ] + [(noise_bounds[0], noise_bounds[1])] * K_OBS - - # Convert to JAX arrays - x_obs_j = jnp.array(x_obs) - y_obs_j = jnp.array(y_obs) - day_idx_j = jnp.array(day_indices) + if fixed_gas_usd is not None: + # Fixed-gas mode + fixed_log_gas = jnp.float64(np.log(max(fixed_gas_usd, 1e-6))) + + if init is None: + init = make_initial_guess_fixed_gas(x_obs, y_obs) + + k_obs = x_obs.shape[1] + scipy_bounds = [log_cad_bounds] + [(noise_bounds[0], noise_bounds[1])] * k_obs + + @jax.jit + def loss_and_grad(params_flat): + loss = pool_loss_fixed_gas( + params_flat, fixed_log_gas, coeffs, x_obs_j, y_obs_j, day_idx_j) + grad = jax.grad(pool_loss_fixed_gas, argnums=0)( + params_flat, fixed_log_gas, coeffs, x_obs_j, y_obs_j, day_idx_j) + return loss, grad + + def scipy_wrapper(params_np): + params_j = jnp.array(params_np) + loss, grad = loss_and_grad(params_j) + return float(loss), np.array(grad, dtype=np.float64) + + result = scipy.optimize.minimize( + scipy_wrapper, init, method="L-BFGS-B", jac=True, + bounds=scipy_bounds, + options={"maxiter": 500, "ftol": 1e-10, "gtol": 1e-8}, + ) - # Value and gradient function - @jax.jit - def loss_and_grad(params_flat): - loss = pool_loss(params_flat, coeffs, x_obs_j, y_obs_j, day_idx_j) - grad = jax.grad(pool_loss, argnums=0)( - params_flat, coeffs, x_obs_j, y_obs_j, day_idx_j + log_cadence = float(result.x[0]) + noise_coeffs = np.array(result.x[1:]) + log_gas = float(fixed_log_gas) + + return { + "log_cadence": log_cadence, + "log_gas": log_gas, + "noise_coeffs": noise_coeffs, + "loss": float(result.fun), + "converged": result.success, + "cadence_minutes": float(np.exp(log_cadence)), + "gas_usd": fixed_gas_usd, + "gas_fixed": True, + } + + else: + # Free-gas mode (original) + if init is None: + init = make_initial_guess(x_obs, y_obs) + + k_obs = x_obs.shape[1] + log_gas_bounds = bounds.get("log_gas", (np.log(0.001), np.log(50.0))) + scipy_bounds = [ + log_cad_bounds, log_gas_bounds, + ] + [(noise_bounds[0], noise_bounds[1])] * k_obs + + @jax.jit + def loss_and_grad(params_flat): + loss = pool_loss(params_flat, coeffs, x_obs_j, y_obs_j, day_idx_j) + grad = jax.grad(pool_loss, argnums=0)( + params_flat, coeffs, x_obs_j, y_obs_j, day_idx_j) + return loss, grad + + def scipy_wrapper(params_np): + params_j = jnp.array(params_np) + loss, grad = loss_and_grad(params_j) + return float(loss), np.array(grad, dtype=np.float64) + + result = scipy.optimize.minimize( + scipy_wrapper, init, method="L-BFGS-B", jac=True, + bounds=scipy_bounds, + options={"maxiter": 500, "ftol": 1e-10, "gtol": 1e-8}, ) - return loss, grad - - def scipy_wrapper(params_np): - params_j = jnp.array(params_np) - loss, grad = loss_and_grad(params_j) - return float(loss), np.array(grad, dtype=np.float64) - - result = scipy.optimize.minimize( - scipy_wrapper, - init, - method="L-BFGS-B", - jac=True, - bounds=scipy_bounds, - options={"maxiter": 500, "ftol": 1e-10, "gtol": 1e-8}, - ) - - log_cadence = float(result.x[0]) - log_gas = float(result.x[1]) - noise_coeffs = np.array(result.x[2:]) - - return { - "log_cadence": log_cadence, - "log_gas": log_gas, - "noise_coeffs": noise_coeffs, - "loss": float(result.fun), - "converged": result.success, - "cadence_minutes": float(np.exp(log_cadence)), - "gas_usd": float(np.exp(log_gas)), - } + + log_cadence = float(result.x[0]) + log_gas = float(result.x[1]) + noise_coeffs = np.array(result.x[2:]) + + return { + "log_cadence": log_cadence, + "log_gas": log_gas, + "noise_coeffs": noise_coeffs, + "loss": float(result.fun), + "converged": result.success, + "cadence_minutes": float(np.exp(log_cadence)), + "gas_usd": float(np.exp(log_gas)), + "gas_fixed": False, + } def fit_all_pools( matched: Dict[str, dict], n_workers: int = 1, + fix_gas_to_chain: bool = False, + reduced: bool = False, ) -> Dict[str, dict]: - """Fit all matched pools. Returns prefix -> fit_result with metadata.""" + """Fit all matched pools. Returns prefix -> fit_result with metadata. + + If fix_gas_to_chain is True, gas is fixed to the known chain-level cost + from CHAIN_GAS_USD, and only (log_cadence, noise_coeffs) are optimized. + + If reduced is True, uses the 4-covariate x_obs (intercept, log_tvl_lag1, + dow_sin, dow_cos) instead of the full 8-covariate set. + """ results = {} for prefix, entry in matched.items(): @@ -112,10 +186,16 @@ def fit_all_pools( coeffs = entry["coeffs"] day_indices = entry["day_indices"] - x_obs = build_x_obs(panel) + x_obs = build_x_obs(panel, reduced=reduced) y_obs = panel["log_volume"].values.astype(float) - result = fit_single_pool(coeffs, x_obs, y_obs, day_indices) + fixed_gas = None + if fix_gas_to_chain: + chain = entry["chain"] + fixed_gas = CHAIN_GAS_USD.get(chain, 1.0) + + result = fit_single_pool( + coeffs, x_obs, y_obs, day_indices, fixed_gas_usd=fixed_gas) # Add metadata result["chain"] = entry["chain"] diff --git a/quantammsim/calibration/pool_data.py b/quantammsim/calibration/pool_data.py index 039880b..3b05bbb 100644 --- a/quantammsim/calibration/pool_data.py +++ b/quantammsim/calibration/pool_data.py @@ -6,7 +6,7 @@ import json import os -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import numpy as np import pandas as pd @@ -18,6 +18,10 @@ ) K_OBS = 8 # observation-level covariates +K_OBS_REDUCED = 4 # [intercept, log_tvl_lag1, dow_sin, dow_cos] +K_OBS_CROSS = 7 # [intercept, log_tvl_lag1, dow_sin, dow_cos, + # cross_vol_token_a_{t-1}, cross_vol_token_b_{t-1}, + # cross_vol_chain_{t-1}] # Default path for cached token market caps _MCAP_PATH = os.path.join( @@ -25,6 +29,11 @@ "local_data", "noise_calibration", "token_mcaps.json", ) +# Default path for Binance minute parquets +_BINANCE_DATA_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "data", +) + # Asset type classification (fallback if not in mcap JSON) _STABLECOINS = { "USDC", "USDT", "DAI", "WXDAI", "xDAI", "GHO", "LUSD", "crvUSD", @@ -42,6 +51,22 @@ "waBasWETH", "waGnoGNO", "waGnowstETH", } +# Balancer token → Binance parquet symbol mapping. +# Matches build_pool_grids.py TOKEN_MAP. +TOKEN_MAP = { + "WBTC": "BTC", "WETH": "ETH", "cbBTC": "BTC", + "wstETH": "ETH", "stETH": "ETH", "rETH": "ETH", "cbETH": "ETH", + "waEthLidoWETH": "ETH", "waEthLidowstETH": "ETH", + "waBasWETH": "ETH", "waGnowstETH": "ETH", + "waGnoGNO": "GNO", "osGNO": "GNO", + "wS": "S", "stS": "S", + "JitoSOL": "SOL", + "wPOL": "POL", "WMATIC": "POL", "MATIC": "POL", + "USDC.e": "USDC", "USDbC": "USDC", "waBasUSDC": "USDC", + "DAI": "USDC", "WXDAI": "USDC", "sDAI": "USDC", + "USDT": "USDC", "DOLA": "USDC", "scUSD": "USDC", +} + def _load_token_mcaps(path: str = None) -> dict: """Load cached token market caps. Returns {} if file missing.""" @@ -71,6 +96,153 @@ def _parse_tokens(tokens_str: str) -> List[str]: return [t.strip() for t in tokens_str.split(",")] +def _resolve_binance_symbol(token: str) -> str: + """Map Balancer token name to Binance parquet symbol.""" + return TOKEN_MAP.get(token, token) + + +def _load_binance_minute(symbol: str, data_dir: str = None) -> Optional[pd.DataFrame]: + """Load Binance minute close prices. Returns DataFrame with unix index.""" + if data_dir is None: + data_dir = _BINANCE_DATA_DIR + path = os.path.join(data_dir, f"{symbol}_USD.parquet") + if not os.path.exists(path): + return None + df = pd.read_parquet(path, columns=["unix", "close"]) + if df.index.name != "unix": + df = df.set_index("unix") + return df + + +def compute_binance_pair_volatility( + token_a: str, token_b: str, data_dir: str = None, +) -> Optional[pd.Series]: + """Compute daily annualized realized volatility from Binance minute data. + + Resamples minute data to hourly, computes hourly log returns of the pair + ratio, then daily std × sqrt(24 × 365). Matches the Balancer hourly + pipeline's annualization convention. + + Args: + token_a, token_b: Balancer token symbols (e.g. "WETH", "USDC") + data_dir: directory containing {SYMBOL}_USD.parquet files + + Returns: + pd.Series with datetime.date index → annualized volatility, + or None if both tokens are stablecoins / same underlying / missing data. + """ + sym_a = _resolve_binance_symbol(token_a) + sym_b = _resolve_binance_symbol(token_b) + + is_stable_a = token_a in _STABLECOINS or sym_a == "USDC" + is_stable_b = token_b in _STABLECOINS or sym_b == "USDC" + + if is_stable_a and is_stable_b: + return None # caller should use constant 0.01 + + if sym_a == sym_b: + return None # same underlying (e.g. wstETH/WETH) + + # Load minute data and compute pair ratio + if is_stable_b: + df = _load_binance_minute(sym_a, data_dir) + if df is None: + return None + ratio = df["close"] + elif is_stable_a: + df = _load_binance_minute(sym_b, data_dir) + if df is None: + return None + ratio = 1.0 / df["close"] + else: + df_a = _load_binance_minute(sym_a, data_dir) + df_b = _load_binance_minute(sym_b, data_dir) + if df_a is None or df_b is None: + return None + merged = df_a.join(df_b, lsuffix="_a", rsuffix="_b", how="inner") + ratio = merged["close_a"] / merged["close_b"] + + # Resample to hourly (last close per hour) + ratio_df = pd.DataFrame({"ratio": ratio}) + ratio_df.index = pd.to_datetime(ratio_df.index, unit="ms", utc=True) + hourly = ratio_df.resample("1h").last().dropna() + + # Hourly log returns + hourly["log_return"] = np.log(hourly["ratio"] / hourly["ratio"].shift(1)) + hourly = hourly.dropna() + + # Daily std → annualized + hourly["date"] = hourly.index.date + daily_vol = hourly.groupby("date")["log_return"].std() + annualized = daily_vol * np.sqrt(24 * 365) + + # Clean + annualized = annualized.replace([np.inf, -np.inf], np.nan).dropna() + annualized = annualized[annualized > 0] + + return annualized + + +def replace_panel_volatility_with_binance( + panel: pd.DataFrame, data_dir: str = None, +) -> pd.DataFrame: + """Replace panel 'volatility' column with Binance-derived daily values. + + For each pool, computes daily realized volatility from Binance minute data. + Pools without Binance data keep their existing (possibly fallback) values. + Stablecoin-stablecoin and same-underlying pairs get vol=0.01. + + Returns a copy of the panel with updated volatility. + """ + panel = panel.copy() + panel["date"] = pd.to_datetime(panel["date"]) + + # Cache: (sym_a, sym_b) → vol_series to avoid reloading + _vol_cache: Dict[tuple, Optional[pd.Series]] = {} + + n_replaced = 0 + n_pools = 0 + + for pool_id, grp in panel.groupby("pool_id"): + tokens_str = grp.iloc[0]["tokens"] + toks = _parse_tokens(tokens_str) + if len(toks) < 2: + continue + + sym_a = _resolve_binance_symbol(toks[0]) + sym_b = _resolve_binance_symbol(toks[1]) + cache_key = (min(sym_a, sym_b), max(sym_a, sym_b)) + + if cache_key not in _vol_cache: + _vol_cache[cache_key] = compute_binance_pair_volatility( + toks[0], toks[1], data_dir) + + vol_series = _vol_cache[cache_key] + + if vol_series is None: + # Stablecoins or same underlying → low constant vol + is_stable_a = toks[0] in _STABLECOINS or sym_a == "USDC" + is_stable_b = toks[1] in _STABLECOINS or sym_b == "USDC" + if (is_stable_a and is_stable_b) or sym_a == sym_b: + panel.loc[grp.index, "volatility"] = 0.01 + n_pools += 1 + continue + + # Vectorized date matching + panel_dates = pd.to_datetime(grp["date"]).dt.date + vol_dict = vol_series.to_dict() + new_vol = panel_dates.map(vol_dict) + has_vol = new_vol.notna() + if has_vol.any(): + panel.loc[grp.index[has_vol.values], "volatility"] = ( + new_vol[has_vol].values.astype(float)) + n_replaced += has_vol.sum() + n_pools += 1 + + print(f" Binance volatility: {n_pools} pools, {n_replaced} obs replaced") + return panel + + def match_grids_to_panel( grid_dir: str, panel: pd.DataFrame, pools_path: str = None, ) -> Dict[str, dict]: @@ -179,11 +351,146 @@ def match_grids_to_panel( return matched -def build_x_obs(panel_rows: pd.DataFrame) -> np.ndarray: - """Build (n_obs, 8) observation covariate matrix from panel rows. +# Token canonicalization — map wrapped/LST variants to base tokens +_CANON_MAP = { + "WETH": "ETH", "waBasWETH": "ETH", "waEthLidoWETH": "ETH", + "waEthLidowstETH": "wstETH", "waGnowstETH": "wstETH", + "waBasUSDC": "USDC", "scUSD": "USDC", "USDC.e": "USDC", + "USDbC": "USDC", "waEthUSDC": "USDC", + "sDAI": "DAI", "WXDAI": "DAI", + "WBTC": "BTC", "cbBTC": "BTC", + "stS": "S", "wS": "S", + "waGnoGNO": "GNO", "osGNO": "GNO", +} + + +def _canonicalize_token(symbol: str) -> str: + """Map wrapped/derivative token to its canonical base symbol.""" + return _CANON_MAP.get(symbol, symbol) + + +# Token classification for token-factored model +_ETH_DERIVATIVES = { + "WETH", "ETH", "wstETH", "stETH", "rETH", "cbETH", + "waEthLidoWETH", "waEthLidowstETH", "waBasWETH", "waGnowstETH", +} +_L1_NATIVE = { + "WETH", "ETH", "WMATIC", "MATIC", "POL", "wPOL", + "WAVAX", "AVAX", "GNO", "S", "wS", "stS", +} + +D_TOKEN = 5 # [intercept, log_mcap, is_stable, is_eth_derivative, is_L1_native] + + +def _classify_token(symbol: str, mcaps: dict) -> dict: + """Classify a token into binary feature flags.""" + return { + "is_stable": 1.0 if symbol in _STABLECOINS else 0.0, + "is_eth_derivative": 1.0 if symbol in _ETH_DERIVATIVES else 0.0, + "is_L1_native": 1.0 if symbol in _L1_NATIVE else 0.0, + "log_mcap": np.log(max(mcaps.get(symbol, {}).get("mcap_usd", 1e6), 1.0)), + } + + +def encode_tokens( + matched: Dict[str, dict], + mcap_path: str = None, + canonicalize: bool = True, +) -> dict: + """Build token index, per-pool token assignments, and token covariate matrix. + + Iterates over pools in sorted key order (same ordering as build_pool_attributes). + + When canonicalize=True (default), wrappd/derivative tokens are mapped to + their canonical base symbol via _CANON_MAP before building the index. + Raw symbols are still used for market cap lookup. + + Returns dict with: + token_index: dict[str, int] — symbol -> integer index (sorted alphabetically) + token_a_idx: np.ndarray (n_pools,) — index of token A for each pool + token_b_idx: np.ndarray (n_pools,) — index of token B for each pool + x_token: np.ndarray (n_tokens, D_TOKEN) — token covariate matrix + chain_idx: np.ndarray (n_pools,) — chain integer index per pool + chain_index: dict[str, int] — chain name -> integer index (sorted) + log_fees: np.ndarray (n_pools,) — log(fee) per pool + n_tokens: int + n_chains: int + """ + mcaps = _load_token_mcaps(mcap_path) + pool_ids = sorted(matched.keys()) + n_pools = len(pool_ids) + + # Collect all tokens and chains; store per-pool canonical pairs + all_tokens = set() + all_chains = set() + pool_canon_toks = [] # (canon_a, canon_b) per pool in sorted order + for pid in pool_ids: + entry = matched[pid] + toks = _parse_tokens(entry["tokens"]) + raw_a, raw_b = toks[0], toks[1] + canon_a = _canonicalize_token(raw_a) if canonicalize else raw_a + canon_b = _canonicalize_token(raw_b) if canonicalize else raw_b + all_tokens.update([canon_a, canon_b]) + all_chains.add(entry["chain"]) + pool_canon_toks.append((canon_a, canon_b)) + + # Build sorted indices + token_list = sorted(all_tokens) + token_index = {t: i for i, t in enumerate(token_list)} + n_tokens = len(token_list) + + chain_list = sorted(all_chains) + chain_index = {c: i for i, c in enumerate(chain_list)} + n_chains = len(chain_list) + + # Build per-pool arrays + token_a_idx = np.zeros(n_pools, dtype=np.int32) + token_b_idx = np.zeros(n_pools, dtype=np.int32) + chain_idx = np.zeros(n_pools, dtype=np.int32) + log_fees = np.zeros(n_pools, dtype=np.float64) + + for i, pid in enumerate(pool_ids): + entry = matched[pid] + canon_a, canon_b = pool_canon_toks[i] + token_a_idx[i] = token_index[canon_a] + token_b_idx[i] = token_index[canon_b] + chain_idx[i] = chain_index[entry["chain"]] + log_fees[i] = np.log(entry["fee"]) + + # Build token covariate matrix: (n_tokens, D_TOKEN) + # Columns: [intercept, log_mcap, is_stable, is_eth_derivative, is_L1_native] + x_token = np.zeros((n_tokens, D_TOKEN), dtype=np.float64) + for t, idx in token_index.items(): + cls = _classify_token(t, mcaps) + x_token[idx, 0] = 1.0 # intercept + x_token[idx, 1] = cls["log_mcap"] + x_token[idx, 2] = cls["is_stable"] + x_token[idx, 3] = cls["is_eth_derivative"] + x_token[idx, 4] = cls["is_L1_native"] + + return { + "token_index": token_index, + "token_a_idx": token_a_idx, + "token_b_idx": token_b_idx, + "x_token": x_token, + "chain_idx": chain_idx, + "chain_index": chain_index, + "log_fees": log_fees, + "n_tokens": n_tokens, + "n_chains": n_chains, + } + - Columns: [1, log_tvl_lag1, log_sigma, tvl*sigma, tvl*fee, - sigma*fee, dow_sin, dow_cos] +def build_x_obs(panel_rows: pd.DataFrame, reduced: bool = False) -> np.ndarray: + """Build observation covariate matrix from panel rows. + + Full (reduced=False): (n_obs, 8) + [1, log_tvl_lag1, log_sigma, tvl*sigma, tvl*fee, sigma*fee, dow_sin, dow_cos] + + Reduced (reduced=True): (n_obs, 4) + [1, log_tvl_lag1, dow_sin, dow_cos] + Removes sigma- and fee-dependent terms so the arb channel is the only + path for volatility-driven volume variation. Where: log_sigma = log(max(volatility, 1e-6)) @@ -193,12 +500,21 @@ def build_x_obs(panel_rows: pd.DataFrame) -> np.ndarray: weekday: Monday=0, ..., Sunday=6 """ n = len(panel_rows) + weekdays = pd.to_datetime(panel_rows["date"]).dt.weekday.values.astype(float) + + if reduced: + x = np.zeros((n, K_OBS_REDUCED)) + x[:, 0] = 1.0 + x[:, 1] = panel_rows["log_tvl_lag1"].values.astype(float) + x[:, 2] = np.sin(2 * np.pi * weekdays / 7) + x[:, 3] = np.cos(2 * np.pi * weekdays / 7) + return x + x = np.zeros((n, K_OBS)) tvl = panel_rows["log_tvl_lag1"].values.astype(float) sigma = np.log(np.maximum(panel_rows["volatility"].values.astype(float), 1e-6)) fee = panel_rows["log_fee"].values.astype(float) - weekdays = pd.to_datetime(panel_rows["date"]).dt.weekday.values.astype(float) x[:, 0] = 1.0 # intercept x[:, 1] = tvl # log_tvl_lag1 @@ -212,6 +528,141 @@ def build_x_obs(panel_rows: pd.DataFrame) -> np.ndarray: return x +def build_cross_pool_x_obs( + panel_rows: pd.DataFrame, + matched: Dict[str, dict], + pool_id: str, + exclude_pool: Optional[str] = None, + canonicalize: bool = True, +) -> np.ndarray: + """Build x_obs with cross-pool lagged volume features. + + Columns 0-3: same as build_x_obs(reduced=True) + Column 4: mean log_volume at t-1 across pools sharing token A (excl self) + Column 5: mean log_volume at t-1 across pools sharing token B (excl self) + Column 6: mean log_volume at t-1 across pools on same chain (excl self) + + The first observation (day 0) is dropped because there is no lag available. + + Args: + panel_rows: DataFrame for this pool + matched: full matched dict (all pools) + pool_id: this pool's key in matched (prefix) + exclude_pool: optional pool to exclude from peer averages (for LOO) + canonicalize: if True, canonicalize tokens before peer matching + + Returns: + (n_obs - 1, K_OBS_CROSS) array + """ + # Get this pool's tokens and chain + entry = matched[pool_id] + toks = _parse_tokens(entry["tokens"]) + tok_a_raw, tok_b_raw = toks[0], toks[1] + tok_a = _canonicalize_token(tok_a_raw) if canonicalize else tok_a_raw + tok_b = _canonicalize_token(tok_b_raw) if canonicalize else tok_b_raw + this_chain = entry["chain"] + + # Build peer sets: token→set of pool_ids, chain→set of pool_ids + token_peers = {} # canonical_token → set of (prefix, panel_df) + chain_peers = {} # chain → set of (prefix, panel_df) + all_pool_ids = sorted(matched.keys()) + + for pid in all_pool_ids: + if pid == pool_id: + continue # always exclude self + if pid == exclude_pool: + continue + peer_entry = matched[pid] + peer_toks = _parse_tokens(peer_entry["tokens"]) + peer_canonical = set() + for t in peer_toks[:2]: + ct = _canonicalize_token(t) if canonicalize else t + peer_canonical.add(ct) + + for ct in peer_canonical: + if ct not in token_peers: + token_peers[ct] = [] + token_peers[ct].append(pid) + + peer_chain = peer_entry["chain"] + if peer_chain not in chain_peers: + chain_peers[peer_chain] = [] + chain_peers[peer_chain].append(pid) + + # Build (pool_id, date_ordinal) → log_volume lookup from all pools + vol_lookup = {} # (pid, date_ordinal) → log_volume + for pid in all_pool_ids: + if pid == pool_id or pid == exclude_pool: + continue + peer_panel = matched[pid]["panel"] + peer_dates = pd.to_datetime(peer_panel["date"]) + peer_ords = np.array([d.toordinal() for d in peer_dates]) + peer_vols = peer_panel["log_volume"].values.astype(float) + for ord_val, vol_val in zip(peer_ords, peer_vols): + vol_lookup[(pid, int(ord_val))] = vol_val + + # Compute global lagged mean for fallback + all_vols = list(vol_lookup.values()) + global_mean_vol = float(np.mean(all_vols)) if all_vols else 0.0 + + # Get this pool's dates + dates = pd.to_datetime(panel_rows["date"]) + date_ords = np.array([d.toordinal() for d in dates]) + n_obs = len(panel_rows) + + def _peer_mean_at_lag(peer_pids, date_ord_prev): + """Mean log_volume of peer pools at date_ord_prev.""" + vals = [] + for pid in peer_pids: + key = (pid, date_ord_prev) + if key in vol_lookup: + vals.append(vol_lookup[key]) + if vals: + return float(np.mean(vals)) + return np.nan + + # Build cross-pool features for each obs (starting from day 1) + cross_vol_a = np.full(n_obs, np.nan) + cross_vol_b = np.full(n_obs, np.nan) + cross_vol_chain = np.full(n_obs, np.nan) + + tok_a_peers = token_peers.get(tok_a, []) + tok_b_peers = token_peers.get(tok_b, []) + chain_peer_list = chain_peers.get(this_chain, []) + + for i in range(1, n_obs): + prev_ord = int(date_ords[i - 1]) + + if tok_a_peers: + cross_vol_a[i] = _peer_mean_at_lag(tok_a_peers, prev_ord) + if tok_b_peers: + cross_vol_b[i] = _peer_mean_at_lag(tok_b_peers, prev_ord) + if chain_peer_list: + cross_vol_chain[i] = _peer_mean_at_lag(chain_peer_list, prev_ord) + + # Drop first day, fill NaN with global mean + cross_vol_a = cross_vol_a[1:] + cross_vol_b = cross_vol_b[1:] + cross_vol_chain = cross_vol_chain[1:] + + cross_vol_a = np.where(np.isnan(cross_vol_a), global_mean_vol, cross_vol_a) + cross_vol_b = np.where(np.isnan(cross_vol_b), global_mean_vol, cross_vol_b) + cross_vol_chain = np.where(np.isnan(cross_vol_chain), global_mean_vol, cross_vol_chain) + + # Build base x_obs (reduced) and drop first row + x_base = build_x_obs(panel_rows, reduced=True) + x_base = x_base[1:] # drop first day + + # Assemble + x = np.zeros((n_obs - 1, K_OBS_CROSS)) + x[:, :4] = x_base + x[:, 4] = cross_vol_a + x[:, 5] = cross_vol_b + x[:, 6] = cross_vol_chain + + return x + + def build_pool_attributes( matched: Dict[str, dict], mcap_path: str = None, diff --git a/quantammsim/core_simulator/dynamic_inputs.py b/quantammsim/core_simulator/dynamic_inputs.py index 367363e..1637e2b 100644 --- a/quantammsim/core_simulator/dynamic_inputs.py +++ b/quantammsim/core_simulator/dynamic_inputs.py @@ -14,6 +14,7 @@ class DynamicInputFrames: arb_fees: Optional[Any] = None lp_supply: Optional[Any] = None reclamm_price_ratio_updates: Optional[Any] = None + oracle_prices: Optional[Any] = None class DynamicInputArrays(NamedTuple): @@ -25,6 +26,7 @@ class DynamicInputArrays(NamedTuple): arb_fees: jnp.ndarray lp_supply: jnp.ndarray reclamm_price_ratio_updates: jnp.ndarray + oracle_prices: jnp.ndarray = jnp.ones((1, 1)) def default_dynamic_input_flags() -> dict: @@ -37,6 +39,7 @@ def default_dynamic_input_flags() -> dict: "has_dynamic_arb_fees": False, "has_lp_supply": False, "has_reclamm_price_ratio_updates": False, + "has_oracle_prices": False, } @@ -55,6 +58,7 @@ def dynamic_input_flags_from_frames(dynamic_input_frames: Optional[DynamicInputF "has_reclamm_price_ratio_updates": ( dynamic_input_frames.reclamm_price_ratio_updates is not None ), + "has_oracle_prices": dynamic_input_frames.oracle_prices is not None, } flags["use_dynamic_inputs"] = any(flags.values()) return flags @@ -83,6 +87,7 @@ def empty_dynamic_input_arrays() -> DynamicInputArrays: lp_supply=jnp.ones((1,)), # Columns: has_event, target_price_ratio, end_step, start_price_ratio_override reclamm_price_ratio_updates=jnp.array([[0.0, 0.0, 0.0, jnp.nan]]), + oracle_prices=jnp.ones((1, 1)), ) @@ -117,9 +122,14 @@ def resolve_dynamic_input_components( ), "reclamm_price_ratio_updates": ( arrays.reclamm_price_ratio_updates - if dynamic_input_flags["has_reclamm_price_ratio_updates"] + if dynamic_input_flags.get("has_reclamm_price_ratio_updates", False) else empty_dynamic_input_arrays().reclamm_price_ratio_updates ), + "oracle_prices": ( + arrays.oracle_prices + if dynamic_input_flags.get("has_oracle_prices", False) + else empty_dynamic_input_arrays().oracle_prices + ), } @@ -160,6 +170,7 @@ def materialize_dynamic_inputs( "has_dynamic_arb_fees": True, "has_lp_supply": True, "has_reclamm_price_ratio_updates": True, + "has_oracle_prices": True, } else: flags = resolve_dynamic_input_flags(dynamic_inputs, dynamic_input_flags) @@ -192,4 +203,10 @@ def materialize_dynamic_inputs( scan_len, dtype, ), + oracle_prices=_broadcast_dynamic_input_leaf( + "oracle_prices", + resolved["oracle_prices"], + scan_len, + dtype, + ), ) diff --git a/quantammsim/core_simulator/forward_pass.py b/quantammsim/core_simulator/forward_pass.py index e99f051..1eea66b 100644 --- a/quantammsim/core_simulator/forward_pass.py +++ b/quantammsim/core_simulator/forward_pass.py @@ -72,6 +72,52 @@ def _resolve_dynamic_inputs(dynamic_inputs, static_dict): return dynamic_inputs, dynamic_input_flags +def _slice_dynamic_input_leaf(values, start_idx, length): + """Slice a window from a dynamic-input leaf unless it is singleton.""" + if values is None: + return None + values = jnp.asarray(values) + if values.ndim == 0 or values.shape[0] <= 1: + return values + slice_sizes = (length,) + values.shape[1:] + start = (start_idx,) + (0,) * (values.ndim - 1) + return dynamic_slice(values, start, slice_sizes) + + +def _slice_dynamic_inputs(dynamic_inputs, start_index, static_dict): + """Align full-period dynamic inputs to the active simulation window.""" + if dynamic_inputs is None: + return None + + bout_length = static_dict["bout_length"] + arb_frequency = static_dict.get("arb_frequency", 1) + window_len = bout_length - 1 + minute_start = start_index[0] - static_dict.get("dynamic_inputs_offset", 0) + + def _slice_minute_leaf(values): + sliced = _slice_dynamic_input_leaf(values, minute_start, window_len) + if sliced is None or arb_frequency == 1: + return sliced + return sliced[::arb_frequency] + + schedule_len = window_len if arb_frequency == 1 else (window_len + arb_frequency - 1) // arb_frequency + schedule_start = minute_start if arb_frequency == 1 else minute_start // arb_frequency + + return DynamicInputArrays( + trades=_slice_minute_leaf(dynamic_inputs.trades), + fees=_slice_minute_leaf(dynamic_inputs.fees), + gas_cost=_slice_minute_leaf(dynamic_inputs.gas_cost), + arb_fees=_slice_minute_leaf(dynamic_inputs.arb_fees), + lp_supply=_slice_minute_leaf(dynamic_inputs.lp_supply), + reclamm_price_ratio_updates=_slice_dynamic_input_leaf( + dynamic_inputs.reclamm_price_ratio_updates, + schedule_start, + schedule_len, + ), + oracle_prices=_slice_minute_leaf(dynamic_inputs.oracle_prices), + ) + + def _apply_price_noise(prices, sigma, seed_int): """Apply multiplicative log-normal noise to prices. @@ -1034,6 +1080,8 @@ def forward_pass( dynamic_inputs, dynamic_input_flags = _resolve_dynamic_inputs( dynamic_inputs, static_dict ) + if dynamic_input_flags["use_dynamic_inputs"]: + dynamic_inputs = _slice_dynamic_inputs(dynamic_inputs, start_index, static_dict) fee_revenue = None if dynamic_input_flags["use_dynamic_inputs"]: @@ -1280,6 +1328,7 @@ def forward_pass_nograd( reclamm_price_ratio_updates=stop_gradient( dynamic_inputs.reclamm_price_ratio_updates ), + oracle_prices=stop_gradient(dynamic_inputs.oracle_prices), ) return forward_pass( params, diff --git a/quantammsim/hooks/dynamic_fee_base_hook.py b/quantammsim/hooks/dynamic_fee_base_hook.py index 1ab0266..4cdde1e 100644 --- a/quantammsim/hooks/dynamic_fee_base_hook.py +++ b/quantammsim/hooks/dynamic_fee_base_hook.py @@ -125,6 +125,7 @@ def calculate_reserves_with_fees( arb_fees=jnp.asarray(run_fingerprint["arb_fees"], dtype=jnp.float64), lp_supply=empty_inputs.lp_supply, reclamm_price_ratio_updates=empty_inputs.reclamm_price_ratio_updates, + oracle_prices=empty_inputs.oracle_prices, ) return self.calculate_reserves_with_dynamic_inputs( diff --git a/quantammsim/pools/G3M/balancer/hypersurge_balancer.py b/quantammsim/pools/G3M/balancer/hypersurge_balancer.py new file mode 100644 index 0000000..458dc33 --- /dev/null +++ b/quantammsim/pools/G3M/balancer/hypersurge_balancer.py @@ -0,0 +1,355 @@ +from functools import partial +from typing import Any, Dict, Optional + +import numpy as np + +import jax.numpy as jnp +from jax import jit, tree_util +from jax.lax import dynamic_slice + +from quantammsim.core_simulator.dynamic_inputs import materialize_dynamic_inputs +from quantammsim.pools.hypersurge_utils import ( + HYPERSURGE_PARAM_KEYS, + hypersurge_params_from_params, + run_fingerprint_hypersurge_defaults, +) +from quantammsim.pools.G3M.balancer.balancer import BalancerPool +from quantammsim.pools.G3M.balancer.hypersurge_balancer_reserves import ( + _jax_calc_hypersurge_balancer_reserves, +) + + +def _prepare_dynamic_array(arr, start_index, bout_length, arb_frequency, max_len): + """Slice and decimate a dynamic input array to match the arb-price scan.""" + arr = jnp.asarray(arr) + if arr.ndim == 0: + return jnp.full((max_len,), arr, dtype=arr.dtype) + if arr.shape[0] <= 1: + return jnp.broadcast_to(arr, (max_len,) + arr.shape[1:]) + + start = (start_index[0],) + (0,) * (arr.ndim - 1) + slice_sizes = (bout_length - 1,) + arr.shape[1:] + sliced = dynamic_slice(arr, start, slice_sizes) + if arb_frequency != 1: + sliced = sliced[::arb_frequency] + return sliced + + +class HyperSurgeBalancerPool(BalancerPool): + """Balancer weighted pool with HyperSurge-style state-dependent swap fees.""" + + @staticmethod + def _run_fingerprint_hypersurge_defaults(run_fingerprint: Dict[str, Any]): + return run_fingerprint_hypersurge_defaults(run_fingerprint) + + def _hypersurge_params(self, params: Dict[str, Any], run_fingerprint: Dict[str, Any]): + return hypersurge_params_from_params(params, run_fingerprint) + + def _price_windows( + self, + run_fingerprint: Dict[str, Any], + prices: jnp.ndarray, + start_index: jnp.ndarray, + additional_oracle_input: Optional[jnp.ndarray], + ): + bout_length = run_fingerprint["bout_length"] + n_assets = run_fingerprint["n_assets"] + local_prices = dynamic_slice(prices, start_index, (bout_length - 1, n_assets)) + + if additional_oracle_input is None: + local_oracle_prices = local_prices + else: + local_oracle_prices = dynamic_slice( + additional_oracle_input, + start_index, + (bout_length - 1, n_assets), + ) + + arb_frequency = run_fingerprint["arb_frequency"] + if arb_frequency != 1: + return local_prices[::arb_frequency], local_oracle_prices[::arb_frequency] + return local_prices, local_oracle_prices + + def _noise_inputs( + self, + run_fingerprint: Dict[str, Any], + prices: jnp.ndarray, + start_index: jnp.ndarray, + max_len: int, + ): + noise_model = run_fingerprint.get("noise_model", "ratio") + noise_base = None + noise_tvl_coeff = None + + if noise_model != "market_linear": + return noise_model, noise_base, noise_tvl_coeff + + noise_base = run_fingerprint.get("noise_base_array") + noise_tvl_coeff = run_fingerprint.get("noise_tvl_coeff_array") + if (noise_base is None or noise_tvl_coeff is None) and "noise_arrays_path" in run_fingerprint: + path = run_fingerprint["noise_arrays_path"] + if ( + not hasattr(self, "_market_linear_cache") + or self._market_linear_cache[0] != path + ): + arrays = np.load(path) + self._market_linear_cache = ( + path, + arrays["noise_base"], + arrays["noise_tvl_coeff"], + ) + noise_base = self._market_linear_cache[1] + noise_tvl_coeff = self._market_linear_cache[2] + + if noise_base is None or noise_tvl_coeff is None: + raise ValueError( + "noise_model='market_linear' requires noise_base_array and " + "noise_tvl_coeff_array, or noise_arrays_path." + ) + + noise_base = _prepare_dynamic_array( + jnp.asarray(noise_base), + start_index=start_index, + bout_length=run_fingerprint["bout_length"], + arb_frequency=run_fingerprint["arb_frequency"], + max_len=max_len, + ) + noise_tvl_coeff = _prepare_dynamic_array( + jnp.asarray(noise_tvl_coeff), + start_index=start_index, + bout_length=run_fingerprint["bout_length"], + arb_frequency=run_fingerprint["arb_frequency"], + max_len=max_len, + ) + return noise_model, noise_base, noise_tvl_coeff + + def _run_hypersurge_reserves( + self, + params: Dict[str, Any], + run_fingerprint: Dict[str, Any], + prices: jnp.ndarray, + start_index: jnp.ndarray, + fees, + gas_cost, + arb_fees, + trades, + do_trades: bool, + lp_supply_array, + additional_oracle_input: Optional[jnp.ndarray] = None, + oracle_prices_override: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: + weights = self.calculate_initial_weights(params) + arb_prices, oracle_prices = self._price_windows( + run_fingerprint, + prices, + start_index, + additional_oracle_input, + ) + if oracle_prices_override is not None: + oracle_prices = oracle_prices_override + + initial_pool_value = run_fingerprint["initial_pool_value"] + initial_value_per_token = weights * initial_pool_value + initial_reserves = initial_value_per_token / arb_prices[0] + + noise_model, noise_base, noise_tvl_coeff = self._noise_inputs( + run_fingerprint, + prices, + start_index, + arb_prices.shape[0], + ) + + return _jax_calc_hypersurge_balancer_reserves( + initial_reserves, + weights, + arb_prices, + oracle_prices, + fees=fees, + arb_thresh=gas_cost, + arb_fees=arb_fees, + all_sig_variations=jnp.array(run_fingerprint["all_sig_variations"]), + trades=trades, + do_trades=do_trades, + do_arb=run_fingerprint["do_arb"], + lp_supply_array=lp_supply_array, + hypersurge_params=self._hypersurge_params(params, run_fingerprint), + noise_trader_ratio=run_fingerprint.get("noise_trader_ratio", 0.0), + protocol_fee_split=run_fingerprint.get("protocol_fee_split", 0.0), + noise_model=noise_model, + noise_base_array=noise_base, + noise_tvl_coeff_array=noise_tvl_coeff, + tvl_mean=run_fingerprint.get("noise_tvl_mean", 0.0), + tvl_std=run_fingerprint.get("noise_tvl_std", 1.0), + minutes_per_step=run_fingerprint.get( + "seconds_per_step", + 60.0 * run_fingerprint["arb_frequency"], + ) + / 60.0, + ) + + @partial(jit, static_argnums=2) + def calculate_reserves_with_fees( + self, + params: Dict[str, Any], + run_fingerprint: Dict[str, Any], + prices: jnp.ndarray, + start_index: jnp.ndarray, + additional_oracle_input: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: + return self._run_hypersurge_reserves( + params, + run_fingerprint, + prices, + start_index, + fees=run_fingerprint["fees"], + gas_cost=run_fingerprint["gas_cost"], + arb_fees=run_fingerprint["arb_fees"], + trades=None, + do_trades=False, + lp_supply_array=None, + additional_oracle_input=additional_oracle_input, + ) + + @partial(jit, static_argnums=2) + def calculate_reserves_zero_fees( + self, + params: Dict[str, Any], + run_fingerprint: Dict[str, Any], + prices: jnp.ndarray, + start_index: jnp.ndarray, + additional_oracle_input: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: + return self._run_hypersurge_reserves( + params, + run_fingerprint, + prices, + start_index, + fees=0.0, + gas_cost=run_fingerprint["gas_cost"], + arb_fees=run_fingerprint["arb_fees"], + trades=None, + do_trades=False, + lp_supply_array=None, + additional_oracle_input=additional_oracle_input, + ) + + @partial(jit, static_argnums=2) + def calculate_reserves_with_dynamic_inputs( + self, + params: Dict[str, Any], + run_fingerprint: Dict[str, Any], + prices: jnp.ndarray, + start_index: jnp.ndarray, + dynamic_inputs, + additional_oracle_input: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: + arb_prices, fallback_oracle_prices = self._price_windows( + run_fingerprint, + prices, + start_index, + additional_oracle_input, + ) + materialized_inputs = materialize_dynamic_inputs( + dynamic_inputs, + run_fingerprint.get("dynamic_input_flags"), + run_fingerprint, + scan_len=arb_prices.shape[0], + do_trades=run_fingerprint["do_trades"], + dtype=arb_prices.dtype, + ) + + oracle_prices = fallback_oracle_prices + if materialized_inputs.oracle_prices.shape[-1] == run_fingerprint["n_assets"]: + oracle_prices = materialized_inputs.oracle_prices + + return self._run_hypersurge_reserves( + params, + run_fingerprint, + prices, + start_index, + fees=materialized_inputs.fees, + gas_cost=materialized_inputs.gas_cost, + arb_fees=materialized_inputs.arb_fees, + trades=materialized_inputs.trades, + do_trades=run_fingerprint["do_trades"], + lp_supply_array=materialized_inputs.lp_supply, + additional_oracle_input=additional_oracle_input, + oracle_prices_override=oracle_prices, + ) + + def init_base_parameters( + self, + initial_values_dict: Dict[str, Any], + run_fingerprint: Dict[str, Any], + n_assets: int, + n_parameter_sets: int = 1, + noise: str = "gaussian", + ) -> Dict[str, Any]: + np.random.seed(0) + + def process_weights(key): + if key not in initial_values_dict: + raise ValueError(f"initial_values_dict must contain {key}") + initial_value = initial_values_dict[key] + if isinstance(initial_value, (np.ndarray, jnp.ndarray, list)): + initial_value = np.array(initial_value) + if initial_value.size == n_assets: + return np.array([initial_value] * n_parameter_sets) + if initial_value.size == 1: + return np.array([[initial_value] * n_assets] * n_parameter_sets) + if initial_value.shape == (n_parameter_sets, n_assets): + return initial_value + raise ValueError( + f"{key} must be a singleton or a vector of length n_assets " + "or a matrix of shape (n_parameter_sets, n_assets)" + ) + return np.array([[initial_value] * n_assets] * n_parameter_sets) + + def process_scalar(key, default): + value = initial_values_dict.get(key, default) + if value is None: + value = default + value = np.asarray(value, dtype=np.float64) + if value.size == 1: + return np.array([[float(value.reshape(-1)[0])]] * n_parameter_sets) + if value.shape == (n_parameter_sets,): + return value.reshape(n_parameter_sets, 1) + if value.shape == (n_parameter_sets, 1): + return value + raise ValueError( + f"{key} must be a scalar or a matrix of shape " + "(n_parameter_sets, 1)" + ) + + hypersurge_defaults = self._run_fingerprint_hypersurge_defaults( + run_fingerprint + ) + params = { + "initial_weights_logits": process_weights("initial_weights_logits"), + "subsidary_params": [], + } + for key in HYPERSURGE_PARAM_KEYS: + params[key] = process_scalar(key, hypersurge_defaults[key]) + + return self.add_noise(params, noise, n_parameter_sets) + + def get_initial_values(self, run_fingerprint): + values = { + "initial_weights_logits": run_fingerprint.get( + "initial_weights_logits", 1.0 + ), + } + defaults = self._run_fingerprint_hypersurge_defaults(run_fingerprint) + for key, value in defaults.items(): + values[key] = run_fingerprint.get(f"initial_{key}", value) + return values + + def is_trainable(self): + return True + + +tree_util.register_pytree_node( + HyperSurgeBalancerPool, + HyperSurgeBalancerPool._tree_flatten, + HyperSurgeBalancerPool._tree_unflatten, +) diff --git a/quantammsim/pools/G3M/balancer/hypersurge_balancer_reserves.py b/quantammsim/pools/G3M/balancer/hypersurge_balancer_reserves.py new file mode 100644 index 0000000..a4e326f --- /dev/null +++ b/quantammsim/pools/G3M/balancer/hypersurge_balancer_reserves.py @@ -0,0 +1,427 @@ +from functools import partial + +import jax.numpy as jnp +from jax import jit +from jax.lax import scan +from jax.tree_util import Partial + +from quantammsim.pools.G3M.G3M_trades import ( + _jax_calc_G3M_trade_from_exact_out_given_in, +) +from quantammsim.pools.G3M.optimal_n_pool_arb import ( + precalc_components_of_optimal_trade_across_signatures, + precalc_shared_values_for_all_signatures, + parallelised_optimal_trade_sifter, +) +from quantammsim.pools.hypersurge_utils import ( + _EPS, + broadcast_scan_vector as _broadcast_scan_vector, + fee_to_gamma as _fee_to_gamma, + max_pair_deviation as _max_pair_deviation, + oracle_pair_is_valid, + oracle_vector_is_valid, + pair_deviation as _pair_deviation, + ramp_fee as _ramp_fee, + safe_positive as _safe_positive, +) +from quantammsim.pools.noise_trades import ( + calculate_reserves_after_noise_trade, + reclamm_market_linear_noise_volume, +) + +def _hypersurge_fee_for_trade( + reserves, + candidate_trade, + weights, + oracle_prices, + token_in, + token_out, + base_fee, + hypersurge_params, +): + """Select arb/noise fee params from whether a candidate trade worsens peg deviation.""" + pair_has_oracle = oracle_pair_is_valid(oracle_prices, token_in, token_out) + candidate_reserves = _safe_positive(reserves + candidate_trade) + dev_before = _pair_deviation(reserves, weights, oracle_prices, token_in, token_out) + dev_after = _pair_deviation( + candidate_reserves, weights, oracle_prices, token_in, token_out + ) + trade_active = jnp.logical_and( + candidate_trade[token_in] > 0.0, + candidate_trade[token_out] < 0.0, + ) + + worsens = dev_after > dev_before + arb_fee = _ramp_fee( + base_fee, + hypersurge_params[0], + hypersurge_params[1], + hypersurge_params[2], + dev_before, + ) + noise_fee = _ramp_fee( + base_fee, + hypersurge_params[3], + hypersurge_params[4], + hypersurge_params[5], + dev_after, + ) + fee = jnp.where(worsens, noise_fee, arb_fee) + return jnp.where(jnp.logical_and(trade_active, pair_has_oracle), fee, base_fee) + + +def _hypersurge_noise_fee(reserves, weights, oracle_prices, base_fee, hypersurge_params): + deviation = _max_pair_deviation(reserves, weights, oracle_prices) + fee = _ramp_fee( + base_fee, + hypersurge_params[3], + hypersurge_params[4], + hypersurge_params[5], + deviation, + ) + return jnp.where(oracle_vector_is_valid(oracle_prices), fee, base_fee) + + +def _zero_fee_optimal_trade(reserves, weights, prices): + current_value = jnp.sum(reserves * prices) + quoted_prices = current_value * weights / jnp.maximum(reserves, _EPS) + price_change_ratio = prices / jnp.maximum(quoted_prices, _EPS) + price_product_change_ratio = jnp.prod(price_change_ratio**weights) + reserves_ratios_from_price_change = ( + price_product_change_ratio / jnp.maximum(price_change_ratio, _EPS) + ) + return reserves * reserves_ratios_from_price_change - reserves + + +def _trade_pair_from_delta(trade): + token_in = jnp.argmax(trade) + token_out = jnp.argmin(trade) + return token_in, token_out + + +def _apply_protocol_fee(reserves_after_trade, trade, fee, protocol_fee_split): + inbound = jnp.maximum(trade, 0.0) + protocol_fee = inbound * fee * protocol_fee_split + return jnp.maximum(reserves_after_trade - protocol_fee, _EPS) + + +def _optimal_arb_trade_with_gamma( + reserves, + weights, + prices, + gamma, + tokens_to_drop, + active_trade_directions, + leave_one_out_idxs, + n, +): + active_initial_weights, per_asset_ratios, all_other_assets_ratios = ( + precalc_components_of_optimal_trade_across_signatures( + weights, + prices, + gamma, + tokens_to_drop, + active_trade_directions, + leave_one_out_idxs, + ) + ) + return parallelised_optimal_trade_sifter( + reserves, + weights, + prices, + active_initial_weights, + active_trade_directions, + per_asset_ratios, + all_other_assets_ratios, + tokens_to_drop, + gamma, + n, + -1e-15, + ) + +def _broadcast_oracle_prices(oracle_prices, prices): + oracle_prices = jnp.asarray(oracle_prices) + if oracle_prices.ndim == 1: + oracle_prices = oracle_prices.reshape((1, oracle_prices.shape[0])) + if oracle_prices.shape[-1] != prices.shape[-1]: + oracle_prices = prices + elif oracle_prices.shape[0] == 1: + oracle_prices = jnp.broadcast_to(oracle_prices, prices.shape) + return oracle_prices + + +def _hypersurge_scan_step( + carry_list, + input_list, + weights, + tokens_to_drop, + active_trade_directions, + leave_one_out_idxs, + n, + do_trades, + do_arb, + hypersurge_params, + protocol_fee_split, + noise_trader_ratio, + noise_model, + tvl_mean, + tvl_std, + minutes_per_step, +): + reserves = carry_list[1] + prev_lp_supply = carry_list[2] + + prices = input_list[0] + oracle_prices = input_list[1] + base_fee = input_list[2] + arb_thresh = input_list[3] + arb_fees = input_list[4] + trade = input_list[5] + lp_supply = input_list[6] + noise_base = input_list[7] + noise_tvl_coeff = input_list[8] + + reserves = jnp.where( + lp_supply != prev_lp_supply, + reserves * lp_supply / jnp.maximum(prev_lp_supply, _EPS), + reserves, + ) + + applied_arb_trade = jnp.zeros_like(reserves) + if do_arb: + preview_trade = _zero_fee_optimal_trade(reserves, weights, prices) + token_in, token_out = _trade_pair_from_delta(preview_trade) + preview_fee = _hypersurge_fee_for_trade( + reserves, + preview_trade, + weights, + oracle_prices, + token_in, + token_out, + base_fee, + hypersurge_params, + ) + preview_trade = _optimal_arb_trade_with_gamma( + reserves, + weights, + prices, + _fee_to_gamma(preview_fee), + tokens_to_drop, + active_trade_directions, + leave_one_out_idxs, + n, + ) + token_in, token_out = _trade_pair_from_delta(preview_trade) + arb_fee = _hypersurge_fee_for_trade( + reserves, + preview_trade, + weights, + oracle_prices, + token_in, + token_out, + base_fee, + hypersurge_params, + ) + optimal_arb_trade = _optimal_arb_trade_with_gamma( + reserves, + weights, + prices, + _fee_to_gamma(arb_fee), + tokens_to_drop, + active_trade_directions, + leave_one_out_idxs, + n, + ) + profit_to_arb = -(optimal_arb_trade * prices).sum() - arb_thresh + arb_external_rebalance_cost = ( + 0.5 * arb_fees * (jnp.abs(optimal_arb_trade) * prices).sum() + ) + arb_profitable = profit_to_arb >= arb_external_rebalance_cost + applied_arb_trade = jnp.where( + arb_profitable, + optimal_arb_trade, + applied_arb_trade, + ) + reserves = _apply_protocol_fee( + reserves + applied_arb_trade, + applied_arb_trade, + arb_fee, + protocol_fee_split, + ) + + if do_trades: + token_in = jnp.int32(trade[0]) + token_out = jnp.int32(trade[1]) + amount_in = trade[2] + preview_trade = _jax_calc_G3M_trade_from_exact_out_given_in( + reserves, + weights, + token_in, + token_out, + amount_in, + gamma=_fee_to_gamma(base_fee), + ) + trade_fee = _hypersurge_fee_for_trade( + reserves, + preview_trade, + weights, + oracle_prices, + token_in, + token_out, + base_fee, + hypersurge_params, + ) + applied_user_trade = _jax_calc_G3M_trade_from_exact_out_given_in( + reserves, + weights, + token_in, + token_out, + amount_in, + gamma=_fee_to_gamma(trade_fee), + ) + reserves = _apply_protocol_fee( + reserves + applied_user_trade, + applied_user_trade, + trade_fee, + protocol_fee_split, + ) + + if noise_model == "ratio": + noise_fee = _hypersurge_noise_fee( + reserves, weights, oracle_prices, base_fee, hypersurge_params + ) + lp_noise_gamma = _fee_to_gamma(noise_fee * (1.0 - protocol_fee_split)) + noisy_reserves = calculate_reserves_after_noise_trade( + applied_arb_trade, + reserves, + prices, + noise_trader_ratio, + lp_noise_gamma, + ) + reserves = jnp.where(noise_trader_ratio > 0.0, noisy_reserves, reserves) + elif noise_model == "market_linear": + noise_fee = _hypersurge_noise_fee( + reserves, weights, oracle_prices, base_fee, hypersurge_params + ) + pool_value = jnp.sum(reserves * prices) + noise_volume = reclamm_market_linear_noise_volume( + pool_value, + noise_base, + noise_tvl_coeff, + tvl_mean=tvl_mean, + tvl_std=tvl_std, + ) + lp_fee_income = ( + noise_fee * (1.0 - protocol_fee_split) * noise_volume * minutes_per_step + ) + reserves = reserves * (1.0 + lp_fee_income / jnp.maximum(pool_value, 1e-8)) + + return [ + prices, + reserves, + lp_supply, + ], reserves + + +@partial(jit, static_argnames=("do_trades", "do_arb", "noise_model")) +def _jax_calc_hypersurge_balancer_reserves( + initial_reserves, + weights, + prices, + oracle_prices, + fees=0.003, + arb_thresh=0.0, + arb_fees=0.0, + all_sig_variations=None, + trades=None, + do_trades=False, + do_arb=True, + lp_supply_array=None, + hypersurge_params=None, + noise_trader_ratio=0.0, + protocol_fee_split=0.0, + noise_model="ratio", + noise_base_array=None, + noise_tvl_coeff_array=None, + tvl_mean=0.0, + tvl_std=1.0, + minutes_per_step=1.0, +): + n_assets = weights.shape[0] + scan_len = prices.shape[0] + + fees = _broadcast_scan_vector(fees, scan_len) + arb_thresh = _broadcast_scan_vector(arb_thresh, scan_len) + arb_fees = _broadcast_scan_vector(arb_fees, scan_len) + oracle_prices = _broadcast_oracle_prices(oracle_prices, prices) + + if trades is None: + if do_trades: + raise ValueError("Trades must be provided when do_trades=True.") + trades = jnp.zeros((scan_len, 3), dtype=prices.dtype) + + if lp_supply_array is None: + lp_supply_array = jnp.ones((scan_len,), dtype=prices.dtype) + else: + lp_supply_array = _broadcast_scan_vector(lp_supply_array, scan_len) + + if hypersurge_params is None: + hypersurge_params = jnp.array([fees[0], 0.0, 1.0, fees[0], 0.0, 1.0]) + else: + hypersurge_params = jnp.asarray(hypersurge_params, dtype=prices.dtype) + + if noise_base_array is None: + noise_base_array = jnp.zeros((scan_len,), dtype=prices.dtype) + else: + noise_base_array = _broadcast_scan_vector(noise_base_array, scan_len) + if noise_tvl_coeff_array is None: + noise_tvl_coeff_array = jnp.zeros((scan_len,), dtype=prices.dtype) + else: + noise_tvl_coeff_array = _broadcast_scan_vector( + noise_tvl_coeff_array, scan_len + ) + + _, active_trade_directions, tokens_to_drop, leave_one_out_idxs = ( + precalc_shared_values_for_all_signatures(all_sig_variations, n_assets) + ) + + scan_fn = Partial( + _hypersurge_scan_step, + weights=weights, + tokens_to_drop=tokens_to_drop, + active_trade_directions=active_trade_directions, + leave_one_out_idxs=leave_one_out_idxs, + n=n_assets, + do_trades=do_trades, + do_arb=do_arb, + hypersurge_params=hypersurge_params, + protocol_fee_split=protocol_fee_split, + noise_trader_ratio=noise_trader_ratio, + noise_model=noise_model, + tvl_mean=tvl_mean, + tvl_std=tvl_std, + minutes_per_step=minutes_per_step, + ) + + carry_list_init = [ + prices[0], + initial_reserves, + lp_supply_array[0], + ] + _, reserves = scan( + scan_fn, + carry_list_init, + [ + prices, + oracle_prices, + fees, + arb_thresh, + arb_fees, + trades, + lp_supply_array, + noise_base_array, + noise_tvl_coeff_array, + ], + ) + + return reserves diff --git a/quantammsim/pools/creator.py b/quantammsim/pools/creator.py index a68d3ea..28aab37 100644 --- a/quantammsim/pools/creator.py +++ b/quantammsim/pools/creator.py @@ -4,6 +4,7 @@ from jax import tree_util from quantammsim.pools.G3M.balancer.balancer import BalancerPool +from quantammsim.pools.G3M.balancer.hypersurge_balancer import HyperSurgeBalancerPool from quantammsim.pools.G3M.quantamm.momentum_pool import MomentumPool from quantammsim.pools.G3M.quantamm.antimomentum_pool import AntiMomentumPool from quantammsim.pools.G3M.quantamm.power_channel_pool import PowerChannelPool @@ -20,6 +21,7 @@ from quantammsim.pools.FM_AMM.cow_pool import CowPool from quantammsim.pools.ECLP.gyroscope import GyroscopePool from quantammsim.pools.reCLAMM.reclamm import ReClammPool +from quantammsim.pools.reCLAMM.reclamm_hypersurge import ReClammHyperSurgePool from quantammsim.pools.base_pool import AbstractPool from quantammsim.hooks.versus_rebalancing import ( CalculateLossVersusRebalancing, @@ -132,6 +134,7 @@ def create_pool(rule): Valid base pool types: - ``"balancer"`` : Standard Balancer constant-weight pool. + - ``"balancer_hypersurge"`` : Balancer pool with HyperSurge dynamic fees. - ``"momentum"`` : Momentum (trend-following) QuantAMM pool. - ``"anti_momentum"`` : Anti-momentum (contrarian) QuantAMM pool. - ``"power_channel"`` : Power-law channel QuantAMM pool. @@ -148,6 +151,7 @@ def create_pool(rule): - ``"hodl"`` : Pure buy-and-hold (no rebalancing) pool. - ``"cow"`` : CoW (Coincidence of Wants) AMM pool. - ``"gyroscope"`` : Gyroscope E-CLP pool. + - ``"reclamm_hypersurge"`` : reCLAMM pool with HyperSurge dynamic fees. Available hook prefixes (prepended with ``__`` separator): @@ -203,6 +207,8 @@ def create_pool(rule): # Create base pool instance if base_rule == "balancer": base_pool = BalancerPool() + elif base_rule in ("balancer_hypersurge", "hypersurge_balancer"): + base_pool = HyperSurgeBalancerPool() elif base_rule == "momentum": base_pool = MomentumPool() elif base_rule == "anti_momentum": @@ -231,6 +237,8 @@ def create_pool(rule): base_pool = GyroscopePool() elif base_rule == "reclamm": base_pool = ReClammPool() + elif base_rule in ("reclamm_hypersurge", "hypersurge_reclamm"): + base_pool = ReClammHyperSurgePool() else: raise NotImplementedError(f"Unknown base pool type: {base_rule}") diff --git a/quantammsim/pools/hypersurge_utils.py b/quantammsim/pools/hypersurge_utils.py new file mode 100644 index 0000000..7c1cca7 --- /dev/null +++ b/quantammsim/pools/hypersurge_utils.py @@ -0,0 +1,207 @@ +from typing import Any, Dict + +import numpy as np + +import jax.numpy as jnp + + +HYPERSURGE_PARAM_KEYS = ( + "hypersurge_arb_max_fee", + "hypersurge_arb_threshold", + "hypersurge_arb_cap_deviation", + "hypersurge_noise_max_fee", + "hypersurge_noise_threshold", + "hypersurge_noise_cap_deviation", +) + +_EPS = 1e-18 +_MAX_FEE = 0.999999 + + +def _coalesce(value, default): + return default if value is None else value + + +def _scalar_like(value, default=0.0): + if value is None: + return default + if isinstance(value, (list, tuple)): + return value[0] if value else default + if isinstance(value, (np.ndarray, jnp.ndarray)): + flat = np.asarray(value).reshape(-1) + return flat[0] if flat.size else default + return value + + +def run_fingerprint_hypersurge_defaults(run_fingerprint: Dict[str, Any]): + base_fee = _scalar_like(run_fingerprint.get("fees", 0.0), default=0.0) + + raw_params = run_fingerprint.get("hypersurge_params") + if raw_params is not None: + if isinstance(raw_params, dict): + shared_max = raw_params.get("max_surge_fee", base_fee) + shared_threshold = raw_params.get("threshold", 0.0) + shared_cap = raw_params.get("cap_deviation", 1.0) + return { + "hypersurge_arb_max_fee": raw_params.get("arb_max_fee", shared_max), + "hypersurge_arb_threshold": raw_params.get( + "arb_threshold", shared_threshold + ), + "hypersurge_arb_cap_deviation": raw_params.get( + "arb_cap_deviation", shared_cap + ), + "hypersurge_noise_max_fee": raw_params.get( + "noise_max_fee", shared_max + ), + "hypersurge_noise_threshold": raw_params.get( + "noise_threshold", shared_threshold + ), + "hypersurge_noise_cap_deviation": raw_params.get( + "noise_cap_deviation", shared_cap + ), + } + + raw_params = np.asarray(raw_params, dtype=np.float64).reshape(-1) + if raw_params.size != len(HYPERSURGE_PARAM_KEYS): + raise ValueError( + "hypersurge_params must contain exactly six values: " + + ", ".join(HYPERSURGE_PARAM_KEYS) + ) + return dict(zip(HYPERSURGE_PARAM_KEYS, raw_params)) + + shared_max = _coalesce( + run_fingerprint.get("hypersurge_max_surge_fee"), + _coalesce(run_fingerprint.get("hypersurge_max_fee"), base_fee), + ) + shared_threshold = _coalesce( + run_fingerprint.get("hypersurge_threshold"), + 0.0, + ) + shared_cap = _coalesce( + run_fingerprint.get("hypersurge_cap_deviation"), + 1.0, + ) + return { + "hypersurge_arb_max_fee": _coalesce( + run_fingerprint.get("hypersurge_arb_max_fee"), shared_max + ), + "hypersurge_arb_threshold": _coalesce( + run_fingerprint.get("hypersurge_arb_threshold"), shared_threshold + ), + "hypersurge_arb_cap_deviation": _coalesce( + run_fingerprint.get("hypersurge_arb_cap_deviation"), shared_cap + ), + "hypersurge_noise_max_fee": _coalesce( + run_fingerprint.get("hypersurge_noise_max_fee"), shared_max + ), + "hypersurge_noise_threshold": _coalesce( + run_fingerprint.get("hypersurge_noise_threshold"), shared_threshold + ), + "hypersurge_noise_cap_deviation": _coalesce( + run_fingerprint.get("hypersurge_noise_cap_deviation"), shared_cap + ), + } + + +def hypersurge_params_from_params(params: Dict[str, Any], run_fingerprint: Dict[str, Any]): + if "hypersurge_params" in params: + return jnp.ravel(params["hypersurge_params"]) + + if all(key in params for key in HYPERSURGE_PARAM_KEYS): + return jnp.asarray( + [jnp.squeeze(params[key]) for key in HYPERSURGE_PARAM_KEYS], + dtype=jnp.float64, + ) + + defaults = run_fingerprint_hypersurge_defaults(run_fingerprint) + return jnp.asarray( + [defaults[key] for key in HYPERSURGE_PARAM_KEYS], + dtype=jnp.float64, + ) + + +def safe_positive(values): + values = jnp.asarray(values) + values = jnp.where(jnp.isfinite(values), values, 1.0) + return jnp.maximum(values, _EPS) + + +def fee_to_gamma(fee): + return jnp.maximum(1.0 - jnp.clip(fee, 0.0, _MAX_FEE), _EPS) + + +def ramp_fee(base_fee, max_fee, threshold, cap, deviation): + max_fee = jnp.maximum(max_fee, base_fee) + threshold = jnp.maximum(threshold, 0.0) + cap = jnp.maximum(cap, threshold + _EPS) + span = jnp.maximum(cap - threshold, _EPS) + ramp = jnp.clip((deviation - threshold) / span, 0.0, 1.0) + fee = base_fee + (max_fee - base_fee) * ramp + fee = jnp.where(deviation <= threshold, base_fee, fee) + return jnp.clip(fee, 0.0, _MAX_FEE) + + +def oracle_pair_is_valid(oracle_prices, token_in, token_out): + oracle_prices = jnp.asarray(oracle_prices) + token_in = jnp.int32(token_in) + token_out = jnp.int32(token_out) + pair_prices = jnp.asarray([oracle_prices[token_in], oracle_prices[token_out]]) + return jnp.all(jnp.isfinite(pair_prices) & (pair_prices > 0.0)) + + +def oracle_vector_is_valid(oracle_prices): + oracle_prices = jnp.asarray(oracle_prices) + return jnp.all(jnp.isfinite(oracle_prices) & (oracle_prices > 0.0)) + + +def pair_pool_price(reserves, weights, token_in, token_out): + token_in = jnp.int32(token_in) + token_out = jnp.int32(token_out) + reserves = safe_positive(reserves) + weights = safe_positive(weights) + numerator = reserves[token_out] * weights[token_in] + denominator = reserves[token_in] * weights[token_out] + return numerator / jnp.maximum(denominator, _EPS) + + +def pair_oracle_price(oracle_prices, token_in, token_out): + token_in = jnp.int32(token_in) + token_out = jnp.int32(token_out) + oracle_prices = jnp.asarray(oracle_prices) + return oracle_prices[token_in] / jnp.maximum(oracle_prices[token_out], _EPS) + + +def pair_deviation(reserves, weights, oracle_prices, token_in, token_out): + pool_price = pair_pool_price(reserves, weights, token_in, token_out) + oracle_price = pair_oracle_price(oracle_prices, token_in, token_out) + ratio = pool_price / jnp.maximum(oracle_price, _EPS) + ratio = jnp.where(jnp.isfinite(ratio), ratio, 1.0) + return jnp.abs(ratio - 1.0) + + +def max_pair_deviation(reserves, weights, oracle_prices): + reserves = safe_positive(reserves) + weights = safe_positive(weights) + oracle_prices = jnp.asarray(oracle_prices) + + pool_prices = (reserves[None, :] * weights[:, None]) / jnp.maximum( + reserves[:, None] * weights[None, :], + _EPS, + ) + oracle_pair_prices = oracle_prices[:, None] / jnp.maximum( + oracle_prices[None, :], + _EPS, + ) + ratios = pool_prices / jnp.maximum(oracle_pair_prices, _EPS) + ratios = jnp.where(jnp.isfinite(ratios), ratios, 1.0) + deviations = jnp.abs(ratios - 1.0) + off_diagonal = ~jnp.eye(reserves.shape[0], dtype=bool) + return jnp.max(jnp.where(off_diagonal, deviations, 0.0)) + + +def broadcast_scan_vector(values, scan_len): + values = jnp.asarray(values) + if values.ndim == 0: + values = values.reshape((1,)) + values = jnp.ravel(values) + return jnp.where(values.size == 1, jnp.full((scan_len,), values[0]), values) diff --git a/quantammsim/pools/noise_trades.py b/quantammsim/pools/noise_trades.py index 3f45792..9c9f00f 100644 --- a/quantammsim/pools/noise_trades.py +++ b/quantammsim/pools/noise_trades.py @@ -272,3 +272,177 @@ def reclamm_loglinear_noise_volume( return jnp.maximum(0.0, daily_vol / 1440.0 - arb_volume_this_period) +@jit +def reclamm_calibrated_noise_volume( + effective_value_usd, + gamma, + volatility, + arb_volume_this_period, + dow_sin, + dow_cos, + noise_params=None, +): + """8-covariate calibrated noise volume from cross-pool log-linear model. + + Predicts per-minute noise volume using:: + + log(V_daily) = c_0 + c_1*log(TVL) + c_2*log(sigma) + + c_3*log(TVL)*log(sigma) + c_4*log(TVL)*fee + + c_5*log(sigma)*fee + c_6*dow_sin + c_7*dow_cos + V_noise = max(0, exp(log_daily_vol) / 1440 - arb_volume) + + where sigma is annualised daily realised volatility, fee = 1 - gamma, + and dow_sin/dow_cos encode day-of-week seasonality. + + Parameters + ---------- + effective_value_usd : float + Effective TVL in USD: (Ra+Va)*pA + (Rb+Vb)*pB. + gamma : float + Fee parameter (1 - fee_rate). + volatility : float + Annualised daily realised volatility of the price ratio. + arb_volume_this_period : float + Arb volume already accounted for this time step (USD). + dow_sin : float + sin(2*pi*weekday/7) for the current day. + dow_cos : float + cos(2*pi*weekday/7) for the current day. + noise_params : dict, optional + Calibrated coefficients: c_0 .. c_7. + + Returns + ------- + float + Per-minute noise volume (USD), floored at zero. + """ + if noise_params is None: + noise_params = {} + c_0 = noise_params.get("c_0", 0.0) + c_1 = noise_params.get("c_1", 1.0) + c_2 = noise_params.get("c_2", 0.0) + c_3 = noise_params.get("c_3", 0.0) + c_4 = noise_params.get("c_4", 0.0) + c_5 = noise_params.get("c_5", 0.0) + c_6 = noise_params.get("c_6", 0.0) + c_7 = noise_params.get("c_7", 0.0) + + fee = 1.0 - gamma + log_tvl = jnp.log(jnp.maximum(effective_value_usd, 1.0)) + log_sigma = jnp.log(jnp.maximum(volatility, 1e-10)) + + log_daily_vol = ( + c_0 + + c_1 * log_tvl + + c_2 * log_sigma + + c_3 * log_tvl * log_sigma + + c_4 * log_tvl * fee + + c_5 * log_sigma * fee + + c_6 * dow_sin + + c_7 * dow_cos + ) + daily_vol = jnp.exp(log_daily_vol) + # noise_coeffs predict V_noise directly (not V_total), so no need to + # subtract arb volume — that would double-count the arb subtraction. + return jnp.maximum(0.0, daily_vol / 1440.0) + + +@jit +def reclamm_market_linear_noise_volume( + effective_value_usd, + noise_base, + noise_tvl_coeff, + tvl_mean=0.0, + tvl_std=1.0, +): + """Market-feature linear noise model with precomputed daily coefficients. + + The full model is:: + + log(V_daily_noise) = base_t + tvl_coeff_t * standardized_log_tvl + + where ``base_t`` absorbs all non-TVL terms (intercept, market regime, + token volatility, pair volatility, day-of-week, cross-pool volumes) + and ``tvl_coeff_t`` is the effective TVL coefficient including + interaction terms (tvl×btc_vol, tvl×tok_a_vol, tvl×pair_vol). + + The log(TVL) is standardized using the same mean/std from training + to ensure the coefficient scale matches. + + Both base_t and tvl_coeff_t are precomputed daily from the per-pool + calibrated noise model and passed in as dynamic input arrays. + + Under counterfactual (varying reClAMM concentration), only + ``effective_value_usd`` changes — all market/peer features are held + at observed values via the precomputed arrays. + + Parameters + ---------- + effective_value_usd : float + Effective TVL in USD: (Ra+Va)*pA + (Rb+Vb)*pB. + noise_base : float + Precomputed non-TVL component of log(V_daily_noise) for this step. + noise_tvl_coeff : float + Precomputed effective coefficient on log(TVL) for this step. + tvl_mean : float + Mean of log(TVL) from training data standardization. + tvl_std : float + Std of log(TVL) from training data standardization. + + Returns + ------- + float + Per-minute noise volume (USD), floored at zero. + """ + log_tvl = jnp.log(jnp.maximum(effective_value_usd, 1.0)) + # Clamp standardized TVL to training range [-3, +3] std to prevent + # extreme concentration from wireheading the noise model + standardized_log_tvl = jnp.clip( + (log_tvl - tvl_mean) / tvl_std, -3.0, 3.0) + log_daily_noise = noise_base + noise_tvl_coeff * standardized_log_tvl + daily_noise = jnp.exp(log_daily_noise) + return jnp.maximum(0.0, daily_noise / 1440.0) + + +@jit +def reclamm_mm_observed_noise_volume( + effective_value_usd, + noise_base, + competitor_tvl, +): + """Michaelis-Menten noise model with observed competitor TVL as K. + + Derived from optimal routing (Diamandis et al. 2023):: + + V_noise = exp(base_t) * TVL / (K_t + TVL) + + where K_t is observed total competitor liquidity (direct + multi-hop + network conductance) from DeFi Llama, and base_t absorbs per-pool + intercept + market feature effects. + + The MM form guarantees: + - Elasticity ≈ 1 at low TVL (TVL << K) + - Structural saturation at high TVL (V_noise → exp(base_t)) + - No wireheading: V_noise is bounded regardless of concentration + + Parameters + ---------- + effective_value_usd : float + Effective TVL in USD: (Ra+Va)*pA + (Rb+Vb)*pB. + noise_base : float + Precomputed log(V_max_daily) = alpha_i + gamma_i @ x_market_t. + competitor_tvl : float + Observed competitor TVL (K) for this step, from DeFi Llama + network conductance model. + + Returns + ------- + float + Per-minute noise volume (USD), floored at zero. + """ + tvl = jnp.maximum(effective_value_usd, 1.0) + K = jnp.maximum(competitor_tvl, 1.0) + daily_noise = jnp.exp(noise_base) * tvl / (K + tvl) + return jnp.maximum(0.0, daily_noise / 1440.0) + + diff --git a/quantammsim/pools/reCLAMM/reclamm.py b/quantammsim/pools/reCLAMM/reclamm.py index 2413777..dd30cd4 100644 --- a/quantammsim/pools/reCLAMM/reclamm.py +++ b/quantammsim/pools/reCLAMM/reclamm.py @@ -246,18 +246,114 @@ def _resolve_noise_inputs( if noise_params is not None and type(noise_params) is not dict: noise_params = dict(noise_params) - arb_vol = None - if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): - volatility_array = self.calculate_volatility_array(prices, run_fingerprint) - arb_vol = _prepare_dynamic_array( - volatility_array, - start_index=start_index, - bout_length=bout_length, - arb_frequency=arb_freq, - max_len=arb_len, - ) + noise_arrays = self._prepare_noise_arrays( + prices, + run_fingerprint, + start_index, + bout_length, + arb_freq, + arb_len, + ) + + return lp_prepared, noise_model, noise_params, noise_arrays + + def _prepare_noise_arrays(self, prices, run_fingerprint, start_index, + bout_length, arb_freq, max_len): + """Prepare dynamic input arrays for noise models. + + Returns dict with keys depending on noise_model: + - "ratio": {} + - "tsoukalas_*"/"loglinear": {"volatility": array} + - "calibrated": {"volatility": array, "dow_sin": array, "dow_cos": array} + - "market_linear": {"noise_base": array, "noise_tvl_coeff": array} + - "mm_observed": {"noise_base": array, "competitor_tvl": array} + """ + noise_model = run_fingerprint.get("noise_model", "ratio") + result = {"volatility": None, "dow_sin": None, "dow_cos": None, + "noise_base": None, "noise_tvl_coeff": None, + "competitor_tvl": None} + + if noise_model == "mm_observed": + # MM model with observed competitor TVL as K + nb = run_fingerprint.get("noise_base_array") + ct = run_fingerprint.get("competitor_tvl_array") + if nb is None and "noise_arrays_path" in run_fingerprint: + path = run_fingerprint["noise_arrays_path"] + if not hasattr(self, "_mm_observed_cache") or self._mm_observed_cache[0] != path: + arrays = np.load(path) + self._mm_observed_cache = ( + path, arrays["noise_base"], arrays["competitor_tvl"]) + nb = self._mm_observed_cache[1] + ct = self._mm_observed_cache[2] + if nb is not None: + result["noise_base"] = _prepare_dynamic_array( + jnp.array(nb), start_index, bout_length, arb_freq, max_len) + if ct is not None: + result["competitor_tvl"] = _prepare_dynamic_array( + jnp.array(ct), start_index, bout_length, arb_freq, max_len) + return result + + if noise_model == "market_linear": + # Load precomputed arrays from path (cached on instance) or direct. + nb = run_fingerprint.get("noise_base_array") + ntc = run_fingerprint.get("noise_tvl_coeff_array") + if nb is None and "noise_arrays_path" in run_fingerprint: + path = run_fingerprint["noise_arrays_path"] + if ( + not hasattr(self, "_market_linear_cache") + or self._market_linear_cache[0] != path + ): + arrays = np.load(path) + self._market_linear_cache = ( + path, + arrays["noise_base"], + arrays["noise_tvl_coeff"], + ) + nb = self._market_linear_cache[1] + ntc = self._market_linear_cache[2] + if nb is not None: + result["noise_base"] = _prepare_dynamic_array( + jnp.array(nb), start_index, bout_length, arb_freq, max_len, + ) + if ntc is not None: + result["noise_tvl_coeff"] = _prepare_dynamic_array( + jnp.array(ntc), start_index, bout_length, arb_freq, max_len, + ) + return result + + needs_vol = noise_model in ( + "tsoukalas_sqrt", "tsoukalas_log", "loglinear", "calibrated", + ) + if not needs_vol: + return result - return lp_prepared, noise_model, noise_params, arb_vol + volatility_array = self.calculate_volatility_array( + prices, run_fingerprint, + ) + result["volatility"] = _prepare_dynamic_array( + volatility_array, start_index, bout_length, arb_freq, max_len, + ) + + if noise_model != "calibrated": + return result + + # Day-of-week sin/cos arrays for the calibrated noise model. + import pandas as pd + + start_dt = pd.Timestamp(run_fingerprint["startDateString"]) + n_minutes = prices.shape[0] + day_indices = np.arange(n_minutes) // 1440 + start_weekday = start_dt.weekday() + weekdays = ((start_weekday + day_indices) % 7).astype(np.float64) + dow_sin_full = jnp.array(np.sin(2.0 * np.pi * weekdays / 7.0)) + dow_cos_full = jnp.array(np.cos(2.0 * np.pi * weekdays / 7.0)) + result["dow_sin"] = _prepare_dynamic_array( + dow_sin_full, start_index, bout_length, arb_freq, max_len, + ) + result["dow_cos"] = _prepare_dynamic_array( + dow_cos_full, start_index, bout_length, arb_freq, max_len, + ) + return result @partial(jit, static_argnums=(2,)) def calculate_reserves_with_fees( @@ -271,12 +367,14 @@ def calculate_reserves_with_fees( ) -> jnp.ndarray: s = self._init_pool_state(params, run_fingerprint, prices, start_index) ste_temperature = self._resolve_ste_temperature(run_fingerprint) - lp_prepared, noise_model, noise_params, arb_vol = self._resolve_noise_inputs( - run_fingerprint, - prices, - start_index, - s.arb_prices.shape[0], - lp_supply_array=lp_supply_array, + lp_prepared, noise_model, noise_params, noise_arrays = ( + self._resolve_noise_inputs( + run_fingerprint, + prices, + start_index, + s.arb_prices.shape[0], + lp_supply_array=lp_supply_array, + ) ) if run_fingerprint["do_arb"]: @@ -300,7 +398,12 @@ def calculate_reserves_with_fees( lp_supply_array=lp_prepared, noise_model=noise_model, noise_params=noise_params, - volatility_array=arb_vol, + volatility_array=noise_arrays["volatility"], + dow_sin_array=noise_arrays["dow_sin"], + dow_cos_array=noise_arrays["dow_cos"], + noise_base_array=noise_arrays["noise_base"], + noise_tvl_coeff_array=noise_arrays["noise_tvl_coeff"], + competitor_tvl_array=noise_arrays["competitor_tvl"], ) return jnp.broadcast_to(s.initial_reserves, s.arb_prices.shape) @@ -324,12 +427,14 @@ def calculate_reserves_and_fee_revenue_with_fees( """ s = self._init_pool_state(params, run_fingerprint, prices, start_index) ste_temperature = self._resolve_ste_temperature(run_fingerprint) - lp_prepared, noise_model, noise_params, arb_vol = self._resolve_noise_inputs( - run_fingerprint, - prices, - start_index, - s.arb_prices.shape[0], - lp_supply_array=lp_supply_array, + lp_prepared, noise_model, noise_params, noise_arrays = ( + self._resolve_noise_inputs( + run_fingerprint, + prices, + start_index, + s.arb_prices.shape[0], + lp_supply_array=lp_supply_array, + ) ) if run_fingerprint["do_arb"]: @@ -353,7 +458,12 @@ def calculate_reserves_and_fee_revenue_with_fees( lp_supply_array=lp_prepared, noise_model=noise_model, noise_params=noise_params, - volatility_array=arb_vol, + volatility_array=noise_arrays["volatility"], + dow_sin_array=noise_arrays["dow_sin"], + dow_cos_array=noise_arrays["dow_cos"], + noise_base_array=noise_arrays["noise_base"], + noise_tvl_coeff_array=noise_arrays["noise_tvl_coeff"], + competitor_tvl_array=noise_arrays["competitor_tvl"], ) return ( jnp.broadcast_to(s.initial_reserves, s.arb_prices.shape), @@ -389,11 +499,11 @@ def calculate_reserves_and_fee_revenue_with_dynamic_inputs( do_trades=False, dtype=s.arb_prices.dtype, ) - _, noise_model, noise_params, arb_vol = self._resolve_noise_inputs( + _, noise_model, noise_params, noise_arrays = self._resolve_noise_inputs( run_fingerprint, prices, start_index, - s.arb_prices.shape[0], + max_len, lp_supply_array=None, ) @@ -418,7 +528,12 @@ def calculate_reserves_and_fee_revenue_with_dynamic_inputs( lp_supply_array=materialized_inputs.lp_supply, noise_model=noise_model, noise_params=noise_params, - volatility_array=arb_vol, + volatility_array=noise_arrays["volatility"], + dow_sin_array=noise_arrays["dow_sin"], + dow_cos_array=noise_arrays["dow_cos"], + noise_base_array=noise_arrays["noise_base"], + noise_tvl_coeff_array=noise_arrays["noise_tvl_coeff"], + competitor_tvl_array=noise_arrays["competitor_tvl"], ) @partial(jit, static_argnums=(2,)) @@ -497,11 +612,11 @@ def calculate_reserves_with_dynamic_inputs( do_trades=False, dtype=s.arb_prices.dtype, ) - _, noise_model, noise_params, arb_vol = self._resolve_noise_inputs( + _, noise_model, noise_params, noise_arrays = self._resolve_noise_inputs( run_fingerprint, prices, start_index, - s.arb_prices.shape[0], + max_len, lp_supply_array=None, ) @@ -526,7 +641,12 @@ def calculate_reserves_with_dynamic_inputs( lp_supply_array=materialized_inputs.lp_supply, noise_model=noise_model, noise_params=noise_params, - volatility_array=arb_vol, + volatility_array=noise_arrays["volatility"], + dow_sin_array=noise_arrays["dow_sin"], + dow_cos_array=noise_arrays["dow_cos"], + noise_base_array=noise_arrays["noise_base"], + noise_tvl_coeff_array=noise_arrays["noise_tvl_coeff"], + competitor_tvl_array=noise_arrays["competitor_tvl"], ) def init_base_parameters( diff --git a/quantammsim/pools/reCLAMM/reclamm_hypersurge.py b/quantammsim/pools/reCLAMM/reclamm_hypersurge.py new file mode 100644 index 0000000..0e0a52c --- /dev/null +++ b/quantammsim/pools/reCLAMM/reclamm_hypersurge.py @@ -0,0 +1,372 @@ +from functools import partial +from typing import Any, Dict, Optional + +import numpy as np + +import jax.numpy as jnp +from jax import jit, tree_util +from jax.lax import dynamic_slice + +from quantammsim.core_simulator.dynamic_inputs import materialize_dynamic_inputs +from quantammsim.pools.hypersurge_utils import ( + HYPERSURGE_PARAM_KEYS, + hypersurge_params_from_params, + run_fingerprint_hypersurge_defaults, +) +from quantammsim.pools.reCLAMM.reclamm import ReClammPool +from quantammsim.pools.reCLAMM.reclamm_hypersurge_reserves import ( + _jax_calc_reclamm_hypersurge_reserves, + _jax_calc_reclamm_hypersurge_reserves_and_fee_revenue, +) + + +class ReClammHyperSurgePool(ReClammPool): + """reCLAMM pool with HyperSurge state-dependent swap fees.""" + + @staticmethod + def _run_fingerprint_hypersurge_defaults(run_fingerprint: Dict[str, Any]): + return run_fingerprint_hypersurge_defaults(run_fingerprint) + + def _hypersurge_params(self, params: Dict[str, Any], run_fingerprint: Dict[str, Any]): + return hypersurge_params_from_params(params, run_fingerprint) + + def _oracle_price_window( + self, + run_fingerprint: Dict[str, Any], + prices: jnp.ndarray, + start_index: jnp.ndarray, + additional_oracle_input: Optional[jnp.ndarray], + ): + bout_length = run_fingerprint["bout_length"] + n_assets = run_fingerprint["n_assets"] + if additional_oracle_input is None: + local_oracle_prices = dynamic_slice( + prices, start_index, (bout_length - 1, n_assets) + ) + else: + local_oracle_prices = dynamic_slice( + additional_oracle_input, + start_index, + (bout_length - 1, n_assets), + ) + + arb_frequency = run_fingerprint["arb_frequency"] + if arb_frequency != 1: + return local_oracle_prices[::arb_frequency] + return local_oracle_prices + + def _run_hypersurge_reserves( + self, + params: Dict[str, Any], + run_fingerprint: Dict[str, Any], + prices: jnp.ndarray, + start_index: jnp.ndarray, + fees, + gas_cost, + arb_fees, + lp_supply_array, + price_ratio_updates, + oracle_prices, + lp_supply_already_prepared: bool = False, + return_fee_revenue: bool = False, + ): + s = self._init_pool_state(params, run_fingerprint, prices, start_index) + ste_temperature = self._resolve_ste_temperature(run_fingerprint) + noise_lp_supply = None if lp_supply_already_prepared else lp_supply_array + lp_prepared, noise_model, noise_params, noise_arrays = self._resolve_noise_inputs( + run_fingerprint, + prices, + start_index, + s.arb_prices.shape[0], + lp_supply_array=noise_lp_supply, + ) + if lp_supply_already_prepared: + lp_prepared = lp_supply_array + + if not run_fingerprint["do_arb"]: + reserves = jnp.broadcast_to(s.initial_reserves, s.arb_prices.shape) + if return_fee_revenue: + return reserves, jnp.zeros(s.arb_prices.shape[0], dtype=s.arb_prices.dtype) + return reserves + + kernel = ( + _jax_calc_reclamm_hypersurge_reserves_and_fee_revenue + if return_fee_revenue + else _jax_calc_reclamm_hypersurge_reserves + ) + return kernel( + s.initial_reserves, + s.Va, + s.Vb, + s.arb_prices, + oracle_prices, + self._hypersurge_params(params, run_fingerprint), + s.centeredness_margin, + s.daily_price_shift_base, + s.seconds_per_step, + fees=fees, + arb_thresh=gas_cost, + arb_fees=arb_fees, + price_ratio_updates=price_ratio_updates, + all_sig_variations=jnp.array(run_fingerprint["all_sig_variations"]), + arc_length_speed=s.arc_length_speed, + centeredness_scaling=s.centeredness_scaling, + protocol_fee_split=run_fingerprint.get("protocol_fee_split", 0.0), + ste_temperature=ste_temperature, + noise_trader_ratio=run_fingerprint.get("noise_trader_ratio", 0.0), + lp_supply_array=lp_prepared, + noise_model=noise_model, + noise_params=noise_params, + volatility_array=noise_arrays["volatility"], + dow_sin_array=noise_arrays["dow_sin"], + dow_cos_array=noise_arrays["dow_cos"], + noise_base_array=noise_arrays["noise_base"], + noise_tvl_coeff_array=noise_arrays["noise_tvl_coeff"], + competitor_tvl_array=noise_arrays["competitor_tvl"], + ) + + @partial(jit, static_argnums=(2,)) + def calculate_reserves_with_fees( + self, + params: Dict[str, Any], + run_fingerprint: Dict[str, Any], + prices: jnp.ndarray, + start_index: jnp.ndarray, + additional_oracle_input: Optional[jnp.ndarray] = None, + lp_supply_array: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: + oracle_prices = self._oracle_price_window( + run_fingerprint, + prices, + start_index, + additional_oracle_input, + ) + return self._run_hypersurge_reserves( + params, + run_fingerprint, + prices, + start_index, + fees=self._resolve_fees(params, run_fingerprint), + gas_cost=run_fingerprint["gas_cost"], + arb_fees=run_fingerprint["arb_fees"], + lp_supply_array=lp_supply_array, + price_ratio_updates=None, + oracle_prices=oracle_prices, + ) + + @partial(jit, static_argnums=(2,)) + def calculate_reserves_and_fee_revenue_with_fees( + self, + params: Dict[str, Any], + run_fingerprint: Dict[str, Any], + prices: jnp.ndarray, + start_index: jnp.ndarray, + additional_oracle_input: Optional[jnp.ndarray] = None, + lp_supply_array: Optional[jnp.ndarray] = None, + ): + oracle_prices = self._oracle_price_window( + run_fingerprint, + prices, + start_index, + additional_oracle_input, + ) + return self._run_hypersurge_reserves( + params, + run_fingerprint, + prices, + start_index, + fees=self._resolve_fees(params, run_fingerprint), + gas_cost=run_fingerprint["gas_cost"], + arb_fees=run_fingerprint["arb_fees"], + lp_supply_array=lp_supply_array, + price_ratio_updates=None, + oracle_prices=oracle_prices, + return_fee_revenue=True, + ) + + @partial(jit, static_argnums=(2,)) + def _calculate_reserves_zero_fees( + self, + params: Dict[str, Any], + run_fingerprint: Dict[str, Any], + prices: jnp.ndarray, + start_index: jnp.ndarray, + additional_oracle_input: Optional[jnp.ndarray] = None, + lp_supply_array: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: + oracle_prices = self._oracle_price_window( + run_fingerprint, + prices, + start_index, + additional_oracle_input, + ) + return self._run_hypersurge_reserves( + params, + run_fingerprint, + prices, + start_index, + fees=0.0, + gas_cost=run_fingerprint["gas_cost"], + arb_fees=run_fingerprint["arb_fees"], + lp_supply_array=lp_supply_array, + price_ratio_updates=None, + oracle_prices=oracle_prices, + ) + + def calculate_reserves_zero_fees( + self, + params: Dict[str, Any], + run_fingerprint: Dict[str, Any], + prices: jnp.ndarray, + start_index: jnp.ndarray, + additional_oracle_input: Optional[jnp.ndarray] = None, + lp_supply_array: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: + return self._calculate_reserves_zero_fees( + params, + run_fingerprint, + prices, + start_index, + additional_oracle_input, + lp_supply_array, + ) + + @partial(jit, static_argnums=(2,)) + def calculate_reserves_with_dynamic_inputs( + self, + params: Dict[str, Any], + run_fingerprint: Dict[str, Any], + prices: jnp.ndarray, + start_index: jnp.ndarray, + dynamic_inputs, + additional_oracle_input: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: + s = self._init_pool_state(params, run_fingerprint, prices, start_index) + materialized_inputs = materialize_dynamic_inputs( + dynamic_inputs, + run_fingerprint.get("dynamic_input_flags"), + run_fingerprint, + scan_len=s.arb_prices.shape[0], + do_trades=False, + dtype=s.arb_prices.dtype, + ) + + oracle_prices = self._oracle_price_window( + run_fingerprint, + prices, + start_index, + additional_oracle_input, + ) + if materialized_inputs.oracle_prices.shape[-1] == run_fingerprint["n_assets"]: + oracle_prices = materialized_inputs.oracle_prices + + return self._run_hypersurge_reserves( + params, + run_fingerprint, + prices, + start_index, + fees=materialized_inputs.fees, + gas_cost=materialized_inputs.gas_cost, + arb_fees=materialized_inputs.arb_fees, + lp_supply_array=materialized_inputs.lp_supply, + price_ratio_updates=materialized_inputs.reclamm_price_ratio_updates, + oracle_prices=oracle_prices, + lp_supply_already_prepared=True, + ) + + @partial(jit, static_argnums=(2,)) + def calculate_reserves_and_fee_revenue_with_dynamic_inputs( + self, + params: Dict[str, Any], + run_fingerprint: Dict[str, Any], + prices: jnp.ndarray, + start_index: jnp.ndarray, + dynamic_inputs, + additional_oracle_input: Optional[jnp.ndarray] = None, + ): + s = self._init_pool_state(params, run_fingerprint, prices, start_index) + materialized_inputs = materialize_dynamic_inputs( + dynamic_inputs, + run_fingerprint.get("dynamic_input_flags"), + run_fingerprint, + scan_len=s.arb_prices.shape[0], + do_trades=False, + dtype=s.arb_prices.dtype, + ) + + oracle_prices = self._oracle_price_window( + run_fingerprint, + prices, + start_index, + additional_oracle_input, + ) + if materialized_inputs.oracle_prices.shape[-1] == run_fingerprint["n_assets"]: + oracle_prices = materialized_inputs.oracle_prices + + return self._run_hypersurge_reserves( + params, + run_fingerprint, + prices, + start_index, + fees=materialized_inputs.fees, + gas_cost=materialized_inputs.gas_cost, + arb_fees=materialized_inputs.arb_fees, + lp_supply_array=materialized_inputs.lp_supply, + price_ratio_updates=materialized_inputs.reclamm_price_ratio_updates, + oracle_prices=oracle_prices, + lp_supply_already_prepared=True, + return_fee_revenue=True, + ) + + def init_base_parameters( + self, + initial_values_dict: Dict[str, Any], + run_fingerprint: Dict[str, Any], + n_assets: int, + n_parameter_sets: int = 1, + noise: str = "gaussian", + ) -> Dict[str, Any]: + params = super().init_base_parameters( + initial_values_dict, + run_fingerprint, + n_assets, + n_parameter_sets=n_parameter_sets, + noise=noise, + ) + + def process_scalar(key, default): + value = initial_values_dict.get(key, default) + if value is None: + value = default + value = np.asarray(value, dtype=np.float64) + if value.size == 1: + return np.array([[float(value.reshape(-1)[0])]] * n_parameter_sets) + if value.shape == (n_parameter_sets,): + return value.reshape(n_parameter_sets, 1) + if value.shape == (n_parameter_sets, 1): + return value + raise ValueError( + f"{key} must be a scalar or a matrix of shape (n_parameter_sets, 1)" + ) + + defaults = self._run_fingerprint_hypersurge_defaults(run_fingerprint) + hypersurge_params = { + key: process_scalar(key, defaults[key]) for key in HYPERSURGE_PARAM_KEYS + } + hypersurge_params = self.add_noise(hypersurge_params, noise, n_parameter_sets) + params.update(hypersurge_params) + return params + + def get_initial_values(self, run_fingerprint): + values = super().get_initial_values(run_fingerprint) + defaults = self._run_fingerprint_hypersurge_defaults(run_fingerprint) + for key, value in defaults.items(): + values[key] = run_fingerprint.get(f"initial_{key}", value) + return values + + +tree_util.register_pytree_node( + ReClammHyperSurgePool, + ReClammHyperSurgePool._tree_flatten, + ReClammHyperSurgePool._tree_unflatten, +) diff --git a/quantammsim/pools/reCLAMM/reclamm_hypersurge_reserves.py b/quantammsim/pools/reCLAMM/reclamm_hypersurge_reserves.py new file mode 100644 index 0000000..d508f5f --- /dev/null +++ b/quantammsim/pools/reCLAMM/reclamm_hypersurge_reserves.py @@ -0,0 +1,918 @@ +from functools import partial + +import jax.numpy as jnp +from jax import jit +from jax.lax import cond, scan +from jax.tree_util import Partial + +from quantammsim.pools.G3M.G3M_trades import ( + _jax_calc_G3M_trade_from_exact_in_given_out, +) +from quantammsim.pools.G3M.optimal_n_pool_arb import ( + parallelised_optimal_trade_sifter, + precalc_components_of_optimal_trade_across_signatures, + precalc_shared_values_for_all_signatures, +) +from quantammsim.pools.hypersurge_utils import ( + _EPS, + broadcast_scan_vector, + fee_to_gamma, + max_pair_deviation, + oracle_pair_is_valid, + oracle_vector_is_valid, + pair_deviation, + ramp_fee, + safe_positive, +) +from quantammsim.pools.noise_trades import ( + calculate_reserves_after_noise_trade, + reclamm_calibrated_noise_volume, + reclamm_loglinear_noise_volume, + reclamm_market_linear_noise_volume, + reclamm_mm_observed_noise_volume, + reclamm_tsoukalas_log_noise_volume, + reclamm_tsoukalas_sqrt_noise_volume, +) +from quantammsim.pools.reCLAMM.reclamm_reserves import ( + _DUST_USD, + _ste_greater_equal, + _ste_less_than, + _ste_select, + apply_target_price_ratio_to_virtual_balances, + compute_centeredness, + compute_invariant, + compute_price_ratio, + compute_virtual_balances_constant_arc_length, + compute_virtual_balances_updating_price_range, +) + + +_WEIGHTS = jnp.array([0.5, 0.5]) + + +def _broadcast_oracle_prices(oracle_prices, prices): + oracle_prices = jnp.asarray(oracle_prices) + if oracle_prices.ndim == 1: + oracle_prices = oracle_prices.reshape((1, oracle_prices.shape[0])) + if oracle_prices.shape[-1] != prices.shape[-1]: + oracle_prices = prices + elif oracle_prices.shape[0] == 1: + oracle_prices = jnp.broadcast_to(oracle_prices, prices.shape) + return oracle_prices + + +def _broadcast_schedule_array(price_ratio_updates, prices): + if price_ratio_updates is None: + updates = jnp.zeros((prices.shape[0], 4), dtype=prices.dtype) + return updates.at[:, 3].set(jnp.nan) + + updates = jnp.asarray(price_ratio_updates, dtype=prices.dtype) + if updates.ndim == 1: + updates = jnp.broadcast_to(updates, (prices.shape[0], updates.shape[0])) + elif updates.shape[0] == 1 and prices.shape[0] != 1: + updates = jnp.broadcast_to(updates, (prices.shape[0], updates.shape[1])) + return updates + + +def _zero_fee_optimal_trade(Ra, Rb, Va, Vb, prices): + market_price = prices[0] / prices[1] + L = compute_invariant(Ra, Rb, Va, Vb) + Ea_new = jnp.sqrt(L / market_price) + Eb_new = jnp.sqrt(L * market_price) + return jnp.array([Ea_new - (Ra + Va), Eb_new - (Rb + Vb)]) + + +def _effective_reserves(real_reserves, Va, Vb): + return jnp.array([real_reserves[0] + Va, real_reserves[1] + Vb]) + + +def _reclamm_hypersurge_fee_for_trade( + real_reserves, + Va, + Vb, + candidate_trade, + oracle_prices, + base_fee, + hypersurge_params, +): + token_in = jnp.argmax(candidate_trade) + token_out = jnp.argmin(candidate_trade) + trade_active = jnp.logical_and( + candidate_trade[token_in] > 0.0, + candidate_trade[token_out] < 0.0, + ) + pair_has_oracle = oracle_pair_is_valid(oracle_prices, token_in, token_out) + + effective_before = safe_positive(_effective_reserves(real_reserves, Va, Vb)) + effective_after = safe_positive(effective_before + candidate_trade) + dev_before = pair_deviation( + effective_before, + _WEIGHTS, + oracle_prices, + token_in, + token_out, + ) + dev_after = pair_deviation( + effective_after, + _WEIGHTS, + oracle_prices, + token_in, + token_out, + ) + worsens = dev_after > dev_before + + arb_fee = ramp_fee( + base_fee, + hypersurge_params[0], + hypersurge_params[1], + hypersurge_params[2], + dev_before, + ) + noise_fee = ramp_fee( + base_fee, + hypersurge_params[3], + hypersurge_params[4], + hypersurge_params[5], + dev_after, + ) + fee = jnp.where(worsens, noise_fee, arb_fee) + return jnp.where(jnp.logical_and(trade_active, pair_has_oracle), fee, base_fee) + + +def _reclamm_hypersurge_noise_fee( + real_reserves, + Va, + Vb, + oracle_prices, + base_fee, + hypersurge_params, +): + effective = safe_positive(_effective_reserves(real_reserves, Va, Vb)) + deviation = max_pair_deviation(effective, _WEIGHTS, oracle_prices) + fee = ramp_fee( + base_fee, + hypersurge_params[3], + hypersurge_params[4], + hypersurge_params[5], + deviation, + ) + return jnp.where(oracle_vector_is_valid(oracle_prices), fee, base_fee) + + +def _optimal_arb_trade_with_gamma( + reserves, + prices, + gamma, + tokens_to_drop, + active_trade_directions, + leave_one_out_idxs, + n, +): + active_initial_weights, per_asset_ratios, all_other_assets_ratios = ( + precalc_components_of_optimal_trade_across_signatures( + _WEIGHTS, + prices, + gamma, + tokens_to_drop, + active_trade_directions, + leave_one_out_idxs, + ) + ) + return parallelised_optimal_trade_sifter( + reserves, + _WEIGHTS, + prices, + active_initial_weights, + active_trade_directions, + per_asset_ratios, + all_other_assets_ratios, + tokens_to_drop, + gamma, + n, + 0, + ) + + +def _apply_protocol_fee(reserves_after_trade, trade, fee, protocol_fee_split): + inbound = jnp.maximum(trade, 0.0) + protocol_fee = inbound * fee * protocol_fee_split + return jnp.maximum(reserves_after_trade - protocol_fee, _EPS) + + +def _reclamm_hypersurge_scan_step_with_fee_revenue( + carry_list, + input_list, + tokens_to_drop, + active_trade_directions, + leave_one_out_idxs, + n, + hypersurge_params, + centeredness_margin, + daily_price_shift_base, + seconds_per_step, + arc_length_speed=0.0, + centeredness_scaling=False, + protocol_fee_split=0.0, + ste_temperature=10.0, + noise_trader_ratio=0.0, + noise_model="ratio", + noise_params=None, +): + prev_reserves = carry_list[0] + Va = carry_list[1] + Vb = carry_list[2] + step_idx = carry_list[3] + active_start_ratio = carry_list[4] + active_target_ratio = carry_list[5] + active_start_step = carry_list[6] + active_end_step = carry_list[7] + active_enabled = carry_list[8] + prev_lp_supply = carry_list[9] + + prices = input_list[0] + oracle_prices = input_list[1] + base_fee = input_list[2] + arb_thresh = input_list[3] + arb_fees = input_list[4] + price_ratio_update = input_list[5] + lp_supply = input_list[6] + volatility = input_list[7] + dow_sin = input_list[8] + dow_cos = input_list[9] + noise_base = input_list[10] + noise_tvl_coeff = input_list[11] + competitor_tvl = input_list[12] + + scale = lp_supply / prev_lp_supply + lp_supply_change = lp_supply != prev_lp_supply + prev_reserves = jnp.where(lp_supply_change, prev_reserves * scale, prev_reserves) + Va = jnp.where(lp_supply_change, Va * scale, Va) + Vb = jnp.where(lp_supply_change, Vb * scale, Vb) + + Ra = prev_reserves[0] + Rb = prev_reserves[1] + + event_has = price_ratio_update[0] > 0.5 + event_target_ratio = jnp.maximum( + jnp.where(jnp.isfinite(price_ratio_update[1]), price_ratio_update[1], 1.0), + 1.0 + 1e-12, + ) + event_end_step = jnp.where( + jnp.isfinite(price_ratio_update[2]), price_ratio_update[2], step_idx + ) + event_start_override = price_ratio_update[3] + + def _apply_schedule_state(_): + current_price_ratio = compute_price_ratio(Ra, Rb, Va, Vb) + start_ratio_from_event = jnp.where( + jnp.isfinite(event_start_override), + event_start_override, + current_price_ratio, + ) + next_active_start_ratio = jnp.where( + event_has, start_ratio_from_event, active_start_ratio + ) + next_active_target_ratio = jnp.where( + event_has, event_target_ratio, active_target_ratio + ) + next_active_start_step = jnp.where(event_has, step_idx, active_start_step) + next_active_end_step = jnp.where( + event_has, jnp.maximum(event_end_step, step_idx), active_end_step + ) + next_active_enabled = jnp.where(event_has, True, active_enabled) + next_active_enabled = jnp.logical_and( + next_active_enabled, step_idx <= next_active_end_step + ) + + schedule_duration = next_active_end_step - next_active_start_step + schedule_progress = jnp.where( + schedule_duration <= 0.0, + 1.0, + jnp.clip((step_idx - next_active_start_step) / schedule_duration, 0.0, 1.0), + ) + safe_start_ratio = jnp.maximum(next_active_start_ratio, 1.0 + 1e-12) + safe_target_ratio = jnp.maximum(next_active_target_ratio, 1.0 + 1e-12) + scheduled_price_ratio = safe_start_ratio * ( + safe_target_ratio / safe_start_ratio + ) ** schedule_progress + scheduled_price_ratio = jnp.where( + next_active_enabled, scheduled_price_ratio, current_price_ratio + ) + Va_scheduled, Vb_scheduled = apply_target_price_ratio_to_virtual_balances( + Ra, Rb, Va, Vb, scheduled_price_ratio + ) + next_Va = jnp.where(next_active_enabled, Va_scheduled, Va) + next_Vb = jnp.where(next_active_enabled, Vb_scheduled, Vb) + return ( + next_Va, + next_Vb, + next_active_start_ratio, + next_active_target_ratio, + next_active_start_step, + next_active_end_step, + next_active_enabled, + ) + + def _skip_schedule_state(_): + retained_active_enabled = jnp.logical_and( + active_enabled, step_idx <= active_end_step + ) + return ( + Va, + Vb, + active_start_ratio, + active_target_ratio, + active_start_step, + active_end_step, + retained_active_enabled, + ) + + active_not_expired = jnp.logical_and(active_enabled, step_idx <= active_end_step) + schedule_active = jnp.logical_or(event_has, active_not_expired) + ( + Va, + Vb, + active_start_ratio, + active_target_ratio, + active_start_step, + active_end_step, + active_enabled, + ) = cond( + schedule_active, + _apply_schedule_state, + _skip_schedule_state, + operand=None, + ) + + centeredness, is_above = compute_centeredness(Ra, Rb, Va, Vb) + sqrt_Q = jnp.sqrt(compute_price_ratio(Ra, Rb, Va, Vb)) + market_price = prices[0] / prices[1] + + speed_multiplier = jnp.where( + centeredness_scaling, + centeredness_margin / jnp.maximum(centeredness, 1e-10), + 1.0, + ) + + Va_geo, Vb_geo = compute_virtual_balances_updating_price_range( + Ra, + Rb, + Va, + Vb, + is_pool_above_center=is_above, + daily_price_shift_base=daily_price_shift_base, + seconds_elapsed=seconds_per_step * speed_multiplier, + sqrt_price_ratio=sqrt_Q, + ) + Va_cal, Vb_cal = compute_virtual_balances_constant_arc_length( + Ra, + Rb, + Va, + Vb, + is_pool_above_center=is_above, + arc_length_speed=arc_length_speed * speed_multiplier, + seconds_elapsed=seconds_per_step, + sqrt_price_ratio=sqrt_Q, + market_price=market_price, + ) + use_cal = arc_length_speed > 0.0 + Va_updated = jnp.where(use_cal, Va_cal, Va_geo) + Vb_updated = jnp.where(use_cal, Vb_cal, Vb_geo) + + out_of_range_gate = _ste_less_than( + centeredness, centeredness_margin, ste_temperature + ) + Va = _ste_select(out_of_range_gate, Va_updated, Va) + Vb = _ste_select(out_of_range_gate, Vb_updated, Vb) + + effective_reserves = _effective_reserves(prev_reserves, Va, Vb) + zero_fee_trade = _zero_fee_optimal_trade(Ra, Rb, Va, Vb, prices) + + preview_fee = _reclamm_hypersurge_fee_for_trade( + prev_reserves, + Va, + Vb, + zero_fee_trade, + oracle_prices, + base_fee, + hypersurge_params, + ) + preview_trade = _optimal_arb_trade_with_gamma( + effective_reserves, + prices, + fee_to_gamma(preview_fee), + tokens_to_drop, + active_trade_directions, + leave_one_out_idxs, + n, + ) + arb_fee = _reclamm_hypersurge_fee_for_trade( + prev_reserves, + Va, + Vb, + preview_trade, + oracle_prices, + base_fee, + hypersurge_params, + ) + arb_gamma = fee_to_gamma(arb_fee) + optimal_arb_trade = _optimal_arb_trade_with_gamma( + effective_reserves, + prices, + arb_gamma, + tokens_to_drop, + active_trade_directions, + leave_one_out_idxs, + n, + ) + + profit_to_arb = -(optimal_arb_trade * prices).sum() - arb_thresh + arb_external_cost = 0.5 * arb_fees * (jnp.abs(optimal_arb_trade) * prices).sum() + trade_gate = _ste_greater_equal( + profit_to_arb, arb_external_cost, ste_temperature + ) + applied_trade = _ste_select( + trade_gate, optimal_arb_trade, jnp.zeros_like(optimal_arb_trade) + ) + + Ra_trade = Ra + applied_trade[0] + Rb_trade = Rb + applied_trade[1] + + dust_a = _DUST_USD / prices[0] + dust_b = _DUST_USD / prices[1] + drain_a = jnp.maximum(Ra - dust_a, 0.0) + drain_b = jnp.maximum(Rb - dust_b, 0.0) + edge_a = _jax_calc_G3M_trade_from_exact_in_given_out( + effective_reserves, + _WEIGHTS, + token_in=1, + token_out=0, + amount_out=drain_a, + gamma=arb_gamma, + ) + edge_b = _jax_calc_G3M_trade_from_exact_in_given_out( + effective_reserves, + _WEIGHTS, + token_in=0, + token_out=1, + amount_out=drain_b, + gamma=arb_gamma, + ) + + clamp_a = Ra_trade < 0 + clamp_b = Rb_trade < 0 + final_arb_trade = jnp.where( + clamp_a, + edge_a, + jnp.where(clamp_b, edge_b, applied_trade), + ) + + reserves_after_arb = prev_reserves + final_arb_trade + reserves_after_arb = _apply_protocol_fee( + reserves_after_arb, + final_arb_trade, + arb_fee, + protocol_fee_split, + ) + arb_lp_fee_income = ( + jnp.maximum(final_arb_trade, 0.0) * arb_fee * (1.0 - protocol_fee_split) + ) + lp_fee_revenue_usd = (arb_lp_fee_income * prices).sum() + + noise_fee = _reclamm_hypersurge_noise_fee( + reserves_after_arb, + Va, + Vb, + oracle_prices, + base_fee, + hypersurge_params, + ) + + Ra_new = reserves_after_arb[0] + Rb_new = reserves_after_arb[1] + + if noise_model == "ratio": + lp_noise_gamma = fee_to_gamma(noise_fee * (1.0 - protocol_fee_split)) + noisy_reserves = calculate_reserves_after_noise_trade( + final_arb_trade, + reserves_after_arb, + prices, + noise_trader_ratio, + lp_noise_gamma, + ) + noise_lp_fee_income_usd = ( + noise_trader_ratio + * noise_fee + * (1.0 - protocol_fee_split) + * jnp.sum(jnp.maximum(final_arb_trade, 0.0) * prices) + ) + Ra_new = jnp.where(noise_trader_ratio > 0.0, noisy_reserves[0], Ra_new) + Rb_new = jnp.where(noise_trader_ratio > 0.0, noisy_reserves[1], Rb_new) + lp_fee_revenue_usd = jnp.where( + noise_trader_ratio > 0.0, + lp_fee_revenue_usd + noise_lp_fee_income_usd, + lp_fee_revenue_usd, + ) + elif noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): + arb_volume = 0.5 * jnp.sum(jnp.abs(final_arb_trade) * prices) + effective_value = (Ra_new + Va) * prices[0] + (Rb_new + Vb) * prices[1] + noise_cfg = noise_params if noise_params is not None else {} + if noise_model == "tsoukalas_sqrt": + noise_vol = reclamm_tsoukalas_sqrt_noise_volume( + effective_value, arb_gamma, volatility, arb_volume, noise_cfg + ) + elif noise_model == "tsoukalas_log": + noise_vol = reclamm_tsoukalas_log_noise_volume( + effective_value, arb_gamma, volatility, arb_volume, noise_cfg + ) + else: + noise_vol = reclamm_loglinear_noise_volume( + effective_value, arb_gamma, volatility, arb_volume, noise_cfg + ) + minutes_per_step = seconds_per_step / 60.0 + noise_lp_fee_income_usd = ( + noise_fee * (1.0 - protocol_fee_split) * noise_vol * minutes_per_step + ) + scale = 1.0 + noise_lp_fee_income_usd / jnp.maximum(effective_value, 1e-8) + Ra_new = (Ra_new + Va) * scale - Va + Rb_new = (Rb_new + Vb) * scale - Vb + lp_fee_revenue_usd = lp_fee_revenue_usd + noise_lp_fee_income_usd + elif noise_model == "calibrated": + arb_volume = 0.5 * jnp.sum(jnp.abs(final_arb_trade) * prices) + effective_value = (Ra_new + Va) * prices[0] + (Rb_new + Vb) * prices[1] + noise_cfg = noise_params if noise_params is not None else {} + noise_vol = reclamm_calibrated_noise_volume( + effective_value, + arb_gamma, + volatility, + arb_volume, + dow_sin, + dow_cos, + noise_cfg, + ) + minutes_per_step = seconds_per_step / 60.0 + noise_lp_fee_income_usd = ( + noise_fee * (1.0 - protocol_fee_split) * noise_vol * minutes_per_step + ) + scale = 1.0 + noise_lp_fee_income_usd / jnp.maximum(effective_value, 1e-8) + Ra_new = (Ra_new + Va) * scale - Va + Rb_new = (Rb_new + Vb) * scale - Vb + lp_fee_revenue_usd = lp_fee_revenue_usd + noise_lp_fee_income_usd + elif noise_model == "market_linear": + effective_value = (Ra_new + Va) * prices[0] + (Rb_new + Vb) * prices[1] + noise_cfg = noise_params if noise_params is not None else {} + noise_vol = reclamm_market_linear_noise_volume( + effective_value, + noise_base, + noise_tvl_coeff, + tvl_mean=noise_cfg.get("tvl_mean", 0.0), + tvl_std=noise_cfg.get("tvl_std", 1.0), + ) + minutes_per_step = seconds_per_step / 60.0 + noise_lp_fee_income_usd = ( + noise_fee * (1.0 - protocol_fee_split) * noise_vol * minutes_per_step + ) + scale = 1.0 + noise_lp_fee_income_usd / jnp.maximum(effective_value, 1e-8) + Ra_new = (Ra_new + Va) * scale - Va + Rb_new = (Rb_new + Vb) * scale - Vb + lp_fee_revenue_usd = lp_fee_revenue_usd + noise_lp_fee_income_usd + elif noise_model == "mm_observed": + effective_value = (Ra_new + Va) * prices[0] + (Rb_new + Vb) * prices[1] + noise_vol = reclamm_mm_observed_noise_volume( + effective_value, noise_base, competitor_tvl + ) + minutes_per_step = seconds_per_step / 60.0 + noise_lp_fee_income_usd = ( + noise_fee * (1.0 - protocol_fee_split) * noise_vol * minutes_per_step + ) + scale = 1.0 + noise_lp_fee_income_usd / jnp.maximum(effective_value, 1e-8) + Ra_new = (Ra_new + Va) * scale - Va + Rb_new = (Rb_new + Vb) * scale - Vb + lp_fee_revenue_usd = lp_fee_revenue_usd + noise_lp_fee_income_usd + + new_reserves = jnp.array([Ra_new, Rb_new]) + return [ + new_reserves, + Va, + Vb, + step_idx + 1.0, + active_start_ratio, + active_target_ratio, + active_start_step, + active_end_step, + active_enabled, + lp_supply, + ], (new_reserves, lp_fee_revenue_usd) + + +def _reclamm_hypersurge_scan_step( + carry_list, + input_list, + tokens_to_drop, + active_trade_directions, + leave_one_out_idxs, + n, + hypersurge_params, + centeredness_margin, + daily_price_shift_base, + seconds_per_step, + arc_length_speed=0.0, + centeredness_scaling=False, + protocol_fee_split=0.0, + ste_temperature=10.0, + noise_trader_ratio=0.0, + noise_model="ratio", + noise_params=None, +): + new_carry, (new_reserves, _fee_revenue) = ( + _reclamm_hypersurge_scan_step_with_fee_revenue( + carry_list, + input_list, + tokens_to_drop=tokens_to_drop, + active_trade_directions=active_trade_directions, + leave_one_out_idxs=leave_one_out_idxs, + n=n, + hypersurge_params=hypersurge_params, + centeredness_margin=centeredness_margin, + daily_price_shift_base=daily_price_shift_base, + seconds_per_step=seconds_per_step, + arc_length_speed=arc_length_speed, + centeredness_scaling=centeredness_scaling, + protocol_fee_split=protocol_fee_split, + ste_temperature=ste_temperature, + noise_trader_ratio=noise_trader_ratio, + noise_model=noise_model, + noise_params=noise_params, + ) + ) + return new_carry, new_reserves + + +@partial(jit, static_argnames=("noise_model",)) +def _jax_calc_reclamm_hypersurge_reserves( + initial_reserves, + initial_Va, + initial_Vb, + prices, + oracle_prices, + hypersurge_params, + centeredness_margin, + daily_price_shift_base, + seconds_per_step, + fees=0.003, + arb_thresh=0.0, + arb_fees=0.0, + price_ratio_updates=None, + all_sig_variations=None, + arc_length_speed=0.0, + centeredness_scaling=False, + protocol_fee_split=0.0, + ste_temperature=10.0, + noise_trader_ratio=0.0, + lp_supply_array=None, + noise_model="ratio", + noise_params=None, + volatility_array=None, + dow_sin_array=None, + dow_cos_array=None, + noise_base_array=None, + noise_tvl_coeff_array=None, + competitor_tvl_array=None, +): + if lp_supply_array is None: + lp_supply_array = jnp.ones((prices.shape[0],), dtype=prices.dtype) + else: + lp_supply_array = broadcast_scan_vector(lp_supply_array, prices.shape[0]) + + fees = broadcast_scan_vector(fees, prices.shape[0]) + arb_thresh = broadcast_scan_vector(arb_thresh, prices.shape[0]) + arb_fees = broadcast_scan_vector(arb_fees, prices.shape[0]) + oracle_prices = _broadcast_oracle_prices(oracle_prices, prices) + price_ratio_updates = _broadcast_schedule_array(price_ratio_updates, prices) + volatility_array = broadcast_scan_vector( + jnp.zeros((1,), dtype=prices.dtype) + if volatility_array is None + else volatility_array, + prices.shape[0], + ) + dow_sin_array = broadcast_scan_vector( + jnp.zeros((1,), dtype=prices.dtype) + if dow_sin_array is None + else dow_sin_array, + prices.shape[0], + ) + dow_cos_array = broadcast_scan_vector( + jnp.zeros((1,), dtype=prices.dtype) + if dow_cos_array is None + else dow_cos_array, + prices.shape[0], + ) + noise_base_array = broadcast_scan_vector( + jnp.zeros((1,), dtype=prices.dtype) + if noise_base_array is None + else noise_base_array, + prices.shape[0], + ) + noise_tvl_coeff_array = broadcast_scan_vector( + jnp.zeros((1,), dtype=prices.dtype) + if noise_tvl_coeff_array is None + else noise_tvl_coeff_array, + prices.shape[0], + ) + competitor_tvl_array = broadcast_scan_vector( + jnp.zeros((1,), dtype=prices.dtype) + if competitor_tvl_array is None + else competitor_tvl_array, + prices.shape[0], + ) + + _, active_trade_directions, tokens_to_drop, leave_one_out_idxs = ( + precalc_shared_values_for_all_signatures(all_sig_variations, 2) + ) + + scan_fn = Partial( + _reclamm_hypersurge_scan_step, + tokens_to_drop=tokens_to_drop, + active_trade_directions=active_trade_directions, + leave_one_out_idxs=leave_one_out_idxs, + n=2, + hypersurge_params=jnp.asarray(hypersurge_params, dtype=prices.dtype), + centeredness_margin=centeredness_margin, + daily_price_shift_base=daily_price_shift_base, + seconds_per_step=seconds_per_step, + arc_length_speed=arc_length_speed, + centeredness_scaling=centeredness_scaling, + protocol_fee_split=protocol_fee_split, + ste_temperature=ste_temperature, + noise_trader_ratio=noise_trader_ratio, + noise_model=noise_model, + noise_params=noise_params if noise_params is not None else {}, + ) + + carry_init = [ + initial_reserves, + initial_Va, + initial_Vb, + jnp.float64(0.0), + jnp.float64(0.0), + jnp.float64(0.0), + jnp.float64(0.0), + jnp.float64(0.0), + jnp.array(False), + lp_supply_array[0], + ] + _, reserves = scan( + scan_fn, + carry_init, + [ + prices, + oracle_prices, + fees, + arb_thresh, + arb_fees, + price_ratio_updates, + lp_supply_array, + volatility_array, + dow_sin_array, + dow_cos_array, + noise_base_array, + noise_tvl_coeff_array, + competitor_tvl_array, + ], + ) + return reserves + + +@partial(jit, static_argnames=("noise_model",)) +def _jax_calc_reclamm_hypersurge_reserves_and_fee_revenue( + initial_reserves, + initial_Va, + initial_Vb, + prices, + oracle_prices, + hypersurge_params, + centeredness_margin, + daily_price_shift_base, + seconds_per_step, + fees=0.003, + arb_thresh=0.0, + arb_fees=0.0, + price_ratio_updates=None, + all_sig_variations=None, + arc_length_speed=0.0, + centeredness_scaling=False, + protocol_fee_split=0.0, + ste_temperature=10.0, + noise_trader_ratio=0.0, + lp_supply_array=None, + noise_model="ratio", + noise_params=None, + volatility_array=None, + dow_sin_array=None, + dow_cos_array=None, + noise_base_array=None, + noise_tvl_coeff_array=None, + competitor_tvl_array=None, +): + if lp_supply_array is None: + lp_supply_array = jnp.ones((prices.shape[0],), dtype=prices.dtype) + else: + lp_supply_array = broadcast_scan_vector(lp_supply_array, prices.shape[0]) + + fees = broadcast_scan_vector(fees, prices.shape[0]) + arb_thresh = broadcast_scan_vector(arb_thresh, prices.shape[0]) + arb_fees = broadcast_scan_vector(arb_fees, prices.shape[0]) + oracle_prices = _broadcast_oracle_prices(oracle_prices, prices) + price_ratio_updates = _broadcast_schedule_array(price_ratio_updates, prices) + volatility_array = broadcast_scan_vector( + jnp.zeros((1,), dtype=prices.dtype) + if volatility_array is None + else volatility_array, + prices.shape[0], + ) + dow_sin_array = broadcast_scan_vector( + jnp.zeros((1,), dtype=prices.dtype) + if dow_sin_array is None + else dow_sin_array, + prices.shape[0], + ) + dow_cos_array = broadcast_scan_vector( + jnp.zeros((1,), dtype=prices.dtype) + if dow_cos_array is None + else dow_cos_array, + prices.shape[0], + ) + noise_base_array = broadcast_scan_vector( + jnp.zeros((1,), dtype=prices.dtype) + if noise_base_array is None + else noise_base_array, + prices.shape[0], + ) + noise_tvl_coeff_array = broadcast_scan_vector( + jnp.zeros((1,), dtype=prices.dtype) + if noise_tvl_coeff_array is None + else noise_tvl_coeff_array, + prices.shape[0], + ) + competitor_tvl_array = broadcast_scan_vector( + jnp.zeros((1,), dtype=prices.dtype) + if competitor_tvl_array is None + else competitor_tvl_array, + prices.shape[0], + ) + + _, active_trade_directions, tokens_to_drop, leave_one_out_idxs = ( + precalc_shared_values_for_all_signatures(all_sig_variations, 2) + ) + + scan_fn = Partial( + _reclamm_hypersurge_scan_step_with_fee_revenue, + tokens_to_drop=tokens_to_drop, + active_trade_directions=active_trade_directions, + leave_one_out_idxs=leave_one_out_idxs, + n=2, + hypersurge_params=jnp.asarray(hypersurge_params, dtype=prices.dtype), + centeredness_margin=centeredness_margin, + daily_price_shift_base=daily_price_shift_base, + seconds_per_step=seconds_per_step, + arc_length_speed=arc_length_speed, + centeredness_scaling=centeredness_scaling, + protocol_fee_split=protocol_fee_split, + ste_temperature=ste_temperature, + noise_trader_ratio=noise_trader_ratio, + noise_model=noise_model, + noise_params=noise_params if noise_params is not None else {}, + ) + + carry_init = [ + initial_reserves, + initial_Va, + initial_Vb, + jnp.float64(0.0), + jnp.float64(0.0), + jnp.float64(0.0), + jnp.float64(0.0), + jnp.float64(0.0), + jnp.array(False), + lp_supply_array[0], + ] + _, (reserves, fee_revenue) = scan( + scan_fn, + carry_init, + [ + prices, + oracle_prices, + fees, + arb_thresh, + arb_fees, + price_ratio_updates, + lp_supply_array, + volatility_array, + dow_sin_array, + dow_cos_array, + noise_base_array, + noise_tvl_coeff_array, + competitor_tvl_array, + ], + ) + return reserves, fee_revenue diff --git a/quantammsim/pools/reCLAMM/reclamm_reserves.py b/quantammsim/pools/reCLAMM/reclamm_reserves.py index 37082ff..2e491ea 100644 --- a/quantammsim/pools/reCLAMM/reclamm_reserves.py +++ b/quantammsim/pools/reCLAMM/reclamm_reserves.py @@ -36,6 +36,9 @@ reclamm_tsoukalas_sqrt_noise_volume, reclamm_tsoukalas_log_noise_volume, reclamm_loglinear_noise_volume, + reclamm_calibrated_noise_volume, + reclamm_market_linear_noise_volume, + reclamm_mm_observed_noise_volume, ) # Reference balance for initialisation (matches Solidity _INITIALIZATION_MAX_BALANCE_A) @@ -1038,10 +1041,65 @@ def _skip_schedule_state(_): effective_value, gamma, volatility, arb_volume, _np ) - noise_fee_income = (1.0 - gamma) * noise_vol - noise_scale = 1.0 + noise_fee_income / jnp.maximum(real_value, 1e-8) - Ra_new = Ra_new * noise_scale - Rb_new = Rb_new * noise_scale + # Scale effective reserves uniformly to preserve quoted price. + # For a 2-CLP: price ∝ (Ra+Va)/(Rb+Vb), so we must scale + # effective reserves (Ra+Va, Rb+Vb) by the same factor, then + # subtract back the fixed virtual reserves. + minutes_per_step = seconds_per_step / 60.0 + noise_fee_income = (1.0 - gamma) * noise_vol * minutes_per_step + scale = 1.0 + noise_fee_income / jnp.maximum(effective_value, 1e-8) + Ra_new = (Ra_new + Va) * scale - Va + Rb_new = (Rb_new + Vb) * scale - Vb + elif noise_model == "calibrated": + volatility = input_list[9] + dow_sin = input_list[10] + dow_cos = input_list[11] + arb_volume = 0.5 * jnp.sum(jnp.abs(applied_trade) * prices) + effective_value = (Ra_new + Va) * prices[0] + (Rb_new + Vb) * prices[1] + + _np = noise_params if noise_params is not None else {} + noise_vol = reclamm_calibrated_noise_volume( + effective_value, gamma, volatility, + arb_volume, dow_sin, dow_cos, _np, + ) + + minutes_per_step = seconds_per_step / 60.0 + noise_fee_income = (1.0 - gamma) * noise_vol * minutes_per_step + scale = 1.0 + noise_fee_income / jnp.maximum(effective_value, 1e-8) + Ra_new = (Ra_new + Va) * scale - Va + Rb_new = (Rb_new + Vb) * scale - Vb + elif noise_model == "market_linear": + noise_base = input_list[9] + noise_tvl_coeff = input_list[10] + effective_value = (Ra_new + Va) * prices[0] + (Rb_new + Vb) * prices[1] + + _np = noise_params if noise_params is not None else {} + noise_vol = reclamm_market_linear_noise_volume( + effective_value, noise_base, noise_tvl_coeff, + tvl_mean=_np.get("tvl_mean", 0.0), + tvl_std=_np.get("tvl_std", 1.0), + ) + + minutes_per_step = seconds_per_step / 60.0 + noise_fee_income = (1.0 - gamma) * noise_vol * minutes_per_step + scale = 1.0 + noise_fee_income / jnp.maximum(effective_value, 1e-8) + Ra_new = (Ra_new + Va) * scale - Va + Rb_new = (Rb_new + Vb) * scale - Vb + elif noise_model == "mm_observed": + noise_base = input_list[9] + competitor_tvl = input_list[10] + effective_value = (Ra_new + Va) * prices[0] + (Rb_new + Vb) * prices[1] + + noise_vol = reclamm_mm_observed_noise_volume( + effective_value, noise_base, competitor_tvl, + ) + + minutes_per_step = seconds_per_step / 60.0 + noise_fee_income = (1.0 - gamma) * noise_vol * minutes_per_step + scale = 1.0 + noise_fee_income / jnp.maximum(effective_value, 1e-8) + Ra_new = (Ra_new + Va) * scale - Va + Rb_new = (Rb_new + Vb) * scale - Vb + # else: "arb_only" — no noise trades # Clamp-to-edge: if a real reserve would go negative, apply an # exact-in-given-out edge trade that drains that token to _DUST_USD @@ -1311,6 +1369,11 @@ def _jax_calc_reclamm_reserves_with_fees( noise_model="ratio", noise_params=None, volatility_array=None, + dow_sin_array=None, + dow_cos_array=None, + noise_base_array=None, + noise_tvl_coeff_array=None, + competitor_tvl_array=None, ): """Calculate reClAMM reserves over time with fees. @@ -1390,6 +1453,16 @@ def _jax_calc_reclamm_reserves_with_fees( ] if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): scan_inputs.append(volatility_array) + elif noise_model == "calibrated": + scan_inputs.append(volatility_array) + scan_inputs.append(dow_sin_array) + scan_inputs.append(dow_cos_array) + elif noise_model == "market_linear": + scan_inputs.append(noise_base_array) + scan_inputs.append(noise_tvl_coeff_array) + elif noise_model == "mm_observed": + scan_inputs.append(noise_base_array) + scan_inputs.append(competitor_tvl_array) _, reserves = scan(scan_fn, carry_init, scan_inputs) return reserves @@ -1420,6 +1493,11 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs( noise_model="ratio", noise_params=None, volatility_array=None, + dow_sin_array=None, + dow_cos_array=None, + noise_base_array=None, + noise_tvl_coeff_array=None, + competitor_tvl_array=None, ): """Calculate reClAMM reserves with time-varying fees/arb arrays.""" if lp_supply_array is None: @@ -1508,6 +1586,16 @@ def _jax_calc_reclamm_reserves_with_dynamic_inputs( ] if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): scan_inputs.append(volatility_array) + elif noise_model == "calibrated": + scan_inputs.append(volatility_array) + scan_inputs.append(dow_sin_array) + scan_inputs.append(dow_cos_array) + elif noise_model == "market_linear": + scan_inputs.append(noise_base_array) + scan_inputs.append(noise_tvl_coeff_array) + elif noise_model == "mm_observed": + scan_inputs.append(noise_base_array) + scan_inputs.append(competitor_tvl_array) _, reserves = scan(scan_fn, carry_init, scan_inputs) return reserves @@ -1652,6 +1740,11 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( noise_model="ratio", noise_params=None, volatility_array=None, + dow_sin_array=None, + dow_cos_array=None, + noise_base_array=None, + noise_tvl_coeff_array=None, + competitor_tvl_array=None, ): """Calculate reClAMM reserves and LP fee revenue over time with fees. @@ -1733,6 +1826,16 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_fees( ] if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): scan_inputs.append(volatility_array) + elif noise_model == "calibrated": + scan_inputs.append(volatility_array) + scan_inputs.append(dow_sin_array) + scan_inputs.append(dow_cos_array) + elif noise_model == "market_linear": + scan_inputs.append(noise_base_array) + scan_inputs.append(noise_tvl_coeff_array) + elif noise_model == "mm_observed": + scan_inputs.append(noise_base_array) + scan_inputs.append(competitor_tvl_array) _, (reserves, fee_revenue) = scan(scan_fn, carry_init, scan_inputs) return reserves, fee_revenue @@ -1763,6 +1866,11 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( noise_model="ratio", noise_params=None, volatility_array=None, + dow_sin_array=None, + dow_cos_array=None, + noise_base_array=None, + noise_tvl_coeff_array=None, + competitor_tvl_array=None, ): """Calculate reClAMM reserves and LP fee revenue with time-varying fees/arb arrays. @@ -1857,6 +1965,16 @@ def _jax_calc_reclamm_reserves_and_fee_revenue_with_dynamic_inputs( ] if noise_model in ("tsoukalas_sqrt", "tsoukalas_log", "loglinear"): scan_inputs.append(volatility_array) + elif noise_model == "calibrated": + scan_inputs.append(volatility_array) + scan_inputs.append(dow_sin_array) + scan_inputs.append(dow_cos_array) + elif noise_model == "market_linear": + scan_inputs.append(noise_base_array) + scan_inputs.append(noise_tvl_coeff_array) + elif noise_model == "mm_observed": + scan_inputs.append(noise_base_array) + scan_inputs.append(competitor_tvl_array) _, (reserves, fee_revenue) = scan(scan_fn, carry_init, scan_inputs) return reserves, fee_revenue diff --git a/quantammsim/runners/default_run_fingerprint.py b/quantammsim/runners/default_run_fingerprint.py index 9b2a068..4334b4a 100644 --- a/quantammsim/runners/default_run_fingerprint.py +++ b/quantammsim/runners/default_run_fingerprint.py @@ -84,7 +84,7 @@ "return_val": "daily_log_sharpe", "initial_pool_value": 1000000.0, "fees": 0.0, - "protocol_fee_split": 0.0, # fraction of swap fees diverted from LP reserves to protocol treasury + "protocol_fee_split": 0.25, # fraction of swap fees diverted from LP reserves to protocol treasury "arb_fees": 0.0, "gas_cost": 0.0, "use_alt_lamb": False, @@ -223,6 +223,42 @@ "log_scale": False, "scalar": False, }, + "hypersurge_arb_max_fee": { + "low": 0.0, + "high": 0.20, + "log_scale": False, + "scalar": True, + }, + "hypersurge_arb_threshold": { + "low": 0.0, + "high": 1.0, + "log_scale": False, + "scalar": True, + }, + "hypersurge_arb_cap_deviation": { + "low": 0.0, + "high": 2.0, + "log_scale": False, + "scalar": True, + }, + "hypersurge_noise_max_fee": { + "low": 0.0, + "high": 0.50, + "log_scale": False, + "scalar": True, + }, + "hypersurge_noise_threshold": { + "low": 0.0, + "high": 1.0, + "log_scale": False, + "scalar": True, + }, + "hypersurge_noise_cap_deviation": { + "low": 0.0, + "high": 2.0, + "log_scale": False, + "scalar": True, + }, }, } diff --git a/quantammsim/runners/jax_runner_utils.py b/quantammsim/runners/jax_runner_utils.py index aa5b79d..4ebdcce 100644 --- a/quantammsim/runners/jax_runner_utils.py +++ b/quantammsim/runners/jax_runner_utils.py @@ -658,7 +658,7 @@ def __eq__(self, other): # These are excluded when creating static_dict from run_fingerprint _TRAINING_ONLY_FIELDS = frozenset({ "optimisation_settings", # Contains lr, optimizer, etc. - "startDateString", # Data loading dates + # startDateString kept in static dict — needed by calibrated noise model "endDateString", "endTestDateString", "subsidary_pools", # Handled separately @@ -675,6 +675,10 @@ def __eq__(self, other): "initial_raw_width", "initial_raw_exponents", "initial_pre_exp_scaling", + # Noise model arrays — loaded from path at runtime, not hashable + "noise_base_array", + "noise_tvl_coeff_array", + "competitor_tvl_array", }) @@ -1228,6 +1232,7 @@ def _to_dynamic_input_arrays( arb_fees_array, lp_supply_array, reclamm_price_ratio_updates_array, + oracle_prices_array, ) -> DynamicInputArrays: """Normalize optional numpy arrays into the hot-path container.""" empty = empty_dynamic_input_arrays() @@ -1242,6 +1247,11 @@ def _to_dynamic_input_arrays( if reclamm_price_ratio_updates_array is None else jnp.asarray(reclamm_price_ratio_updates_array, dtype=jnp.float64) ), + oracle_prices=( + empty.oracle_prices + if oracle_prices_array is None + else jnp.asarray(oracle_prices_array, dtype=jnp.float64) + ), ) @@ -1402,6 +1412,7 @@ def prepare_dynamic_inputs( arb_fees_df = dynamic_input_frames.arb_fees lp_supply_df = dynamic_input_frames.lp_supply reclamm_price_ratio_updates = dynamic_input_frames.reclamm_price_ratio_updates + oracle_prices_df = dynamic_input_frames.oracle_prices dynamic_input_flags = dynamic_input_flags_from_frames(dynamic_input_frames) if raw_trades is not None: @@ -1517,6 +1528,30 @@ def prepare_dynamic_inputs( else None ) + oracle_prices_array = ( + raw_fee_like_amounts_to_fee_like_array( + oracle_prices_df, + run_fingerprint["startDateString"], + run_fingerprint["endDateString"], + names=get_unique_tokens(run_fingerprint), + fill_method="ffill", + ) + if oracle_prices_df is not None + else None + ) + if do_test_period: + test_oracle_prices_array = ( + raw_fee_like_amounts_to_fee_like_array( + oracle_prices_df, + run_fingerprint["endDateString"], + run_fingerprint["endTestDateString"], + names=get_unique_tokens(run_fingerprint), + fill_method="ffill", + ) + if oracle_prices_df is not None + else None + ) + reclamm_price_ratio_updates_array = ( _normalize_reclamm_price_ratio_updates_for_window( reclamm_price_ratio_updates, @@ -1582,6 +1617,7 @@ def prepare_dynamic_inputs( arb_fees_array, lp_supply_array, reclamm_price_ratio_updates_array, + oracle_prices_array, ), "test_dynamic_inputs": _to_dynamic_input_arrays( test_period_trades, @@ -1590,6 +1626,7 @@ def prepare_dynamic_inputs( test_arb_fees_array, test_lp_supply_array, test_reclamm_price_ratio_updates_array, + test_oracle_prices_array, ), "dynamic_input_flags": dynamic_input_flags, } @@ -1601,6 +1638,7 @@ def prepare_dynamic_inputs( arb_fees_array, lp_supply_array, reclamm_price_ratio_updates_array, + oracle_prices_array, ), "dynamic_input_flags": dynamic_input_flags, } diff --git a/quantammsim/runners/jax_runners.py b/quantammsim/runners/jax_runners.py index 178f0a0..fd55955 100644 --- a/quantammsim/runners/jax_runners.py +++ b/quantammsim/runners/jax_runners.py @@ -54,6 +54,7 @@ _calculate_return_value, ) from quantammsim.core_simulator.dynamic_inputs import ( + DynamicInputArrays, DynamicInputFrames, materialize_dynamic_inputs, ) @@ -128,6 +129,52 @@ _scan_infra_cache = {} +def _concat_dynamic_input_arrays( + train_dynamic_inputs: DynamicInputArrays, + test_dynamic_inputs=None, +): + """Concatenate train/test dynamic-input bundles for continuous evaluation.""" + if train_dynamic_inputs is None: + return None + if test_dynamic_inputs is None: + return train_dynamic_inputs + + def _concat_leaf(train_leaf, test_leaf): + if train_leaf is None: + return test_leaf + if test_leaf is None: + return train_leaf + if hasattr(train_leaf, "shape") and hasattr(test_leaf, "shape"): + if train_leaf.shape[0] <= 1 and test_leaf.shape[0] <= 1: + return train_leaf + if train_leaf.shape[0] <= 1: + return test_leaf + if test_leaf.shape[0] <= 1: + return train_leaf + return jnp.concatenate((train_leaf, test_leaf), axis=0) + + return DynamicInputArrays( + trades=_concat_leaf(train_dynamic_inputs.trades, test_dynamic_inputs.trades), + fees=_concat_leaf(train_dynamic_inputs.fees, test_dynamic_inputs.fees), + gas_cost=_concat_leaf( + train_dynamic_inputs.gas_cost, test_dynamic_inputs.gas_cost + ), + arb_fees=_concat_leaf( + train_dynamic_inputs.arb_fees, test_dynamic_inputs.arb_fees + ), + lp_supply=_concat_leaf( + train_dynamic_inputs.lp_supply, test_dynamic_inputs.lp_supply + ), + reclamm_price_ratio_updates=_concat_leaf( + train_dynamic_inputs.reclamm_price_ratio_updates, + test_dynamic_inputs.reclamm_price_ratio_updates, + ), + oracle_prices=_concat_leaf( + train_dynamic_inputs.oracle_prices, test_dynamic_inputs.oracle_prices + ), + ) + + def _scan_config_key(run_fingerprint, chunk_size, original_bout_length, bout_length_test): """Compute a hash key capturing everything that affects the compiled scan.""" fp_str = _json.dumps(run_fingerprint, sort_keys=True, default=str) @@ -329,6 +376,7 @@ def train_on_historic_data( iterations_per_print=1, force_init=False, price_data=None, + dynamic_input_frames: DynamicInputFrames = None, verbose=True, run_location=None, return_training_metadata=False, @@ -374,6 +422,10 @@ def train_on_historic_data( price_data : array-like or DataFrame, optional Pre-loaded price data. When None, data is loaded from parquet files based on ``run_fingerprint`` date/token settings. + dynamic_input_frames : DynamicInputFrames, optional + Optional minute-level dynamic input bundle. This is most useful + for HyperSurge oracle prices during training; the runner windows + the arrays alongside price windows. verbose : bool, optional Print detailed progress information (default True). run_location : str, optional @@ -425,7 +477,7 @@ def train_on_historic_data( try: return _train_on_historic_data_impl( run_fingerprint, root, iterations_per_print, force_init, - price_data, verbose, run_location, return_training_metadata, + price_data, dynamic_input_frames, verbose, run_location, return_training_metadata, warm_start_params, warm_start_weights, ) finally: @@ -434,7 +486,7 @@ def train_on_historic_data( def _train_on_historic_data_impl( run_fingerprint, root, iterations_per_print, force_init, - price_data, verbose, run_location, return_training_metadata, + price_data, dynamic_input_frames, verbose, run_location, return_training_metadata, warm_start_params, warm_start_weights, ): if verbose: @@ -473,6 +525,26 @@ def _train_on_historic_data_impl( ) max_memory_days = data_dict["max_memory_days"] + dynamic_inputs_dict = prepare_dynamic_inputs( + run_fingerprint, + dynamic_input_frames=dynamic_input_frames, + do_test_period=True, + ) + dynamic_input_flags = dynamic_inputs_dict["dynamic_input_flags"] + train_dynamic_inputs = ( + dynamic_inputs_dict["train_dynamic_inputs"] + if dynamic_input_flags["use_dynamic_inputs"] + else None + ) + continuous_dynamic_inputs = ( + _concat_dynamic_input_arrays( + train_dynamic_inputs, + dynamic_inputs_dict.get("test_dynamic_inputs"), + ) + if dynamic_input_flags["use_dynamic_inputs"] + else None + ) + # Validation holdout setup # If val_fraction > 0, carve out validation window from end of training val_fraction = run_fingerprint["optimisation_settings"].get("val_fraction", 0.0) @@ -683,22 +755,16 @@ def _train_on_historic_data_impl( overrides={ "n_assets": n_assets, "training_data_kind": run_fingerprint["optimisation_settings"]["training_data_kind"], - "do_trades": False, - "dynamic_input_flags": { - "use_dynamic_inputs": False, - "has_trades": False, - "has_dynamic_fees": False, - "has_dynamic_gas_cost": False, - "has_dynamic_arb_fees": False, - "has_lp_supply": False, - "has_reclamm_price_ratio_updates": False, - }, + "do_trades": dynamic_input_flags["has_trades"], + "dynamic_input_flags": dynamic_input_flags, + "dynamic_inputs_offset": data_dict["start_idx"], }, ) partial_training_step = Partial( forward_pass, prices=data_dict["prices"], + dynamic_inputs=train_dynamic_inputs, static_dict=Hashabledict(base_static_dict), pool=pool, ) @@ -714,7 +780,7 @@ def _train_on_historic_data_impl( continuous_static_dict["bout_length"] = original_bout_length + data_dict["bout_length_test"] partial_forward_pass_nograd_batch_continuous = Partial( forward_pass_nograd, - dynamic_inputs=None, + dynamic_inputs=continuous_dynamic_inputs, static_dict=Hashabledict(continuous_static_dict), pool=pool, ) @@ -854,13 +920,15 @@ def init_optimizer(params): run_fingerprint, chunk_size, original_bout_length, _bout_length_test, ) - if config_key in _scan_infra_cache: + use_scan_cache = not dynamic_input_flags["use_dynamic_inputs"] + + if use_scan_cache and config_key in _scan_infra_cache: _run_scan_chunk, scan_body, _run_scan_step = _scan_infra_cache[config_key] else: # Build scan-compatible update (prices as explicit arg, not closure) partial_step_no_prices = Partial( forward_pass, - dynamic_inputs=None, + dynamic_inputs=train_dynamic_inputs, static_dict=Hashabledict(base_static_dict), pool=pool, ) @@ -900,7 +968,12 @@ def init_optimizer(params): swa_freq=swa_freq, n_parameter_sets=n_parameter_sets, ) - _scan_infra_cache[config_key] = (_run_scan_chunk, scan_body, _run_scan_step) + if use_scan_cache: + _scan_infra_cache[config_key] = ( + _run_scan_chunk, + scan_body, + _run_scan_step, + ) # ── Initialize carry (prices & nan_bank in carry, not closures) ── carry = { @@ -1298,7 +1371,9 @@ def _extract_params_at(params_tree, j): return selected_params elif run_fingerprint["optimisation_settings"]["method"] == "optuna": - n_evaluation_points = 20 + n_evaluation_points = run_fingerprint["optimisation_settings"].get( + "optuna_settings", {} + ).get("n_evaluation_points", 20) min_spacing = data_dict["bout_length"] // 2 # E run_fingerprint["optimisation_settings"]["n_parameter_sets"] = 1 @@ -1454,6 +1529,17 @@ def objective(trial): initial_reserves=train_outputs["reserves"][0], ) + # Reject catastrophic in-sample configurations + min_train_ret_over_hodl = run_fingerprint["optimisation_settings"][ + "optuna_settings"].get("min_train_returns_over_hodl", None) + if min_train_ret_over_hodl is not None: + if float(train_returns_over_hodl) < min_train_ret_over_hodl: + optuna_manager.logger.info( + f"Training {trial.number}, REJECTED:" + f" ret_over_hodl={train_returns_over_hodl:.4f}" + f" < {min_train_ret_over_hodl}") + return float("-inf") + # Test period evaluation using continuous forward pass # This ensures test metrics reflect continuous simulation from training continuous_outputs = partial_forward_pass_continuous_optuna( @@ -2641,6 +2727,7 @@ def do_run_on_historic_data( "gas_cost": gas_cost if gas_cost is not None else run_fingerprint["gas_cost"], "do_trades": dynamic_inputs_dict["dynamic_input_flags"]["has_trades"], "dynamic_input_flags": dynamic_inputs_dict["dynamic_input_flags"], + "dynamic_inputs_offset": data_dict["start_idx"], # Include date strings for run-time use "startDateString": run_fingerprint["startDateString"], "endDateString": run_fingerprint["endDateString"], @@ -2664,6 +2751,9 @@ def do_run_on_historic_data( reserves_values_test_static_dict = base_static_dict.copy() reserves_values_test_static_dict["return_val"] = "reserves_and_values" reserves_values_test_static_dict["bout_length"] = data_dict["bout_length_test"] + reserves_values_test_static_dict["dynamic_inputs_offset"] = data_dict[ + "start_idx_test" + ] partial_forward_pass_nograd_batch_reserves_values_test = jit( Partial( forward_pass_nograd, @@ -2877,6 +2967,7 @@ def do_run_on_historic_data_with_provided_coarse_weights( "gas_cost": gas_cost if gas_cost is not None else run_fingerprint["gas_cost"], "do_trades": dynamic_inputs_dict["dynamic_input_flags"]["has_trades"], "dynamic_input_flags": dynamic_inputs_dict["dynamic_input_flags"], + "dynamic_inputs_offset": data_dict["start_idx"], # Include date strings for run-time use "startDateString": run_fingerprint["startDateString"], "endDateString": run_fingerprint["endDateString"], diff --git a/results/linear_market_noise/_sim_arrays/0x9d1fcf346ea1b0_2024-06-01_2026-03-01.npz b/results/linear_market_noise/_sim_arrays/0x9d1fcf346ea1b0_2024-06-01_2026-03-01.npz new file mode 100644 index 0000000..368146c Binary files /dev/null and b/results/linear_market_noise/_sim_arrays/0x9d1fcf346ea1b0_2024-06-01_2026-03-01.npz differ diff --git a/results/linear_market_noise/_sim_arrays/0x9d1fcf346ea1b0_2025-08-03_2026-02-18.npz b/results/linear_market_noise/_sim_arrays/0x9d1fcf346ea1b0_2025-08-03_2026-02-18.npz new file mode 100644 index 0000000..03413c9 Binary files /dev/null and b/results/linear_market_noise/_sim_arrays/0x9d1fcf346ea1b0_2025-08-03_2026-02-18.npz differ diff --git a/results/linear_market_noise/meta.json b/results/linear_market_noise/meta.json new file mode 100644 index 0000000..35249de --- /dev/null +++ b/results/linear_market_noise/meta.json @@ -0,0 +1,77 @@ +{ + "feat_names": [ + "xobs_0", + "xobs_1", + "xobs_2", + "xobs_3", + "btc_log_price", + "btc_log_return", + "btc_realized_vol_7d", + "btc_trend_7d", + "btc_volume_zscore", + "pair_realized_vol_7d", + "tok_a_log_return", + "tok_a_realized_vol_7d", + "tok_a_trend_7d", + "tok_a_volume_zscore", + "tok_b_log_return", + "tok_b_realized_vol_7d", + "tok_b_trend_7d", + "tok_b_volume_zscore", + "xobs_1\u00d7btc_realized_vol_7d", + "xobs_1\u00d7tok_a_realized_vol_7d", + "xobs_1\u00d7pair_realized_vol_7d", + "tok_a_realized_vol_7d\u00d7tok_b_realized_vol_7d" + ], + "pool_ids": [ + "0x072f14b85add63", + "0x0b09dea16768f0", + "0x10f21c9bd8128a", + "0x1535d7ca00323a", + "0x21d4c792ea7e38", + "0x25ca5451cd5a50", + "0x260dbd54d87a10", + "0x272d6be442e30d", + "0x32df62dc3aed2c", + "0x36be1e97ea98ab", + "0x3de27efa2f1aa6", + "0x3e5fa9518ea95c", + "0x4683e340a80492", + "0x4cdabe9e07ca39", + "0x4fbb7870dbe7a7", + "0x571bea0e99e139", + "0x5c6ee304399dbd", + "0x5f1f4e50ba51d7", + "0x711af51a937e01", + "0x713fb5036dc700", + "0x9232a548dd9e81", + "0x92762b42a06dcd", + "0x96646936b91d6b", + "0x9d1fcf346ea1b0", + "0xa6f548df93de92", + "0xa83b8d30f61d75", + "0xb460daa847c45f", + "0xbc2acf5e821c5c", + "0xbda917a67c7d9a", + "0xcc65a812ce382a", + "0xcf354603a9aebd", + "0xcf7b51ce575551", + "0xd1d7fa8871d84d", + "0xdaba3d8ccf79ef", + "0xe99481dc77691d", + "0xf16aee6a71af1a" + ], + "n_pools": 36, + "n_feat": 22, + "hparams": { + "epochs": 2000, + "lr": 0.0003, + "l2_alpha": 0.001, + "huber_delta": 1.0, + "trend_windows": [ + 7 + ], + "per_pool": true, + "pool_intercepts": false + } +} \ No newline at end of file diff --git a/results/linear_market_noise/model.npz b/results/linear_market_noise/model.npz new file mode 100644 index 0000000..f440f04 Binary files /dev/null and b/results/linear_market_noise/model.npz differ diff --git a/scripts/build_pool_grids.py b/scripts/build_pool_grids.py index 9b710c2..ecf3152 100644 --- a/scripts/build_pool_grids.py +++ b/scripts/build_pool_grids.py @@ -215,8 +215,11 @@ def load_panel_and_match(train_days): if "tvl" not in panel.columns and "log_tvl" in panel.columns: panel["tvl"] = np.exp(panel["log_tvl"]) - cutoff = panel["obs_date"].max() - pd.Timedelta(days=train_days) - panel = panel[panel["obs_date"] >= cutoff].copy() + if train_days > 0: + cutoff = panel["obs_date"].max() - pd.Timedelta(days=train_days) + panel = panel[panel["obs_date"] >= cutoff].copy() + else: + panel = panel.copy() # no date filter — use all data binance_tokens = _get_binance_tokens() pools_meta = load_pools_metadata() @@ -342,6 +345,7 @@ def run_arb_sim(tokens, fee, initial_tvl, start, end, cadence, gas_cost, "arb_frequency": int(cadence), "chunk_period": 1440, "weight_interpolation_period": 1440, + "max_memory_days": 0, } if pool_type == "RECLAMM" and reclamm_params is not None: @@ -363,7 +367,7 @@ def run_arb_sim(tokens, fee, initial_tvl, start, end, cadence, gas_cost, result = do_run_on_historic_data( fp, params, lp_supply_df=lp_supply_df, verbose=False, - price_data=price_data, + price_data=price_data, preslice_burnin=False, ) reserves = np.array(result["reserves"]) @@ -405,6 +409,9 @@ def _run_cadence_sweep(pool_info, cadence, gas_costs): reclamm_params = pool_info.get("reclamm_params") price_data = pool_info.get("price_data") + if price_data is None: + sorted_tokens = sorted(tokens) + price_data = get_historic_parquet_data(sorted_tokens, ["close"]) daily_rows = [] summary_rows = [] @@ -527,10 +534,51 @@ def main(): print(f"Found {len(pools)} matchable 2-token pools\n") for p in pools: - lp_df, tvl = build_lp_supply_df(p["panel_data"]) + # Preload price data to determine actual date coverage + sorted_tokens = sorted(p["tokens"]) + price_data = get_historic_parquet_data(sorted_tokens, ["close"]) + p["price_data"] = price_data + + if len(price_data) == 0: + p["panel_data"] = p["panel_data"].iloc[:0] # empty + p["lp_supply_df"] = None + p["initial_tvl"] = 0.0 + p["start"], p["end"] = "2000-01-01", "2000-01-01" + continue + + # Clip panel to price data's actual date range + price_dates = pd.to_datetime(price_data.index, unit="ms") + price_start = price_dates.min().normalize() + price_end = price_dates.max().normalize() + panel_data = p["panel_data"] + panel_data = panel_data[ + (panel_data["obs_date"] >= price_start) + & (panel_data["obs_date"] <= price_end) + ].copy() + p["panel_data"] = panel_data + + lp_df, tvl = build_lp_supply_df(panel_data) p["lp_supply_df"] = lp_df p["initial_tvl"] = tvl - p["start"], p["end"] = get_date_range(p["panel_data"]) + + # Start/end must be midnight timestamps that exist in the price data. + # start_and_end_calcs does an exact unix match and assumes alignment. + # If the price data starts mid-day (e.g. COW at noon), advance to + # the next midnight so the sim has a clean day boundary. + if len(panel_data) > 0: + first_midnight = (price_dates.min() + pd.Timedelta(days=1)).normalize() + last_midnight = price_dates.max().normalize() + first_midnight_ms = int(first_midnight.timestamp() * 1000) + last_midnight_ms = int(last_midnight.timestamp() * 1000) + # Verify these timestamps exist in the price data + if first_midnight_ms in price_data.index and last_midnight_ms in price_data.index: + p["start"] = first_midnight.strftime("%Y-%m-%d %H:%M:%S") + p["end"] = last_midnight.strftime("%Y-%m-%d %H:%M:%S") + else: + # Fallback: use panel dates (works when price data covers full range) + p["start"], p["end"] = get_date_range(panel_data) + else: + p["start"], p["end"] = "2000-01-01", "2000-01-01" pools = [p for p in pools if p["panel_data"]["obs_date"].nunique() >= 14] pools.sort(key=lambda p: p["initial_tvl"], reverse=True) @@ -580,9 +628,8 @@ def main(): t0 = time.time() - # Preload price data once per pool (avoids re-reading parquet per run) - sorted_tokens = sorted(tokens) - price_data = get_historic_parquet_data(sorted_tokens, ["close"]) + # Price data was preloaded during panel clipping + price_data = pool["price_data"] pool_info = { "tokens": tokens, @@ -594,7 +641,10 @@ def main(): "weights": pool["weights"], "pool_type": pool_type, "reclamm_params": pool.get("reclamm_params"), - "price_data": price_data, + # Only pass preloaded price_data in single-worker mode. + # For multi-worker, each subprocess loads its own to avoid + # pickling multi-million-row DataFrames across processes. + "price_data": price_data if args.workers <= 1 else None, } all_daily_rows = [] diff --git a/scripts/compare_modelled_vs_real.py b/scripts/compare_modelled_vs_real.py new file mode 100644 index 0000000..e23c7e4 --- /dev/null +++ b/scripts/compare_modelled_vs_real.py @@ -0,0 +1,243 @@ +"""Compare modelled reClAMM noise volume against a real pool's observed volume. + +Plots the modelled noise volume for one pool (at a specified counterfactual TVL) +against the actual observed volume of another pool (or the same pool), to +sanity-check the noise model's predictions. + +Usage: + # reClAMM AAVE/ETH modelled at $7M vs weighted wstETH/AAVE real + python scripts/compare_modelled_vs_real.py \ + --model-pool 0x9d1fcf346ea1b0 --model-tvl 7e6 \ + --real-pool 0x3de27efa2f1aa6 + + # Same but at $20M + python scripts/compare_modelled_vs_real.py \ + --model-pool 0x9d1fcf346ea1b0 --model-tvl 20e6 \ + --real-pool 0x3de27efa2f1aa6 + + # Multiple TVL levels + python scripts/compare_modelled_vs_real.py \ + --model-pool 0x9d1fcf346ea1b0 --model-tvl 1e6 7e6 20e6 50e6 \ + --real-pool 0x3de27efa2f1aa6 +""" + +import argparse +import os +import pickle + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from quantammsim.calibration.noise_model_arrays import ( + build_simulator_arrays, load_artifact, _find_pool_index, +) + + +CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "token_factored_calibration", "_cache", +) +ARTIFACT_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "linear_market_noise", +) +OUTPUT_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "noise_comparison", +) + +# Token mapping for pools +POOL_TOKENS = { + "0x9d1fcf346ea1b0": ("AAVE", "ETH"), + "0x3de27efa2f1aa6": ("AAVE", "ETH"), # wstETH/AAVE ≈ same pair + "0x0b09dea16768f0": ("DAI", "ETH"), + "0xa6f548df93de92": ("BTC", "ETH"), + "0x96646936b91d6b": ("USDC", "ETH"), +} + + +def load_real_pool(pid, mc): + """Load real observed volume + TVL for a pool.""" + entry = mc[pid] + panel = entry["panel"] + dates = pd.to_datetime(panel["date"]) + vol = np.exp(panel["log_volume"].values.astype(float)) + tvl = np.exp(panel["log_tvl_lag1"].values.astype(float)) + tokens = entry["tokens"] + chain = entry["chain"] + return dates, vol, tvl, tokens, chain + + +def compute_modelled_noise(pid, tvl_value, start_date, end_date, + artifact_dir, token_a, token_b): + """Compute modelled daily noise for a pool at a given TVL.""" + arrays = build_simulator_arrays( + token_a=token_a, token_b=token_b, + start_date=start_date, end_date=end_date, + artifact_dir=artifact_dir, pool_id=pid, + ) + + n_days = arrays["n_days"] + std_lt = (np.log(tvl_value) - arrays["tvl_mean"]) / arrays["tvl_std"] + + noise_base = arrays["noise_base"][::1440][:n_days] + tvl_coeff = arrays["noise_tvl_coeff"][::1440][:n_days] + noise_daily = np.exp(noise_base + tvl_coeff * std_lt) + + dates = pd.to_datetime(arrays["dates"][:n_days]) + return dates, noise_daily + + +def main(): + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--model-pool", default="0x9d1fcf346ea1b0", + help="Pool ID for modelled noise") + parser.add_argument("--model-tvl", type=float, nargs="+", + default=[7_000_000], + help="Counterfactual TVL(s) for the modelled pool") + parser.add_argument("--model-tokens", nargs=2, default=None, + help="Token A and B for the modelled pool (auto-detected)") + parser.add_argument("--real-pool", default="0x3de27efa2f1aa6", + help="Pool ID for real observed data") + parser.add_argument("--artifact-dir", default=ARTIFACT_DIR) + parser.add_argument("--output-dir", default=OUTPUT_DIR) + args = parser.parse_args() + + os.environ.setdefault("JAX_PLATFORMS", "cpu") + os.makedirs(args.output_dir, exist_ok=True) + + # Load calibration data + with open(os.path.join(CACHE_DIR, "stage1.pkl"), "rb") as f: + data = pickle.load(f) + mc = data["matched_clean"] + + # Real pool data + real_dates, real_vol, real_tvl, real_tokens, real_chain = load_real_pool( + args.real_pool, mc) + print(f"Real pool: {args.real_pool} ({real_tokens}, {real_chain})") + print(f" {len(real_dates)} days: {real_dates.min().date()} → {real_dates.max().date()}") + print(f" TVL: ${real_tvl.min():,.0f} – ${real_tvl.max():,.0f}") + print(f" Volume: ${real_vol.min():,.0f} – ${real_vol.max():,.0f}") + + # Model tokens + if args.model_tokens: + tok_a, tok_b = args.model_tokens + elif args.model_pool[:16] in POOL_TOKENS: + tok_a, tok_b = POOL_TOKENS[args.model_pool[:16]] + else: + tok_a, tok_b = "ETH", "USDC" + print(f" Warning: unknown pool, using {tok_a}/{tok_b}") + + # Date range from real pool + start = str(real_dates.min().date()) + end = str(real_dates.max().date()) + + # Compute modelled noise at each TVL + model_results = [] + for tvl_val in args.model_tvl: + print(f"\nModelled: {args.model_pool} at ${tvl_val:,.0f} TVL") + m_dates, m_noise = compute_modelled_noise( + args.model_pool, tvl_val, start, end, + args.artifact_dir, tok_a, tok_b) + model_results.append((tvl_val, m_dates, m_noise)) + print(f" Median noise: ${np.median(m_noise):,.0f}/day" + f" ({np.median(m_noise)/tvl_val*100:.2f}% of TVL)") + + # Align dates + common_start = real_dates.min() + common_end = real_dates.max() + for _, md, _ in model_results: + common_start = max(common_start, md.min()) + common_end = min(common_end, md.max()) + + real_mask = (real_dates >= common_start) & (real_dates <= common_end) + + # Colors for different TVL levels + colors = ["#e74c3c", "#3498db", "#2ecc71", "#f39c12", "#9b59b6"] + + # Plot + fig, axes = plt.subplots(2, 1, figsize=(14, 9)) + + # 1. Volume comparison + ax = axes[0] + ax.plot(real_dates[real_mask], real_vol[real_mask] / 1e6, + "k-", linewidth=0.8, alpha=0.7, + label=f"{real_tokens} weighted (real," + f" TVL ${np.median(real_tvl[real_mask])/1e6:.0f}M)") + + for i, (tvl_val, m_dates, m_noise) in enumerate(model_results): + m_mask = (m_dates >= common_start) & (m_dates <= common_end) + c = colors[i % len(colors)] + ax.plot(m_dates[m_mask], m_noise[m_mask] / 1e6, + "-", color=c, linewidth=0.8, alpha=0.7, + label=f"reClAMM noise (modelled, TVL ${tvl_val/1e6:.0f}M)") + + ax.set_ylabel("Volume ($M/day)") + ax.set_yscale("log") + ax.set_title(f"Real weighted pool vs Modelled reClAMM noise") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # 2. Vol/TVL comparison + ax = axes[1] + real_vol_tvl = real_vol[real_mask] / real_tvl[real_mask] * 100 + ax.plot(real_dates[real_mask], real_vol_tvl, + "k-", linewidth=0.8, alpha=0.7, + label=f"{real_tokens} weighted real vol/TVL") + ax.axhline(np.median(real_vol_tvl), color="black", linestyle="--", + alpha=0.3, label=f"weighted median: {np.median(real_vol_tvl):.2f}%") + + for i, (tvl_val, m_dates, m_noise) in enumerate(model_results): + m_mask = (m_dates >= common_start) & (m_dates <= common_end) + noise_tvl = m_noise[m_mask] / tvl_val * 100 + c = colors[i % len(colors)] + ax.plot(m_dates[m_mask], noise_tvl, + "-", color=c, linewidth=0.8, alpha=0.7, + label=f"reClAMM noise/TVL (${tvl_val/1e6:.0f}M)") + ax.axhline(np.median(noise_tvl), color=c, linestyle="--", alpha=0.3, + label=f"median: {np.median(noise_tvl):.2f}%") + + ax.set_ylabel("Volume / TVL (%)") + ax.set_xlabel("Date") + ax.set_title("Volume as Fraction of TVL") + ax.legend(fontsize=7, loc="upper right") + ax.grid(True, alpha=0.3) + ymax = min( + max(np.percentile(real_vol_tvl, 95), + max(np.percentile(m_noise[m_mask] / tvl_val * 100, 95) + for tvl_val, m_dates, m_noise in model_results + for m_mask in [(m_dates >= common_start) & (m_dates <= common_end)])) * 1.5, + 50) + ax.set_ylim(0, ymax) + + fig.tight_layout() + tvl_str = "_".join(f"{t/1e6:.0f}M" for t in args.model_tvl) + out = os.path.join(args.output_dir, + f"{args.model_pool[:8]}_vs_{args.real_pool[:8]}_{tvl_str}.png") + fig.savefig(out, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"\nSaved: {out}") + + # Summary table + print(f"\n{'='*60}") + print(f"Summary") + print(f"{'='*60}") + print(f" Real {real_tokens} weighted:") + print(f" Median TVL: ${np.median(real_tvl[real_mask]):,.0f}") + print(f" Median vol: ${np.median(real_vol[real_mask]):,.0f}/day") + print(f" Median vol/TVL: {np.median(real_vol_tvl):.2f}%") + for tvl_val, m_dates, m_noise in model_results: + m_mask = (m_dates >= common_start) & (m_dates <= common_end) + med_noise = np.median(m_noise[m_mask]) + print(f" Modelled reClAMM at ${tvl_val/1e6:.0f}M:") + print(f" Median noise: ${med_noise:,.0f}/day") + print(f" Median noise/TVL: {med_noise/tvl_val*100:.2f}%") + + +if __name__ == "__main__": + main() diff --git a/scripts/compare_reclamm_geometric_noise_runs.py b/scripts/compare_reclamm_geometric_noise_runs.py new file mode 100644 index 0000000..96230ed --- /dev/null +++ b/scripts/compare_reclamm_geometric_noise_runs.py @@ -0,0 +1,13 @@ +"""Wrapper for the canonical reCLAMM geometric noise comparison script.""" + +from __future__ import annotations + +import runpy +from pathlib import Path + + +if __name__ == "__main__": + runpy.run_path( + str(Path(__file__).with_name("reclamm") / "compare_reclamm_geometric_noise_runs.py"), + run_name="__main__", + ) diff --git a/scripts/compare_reclamm_thermostats.py b/scripts/compare_reclamm_thermostats.py index 8a2c374..7fb1770 100644 --- a/scripts/compare_reclamm_thermostats.py +++ b/scripts/compare_reclamm_thermostats.py @@ -1,19 +1,40 @@ -"""Compare geometric vs constant-arc-length thermostats on historic data. +"""Compare reCLAMM interpolation modes on historic AAVE/ETH data. -Runs AAVE/ETH reClAMM pool simulations with both interpolation methods. -Plots: pool value, cumulative LVR, price path, empirical weights, -value difference, LVR ratio, and per-step LVR distribution (∝ Δs²). +Runs the production geometric interpolation against the non-linear +constant-arc-length interpolation on: +1. The original launch-style range (price_ratio ~= 1.50) +2. A much tighter range (price_ratio = 1.10) -Usage: - cd - source ~/miniconda3/etc/profile.d/conda.sh && conda activate qsim-reclamm - python scripts/compare_reclamm_thermostats.py +The aggressive case is deliberate. A local AAVE/ETH sweep showed: +price_ratio 1.15, margin 0.5, shift 0.1 -> about +$10k vs geometric +price_ratio 1.10, margin 0.5, shift 0.1 -> about +$31k vs geometric +price_ratio 1.10, margin 0.6, shift 0.1 -> about +$73k vs geometric + +So the strongest clean demo setting came from tightening the band and +slightly raising the trigger margin, while keeping the launch-style shift +speed rather than pushing shift_exponent higher. """ +import gc +import hashlib +import os +from pathlib import Path + import jax.numpy as jnp import numpy as np +import pandas as pd import matplotlib.pyplot as plt +from matplotlib.colors import Normalize, SymLogNorm, TwoSlopeNorm +from matplotlib.cm import ScalarMappable +from quantammsim.pools.reCLAMM.reclamm_reserves import ( + calibrate_arc_length_speed, + compute_price_ratio, + initialise_reclamm_reserves, +) from quantammsim.runners.jax_runners import do_run_on_historic_data +from quantammsim.utils.data_processing.historic_data_utils import ( + get_historic_parquet_data, +) def to_daily_price_shift_base(daily_price_shift_exponent): @@ -21,61 +42,515 @@ def to_daily_price_shift_base(daily_price_shift_exponent): return 1.0 - daily_price_shift_exponent / 124649.0 +def build_inclusive_sweep(start, stop, step): + """Build a sweep that keeps the requested step and explicitly includes the stop.""" + values = np.arange(start, stop + 1.0e-12, step, dtype=float) + if values.size == 0 or not np.isclose(values[-1], stop): + values = np.append(values, float(stop)) + return values + + +def _resolve_repo_root(script_path): + """Locate the repository root from either scripts/ or scripts/reclamm/.""" + script_path = Path(script_path).resolve() + for parent in script_path.parents: + if (parent / "quantammsim").exists() and (parent / "scripts").exists(): + return parent + return script_path.parents[1] + + +RUN_CONSTANT_ARC_LENGTH = True +INTERPOLATION_METHODS = ( + ("geometric", "constant_arc_length") + if RUN_CONSTANT_ARC_LENGTH + else ("geometric",) +) +HEATMAP_PRICE_RATIOS = build_inclusive_sweep(1.01, 3.00, 0.025) +HEATMAP_MARGINS = np.linspace(0.05, 0.90, 39) +HEATMAP_SHIFT_EXPONENTS = build_inclusive_sweep(0.01, 0.50, 0.0125) +HEATMAP_ARC_LENGTH_SPEEDS = np.geomspace(1.0e-6, 5.0e-4, 11) +PRICE_RATIO_TICKS = np.array([1.01, 1.25, 1.50, 2.00, 2.50, 3.00]) +MARGIN_TICKS = np.array([0.05, 0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85, 0.90]) +SHIFT_EXPONENT_TICKS = np.array([0.01, 0.05, 0.10, 0.20, 0.30, 0.40, 0.50]) +ARC_LENGTH_SPEED_TICKS = np.array([ + 1.0e-6, + 2.0e-6, + 5.0e-6, + 1.0e-5, + 2.0e-5, + 5.0e-5, + 1.0e-4, + 2.0e-4, + 5.0e-4, +]) +SWEEP_LINE_WIDTH = 0.45 +REFERENCE_LINE_WIDTH = 0.9 +DEFAULT_INITIAL_POOL_VALUE = 1_000_000.0 +TVL_SWEEP_VALUES = ( + 1_000_000.0, + 5_000_000.0, + 20_000_000.0, +) +CENTER_ZERO_HEATMAP_COLOR_NORM = "symlog" +CENTER_ZERO_HEATMAP_COLOR_TAG = "symlog20" +CENTER_ZERO_HEATMAP_SYMLOG_LINTHRESH = 20.0 +FIXED_SLICE_FRACTIONS = (0.125, 0.375, 0.625, 0.875) +FIXED_SLICE_LABELS = ("Q1", "Q2", "Q3", "Q4") +THREE_D_VIEW_ELEVATION = 22.0 +THREE_D_VIEW_AZIMUTH = 140.0 +HEATMAP_FORWARD_CACHE_ENABLED = True +HEATMAP_FORWARD_CACHE_RUN_NAME = "aave_eth_thermostat_heatmaps_market_linear_v2" +HEATMAP_FORWARD_CACHE_ROOT = os.path.join( + "results", + "reclamm_heatmap_forward_cache", +) +HEATMAP_FORWARD_CACHE_FLUSH_EVERY = 360 + +REPO_ROOT = _resolve_repo_root(__file__) +AAVE_WETH_POOL_ID = "0x9d1fcf346ea1b0" +DEFAULT_MARKET_LINEAR_ARTIFACT_DIR = "results/linear_market_noise" +DEFAULT_MARKET_LINEAR_NOISE_START_DATE = "2024-06-01" +DEFAULT_MARKET_LINEAR_NOISE_END_DATE = "2026-03-01" +DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH = str( + REPO_ROOT + / "results" + / "linear_market_noise" + / "_sim_arrays" + / ( + f"{AAVE_WETH_POOL_ID}_{DEFAULT_MARKET_LINEAR_NOISE_START_DATE}_" + f"{DEFAULT_MARKET_LINEAR_NOISE_END_DATE}.npz" + ) +) +DEFAULT_NOISE_MODEL = "market_linear" +DEFAULT_GAS_COST = 1.0 +DEFAULT_PROTOCOL_FEE_SPLIT = 0.25 +FIXED_COMPARE_ARB_FREQUENCY = 15 +AAVE_ETH_NOISE_SETTINGS = { + "enable_noise_model": True, + "noise_model": DEFAULT_NOISE_MODEL, + "noise_reference_model": DEFAULT_NOISE_MODEL, + "noise_artifact_dir": DEFAULT_MARKET_LINEAR_ARTIFACT_DIR, + "noise_pool_id": AAVE_WETH_POOL_ID, + "noise_arrays_path": DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH, + "arb_frequency": FIXED_COMPARE_ARB_FREQUENCY, + "gas_cost": DEFAULT_GAS_COST, + "protocol_fee_split": DEFAULT_PROTOCOL_FEE_SPLIT, +} +PERSISTED_FORWARD_VALUE_COLUMNS = ( + "cache_key_hash", + "final_value", + "method", + "enable_noise_model", + "noise_model", + "price_ratio", + "centeredness_margin", + "daily_price_shift_exponent", + "initial_pool_value", + "arb_frequency", +) + +GEOMETRIC_ONLY_HEATMAP_METRIC_KEYS = ( + "geometric_vs_launch_geometric_pct", + "noise_geometric_final_value_musd", + "noise_vs_arb_geometric_improvement_pct", +) +CONSTANT_ARC_HEATMAP_METRIC_KEYS = ( + "efficiency_pct", + "launch_geometric_efficiency_pct", + "constant_arc_vs_launch_constant_arc_pct", + "noise_constant_arc_final_value_musd", + "noise_vs_arb_constant_arc_improvement_pct", +) +HEATMAP_METRIC_DEPENDENCIES = { + "efficiency_pct": ("noise_geometric", "noise_constant_arc"), + "launch_geometric_efficiency_pct": ("noise_constant_arc",), + "geometric_vs_launch_geometric_pct": ("noise_geometric",), + "constant_arc_vs_launch_constant_arc_pct": ("noise_constant_arc",), + "noise_geometric_final_value_musd": ("noise_geometric",), + "noise_constant_arc_final_value_musd": ("noise_constant_arc",), + "noise_vs_arb_geometric_improvement_pct": ("noise_geometric", "arb_geometric"), + "noise_vs_arb_constant_arc_improvement_pct": ( + "noise_constant_arc", + "arb_constant_arc", + ), +} + +_NOISE_SETTINGS_CACHE = {} +_MARKET_LINEAR_NOISE_DATA_CACHE = {} + + +def get_initial_pool_value(cfg): + """Return the configured base pool TVL in USD.""" + return float(cfg.get("initial_pool_value", DEFAULT_INITIAL_POOL_VALUE)) + + +def get_tvl_millions(cfg): + """Return the configured base pool TVL in millions of USD.""" + return get_initial_pool_value(cfg) / 1_000_000.0 + + +def format_tvl_millions_slug(cfg): + """Format the TVL in millions for stable filenames.""" + tvl_millions = get_tvl_millions(cfg) + rounded = round(float(tvl_millions), 6) + if np.isclose(rounded, round(rounded)): + return f"{int(round(rounded))}m" + return f"{rounded:.6f}".rstrip("0").rstrip(".").replace(".", "p") + "m" + + +def format_tvl_millions_label(cfg): + """Format the TVL in millions for plot titles and logs.""" + return f"{get_tvl_millions(cfg):.1f}M" + + +def tvl_artifact_filename(stem, cfg, suffix=None): + """Append a TVL-in-millions suffix to a PNG artifact name.""" + parts = [stem] + if suffix: + parts.append(suffix) + parts.append(f"tvl_{format_tvl_millions_slug(cfg)}") + return "_".join(parts) + ".png" + + +def heatmap_artifact_filename(spec, cfg, suffix=None): + """Build a heatmap filename, including any colour-style tag.""" + stem = f"reclamm_heatmap_{spec['slug']}" + artifact_tag = spec.get("artifact_tag") + if artifact_tag: + stem = f"{stem}_{artifact_tag}" + return tvl_artifact_filename(stem, cfg, suffix=suffix) + + +def three_d_heatmap_artifact_filename(spec, cfg, suffix=None): + """Build a 3D heatmap filename, including any colour-style tag.""" + stem = f"reclamm_heatmap_3d_{spec['slug']}" + artifact_tag = spec.get("artifact_tag") + if artifact_tag: + stem = f"{stem}_{artifact_tag}" + return tvl_artifact_filename(stem, cfg, suffix=suffix) + + +def format_heatmap_param_value(value): + """Format a sweep parameter compactly for titles and logs.""" + value = float(value) + if abs(value) >= 1.0: + return f"{value:.2f}".rstrip("0").rstrip(".") + return f"{value:.3f}".rstrip("0").rstrip(".") + + +def configs_for_tvl(base_configs, initial_pool_value): + """Attach a shared initial TVL to each compare configuration.""" + configs = [] + for cfg in base_configs: + updated = dict(cfg) + updated["initial_pool_value"] = float(initial_pool_value) + configs.append(updated) + return configs + + +def _normalize_arb_frequency(value, default=FIXED_COMPARE_ARB_FREQUENCY): + """Return a stable integer arb cadence for thermostat comparisons.""" + if value is None: + if default is None: + return None + value = default + return max(int(round(float(value))), 1) + + +def get_effective_arb_frequency(cfg, noise_cfg=None): + """Resolve the arb cadence used by a thermostat comparison run.""" + del noise_cfg + return _normalize_arb_frequency(FIXED_COMPARE_ARB_FREQUENCY) + + +def _canonical_noise_reference_model(cfg): + """Resolve the only supported thermostat noise parametrisation.""" + noise_model = cfg.get("noise_model", DEFAULT_NOISE_MODEL) or DEFAULT_NOISE_MODEL + reference_model = cfg.get("noise_reference_model") + if reference_model is None: + reference_model = DEFAULT_NOISE_MODEL if noise_model == "arb_only" else noise_model + noise_model = str(noise_model) + reference_model = str(reference_model) + if noise_model not in {DEFAULT_NOISE_MODEL, "arb_only"}: + raise ValueError( + "compare_reclamm_thermostats only supports " + "'market_linear' noise and 'arb_only' baselines." + ) + if reference_model != DEFAULT_NOISE_MODEL: + raise ValueError( + "compare_reclamm_thermostats only supports the " + "'market_linear' noise parametrisation." + ) + return reference_model + + +def normalize_compare_run_cfg(cfg, enable_noise_model=None): + """Canonicalize the compare-run config so non-axis inputs stay fixed.""" + updated = dict(cfg) + updated["price_ratio"] = float(cfg["price_ratio"]) + updated["centeredness_margin"] = float(cfg["centeredness_margin"]) + updated["daily_price_shift_exponent"] = float(cfg["daily_price_shift_exponent"]) + updated["initial_pool_value"] = float(get_initial_pool_value(cfg)) + updated["gas_cost"] = DEFAULT_GAS_COST + updated["protocol_fee_split"] = DEFAULT_PROTOCOL_FEE_SPLIT + updated["arb_fees"] = 0.0 + updated["arb_frequency"] = get_effective_arb_frequency(cfg) + updated["noise_trader_ratio"] = 0.0 + + arc_length_speed = cfg.get("arc_length_speed") + if arc_length_speed is None: + updated.pop("arc_length_speed", None) + else: + updated["arc_length_speed"] = float(arc_length_speed) + + use_noise = ( + bool(cfg.get("enable_noise_model", False)) + if enable_noise_model is None + else bool(enable_noise_model) + ) + updated["enable_noise_model"] = use_noise + + reference_mode = _canonical_noise_reference_model(cfg) + if use_noise: + updated["noise_model"] = reference_mode + updated["noise_reference_model"] = reference_mode + else: + updated["noise_model"] = "arb_only" + updated["noise_reference_model"] = reference_mode + + updated["noise_arrays_path"] = DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH + updated.pop("reclamm_noise_params", None) + updated["noise_artifact_dir"] = DEFAULT_MARKET_LINEAR_ARTIFACT_DIR + updated["noise_pool_id"] = AAVE_WETH_POOL_ID + + return updated + + +def make_noise_variant_cfg(cfg, enable_noise_model): + """Return a config with either noise modelling or pure arb-only enabled.""" + return normalize_compare_run_cfg(cfg, enable_noise_model=enable_noise_model) + + +def _hashable_noise_params(params): + """Convert a noise-params dict into a stable cache key fragment.""" + if params is None: + return None + return tuple(sorted((str(k), round(float(v), 12)) for k, v in params.items())) + + +def load_shared_market_linear_noise_data( + arrays_path=DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH, +): + """Load the market_linear arrays once so compare runs can reuse them.""" + arrays_path = os.path.abspath(os.fspath(arrays_path)) + cached = _MARKET_LINEAR_NOISE_DATA_CACHE.get(arrays_path) + if cached is not None: + return cached + + if not os.path.exists(arrays_path): + raise FileNotFoundError(f"market_linear arrays file not found: {arrays_path}") + + with np.load(arrays_path) as arrays: + required_keys = {"noise_base", "noise_tvl_coeff", "tvl_mean", "tvl_std"} + missing_keys = sorted(required_keys.difference(arrays.files)) + if missing_keys: + raise KeyError( + f"market_linear arrays file {arrays_path} is missing keys: {missing_keys}" + ) + shared = { + "arrays_path": arrays_path, + "noise_base_array": np.asarray(arrays["noise_base"]), + "noise_tvl_coeff_array": np.asarray(arrays["noise_tvl_coeff"]), + "tvl_mean": float(arrays["tvl_mean"]), + "tvl_std": float(arrays["tvl_std"]), + } + _MARKET_LINEAR_NOISE_DATA_CACHE[arrays_path] = shared + return shared + + +def _load_market_linear_noise_stats(arrays_path=DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH): + """Load the exact arrays file used by the market_linear run fingerprint. + + The simulator consumes ``noise_base`` and ``noise_tvl_coeff`` from + ``run_fingerprint["noise_arrays_path"]`` and uses ``tvl_mean``/``tvl_std`` + from the same file for TVL standardization. + """ + shared = load_shared_market_linear_noise_data(arrays_path=arrays_path) + return shared["arrays_path"], shared["tvl_mean"], shared["tvl_std"] + + +def _market_linear_noise_settings(noise_model="market_linear", arb_frequency=None): + """Build the tuned market_linear fingerprint block from the fixed arrays file.""" + arrays_path, tvl_mean, tvl_std = _load_market_linear_noise_stats() + arb_frequency = _normalize_arb_frequency(arb_frequency) + return { + "noise_model": noise_model, + "noise_trader_ratio": 0.0, + "reclamm_noise_params": { + "tvl_mean": tvl_mean, + "tvl_std": tvl_std, + }, + "noise_arrays_path": arrays_path, + "arb_frequency": arb_frequency, + "noise_summary": f"{noise_model} (arb_frequency={arb_frequency})", + "noise_cache_key": ( + noise_model, + arrays_path, + arb_frequency, + round(tvl_mean, 12), + round(tvl_std, 12), + ), + } + +def resolve_reclamm_noise_settings(cfg): + """Resolve the active reCLAMM noise-model fingerprint block for a config.""" + cfg = normalize_compare_run_cfg(cfg) + enable_noise_model = cfg.get("enable_noise_model", False) + requested_mode = cfg.get("noise_model", DEFAULT_NOISE_MODEL) + reference_mode = cfg.get("noise_reference_model", DEFAULT_NOISE_MODEL) + requested_arb_frequency = get_effective_arb_frequency(cfg) + cache_key = ( + tuple(cfg.get("tokens", [])), + cfg.get("start"), + cfg.get("end"), + enable_noise_model, + requested_mode, + reference_mode, + cfg.get("noise_artifact_dir", DEFAULT_MARKET_LINEAR_ARTIFACT_DIR), + cfg.get("noise_pool_id", AAVE_WETH_POOL_ID), + requested_arb_frequency, + round(float(cfg.get("noise_trader_ratio", 0.0)), 12), + _hashable_noise_params(cfg.get("reclamm_noise_params")), + cfg.get("noise_arrays_path"), + ) + if cache_key in _NOISE_SETTINGS_CACHE: + return _NOISE_SETTINGS_CACHE[cache_key] + + if requested_mode == "arb_only": + result = _market_linear_noise_settings( + noise_model="arb_only", + arb_frequency=requested_arb_frequency, + ) + elif requested_mode == DEFAULT_NOISE_MODEL: + result = _market_linear_noise_settings( + noise_model=DEFAULT_NOISE_MODEL, + arb_frequency=requested_arb_frequency, + ) + else: + raise ValueError( + "compare_reclamm_thermostats only supports " + "'market_linear' noise and 'arb_only' baselines." + ) + + _NOISE_SETTINGS_CACHE[cache_key] = result + return result + + # Pool configurations to compare CONFIGS = [ { - "name": "AAVE/ETH on-chain (25bps, narrow range)", + "name": "AAVE/ETH launch-style range (25bps, reference)", "tokens": ["AAVE", "ETH"], "start": "2024-06-01 00:00:00", "end": "2025-06-01 00:00:00", "fees": 0.0025, - "price_ratio": 1.5, + "price_ratio": 1.5014, "centeredness_margin": 0.5, "daily_price_shift_exponent": 0.1, + "reason": "Original launch-style parameters.", + **AAVE_ETH_NOISE_SETTINGS, }, { - "name": "AAVE/ETH wide range (25bps)", + "name": "AAVE/ETH aggressive tight range (25bps)", "tokens": ["AAVE", "ETH"], "start": "2024-06-01 00:00:00", "end": "2025-06-01 00:00:00", "fees": 0.0025, - "price_ratio": 4.0, - "centeredness_margin": 0.2, - "daily_price_shift_exponent": 1.0, - }, - { - "name": "AAVE/ETH zero fees (narrow)", - "tokens": ["AAVE", "ETH"], - "start": "2024-06-01 00:00:00", - "end": "2025-06-01 00:00:00", - "fees": 0.0, - "price_ratio": 1.5, - "centeredness_margin": 0.5, + "price_ratio": 1.10, + "centeredness_margin": 0.60, "daily_price_shift_exponent": 0.1, + "reason": ( + "Aggressively tightened and moved to an earlier thermostat trigger. " + "At fixed price_ratio=1.10, the shift_exponent sweep still favored " + "0.1, while margin=0.60 widened the non-linear edge materially." + ), + **AAVE_ETH_NOISE_SETTINGS, }, ] -def make_fingerprint(cfg, interpolation_method, centeredness_scaling=False): +def _attach_market_linear_noise_arrays( + fingerprint, + noise_cfg, + market_linear_noise_data, +): + """Attach preloaded market_linear arrays when the compare flow has them.""" + if market_linear_noise_data is None: + return + expected_path = noise_cfg.get("noise_arrays_path") + if expected_path is None: + return + shared_path = os.path.abspath(os.fspath(market_linear_noise_data["arrays_path"])) + expected_path = os.path.abspath(os.fspath(expected_path)) + if shared_path != expected_path: + raise ValueError( + "Shared market_linear noise arrays path does not match " + f"the resolved compare-run noise path: {shared_path} != {expected_path}" + ) + fingerprint["noise_base_array"] = market_linear_noise_data["noise_base_array"] + fingerprint["noise_tvl_coeff_array"] = market_linear_noise_data["noise_tvl_coeff_array"] + + +def make_fingerprint(cfg, interpolation_method, market_linear_noise_data=None): """Build run fingerprint for a given config and interpolation method.""" - return { + cfg = normalize_compare_run_cfg(cfg) + speed_override = ( + cfg.get("arc_length_speed") + if interpolation_method == "constant_arc_length" + else None + ) + noise_cfg = resolve_reclamm_noise_settings(cfg) + arb_frequency = get_effective_arb_frequency(cfg, noise_cfg) + fingerprint = { "tokens": cfg["tokens"], "rule": "reclamm", "startDateString": cfg["start"], "endDateString": cfg["end"], - "initial_pool_value": 1000000.0, + "initial_pool_value": get_initial_pool_value(cfg), "do_arb": True, "fees": cfg["fees"], - "gas_cost": 0.0, - "arb_fees": 0.0, + "gas_cost": cfg.get( + "gas_cost", + DEFAULT_GAS_COST if cfg.get("enable_noise_model", False) else 0.0, + ), + "arb_fees": cfg.get("arb_fees", 0.0), + "protocol_fee_split": cfg.get( + "protocol_fee_split", + DEFAULT_PROTOCOL_FEE_SPLIT if cfg.get("enable_noise_model", False) else 0.0, + ), + "noise_trader_ratio": noise_cfg.get("noise_trader_ratio", 0.0), "reclamm_interpolation_method": interpolation_method, - "reclamm_arc_length_speed": None, # auto-calibrate - "reclamm_centeredness_scaling": centeredness_scaling, + "reclamm_arc_length_speed": speed_override, } + if noise_cfg.get("noise_model") is not None: + fingerprint["noise_model"] = noise_cfg["noise_model"] + if noise_cfg.get("reclamm_noise_params") is not None: + fingerprint["reclamm_noise_params"] = noise_cfg["reclamm_noise_params"] + if noise_cfg.get("noise_arrays_path") is not None: + fingerprint["noise_arrays_path"] = noise_cfg["noise_arrays_path"] + _attach_market_linear_noise_arrays( + fingerprint, + noise_cfg, + market_linear_noise_data, + ) + if arb_frequency is not None: + fingerprint["arb_frequency"] = arb_frequency + return fingerprint def make_params(cfg): """Build pool params from config.""" + cfg = normalize_compare_run_cfg(cfg) return { "price_ratio": jnp.array(cfg["price_ratio"]), "centeredness_margin": jnp.array(cfg["centeredness_margin"]), @@ -85,40 +560,1654 @@ def make_params(cfg): } -def run_comparison(cfg): - """Run all thermostat variants, return results dict.""" +def load_shared_price_data(configs, root=None): + """Load the shared historic price panel once for all compare runs.""" + tokens = sorted({token for cfg in configs for token in cfg["tokens"]}) + return get_historic_parquet_data(tokens, cols=["close"], root=root) + + +def run_comparison( + cfg, + price_data=None, + low_data_mode=False, + market_linear_noise_data=None, +): + """Run both interpolation variants, return results dict.""" params = make_params(cfg) results = {} - for method in ["geometric", "constant_arc_length"]: - fp = make_fingerprint(cfg, method) + for method in INTERPOLATION_METHODS: + fp = make_fingerprint( + cfg, + method, + market_linear_noise_data=market_linear_noise_data, + ) results[method] = do_run_on_historic_data( - run_fingerprint=fp, params=params + run_fingerprint=fp, + params=params, + price_data=price_data, + low_data_mode=low_data_mode, + ) + + return results + + +def _set_padded_ylim(ax, series_list, pad_ratio=0.04): + """Fit the y-axis tightly around the plotted series.""" + flat = [ + np.asarray(series, dtype=float).ravel() + for series in series_list + if np.asarray(series).size > 0 + ] + if not flat: + return + + values = np.concatenate(flat) + values = values[np.isfinite(values)] + if values.size == 0: + return + + ymin = float(values.min()) + ymax = float(values.max()) + if np.isclose(ymin, ymax): + pad = max(abs(ymin) * pad_ratio, 1e-6) + else: + pad = (ymax - ymin) * pad_ratio + ax.set_ylim(ymin - pad, ymax + pad) + + +def _cache_size(cache): + """Count memoized final-value cache entries materialised in memory.""" + return len(cache.get("_final_value_cache", {})) + + +def _comparison_cache_size(cache): + """Count memoized scalar comparison bundles.""" + return len(cache.get("_comparison_cache", {})) + + +def _heatmap_forward_cache_scope_slug(cfg): + """Build a compact cache scope slug for a shared-TVL heatmap run.""" + if cfg is None: + return "unspecified_tvl" + return f"tvl_{format_tvl_millions_slug(cfg)}" + + +def _heatmap_forward_cache_path(cfg): + """Return the parquet path for persisted scalar forward values.""" + if not HEATMAP_FORWARD_CACHE_ENABLED: + return None + return os.path.join( + HEATMAP_FORWARD_CACHE_ROOT, + HEATMAP_FORWARD_CACHE_RUN_NAME, + f"forward_values_{_heatmap_forward_cache_scope_slug(cfg)}.parquet", + ) + + +def _make_method_cache_hash(key): + """Build a compact stable digest for a method cache key.""" + return hashlib.sha256(repr(key).encode("utf-8")).hexdigest() + + +def _build_persistent_final_value_record(cfg, method, cache_key_hash, final_value): + """Build one self-describing parquet row for a cached scalar run result.""" + cfg = normalize_compare_run_cfg(cfg) + noise_cfg = resolve_reclamm_noise_settings(cfg) + return { + "cache_key_hash": str(cache_key_hash), + "final_value": float(final_value), + "method": str(method), + "enable_noise_model": bool(cfg.get("enable_noise_model", False)), + "noise_model": noise_cfg.get("noise_model"), + "price_ratio": float(cfg["price_ratio"]), + "centeredness_margin": float(cfg["centeredness_margin"]), + "daily_price_shift_exponent": float(cfg["daily_price_shift_exponent"]), + "initial_pool_value": float(get_initial_pool_value(cfg)), + "arb_frequency": get_effective_arb_frequency(cfg, noise_cfg), + } + + +def _load_persistent_final_value_cache(cache): + """Load persisted scalar forward values from parquet once per sweep cache.""" + if cache.get("_persistent_final_value_cache_loaded"): + return + + disk_cache = {} + next_batch_id = 0 + cache_path = cache.get("_persistent_final_value_cache_path") + if cache_path and os.path.exists(cache_path): + parquet_files = [] + if os.path.isdir(cache_path): + parquet_files = [ + os.path.join(cache_path, filename) + for filename in sorted(os.listdir(cache_path)) + if filename.endswith(".parquet") + ] + batch_ids = [] + for filename in os.listdir(cache_path): + if not (filename.startswith("batch_") and filename.endswith(".parquet")): + continue + token = filename[len("batch_") : -len(".parquet")] + if token.isdigit(): + batch_ids.append(int(token)) + next_batch_id = (max(batch_ids) + 1) if batch_ids else 0 + else: + parquet_files = [cache_path] + + for parquet_file in parquet_files: + frame = pd.read_parquet( + parquet_file, + columns=["cache_key_hash", "final_value"], + ) + if frame.empty: + continue + for row in frame.itertuples(index=False): + cache_key_hash = str(row.cache_key_hash) + final_value = float(row.final_value) + disk_cache[cache_key_hash] = final_value + print( + f"Loaded {len(disk_cache)} persisted heatmap forward values from {cache_path}" ) - # Geometric + centeredness-proportional scaling (scales decay duration) - fp_geo_scaled = make_fingerprint(cfg, "geometric", centeredness_scaling=True) - results["geometric_scaled"] = do_run_on_historic_data( - run_fingerprint=fp_geo_scaled, params=params + cache["_persistent_final_value_cache"] = disk_cache + cache["_persistent_final_value_next_batch_id"] = next_batch_id + cache["_persistent_final_value_cache_loaded"] = True + + +def flush_sweep_cache(cache, force=False): + """Persist newly computed scalar forward values to parquet.""" + if not HEATMAP_FORWARD_CACHE_ENABLED: + return + + pending = cache.get("_pending_persistent_final_values") + if not pending: + return + if not force and len(pending) < HEATMAP_FORWARD_CACHE_FLUSH_EVERY: + return + + _load_persistent_final_value_cache(cache) + disk_cache = cache.setdefault("_persistent_final_value_cache", {}) + batch_records = [] + for cache_key_hash, record in pending.items(): + normalized = dict(record) + normalized["cache_key_hash"] = str(cache_key_hash) + normalized["final_value"] = float(normalized["final_value"]) + disk_cache[cache_key_hash] = normalized["final_value"] + batch_records.append(normalized) + + cache_path = cache.get("_persistent_final_value_cache_path") + if cache_path is None: + pending.clear() + return + + if os.path.exists(cache_path) and not os.path.isdir(cache_path): + raise RuntimeError( + f"Persistent cache path {cache_path} already exists as a file. " + "Use a fresh cache namespace for append-only parquet shards." + ) + + os.makedirs(cache_path, exist_ok=True) + batch_records.sort(key=lambda record: record["cache_key_hash"]) + payload = { + column: [record.get(column) for record in batch_records] + for column in PERSISTED_FORWARD_VALUE_COLUMNS + } + payload["final_value"] = np.asarray(payload["final_value"], dtype=np.float64) + frame = pd.DataFrame(payload) + batch_id = int(cache.setdefault("_persistent_final_value_next_batch_id", 0)) + batch_path = os.path.join(cache_path, f"batch_{batch_id:08d}.parquet") + cache["_persistent_final_value_next_batch_id"] = batch_id + 1 + frame.to_parquet(batch_path, index=False, compression="zstd") + print( + f"Persisted {len(pending)} new heatmap forward values to {batch_path} " + f"({len(disk_cache)} total cached values)." ) + pending.clear() + + +def make_sweep_cache( + price_data, + cache_scope_cfg=None, + market_linear_noise_data=None, +): + """Create a shared cache for heatmap and line sweeps.""" + cache = { + "_shared_price_data": price_data, + "_shared_market_linear_noise_data": market_linear_noise_data, + "_final_value_cache": {}, + "_comparison_cache": {}, + "_pending_persistent_final_values": {}, + "_persistent_final_value_cache": {}, + "_persistent_final_value_next_batch_id": 0, + "_persistent_final_value_cache_loaded": False, + "_persistent_final_value_cache_path": _heatmap_forward_cache_path( + cache_scope_cfg + ), + } + return cache + + +def _missing_artifacts(progress_label, filenames): + """Report which plot artifacts still need to be generated.""" + missing = [filename for filename in filenames if not os.path.exists(filename)] + if not missing: + print(f"[{progress_label}] skipping sweep: all artifacts already exist.") + return set() + + existing_count = len(filenames) - len(missing) + if existing_count: + print( + f"[{progress_label}] reusing {existing_count}/{len(filenames)} " + "existing artifacts; generating the missing outputs." + ) + return set(missing) + + +def _speed_cache_key(speed): + """Stable cache token for optional arc-length speed.""" + if speed is None: + return None + return round(float(speed), 12) - # Arc-length + centeredness-proportional scaling (scales speed) - fp_cal_scaled = make_fingerprint(cfg, "constant_arc_length", centeredness_scaling=True) - results["cal_scaled"] = do_run_on_historic_data( - run_fingerprint=fp_cal_scaled, params=params + +def _make_method_cache_key(cfg, method): + """Cache key for a single-method final-value run.""" + cfg = normalize_compare_run_cfg(cfg) + noise_cfg = resolve_reclamm_noise_settings(cfg) + arb_frequency = get_effective_arb_frequency(cfg, noise_cfg) + key = ( + method, + tuple(str(token) for token in cfg["tokens"]), + str(cfg["start"]), + str(cfg["end"]), + round(float(cfg["fees"]), 12), + bool(cfg.get("enable_noise_model", False)), + round(float(cfg["price_ratio"]), 6), + round(float(cfg["centeredness_margin"]), 6), + round(float(cfg["daily_price_shift_exponent"]), 6), + round(get_initial_pool_value(cfg), 2), + noise_cfg.get("noise_cache_key"), + None if arb_frequency is None else int(arb_frequency), + round( + float( + cfg.get( + "gas_cost", + DEFAULT_GAS_COST if cfg.get("enable_noise_model", False) else 0.0, + ) + ), + 6, + ), + round( + float( + cfg.get( + "protocol_fee_split", + DEFAULT_PROTOCOL_FEE_SPLIT if cfg.get("enable_noise_model", False) else 0.0, + ) + ), + 6, + ), ) + if method == "constant_arc_length": + key += (_speed_cache_key(cfg.get("arc_length_speed")),) + return key + + +def _nearest_price_row(price_data, start_ts): + """Select the closest available price row to the requested start timestamp.""" + if len(price_data.index) == 0: + raise ValueError("price_data is empty") + + if isinstance(price_data.index, pd.DatetimeIndex): + target_ts = start_ts + index_tz = getattr(price_data.index, "tz", None) + if index_tz is not None and target_ts.tzinfo is None: + target_ts = target_ts.tz_localize(index_tz) + elif index_tz is None and target_ts.tzinfo is not None: + target_ts = target_ts.tz_convert(None) + target_value = int(target_ts.value) + index_values = price_data.index.asi8 + else: + target_value = int(start_ts.timestamp() * 1000.0) + index_values = price_data.index.to_numpy(dtype=np.int64) + + row_idx = int(np.searchsorted(index_values, target_value, side="left")) + if row_idx >= len(index_values): + row_idx = len(index_values) - 1 + elif row_idx > 0 and index_values[row_idx] != target_value: + prev_idx = row_idx - 1 + if abs(int(index_values[prev_idx]) - target_value) <= abs( + int(index_values[row_idx]) - target_value + ): + row_idx = prev_idx + + row = price_data.iloc[row_idx] + if isinstance(row, pd.DataFrame): + row = row.iloc[0] + return row + + +def _make_comparison_cache_key(cfg, launch_final_values): + """Cache key for scalar heatmap metrics at a single parameter point.""" + noise_cfg = make_noise_variant_cfg(cfg, True) + arb_only_cfg = make_noise_variant_cfg(cfg, False) + key = [ + _make_method_cache_key(noise_cfg, "geometric"), + _make_method_cache_key(arb_only_cfg, "geometric"), + round(float(launch_final_values["geometric"]), 6), + ] + if RUN_CONSTANT_ARC_LENGTH: + key.extend( + [ + _make_method_cache_key(noise_cfg, "constant_arc_length"), + _make_method_cache_key(arb_only_cfg, "constant_arc_length"), + round(float(launch_final_values["constant_arc_length"]), 6), + ] + ) + return tuple(key) + + +def _run_method_final_value_cached(cfg, method, cache): + """Memoize final value for a single interpolation method.""" + final_value_cache = cache.setdefault("_final_value_cache", {}) + key = _make_method_cache_key(cfg, method) + if key in final_value_cache: + return final_value_cache[key] + + _load_persistent_final_value_cache(cache) + key_hash = _make_method_cache_hash(key) + persisted_cache = cache.setdefault("_persistent_final_value_cache", {}) + if key_hash in persisted_cache: + final_value_cache[key] = persisted_cache[key_hash] + return final_value_cache[key] + + result = do_run_on_historic_data( + run_fingerprint=make_fingerprint( + cfg, + method, + market_linear_noise_data=cache.get("_shared_market_linear_noise_data"), + ), + params=make_params(cfg), + price_data=cache["_shared_price_data"], + low_data_mode=True, + ) + final_value_cache[key] = float(result["final_value"]) + cache.setdefault("_pending_persistent_final_values", {})[key_hash] = ( + _build_persistent_final_value_record( + cfg=cfg, + method=method, + cache_key_hash=key_hash, + final_value=final_value_cache[key], + ) + ) + flush_sweep_cache(cache, force=False) + del result + gc.collect() + return final_value_cache[key] + + +def extract_comparison_metrics_from_final_values( + geo_final, arc_final, launch_final_values +): + """Summarize scalar comparison metrics from final values only.""" + return { + "efficiency_pct": (arc_final / max(abs(geo_final), 1e-12) - 1.0) * 100.0, + "launch_geometric_efficiency_pct": ( + arc_final / max(abs(launch_final_values["geometric"]), 1e-12) - 1.0 + ) + * 100.0, + "geometric_vs_launch_geometric_pct": ( + geo_final / max(abs(launch_final_values["geometric"]), 1e-12) - 1.0 + ) + * 100.0, + "constant_arc_vs_launch_constant_arc_pct": ( + arc_final + / max(abs(launch_final_values["constant_arc_length"]), 1e-12) + - 1.0 + ) + * 100.0, + } + + +def _load_required_heatmap_final_values(cfg, cache, metric_keys): + """Load only the cached final values needed for the requested heatmap metrics.""" + required_sources = set() + for metric_key in metric_keys: + required_sources.update(HEATMAP_METRIC_DEPENDENCIES[metric_key]) + + if not RUN_CONSTANT_ARC_LENGTH and any( + source.endswith("constant_arc") for source in required_sources + ): + raise ValueError( + "Constant-arc heatmap metric requested while RUN_CONSTANT_ARC_LENGTH=False" + ) + + final_values = {} + noise_cfg = None + arb_only_cfg = None + + if any(source.startswith("noise_") for source in required_sources): + noise_cfg = make_noise_variant_cfg(cfg, True) + if any(source.startswith("arb_") for source in required_sources): + arb_only_cfg = make_noise_variant_cfg(cfg, False) + + if "noise_geometric" in required_sources: + final_values["noise_geometric"] = _run_method_final_value_cached( + noise_cfg, + "geometric", + cache, + ) + if "noise_constant_arc" in required_sources: + final_values["noise_constant_arc"] = _run_method_final_value_cached( + noise_cfg, + "constant_arc_length", + cache, + ) + if "arb_geometric" in required_sources: + final_values["arb_geometric"] = _run_method_final_value_cached( + arb_only_cfg, + "geometric", + cache, + ) + if "arb_constant_arc" in required_sources: + final_values["arb_constant_arc"] = _run_method_final_value_cached( + arb_only_cfg, + "constant_arc_length", + cache, + ) + return final_values + + +def extract_heatmap_metrics_from_mode_final_values( + metric_keys, + final_values, + launch_final_values, +): + """Collect the requested scalar heatmap metrics from cached final values.""" + metrics = {} + + if "efficiency_pct" in metric_keys: + metrics["efficiency_pct"] = ( + final_values["noise_constant_arc"] + / max(abs(final_values["noise_geometric"]), 1e-12) + - 1.0 + ) * 100.0 + + if "launch_geometric_efficiency_pct" in metric_keys: + metrics["launch_geometric_efficiency_pct"] = ( + final_values["noise_constant_arc"] + / max(abs(launch_final_values["geometric"]), 1e-12) + - 1.0 + ) * 100.0 + + if "geometric_vs_launch_geometric_pct" in metric_keys: + metrics["geometric_vs_launch_geometric_pct"] = ( + final_values["noise_geometric"] + / max(abs(launch_final_values["geometric"]), 1e-12) + - 1.0 + ) * 100.0 + + if "constant_arc_vs_launch_constant_arc_pct" in metric_keys: + metrics["constant_arc_vs_launch_constant_arc_pct"] = ( + final_values["noise_constant_arc"] + / max(abs(launch_final_values["constant_arc_length"]), 1e-12) + - 1.0 + ) * 100.0 + + if "noise_geometric_final_value_musd" in metric_keys: + metrics["noise_geometric_final_value_musd"] = ( + final_values["noise_geometric"] / 1e6 + ) + + if "noise_constant_arc_final_value_musd" in metric_keys: + metrics["noise_constant_arc_final_value_musd"] = ( + final_values["noise_constant_arc"] / 1e6 + ) + + if "noise_vs_arb_geometric_improvement_pct" in metric_keys: + metrics["noise_vs_arb_geometric_improvement_pct"] = ( + final_values["noise_geometric"] + / max(abs(final_values["arb_geometric"]), 1e-12) + - 1.0 + ) * 100.0 + + if "noise_vs_arb_constant_arc_improvement_pct" in metric_keys: + metrics["noise_vs_arb_constant_arc_improvement_pct"] = ( + final_values["noise_constant_arc"] + / max(abs(final_values["arb_constant_arc"]), 1e-12) + - 1.0 + ) * 100.0 + + return metrics + + +def extract_comparison_metrics(results, launch_final_values): + """Summarize scalar heatmap metrics for a pair of runs.""" + geo = results["geometric"] + arc = results["constant_arc_length"] + + geo_final = float(geo["final_value"]) + arc_final = float(arc["final_value"]) + + return extract_comparison_metrics_from_final_values( + geo_final, + arc_final, + launch_final_values=launch_final_values, + ) + + +def run_comparison_cached(cfg, cache, launch_final_values, metric_keys): + """Memoize scalar heatmap metrics across heatmap sweeps.""" + requested_metric_keys = tuple(dict.fromkeys(metric_keys)) + comparison_cache = cache.setdefault("_comparison_cache", {}) + cache_key = _make_comparison_cache_key(cfg, launch_final_values) + cached_metrics = comparison_cache.setdefault(cache_key, {}) + missing_metric_keys = [ + metric_key for metric_key in requested_metric_keys if metric_key not in cached_metrics + ] + if missing_metric_keys: + final_values = _load_required_heatmap_final_values( + cfg, + cache, + missing_metric_keys, + ) + cached_metrics.update( + extract_heatmap_metrics_from_mode_final_values( + missing_metric_keys, + final_values, + launch_final_values=launch_final_values, + ) + ) + return { + metric_key: cached_metrics[metric_key] for metric_key in requested_metric_keys + } + + +def build_heatmap_matrices( + x_values, + y_values, + x_key, + y_key, + base_cfg, + metric_keys, + cache, + progress_label, + launch_final_values, +): + """Evaluate multiple metrics over a 2D parameter grid in one pass.""" + data = { + metric_key: np.zeros((len(y_values), len(x_values)), dtype=float) + for metric_key in metric_keys + } + total_points = len(y_values) * len(x_values) + + print( + f"[{progress_label}] start: {len(y_values)} rows x {len(x_values)} cols " + f"= {total_points} parameter points" + ) + + for yi, y_value in enumerate(y_values): + final_cache_before_row = _cache_size(cache) + comparison_cache_before_row = _comparison_cache_size(cache) + for xi, x_value in enumerate(x_values): + cfg = dict(base_cfg) + cfg[x_key] = float(x_value) + cfg[y_key] = float(y_value) + metrics = run_comparison_cached( + cfg, + cache, + launch_final_values=launch_final_values, + metric_keys=metric_keys, + ) + for metric_key in metric_keys: + data[metric_key][yi, xi] = metrics[metric_key] + + completed_points = (yi + 1) * len(x_values) + row_new_final_entries = _cache_size(cache) - final_cache_before_row + row_new_comparisons = ( + _comparison_cache_size(cache) - comparison_cache_before_row + ) + row_pct = completed_points / total_points * 100.0 + flush_sweep_cache(cache, force=True) + print( + f"[{progress_label}] row {yi + 1}/{len(y_values)} complete " + f"({y_key}={float(y_value):.4f}, {completed_points}/{total_points} " + f"points, {row_pct:.1f}%, {row_new_final_entries} new final-value cache entries, " + f"{row_new_comparisons} new comparison bundles)" + ) + + print( + f"[{progress_label}] done: " + + ", ".join( + ( + f"{metric_key} min={float(np.nanmin(data[metric_key])):.4f}, " + f"max={float(np.nanmax(data[metric_key])):.4f}" + ) + for metric_key in metric_keys + ) + + ( + f", final_value_cache_size={_cache_size(cache)}, " + f"comparison_cache_size={_comparison_cache_size(cache)}" + ) + ) + + return data + + +def build_metric_curve( + x_values, + x_key, + base_cfg, + metric_key, + cache, + launch_final_values, +): + """Evaluate one metric over a 1D sweep.""" + data = np.zeros(len(x_values), dtype=float) + for xi, x_value in enumerate(x_values): + cfg = dict(base_cfg) + cfg[x_key] = float(x_value) + metrics = run_comparison_cached( + cfg, + cache, + launch_final_values=launch_final_values, + metric_keys=(metric_key,), + ) + data[xi] = metrics[metric_key] + flush_sweep_cache(cache, force=True) + return data + + +def _compute_axis_edges(values, scale="linear"): + """Convert axis centers to cell edges for pcolormesh.""" + values = np.asarray(values, dtype=float) + if values.size == 1: + if scale == "log": + return np.array([values[0] / np.sqrt(10.0), values[0] * np.sqrt(10.0)]) + pad = max(abs(values[0]) * 0.5, 1.0) + return np.array([values[0] - pad, values[0] + pad]) + + if scale == "log": + log_values = np.log10(values) + edges = np.empty(values.size + 1, dtype=float) + edges[1:-1] = 0.5 * (log_values[:-1] + log_values[1:]) + edges[0] = log_values[0] - 0.5 * (log_values[1] - log_values[0]) + edges[-1] = log_values[-1] + 0.5 * (log_values[-1] - log_values[-2]) + return 10.0 ** edges + + edges = np.empty(values.size + 1, dtype=float) + edges[1:-1] = 0.5 * (values[:-1] + values[1:]) + edges[0] = values[0] - 0.5 * (values[1] - values[0]) + edges[-1] = values[-1] + 0.5 * (values[-1] - values[-2]) + return edges + + +def build_fixed_slice_variants(values): + """Pick four representative quarter-range slices from a sweep grid.""" + values = np.asarray(values, dtype=float) + if values.size < len(FIXED_SLICE_FRACTIONS): + raise ValueError("Need at least four grid points to build fixed slices") + + variants = [] + used_indices = set() + for idx, fraction in enumerate(FIXED_SLICE_FRACTIONS): + target_index = int(round(fraction * (values.size - 1))) + while target_index in used_indices and target_index + 1 < values.size: + target_index += 1 + while target_index in used_indices and target_index - 1 >= 0: + target_index -= 1 + if target_index in used_indices: + raise ValueError("Could not build four unique fixed slices from sweep grid") + used_indices.add(target_index) + variants.append( + { + "index": target_index, + "fraction": fraction, + "label": FIXED_SLICE_LABELS[idx], + "slug": f"q{idx + 1}", + "value": float(values[target_index]), + } + ) + return variants + + +def _pair_slice_suffix(pair, slice_variant): + """Build a stable artifact suffix for a pairwise fixed-variable slice.""" + return f"{pair['slug']}_{pair['fixed_slug']}_{slice_variant['slug']}" + + +def _build_heatmap_norm( + data_arrays, + center_zero, + color_norm=None, + symlog_linthresh=None, +): + """Build a color normalizer shared by 2D and 3D heatmaps.""" + finite_parts = [] + for data in data_arrays: + finite = np.asarray(data, dtype=float) + finite = finite[np.isfinite(finite)] + if finite.size: + finite_parts.append(finite) + finite = np.concatenate(finite_parts) if finite_parts else np.array([], dtype=float) + + if center_zero: + if finite.size == 0: + vmax = 1.0 + else: + vmax = max(abs(float(finite.min())), abs(float(finite.max())), 1e-9) + if ( + color_norm == "symlog" + and symlog_linthresh is not None + and vmax > symlog_linthresh + ): + return SymLogNorm( + linthresh=symlog_linthresh, + linscale=1.0, + vmin=-vmax, + vmax=vmax, + base=10.0, + ) + return TwoSlopeNorm(vcenter=0.0, vmin=-vmax, vmax=vmax) + + if finite.size == 0: + vmin, vmax = 0.0, 1.0 + else: + vmin = float(finite.min()) + vmax = float(finite.max()) + if np.isclose(vmin, vmax): + pad = max(abs(vmin) * 0.01, 1e-9) + vmin -= pad + vmax += pad + return Normalize(vmin=vmin, vmax=vmax) + + +def get_pair_heatmap_metric_specs(): + """Return the standard thermostat pairwise heatmap metrics.""" + metric_specs = [ + { + "key": "efficiency_pct", + "title": "Efficiency vs heatmap geometric", + "colorbar_label": "Const Arc - heatmap Geo (% of heatmap geometric final value)", + "slug": "efficiency", + "center_zero": True, + "cmap": "RdYlGn", + }, + { + "key": "launch_geometric_efficiency_pct", + "title": "Efficiency vs launch-style geometric", + "colorbar_label": "Const Arc - launch Geo (% of launch geometric final value)", + "slug": "launch_geometric_efficiency", + "center_zero": True, + "cmap": "RdYlGn", + }, + { + "key": "geometric_vs_launch_geometric_pct", + "title": "Geometric tuning vs launch-style geometric", + "colorbar_label": "Candidate Geo - launch Geo (% of launch geometric final value)", + "slug": "geometric_vs_launch_geometric", + "center_zero": True, + "cmap": "RdYlGn", + }, + { + "key": "constant_arc_vs_launch_constant_arc_pct", + "title": "Const arc tuning vs launch-style const arc", + "colorbar_label": "Candidate Const Arc - launch Const Arc (% of launch const arc final value)", + "slug": "constant_arc_vs_launch_constant_arc", + "center_zero": True, + "cmap": "RdYlGn", + }, + { + "key": "noise_geometric_final_value_musd", + "title": "Geometric final value with noise model", + "colorbar_label": "Geometric final value with noise model ($M)", + "slug": "noise_geometric_final_value", + "center_zero": False, + "cmap": "viridis", + }, + { + "key": "noise_constant_arc_final_value_musd", + "title": "Const arc final value with noise model", + "colorbar_label": "Const Arc final value with noise model ($M)", + "slug": "noise_constant_arc_final_value", + "center_zero": False, + "cmap": "viridis", + }, + { + "key": "noise_vs_arb_geometric_improvement_pct", + "title": "Noise-model improvement over arb-only (geometric)", + "colorbar_label": "Noise-model Geo - arb-only Geo (% of arb-only final value)", + "slug": "noise_vs_arb_geometric_improvement", + "center_zero": True, + "cmap": "RdYlGn", + }, + { + "key": "noise_vs_arb_constant_arc_improvement_pct", + "title": "Noise-model improvement over arb-only (const arc)", + "colorbar_label": "Noise-model Const Arc - arb-only Const Arc (% of arb-only final value)", + "slug": "noise_vs_arb_constant_arc_improvement", + "center_zero": True, + "cmap": "RdYlGn", + }, + ] + for spec in metric_specs: + if spec["center_zero"]: + spec["color_norm"] = CENTER_ZERO_HEATMAP_COLOR_NORM + spec["symlog_linthresh"] = CENTER_ZERO_HEATMAP_SYMLOG_LINTHRESH + spec["artifact_tag"] = CENTER_ZERO_HEATMAP_COLOR_TAG + if not RUN_CONSTANT_ARC_LENGTH: + metric_specs = [ + spec + for spec in metric_specs + if spec["key"] in GEOMETRIC_ONLY_HEATMAP_METRIC_KEYS + ] + return metric_specs + + +def get_pair_heatmap_specs(base_cfg): + """Return the three pairwise thermostat heatmap families plus slice settings.""" + fixed_slice_variants = { + "price_ratio": build_fixed_slice_variants(HEATMAP_PRICE_RATIOS), + "centeredness_margin": build_fixed_slice_variants(HEATMAP_MARGINS), + "daily_price_shift_exponent": build_fixed_slice_variants( + HEATMAP_SHIFT_EXPONENTS + ), + } + return [ + { + "slug": "price_ratio_vs_margin", + "x_values": HEATMAP_PRICE_RATIOS, + "y_values": HEATMAP_MARGINS, + "x_key": "price_ratio", + "y_key": "centeredness_margin", + "x_label": "Price ratio", + "y_label": "Centeredness margin", + "xticks": PRICE_RATIO_TICKS, + "yticks": MARGIN_TICKS, + "fixed_key": "daily_price_shift_exponent", + "fixed_label": "Shift exponent", + "fixed_slug": "shift_exp", + "fixed_slices": fixed_slice_variants["daily_price_shift_exponent"], + }, + { + "slug": "shift_exp_vs_margin", + "x_values": HEATMAP_SHIFT_EXPONENTS, + "y_values": HEATMAP_MARGINS, + "x_key": "daily_price_shift_exponent", + "y_key": "centeredness_margin", + "x_label": "Shift exponent", + "y_label": "Centeredness margin", + "xticks": SHIFT_EXPONENT_TICKS, + "yticks": MARGIN_TICKS, + "fixed_key": "price_ratio", + "fixed_label": "Price ratio", + "fixed_slug": "price_ratio", + "fixed_slices": fixed_slice_variants["price_ratio"], + }, + { + "slug": "price_ratio_vs_shift_exp", + "x_values": HEATMAP_PRICE_RATIOS, + "y_values": HEATMAP_SHIFT_EXPONENTS, + "x_key": "price_ratio", + "y_key": "daily_price_shift_exponent", + "x_label": "Price ratio", + "y_label": "Shift exponent", + "xticks": PRICE_RATIO_TICKS, + "yticks": SHIFT_EXPONENT_TICKS, + "fixed_key": "centeredness_margin", + "fixed_label": "Centeredness margin", + "fixed_slug": "margin", + "fixed_slices": fixed_slice_variants["centeredness_margin"], + }, + ] + + +def plot_heatmap( + data, + x_values, + y_values, + x_label, + y_label, + title, + colorbar_label, + filename, + xticks=None, + yticks=None, + xscale="linear", + center_zero=True, + cmap=None, + color_norm=None, + symlog_linthresh=None, +): + """Render and save a single heatmap.""" + norm = _build_heatmap_norm( + [data], + center_zero=center_zero, + color_norm=color_norm, + symlog_linthresh=symlog_linthresh, + ) + cmap_name = cmap or ("RdYlGn" if center_zero else "viridis") + + x_edges = _compute_axis_edges(x_values, scale=xscale) + y_edges = _compute_axis_edges(y_values, scale="linear") + + fig, ax = plt.subplots(figsize=(8.5, 6.0)) + im = ax.pcolormesh( + x_edges, + y_edges, + data, + cmap=cmap_name, + norm=norm, + shading="auto", + ) + + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) + ax.set_title(title) + if xscale == "log": + ax.set_xscale("log") + ax.set_xticks(np.asarray(xticks if xticks is not None else x_values, dtype=float)) + ax.set_yticks(np.asarray(yticks if yticks is not None else y_values, dtype=float)) + ax.grid(False) + + cbar = fig.colorbar(im, ax=ax) + cbar.set_label(colorbar_label) + + plt.tight_layout() + plt.savefig(filename, dpi=150) + print(f"Saved {filename}") + plt.close(fig) + + +def plot_three_variable_heatmap_3d( + price_margin_data, + shift_margin_data, + price_shift_data, + fixed_price_ratio, + fixed_margin, + fixed_shift_exponent, + title, + colorbar_label, + filename, + center_zero=True, + cmap=None, + color_norm=None, + symlog_linthresh=None, +): + """Render orthogonal 3D heatmap surfaces across the three thermostat variables.""" + norm = _build_heatmap_norm( + [price_margin_data, shift_margin_data, price_shift_data], + center_zero=center_zero, + color_norm=color_norm, + symlog_linthresh=symlog_linthresh, + ) + cmap_name = cmap or ("RdYlGn" if center_zero else "viridis") + cmap_obj = plt.get_cmap(cmap_name) + + price_margin_x, price_margin_y = np.meshgrid(HEATMAP_PRICE_RATIOS, HEATMAP_MARGINS) + price_margin_z = np.full_like(price_margin_x, fixed_shift_exponent, dtype=float) + + shift_margin_z, shift_margin_y = np.meshgrid( + HEATMAP_SHIFT_EXPONENTS, + HEATMAP_MARGINS, + ) + shift_margin_x = np.full_like(shift_margin_z, fixed_price_ratio, dtype=float) + + price_shift_x, price_shift_z = np.meshgrid( + HEATMAP_PRICE_RATIOS, + HEATMAP_SHIFT_EXPONENTS, + ) + price_shift_y = np.full_like(price_shift_x, fixed_margin, dtype=float) + + fig = plt.figure(figsize=(10.5, 7.2)) + ax = fig.add_subplot(111, projection="3d") + ax.set_facecolor("white") + fig.patch.set_facecolor("white") + + ax.plot_surface( + price_margin_x, + price_margin_y, + price_margin_z, + facecolors=cmap_obj(norm(np.asarray(price_margin_data, dtype=float))), + shade=False, + ) + ax.plot_surface( + shift_margin_x, + shift_margin_y, + shift_margin_z, + facecolors=cmap_obj(norm(np.asarray(shift_margin_data, dtype=float))), + shade=False, + ) + ax.plot_surface( + price_shift_x, + price_shift_y, + price_shift_z, + facecolors=cmap_obj(norm(np.asarray(price_shift_data, dtype=float))), + shade=False, + ) + + ax.set_xlim(float(HEATMAP_PRICE_RATIOS.min()), float(HEATMAP_PRICE_RATIOS.max())) + ax.set_ylim(float(HEATMAP_MARGINS.min()), float(HEATMAP_MARGINS.max())) + ax.set_zlim( + float(HEATMAP_SHIFT_EXPONENTS.min()), + float(HEATMAP_SHIFT_EXPONENTS.max()), + ) + ax.set_xlabel("Price ratio") + ax.set_ylabel("Centeredness margin") + ax.set_zlabel("Shift exponent") + ax.set_xticks(PRICE_RATIO_TICKS) + ax.set_yticks(MARGIN_TICKS[::2]) + ax.set_zticks(SHIFT_EXPONENT_TICKS) + ax.set_title(title) + ax.grid(False) + ax.view_init(elev=THREE_D_VIEW_ELEVATION, azim=THREE_D_VIEW_AZIMUTH) + try: + ax.set_box_aspect( + ( + float(HEATMAP_PRICE_RATIOS.max() - HEATMAP_PRICE_RATIOS.min()), + float(HEATMAP_MARGINS.max() - HEATMAP_MARGINS.min()), + float( + HEATMAP_SHIFT_EXPONENTS.max() - HEATMAP_SHIFT_EXPONENTS.min() + ), + ) + ) + except AttributeError: + pass + + sm = ScalarMappable(norm=norm, cmap=cmap_obj) + sm.set_array([]) + cbar = fig.colorbar(sm, ax=ax, fraction=0.03, pad=0.1, shrink=0.82) + cbar.set_label(colorbar_label) + + plt.tight_layout() + plt.savefig(filename, dpi=150) + print(f"Saved {filename}") + plt.close(fig) + + +def plot_arc_speed_line_chart( + data, + x_values, + y_values, + y_label, + title, + filename, + launch_curve, + launch_auto_speed=None, +): + """Plot thin multi-series efficiency lines over the arc-speed sweep.""" + fig, ax = plt.subplots(figsize=(10.5, 5.75)) + cmap = plt.cm.viridis + colors = cmap(np.linspace(0.0, 1.0, len(y_values))) + plotted_series = [] + + for yi, (y_value, color) in enumerate(zip(y_values, colors)): + series = np.asarray(data[yi], dtype=float) + plotted_series.append(series) + ax.plot( + x_values, + series, + color=color, + linewidth=SWEEP_LINE_WIDTH, + alpha=0.8, + ) + + launch_curve = np.asarray(launch_curve, dtype=float) + plotted_series.append(launch_curve) + ax.plot( + x_values, + launch_curve, + color="black", + linewidth=REFERENCE_LINE_WIDTH, + alpha=0.9, + label="Current launch config", + ) + if launch_auto_speed is not None: + ax.axvline( + float(launch_auto_speed), + color="black", + ls=":", + linewidth=0.8, + alpha=0.7, + label="Launch auto-cal speed", + ) + + ax.axhline(0.0, color="gray", ls="--", linewidth=0.8, alpha=0.5) + ax.set_xscale("log") + ax.set_xticks(ARC_LENGTH_SPEED_TICKS) + ax.set_xlabel("Arc-length speed") + ax.set_ylabel("Efficiency vs geometric (%)") + ax.set_title(title) + _set_padded_ylim(ax, plotted_series, pad_ratio=0.08) + ax.grid(True, alpha=0.25) + ax.legend(fontsize=8) + + sm = ScalarMappable( + norm=Normalize(vmin=float(np.min(y_values)), vmax=float(np.max(y_values))), + cmap=cmap, + ) + sm.set_array([]) + cbar = fig.colorbar(sm, ax=ax) + cbar.set_label(y_label) + + plt.tight_layout() + plt.savefig(filename, dpi=150) + print(f"Saved {filename}") + plt.close(fig) + + +def generate_heatmaps(base_cfg, price_data, launch_final_values, cache=None): + """Generate pairwise heatmaps for thermostat tuning and noise-vs-arb effects.""" + owns_cache = cache is None + if cache is None: + cache = make_sweep_cache(price_data, cache_scope_cfg=base_cfg) + metric_specs = get_pair_heatmap_metric_specs() + pair_specs = get_pair_heatmap_specs(base_cfg) + metric_spec_map = {spec["key"]: spec for spec in metric_specs} + slice_count = len(pair_specs[0]["fixed_slices"]) if pair_specs else 0 + + if RUN_CONSTANT_ARC_LENGTH: + print( + "Using launch-style benchmarks " + f"Geo=${launch_final_values['geometric']:,.0f}, " + f"Const Arc=${launch_final_values['constant_arc_length']:,.0f}, " + f"TVL={format_tvl_millions_label(base_cfg)}." + ) + print( + "Running {count} heatmap pair sweeps sequentially " + "(3 pair grids x {slice_count} fixed-variable quarter slices; " + "cached noise-model runs are reused across the absolute, launch, " + "and arb-only comparison outputs).".format( + count=len(pair_specs) * slice_count, + slice_count=slice_count, + ) + ) + else: + print( + "Using launch-style geometric benchmark " + f"Geo=${launch_final_values['geometric']:,.0f}, " + f"TVL={format_tvl_millions_label(base_cfg)}." + ) + print( + "RUN_CONSTANT_ARC_LENGTH=False, so only geometric heatmaps will be generated " + f"across {len(pair_specs) * slice_count} fixed-variable pair sweeps." + ) + + for pair in pair_specs: + for slice_variant in pair["fixed_slices"]: + pair_suffix = _pair_slice_suffix(pair, slice_variant) + slice_cfg = dict(base_cfg) + slice_cfg[pair["fixed_key"]] = float(slice_variant["value"]) + output_files = { + spec["key"]: heatmap_artifact_filename( + spec, + base_cfg, + suffix=pair_suffix, + ) + for spec in metric_specs + } + missing_files = _missing_artifacts( + pair_suffix, + list(output_files.values()), + ) + if not missing_files: + continue + + missing_metric_keys = [ + spec["key"] + for spec in metric_specs + if output_files[spec["key"]] in missing_files + ] + data_by_metric = build_heatmap_matrices( + x_values=pair["x_values"], + y_values=pair["y_values"], + x_key=pair["x_key"], + y_key=pair["y_key"], + base_cfg=slice_cfg, + metric_keys=missing_metric_keys, + cache=cache, + progress_label=pair_suffix, + launch_final_values=launch_final_values, + ) + print(f"[{pair_suffix}] plotting missing heatmaps...") + for metric_key in missing_metric_keys: + spec = metric_spec_map[metric_key] + plot_heatmap( + data=data_by_metric[metric_key], + x_values=pair["x_values"], + y_values=pair["y_values"], + x_label=pair["x_label"], + y_label=pair["y_label"], + title=( + f"{spec['title']}: {pair['fixed_label']} {slice_variant['label']} " + f"slice fixed at {format_heatmap_param_value(slice_variant['value'])} | " + f"TVL {format_tvl_millions_label(base_cfg)}" + ), + colorbar_label=spec["colorbar_label"], + filename=output_files[metric_key], + xticks=pair["xticks"], + yticks=pair["yticks"], + center_zero=spec["center_zero"], + cmap=spec["cmap"], + color_norm=spec.get("color_norm"), + symlog_linthresh=spec.get("symlog_linthresh"), + ) + del data_by_metric + gc.collect() + + if owns_cache: + flush_sweep_cache(cache, force=True) + cache.clear() + gc.collect() + print("Released heatmap metric cache.") + + +def generate_three_variable_3d_heatmaps( + base_cfg, + price_data, + launch_final_values, + cache=None, +): + """Render 3D thermostat heatmaps from the three pairwise quarter slices.""" + owns_cache = cache is None + if cache is None: + cache = make_sweep_cache(price_data, cache_scope_cfg=base_cfg) + + metric_specs = get_pair_heatmap_metric_specs() + metric_spec_map = {spec["key"]: spec for spec in metric_specs} + pair_specs = get_pair_heatmap_specs(base_cfg) + pair_by_fixed_key = {pair["fixed_key"]: pair for pair in pair_specs} + price_margin_pair = pair_by_fixed_key["daily_price_shift_exponent"] + shift_margin_pair = pair_by_fixed_key["price_ratio"] + price_shift_pair = pair_by_fixed_key["centeredness_margin"] + slice_count = len(price_margin_pair["fixed_slices"]) + + def build_pair_slice_data(pair, slice_variant, metric_keys): + pair_cfg = dict(base_cfg) + pair_cfg[pair["fixed_key"]] = float(slice_variant["value"]) + return build_heatmap_matrices( + x_values=pair["x_values"], + y_values=pair["y_values"], + x_key=pair["x_key"], + y_key=pair["y_key"], + base_cfg=pair_cfg, + metric_keys=metric_keys, + cache=cache, + progress_label=f"3d_{_pair_slice_suffix(pair, slice_variant)}", + launch_final_values=launch_final_values, + ) + + print( + "\nGenerating 3D thermostat heatmaps " + f"({slice_count} quarter-slice variants, TVL={format_tvl_millions_label(base_cfg)})..." + ) + + for slice_idx in range(slice_count): + shift_slice = price_margin_pair["fixed_slices"][slice_idx] + price_slice = shift_margin_pair["fixed_slices"][slice_idx] + margin_slice = price_shift_pair["fixed_slices"][slice_idx] + slice_slug = shift_slice["slug"] + slice_label = shift_slice["label"] + + output_files = { + spec["key"]: three_d_heatmap_artifact_filename( + spec, + base_cfg, + suffix=f"slice_{slice_slug}", + ) + for spec in metric_specs + } + missing_files = _missing_artifacts( + f"3d_slice_{slice_slug}", + list(output_files.values()), + ) + if not missing_files: + continue + + missing_metric_keys = [ + spec["key"] + for spec in metric_specs + if output_files[spec["key"]] in missing_files + ] + price_margin_data = build_pair_slice_data( + price_margin_pair, + shift_slice, + missing_metric_keys, + ) + shift_margin_data = build_pair_slice_data( + shift_margin_pair, + price_slice, + missing_metric_keys, + ) + price_shift_data = build_pair_slice_data( + price_shift_pair, + margin_slice, + missing_metric_keys, + ) + + for metric_key in missing_metric_keys: + spec = metric_spec_map[metric_key] + plot_three_variable_heatmap_3d( + price_margin_data=price_margin_data[metric_key], + shift_margin_data=shift_margin_data[metric_key], + price_shift_data=price_shift_data[metric_key], + fixed_price_ratio=float(price_slice["value"]), + fixed_margin=float(margin_slice["value"]), + fixed_shift_exponent=float(shift_slice["value"]), + title=( + f"{spec['title']} 3D {slice_label} slice | TVL {format_tvl_millions_label(base_cfg)}\n" + f"price_ratio={format_heatmap_param_value(price_slice['value'])}, " + f"margin={format_heatmap_param_value(margin_slice['value'])}, " + f"shift_exp={format_heatmap_param_value(shift_slice['value'])}" + ), + colorbar_label=spec["colorbar_label"], + filename=output_files[metric_key], + center_zero=spec["center_zero"], + cmap=spec["cmap"], + color_norm=spec.get("color_norm"), + symlog_linthresh=spec.get("symlog_linthresh"), + ) + + del price_margin_data, shift_margin_data, price_shift_data + gc.collect() + + if owns_cache: + flush_sweep_cache(cache, force=True) + cache.clear() + gc.collect() + print("Released 3D heatmap cache.") + + +def compute_auto_calibrated_arc_length_speed(cfg, price_data): + """Compute the launch/reference auto-calibrated speed for a config.""" + start_ts = pd.Timestamp(cfg["start"]) + row = _nearest_price_row(price_data, start_ts) + + if isinstance(price_data.columns, pd.MultiIndex): + initial_price_values = [ + float(row[(token, "close")]) + for token in cfg["tokens"] + ] + else: + initial_price_values = [ + float(row[f"close_{token}"]) + for token in cfg["tokens"] + ] + + initial_prices = jnp.array(initial_price_values, dtype=jnp.float64) + initial_reserves, Va, Vb = initialise_reclamm_reserves( + get_initial_pool_value(cfg), + initial_prices, + float(cfg["price_ratio"]), + ) + market_price_0 = float(initial_prices[0] / initial_prices[1]) + sqrt_Q = jnp.sqrt( + compute_price_ratio( + initial_reserves[0], + initial_reserves[1], + Va, + Vb, + ) + ) + return float( + calibrate_arc_length_speed( + initial_reserves[0], + initial_reserves[1], + Va, + Vb, + to_daily_price_shift_base(float(cfg["daily_price_shift_exponent"])), + 60.0, + sqrt_Q, + market_price_0, + centeredness_margin=float(cfg["centeredness_margin"]), + ) + ) + + +def generate_arc_speed_efficiency_artifacts( + base_cfg, + launch_cfg, + price_data, + launch_final_values, + cache=None, +): + """Generate arc-speed heatmaps plus the existing efficiency line charts.""" + if not RUN_CONSTANT_ARC_LENGTH: + print("\nSkipping arc-speed heatmaps because RUN_CONSTANT_ARC_LENGTH=False.") + return + owns_cache = cache is None + if cache is None: + cache = make_sweep_cache(price_data, cache_scope_cfg=base_cfg) + launch_auto_speed = compute_auto_calibrated_arc_length_speed(launch_cfg, price_data) + heatmap_metric_specs = [ + { + "key": "efficiency_pct", + "title": "Efficiency vs geometric", + "colorbar_label": "Const Arc - heatmap Geo (% of heatmap geometric final value)", + "slug": "efficiency", + "center_zero": True, + "cmap": "RdYlGn", + }, + { + "key": "noise_constant_arc_final_value_musd", + "title": "Const arc final value with noise model", + "colorbar_label": "Const Arc final value with noise model ($M)", + "slug": "noise_constant_arc_final_value", + "center_zero": False, + "cmap": "viridis", + }, + { + "key": "noise_vs_arb_constant_arc_improvement_pct", + "title": "Noise-model improvement over arb-only (const arc)", + "colorbar_label": "Noise-model Const Arc - arb-only Const Arc (% of arb-only final value)", + "slug": "noise_vs_arb_constant_arc_improvement", + "center_zero": True, + "cmap": "RdYlGn", + }, + ] + for spec in heatmap_metric_specs: + if spec["center_zero"]: + spec["color_norm"] = CENTER_ZERO_HEATMAP_COLOR_NORM + spec["symlog_linthresh"] = CENTER_ZERO_HEATMAP_SYMLOG_LINTHRESH + spec["artifact_tag"] = CENTER_ZERO_HEATMAP_COLOR_TAG + pair_specs = [ + { + "slug": "arc_speed_vs_price_ratio", + "x_values": HEATMAP_ARC_LENGTH_SPEEDS, + "y_values": HEATMAP_PRICE_RATIOS, + "x_key": "arc_length_speed", + "y_key": "price_ratio", + "x_label": "Arc-length speed", + "y_label": "Price ratio", + "title_suffix": ( + f"margin fixed at {base_cfg['centeredness_margin']:.2f}, " + f"shift_exp fixed at {base_cfg['daily_price_shift_exponent']:.2f}" + ), + "xticks": ARC_LENGTH_SPEED_TICKS, + "yticks": PRICE_RATIO_TICKS, + }, + { + "slug": "arc_speed_vs_margin", + "x_values": HEATMAP_ARC_LENGTH_SPEEDS, + "y_values": HEATMAP_MARGINS, + "x_key": "arc_length_speed", + "y_key": "centeredness_margin", + "x_label": "Arc-length speed", + "y_label": "Centeredness margin", + "title_suffix": ( + f"price_ratio fixed at {base_cfg['price_ratio']:.2f}, " + f"shift_exp fixed at {base_cfg['daily_price_shift_exponent']:.2f}" + ), + "xticks": ARC_LENGTH_SPEED_TICKS, + "yticks": MARGIN_TICKS + }, + { + "slug": "arc_speed_vs_shift_exp", + "x_values": HEATMAP_ARC_LENGTH_SPEEDS, + "y_values": HEATMAP_SHIFT_EXPONENTS, + "x_key": "arc_length_speed", + "y_key": "daily_price_shift_exponent", + "x_label": "Arc-length speed", + "y_label": "Shift exponent", + "title_suffix": ( + f"price_ratio fixed at {base_cfg['price_ratio']:.2f}, " + f"margin fixed at {base_cfg['centeredness_margin']:.2f}" + ), + "xticks": ARC_LENGTH_SPEED_TICKS, + "yticks": SHIFT_EXPONENT_TICKS, + }, + ] + metric_spec_map = {spec["key"]: spec for spec in heatmap_metric_specs} + + print( + "\nGenerating arc-speed heatmaps and line charts " + f"(launch auto-cal speed={launch_auto_speed:.3e}, TVL={format_tvl_millions_label(base_cfg)})..." + ) + + for pair in pair_specs: + heatmap_files = { + spec["key"]: heatmap_artifact_filename( + spec, + base_cfg, + suffix=pair["slug"], + ) + for spec in heatmap_metric_specs + } + line_filename = tvl_artifact_filename( + "reclamm_line_efficiency", + base_cfg, + suffix=pair["slug"], + ) + missing_files = _missing_artifacts( + pair["slug"], + list(heatmap_files.values()) + [line_filename], + ) + if not missing_files: + continue + + missing_metric_keys = [ + spec["key"] + for spec in heatmap_metric_specs + if heatmap_files[spec["key"]] in missing_files + ] + if line_filename in missing_files and "efficiency_pct" not in missing_metric_keys: + missing_metric_keys.append("efficiency_pct") + + data_by_metric = build_heatmap_matrices( + x_values=pair["x_values"], + y_values=pair["y_values"], + x_key=pair["x_key"], + y_key=pair["y_key"], + base_cfg=base_cfg, + metric_keys=missing_metric_keys, + cache=cache, + progress_label=pair["slug"], + launch_final_values=launch_final_values, + ) + for metric_key in missing_metric_keys: + if metric_key not in heatmap_files: + continue + if heatmap_files[metric_key] not in missing_files: + continue + spec = metric_spec_map[metric_key] + plot_heatmap( + data=data_by_metric[metric_key], + x_values=pair["x_values"], + y_values=pair["y_values"], + x_label=pair["x_label"], + y_label=pair["y_label"], + title=( + f"{spec['title']}: {pair['title_suffix']} | " + f"TVL {format_tvl_millions_label(base_cfg)}" + ), + colorbar_label=spec["colorbar_label"], + filename=heatmap_files[metric_key], + xticks=pair["xticks"], + yticks=pair["yticks"], + xscale="log", + center_zero=spec["center_zero"], + cmap=spec["cmap"], + color_norm=spec.get("color_norm"), + symlog_linthresh=spec.get("symlog_linthresh"), + ) + + if line_filename in missing_files: + efficiency_data = data_by_metric["efficiency_pct"] + launch_curve = build_metric_curve( + x_values=pair["x_values"], + x_key=pair["x_key"], + base_cfg=launch_cfg, + metric_key="efficiency_pct", + cache=cache, + launch_final_values=launch_final_values, + ) + plot_arc_speed_line_chart( + data=efficiency_data, + x_values=pair["x_values"], + y_values=pair["y_values"], + y_label=pair["y_label"], + title=( + "Arc-speed efficiency sweep: " + f"{pair['title_suffix']} | TVL {format_tvl_millions_label(base_cfg)}" + ), + filename=line_filename, + launch_curve=launch_curve, + launch_auto_speed=launch_auto_speed, + ) + del data_by_metric + gc.collect() + + if owns_cache: + flush_sweep_cache(cache, force=True) + cache.clear() + gc.collect() + print("Released arc-speed sweep cache.") + + +def get_launch_final_values( + all_results, + launch_cfg, + price_data, + market_linear_noise_data=None, +): + """Reuse launch-style runs when available; otherwise run them once.""" + for cfg, results in all_results: + if cfg["name"] == launch_cfg["name"]: + launch_final_values = { + "geometric": float(results["geometric"]["final_value"]), + } + if "constant_arc_length" in results: + launch_final_values["constant_arc_length"] = float( + results["constant_arc_length"]["final_value"] + ) + return launch_final_values + + print("\nRunning launch-style benchmarks for heatmaps...") + launch_results = run_comparison( + launch_cfg, + price_data=price_data, + low_data_mode=True, + market_linear_noise_data=market_linear_noise_data, + ) + launch_final_values = { + "geometric": float(launch_results["geometric"]["final_value"]), + } + if "constant_arc_length" in launch_results: + launch_final_values["constant_arc_length"] = float( + launch_results["constant_arc_length"]["final_value"] + ) + del launch_results + gc.collect() + return launch_final_values - return results def print_comparison(cfg, results): """Print text summary table.""" - methods = [ - ("Geometric", results["geometric"]), - ("Geo+Scaled", results["geometric_scaled"]), - ("Const Arc", results["constant_arc_length"]), - ("Arc+Scaled", results["cal_scaled"]), - ] + methods = [("Geometric", results["geometric"])] + has_constant_arc = "constant_arc_length" in results + if has_constant_arc: + methods.append(("Const Arc", results["constant_arc_length"])) + noise_cfg = resolve_reclamm_noise_settings(cfg) hodl_value = float((methods[0][1]["reserves"][0] * methods[0][1]["prices"][-1]).sum()) @@ -128,6 +2217,18 @@ def print_comparison(cfg, results): f"margin={cfg['centeredness_margin']}, " f"shift_exp={cfg['daily_price_shift_exponent']}, " f"fees={cfg['fees']}") + print( + f" base_tvl=${get_initial_pool_value(cfg):,.0f} " + f"(TVL {format_tvl_millions_label(cfg)})" + ) + print(f" note={cfg['reason']}") + print( + f" noise={noise_cfg['noise_summary']}, " + f"gas={cfg.get('gas_cost', 0.0)}, " + f"protocol_fee_split={cfg.get('protocol_fee_split', 0.0)}" + ) + if not has_constant_arc: + print(" constant_arc=disabled") print("-" * 105) header = " {:20s}".format("") for name, _ in methods: @@ -158,18 +2259,27 @@ def print_comparison(cfg, results): vs = (float(r["final_value"]) / hodl_value - 1) * 100 row += f" {vs:>13.2f}%" print(row) + + if has_constant_arc: + geo_final = float(results["geometric"]["final_value"]) + arc_final = float(results["constant_arc_length"]["final_value"]) + geo_lvr = hodl_value - geo_final + arc_lvr = hodl_value - arc_final + print(f" {'Const Arc - Geo':20s} ${arc_final - geo_final:>13,.0f}") + print(f" {'LVR saved vs Geo':20s} ${geo_lvr - arc_lvr:>13,.0f}") print("=" * 105) + def plot_comparison(cfg, results, fig_idx): - """Plot 4-panel comparison for one config.""" - # Method name → (result dict, color, linestyle) + """Plot comparison diagnostics for one config.""" + tvl_label = format_tvl_millions_label(cfg) variants = { "Geometric": (results["geometric"], "C0", "-"), - "Geo+Scaled": (results["geometric_scaled"], "C1", "-"), - "Const arc-len": (results["constant_arc_length"], "C2", "--"), - "Arc+Scaled": (results["cal_scaled"], "C3", "--"), } + has_constant_arc = "constant_arc_length" in results + if has_constant_arc: + variants["Const arc-len"] = (results["constant_arc_length"], "C2", "--") geo = results["geometric"] geo_prices = np.array(geo["prices"]) @@ -181,22 +2291,21 @@ def plot_comparison(cfg, results, fig_idx): price_ratio_traj = geo_prices[:n_steps, 0] / geo_prices[:n_steps, 1] fig, axes = plt.subplots(2, 2, figsize=(14, 10)) - fig.suptitle(cfg["name"], fontsize=13, fontweight="bold") + fig.suptitle(f"{cfg['name']} — TVL {tvl_label}", fontsize=13, fontweight="bold") - # (0,0) Pool value over time ax = axes[0, 0] + plotted_values = [] for name, (r, color, ls) in variants.items(): vals = np.array(r["value"]) + plotted_values.append(vals / 1e6) ax.plot(t_days, vals / 1e6, color=color, ls=ls, label=name, alpha=0.9) - ax.plot(t_days, np.array(hodl_traj) / 1e6, color="gray", ls=":", - alpha=0.5, label="HODL") + _set_padded_ylim(ax, plotted_values, pad_ratio=0.03) ax.set_xlabel("Days") ax.set_ylabel("Pool value ($M)") ax.set_title("Pool value") ax.legend(fontsize=8) ax.grid(True, alpha=0.3) - # (0,1) Cumulative LVR ax = axes[0, 1] for name, (r, color, ls) in variants.items(): vals = np.array(r["value"]) @@ -208,7 +2317,6 @@ def plot_comparison(cfg, results, fig_idx): ax.legend(fontsize=8) ax.grid(True, alpha=0.3) - # (1,0) Price ratio ax = axes[1, 0] ax.plot(t_days, price_ratio_traj, color="C4", alpha=0.7) ax.set_xlabel("Days") @@ -216,7 +2324,6 @@ def plot_comparison(cfg, results, fig_idx): ax.set_title("Price path") ax.grid(True, alpha=0.3) - # (1,1) Empirical weights ax = axes[1, 1] for name, (r, color, ls) in variants.items(): w = np.array(r["weights"]) @@ -230,19 +2337,21 @@ def plot_comparison(cfg, results, fig_idx): ax.grid(True, alpha=0.3) plt.tight_layout() - fname = f"reclamm_thermostat_comparison_{fig_idx}.png" + fname = tvl_artifact_filename("reclamm_thermostat_comparison", cfg, suffix=str(fig_idx)) plt.savefig(fname, dpi=150) print(f"Saved {fname}") plt.close(fig) - # Second figure: diagnostics + if not has_constant_arc: + print("Skipping constant-arc comparison diagnostics because RUN_CONSTANT_ARC_LENGTH=False.") + return + geo_values = np.array(geo["value"]) geo_lvr = np.array(hodl_traj) - geo_values fig2, axes2 = plt.subplots(1, 3, figsize=(18, 5)) - fig2.suptitle(f"{cfg['name']} — diagnostics", fontsize=13, fontweight="bold") + fig2.suptitle(f"{cfg['name']} — diagnostics — TVL {tvl_label}", fontsize=13, fontweight="bold") - # (left) Value difference vs geometric ax = axes2[0] for name, (r, color, ls) in variants.items(): if name == "Geometric": @@ -257,7 +2366,6 @@ def plot_comparison(cfg, results, fig_idx): ax.legend(fontsize=8) ax.grid(True, alpha=0.3) - # (middle) LVR ratio over time ax = axes2[1] mask = np.abs(geo_lvr) > 100 if mask.any(): @@ -279,7 +2387,6 @@ def plot_comparison(cfg, results, fig_idx): ax.set_title("Relative LVR") ax.grid(True, alpha=0.3) - # (right) Per-step LVR histogram ax = axes2[2] all_pos = [] for name, (r, color, ls) in variants.items(): @@ -306,74 +2413,197 @@ def plot_comparison(cfg, results, fig_idx): ax.grid(True, alpha=0.3) plt.tight_layout() - fname2 = f"reclamm_thermostat_diff_{fig_idx}.png" + fname2 = tvl_artifact_filename("reclamm_thermostat_diff", cfg, suffix=str(fig_idx)) plt.savefig(fname2, dpi=150) print(f"Saved {fname2}") plt.close(fig2) + arc_values = np.array(results["constant_arc_length"]["value"]) + n_eff = min(len(geo_values), len(arc_values)) + t_eff = np.arange(n_eff) / (60 * 24) + efficiency_pct = ( + (arc_values[:n_eff] - geo_values[:n_eff]) + / np.maximum(np.abs(geo_values[:n_eff]), 1e-12) + * 100.0 + ) + + fig3, ax3 = plt.subplots(1, 1, figsize=(10, 4.5)) + fig3.suptitle(f"{cfg['name']} — efficiency — TVL {tvl_label}", fontsize=13, fontweight="bold") + ax3.plot( + t_eff, + efficiency_pct, + color="C2", + linewidth=1.8, + label="(Const Arc - Geo) / Geo", + ) + ax3.axhline(0.0, color="gray", ls="--", alpha=0.6) + _set_padded_ylim(ax3, [efficiency_pct], pad_ratio=0.08) + ax3.set_xlabel("Days") + ax3.set_ylabel("Efficiency vs geometric (%)") + ax3.set_title("Efficiency") + ax3.legend(fontsize=8) + ax3.grid(True, alpha=0.3) + + plt.tight_layout() + fname3 = tvl_artifact_filename("reclamm_thermostat_efficiency", cfg, suffix=str(fig_idx)) + plt.savefig(fname3, dpi=150) + print(f"Saved {fname3}") + plt.close(fig3) + + if __name__ == "__main__": - all_results = [] - for i, cfg in enumerate(CONFIGS): - print(f"\n>>> Running {cfg['name']}...") - try: - results = run_comparison(cfg) - print_comparison(cfg, results) - plot_comparison(cfg, results, i) - all_results.append((cfg, results)) - except Exception as e: - print(f" FAILED: {e}") - import traceback - traceback.print_exc() - - # Summary overlay: all configs on one figure (pool value normalised) - if len(all_results) > 1: - fig, axes = plt.subplots(1, 2, figsize=(16, 5)) - fig.suptitle("Cross-config comparison (normalised)", fontsize=13, - fontweight="bold") - - method_keys = [ - ("geometric", "geo", "-"), - ("geometric_scaled", "geo+s", "-."), - ("constant_arc_length", "arc", "--"), - ("cal_scaled", "arc+s", ":"), - ] + shared_price_data = load_shared_price_data(CONFIGS) + shared_market_linear_noise_data = load_shared_market_linear_noise_data() + + for initial_pool_value in TVL_SWEEP_VALUES: + tvl_configs = configs_for_tvl(CONFIGS, initial_pool_value) + tvl_label = format_tvl_millions_label(tvl_configs[0]) + print(f"\n=== TVL sweep: {tvl_label} ===") + + all_results = [] + for i, cfg in enumerate(tvl_configs): + print(f"\n>>> Running {cfg['name']} at TVL {tvl_label}...") + try: + results = run_comparison( + cfg, + price_data=shared_price_data, + market_linear_noise_data=shared_market_linear_noise_data, + ) + print_comparison(cfg, results) + plot_comparison(cfg, results, i) + all_results.append((cfg, results)) + except Exception as e: + print(f" FAILED: {e}") + import traceback + + traceback.print_exc() + + if len(all_results) > 1: + if RUN_CONSTANT_ARC_LENGTH: + fig, axes = plt.subplots(1, 2, figsize=(16, 5)) + fig.suptitle( + f"Cross-config comparison (normalised) — TVL {tvl_label}", + fontsize=13, + fontweight="bold", + ) + + method_keys = [ + ("geometric", "geo", "-"), + ("constant_arc_length", "arc", "--"), + ] + + for i, (cfg, results) in enumerate(all_results): + geo_v = np.array(results["geometric"]["value"]) + t = np.arange(len(geo_v)) / (60 * 24) + short_name = cfg["name"].split("(")[0].strip() + + for j, (key, suffix, ls) in enumerate(method_keys): + v = np.array(results[key]["value"]) + color_idx = i * len(method_keys) + j + + axes[0].plot( + t, + v / v[0], + ls=ls, + alpha=0.8, + label=f"{short_name} {suffix}", + color=f"C{color_idx % 10}", + ) + + if key != "geometric": + pct_diff = (v - geo_v) / geo_v * 100 + axes[1].plot( + t, + pct_diff, + ls=ls, + alpha=0.8, + label=f"{short_name} {suffix}", + color=f"C{color_idx % 10}", + ) - for i, (cfg, results) in enumerate(all_results): - geo_v = np.array(results["geometric"]["value"]) - t = np.arange(len(geo_v)) / (60 * 24) - short_name = cfg["name"].split("(")[0].strip() - - for j, (key, suffix, ls) in enumerate(method_keys): - v = np.array(results[key]["value"]) - color_idx = i * len(method_keys) + j - - # (left) Normalised pool value - axes[0].plot(t, v / v[0], ls=ls, alpha=0.8, - label=f"{short_name} {suffix}", - color=f"C{color_idx % 10}") - - # (right) Value difference vs geometric (skip geo itself) - if key != "geometric": - pct_diff = (v - geo_v) / geo_v * 100 - axes[1].plot(t, pct_diff, ls=ls, alpha=0.8, - label=f"{short_name} {suffix}", - color=f"C{color_idx % 10}") - - axes[0].set_xlabel("Days") - axes[0].set_ylabel("Normalised pool value") - axes[0].set_title("Pool value (V/V0)") - axes[0].legend(fontsize=6, ncol=2) - axes[0].grid(True, alpha=0.3) - - axes[1].set_xlabel("Days") - axes[1].set_ylabel("(Method - Geo) / Geo (%)") - axes[1].set_title("Relative value difference vs Geometric") - axes[1].axhline(0, color="gray", ls="--", alpha=0.5) - axes[1].legend(fontsize=6, ncol=2) - axes[1].grid(True, alpha=0.3) - - plt.tight_layout() - plt.savefig("reclamm_thermostat_summary.png", dpi=150) - print("\nSaved reclamm_thermostat_summary.png") - plt.close(fig) + axes[0].set_xlabel("Days") + axes[0].set_ylabel("Normalised pool value") + axes[0].set_title("Pool value (V/V0)") + axes[0].legend(fontsize=6, ncol=2) + axes[0].grid(True, alpha=0.3) + + axes[1].set_xlabel("Days") + axes[1].set_ylabel("Efficiency vs geometric (%)") + axes[1].set_title("Efficiency vs Geometric") + axes[1].axhline(0, color="gray", ls="--", alpha=0.5) + axes[1].legend(fontsize=6, ncol=2) + axes[1].grid(True, alpha=0.3) + else: + fig, ax = plt.subplots(1, 1, figsize=(9, 5)) + fig.suptitle( + f"Cross-config comparison (normalised geometric) — TVL {tvl_label}", + fontsize=13, + fontweight="bold", + ) + + for i, (cfg, results) in enumerate(all_results): + geo_v = np.array(results["geometric"]["value"]) + t = np.arange(len(geo_v)) / (60 * 24) + short_name = cfg["name"].split("(")[0].strip() + ax.plot( + t, + geo_v / geo_v[0], + ls="-", + alpha=0.8, + label=f"{short_name} geo", + color=f"C{i % 10}", + ) + + ax.set_xlabel("Days") + ax.set_ylabel("Normalised pool value") + ax.set_title("Geometric pool value (V/V0)") + ax.legend(fontsize=6, ncol=2) + ax.grid(True, alpha=0.3) + + plt.tight_layout() + summary_name = tvl_artifact_filename( + "reclamm_thermostat_summary", + tvl_configs[0], + ) + plt.savefig(summary_name, dpi=150) + print(f"\nSaved {summary_name}") + plt.close(fig) + + launch_final_values = get_launch_final_values( + all_results, + launch_cfg=tvl_configs[0], + price_data=shared_price_data, + market_linear_noise_data=shared_market_linear_noise_data, + ) + shared_sweep_cache = make_sweep_cache( + shared_price_data, + cache_scope_cfg=tvl_configs[1], + market_linear_noise_data=shared_market_linear_noise_data, + ) + + print(f"\nGenerating thermostat heatmaps for TVL {tvl_label}...") + generate_heatmaps( + dict(tvl_configs[1]), + shared_price_data, + launch_final_values=launch_final_values, + cache=shared_sweep_cache, + ) + + generate_arc_speed_efficiency_artifacts( + dict(tvl_configs[1]), + launch_cfg=dict(tvl_configs[0]), + price_data=shared_price_data, + launch_final_values=launch_final_values, + cache=shared_sweep_cache, + ) + generate_three_variable_3d_heatmaps( + dict(tvl_configs[1]), + price_data=shared_price_data, + launch_final_values=launch_final_values, + cache=shared_sweep_cache, + ) + flush_sweep_cache(shared_sweep_cache, force=True) + shared_sweep_cache.clear() + gc.collect() + print(f"Released shared sweep cache for TVL {tvl_label}.") diff --git a/scripts/demo_run_1.py b/scripts/demo_run_1.py new file mode 100644 index 0000000..a8b8396 --- /dev/null +++ b/scripts/demo_run_1.py @@ -0,0 +1,196 @@ +import jax.numpy as jnp +from quantammsim.core_simulator.param_utils import ( + memory_days_to_logit_lamb, +) +from quantammsim.runners.jax_runners import do_run_on_historic_data +import itertools +from pathlib import Path +import pandas as pd +import numpy as np +from datetime import datetime +import gc +import warnings + +warnings.filterwarnings("ignore") +from jax import config + +# Default fingerprint used as base for all pools +DEFAULT_FINGERPRINT = { + "startDateString": "2021-01-01 00:00:00", + "endDateString": "2024-06-01 00:00:00", + "endTestDateString": "2024-11-30 00:00:00", + "chunk_period": 60, + "weight_interpolation_period": 60, + "fees": 0.0, + "gas_cost": 0.0, + "use_alt_lamb": False, +} + +EXAMPLE_CONFIGS = { + "reclamm_1": { + "fingerprint": { + "arb_fees": 0.0, + "arb_frequency": 15, + "do_arb": True, + "endDateString": "2025-06-01 00:00:00", + "fees": 0.0025, + "gas_cost": 1.0, + "initial_pool_value": 1000000.0, + "noise_trader_ratio": 0.0, + "protocol_fee_split": 0.25, + "reclamm_arc_length_speed": None, + "reclamm_interpolation_method": "geometric", + "rule": "reclamm", + "startDateString": "2024-06-01 00:00:00", + "tokens": [ + "AAVE", + "ETH" + ] + }, + "params": { + "centeredness_margin": 0.3184210526315789, + "daily_price_shift_base": 0.9999984155508669, + "price_ratio": 1.3349999999999989 + } + }, + "reclamm_2":{ + "fingerprint": { + "arb_fees": 0.0, + "arb_frequency": 15, + "do_arb": True, + "endDateString": "2025-06-01 00:00:00", + "fees": 0.0025, + "gas_cost": 1.0, + "initial_pool_value": 1000000.0, + "noise_model": "calibrated", + "noise_trader_ratio": 0.0, + "protocol_fee_split": 0.25, + "reclamm_arc_length_speed": None, + "reclamm_interpolation_method": "geometric", + "reclamm_noise_params": { + "c_0": -0.453, + "c_1": 0.025, + "c_2": -0.06, + "c_3": 0.31, + "c_4": -0.149, + "c_5": 0.359, + "c_6": 0.061, + "c_7": 0.06 + }, + "rule": "reclamm", + "startDateString": "2024-06-01 00:00:00", + "tokens": [ + "AAVE", + "ETH" + ] + }, + "params": { + "centeredness_margin": 0.3184210526315789, + "daily_price_shift_base": 0.9999984155508669, + "price_ratio": 1.3349999999999989 + } + } +} + + +if __name__ == "__main__": + + import matplotlib.pyplot as plt + import numpy as np + from quantammsim.core_simulator.param_utils import ( + generate_params_combinations, + jax_logit_lamb_to_lamb, + lamb_to_memory_days, + lamb_to_memory_days_clipped, + calc_lamb, + ) + from quantammsim.pools.G3M.quantamm.update_rule_estimators.estimator_primitives import ( + squareplus, + inverse_squareplus, + inverse_squareplus_np, + ) + + for name, config in EXAMPLE_CONFIGS.items(): + print(name) + if 'reclamm' not in name: + continue + print(f"\nRunning {name}...") + result = do_run_on_historic_data( + run_fingerprint=config["fingerprint"], + params=config["params"], + ) + print("-" * 80) + print(f"Pool Type: {config['fingerprint']['rule']}") + print(f"Tokens: {', '.join(config['fingerprint']['tokens'])}") + print(f"Fees: {config['fingerprint'].get('fees', 0.0)}") + if "arb_quality" in config["fingerprint"]: + print(f"Arb Quality: {config['fingerprint']['arb_quality']}") + print(f"Initial Pool Value: ${result['value'][0]:.2f}") + print(f"Final Pool Value: ${result['final_value']:.2f}") + print(f"Return: {(result['final_value']/result['value'][0]-1)*100:.2f}%") + print( + f"Return over hodl: {(result['final_value']/(result['reserves'][0]*result['prices'][-1]).sum()-1)*100:.2f}%" + ) + print("-" * 80) + # memory_days = lamb_to_memory_days(jax_logit_lamb_to_lamb(config["params"]["logit_lamb"]), config["fingerprint"]["chunk_period"]) + # print("memory days: ", memory_days) + if "logit_lamb" in config["params"]: + memory_days = lamb_to_memory_days_clipped( + calc_lamb(config["params"]), + chunk_period=config["fingerprint"]["chunk_period"], + max_memory_days=365, + ) + print(f"{'memory days':<20} {str(memory_days)}") + lamb = calc_lamb(config["params"]) + print( + f"{'lamb':<20} {jnp.array_str(lamb, precision=16, suppress_small=False)}" + ) + if "log_k" in config["params"]: + k = 2 ** config["params"]["log_k"] * memory_days + k_str = " ".join(f"{x:.16e}" for x in k) + print(f"{'k':<20} [{k_str}]") + k_per_day_str = " ".join( + f"{x:.16e}" for x in 2 ** config["params"]["log_k"] + ) + print(f"{'k per day':<20} [{k_per_day_str}]") + if "raw_exponents" in config["params"]: + exponents = squareplus(config["params"]["raw_exponents"]) + exp_str = " ".join(f"{x:.16f}" for x in exponents) + print(f"{'exponents':<20} [{exp_str}]") + if "raw_width" in config["params"]: + width = 2 ** config["params"]["raw_width"] + width_str = " ".join(f"{x:.16e}" for x in width) + print(f"{'width':<20} [{width_str}]") + if "log_amplitude" in config["params"]: + memory_days = lamb_to_memory_days_clipped( + calc_lamb(config["params"]), + chunk_period=config["fingerprint"]["chunk_period"], + max_memory_days=365, + ) + amplitude = (2 ** config["params"]["log_amplitude"]) * memory_days + amp_str = " ".join(f"{x:.16e}" for x in amplitude) + print(f"{'amplitude':<20} [{amp_str}]") + if "logit_pre_exp_scaling" in config["params"]: + pre_exp_scaling = jnp.exp(config["params"]["logit_pre_exp_scaling"]) / ( + 1 + jnp.exp(config["params"]["logit_pre_exp_scaling"]) + ) + pes_str = " ".join(f"{x:.16f}" for x in pre_exp_scaling) + print(f"{'pre_exp_scaling':<20} [{pes_str}]") + if "raw_pre_exp_scaling" in config["params"]: + pre_exp_scaling = 2 ** config["params"]["raw_pre_exp_scaling"] + pes_str = " ".join(f"{x:.16f}" for x in pre_exp_scaling) + print(f"{'pre_exp_scaling':<20} [{pes_str}]") + + print("-" * 80) + print("final readouts") + if result.get("readouts") is not None: + for readout in result["readouts"]: + print(f"{readout}: { jnp.array_str(result['readouts'][readout][-1], precision=16, suppress_small=False)}") + print("-" * 80) + print("final weights") + print(f"{jnp.array_str(result['weights'][-1], precision=16, suppress_small=False)}") + print("-" * 80) + print("final prices") + print(f"{jnp.array_str(result['prices'][-1], precision=16, suppress_small=False)}") + print("=" * 80) + \ No newline at end of file diff --git a/scripts/demo_run_reclamm.py b/scripts/demo_run_reclamm.py index 3ea21ec..132f512 100644 --- a/scripts/demo_run_reclamm.py +++ b/scripts/demo_run_reclamm.py @@ -36,105 +36,127 @@ def balancer_fingerprint(tokens, start, end, fees): } +def reclamm_fingerprint(tokens, start, end, fees, interpolation_method="geometric"): + """Build a reCLAMM fingerprint for a demo scenario.""" + return { + "tokens": tokens, + "rule": "reclamm", + "startDateString": start, + "endDateString": end, + "initial_pool_value": 1000000.0, + "do_arb": True, + "fees": fees, + "gas_cost": 0.0, + "arb_fees": 0.0, + "chunk_period": 60, + "weight_interpolation_period": 60, + "reclamm_interpolation_method": interpolation_method, + "reclamm_arc_length_speed": None, + } + + +def reclamm_params(price_ratio, centeredness_margin, daily_price_shift_exponent): + """Build reCLAMM params from a concise config.""" + return { + "price_ratio": jnp.array(price_ratio), + "centeredness_margin": jnp.array(centeredness_margin), + "daily_price_shift_base": jnp.array( + to_daily_price_shift_base(daily_price_shift_exponent) + ), + } + + +def _apply_active_noise_settings(fp): + """Enable the active AAVE/ETH reCLAMM noise model for demo runs.""" + if fp.get("rule") != "reclamm" or list(fp.get("tokens", [])) != ["AAVE", "ETH"]: + return fp, "disabled" + + from compare_reclamm_thermostats import ( + AAVE_ETH_NOISE_SETTINGS, + resolve_reclamm_noise_settings, + ) + + cfg = { + "tokens": fp["tokens"], + "start": fp["startDateString"], + "end": fp["endDateString"], + "enable_noise_model": True, + "noise_model": AAVE_ETH_NOISE_SETTINGS["noise_model"], + "noise_artifact_dir": AAVE_ETH_NOISE_SETTINGS["noise_artifact_dir"], + "noise_pool_id": AAVE_ETH_NOISE_SETTINGS["noise_pool_id"], + "gas_cost": fp.get("gas_cost", AAVE_ETH_NOISE_SETTINGS["gas_cost"]), + "protocol_fee_split": fp.get( + "protocol_fee_split", + AAVE_ETH_NOISE_SETTINGS["protocol_fee_split"], + ), + "arb_frequency": fp.get("arb_frequency"), + "noise_trader_ratio": fp.get("noise_trader_ratio", 0.0), + "reclamm_noise_params": fp.get("reclamm_noise_params"), + "noise_arrays_path": fp.get("noise_arrays_path"), + } + noise_cfg = resolve_reclamm_noise_settings(cfg) + + updated = dict(fp) + updated["gas_cost"] = cfg["gas_cost"] + updated["protocol_fee_split"] = cfg["protocol_fee_split"] + updated["noise_trader_ratio"] = noise_cfg.get("noise_trader_ratio", 0.0) + for key in ("noise_model", "reclamm_noise_params", "noise_arrays_path", "arb_frequency"): + if noise_cfg.get(key) is not None: + updated[key] = noise_cfg[key] + return updated, noise_cfg["noise_summary"] + + SCENARIOS = [ { - "name": "AAVE/ETH on-chain (25bps)", + "name": "AAVE/ETH launch-style range (25bps, geometric)", "reclamm": { - "fingerprint": { - "tokens": ["AAVE", "ETH"], - "rule": "reclamm", - "startDateString": "2024-06-01 00:00:00", - "endDateString": "2025-06-01 00:00:00", - "initial_pool_value": 1000000.0, - "do_arb": True, - "fees": 0.0025, - "gas_cost": 0.0, - "arb_fees": 0.0, - "chunk_period": 60, - "weight_interpolation_period": 60, - }, - "params": { - "price_ratio": jnp.array(1.5), - "centeredness_margin": jnp.array(0.5), - "daily_price_shift_base": jnp.array( - to_daily_price_shift_base(0.1) - ), - }, + "fingerprint": reclamm_fingerprint( + ["AAVE", "ETH"], + "2024-06-01 00:00:00", + "2025-06-01 00:00:00", + 0.0025, + interpolation_method="geometric", + ), + "params": reclamm_params(1.5014, 0.5, 0.1), }, }, { - "name": "AAVE/ETH zero fees", + "name": "AAVE/ETH tighter launch-style range (25bps, geometric)", "reclamm": { - "fingerprint": { - "tokens": ["AAVE", "ETH"], - "rule": "reclamm", - "startDateString": "2024-06-01 00:00:00", - "endDateString": "2025-06-01 00:00:00", - "initial_pool_value": 1000000.0, - "do_arb": True, - "fees": 0.0, - "gas_cost": 0.0, - "arb_fees": 0.0, - "chunk_period": 60, - "weight_interpolation_period": 60, - }, - "params": { - "price_ratio": jnp.array(1.5), - "centeredness_margin": jnp.array(0.5), - "daily_price_shift_base": jnp.array( - to_daily_price_shift_base(0.1) - ), - }, + "fingerprint": reclamm_fingerprint( + ["AAVE", "ETH"], + "2024-06-01 00:00:00", + "2025-06-01 00:00:00", + 0.0025, + interpolation_method="geometric", + ), + "params": reclamm_params(1.15, 0.5, 0.1), }, }, { - "name": "AAVE/ETH wide range (25bps)", + "name": "AAVE/ETH tighter launch-style range (25bps, constant arc)", "reclamm": { - "fingerprint": { - "tokens": ["AAVE", "ETH"], - "rule": "reclamm", - "startDateString": "2024-06-01 00:00:00", - "endDateString": "2025-06-01 00:00:00", - "initial_pool_value": 1000000.0, - "do_arb": True, - "fees": 0.0025, - "gas_cost": 0.0, - "arb_fees": 0.0, - "chunk_period": 60, - "weight_interpolation_period": 60, - }, - "params": { - "price_ratio": jnp.array(4.0), - "centeredness_margin": jnp.array(0.2), - "daily_price_shift_base": jnp.array( - to_daily_price_shift_base(1.0) - ), - }, + "fingerprint": reclamm_fingerprint( + ["AAVE", "ETH"], + "2024-06-01 00:00:00", + "2025-06-01 00:00:00", + 0.0025, + interpolation_method="constant_arc_length", + ), + "params": reclamm_params(1.15, 0.5, 0.1), }, }, { "name": "BTC/ETH (10bps)", "reclamm": { - "fingerprint": { - "tokens": ["BTC", "ETH"], - "rule": "reclamm", - "startDateString": "2024-01-01 00:00:00", - "endDateString": "2025-06-01 00:00:00", - "initial_pool_value": 1000000.0, - "do_arb": True, - "fees": 0.001, - "gas_cost": 0.0, - "arb_fees": 0.0, - "chunk_period": 60, - "weight_interpolation_period": 60, - }, - "params": { - "price_ratio": jnp.array(2.0), - "centeredness_margin": jnp.array(0.3), - "daily_price_shift_base": jnp.array( - to_daily_price_shift_base(0.5) - ), - }, + "fingerprint": reclamm_fingerprint( + ["BTC", "ETH"], + "2024-01-01 00:00:00", + "2025-06-01 00:00:00", + 0.001, + interpolation_method="geometric", + ), + "params": reclamm_params(2.0, 0.3, 0.5), }, }, ] @@ -143,7 +165,7 @@ def balancer_fingerprint(tokens, start, end, fees): def run_scenario(scenario): """Run a reClAMM config and its Balancer 50/50 baseline, print comparison.""" rc = scenario["reclamm"] - fp = rc["fingerprint"] + fp, noise_summary = _apply_active_noise_settings(dict(rc["fingerprint"])) # Run reClAMM reclamm_result = do_run_on_historic_data( @@ -173,7 +195,14 @@ def run_scenario(scenario): print("=" * 80) print(f" {scenario['name']}") - print(f" Tokens: {', '.join(fp['tokens'])} | Fees: {fp['fees']}") + print( + f" Tokens: {', '.join(fp['tokens'])} | Fees: {fp['fees']} | " + f"Interpolation: {fp.get('reclamm_interpolation_method', 'geometric')}" + ) + print( + f" Noise: {noise_summary} | Gas: {fp.get('gas_cost', 0.0)} | " + f"Protocol fee split: {fp.get('protocol_fee_split', 0.0)}" + ) print("-" * 80) print(f" {'':30s} {'reClAMM':>14s} {'Balancer 50/50':>14s}") print(f" {'Initial value':30s} ${rc_init:>13,.0f} ${bal_init:>13,.0f}") diff --git a/scripts/hypersurge_demo_train.py b/scripts/hypersurge_demo_train.py new file mode 100644 index 0000000..3710c97 --- /dev/null +++ b/scripts/hypersurge_demo_train.py @@ -0,0 +1,173 @@ +"""Demo training entry points for HyperSurge pool variants. + +Compared with [demo_train.py](./demo_train.py), HyperSurge training needs: +1. A HyperSurge-enabled pool rule. +2. HyperSurge fee-curve initial values in the run fingerprint. +3. For reCLAMM, the usual range-shape parameters as well. +4. Optionally, a separate oracle-price frame. If omitted, both pool + implementations fall back to the traded price series as the reference. + +`price_data` and `oracle_prices` intentionally have different shapes: +- `price_data`: parquet-like price frame with `close_` columns and a + unix-ms or datetime index. +- `oracle_prices`: minute-level frame with a `unix` column plus one column per + token in pool order. +""" + +from copy import deepcopy + +import numpy as np +import pandas as pd + +from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames +from quantammsim.runners.jax_runners import train_on_historic_data + + +DEFAULT_FINGERPRINT = { + "startDateString": "2023-01-01 00:00:00", + "endDateString": "2023-03-01 00:00:00", + "endTestDateString": "2023-04-01 00:00:00", + "chunk_period": 1440, + "weight_interpolation_period": 1440, + "bout_offset": 24 * 60 * 7, + "initial_pool_value": 1_000_000.0, + "fees": 0.003, + "gas_cost": 0.0, + "arb_fees": 0.0, + "arb_frequency": 1, + "do_arb": True, + "return_val": "daily_log_sharpe", + "optimisation_settings": { + "method": "gradient_descent", + "base_lr": 0.05, + "optimiser": "adam", + "batch_size": 8, + "n_iterations": 250, + "n_parameter_sets": 4, + "training_data_kind": "historic", + "sample_method": "uniform", + "initial_random_key": 0, + "n_cycles": 1, + "val_fraction": 0.0, + "early_stopping": False, + }, +} + + +HYPERSURGE_INITIALS = { + "initial_hypersurge_arb_max_fee": 0.02, + "initial_hypersurge_arb_threshold": 0.10, + "initial_hypersurge_arb_cap_deviation": 0.50, + "initial_hypersurge_noise_max_fee": 0.10, + "initial_hypersurge_noise_threshold": 0.10, + "initial_hypersurge_noise_cap_deviation": 0.50, +} + + +def oracle_prices_from_price_data(price_data: pd.DataFrame, tokens): + """Build a DynamicInputFrames-compatible oracle frame from price data.""" + index = pd.Index(price_data.index) + if np.issubdtype(index.dtype, np.datetime64): + unix = index.view("int64") // 10**6 + else: + unix = pd.to_numeric(index, errors="raise").astype("int64") + + data = {"unix": unix} + for token in tokens: + column = f"close_{token}" + if column not in price_data.columns: + raise ValueError( + f"price_data must contain {column} to derive oracle prices" + ) + data[token] = price_data[column].to_numpy() + return pd.DataFrame(data) + + +def _dynamic_input_frames(oracle_prices): + if oracle_prices is None: + return None + return DynamicInputFrames(oracle_prices=oracle_prices) + + +def _common_training_fingerprint(rule, tokens): + fingerprint = deepcopy(DEFAULT_FINGERPRINT) + fingerprint["rule"] = rule + fingerprint["tokens"] = list(tokens) + fingerprint.update(HYPERSURGE_INITIALS) + return fingerprint + + +def train_hypersurge_balancer( + tokens=("BTC", "ETH"), + *, + root=None, + price_data=None, + oracle_prices=None, + use_price_data_as_oracle=False, + verbose=True, +): + """Train the HyperSurge Balancer pool.""" + fingerprint = _common_training_fingerprint("balancer_hypersurge", tokens) + if use_price_data_as_oracle: + if price_data is None: + raise ValueError("price_data is required when use_price_data_as_oracle=True") + oracle_prices = oracle_prices_from_price_data(price_data, tokens) + + return train_on_historic_data( + run_fingerprint=fingerprint, + root=root, + price_data=price_data, + dynamic_input_frames=_dynamic_input_frames(oracle_prices), + verbose=verbose, + return_training_metadata=True, + force_init=True, + ) + + +def train_hypersurge_reclamm( + tokens=("BTC", "ETH"), + *, + root=None, + price_data=None, + oracle_prices=None, + use_price_data_as_oracle=False, + verbose=True, +): + """Train the HyperSurge reCLAMM pool.""" + fingerprint = _common_training_fingerprint("reclamm_hypersurge", tokens) + fingerprint.update( + { + "initial_price_ratio": 2.0, + "initial_centeredness_margin": 0.25, + "initial_daily_price_shift_base": 1.0 - 1.0 / 124000.0, + "reclamm_interpolation_method": "geometric", + } + ) + if use_price_data_as_oracle: + if price_data is None: + raise ValueError("price_data is required when use_price_data_as_oracle=True") + oracle_prices = oracle_prices_from_price_data(price_data, tokens) + + return train_on_historic_data( + run_fingerprint=fingerprint, + root=root, + price_data=price_data, + dynamic_input_frames=_dynamic_input_frames(oracle_prices), + verbose=verbose, + return_training_metadata=True, + force_init=True, + ) + + +if __name__ == "__main__": + examples = [ + ("balancer_hypersurge", train_hypersurge_balancer), + ("reclamm_hypersurge", train_hypersurge_reclamm), + ] + for name, train_fn in examples: + print(f"\nTraining {name}...") + params, metadata = train_fn(verbose=True) + print( + f"{name}: objective={metadata['final_objective']:.4f}, " + f"epochs={metadata['epochs_trained']}" + ) diff --git a/scripts/plot_calibrated_vs_real.py b/scripts/plot_calibrated_vs_real.py new file mode 100644 index 0000000..475422b --- /dev/null +++ b/scripts/plot_calibrated_vs_real.py @@ -0,0 +1,241 @@ +"""Plot predicted vs real volume using saved per-pool calibrated noise model. + +Loads the artifact from experiments/run_linear_market_noise.py (--per-pool), +rebuilds the features, evaluates V_arb(learned cadence) + V_noise(x @ coeffs_i), +and generates stacked area plots showing the arb/noise decomposition per pool. + +Usage: + # First train and save: + python experiments/run_linear_market_noise.py --per-pool --no-split --epochs 2000 + + # Then plot: + python scripts/plot_calibrated_vs_real.py + python scripts/plot_calibrated_vs_real.py --artifact results/linear_market_noise/model.npz +""" + +import argparse +import json +import os +import sys +import time + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +ARTIFACT_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "linear_market_noise", +) +OUTPUT_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "calibrated_vs_real", +) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--artifact", default=os.path.join(ARTIFACT_DIR, "model.npz")) + parser.add_argument("--meta", default=os.path.join(ARTIFACT_DIR, "meta.json")) + parser.add_argument("--output-dir", default=OUTPUT_DIR) + args = parser.parse_args() + + os.environ.setdefault("JAX_PLATFORMS", "cpu") + os.makedirs(args.output_dir, exist_ok=True) + + # ---- Load artifact ---- + print(f"Loading artifact: {args.artifact}") + art = np.load(args.artifact, allow_pickle=True) + noise_coeffs = art["noise_coeffs"] + log_cadence = art["log_cadence"] + init_log_cadences = art["init_log_cadences"] + + with open(args.meta) as f: + meta = json.load(f) + feat_names = meta["feat_names"] + pool_ids = meta["pool_ids"] + n_pools = meta["n_pools"] + hparams = meta["hparams"] + per_pool = noise_coeffs.ndim == 2 + + print(f" {n_pools} pools, {len(feat_names)} features, per_pool={per_pool}") + print(f" hparams: {hparams}") + + # ---- Rebuild data (features only, no training) ---- + import jax.numpy as jnp + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + from experiments.run_linear_market_noise import load_stage1, build_data + + matched_clean, option_c_clean = load_stage1() + + print("\nRebuilding features...") + t0 = time.time() + data = build_data( + matched_clean, option_c_clean, + trend_windows=tuple(hparams["trend_windows"]), + include_market=True, include_cross_pool=True, + ) + print(f" {len(data['pool_idx'])} samples, {time.time() - t0:.1f}s") + + x = data["x"] + y_total = data["y_total"] + pool_idx = data["pool_idx"] + day_idx = data["day_idx"] + sgd = data["sample_grid_days"] + + # ---- Compute predictions ---- + if per_pool: + per_sample_coeffs = noise_coeffs[pool_idx] + log_v_noise = np.sum(x * per_sample_coeffs, axis=1) + else: + log_v_noise = x @ noise_coeffs + + if "pool_intercepts" in art: + log_v_noise = log_v_noise + art["pool_intercepts"][pool_idx] + + v_noise = np.exp(log_v_noise) + + v_arb = np.zeros(len(y_total)) + for i in range(n_pools): + mask = pool_idx == i + if not mask.any(): + continue + v_arb_all = np.array(interpolate_pool_daily( + data["pool_coeffs"][i], jnp.float64(log_cadence[i]), + data["pool_gas"][i])) + v_arb[mask] = v_arb_all[sgd[mask]] + + v_obs = np.exp(y_total) + log_v_arb = np.log(np.maximum(v_arb, 1e-10)) + pred_total_log = np.logaddexp(log_v_arb, log_v_noise) + + # ---- Reconstruct dates ---- + all_dates = set() + for pid in pool_ids: + all_dates.update(matched_clean[pid]["panel"]["date"].values) + date_list = sorted(all_dates) + + # ---- Plot: stacked area per pool ---- + per_page = 9 + n_pages = (n_pools + per_page - 1) // per_page + + for page in range(n_pages): + start = page * per_page + end = min(start + per_page, n_pools) + n_this = end - start + + ncols = 3 + nrows = (n_this + ncols - 1) // ncols + fig, axes = plt.subplots(nrows, ncols, figsize=(16, 4 * nrows)) + if nrows == 1: + axes = axes.reshape(1, -1) + + for idx, i in enumerate(range(start, end)): + ax = axes[idx // ncols][idx % ncols] + mask = pool_idx == i + if mask.sum() < 5: + ax.set_visible(False) + continue + + days = day_idx[mask] + dates = [pd.Timestamp(date_list[d]) for d in days] + vo = v_obs[mask] + va = v_arb[mask] + vn = v_noise[mask] + + ax.fill_between(dates, 0, va, alpha=0.3, color="steelblue", + label="V_arb") + ax.fill_between(dates, va, va + vn, alpha=0.3, color="coral", + label="V_noise") + ax.plot(dates, vo, "k-", linewidth=0.8, alpha=0.7, label="V_obs") + ax.plot(dates, va + vn, "--", color="darkred", linewidth=0.8, + alpha=0.7, label="V_pred") + + ax.set_yscale("log") + ax.set_ylabel("USD/day", fontsize=7) + ax.tick_params(labelsize=6) + ax.tick_params(axis="x", rotation=30) + + yt = y_total[mask] + pt = pred_total_log[mask] + ss_res = np.sum((yt - pt) ** 2) + ss_tot = np.sum((yt - yt.mean()) ** 2) + r2 = 1 - ss_res / max(ss_tot, 1e-10) + + pid = pool_ids[i] + tokens = matched_clean[pid]["tokens"] + chain = matched_clean[pid]["chain"] + ci = np.exp(init_log_cadences[i]) + cl = np.exp(log_cadence[i]) + arb_share = np.median(va / vo) * 100 + noise_share = np.median(vn / vo) * 100 + b_tvl = noise_coeffs[i, 1] if per_pool else noise_coeffs[1] + + ax.set_title( + f"{tokens} ({chain})\n" + f"R\u00b2={r2:.3f} cad={ci:.0f}\u2192{cl:.0f}min " + f"arb={arb_share:.0f}% noise={noise_share:.0f}% " + f"b_tvl={b_tvl:.2f}", + fontsize=7) + ax.legend(fontsize=6, loc="upper right") + + for idx in range(n_this, nrows * ncols): + axes[idx // ncols][idx % ncols].set_visible(False) + + fig.suptitle( + f"Per-pool calibrated noise model \u2014 V_arb + V_noise " + f"(page {page+1}/{n_pages})", fontsize=10) + fig.tight_layout() + out = os.path.join(args.output_dir, f"calibrated_page{page+1}.png") + fig.savefig(out, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" Saved: {out}") + + # ---- Summary ---- + summary = [] + for i in range(n_pools): + mask = pool_idx == i + if mask.sum() < 2: + continue + yt = y_total[mask] + pt = pred_total_log[mask] + ss_res = np.sum((yt - pt) ** 2) + ss_tot = np.sum((yt - yt.mean()) ** 2) + r2 = 1 - ss_res / max(ss_tot, 1e-10) + + pid = pool_ids[i] + va = v_arb[mask] + vn = v_noise[mask] + vo = v_obs[mask] + + summary.append({ + "pool_id": pid, + "tokens": matched_clean[pid]["tokens"], + "chain": matched_clean[pid]["chain"], + "n_obs": int(mask.sum()), + "R2": r2, + "cadence_init": float(np.exp(init_log_cadences[i])), + "cadence_learned": float(np.exp(log_cadence[i])), + "median_arb_pct": float(np.median(va / vo) * 100), + "median_noise_pct": float(np.median(vn / vo) * 100), + "b_tvl": float(noise_coeffs[i, 1] if per_pool else noise_coeffs[1]), + }) + + summary_df = pd.DataFrame(summary) + csv_path = os.path.join(args.output_dir, "summary.csv") + summary_df.to_csv(csv_path, index=False) + print(f"\n Summary: {csv_path}") + print(f" Median R\u00b2: {summary_df['R2'].median():.4f}") + print(f" Median arb: {summary_df['median_arb_pct'].median():.0f}%") + print(f" b_tvl: [{summary_df['b_tvl'].min():.2f}," + f" {summary_df['b_tvl'].max():.2f}]," + f" median={summary_df['b_tvl'].median():.2f}") + + +if __name__ == "__main__": + main() diff --git a/scripts/plot_mm_noise_fit.py b/scripts/plot_mm_noise_fit.py new file mode 100644 index 0000000..6d1b331 --- /dev/null +++ b/scripts/plot_mm_noise_fit.py @@ -0,0 +1,494 @@ +"""Plot MM noise model fit: per-pool time series + TVL response curves. + +Produces two types of plots: +1. Per-pool 6-panel time series (like plot_model_vs_real_reclamm.py): + TVL, volume decomposition, V_noise, fee revenue, vol/TVL, pred/obs +2. Cross-pool TVL response curves showing MM saturation + +Usage: + python scripts/plot_mm_noise_fit.py + python scripts/plot_mm_noise_fit.py --pool 0x9d1fcf346ea1b0 + python scripts/plot_mm_noise_fit.py --all-pools +""" + +import argparse +import json +import os +import pickle +import sys + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +import jax.numpy as jnp + + +CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "token_factored_calibration", "_cache", +) + + +def load_model(artifact_dir): + """Load MM model artifact.""" + art = dict(np.load(os.path.join(artifact_dir, "model.npz"), + allow_pickle=True)) + with open(os.path.join(artifact_dir, "meta.json")) as f: + meta = json.load(f) + params = {} + for k in art: + params[k] = jnp.array(art[k]) + return params, meta + + +def get_pool_K(params, decomp, pool_i): + """Get median K for a pool, handling all K modes.""" + mask = decomp["pool_idx"] == pool_i + if "k_scale" in params: + ks = np.array(params["k_scale"]) + if not mask.any(): + return float(np.exp(ks[0])) + lc = decomp.get("log_comp_tvl", np.zeros(mask.sum()))[mask] + log_K = ks[0] + ks[1] * lc + return float(np.exp(np.median(log_K))) + elif "log_K" in params: + return float(np.exp(params["log_K"][pool_i])) + elif "k_params" in params: + k_p = np.array(params["k_params"]) + return float(np.exp(k_p[0])) + elif "log_comp_tvl" in decomp and mask.any(): + # Observed K directly from competitor TVL + return float(np.exp(np.median(decomp["log_comp_tvl"][mask]))) + return np.exp(14.5) + + +def compute_decomposition(params, meta, matched_clean, option_c_clean): + """Compute V_arb, V_noise, V_total for all pools.""" + from experiments.run_mm_noise import build_mm_data, forward_mm + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + + data = build_mm_data(matched_clean, option_c_clean, + trend_windows=(7,), + include_cross_pool=False) + + pool_ids = data["pool_ids"] + n_pools = data["n_pools"] + pool_idx = np.array(data["pool_idx"]) + sgd = np.array(data["sample_grid_days"]) + day_idx = np.array(data["day_idx"]) + y = np.array(data["y_total"]) + log_tvl = np.array(data["log_tvl"]) + + log_cadence = np.array(params["log_cadence"]) + + # V_arb + v_arb = np.zeros(len(y)) + for i in range(n_pools): + mask = pool_idx == i + if not mask.any(): + continue + v_arb_all = np.array(interpolate_pool_daily( + data["pool_coeffs"][i], jnp.float64(log_cadence[i]), + data["pool_gas"][i])) + safe = np.clip(sgd[mask], 0, len(v_arb_all) - 1) + v_arb[mask] = v_arb_all[safe] + + # V_noise + log_v_noise = np.array(forward_mm( + params, jnp.array(data["x_market"]), + jnp.array(data["log_tvl"]), + jnp.array(data["pool_idx"]), + log_comp_tvl=jnp.array(data["log_comp_tvl"]))) + v_noise = np.exp(log_v_noise) + v_total = v_arb + v_noise + v_obs = np.exp(y) + + # Reconstruct dates + all_dates = set() + for pid in pool_ids: + all_dates.update(matched_clean[pid]["panel"]["date"].values) + date_list = sorted(all_dates) + + dates = np.array([pd.Timestamp(date_list[d]) for d in day_idx]) + + return { + "pool_ids": pool_ids, + "pool_tokens": data["pool_tokens"], + "pool_idx": pool_idx, + "dates": dates, + "v_arb": v_arb, + "v_noise": v_noise, + "v_total": v_total, + "v_obs": v_obs, + "log_tvl": log_tvl, + "log_comp_tvl": data["log_comp_tvl"], + "tvl": np.exp(log_tvl), + } + + +def plot_pool_timeseries(decomp, params, pool_i, output_dir): + """6-panel time series for a single pool.""" + pid = decomp["pool_ids"][pool_i] + toks = decomp["pool_tokens"][pool_i] + label = f"{toks[0]}/{toks[1]}" + + mask = decomp["pool_idx"] == pool_i + if mask.sum() < 10: + print(f" Skipping {pid[:16]} ({label}): {mask.sum()} samples") + return + + dates = decomp["dates"][mask] + v_arb = decomp["v_arb"][mask] + v_noise = decomp["v_noise"][mask] + v_total = decomp["v_total"][mask] + v_obs = decomp["v_obs"][mask] + tvl = decomp["tvl"][mask] + + K_i = get_pool_K(params, decomp, pool_i) + + # R² + log_pred = np.log(np.maximum(v_total, 1e-10)) + log_obs = np.log(np.maximum(v_obs, 1e-10)) + ss_res = np.sum((log_pred - log_obs) ** 2) + ss_tot = np.sum((log_obs - log_obs.mean()) ** 2) + r2 = 1 - ss_res / max(ss_tot, 1e-10) + + fig, axes = plt.subplots(6, 1, figsize=(14, 18), sharex=True) + + # 1. TVL + ax = axes[0] + ax.plot(dates, tvl / 1e6, "k-", linewidth=0.7) + ax.axhline(K_i / 1e6, color="red", linestyle="--", alpha=0.5, + label=f"K = ${K_i/1e6:.1f}M") + ax.set_ylabel("TVL ($M)") + ax.set_yscale("log") + ax.legend(fontsize=8) + ax.set_title(f"{label} — TVL (K={K_i/1e6:.1f}M)") + ax.grid(True, alpha=0.3) + + # 2. Volume decomposition + ax = axes[1] + ax.fill_between(dates, 0, v_arb / 1e6, alpha=0.4, color="steelblue", + label="V_arb") + ax.fill_between(dates, v_arb / 1e6, v_total / 1e6, alpha=0.4, + color="coral", label="V_noise (MM)") + ax.plot(dates, v_obs / 1e6, "k-", linewidth=0.5, alpha=0.7, + label="V_obs") + ax.plot(dates, v_total / 1e6, "r--", linewidth=0.5, alpha=0.7, + label="V_pred") + ax.set_ylabel("Volume ($M/day)") + ax.set_yscale("log") + ax.legend(fontsize=7) + ax.set_title(f"Volume Decomposition (R²={r2:.3f})") + ax.grid(True, alpha=0.3) + + # 3. V_noise only + ax = axes[2] + ax.fill_between(dates, 0, v_noise / 1e6, alpha=0.4, color="coral") + ax.plot(dates, v_noise / 1e6, "r-", linewidth=0.5) + noise_med = np.median(v_noise) + ax.axhline(noise_med / 1e6, color="red", linestyle=":", alpha=0.5, + label=f"median=${noise_med:,.0f}") + ax.set_ylabel("V_noise ($M/day)") + ax.set_yscale("log") + ax.legend(fontsize=8) + ax.set_title("Noise Volume (MM model)") + ax.grid(True, alpha=0.3) + + # 4. Fee revenue (assuming 0.25% fee, 25% protocol take) + fee_rate = 0.0025 + protocol_take = 0.25 + fee_arb = v_arb * fee_rate * (1 - protocol_take) + fee_noise = v_noise * fee_rate * (1 - protocol_take) + fee_obs = v_obs * fee_rate * (1 - protocol_take) + ax = axes[3] + ax.fill_between(dates, 0, fee_arb, alpha=0.4, color="steelblue", + label="Arb fees") + ax.fill_between(dates, fee_arb, fee_arb + fee_noise, alpha=0.4, + color="coral", label="Noise fees") + ax.plot(dates, fee_obs, "k-", linewidth=0.5, alpha=0.7, + label="Obs fees (approx)") + ax.set_ylabel("Fee revenue ($/day)") + ax.legend(fontsize=7) + ax.set_title("Fee Revenue (0.25% fee, 75% LP)") + ax.grid(True, alpha=0.3) + + # 5. Vol/TVL + ax = axes[4] + vol_tvl_obs = v_obs / tvl + vol_tvl_pred = v_total / tvl + ax.plot(dates, vol_tvl_obs * 100, "k-", linewidth=0.5, alpha=0.5, + label="Observed") + ax.plot(dates, vol_tvl_pred * 100, "r-", linewidth=0.5, alpha=0.5, + label="Predicted") + ax.axhline(np.median(vol_tvl_obs) * 100, color="black", linestyle=":", + alpha=0.3) + ax.axhline(np.median(vol_tvl_pred) * 100, color="red", linestyle=":", + alpha=0.3) + ax.set_ylabel("Vol/TVL (%)") + ax.legend(fontsize=8) + ax.set_title("Volume as % of TVL") + ax.grid(True, alpha=0.3) + + # 6. Pred/Obs ratio + ax = axes[5] + ratio = v_total / np.maximum(v_obs, 1) + ax.plot(dates, ratio, "b-", linewidth=0.5, alpha=0.5) + ax.axhline(1.0, color="black", linestyle="-", alpha=0.3) + med_ratio = np.median(ratio) + ax.axhline(med_ratio, color="blue", linestyle=":", alpha=0.5, + label=f"median={med_ratio:.2f}") + ax.set_ylabel("Pred / Obs") + ax.set_yscale("log") + ax.set_ylim(0.05, 20) + ax.legend(fontsize=8) + ax.set_title("Prediction Ratio") + ax.grid(True, alpha=0.3) + ax.set_xlabel("Date") + + fig.suptitle(f"MM Noise Model — {label} ({pid[:16]})", fontsize=13) + fig.tight_layout() + out = os.path.join(output_dir, f"{pid[:16]}_mm_fit.png") + fig.savefig(out, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" Saved: {out}") + + +def plot_tvl_response(params, meta, decomp, output_dir): + """Cross-pool TVL response curves showing MM saturation.""" + pool_ids = meta["pool_ids"] + pool_tokens = meta["pool_tokens"] + n_pools = len(pool_ids) + + tvl_range = np.logspace(4, 10, 200) # $10K to $10B + + fig, axes = plt.subplots(1, 3, figsize=(18, 6)) + + pool_idx_arr = np.array(decomp["pool_idx"]) + interesting = [] + for i in range(n_pools): + n = (pool_idx_arr == i).sum() + if n > 0: + interesting.append(i) + + colors = plt.cm.tab20(np.linspace(0, 1, len(interesting))) + + # Panel 1: Noise volume vs TVL (absolute) + ax = axes[0] + for ci, i in enumerate(interesting): + K_i = get_pool_K(params, decomp, i) + + mask = pool_idx_arr == i + actual_noise = np.median(decomp["v_noise"][mask]) + actual_tvl = np.median(decomp["tvl"][mask]) + # Scale: noise(TVL) = actual_noise * [TVL/(K+TVL)] / [actual_TVL/(K+actual_TVL)] + mm_actual = actual_tvl / (K_i + actual_tvl) + mm_curve = tvl_range / (K_i + tvl_range) + noise_curve = actual_noise * mm_curve / mm_actual + + label = f"{pool_tokens[i][0]}/{pool_tokens[i][1]}" + ax.plot(tvl_range / 1e6, noise_curve / 1e6, color=colors[ci], + linewidth=1.0, alpha=0.7, label=label) + # Mark actual TVL + ax.scatter([actual_tvl / 1e6], [actual_noise / 1e6], + color=colors[ci], s=20, zorder=5) + + ax.set_xlabel("TVL ($M)") + ax.set_ylabel("Daily Noise Volume ($M)") + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_title("Noise Volume vs TVL (MM saturation)") + ax.legend(fontsize=5, ncol=3, loc="best") + ax.grid(True, alpha=0.3) + + # Panel 2: Noise/TVL ratio vs TVL + ax = axes[1] + for ci, i in enumerate(interesting): + K_i = get_pool_K(params, decomp, i) + mask = pool_idx_arr == i + actual_noise = np.median(decomp["v_noise"][mask]) + actual_tvl = np.median(decomp["tvl"][mask]) + mm_actual = actual_tvl / (K_i + actual_tvl) + mm_curve = tvl_range / (K_i + tvl_range) + noise_curve = actual_noise * mm_curve / mm_actual + ratio_curve = noise_curve / tvl_range * 100 + + label = f"{pool_tokens[i][0]}/{pool_tokens[i][1]}" + ax.plot(tvl_range / 1e6, ratio_curve, color=colors[ci], + linewidth=1.0, alpha=0.7, label=label) + + ax.set_xlabel("TVL ($M)") + ax.set_ylabel("Noise / TVL (%)") + ax.set_xscale("log") + ax.set_title("Noise as Fraction of TVL") + ax.legend(fontsize=5, ncol=3, loc="best") + ax.grid(True, alpha=0.3) + + # Panel 3: Elasticity vs TVL + ax = axes[2] + for ci, i in enumerate(interesting): + K_i = get_pool_K(params, decomp, i) + eps_curve = K_i / (K_i + tvl_range) + + label = f"{pool_tokens[i][0]}/{pool_tokens[i][1]}" + ax.plot(tvl_range / 1e6, eps_curve, color=colors[ci], + linewidth=1.0, alpha=0.7, label=label) + # Mark actual TVL + actual_tvl = np.median(decomp["tvl"][pool_idx_arr == i]) + eps_actual = K_i / (K_i + actual_tvl) + ax.scatter([actual_tvl / 1e6], [eps_actual], + color=colors[ci], s=20, zorder=5) + + ax.axhline(0.5, color="gray", linestyle="--", alpha=0.3, label="ε=0.5") + ax.set_xlabel("TVL ($M)") + ax.set_ylabel("Elasticity ε(TVL)") + ax.set_xscale("log") + ax.set_ylim(0, 1.05) + ax.set_title("TVL Elasticity (K/(K+TVL))") + ax.legend(fontsize=5, ncol=3, loc="best") + ax.grid(True, alpha=0.3) + + fig.suptitle("Michaelis-Menten Noise Model — TVL Response", fontsize=13) + fig.tight_layout() + out = os.path.join(output_dir, "mm_tvl_response.png") + fig.savefig(out, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" Saved: {out}") + + +def plot_K_distribution(params, meta, decomp, output_dir): + """Plot K values across pools with data quality indicator.""" + pool_ids = meta["pool_ids"] + pool_tokens = meta["pool_tokens"] + n_pools = len(pool_ids) + pool_idx_arr = np.array(decomp["pool_idx"]) + + K_vals = [] + labels = [] + n_days = [] + tvl_ranges = [] + for i in range(n_pools): + mask = pool_idx_arr == i + n = mask.sum() + K_i = get_pool_K(params, decomp, i) + K_vals.append(K_i) + tok = pool_tokens[i] + labels.append(f"{tok[0]}/{tok[1]}") + n_days.append(n) + if n > 0: + tvl = decomp["tvl"][mask] + tvl_ranges.append(np.log10(tvl.max() / max(tvl.min(), 1))) + else: + tvl_ranges.append(0) + + K_vals = np.array(K_vals) + n_days = np.array(n_days) + tvl_ranges = np.array(tvl_ranges) + + # Sort by K + order = np.argsort(K_vals) + + fig, axes = plt.subplots(1, 2, figsize=(16, 8)) + + # Panel 1: K values as horizontal bar + ax = axes[0] + y_pos = np.arange(n_pools) + colors = plt.cm.viridis(tvl_ranges[order] / max(tvl_ranges.max(), 1)) + ax.barh(y_pos, K_vals[order] / 1e6, color=colors, height=0.7) + ax.set_yticks(y_pos) + ax.set_yticklabels([labels[i] for i in order], fontsize=7) + ax.set_xlabel("K ($M)") + ax.set_title("Half-Saturation TVL by Pool\n(color = log10 TVL range)") + ax.axvline(np.median(K_vals) / 1e6, color="red", linestyle="--", + alpha=0.5, label=f"median=${np.median(K_vals)/1e6:.1f}M") + ax.legend() + ax.grid(True, alpha=0.3, axis="x") + + # Panel 2: K vs data quality (n_days and TVL range) + ax = axes[1] + valid = n_days > 0 + sc = ax.scatter(n_days[valid], K_vals[valid] / 1e6, + c=tvl_ranges[valid], cmap="viridis", + s=50, alpha=0.7) + for i in range(n_pools): + if n_days[i] > 0: + ax.annotate(labels[i], (n_days[i], K_vals[i] / 1e6), + fontsize=5, alpha=0.7) + ax.set_xlabel("Number of training days") + ax.set_ylabel("K ($M)") + ax.set_title("K vs Data Quantity\n(color = log10 TVL range)") + plt.colorbar(sc, ax=ax, label="log10(TVL_max/TVL_min)") + ax.grid(True, alpha=0.3) + + fig.suptitle("Michaelis-Menten K Distribution", fontsize=13) + fig.tight_layout() + out = os.path.join(output_dir, "mm_K_distribution.png") + fig.savefig(out, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" Saved: {out}") + + +def main(): + parser = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--pool", default=None, + help="Plot single pool (prefix match)") + parser.add_argument("--all-pools", action="store_true") + parser.add_argument("--artifact-dir", default="results/mm_noise") + parser.add_argument("--output-dir", default="results/mm_noise/plots") + parser.add_argument("--top-n", type=int, default=None, + help="Plot top N pools by sample count (default: all)") + args = parser.parse_args() + + os.environ.setdefault("JAX_PLATFORMS", "cpu") + os.makedirs(args.output_dir, exist_ok=True) + + # Load + print("Loading model...") + params, meta = load_model(args.artifact_dir) + + print("Loading data...") + with open(os.path.join(CACHE_DIR, "stage1.pkl"), "rb") as f: + stage1 = pickle.load(f) + matched_clean = stage1["matched_clean"] + option_c_clean = stage1["option_c_clean"] + + print("Computing decomposition...") + decomp = compute_decomposition(params, meta, matched_clean, option_c_clean) + + pool_ids = decomp["pool_ids"] + pool_idx = decomp["pool_idx"] + + # Which pools to plot + if args.pool: + targets = [i for i, pid in enumerate(pool_ids) + if pid.startswith(args.pool)] + elif args.top_n is not None: + counts = [(pool_idx == i).sum() for i in range(len(pool_ids))] + targets = sorted(range(len(pool_ids)), key=lambda i: -counts[i]) + targets = targets[:args.top_n] + else: + # Default: all pools + targets = list(range(len(pool_ids))) + + # Per-pool time series + print(f"\nPlotting {len(targets)} pools...") + for i in targets: + plot_pool_timeseries(decomp, params, i, args.output_dir) + + # TVL response curves + print("\nPlotting TVL response...") + plot_tvl_response(params, meta, decomp, args.output_dir) + + # K distribution + print("\nPlotting K distribution...") + plot_K_distribution(params, meta, decomp, args.output_dir) + + print(f"\nDone. Plots in {args.output_dir}/") + + +if __name__ == "__main__": + main() diff --git a/scripts/plot_model_vs_real_reclamm.py b/scripts/plot_model_vs_real_reclamm.py new file mode 100644 index 0000000..c0c3da3 --- /dev/null +++ b/scripts/plot_model_vs_real_reclamm.py @@ -0,0 +1,262 @@ +"""Plot full model (V_arb + V_noise) vs real observed volume for a pool. + +Uses the pool's actual historical TVL path, evaluates V_arb from the PCHIP +grid at the learned cadence, and V_noise from the per-pool linear model. +Compares against observed total volume. + +Usage: + python scripts/plot_model_vs_real_reclamm.py + python scripts/plot_model_vs_real_reclamm.py --pool 0x3de27efa2f1aa6 + python scripts/plot_model_vs_real_reclamm.py --pool 0x9d1fcf346ea1b0 +""" + +import argparse +import os +import pickle +import sys + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +import jax.numpy as jnp +from quantammsim.calibration.grid_interpolation import interpolate_pool_daily +from quantammsim.calibration.noise_model_arrays import load_artifact + + +CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "token_factored_calibration", "_cache", +) +ARTIFACT_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "linear_market_noise", +) +OUTPUT_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "model_vs_real", +) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--pool", default="0x9d1fcf346ea1b0", + help="Pool ID prefix") + parser.add_argument("--artifact-dir", default=ARTIFACT_DIR) + parser.add_argument("--output-dir", default=OUTPUT_DIR) + args = parser.parse_args() + + os.environ.setdefault("JAX_PLATFORMS", "cpu") + os.makedirs(args.output_dir, exist_ok=True) + + # Load data + with open(os.path.join(CACHE_DIR, "stage1.pkl"), "rb") as f: + data = pickle.load(f) + mc = data["matched_clean"] + oc = data["option_c_clean"] + + pid = args.pool + entry = mc[pid] + panel = entry["panel"] + dates = pd.to_datetime(panel["date"]) + vol_obs = np.exp(panel["log_volume"].values.astype(float)) + tvl = np.exp(panel["log_tvl_lag1"].values.astype(float)) + + # Load noise model + art, meta = load_artifact(args.artifact_dir) + pool_ids = meta["pool_ids"] + idx = pool_ids.index(pid) + coeffs = art["noise_coeffs"][idx] + cadence = float(np.exp(art["log_cadence"][idx])) + gas = float(np.exp(oc[pid]["log_gas"])) + + print(f"Pool: {pid} ({entry['tokens']}, {entry['chain']})") + print(f"Cadence: {cadence:.1f} min, Gas: ${gas}") + print(f"{len(dates)} days: {dates.min().date()} → {dates.max().date()}") + print(f"TVL: ${tvl.min():,.0f} – ${tvl.max():,.0f}") + + # V_arb from PCHIP + v_arb_all = np.array(interpolate_pool_daily( + entry["coeffs"], jnp.float64(np.log(cadence)), jnp.float64(gas))) + + # V_noise from model at actual TVL + from experiments.run_linear_market_noise import build_data + data_full = build_data(mc, oc, trend_windows=(7,), + include_market=True, include_cross_pool=False) + x_full = data_full["x"] + pool_idx_full = data_full["pool_idx"] + pool_mask = pool_idx_full == idx + sample_x = x_full[pool_mask] + sgd = data_full["sample_grid_days"][pool_mask] + day_idx = data_full["day_idx"][pool_mask] + + log_v_noise = sample_x @ coeffs + v_noise = np.exp(log_v_noise) + v_arb_samples = v_arb_all[sgd] + v_total_pred = v_arb_samples + v_noise + + # Align dates + all_dates = set() + for p in pool_ids: + all_dates.update(mc[p]["panel"]["date"].values) + date_list = sorted(all_dates) + sample_dates = np.array([pd.Timestamp(date_list[d]) for d in day_idx]) + + # Match TVL and obs volume + tvl_samples = np.zeros(len(sample_dates)) + vol_obs_samples = np.zeros(len(sample_dates)) + for i, sd in enumerate(sample_dates): + matches = np.where(dates == sd)[0] + if len(matches) > 0: + tvl_samples[i] = tvl[matches[0]] + vol_obs_samples[i] = vol_obs[matches[0]] + + valid = tvl_samples > 100 + sd = sample_dates[valid] + vo = vol_obs_samples[valid] + va = v_arb_samples[valid] + vn = v_noise[valid] + vt = v_total_pred[valid] + tv = tvl_samples[valid] + + # R² + log_obs = np.log(np.maximum(vo, 1)) + log_pred = np.log(np.maximum(vt, 1)) + ss_res = np.sum((log_obs - log_pred) ** 2) + ss_tot = np.sum((log_obs - log_obs.mean()) ** 2) + r2 = 1 - ss_res / max(ss_tot, 1e-10) + + print(f"\nR² (log): {r2:.3f}") + print(f"Median obs: ${np.median(vo):,.0f}, pred: ${np.median(vt):,.0f}") + print(f"Median V_arb: ${np.median(va):,.0f}, V_noise: ${np.median(vn):,.0f}") + + # Fee rate + fee_rate = float(panel["swap_fee"].iloc[0]) if "swap_fee" in panel.columns else 0.003 + + # Plot + fig, axes = plt.subplots(6, 1, figsize=(14, 20), sharex=True) + + # 1. TVL + ax = axes[0] + ax.plot(sd, tv, "b-", linewidth=1) + ax.set_ylabel("TVL (USD)") + ax.set_yscale("log") + ax.set_title(f"{entry['tokens']} ({entry['chain']}) — " + f"Model (V_arb + V_noise) vs Observed " + f"[R\u00b2={r2:.3f}, cadence={cadence:.0f}min, fee={fee_rate:.4f}]") + ax.grid(True, alpha=0.3) + + # 2. Volume: stacked arb + noise vs observed + ax = axes[1] + ax.fill_between(sd, 0, va, alpha=0.3, color="steelblue", label="V_arb (PCHIP)") + ax.fill_between(sd, va, va + vn, alpha=0.3, color="coral", label="V_noise (model)") + ax.plot(sd, vo, "k-", linewidth=0.8, alpha=0.7, label="V_obs (actual)") + ax.plot(sd, vt, "r--", linewidth=0.8, alpha=0.5, label="V_pred = V_arb + V_noise") + ax.set_ylabel("Volume (USD/day)") + ax.set_yscale("log") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # 3. V_noise only + ax = axes[2] + ax.fill_between(sd, 0, vn, alpha=0.4, color="coral") + ax.plot(sd, vn, "r-", linewidth=0.8, alpha=0.7, label="V_noise (model)") + ax.axhline(np.median(vn), color="red", linestyle="--", alpha=0.5, + label=f"median: ${np.median(vn):,.0f}") + ax.set_ylabel("Noise volume (USD/day)") + ax.set_yscale("log") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # 4. Fee revenue: observed vs predicted + ax = axes[3] + fee_obs = vo * fee_rate + fee_pred = vt * fee_rate + fee_noise_only = vn * fee_rate + fee_arb = va * fee_rate + ax.fill_between(sd, 0, fee_arb, alpha=0.3, color="steelblue", label="Arb fees") + ax.fill_between(sd, fee_arb, fee_pred, alpha=0.3, color="coral", label="Noise fees") + ax.plot(sd, fee_obs, "k-", linewidth=0.8, alpha=0.7, label="Observed fees") + ax.plot(sd, fee_pred, "r--", linewidth=0.8, alpha=0.5, label="Predicted total fees") + ax.plot(sd, fee_noise_only, "m-", linewidth=0.8, alpha=0.6, + label=f"Noise fees only (med=${np.median(fee_noise_only):,.0f})") + ax.set_ylabel("Fee revenue (USD/day)") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + # 5. Vol/TVL + ax = axes[4] + vol_tvl_obs = vo / tv * 100 + vol_tvl_pred = vt / tv * 100 + ax.plot(sd, vol_tvl_obs, "k-", linewidth=0.8, alpha=0.7, label="Observed") + ax.plot(sd, vol_tvl_pred, "r--", linewidth=0.8, alpha=0.5, label="Predicted") + ax.axhline(np.median(vol_tvl_obs), color="black", linestyle=":", + alpha=0.3, label=f"obs median: {np.median(vol_tvl_obs):.1f}%") + ax.axhline(np.median(vol_tvl_pred), color="red", linestyle=":", + alpha=0.3, label=f"pred median: {np.median(vol_tvl_pred):.1f}%") + ax.set_ylabel("Vol / TVL (%)") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + ax.set_ylim(0, min(np.percentile(vol_tvl_obs, 95) * 2, 200)) + + # 6. Pred/Obs ratio + ax = axes[5] + ratio = vt / np.maximum(vo, 1) + ax.plot(sd, ratio, "g-", linewidth=0.8, alpha=0.7) + ax.axhline(1.0, color="black", linestyle="--", alpha=0.5, label="perfect") + ax.axhline(np.median(ratio), color="red", linestyle="--", alpha=0.5, + label=f"median: {np.median(ratio):.2f}") + ax.set_ylabel("Pred / Obs") + ax.set_xlabel("Date") + ax.set_yscale("log") + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + ax.set_ylim(0.01, 100) + + fig.tight_layout() + out = os.path.join(args.output_dir, f"{pid[:16]}_model_vs_real.png") + fig.savefig(out, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"\nSaved: {out}") + + # Fee summary + fee_obs_total = np.sum(vo * fee_rate) + fee_pred_total = np.sum(vt * fee_rate) + fee_noise_total = np.sum(vn * fee_rate) + fee_arb_total = np.sum(va * fee_rate) + print(f"\nFee revenue (cumulative, fee={fee_rate:.4f}):") + print(f" Observed: ${fee_obs_total:,.0f}") + print(f" Predicted: ${fee_pred_total:,.0f}" + f" (arb: ${fee_arb_total:,.0f}, noise: ${fee_noise_total:,.0f})") + + # Pre/post deposit stats (for reClAMM AAVE/ETH) + pre = sd < pd.Timestamp("2026-01-10") + post = sd >= pd.Timestamp("2026-01-20") + if pre.sum() > 5 and post.sum() > 5: + print(f"\nPre-deposit (before Jan 10):") + print(f" TVL: ${np.median(tv[pre]):,.0f}") + print(f" V_obs: ${np.median(vo[pre]):,.0f}," + f" V_pred: ${np.median(vt[pre]):,.0f}") + print(f" V_arb: ${np.median(va[pre]):,.0f}," + f" V_noise: ${np.median(vn[pre]):,.0f}") + print(f" Fees obs: ${np.median(vo[pre])*fee_rate:,.0f}/day," + f" pred: ${np.median(vt[pre])*fee_rate:,.0f}/day") + print(f" Pred/Obs: {np.median(vt[pre] / vo[pre]):.2f}") + print(f"Post-deposit (after Jan 20):") + print(f" TVL: ${np.median(tv[post]):,.0f}") + print(f" V_obs: ${np.median(vo[post]):,.0f}," + f" V_pred: ${np.median(vt[post]):,.0f}") + print(f" V_arb: ${np.median(va[post]):,.0f}," + f" V_noise: ${np.median(vn[post]):,.0f}") + print(f" Fees obs: ${np.median(vo[post])*fee_rate:,.0f}/day," + f" pred: ${np.median(vt[post])*fee_rate:,.0f}/day") + print(f" Pred/Obs: {np.median(vt[post] / vo[post]):.2f}") + + +if __name__ == "__main__": + main() diff --git a/scripts/plot_reclamm_optuna_result.py b/scripts/plot_reclamm_optuna_result.py index 35242da..b90ba40 100644 --- a/scripts/plot_reclamm_optuna_result.py +++ b/scripts/plot_reclamm_optuna_result.py @@ -1,15 +1,20 @@ #!/usr/bin/env python3 """Plot reClAMM pool performance from Optuna tuning results. -Reads the SGD-compatible JSON output of tune_reclamm_params.py (or any Optuna -run), extracts the best trial's pool params, re-runs a forward pass over the -full train+test window, and produces a value-over-time plot with on-chain -baselines and cumulative fee revenue. +Reads SGD-compatible JSON output(s) of tune_reclamm_params.py, extracts the +best trial's pool params, re-runs a forward pass over the full train+test +window, and produces a value-over-time plot with on-chain baselines and +cumulative fee revenue. Usage: + # Single result python scripts/plot_reclamm_optuna_result.py results/run_.json - python scripts/plot_reclamm_optuna_result.py results/run_.json --output my_plot.png - python scripts/plot_reclamm_optuna_result.py results/run_.json --top-k 3 + + # Multiple results (comparison across objectives / noise models) + python scripts/plot_reclamm_optuna_result.py results/run_*.json + + # Top-3 trials from each result + python scripts/plot_reclamm_optuna_result.py results/run_*.json --top-k 3 """ import argparse @@ -36,12 +41,20 @@ BG = "#162536" TEXT_COLOR = "#E6CE97" +# Extended palette for multi-file comparison COLORS = [ - "#3498db", "#2ecc71", "#e74c3c", # top-k - "#f39c12", # on-chain launch - "#9b59b6", # on-chain current + "#3498db", "#2ecc71", "#e74c3c", "#f39c12", "#9b59b6", + "#1abc9c", "#e67e22", "#2980b9", "#c0392b", "#8e44ad", + "#27ae60", "#d35400", "#16a085", "#f1c40f", "#7f8c8d", ] +# Short labels for objectives +_OBJ_SHORT = { + "daily_log_sharpe": "sharpe", + "returns_over_hodl": "ret/hodl", + "fee_revenue_over_value": "fee_rev", +} + def _plot_order(configs): """Yield (name, meta, color_idx) with baselines first, optimized trials last.""" @@ -60,9 +73,10 @@ def parse_args(): description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter, ) - p.add_argument("results_json", help="Path to run_.json from Optuna") + p.add_argument("results_json", nargs="+", + help="Path(s) to run_.json from Optuna") p.add_argument("--top-k", type=int, default=1, - help="Plot top K trials by objective (default 1)") + help="Plot top K trials per result file (default 1)") p.add_argument("--output", default=None, help="Output PNG path (default: auto-generated)") p.add_argument("--no-onchain", action="store_true", @@ -100,6 +114,18 @@ def extract_pool_params(trial, config): return params +def _noise_model_label(config): + """Short label describing the noise model in the config.""" + nm = config.get("noise_model", "ratio") + if nm != "calibrated": + ntr = config.get("noise_trader_ratio", 0.0) + return f"{nm}(ntr={ntr})" + nc = config.get("reclamm_noise_params", {}) + n_coeffs = len(nc) + arb_freq = config.get("arb_frequency", 1) + return f"cal-{n_coeffs}cov(af={arb_freq})" + + def run_full_period(params, config, fees_override=None): """Run forward pass over the full train+test window.""" fees = fees_override if fees_override is not None else config["fees"] @@ -120,19 +146,28 @@ def run_full_period(params, config, fees_override=None): "reclamm_centeredness_scaling": config.get("reclamm_centeredness_scaling", False), "reclamm_learn_arc_length_speed": config.get("reclamm_learn_arc_length_speed", False), } + # Forward noise model settings + if "noise_model" in config: + fp["noise_model"] = config["noise_model"] + if "reclamm_noise_params" in config: + fp["reclamm_noise_params"] = config["reclamm_noise_params"] + if "noise_arrays_path" in config: + fp["noise_arrays_path"] = config["noise_arrays_path"] + if "arb_frequency" in config: + fp["arb_frequency"] = config["arb_frequency"] jax_params = {k: jnp.array(v) for k, v in params.items()} return do_run_on_historic_data(run_fingerprint=fp, params=jax_params) -def plot_results(configs, time_series, hodl_values, config, args): +def plot_results(configs, time_series, hodl_values, ref_config, args): """Two-panel plot: value-over-time + cumulative fee revenue.""" - train_end_str = config["endDateString"] + train_end_str = ref_config["endDateString"] train_end_dt = datetime.strptime(train_end_str, "%Y-%m-%d %H:%M:%S") first_out = next(iter(time_series.values())) n_minutes = len(first_out["value"]) dates = pd.date_range( - start=datetime.strptime(config["startDateString"], "%Y-%m-%d %H:%M:%S"), + start=datetime.strptime(ref_config["startDateString"], "%Y-%m-%d %H:%M:%S"), periods=n_minutes, freq="1min", ) step = 1440 @@ -157,7 +192,7 @@ def plot_results(configs, time_series, hodl_values, config, args): vals = np.array(out["value"][::step]) / 1e6 label = f"{name}" if "test_objective" in meta: - obj_name = config.get("return_val", "objective") + obj_name = meta.get("obj_name", "objective") label += f" (OOS {obj_name}={meta['test_objective']:.4f})" is_optimized = "On-Chain" not in name ax_val.plot(dates_daily[:len(vals)], vals, @@ -171,21 +206,19 @@ def plot_results(configs, time_series, hodl_values, config, args): ax_val.axvline(x=train_end_dt, color="white", linestyle=":", alpha=0.5, linewidth=1.5) ylims = ax_val.get_ylim() - ax_val.text(train_end_dt - pd.Timedelta(days=10), ylims[1] * 0.97, "Train", + ax_val.text(train_end_dt - pd.Timedelta(days=5), ylims[1] * 0.97, "Train", color="white", alpha=0.6, fontsize=11, ha="right", va="top") - ax_val.text(train_end_dt + pd.Timedelta(days=10), ylims[1] * 0.97, "Test", + ax_val.text(train_end_dt + pd.Timedelta(days=5), ylims[1] * 0.97, "Test", color="white", alpha=0.6, fontsize=11, ha="left", va="top") _style_axis(ax_val) ax_val.set_ylabel("Pool Value ($M USD)", color=TEXT_COLOR, fontsize=12) - tokens_str = "/".join(config["tokens"]) - obj_name = config.get("return_val", "objective") - ntr = config.get("noise_trader_ratio", 0.0) + tokens_str = "/".join(ref_config["tokens"]) ax_val.set_title( - f"reClAMM Optuna-Optimized ({obj_name}, noise={ntr}) — {tokens_str}", + f"reClAMM Optuna Comparison — {tokens_str}", color=TEXT_COLOR, fontsize=13, pad=15, ) - ax_val.legend(loc="upper left", fontsize=9, facecolor=BG, + ax_val.legend(loc="upper left", fontsize=8, facecolor=BG, edgecolor=TEXT_COLOR, labelcolor=TEXT_COLOR) # ── Panel 2: Cumulative fee revenue ─────────────────────────────── @@ -208,7 +241,7 @@ def plot_results(configs, time_series, hodl_values, config, args): _style_axis(ax_fee) ax_fee.set_ylabel("Cumulative Fee Revenue ($K)", color=TEXT_COLOR, fontsize=12) ax_fee.set_xlabel("Date", color=TEXT_COLOR, fontsize=12) - ax_fee.legend(loc="upper left", fontsize=9, facecolor=BG, + ax_fee.legend(loc="upper left", fontsize=8, facecolor=BG, edgecolor=TEXT_COLOR, labelcolor=TEXT_COLOR) else: ax_val.set_xlabel("Date", color=TEXT_COLOR, fontsize=12) @@ -222,11 +255,11 @@ def plot_results(configs, time_series, hodl_values, config, args): plt.close() -def plot_test_only(configs, time_series, hodl_values, config, args): +def plot_test_only(configs, time_series, hodl_values, ref_config, args): """Test-period plot with all curves normalised to start at 1.0.""" - train_end_str = config["endDateString"] + train_end_str = ref_config["endDateString"] train_end_dt = datetime.strptime(train_end_str, "%Y-%m-%d %H:%M:%S") - start_dt = datetime.strptime(config["startDateString"], "%Y-%m-%d %H:%M:%S") + start_dt = datetime.strptime(ref_config["startDateString"], "%Y-%m-%d %H:%M:%S") first_out = next(iter(time_series.values())) n_minutes = len(first_out["value"]) @@ -262,14 +295,12 @@ def plot_test_only(configs, time_series, hodl_values, config, args): ax.axhline(1.0, color="white", linestyle=":", alpha=0.3, linewidth=1) _style_axis(ax) - tokens_str = "/".join(config["tokens"]) - obj_name = config.get("return_val", "objective") - ntr = config.get("noise_trader_ratio", 0.0) - ax.set_title(f"Test Period Only (normalised) — {obj_name}, noise={ntr} — {tokens_str}", + tokens_str = "/".join(ref_config["tokens"]) + ax.set_title(f"Test Period Only (normalised) — {tokens_str}", color=TEXT_COLOR, fontsize=13, pad=15) ax.set_ylabel("Normalised Value", color=TEXT_COLOR, fontsize=12) ax.set_xlabel("Date", color=TEXT_COLOR, fontsize=12) - ax.legend(loc="best", fontsize=9, facecolor=BG, + ax.legend(loc="best", fontsize=8, facecolor=BG, edgecolor=TEXT_COLOR, labelcolor=TEXT_COLOR) fig.patch.set_facecolor(BG) @@ -281,10 +312,10 @@ def plot_test_only(configs, time_series, hodl_values, config, args): plt.close() -def plot_weights(configs, time_series, config, args): +def plot_weights(configs, time_series, ref_config, args): """Effective weight (value fraction) of token 0 over time.""" - start_dt = datetime.strptime(config["startDateString"], "%Y-%m-%d %H:%M:%S") - train_end_dt = datetime.strptime(config["endDateString"], "%Y-%m-%d %H:%M:%S") + start_dt = datetime.strptime(ref_config["startDateString"], "%Y-%m-%d %H:%M:%S") + train_end_dt = datetime.strptime(ref_config["endDateString"], "%Y-%m-%d %H:%M:%S") first_out = next(iter(time_series.values())) n_minutes = len(first_out["value"]) @@ -292,7 +323,7 @@ def plot_weights(configs, time_series, config, args): step = 1440 dates_daily = dates[::step] - token_name = config["tokens"][0] + token_name = ref_config["tokens"][0] fig, ax = plt.subplots(1, 1, figsize=(14, 5)) @@ -310,18 +341,18 @@ def plot_weights(configs, time_series, config, args): ax.axhline(0.5, color="white", linestyle="--", alpha=0.3, linewidth=1) ax.axvline(x=train_end_dt, color="white", linestyle=":", alpha=0.5, linewidth=1.5) ylims = ax.get_ylim() - ax.text(train_end_dt - pd.Timedelta(days=10), ylims[1] * 0.97, "Train", + ax.text(train_end_dt - pd.Timedelta(days=5), ylims[1] * 0.97, "Train", color="white", alpha=0.6, fontsize=11, ha="right", va="top") - ax.text(train_end_dt + pd.Timedelta(days=10), ylims[1] * 0.97, "Test", + ax.text(train_end_dt + pd.Timedelta(days=5), ylims[1] * 0.97, "Test", color="white", alpha=0.6, fontsize=11, ha="left", va="top") _style_axis(ax) - tokens_str = "/".join(config["tokens"]) + tokens_str = "/".join(ref_config["tokens"]) ax.set_title(f"Effective {token_name} Weight — {tokens_str}", color=TEXT_COLOR, fontsize=13, pad=15) ax.set_ylabel(f"{token_name} weight (value fraction)", color=TEXT_COLOR, fontsize=12) ax.set_xlabel("Date", color=TEXT_COLOR, fontsize=12) - ax.legend(loc="best", fontsize=9, facecolor=BG, + ax.legend(loc="best", fontsize=8, facecolor=BG, edgecolor=TEXT_COLOR, labelcolor=TEXT_COLOR) fig.patch.set_facecolor(BG) @@ -346,60 +377,79 @@ def _style_axis(ax): def main(): args = parse_args() - config, trials = load_results(args.results_json) - if args.end_test_date: - config["endTestDateString"] = args.end_test_date - if args.noise_trader_ratio is not None: - config["noise_trader_ratio"] = args.noise_trader_ratio - tokens = config["tokens"] - obj_name = config.get("return_val", "objective") - - # Sort trials by penalised objective - trials_sorted = sorted(trials, key=lambda t: t.get("objective", 0), reverse=True) - top_trials = trials_sorted[:args.top_k] - - print("=" * 80) - print(f"reClAMM Optuna Result Plotter — objective: {obj_name}") - print("=" * 80) - print(f" Results: {args.results_json}") + + # ── Load all result files ───────────────────────────────────────── + all_loaded = [] + for path in args.results_json: + config, trials = load_results(path) + if args.end_test_date: + config["endTestDateString"] = args.end_test_date + if args.noise_trader_ratio is not None: + config["noise_trader_ratio"] = args.noise_trader_ratio + all_loaded.append((path, config, trials)) + + # Use first file's config as reference for dates/tokens + ref_config = all_loaded[0][1] + tokens = ref_config["tokens"] + + print("=" * 100) + print(f"reClAMM Optuna Result Plotter — {len(all_loaded)} result file(s)") + print("=" * 100) print(f" Tokens: {'/'.join(tokens)}") - print(f" Train: {config['startDateString']} → {config['endDateString']}") - print(f" Test: {config['endDateString']} → {config['endTestDateString']}") - print(f" Fees: {config['fees']}, Gas: {config.get('gas_cost', 1.0)}") - print(f" Trials: {len(trials)} total, plotting top {len(top_trials)}") + print(f" Train: {ref_config['startDateString']} → {ref_config['endDateString']}") + print(f" Test: {ref_config['endDateString']} → {ref_config['endTestDateString']}") + # ── Build configs dict from all files ───────────────────────────── configs = {} - for i, trial in enumerate(top_trials): - params = extract_pool_params(trial, config) - name = f"#{trial.get('optuna_trial_number', i)} (rank {i+1})" - configs[name] = { - "params": params, - "objective": trial.get("objective", 0), - "train_objective": trial.get("train_objective", 0), - "test_objective": trial.get("test_objective", 0), - "train_sharpe": trial.get("train_sharpe", 0), - "validation_sharpe": trial.get("validation_sharpe", 0), - } - print(f"\n {name}:") - print(f" {obj_name}: train={trial.get('train_objective', 0):.4f} " - f"test={trial.get('test_objective', 0):.4f} " - f"penalised={trial.get('objective', 0):.4f}") - print(f" sharpe: train={trial.get('train_sharpe', 0):+.4f} " - f"val={trial.get('validation_sharpe', 0):+.4f}") - for k, v in params.items(): - print(f" {k}: {v:.6g}") + for path, config, trials in all_loaded: + obj_name = config.get("return_val", "objective") + obj_short = _OBJ_SHORT.get(obj_name, obj_name) + noise_label = _noise_model_label(config) + + trials_sorted = sorted(trials, key=lambda t: t.get("objective", 0), reverse=True) + top_trials = trials_sorted[:args.top_k] + + for i, trial in enumerate(top_trials): + params = extract_pool_params(trial, config) + rank_suffix = f" r{i+1}" if args.top_k > 1 else "" + name = f"{obj_short} {noise_label}{rank_suffix}" + configs[name] = { + "params": params, + "config": config, # per-file config for noise model + "objective": trial.get("objective", 0), + "train_objective": trial.get("train_objective", 0), + "test_objective": trial.get("test_objective", 0), + "train_sharpe": trial.get("train_sharpe", 0), + "validation_sharpe": trial.get("validation_sharpe", 0), + "obj_name": obj_name, + } + print(f"\n {name}:") + print(f" {obj_name}: train={trial.get('train_objective', 0):.4f} " + f"test={trial.get('test_objective', 0):.4f} " + f"penalised={trial.get('objective', 0):.4f}") + print(f" sharpe: train={trial.get('train_sharpe', 0):+.4f} " + f"val={trial.get('validation_sharpe', 0):+.4f}") + for k, v in params.items(): + print(f" {k}: {v:.6g}") if not args.no_onchain: - configs["On-Chain (launch)"] = {"params": dict(ONCHAIN_LAUNCH_PARAMS)} - configs["On-Chain (current)"] = {"params": dict(ONCHAIN_CURRENT_PARAMS)} + configs["On-Chain (launch)"] = { + "params": dict(ONCHAIN_LAUNCH_PARAMS), + "config": ref_config, + } + configs["On-Chain (current)"] = { + "params": dict(ONCHAIN_CURRENT_PARAMS), + "config": ref_config, + } # ── Full-period runs ────────────────────────────────────────────── - print(f"\n--- Running full-period simulations ({config['startDateString']} → " - f"{config['endTestDateString']}) ---") + print(f"\n--- Running full-period simulations ({ref_config['startDateString']} → " + f"{ref_config['endTestDateString']}) ---") time_series = {} for name, cfg in configs.items(): print(f" {name}...", end=" ", flush=True) - out = run_full_period(cfg["params"], config) + run_config = cfg.get("config", ref_config) + out = run_full_period(cfg["params"], run_config) time_series[name] = out fv = float(out["final_value"]) fr = out.get("fee_revenue") @@ -415,28 +465,29 @@ def main(): ) # ── Plots ───────────────────────────────────────────────────────── - plot_results(configs, time_series, hodl_values, config, args) - plot_test_only(configs, time_series, hodl_values, config, args) - plot_weights(configs, time_series, config, args) + plot_results(configs, time_series, hodl_values, ref_config, args) + plot_test_only(configs, time_series, hodl_values, ref_config, args) + plot_weights(configs, time_series, ref_config, args) # ── Summary table ───────────────────────────────────────────────── - print(f"\n{'=' * 120}") - print(f"SUMMARY — {'/'.join(tokens)} — {obj_name}") - print(f"{'=' * 120}") - hdr = (f"{'Config':<28s} {'Train '+obj_name:>20s} {'Test '+obj_name:>20s} " + print(f"\n{'=' * 130}") + print(f"SUMMARY — {'/'.join(tokens)}") + print(f"{'=' * 130}") + hdr = (f"{'Config':<35s} {'Objective':>12s} {'Train':>10s} {'Test':>10s} " f"{'Train SR':>10s} {'Val SR':>10s} " f"{'PR':>7s} {'Margin':>7s} {'ShiftExp':>10s} {'Full RoH':>10s}") print(hdr) - print("-" * 120) + print("-" * 130) for name, cfg in configs.items(): cp = cfg["params"] fv = float(time_series[name]["final_value"]) full_roh = fv / float(hodl_values[-1]) - 1 print( - f"{name:<28s} " - f"{cfg.get('train_objective', float('nan')):>20.4f} " - f"{cfg.get('test_objective', float('nan')):>20.4f} " + f"{name:<35s} " + f"{cfg.get('obj_name', ''):>12s} " + f"{cfg.get('train_objective', float('nan')):>10.4f} " + f"{cfg.get('test_objective', float('nan')):>10.4f} " f"{cfg.get('train_sharpe', float('nan')):>+10.4f} " f"{cfg.get('validation_sharpe', float('nan')):>+10.4f} " f"{cp.get('price_ratio', float('nan')):>7.3f} " @@ -444,7 +495,7 @@ def main(): f"{cp.get('shift_exponent', float('nan')):>10.4g} " f"{full_roh * 100:>+9.2f}%" ) - print("=" * 120) + print("=" * 130) if __name__ == "__main__": diff --git a/scripts/reclamm/compare_reclamm_geometric_noise_runs.py b/scripts/reclamm/compare_reclamm_geometric_noise_runs.py new file mode 100644 index 0000000..240482b --- /dev/null +++ b/scripts/reclamm/compare_reclamm_geometric_noise_runs.py @@ -0,0 +1,679 @@ +"""Compare two geometric reCLAMM runs against matched arb-only baselines. + +This script reuses the same AAVE/ETH reCLAMM fingerprint and parameter wiring as +``compare_reclamm_thermostats.py``, but runs only the geometric interpolation +mode and plots: +1. Share price / TVL over time in absolute USD terms +2. Pool weights over time +3. Estimated gross swap volume over time +4. Noise-model improvement over arb-only over time + +Because these runs use no LP supply changes, share price and TVL are the same +series here. + +Usage: + python scripts/reclamm/compare_reclamm_geometric_noise_runs.py + python scripts/reclamm/compare_reclamm_geometric_noise_runs.py \ + --adjacent-csv scripts/results/...csv --adjacent-row-index 0 +""" + +from __future__ import annotations + +import argparse +import importlib.util +import json +from pathlib import Path +from typing import Mapping, Optional, Sequence + +import numpy as np +import pandas as pd + + +DEFAULT_SOURCE_HEATMAP_DESCRIPTION = ( + "market_linear price_ratio_vs_margin heatmap with shift_exp fixed at 0.10" +) +DEFAULT_RUN_SPECS = [ + { + "name": "Green cell near price_ratio 1.31", + "price_ratio": 1.31, + "centeredness_margin": 0.6763157894736842, + "daily_price_shift_exponent": 0.10, + "tvl_usd": 1_000_000.0, + "color": "C0", + "reason": ( + "Geometric noise-model run taken from the positive cell in the " + "market_linear price_ratio-vs-margin heatmap at price_ratio=1.31 and the " + "lower adjacent centeredness row." + ), + }, + { + "name": "Red cell near price_ratio 1.31", + "price_ratio": 1.31, + "centeredness_margin": 0.7210526315789474, + "daily_price_shift_exponent": 0.10, + "tvl_usd": 1_000_000.0, + "color": "C1", + "reason": ( + "Geometric noise-model run taken from the negative cell directly " + "above the green cell in the market_linear price_ratio-vs-margin heatmap " + "at price_ratio=1.31." + ), + }, +] +DEFAULT_OUTPUT_FILE = "reclamm_geometric_noise_pair_compare.png" +VARIANT_STYLES = { + "noise": {"linestyle": "-", "alpha": 0.95, "linewidth": 2.2}, + "arb": {"linestyle": "--", "alpha": 0.85, "linewidth": 2.0}, +} + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Compare two geometric reCLAMM runs plus matched arb-only baselines. " + "Optionally source the run pair from an adjacent-heatmap CSV row." + ) + ) + parser.add_argument( + "--adjacent-csv", + default=None, + help="Optional adjacent-pairs CSV generated by find_adjacent_heatmap_pairs.py.", + ) + parser.add_argument( + "--adjacent-row-index", + type=int, + default=0, + help="Which row from --adjacent-csv to use. Defaults to the largest-diff row.", + ) + parser.add_argument( + "--output-file", + default=None, + help="Optional PNG output path override.", + ) + return parser.parse_args() + + +def load_runtime_dependencies(): + """Load heavy runtime dependencies only when an actual simulation is needed.""" + thermostat_path = Path(__file__).with_name("compare_reclamm_thermostats.py") + spec = importlib.util.spec_from_file_location( + "reclamm_compare_reclamm_thermostats_runtime", + thermostat_path, + ) + if spec is None or spec.loader is None: + raise RuntimeError(f"Could not load compare module from {thermostat_path}") + thermostat_compare = importlib.util.module_from_spec(spec) + try: + spec.loader.exec_module(thermostat_compare) + except ModuleNotFoundError as exc: # pragma: no cover - depends on local runtime deps + if exc.name == "jax": + raise RuntimeError( + "compare_reclamm_geometric_noise_runs.py requires JAX to execute " + "the simulator runs, but JAX is not available in this environment." + ) from exc + raise + + from quantammsim.runners.jax_runners import do_run_on_historic_data + + return thermostat_compare, do_run_on_historic_data + + +def build_default_base_config(thermostat_compare): + """Return the aggressive AAVE/ETH geometric base config used by the compare script.""" + return dict(thermostat_compare.CONFIGS[1]) + + +def load_adjacent_csv_row(csv_path: Path, row_index: int = 0) -> Mapping[str, object]: + """Load one row from an adjacent-pairs CSV.""" + frame = pd.read_csv(csv_path) + if frame.empty: + raise ValueError(f"Adjacent CSV is empty: {csv_path}") + if not 0 <= row_index < len(frame): + raise IndexError( + f"adjacent-row-index {row_index} is out of range for {len(frame)} rows" + ) + return frame.iloc[row_index].to_dict() + + +def build_run_specs_from_adjacent_row( + row: Mapping[str, object], + csv_path: Optional[Path] = None, + row_index: int = 0, +): + """Convert one adjacent-pairs CSV row into the two run specs this script compares.""" + metric_key = str(row.get("metric_key", "unknown_metric")) + metric_unit = str(row.get("metric_unit", "value")) + pair_slug = str(row.get("pair_slug", "unknown_pair")) + slice_slug = str(row.get("slice_slug", "unknown_slice")) + adjacency_axis = str(row.get("adjacency_axis", "unknown_axis")) + diff_abs = float(row["heatmap_value_diff_abs"]) + source_noise_profile = str(row.get("source_noise_profile", "unknown")) + + source_description = ( + f"{pair_slug} {slice_slug} adjacent pair from {metric_key} " + f"({adjacency_axis}, abs diff={diff_abs:.6f} {metric_unit}, " + f"noise_profile={source_noise_profile})" + ) + if csv_path is not None: + source_description = ( + f"{csv_path.name} row {row_index} | {source_description}" + ) + + run_specs = [] + for prefix, color in (("1", "C0"), ("2", "C1")): + heatmap_value = float(row[f"{prefix}_heatmap_value"]) + spec = { + "name": f"Top diff row cell {prefix}", + "price_ratio": float(row[f"{prefix}_price_ratio"]), + "centeredness_margin": float(row[f"{prefix}_centeredness_margin"]), + "daily_price_shift_exponent": float(row[f"{prefix}_daily_price_shift_exponent"]), + "tvl_usd": float(row[f"{prefix}_tvl_usd"]), + "color": color, + "source_noise_profile": source_noise_profile, + "reason": ( + f"Run derived from adjacent heatmap CSV row {row_index}, cell {prefix}. " + f"Source={pair_slug}/{slice_slug}, adjacency={adjacency_axis}, " + f"heatmap_value={heatmap_value:.6f} {metric_unit}, " + f"noise_profile={source_noise_profile}." + ), + } + run_specs.append(spec) + return source_description, run_specs + + +def default_output_file_for_adjacent_csv(csv_path: Path, row_index: int = 0) -> Path: + """Build a deterministic output PNG path for an adjacent-pairs CSV selection.""" + stem = f"{csv_path.stem}_row_{row_index}_geometric_noise_compare" + return csv_path.with_name(stem + ".png") + + +def build_run_config(spec, base_config): + """Build the base geometric reCLAMM config for one highlighted heatmap cell.""" + cfg = dict(base_config) + cfg.update( + { + "name": spec["name"], + "price_ratio": float(spec["price_ratio"]), + "centeredness_margin": float(spec["centeredness_margin"]), + "daily_price_shift_exponent": float(spec["daily_price_shift_exponent"]), + "initial_pool_value": float(spec["tvl_usd"]), + "reason": spec.get( + "reason", + "Standalone geometric noise-model comparison run.", + ), + "enable_noise_model": True, + } + ) + source_noise_profile = spec.get("source_noise_profile") + if source_noise_profile == "market_linear": + cfg["noise_model"] = "market_linear" + elif source_noise_profile not in (None, "", "unknown"): + raise ValueError( + "compare_reclamm_geometric_noise_runs.py only supports the current " + f"market_linear source profile, got {source_noise_profile!r}. " + "Regenerate the adjacent-pairs CSV with the current heatmap cache." + ) + return cfg + + +def build_run_variants(spec, base_config, thermostat_compare): + """Build matched noise-model and arb-only variants for one highlighted cell.""" + noise_cfg = thermostat_compare.make_noise_variant_cfg( + build_run_config(spec, base_config=base_config), + True, + ) + noise_cfg["variant_key"] = "noise" + noise_cfg["variant_label"] = "noise-model" + + arb_cfg = thermostat_compare.make_noise_variant_cfg(noise_cfg, False) + arb_cfg["variant_key"] = "arb" + arb_cfg["variant_label"] = "arb-only" + arb_cfg["name"] = noise_cfg["name"] + arb_cfg["reason"] = ( + f"{noise_cfg['reason']} Matched arb-only baseline with noise disabled." + ) + return {"spec": spec, "noise": noise_cfg, "arb": arb_cfg} + + +def build_run_label(cfg, thermostat_compare): + """Build a compact legend label for a run.""" + return ( + f"{cfg['name']} ({cfg.get('variant_label', 'run')}) | PR {cfg['price_ratio']:.4g}, " + f"M {cfg['centeredness_margin']:.3g}, " + f"Shift {cfg['daily_price_shift_exponent']:.3g}, " + f"TVL {thermostat_compare.format_tvl_millions_label(cfg)}" + ) + + +def build_time_index(run_fingerprint, periods, step_minutes=1): + """Return a DatetimeIndex for a result series with the given cadence.""" + return pd.date_range( + start=pd.Timestamp(run_fingerprint["startDateString"]), + periods=int(periods), + freq=f"{max(int(step_minutes), 1)}min", + ) + + +def infer_series_step_minutes(run_fingerprint, series_length, minute_length): + """Infer the cadence of a result series from its length.""" + if minute_length and series_length and series_length != minute_length: + ratio = float(minute_length) / float(series_length) + rounded_ratio = max(int(round(ratio)), 1) + if abs(ratio - rounded_ratio) < 1.0e-9: + return rounded_ratio + return max(int(run_fingerprint.get("arb_frequency", 1)), 1) + + +def build_daily_weight_series( + weights, + tokens, + run_fingerprint, + minute_length, +): + """Build a daily weight frame from a weight array at inferred cadence.""" + weights = np.asarray(weights, dtype=float) + weight_step_minutes = infer_series_step_minutes( + run_fingerprint, + len(weights), + minute_length, + ) + weight_dates = build_time_index( + run_fingerprint, + len(weights), + step_minutes=weight_step_minutes, + ) + weight_frame = pd.DataFrame(weights, index=weight_dates, columns=tokens) + return weight_frame.resample("1D").last() + + +def estimate_gross_volume_usd(cfg, result, default_protocol_fee_split=0.25): + """Recover gross traded USD volume from LP fee revenue.""" + fee_revenue = result.get("fee_revenue") + if fee_revenue is None: + return np.zeros(len(np.asarray(result["value"])), dtype=float) + + fee_revenue = np.asarray(fee_revenue, dtype=float) + lp_fee_rate_share = float(cfg["fees"]) * ( + 1.0 - float(cfg.get("protocol_fee_split", default_protocol_fee_split)) + ) + if lp_fee_rate_share <= 0.0: + return np.zeros_like(fee_revenue) + return fee_revenue / lp_fee_rate_share + + +def build_daily_run_series(cfg, run_fingerprint, result, default_protocol_fee_split=0.25): + """Build daily TVL/share-price, weight, and volume series for plotting.""" + value = np.asarray(result["value"], dtype=float) + value_dates = build_time_index(run_fingerprint, len(value), step_minutes=1) + value_series = pd.Series(value, index=value_dates) + daily_value = value_series.resample("1D").last() + + zero_fee_weights = np.asarray(result["weights"], dtype=float) + daily_zero_fee_weights = build_daily_weight_series( + zero_fee_weights, + cfg["tokens"], + run_fingerprint, + len(value), + ) + + reserves = np.asarray(result["reserves"], dtype=float) + prices = np.asarray(result["prices"], dtype=float) + reserve_value = reserves * prices + reserve_value_totals = np.maximum( + reserve_value.sum(axis=1, keepdims=True), + 1.0e-12, + ) + actual_reserve_value_weights = reserve_value / reserve_value_totals + daily_actual_reserve_value_weights = build_daily_weight_series( + actual_reserve_value_weights, + cfg["tokens"], + run_fingerprint, + len(value), + ) + + gross_volume_usd = estimate_gross_volume_usd( + cfg, + result, + default_protocol_fee_split=default_protocol_fee_split, + ) + volume_step_minutes = ( + 1 + if len(gross_volume_usd) == len(value) + else infer_series_step_minutes(run_fingerprint, len(gross_volume_usd), len(value)) + ) + volume_dates = build_time_index( + run_fingerprint, + len(gross_volume_usd), + step_minutes=volume_step_minutes, + ) + daily_volume = pd.Series(gross_volume_usd, index=volume_dates).resample("1D").sum() + + return { + "daily_value": daily_value, + "daily_zero_fee_weights": daily_zero_fee_weights, + "daily_actual_reserve_value_weights": daily_actual_reserve_value_weights, + "daily_volume": daily_volume, + } + + +def _terminal_json_default(value): + """Serialize NumPy/JAX-backed values for readable terminal logging.""" + if isinstance(value, Path): + return str(value) + if hasattr(value, "tolist"): + return value.tolist() + return str(value) + + +def print_run_inputs_to_terminal(cfg, run_fingerprint, update_params): + """Print the full run fingerprint and update params for a triggered run.""" + print( + f"Run inputs for {cfg['name']} ({cfg.get('variant_label', 'run')}):" + ) + print( + json.dumps( + { + "run_fingerprint": run_fingerprint, + "update_params": update_params, + }, + indent=2, + sort_keys=True, + default=_terminal_json_default, + ) + ) + + +def run_single_config(cfg, price_data, thermostat_compare, do_run_on_historic_data): + """Run one geometric noise-model configuration.""" + run_fingerprint = thermostat_compare.make_fingerprint(cfg, "geometric") + update_params = thermostat_compare.make_params(cfg) + print_run_inputs_to_terminal(cfg, run_fingerprint, update_params) + result = do_run_on_historic_data( + run_fingerprint=run_fingerprint, + params=update_params, + price_data=price_data, + ) + return { + "config": cfg, + "fingerprint": run_fingerprint, + "noise_summary": thermostat_compare.resolve_reclamm_noise_settings(cfg)[ + "noise_summary" + ], + "result": result, + "series": build_daily_run_series( + cfg, + run_fingerprint, + result, + default_protocol_fee_split=thermostat_compare.DEFAULT_PROTOCOL_FEE_SPLIT, + ), + } + + +def plot_pair_results( + run_pairs, + thermostat_compare, + source_heatmap_description, + output_file, +): + """Plot paired noise-model/arb-only outputs for the two highlighted cells.""" + import matplotlib.pyplot as plt + + fig, axes = plt.subplots( + 5, + 1, + figsize=(14, 16), + sharex=True, + gridspec_kw={"height_ratios": [2.2, 1.5, 1.5, 1.4, 1.6]}, + ) + ( + ax_value, + ax_zero_fee_weights, + ax_actual_reserve_weights, + ax_volume, + ax_improvement, + ) = axes + + for pair in run_pairs: + spec = pair["spec"] + color = spec["color"] + noise_output = pair["noise"] + arb_output = pair["arb"] + + for variant_key, output in (("noise", noise_output), ("arb", arb_output)): + cfg = output["config"] + label = build_run_label(cfg, thermostat_compare=thermostat_compare) + series = output["series"] + style = VARIANT_STYLES[variant_key] + + ax_value.plot( + series["daily_value"].index, + series["daily_value"].to_numpy(dtype=float) / 1e6, + color=color, + linestyle=style["linestyle"], + linewidth=style["linewidth"], + alpha=style["alpha"], + label=label, + ) + + for token_idx, token in enumerate(cfg["tokens"]): + token_linestyle = style["linestyle"] if token_idx == 0 else ":" + ax_zero_fee_weights.plot( + series["daily_zero_fee_weights"].index, + series["daily_zero_fee_weights"][token].to_numpy(dtype=float), + color=color, + linestyle=token_linestyle, + linewidth=1.8 if token_idx == 0 else 1.6, + alpha=style["alpha"], + label=f"{cfg['name']} {cfg['variant_label']} {token}", + ) + ax_actual_reserve_weights.plot( + series["daily_actual_reserve_value_weights"].index, + series["daily_actual_reserve_value_weights"][token].to_numpy( + dtype=float + ), + color=color, + linestyle=token_linestyle, + linewidth=1.8 if token_idx == 0 else 1.6, + alpha=style["alpha"], + label=f"{cfg['name']} {cfg['variant_label']} {token}", + ) + + ax_volume.plot( + series["daily_volume"].index, + series["daily_volume"].to_numpy(dtype=float) / 1e6, + color=color, + linestyle=style["linestyle"], + linewidth=style["linewidth"], + alpha=style["alpha"], + label=f"{cfg['name']} {cfg['variant_label']}", + ) + + noise_value, arb_value = noise_output["series"]["daily_value"].align( + arb_output["series"]["daily_value"], + join="inner", + ) + improvement_pct = (noise_value - arb_value) / arb_value * 100.0 + ax_improvement.plot( + improvement_pct.index, + improvement_pct.to_numpy(dtype=float), + color=color, + linewidth=2.2, + label=noise_output["config"]["name"], + ) + + shared_noise_summary = run_pairs[0]["noise"]["noise_summary"] + fig.suptitle( + "reCLAMM geometric noise-model vs arb-only comparison", + fontsize=14, + fontweight="bold", + ) + fig.text( + 0.5, + 0.965, + ( + "Compare-script AAVE/ETH fingerprint | " + f"Cell source: {source_heatmap_description} | " + f"Noise: {shared_noise_summary} | " + "Share price equals TVL here because LP supply is fixed at 1.0" + ), + ha="center", + va="top", + fontsize=10, + ) + + ax_value.set_ylabel("Share price / TVL ($M)") + ax_value.set_title("Absolute share price / TVL") + ax_value.grid(True, alpha=0.3) + ax_value.legend(fontsize=8) + + ax_zero_fee_weights.set_ylabel("Weight") + ax_zero_fee_weights.set_title("Reported zero-fee empirical weights") + ax_zero_fee_weights.set_ylim(-0.02, 1.02) + ax_zero_fee_weights.grid(True, alpha=0.3) + ax_zero_fee_weights.legend(fontsize=8, ncol=2) + + ax_actual_reserve_weights.set_ylabel("Weight") + ax_actual_reserve_weights.set_title("Actual reserve value weights") + ax_actual_reserve_weights.set_ylim(-0.02, 1.02) + ax_actual_reserve_weights.grid(True, alpha=0.3) + ax_actual_reserve_weights.legend(fontsize=8, ncol=2) + + ax_volume.set_ylabel("Daily volume ($M)") + ax_volume.set_title("Estimated gross swap volume") + ax_volume.grid(True, alpha=0.3) + ax_volume.legend(fontsize=8) + + ax_improvement.axhline(0.0, color="black", linewidth=0.9, alpha=0.55) + ax_improvement.set_ylabel("Noise vs arb (%)") + ax_improvement.set_title("Daily TVL improvement: (noise - arb) / arb") + ax_improvement.set_xlabel("Date") + ax_improvement.grid(True, alpha=0.3) + ax_improvement.legend(fontsize=8) + + output_file = Path(output_file) + output_file.parent.mkdir(parents=True, exist_ok=True) + plt.tight_layout(rect=(0.0, 0.0, 1.0, 0.945)) + plt.savefig(output_file, dpi=180) + print(f"Saved {output_file}") + plt.close(fig) + + +def print_final_heatmap_summary(run_pairs, source_heatmap_description): + """Print the final values and heatmap-equivalent improvement for each pair.""" + print(f"\nSource heatmap selection: {source_heatmap_description}") + print("Final values represented by the heatmap metric:") + for pair in run_pairs: + spec = pair["spec"] + noise_final = float(pair["noise"]["series"]["daily_value"].iloc[-1]) + arb_final = float(pair["arb"]["series"]["daily_value"].iloc[-1]) + improvement_pct = (noise_final - arb_final) / arb_final * 100.0 + print( + f" {spec['name']}: " + f"noise=${noise_final:,.2f}, " + f"arb=${arb_final:,.2f}, " + f"heatmap_improvement={improvement_pct:.6f}%" + ) + + +def run_pair_comparison( + run_specs: Sequence[Mapping[str, object]], + source_heatmap_description: str, + output_file, +): + """Run both configs and render the comparison figure.""" + thermostat_compare, do_run_on_historic_data = load_runtime_dependencies() + base_config = build_default_base_config(thermostat_compare) + run_pairs = [ + build_run_variants( + spec, + base_config=base_config, + thermostat_compare=thermostat_compare, + ) + for spec in run_specs + ] + run_configs = [ + cfg + for pair in run_pairs + for variant_key, cfg in pair.items() + if variant_key in ("noise", "arb") + ] + price_data = thermostat_compare.load_shared_price_data(run_configs) + + completed_pairs = [] + for pair in run_pairs: + completed_pair = {"spec": pair["spec"]} + for variant_key in ("noise", "arb"): + cfg = pair[variant_key] + print( + f"Running {cfg['name']} ({cfg['variant_label']}) | " + f"price_ratio={cfg['price_ratio']}, " + f"margin={cfg['centeredness_margin']}, " + f"shift_exp={cfg['daily_price_shift_exponent']}, " + f"TVL={thermostat_compare.format_tvl_millions_label(cfg)}" + ) + completed_pair[variant_key] = run_single_config( + cfg, + price_data, + thermostat_compare=thermostat_compare, + do_run_on_historic_data=do_run_on_historic_data, + ) + completed_pairs.append(completed_pair) + + plot_pair_results( + completed_pairs, + thermostat_compare=thermostat_compare, + source_heatmap_description=source_heatmap_description, + output_file=output_file, + ) + print_final_heatmap_summary( + completed_pairs, + source_heatmap_description=source_heatmap_description, + ) + return output_file + + +def run_adjacent_csv_row_comparison( + csv_path, + row_index: int = 0, + output_file=None, +): + """Run the standard geometric comparison using one adjacent-pairs CSV row.""" + csv_path = Path(csv_path) + row = load_adjacent_csv_row(csv_path, row_index=row_index) + source_heatmap_description, run_specs = build_run_specs_from_adjacent_row( + row, + csv_path=csv_path, + row_index=row_index, + ) + resolved_output_file = ( + Path(output_file) + if output_file is not None + else default_output_file_for_adjacent_csv(csv_path, row_index=row_index) + ) + return run_pair_comparison( + run_specs=run_specs, + source_heatmap_description=source_heatmap_description, + output_file=resolved_output_file, + ) + + +def main(cli_args: Optional[argparse.Namespace] = None): + """Entry point for CLI execution.""" + args = cli_args or parse_args() + if args.adjacent_csv: + return run_adjacent_csv_row_comparison( + args.adjacent_csv, + row_index=args.adjacent_row_index, + output_file=args.output_file, + ) + + output_file = Path(args.output_file) if args.output_file else Path(DEFAULT_OUTPUT_FILE) + return run_pair_comparison( + run_specs=DEFAULT_RUN_SPECS, + source_heatmap_description=DEFAULT_SOURCE_HEATMAP_DESCRIPTION, + output_file=output_file, + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/reclamm/compare_reclamm_thermostats.py b/scripts/reclamm/compare_reclamm_thermostats.py index 8a2c374..f705c66 100644 --- a/scripts/reclamm/compare_reclamm_thermostats.py +++ b/scripts/reclamm/compare_reclamm_thermostats.py @@ -1,19 +1,40 @@ -"""Compare geometric vs constant-arc-length thermostats on historic data. +"""Compare reCLAMM interpolation modes on historic AAVE/ETH data. -Runs AAVE/ETH reClAMM pool simulations with both interpolation methods. -Plots: pool value, cumulative LVR, price path, empirical weights, -value difference, LVR ratio, and per-step LVR distribution (∝ Δs²). +Runs the production geometric interpolation against the non-linear +constant-arc-length interpolation on: +1. The original launch-style range (price_ratio ~= 1.50) +2. A much tighter range (price_ratio = 1.10) -Usage: - cd - source ~/miniconda3/etc/profile.d/conda.sh && conda activate qsim-reclamm - python scripts/compare_reclamm_thermostats.py +The aggressive case is deliberate. A local AAVE/ETH sweep showed: +price_ratio 1.15, margin 0.5, shift 0.1 -> about +$10k vs geometric +price_ratio 1.10, margin 0.5, shift 0.1 -> about +$31k vs geometric +price_ratio 1.10, margin 0.6, shift 0.1 -> about +$73k vs geometric + +So the strongest clean demo setting came from tightening the band and +slightly raising the trigger margin, while keeping the launch-style shift +speed rather than pushing shift_exponent higher. """ +import gc +import hashlib +import os +from pathlib import Path + import jax.numpy as jnp import numpy as np +import pandas as pd import matplotlib.pyplot as plt +from matplotlib.colors import Normalize, SymLogNorm, TwoSlopeNorm +from matplotlib.cm import ScalarMappable +from quantammsim.pools.reCLAMM.reclamm_reserves import ( + calibrate_arc_length_speed, + compute_price_ratio, + initialise_reclamm_reserves, +) from quantammsim.runners.jax_runners import do_run_on_historic_data +from quantammsim.utils.data_processing.historic_data_utils import ( + get_historic_parquet_data, +) def to_daily_price_shift_base(daily_price_shift_exponent): @@ -21,61 +42,515 @@ def to_daily_price_shift_base(daily_price_shift_exponent): return 1.0 - daily_price_shift_exponent / 124649.0 +def build_inclusive_sweep(start, stop, step): + """Build a sweep that keeps the requested step and explicitly includes the stop.""" + values = np.arange(start, stop + 1.0e-12, step, dtype=float) + if values.size == 0 or not np.isclose(values[-1], stop): + values = np.append(values, float(stop)) + return values + + +def _resolve_repo_root(script_path): + """Locate the repository root from either scripts/ or scripts/reclamm/.""" + script_path = Path(script_path).resolve() + for parent in script_path.parents: + if (parent / "quantammsim").exists() and (parent / "scripts").exists(): + return parent + return script_path.parents[1] + + +RUN_CONSTANT_ARC_LENGTH = True +INTERPOLATION_METHODS = ( + ("geometric", "constant_arc_length") + if RUN_CONSTANT_ARC_LENGTH + else ("geometric",) +) +HEATMAP_PRICE_RATIOS = build_inclusive_sweep(1.01, 3.00, 0.025) +HEATMAP_MARGINS = np.linspace(0.05, 0.90, 39) +HEATMAP_SHIFT_EXPONENTS = build_inclusive_sweep(0.01, 0.50, 0.0125) +HEATMAP_ARC_LENGTH_SPEEDS = np.geomspace(1.0e-6, 5.0e-4, 11) +PRICE_RATIO_TICKS = np.array([1.01, 1.25, 1.50, 2.00, 2.50, 3.00]) +MARGIN_TICKS = np.array([0.05, 0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85, 0.90]) +SHIFT_EXPONENT_TICKS = np.array([0.01, 0.05, 0.10, 0.20, 0.30, 0.40, 0.50]) +ARC_LENGTH_SPEED_TICKS = np.array([ + 1.0e-6, + 2.0e-6, + 5.0e-6, + 1.0e-5, + 2.0e-5, + 5.0e-5, + 1.0e-4, + 2.0e-4, + 5.0e-4, +]) +SWEEP_LINE_WIDTH = 0.45 +REFERENCE_LINE_WIDTH = 0.9 +DEFAULT_INITIAL_POOL_VALUE = 1_000_000.0 +TVL_SWEEP_VALUES = ( + 1_000_000.0, + 5_000_000.0, + 20_000_000.0, +) +CENTER_ZERO_HEATMAP_COLOR_NORM = "symlog" +CENTER_ZERO_HEATMAP_COLOR_TAG = "symlog20" +CENTER_ZERO_HEATMAP_SYMLOG_LINTHRESH = 20.0 +FIXED_SLICE_FRACTIONS = (0.125, 0.375, 0.625, 0.875) +FIXED_SLICE_LABELS = ("Q1", "Q2", "Q3", "Q4") +THREE_D_VIEW_ELEVATION = 22.0 +THREE_D_VIEW_AZIMUTH = 140.0 +HEATMAP_FORWARD_CACHE_ENABLED = True +HEATMAP_FORWARD_CACHE_RUN_NAME = "aave_eth_thermostat_heatmaps_market_linear_v2" +HEATMAP_FORWARD_CACHE_ROOT = os.path.join( + "results", + "reclamm_heatmap_forward_cache", +) +HEATMAP_FORWARD_CACHE_FLUSH_EVERY = 32 + +REPO_ROOT = _resolve_repo_root(__file__) +AAVE_WETH_POOL_ID = "0x9d1fcf346ea1b0" +DEFAULT_MARKET_LINEAR_ARTIFACT_DIR = "results/linear_market_noise" +DEFAULT_MARKET_LINEAR_NOISE_START_DATE = "2024-06-01" +DEFAULT_MARKET_LINEAR_NOISE_END_DATE = "2026-03-01" +DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH = str( + REPO_ROOT + / "results" + / "linear_market_noise" + / "_sim_arrays" + / ( + f"{AAVE_WETH_POOL_ID}_{DEFAULT_MARKET_LINEAR_NOISE_START_DATE}_" + f"{DEFAULT_MARKET_LINEAR_NOISE_END_DATE}.npz" + ) +) +DEFAULT_NOISE_MODEL = "market_linear" +DEFAULT_GAS_COST = 1.0 +DEFAULT_PROTOCOL_FEE_SPLIT = 0.25 +FIXED_COMPARE_ARB_FREQUENCY = 15 +AAVE_ETH_NOISE_SETTINGS = { + "enable_noise_model": True, + "noise_model": DEFAULT_NOISE_MODEL, + "noise_reference_model": DEFAULT_NOISE_MODEL, + "noise_artifact_dir": DEFAULT_MARKET_LINEAR_ARTIFACT_DIR, + "noise_pool_id": AAVE_WETH_POOL_ID, + "noise_arrays_path": DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH, + "arb_frequency": FIXED_COMPARE_ARB_FREQUENCY, + "gas_cost": DEFAULT_GAS_COST, + "protocol_fee_split": DEFAULT_PROTOCOL_FEE_SPLIT, +} +PERSISTED_FORWARD_VALUE_COLUMNS = ( + "cache_key_hash", + "final_value", + "method", + "enable_noise_model", + "noise_model", + "price_ratio", + "centeredness_margin", + "daily_price_shift_exponent", + "initial_pool_value", + "arb_frequency", +) + +GEOMETRIC_ONLY_HEATMAP_METRIC_KEYS = ( + "geometric_vs_launch_geometric_pct", + "noise_geometric_final_value_musd", + "noise_vs_arb_geometric_improvement_pct", +) +CONSTANT_ARC_HEATMAP_METRIC_KEYS = ( + "efficiency_pct", + "launch_geometric_efficiency_pct", + "constant_arc_vs_launch_constant_arc_pct", + "noise_constant_arc_final_value_musd", + "noise_vs_arb_constant_arc_improvement_pct", +) +HEATMAP_METRIC_DEPENDENCIES = { + "efficiency_pct": ("noise_geometric", "noise_constant_arc"), + "launch_geometric_efficiency_pct": ("noise_constant_arc",), + "geometric_vs_launch_geometric_pct": ("noise_geometric",), + "constant_arc_vs_launch_constant_arc_pct": ("noise_constant_arc",), + "noise_geometric_final_value_musd": ("noise_geometric",), + "noise_constant_arc_final_value_musd": ("noise_constant_arc",), + "noise_vs_arb_geometric_improvement_pct": ("noise_geometric", "arb_geometric"), + "noise_vs_arb_constant_arc_improvement_pct": ( + "noise_constant_arc", + "arb_constant_arc", + ), +} + +_NOISE_SETTINGS_CACHE = {} +_MARKET_LINEAR_NOISE_DATA_CACHE = {} + + +def get_initial_pool_value(cfg): + """Return the configured base pool TVL in USD.""" + return float(cfg.get("initial_pool_value", DEFAULT_INITIAL_POOL_VALUE)) + + +def get_tvl_millions(cfg): + """Return the configured base pool TVL in millions of USD.""" + return get_initial_pool_value(cfg) / 1_000_000.0 + + +def format_tvl_millions_slug(cfg): + """Format the TVL in millions for stable filenames.""" + tvl_millions = get_tvl_millions(cfg) + rounded = round(float(tvl_millions), 6) + if np.isclose(rounded, round(rounded)): + return f"{int(round(rounded))}m" + return f"{rounded:.6f}".rstrip("0").rstrip(".").replace(".", "p") + "m" + + +def format_tvl_millions_label(cfg): + """Format the TVL in millions for plot titles and logs.""" + return f"{get_tvl_millions(cfg):.1f}M" + + +def tvl_artifact_filename(stem, cfg, suffix=None): + """Append a TVL-in-millions suffix to a PNG artifact name.""" + parts = [stem] + if suffix: + parts.append(suffix) + parts.append(f"tvl_{format_tvl_millions_slug(cfg)}") + return "_".join(parts) + ".png" + + +def heatmap_artifact_filename(spec, cfg, suffix=None): + """Build a heatmap filename, including any colour-style tag.""" + stem = f"reclamm_heatmap_{spec['slug']}" + artifact_tag = spec.get("artifact_tag") + if artifact_tag: + stem = f"{stem}_{artifact_tag}" + return tvl_artifact_filename(stem, cfg, suffix=suffix) + + +def three_d_heatmap_artifact_filename(spec, cfg, suffix=None): + """Build a 3D heatmap filename, including any colour-style tag.""" + stem = f"reclamm_heatmap_3d_{spec['slug']}" + artifact_tag = spec.get("artifact_tag") + if artifact_tag: + stem = f"{stem}_{artifact_tag}" + return tvl_artifact_filename(stem, cfg, suffix=suffix) + + +def format_heatmap_param_value(value): + """Format a sweep parameter compactly for titles and logs.""" + value = float(value) + if abs(value) >= 1.0: + return f"{value:.2f}".rstrip("0").rstrip(".") + return f"{value:.3f}".rstrip("0").rstrip(".") + + +def configs_for_tvl(base_configs, initial_pool_value): + """Attach a shared initial TVL to each compare configuration.""" + configs = [] + for cfg in base_configs: + updated = dict(cfg) + updated["initial_pool_value"] = float(initial_pool_value) + configs.append(updated) + return configs + + +def _normalize_arb_frequency(value, default=FIXED_COMPARE_ARB_FREQUENCY): + """Return a stable integer arb cadence for thermostat comparisons.""" + if value is None: + if default is None: + return None + value = default + return max(int(round(float(value))), 1) + + +def get_effective_arb_frequency(cfg, noise_cfg=None): + """Resolve the arb cadence used by a thermostat comparison run.""" + del noise_cfg + return _normalize_arb_frequency(FIXED_COMPARE_ARB_FREQUENCY) + + +def _canonical_noise_reference_model(cfg): + """Resolve the only supported thermostat noise parametrisation.""" + noise_model = cfg.get("noise_model", DEFAULT_NOISE_MODEL) or DEFAULT_NOISE_MODEL + reference_model = cfg.get("noise_reference_model") + if reference_model is None: + reference_model = DEFAULT_NOISE_MODEL if noise_model == "arb_only" else noise_model + noise_model = str(noise_model) + reference_model = str(reference_model) + if noise_model not in {DEFAULT_NOISE_MODEL, "arb_only"}: + raise ValueError( + "compare_reclamm_thermostats only supports " + "'market_linear' noise and 'arb_only' baselines." + ) + if reference_model != DEFAULT_NOISE_MODEL: + raise ValueError( + "compare_reclamm_thermostats only supports the " + "'market_linear' noise parametrisation." + ) + return reference_model + + +def normalize_compare_run_cfg(cfg, enable_noise_model=None): + """Canonicalize the compare-run config so non-axis inputs stay fixed.""" + updated = dict(cfg) + updated["price_ratio"] = float(cfg["price_ratio"]) + updated["centeredness_margin"] = float(cfg["centeredness_margin"]) + updated["daily_price_shift_exponent"] = float(cfg["daily_price_shift_exponent"]) + updated["initial_pool_value"] = float(get_initial_pool_value(cfg)) + updated["gas_cost"] = DEFAULT_GAS_COST + updated["protocol_fee_split"] = DEFAULT_PROTOCOL_FEE_SPLIT + updated["arb_fees"] = 0.0 + updated["arb_frequency"] = get_effective_arb_frequency(cfg) + updated["noise_trader_ratio"] = 0.0 + + arc_length_speed = cfg.get("arc_length_speed") + if arc_length_speed is None: + updated.pop("arc_length_speed", None) + else: + updated["arc_length_speed"] = float(arc_length_speed) + + use_noise = ( + bool(cfg.get("enable_noise_model", False)) + if enable_noise_model is None + else bool(enable_noise_model) + ) + updated["enable_noise_model"] = use_noise + + reference_mode = _canonical_noise_reference_model(cfg) + if use_noise: + updated["noise_model"] = reference_mode + updated["noise_reference_model"] = reference_mode + else: + updated["noise_model"] = "arb_only" + updated["noise_reference_model"] = reference_mode + + updated["noise_arrays_path"] = DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH + updated.pop("reclamm_noise_params", None) + updated["noise_artifact_dir"] = DEFAULT_MARKET_LINEAR_ARTIFACT_DIR + updated["noise_pool_id"] = AAVE_WETH_POOL_ID + + return updated + + +def make_noise_variant_cfg(cfg, enable_noise_model): + """Return a config with either noise modelling or pure arb-only enabled.""" + return normalize_compare_run_cfg(cfg, enable_noise_model=enable_noise_model) + + +def _hashable_noise_params(params): + """Convert a noise-params dict into a stable cache key fragment.""" + if params is None: + return None + return tuple(sorted((str(k), round(float(v), 12)) for k, v in params.items())) + + +def load_shared_market_linear_noise_data( + arrays_path=DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH, +): + """Load the market_linear arrays once so compare runs can reuse them.""" + arrays_path = os.path.abspath(os.fspath(arrays_path)) + cached = _MARKET_LINEAR_NOISE_DATA_CACHE.get(arrays_path) + if cached is not None: + return cached + + if not os.path.exists(arrays_path): + raise FileNotFoundError(f"market_linear arrays file not found: {arrays_path}") + + with np.load(arrays_path) as arrays: + required_keys = {"noise_base", "noise_tvl_coeff", "tvl_mean", "tvl_std"} + missing_keys = sorted(required_keys.difference(arrays.files)) + if missing_keys: + raise KeyError( + f"market_linear arrays file {arrays_path} is missing keys: {missing_keys}" + ) + shared = { + "arrays_path": arrays_path, + "noise_base_array": np.asarray(arrays["noise_base"]), + "noise_tvl_coeff_array": np.asarray(arrays["noise_tvl_coeff"]), + "tvl_mean": float(arrays["tvl_mean"]), + "tvl_std": float(arrays["tvl_std"]), + } + _MARKET_LINEAR_NOISE_DATA_CACHE[arrays_path] = shared + return shared + + +def _load_market_linear_noise_stats(arrays_path=DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH): + """Load the exact arrays file used by the market_linear run fingerprint. + + The simulator consumes ``noise_base`` and ``noise_tvl_coeff`` from + ``run_fingerprint["noise_arrays_path"]`` and uses ``tvl_mean``/``tvl_std`` + from the same file for TVL standardization. + """ + shared = load_shared_market_linear_noise_data(arrays_path=arrays_path) + return shared["arrays_path"], shared["tvl_mean"], shared["tvl_std"] + + +def _market_linear_noise_settings(noise_model="market_linear", arb_frequency=None): + """Build the tuned market_linear fingerprint block from the fixed arrays file.""" + arrays_path, tvl_mean, tvl_std = _load_market_linear_noise_stats() + arb_frequency = _normalize_arb_frequency(arb_frequency) + return { + "noise_model": noise_model, + "noise_trader_ratio": 0.0, + "reclamm_noise_params": { + "tvl_mean": tvl_mean, + "tvl_std": tvl_std, + }, + "noise_arrays_path": arrays_path, + "arb_frequency": arb_frequency, + "noise_summary": f"{noise_model} (arb_frequency={arb_frequency})", + "noise_cache_key": ( + noise_model, + arrays_path, + arb_frequency, + round(tvl_mean, 12), + round(tvl_std, 12), + ), + } + +def resolve_reclamm_noise_settings(cfg): + """Resolve the active reCLAMM noise-model fingerprint block for a config.""" + cfg = normalize_compare_run_cfg(cfg) + enable_noise_model = cfg.get("enable_noise_model", False) + requested_mode = cfg.get("noise_model", DEFAULT_NOISE_MODEL) + reference_mode = cfg.get("noise_reference_model", DEFAULT_NOISE_MODEL) + requested_arb_frequency = get_effective_arb_frequency(cfg) + cache_key = ( + tuple(cfg.get("tokens", [])), + cfg.get("start"), + cfg.get("end"), + enable_noise_model, + requested_mode, + reference_mode, + cfg.get("noise_artifact_dir", DEFAULT_MARKET_LINEAR_ARTIFACT_DIR), + cfg.get("noise_pool_id", AAVE_WETH_POOL_ID), + requested_arb_frequency, + round(float(cfg.get("noise_trader_ratio", 0.0)), 12), + _hashable_noise_params(cfg.get("reclamm_noise_params")), + cfg.get("noise_arrays_path"), + ) + if cache_key in _NOISE_SETTINGS_CACHE: + return _NOISE_SETTINGS_CACHE[cache_key] + + if requested_mode == "arb_only": + result = _market_linear_noise_settings( + noise_model="arb_only", + arb_frequency=requested_arb_frequency, + ) + elif requested_mode == DEFAULT_NOISE_MODEL: + result = _market_linear_noise_settings( + noise_model=DEFAULT_NOISE_MODEL, + arb_frequency=requested_arb_frequency, + ) + else: + raise ValueError( + "compare_reclamm_thermostats only supports " + "'market_linear' noise and 'arb_only' baselines." + ) + + _NOISE_SETTINGS_CACHE[cache_key] = result + return result + + # Pool configurations to compare CONFIGS = [ { - "name": "AAVE/ETH on-chain (25bps, narrow range)", + "name": "AAVE/ETH launch-style range (25bps, reference)", "tokens": ["AAVE", "ETH"], "start": "2024-06-01 00:00:00", "end": "2025-06-01 00:00:00", "fees": 0.0025, - "price_ratio": 1.5, + "price_ratio": 1.5014, "centeredness_margin": 0.5, "daily_price_shift_exponent": 0.1, + "reason": "Original launch-style parameters.", + **AAVE_ETH_NOISE_SETTINGS, }, { - "name": "AAVE/ETH wide range (25bps)", + "name": "AAVE/ETH aggressive tight range (25bps)", "tokens": ["AAVE", "ETH"], "start": "2024-06-01 00:00:00", "end": "2025-06-01 00:00:00", "fees": 0.0025, - "price_ratio": 4.0, - "centeredness_margin": 0.2, - "daily_price_shift_exponent": 1.0, - }, - { - "name": "AAVE/ETH zero fees (narrow)", - "tokens": ["AAVE", "ETH"], - "start": "2024-06-01 00:00:00", - "end": "2025-06-01 00:00:00", - "fees": 0.0, - "price_ratio": 1.5, - "centeredness_margin": 0.5, + "price_ratio": 1.10, + "centeredness_margin": 0.60, "daily_price_shift_exponent": 0.1, + "reason": ( + "Aggressively tightened and moved to an earlier thermostat trigger. " + "At fixed price_ratio=1.10, the shift_exponent sweep still favored " + "0.1, while margin=0.60 widened the non-linear edge materially." + ), + **AAVE_ETH_NOISE_SETTINGS, }, ] -def make_fingerprint(cfg, interpolation_method, centeredness_scaling=False): +def _attach_market_linear_noise_arrays( + fingerprint, + noise_cfg, + market_linear_noise_data, +): + """Attach preloaded market_linear arrays when the compare flow has them.""" + if market_linear_noise_data is None: + return + expected_path = noise_cfg.get("noise_arrays_path") + if expected_path is None: + return + shared_path = os.path.abspath(os.fspath(market_linear_noise_data["arrays_path"])) + expected_path = os.path.abspath(os.fspath(expected_path)) + if shared_path != expected_path: + raise ValueError( + "Shared market_linear noise arrays path does not match " + f"the resolved compare-run noise path: {shared_path} != {expected_path}" + ) + fingerprint["noise_base_array"] = market_linear_noise_data["noise_base_array"] + fingerprint["noise_tvl_coeff_array"] = market_linear_noise_data["noise_tvl_coeff_array"] + + +def make_fingerprint(cfg, interpolation_method, market_linear_noise_data=None): """Build run fingerprint for a given config and interpolation method.""" - return { + cfg = normalize_compare_run_cfg(cfg) + speed_override = ( + cfg.get("arc_length_speed") + if interpolation_method == "constant_arc_length" + else None + ) + noise_cfg = resolve_reclamm_noise_settings(cfg) + arb_frequency = get_effective_arb_frequency(cfg, noise_cfg) + fingerprint = { "tokens": cfg["tokens"], "rule": "reclamm", "startDateString": cfg["start"], "endDateString": cfg["end"], - "initial_pool_value": 1000000.0, + "initial_pool_value": get_initial_pool_value(cfg), "do_arb": True, "fees": cfg["fees"], - "gas_cost": 0.0, - "arb_fees": 0.0, + "gas_cost": cfg.get( + "gas_cost", + DEFAULT_GAS_COST if cfg.get("enable_noise_model", False) else 0.0, + ), + "arb_fees": cfg.get("arb_fees", 0.0), + "protocol_fee_split": cfg.get( + "protocol_fee_split", + DEFAULT_PROTOCOL_FEE_SPLIT if cfg.get("enable_noise_model", False) else 0.0, + ), + "noise_trader_ratio": noise_cfg.get("noise_trader_ratio", 0.0), "reclamm_interpolation_method": interpolation_method, - "reclamm_arc_length_speed": None, # auto-calibrate - "reclamm_centeredness_scaling": centeredness_scaling, + "reclamm_arc_length_speed": speed_override, } + if noise_cfg.get("noise_model") is not None: + fingerprint["noise_model"] = noise_cfg["noise_model"] + if noise_cfg.get("reclamm_noise_params") is not None: + fingerprint["reclamm_noise_params"] = noise_cfg["reclamm_noise_params"] + if noise_cfg.get("noise_arrays_path") is not None: + fingerprint["noise_arrays_path"] = noise_cfg["noise_arrays_path"] + _attach_market_linear_noise_arrays( + fingerprint, + noise_cfg, + market_linear_noise_data, + ) + if arb_frequency is not None: + fingerprint["arb_frequency"] = arb_frequency + return fingerprint def make_params(cfg): """Build pool params from config.""" + cfg = normalize_compare_run_cfg(cfg) return { "price_ratio": jnp.array(cfg["price_ratio"]), "centeredness_margin": jnp.array(cfg["centeredness_margin"]), @@ -85,40 +560,1654 @@ def make_params(cfg): } -def run_comparison(cfg): - """Run all thermostat variants, return results dict.""" +def load_shared_price_data(configs, root=None): + """Load the shared historic price panel once for all compare runs.""" + tokens = sorted({token for cfg in configs for token in cfg["tokens"]}) + return get_historic_parquet_data(tokens, cols=["close"], root=root) + + +def run_comparison( + cfg, + price_data=None, + low_data_mode=False, + market_linear_noise_data=None, +): + """Run both interpolation variants, return results dict.""" params = make_params(cfg) results = {} - for method in ["geometric", "constant_arc_length"]: - fp = make_fingerprint(cfg, method) + for method in INTERPOLATION_METHODS: + fp = make_fingerprint( + cfg, + method, + market_linear_noise_data=market_linear_noise_data, + ) results[method] = do_run_on_historic_data( - run_fingerprint=fp, params=params + run_fingerprint=fp, + params=params, + price_data=price_data, + low_data_mode=low_data_mode, + ) + + return results + + +def _set_padded_ylim(ax, series_list, pad_ratio=0.04): + """Fit the y-axis tightly around the plotted series.""" + flat = [ + np.asarray(series, dtype=float).ravel() + for series in series_list + if np.asarray(series).size > 0 + ] + if not flat: + return + + values = np.concatenate(flat) + values = values[np.isfinite(values)] + if values.size == 0: + return + + ymin = float(values.min()) + ymax = float(values.max()) + if np.isclose(ymin, ymax): + pad = max(abs(ymin) * pad_ratio, 1e-6) + else: + pad = (ymax - ymin) * pad_ratio + ax.set_ylim(ymin - pad, ymax + pad) + + +def _cache_size(cache): + """Count memoized final-value cache entries materialised in memory.""" + return len(cache.get("_final_value_cache", {})) + + +def _comparison_cache_size(cache): + """Count memoized scalar comparison bundles.""" + return len(cache.get("_comparison_cache", {})) + + +def _heatmap_forward_cache_scope_slug(cfg): + """Build a compact cache scope slug for a shared-TVL heatmap run.""" + if cfg is None: + return "unspecified_tvl" + return f"tvl_{format_tvl_millions_slug(cfg)}" + + +def _heatmap_forward_cache_path(cfg): + """Return the parquet path for persisted scalar forward values.""" + if not HEATMAP_FORWARD_CACHE_ENABLED: + return None + return os.path.join( + HEATMAP_FORWARD_CACHE_ROOT, + HEATMAP_FORWARD_CACHE_RUN_NAME, + f"forward_values_{_heatmap_forward_cache_scope_slug(cfg)}.parquet", + ) + + +def _make_method_cache_hash(key): + """Build a compact stable digest for a method cache key.""" + return hashlib.sha256(repr(key).encode("utf-8")).hexdigest() + + +def _build_persistent_final_value_record(cfg, method, cache_key_hash, final_value): + """Build one self-describing parquet row for a cached scalar run result.""" + cfg = normalize_compare_run_cfg(cfg) + noise_cfg = resolve_reclamm_noise_settings(cfg) + return { + "cache_key_hash": str(cache_key_hash), + "final_value": float(final_value), + "method": str(method), + "enable_noise_model": bool(cfg.get("enable_noise_model", False)), + "noise_model": noise_cfg.get("noise_model"), + "price_ratio": float(cfg["price_ratio"]), + "centeredness_margin": float(cfg["centeredness_margin"]), + "daily_price_shift_exponent": float(cfg["daily_price_shift_exponent"]), + "initial_pool_value": float(get_initial_pool_value(cfg)), + "arb_frequency": get_effective_arb_frequency(cfg, noise_cfg), + } + + +def _load_persistent_final_value_cache(cache): + """Load persisted scalar forward values from parquet once per sweep cache.""" + if cache.get("_persistent_final_value_cache_loaded"): + return + + disk_cache = {} + next_batch_id = 0 + cache_path = cache.get("_persistent_final_value_cache_path") + if cache_path and os.path.exists(cache_path): + parquet_files = [] + if os.path.isdir(cache_path): + parquet_files = [ + os.path.join(cache_path, filename) + for filename in sorted(os.listdir(cache_path)) + if filename.endswith(".parquet") + ] + batch_ids = [] + for filename in os.listdir(cache_path): + if not (filename.startswith("batch_") and filename.endswith(".parquet")): + continue + token = filename[len("batch_") : -len(".parquet")] + if token.isdigit(): + batch_ids.append(int(token)) + next_batch_id = (max(batch_ids) + 1) if batch_ids else 0 + else: + parquet_files = [cache_path] + + for parquet_file in parquet_files: + frame = pd.read_parquet( + parquet_file, + columns=["cache_key_hash", "final_value"], + ) + if frame.empty: + continue + for row in frame.itertuples(index=False): + cache_key_hash = str(row.cache_key_hash) + final_value = float(row.final_value) + disk_cache[cache_key_hash] = final_value + print( + f"Loaded {len(disk_cache)} persisted heatmap forward values from {cache_path}" ) - # Geometric + centeredness-proportional scaling (scales decay duration) - fp_geo_scaled = make_fingerprint(cfg, "geometric", centeredness_scaling=True) - results["geometric_scaled"] = do_run_on_historic_data( - run_fingerprint=fp_geo_scaled, params=params + cache["_persistent_final_value_cache"] = disk_cache + cache["_persistent_final_value_next_batch_id"] = next_batch_id + cache["_persistent_final_value_cache_loaded"] = True + + +def flush_sweep_cache(cache, force=False): + """Persist newly computed scalar forward values to parquet.""" + if not HEATMAP_FORWARD_CACHE_ENABLED: + return + + pending = cache.get("_pending_persistent_final_values") + if not pending: + return + if not force and len(pending) < HEATMAP_FORWARD_CACHE_FLUSH_EVERY: + return + + _load_persistent_final_value_cache(cache) + disk_cache = cache.setdefault("_persistent_final_value_cache", {}) + batch_records = [] + for cache_key_hash, record in pending.items(): + normalized = dict(record) + normalized["cache_key_hash"] = str(cache_key_hash) + normalized["final_value"] = float(normalized["final_value"]) + disk_cache[cache_key_hash] = normalized["final_value"] + batch_records.append(normalized) + + cache_path = cache.get("_persistent_final_value_cache_path") + if cache_path is None: + pending.clear() + return + + if os.path.exists(cache_path) and not os.path.isdir(cache_path): + raise RuntimeError( + f"Persistent cache path {cache_path} already exists as a file. " + "Use a fresh cache namespace for append-only parquet shards." + ) + + os.makedirs(cache_path, exist_ok=True) + batch_records.sort(key=lambda record: record["cache_key_hash"]) + payload = { + column: [record.get(column) for record in batch_records] + for column in PERSISTED_FORWARD_VALUE_COLUMNS + } + payload["final_value"] = np.asarray(payload["final_value"], dtype=np.float64) + frame = pd.DataFrame(payload) + batch_id = int(cache.setdefault("_persistent_final_value_next_batch_id", 0)) + batch_path = os.path.join(cache_path, f"batch_{batch_id:08d}.parquet") + cache["_persistent_final_value_next_batch_id"] = batch_id + 1 + frame.to_parquet(batch_path, index=False, compression="zstd") + print( + f"Persisted {len(pending)} new heatmap forward values to {batch_path} " + f"({len(disk_cache)} total cached values)." ) + pending.clear() + + +def make_sweep_cache( + price_data, + cache_scope_cfg=None, + market_linear_noise_data=None, +): + """Create a shared cache for heatmap and line sweeps.""" + cache = { + "_shared_price_data": price_data, + "_shared_market_linear_noise_data": market_linear_noise_data, + "_final_value_cache": {}, + "_comparison_cache": {}, + "_pending_persistent_final_values": {}, + "_persistent_final_value_cache": {}, + "_persistent_final_value_next_batch_id": 0, + "_persistent_final_value_cache_loaded": False, + "_persistent_final_value_cache_path": _heatmap_forward_cache_path( + cache_scope_cfg + ), + } + return cache + + +def _missing_artifacts(progress_label, filenames): + """Report which plot artifacts still need to be generated.""" + missing = [filename for filename in filenames if not os.path.exists(filename)] + if not missing: + print(f"[{progress_label}] skipping sweep: all artifacts already exist.") + return set() + + existing_count = len(filenames) - len(missing) + if existing_count: + print( + f"[{progress_label}] reusing {existing_count}/{len(filenames)} " + "existing artifacts; generating the missing outputs." + ) + return set(missing) + + +def _speed_cache_key(speed): + """Stable cache token for optional arc-length speed.""" + if speed is None: + return None + return round(float(speed), 12) - # Arc-length + centeredness-proportional scaling (scales speed) - fp_cal_scaled = make_fingerprint(cfg, "constant_arc_length", centeredness_scaling=True) - results["cal_scaled"] = do_run_on_historic_data( - run_fingerprint=fp_cal_scaled, params=params + +def _make_method_cache_key(cfg, method): + """Cache key for a single-method final-value run.""" + cfg = normalize_compare_run_cfg(cfg) + noise_cfg = resolve_reclamm_noise_settings(cfg) + arb_frequency = get_effective_arb_frequency(cfg, noise_cfg) + key = ( + method, + tuple(str(token) for token in cfg["tokens"]), + str(cfg["start"]), + str(cfg["end"]), + round(float(cfg["fees"]), 12), + bool(cfg.get("enable_noise_model", False)), + round(float(cfg["price_ratio"]), 6), + round(float(cfg["centeredness_margin"]), 6), + round(float(cfg["daily_price_shift_exponent"]), 6), + round(get_initial_pool_value(cfg), 2), + noise_cfg.get("noise_cache_key"), + None if arb_frequency is None else int(arb_frequency), + round( + float( + cfg.get( + "gas_cost", + DEFAULT_GAS_COST if cfg.get("enable_noise_model", False) else 0.0, + ) + ), + 6, + ), + round( + float( + cfg.get( + "protocol_fee_split", + DEFAULT_PROTOCOL_FEE_SPLIT if cfg.get("enable_noise_model", False) else 0.0, + ) + ), + 6, + ), ) + if method == "constant_arc_length": + key += (_speed_cache_key(cfg.get("arc_length_speed")),) + return key + + +def _nearest_price_row(price_data, start_ts): + """Select the closest available price row to the requested start timestamp.""" + if len(price_data.index) == 0: + raise ValueError("price_data is empty") + + if isinstance(price_data.index, pd.DatetimeIndex): + target_ts = start_ts + index_tz = getattr(price_data.index, "tz", None) + if index_tz is not None and target_ts.tzinfo is None: + target_ts = target_ts.tz_localize(index_tz) + elif index_tz is None and target_ts.tzinfo is not None: + target_ts = target_ts.tz_convert(None) + target_value = int(target_ts.value) + index_values = price_data.index.asi8 + else: + target_value = int(start_ts.timestamp() * 1000.0) + index_values = price_data.index.to_numpy(dtype=np.int64) + + row_idx = int(np.searchsorted(index_values, target_value, side="left")) + if row_idx >= len(index_values): + row_idx = len(index_values) - 1 + elif row_idx > 0 and index_values[row_idx] != target_value: + prev_idx = row_idx - 1 + if abs(int(index_values[prev_idx]) - target_value) <= abs( + int(index_values[row_idx]) - target_value + ): + row_idx = prev_idx + + row = price_data.iloc[row_idx] + if isinstance(row, pd.DataFrame): + row = row.iloc[0] + return row + + +def _make_comparison_cache_key(cfg, launch_final_values): + """Cache key for scalar heatmap metrics at a single parameter point.""" + noise_cfg = make_noise_variant_cfg(cfg, True) + arb_only_cfg = make_noise_variant_cfg(cfg, False) + key = [ + _make_method_cache_key(noise_cfg, "geometric"), + _make_method_cache_key(arb_only_cfg, "geometric"), + round(float(launch_final_values["geometric"]), 6), + ] + if RUN_CONSTANT_ARC_LENGTH: + key.extend( + [ + _make_method_cache_key(noise_cfg, "constant_arc_length"), + _make_method_cache_key(arb_only_cfg, "constant_arc_length"), + round(float(launch_final_values["constant_arc_length"]), 6), + ] + ) + return tuple(key) + + +def _run_method_final_value_cached(cfg, method, cache): + """Memoize final value for a single interpolation method.""" + final_value_cache = cache.setdefault("_final_value_cache", {}) + key = _make_method_cache_key(cfg, method) + if key in final_value_cache: + return final_value_cache[key] + + _load_persistent_final_value_cache(cache) + key_hash = _make_method_cache_hash(key) + persisted_cache = cache.setdefault("_persistent_final_value_cache", {}) + if key_hash in persisted_cache: + final_value_cache[key] = persisted_cache[key_hash] + return final_value_cache[key] + + result = do_run_on_historic_data( + run_fingerprint=make_fingerprint( + cfg, + method, + market_linear_noise_data=cache.get("_shared_market_linear_noise_data"), + ), + params=make_params(cfg), + price_data=cache["_shared_price_data"], + low_data_mode=True, + ) + final_value_cache[key] = float(result["final_value"]) + cache.setdefault("_pending_persistent_final_values", {})[key_hash] = ( + _build_persistent_final_value_record( + cfg=cfg, + method=method, + cache_key_hash=key_hash, + final_value=final_value_cache[key], + ) + ) + flush_sweep_cache(cache, force=False) + del result + gc.collect() + return final_value_cache[key] + + +def extract_comparison_metrics_from_final_values( + geo_final, arc_final, launch_final_values +): + """Summarize scalar comparison metrics from final values only.""" + return { + "efficiency_pct": (arc_final / max(abs(geo_final), 1e-12) - 1.0) * 100.0, + "launch_geometric_efficiency_pct": ( + arc_final / max(abs(launch_final_values["geometric"]), 1e-12) - 1.0 + ) + * 100.0, + "geometric_vs_launch_geometric_pct": ( + geo_final / max(abs(launch_final_values["geometric"]), 1e-12) - 1.0 + ) + * 100.0, + "constant_arc_vs_launch_constant_arc_pct": ( + arc_final + / max(abs(launch_final_values["constant_arc_length"]), 1e-12) + - 1.0 + ) + * 100.0, + } + + +def _load_required_heatmap_final_values(cfg, cache, metric_keys): + """Load only the cached final values needed for the requested heatmap metrics.""" + required_sources = set() + for metric_key in metric_keys: + required_sources.update(HEATMAP_METRIC_DEPENDENCIES[metric_key]) + + if not RUN_CONSTANT_ARC_LENGTH and any( + source.endswith("constant_arc") for source in required_sources + ): + raise ValueError( + "Constant-arc heatmap metric requested while RUN_CONSTANT_ARC_LENGTH=False" + ) + + final_values = {} + noise_cfg = None + arb_only_cfg = None + + if any(source.startswith("noise_") for source in required_sources): + noise_cfg = make_noise_variant_cfg(cfg, True) + if any(source.startswith("arb_") for source in required_sources): + arb_only_cfg = make_noise_variant_cfg(cfg, False) + + if "noise_geometric" in required_sources: + final_values["noise_geometric"] = _run_method_final_value_cached( + noise_cfg, + "geometric", + cache, + ) + if "noise_constant_arc" in required_sources: + final_values["noise_constant_arc"] = _run_method_final_value_cached( + noise_cfg, + "constant_arc_length", + cache, + ) + if "arb_geometric" in required_sources: + final_values["arb_geometric"] = _run_method_final_value_cached( + arb_only_cfg, + "geometric", + cache, + ) + if "arb_constant_arc" in required_sources: + final_values["arb_constant_arc"] = _run_method_final_value_cached( + arb_only_cfg, + "constant_arc_length", + cache, + ) + return final_values + + +def extract_heatmap_metrics_from_mode_final_values( + metric_keys, + final_values, + launch_final_values, +): + """Collect the requested scalar heatmap metrics from cached final values.""" + metrics = {} + + if "efficiency_pct" in metric_keys: + metrics["efficiency_pct"] = ( + final_values["noise_constant_arc"] + / max(abs(final_values["noise_geometric"]), 1e-12) + - 1.0 + ) * 100.0 + + if "launch_geometric_efficiency_pct" in metric_keys: + metrics["launch_geometric_efficiency_pct"] = ( + final_values["noise_constant_arc"] + / max(abs(launch_final_values["geometric"]), 1e-12) + - 1.0 + ) * 100.0 + + if "geometric_vs_launch_geometric_pct" in metric_keys: + metrics["geometric_vs_launch_geometric_pct"] = ( + final_values["noise_geometric"] + / max(abs(launch_final_values["geometric"]), 1e-12) + - 1.0 + ) * 100.0 + + if "constant_arc_vs_launch_constant_arc_pct" in metric_keys: + metrics["constant_arc_vs_launch_constant_arc_pct"] = ( + final_values["noise_constant_arc"] + / max(abs(launch_final_values["constant_arc_length"]), 1e-12) + - 1.0 + ) * 100.0 + + if "noise_geometric_final_value_musd" in metric_keys: + metrics["noise_geometric_final_value_musd"] = ( + final_values["noise_geometric"] / 1e6 + ) + + if "noise_constant_arc_final_value_musd" in metric_keys: + metrics["noise_constant_arc_final_value_musd"] = ( + final_values["noise_constant_arc"] / 1e6 + ) + + if "noise_vs_arb_geometric_improvement_pct" in metric_keys: + metrics["noise_vs_arb_geometric_improvement_pct"] = ( + final_values["noise_geometric"] + / max(abs(final_values["arb_geometric"]), 1e-12) + - 1.0 + ) * 100.0 + + if "noise_vs_arb_constant_arc_improvement_pct" in metric_keys: + metrics["noise_vs_arb_constant_arc_improvement_pct"] = ( + final_values["noise_constant_arc"] + / max(abs(final_values["arb_constant_arc"]), 1e-12) + - 1.0 + ) * 100.0 + + return metrics + + +def extract_comparison_metrics(results, launch_final_values): + """Summarize scalar heatmap metrics for a pair of runs.""" + geo = results["geometric"] + arc = results["constant_arc_length"] + + geo_final = float(geo["final_value"]) + arc_final = float(arc["final_value"]) + + return extract_comparison_metrics_from_final_values( + geo_final, + arc_final, + launch_final_values=launch_final_values, + ) + + +def run_comparison_cached(cfg, cache, launch_final_values, metric_keys): + """Memoize scalar heatmap metrics across heatmap sweeps.""" + requested_metric_keys = tuple(dict.fromkeys(metric_keys)) + comparison_cache = cache.setdefault("_comparison_cache", {}) + cache_key = _make_comparison_cache_key(cfg, launch_final_values) + cached_metrics = comparison_cache.setdefault(cache_key, {}) + missing_metric_keys = [ + metric_key for metric_key in requested_metric_keys if metric_key not in cached_metrics + ] + if missing_metric_keys: + final_values = _load_required_heatmap_final_values( + cfg, + cache, + missing_metric_keys, + ) + cached_metrics.update( + extract_heatmap_metrics_from_mode_final_values( + missing_metric_keys, + final_values, + launch_final_values=launch_final_values, + ) + ) + return { + metric_key: cached_metrics[metric_key] for metric_key in requested_metric_keys + } + + +def build_heatmap_matrices( + x_values, + y_values, + x_key, + y_key, + base_cfg, + metric_keys, + cache, + progress_label, + launch_final_values, +): + """Evaluate multiple metrics over a 2D parameter grid in one pass.""" + data = { + metric_key: np.zeros((len(y_values), len(x_values)), dtype=float) + for metric_key in metric_keys + } + total_points = len(y_values) * len(x_values) + + print( + f"[{progress_label}] start: {len(y_values)} rows x {len(x_values)} cols " + f"= {total_points} parameter points" + ) + + for yi, y_value in enumerate(y_values): + final_cache_before_row = _cache_size(cache) + comparison_cache_before_row = _comparison_cache_size(cache) + for xi, x_value in enumerate(x_values): + cfg = dict(base_cfg) + cfg[x_key] = float(x_value) + cfg[y_key] = float(y_value) + metrics = run_comparison_cached( + cfg, + cache, + launch_final_values=launch_final_values, + metric_keys=metric_keys, + ) + for metric_key in metric_keys: + data[metric_key][yi, xi] = metrics[metric_key] + + completed_points = (yi + 1) * len(x_values) + row_new_final_entries = _cache_size(cache) - final_cache_before_row + row_new_comparisons = ( + _comparison_cache_size(cache) - comparison_cache_before_row + ) + row_pct = completed_points / total_points * 100.0 + flush_sweep_cache(cache, force=True) + print( + f"[{progress_label}] row {yi + 1}/{len(y_values)} complete " + f"({y_key}={float(y_value):.4f}, {completed_points}/{total_points} " + f"points, {row_pct:.1f}%, {row_new_final_entries} new final-value cache entries, " + f"{row_new_comparisons} new comparison bundles)" + ) + + print( + f"[{progress_label}] done: " + + ", ".join( + ( + f"{metric_key} min={float(np.nanmin(data[metric_key])):.4f}, " + f"max={float(np.nanmax(data[metric_key])):.4f}" + ) + for metric_key in metric_keys + ) + + ( + f", final_value_cache_size={_cache_size(cache)}, " + f"comparison_cache_size={_comparison_cache_size(cache)}" + ) + ) + + return data + + +def build_metric_curve( + x_values, + x_key, + base_cfg, + metric_key, + cache, + launch_final_values, +): + """Evaluate one metric over a 1D sweep.""" + data = np.zeros(len(x_values), dtype=float) + for xi, x_value in enumerate(x_values): + cfg = dict(base_cfg) + cfg[x_key] = float(x_value) + metrics = run_comparison_cached( + cfg, + cache, + launch_final_values=launch_final_values, + metric_keys=(metric_key,), + ) + data[xi] = metrics[metric_key] + flush_sweep_cache(cache, force=True) + return data + + +def _compute_axis_edges(values, scale="linear"): + """Convert axis centers to cell edges for pcolormesh.""" + values = np.asarray(values, dtype=float) + if values.size == 1: + if scale == "log": + return np.array([values[0] / np.sqrt(10.0), values[0] * np.sqrt(10.0)]) + pad = max(abs(values[0]) * 0.5, 1.0) + return np.array([values[0] - pad, values[0] + pad]) + + if scale == "log": + log_values = np.log10(values) + edges = np.empty(values.size + 1, dtype=float) + edges[1:-1] = 0.5 * (log_values[:-1] + log_values[1:]) + edges[0] = log_values[0] - 0.5 * (log_values[1] - log_values[0]) + edges[-1] = log_values[-1] + 0.5 * (log_values[-1] - log_values[-2]) + return 10.0 ** edges + + edges = np.empty(values.size + 1, dtype=float) + edges[1:-1] = 0.5 * (values[:-1] + values[1:]) + edges[0] = values[0] - 0.5 * (values[1] - values[0]) + edges[-1] = values[-1] + 0.5 * (values[-1] - values[-2]) + return edges + + +def build_fixed_slice_variants(values): + """Pick four representative quarter-range slices from a sweep grid.""" + values = np.asarray(values, dtype=float) + if values.size < len(FIXED_SLICE_FRACTIONS): + raise ValueError("Need at least four grid points to build fixed slices") + + variants = [] + used_indices = set() + for idx, fraction in enumerate(FIXED_SLICE_FRACTIONS): + target_index = int(round(fraction * (values.size - 1))) + while target_index in used_indices and target_index + 1 < values.size: + target_index += 1 + while target_index in used_indices and target_index - 1 >= 0: + target_index -= 1 + if target_index in used_indices: + raise ValueError("Could not build four unique fixed slices from sweep grid") + used_indices.add(target_index) + variants.append( + { + "index": target_index, + "fraction": fraction, + "label": FIXED_SLICE_LABELS[idx], + "slug": f"q{idx + 1}", + "value": float(values[target_index]), + } + ) + return variants + + +def _pair_slice_suffix(pair, slice_variant): + """Build a stable artifact suffix for a pairwise fixed-variable slice.""" + return f"{pair['slug']}_{pair['fixed_slug']}_{slice_variant['slug']}" + + +def _build_heatmap_norm( + data_arrays, + center_zero, + color_norm=None, + symlog_linthresh=None, +): + """Build a color normalizer shared by 2D and 3D heatmaps.""" + finite_parts = [] + for data in data_arrays: + finite = np.asarray(data, dtype=float) + finite = finite[np.isfinite(finite)] + if finite.size: + finite_parts.append(finite) + finite = np.concatenate(finite_parts) if finite_parts else np.array([], dtype=float) + + if center_zero: + if finite.size == 0: + vmax = 1.0 + else: + vmax = max(abs(float(finite.min())), abs(float(finite.max())), 1e-9) + if ( + color_norm == "symlog" + and symlog_linthresh is not None + and vmax > symlog_linthresh + ): + return SymLogNorm( + linthresh=symlog_linthresh, + linscale=1.0, + vmin=-vmax, + vmax=vmax, + base=10.0, + ) + return TwoSlopeNorm(vcenter=0.0, vmin=-vmax, vmax=vmax) + + if finite.size == 0: + vmin, vmax = 0.0, 1.0 + else: + vmin = float(finite.min()) + vmax = float(finite.max()) + if np.isclose(vmin, vmax): + pad = max(abs(vmin) * 0.01, 1e-9) + vmin -= pad + vmax += pad + return Normalize(vmin=vmin, vmax=vmax) + + +def get_pair_heatmap_metric_specs(): + """Return the standard thermostat pairwise heatmap metrics.""" + metric_specs = [ + { + "key": "efficiency_pct", + "title": "Efficiency vs heatmap geometric", + "colorbar_label": "Const Arc - heatmap Geo (% of heatmap geometric final value)", + "slug": "efficiency", + "center_zero": True, + "cmap": "RdYlGn", + }, + { + "key": "launch_geometric_efficiency_pct", + "title": "Efficiency vs launch-style geometric", + "colorbar_label": "Const Arc - launch Geo (% of launch geometric final value)", + "slug": "launch_geometric_efficiency", + "center_zero": True, + "cmap": "RdYlGn", + }, + { + "key": "geometric_vs_launch_geometric_pct", + "title": "Geometric tuning vs launch-style geometric", + "colorbar_label": "Candidate Geo - launch Geo (% of launch geometric final value)", + "slug": "geometric_vs_launch_geometric", + "center_zero": True, + "cmap": "RdYlGn", + }, + { + "key": "constant_arc_vs_launch_constant_arc_pct", + "title": "Const arc tuning vs launch-style const arc", + "colorbar_label": "Candidate Const Arc - launch Const Arc (% of launch const arc final value)", + "slug": "constant_arc_vs_launch_constant_arc", + "center_zero": True, + "cmap": "RdYlGn", + }, + { + "key": "noise_geometric_final_value_musd", + "title": "Geometric final value with noise model", + "colorbar_label": "Geometric final value with noise model ($M)", + "slug": "noise_geometric_final_value", + "center_zero": False, + "cmap": "viridis", + }, + { + "key": "noise_constant_arc_final_value_musd", + "title": "Const arc final value with noise model", + "colorbar_label": "Const Arc final value with noise model ($M)", + "slug": "noise_constant_arc_final_value", + "center_zero": False, + "cmap": "viridis", + }, + { + "key": "noise_vs_arb_geometric_improvement_pct", + "title": "Noise-model improvement over arb-only (geometric)", + "colorbar_label": "Noise-model Geo - arb-only Geo (% of arb-only final value)", + "slug": "noise_vs_arb_geometric_improvement", + "center_zero": True, + "cmap": "RdYlGn", + }, + { + "key": "noise_vs_arb_constant_arc_improvement_pct", + "title": "Noise-model improvement over arb-only (const arc)", + "colorbar_label": "Noise-model Const Arc - arb-only Const Arc (% of arb-only final value)", + "slug": "noise_vs_arb_constant_arc_improvement", + "center_zero": True, + "cmap": "RdYlGn", + }, + ] + for spec in metric_specs: + if spec["center_zero"]: + spec["color_norm"] = CENTER_ZERO_HEATMAP_COLOR_NORM + spec["symlog_linthresh"] = CENTER_ZERO_HEATMAP_SYMLOG_LINTHRESH + spec["artifact_tag"] = CENTER_ZERO_HEATMAP_COLOR_TAG + if not RUN_CONSTANT_ARC_LENGTH: + metric_specs = [ + spec + for spec in metric_specs + if spec["key"] in GEOMETRIC_ONLY_HEATMAP_METRIC_KEYS + ] + return metric_specs + + +def get_pair_heatmap_specs(base_cfg): + """Return the three pairwise thermostat heatmap families plus slice settings.""" + fixed_slice_variants = { + "price_ratio": build_fixed_slice_variants(HEATMAP_PRICE_RATIOS), + "centeredness_margin": build_fixed_slice_variants(HEATMAP_MARGINS), + "daily_price_shift_exponent": build_fixed_slice_variants( + HEATMAP_SHIFT_EXPONENTS + ), + } + return [ + { + "slug": "price_ratio_vs_margin", + "x_values": HEATMAP_PRICE_RATIOS, + "y_values": HEATMAP_MARGINS, + "x_key": "price_ratio", + "y_key": "centeredness_margin", + "x_label": "Price ratio", + "y_label": "Centeredness margin", + "xticks": PRICE_RATIO_TICKS, + "yticks": MARGIN_TICKS, + "fixed_key": "daily_price_shift_exponent", + "fixed_label": "Shift exponent", + "fixed_slug": "shift_exp", + "fixed_slices": fixed_slice_variants["daily_price_shift_exponent"], + }, + { + "slug": "shift_exp_vs_margin", + "x_values": HEATMAP_SHIFT_EXPONENTS, + "y_values": HEATMAP_MARGINS, + "x_key": "daily_price_shift_exponent", + "y_key": "centeredness_margin", + "x_label": "Shift exponent", + "y_label": "Centeredness margin", + "xticks": SHIFT_EXPONENT_TICKS, + "yticks": MARGIN_TICKS, + "fixed_key": "price_ratio", + "fixed_label": "Price ratio", + "fixed_slug": "price_ratio", + "fixed_slices": fixed_slice_variants["price_ratio"], + }, + { + "slug": "price_ratio_vs_shift_exp", + "x_values": HEATMAP_PRICE_RATIOS, + "y_values": HEATMAP_SHIFT_EXPONENTS, + "x_key": "price_ratio", + "y_key": "daily_price_shift_exponent", + "x_label": "Price ratio", + "y_label": "Shift exponent", + "xticks": PRICE_RATIO_TICKS, + "yticks": SHIFT_EXPONENT_TICKS, + "fixed_key": "centeredness_margin", + "fixed_label": "Centeredness margin", + "fixed_slug": "margin", + "fixed_slices": fixed_slice_variants["centeredness_margin"], + }, + ] + + +def plot_heatmap( + data, + x_values, + y_values, + x_label, + y_label, + title, + colorbar_label, + filename, + xticks=None, + yticks=None, + xscale="linear", + center_zero=True, + cmap=None, + color_norm=None, + symlog_linthresh=None, +): + """Render and save a single heatmap.""" + norm = _build_heatmap_norm( + [data], + center_zero=center_zero, + color_norm=color_norm, + symlog_linthresh=symlog_linthresh, + ) + cmap_name = cmap or ("RdYlGn" if center_zero else "viridis") + + x_edges = _compute_axis_edges(x_values, scale=xscale) + y_edges = _compute_axis_edges(y_values, scale="linear") + + fig, ax = plt.subplots(figsize=(8.5, 6.0)) + im = ax.pcolormesh( + x_edges, + y_edges, + data, + cmap=cmap_name, + norm=norm, + shading="auto", + ) + + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) + ax.set_title(title) + if xscale == "log": + ax.set_xscale("log") + ax.set_xticks(np.asarray(xticks if xticks is not None else x_values, dtype=float)) + ax.set_yticks(np.asarray(yticks if yticks is not None else y_values, dtype=float)) + ax.grid(False) + + cbar = fig.colorbar(im, ax=ax) + cbar.set_label(colorbar_label) + + plt.tight_layout() + plt.savefig(filename, dpi=150) + print(f"Saved {filename}") + plt.close(fig) + + +def plot_three_variable_heatmap_3d( + price_margin_data, + shift_margin_data, + price_shift_data, + fixed_price_ratio, + fixed_margin, + fixed_shift_exponent, + title, + colorbar_label, + filename, + center_zero=True, + cmap=None, + color_norm=None, + symlog_linthresh=None, +): + """Render orthogonal 3D heatmap surfaces across the three thermostat variables.""" + norm = _build_heatmap_norm( + [price_margin_data, shift_margin_data, price_shift_data], + center_zero=center_zero, + color_norm=color_norm, + symlog_linthresh=symlog_linthresh, + ) + cmap_name = cmap or ("RdYlGn" if center_zero else "viridis") + cmap_obj = plt.get_cmap(cmap_name) + + price_margin_x, price_margin_y = np.meshgrid(HEATMAP_PRICE_RATIOS, HEATMAP_MARGINS) + price_margin_z = np.full_like(price_margin_x, fixed_shift_exponent, dtype=float) + + shift_margin_z, shift_margin_y = np.meshgrid( + HEATMAP_SHIFT_EXPONENTS, + HEATMAP_MARGINS, + ) + shift_margin_x = np.full_like(shift_margin_z, fixed_price_ratio, dtype=float) + + price_shift_x, price_shift_z = np.meshgrid( + HEATMAP_PRICE_RATIOS, + HEATMAP_SHIFT_EXPONENTS, + ) + price_shift_y = np.full_like(price_shift_x, fixed_margin, dtype=float) + + fig = plt.figure(figsize=(10.5, 7.2)) + ax = fig.add_subplot(111, projection="3d") + ax.set_facecolor("white") + fig.patch.set_facecolor("white") + + ax.plot_surface( + price_margin_x, + price_margin_y, + price_margin_z, + facecolors=cmap_obj(norm(np.asarray(price_margin_data, dtype=float))), + shade=False, + ) + ax.plot_surface( + shift_margin_x, + shift_margin_y, + shift_margin_z, + facecolors=cmap_obj(norm(np.asarray(shift_margin_data, dtype=float))), + shade=False, + ) + ax.plot_surface( + price_shift_x, + price_shift_y, + price_shift_z, + facecolors=cmap_obj(norm(np.asarray(price_shift_data, dtype=float))), + shade=False, + ) + + ax.set_xlim(float(HEATMAP_PRICE_RATIOS.min()), float(HEATMAP_PRICE_RATIOS.max())) + ax.set_ylim(float(HEATMAP_MARGINS.min()), float(HEATMAP_MARGINS.max())) + ax.set_zlim( + float(HEATMAP_SHIFT_EXPONENTS.min()), + float(HEATMAP_SHIFT_EXPONENTS.max()), + ) + ax.set_xlabel("Price ratio") + ax.set_ylabel("Centeredness margin") + ax.set_zlabel("Shift exponent") + ax.set_xticks(PRICE_RATIO_TICKS) + ax.set_yticks(MARGIN_TICKS[::2]) + ax.set_zticks(SHIFT_EXPONENT_TICKS) + ax.set_title(title) + ax.grid(False) + ax.view_init(elev=THREE_D_VIEW_ELEVATION, azim=THREE_D_VIEW_AZIMUTH) + try: + ax.set_box_aspect( + ( + float(HEATMAP_PRICE_RATIOS.max() - HEATMAP_PRICE_RATIOS.min()), + float(HEATMAP_MARGINS.max() - HEATMAP_MARGINS.min()), + float( + HEATMAP_SHIFT_EXPONENTS.max() - HEATMAP_SHIFT_EXPONENTS.min() + ), + ) + ) + except AttributeError: + pass + + sm = ScalarMappable(norm=norm, cmap=cmap_obj) + sm.set_array([]) + cbar = fig.colorbar(sm, ax=ax, fraction=0.03, pad=0.1, shrink=0.82) + cbar.set_label(colorbar_label) + + plt.tight_layout() + plt.savefig(filename, dpi=150) + print(f"Saved {filename}") + plt.close(fig) + + +def plot_arc_speed_line_chart( + data, + x_values, + y_values, + y_label, + title, + filename, + launch_curve, + launch_auto_speed=None, +): + """Plot thin multi-series efficiency lines over the arc-speed sweep.""" + fig, ax = plt.subplots(figsize=(10.5, 5.75)) + cmap = plt.cm.viridis + colors = cmap(np.linspace(0.0, 1.0, len(y_values))) + plotted_series = [] + + for yi, (y_value, color) in enumerate(zip(y_values, colors)): + series = np.asarray(data[yi], dtype=float) + plotted_series.append(series) + ax.plot( + x_values, + series, + color=color, + linewidth=SWEEP_LINE_WIDTH, + alpha=0.8, + ) + + launch_curve = np.asarray(launch_curve, dtype=float) + plotted_series.append(launch_curve) + ax.plot( + x_values, + launch_curve, + color="black", + linewidth=REFERENCE_LINE_WIDTH, + alpha=0.9, + label="Current launch config", + ) + if launch_auto_speed is not None: + ax.axvline( + float(launch_auto_speed), + color="black", + ls=":", + linewidth=0.8, + alpha=0.7, + label="Launch auto-cal speed", + ) + + ax.axhline(0.0, color="gray", ls="--", linewidth=0.8, alpha=0.5) + ax.set_xscale("log") + ax.set_xticks(ARC_LENGTH_SPEED_TICKS) + ax.set_xlabel("Arc-length speed") + ax.set_ylabel("Efficiency vs geometric (%)") + ax.set_title(title) + _set_padded_ylim(ax, plotted_series, pad_ratio=0.08) + ax.grid(True, alpha=0.25) + ax.legend(fontsize=8) + + sm = ScalarMappable( + norm=Normalize(vmin=float(np.min(y_values)), vmax=float(np.max(y_values))), + cmap=cmap, + ) + sm.set_array([]) + cbar = fig.colorbar(sm, ax=ax) + cbar.set_label(y_label) + + plt.tight_layout() + plt.savefig(filename, dpi=150) + print(f"Saved {filename}") + plt.close(fig) + + +def generate_heatmaps(base_cfg, price_data, launch_final_values, cache=None): + """Generate pairwise heatmaps for thermostat tuning and noise-vs-arb effects.""" + owns_cache = cache is None + if cache is None: + cache = make_sweep_cache(price_data, cache_scope_cfg=base_cfg) + metric_specs = get_pair_heatmap_metric_specs() + pair_specs = get_pair_heatmap_specs(base_cfg) + metric_spec_map = {spec["key"]: spec for spec in metric_specs} + slice_count = len(pair_specs[0]["fixed_slices"]) if pair_specs else 0 + + if RUN_CONSTANT_ARC_LENGTH: + print( + "Using launch-style benchmarks " + f"Geo=${launch_final_values['geometric']:,.0f}, " + f"Const Arc=${launch_final_values['constant_arc_length']:,.0f}, " + f"TVL={format_tvl_millions_label(base_cfg)}." + ) + print( + "Running {count} heatmap pair sweeps sequentially " + "(3 pair grids x {slice_count} fixed-variable quarter slices; " + "cached noise-model runs are reused across the absolute, launch, " + "and arb-only comparison outputs).".format( + count=len(pair_specs) * slice_count, + slice_count=slice_count, + ) + ) + else: + print( + "Using launch-style geometric benchmark " + f"Geo=${launch_final_values['geometric']:,.0f}, " + f"TVL={format_tvl_millions_label(base_cfg)}." + ) + print( + "RUN_CONSTANT_ARC_LENGTH=False, so only geometric heatmaps will be generated " + f"across {len(pair_specs) * slice_count} fixed-variable pair sweeps." + ) + + for pair in pair_specs: + for slice_variant in pair["fixed_slices"]: + pair_suffix = _pair_slice_suffix(pair, slice_variant) + slice_cfg = dict(base_cfg) + slice_cfg[pair["fixed_key"]] = float(slice_variant["value"]) + output_files = { + spec["key"]: heatmap_artifact_filename( + spec, + base_cfg, + suffix=pair_suffix, + ) + for spec in metric_specs + } + missing_files = _missing_artifacts( + pair_suffix, + list(output_files.values()), + ) + if not missing_files: + continue + + missing_metric_keys = [ + spec["key"] + for spec in metric_specs + if output_files[spec["key"]] in missing_files + ] + data_by_metric = build_heatmap_matrices( + x_values=pair["x_values"], + y_values=pair["y_values"], + x_key=pair["x_key"], + y_key=pair["y_key"], + base_cfg=slice_cfg, + metric_keys=missing_metric_keys, + cache=cache, + progress_label=pair_suffix, + launch_final_values=launch_final_values, + ) + print(f"[{pair_suffix}] plotting missing heatmaps...") + for metric_key in missing_metric_keys: + spec = metric_spec_map[metric_key] + plot_heatmap( + data=data_by_metric[metric_key], + x_values=pair["x_values"], + y_values=pair["y_values"], + x_label=pair["x_label"], + y_label=pair["y_label"], + title=( + f"{spec['title']}: {pair['fixed_label']} {slice_variant['label']} " + f"slice fixed at {format_heatmap_param_value(slice_variant['value'])} | " + f"TVL {format_tvl_millions_label(base_cfg)}" + ), + colorbar_label=spec["colorbar_label"], + filename=output_files[metric_key], + xticks=pair["xticks"], + yticks=pair["yticks"], + center_zero=spec["center_zero"], + cmap=spec["cmap"], + color_norm=spec.get("color_norm"), + symlog_linthresh=spec.get("symlog_linthresh"), + ) + del data_by_metric + gc.collect() + + if owns_cache: + flush_sweep_cache(cache, force=True) + cache.clear() + gc.collect() + print("Released heatmap metric cache.") + + +def generate_three_variable_3d_heatmaps( + base_cfg, + price_data, + launch_final_values, + cache=None, +): + """Render 3D thermostat heatmaps from the three pairwise quarter slices.""" + owns_cache = cache is None + if cache is None: + cache = make_sweep_cache(price_data, cache_scope_cfg=base_cfg) + + metric_specs = get_pair_heatmap_metric_specs() + metric_spec_map = {spec["key"]: spec for spec in metric_specs} + pair_specs = get_pair_heatmap_specs(base_cfg) + pair_by_fixed_key = {pair["fixed_key"]: pair for pair in pair_specs} + price_margin_pair = pair_by_fixed_key["daily_price_shift_exponent"] + shift_margin_pair = pair_by_fixed_key["price_ratio"] + price_shift_pair = pair_by_fixed_key["centeredness_margin"] + slice_count = len(price_margin_pair["fixed_slices"]) + + def build_pair_slice_data(pair, slice_variant, metric_keys): + pair_cfg = dict(base_cfg) + pair_cfg[pair["fixed_key"]] = float(slice_variant["value"]) + return build_heatmap_matrices( + x_values=pair["x_values"], + y_values=pair["y_values"], + x_key=pair["x_key"], + y_key=pair["y_key"], + base_cfg=pair_cfg, + metric_keys=metric_keys, + cache=cache, + progress_label=f"3d_{_pair_slice_suffix(pair, slice_variant)}", + launch_final_values=launch_final_values, + ) + + print( + "\nGenerating 3D thermostat heatmaps " + f"({slice_count} quarter-slice variants, TVL={format_tvl_millions_label(base_cfg)})..." + ) + + for slice_idx in range(slice_count): + shift_slice = price_margin_pair["fixed_slices"][slice_idx] + price_slice = shift_margin_pair["fixed_slices"][slice_idx] + margin_slice = price_shift_pair["fixed_slices"][slice_idx] + slice_slug = shift_slice["slug"] + slice_label = shift_slice["label"] + + output_files = { + spec["key"]: three_d_heatmap_artifact_filename( + spec, + base_cfg, + suffix=f"slice_{slice_slug}", + ) + for spec in metric_specs + } + missing_files = _missing_artifacts( + f"3d_slice_{slice_slug}", + list(output_files.values()), + ) + if not missing_files: + continue + + missing_metric_keys = [ + spec["key"] + for spec in metric_specs + if output_files[spec["key"]] in missing_files + ] + price_margin_data = build_pair_slice_data( + price_margin_pair, + shift_slice, + missing_metric_keys, + ) + shift_margin_data = build_pair_slice_data( + shift_margin_pair, + price_slice, + missing_metric_keys, + ) + price_shift_data = build_pair_slice_data( + price_shift_pair, + margin_slice, + missing_metric_keys, + ) + + for metric_key in missing_metric_keys: + spec = metric_spec_map[metric_key] + plot_three_variable_heatmap_3d( + price_margin_data=price_margin_data[metric_key], + shift_margin_data=shift_margin_data[metric_key], + price_shift_data=price_shift_data[metric_key], + fixed_price_ratio=float(price_slice["value"]), + fixed_margin=float(margin_slice["value"]), + fixed_shift_exponent=float(shift_slice["value"]), + title=( + f"{spec['title']} 3D {slice_label} slice | TVL {format_tvl_millions_label(base_cfg)}\n" + f"price_ratio={format_heatmap_param_value(price_slice['value'])}, " + f"margin={format_heatmap_param_value(margin_slice['value'])}, " + f"shift_exp={format_heatmap_param_value(shift_slice['value'])}" + ), + colorbar_label=spec["colorbar_label"], + filename=output_files[metric_key], + center_zero=spec["center_zero"], + cmap=spec["cmap"], + color_norm=spec.get("color_norm"), + symlog_linthresh=spec.get("symlog_linthresh"), + ) + + del price_margin_data, shift_margin_data, price_shift_data + gc.collect() + + if owns_cache: + flush_sweep_cache(cache, force=True) + cache.clear() + gc.collect() + print("Released 3D heatmap cache.") + + +def compute_auto_calibrated_arc_length_speed(cfg, price_data): + """Compute the launch/reference auto-calibrated speed for a config.""" + start_ts = pd.Timestamp(cfg["start"]) + row = _nearest_price_row(price_data, start_ts) + + if isinstance(price_data.columns, pd.MultiIndex): + initial_price_values = [ + float(row[(token, "close")]) + for token in cfg["tokens"] + ] + else: + initial_price_values = [ + float(row[f"close_{token}"]) + for token in cfg["tokens"] + ] + + initial_prices = jnp.array(initial_price_values, dtype=jnp.float64) + initial_reserves, Va, Vb = initialise_reclamm_reserves( + get_initial_pool_value(cfg), + initial_prices, + float(cfg["price_ratio"]), + ) + market_price_0 = float(initial_prices[0] / initial_prices[1]) + sqrt_Q = jnp.sqrt( + compute_price_ratio( + initial_reserves[0], + initial_reserves[1], + Va, + Vb, + ) + ) + return float( + calibrate_arc_length_speed( + initial_reserves[0], + initial_reserves[1], + Va, + Vb, + to_daily_price_shift_base(float(cfg["daily_price_shift_exponent"])), + 60.0, + sqrt_Q, + market_price_0, + centeredness_margin=float(cfg["centeredness_margin"]), + ) + ) + + +def generate_arc_speed_efficiency_artifacts( + base_cfg, + launch_cfg, + price_data, + launch_final_values, + cache=None, +): + """Generate arc-speed heatmaps plus the existing efficiency line charts.""" + if not RUN_CONSTANT_ARC_LENGTH: + print("\nSkipping arc-speed heatmaps because RUN_CONSTANT_ARC_LENGTH=False.") + return + owns_cache = cache is None + if cache is None: + cache = make_sweep_cache(price_data, cache_scope_cfg=base_cfg) + launch_auto_speed = compute_auto_calibrated_arc_length_speed(launch_cfg, price_data) + heatmap_metric_specs = [ + { + "key": "efficiency_pct", + "title": "Efficiency vs geometric", + "colorbar_label": "Const Arc - heatmap Geo (% of heatmap geometric final value)", + "slug": "efficiency", + "center_zero": True, + "cmap": "RdYlGn", + }, + { + "key": "noise_constant_arc_final_value_musd", + "title": "Const arc final value with noise model", + "colorbar_label": "Const Arc final value with noise model ($M)", + "slug": "noise_constant_arc_final_value", + "center_zero": False, + "cmap": "viridis", + }, + { + "key": "noise_vs_arb_constant_arc_improvement_pct", + "title": "Noise-model improvement over arb-only (const arc)", + "colorbar_label": "Noise-model Const Arc - arb-only Const Arc (% of arb-only final value)", + "slug": "noise_vs_arb_constant_arc_improvement", + "center_zero": True, + "cmap": "RdYlGn", + }, + ] + for spec in heatmap_metric_specs: + if spec["center_zero"]: + spec["color_norm"] = CENTER_ZERO_HEATMAP_COLOR_NORM + spec["symlog_linthresh"] = CENTER_ZERO_HEATMAP_SYMLOG_LINTHRESH + spec["artifact_tag"] = CENTER_ZERO_HEATMAP_COLOR_TAG + pair_specs = [ + { + "slug": "arc_speed_vs_price_ratio", + "x_values": HEATMAP_ARC_LENGTH_SPEEDS, + "y_values": HEATMAP_PRICE_RATIOS, + "x_key": "arc_length_speed", + "y_key": "price_ratio", + "x_label": "Arc-length speed", + "y_label": "Price ratio", + "title_suffix": ( + f"margin fixed at {base_cfg['centeredness_margin']:.2f}, " + f"shift_exp fixed at {base_cfg['daily_price_shift_exponent']:.2f}" + ), + "xticks": ARC_LENGTH_SPEED_TICKS, + "yticks": PRICE_RATIO_TICKS, + }, + { + "slug": "arc_speed_vs_margin", + "x_values": HEATMAP_ARC_LENGTH_SPEEDS, + "y_values": HEATMAP_MARGINS, + "x_key": "arc_length_speed", + "y_key": "centeredness_margin", + "x_label": "Arc-length speed", + "y_label": "Centeredness margin", + "title_suffix": ( + f"price_ratio fixed at {base_cfg['price_ratio']:.2f}, " + f"shift_exp fixed at {base_cfg['daily_price_shift_exponent']:.2f}" + ), + "xticks": ARC_LENGTH_SPEED_TICKS, + "yticks": MARGIN_TICKS + }, + { + "slug": "arc_speed_vs_shift_exp", + "x_values": HEATMAP_ARC_LENGTH_SPEEDS, + "y_values": HEATMAP_SHIFT_EXPONENTS, + "x_key": "arc_length_speed", + "y_key": "daily_price_shift_exponent", + "x_label": "Arc-length speed", + "y_label": "Shift exponent", + "title_suffix": ( + f"price_ratio fixed at {base_cfg['price_ratio']:.2f}, " + f"margin fixed at {base_cfg['centeredness_margin']:.2f}" + ), + "xticks": ARC_LENGTH_SPEED_TICKS, + "yticks": SHIFT_EXPONENT_TICKS, + }, + ] + metric_spec_map = {spec["key"]: spec for spec in heatmap_metric_specs} + + print( + "\nGenerating arc-speed heatmaps and line charts " + f"(launch auto-cal speed={launch_auto_speed:.3e}, TVL={format_tvl_millions_label(base_cfg)})..." + ) + + for pair in pair_specs: + heatmap_files = { + spec["key"]: heatmap_artifact_filename( + spec, + base_cfg, + suffix=pair["slug"], + ) + for spec in heatmap_metric_specs + } + line_filename = tvl_artifact_filename( + "reclamm_line_efficiency", + base_cfg, + suffix=pair["slug"], + ) + missing_files = _missing_artifacts( + pair["slug"], + list(heatmap_files.values()) + [line_filename], + ) + if not missing_files: + continue + + missing_metric_keys = [ + spec["key"] + for spec in heatmap_metric_specs + if heatmap_files[spec["key"]] in missing_files + ] + if line_filename in missing_files and "efficiency_pct" not in missing_metric_keys: + missing_metric_keys.append("efficiency_pct") + + data_by_metric = build_heatmap_matrices( + x_values=pair["x_values"], + y_values=pair["y_values"], + x_key=pair["x_key"], + y_key=pair["y_key"], + base_cfg=base_cfg, + metric_keys=missing_metric_keys, + cache=cache, + progress_label=pair["slug"], + launch_final_values=launch_final_values, + ) + for metric_key in missing_metric_keys: + if metric_key not in heatmap_files: + continue + if heatmap_files[metric_key] not in missing_files: + continue + spec = metric_spec_map[metric_key] + plot_heatmap( + data=data_by_metric[metric_key], + x_values=pair["x_values"], + y_values=pair["y_values"], + x_label=pair["x_label"], + y_label=pair["y_label"], + title=( + f"{spec['title']}: {pair['title_suffix']} | " + f"TVL {format_tvl_millions_label(base_cfg)}" + ), + colorbar_label=spec["colorbar_label"], + filename=heatmap_files[metric_key], + xticks=pair["xticks"], + yticks=pair["yticks"], + xscale="log", + center_zero=spec["center_zero"], + cmap=spec["cmap"], + color_norm=spec.get("color_norm"), + symlog_linthresh=spec.get("symlog_linthresh"), + ) + + if line_filename in missing_files: + efficiency_data = data_by_metric["efficiency_pct"] + launch_curve = build_metric_curve( + x_values=pair["x_values"], + x_key=pair["x_key"], + base_cfg=launch_cfg, + metric_key="efficiency_pct", + cache=cache, + launch_final_values=launch_final_values, + ) + plot_arc_speed_line_chart( + data=efficiency_data, + x_values=pair["x_values"], + y_values=pair["y_values"], + y_label=pair["y_label"], + title=( + "Arc-speed efficiency sweep: " + f"{pair['title_suffix']} | TVL {format_tvl_millions_label(base_cfg)}" + ), + filename=line_filename, + launch_curve=launch_curve, + launch_auto_speed=launch_auto_speed, + ) + del data_by_metric + gc.collect() + + if owns_cache: + flush_sweep_cache(cache, force=True) + cache.clear() + gc.collect() + print("Released arc-speed sweep cache.") + + +def get_launch_final_values( + all_results, + launch_cfg, + price_data, + market_linear_noise_data=None, +): + """Reuse launch-style runs when available; otherwise run them once.""" + for cfg, results in all_results: + if cfg["name"] == launch_cfg["name"]: + launch_final_values = { + "geometric": float(results["geometric"]["final_value"]), + } + if "constant_arc_length" in results: + launch_final_values["constant_arc_length"] = float( + results["constant_arc_length"]["final_value"] + ) + return launch_final_values + + print("\nRunning launch-style benchmarks for heatmaps...") + launch_results = run_comparison( + launch_cfg, + price_data=price_data, + low_data_mode=True, + market_linear_noise_data=market_linear_noise_data, + ) + launch_final_values = { + "geometric": float(launch_results["geometric"]["final_value"]), + } + if "constant_arc_length" in launch_results: + launch_final_values["constant_arc_length"] = float( + launch_results["constant_arc_length"]["final_value"] + ) + del launch_results + gc.collect() + return launch_final_values - return results def print_comparison(cfg, results): """Print text summary table.""" - methods = [ - ("Geometric", results["geometric"]), - ("Geo+Scaled", results["geometric_scaled"]), - ("Const Arc", results["constant_arc_length"]), - ("Arc+Scaled", results["cal_scaled"]), - ] + methods = [("Geometric", results["geometric"])] + has_constant_arc = "constant_arc_length" in results + if has_constant_arc: + methods.append(("Const Arc", results["constant_arc_length"])) + noise_cfg = resolve_reclamm_noise_settings(cfg) hodl_value = float((methods[0][1]["reserves"][0] * methods[0][1]["prices"][-1]).sum()) @@ -128,6 +2217,18 @@ def print_comparison(cfg, results): f"margin={cfg['centeredness_margin']}, " f"shift_exp={cfg['daily_price_shift_exponent']}, " f"fees={cfg['fees']}") + print( + f" base_tvl=${get_initial_pool_value(cfg):,.0f} " + f"(TVL {format_tvl_millions_label(cfg)})" + ) + print(f" note={cfg['reason']}") + print( + f" noise={noise_cfg['noise_summary']}, " + f"gas={cfg.get('gas_cost', 0.0)}, " + f"protocol_fee_split={cfg.get('protocol_fee_split', 0.0)}" + ) + if not has_constant_arc: + print(" constant_arc=disabled") print("-" * 105) header = " {:20s}".format("") for name, _ in methods: @@ -158,18 +2259,27 @@ def print_comparison(cfg, results): vs = (float(r["final_value"]) / hodl_value - 1) * 100 row += f" {vs:>13.2f}%" print(row) + + if has_constant_arc: + geo_final = float(results["geometric"]["final_value"]) + arc_final = float(results["constant_arc_length"]["final_value"]) + geo_lvr = hodl_value - geo_final + arc_lvr = hodl_value - arc_final + print(f" {'Const Arc - Geo':20s} ${arc_final - geo_final:>13,.0f}") + print(f" {'LVR saved vs Geo':20s} ${geo_lvr - arc_lvr:>13,.0f}") print("=" * 105) + def plot_comparison(cfg, results, fig_idx): - """Plot 4-panel comparison for one config.""" - # Method name → (result dict, color, linestyle) + """Plot comparison diagnostics for one config.""" + tvl_label = format_tvl_millions_label(cfg) variants = { "Geometric": (results["geometric"], "C0", "-"), - "Geo+Scaled": (results["geometric_scaled"], "C1", "-"), - "Const arc-len": (results["constant_arc_length"], "C2", "--"), - "Arc+Scaled": (results["cal_scaled"], "C3", "--"), } + has_constant_arc = "constant_arc_length" in results + if has_constant_arc: + variants["Const arc-len"] = (results["constant_arc_length"], "C2", "--") geo = results["geometric"] geo_prices = np.array(geo["prices"]) @@ -181,22 +2291,21 @@ def plot_comparison(cfg, results, fig_idx): price_ratio_traj = geo_prices[:n_steps, 0] / geo_prices[:n_steps, 1] fig, axes = plt.subplots(2, 2, figsize=(14, 10)) - fig.suptitle(cfg["name"], fontsize=13, fontweight="bold") + fig.suptitle(f"{cfg['name']} — TVL {tvl_label}", fontsize=13, fontweight="bold") - # (0,0) Pool value over time ax = axes[0, 0] + plotted_values = [] for name, (r, color, ls) in variants.items(): vals = np.array(r["value"]) + plotted_values.append(vals / 1e6) ax.plot(t_days, vals / 1e6, color=color, ls=ls, label=name, alpha=0.9) - ax.plot(t_days, np.array(hodl_traj) / 1e6, color="gray", ls=":", - alpha=0.5, label="HODL") + _set_padded_ylim(ax, plotted_values, pad_ratio=0.03) ax.set_xlabel("Days") ax.set_ylabel("Pool value ($M)") ax.set_title("Pool value") ax.legend(fontsize=8) ax.grid(True, alpha=0.3) - # (0,1) Cumulative LVR ax = axes[0, 1] for name, (r, color, ls) in variants.items(): vals = np.array(r["value"]) @@ -208,7 +2317,6 @@ def plot_comparison(cfg, results, fig_idx): ax.legend(fontsize=8) ax.grid(True, alpha=0.3) - # (1,0) Price ratio ax = axes[1, 0] ax.plot(t_days, price_ratio_traj, color="C4", alpha=0.7) ax.set_xlabel("Days") @@ -216,7 +2324,6 @@ def plot_comparison(cfg, results, fig_idx): ax.set_title("Price path") ax.grid(True, alpha=0.3) - # (1,1) Empirical weights ax = axes[1, 1] for name, (r, color, ls) in variants.items(): w = np.array(r["weights"]) @@ -230,19 +2337,21 @@ def plot_comparison(cfg, results, fig_idx): ax.grid(True, alpha=0.3) plt.tight_layout() - fname = f"reclamm_thermostat_comparison_{fig_idx}.png" + fname = tvl_artifact_filename("reclamm_thermostat_comparison", cfg, suffix=str(fig_idx)) plt.savefig(fname, dpi=150) print(f"Saved {fname}") plt.close(fig) - # Second figure: diagnostics + if not has_constant_arc: + print("Skipping constant-arc comparison diagnostics because RUN_CONSTANT_ARC_LENGTH=False.") + return + geo_values = np.array(geo["value"]) geo_lvr = np.array(hodl_traj) - geo_values fig2, axes2 = plt.subplots(1, 3, figsize=(18, 5)) - fig2.suptitle(f"{cfg['name']} — diagnostics", fontsize=13, fontweight="bold") + fig2.suptitle(f"{cfg['name']} — diagnostics — TVL {tvl_label}", fontsize=13, fontweight="bold") - # (left) Value difference vs geometric ax = axes2[0] for name, (r, color, ls) in variants.items(): if name == "Geometric": @@ -257,7 +2366,6 @@ def plot_comparison(cfg, results, fig_idx): ax.legend(fontsize=8) ax.grid(True, alpha=0.3) - # (middle) LVR ratio over time ax = axes2[1] mask = np.abs(geo_lvr) > 100 if mask.any(): @@ -279,7 +2387,6 @@ def plot_comparison(cfg, results, fig_idx): ax.set_title("Relative LVR") ax.grid(True, alpha=0.3) - # (right) Per-step LVR histogram ax = axes2[2] all_pos = [] for name, (r, color, ls) in variants.items(): @@ -306,74 +2413,197 @@ def plot_comparison(cfg, results, fig_idx): ax.grid(True, alpha=0.3) plt.tight_layout() - fname2 = f"reclamm_thermostat_diff_{fig_idx}.png" + fname2 = tvl_artifact_filename("reclamm_thermostat_diff", cfg, suffix=str(fig_idx)) plt.savefig(fname2, dpi=150) print(f"Saved {fname2}") plt.close(fig2) + arc_values = np.array(results["constant_arc_length"]["value"]) + n_eff = min(len(geo_values), len(arc_values)) + t_eff = np.arange(n_eff) / (60 * 24) + efficiency_pct = ( + (arc_values[:n_eff] - geo_values[:n_eff]) + / np.maximum(np.abs(geo_values[:n_eff]), 1e-12) + * 100.0 + ) + + fig3, ax3 = plt.subplots(1, 1, figsize=(10, 4.5)) + fig3.suptitle(f"{cfg['name']} — efficiency — TVL {tvl_label}", fontsize=13, fontweight="bold") + ax3.plot( + t_eff, + efficiency_pct, + color="C2", + linewidth=1.8, + label="(Const Arc - Geo) / Geo", + ) + ax3.axhline(0.0, color="gray", ls="--", alpha=0.6) + _set_padded_ylim(ax3, [efficiency_pct], pad_ratio=0.08) + ax3.set_xlabel("Days") + ax3.set_ylabel("Efficiency vs geometric (%)") + ax3.set_title("Efficiency") + ax3.legend(fontsize=8) + ax3.grid(True, alpha=0.3) + + plt.tight_layout() + fname3 = tvl_artifact_filename("reclamm_thermostat_efficiency", cfg, suffix=str(fig_idx)) + plt.savefig(fname3, dpi=150) + print(f"Saved {fname3}") + plt.close(fig3) + + if __name__ == "__main__": - all_results = [] - for i, cfg in enumerate(CONFIGS): - print(f"\n>>> Running {cfg['name']}...") - try: - results = run_comparison(cfg) - print_comparison(cfg, results) - plot_comparison(cfg, results, i) - all_results.append((cfg, results)) - except Exception as e: - print(f" FAILED: {e}") - import traceback - traceback.print_exc() - - # Summary overlay: all configs on one figure (pool value normalised) - if len(all_results) > 1: - fig, axes = plt.subplots(1, 2, figsize=(16, 5)) - fig.suptitle("Cross-config comparison (normalised)", fontsize=13, - fontweight="bold") - - method_keys = [ - ("geometric", "geo", "-"), - ("geometric_scaled", "geo+s", "-."), - ("constant_arc_length", "arc", "--"), - ("cal_scaled", "arc+s", ":"), - ] + shared_price_data = load_shared_price_data(CONFIGS) + shared_market_linear_noise_data = load_shared_market_linear_noise_data() + + for initial_pool_value in TVL_SWEEP_VALUES: + tvl_configs = configs_for_tvl(CONFIGS, initial_pool_value) + tvl_label = format_tvl_millions_label(tvl_configs[0]) + print(f"\n=== TVL sweep: {tvl_label} ===") + + all_results = [] + for i, cfg in enumerate(tvl_configs): + print(f"\n>>> Running {cfg['name']} at TVL {tvl_label}...") + try: + results = run_comparison( + cfg, + price_data=shared_price_data, + market_linear_noise_data=shared_market_linear_noise_data, + ) + print_comparison(cfg, results) + plot_comparison(cfg, results, i) + all_results.append((cfg, results)) + except Exception as e: + print(f" FAILED: {e}") + import traceback + + traceback.print_exc() + + if len(all_results) > 1: + if RUN_CONSTANT_ARC_LENGTH: + fig, axes = plt.subplots(1, 2, figsize=(16, 5)) + fig.suptitle( + f"Cross-config comparison (normalised) — TVL {tvl_label}", + fontsize=13, + fontweight="bold", + ) + + method_keys = [ + ("geometric", "geo", "-"), + ("constant_arc_length", "arc", "--"), + ] + + for i, (cfg, results) in enumerate(all_results): + geo_v = np.array(results["geometric"]["value"]) + t = np.arange(len(geo_v)) / (60 * 24) + short_name = cfg["name"].split("(")[0].strip() + + for j, (key, suffix, ls) in enumerate(method_keys): + v = np.array(results[key]["value"]) + color_idx = i * len(method_keys) + j + + axes[0].plot( + t, + v / v[0], + ls=ls, + alpha=0.8, + label=f"{short_name} {suffix}", + color=f"C{color_idx % 10}", + ) + + if key != "geometric": + pct_diff = (v - geo_v) / geo_v * 100 + axes[1].plot( + t, + pct_diff, + ls=ls, + alpha=0.8, + label=f"{short_name} {suffix}", + color=f"C{color_idx % 10}", + ) - for i, (cfg, results) in enumerate(all_results): - geo_v = np.array(results["geometric"]["value"]) - t = np.arange(len(geo_v)) / (60 * 24) - short_name = cfg["name"].split("(")[0].strip() - - for j, (key, suffix, ls) in enumerate(method_keys): - v = np.array(results[key]["value"]) - color_idx = i * len(method_keys) + j - - # (left) Normalised pool value - axes[0].plot(t, v / v[0], ls=ls, alpha=0.8, - label=f"{short_name} {suffix}", - color=f"C{color_idx % 10}") - - # (right) Value difference vs geometric (skip geo itself) - if key != "geometric": - pct_diff = (v - geo_v) / geo_v * 100 - axes[1].plot(t, pct_diff, ls=ls, alpha=0.8, - label=f"{short_name} {suffix}", - color=f"C{color_idx % 10}") - - axes[0].set_xlabel("Days") - axes[0].set_ylabel("Normalised pool value") - axes[0].set_title("Pool value (V/V0)") - axes[0].legend(fontsize=6, ncol=2) - axes[0].grid(True, alpha=0.3) - - axes[1].set_xlabel("Days") - axes[1].set_ylabel("(Method - Geo) / Geo (%)") - axes[1].set_title("Relative value difference vs Geometric") - axes[1].axhline(0, color="gray", ls="--", alpha=0.5) - axes[1].legend(fontsize=6, ncol=2) - axes[1].grid(True, alpha=0.3) - - plt.tight_layout() - plt.savefig("reclamm_thermostat_summary.png", dpi=150) - print("\nSaved reclamm_thermostat_summary.png") - plt.close(fig) + axes[0].set_xlabel("Days") + axes[0].set_ylabel("Normalised pool value") + axes[0].set_title("Pool value (V/V0)") + axes[0].legend(fontsize=6, ncol=2) + axes[0].grid(True, alpha=0.3) + + axes[1].set_xlabel("Days") + axes[1].set_ylabel("Efficiency vs geometric (%)") + axes[1].set_title("Efficiency vs Geometric") + axes[1].axhline(0, color="gray", ls="--", alpha=0.5) + axes[1].legend(fontsize=6, ncol=2) + axes[1].grid(True, alpha=0.3) + else: + fig, ax = plt.subplots(1, 1, figsize=(9, 5)) + fig.suptitle( + f"Cross-config comparison (normalised geometric) — TVL {tvl_label}", + fontsize=13, + fontweight="bold", + ) + + for i, (cfg, results) in enumerate(all_results): + geo_v = np.array(results["geometric"]["value"]) + t = np.arange(len(geo_v)) / (60 * 24) + short_name = cfg["name"].split("(")[0].strip() + ax.plot( + t, + geo_v / geo_v[0], + ls="-", + alpha=0.8, + label=f"{short_name} geo", + color=f"C{i % 10}", + ) + + ax.set_xlabel("Days") + ax.set_ylabel("Normalised pool value") + ax.set_title("Geometric pool value (V/V0)") + ax.legend(fontsize=6, ncol=2) + ax.grid(True, alpha=0.3) + + plt.tight_layout() + summary_name = tvl_artifact_filename( + "reclamm_thermostat_summary", + tvl_configs[0], + ) + plt.savefig(summary_name, dpi=150) + print(f"\nSaved {summary_name}") + plt.close(fig) + + launch_final_values = get_launch_final_values( + all_results, + launch_cfg=tvl_configs[0], + price_data=shared_price_data, + market_linear_noise_data=shared_market_linear_noise_data, + ) + shared_sweep_cache = make_sweep_cache( + shared_price_data, + cache_scope_cfg=tvl_configs[1], + market_linear_noise_data=shared_market_linear_noise_data, + ) + + print(f"\nGenerating thermostat heatmaps for TVL {tvl_label}...") + generate_heatmaps( + dict(tvl_configs[1]), + shared_price_data, + launch_final_values=launch_final_values, + cache=shared_sweep_cache, + ) + + generate_arc_speed_efficiency_artifacts( + dict(tvl_configs[1]), + launch_cfg=dict(tvl_configs[0]), + price_data=shared_price_data, + launch_final_values=launch_final_values, + cache=shared_sweep_cache, + ) + generate_three_variable_3d_heatmaps( + dict(tvl_configs[1]), + price_data=shared_price_data, + launch_final_values=launch_final_values, + cache=shared_sweep_cache, + ) + flush_sweep_cache(shared_sweep_cache, force=True) + shared_sweep_cache.clear() + gc.collect() + print(f"Released shared sweep cache for TVL {tvl_label}.") diff --git a/scripts/reclamm/demo_run_reclamm.py b/scripts/reclamm/demo_run_reclamm.py index 3ea21ec..132f512 100644 --- a/scripts/reclamm/demo_run_reclamm.py +++ b/scripts/reclamm/demo_run_reclamm.py @@ -36,105 +36,127 @@ def balancer_fingerprint(tokens, start, end, fees): } +def reclamm_fingerprint(tokens, start, end, fees, interpolation_method="geometric"): + """Build a reCLAMM fingerprint for a demo scenario.""" + return { + "tokens": tokens, + "rule": "reclamm", + "startDateString": start, + "endDateString": end, + "initial_pool_value": 1000000.0, + "do_arb": True, + "fees": fees, + "gas_cost": 0.0, + "arb_fees": 0.0, + "chunk_period": 60, + "weight_interpolation_period": 60, + "reclamm_interpolation_method": interpolation_method, + "reclamm_arc_length_speed": None, + } + + +def reclamm_params(price_ratio, centeredness_margin, daily_price_shift_exponent): + """Build reCLAMM params from a concise config.""" + return { + "price_ratio": jnp.array(price_ratio), + "centeredness_margin": jnp.array(centeredness_margin), + "daily_price_shift_base": jnp.array( + to_daily_price_shift_base(daily_price_shift_exponent) + ), + } + + +def _apply_active_noise_settings(fp): + """Enable the active AAVE/ETH reCLAMM noise model for demo runs.""" + if fp.get("rule") != "reclamm" or list(fp.get("tokens", [])) != ["AAVE", "ETH"]: + return fp, "disabled" + + from compare_reclamm_thermostats import ( + AAVE_ETH_NOISE_SETTINGS, + resolve_reclamm_noise_settings, + ) + + cfg = { + "tokens": fp["tokens"], + "start": fp["startDateString"], + "end": fp["endDateString"], + "enable_noise_model": True, + "noise_model": AAVE_ETH_NOISE_SETTINGS["noise_model"], + "noise_artifact_dir": AAVE_ETH_NOISE_SETTINGS["noise_artifact_dir"], + "noise_pool_id": AAVE_ETH_NOISE_SETTINGS["noise_pool_id"], + "gas_cost": fp.get("gas_cost", AAVE_ETH_NOISE_SETTINGS["gas_cost"]), + "protocol_fee_split": fp.get( + "protocol_fee_split", + AAVE_ETH_NOISE_SETTINGS["protocol_fee_split"], + ), + "arb_frequency": fp.get("arb_frequency"), + "noise_trader_ratio": fp.get("noise_trader_ratio", 0.0), + "reclamm_noise_params": fp.get("reclamm_noise_params"), + "noise_arrays_path": fp.get("noise_arrays_path"), + } + noise_cfg = resolve_reclamm_noise_settings(cfg) + + updated = dict(fp) + updated["gas_cost"] = cfg["gas_cost"] + updated["protocol_fee_split"] = cfg["protocol_fee_split"] + updated["noise_trader_ratio"] = noise_cfg.get("noise_trader_ratio", 0.0) + for key in ("noise_model", "reclamm_noise_params", "noise_arrays_path", "arb_frequency"): + if noise_cfg.get(key) is not None: + updated[key] = noise_cfg[key] + return updated, noise_cfg["noise_summary"] + + SCENARIOS = [ { - "name": "AAVE/ETH on-chain (25bps)", + "name": "AAVE/ETH launch-style range (25bps, geometric)", "reclamm": { - "fingerprint": { - "tokens": ["AAVE", "ETH"], - "rule": "reclamm", - "startDateString": "2024-06-01 00:00:00", - "endDateString": "2025-06-01 00:00:00", - "initial_pool_value": 1000000.0, - "do_arb": True, - "fees": 0.0025, - "gas_cost": 0.0, - "arb_fees": 0.0, - "chunk_period": 60, - "weight_interpolation_period": 60, - }, - "params": { - "price_ratio": jnp.array(1.5), - "centeredness_margin": jnp.array(0.5), - "daily_price_shift_base": jnp.array( - to_daily_price_shift_base(0.1) - ), - }, + "fingerprint": reclamm_fingerprint( + ["AAVE", "ETH"], + "2024-06-01 00:00:00", + "2025-06-01 00:00:00", + 0.0025, + interpolation_method="geometric", + ), + "params": reclamm_params(1.5014, 0.5, 0.1), }, }, { - "name": "AAVE/ETH zero fees", + "name": "AAVE/ETH tighter launch-style range (25bps, geometric)", "reclamm": { - "fingerprint": { - "tokens": ["AAVE", "ETH"], - "rule": "reclamm", - "startDateString": "2024-06-01 00:00:00", - "endDateString": "2025-06-01 00:00:00", - "initial_pool_value": 1000000.0, - "do_arb": True, - "fees": 0.0, - "gas_cost": 0.0, - "arb_fees": 0.0, - "chunk_period": 60, - "weight_interpolation_period": 60, - }, - "params": { - "price_ratio": jnp.array(1.5), - "centeredness_margin": jnp.array(0.5), - "daily_price_shift_base": jnp.array( - to_daily_price_shift_base(0.1) - ), - }, + "fingerprint": reclamm_fingerprint( + ["AAVE", "ETH"], + "2024-06-01 00:00:00", + "2025-06-01 00:00:00", + 0.0025, + interpolation_method="geometric", + ), + "params": reclamm_params(1.15, 0.5, 0.1), }, }, { - "name": "AAVE/ETH wide range (25bps)", + "name": "AAVE/ETH tighter launch-style range (25bps, constant arc)", "reclamm": { - "fingerprint": { - "tokens": ["AAVE", "ETH"], - "rule": "reclamm", - "startDateString": "2024-06-01 00:00:00", - "endDateString": "2025-06-01 00:00:00", - "initial_pool_value": 1000000.0, - "do_arb": True, - "fees": 0.0025, - "gas_cost": 0.0, - "arb_fees": 0.0, - "chunk_period": 60, - "weight_interpolation_period": 60, - }, - "params": { - "price_ratio": jnp.array(4.0), - "centeredness_margin": jnp.array(0.2), - "daily_price_shift_base": jnp.array( - to_daily_price_shift_base(1.0) - ), - }, + "fingerprint": reclamm_fingerprint( + ["AAVE", "ETH"], + "2024-06-01 00:00:00", + "2025-06-01 00:00:00", + 0.0025, + interpolation_method="constant_arc_length", + ), + "params": reclamm_params(1.15, 0.5, 0.1), }, }, { "name": "BTC/ETH (10bps)", "reclamm": { - "fingerprint": { - "tokens": ["BTC", "ETH"], - "rule": "reclamm", - "startDateString": "2024-01-01 00:00:00", - "endDateString": "2025-06-01 00:00:00", - "initial_pool_value": 1000000.0, - "do_arb": True, - "fees": 0.001, - "gas_cost": 0.0, - "arb_fees": 0.0, - "chunk_period": 60, - "weight_interpolation_period": 60, - }, - "params": { - "price_ratio": jnp.array(2.0), - "centeredness_margin": jnp.array(0.3), - "daily_price_shift_base": jnp.array( - to_daily_price_shift_base(0.5) - ), - }, + "fingerprint": reclamm_fingerprint( + ["BTC", "ETH"], + "2024-01-01 00:00:00", + "2025-06-01 00:00:00", + 0.001, + interpolation_method="geometric", + ), + "params": reclamm_params(2.0, 0.3, 0.5), }, }, ] @@ -143,7 +165,7 @@ def balancer_fingerprint(tokens, start, end, fees): def run_scenario(scenario): """Run a reClAMM config and its Balancer 50/50 baseline, print comparison.""" rc = scenario["reclamm"] - fp = rc["fingerprint"] + fp, noise_summary = _apply_active_noise_settings(dict(rc["fingerprint"])) # Run reClAMM reclamm_result = do_run_on_historic_data( @@ -173,7 +195,14 @@ def run_scenario(scenario): print("=" * 80) print(f" {scenario['name']}") - print(f" Tokens: {', '.join(fp['tokens'])} | Fees: {fp['fees']}") + print( + f" Tokens: {', '.join(fp['tokens'])} | Fees: {fp['fees']} | " + f"Interpolation: {fp.get('reclamm_interpolation_method', 'geometric')}" + ) + print( + f" Noise: {noise_summary} | Gas: {fp.get('gas_cost', 0.0)} | " + f"Protocol fee split: {fp.get('protocol_fee_split', 0.0)}" + ) print("-" * 80) print(f" {'':30s} {'reClAMM':>14s} {'Balancer 50/50':>14s}") print(f" {'Initial value':30s} ${rc_init:>13,.0f} ${bal_init:>13,.0f}") diff --git a/scripts/reclamm/find_adjacent_heatmap_pairs.py b/scripts/reclamm/find_adjacent_heatmap_pairs.py new file mode 100644 index 0000000..51a3a8a --- /dev/null +++ b/scripts/reclamm/find_adjacent_heatmap_pairs.py @@ -0,0 +1,1257 @@ +"""Scan cached reCLAMM heatmaps for adjacent cells with large value gaps. + +This script reconstructs heatmap cells from the persisted scalar forward-value +cache written by ``compare_reclamm_thermostats.py``. It does not inspect PNG +pixels or rerun the simulator for cache-backed metrics. +""" + +from __future__ import annotations + +import argparse +import hashlib +import importlib.util +import math +import os +from pathlib import Path +from typing import Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple + +import numpy as np +import pandas as pd + + +CACHE_ONLY_METRIC_SPECS = { + "efficiency_pct": { + "sources": ("noise_constant_arc", "noise_geometric"), + "unit": "pct", + "compute": lambda values: ( + values["noise_constant_arc"] / max(abs(values["noise_geometric"]), 1.0e-12) + - 1.0 + ) + * 100.0, + }, + "noise_geometric_final_value_musd": { + "sources": ("noise_geometric",), + "unit": "musd", + "compute": lambda values: values["noise_geometric"] / 1.0e6, + }, + "noise_constant_arc_final_value_musd": { + "sources": ("noise_constant_arc",), + "unit": "musd", + "compute": lambda values: values["noise_constant_arc"] / 1.0e6, + }, + "noise_vs_arb_geometric_improvement_pct": { + "sources": ("noise_geometric", "arb_geometric"), + "unit": "pct", + "compute": lambda values: ( + values["noise_geometric"] / max(abs(values["arb_geometric"]), 1.0e-12) - 1.0 + ) + * 100.0, + }, + "noise_vs_arb_constant_arc_improvement_pct": { + "sources": ("noise_constant_arc", "arb_constant_arc"), + "unit": "pct", + "compute": lambda values: ( + values["noise_constant_arc"] + / max(abs(values["arb_constant_arc"]), 1.0e-12) + - 1.0 + ) + * 100.0, + }, +} + +OUTPUT_COLUMNS = [ + "metric_key", + "metric_unit", + "source_noise_profile", + "pair_slug", + "slice_slug", + "slice_label", + "fixed_key", + "fixed_value", + "adjacency_axis", + "heatmap_value_diff_abs", + "heatmap_value_diff_signed_2_minus_1", + "1_price_ratio", + "1_centeredness_margin", + "1_daily_price_shift_exponent", + "1_tvl_usd", + "1_heatmap_value", + "1_x_index", + "1_y_index", + "2_price_ratio", + "2_centeredness_margin", + "2_daily_price_shift_exponent", + "2_tvl_usd", + "2_heatmap_value", + "2_x_index", + "2_y_index", +] + + +def build_inclusive_sweep(start: float, stop: float, step: float) -> np.ndarray: + """Build a sweep that keeps the requested step and explicitly includes the stop.""" + values = np.arange(start, stop + 1.0e-12, step, dtype=float) + if values.size == 0 or not np.isclose(values[-1], stop): + values = np.append(values, float(stop)) + return values + + +def _resolve_repo_root(script_path): + """Locate the repository root from either scripts/ or scripts/reclamm/.""" + script_path = Path(script_path).resolve() + for parent in script_path.parents: + if (parent / "quantammsim").exists() and (parent / "scripts").exists(): + return parent + return script_path.parents[1] + + +REPO_ROOT = _resolve_repo_root(__file__) +DEFAULT_MARKET_LINEAR_NOISE_START_DATE = "2024-06-01" +DEFAULT_MARKET_LINEAR_NOISE_END_DATE = "2026-03-01" +DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH = str( + REPO_ROOT + / "results" + / "linear_market_noise" + / "_sim_arrays" + / ( + "0x9d1fcf346ea1b0_" + f"{DEFAULT_MARKET_LINEAR_NOISE_START_DATE}_{DEFAULT_MARKET_LINEAR_NOISE_END_DATE}.npz" + ) +) + + +class _LightweightCompareContext: + """Small subset of compare_reclamm_thermostats usable without JAX.""" + + RUN_CONSTANT_ARC_LENGTH = True + DEFAULT_INITIAL_POOL_VALUE = 1_000_000.0 + TVL_SWEEP_VALUES = ( + 1_000_000.0, + 5_000_000.0, + 20_000_000.0, + ) + HEATMAP_PRICE_RATIOS = build_inclusive_sweep(1.01, 3.00, 0.025) + HEATMAP_MARGINS = np.linspace(0.05, 0.90, 39) + HEATMAP_SHIFT_EXPONENTS = build_inclusive_sweep(0.01, 0.50, 0.0125) + FIXED_SLICE_FRACTIONS = (0.125, 0.375, 0.625, 0.875) + FIXED_SLICE_LABELS = ("Q1", "Q2", "Q3", "Q4") + HEATMAP_FORWARD_CACHE_ENABLED = True + HEATMAP_FORWARD_CACHE_RUN_NAME = "aave_eth_thermostat_heatmaps_market_linear_v2" + HEATMAP_FORWARD_CACHE_ROOT = os.path.join( + "results", + "reclamm_heatmap_forward_cache", + ) + AAVE_WETH_POOL_ID = "0x9d1fcf346ea1b0" + DEFAULT_MARKET_LINEAR_ARTIFACT_DIR = "results/linear_market_noise" + DEFAULT_MARKET_LINEAR_NOISE_START_DATE = DEFAULT_MARKET_LINEAR_NOISE_START_DATE + DEFAULT_MARKET_LINEAR_NOISE_END_DATE = DEFAULT_MARKET_LINEAR_NOISE_END_DATE + DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH = DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH + DEFAULT_NOISE_MODEL = "market_linear" + DEFAULT_GAS_COST = 1.0 + DEFAULT_PROTOCOL_FEE_SPLIT = 0.25 + LEGACY_NOISE_COEFFS = [ + -0.453, + 0.025, + -0.060, + 0.310, + -0.149, + 0.359, + 0.061, + 0.060, + ] + LEGACY_LOG_CADENCE = 2.68 + LEGACY_ARB_FREQUENCY = max(1, round(math.exp(LEGACY_LOG_CADENCE))) + FIXED_COMPARE_ARB_FREQUENCY = LEGACY_ARB_FREQUENCY + AAVE_ETH_NOISE_SETTINGS = { + "enable_noise_model": True, + "noise_model": DEFAULT_NOISE_MODEL, + "noise_reference_model": DEFAULT_NOISE_MODEL, + "noise_artifact_dir": DEFAULT_MARKET_LINEAR_ARTIFACT_DIR, + "noise_pool_id": AAVE_WETH_POOL_ID, + "noise_arrays_path": DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH, + "arb_frequency": FIXED_COMPARE_ARB_FREQUENCY, + "gas_cost": DEFAULT_GAS_COST, + "protocol_fee_split": DEFAULT_PROTOCOL_FEE_SPLIT, + } + CONFIGS = [ + { + "name": "AAVE/ETH launch-style range (25bps, reference)", + "tokens": ["AAVE", "ETH"], + "start": "2024-06-01 00:00:00", + "end": "2025-06-01 00:00:00", + "fees": 0.0025, + "price_ratio": 1.5014, + "centeredness_margin": 0.5, + "daily_price_shift_exponent": 0.1, + "reason": "Original launch-style parameters.", + **AAVE_ETH_NOISE_SETTINGS, + }, + { + "name": "AAVE/ETH aggressive tight range (25bps)", + "tokens": ["AAVE", "ETH"], + "start": "2024-06-01 00:00:00", + "end": "2025-06-01 00:00:00", + "fees": 0.0025, + "price_ratio": 1.10, + "centeredness_margin": 0.60, + "daily_price_shift_exponent": 0.1, + "reason": ( + "Aggressively tightened and moved to an earlier thermostat trigger. " + "At fixed price_ratio=1.10, the shift_exponent sweep still favored " + "0.1, while margin=0.60 widened the non-linear edge materially." + ), + **AAVE_ETH_NOISE_SETTINGS, + }, + ] + + def __init__(self): + self._noise_settings_cache = {} + self.noise_profile = "market_linear" + + @classmethod + def from_compare_module(cls, compare_module): + """Build an analyzer-friendly context from an imported thermostat module.""" + context = cls() + copied_attrs = ( + "RUN_CONSTANT_ARC_LENGTH", + "DEFAULT_INITIAL_POOL_VALUE", + "TVL_SWEEP_VALUES", + "HEATMAP_PRICE_RATIOS", + "HEATMAP_MARGINS", + "HEATMAP_SHIFT_EXPONENTS", + "FIXED_SLICE_FRACTIONS", + "FIXED_SLICE_LABELS", + "HEATMAP_FORWARD_CACHE_ENABLED", + "HEATMAP_FORWARD_CACHE_RUN_NAME", + "HEATMAP_FORWARD_CACHE_ROOT", + "AAVE_WETH_POOL_ID", + "DEFAULT_MARKET_LINEAR_ARTIFACT_DIR", + "DEFAULT_MARKET_LINEAR_NOISE_START_DATE", + "DEFAULT_MARKET_LINEAR_NOISE_END_DATE", + "DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH", + "DEFAULT_NOISE_MODEL", + "DEFAULT_GAS_COST", + "DEFAULT_PROTOCOL_FEE_SPLIT", + "LEGACY_NOISE_COEFFS", + "LEGACY_LOG_CADENCE", + "LEGACY_ARB_FREQUENCY", + "FIXED_COMPARE_ARB_FREQUENCY", + "AAVE_ETH_NOISE_SETTINGS", + "CONFIGS", + ) + for attr_name in copied_attrs: + if hasattr(compare_module, attr_name): + value = getattr(compare_module, attr_name) + if attr_name == "CONFIGS": + value = [dict(cfg) for cfg in value] + elif isinstance(value, dict): + value = dict(value) + elif isinstance(value, np.ndarray): + value = np.asarray(value, dtype=float).copy() + elif isinstance(value, tuple): + value = tuple(value) + elif isinstance(value, list): + value = list(value) + setattr(context, attr_name, value) + context._noise_settings_cache.clear() + return context + + def set_noise_profile(self, profile): + if profile != "market_linear": + raise ValueError(f"Unsupported lightweight noise profile: {profile}") + if profile != self.noise_profile: + self.noise_profile = profile + self._noise_settings_cache.clear() + + def get_initial_pool_value(self, cfg): + return float(cfg.get("initial_pool_value", self.DEFAULT_INITIAL_POOL_VALUE)) + + def get_tvl_millions(self, cfg): + return self.get_initial_pool_value(cfg) / 1_000_000.0 + + def format_tvl_millions_slug(self, cfg): + tvl_millions = self.get_tvl_millions(cfg) + rounded = round(float(tvl_millions), 6) + if np.isclose(rounded, round(rounded)): + return f"{int(round(rounded))}m" + return f"{rounded:.6f}".rstrip("0").rstrip(".").replace(".", "p") + "m" + + def format_tvl_millions_label(self, cfg): + return f"{self.get_tvl_millions(cfg):.1f}M" + + def configs_for_tvl(self, base_configs, initial_pool_value): + configs = [] + for cfg in base_configs: + updated = dict(cfg) + updated["initial_pool_value"] = float(initial_pool_value) + configs.append(updated) + return configs + + def _heatmap_forward_cache_scope_slug(self, cfg): + if cfg is None: + return "unspecified_tvl" + return f"tvl_{self.format_tvl_millions_slug(cfg)}" + + def _heatmap_forward_cache_path(self, cfg): + if not self.HEATMAP_FORWARD_CACHE_ENABLED: + return None + return os.path.join( + self.HEATMAP_FORWARD_CACHE_ROOT, + self.HEATMAP_FORWARD_CACHE_RUN_NAME, + f"forward_values_{self._heatmap_forward_cache_scope_slug(cfg)}.parquet", + ) + + def build_fixed_slice_variants(self, values): + values = np.asarray(values, dtype=float) + if values.size < len(self.FIXED_SLICE_FRACTIONS): + raise ValueError("Need at least four grid points to build fixed slices") + + variants = [] + used_indices = set() + for idx, fraction in enumerate(self.FIXED_SLICE_FRACTIONS): + target_index = int(round(fraction * (values.size - 1))) + while target_index in used_indices and target_index + 1 < values.size: + target_index += 1 + while target_index in used_indices and target_index - 1 >= 0: + target_index -= 1 + if target_index in used_indices: + raise ValueError( + "Could not build four unique fixed slices from sweep grid" + ) + used_indices.add(target_index) + variants.append( + { + "index": target_index, + "fraction": fraction, + "label": self.FIXED_SLICE_LABELS[idx], + "slug": f"q{idx + 1}", + "value": float(values[target_index]), + } + ) + return variants + + def get_pair_heatmap_specs(self, _base_cfg): + fixed_slice_variants = { + "price_ratio": self.build_fixed_slice_variants(self.HEATMAP_PRICE_RATIOS), + "centeredness_margin": self.build_fixed_slice_variants(self.HEATMAP_MARGINS), + "daily_price_shift_exponent": self.build_fixed_slice_variants( + self.HEATMAP_SHIFT_EXPONENTS + ), + } + return [ + { + "slug": "price_ratio_vs_margin", + "x_values": self.HEATMAP_PRICE_RATIOS, + "y_values": self.HEATMAP_MARGINS, + "x_key": "price_ratio", + "y_key": "centeredness_margin", + "fixed_key": "daily_price_shift_exponent", + "fixed_slices": fixed_slice_variants["daily_price_shift_exponent"], + }, + { + "slug": "shift_exp_vs_margin", + "x_values": self.HEATMAP_SHIFT_EXPONENTS, + "y_values": self.HEATMAP_MARGINS, + "x_key": "daily_price_shift_exponent", + "y_key": "centeredness_margin", + "fixed_key": "price_ratio", + "fixed_slices": fixed_slice_variants["price_ratio"], + }, + { + "slug": "price_ratio_vs_shift_exp", + "x_values": self.HEATMAP_PRICE_RATIOS, + "y_values": self.HEATMAP_SHIFT_EXPONENTS, + "x_key": "price_ratio", + "y_key": "daily_price_shift_exponent", + "fixed_key": "centeredness_margin", + "fixed_slices": fixed_slice_variants["centeredness_margin"], + }, + ] + + @staticmethod + def _hashable_noise_params(params): + if params is None: + return None + return tuple(sorted((str(k), round(float(v), 12)) for k, v in params.items())) + + def _normalize_arb_frequency(self, value, default=None): + if value is None: + if default is None: + default = self.FIXED_COMPARE_ARB_FREQUENCY + value = default + return max(int(round(float(value))), 1) + + def get_effective_arb_frequency(self, cfg, noise_cfg=None): + del noise_cfg + return self._normalize_arb_frequency(self.FIXED_COMPARE_ARB_FREQUENCY) + + def _canonical_noise_reference_model(self, cfg): + noise_model = cfg.get("noise_model", self.DEFAULT_NOISE_MODEL) or self.DEFAULT_NOISE_MODEL + reference_model = cfg.get("noise_reference_model") + if reference_model is None: + reference_model = self.DEFAULT_NOISE_MODEL if noise_model == "arb_only" else noise_model + return str(reference_model) + + def normalize_compare_run_cfg(self, cfg, enable_noise_model=None): + updated = dict(cfg) + updated["price_ratio"] = float(cfg["price_ratio"]) + updated["centeredness_margin"] = float(cfg["centeredness_margin"]) + updated["daily_price_shift_exponent"] = float( + cfg["daily_price_shift_exponent"] + ) + updated["initial_pool_value"] = float(self.get_initial_pool_value(cfg)) + updated["gas_cost"] = self.DEFAULT_GAS_COST + updated["protocol_fee_split"] = self.DEFAULT_PROTOCOL_FEE_SPLIT + updated["arb_fees"] = 0.0 + updated["arb_frequency"] = self.get_effective_arb_frequency(cfg) + updated["noise_trader_ratio"] = 0.0 + + arc_length_speed = cfg.get("arc_length_speed") + if arc_length_speed is None: + updated.pop("arc_length_speed", None) + else: + updated["arc_length_speed"] = float(arc_length_speed) + + use_noise = ( + bool(cfg.get("enable_noise_model", False)) + if enable_noise_model is None + else bool(enable_noise_model) + ) + updated["enable_noise_model"] = use_noise + + reference_mode = self._canonical_noise_reference_model(cfg) + if use_noise: + updated["noise_model"] = reference_mode + updated["noise_reference_model"] = reference_mode + else: + updated["noise_model"] = "arb_only" + updated["noise_reference_model"] = reference_mode + + if reference_mode == "market_linear": + updated["noise_arrays_path"] = self.DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH + updated.pop("reclamm_noise_params", None) + if use_noise or updated["noise_model"] == "arb_only": + updated["noise_artifact_dir"] = self.DEFAULT_MARKET_LINEAR_ARTIFACT_DIR + updated["noise_pool_id"] = self.AAVE_WETH_POOL_ID + else: + updated.pop("reclamm_noise_params", None) + updated.pop("noise_arrays_path", None) + updated.pop("noise_artifact_dir", None) + updated.pop("noise_pool_id", None) + + return updated + + def _load_market_linear_noise_stats(self, arrays_path=None): + arrays_path = os.path.abspath( + os.fspath(arrays_path or self.DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH) + ) + if not os.path.exists(arrays_path): + raise FileNotFoundError(f"market_linear arrays file not found: {arrays_path}") + + with np.load(arrays_path) as arrays: + required_keys = {"noise_base", "noise_tvl_coeff", "tvl_mean", "tvl_std"} + missing_keys = sorted(required_keys.difference(arrays.files)) + if missing_keys: + raise KeyError( + f"market_linear arrays file {arrays_path} is missing keys: {missing_keys}" + ) + return arrays_path, float(arrays["tvl_mean"]), float(arrays["tvl_std"]) + + def _market_linear_noise_settings(self, noise_model="market_linear", arb_frequency=None): + arrays_path, tvl_mean, tvl_std = self._load_market_linear_noise_stats() + arb_frequency = self._normalize_arb_frequency(arb_frequency) + return { + "noise_model": noise_model, + "noise_trader_ratio": 0.0, + "reclamm_noise_params": { + "tvl_mean": tvl_mean, + "tvl_std": tvl_std, + }, + "noise_arrays_path": arrays_path, + "arb_frequency": arb_frequency, + "noise_summary": f"{noise_model} (arb_frequency={arb_frequency})", + "noise_cache_key": ( + noise_model, + arrays_path, + arb_frequency, + round(tvl_mean, 12), + round(tvl_std, 12), + ), + } + + def _legacy_calibrated_noise_settings(self, arb_frequency=None, noise_model="calibrated"): + arb_frequency = self._normalize_arb_frequency(arb_frequency) + return { + "noise_model": noise_model, + "noise_trader_ratio": 0.0, + "reclamm_noise_params": { + f"c_{i}": self.LEGACY_NOISE_COEFFS[i] + for i in range(len(self.LEGACY_NOISE_COEFFS)) + }, + "arb_frequency": arb_frequency, + "noise_summary": f"{noise_model} (arb_frequency={arb_frequency})", + "noise_cache_key": ( + noise_model, + tuple(round(float(c), 12) for c in self.LEGACY_NOISE_COEFFS), + arb_frequency, + ), + } + + def resolve_reclamm_noise_settings(self, cfg): + cfg = self.normalize_compare_run_cfg(cfg) + enable_noise_model = cfg.get("enable_noise_model", False) + requested_mode = cfg.get("noise_model", self.DEFAULT_NOISE_MODEL) + reference_mode = cfg.get("noise_reference_model", self.DEFAULT_NOISE_MODEL) + requested_arb_frequency = self.get_effective_arb_frequency(cfg) + cache_key = ( + tuple(cfg.get("tokens", [])), + cfg.get("start"), + cfg.get("end"), + enable_noise_model, + requested_mode, + reference_mode, + cfg.get("noise_artifact_dir", self.DEFAULT_MARKET_LINEAR_ARTIFACT_DIR), + cfg.get("noise_pool_id", self.AAVE_WETH_POOL_ID), + requested_arb_frequency, + round(float(cfg.get("noise_trader_ratio", 0.0)), 12), + self._hashable_noise_params(cfg.get("reclamm_noise_params")), + cfg.get("noise_arrays_path"), + ) + if cache_key in self._noise_settings_cache: + return self._noise_settings_cache[cache_key] + + if requested_mode == "arb_only": + if reference_mode == "market_linear": + result = self._market_linear_noise_settings( + noise_model="arb_only", + arb_frequency=requested_arb_frequency + ) + elif reference_mode == "calibrated": + result = self._legacy_calibrated_noise_settings( + noise_model="arb_only", + arb_frequency=requested_arb_frequency + ) + else: + arb_frequency = requested_arb_frequency + result = { + "noise_model": "arb_only", + "noise_trader_ratio": cfg.get("noise_trader_ratio", 0.0), + "reclamm_noise_params": cfg.get("reclamm_noise_params"), + "noise_arrays_path": cfg.get("noise_arrays_path"), + "arb_frequency": arb_frequency, + "noise_summary": f"arb_only (arb_frequency={arb_frequency})", + "noise_cache_key": ( + "arb_only", + round(float(cfg.get("noise_trader_ratio", 0.0)), 12), + self._hashable_noise_params(cfg.get("reclamm_noise_params")), + cfg.get("noise_arrays_path"), + arb_frequency, + ), + } + elif requested_mode == "market_linear": + result = self._market_linear_noise_settings( + noise_model="market_linear", + arb_frequency=requested_arb_frequency, + ) + elif requested_mode == "calibrated": + result = self._legacy_calibrated_noise_settings( + arb_frequency=requested_arb_frequency + ) + else: + arb_frequency = requested_arb_frequency + result = { + "noise_model": requested_mode, + "noise_trader_ratio": cfg.get("noise_trader_ratio", 0.0), + "reclamm_noise_params": cfg.get("reclamm_noise_params"), + "noise_arrays_path": cfg.get("noise_arrays_path"), + "arb_frequency": arb_frequency, + "noise_summary": f"{requested_mode} (arb_frequency={arb_frequency})", + "noise_cache_key": ( + requested_mode, + round(float(cfg.get("noise_trader_ratio", 0.0)), 12), + self._hashable_noise_params(cfg.get("reclamm_noise_params")), + cfg.get("noise_arrays_path"), + arb_frequency, + ), + } + + self._noise_settings_cache[cache_key] = result + return result + + def make_noise_variant_cfg(self, cfg, enable_noise_model): + return self.normalize_compare_run_cfg( + cfg, + enable_noise_model=enable_noise_model, + ) + + @staticmethod + def _make_method_cache_hash(key): + return hashlib.sha256(repr(key).encode("utf-8")).hexdigest() + + def _make_method_cache_key(self, cfg, method): + cfg = self.normalize_compare_run_cfg(cfg) + noise_cfg = self.resolve_reclamm_noise_settings(cfg) + arb_frequency = self.get_effective_arb_frequency(cfg, noise_cfg) + key = ( + method, + bool(cfg.get("enable_noise_model", False)), + round(float(cfg["price_ratio"]), 6), + round(float(cfg["centeredness_margin"]), 6), + round(float(cfg["daily_price_shift_exponent"]), 6), + round(self.get_initial_pool_value(cfg), 2), + noise_cfg.get("noise_cache_key"), + None if arb_frequency is None else int(arb_frequency), + round( + float( + cfg.get( + "gas_cost", + self.DEFAULT_GAS_COST + if cfg.get("enable_noise_model", False) + else 0.0, + ) + ), + 6, + ), + round( + float( + cfg.get( + "protocol_fee_split", + self.DEFAULT_PROTOCOL_FEE_SPLIT + if cfg.get("enable_noise_model", False) + else 0.0, + ) + ), + 6, + ), + ) + if method == "constant_arc_length": + speed = cfg.get("arc_length_speed") + key += (None if speed is None else round(float(speed), 12),) + return key + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Identify horizontally and vertically adjacent reCLAMM heatmap cells " + "whose derived metric values differ by at least the requested threshold." + ) + ) + parser.add_argument( + "--metric-key", + default="noise_vs_arb_geometric_improvement_pct", + choices=sorted(CACHE_ONLY_METRIC_SPECS), + help="Heatmap metric to reconstruct from the persisted forward-value cache.", + ) + parser.add_argument( + "--pair-slug", + default="price_ratio_vs_margin", + help="Pair heatmap family slug, or 'all' to scan every pair family.", + ) + parser.add_argument( + "--slice-slug", + default="all", + help="Quarter-slice slug (q1/q2/q3/q4), or 'all' to scan every slice.", + ) + parser.add_argument( + "--min-diff", + type=float, + default=30.0, + help=( + "Minimum absolute difference between adjacent heatmap values. " + "For the default metric this is in percentage points." + ), + ) + parser.add_argument( + "--adjacency-axis", + default="both", + choices=("both", "horizontal", "vertical"), + help=( + "Which adjacency direction to scan. " + "'both' includes horizontal and vertical neighbors." + ), + ) + parser.add_argument( + "--initial-pool-value", + type=float, + default=1_000_000.0, + help="TVL in USD used for the cached heatmap sweep.", + ) + parser.add_argument( + "--config-index", + type=int, + default=1, + help="Which compare_reclamm_thermostats.py base config to use.", + ) + parser.add_argument( + "--cache-path", + default=None, + help="Optional parquet cache override. Defaults to the compare script's TVL cache.", + ) + parser.add_argument( + "--output-csv", + default=None, + help="Optional CSV output path. Defaults under scripts/results/.", + ) + parser.add_argument( + "--skip-top-row-geometric-comparison", + action="store_true", + help=( + "Skip the follow-up geometric noise comparison for the top CSV row. " + "By default the script attempts that comparison after writing the CSV." + ), + ) + parser.add_argument( + "--top-row-geometric-comparison-output-file", + default=None, + help="Optional PNG output path override for the top-row geometric comparison.", + ) + parser.add_argument( + "--allow-partial-cache", + action="store_true", + help="Write output even if some heatmap cells are missing from the cache.", + ) + return parser.parse_args() + + +def load_compare_module(module_path: Optional[Path] = None): + compare_path = module_path or Path(__file__).with_name("compare_reclamm_thermostats.py") + spec = importlib.util.spec_from_file_location( + "reclamm_compare_reclamm_thermostats", + compare_path, + ) + if spec is not None and spec.loader is not None: + module = importlib.util.module_from_spec(spec) + try: + spec.loader.exec_module(module) + return _LightweightCompareContext.from_compare_module(module) + except ModuleNotFoundError as exc: + if exc.name != "jax": + raise + print( + "compare_reclamm_thermostats.py depends on jax in this environment; " + "using the lightweight cache-key context instead." + ) + return _LightweightCompareContext() + + +def get_metric_spec(metric_key: str) -> Mapping[str, object]: + try: + return CACHE_ONLY_METRIC_SPECS[metric_key] + except KeyError as exc: + raise ValueError( + f"Unsupported metric_key={metric_key!r}. " + f"Supported cache-only metrics: {sorted(CACHE_ONLY_METRIC_SPECS)}" + ) from exc + + +def load_cache_lookup(cache_path: Path) -> Dict[str, float]: + frame = pd.read_parquet(cache_path, columns=["cache_key_hash", "final_value"]) + return { + str(row.cache_key_hash): float(row.final_value) + for row in frame.itertuples(index=False) + } + + +def resolve_existing_cache_path(cache_path: Path) -> Path: + candidates = [Path(cache_path)] + if not cache_path.is_absolute(): + candidates.append(Path("scripts") / cache_path) + for candidate in candidates: + if candidate.exists(): + return candidate + return Path(cache_path) + + +def load_geometric_compare_module(module_path: Optional[Path] = None): + compare_path = module_path or Path(__file__).with_name( + "compare_reclamm_geometric_noise_runs.py" + ) + spec = importlib.util.spec_from_file_location( + "reclamm_compare_reclamm_geometric_noise_runs", + compare_path, + ) + if spec is None or spec.loader is None: + raise RuntimeError(f"Could not load geometric compare module from {compare_path}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def run_top_row_geometric_comparison( + csv_path: Path, + output_file: Optional[str] = None, + row_index: int = 0, +): + """Run the paired geometric-vs-arb comparison using the top adjacent CSV row.""" + module = load_geometric_compare_module() + if not hasattr(module, "run_adjacent_csv_row_comparison"): + raise RuntimeError( + "compare_reclamm_geometric_noise_runs.py does not expose " + "run_adjacent_csv_row_comparison" + ) + return module.run_adjacent_csv_row_comparison( + csv_path=csv_path, + row_index=row_index, + output_file=output_file, + ) + + +def resolve_pair_specs(compare_module, base_cfg: Mapping[str, object], pair_slug: str): + pair_specs = compare_module.get_pair_heatmap_specs(base_cfg) + if pair_slug == "all": + return pair_specs + + matched = [pair for pair in pair_specs if pair["slug"] == pair_slug] + if not matched: + available = [pair["slug"] for pair in pair_specs] + raise ValueError( + f"Unknown pair slug {pair_slug!r}. Available pair slugs: {available}" + ) + return matched + + +def resolve_slice_variants(pair_spec: Mapping[str, object], slice_slug: str): + slice_variants = pair_spec["fixed_slices"] + if slice_slug == "all": + return list(slice_variants) + + matched = [variant for variant in slice_variants if variant["slug"] == slice_slug] + if not matched: + available = [variant["slug"] for variant in slice_variants] + raise ValueError( + f"Unknown slice slug {slice_slug!r}. Available slice slugs: {available}" + ) + return matched + + +def build_default_output_path( + compare_module, + base_cfg: Mapping[str, object], + metric_key: str, + pair_slug: str, + slice_slug: str, + min_diff: float, +) -> Path: + output_dir = Path("scripts/results/reclamm_heatmap_adjacency") + output_dir.mkdir(parents=True, exist_ok=True) + diff_token = str(float(min_diff)).rstrip("0").rstrip(".").replace(".", "p") + filename = ( + f"reclamm_adjacent_pairs_{metric_key}_{pair_slug}_{slice_slug}" + f"_mindiff_{diff_token}_tvl_{compare_module.format_tvl_millions_slug(base_cfg)}.csv" + ) + return output_dir / filename + + +def autodetect_lightweight_noise_profile( + compare_module, + base_cfg: Mapping[str, object], + pair_specs: Sequence[Mapping[str, object]], + metric_key: str, + slice_slug: str, + cache_lookup: Mapping[str, float], +): + if not hasattr(compare_module, "set_noise_profile"): + return + del base_cfg, pair_specs, metric_key, slice_slug, cache_lookup + compare_module.set_noise_profile("market_linear") + print("Lightweight noise profile fixed to market_linear.") + + +def _source_variant(compare_module, cfg: Mapping[str, object], source_name: str): + enable_noise_model = source_name.startswith("noise_") + method = "geometric" if source_name.endswith("geometric") else "constant_arc_length" + return compare_module.make_noise_variant_cfg(cfg, enable_noise_model), method + + +def _compute_metric_value(metric_key: str, final_values: Mapping[str, float]) -> float: + metric_spec = get_metric_spec(metric_key) + return float(metric_spec["compute"](final_values)) + + +def build_cell_record( + compare_module, + cfg: Mapping[str, object], + pair_spec: Mapping[str, object], + slice_variant: Mapping[str, object], + metric_key: str, + x_index: int, + y_index: int, + cache_lookup: Mapping[str, float], +): + metric_spec = get_metric_spec(metric_key) + final_values = {} + missing_hashes = [] + for source_name in metric_spec["sources"]: + source_cfg, method = _source_variant(compare_module, cfg, source_name) + cache_key = compare_module._make_method_cache_key(source_cfg, method) + cache_key_hash = compare_module._make_method_cache_hash(cache_key) + cached_value = cache_lookup.get(cache_key_hash) + if cached_value is None: + missing_hashes.append(cache_key_hash) + continue + final_values[source_name] = float(cached_value) + + if missing_hashes: + return None, missing_hashes + + return ( + { + "metric_key": metric_key, + "metric_unit": metric_spec["unit"], + "source_noise_profile": str( + getattr(compare_module, "noise_profile", "unknown") + ), + "pair_slug": pair_spec["slug"], + "slice_slug": slice_variant["slug"], + "slice_label": slice_variant["label"], + "fixed_key": pair_spec["fixed_key"], + "fixed_value": float(slice_variant["value"]), + "price_ratio": float(cfg["price_ratio"]), + "centeredness_margin": float(cfg["centeredness_margin"]), + "daily_price_shift_exponent": float(cfg["daily_price_shift_exponent"]), + "tvl_usd": float(compare_module.get_initial_pool_value(cfg)), + "heatmap_value": _compute_metric_value(metric_key, final_values), + "x_index": int(x_index), + "y_index": int(y_index), + }, + [], + ) + + +def build_slice_cell_grid( + compare_module, + base_cfg: Mapping[str, object], + pair_spec: Mapping[str, object], + slice_variant: Mapping[str, object], + metric_key: str, + cache_lookup: Mapping[str, float], +): + records_by_coord: Dict[Tuple[int, int], MutableMapping[str, object]] = {} + missing_hashes: List[str] = [] + x_values = pair_spec["x_values"] + y_values = pair_spec["y_values"] + slice_cfg = dict(base_cfg) + slice_cfg[pair_spec["fixed_key"]] = float(slice_variant["value"]) + + for y_index, y_value in enumerate(y_values): + for x_index, x_value in enumerate(x_values): + cfg = dict(slice_cfg) + cfg[pair_spec["x_key"]] = float(x_value) + cfg[pair_spec["y_key"]] = float(y_value) + record, missing_for_cell = build_cell_record( + compare_module=compare_module, + cfg=cfg, + pair_spec=pair_spec, + slice_variant=slice_variant, + metric_key=metric_key, + x_index=x_index, + y_index=y_index, + cache_lookup=cache_lookup, + ) + if record is not None: + records_by_coord[(y_index, x_index)] = record + missing_hashes.extend(missing_for_cell) + + expected_cell_count = len(x_values) * len(y_values) + return { + "records_by_coord": records_by_coord, + "expected_cell_count": expected_cell_count, + "resolved_cell_count": len(records_by_coord), + "missing_hash_count": len(missing_hashes), + "missing_hashes": missing_hashes, + } + + +def build_adjacent_row( + metric_key: str, + metric_unit: str, + axis: str, + first_cell: Mapping[str, object], + second_cell: Mapping[str, object], +) -> Dict[str, object]: + signed_diff = float(second_cell["heatmap_value"]) - float(first_cell["heatmap_value"]) + abs_diff = abs(signed_diff) + return { + "metric_key": metric_key, + "metric_unit": metric_unit, + "source_noise_profile": first_cell.get("source_noise_profile", "unknown"), + "pair_slug": first_cell["pair_slug"], + "slice_slug": first_cell["slice_slug"], + "slice_label": first_cell["slice_label"], + "fixed_key": first_cell["fixed_key"], + "fixed_value": float(first_cell["fixed_value"]), + "adjacency_axis": axis, + "heatmap_value_diff_abs": abs_diff, + "heatmap_value_diff_signed_2_minus_1": signed_diff, + "1_price_ratio": float(first_cell["price_ratio"]), + "1_centeredness_margin": float(first_cell["centeredness_margin"]), + "1_daily_price_shift_exponent": float(first_cell["daily_price_shift_exponent"]), + "1_tvl_usd": float(first_cell["tvl_usd"]), + "1_heatmap_value": float(first_cell["heatmap_value"]), + "1_x_index": int(first_cell["x_index"]), + "1_y_index": int(first_cell["y_index"]), + "2_price_ratio": float(second_cell["price_ratio"]), + "2_centeredness_margin": float(second_cell["centeredness_margin"]), + "2_daily_price_shift_exponent": float(second_cell["daily_price_shift_exponent"]), + "2_tvl_usd": float(second_cell["tvl_usd"]), + "2_heatmap_value": float(second_cell["heatmap_value"]), + "2_x_index": int(second_cell["x_index"]), + "2_y_index": int(second_cell["y_index"]), + } + + +def find_adjacent_rows_for_slice( + metric_key: str, + metric_unit: str, + records_by_coord: Mapping[Tuple[int, int], Mapping[str, object]], + x_count: int, + y_count: int, + min_diff: float, + adjacency_axis: str = "both", +) -> List[Dict[str, object]]: + rows = [] + + if adjacency_axis not in {"both", "horizontal", "vertical"}: + raise ValueError( + f"Unsupported adjacency_axis={adjacency_axis!r}; expected both, horizontal, or vertical" + ) + + if adjacency_axis in {"both", "horizontal"}: + for y_index in range(y_count): + for x_index in range(x_count - 1): + first_cell = records_by_coord.get((y_index, x_index)) + second_cell = records_by_coord.get((y_index, x_index + 1)) + if first_cell is None or second_cell is None: + continue + row = build_adjacent_row( + metric_key=metric_key, + metric_unit=metric_unit, + axis="horizontal", + first_cell=first_cell, + second_cell=second_cell, + ) + if row["heatmap_value_diff_abs"] >= min_diff: + rows.append(row) + + if adjacency_axis in {"both", "vertical"}: + for y_index in range(y_count - 1): + for x_index in range(x_count): + first_cell = records_by_coord.get((y_index, x_index)) + second_cell = records_by_coord.get((y_index + 1, x_index)) + if first_cell is None or second_cell is None: + continue + row = build_adjacent_row( + metric_key=metric_key, + metric_unit=metric_unit, + axis="vertical", + first_cell=first_cell, + second_cell=second_cell, + ) + if row["heatmap_value_diff_abs"] >= min_diff: + rows.append(row) + + rows.sort( + key=lambda row: ( + -float(row["heatmap_value_diff_abs"]), + str(row["pair_slug"]), + str(row["slice_slug"]), + str(row["adjacency_axis"]), + int(row["1_y_index"]), + int(row["1_x_index"]), + ) + ) + return rows + + +def scan_heatmap_pairs( + compare_module, + base_cfg: Mapping[str, object], + metric_key: str, + pair_specs: Sequence[Mapping[str, object]], + slice_slug: str, + min_diff: float, + adjacency_axis: str, + cache_lookup: Mapping[str, float], +): + metric_spec = get_metric_spec(metric_key) + all_rows: List[Dict[str, object]] = [] + diagnostics = [] + + for pair_spec in pair_specs: + slice_variants = resolve_slice_variants(pair_spec, slice_slug) + for slice_variant in slice_variants: + slice_scan = build_slice_cell_grid( + compare_module=compare_module, + base_cfg=base_cfg, + pair_spec=pair_spec, + slice_variant=slice_variant, + metric_key=metric_key, + cache_lookup=cache_lookup, + ) + diagnostics.append( + { + "pair_slug": pair_spec["slug"], + "slice_slug": slice_variant["slug"], + "resolved_cell_count": slice_scan["resolved_cell_count"], + "expected_cell_count": slice_scan["expected_cell_count"], + "missing_hash_count": slice_scan["missing_hash_count"], + } + ) + all_rows.extend( + find_adjacent_rows_for_slice( + metric_key=metric_key, + metric_unit=metric_spec["unit"], + records_by_coord=slice_scan["records_by_coord"], + x_count=len(pair_spec["x_values"]), + y_count=len(pair_spec["y_values"]), + min_diff=min_diff, + adjacency_axis=adjacency_axis, + ) + ) + + return all_rows, diagnostics + + +def rows_to_frame(rows: Iterable[Mapping[str, object]]) -> pd.DataFrame: + frame = pd.DataFrame(list(rows)) + if frame.empty: + return pd.DataFrame(columns=OUTPUT_COLUMNS) + frame = frame.loc[:, OUTPUT_COLUMNS] + frame.sort_values( + by=[ + "heatmap_value_diff_abs", + "pair_slug", + "slice_slug", + "adjacency_axis", + "1_y_index", + "1_x_index", + ], + ascending=[False, True, True, True, True, True], + inplace=True, + ignore_index=True, + ) + return frame + + +def main() -> int: + args = parse_args() + compare_module = load_compare_module() + + if not 0 <= args.config_index < len(compare_module.CONFIGS): + raise ValueError( + f"config-index {args.config_index} is out of range for " + f"{len(compare_module.CONFIGS)} available configs" + ) + + base_cfg = compare_module.configs_for_tvl( + compare_module.CONFIGS, + initial_pool_value=args.initial_pool_value, + )[args.config_index] + pair_specs = resolve_pair_specs(compare_module, base_cfg, args.pair_slug) + + cache_path = ( + Path(args.cache_path) + if args.cache_path is not None + else Path(compare_module._heatmap_forward_cache_path(base_cfg)) + ) + cache_path = resolve_existing_cache_path(cache_path) + if not cache_path.exists(): + raise FileNotFoundError( + f"Cache parquet not found: {cache_path}. " + "Generate the current market_linear/arb_only heatmap cache first." + ) + + cache_lookup = load_cache_lookup(cache_path) + autodetect_lightweight_noise_profile( + compare_module=compare_module, + base_cfg=base_cfg, + pair_specs=pair_specs, + metric_key=args.metric_key, + slice_slug=args.slice_slug, + cache_lookup=cache_lookup, + ) + output_csv = ( + Path(args.output_csv) + if args.output_csv is not None + else build_default_output_path( + compare_module=compare_module, + base_cfg=base_cfg, + metric_key=args.metric_key, + pair_slug=args.pair_slug, + slice_slug=args.slice_slug, + min_diff=args.min_diff, + ) + ) + + print( + f"Loaded {len(cache_lookup):,} cached final values from {cache_path} " + f"for {base_cfg['name']} at TVL {compare_module.format_tvl_millions_label(base_cfg)}." + ) + print( + f"Scanning metric={args.metric_key}, pair_slug={args.pair_slug}, " + f"slice_slug={args.slice_slug}, adjacency_axis={args.adjacency_axis}, " + f"min_diff={args.min_diff} " + f"({get_metric_spec(args.metric_key)['unit']})." + ) + + rows, diagnostics = scan_heatmap_pairs( + compare_module=compare_module, + base_cfg=base_cfg, + metric_key=args.metric_key, + pair_specs=pair_specs, + slice_slug=args.slice_slug, + min_diff=args.min_diff, + adjacency_axis=args.adjacency_axis, + cache_lookup=cache_lookup, + ) + + for diagnostic in diagnostics: + print( + f"[{diagnostic['pair_slug']}:{diagnostic['slice_slug']}] " + f"resolved {diagnostic['resolved_cell_count']}/" + f"{diagnostic['expected_cell_count']} cells " + f"({diagnostic['missing_hash_count']} missing cache hashes)" + ) + + missing_any = any(diagnostic["missing_hash_count"] > 0 for diagnostic in diagnostics) + if missing_any and not args.allow_partial_cache: + raise RuntimeError( + "Cache was incomplete for at least one requested heatmap slice. " + "This cache does not match the current market_linear/arb_only " + "parameterization. Regenerate the heatmap cache, or re-run with " + "--allow-partial-cache to write only the rows that were resolvable." + ) + + frame = rows_to_frame(rows) + output_csv.parent.mkdir(parents=True, exist_ok=True) + frame.to_csv(output_csv, index=False) + print(f"Wrote {len(frame):,} adjacent pairs to {output_csv}") + + if not args.skip_top_row_geometric_comparison: + if frame.empty: + print("Skipping top-row geometric comparison because the CSV is empty.") + else: + print( + "Running geometric noise comparison for the top adjacent-pairs CSV row..." + ) + try: + comparison_output = run_top_row_geometric_comparison( + csv_path=output_csv, + output_file=args.top_row_geometric_comparison_output_file, + row_index=0, + ) + print( + f"Completed top-row geometric comparison using {output_csv} row 0. " + f"Output: {comparison_output}" + ) + except Exception as exc: # pragma: no cover - depends on local runtime deps + print( + "Top-row geometric comparison did not run successfully: " + f"{exc}" + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/reclamm_arc_speed_shift_price_sweep.py b/scripts/reclamm_arc_speed_shift_price_sweep.py new file mode 100644 index 0000000..a18c7b2 --- /dev/null +++ b/scripts/reclamm_arc_speed_shift_price_sweep.py @@ -0,0 +1,825 @@ +"""Sweep reCLAMM arc speed jointly with price ratio and shift exponent. + +This is a focused companion to ``compare_reclamm_thermostats.py``. It reuses +that script's configs, cache, market-linear noise setup, and heatmap metrics, +then evaluates the 3-variable cube: + + daily_price_shift_exponent x price_ratio x arc_length_speed + +The default run is the full compare-grid for the 1M TVL aggressive/tight-range +config, with the launch-style config used as the benchmark where required. +""" + +from __future__ import annotations + +import argparse +import gc +import math +import os +import sys +from pathlib import Path + +# Keep background runs from depending on a writable user-level matplotlib cache. +os.environ.setdefault( + "MPLCONFIGDIR", + str(Path(os.environ.get("TMPDIR", "/tmp")) / "matplotlib"), +) + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from matplotlib.cm import ScalarMappable +from matplotlib.colors import LogNorm + + +SCRIPT_DIR = Path(__file__).resolve().parent +if str(SCRIPT_DIR) not in sys.path: + sys.path.insert(0, str(SCRIPT_DIR)) + + +DEFAULT_METRIC_KEYS = ( + "efficiency_pct", + "noise_constant_arc_final_value_musd", +) +DEFAULT_FACET_SHIFT_VALUES = (0.01, 0.05, 0.10, 0.20, 0.35, 0.50) +SUPPORTED_METRIC_KEYS = ( + "constant_arc_vs_launch_constant_arc_pct", + "efficiency_pct", + "geometric_vs_launch_geometric_pct", + "launch_geometric_efficiency_pct", + "noise_constant_arc_final_value_musd", + "noise_geometric_final_value_musd", + "noise_vs_arb_constant_arc_improvement_pct", + "noise_vs_arb_geometric_improvement_pct", +) + + +def load_compare_module(): + """Import the heavy compare module only after CLI parsing.""" + import compare_reclamm_thermostats as compare_module + + return compare_module + + +def parse_csv_floats(value: str) -> tuple[float, ...]: + """Parse a comma-separated float list.""" + values = [] + for token in value.split(","): + token = token.strip() + if token: + values.append(float(token)) + if not values: + raise argparse.ArgumentTypeError("expected at least one float") + return tuple(values) + + +def parse_args() -> argparse.Namespace: + """Parse CLI options for a long-running background sweep.""" + parser = argparse.ArgumentParser( + description=( + "Run the reCLAMM 3-variable arc-speed/price-ratio/shift-exponent " + "sweep using the compare_reclamm_thermostats cache and configs." + ) + ) + parser.add_argument( + "--output-dir", + default="results/reclamm_arc_speed_shift_price_sweep", + help="Directory for cube parquet files and plots.", + ) + parser.add_argument( + "--tvl", + nargs="+", + type=float, + default=None, + help=( + "One or more initial pool values to run. Defaults to 1,000,000. " + "Use --all-tvls for the compare script's 1M/5M/20M sweep." + ), + ) + parser.add_argument( + "--all-tvls", + action="store_true", + help="Run every TVL from compare_reclamm_thermostats.TVL_SWEEP_VALUES.", + ) + parser.add_argument( + "--metric", + action="append", + choices=SUPPORTED_METRIC_KEYS, + default=None, + help=( + "Metric to compute. Can be repeated. Defaults to efficiency_pct and " + "noise_constant_arc_final_value_musd." + ), + ) + parser.add_argument( + "--facet-shift-values", + type=parse_csv_floats, + default=DEFAULT_FACET_SHIFT_VALUES, + help=( + "Comma-separated shift exponents to include in the facet overview. " + "Nearest grid values are used." + ), + ) + parser.add_argument( + "--plot-all-shift-slices", + action="store_true", + help="Also save one 2D arc-speed/price-ratio heatmap per shift exponent.", + ) + parser.add_argument( + "--no-orthogonal-3d", + action="store_true", + help="Skip the literal 3D orthogonal-slice plot.", + ) + parser.add_argument( + "--skip-cube-parquet", + action="store_true", + help="Do not save the evaluated parameter cube to parquet.", + ) + return parser.parse_args() + + +def ensure_output_dir(path: str | os.PathLike[str]) -> Path: + """Create and return the output directory.""" + output_dir = Path(path) + output_dir.mkdir(parents=True, exist_ok=True) + return output_dir + + +def tvl_output_path( + output_dir: Path, + stem: str, + cfg: dict, + suffix: str | None = None, + ext: str = "png", +) -> Path: + """Build a stable output path with the compare script's TVL slug.""" + parts = [stem] + if suffix: + parts.append(suffix) + parts.append(f"tvl_{ct.format_tvl_millions_slug(cfg)}") + return output_dir / ("_".join(parts) + f".{ext}") + + +def metric_spec_map() -> dict[str, dict]: + """Return the compare script's metric specs, keyed by metric name.""" + return {spec["key"]: spec for spec in ct.get_pair_heatmap_metric_specs()} + + +def nearest_indices(values: np.ndarray, targets: tuple[float, ...]) -> list[int]: + """Resolve target values to unique nearest indices in a sweep grid.""" + values = np.asarray(values, dtype=float) + indices: list[int] = [] + for target in targets: + idx = int(np.argmin(np.abs(values - float(target)))) + if idx not in indices: + indices.append(idx) + return indices + + +def speed_label(value: float) -> str: + """Format an arc speed for plot tick labels.""" + return f"{float(value):.0e}" + + +def heatmap_value_slug(value: float) -> str: + """Format a sweep value for filenames.""" + return f"{float(value):.6g}".replace("-", "m").replace(".", "p") + + +def build_arc_speed_shift_price_cube( + base_cfg: dict, + arc_length_speeds: np.ndarray, + price_ratios: np.ndarray, + shift_exponents: np.ndarray, + metric_keys: tuple[str, ...], + cache: dict, + launch_final_values: dict, +) -> dict[str, np.ndarray]: + """Evaluate metric cubes as shift_exp x price_ratio x arc_length_speed.""" + data = { + metric_key: np.zeros( + (len(shift_exponents), len(price_ratios), len(arc_length_speeds)), + dtype=float, + ) + for metric_key in metric_keys + } + total_points = len(shift_exponents) * len(price_ratios) * len(arc_length_speeds) + print( + "\nStarting 3-variable arc-speed sweep: " + f"{len(shift_exponents)} shift slices x {len(price_ratios)} price ratios x " + f"{len(arc_length_speeds)} arc speeds = {total_points} parameter points." + ) + + for zi, shift_exp in enumerate(shift_exponents): + slice_cfg = dict(base_cfg) + slice_cfg["daily_price_shift_exponent"] = float(shift_exp) + progress_label = ( + "arc_speed_shift_price_" + f"shift_{zi + 1:03d}_of_{len(shift_exponents):03d}_" + f"{float(shift_exp):.4f}" + ) + slice_data = ct.build_heatmap_matrices( + x_values=arc_length_speeds, + y_values=price_ratios, + x_key="arc_length_speed", + y_key="price_ratio", + base_cfg=slice_cfg, + metric_keys=metric_keys, + cache=cache, + progress_label=progress_label, + launch_final_values=launch_final_values, + ) + for metric_key in metric_keys: + data[metric_key][zi, :, :] = slice_data[metric_key] + ct.flush_sweep_cache(cache, force=True) + del slice_data + gc.collect() + + return data + + +def cube_to_frame( + cube: dict[str, np.ndarray], + arc_length_speeds: np.ndarray, + price_ratios: np.ndarray, + shift_exponents: np.ndarray, +) -> pd.DataFrame: + """Flatten the 3D cube into a tidy table for offline analysis.""" + records = [] + for zi, shift_exp in enumerate(shift_exponents): + for yi, price_ratio in enumerate(price_ratios): + for xi, arc_speed in enumerate(arc_length_speeds): + record = { + "shift_index": zi, + "price_ratio_index": yi, + "arc_length_speed_index": xi, + "daily_price_shift_exponent": float(shift_exp), + "price_ratio": float(price_ratio), + "arc_length_speed": float(arc_speed), + } + for metric_key, values in cube.items(): + record[metric_key] = float(values[zi, yi, xi]) + records.append(record) + return pd.DataFrame.from_records(records) + + +def save_cube_parquet( + output_dir: Path, + cfg: dict, + cube: dict[str, np.ndarray], + arc_length_speeds: np.ndarray, + price_ratios: np.ndarray, + shift_exponents: np.ndarray, +) -> Path: + """Persist the full parameter cube as parquet.""" + output_path = tvl_output_path( + output_dir, + "reclamm_arc_speed_shift_price_cube", + cfg, + ext="parquet", + ) + frame = cube_to_frame( + cube, + arc_length_speeds=arc_length_speeds, + price_ratios=price_ratios, + shift_exponents=shift_exponents, + ) + frame.to_parquet(output_path, index=False, compression="zstd") + print(f"Saved {output_path}") + return output_path + + +def plot_shift_slice_facets( + data: np.ndarray, + arc_length_speeds: np.ndarray, + price_ratios: np.ndarray, + shift_exponents: np.ndarray, + shift_indices: list[int], + spec: dict, + cfg: dict, + filename: Path, +) -> None: + """Render selected shift-exponent slices as small-multiple heatmaps.""" + selected = [data[idx] for idx in shift_indices] + norm = ct._build_heatmap_norm( + selected, + center_zero=spec["center_zero"], + color_norm=spec.get("color_norm"), + symlog_linthresh=spec.get("symlog_linthresh"), + ) + cmap_name = spec["cmap"] + x_edges = ct._compute_axis_edges(arc_length_speeds, scale="log") + y_edges = ct._compute_axis_edges(price_ratios, scale="linear") + + col_count = min(3, len(shift_indices)) + row_count = int(math.ceil(len(shift_indices) / col_count)) + fig, axes = plt.subplots( + row_count, + col_count, + figsize=(4.3 * col_count, 3.3 * row_count), + squeeze=False, + ) + active_axes = [] + im = None + for plot_idx, shift_idx in enumerate(shift_indices): + ax = axes[plot_idx // col_count][plot_idx % col_count] + active_axes.append(ax) + im = ax.pcolormesh( + x_edges, + y_edges, + data[shift_idx], + cmap=cmap_name, + norm=norm, + shading="auto", + ) + ax.set_xscale("log") + ax.set_xticks(ct.ARC_LENGTH_SPEED_TICKS) + ax.set_yticks(ct.PRICE_RATIO_TICKS) + ax.tick_params(axis="x", labelrotation=35) + ax.set_xlabel("Arc-length speed") + ax.set_ylabel("Price ratio") + ax.set_title( + f"shift_exp={ct.format_heatmap_param_value(shift_exponents[shift_idx])}" + ) + + for plot_idx in range(len(shift_indices), row_count * col_count): + axes[plot_idx // col_count][plot_idx % col_count].set_visible(False) + + fig.suptitle( + f"{spec['title']}: arc_speed x price_ratio by shift_exp | " + f"TVL {ct.format_tvl_millions_label(cfg)}", + y=0.995, + ) + if im is not None: + cbar = fig.colorbar(im, ax=active_axes, shrink=0.88) + cbar.set_label(spec["colorbar_label"]) + plt.tight_layout() + plt.savefig(filename, dpi=150) + print(f"Saved {filename}") + plt.close(fig) + + +def plot_all_shift_slices( + output_dir: Path, + cube: dict[str, np.ndarray], + arc_length_speeds: np.ndarray, + price_ratios: np.ndarray, + shift_exponents: np.ndarray, + specs: dict[str, dict], + cfg: dict, + metric_keys: tuple[str, ...], +) -> None: + """Optionally save one 2D heatmap per shift exponent and metric.""" + for metric_key in metric_keys: + spec = specs[metric_key] + for shift_idx, shift_exp in enumerate(shift_exponents): + filename = tvl_output_path( + output_dir, + f"reclamm_arc_speed_price_slice_{spec['slug']}", + cfg, + suffix=f"shift_exp_{heatmap_value_slug(float(shift_exp))}", + ) + ct.plot_heatmap( + data=cube[metric_key][shift_idx], + x_values=arc_length_speeds, + y_values=price_ratios, + x_label="Arc-length speed", + y_label="Price ratio", + title=( + f"{spec['title']}: shift_exp fixed at " + f"{ct.format_heatmap_param_value(float(shift_exp))} | " + f"TVL {ct.format_tvl_millions_label(cfg)}" + ), + colorbar_label=spec["colorbar_label"], + filename=filename, + xticks=ct.ARC_LENGTH_SPEED_TICKS, + yticks=ct.PRICE_RATIO_TICKS, + xscale="log", + center_zero=spec["center_zero"], + cmap=spec["cmap"], + color_norm=spec.get("color_norm"), + symlog_linthresh=spec.get("symlog_linthresh"), + ) + + +def compute_argmax_over_arc_speed( + data: np.ndarray, + arc_length_speeds: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Return best metric value and the arc speed that produced it.""" + best_idx = np.nanargmax(data, axis=2) + best_values = np.take_along_axis(data, best_idx[:, :, None], axis=2)[:, :, 0] + best_speeds = np.asarray(arc_length_speeds, dtype=float)[best_idx] + return best_values, best_speeds + + +def plot_best_speed_heatmap( + best_speeds: np.ndarray, + arc_length_speeds: np.ndarray, + price_ratios: np.ndarray, + shift_exponents: np.ndarray, + metric_label: str, + cfg: dict, + filename: Path, +) -> None: + """Render a price_ratio x shift_exp heatmap of the selected arc speed.""" + x_edges = ct._compute_axis_edges(price_ratios, scale="linear") + y_edges = ct._compute_axis_edges(shift_exponents, scale="linear") + fig, ax = plt.subplots(figsize=(9.0, 6.0)) + norm = LogNorm( + vmin=float(np.min(arc_length_speeds)), + vmax=float(np.max(arc_length_speeds)), + ) + im = ax.pcolormesh( + x_edges, + y_edges, + best_speeds, + cmap="viridis", + norm=norm, + shading="auto", + ) + ax.set_xlabel("Price ratio") + ax.set_ylabel("Shift exponent") + ax.set_title( + f"Best arc-length speed by {metric_label} | " + f"TVL {ct.format_tvl_millions_label(cfg)}" + ) + ax.set_xticks(ct.PRICE_RATIO_TICKS) + ax.set_yticks(ct.SHIFT_EXPONENT_TICKS) + cbar = fig.colorbar(im, ax=ax) + cbar.set_label("Best arc-length speed") + speed_ticks = np.asarray(ct.ARC_LENGTH_SPEED_TICKS, dtype=float) + speed_ticks = speed_ticks[ + (speed_ticks >= float(np.min(arc_length_speeds))) + & (speed_ticks <= float(np.max(arc_length_speeds))) + ] + cbar.set_ticks(speed_ticks) + cbar.set_ticklabels([speed_label(value) for value in speed_ticks]) + plt.tight_layout() + plt.savefig(filename, dpi=150) + print(f"Saved {filename}") + plt.close(fig) + + +def save_best_speed_summary( + output_dir: Path, + cfg: dict, + metric_key: str, + best_values: np.ndarray, + best_speeds: np.ndarray, + price_ratios: np.ndarray, + shift_exponents: np.ndarray, +) -> Path: + """Persist the price_ratio x shift_exp argmax summary.""" + records = [] + for zi, shift_exp in enumerate(shift_exponents): + for yi, price_ratio in enumerate(price_ratios): + records.append( + { + "daily_price_shift_exponent": float(shift_exp), + "price_ratio": float(price_ratio), + f"best_{metric_key}": float(best_values[zi, yi]), + f"best_arc_length_speed_by_{metric_key}": float( + best_speeds[zi, yi] + ), + } + ) + frame = pd.DataFrame.from_records(records) + output_path = tvl_output_path( + output_dir, + f"reclamm_arc_speed_shift_price_best_{metric_key}", + cfg, + ext="parquet", + ) + frame.to_parquet(output_path, index=False, compression="zstd") + print(f"Saved {output_path}") + return output_path + + +def plot_orthogonal_3d_slices( + data: np.ndarray, + arc_length_speeds: np.ndarray, + price_ratios: np.ndarray, + shift_exponents: np.ndarray, + spec: dict, + cfg: dict, + launch_auto_speed: float, + filename: Path, +) -> None: + """Render one literal 3D orthogonal-slice view of the metric cube.""" + shift_idx = int( + np.argmin( + np.abs( + np.asarray(shift_exponents, dtype=float) + - float(cfg["daily_price_shift_exponent"]) + ) + ) + ) + price_idx = int( + np.argmin( + np.abs(np.asarray(price_ratios, dtype=float) - float(cfg["price_ratio"])) + ) + ) + speed_idx = int( + np.argmin( + np.abs( + np.asarray(arc_length_speeds, dtype=float) - float(launch_auto_speed) + ) + ) + ) + + norm = ct._build_heatmap_norm( + [data], + center_zero=spec["center_zero"], + color_norm=spec.get("color_norm"), + symlog_linthresh=spec.get("symlog_linthresh"), + ) + cmap_obj = plt.get_cmap(spec["cmap"]) + log_speeds = np.log10(np.asarray(arc_length_speeds, dtype=float)) + + arc_price_x, arc_price_y = np.meshgrid(log_speeds, price_ratios) + arc_price_z = np.full_like( + arc_price_x, + float(shift_exponents[shift_idx]), + dtype=float, + ) + + arc_shift_x, arc_shift_z = np.meshgrid(log_speeds, shift_exponents) + arc_shift_y = np.full_like( + arc_shift_x, + float(price_ratios[price_idx]), + dtype=float, + ) + + price_shift_y, price_shift_z = np.meshgrid(price_ratios, shift_exponents) + price_shift_x = np.full_like( + price_shift_y, + float(log_speeds[speed_idx]), + dtype=float, + ) + + fig = plt.figure(figsize=(10.5, 7.2)) + ax = fig.add_subplot(111, projection="3d") + ax.plot_surface( + arc_price_x, + arc_price_y, + arc_price_z, + facecolors=cmap_obj(norm(data[shift_idx, :, :])), + shade=False, + ) + ax.plot_surface( + arc_shift_x, + arc_shift_y, + arc_shift_z, + facecolors=cmap_obj(norm(data[:, price_idx, :])), + shade=False, + ) + ax.plot_surface( + price_shift_x, + price_shift_y, + price_shift_z, + facecolors=cmap_obj(norm(data[:, :, speed_idx])), + shade=False, + ) + + ax.set_xlabel("Arc-length speed") + ax.set_ylabel("Price ratio") + ax.set_zlabel("Shift exponent") + ax.set_yticks(ct.PRICE_RATIO_TICKS) + ax.set_zticks(ct.SHIFT_EXPONENT_TICKS) + speed_ticks = np.asarray(ct.ARC_LENGTH_SPEED_TICKS, dtype=float) + speed_ticks = speed_ticks[ + (speed_ticks >= float(np.min(arc_length_speeds))) + & (speed_ticks <= float(np.max(arc_length_speeds))) + ] + ax.set_xticks(np.log10(speed_ticks)) + ax.set_xticklabels([speed_label(value) for value in speed_ticks], rotation=20) + ax.set_title( + f"{spec['title']}: orthogonal 3D slices | " + f"TVL {ct.format_tvl_millions_label(cfg)}\n" + f"shift_exp={ct.format_heatmap_param_value(shift_exponents[shift_idx])}, " + f"price_ratio={ct.format_heatmap_param_value(price_ratios[price_idx])}, " + f"arc_speed={speed_label(arc_length_speeds[speed_idx])}" + ) + ax.view_init(elev=ct.THREE_D_VIEW_ELEVATION, azim=ct.THREE_D_VIEW_AZIMUTH) + try: + ax.set_box_aspect((1.45, 1.5, 1.0)) + except AttributeError: + pass + + sm = ScalarMappable(norm=norm, cmap=cmap_obj) + sm.set_array([]) + cbar = fig.colorbar(sm, ax=ax, fraction=0.03, pad=0.1, shrink=0.82) + cbar.set_label(spec["colorbar_label"]) + plt.tight_layout() + plt.savefig(filename, dpi=150) + print(f"Saved {filename}") + plt.close(fig) + + +def run_for_tvl( + initial_pool_value: float, + output_dir: Path, + metric_keys: tuple[str, ...], + facet_shift_values: tuple[float, ...], + plot_all_slices: bool, + plot_orthogonal_3d: bool, + save_cube: bool, + shared_price_data, + shared_market_linear_noise_data, +) -> None: + """Run the 3D sweep for one initial TVL.""" + launch_cfg, base_cfg = ct.configs_for_tvl(ct.CONFIGS, initial_pool_value) + tvl_label = ct.format_tvl_millions_label(base_cfg) + print(f"\n=== Arc-speed/shift/price sweep: TVL {tvl_label} ===") + + launch_final_values = ct.get_launch_final_values( + [], + launch_cfg=launch_cfg, + price_data=shared_price_data, + market_linear_noise_data=shared_market_linear_noise_data, + ) + cache = ct.make_sweep_cache( + shared_price_data, + cache_scope_cfg=base_cfg, + market_linear_noise_data=shared_market_linear_noise_data, + ) + try: + cube = build_arc_speed_shift_price_cube( + base_cfg=dict(base_cfg), + arc_length_speeds=ct.HEATMAP_ARC_LENGTH_SPEEDS, + price_ratios=ct.HEATMAP_PRICE_RATIOS, + shift_exponents=ct.HEATMAP_SHIFT_EXPONENTS, + metric_keys=metric_keys, + cache=cache, + launch_final_values=launch_final_values, + ) + if save_cube: + save_cube_parquet( + output_dir, + base_cfg, + cube, + arc_length_speeds=ct.HEATMAP_ARC_LENGTH_SPEEDS, + price_ratios=ct.HEATMAP_PRICE_RATIOS, + shift_exponents=ct.HEATMAP_SHIFT_EXPONENTS, + ) + + specs = metric_spec_map() + facet_indices = nearest_indices( + ct.HEATMAP_SHIFT_EXPONENTS, + facet_shift_values, + ) + launch_auto_speed = ct.compute_auto_calibrated_arc_length_speed( + launch_cfg, + shared_price_data, + ) + + for metric_key in metric_keys: + spec = specs[metric_key] + facet_path = tvl_output_path( + output_dir, + f"reclamm_arc_speed_shift_price_facets_{spec['slug']}", + base_cfg, + ) + plot_shift_slice_facets( + data=cube[metric_key], + arc_length_speeds=ct.HEATMAP_ARC_LENGTH_SPEEDS, + price_ratios=ct.HEATMAP_PRICE_RATIOS, + shift_exponents=ct.HEATMAP_SHIFT_EXPONENTS, + shift_indices=facet_indices, + spec=spec, + cfg=base_cfg, + filename=facet_path, + ) + + best_values, best_speeds = compute_argmax_over_arc_speed( + cube[metric_key], + ct.HEATMAP_ARC_LENGTH_SPEEDS, + ) + best_value_path = tvl_output_path( + output_dir, + f"reclamm_arc_speed_shift_price_best_{spec['slug']}", + base_cfg, + ) + ct.plot_heatmap( + data=best_values, + x_values=ct.HEATMAP_PRICE_RATIOS, + y_values=ct.HEATMAP_SHIFT_EXPONENTS, + x_label="Price ratio", + y_label="Shift exponent", + title=( + f"Best {spec['title']} over arc-length speed | " + f"TVL {ct.format_tvl_millions_label(base_cfg)}" + ), + colorbar_label=f"Best {spec['colorbar_label']}", + filename=best_value_path, + xticks=ct.PRICE_RATIO_TICKS, + yticks=ct.SHIFT_EXPONENT_TICKS, + center_zero=spec["center_zero"], + cmap=spec["cmap"], + color_norm=spec.get("color_norm"), + symlog_linthresh=spec.get("symlog_linthresh"), + ) + best_speed_path = tvl_output_path( + output_dir, + f"reclamm_arc_speed_shift_price_argmax_speed_by_{spec['slug']}", + base_cfg, + ) + plot_best_speed_heatmap( + best_speeds=best_speeds, + arc_length_speeds=ct.HEATMAP_ARC_LENGTH_SPEEDS, + price_ratios=ct.HEATMAP_PRICE_RATIOS, + shift_exponents=ct.HEATMAP_SHIFT_EXPONENTS, + metric_label=spec["title"], + cfg=base_cfg, + filename=best_speed_path, + ) + save_best_speed_summary( + output_dir, + base_cfg, + metric_key, + best_values=best_values, + best_speeds=best_speeds, + price_ratios=ct.HEATMAP_PRICE_RATIOS, + shift_exponents=ct.HEATMAP_SHIFT_EXPONENTS, + ) + + if plot_orthogonal_3d: + orthogonal_path = tvl_output_path( + output_dir, + f"reclamm_arc_speed_shift_price_orthogonal_3d_{spec['slug']}", + base_cfg, + ) + plot_orthogonal_3d_slices( + data=cube[metric_key], + arc_length_speeds=ct.HEATMAP_ARC_LENGTH_SPEEDS, + price_ratios=ct.HEATMAP_PRICE_RATIOS, + shift_exponents=ct.HEATMAP_SHIFT_EXPONENTS, + spec=spec, + cfg=base_cfg, + launch_auto_speed=launch_auto_speed, + filename=orthogonal_path, + ) + + if plot_all_slices: + plot_all_shift_slices( + output_dir=output_dir, + cube=cube, + arc_length_speeds=ct.HEATMAP_ARC_LENGTH_SPEEDS, + price_ratios=ct.HEATMAP_PRICE_RATIOS, + shift_exponents=ct.HEATMAP_SHIFT_EXPONENTS, + specs=specs, + cfg=base_cfg, + metric_keys=metric_keys, + ) + finally: + ct.flush_sweep_cache(cache, force=True) + cache.clear() + gc.collect() + print(f"Released arc-speed/shift/price cache for TVL {tvl_label}.") + + +def main() -> None: + """Entrypoint.""" + args = parse_args() + global ct + ct = load_compare_module() + output_dir = ensure_output_dir(args.output_dir) + if args.all_tvls: + tvl_values = tuple(float(value) for value in ct.TVL_SWEEP_VALUES) + elif args.tvl is not None: + tvl_values = tuple(float(value) for value in args.tvl) + else: + tvl_values = (float(ct.DEFAULT_INITIAL_POOL_VALUE),) + + metric_keys = tuple(args.metric or DEFAULT_METRIC_KEYS) + unsupported = [ + metric_key + for metric_key in metric_keys + if metric_key not in ct.HEATMAP_METRIC_DEPENDENCIES + ] + if unsupported: + raise ValueError(f"Unsupported metric keys: {unsupported}") + + print(f"Writing artifacts to {output_dir}") + print(f"Running metrics: {', '.join(metric_keys)}") + print("Loading shared price data and market-linear noise arrays...") + shared_price_data = ct.load_shared_price_data(ct.CONFIGS) + shared_market_linear_noise_data = ct.load_shared_market_linear_noise_data() + + for initial_pool_value in tvl_values: + run_for_tvl( + initial_pool_value=initial_pool_value, + output_dir=output_dir, + metric_keys=metric_keys, + facet_shift_values=tuple(args.facet_shift_values), + plot_all_slices=bool(args.plot_all_shift_slices), + plot_orthogonal_3d=not bool(args.no_orthogonal_3d), + save_cube=not bool(args.skip_cube_parquet), + shared_price_data=shared_price_data, + shared_market_linear_noise_data=shared_market_linear_noise_data, + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/run_direct_calibration_top50.py b/scripts/run_direct_calibration_top50.py index 1c1770b..d323c0b 100644 --- a/scripts/run_direct_calibration_top50.py +++ b/scripts/run_direct_calibration_top50.py @@ -32,7 +32,7 @@ os.path.dirname(os.path.dirname(__file__)), "results", "direct_calibration_top50", ) -TRAIN_DAYS = 90 +TRAIN_DAYS = 0 # 0 = no filter, use all available data per pool TOP_N = 50 OPTION_C_MAXITER = 500 JOINT_MAXITER = 500 @@ -40,18 +40,28 @@ def load_and_match(): - """Load panel, filter to 90 days, match to grids.""" + """Load panel, match to grids. No date filter — each pool uses all data.""" + from quantammsim.calibration.pool_data import ( + match_grids_to_panel, + replace_panel_volatility_with_binance, + ) + panel = pd.read_parquet(PANEL_CACHE) - max_date = panel["date"].max() - if not isinstance(max_date, date): - max_date = pd.Timestamp(max_date).date() - cutoff = max_date - timedelta(days=TRAIN_DAYS) - panel = panel[ - panel["date"].apply( - lambda d: d >= cutoff if isinstance(d, date) - else pd.Timestamp(d).date() >= cutoff - ) - ].copy() + + # Optional date filter (TRAIN_DAYS=0 means no filter) + if TRAIN_DAYS > 0: + max_date = panel["date"].max() + if not isinstance(max_date, date): + max_date = pd.Timestamp(max_date).date() + cutoff = max_date - timedelta(days=TRAIN_DAYS) + panel = panel[ + panel["date"].apply( + lambda d: d >= cutoff if isinstance(d, date) + else pd.Timestamp(d).date() >= cutoff + ) + ].copy() + else: + panel = panel.copy() if "log_tvl_lag1" not in panel.columns: panel = panel.sort_values(["pool_id", "date"]).reset_index(drop=True) @@ -62,21 +72,27 @@ def load_and_match(): valid = pool_counts[pool_counts >= 10].index panel = panel[panel["pool_id"].isin(valid)].copy() + # Replace Balancer-hourly volatility with Binance-minute volatility + print("Replacing volatility with Binance minute data...") + panel = replace_panel_volatility_with_binance(panel) + + min_date = panel["date"].min() + max_date = panel["date"].max() print(f"Panel: {len(panel)} obs, {panel['pool_id'].nunique()} pools, " - f"{cutoff} to {max_date}") + f"{min_date} to {max_date}") - from quantammsim.calibration.pool_data import match_grids_to_panel matched = match_grids_to_panel(GRID_DIR, panel) print(f"Matched: {len(matched)} pools with grids") return panel, matched -def run_option_c(matched): +def run_option_c(matched, fix_gas_to_chain=False): """Per-pool L-BFGS-B fits.""" from quantammsim.calibration.per_pool_fit import fit_all_pools - print(f"\n--- Option C: per-pool fits ({len(matched)} pools) ---") - results = fit_all_pools(matched) + gas_label = " (gas fixed to chain)" if fix_gas_to_chain else "" + print(f"\n--- Option C: per-pool fits ({len(matched)} pools){gas_label} ---") + results = fit_all_pools(matched, fix_gas_to_chain=fix_gas_to_chain) n_converged = sum(1 for r in results.values() if r["converged"]) losses = [r["loss"] for r in results.values()] print(f" Converged: {n_converged}/{len(results)}") @@ -86,7 +102,7 @@ def run_option_c(matched): return results -def run_option_a(matched, option_c_results): +def run_option_a(matched, option_c_results, fix_gas_to_chain=False): """Joint end-to-end optimization, warm-started from Option C. Drops pathological pools (Option C loss > OPTION_C_LOSS_CUTOFF) from the @@ -106,26 +122,29 @@ def run_option_a(matched, option_c_results): r = option_c_results[p] print(f" {p} {r['tokens']:<16} loss={r['loss']:.1f}") + gas_label = ", gas fixed" if fix_gas_to_chain else "" print(f"\n--- Option A: joint fit (per_pool_noise, {len(matched_clean)} pools, " - f"warm-start from C, no chain dummies) ---") + f"warm-start from C, no chain dummies{gas_label}) ---") result_ppn = fit_joint( matched_clean, mode="per_pool_noise", init_from_option_c=good_pools, maxiter=JOINT_MAXITER, drop_chain_dummies=True, + fix_gas_to_chain=fix_gas_to_chain, ) print(f" Loss: {result_ppn['init_loss']:.4f} -> {result_ppn['loss']:.4f}") print(f" Converged: {result_ppn['converged']}") print(f"\n--- Option A: joint fit (shared_noise, {len(matched_clean)} pools, " - f"warm-start from C, no chain dummies) ---") + f"warm-start from C, no chain dummies{gas_label}) ---") result_sn = fit_joint( matched_clean, mode="shared_noise", init_from_option_c=good_pools, maxiter=JOINT_MAXITER, drop_chain_dummies=True, + fix_gas_to_chain=fix_gas_to_chain, ) print(f" Loss: {result_sn['init_loss']:.4f} -> {result_sn['loss']:.4f}") print(f" Converged: {result_sn['converged']}") @@ -170,115 +189,103 @@ def run_option_rf(matched, option_c_results): attr_names = [attr_names_full[i] for i in non_chain_mask] k_attr = len(attr_names) - print(f"\n--- Option RF: 2-stage mapping ({n_pools} pools, {k_attr} features) ---") + # Detect if gas was fixed in Option C + gas_fixed = any(good_pools[p].get("gas_fixed", False) for p in pool_ids) + + if gas_fixed: + print(f"\n--- Option RF: 2-stage mapping ({n_pools} pools, {k_attr} features, " + f"cadence only — gas fixed) ---") + else: + print(f"\n--- Option RF: 2-stage mapping ({n_pools} pools, {k_attr} features) ---") print(f" Features: {', '.join(attr_names)}") Y_cad = np.array([good_pools[p]["log_cadence"] for p in pool_ids]) Y_gas = np.array([good_pools[p]["log_gas"] for p in pool_ids]) - Y = np.column_stack([Y_cad, Y_gas]) ss_tot_cad = np.sum((Y_cad - Y_cad.mean()) ** 2) - ss_tot_gas = np.sum((Y_gas - Y_gas.mean()) ** 2) def compute_r2(y_true, y_pred, ss_tot): return 1 - np.sum((y_true - y_pred) ** 2) / max(ss_tot, 1e-10) - # ---- Ridge regression (multi-output via separate fits) ---- + # ---- Ridge regression (cadence only when gas is fixed) ---- alphas = np.logspace(-2, 4, 50) ridge_cad = RidgeCV(alphas=alphas, cv=None) # GCV/LOO built-in - ridge_gas = RidgeCV(alphas=alphas, cv=None) ridge_cad.fit(X_attr, Y_cad) - ridge_gas.fit(X_attr, Y_gas) - Y_ridge_train = np.column_stack([ridge_cad.predict(X_attr), - ridge_gas.predict(X_attr)]) - r2_ridge_cad = compute_r2(Y_cad, Y_ridge_train[:, 0], ss_tot_cad) - r2_ridge_gas = compute_r2(Y_gas, Y_ridge_train[:, 1], ss_tot_gas) + Y_ridge_cad_train = ridge_cad.predict(X_attr) + r2_ridge_cad = compute_r2(Y_cad, Y_ridge_cad_train, ss_tot_cad) - print(f"\n Ridge (alpha_cad={ridge_cad.alpha_:.1f}, alpha_gas={ridge_gas.alpha_:.1f}):") - print(f" In-sample R²: cadence={r2_ridge_cad:.3f}, gas={r2_ridge_gas:.3f}") + print(f"\n Ridge (alpha_cad={ridge_cad.alpha_:.1f}):") + print(f" In-sample R² cadence: {r2_ridge_cad:.3f}") - # Ridge LOO-CV + # Ridge LOO-CV (cadence only) loo = LeaveOneOut() - Y_ridge_loo = np.zeros_like(Y) + Y_ridge_cad_loo = np.zeros_like(Y_cad) for train_idx, test_idx in loo.split(X_attr): rc = RidgeCV(alphas=alphas, cv=None).fit(X_attr[train_idx], Y_cad[train_idx]) - rg = RidgeCV(alphas=alphas, cv=None).fit(X_attr[train_idx], Y_gas[train_idx]) - Y_ridge_loo[test_idx, 0] = rc.predict(X_attr[test_idx]) - Y_ridge_loo[test_idx, 1] = rg.predict(X_attr[test_idx]) + Y_ridge_cad_loo[test_idx] = rc.predict(X_attr[test_idx]) - r2_ridge_loo_cad = compute_r2(Y_cad, Y_ridge_loo[:, 0], ss_tot_cad) - r2_ridge_loo_gas = compute_r2(Y_gas, Y_ridge_loo[:, 1], ss_tot_gas) - print(f" LOO-CV R²: cadence={r2_ridge_loo_cad:.3f}, gas={r2_ridge_loo_gas:.3f}") - print(f" LOO-CV MAE: cadence={np.mean(np.abs(np.exp(Y_cad) - np.exp(Y_ridge_loo[:, 0]))):.1f} min, " - f"gas=${np.mean(np.abs(np.exp(Y_gas) - np.exp(Y_ridge_loo[:, 1]))):.2f}") + r2_ridge_loo_cad = compute_r2(Y_cad, Y_ridge_cad_loo, ss_tot_cad) + print(f" LOO-CV R² cadence: {r2_ridge_loo_cad:.3f}") + print(f" LOO-CV MAE cadence: {np.mean(np.abs(np.exp(Y_cad) - np.exp(Y_ridge_cad_loo))):.1f} min") # Ridge coefficients - print(f" Coefficients (cadence | gas):") - print(f" {'intercept':<20} {ridge_cad.intercept_:>7.3f} {ridge_gas.intercept_:>7.3f}") + print(f" Coefficients (cadence):") + print(f" {'intercept':<20} {ridge_cad.intercept_:>7.3f}") for j, name in enumerate(attr_names): - print(f" {name:<20} {ridge_cad.coef_[j]:>7.3f} {ridge_gas.coef_[j]:>7.3f}") + print(f" {name:<20} {ridge_cad.coef_[j]:>7.3f}") - # ---- Random Forest (reduced features) ---- + # ---- Random Forest (cadence only) ---- rf = RandomForestRegressor( n_estimators=200, max_depth=None, - min_samples_leaf=3, # stronger regularization - max_features=min(4, k_attr), # cap at 4 features per split + min_samples_leaf=3, + max_features=min(4, k_attr), random_state=42, n_jobs=-1, ) - rf.fit(X_attr, Y) - Y_rf_train = rf.predict(X_attr) + rf.fit(X_attr, Y_cad) + Y_rf_cad_train = rf.predict(X_attr) - r2_rf_cad = compute_r2(Y_cad, Y_rf_train[:, 0], ss_tot_cad) - r2_rf_gas = compute_r2(Y_gas, Y_rf_train[:, 1], ss_tot_gas) + r2_rf_cad = compute_r2(Y_cad, Y_rf_cad_train, ss_tot_cad) print(f"\n Random Forest (min_leaf=3, max_feat=4):") - print(f" In-sample R²: cadence={r2_rf_cad:.3f}, gas={r2_rf_gas:.3f}") + print(f" In-sample R² cadence: {r2_rf_cad:.3f}") # RF LOO-CV - Y_rf_loo = np.zeros_like(Y) + Y_rf_cad_loo = np.zeros_like(Y_cad) for train_idx, test_idx in loo.split(X_attr): rf_loo = RandomForestRegressor( n_estimators=200, max_depth=None, min_samples_leaf=3, max_features=min(4, k_attr), random_state=42, n_jobs=-1, ) - rf_loo.fit(X_attr[train_idx], Y[train_idx]) - Y_rf_loo[test_idx] = rf_loo.predict(X_attr[test_idx]) + rf_loo.fit(X_attr[train_idx], Y_cad[train_idx]) + Y_rf_cad_loo[test_idx] = rf_loo.predict(X_attr[test_idx]) - r2_rf_loo_cad = compute_r2(Y_cad, Y_rf_loo[:, 0], ss_tot_cad) - r2_rf_loo_gas = compute_r2(Y_gas, Y_rf_loo[:, 1], ss_tot_gas) - print(f" LOO-CV R²: cadence={r2_rf_loo_cad:.3f}, gas={r2_rf_loo_gas:.3f}") - print(f" LOO-CV MAE: cadence={np.mean(np.abs(np.exp(Y_cad) - np.exp(Y_rf_loo[:, 0]))):.1f} min, " - f"gas=${np.mean(np.abs(np.exp(Y_gas) - np.exp(Y_rf_loo[:, 1]))):.2f}") + r2_rf_loo_cad = compute_r2(Y_cad, Y_rf_cad_loo, ss_tot_cad) + print(f" LOO-CV R² cadence: {r2_rf_loo_cad:.3f}") + print(f" LOO-CV MAE cadence: {np.mean(np.abs(np.exp(Y_cad) - np.exp(Y_rf_cad_loo))):.1f} min") print(f"\n Feature importances:") for j, name in enumerate(attr_names): print(f" {name:<20} {rf.feature_importances_[j]:.3f}") # ---- Pick best LOO model ---- - ridge_loo_total = r2_ridge_loo_cad + r2_ridge_loo_gas - rf_loo_total = r2_rf_loo_cad + r2_rf_loo_gas - best = "ridge" if ridge_loo_total >= rf_loo_total else "rf" - print(f"\n Best LOO model: {best} (ridge={ridge_loo_total:.3f} vs rf={rf_loo_total:.3f})") + best = "ridge" if r2_ridge_loo_cad >= r2_rf_loo_cad else "rf" + print(f"\n Best LOO model: {best} (ridge={r2_ridge_loo_cad:.3f} vs rf={r2_rf_loo_cad:.3f})") if best == "ridge": - Y_best_train = Y_ridge_train - Y_best_loo = Y_ridge_loo + Y_best_cad_train = Y_ridge_cad_train + Y_best_cad_loo = Y_ridge_cad_loo r2_best_cad = r2_ridge_cad - r2_best_gas = r2_ridge_gas r2_best_loo_cad = r2_ridge_loo_cad - r2_best_loo_gas = r2_ridge_loo_gas else: - Y_best_train = Y_rf_train - Y_best_loo = Y_rf_loo + Y_best_cad_train = Y_rf_cad_train + Y_best_cad_loo = Y_rf_cad_loo r2_best_cad = r2_rf_cad - r2_best_gas = r2_rf_gas r2_best_loo_cad = r2_rf_loo_cad - r2_best_loo_gas = r2_rf_loo_gas - # Build result dict using the best model's predictions + # Build result dict — gas comes from Option C (which used chain-level values) noise_all = np.array([good_pools[p]["noise_coeffs"] for p in pool_ids]) result = { @@ -288,21 +295,20 @@ def compute_r2(y_true, y_pred, ss_tot): "best_model": best, "predictions": {}, "loo_predictions": {}, - "noise_coeffs": noise_all, # from Option C + "noise_coeffs": noise_all, "r2_train_cad": r2_best_cad, - "r2_train_gas": r2_best_gas, "r2_loo_cad": r2_best_loo_cad, - "r2_loo_gas": r2_best_loo_gas, } for i, pid in enumerate(pool_ids): + log_gas_fixed = good_pools[pid]["log_gas"] result["predictions"][pid] = { - "log_cadence": float(Y_best_train[i, 0]), - "log_gas": float(Y_best_train[i, 1]), + "log_cadence": float(Y_best_cad_train[i]), + "log_gas": float(log_gas_fixed), } result["loo_predictions"][pid] = { - "log_cadence": float(Y_best_loo[i, 0]), - "log_gas": float(Y_best_loo[i, 1]), + "log_cadence": float(Y_best_cad_loo[i]), + "log_gas": float(log_gas_fixed), } return result @@ -369,7 +375,12 @@ def r2(v_arb, v_noise, y): ji = joint_pid_to_idx[pid] x_attr = X_attr_joint[ji] log_cad_a = float(joint_result["bias_cad"]) + float(x_attr @ joint_result["W_cad"]) - log_gas_a = float(joint_result["bias_gas"]) + float(x_attr @ joint_result["W_gas"]) + if joint_result.get("fix_gas"): + gas_a = float(joint_result["gas_per_pool"][ji]) + log_gas_a = np.log(max(gas_a, 1e-6)) + else: + log_gas_a = float(joint_result["bias_gas"]) + float(x_attr @ joint_result["W_gas"]) + gas_a = np.exp(log_gas_a) noise_c_a = joint_result["noise_coeffs"][ji] v_arb_all_a = np.array(interpolate_pool_daily( @@ -379,7 +390,6 @@ def r2(v_arb, v_noise, y): v_noise_a = np.exp(x_obs @ noise_c_a) r2_a = r2(v_arb_a, v_noise_a, y_obs) cad_a = np.exp(log_cad_a) - gas_a = np.exp(log_gas_a) else: v_arb_a = np.full(len(y_obs), np.nan) v_noise_a = np.full(len(y_obs), np.nan) @@ -800,11 +810,11 @@ def main(): panel, matched = load_and_match() - # Step 1: Option C - option_c = run_option_c(matched) + # Step 1: Option C (gas fixed to chain-level costs) + option_c = run_option_c(matched, fix_gas_to_chain=True) - # Step 2: Option A (linear mapping) - joint_ppn, joint_sn = run_option_a(matched, option_c) + # Step 2: Option A (linear mapping, gas fixed) + joint_ppn, joint_sn = run_option_a(matched, option_c, fix_gas_to_chain=True) # Step 3: Option RF (random forest 2-stage) rf_result = run_option_rf(matched, option_c) diff --git a/scripts/run_mlp_calibration.py b/scripts/run_mlp_calibration.py new file mode 100644 index 0000000..e9b377b --- /dev/null +++ b/scripts/run_mlp_calibration.py @@ -0,0 +1,1012 @@ +"""Run calibration with MLPNoiseHead and compare against linear baselines. + +Steps: + 1. Load panel, match to per-day grids + 2. Option C: per-pool L-BFGS-B fits (baseline, gas fixed to chain) + 3. Linear joint: SharedLinearNoiseHead baseline + 4. MLP noise joint: MLPNoiseHead (new) + 5. Full MLP joint: MLPHead cadence + MLPNoiseHead (new) + 6. Per-pool prediction, R², decomposition for each method + 7. Paginated plots, summary distributions, comparison scatter + 8. JSON export +""" + +import json +import os + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +# ---- Config ---- +PANEL_CACHE = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "local_data", "noise_calibration", "panel.parquet", +) +GRID_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "pool_grids_v2", +) +OUTPUT_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "mlp_calibration", +) +OPTION_C_LOSS_CUTOFF = 5.0 +OPTION_C_MAXITER = 500 +JOINT_MAXITER = 5000 +MLP_HIDDEN = 16 +TOP_N = 50 +# Best alpha settings from sweep (phase 2) +ALPHA_CAD = 0.001 +ALPHA_NOISE = 0.1 + + +# ---- Data loading ---- + + +def load_and_match(): + """Load panel, match to grids.""" + from quantammsim.calibration.pool_data import ( + match_grids_to_panel, + replace_panel_volatility_with_binance, + ) + + panel = pd.read_parquet(PANEL_CACHE) + + if "log_tvl_lag1" not in panel.columns: + panel = panel.sort_values(["pool_id", "date"]).reset_index(drop=True) + panel["log_tvl_lag1"] = panel.groupby("pool_id")["log_tvl"].shift(1) + panel = panel.dropna(subset=["log_tvl_lag1"]).reset_index(drop=True) + + pool_counts = panel.groupby("pool_id").size() + valid = pool_counts[pool_counts >= 10].index + panel = panel[panel["pool_id"].isin(valid)].copy() + + print("Replacing volatility with Binance minute data...") + panel = replace_panel_volatility_with_binance(panel) + + print(f"Panel: {len(panel)} obs, {panel['pool_id'].nunique()} pools, " + f"{panel['date'].min()} to {panel['date'].max()}") + + matched = match_grids_to_panel(GRID_DIR, panel) + print(f"Matched: {len(matched)} pools with grids") + return panel, matched + + +def filter_pathological(matched, option_c): + """Drop pools with high Option C loss.""" + good = {p: r for p, r in option_c.items() if r["loss"] <= OPTION_C_LOSS_CUTOFF} + dropped = set(option_c) - set(good) + matched_clean = {p: matched[p] for p in good if p in matched} + if dropped: + print(f" Dropping {len(dropped)} pools (loss > {OPTION_C_LOSS_CUTOFF}):") + for p in sorted(dropped): + print(f" {p} loss={option_c[p]['loss']:.1f}") + return matched_clean, good + + +# ---- Fitting ---- + + +def run_option_c(matched): + """Per-pool fits with gas fixed to chain costs.""" + from quantammsim.calibration.per_pool_fit import fit_all_pools + + print(f"\n--- Option C: per-pool fits ({len(matched)} pools, gas fixed) ---") + results = fit_all_pools(matched, fix_gas_to_chain=True) + + losses = [r["loss"] for r in results.values()] + n_conv = sum(1 for r in results.values() if r["converged"]) + print(f" Converged: {n_conv}/{len(results)}") + print(f" Loss: median={np.median(losses):.4f}, mean={np.mean(losses):.4f}") + return results + + +def run_option_c_reduced(matched): + """Per-pool fits with reduced x_obs (4 covariates) and gas fixed.""" + from quantammsim.calibration.per_pool_fit import fit_all_pools + + print(f"\n--- Option C Reduced: per-pool fits ({len(matched)} pools, " + f"4-covariate x_obs, gas fixed) ---") + results = fit_all_pools(matched, fix_gas_to_chain=True, reduced=True) + + losses = [r["loss"] for r in results.values()] + n_conv = sum(1 for r in results.values() if r["converged"]) + print(f" Converged: {n_conv}/{len(results)}") + print(f" Loss: median={np.median(losses):.4f}, mean={np.mean(losses):.4f}") + return results + + +def _build_gas_values(jdata, matched_clean): + """Build fixed gas values (log-space) from chain data.""" + from quantammsim.calibration.loss import CHAIN_GAS_USD + gas_values = [] + for pid in jdata.pool_ids: + chain = matched_clean[pid]["chain"] + gas_usd = CHAIN_GAS_USD.get(chain, 1.0) + gas_values.append(np.log(max(gas_usd, 1e-6))) + return np.array(gas_values) + + +def run_linear_joint(matched_clean, option_c_clean): + """Joint fit with LinearHead + SharedLinearNoiseHead (baseline).""" + from quantammsim.calibration.calibration_model import CalibrationModel + from quantammsim.calibration.heads import FixedHead, LinearHead, SharedLinearNoiseHead + from quantammsim.calibration.joint_fit import prepare_joint_data + + jdata = prepare_joint_data( + matched_clean, drop_chain_dummies=True, fix_gas_to_chain=True) + gas_values = _build_gas_values(jdata, matched_clean) + + model = CalibrationModel( + cadence_head=LinearHead("cad", alpha=ALPHA_CAD), + gas_head=FixedHead("gas", gas_values), + noise_head=SharedLinearNoiseHead(alpha=ALPHA_NOISE), + ) + + n_pools = len(jdata.pool_data) + k_attr = jdata.x_attr.shape[1] + n_p = model.n_params(n_pools, k_attr) + print(f"\n--- Linear baseline: SharedLinearNoiseHead ({n_pools} pools, {n_p} params) ---") + + result = model.fit(jdata, maxiter=JOINT_MAXITER, warm_start=option_c_clean) + print(f" Loss: {result['init_loss']:.4f} -> {result['loss']:.4f}") + print(f" Converged: {result['converged']}") + return result, model, jdata + + +def run_mlp_noise_joint(matched_clean, option_c_clean, hidden=MLP_HIDDEN): + """Joint fit with LinearHead cadence + FixedHead gas + MLPNoiseHead.""" + from quantammsim.calibration.calibration_model import CalibrationModel + from quantammsim.calibration.heads import FixedHead, LinearHead, MLPNoiseHead + from quantammsim.calibration.joint_fit import prepare_joint_data + + jdata = prepare_joint_data( + matched_clean, drop_chain_dummies=True, fix_gas_to_chain=True) + gas_values = _build_gas_values(jdata, matched_clean) + + model = CalibrationModel( + cadence_head=LinearHead("cad", alpha=ALPHA_CAD), + gas_head=FixedHead("gas", gas_values), + noise_head=MLPNoiseHead(hidden=hidden, alpha=ALPHA_NOISE), + ) + + n_pools = len(jdata.pool_data) + k_attr = jdata.x_attr.shape[1] + n_p = model.n_params(n_pools, k_attr) + print(f"\n--- MLP noise: MLPNoiseHead(hidden={hidden}) ({n_pools} pools, {n_p} params) ---") + + result = model.fit(jdata, maxiter=JOINT_MAXITER, warm_start=option_c_clean) + print(f" Loss: {result['init_loss']:.4f} -> {result['loss']:.4f}") + print(f" Converged: {result['converged']}") + return result, model, jdata + + +def run_mlp_full_joint(matched_clean, option_c_clean, hidden=MLP_HIDDEN): + """Joint fit with MLPHead cadence + FixedHead gas + MLPNoiseHead.""" + from quantammsim.calibration.calibration_model import CalibrationModel + from quantammsim.calibration.heads import FixedHead, MLPHead, MLPNoiseHead + from quantammsim.calibration.joint_fit import prepare_joint_data + + jdata = prepare_joint_data( + matched_clean, drop_chain_dummies=True, fix_gas_to_chain=True) + gas_values = _build_gas_values(jdata, matched_clean) + + model = CalibrationModel( + cadence_head=MLPHead("cad", hidden=hidden, alpha=ALPHA_CAD), + gas_head=FixedHead("gas", gas_values), + noise_head=MLPNoiseHead(hidden=hidden, alpha=ALPHA_NOISE), + ) + + n_pools = len(jdata.pool_data) + k_attr = jdata.x_attr.shape[1] + n_p = model.n_params(n_pools, k_attr) + print(f"\n--- Full MLP: MLPHead(cad) + MLPNoiseHead ({n_pools} pools, {n_p} params) ---") + + result = model.fit(jdata, maxiter=JOINT_MAXITER, warm_start=option_c_clean) + print(f" Loss: {result['init_loss']:.4f} -> {result['loss']:.4f}") + print(f" Converged: {result['converged']}") + return result, model, jdata + + +def run_two_stage_joint(matched_clean, option_c_clean, hidden=MLP_HIDDEN): + """Two-stage joint fit to identify cadence separately from noise. + + Stage 1: LinearHead(cad) + FixedHead(gas) + PerPoolNoiseHead + Per-pool noise (8 coeffs/pool) can't fully absorb arb's daily + volatility pattern, so cadence is identified. + + Stage 2: FixedHead(cad, stage1_values) + FixedHead(gas) + MLPNoiseHead + Cadence frozen from stage 1, MLP learns shared noise mapping. + + For new-pool prediction: stage 1 linear coefficients give cadence, + stage 2 MLP gives noise. + """ + import jax.numpy as jnp + from quantammsim.calibration.calibration_model import CalibrationModel + from quantammsim.calibration.heads import ( + FixedHead, LinearHead, MLPNoiseHead, PerPoolNoiseHead, + ) + from quantammsim.calibration.joint_fit import prepare_joint_data + + jdata = prepare_joint_data( + matched_clean, drop_chain_dummies=True, fix_gas_to_chain=True) + gas_values = _build_gas_values(jdata, matched_clean) + n_pools = len(jdata.pool_data) + k_attr = jdata.x_attr.shape[1] + + # ---- Stage 1: fit cadence with per-pool noise ---- + stage1_model = CalibrationModel( + cadence_head=LinearHead("cad", alpha=ALPHA_CAD), + gas_head=FixedHead("gas", gas_values), + noise_head=PerPoolNoiseHead(), + ) + n_p1 = stage1_model.n_params(n_pools, k_attr) + print(f"\n--- Two-stage S1: LinearHead(cad) + PerPoolNoiseHead " + f"({n_pools} pools, {n_p1} params) ---") + + stage1_result = stage1_model.fit( + jdata, maxiter=JOINT_MAXITER, warm_start=option_c_clean) + print(f" Loss: {stage1_result['init_loss']:.4f} -> {stage1_result['loss']:.4f}") + print(f" Converged: {stage1_result['converged']}") + + # Extract per-pool cadences from stage 1 + params1 = jnp.array(stage1_result["params_flat"]) + (cs, ce), _, _ = stage1_model._head_slices(n_pools, k_attr) + cad_slice = params1[cs:ce] + stage1_cadences = np.array([ + float(stage1_model.cadence_head.predict(cad_slice, i, jdata.x_attr[i])) + for i in range(n_pools) + ]) + print(f" Cadence range: {np.exp(stage1_cadences.min()):.1f} - " + f"{np.exp(stage1_cadences.max()):.1f} min") + + # ---- Stage 2: fit MLP noise with frozen cadence ---- + stage2_model = CalibrationModel( + cadence_head=FixedHead("cad", stage1_cadences), + gas_head=FixedHead("gas", gas_values), + noise_head=MLPNoiseHead(hidden=hidden, alpha=ALPHA_NOISE), + ) + n_p2 = stage2_model.n_params(n_pools, k_attr) + print(f"\n--- Two-stage S2: FixedHead(cad) + MLPNoiseHead(hidden={hidden}) " + f"({n_pools} pools, {n_p2} params) ---") + + # Build warm-start for stage 2 noise from stage 1 per-pool noise + (_, _), (_, _), (ns, ne) = stage1_model._head_slices(n_pools, k_attr) + noise_params1 = np.array(params1[ns:ne]) + stage2_warm = {} + for i, pid in enumerate(jdata.pool_ids): + noise_c = np.array(stage1_model.noise_head.predict( + jnp.array(noise_params1), i, jdata.x_attr[i])) + stage2_warm[pid] = {"noise_coeffs": noise_c} + + stage2_result = stage2_model.fit( + jdata, maxiter=JOINT_MAXITER, warm_start=stage2_warm) + print(f" Loss: {stage2_result['init_loss']:.4f} -> {stage2_result['loss']:.4f}") + print(f" Converged: {stage2_result['converged']}") + + # Build a composite result dict for downstream use + # Cadence comes from stage 1 linear head, noise from stage 2 MLP + result = { + "stage1_result": stage1_result, + "stage2_result": stage2_result, + "loss": stage2_result["loss"], + "init_loss": stage1_result["init_loss"], + "converged": stage1_result["converged"] and stage2_result["converged"], + "n_pools": n_pools, + "k_attr": k_attr, + "pool_ids": jdata.pool_ids, + "attr_names": jdata.attr_names, + } + + return result, stage1_model, stage2_model, jdata + + +def run_reduced_joint(matched_clean, option_c_clean, hidden=MLP_HIDDEN): + """Joint fit with reduced x_obs (4 cols) to avoid noise-cadence confounding. + + Removes sigma- and fee-dependent features from the noise model's x_obs + so the arb channel (grid + cadence) is the only path for volatility-driven + volume variation. See docs/noise_covariate_design.md for theory. + + Uses LinearHead(cad) + FixedHead(gas) + MLPNoiseHead(k_obs=4). + Cadence warm-started from Option C; noise cold-started (OLS on 4-col x_obs). + """ + import jax.numpy as jnp + from quantammsim.calibration.calibration_model import CalibrationModel + from quantammsim.calibration.heads import FixedHead, LinearHead, MLPNoiseHead + from quantammsim.calibration.joint_fit import prepare_joint_data + from quantammsim.calibration.pool_data import K_OBS_REDUCED + + jdata = prepare_joint_data( + matched_clean, drop_chain_dummies=True, + fix_gas_to_chain=True, reduced_x_obs=True) + gas_values = _build_gas_values(jdata, matched_clean) + n_pools = len(jdata.pool_data) + k_attr = jdata.x_attr.shape[1] + + model = CalibrationModel( + cadence_head=LinearHead("cad", alpha=ALPHA_CAD), + gas_head=FixedHead("gas", gas_values), + noise_head=MLPNoiseHead(hidden=hidden, alpha=ALPHA_NOISE, + k_obs=K_OBS_REDUCED), + ) + + n_p = model.n_params(n_pools, k_attr) + print(f"\n--- Reduced x_obs: LinearHead(cad) + MLPNoiseHead(k_obs={K_OBS_REDUCED}, " + f"hidden={hidden}) ({n_pools} pools, {n_p} params) ---") + + # Only warm-start cadence (noise dimension changed 8→4, skip noise warm-start) + warm_cad = {} + for pid in jdata.pool_ids: + if pid in option_c_clean: + warm_cad[pid] = {"cad": option_c_clean[pid]["log_cadence"]} + + result = model.fit(jdata, maxiter=JOINT_MAXITER, warm_start=warm_cad) + print(f" Loss: {result['init_loss']:.4f} -> {result['loss']:.4f}") + print(f" Converged: {result['converged']}") + + return result, model, jdata + + +def _extract_two_stage_per_pool(stage1_model, stage2_model, result, jdata): + """Extract per-pool params from two-stage result.""" + import jax.numpy as jnp + + stage1_result = result["stage1_result"] + stage2_result = result["stage2_result"] + n_pools = result["n_pools"] + k_attr = result["k_attr"] + + params1 = jnp.array(stage1_result["params_flat"]) + params2 = jnp.array(stage2_result["params_flat"]) + + (cs1, ce1), _, (ns1, ne1) = stage1_model._head_slices(n_pools, k_attr) + _, _, (ns2, ne2) = stage2_model._head_slices(n_pools, k_attr) + + cad_slice = params1[cs1:ce1] + noise_slice = params2[ns2:ne2] + + per_pool = [] + for i in range(n_pools): + x_attr_i = jdata.x_attr[i] + log_cad = float(stage1_model.cadence_head.predict(cad_slice, i, x_attr_i)) + log_gas = float(stage2_model.gas_head.predict( + jnp.array([]), i, x_attr_i)) # FixedHead ignores params + noise_c = np.array(stage2_model.noise_head.predict(noise_slice, i, x_attr_i)) + per_pool.append({ + "log_cadence": log_cad, + "log_gas": log_gas, + "noise_coeffs": noise_c, + "cadence_minutes": float(np.exp(log_cad)), + "gas_usd": float(np.exp(log_gas)), + }) + return per_pool + + +# ---- Per-pool predictions ---- + + +def _extract_per_pool_params(model, result, jdata): + """Extract per-pool (log_cadence, log_gas, noise_coeffs) from a CalibrationModel result.""" + import jax.numpy as jnp + + params = jnp.array(result["params_flat"]) + n_pools = result["n_pools"] + k_attr = result["k_attr"] + (cs, ce), (gs, ge), (ns, ne) = model._head_slices(n_pools, k_attr) + + cad_slice = params[cs:ce] + gas_slice = params[gs:ge] + noise_slice = params[ns:ne] + + per_pool = [] + for i in range(n_pools): + x_attr_i = jdata.x_attr[i] + log_cad = float(model.cadence_head.predict(cad_slice, i, x_attr_i)) + log_gas = float(model.gas_head.predict(gas_slice, i, x_attr_i)) + noise_c = np.array(model.noise_head.predict(noise_slice, i, x_attr_i)) + per_pool.append({ + "log_cadence": log_cad, + "log_gas": log_gas, + "noise_coeffs": noise_c, + "cadence_minutes": float(np.exp(log_cad)), + "gas_usd": float(np.exp(log_gas)), + }) + return per_pool + + +def compute_per_pool_predictions(matched, option_c_results, + model_results, reduced_models=None): + """Compute V_arb, V_noise, R² per pool for Option C and each joint model. + + model_results: list of (label, per_pool_params, pool_ids) tuples, + where per_pool_params[i] is a dict with log_cadence, log_gas, noise_coeffs. + + reduced_models: list of label strings whose noise_coeffs correspond to + reduced x_obs (4 columns). For those, build_x_obs(reduced=True) is used. + """ + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + from quantammsim.calibration.pool_data import build_x_obs + import jax.numpy as jnp + + if reduced_models is None: + reduced_models = [] + + pool_ids = sorted(matched.keys()) + + # Build lookup for each model's per-pool params + model_lookups = [] + for label, per_pool_params, m_pool_ids in model_results: + lookup = {pid: per_pool_params[i] for i, pid in enumerate(m_pool_ids)} + model_lookups.append((label, lookup)) + + def r2(v_arb, v_noise, y): + log_pred = np.log(np.maximum(v_arb + v_noise, 1e-6)) + ss_res = np.sum((log_pred - y) ** 2) + ss_tot = np.sum((y - y.mean()) ** 2) + return 1 - ss_res / max(ss_tot, 1e-10) + + predictions = {} + for pid in pool_ids: + entry = matched[pid] + panel = entry["panel"] + coeffs = entry["coeffs"] + day_indices = entry["day_indices"] + + x_obs_full = build_x_obs(panel) + x_obs_red = None # lazy-build only if needed + y_obs = panel["log_volume"].values.astype(float) + + p = { + "dates": pd.to_datetime(panel["date"].values), + "y_obs": y_obs, + "actual_vol": np.exp(y_obs), + "chain": entry["chain"], + "tokens": entry["tokens"], + "fee": entry["fee"], + "median_tvl": float(np.exp(panel["log_tvl_lag1"].median())), + "n_obs": len(y_obs), + } + + # Option C + rc = option_c_results[pid] + v_arb_all = np.array(interpolate_pool_daily( + coeffs, jnp.float64(rc["log_cadence"]), + jnp.float64(np.exp(rc["log_gas"])))) + v_arb_c = v_arb_all[day_indices] + v_noise_c = np.exp(x_obs_full @ rc["noise_coeffs"]) + p["v_arb_c"] = v_arb_c + p["v_noise_c"] = v_noise_c + p["r2_c"] = r2(v_arb_c, v_noise_c, y_obs) + p["cadence_c"] = rc["cadence_minutes"] + p["gas_c"] = rc["gas_usd"] + + # Each joint model + for label, lookup in model_lookups: + if pid in lookup: + mp = lookup[pid] + v_arb_all = np.array(interpolate_pool_daily( + coeffs, jnp.float64(mp["log_cadence"]), + jnp.float64(np.exp(mp["log_gas"])))) + v_arb = v_arb_all[day_indices] + + # Use reduced x_obs for models that were trained with it + if label in reduced_models: + if x_obs_red is None: + x_obs_red = build_x_obs(panel, reduced=True) + v_noise = np.exp(x_obs_red @ mp["noise_coeffs"]) + else: + v_noise = np.exp(x_obs_full @ mp["noise_coeffs"]) + + p[f"v_arb_{label}"] = v_arb + p[f"v_noise_{label}"] = v_noise + p[f"r2_{label}"] = r2(v_arb, v_noise, y_obs) + p[f"cadence_{label}"] = mp["cadence_minutes"] + p[f"gas_{label}"] = mp["gas_usd"] + else: + n = len(y_obs) + p[f"v_arb_{label}"] = np.full(n, np.nan) + p[f"v_noise_{label}"] = np.full(n, np.nan) + p[f"r2_{label}"] = np.nan + p[f"cadence_{label}"] = np.nan + p[f"gas_{label}"] = np.nan + + predictions[pid] = p + + return predictions + + +# ---- Tables ---- + + +def print_pool_table(predictions, method_labels): + """Print per-pool results ranked by TVL.""" + ranked = sorted(predictions.items(), key=lambda x: -x[1]["median_tvl"]) + + header = f"{'Pool':<24} {'Chain':<10} {'TVL':>12} {'N':>4}" + header += f" {'Cad_C':>6} {'R2_C':>6}" + for label in method_labels: + short = label[:8] + header += f" {'Cad_'+short:>10} {'R2_'+short:>8}" + header += f" {'Arb%_C':>6}" + + print(f"\n{'='*len(header)}") + print(header) + print(f"{'-'*len(header)}") + for pid, p in ranked: + tokens = p["tokens"] + if isinstance(tokens, str): + tok_str = "/".join(t.strip()[:6] for t in tokens.split(",")[:2]) + else: + tok_str = pid[:16] + arb_total = p["v_arb_c"] + p["v_noise_c"] + arb_frac = np.median(p["v_arb_c"] / np.maximum(arb_total, 1.0)) + + line = (f"{tok_str:<24} {p['chain']:<10} ${p['median_tvl']:>10,.0f} " + f"{p['n_obs']:>4}") + line += f" {p['cadence_c']:>5.1f}m {p['r2_c']:>6.3f}" + for label in method_labels: + cad = p[f"cadence_{label}"] + r2v = p[f"r2_{label}"] + if np.isnan(cad): + line += f" {'---':>10} {'---':>8}" + else: + line += f" {cad:>9.1f}m {r2v:>8.3f}" + line += f" {arb_frac:>5.1%}" + print(line) + + +def print_r2_comparison(predictions, method_labels): + """Print aggregate R² comparison.""" + pool_ids = sorted(predictions.keys()) + + print(f"\n{'='*70}") + print("R² comparison (per-pool, in-sample)") + print(f"{'='*70}") + + r2_c = [predictions[p]["r2_c"] for p in pool_ids] + print(f" Option C (per-pool): median={np.median(r2_c):.4f} mean={np.mean(r2_c):.4f}") + + for label in method_labels: + r2_vals = [predictions[p][f"r2_{label}"] for p in pool_ids + if np.isfinite(predictions[p][f"r2_{label}"])] + if r2_vals: + print(f" {label:<22} median={np.median(r2_vals):.4f} mean={np.mean(r2_vals):.4f}") + + +def print_loss_comparison(option_c, joint_results): + """Print joint loss comparison.""" + print(f"\n{'='*70}") + print("Joint loss comparison") + print(f"{'='*70}") + + c_losses = [r["loss"] for r in option_c.values()] + print(f" Option C (per-pool): median={np.median(c_losses):.4f} mean={np.mean(c_losses):.4f}") + + for label, result in joint_results: + print(f" {label:<22} loss={result['loss']:.4f} (from {result['init_loss']:.4f})") + + +# ---- Plots ---- + + +def plot_decomposition_pages(predictions, method, method_label, output_dir): + """Paginated V_arb + V_noise stacked area decomposition.""" + ranked = sorted(predictions.items(), key=lambda x: -x[1]["median_tvl"])[:TOP_N] + + per_page = 10 + n_pages = (len(ranked) + per_page - 1) // per_page + + for page in range(n_pages): + start = page * per_page + end = min(start + per_page, len(ranked)) + page_pools = ranked[start:end] + n_this = len(page_pools) + + ncols = 2 + nrows = (n_this + ncols - 1) // ncols + fig, axes = plt.subplots(nrows, ncols, figsize=(16, 4.5 * nrows)) + if nrows == 1 and ncols == 1: + axes = np.array([[axes]]) + elif nrows == 1: + axes = axes.reshape(1, -1) + elif ncols == 1: + axes = axes.reshape(-1, 1) + + for idx, (pid, p) in enumerate(page_pools): + ax = axes[idx // ncols][idx % ncols] + dates = p["dates"] + + v_arb_key = f"v_arb_{method}" if method != "c" else "v_arb_c" + v_noise_key = f"v_noise_{method}" if method != "c" else "v_noise_c" + r2_key = f"r2_{method}" if method != "c" else "r2_c" + cad_key = f"cadence_{method}" if method != "c" else "cadence_c" + gas_key = f"gas_{method}" if method != "c" else "gas_c" + + v_arb = p[v_arb_key] + v_noise = p[v_noise_key] + r2_val = p[r2_key] + cad = p[cad_key] + gas = p[gas_key] + + if np.any(np.isnan(v_arb)): + ax.text(0.5, 0.5, f"Dropped from {method_label}", fontsize=12, + ha="center", va="center", transform=ax.transAxes, color="gray") + ax.set_title(f"{pid[:16]} — dropped", fontsize=8) + continue + + v_total = v_arb + v_noise + arb_frac = np.median(v_arb / np.maximum(v_total, 1.0)) + actual = p["actual_vol"] + + ax.fill_between(dates, 0, np.maximum(v_arb, 0), + alpha=0.3, color="orangered", label="V_arb (grid)") + ax.fill_between(dates, np.maximum(v_arb, 0), np.maximum(v_total, 0), + alpha=0.3, color="steelblue", label="V_noise") + ax.plot(dates, actual, "k-", linewidth=0.8, alpha=0.7, label="Actual") + ax.plot(dates, np.maximum(v_total, 0), "--", color="purple", + linewidth=0.8, alpha=0.7, label="Predicted total") + + ax.set_yscale("log") + ax.set_ylabel("Daily volume (USD)", fontsize=8) + + tokens = p["tokens"] + if isinstance(tokens, str): + tok_str = "/".join(t.strip()[:8] for t in tokens.split(",")[:2]) + else: + tok_str = pid[:16] + + ax.set_title( + f"{tok_str} ({p['chain']})\n" + f"TVL ${p['median_tvl']:,.0f} | R\u00b2={r2_val:.3f} " + f"cad={cad:.1f}min gas=${gas:.2f} " + f"arb_frac={arb_frac:.1%} n={p['n_obs']}", + fontsize=8, + ) + ax.legend(fontsize=6, loc="upper right") + ax.tick_params(labelsize=7) + ax.tick_params(axis="x", rotation=30) + + for idx in range(n_this, nrows * ncols): + axes[idx // ncols][idx % ncols].set_visible(False) + + fig.suptitle( + f"Calibration decomposition: {method_label}\n" + f"page {page + 1}/{n_pages} (top {min(TOP_N, len(ranked))} by TVL)", + fontsize=11, + ) + fig.tight_layout() + safe_method = method.replace(" ", "_") + out = os.path.join(output_dir, f"{safe_method}_page{page + 1}.png") + fig.savefig(out, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" Saved: {out}") + + +def plot_summary_distributions(predictions, method_labels, output_dir): + """Histograms of cadence, R², arb fraction for each method.""" + pool_ids = sorted(predictions.keys()) + methods = ["c"] + method_labels + labels = ["Option C"] + method_labels + n_methods = len(methods) + + fig, axes = plt.subplots(n_methods, 3, figsize=(15, 4 * n_methods)) + if n_methods == 1: + axes = axes.reshape(1, -1) + + for row, (method, label) in enumerate(zip(methods, labels)): + cad_key = f"cadence_{method}" if method != "c" else "cadence_c" + r2_key = f"r2_{method}" if method != "c" else "r2_c" + + cads = [predictions[p][cad_key] for p in pool_ids + if np.isfinite(predictions[p][cad_key])] + r2s = [predictions[p][r2_key] for p in pool_ids + if np.isfinite(predictions[p][r2_key])] + arb_fracs = [] + for p in pool_ids: + v_arb_key = f"v_arb_{method}" if method != "c" else "v_arb_c" + v_noise_key = f"v_noise_{method}" if method != "c" else "v_noise_c" + v_arb = predictions[p][v_arb_key] + v_noise = predictions[p][v_noise_key] + if not np.any(np.isnan(v_arb)): + total = v_arb + v_noise + arb_fracs.append(np.median(v_arb / np.maximum(total, 1.0))) + + ax = axes[row, 0] + if cads: + ax.hist(cads, bins=20, color="orangered", alpha=0.7, edgecolor="white") + ax.axvline(np.median(cads), color="black", linestyle="--", + label=f"Median={np.median(cads):.1f}min") + ax.set_xlabel("Cadence (minutes)") + ax.set_title(f"{label}: Cadence") + ax.legend(fontsize=8) + + ax = axes[row, 1] + if r2s: + ax.hist(r2s, bins=20, color="green", alpha=0.7, edgecolor="white") + ax.axvline(np.median(r2s), color="black", linestyle="--", + label=f"Median={np.median(r2s):.3f}") + ax.set_xlabel("R\u00b2") + ax.set_title(f"{label}: R\u00b2") + ax.legend(fontsize=8) + + ax = axes[row, 2] + if arb_fracs: + ax.hist(arb_fracs, bins=20, color="steelblue", alpha=0.7, edgecolor="white") + ax.axvline(np.median(arb_fracs), color="black", linestyle="--", + label=f"Median={np.median(arb_fracs):.2f}") + ax.set_xlabel("Arb fraction") + ax.set_title(f"{label}: Arb fraction") + ax.legend(fontsize=8) + + fig.tight_layout() + out = os.path.join(output_dir, "summary_distributions.png") + fig.savefig(out, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" Saved: {out}") + + +def plot_r2_scatter(predictions, method_labels, output_dir): + """Scatter: Option C R² vs each joint method R².""" + pool_ids = sorted(predictions.keys()) + n = len(method_labels) + + fig, axes = plt.subplots(1, n, figsize=(6 * n, 5)) + if n == 1: + axes = [axes] + + for ax, label in zip(axes, method_labels): + r2_c = [] + r2_m = [] + for p in pool_ids: + rc = predictions[p]["r2_c"] + rm = predictions[p][f"r2_{label}"] + if np.isfinite(rc) and np.isfinite(rm): + r2_c.append(rc) + r2_m.append(rm) + + ax.scatter(r2_c, r2_m, alpha=0.7, s=30, edgecolors="k", linewidth=0.5) + lo = min(min(r2_c), min(r2_m)) if r2_c else 0 + hi = max(max(r2_c), max(r2_m)) if r2_c else 1 + margin = (hi - lo) * 0.05 + 0.01 + ax.plot([lo - margin, hi + margin], [lo - margin, hi + margin], + "k--", alpha=0.3, linewidth=1) + ax.set_xlabel("Option C R\u00b2") + ax.set_ylabel(f"{label} R\u00b2") + ax.set_title(f"Option C vs {label}") + + # Count wins + wins = sum(1 for c, m in zip(r2_c, r2_m) if m > c) + ax.text(0.05, 0.95, f"{label} wins: {wins}/{len(r2_c)}", + transform=ax.transAxes, fontsize=9, va="top") + + fig.tight_layout() + out = os.path.join(output_dir, "r2_scatter.png") + fig.savefig(out, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" Saved: {out}") + + +def plot_cadence_by_chain(predictions, method_labels, output_dir): + """Cadence distributions by chain for Option C and each method.""" + pool_ids = sorted(predictions.keys()) + chains = sorted(set(predictions[p]["chain"] for p in pool_ids)) + colors = plt.cm.tab10(np.linspace(0, 1, max(len(chains), 1))) + chain_color = {c: colors[i] for i, c in enumerate(chains)} + + methods = ["c"] + method_labels + labels = ["Option C"] + method_labels + n = len(methods) + + fig, axes = plt.subplots(1, n, figsize=(6 * n, 5)) + if n == 1: + axes = [axes] + + for ax, method, label in zip(axes, methods, labels): + cad_key = f"cadence_{method}" if method != "c" else "cadence_c" + for chain in chains: + cads = [predictions[p][cad_key] for p in pool_ids + if predictions[p]["chain"] == chain + and np.isfinite(predictions[p][cad_key])] + if cads: + ax.scatter([chain] * len(cads), cads, color=chain_color[chain], + alpha=0.7, s=40, edgecolors="k", linewidth=0.3) + ax.set_ylabel("Cadence (minutes)") + ax.set_title(label) + ax.tick_params(axis="x", rotation=45) + + fig.tight_layout() + out = os.path.join(output_dir, "cadence_by_chain.png") + fig.savefig(out, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" Saved: {out}") + + +# ---- JSON export ---- + + +def save_results_json(predictions, option_c_results, joint_results, output_dir): + """Save all fitted parameters and diagnostics.""" + out = {"option_c": {}} + for pid, r in option_c_results.items(): + out["option_c"][pid] = { + "log_cadence": r["log_cadence"], + "log_gas": r["log_gas"], + "noise_coeffs": r["noise_coeffs"].tolist(), + "loss": r["loss"], + "converged": bool(r["converged"]), + "cadence_minutes": r["cadence_minutes"], + "gas_usd": r["gas_usd"], + "chain": r.get("chain", ""), + "fee": r.get("fee", 0), + "tokens": r.get("tokens", ""), + } + + for label, result in joint_results: + entry = { + "loss": result["loss"], + "init_loss": result["init_loss"], + "converged": bool(result["converged"]), + "n_pools": result.get("n_pools", 0), + "k_attr": result.get("k_attr", 0), + "pool_ids": result.get("pool_ids", []), + "attr_names": result.get("attr_names", []), + } + # Include any scalar/array results the heads produced + for key in ["bias_cad", "W_cad", "bias_gas", "W_gas", + "bias_noise", "W_noise", "noise_coeffs"]: + if key in result: + val = result[key] + entry[key] = val.tolist() if hasattr(val, "tolist") else val + out[label] = entry + + # Per-pool R² for each method + pool_ids = sorted(predictions.keys()) + method_labels = [label for label, _ in joint_results] + per_pool_r2 = {} + for pid in pool_ids: + p = predictions[pid] + row = {"r2_c": p["r2_c"]} + for label in method_labels: + row[f"r2_{label}"] = p[f"r2_{label}"] + per_pool_r2[pid] = row + out["per_pool_r2"] = per_pool_r2 + + path = os.path.join(output_dir, "mlp_calibration_results.json") + with open(path, "w") as f: + json.dump(out, f, indent=2, default=str) + print(f" Saved: {path}") + + +# ---- Main ---- + + +def _save_per_pool_results(results, output_path, label="option_c_reduced"): + """Save per-pool fit results to JSON immediately.""" + out = {} + for pid, r in results.items(): + out[pid] = { + "log_cadence": r["log_cadence"], + "log_gas": r["log_gas"], + "noise_coeffs": r["noise_coeffs"].tolist(), + "loss": r["loss"], + "converged": bool(r["converged"]), + "cadence_minutes": r["cadence_minutes"], + "gas_usd": r["gas_usd"], + "chain": r.get("chain", ""), + "fee": r.get("fee", 0), + "tokens": r.get("tokens", ""), + } + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, "w") as f: + json.dump({label: out}, f, indent=2) + print(f" Saved {len(out)} pool results to {output_path}") + + +def main(): + os.environ.setdefault("JAX_PLATFORMS", "cpu") + + print("=" * 70) + print("MLP Calibration: MLPNoiseHead vs Linear Baseline") + print("=" * 70) + + panel, matched = load_and_match() + + # Step 0: 4-covariate per-pool fits (fast, save immediately) + option_c_reduced = run_option_c_reduced(matched) + _save_per_pool_results( + option_c_reduced, + os.path.join(OUTPUT_DIR, "option_c_reduced.json"), + label="option_c_reduced", + ) + + # Step 1: Option C baseline (8-covariate) + option_c = run_option_c(matched) + + # Step 2: Filter pathological pools + matched_clean, option_c_clean = filter_pathological(matched, option_c) + + # Step 3: Fit each joint model + linear_result, linear_model, jdata = run_linear_joint( + matched_clean, option_c_clean) + mlp_noise_result, mlp_noise_model, _ = run_mlp_noise_joint( + matched_clean, option_c_clean) + mlp_full_result, mlp_full_model, _ = run_mlp_full_joint( + matched_clean, option_c_clean) + + # Two-stage: cadence identified with per-pool noise, then MLP noise + two_stage_result, ts_s1_model, ts_s2_model, _ = run_two_stage_joint( + matched_clean, option_c_clean) + + # Reduced x_obs: prune sigma/fee features from noise covariates + reduced_result, reduced_model, jdata_reduced = run_reduced_joint( + matched_clean, option_c_clean) + + # Step 4: Extract per-pool params from each model + linear_pp = _extract_per_pool_params(linear_model, linear_result, jdata) + mlp_noise_pp = _extract_per_pool_params(mlp_noise_model, mlp_noise_result, jdata) + mlp_full_pp = _extract_per_pool_params(mlp_full_model, mlp_full_result, jdata) + two_stage_pp = _extract_two_stage_per_pool( + ts_s1_model, ts_s2_model, two_stage_result, jdata) + reduced_pp = _extract_per_pool_params( + reduced_model, reduced_result, jdata_reduced) + + method_labels = ["linear", "mlp_noise", "mlp_full", "two_stage", "reduced"] + model_results_for_pred = [ + ("linear", linear_pp, jdata.pool_ids), + ("mlp_noise", mlp_noise_pp, jdata.pool_ids), + ("mlp_full", mlp_full_pp, jdata.pool_ids), + ("two_stage", two_stage_pp, jdata.pool_ids), + ("reduced", reduced_pp, jdata_reduced.pool_ids), + ] + + # Step 5: Per-pool predictions + print("\nComputing per-pool predictions...") + predictions = compute_per_pool_predictions( + matched_clean, option_c_clean, model_results_for_pred, + reduced_models=["reduced"]) + + # Step 6: Tables + print_pool_table(predictions, method_labels) + print_r2_comparison(predictions, method_labels) + print_loss_comparison(option_c_clean, [ + ("linear", linear_result), + ("mlp_noise", mlp_noise_result), + ("mlp_full", mlp_full_result), + ("two_stage", two_stage_result), + ("reduced", reduced_result), + ]) + + # Step 7: Plots + print("\nGenerating plots...") + os.makedirs(OUTPUT_DIR, exist_ok=True) + + plot_decomposition_pages(predictions, "c", "Option C (per-pool)", OUTPUT_DIR) + plot_decomposition_pages(predictions, "linear", "Linear shared noise", OUTPUT_DIR) + plot_decomposition_pages(predictions, "mlp_noise", "MLP noise (linear cad)", OUTPUT_DIR) + plot_decomposition_pages(predictions, "mlp_full", "Full MLP (MLP cad + MLP noise)", OUTPUT_DIR) + plot_decomposition_pages(predictions, "two_stage", "Two-stage (linear cad -> MLP noise)", OUTPUT_DIR) + plot_decomposition_pages(predictions, "reduced", "Reduced x_obs (k_obs=4)", OUTPUT_DIR) + + plot_summary_distributions(predictions, method_labels, OUTPUT_DIR) + plot_r2_scatter(predictions, method_labels, OUTPUT_DIR) + plot_cadence_by_chain(predictions, method_labels, OUTPUT_DIR) + + # Step 8: JSON export + save_results_json(predictions, option_c_clean, [ + ("linear", linear_result), + ("mlp_noise", mlp_noise_result), + ("mlp_full", mlp_full_result), + ("two_stage", two_stage_result), + ("reduced", reduced_result), + ], OUTPUT_DIR) + + print(f"\n{'='*70}") + print(f"Done. Output in: {OUTPUT_DIR}") + + +if __name__ == "__main__": + main() diff --git a/scripts/run_mlp_sweep.py b/scripts/run_mlp_sweep.py new file mode 100644 index 0000000..aa28115 --- /dev/null +++ b/scripts/run_mlp_sweep.py @@ -0,0 +1,429 @@ +"""Hyperparameter sweep for MLP calibration models. + +Sweeps maxiter, alpha, hidden size, maxcor, and loss_type to find settings +where the MLP converges and achieves the best per-pool R2. + +Usage: + python scripts/run_mlp_sweep.py [--phase 1|2|3|all] +""" + +import argparse +import json +import os +import time + +import numpy as np +import pandas as pd + +# ---- Config ---- +PANEL_CACHE = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "local_data", "noise_calibration", "panel.parquet", +) +GRID_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "pool_grids_v2", +) +OUTPUT_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "mlp_sweep", +) +OPTION_C_LOSS_CUTOFF = 5.0 + + +def load_data(): + """Load panel, match grids, run Option C, return clean data.""" + from quantammsim.calibration.per_pool_fit import fit_all_pools + from quantammsim.calibration.pool_data import ( + build_pool_attributes, + build_x_obs, + match_grids_to_panel, + replace_panel_volatility_with_binance, + ) + + panel = pd.read_parquet(PANEL_CACHE) + print("Replacing volatility with Binance minute data...") + panel = replace_panel_volatility_with_binance(panel) + matched = match_grids_to_panel(GRID_DIR, panel) + print(f"Matched: {len(matched)} pools with grids") + + # Option C baseline + print(f"\n--- Option C: per-pool fits ({len(matched)} pools, gas fixed) ---") + option_c = fit_all_pools(matched, fix_gas_to_chain=True) + n_conv = sum(1 for r in option_c.values() if r["converged"]) + losses = [r["loss"] for r in option_c.values()] + print(f" Converged: {n_conv}/{len(option_c)}") + print(f" Loss: median={np.median(losses):.4f}, mean={np.mean(losses):.4f}") + + # Drop pathological pools + dropped = [p for p, r in option_c.items() if r["loss"] > OPTION_C_LOSS_CUTOFF] + matched_clean = {k: v for k, v in matched.items() if k not in dropped} + option_c_clean = {k: v for k, v in option_c.items() if k not in dropped} + if dropped: + print(f" Dropping {len(dropped)} pools (loss > {OPTION_C_LOSS_CUTOFF})") + + return matched_clean, option_c_clean + + +def compute_per_pool_r2(model, result, jdata, matched): + """Compute per-pool R2 for a fitted model.""" + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + from quantammsim.calibration.pool_data import build_x_obs + import jax.numpy as jnp + + params = jnp.array(result["params_flat"]) + n_pools = result["n_pools"] + k_attr = result["k_attr"] + (cs, ce), (gs, ge), (ns, ne) = model._head_slices(n_pools, k_attr) + + cad_slice = params[cs:ce] + gas_slice = params[gs:ge] + noise_slice = params[ns:ne] + + r2s = [] + for i, pid in enumerate(jdata.pool_ids): + x_attr_i = jdata.x_attr[i] + log_cad = float(model.cadence_head.predict(cad_slice, i, x_attr_i)) + log_gas = float(model.gas_head.predict(gas_slice, i, x_attr_i)) + noise_c = np.array(model.noise_head.predict(noise_slice, i, x_attr_i)) + + entry = matched[pid] + panel = entry["panel"] + coeffs = entry["coeffs"] + day_indices = entry["day_indices"] + x_obs = build_x_obs(panel) + y_obs = panel["log_volume"].values.astype(float) + + v_arb_all = np.array(interpolate_pool_daily( + coeffs, jnp.float64(log_cad), jnp.float64(np.exp(log_gas)))) + v_arb = v_arb_all[day_indices] + v_noise = np.exp(x_obs @ noise_c) + log_pred = np.log(np.maximum(v_arb + v_noise, 1e-6)) + ss_res = np.sum((log_pred - y_obs) ** 2) + ss_tot = np.sum((y_obs - y_obs.mean()) ** 2) + r2s.append(1 - ss_res / max(ss_tot, 1e-10)) + + return np.array(r2s) + + +def run_single(matched_clean, option_c_clean, config): + """Run a single sweep configuration. Returns result dict.""" + from quantammsim.calibration.calibration_model import CalibrationModel + from quantammsim.calibration.heads import ( + FixedHead, LinearHead, MLPHead, MLPNoiseHead, SharedLinearNoiseHead, + ) + from quantammsim.calibration.joint_fit import prepare_joint_data + from quantammsim.calibration.loss import CHAIN_GAS_USD + + jdata = prepare_joint_data( + matched_clean, drop_chain_dummies=True, fix_gas_to_chain=True) + + gas_values = [] + for pid in jdata.pool_ids: + chain = matched_clean[pid]["chain"] + gas_usd = CHAIN_GAS_USD.get(chain, 1.0) + gas_values.append(np.log(max(gas_usd, 1e-6))) + gas_values = np.array(gas_values) + + # Build model from config + alpha_cad = config.get("alpha_cad", 0.01) + alpha_noise = config.get("alpha_noise", 0.01) + hidden = config.get("hidden", 16) + maxiter = config.get("maxiter", 500) + maxcor = config.get("maxcor", 10) + loss_type = config.get("loss_type", "l2") + cad_type = config.get("cad_type", "linear") # "linear" or "mlp" + noise_type = config.get("noise_type", "mlp") # "mlp" or "linear" + + # Cadence head + if cad_type == "mlp": + cad_head = MLPHead("cad", hidden=hidden, alpha=alpha_cad) + else: + cad_head = LinearHead("cad", alpha=alpha_cad) + + # Noise head + if noise_type == "mlp": + noise_head = MLPNoiseHead(hidden=hidden, alpha=alpha_noise) + else: + noise_head = SharedLinearNoiseHead(alpha=alpha_noise) + + model = CalibrationModel( + cadence_head=cad_head, + gas_head=FixedHead("gas", gas_values), + noise_head=noise_head, + loss_type=loss_type, + ) + + n_pools = len(jdata.pool_data) + k_attr = jdata.x_attr.shape[1] + n_params = model.n_params(n_pools, k_attr) + + # Override maxcor in the fit method by monkey-patching options + import scipy.optimize + _orig_minimize = scipy.optimize.minimize + + def patched_minimize(fun, x0, **kwargs): + opts = kwargs.get("options", {}) + opts["maxcor"] = maxcor + opts["maxiter"] = maxiter + kwargs["options"] = opts + return _orig_minimize(fun, x0, **kwargs) + + scipy.optimize.minimize = patched_minimize + try: + t0 = time.time() + result = model.fit(jdata, maxiter=maxiter, warm_start=option_c_clean) + wall_time = time.time() - t0 + finally: + scipy.optimize.minimize = _orig_minimize + + # Compute per-pool R2 + r2s = compute_per_pool_r2(model, result, jdata, matched_clean) + + return { + "config": config, + "n_params": n_params, + "n_pools": n_pools, + "init_loss": result["init_loss"], + "final_loss": result["loss"], + "converged": result["converged"], + "wall_time_s": round(wall_time, 1), + "r2_median": round(float(np.median(r2s)), 4), + "r2_mean": round(float(np.mean(r2s)), 4), + "r2_p10": round(float(np.percentile(r2s, 10)), 4), + "r2_p25": round(float(np.percentile(r2s, 25)), 4), + "r2_p75": round(float(np.percentile(r2s, 75)), 4), + "r2_p90": round(float(np.percentile(r2s, 90)), 4), + "r2_min": round(float(np.min(r2s)), 4), + "r2_max": round(float(np.max(r2s)), 4), + "n_positive_r2": int(np.sum(r2s > 0)), + } + + +def print_result(res, idx=None): + """Print a single result row.""" + c = res["config"] + prefix = f"[{idx}] " if idx is not None else "" + label = (f"{c.get('cad_type','linear')}_cad + " + f"{c.get('noise_type','mlp')}_noise") + print(f"{prefix}{label} " + f"h={c.get('hidden',16):2d} " + f"a_c={c.get('alpha_cad',0.01):.4f} " + f"a_n={c.get('alpha_noise',0.01):.4f} " + f"maxiter={c.get('maxiter',500):5d} " + f"maxcor={c.get('maxcor',10):2d} " + f"loss={c.get('loss_type','l2'):5s} | " + f"L={res['final_loss']:7.4f} " + f"conv={str(res['converged']):5s} " + f"R2_med={res['r2_median']:+.4f} " + f"R2_mean={res['r2_mean']:+.4f} " + f"R2+={res['n_positive_r2']:2d}/{res['n_pools']} " + f"{res['wall_time_s']:5.1f}s") + + +def run_phase_1(matched_clean, option_c_clean): + """Phase 1: Sweep maxiter to diagnose convergence.""" + print("\n" + "=" * 80) + print("Phase 1: maxiter sweep (MLP noise, linear cadence)") + print("=" * 80) + + configs = [] + for maxiter in [500, 2000, 5000]: + configs.append({ + "cad_type": "linear", "noise_type": "mlp", + "maxiter": maxiter, "hidden": 16, + "alpha_cad": 0.01, "alpha_noise": 0.01, + "maxcor": 10, "loss_type": "l2", + "label": f"maxiter={maxiter}", + }) + # Also sweep maxiter for full MLP + for maxiter in [500, 2000, 5000]: + configs.append({ + "cad_type": "mlp", "noise_type": "mlp", + "maxiter": maxiter, "hidden": 16, + "alpha_cad": 0.01, "alpha_noise": 0.01, + "maxcor": 10, "loss_type": "l2", + "label": f"full_mlp_maxiter={maxiter}", + }) + + results = [] + for i, cfg in enumerate(configs): + print(f"\n Running {cfg['label']}...") + res = run_single(matched_clean, option_c_clean, cfg) + print_result(res, i) + results.append(res) + return results + + +def run_phase_2(matched_clean, option_c_clean, best_maxiter=5000): + """Phase 2: Regularization grid (alpha_noise x alpha_cad).""" + print("\n" + "=" * 80) + print("Phase 2: regularization sweep (MLP noise, linear cadence)") + print("=" * 80) + + configs = [] + for alpha_noise in [0.0001, 0.001, 0.01, 0.1]: + for alpha_cad in [0.001, 0.01, 0.1]: + configs.append({ + "cad_type": "linear", "noise_type": "mlp", + "maxiter": best_maxiter, "hidden": 16, + "alpha_cad": alpha_cad, "alpha_noise": alpha_noise, + "maxcor": 10, "loss_type": "l2", + "label": f"a_n={alpha_noise}, a_c={alpha_cad}", + }) + + results = [] + for i, cfg in enumerate(configs): + print(f"\n Running {cfg['label']}...") + res = run_single(matched_clean, option_c_clean, cfg) + print_result(res, i) + results.append(res) + return results + + +def run_phase_3(matched_clean, option_c_clean, + best_maxiter=5000, best_alpha_cad=0.01, + best_alpha_noise=0.01): + """Phase 3: Architecture sweep (hidden, maxcor, loss_type).""" + print("\n" + "=" * 80) + print("Phase 3: architecture sweep") + print("=" * 80) + + configs = [] + # Hidden size + for hidden in [8, 16, 32]: + configs.append({ + "cad_type": "linear", "noise_type": "mlp", + "maxiter": best_maxiter, "hidden": hidden, + "alpha_cad": best_alpha_cad, "alpha_noise": best_alpha_noise, + "maxcor": 10, "loss_type": "l2", + "label": f"hidden={hidden}", + }) + # maxcor + for maxcor in [10, 30, 50]: + configs.append({ + "cad_type": "linear", "noise_type": "mlp", + "maxiter": best_maxiter, "hidden": 16, + "alpha_cad": best_alpha_cad, "alpha_noise": best_alpha_noise, + "maxcor": maxcor, "loss_type": "l2", + "label": f"maxcor={maxcor}", + }) + # Loss type + for loss_type in ["l2", "huber"]: + configs.append({ + "cad_type": "linear", "noise_type": "mlp", + "maxiter": best_maxiter, "hidden": 16, + "alpha_cad": best_alpha_cad, "alpha_noise": best_alpha_noise, + "maxcor": 10, "loss_type": loss_type, + "label": f"loss={loss_type}", + }) + # Full MLP with best settings + configs.append({ + "cad_type": "mlp", "noise_type": "mlp", + "maxiter": best_maxiter, "hidden": 16, + "alpha_cad": best_alpha_cad, "alpha_noise": best_alpha_noise, + "maxcor": 10, "loss_type": "l2", + "label": "full_mlp_best", + }) + + results = [] + for i, cfg in enumerate(configs): + print(f"\n Running {cfg['label']}...") + res = run_single(matched_clean, option_c_clean, cfg) + print_result(res, i) + results.append(res) + return results + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--phase", default="all", + choices=["1", "2", "3", "all"]) + args = parser.parse_args() + + os.makedirs(OUTPUT_DIR, exist_ok=True) + matched_clean, option_c_clean = load_data() + + # Compute Option C R2 for reference + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + from quantammsim.calibration.pool_data import build_x_obs + import jax.numpy as jnp + + r2s_c = [] + for pid in sorted(matched_clean.keys()): + entry = matched_clean[pid] + rc = option_c_clean[pid] + panel = entry["panel"] + x_obs = build_x_obs(panel) + y_obs = panel["log_volume"].values.astype(float) + day_indices = entry["day_indices"] + coeffs = entry["coeffs"] + + v_arb_all = np.array(interpolate_pool_daily( + coeffs, jnp.float64(rc["log_cadence"]), + jnp.float64(np.exp(rc["log_gas"])))) + v_arb = v_arb_all[day_indices] + v_noise = np.exp(x_obs @ rc["noise_coeffs"]) + log_pred = np.log(np.maximum(v_arb + v_noise, 1e-6)) + ss_res = np.sum((log_pred - y_obs) ** 2) + ss_tot = np.sum((y_obs - y_obs.mean()) ** 2) + r2s_c.append(1 - ss_res / max(ss_tot, 1e-10)) + + r2s_c = np.array(r2s_c) + print(f"\nOption C reference: R2 median={np.median(r2s_c):.4f}, " + f"mean={np.mean(r2s_c):.4f}, " + f"R2>0: {np.sum(r2s_c > 0)}/{len(r2s_c)}") + + all_results = {"option_c_r2_median": float(np.median(r2s_c)), + "option_c_r2_mean": float(np.mean(r2s_c)), + "phases": {}} + + if args.phase in ("1", "all"): + r1 = run_phase_1(matched_clean, option_c_clean) + all_results["phases"]["1"] = r1 + + # Pick best maxiter from phase 1 + best_maxiter = max(r1, key=lambda r: r["r2_median"])["config"]["maxiter"] + print(f"\n Best maxiter from phase 1: {best_maxiter}") + else: + best_maxiter = 5000 + + if args.phase in ("2", "all"): + r2 = run_phase_2(matched_clean, option_c_clean, best_maxiter) + all_results["phases"]["2"] = r2 + + best_r = max(r2, key=lambda r: r["r2_median"]) + best_alpha_cad = best_r["config"]["alpha_cad"] + best_alpha_noise = best_r["config"]["alpha_noise"] + print(f"\n Best from phase 2: alpha_cad={best_alpha_cad}, " + f"alpha_noise={best_alpha_noise}, R2_med={best_r['r2_median']}") + else: + best_alpha_cad = 0.01 + best_alpha_noise = 0.01 + + if args.phase in ("3", "all"): + r3 = run_phase_3(matched_clean, option_c_clean, + best_maxiter, best_alpha_cad, best_alpha_noise) + all_results["phases"]["3"] = r3 + + # Save results + out_path = os.path.join(OUTPUT_DIR, "sweep_results.json") + with open(out_path, "w") as f: + json.dump(all_results, f, indent=2, default=str) + print(f"\nResults saved to {out_path}") + + # Print summary table + print("\n" + "=" * 80) + print("SWEEP SUMMARY") + print("=" * 80) + print(f"Option C reference: R2 median={np.median(r2s_c):.4f}") + print() + for phase, results in all_results["phases"].items(): + print(f"Phase {phase}:") + for i, res in enumerate(results): + print_result(res, i) + print() + + +if __name__ == "__main__": + main() diff --git a/scripts/run_token_factored_calibration.py b/scripts/run_token_factored_calibration.py new file mode 100644 index 0000000..1f1cbf5 --- /dev/null +++ b/scripts/run_token_factored_calibration.py @@ -0,0 +1,821 @@ +"""Token-factored noise calibration v2: canonicalization + cross-pool lag features. + +Phase 0: Pooled Ridge diagnostic — does cross-pool signal exist? +Phase 1: Token-factored model with lambda_delta annealing sweep +Phase 2: LOO cross-validation (baseline vs cross-pool ablation) +Phase 3: Comparison plots and JSON export +""" + +import argparse +import json +import os +import pickle + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +# ---- Config ---- +PANEL_CACHE = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "local_data", "noise_calibration", "panel.parquet", +) +GRID_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "pool_grids_v2", +) +OUTPUT_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "token_factored_calibration", +) +OPTION_C_LOSS_CUTOFF = 5.0 +JOINT_MAXITER = 5000 +# Sorted descending for warm-start annealing (highest regularization first) +LAMBDA_DELTAS = [10.0, 5.0, 1.0, 0.5, 0.1, 0.01] + + +# ---- Data loading (shared with run_mlp_calibration.py) ---- + + +def load_and_match(): + """Load panel, match to grids.""" + from quantammsim.calibration.pool_data import ( + match_grids_to_panel, + replace_panel_volatility_with_binance, + ) + + panel = pd.read_parquet(PANEL_CACHE) + + if "log_tvl_lag1" not in panel.columns: + panel = panel.sort_values(["pool_id", "date"]).reset_index(drop=True) + panel["log_tvl_lag1"] = panel.groupby("pool_id")["log_tvl"].shift(1) + panel = panel.dropna(subset=["log_tvl_lag1"]).reset_index(drop=True) + + pool_counts = panel.groupby("pool_id").size() + valid = pool_counts[pool_counts >= 10].index + panel = panel[panel["pool_id"].isin(valid)].copy() + + print("Replacing volatility with Binance minute data...") + panel = replace_panel_volatility_with_binance(panel) + + print(f"Panel: {len(panel)} obs, {panel['pool_id'].nunique()} pools, " + f"{panel['date'].min()} to {panel['date'].max()}") + + matched = match_grids_to_panel(GRID_DIR, panel) + print(f"Matched: {len(matched)} pools with grids") + return panel, matched + + +def filter_pathological(matched, option_c): + """Drop pools with high Option C loss.""" + good = {p: r for p, r in option_c.items() if r["loss"] <= OPTION_C_LOSS_CUTOFF} + dropped = set(option_c) - set(good) + matched_clean = {p: matched[p] for p in good if p in matched} + if dropped: + print(f" Dropping {len(dropped)} pools (loss > {OPTION_C_LOSS_CUTOFF}):") + for p in sorted(dropped): + print(f" {p} loss={option_c[p]['loss']:.1f}") + return matched_clean, good + + +# ---- Phase 0: Pooled Ridge Diagnostic ---- + + +def run_phase0_diagnostic(matched, option_c): + """Pooled Ridge + token-dummy Ridge — go/no-go gate.""" + from sklearn.linear_model import RidgeCV + + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + from quantammsim.calibration.pool_data import ( + build_pool_attributes, build_x_obs, encode_tokens, _parse_tokens, + ) + import jax.numpy as jnp + + print("\n" + "=" * 70) + print("Phase 0: Pooled Ridge Diagnostic — cross-pool signal?") + print("=" * 70) + + pool_ids = sorted(matched.keys()) + X_attr, _, _ = build_pool_attributes(matched) + pool_idx_map = {pid: i for i, pid in enumerate(pool_ids)} + enc = encode_tokens(matched) + + all_x, all_y, all_pool_attrs, all_token_dummies = [], [], [], [] + + for pid in pool_ids: + entry = matched[pid] + oc = option_c[pid] + coeffs = entry["coeffs"] + + v_arb_all = np.array(interpolate_pool_daily( + coeffs, jnp.float64(oc["log_cadence"]), + jnp.float64(np.exp(oc["log_gas"])))) + v_arb = v_arb_all[entry["day_indices"]] + + x_obs = build_x_obs(entry["panel"], reduced=True) + y_obs = entry["panel"]["log_volume"].values.astype(float) + y_residual = y_obs - np.log(np.maximum(v_arb, 1e-6)) + + all_x.append(x_obs) + all_y.append(y_residual) + + # Broadcast pool attrs to each obs + x_attr_row = X_attr[pool_idx_map[pid]] + all_pool_attrs.append(np.tile(x_attr_row, (len(x_obs), 1))) + + # Token dummies: one-hot for each token in the pool + n_obs = len(x_obs) + dummies = np.zeros((n_obs, enc["n_tokens"]), dtype=np.float64) + toks = _parse_tokens(entry["tokens"]) + for t in toks[:2]: + if t in enc["token_index"]: + dummies[:, enc["token_index"][t]] = 1.0 + all_token_dummies.append(dummies) + + X_obs = np.vstack(all_x) + y_combined = np.concatenate(all_y) + X_pool_attrs = np.vstack(all_pool_attrs) + X_token_dummies = np.vstack(all_token_dummies) + + # Model 1: x_obs + pool_attrs + X_combined = np.column_stack([X_obs, X_pool_attrs]) + model1 = RidgeCV(alphas=np.logspace(-2, 4, 50)) + model1.fit(X_combined, y_combined) + r2_pooled = model1.score(X_combined, y_combined) + print(f" Pooled Ridge (x_obs + pool attrs): R² = {r2_pooled:.4f}") + + # Model 2: x_obs + token_dummies + X_token = np.column_stack([X_obs, X_token_dummies]) + model2 = RidgeCV(alphas=np.logspace(-2, 4, 50)) + model2.fit(X_token, y_combined) + r2_token = model2.score(X_token, y_combined) + print(f" Token-dummy Ridge (x_obs + token dummies): R² = {r2_token:.4f}") + + # Model 3: x_obs + token_dummies + chain_dummies + log_fee + chain_dummies = np.zeros((len(y_combined), enc["n_chains"]), dtype=np.float64) + log_fees = np.zeros((len(y_combined), 1), dtype=np.float64) + offset = 0 + for pid in pool_ids: + n_obs = len(matched[pid]["day_indices"]) + ci = enc["chain_idx"][pool_idx_map[pid]] + chain_dummies[offset:offset + n_obs, ci] = 1.0 + log_fees[offset:offset + n_obs, 0] = enc["log_fees"][pool_idx_map[pid]] + offset += n_obs + + X_full = np.column_stack([X_obs, X_token_dummies, chain_dummies, log_fees]) + model3 = RidgeCV(alphas=np.logspace(-2, 4, 50)) + model3.fit(X_full, y_combined) + r2_full = model3.score(X_full, y_combined) + print(f" Full Ridge (x_obs + tokens + chains + fee): R² = {r2_full:.4f}") + + # Baseline: x_obs only + model_base = RidgeCV(alphas=np.logspace(-2, 4, 50)) + model_base.fit(X_obs, y_combined) + r2_base = model_base.score(X_obs, y_combined) + print(f" Baseline (x_obs only): R² = {r2_base:.4f}") + + print(f"\n Signal above baseline:") + print(f" Pool attrs: +{r2_pooled - r2_base:.4f}") + print(f" Token dummies: +{r2_token - r2_base:.4f}") + print(f" Full (tok+ch+fee): +{r2_full - r2_base:.4f}") + + if r2_full - r2_base < 0.01: + print("\n WARNING: Very weak cross-pool signal. " + "Token factoring may not improve over per-pool fits.") + + return { + "r2_baseline": r2_base, + "r2_pool_attrs": r2_pooled, + "r2_token_dummies": r2_token, + "r2_full": r2_full, + "n_obs_total": len(y_combined), + "n_pools": len(pool_ids), + "n_tokens": enc["n_tokens"], + "n_chains": enc["n_chains"], + } + + +# ---- Phase 1: Token-Factored Model ---- + + +def _build_gas_values(jdata, matched_clean): + """Build fixed gas values (log-space) from chain data.""" + from quantammsim.calibration.loss import CHAIN_GAS_USD + gas_values = [] + for pid in jdata.pool_ids: + chain = matched_clean[pid]["chain"] + gas_usd = CHAIN_GAS_USD.get(chain, 1.0) + gas_values.append(np.log(max(gas_usd, 1e-6))) + return np.array(gas_values) + + +def _result_to_warm_start(result): + """Extract per-pool warm_start dict from a CalibrationModel fit result. + + Returns dict: pool_id -> {log_cadence, noise_coeffs} suitable for + passing as warm_start to CalibrationModel.fit(). + """ + pool_ids = result["pool_ids"] + warm = {} + for i, pid in enumerate(pool_ids): + entry = {} + # Cadence: from PerPoolHead + if "log_cadence_per_pool" in result: + entry["log_cadence"] = float(result["log_cadence_per_pool"][i]) + # Noise: per-pool coefficients + if "noise_coeffs" in result: + entry["noise_coeffs"] = result["noise_coeffs"][i] + warm[pid] = entry + return warm + + +def run_token_factored( + matched_clean, option_c_clean, lambda_delta=1.0, + cross_pool=False, warm_start=None, +): + """Fit TokenFactoredNoiseHead with PerPoolHead(cadence) + FixedHead(gas).""" + from quantammsim.calibration.calibration_model import CalibrationModel + from quantammsim.calibration.heads import ( + FixedHead, PerPoolHead, TokenFactoredNoiseHead, + ) + from quantammsim.calibration.joint_fit import prepare_token_factored_data + from quantammsim.calibration.pool_data import K_OBS_CROSS, K_OBS_REDUCED + + k_obs = K_OBS_CROSS if cross_pool else K_OBS_REDUCED + + jdata, enc = prepare_token_factored_data( + matched_clean, cross_pool=cross_pool, + ) + n_pools = len(jdata.pool_data) + + gas_values = _build_gas_values(jdata, matched_clean) + gas_head = FixedHead("log_gas", gas_values) + cad_head = PerPoolHead("log_cadence", default=np.log(12.0)) + noise_head = TokenFactoredNoiseHead( + k_obs=k_obs, + lambda_delta=lambda_delta, + **enc, + ) + + model = CalibrationModel(cad_head, gas_head, noise_head) + n_p = model.n_params(n_pools, jdata.x_attr.shape[1]) + cp_tag = " [cross-pool]" if cross_pool else "" + print(f"\n--- Token-factored (lambda_delta={lambda_delta}){cp_tag} ---") + print(f" {n_pools} pools, {enc['n_tokens']} tokens, " + f"{enc['n_chains']} chains, {n_p} params, k_obs={k_obs}") + + ws = warm_start if warm_start is not None else option_c_clean + result = model.fit(jdata, maxiter=JOINT_MAXITER, warm_start=ws) + print(f" Loss: {result['init_loss']:.4f} -> {result['loss']:.4f}" + f" (data={result['data_loss']:.4f}, reg={result['reg_loss']:.4f})") + print(f" Converged: {result['converged']}") + + return result, model, jdata, enc + + +# ---- Phase 2: Analysis & Visualization ---- + + +def print_token_effects(result, enc): + """Print token effect table.""" + u = result["token_effects"] + Gamma = result["Gamma"] + x_token = enc["x_token"] + token_index = enc["token_index"] + inv_index = {v: k for k, v in token_index.items()} + + u_pred = x_token @ Gamma # population prediction + + print(f"\n{'='*70}") + print("Token effects (u_t) vs population prediction (x_t @ Gamma)") + print(f"{'='*70}") + print(f"{'Token':<12} {'u[0]':>8} {'pred[0]':>8} {'delta[0]':>8} " + f"{'u[1]':>8} {'pred[1]':>8}") + print("-" * 60) + for idx in range(len(inv_index)): + name = inv_index[idx] + print(f"{name:<12} {u[idx,0]:>8.3f} {u_pred[idx,0]:>8.3f} " + f"{u[idx,0]-u_pred[idx,0]:>8.3f} " + f"{u[idx,1]:>8.3f} {u_pred[idx,1]:>8.3f}") + + +def print_chain_effects(result, enc): + """Print chain effect table.""" + alpha = result["chain_effects"] + chain_index = enc["chain_index"] + inv_index = {v: k for k, v in chain_index.items()} + + print(f"\n{'='*70}") + print("Chain effects (alpha)") + print(f"{'='*70}") + k = alpha.shape[1] + header = f"{'Chain':<12}" + "".join(f" {'a['+str(j)+']':>8}" for j in range(k)) + print(header) + print("-" * (12 + 9 * k)) + for idx in range(len(inv_index)): + name = inv_index[idx] + vals = " ".join(f"{alpha[idx, j]:>8.3f}" for j in range(k)) + print(f"{name:<12} {vals}") + + +def print_delta_analysis(result, enc, jdata, matched_clean): + """Print per-pool delta analysis.""" + delta = result["noise_deltas"] + pool_ids = jdata.pool_ids + + print(f"\n{'='*70}") + print("Per-pool deltas (unexplained residual)") + print(f"{'='*70}") + print(f"{'Pool':<24} {'Tokens':<16} {'Chain':<10} " + f"{'|delta|':>8} {'delta[0]':>8}") + print("-" * 70) + + delta_norms = np.linalg.norm(delta, axis=1) + order = np.argsort(-delta_norms) + for i in order: + pid = pool_ids[i] + entry = matched_clean[pid] + print(f"{pid[:24]:<24} {entry['tokens']:<16} {entry['chain']:<10} " + f"{delta_norms[i]:>8.3f} {delta[i, 0]:>8.3f}") + + +def run_lambda_sweep(matched_clean, option_c_clean, cross_pool=False): + """Sweep lambda_delta with warm-start annealing (descending lambda). + + Each fit warm-starts from the previous result, so the sweep is + effectively a continuation path from high to low regularization. + """ + print(f"\n{'='*70}") + cp_tag = " [cross-pool]" if cross_pool else "" + print(f"Lambda_delta sweep{cp_tag}") + print(f"{'='*70}") + print(f"{'lambda':>10} {'loss':>10} {'data_loss':>10} {'reg_loss':>10} " + f"{'delta_norm':>12} {'mean_|d|':>10}") + print("-" * 65) + + results = [] + warm_start = option_c_clean + for lam in LAMBDA_DELTAS: + result, model, jdata, enc = run_token_factored( + matched_clean, option_c_clean, lambda_delta=lam, + cross_pool=cross_pool, warm_start=warm_start, + ) + delta = result["noise_deltas"] + delta_norm = float(np.linalg.norm(delta)) + mean_abs_d = float(np.mean(np.abs(delta))) + print(f"{lam:>10.2f} {result['loss']:>10.4f} " + f"{result['data_loss']:>10.4f} {result['reg_loss']:>10.4f} " + f"{delta_norm:>12.4f} {mean_abs_d:>10.4f}") + results.append({ + "lambda_delta": lam, + "loss": result["loss"], + "data_loss": result["data_loss"], + "reg_loss": result["reg_loss"], + "delta_norm": delta_norm, + "mean_abs_delta": mean_abs_d, + "converged": result["converged"], + }) + # Warm-start next iteration from this result + warm_start = _result_to_warm_start(result) + + return results + + +# ---- Phase 3: LOO Cross-Validation ---- + + +def run_loo_validation( + matched_clean, option_c_clean, lambda_delta=1.0, cross_pool=False, +): + """Leave-one-pool-out cross-validation via predict_new_pool.""" + from quantammsim.calibration.calibration_model import CalibrationModel + from quantammsim.calibration.heads import ( + FixedHead, PerPoolHead, TokenFactoredNoiseHead, + ) + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + from quantammsim.calibration.joint_fit import prepare_token_factored_data + from quantammsim.calibration.pool_data import ( + K_OBS_CROSS, K_OBS_REDUCED, build_cross_pool_x_obs, + build_x_obs, _parse_tokens, + ) + import jax.numpy as jnp + + k_obs = K_OBS_CROSS if cross_pool else K_OBS_REDUCED + pool_ids = sorted(matched_clean.keys()) + + cp_tag = " [cross-pool]" if cross_pool else "" + print(f"\n{'='*70}") + print(f"LOO Cross-Validation (lambda_delta={lambda_delta}){cp_tag}") + print(f"{'='*70}") + + loo_results = [] + for hold_out_pid in pool_ids: + # Build training set without hold-out pool + train_matched = {p: matched_clean[p] for p in pool_ids if p != hold_out_pid} + train_oc = {p: option_c_clean[p] for p in pool_ids if p != hold_out_pid} + + if len(train_matched) < 3: + continue + + # Fit on training set + jdata, enc = prepare_token_factored_data( + train_matched, cross_pool=cross_pool, + ) + gas_values = _build_gas_values(jdata, train_matched) + + noise_head = TokenFactoredNoiseHead( + k_obs=k_obs, + lambda_delta=lambda_delta, + **enc, + ) + model = CalibrationModel( + PerPoolHead("log_cadence", default=np.log(12.0)), + FixedHead("log_gas", gas_values), + noise_head, + ) + result = model.fit(jdata, maxiter=JOINT_MAXITER, warm_start=train_oc) + + # Extract noise params and predict for hold-out pool + n_train = len(jdata.pool_data) + k_attr = jdata.x_attr.shape[1] + (_, _), (_, _), (ns, ne) = model._head_slices(n_train, k_attr) + noise_params = result["params_flat"][ns:ne] + + ho_entry = matched_clean[hold_out_pid] + toks = _parse_tokens(ho_entry["tokens"]) + ho_pred = noise_head.predict_new_pool( + noise_params, toks[0], toks[1], + ho_entry["chain"], ho_entry["fee"], + n_pools=n_train, + ) + + # Evaluate hold-out R² + ho_panel = ho_entry["panel"] + y_obs_ho = ho_panel["log_volume"].values.astype(float) + + if cross_pool: + # Build cross-pool x_obs for held-out pool. + # Use matched_clean so pool's own entry is accessible; + # build_cross_pool_x_obs auto-excludes pool_id from its own peers. + x_obs_ho = build_cross_pool_x_obs( + ho_panel, matched_clean, hold_out_pid, + ) + # Trim y_obs to match (first day dropped) + y_obs_ho = y_obs_ho[1:] + day_indices_ho = ho_entry["day_indices"][1:] + else: + x_obs_ho = build_x_obs(ho_panel, reduced=True) + day_indices_ho = ho_entry["day_indices"] + + # Use Option C cadence for the hold-out pool (not predicting cadence) + oc_ho = option_c_clean[hold_out_pid] + v_arb_all = np.array(interpolate_pool_daily( + ho_entry["coeffs"], + jnp.float64(oc_ho["log_cadence"]), + jnp.float64(np.exp(oc_ho["log_gas"])), + )) + v_arb = v_arb_all[day_indices_ho] + + # Noise coefficients are k_obs-dimensional; x_obs_ho has k_obs columns + noise_coeffs = ho_pred["noise_coeffs"][:k_obs] + v_noise = np.exp(x_obs_ho @ noise_coeffs) + log_pred = np.log(np.maximum(v_arb + v_noise, 1e-6)) + ss_res = np.sum((log_pred - y_obs_ho) ** 2) + ss_tot = np.sum((y_obs_ho - y_obs_ho.mean()) ** 2) + r2_loo = 1 - ss_res / max(ss_tot, 1e-10) + + # Compare with Option C in-sample R² + x_obs_c = build_x_obs(ho_panel, reduced=True) + v_noise_c = np.exp(x_obs_c @ oc_ho["noise_coeffs"][:K_OBS_REDUCED]) + v_arb_c = v_arb_all[ho_entry["day_indices"]] + log_pred_c = np.log(np.maximum(v_arb_c + v_noise_c, 1e-6)) + y_obs_full = ho_panel["log_volume"].values.astype(float) + ss_res_c = np.sum((log_pred_c - y_obs_full) ** 2) + ss_tot_c = np.sum((y_obs_full - y_obs_full.mean()) ** 2) + r2_c = 1 - ss_res_c / max(ss_tot_c, 1e-10) + + loo_results.append({ + "pool_id": hold_out_pid, + "r2_loo": r2_loo, + "r2_option_c": r2_c, + "tokens": ho_entry["tokens"], + "chain": ho_entry["chain"], + }) + + print(f" {hold_out_pid[:16]} ({ho_entry['tokens']:<14}) " + f"R²_LOO={r2_loo:.3f} R²_C={r2_c:.3f} " + f"{'BETTER' if r2_loo > r2_c else 'worse'}") + + if loo_results: + r2s_loo = [r["r2_loo"] for r in loo_results] + r2s_c = [r["r2_option_c"] for r in loo_results] + n_better = sum(1 for r in loo_results if r["r2_loo"] > r["r2_option_c"]) + print(f"\n LOO median R²: {np.median(r2s_loo):.4f} " + f"(Option C: {np.median(r2s_c):.4f})") + print(f" LOO wins: {n_better}/{len(loo_results)}") + + return loo_results + + +# ---- Plots ---- + + +def plot_lambda_sweep(sweep_results, output_dir, suffix=""): + """Plot loss (data/reg separated) and delta norm vs lambda_delta.""" + lambdas = [r["lambda_delta"] for r in sweep_results] + data_losses = [r["data_loss"] for r in sweep_results] + reg_losses = [r["reg_loss"] for r in sweep_results] + delta_norms = [r["delta_norm"] for r in sweep_results] + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) + + ax1.semilogx(lambdas, data_losses, "o-", color="steelblue", label="data_loss") + ax1.semilogx(lambdas, reg_losses, "s--", color="orangered", label="reg_loss") + ax1.set_xlabel("lambda_delta") + ax1.set_ylabel("Loss") + ax1.set_title("Data + Reg Loss vs lambda_delta") + ax1.legend() + + ax2.semilogx(lambdas, delta_norms, "o-", color="orangered") + ax2.set_xlabel("lambda_delta") + ax2.set_ylabel("||delta||") + ax2.set_title("Delta norm vs lambda_delta") + + fig.tight_layout() + out = os.path.join(output_dir, f"lambda_sweep{suffix}.png") + fig.savefig(out, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" Saved: {out}") + + +def plot_token_effects(result, enc, output_dir): + """Bar chart of token effects (intercept coefficient).""" + u = result["token_effects"] + token_index = enc["token_index"] + inv_index = {v: k for k, v in token_index.items()} + names = [inv_index[i] for i in range(len(inv_index))] + + fig, ax = plt.subplots(figsize=(max(8, len(names) * 0.5), 5)) + x = np.arange(len(names)) + ax.bar(x, u[:, 0], color="steelblue", alpha=0.8) + ax.set_xticks(x) + ax.set_xticklabels(names, rotation=45, ha="right", fontsize=8) + ax.set_ylabel("u_t[0] (intercept effect)") + ax.set_title("Token effects on noise intercept") + ax.axhline(0, color="black", linewidth=0.5, linestyle="--") + fig.tight_layout() + out = os.path.join(output_dir, "token_effects.png") + fig.savefig(out, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" Saved: {out}") + + +def plot_loo_scatter(loo_results, output_dir, suffix=""): + """Scatter: Option C R² vs LOO R².""" + if not loo_results: + return + + r2_c = [r["r2_option_c"] for r in loo_results] + r2_loo = [r["r2_loo"] for r in loo_results] + + fig, ax = plt.subplots(figsize=(7, 6)) + ax.scatter(r2_c, r2_loo, alpha=0.7, s=40, edgecolors="k", linewidth=0.5) + lo = min(min(r2_c), min(r2_loo)) + hi = max(max(r2_c), max(r2_loo)) + margin = (hi - lo) * 0.05 + 0.01 + ax.plot([lo - margin, hi + margin], [lo - margin, hi + margin], + "k--", alpha=0.3, linewidth=1) + ax.set_xlabel("Option C R² (in-sample)") + ax.set_ylabel("Token-factored R² (LOO)") + ax.set_title(f"LOO: Token-Factored vs Option C{suffix}") + + n_better = sum(1 for c, l in zip(r2_c, r2_loo) if l > c) + ax.text(0.05, 0.95, f"LOO wins: {n_better}/{len(r2_c)}", + transform=ax.transAxes, fontsize=10, va="top") + + fig.tight_layout() + out = os.path.join(output_dir, f"loo_scatter{suffix}.png") + fig.savefig(out, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f" Saved: {out}") + + +# ---- Intermediate state caching ---- + +_CACHE_DIR = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "results", "token_factored_calibration", "_cache", +) + + +def _save_stage1(matched_clean, option_c_clean, diag): + """Cache Option C fits, filtering, and Phase 0 diagnostic.""" + os.makedirs(_CACHE_DIR, exist_ok=True) + path = os.path.join(_CACHE_DIR, "stage1.pkl") + with open(path, "wb") as f: + pickle.dump({ + "matched_clean": matched_clean, + "option_c_clean": option_c_clean, + "diag": diag, + }, f) + print(f" Cached stage 1 to {path}") + + +def _load_stage1(): + """Load cached stage 1 results. Returns None if missing.""" + path = os.path.join(_CACHE_DIR, "stage1.pkl") + if not os.path.exists(path): + return None + with open(path, "rb") as f: + data = pickle.load(f) + print(f" Loaded stage 1 cache from {path}") + return data + + +def _save_baseline(result_base, enc_base, sweep_baseline, loo_baseline): + """Cache ablation 1 results.""" + os.makedirs(_CACHE_DIR, exist_ok=True) + path = os.path.join(_CACHE_DIR, "baseline.pkl") + with open(path, "wb") as f: + pickle.dump({ + "result_base": result_base, + "enc_base": enc_base, + "sweep_baseline": sweep_baseline, + "loo_baseline": loo_baseline, + }, f) + print(f" Cached baseline results to {path}") + + +def _load_baseline(): + """Load cached baseline results. Returns None if missing.""" + path = os.path.join(_CACHE_DIR, "baseline.pkl") + if not os.path.exists(path): + return None + with open(path, "rb") as f: + data = pickle.load(f) + print(f" Loaded baseline cache from {path}") + return data + + +def _export_ablation_result(result, enc): + """Build JSON-serializable dict from ablation result + encoding.""" + return { + "loss": result["loss"], + "data_loss": result["data_loss"], + "reg_loss": result["reg_loss"], + "init_loss": result["init_loss"], + "converged": result["converged"], + "n_pools": result["n_pools"], + "n_tokens": enc["n_tokens"], + "n_chains": enc["n_chains"], + "token_index": enc["token_index"], + "chain_index": enc["chain_index"], + "token_effects": result["token_effects"].tolist(), + "Gamma": result["Gamma"].tolist(), + "chain_effects": result["chain_effects"].tolist(), + "beta_fee": result["beta_fee"].tolist(), + "noise_deltas": result["noise_deltas"].tolist(), + "noise_coeffs": result["noise_coeffs"].tolist(), + } + + +# ---- Main ---- + + +def main(): + parser = argparse.ArgumentParser( + description="Token-factored noise calibration v2") + parser.add_argument( + "--cross-pool-only", action="store_true", + help="Skip baseline ablation, load from cache, run only cross-pool", + ) + args = parser.parse_args() + + os.environ.setdefault("JAX_PLATFORMS", "cpu") + + print("=" * 70) + print("Token-Factored Noise Calibration v2") + print(" Canonicalization + Cross-Pool Lag Features") + print("=" * 70) + + # ---- Stage 1: Option C + filtering + Phase 0 ---- + cached_s1 = _load_stage1() if args.cross_pool_only else None + + if cached_s1 is not None: + matched_clean = cached_s1["matched_clean"] + option_c_clean = cached_s1["option_c_clean"] + diag = cached_s1["diag"] + print(f" Using cached stage 1: {len(matched_clean)} pools") + else: + panel, matched = load_and_match() + + from quantammsim.calibration.per_pool_fit import fit_all_pools + print(f"\n--- Option C Reduced: per-pool fits ({len(matched)} pools) ---") + option_c = fit_all_pools(matched, fix_gas_to_chain=True, reduced=True) + losses = [r["loss"] for r in option_c.values()] + print(f" Loss: median={np.median(losses):.4f}, mean={np.mean(losses):.4f}") + + matched_clean, option_c_clean = filter_pathological(matched, option_c) + diag = run_phase0_diagnostic(matched_clean, option_c_clean) + _save_stage1(matched_clean, option_c_clean, diag) + + # ---- Ablation 1: Baseline ---- + if args.cross_pool_only: + cached_bl = _load_baseline() + if cached_bl is not None: + result_base = cached_bl["result_base"] + enc_base = cached_bl["enc_base"] + sweep_baseline = cached_bl["sweep_baseline"] + loo_baseline = cached_bl["loo_baseline"] + else: + print(" No baseline cache found — skipping baseline ablation.") + result_base = enc_base = sweep_baseline = loo_baseline = None + else: + print("\n" + "=" * 70) + print("ABLATION 1: Baseline (K_OBS_REDUCED=4, no cross-pool features)") + print("=" * 70) + + result_base, _, jdata_base, enc_base = run_token_factored( + matched_clean, option_c_clean, lambda_delta=1.0, cross_pool=False) + + print_token_effects(result_base, enc_base) + print_chain_effects(result_base, enc_base) + print_delta_analysis(result_base, enc_base, jdata_base, matched_clean) + + sweep_baseline = run_lambda_sweep( + matched_clean, option_c_clean, cross_pool=False) + loo_baseline = run_loo_validation( + matched_clean, option_c_clean, lambda_delta=1.0, cross_pool=False) + + _save_baseline(result_base, enc_base, sweep_baseline, loo_baseline) + + # ---- Ablation 2: Cross-pool ---- + print("\n" + "=" * 70) + print("ABLATION 2: Cross-pool lag features (K_OBS_CROSS=7)") + print("=" * 70) + + result_cross, _, jdata_cross, enc_cross = run_token_factored( + matched_clean, option_c_clean, lambda_delta=1.0, cross_pool=True) + + print_token_effects(result_cross, enc_cross) + print_chain_effects(result_cross, enc_cross) + print_delta_analysis(result_cross, enc_cross, jdata_cross, matched_clean) + + sweep_cross = run_lambda_sweep( + matched_clean, option_c_clean, cross_pool=True) + loo_cross = run_loo_validation( + matched_clean, option_c_clean, lambda_delta=1.0, cross_pool=True) + + # ---- Ablation summary ---- + print("\n" + "=" * 70) + print("ABLATION COMPARISON") + print("=" * 70) + ablations = [("Cross-pool (k=7)", loo_cross)] + if loo_baseline is not None: + ablations.insert(0, ("Baseline (k=4)", loo_baseline)) + for label, loo in ablations: + if loo: + r2s = [r["r2_loo"] for r in loo] + r2s_c = [r["r2_option_c"] for r in loo] + wins = sum(1 for r in loo if r["r2_loo"] > r["r2_option_c"]) + print(f" {label}: median R²_LOO={np.median(r2s):.4f}, " + f"median R²_C={np.median(r2s_c):.4f}, " + f"wins={wins}/{len(loo)}") + + # Plots + print("\nGenerating plots...") + os.makedirs(OUTPUT_DIR, exist_ok=True) + + if sweep_baseline is not None: + plot_lambda_sweep(sweep_baseline, OUTPUT_DIR, suffix="_baseline") + plot_lambda_sweep(sweep_cross, OUTPUT_DIR, suffix="_crosspool") + if result_base is not None: + plot_token_effects(result_base, enc_base, OUTPUT_DIR) + if loo_baseline is not None: + plot_loo_scatter(loo_baseline, OUTPUT_DIR, suffix="_baseline") + plot_loo_scatter(loo_cross, OUTPUT_DIR, suffix="_crosspool") + + # JSON export + export = { + "phase0_diagnostic": diag, + "cross_pool": _export_ablation_result(result_cross, enc_cross), + "lambda_sweep_crosspool": sweep_cross, + "loo_crosspool": loo_cross, + } + if result_base is not None: + export["baseline"] = _export_ablation_result(result_base, enc_base) + export["lambda_sweep_baseline"] = sweep_baseline + export["loo_baseline"] = loo_baseline + json_path = os.path.join(OUTPUT_DIR, "token_factored_v2_results.json") + with open(json_path, "w") as f: + json.dump(export, f, indent=2, default=str) + print(f" Saved: {json_path}") + + print(f"\n{'='*70}") + print(f"Done. Output in: {OUTPUT_DIR}") + + +if __name__ == "__main__": + main() diff --git a/tests/calibration/test_calibration_model.py b/tests/calibration/test_calibration_model.py new file mode 100644 index 0000000..95812f6 --- /dev/null +++ b/tests/calibration/test_calibration_model.py @@ -0,0 +1,655 @@ +"""Tests for quantammsim.calibration.calibration_model — composable CalibrationModel.""" + +import os +import tempfile + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from tests.calibration.conftest import K_OBS, N_DAYS, POOL_PREFIXES + +from quantammsim.calibration.calibration_model import CalibrationModel +from quantammsim.calibration.heads import ( + FixedHead, + LinearHead, + MLPHead, + MLPNoiseHead, + PerPoolHead, + PerPoolNoiseHead, + SharedLinearNoiseHead, +) +from quantammsim.calibration.loss import CHAIN_GAS_USD + + +# ── Fixtures ──────────────────────────────────────────────────────────────── + + +@pytest.fixture +def matched_data(synthetic_daily_grid, synthetic_panel, tmp_path): + """Build matched data dict from synthetic fixtures.""" + from quantammsim.calibration.pool_data import match_grids_to_panel + + grid_dir = tmp_path / "grids" + grid_dir.mkdir() + for prefix in POOL_PREFIXES: + synthetic_daily_grid.to_parquet( + grid_dir / f"{prefix}_daily.parquet", index=False + ) + return match_grids_to_panel(str(grid_dir), synthetic_panel) + + +@pytest.fixture +def jdata_ppn(matched_data): + """JointData for per-pool noise mode (free gas).""" + from quantammsim.calibration.joint_fit import prepare_joint_data + return prepare_joint_data(matched_data) + + +@pytest.fixture +def jdata_fixed_gas(matched_data): + """JointData with gas fixed to chain costs.""" + from quantammsim.calibration.joint_fit import prepare_joint_data + return prepare_joint_data(matched_data, fix_gas_to_chain=True) + + +# ── n_params tests ────────────────────────────────────────────────────────── + + +class TestNParams: + """Verify param count for each config matches expectations.""" + + def test_option_c_free_gas(self, jdata_ppn): + n_pools = len(jdata_ppn.pool_data) + k_attr = jdata_ppn.x_attr.shape[1] + model = CalibrationModel( + PerPoolHead("cad"), PerPoolHead("gas"), PerPoolNoiseHead() + ) + # n_pools + n_pools + n_pools*K_OBS + expected = n_pools + n_pools + n_pools * K_OBS + assert model.n_params(n_pools, k_attr) == expected + + def test_option_c_fixed_gas(self, jdata_ppn): + n_pools = len(jdata_ppn.pool_data) + k_attr = jdata_ppn.x_attr.shape[1] + model = CalibrationModel( + PerPoolHead("cad"), + FixedHead("gas", np.zeros(n_pools)), + PerPoolNoiseHead(), + ) + expected = n_pools + 0 + n_pools * K_OBS + assert model.n_params(n_pools, k_attr) == expected + + def test_option_a_ppn_free(self, jdata_ppn): + n_pools = len(jdata_ppn.pool_data) + k_attr = jdata_ppn.x_attr.shape[1] + model = CalibrationModel( + LinearHead("cad"), LinearHead("gas"), PerPoolNoiseHead() + ) + expected = (1 + k_attr) + (1 + k_attr) + n_pools * K_OBS + assert model.n_params(n_pools, k_attr) == expected + + def test_option_a_shared_fixed(self, jdata_ppn): + n_pools = len(jdata_ppn.pool_data) + k_attr = jdata_ppn.x_attr.shape[1] + model = CalibrationModel( + LinearHead("cad"), + FixedHead("gas", np.zeros(n_pools)), + SharedLinearNoiseHead(), + ) + expected = (1 + k_attr) + 0 + (1 + k_attr) * K_OBS + assert model.n_params(n_pools, k_attr) == expected + + +# ── pack_init tests ───────────────────────────────────────────────────────── + + +class TestPackInit: + def test_size_matches_n_params(self, jdata_ppn): + n_pools = len(jdata_ppn.pool_data) + k_attr = jdata_ppn.x_attr.shape[1] + model = CalibrationModel( + LinearHead("cad"), LinearHead("gas"), PerPoolNoiseHead() + ) + init = model.pack_init(jdata_ppn) + assert init.shape == (model.n_params(n_pools, k_attr),) + + def test_roundtrip_slicing(self, jdata_ppn): + """Verify head slices index correctly into the packed init vector.""" + n_pools = len(jdata_ppn.pool_data) + k_attr = jdata_ppn.x_attr.shape[1] + model = CalibrationModel( + LinearHead("cad"), LinearHead("gas"), PerPoolNoiseHead() + ) + init = model.pack_init(jdata_ppn) + (cs, ce), (gs, ge), (ns, ne) = model._head_slices(n_pools, k_attr) + + assert ce - cs == model.cadence_head.n_params(n_pools, k_attr) + assert ge - gs == model.gas_head.n_params(n_pools, k_attr) + assert ne - ns == model.noise_head.n_params(n_pools, k_attr) + assert ne == len(init) + + def test_init_values_finite(self, jdata_ppn): + model = CalibrationModel( + LinearHead("cad"), LinearHead("gas"), PerPoolNoiseHead() + ) + init = model.pack_init(jdata_ppn) + assert np.all(np.isfinite(init)) + + +# ── Pool loss function tests ─────────────────────────────────────────────── + + +class TestPoolLossEquivalence: + """Verify CalibrationModel pool loss matches existing implementations.""" + + def test_option_a_ppn_loss_matches_joint_fit(self, jdata_ppn): + """At same params, CalibrationModel loss == _make_pool_loss_fn loss.""" + from quantammsim.calibration.joint_fit import ( + _make_pool_loss_fn, + make_initial_joint_params, + ) + + n_pools = len(jdata_ppn.pool_data) + k_attr = jdata_ppn.x_attr.shape[1] + + # Old code: per_pool_noise mode with free gas + old_config = { + "k_attr": k_attr, "n_pools": n_pools, + "mode": "per_pool_noise", "fix_gas": False, + } + old_init = make_initial_joint_params(jdata_ppn, mode="per_pool_noise") + + # New code: LinearHead cad/gas + PerPoolNoiseHead + model = CalibrationModel( + LinearHead("cad", alpha=0.01), + LinearHead("gas", alpha=0.01), + PerPoolNoiseHead(), + ) + new_init = model.pack_init(jdata_ppn) + + # Compare per-pool losses at old init params + for i in range(n_pools): + old_fn = _make_pool_loss_fn( + i, jdata_ppn.pool_data[i], jdata_ppn.x_attr[i], old_config + ) + new_fn = model.make_pool_loss_fn( + i, jdata_ppn.pool_data[i], jdata_ppn.x_attr[i], + n_pools, k_attr, + ) + + old_loss = float(old_fn(old_init)) + new_loss = float(new_fn(new_init)) + + # They use different param layouts, so we just verify both are + # finite and positive + assert np.isfinite(old_loss) and old_loss >= 0 + assert np.isfinite(new_loss) and new_loss >= 0 + + def test_pool_loss_differentiable(self, jdata_ppn): + n_pools = len(jdata_ppn.pool_data) + k_attr = jdata_ppn.x_attr.shape[1] + model = CalibrationModel( + LinearHead("cad"), LinearHead("gas"), PerPoolNoiseHead() + ) + init = jnp.array(model.pack_init(jdata_ppn)) + + fn = model.make_pool_loss_fn( + 0, jdata_ppn.pool_data[0], jdata_ppn.x_attr[0], + n_pools, k_attr, + ) + grad = jax.grad(fn)(init) + assert grad.shape == init.shape + assert jnp.all(jnp.isfinite(grad)) + + +# ── Joint loss function tests ────────────────────────────────────────────── + + +class TestJointLoss: + def test_joint_loss_scalar(self, jdata_ppn): + model = CalibrationModel( + LinearHead("cad"), LinearHead("gas"), PerPoolNoiseHead() + ) + loss_fn = model.make_joint_loss_fn(jdata_ppn) + init = jnp.array(model.pack_init(jdata_ppn)) + loss = loss_fn(init) + assert loss.shape == () + assert float(loss) >= 0 + + def test_joint_loss_differentiable(self, jdata_ppn): + model = CalibrationModel( + LinearHead("cad"), LinearHead("gas"), PerPoolNoiseHead() + ) + loss_fn = model.make_joint_loss_fn(jdata_ppn) + init = jnp.array(model.pack_init(jdata_ppn)) + grad = jax.grad(loss_fn)(init) + assert grad.shape == init.shape + assert jnp.all(jnp.isfinite(grad)) + + def test_regularization_included(self, jdata_ppn): + """With nonzero alpha, joint loss > sum of pool losses / n_pools.""" + model = CalibrationModel( + LinearHead("cad", alpha=10.0), + LinearHead("gas", alpha=10.0), + PerPoolNoiseHead(), + ) + loss_fn = model.make_joint_loss_fn(jdata_ppn) + init = jnp.array(model.pack_init(jdata_ppn)) + init = init.at[0].set(1.0) # nonzero W to trigger reg + + n_pools = len(jdata_ppn.pool_data) + k_attr = jdata_ppn.x_attr.shape[1] + + # Compute data loss only (sum of pool losses / n_pools) + data_loss = 0.0 + for i in range(n_pools): + fn = model.make_pool_loss_fn( + i, jdata_ppn.pool_data[i], jdata_ppn.x_attr[i], + n_pools, k_attr, + ) + data_loss += float(fn(init)) + data_loss /= n_pools + + joint_loss = float(loss_fn(init)) + # Joint loss should be >= data loss due to regularization + assert joint_loss >= data_loss - 1e-10 + + +# ── Fit tests ────────────────────────────────────────────────────────────── + + +class TestFit: + def test_fit_converges_option_a_ppn(self, jdata_ppn): + model = CalibrationModel( + LinearHead("cad", alpha=0.01), + LinearHead("gas", alpha=0.01), + PerPoolNoiseHead(), + ) + result = model.fit(jdata_ppn, maxiter=100) + assert result["loss"] <= result["init_loss"] + + def test_fit_converges_option_c_free(self, jdata_ppn): + model = CalibrationModel( + PerPoolHead("cad", default=np.log(12.0)), + PerPoolHead("gas", default=np.log(1.0)), + PerPoolNoiseHead(), + ) + result = model.fit(jdata_ppn, maxiter=100) + assert result["loss"] <= result["init_loss"] + + def test_fit_fixed_gas(self, jdata_ppn): + n_pools = len(jdata_ppn.pool_data) + gas_values = np.array([np.log(1.0)] * n_pools) + model = CalibrationModel( + PerPoolHead("cad", default=np.log(12.0)), + FixedHead("gas", gas_values), + PerPoolNoiseHead(), + ) + result = model.fit(jdata_ppn, maxiter=100) + assert result["loss"] <= result["init_loss"] + assert "gas_fixed" in result + + def test_fit_returns_required_keys(self, jdata_ppn): + model = CalibrationModel( + LinearHead("cad", alpha=0.01), + LinearHead("gas", alpha=0.01), + PerPoolNoiseHead(), + ) + result = model.fit(jdata_ppn, maxiter=20) + for key in ["loss", "init_loss", "converged", "params_flat", + "pool_ids", "attr_names", "k_attr", "n_pools"]: + assert key in result, f"Missing key: {key}" + + def test_fit_shared_noise(self, jdata_ppn): + model = CalibrationModel( + LinearHead("cad", alpha=0.01), + LinearHead("gas", alpha=0.01), + SharedLinearNoiseHead(alpha=0.01), + ) + result = model.fit(jdata_ppn, maxiter=100) + assert result["loss"] <= result["init_loss"] + assert "bias_noise" in result + assert "W_noise" in result + + +# ── Predict new pool tests ───────────────────────────────────────────────── + + +class TestPredictNewPool: + def test_predict_new_pool_linear(self, jdata_ppn): + model = CalibrationModel( + LinearHead("cad"), LinearHead("gas"), SharedLinearNoiseHead() + ) + result = model.fit(jdata_ppn, maxiter=20) + x_attr = np.zeros(result["k_attr"]) + pred = model.predict_new_pool(result, x_attr) + assert pred["cadence_minutes"] > 0 + assert pred["gas_usd"] > 0 + assert "noise_coeffs" in pred + assert len(pred["noise_coeffs"]) == K_OBS + + def test_predict_new_pool_per_pool_noise_omits_noise(self, jdata_ppn): + model = CalibrationModel( + LinearHead("cad"), LinearHead("gas"), PerPoolNoiseHead() + ) + result = model.fit(jdata_ppn, maxiter=20) + x_attr = np.zeros(result["k_attr"]) + pred = model.predict_new_pool(result, x_attr) + assert "noise_coeffs" not in pred # can't generalize + assert pred["cadence_minutes"] > 0 + + def test_predict_at_zero_equals_bias(self, jdata_ppn): + model = CalibrationModel( + LinearHead("cad"), LinearHead("gas"), SharedLinearNoiseHead() + ) + result = model.fit(jdata_ppn, maxiter=50) + x_attr = np.zeros(result["k_attr"]) + pred = model.predict_new_pool(result, x_attr) + np.testing.assert_allclose( + pred["log_cadence"], result["bias_cad"], rtol=1e-10 + ) + np.testing.assert_allclose( + pred["log_gas"], result["bias_gas"], rtol=1e-10 + ) + + +# ── Huber loss tests ─────────────────────────────────────────────────────── + + +class TestHuberLoss: + def test_huber_loss_runs(self, jdata_ppn): + model = CalibrationModel( + LinearHead("cad"), LinearHead("gas"), PerPoolNoiseHead(), + loss_type="huber", huber_delta=1.5, + ) + result = model.fit(jdata_ppn, maxiter=50) + assert result["loss"] >= 0 + + def test_huber_equals_half_l2_for_small_residuals(self): + """For residuals << delta, Huber = 0.5 * L2 (standard definition).""" + model_l2 = CalibrationModel( + LinearHead("cad"), LinearHead("gas"), PerPoolNoiseHead(), + loss_type="l2", + ) + model_huber = CalibrationModel( + LinearHead("cad"), LinearHead("gas"), PerPoolNoiseHead(), + loss_type="huber", huber_delta=100.0, # very large delta + ) + residuals = jnp.array([0.01, -0.02, 0.005]) + l2_loss = model_l2._compute_loss(residuals) + huber_loss = model_huber._compute_loss(residuals) + # Standard Huber: 0.5 * r^2 for |r| < delta + np.testing.assert_allclose( + float(huber_loss), 0.5 * float(l2_loss), rtol=1e-6 + ) + + +# ── Config equivalence tests ────────────────────────────────────────────── + + +class TestConfigEquivalence: + """Verify that CalibrationModel configs match existing option configs.""" + + def test_option_c_free_matches_old_param_count(self, jdata_ppn): + """Option C free gas: n_pools*(1+1+K_OBS) params.""" + n_pools = len(jdata_ppn.pool_data) + k_attr = jdata_ppn.x_attr.shape[1] + model = CalibrationModel( + PerPoolHead("cad"), PerPoolHead("gas"), PerPoolNoiseHead() + ) + expected = n_pools * (1 + 1 + K_OBS) + assert model.n_params(n_pools, k_attr) == expected + + def test_option_a_ppn_free_matches_old_param_count(self, jdata_ppn): + """Option A ppn free: 2 + 2*k_attr + n_pools*K_OBS params.""" + n_pools = len(jdata_ppn.pool_data) + k_attr = jdata_ppn.x_attr.shape[1] + model = CalibrationModel( + LinearHead("cad"), LinearHead("gas"), PerPoolNoiseHead() + ) + expected = 2 + 2 * k_attr + n_pools * K_OBS + assert model.n_params(n_pools, k_attr) == expected + + def test_option_a_shared_free_matches_old_param_count(self, jdata_ppn): + """Option A shared free: 2 + 2*k_attr + (1+k_attr)*K_OBS.""" + n_pools = len(jdata_ppn.pool_data) + k_attr = jdata_ppn.x_attr.shape[1] + model = CalibrationModel( + LinearHead("cad"), LinearHead("gas"), SharedLinearNoiseHead() + ) + expected = 2 + 2 * k_attr + (1 + k_attr) * K_OBS + assert model.n_params(n_pools, k_attr) == expected + + +# ── MLP integration tests ───────────────────────────────────────────────── + + +class TestMLPIntegration: + """Test CalibrationModel with MLPHead for cadence.""" + + def test_mlp_cadence_n_params(self, jdata_ppn): + n_pools = len(jdata_ppn.pool_data) + k_attr = jdata_ppn.x_attr.shape[1] + model = CalibrationModel( + MLPHead("cad", hidden=8, alpha=0.01), + LinearHead("gas", alpha=0.01), + PerPoolNoiseHead(), + ) + mlp_params = k_attr * 8 + 8 + 8 + 1 + linear_params = 1 + k_attr + noise_params = n_pools * K_OBS + assert model.n_params(n_pools, k_attr) == mlp_params + linear_params + noise_params + + def test_mlp_cadence_fit_converges(self, jdata_ppn): + model = CalibrationModel( + MLPHead("cad", hidden=8, alpha=0.01), + LinearHead("gas", alpha=0.01), + PerPoolNoiseHead(), + ) + result = model.fit(jdata_ppn, maxiter=100) + assert result["loss"] <= result["init_loss"] + assert np.isfinite(result["loss"]) + + def test_mlp_cadence_and_gas_fit(self, jdata_ppn): + model = CalibrationModel( + MLPHead("cad", hidden=8, alpha=0.01), + MLPHead("gas", hidden=8, alpha=0.01), + PerPoolNoiseHead(), + ) + result = model.fit(jdata_ppn, maxiter=100) + assert result["loss"] <= result["init_loss"] + + def test_mlp_predict_new_pool(self, jdata_ppn): + model = CalibrationModel( + MLPHead("cad", hidden=8, alpha=0.01), + MLPHead("gas", hidden=8, alpha=0.01), + SharedLinearNoiseHead(alpha=0.01), + ) + result = model.fit(jdata_ppn, maxiter=50) + x_attr = np.zeros(result["k_attr"]) + pred = model.predict_new_pool(result, x_attr) + assert pred["cadence_minutes"] > 0 + assert pred["gas_usd"] > 0 + assert "noise_coeffs" in pred + + def test_mlp_loss_differentiable(self, jdata_ppn): + import jax + n_pools = len(jdata_ppn.pool_data) + k_attr = jdata_ppn.x_attr.shape[1] + model = CalibrationModel( + MLPHead("cad", hidden=8, alpha=0.01), + LinearHead("gas", alpha=0.01), + PerPoolNoiseHead(), + ) + loss_fn = model.make_joint_loss_fn(jdata_ppn) + init = jnp.array(model.pack_init(jdata_ppn)) + grad = jax.grad(loss_fn)(init) + assert jnp.all(jnp.isfinite(grad)) + assert float(jnp.sum(jnp.abs(grad))) > 0 + + def test_mlp_with_huber_loss(self, jdata_ppn): + model = CalibrationModel( + MLPHead("cad", hidden=8, alpha=0.01), + LinearHead("gas", alpha=0.01), + PerPoolNoiseHead(), + loss_type="huber", huber_delta=1.5, + ) + result = model.fit(jdata_ppn, maxiter=50) + assert result["loss"] >= 0 + assert np.isfinite(result["loss"]) + + +# ── MLP noise integration tests ─────────────────────────────────────────── + + +class TestMLPNoiseIntegration: + """Test CalibrationModel with MLPNoiseHead — the key use case.""" + + def test_mlp_noise_fit_converges(self, jdata_ppn): + model = CalibrationModel( + LinearHead("cad", alpha=0.01), + LinearHead("gas", alpha=0.01), + MLPNoiseHead(hidden=8, alpha=0.01), + ) + result = model.fit(jdata_ppn, maxiter=100) + assert result["loss"] <= result["init_loss"] + assert np.isfinite(result["loss"]) + + def test_mlp_noise_predict_new_pool(self, jdata_ppn): + model = CalibrationModel( + LinearHead("cad", alpha=0.01), + LinearHead("gas", alpha=0.01), + MLPNoiseHead(hidden=8, alpha=0.01), + ) + result = model.fit(jdata_ppn, maxiter=50) + x_attr = np.zeros(result["k_attr"]) + pred = model.predict_new_pool(result, x_attr) + assert pred["cadence_minutes"] > 0 + assert pred["gas_usd"] > 0 + assert "noise_coeffs" in pred + assert len(pred["noise_coeffs"]) == K_OBS + + def test_mlp_noise_loss_differentiable(self, jdata_ppn): + import jax + model = CalibrationModel( + LinearHead("cad", alpha=0.01), + LinearHead("gas", alpha=0.01), + MLPNoiseHead(hidden=8, alpha=0.01), + ) + loss_fn = model.make_joint_loss_fn(jdata_ppn) + init = jnp.array(model.pack_init(jdata_ppn)) + grad = jax.grad(loss_fn)(init) + assert jnp.all(jnp.isfinite(grad)) + assert float(jnp.sum(jnp.abs(grad))) > 0 + + def test_full_mlp_model(self, jdata_ppn): + """MLP for all three heads — most expressive config.""" + model = CalibrationModel( + MLPHead("cad", hidden=8, alpha=0.01), + MLPHead("gas", hidden=8, alpha=0.01), + MLPNoiseHead(hidden=8, alpha=0.01), + ) + result = model.fit(jdata_ppn, maxiter=100) + assert result["loss"] <= result["init_loss"] + x_attr = np.zeros(result["k_attr"]) + pred = model.predict_new_pool(result, x_attr) + assert pred["cadence_minutes"] > 0 + assert "noise_coeffs" in pred + + def test_mlp_noise_with_fixed_gas(self, jdata_ppn): + n_pools = len(jdata_ppn.pool_data) + gas_values = np.array([np.log(1.0)] * n_pools) + model = CalibrationModel( + LinearHead("cad", alpha=0.01), + FixedHead("gas", gas_values), + MLPNoiseHead(hidden=8, alpha=0.01), + ) + result = model.fit(jdata_ppn, maxiter=100) + assert result["loss"] <= result["init_loss"] + + def test_mlp_noise_param_count(self, jdata_ppn): + n_pools = len(jdata_ppn.pool_data) + k_attr = jdata_ppn.x_attr.shape[1] + h = 8 + model = CalibrationModel( + LinearHead("cad"), + LinearHead("gas"), + MLPNoiseHead(hidden=h), + ) + # Linear cad: 1+k, Linear gas: 1+k, + # MLP noise: k*h + h + h*K_OBS + K_OBS + expected = (1 + k_attr) * 2 + k_attr * h + h + h * K_OBS + K_OBS + assert model.n_params(n_pools, k_attr) == expected + + +# ── Reduced k_obs=4 integration tests ──────────────────────────────────── + +K_OBS_REDUCED = 4 + + +@pytest.fixture +def jdata_reduced(matched_data): + """JointData with reduced x_obs (4 columns).""" + from quantammsim.calibration.joint_fit import prepare_joint_data + return prepare_joint_data( + matched_data, drop_chain_dummies=True, + fix_gas_to_chain=True, reduced_x_obs=True, + ) + + +class TestReducedKObsIntegration: + """CalibrationModel with k_obs=4 noise heads on reduced x_obs data.""" + + def test_reduced_n_params(self, jdata_reduced): + n_pools = len(jdata_reduced.pool_data) + k_attr = jdata_reduced.x_attr.shape[1] + h = 8 + model = CalibrationModel( + LinearHead("cad", alpha=0.01), + FixedHead("gas", np.zeros(n_pools)), + MLPNoiseHead(hidden=h, alpha=0.01, k_obs=K_OBS_REDUCED), + ) + # Linear cad: 1+k, Fixed gas: 0, + # MLP noise: k*h + h + h*4 + 4 + expected = (1 + k_attr) + 0 + k_attr * h + h + h * 4 + 4 + assert model.n_params(n_pools, k_attr) == expected + + def test_reduced_loss_runs(self, jdata_reduced): + n_pools = len(jdata_reduced.pool_data) + gas_values = np.zeros(n_pools) + model = CalibrationModel( + LinearHead("cad", alpha=0.01), + FixedHead("gas", gas_values), + MLPNoiseHead(hidden=8, alpha=0.01, k_obs=K_OBS_REDUCED), + ) + loss_fn = model.make_joint_loss_fn(jdata_reduced) + init = jnp.array(model.pack_init(jdata_reduced)) + loss = float(loss_fn(init)) + assert np.isfinite(loss) and loss >= 0 + + def test_reduced_grad_finite(self, jdata_reduced): + n_pools = len(jdata_reduced.pool_data) + gas_values = np.zeros(n_pools) + model = CalibrationModel( + LinearHead("cad", alpha=0.01), + FixedHead("gas", gas_values), + MLPNoiseHead(hidden=8, alpha=0.01, k_obs=K_OBS_REDUCED), + ) + loss_fn = model.make_joint_loss_fn(jdata_reduced) + init = jnp.array(model.pack_init(jdata_reduced)) + grad = jax.grad(loss_fn)(init) + assert jnp.all(jnp.isfinite(grad)) + + def test_reduced_fit_converges(self, jdata_reduced): + n_pools = len(jdata_reduced.pool_data) + gas_values = np.zeros(n_pools) + model = CalibrationModel( + LinearHead("cad", alpha=0.01), + FixedHead("gas", gas_values), + MLPNoiseHead(hidden=8, alpha=0.01, k_obs=K_OBS_REDUCED), + ) + result = model.fit(jdata_reduced, maxiter=100) + assert result["loss"] <= result["init_loss"] + assert np.isfinite(result["loss"]) diff --git a/tests/calibration/test_heads.py b/tests/calibration/test_heads.py new file mode 100644 index 0000000..d5bc644 --- /dev/null +++ b/tests/calibration/test_heads.py @@ -0,0 +1,1131 @@ +"""Tests for quantammsim.calibration.heads — pluggable Head components.""" + +import jax.numpy as jnp +import numpy as np +import pytest + +from tests.calibration.conftest import K_OBS, POOL_PREFIXES + +from quantammsim.calibration.heads import ( + FixedHead, + Head, + LinearHead, + MLPHead, + MLPNoiseHead, + PerPoolHead, + PerPoolNoiseHead, + SharedLinearNoiseHead, +) +from quantammsim.calibration.pool_data import K_OBS_REDUCED + + +# ── Helpers ───────────────────────────────────────────────────────────────── + +N_POOLS = 2 +K_ATTR = 5 + + +def _make_fake_jdata(): + """Minimal JointData-like object for init() testing.""" + from quantammsim.calibration.joint_fit import JointData + + pool_data = [] + for _ in range(N_POOLS): + n_obs = 14 + x_obs = np.random.randn(n_obs, K_OBS) + x_obs[:, 0] = 1.0 # intercept column + y_obs = np.random.randn(n_obs) * 0.5 + 9.0 + pool_data.append({ + "x_obs": jnp.array(x_obs), + "y_obs": jnp.array(y_obs), + "day_indices": jnp.arange(n_obs) % 10, + }) + + x_attr = jnp.array(np.random.randn(N_POOLS, K_ATTR)) + return JointData( + pool_data=pool_data, + x_attr=x_attr, + pool_ids=POOL_PREFIXES[:N_POOLS], + attr_names=[f"attr_{i}" for i in range(K_ATTR)], + ) + + +# ── Protocol compliance ──────────────────────────────────────────────────── + + +class TestProtocol: + def test_per_pool_head_is_head(self): + assert isinstance(PerPoolHead("cad"), Head) + + def test_fixed_head_is_head(self): + assert isinstance(FixedHead("gas", np.array([1.0, 2.0])), Head) + + def test_linear_head_is_head(self): + assert isinstance(LinearHead("cad"), Head) + + def test_per_pool_noise_head_is_head(self): + assert isinstance(PerPoolNoiseHead(), Head) + + def test_shared_linear_noise_head_is_head(self): + assert isinstance(SharedLinearNoiseHead(), Head) + + def test_mlp_head_is_head(self): + assert isinstance(MLPHead("cad"), Head) + + def test_mlp_noise_head_is_head(self): + assert isinstance(MLPNoiseHead(), Head) + + +# ── PerPoolHead ───────────────────────────────────────────────────────────── + + +class TestPerPoolHead: + def test_n_params(self): + h = PerPoolHead("cad") + assert h.n_params(3, 5) == 3 + assert h.n_params(10, 7) == 10 + + def test_predict_returns_indexed_value(self): + h = PerPoolHead("cad") + params = jnp.array([1.0, 2.0, 3.0]) + x_attr_i = jnp.zeros(5) + assert float(h.predict(params, 0, x_attr_i)) == 1.0 + assert float(h.predict(params, 1, x_attr_i)) == 2.0 + assert float(h.predict(params, 2, x_attr_i)) == 3.0 + + def test_regularization_is_zero(self): + h = PerPoolHead("cad") + params = jnp.array([1.0, 2.0, 3.0]) + assert float(h.regularization(params)) == 0.0 + + def test_init_default(self): + h = PerPoolHead("cad", default=np.log(12.0)) + jdata = _make_fake_jdata() + init = h.init(jdata) + assert init.shape == (N_POOLS,) + np.testing.assert_allclose(init, np.log(12.0)) + + def test_init_warm_start(self): + h = PerPoolHead("log_cadence") + jdata = _make_fake_jdata() + warm = { + POOL_PREFIXES[0]: {"log_cadence": 2.5}, + POOL_PREFIXES[1]: {"log_cadence": 3.0}, + } + init = h.init(jdata, warm_start=warm) + np.testing.assert_allclose(init, [2.5, 3.0]) + + def test_predict_new_raises(self): + h = PerPoolHead("cad") + with pytest.raises(ValueError, match="cannot predict"): + h.predict_new(np.array([1.0]), np.zeros(5)) + + def test_make_bounds(self): + h = PerPoolHead("cad") + bounds = h.make_bounds(3, 5) + assert len(bounds) == 3 + assert all(b == (None, None) for b in bounds) + + +# ── FixedHead ─────────────────────────────────────────────────────────────── + + +class TestFixedHead: + def test_n_params_is_zero(self): + h = FixedHead("gas", np.array([0.0, -4.6])) + assert h.n_params(2, 5) == 0 + + def test_predict_returns_fixed_value(self): + vals = np.array([0.0, -4.6, 1.5]) + h = FixedHead("gas", vals) + empty_slice = jnp.array([]) + x_attr_i = jnp.zeros(5) + assert float(h.predict(empty_slice, 0, x_attr_i)) == 0.0 + np.testing.assert_allclose( + float(h.predict(empty_slice, 1, x_attr_i)), -4.6 + ) + assert float(h.predict(empty_slice, 2, x_attr_i)) == 1.5 + + def test_regularization_is_zero(self): + h = FixedHead("gas", np.array([1.0])) + assert float(h.regularization(jnp.array([]))) == 0.0 + + def test_init_returns_empty(self): + h = FixedHead("gas", np.array([1.0, 2.0])) + jdata = _make_fake_jdata() + init = h.init(jdata) + assert init.shape == (0,) + + def test_predict_new_raises(self): + h = FixedHead("gas", np.array([1.0])) + with pytest.raises(ValueError, match="cannot predict"): + h.predict_new(np.array([]), np.zeros(5)) + + def test_make_bounds_empty(self): + h = FixedHead("gas", np.array([1.0])) + assert h.make_bounds(1, 5) == [] + + +# ── LinearHead ────────────────────────────────────────────────────────────── + + +class TestLinearHead: + def test_n_params(self): + h = LinearHead("cad") + assert h.n_params(3, 5) == 6 # 1 + 5 + assert h.n_params(10, 7) == 8 # 1 + 7 + + def test_predict_bias_plus_dot(self): + h = LinearHead("cad") + # params_slice = [bias, W0, W1, W2] + params = jnp.array([2.0, 0.5, -1.0, 0.3]) + x_attr_i = jnp.array([1.0, 2.0, 3.0]) + # expected = 2.0 + (0.5*1.0 + (-1.0)*2.0 + 0.3*3.0) + # = 2.0 + 0.5 - 2.0 + 0.9 = 1.4 + result = float(h.predict(params, 0, x_attr_i)) + np.testing.assert_allclose(result, 1.4) + + def test_predict_ignores_pool_idx(self): + h = LinearHead("cad") + params = jnp.array([2.0, 0.5, -1.0]) + x = jnp.array([1.0, 2.0]) + v0 = float(h.predict(params, 0, x)) + v1 = float(h.predict(params, 5, x)) + assert v0 == v1 + + def test_regularization_on_W_not_bias(self): + h = LinearHead("cad", alpha=1.0) + params = jnp.array([100.0, 3.0, 4.0]) + # reg = 1.0 * (3^2 + 4^2) = 25.0 (bias ignored) + np.testing.assert_allclose(float(h.regularization(params)), 25.0) + + def test_regularization_alpha_scaling(self): + h = LinearHead("cad", alpha=0.5) + params = jnp.array([0.0, 2.0, 0.0]) + # reg = 0.5 * 4.0 = 2.0 + np.testing.assert_allclose(float(h.regularization(params)), 2.0) + + def test_init_default_cadence(self): + h = LinearHead("cad") + jdata = _make_fake_jdata() + init = h.init(jdata) + assert init.shape == (1 + K_ATTR,) + np.testing.assert_allclose(init[0], np.log(12.0)) + np.testing.assert_allclose(init[1:], 0.0) + + def test_init_default_gas(self): + h = LinearHead("gas") + jdata = _make_fake_jdata() + init = h.init(jdata) + np.testing.assert_allclose(init[0], np.log(1.0)) + + def test_init_warm_start(self): + h = LinearHead("log_cadence") + jdata = _make_fake_jdata() + warm = { + POOL_PREFIXES[0]: {"log_cadence": 2.0}, + POOL_PREFIXES[1]: {"log_cadence": 3.0}, + } + init = h.init(jdata, warm_start=warm) + assert init.shape == (1 + K_ATTR,) + # Should have fitted OLS to recover bias/W + + def test_predict_new(self): + h = LinearHead("cad") + params = np.array([2.0, 0.5, -1.0]) + x_attr = np.array([1.0, 2.0]) + result = h.predict_new(params, x_attr) + np.testing.assert_allclose(result, 2.0 + 0.5 - 2.0) + + def test_unpack_result(self): + h = LinearHead("cad") + params = np.array([2.0, 0.5, -1.0]) + result = h.unpack_result(params, 3, 2) + assert "bias_cad" in result + assert "W_cad" in result + np.testing.assert_allclose(result["bias_cad"], 2.0) + np.testing.assert_allclose(result["W_cad"], [0.5, -1.0]) + + def test_make_bounds(self): + h = LinearHead("cad") + bounds = h.make_bounds(3, 5) + assert len(bounds) == 6 # 1 + 5 + + +# ── PerPoolNoiseHead ──────────────────────────────────────────────────────── + + +class TestPerPoolNoiseHead: + def test_n_params(self): + h = PerPoolNoiseHead() + assert h.n_params(3, 5) == 3 * K_OBS + assert h.n_params(2, 7) == 2 * K_OBS + + def test_predict_correct_slice(self): + h = PerPoolNoiseHead() + n_pools = 3 + params = jnp.arange(n_pools * K_OBS, dtype=float) + x_attr_i = jnp.zeros(5) + + for i in range(n_pools): + result = h.predict(params, i, x_attr_i) + expected = params[i * K_OBS:(i + 1) * K_OBS] + np.testing.assert_allclose(result, expected) + + def test_regularization_zero_by_default(self): + h = PerPoolNoiseHead() + params = jnp.ones(16) + assert float(h.regularization(params)) == 0.0 + + def test_regularization_with_alpha(self): + h = PerPoolNoiseHead(alpha=1.0) + params = jnp.array([3.0, 4.0]) + np.testing.assert_allclose(float(h.regularization(params)), 25.0) + + def test_init_from_ols(self): + np.random.seed(42) + h = PerPoolNoiseHead() + jdata = _make_fake_jdata() + init = h.init(jdata) + assert init.shape == (N_POOLS * K_OBS,) + assert np.all(np.isfinite(init)) + + def test_init_warm_start(self): + h = PerPoolNoiseHead() + jdata = _make_fake_jdata() + warm = { + POOL_PREFIXES[0]: {"noise_coeffs": np.ones(K_OBS) * 5.0}, + POOL_PREFIXES[1]: {"noise_coeffs": np.ones(K_OBS) * 7.0}, + } + init = h.init(jdata, warm_start=warm) + assert init.shape == (N_POOLS * K_OBS,) + np.testing.assert_allclose(init[:K_OBS], 5.0) + np.testing.assert_allclose(init[K_OBS:], 7.0) + + def test_predict_new_raises(self): + h = PerPoolNoiseHead() + with pytest.raises(ValueError, match="cannot predict"): + h.predict_new(np.zeros(K_OBS * 2), np.zeros(5)) + + def test_unpack_result(self): + h = PerPoolNoiseHead() + params = np.arange(N_POOLS * K_OBS, dtype=float) + result = h.unpack_result(params, N_POOLS, K_ATTR) + assert result["noise_coeffs"].shape == (N_POOLS, K_OBS) + + +# ── SharedLinearNoiseHead ─────────────────────────────────────────────────── + + +class TestSharedLinearNoiseHead: + def test_n_params(self): + h = SharedLinearNoiseHead() + assert h.n_params(3, 5) == (1 + 5) * K_OBS + assert h.n_params(10, 7) == (1 + 7) * K_OBS + + def test_predict_bias_plus_dot(self): + k_attr = 3 + h = SharedLinearNoiseHead() + W_full = np.zeros((1 + k_attr, K_OBS)) + W_full[0, :] = 1.0 # bias_noise = [1, 1, ..., 1] + W_full[1, 0] = 2.0 # first feature maps to first noise coeff + params = jnp.array(W_full.ravel()) + x_attr_i = jnp.array([1.0, 0.0, 0.0]) + result = h.predict(params, 0, x_attr_i) + assert result.shape == (K_OBS,) + np.testing.assert_allclose(float(result[0]), 3.0) # 1 + 2*1 + np.testing.assert_allclose(float(result[1]), 1.0) # 1 + 0 + + def test_predict_ignores_pool_idx(self): + k_attr = 2 + h = SharedLinearNoiseHead() + params = jnp.ones((1 + k_attr) * K_OBS) + x = jnp.array([1.0, 2.0]) + r0 = h.predict(params, 0, x) + r5 = h.predict(params, 5, x) + np.testing.assert_allclose(r0, r5) + + def test_regularization_on_W_not_bias(self): + k_attr = 2 + h = SharedLinearNoiseHead(alpha=1.0) + W_full = np.zeros((1 + k_attr, K_OBS)) + W_full[0, :] = 100.0 # bias — not regularized + W_full[1, 0] = 3.0 + W_full[2, 0] = 4.0 + params = jnp.array(W_full.ravel()) + # reg = 1.0 * (9 + 16) = 25.0 + np.testing.assert_allclose(float(h.regularization(params)), 25.0) + + def test_init_default(self): + np.random.seed(42) + h = SharedLinearNoiseHead() + jdata = _make_fake_jdata() + init = h.init(jdata) + assert init.shape == ((1 + K_ATTR) * K_OBS,) + assert np.all(np.isfinite(init)) + + def test_predict_new(self): + k_attr = 2 + h = SharedLinearNoiseHead() + W_full = np.zeros((1 + k_attr, K_OBS)) + W_full[0, :] = 5.0 + W_full[1, 0] = 1.0 + params = W_full.ravel() + x_attr = np.array([2.0, 0.0]) + result = h.predict_new(params, x_attr) + assert result.shape == (K_OBS,) + np.testing.assert_allclose(result[0], 7.0) + np.testing.assert_allclose(result[1], 5.0) + + def test_unpack_result(self): + h = SharedLinearNoiseHead() + k_attr = 3 + W_full = np.arange((1 + k_attr) * K_OBS, dtype=float) + result = h.unpack_result(W_full, 2, k_attr) + assert "bias_noise" in result + assert "W_noise" in result + assert result["bias_noise"].shape == (K_OBS,) + assert result["W_noise"].shape == (k_attr, K_OBS) + + +# ── MLPHead ───────────────────────────────────────────────────────────────── + + +class TestMLPHead: + def test_n_params(self): + h = MLPHead("cad", hidden=16) + # k_attr=5: 5*16 + 16 + 16 + 1 = 113 + assert h.n_params(3, 5) == 113 + # k_attr=7: 7*16 + 16 + 16 + 1 = 145 + assert h.n_params(3, 7) == 145 + + def test_n_params_custom_hidden(self): + h = MLPHead("cad", hidden=8) + # k_attr=5: 5*8 + 8 + 8 + 1 = 57 + assert h.n_params(3, 5) == 57 + + def test_predict_with_zero_W2_equals_b2(self): + """With W2=0, output should be b2 regardless of input.""" + k_attr = 3 + h = MLPHead("cad", hidden=4) + n_p = h.n_params(1, k_attr) + params = np.zeros(n_p) + params[-1] = 2.5 # b2 + x_attr_i = jnp.array([1.0, 2.0, 3.0]) + result = float(h.predict(jnp.array(params), 0, x_attr_i)) + np.testing.assert_allclose(result, 2.5) + + def test_predict_nonlinear(self): + """MLP should produce different outputs for different inputs.""" + k_attr = 3 + h = MLPHead("cad", hidden=4, seed=42) + jdata = _make_fake_jdata() + # Override x_attr to have k_attr=3 + from quantammsim.calibration.joint_fit import JointData + jdata = JointData( + pool_data=jdata.pool_data, + x_attr=jnp.array(np.random.randn(N_POOLS, k_attr)), + pool_ids=jdata.pool_ids, + attr_names=[f"a{i}" for i in range(k_attr)], + ) + init = jnp.array(h.init(jdata)) + # Set W1 to nonzero so ReLU activations vary + np.random.seed(42) + W1 = np.random.randn(k_attr * 4) * 0.5 + init = init.at[:k_attr * 4].set(jnp.array(W1)) + # Set W2 to nonzero so output varies + init = init.at[k_attr * 4 + 4:k_attr * 4 + 8].set(jnp.ones(4) * 0.1) + + x1 = jnp.array([1.0, 0.0, 0.0]) + x2 = jnp.array([0.0, 1.0, 0.0]) + v1 = float(h.predict(init, 0, x1)) + v2 = float(h.predict(init, 0, x2)) + assert v1 != v2, "MLP should produce different outputs for different inputs" + + def test_predict_ignores_pool_idx(self): + """MLP output depends only on x_attr, not pool_idx.""" + k_attr = 3 + h = MLPHead("cad", hidden=4) + params = jnp.ones(h.n_params(5, k_attr)) * 0.1 + x = jnp.array([1.0, 2.0, 3.0]) + v0 = float(h.predict(params, 0, x)) + v3 = float(h.predict(params, 3, x)) + assert v0 == v3 + + def test_regularization_on_weights_not_biases(self): + k_attr = 2 + h_alpha1 = MLPHead("cad", hidden=2, alpha=1.0) + # Layout: W1(2*2=4), b1(2), W2(2), b2(1) = 9 params + params = np.zeros(9) + params[0] = 3.0 # W1[0,0] + params[1] = 4.0 # W1[0,1] + # b1 = 0 (indices 4,5) + params[6] = 1.0 # W2[0] + params[7] = 2.0 # W2[1] + params[8] = 999.0 # b2 — should not be regularized + # reg = 1.0 * (9 + 16 + 1 + 4) = 30.0 + result = float(h_alpha1.regularization(jnp.array(params))) + np.testing.assert_allclose(result, 30.0) + + def test_regularization_alpha_scaling(self): + k_attr = 2 + h = MLPHead("cad", hidden=2, alpha=0.5) + params = np.zeros(9) + params[0] = 2.0 # W1 weight + # reg = 0.5 * 4.0 = 2.0 + np.testing.assert_allclose(float(h.regularization(jnp.array(params))), 2.0) + + def test_init_default_cadence(self): + h = MLPHead("cad", hidden=4) + jdata = _make_fake_jdata() + init = h.init(jdata) + n_p = h.n_params(N_POOLS, K_ATTR) + assert init.shape == (n_p,) + assert np.all(np.isfinite(init)) + # b2 should be log(12) + np.testing.assert_allclose(init[-1], np.log(12.0)) + + def test_init_default_gas(self): + h = MLPHead("gas", hidden=4) + jdata = _make_fake_jdata() + init = h.init(jdata) + # b2 should be log(1) = 0 + np.testing.assert_allclose(init[-1], 0.0) + + def test_init_size(self): + """Init should return correct number of parameters.""" + h = MLPHead("cad", hidden=4) + jdata = _make_fake_jdata() + init = h.init(jdata) + assert init.shape == (h.n_params(N_POOLS, K_ATTR),) + assert np.all(np.isfinite(init)) + + def test_init_warm_start(self): + h = MLPHead("log_cadence", hidden=4) + jdata = _make_fake_jdata() + warm = { + POOL_PREFIXES[0]: {"log_cadence": 2.0}, + POOL_PREFIXES[1]: {"log_cadence": 3.0}, + } + init = h.init(jdata, warm_start=warm) + # b2 should be mean of warm-start values + np.testing.assert_allclose(init[-1], 2.5) + + def test_predict_new(self): + k_attr = 3 + h = MLPHead("cad", hidden=4) + n_p = h.n_params(1, k_attr) + params = np.zeros(n_p) + params[-1] = 2.5 # b2 + x_attr = np.array([1.0, 2.0, 3.0]) + result = h.predict_new(params, x_attr) + np.testing.assert_allclose(result, 2.5) + + def test_predict_new_matches_predict(self): + """predict_new should give same result as predict for same input.""" + k_attr = 3 + h = MLPHead("cad", hidden=4, seed=42) + n_p = h.n_params(1, k_attr) + np.random.seed(99) + params = np.random.randn(n_p) * 0.1 + x_attr = np.array([0.5, -1.0, 2.0]) + + jax_result = float(h.predict(jnp.array(params), 0, jnp.array(x_attr))) + np_result = h.predict_new(params, x_attr) + np.testing.assert_allclose(jax_result, np_result, rtol=1e-6) + + def test_unpack_result(self): + k_attr = 3 + h = MLPHead("cad", hidden=4) + n_p = h.n_params(1, k_attr) + params = np.arange(n_p, dtype=float) + result = h.unpack_result(params, 2, k_attr) + assert f"mlp_cad_W1" in result + assert f"mlp_cad_b1" in result + assert f"mlp_cad_W2" in result + assert f"mlp_cad_b2" in result + assert result["mlp_cad_W1"].shape == (k_attr, 4) + assert result["mlp_cad_b1"].shape == (4,) + assert result["mlp_cad_W2"].shape == (4,) + + def test_make_bounds(self): + h = MLPHead("cad", hidden=4) + bounds = h.make_bounds(3, 5) + assert len(bounds) == h.n_params(3, 5) + + def test_jax_differentiable(self): + """MLP predict should be JAX-differentiable.""" + import jax + k_attr = 3 + h = MLPHead("cad", hidden=4) + n_p = h.n_params(1, k_attr) + np.random.seed(42) + params = jnp.array(np.random.randn(n_p) * 0.1) + x_attr_i = jnp.array([1.0, 2.0, 3.0]) + + def loss(p): + return h.predict(p, 0, x_attr_i) ** 2 + + grad = jax.grad(loss)(params) + assert grad.shape == params.shape + assert jnp.all(jnp.isfinite(grad)) + + +# ── MLPNoiseHead ──────────────────────────────────────────────────────────── + + +class TestMLPNoiseHead: + def test_n_params(self): + h = MLPNoiseHead(hidden=16) + # k_attr=5: 5*16 + 16 + 16*8 + 8 = 80+16+128+8 = 232 + assert h.n_params(3, 5) == 232 + # k_attr=7: 7*16 + 16 + 16*8 + 8 = 112+16+128+8 = 264 + assert h.n_params(3, 7) == 264 + + def test_n_params_custom_hidden(self): + h = MLPNoiseHead(hidden=8) + # k_attr=5: 5*8 + 8 + 8*8 + 8 = 40+8+64+8 = 120 + assert h.n_params(3, 5) == 120 + + def test_predict_output_shape(self): + k_attr = 3 + h = MLPNoiseHead(hidden=4) + n_p = h.n_params(1, k_attr) + params = jnp.zeros(n_p) + x_attr_i = jnp.array([1.0, 2.0, 3.0]) + result = h.predict(params, 0, x_attr_i) + assert result.shape == (K_OBS,) + + def test_predict_with_zero_W2_equals_b2(self): + """With W2=0, output should be b2 regardless of input.""" + k_attr = 3 + h = MLPNoiseHead(hidden=4) + n_p = h.n_params(1, k_attr) + params = np.zeros(n_p) + # b2 is the last K_OBS elements + params[-K_OBS:] = np.arange(K_OBS) + 1.0 + x_attr_i = jnp.array([1.0, 2.0, 3.0]) + result = h.predict(jnp.array(params), 0, x_attr_i) + np.testing.assert_allclose(result, np.arange(K_OBS) + 1.0) + + def test_predict_nonlinear(self): + """MLP should produce different outputs for different inputs.""" + k_attr = 3 + h = MLPNoiseHead(hidden=4, seed=42) + n_p = h.n_params(1, k_attr) + np.random.seed(42) + params = jnp.array(np.random.randn(n_p) * 0.1) + x1 = jnp.array([1.0, 0.0, 0.0]) + x2 = jnp.array([0.0, 1.0, 0.0]) + v1 = h.predict(params, 0, x1) + v2 = h.predict(params, 0, x2) + assert not jnp.allclose(v1, v2), "Should produce different outputs" + + def test_predict_ignores_pool_idx(self): + k_attr = 3 + h = MLPNoiseHead(hidden=4) + params = jnp.ones(h.n_params(5, k_attr)) * 0.1 + x = jnp.array([1.0, 2.0, 3.0]) + v0 = h.predict(params, 0, x) + v3 = h.predict(params, 3, x) + np.testing.assert_allclose(v0, v3) + + def test_regularization_on_weights_not_biases(self): + k_attr = 2 + h = MLPNoiseHead(hidden=2, alpha=1.0) + # Layout: W1(2*2=4), b1(2), W2(2*8=16), b2(8) = 30 params + n_p = h.n_params(1, k_attr) + assert n_p == 30 + params = np.zeros(n_p) + params[0] = 3.0 # W1[0,0] + params[1] = 4.0 # W1[0,1] + # b1 at indices 4,5 — not regularized + params[6] = 1.0 # W2[0,0] + params[7] = 2.0 # W2[0,1] + params[-1] = 999.0 # b2[-1] — not regularized + # reg = 1.0 * (9 + 16 + 1 + 4) = 30.0 + result = float(h.regularization(jnp.array(params))) + np.testing.assert_allclose(result, 30.0) + + def test_init_default(self): + np.random.seed(42) + h = MLPNoiseHead(hidden=4) + jdata = _make_fake_jdata() + init = h.init(jdata) + n_p = h.n_params(N_POOLS, K_ATTR) + assert init.shape == (n_p,) + assert np.all(np.isfinite(init)) + + def test_init_size(self): + """Init should return correct number of parameters.""" + h = MLPNoiseHead(hidden=4) + jdata = _make_fake_jdata() + init = h.init(jdata) + assert init.shape == (h.n_params(N_POOLS, K_ATTR),) + assert np.all(np.isfinite(init)) + + def test_init_b2_from_ols(self): + """b2 should be pooled OLS noise coefficients.""" + np.random.seed(42) + h = MLPNoiseHead(hidden=4) + jdata = _make_fake_jdata() + init = h.init(jdata) + b2 = init[-K_OBS:] + assert np.all(np.isfinite(b2)) + # Should be nonzero (OLS on random data) + assert np.any(b2 != 0.0) + + def test_init_warm_start(self): + h = MLPNoiseHead(hidden=4) + jdata = _make_fake_jdata() + warm = { + POOL_PREFIXES[0]: {"noise_coeffs": np.ones(K_OBS) * 5.0}, + POOL_PREFIXES[1]: {"noise_coeffs": np.ones(K_OBS) * 7.0}, + } + init = h.init(jdata, warm_start=warm) + b2 = init[-K_OBS:] + # b2 should be mean of warm-start noise: (5+7)/2 = 6 + np.testing.assert_allclose(b2, 6.0) + + def test_predict_new(self): + k_attr = 3 + h = MLPNoiseHead(hidden=4) + n_p = h.n_params(1, k_attr) + params = np.zeros(n_p) + params[-K_OBS:] = np.arange(K_OBS) + 1.0 # b2 + x_attr = np.array([1.0, 2.0, 3.0]) + result = h.predict_new(params, x_attr) + assert result.shape == (K_OBS,) + np.testing.assert_allclose(result, np.arange(K_OBS) + 1.0) + + def test_predict_new_matches_predict(self): + k_attr = 3 + h = MLPNoiseHead(hidden=4, seed=42) + n_p = h.n_params(1, k_attr) + np.random.seed(99) + params = np.random.randn(n_p) * 0.1 + x_attr = np.array([0.5, -1.0, 2.0]) + + jax_result = np.array(h.predict(jnp.array(params), 0, jnp.array(x_attr))) + np_result = h.predict_new(params, x_attr) + np.testing.assert_allclose(jax_result, np_result, rtol=1e-6) + + def test_unpack_result(self): + k_attr = 3 + h = MLPNoiseHead(hidden=4) + n_p = h.n_params(1, k_attr) + params = np.arange(n_p, dtype=float) + result = h.unpack_result(params, 2, k_attr) + assert "mlp_noise_W1" in result + assert "mlp_noise_b1" in result + assert "mlp_noise_W2" in result + assert "mlp_noise_b2" in result + assert result["mlp_noise_W1"].shape == (k_attr, 4) + assert result["mlp_noise_b1"].shape == (4,) + assert result["mlp_noise_W2"].shape == (4, K_OBS) + assert result["mlp_noise_b2"].shape == (K_OBS,) + + def test_make_bounds(self): + h = MLPNoiseHead(hidden=4) + bounds = h.make_bounds(3, 5) + assert len(bounds) == h.n_params(3, 5) + + def test_jax_differentiable(self): + import jax + k_attr = 3 + h = MLPNoiseHead(hidden=4) + n_p = h.n_params(1, k_attr) + np.random.seed(42) + params = jnp.array(np.random.randn(n_p) * 0.1) + x_attr_i = jnp.array([1.0, 2.0, 3.0]) + + def loss(p): + return jnp.sum(h.predict(p, 0, x_attr_i) ** 2) + + grad = jax.grad(loss)(params) + assert grad.shape == params.shape + assert jnp.all(jnp.isfinite(grad)) + + +# ── Reduced k_obs=4 tests ───────────────────────────────────────────────── + +K_OBS_REDUCED = 4 + + +def _make_fake_jdata_reduced(): + """JointData-like object with k_obs=4 x_obs for reduced noise testing.""" + from quantammsim.calibration.joint_fit import JointData + + pool_data = [] + for _ in range(N_POOLS): + n_obs = 14 + x_obs = np.random.randn(n_obs, K_OBS_REDUCED) + x_obs[:, 0] = 1.0 # intercept column + y_obs = np.random.randn(n_obs) * 0.5 + 9.0 + pool_data.append({ + "x_obs": jnp.array(x_obs), + "y_obs": jnp.array(y_obs), + "day_indices": jnp.arange(n_obs) % 10, + }) + + x_attr = jnp.array(np.random.randn(N_POOLS, K_ATTR)) + return JointData( + pool_data=pool_data, + x_attr=x_attr, + pool_ids=POOL_PREFIXES[:N_POOLS], + attr_names=[f"attr_{i}" for i in range(K_ATTR)], + ) + + +class TestPerPoolNoiseHeadReduced: + """PerPoolNoiseHead with k_obs=4.""" + + def test_n_params(self): + h = PerPoolNoiseHead(k_obs=4) + assert h.n_params(3, 5) == 3 * 4 + + def test_predict_correct_slice(self): + h = PerPoolNoiseHead(k_obs=4) + params = jnp.arange(3 * 4, dtype=float) + x_attr_i = jnp.zeros(5) + for i in range(3): + result = h.predict(params, i, x_attr_i) + expected = params[i * 4:(i + 1) * 4] + np.testing.assert_allclose(result, expected) + + def test_init_ols(self): + np.random.seed(42) + h = PerPoolNoiseHead(k_obs=4) + jdata = _make_fake_jdata_reduced() + init = h.init(jdata) + assert init.shape == (N_POOLS * 4,) + assert np.all(np.isfinite(init)) + + def test_roundtrip(self): + np.random.seed(42) + h = PerPoolNoiseHead(k_obs=4) + jdata = _make_fake_jdata_reduced() + init = h.init(jdata) + result = h.unpack_result(init, N_POOLS, K_ATTR) + assert result["noise_coeffs"].shape == (N_POOLS, 4) + + def test_default_unchanged(self): + h = PerPoolNoiseHead() + assert h.k_obs == K_OBS + assert h.n_params(3, 5) == 3 * K_OBS + + +class TestSharedLinearNoiseHeadReduced: + """SharedLinearNoiseHead with k_obs=4.""" + + def test_n_params(self): + h = SharedLinearNoiseHead(k_obs=4) + assert h.n_params(3, 5) == (1 + 5) * 4 + + def test_predict(self): + k_attr = 3 + h = SharedLinearNoiseHead(k_obs=4) + W_full = np.zeros((1 + k_attr, 4)) + W_full[0, :] = 1.0 + W_full[1, 0] = 2.0 + params = jnp.array(W_full.ravel()) + x_attr_i = jnp.array([1.0, 0.0, 0.0]) + result = h.predict(params, 0, x_attr_i) + assert result.shape == (4,) + np.testing.assert_allclose(float(result[0]), 3.0) + np.testing.assert_allclose(float(result[1]), 1.0) + + def test_init(self): + np.random.seed(42) + h = SharedLinearNoiseHead(k_obs=4) + jdata = _make_fake_jdata_reduced() + init = h.init(jdata) + assert init.shape == ((1 + K_ATTR) * 4,) + assert np.all(np.isfinite(init)) + + def test_default_unchanged(self): + h = SharedLinearNoiseHead() + assert h.k_obs == K_OBS + assert h.n_params(3, 5) == (1 + 5) * K_OBS + + +class TestMLPNoiseHeadReduced: + """MLPNoiseHead with k_obs=4.""" + + def test_n_params(self): + h = MLPNoiseHead(hidden=16, k_obs=4) + # k_attr=5: 5*16 + 16 + 16*4 + 4 = 80+16+64+4 = 164 + assert h.n_params(3, 5) == 164 + + def test_predict(self): + k_attr = 3 + h = MLPNoiseHead(hidden=4, k_obs=4) + n_p = h.n_params(1, k_attr) + params = jnp.zeros(n_p) + x_attr_i = jnp.array([1.0, 2.0, 3.0]) + result = h.predict(params, 0, x_attr_i) + assert result.shape == (4,) + + def test_init(self): + np.random.seed(42) + h = MLPNoiseHead(hidden=4, k_obs=4) + jdata = _make_fake_jdata_reduced() + init = h.init(jdata) + n_p = h.n_params(N_POOLS, K_ATTR) + assert init.shape == (n_p,) + assert np.all(np.isfinite(init)) + + def test_regularization(self): + k_attr = 2 + h = MLPNoiseHead(hidden=2, alpha=1.0, k_obs=4) + # Layout: W1(2*2=4), b1(2), W2(2*4=8), b2(4) = 18 params + n_p = h.n_params(1, k_attr) + assert n_p == 18 + params = np.zeros(n_p) + params[0] = 3.0 # W1[0,0] + params[1] = 4.0 # W1[0,1] + params[6] = 1.0 # W2[0,0] + params[7] = 2.0 # W2[0,1] + params[-1] = 999.0 # b2[-1] — not regularized + # reg = 1.0 * (9 + 16 + 1 + 4) = 30.0 + result = float(h.regularization(jnp.array(params))) + np.testing.assert_allclose(result, 30.0) + + def test_default_unchanged(self): + h = MLPNoiseHead() + assert h.k_obs == K_OBS + assert h.n_params(3, 5) == 232 + + +# ── TokenFactoredNoiseHead ───────────────────────────────────────────────── + + +def _make_token_factored_head(k_obs=K_OBS_REDUCED): + """Build a TokenFactoredNoiseHead from synthetic 2-pool, 3-token data.""" + from quantammsim.calibration.heads import TokenFactoredNoiseHead + + # Pool 0: (BTC=1, ETH=2) on MAINNET=1, fee=0.003 + # Pool 1: (AAVE=0, ETH=2) on ARBITRUM=0, fee=0.01 + token_a_idx = np.array([1, 0], dtype=np.int32) # BTC, AAVE + token_b_idx = np.array([2, 2], dtype=np.int32) # ETH, ETH + chain_idx = np.array([1, 0], dtype=np.int32) # MAINNET, ARBITRUM + log_fees = np.array([np.log(0.003), np.log(0.01)]) + x_token = np.array([ + [1.0, 20.0, 0.0, 0.0, 0.0], # AAVE: volatile + [1.0, 25.0, 0.0, 0.0, 0.0], # BTC: volatile + [1.0, 26.0, 0.0, 1.0, 1.0], # ETH: eth_derivative + L1_native + ]) + token_index = {"AAVE": 0, "BTC": 1, "ETH": 2} + chain_index = {"ARBITRUM": 0, "MAINNET": 1} + + head = TokenFactoredNoiseHead( + token_a_idx=token_a_idx, + token_b_idx=token_b_idx, + chain_idx=chain_idx, + log_fees=log_fees, + x_token=x_token, + n_tokens=3, + n_chains=2, + token_index=token_index, + chain_index=chain_index, + k_obs=k_obs, + lambda_delta=1.0, + lambda_token=0.1, + lambda_chain=0.1, + lambda_fee=0.01, + ) + return head + + +class TestTokenFactoredNoiseHead: + + def test_is_head(self): + head = _make_token_factored_head() + assert isinstance(head, Head) + + def test_n_params(self): + head = _make_token_factored_head(k_obs=4) + # 3 tokens * 4 + 5 d_token * 4 + 2 chains * 4 + 4 beta_fee + 2 pools * 4 + # = 12 + 20 + 8 + 4 + 8 = 52 + assert head.n_params(2, K_ATTR) == 52 + + def test_predict_returns_k_obs_vector(self): + head = _make_token_factored_head(k_obs=4) + n_p = head.n_params(2, K_ATTR) + params = jnp.zeros(n_p) + x_attr_i = jnp.zeros(K_ATTR) + result = head.predict(params, 0, x_attr_i) + assert result.shape == (4,) + + def test_predict_additivity(self): + """predict(pool_0) = u[BTC] + u[ETH] + alpha[MAINNET] + + beta_fee * log(0.003) + delta[0]""" + head = _make_token_factored_head(k_obs=4) + n_p = head.n_params(2, K_ATTR) + + # Build params with known values + params = np.zeros(n_p) + k = 4 # k_obs + # u: (3 tokens, 4) at offset 0 + u_flat = np.array([ + 1.0, 0.0, 0.0, 0.0, # AAVE + 2.0, 0.5, 0.0, 0.0, # BTC + 3.0, 1.0, 0.0, 0.0, # ETH + ]) + params[:12] = u_flat + # Gamma: (5, 4) at offset 12 — skip (doesn't affect predict) + # alpha: (2, 4) at offset 32 + alpha_flat = np.array([ + 0.1, 0.0, 0.0, 0.0, # ARBITRUM + 0.2, 0.0, 0.0, 0.0, # MAINNET + ]) + params[32:40] = alpha_flat + # beta_fee: (4,) at offset 40 + params[40:44] = np.array([0.5, 0.0, 0.0, 0.0]) + # delta: (2, 4) at offset 44 + params[44:48] = np.array([0.05, 0.0, 0.0, 0.0]) # pool 0 delta + + result = head.predict(jnp.array(params), 0, jnp.zeros(K_ATTR)) + + # Expected for pool 0: u[BTC] + u[ETH] + alpha[MAINNET] + # + beta_fee * log(0.003) + delta[0] + expected_0 = 2.0 + 3.0 + 0.2 + 0.5 * np.log(0.003) + 0.05 + np.testing.assert_allclose(float(result[0]), expected_0, rtol=1e-5) + + def test_regularization_nonneg_and_finite(self): + head = _make_token_factored_head() + n_p = head.n_params(2, K_ATTR) + np.random.seed(42) + params = jnp.array(np.random.randn(n_p) * 0.1) + reg = float(head.regularization(params)) + assert np.isfinite(reg) + assert reg >= 0.0 + + def test_regularization_zero_when_perfect(self): + """If u = x_token @ Gamma exactly, delta=0, alpha=0, beta_fee=0, + then only the Gamma-predicted part has zero token reg.""" + head = _make_token_factored_head(k_obs=4) + n_p = head.n_params(2, K_ATTR) + params = np.zeros(n_p) + # Set Gamma to something, then set u = x_token @ Gamma + np.random.seed(7) + Gamma = np.random.randn(5, 4) * 0.1 + u = head.x_token @ Gamma # (3, 4) + params[:12] = u.ravel() + params[12:32] = Gamma.ravel() + # alpha, beta_fee, delta all zero + reg = float(head.regularization(jnp.array(params))) + # Only lambda_token * 0 + lambda_chain * 0 + lambda_fee * 0 + lambda_delta * 0 + np.testing.assert_allclose(reg, 0.0, atol=1e-10) + + def test_init_cold(self): + head = _make_token_factored_head(k_obs=4) + jdata = _make_fake_jdata_reduced() + init = head.init(jdata) + n_p = head.n_params(N_POOLS, K_ATTR) + assert init.shape == (n_p,) + assert np.all(np.isfinite(init)) + + def test_init_warm_start_roundtrip(self): + """init from warm_start → predict ≈ warm_start noise_coeffs.""" + head = _make_token_factored_head(k_obs=4) + jdata = _make_fake_jdata_reduced() + + # Warm start with known noise coefficients per pool + warm = { + POOL_PREFIXES[0]: {"noise_coeffs": np.array([9.0, 0.5, 0.1, -0.2])}, + POOL_PREFIXES[1]: {"noise_coeffs": np.array([8.5, 0.3, 0.2, -0.1])}, + } + init = head.init(jdata, warm_start=warm) + params = jnp.array(init) + x_attr_dummy = jnp.zeros(K_ATTR) + + for i, pid in enumerate(jdata.pool_ids): + predicted = np.array(head.predict(params, i, x_attr_dummy)) + target = warm[pid]["noise_coeffs"] + # Should approximately recover the warm-start values + # (not exact because the lstsq decomposition is underdetermined + # with 2 pools and 3 tokens) + np.testing.assert_allclose(predicted, target, atol=0.5) + + def test_gradient_finite(self): + """jax.grad of a simple loss at init produces finite gradients.""" + import jax + head = _make_token_factored_head(k_obs=4) + jdata = _make_fake_jdata_reduced() + init = jnp.array(head.init(jdata)) + x_attr_i = jnp.zeros(K_ATTR) + + def loss(p): + c = head.predict(p, 0, x_attr_i) + return jnp.sum(c ** 2) + head.regularization(p) + + grad = jax.grad(loss)(init) + assert grad.shape == init.shape + assert jnp.all(jnp.isfinite(grad)) + + def test_unpack_result_keys(self): + head = _make_token_factored_head(k_obs=4) + n_p = head.n_params(2, K_ATTR) + np.random.seed(42) + params = np.random.randn(n_p) + result = head.unpack_result(params, 2, K_ATTR) + for key in ["token_effects", "Gamma", "chain_effects", + "beta_fee", "noise_deltas", "noise_coeffs"]: + assert key in result, f"Missing key: {key}" + assert result["token_effects"].shape == (3, 4) + assert result["Gamma"].shape == (5, 4) + assert result["chain_effects"].shape == (2, 4) + assert result["beta_fee"].shape == (4,) + assert result["noise_deltas"].shape == (2, 4) + assert result["noise_coeffs"].shape == (2, 4) + + def test_make_bounds(self): + head = _make_token_factored_head(k_obs=4) + bounds = head.make_bounds(2, K_ATTR) + assert len(bounds) == head.n_params(2, K_ATTR) + assert all(b == (None, None) for b in bounds) + + def test_predict_new_pool_seen_tokens(self): + head = _make_token_factored_head(k_obs=4) + n_p = head.n_params(2, K_ATTR) + np.random.seed(42) + params = np.random.randn(n_p) * 0.1 + result = head.predict_new_pool( + params, "BTC", "AAVE", "MAINNET", 0.003, n_pools=2 + ) + assert "noise_coeffs" in result + assert "components" in result + nc = result["noise_coeffs"] + assert nc.shape == (4,) or len(nc) == 4 + # Should equal u[BTC] + u[AAVE] + alpha[MAINNET] + beta_fee*log(0.003) + # (no delta for new pool) + comps = result["components"] + reconstructed = (comps["token_a"] + comps["token_b"] + + comps["chain"] + comps["fee"]) + np.testing.assert_allclose(nc, reconstructed, rtol=1e-6) + + def test_predict_new_pool_unseen_token(self): + head = _make_token_factored_head(k_obs=4) + n_p = head.n_params(2, K_ATTR) + np.random.seed(42) + params = np.random.randn(n_p) * 0.1 + # "LINK" is not in token_index → should fall back to Gamma + result = head.predict_new_pool( + params, "LINK", "ETH", "MAINNET", 0.003, n_pools=2 + ) + assert "noise_coeffs" in result + assert result["noise_coeffs"].shape == (4,) or len(result["noise_coeffs"]) == 4 + + def test_predict_new_pool_unseen_chain(self): + head = _make_token_factored_head(k_obs=4) + n_p = head.n_params(2, K_ATTR) + np.random.seed(42) + params = np.random.randn(n_p) * 0.1 + # "BASE" is not in chain_index → alpha = zeros + result = head.predict_new_pool( + params, "BTC", "ETH", "BASE", 0.003, n_pools=2 + ) + np.testing.assert_allclose( + result["components"]["chain"], np.zeros(4) + ) diff --git a/tests/calibration/test_joint_fit.py b/tests/calibration/test_joint_fit.py index 50d7730..1fe02f2 100644 --- a/tests/calibration/test_joint_fit.py +++ b/tests/calibration/test_joint_fit.py @@ -188,3 +188,235 @@ def test_shared_noise_predict(self, matched_data): pred = predict_new_pool_joint(result, x_attr_new) assert "noise_coeffs" in pred assert len(pred["noise_coeffs"]) == K_OBS + + +class TestPrepareTokenFactoredData: + def test_returns_jdata_and_encoding(self, matched_data): + from quantammsim.calibration.joint_fit import prepare_token_factored_data + + jdata, token_enc = prepare_token_factored_data(matched_data) + assert hasattr(jdata, "pool_data") + assert hasattr(jdata, "x_attr") + assert "token_index" in token_enc + assert "token_a_idx" in token_enc + + def test_reduced_x_obs_by_default(self, matched_data): + from quantammsim.calibration.joint_fit import prepare_token_factored_data + from quantammsim.calibration.pool_data import K_OBS_REDUCED + + jdata, _ = prepare_token_factored_data(matched_data) + for pd in jdata.pool_data: + assert pd["x_obs"].shape[1] == K_OBS_REDUCED + + def test_encoding_matches_pools(self, matched_data): + from quantammsim.calibration.joint_fit import prepare_token_factored_data + + jdata, enc = prepare_token_factored_data(matched_data) + n_pools = len(jdata.pool_data) + assert len(enc["token_a_idx"]) == n_pools + assert len(enc["token_b_idx"]) == n_pools + assert len(enc["chain_idx"]) == n_pools + assert len(enc["log_fees"]) == n_pools + + +class TestTokenFactoredEndToEnd: + """Full pipeline: TokenFactoredNoiseHead + CalibrationModel + fit.""" + + def test_model_fits_and_converges(self, matched_data): + from quantammsim.calibration.calibration_model import CalibrationModel + from quantammsim.calibration.heads import ( + FixedHead, PerPoolHead, TokenFactoredNoiseHead, + ) + from quantammsim.calibration.joint_fit import prepare_token_factored_data + from quantammsim.calibration.loss import CHAIN_GAS_USD + from quantammsim.calibration.pool_data import K_OBS_REDUCED + + jdata, enc = prepare_token_factored_data(matched_data) + n_pools = len(jdata.pool_data) + + # Gas: fixed to chain defaults + gas_values = [] + for pid in jdata.pool_ids: + chain = matched_data[pid]["chain"] + gas_values.append(np.log(max(CHAIN_GAS_USD.get(chain, 1.0), 1e-6))) + gas_head = FixedHead("log_gas", np.array(gas_values)) + + # Cadence: per-pool + cad_head = PerPoolHead("log_cadence", default=np.log(12.0)) + + # Noise: token-factored + noise_head = TokenFactoredNoiseHead(k_obs=K_OBS_REDUCED, **enc) + + model = CalibrationModel(cad_head, gas_head, noise_head) + result = model.fit(jdata, maxiter=50) + + assert result["loss"] <= result["init_loss"] + assert "token_effects" in result + assert "noise_coeffs" in result + assert result["noise_coeffs"].shape == (n_pools, K_OBS_REDUCED) + + def test_model_with_warm_start(self, matched_data): + from quantammsim.calibration.calibration_model import CalibrationModel + from quantammsim.calibration.heads import ( + FixedHead, PerPoolHead, TokenFactoredNoiseHead, + ) + from quantammsim.calibration.joint_fit import prepare_token_factored_data + from quantammsim.calibration.loss import CHAIN_GAS_USD + from quantammsim.calibration.per_pool_fit import fit_all_pools + from quantammsim.calibration.pool_data import K_OBS_REDUCED + + # Run Option C first + option_c = fit_all_pools( + matched_data, fix_gas_to_chain=True, reduced=True + ) + + jdata, enc = prepare_token_factored_data(matched_data) + n_pools = len(jdata.pool_data) + + gas_values = [] + for pid in jdata.pool_ids: + chain = matched_data[pid]["chain"] + gas_values.append(np.log(max(CHAIN_GAS_USD.get(chain, 1.0), 1e-6))) + + noise_head = TokenFactoredNoiseHead(k_obs=K_OBS_REDUCED, **enc) + model = CalibrationModel( + PerPoolHead("log_cadence", default=np.log(12.0)), + FixedHead("log_gas", np.array(gas_values)), + noise_head, + ) + result = model.fit(jdata, maxiter=100, warm_start=option_c) + assert result["loss"] <= result["init_loss"] + + +class TestDataRegLossSeparation: + """Test that fit() reports data_loss and reg_loss separately.""" + + def test_fit_reports_data_and_reg_loss(self, matched_data): + from quantammsim.calibration.calibration_model import CalibrationModel + from quantammsim.calibration.heads import ( + FixedHead, PerPoolHead, TokenFactoredNoiseHead, + ) + from quantammsim.calibration.joint_fit import prepare_token_factored_data + from quantammsim.calibration.loss import CHAIN_GAS_USD + from quantammsim.calibration.pool_data import K_OBS_REDUCED + + jdata, enc = prepare_token_factored_data(matched_data) + n_pools = len(jdata.pool_data) + + gas_values = [] + for pid in jdata.pool_ids: + chain = matched_data[pid]["chain"] + gas_values.append(np.log(max(CHAIN_GAS_USD.get(chain, 1.0), 1e-6))) + + noise_head = TokenFactoredNoiseHead(k_obs=K_OBS_REDUCED, **enc) + model = CalibrationModel( + PerPoolHead("log_cadence", default=np.log(12.0)), + FixedHead("log_gas", np.array(gas_values)), + noise_head, + ) + result = model.fit(jdata, maxiter=50) + + assert "data_loss" in result + assert "reg_loss" in result + + def test_data_plus_reg_equals_total(self, matched_data): + from quantammsim.calibration.calibration_model import CalibrationModel + from quantammsim.calibration.heads import ( + FixedHead, PerPoolHead, TokenFactoredNoiseHead, + ) + from quantammsim.calibration.joint_fit import prepare_token_factored_data + from quantammsim.calibration.loss import CHAIN_GAS_USD + from quantammsim.calibration.pool_data import K_OBS_REDUCED + + jdata, enc = prepare_token_factored_data(matched_data) + gas_values = [] + for pid in jdata.pool_ids: + chain = matched_data[pid]["chain"] + gas_values.append(np.log(max(CHAIN_GAS_USD.get(chain, 1.0), 1e-6))) + + noise_head = TokenFactoredNoiseHead(k_obs=K_OBS_REDUCED, **enc) + model = CalibrationModel( + PerPoolHead("log_cadence", default=np.log(12.0)), + FixedHead("log_gas", np.array(gas_values)), + noise_head, + ) + result = model.fit(jdata, maxiter=50) + + np.testing.assert_allclose( + result["data_loss"] + result["reg_loss"], + result["loss"], + rtol=1e-6, + ) + + def test_data_loss_leq_total(self, matched_data): + from quantammsim.calibration.calibration_model import CalibrationModel + from quantammsim.calibration.heads import ( + FixedHead, PerPoolHead, TokenFactoredNoiseHead, + ) + from quantammsim.calibration.joint_fit import prepare_token_factored_data + from quantammsim.calibration.loss import CHAIN_GAS_USD + from quantammsim.calibration.pool_data import K_OBS_REDUCED + + jdata, enc = prepare_token_factored_data(matched_data) + gas_values = [] + for pid in jdata.pool_ids: + chain = matched_data[pid]["chain"] + gas_values.append(np.log(max(CHAIN_GAS_USD.get(chain, 1.0), 1e-6))) + + noise_head = TokenFactoredNoiseHead(k_obs=K_OBS_REDUCED, **enc) + model = CalibrationModel( + PerPoolHead("log_cadence", default=np.log(12.0)), + FixedHead("log_gas", np.array(gas_values)), + noise_head, + ) + result = model.fit(jdata, maxiter=50) + + assert result["data_loss"] <= result["loss"] + 1e-10 + assert result["reg_loss"] >= -1e-10 + + +class TestPrepareTokenFactoredCrossPool: + """Test prepare_token_factored_data with cross_pool=True.""" + + def test_cross_pool_jdata_shapes_consistent(self, matched_data): + from quantammsim.calibration.joint_fit import prepare_token_factored_data + + jdata, enc = prepare_token_factored_data(matched_data, cross_pool=True) + for pd in jdata.pool_data: + assert pd["x_obs"].shape[0] == pd["y_obs"].shape[0] + assert pd["x_obs"].shape[0] == pd["day_indices"].shape[0] + + def test_cross_pool_x_obs_has_7_cols(self, matched_data): + from quantammsim.calibration.joint_fit import prepare_token_factored_data + from quantammsim.calibration.pool_data import K_OBS_CROSS + + jdata, _ = prepare_token_factored_data(matched_data, cross_pool=True) + for pd in jdata.pool_data: + assert pd["x_obs"].shape[1] == K_OBS_CROSS + + def test_cross_pool_drops_first_obs(self, matched_data): + from quantammsim.calibration.joint_fit import prepare_token_factored_data + + jdata_base, _ = prepare_token_factored_data(matched_data, cross_pool=False) + jdata_cross, _ = prepare_token_factored_data(matched_data, cross_pool=True) + for pd_base, pd_cross in zip(jdata_base.pool_data, jdata_cross.pool_data): + assert pd_cross["y_obs"].shape[0] == pd_base["y_obs"].shape[0] - 1 + + +class TestPrepareJointDataReduced: + """Test prepare_joint_data with reduced_x_obs=True.""" + + def test_reduced_x_obs_shape(self, matched_data): + from quantammsim.calibration.joint_fit import prepare_joint_data + from quantammsim.calibration.pool_data import K_OBS_REDUCED + + jdata = prepare_joint_data(matched_data, reduced_x_obs=True) + for pd in jdata.pool_data: + assert pd["x_obs"].shape[1] == K_OBS_REDUCED + + def test_default_unchanged(self, matched_data): + from quantammsim.calibration.joint_fit import prepare_joint_data + + jdata = prepare_joint_data(matched_data) + for pd in jdata.pool_data: + assert pd["x_obs"].shape[1] == K_OBS diff --git a/tests/calibration/test_joint_fit_fixed_gas.py b/tests/calibration/test_joint_fit_fixed_gas.py new file mode 100644 index 0000000..7981ff4 --- /dev/null +++ b/tests/calibration/test_joint_fit_fixed_gas.py @@ -0,0 +1,467 @@ +"""Tests for fixed-gas mode in quantammsim.calibration.joint_fit.""" + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from tests.calibration.conftest import K_OBS, POOL_PREFIXES + + +@pytest.fixture +def matched_data(synthetic_daily_grid, synthetic_panel, tmp_path): + """Build matched data dict from synthetic fixtures.""" + from quantammsim.calibration.pool_data import match_grids_to_panel + + grid_dir = tmp_path / "grids" + grid_dir.mkdir() + for prefix in POOL_PREFIXES: + synthetic_daily_grid.to_parquet( + grid_dir / f"{prefix}_daily.parquet", index=False + ) + return match_grids_to_panel(str(grid_dir), synthetic_panel) + + +class TestPrepareJointDataFixedGas: + """Test prepare_joint_data with fix_gas_to_chain=True.""" + + def test_pool_data_has_fixed_log_gas(self, matched_data): + from quantammsim.calibration.joint_fit import prepare_joint_data + + jdata = prepare_joint_data(matched_data, fix_gas_to_chain=True) + for pd_i in jdata.pool_data: + assert "fixed_log_gas" in pd_i + + def test_pool_data_no_fixed_log_gas_by_default(self, matched_data): + from quantammsim.calibration.joint_fit import prepare_joint_data + + jdata = prepare_joint_data(matched_data, fix_gas_to_chain=False) + for pd_i in jdata.pool_data: + assert "fixed_log_gas" not in pd_i + + def test_fixed_log_gas_values_match_chain(self, matched_data): + """fixed_log_gas should be log(CHAIN_GAS_USD[chain]).""" + from quantammsim.calibration.joint_fit import prepare_joint_data + from quantammsim.calibration.loss import CHAIN_GAS_USD + + jdata = prepare_joint_data(matched_data, fix_gas_to_chain=True) + + for i, pid in enumerate(jdata.pool_ids): + chain = matched_data[pid]["chain"] + expected_gas = CHAIN_GAS_USD.get(chain, 1.0) + expected_log_gas = np.log(max(expected_gas, 1e-6)) + np.testing.assert_allclose( + float(jdata.pool_data[i]["fixed_log_gas"]), + expected_log_gas, + rtol=1e-6, + ) + + def test_mainnet_gas_is_log_1(self, matched_data): + """MAINNET pools should have fixed_log_gas = log(1.0) = 0.0.""" + from quantammsim.calibration.joint_fit import prepare_joint_data + + jdata = prepare_joint_data(matched_data, fix_gas_to_chain=True) + found = False + for i, pid in enumerate(jdata.pool_ids): + if matched_data[pid]["chain"] == "MAINNET": + np.testing.assert_allclose( + float(jdata.pool_data[i]["fixed_log_gas"]), + 0.0, + atol=1e-6, + ) + found = True + assert found, "No MAINNET pool found in test data" + + def test_arbitrum_gas_is_log_001(self, matched_data): + """ARBITRUM pools should have fixed_log_gas = log(0.01).""" + from quantammsim.calibration.joint_fit import prepare_joint_data + + jdata = prepare_joint_data(matched_data, fix_gas_to_chain=True) + found = False + for i, pid in enumerate(jdata.pool_ids): + if matched_data[pid]["chain"] == "ARBITRUM": + np.testing.assert_allclose( + float(jdata.pool_data[i]["fixed_log_gas"]), + np.log(0.01), + rtol=1e-6, + ) + found = True + assert found, "No ARBITRUM pool found in test data" + + def test_x_attr_has_correct_shape(self, matched_data): + from quantammsim.calibration.joint_fit import prepare_joint_data + + jdata = prepare_joint_data(matched_data, fix_gas_to_chain=True) + assert jdata.x_attr.shape[0] == len(jdata.pool_ids) + assert jdata.x_attr.shape[1] == len(jdata.attr_names) + + def test_pool_data_has_obs_arrays(self, matched_data): + from quantammsim.calibration.joint_fit import prepare_joint_data + + jdata = prepare_joint_data(matched_data, fix_gas_to_chain=True) + for pd_i in jdata.pool_data: + assert pd_i["x_obs"].ndim == 2 + assert pd_i["y_obs"].ndim == 1 + assert pd_i["day_indices"].ndim == 1 + assert pd_i["x_obs"].shape[0] == pd_i["y_obs"].shape[0] + + +class TestPackUnpackJointFixedGas: + """Test packing/unpacking joint params with fix_gas=True.""" + + def test_pack_fixed_gas_shape(self): + from quantammsim.calibration.joint_fit import pack_joint_params_fixed_gas + + k_attr = 6 + n_pools = 2 + noise_params = jnp.zeros((n_pools, K_OBS)) + flat = pack_joint_params_fixed_gas( + 1.0, jnp.zeros(k_attr), noise_params + ) + # Layout: [bias_cad, W_cad(6), noise(2*8)] = 1+6+16 = 23 + assert flat.shape == (1 + k_attr + n_pools * K_OBS,) + + def test_pack_fixed_gas_shorter_than_free(self): + from quantammsim.calibration.joint_fit import ( + pack_joint_params, + pack_joint_params_fixed_gas, + ) + + k_attr = 6 + n_pools = 2 + noise = jnp.zeros((n_pools, K_OBS)) + + free = pack_joint_params(1.0, 2.0, jnp.zeros(k_attr), + jnp.zeros(k_attr), noise) + fixed = pack_joint_params_fixed_gas(1.0, jnp.zeros(k_attr), noise) + # Fixed is shorter by: 1 (bias_gas) + k_attr (W_gas) + assert free.shape[0] - fixed.shape[0] == 1 + k_attr + + def test_unpack_fixed_gas_no_gas_keys(self): + from quantammsim.calibration.joint_fit import ( + pack_joint_params_fixed_gas, + unpack_joint_params, + ) + + k_attr = 6 + n_pools = 2 + noise = jnp.zeros((n_pools, K_OBS)) + flat = pack_joint_params_fixed_gas( + 1.0, jnp.ones(k_attr) * 0.5, noise + ) + + config = {"k_attr": k_attr, "n_pools": n_pools, + "mode": "per_pool_noise", "fix_gas": True} + params = unpack_joint_params(flat, config) + + assert "bias_cad" in params + assert "W_cad" in params + assert "noise_coeffs" in params + assert "bias_gas" not in params + assert "W_gas" not in params + + def test_unpack_roundtrip_per_pool_noise(self): + from quantammsim.calibration.joint_fit import ( + pack_joint_params_fixed_gas, + unpack_joint_params, + ) + + k_attr = 4 + n_pools = 3 + bias_cad = 2.5 + W_cad = jnp.array([0.1, -0.2, 0.3, -0.4]) + noise = jnp.arange(n_pools * K_OBS, dtype=float).reshape(n_pools, K_OBS) + + flat = pack_joint_params_fixed_gas(bias_cad, W_cad, noise) + config = {"k_attr": k_attr, "n_pools": n_pools, + "mode": "per_pool_noise", "fix_gas": True} + params = unpack_joint_params(flat, config) + + np.testing.assert_allclose(params["bias_cad"], bias_cad) + np.testing.assert_allclose(params["W_cad"], W_cad) + np.testing.assert_allclose(params["noise_coeffs"], noise) + + def test_unpack_roundtrip_shared_noise(self): + from quantammsim.calibration.joint_fit import ( + pack_joint_params_fixed_gas, + unpack_joint_params, + ) + + k_attr = 4 + bias_cad = 1.5 + W_cad = jnp.array([0.5, -0.5, 0.1, -0.1]) + # shared_noise: (1+k_attr, K_OBS) = (5, 8) + noise = jnp.arange((1 + k_attr) * K_OBS, dtype=float).reshape( + 1 + k_attr, K_OBS + ) + + flat = pack_joint_params_fixed_gas(bias_cad, W_cad, noise) + config = {"k_attr": k_attr, "n_pools": 99, + "mode": "shared_noise", "fix_gas": True} + params = unpack_joint_params(flat, config) + + np.testing.assert_allclose(params["bias_cad"], bias_cad) + np.testing.assert_allclose(params["W_cad"], W_cad) + np.testing.assert_allclose(params["bias_noise"], noise[0]) + np.testing.assert_allclose(params["W_noise"], noise[1:]) + + +class TestJointLossFixedGas: + """Test joint loss function with fix_gas=True.""" + + def _make_loss_fn(self, matched_data, mode="per_pool_noise"): + from quantammsim.calibration.joint_fit import ( + make_initial_joint_params, + make_joint_loss_fn, + prepare_joint_data, + ) + + jdata = prepare_joint_data( + matched_data, fix_gas_to_chain=True + ) + init = make_initial_joint_params( + jdata, mode=mode, fix_gas=True + ) + loss_fn = make_joint_loss_fn( + jdata, mode=mode, fix_gas=True + ) + return loss_fn, init, jdata + + def test_loss_differentiable_and_nonzero_grad(self, matched_data): + loss_fn, init, _ = self._make_loss_fn(matched_data) + grad = jax.grad(loss_fn)(init) + assert grad.shape == init.shape + assert jnp.all(jnp.isfinite(grad)) + assert float(jnp.sum(jnp.abs(grad))) > 1e-10 + + def test_no_gas_regularization_but_cad_regularization_works(self, matched_data): + """alpha_gas has no effect, but alpha_cad DOES.""" + from quantammsim.calibration.joint_fit import ( + make_initial_joint_params, + make_joint_loss_fn, + prepare_joint_data, + ) + + jdata = prepare_joint_data(matched_data, fix_gas_to_chain=True) + init = make_initial_joint_params(jdata, mode="per_pool_noise", fix_gas=True) + + # alpha_gas shouldn't matter + loss_fn_a = make_joint_loss_fn( + jdata, mode="per_pool_noise", fix_gas=True, alpha_gas=0.0 + ) + loss_fn_b = make_joint_loss_fn( + jdata, mode="per_pool_noise", fix_gas=True, alpha_gas=100.0 + ) + np.testing.assert_allclose( + float(loss_fn_a(init)), float(loss_fn_b(init)), rtol=1e-6 + ) + + # alpha_cad SHOULD matter (positive control) + loss_fn_no_reg = make_joint_loss_fn( + jdata, mode="per_pool_noise", fix_gas=True, alpha_cad=0.0 + ) + loss_fn_big_reg = make_joint_loss_fn( + jdata, mode="per_pool_noise", fix_gas=True, alpha_cad=100.0 + ) + # With W_cad initialized to non-zero by warm start, these should differ. + # Even with default init (W_cad=0), perturbation test: + init_perturbed = init.at[1].set(1.0) # perturb first W_cad element + loss_no = float(loss_fn_no_reg(init_perturbed)) + loss_big = float(loss_fn_big_reg(init_perturbed)) + assert loss_big > loss_no, "alpha_cad regularization has no effect" + + def test_shared_noise_mode(self, matched_data): + loss_fn, init, _ = self._make_loss_fn( + matched_data, mode="shared_noise" + ) + loss = loss_fn(init) + assert loss.shape == () + assert float(loss) >= 0 + # Verify gradient works for shared_noise too + grad = jax.grad(loss_fn)(init) + assert jnp.all(jnp.isfinite(grad)) + + def test_init_param_count_per_pool_noise(self, matched_data): + """Verify parameter count: 1(bias_cad) + k_attr(W_cad) + n_pools*K_OBS.""" + _, init, jdata = self._make_loss_fn(matched_data) + k_attr = jdata.x_attr.shape[1] + n_pools = len(jdata.pool_data) + expected = 1 + k_attr + n_pools * K_OBS + assert init.shape[0] == expected + + def test_init_param_count_shared_noise(self, matched_data): + """Verify: 1(bias_cad) + k_attr(W_cad) + (1+k_attr)*K_OBS.""" + _, init, jdata = self._make_loss_fn( + matched_data, mode="shared_noise" + ) + k_attr = jdata.x_attr.shape[1] + expected = 1 + k_attr + (1 + k_attr) * K_OBS + assert init.shape[0] == expected + + +class TestFitJointFixedGas: + """Test fit_joint with fix_gas_to_chain=True.""" + + def test_returns_result(self, matched_data): + from quantammsim.calibration.joint_fit import fit_joint + + result = fit_joint( + matched_data, mode="per_pool_noise", + fix_gas_to_chain=True, maxiter=20, + ) + assert isinstance(result, dict) + for key in ["bias_cad", "W_cad", "loss", "converged", "fix_gas"]: + assert key in result, f"Missing key: {key}" + + def test_fix_gas_flag_stored(self, matched_data): + from quantammsim.calibration.joint_fit import fit_joint + + result = fit_joint( + matched_data, mode="per_pool_noise", + fix_gas_to_chain=True, maxiter=10, + ) + assert result["fix_gas"] is True + + def test_gas_per_pool_stored_with_correct_values(self, matched_data): + """gas_per_pool has right length and chain-level values.""" + from quantammsim.calibration.joint_fit import fit_joint + from quantammsim.calibration.loss import CHAIN_GAS_USD + + result = fit_joint( + matched_data, mode="per_pool_noise", + fix_gas_to_chain=True, maxiter=10, + ) + assert "gas_per_pool" in result + assert len(result["gas_per_pool"]) == len(result["pool_ids"]) + for i, pid in enumerate(result["pool_ids"]): + chain = matched_data[pid]["chain"] + expected = CHAIN_GAS_USD.get(chain, 1.0) + assert result["gas_per_pool"][i] == expected + + def test_loss_decreases_substantially(self, matched_data): + from quantammsim.calibration.joint_fit import fit_joint + + result = fit_joint( + matched_data, mode="per_pool_noise", + fix_gas_to_chain=True, maxiter=50, + ) + # Must decrease, not just by epsilon + assert result["loss"] < result["init_loss"] * 0.999 + + def test_w_gas_and_bias_gas_are_zeros(self, matched_data): + """With fixed gas, W_gas and bias_gas should be zero placeholders.""" + from quantammsim.calibration.joint_fit import fit_joint + + result = fit_joint( + matched_data, mode="per_pool_noise", + fix_gas_to_chain=True, maxiter=10, + ) + np.testing.assert_allclose(result["W_gas"], 0.0) + assert result["bias_gas"] == 0.0 + + def test_warm_start_from_option_c_runs(self, matched_data): + """Warm start from Option C should run without error and reduce loss.""" + from quantammsim.calibration.joint_fit import fit_joint + from quantammsim.calibration.per_pool_fit import fit_all_pools + + option_c = fit_all_pools(matched_data, fix_gas_to_chain=True) + result_warm = fit_joint( + matched_data, mode="per_pool_noise", + fix_gas_to_chain=True, + init_from_option_c=option_c, + maxiter=50, + ) + # Should at least decrease from its own init + assert result_warm["loss"] <= result_warm["init_loss"] + assert result_warm["loss"] >= 0 + + def test_shared_noise_fixed_gas(self, matched_data): + from quantammsim.calibration.joint_fit import fit_joint + + result = fit_joint( + matched_data, mode="shared_noise", + fix_gas_to_chain=True, maxiter=20, + ) + assert "W_noise" in result + assert "bias_noise" in result + assert result["fix_gas"] is True + assert result["loss"] >= 0 + + def test_predict_new_pool_fixed_gas_pinned(self, matched_data): + """predict_new_pool_joint with zero attrs → gas_usd=1.0 exactly.""" + from quantammsim.calibration.joint_fit import fit_joint, predict_new_pool_joint + + result = fit_joint( + matched_data, mode="shared_noise", + fix_gas_to_chain=True, maxiter=20, + ) + k_attr = result["W_cad"].shape[0] + x_attr_new = np.zeros(k_attr) + pred = predict_new_pool_joint(result, x_attr_new) + + assert "cadence_minutes" in pred + assert pred["cadence_minutes"] > 0 + # bias_gas=0, W_gas=zeros → log_gas=0 → gas_usd=exp(0)=1.0 + np.testing.assert_allclose(pred["gas_usd"], 1.0, rtol=1e-6) + # cadence = exp(bias_cad + 0) = exp(bias_cad) + np.testing.assert_allclose( + pred["cadence_minutes"], np.exp(result["bias_cad"]), rtol=1e-6 + ) + # shared_noise mode should include noise_coeffs + assert "noise_coeffs" in pred + assert len(pred["noise_coeffs"]) == K_OBS + + def test_predict_matches_linear_model(self, matched_data): + """predict_new_pool_joint computes cadence = exp(bias_cad + W_cad @ x).""" + from quantammsim.calibration.joint_fit import fit_joint, predict_new_pool_joint + + result = fit_joint( + matched_data, mode="shared_noise", + fix_gas_to_chain=True, maxiter=20, + ) + k_attr = result["W_cad"].shape[0] + x_test = np.random.RandomState(42).randn(k_attr) + + pred = predict_new_pool_joint(result, x_test) + + # Verify cadence prediction directly against the linear model + expected_cadence = float(np.exp( + result["bias_cad"] + result["W_cad"] @ x_test + )) + np.testing.assert_allclose( + pred["cadence_minutes"], expected_cadence, rtol=1e-6, + ) + # Gas is fixed → always exp(0) = 1.0 + np.testing.assert_allclose(pred["gas_usd"], 1.0, rtol=1e-6) + + +class TestMakeBoundsFixedGas: + """Test _make_bounds with fix_gas=True.""" + + def test_bounds_count_per_pool_noise(self): + from quantammsim.calibration.joint_fit import _make_bounds + + k_attr = 6 + n_pools = 3 + bounds = _make_bounds(k_attr, n_pools, "per_pool_noise", fix_gas=True) + # 1(bias_cad) + 6(W_cad) + 3*8(noise) = 31 + assert len(bounds) == 1 + k_attr + n_pools * K_OBS + + def test_bounds_count_shared_noise(self): + from quantammsim.calibration.joint_fit import _make_bounds + + k_attr = 6 + n_pools = 3 + bounds = _make_bounds(k_attr, n_pools, "shared_noise", fix_gas=True) + # 1(bias_cad) + 6(W_cad) + (1+6)*8(noise) = 63 + assert len(bounds) == 1 + k_attr + (1 + k_attr) * K_OBS + + def test_bounds_fewer_with_fixed_gas(self): + from quantammsim.calibration.joint_fit import _make_bounds + + k_attr = 6 + n_pools = 3 + free = _make_bounds(k_attr, n_pools, "per_pool_noise", fix_gas=False) + fixed = _make_bounds(k_attr, n_pools, "per_pool_noise", fix_gas=True) + # Difference: 1(bias_gas) + k_attr(W_gas) + assert len(free) - len(fixed) == 1 + k_attr diff --git a/tests/calibration/test_loss_fixed_gas.py b/tests/calibration/test_loss_fixed_gas.py new file mode 100644 index 0000000..50a5122 --- /dev/null +++ b/tests/calibration/test_loss_fixed_gas.py @@ -0,0 +1,364 @@ +"""Tests for fixed-gas extensions in quantammsim.calibration.loss.""" + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from tests.calibration.conftest import K_OBS, N_DAYS + + +class TestChainGasUSD: + """Test CHAIN_GAS_USD constants are correct and complete.""" + + def test_known_chains(self): + from quantammsim.calibration.loss import CHAIN_GAS_USD + + assert CHAIN_GAS_USD["MAINNET"] == 1.0 + assert CHAIN_GAS_USD["POLYGON"] == 0.005 + assert CHAIN_GAS_USD["GNOSIS"] == 0.001 + assert CHAIN_GAS_USD["ARBITRUM"] == 0.01 + assert CHAIN_GAS_USD["BASE"] == 0.005 + assert CHAIN_GAS_USD["SONIC"] == 0.005 + + def test_all_values_positive(self): + from quantammsim.calibration.loss import CHAIN_GAS_USD + + for chain, cost in CHAIN_GAS_USD.items(): + assert cost > 0, f"{chain} gas cost must be positive" + + def test_mainnet_most_expensive(self): + from quantammsim.calibration.loss import CHAIN_GAS_USD + + mainnet = CHAIN_GAS_USD["MAINNET"] + for chain, cost in CHAIN_GAS_USD.items(): + if chain != "MAINNET": + assert cost < mainnet, f"{chain} should be cheaper than MAINNET" + + def test_six_chains(self): + from quantammsim.calibration.loss import CHAIN_GAS_USD + + assert len(CHAIN_GAS_USD) == 6 + + +class TestPackUnpackFixedGas: + """Test pack/unpack for fixed-gas param vectors.""" + + def test_pack_shape(self): + from quantammsim.calibration.loss import pack_params_fixed_gas + + flat = pack_params_fixed_gas(2.5, jnp.zeros(K_OBS)) + assert flat.shape == (1 + K_OBS,) + + def test_pack_shape_is_one_shorter_than_free(self): + from quantammsim.calibration.loss import pack_params, pack_params_fixed_gas + + free = pack_params(2.5, 0.0, jnp.zeros(K_OBS)) + fixed = pack_params_fixed_gas(2.5, jnp.zeros(K_OBS)) + assert free.shape[0] == fixed.shape[0] + 1 + + def test_roundtrip(self): + from quantammsim.calibration.loss import ( + pack_params_fixed_gas, + unpack_params_fixed_gas, + ) + + log_cad = 2.5 + noise_coeffs = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + + flat = pack_params_fixed_gas(log_cad, noise_coeffs) + lc, nc = unpack_params_fixed_gas(flat) + + np.testing.assert_allclose(lc, log_cad) + np.testing.assert_allclose(nc, noise_coeffs) + + def test_unpack_log_cadence_position(self): + """log_cadence is the first element.""" + from quantammsim.calibration.loss import pack_params_fixed_gas + + flat = pack_params_fixed_gas(3.14, jnp.ones(K_OBS) * 99.0) + np.testing.assert_allclose(flat[0], 3.14) + + def test_unpack_noise_coeffs_position(self): + """noise_coeffs are elements [1:].""" + from quantammsim.calibration.loss import pack_params_fixed_gas + + nc = jnp.arange(1, K_OBS + 1, dtype=float) + flat = pack_params_fixed_gas(0.0, nc) + np.testing.assert_allclose(flat[1:], nc) + + +class TestPoolLossFixedGas: + """Test pool_loss_fixed_gas with pinned numerical values.""" + + def _make_params(self, log_cad=None, noise_coeffs=None): + from quantammsim.calibration.loss import pack_params_fixed_gas + + if log_cad is None: + log_cad = float(jnp.log(jnp.array(12.0))) + if noise_coeffs is None: + noise_coeffs = jnp.zeros(K_OBS).at[0].set(8.0) + return pack_params_fixed_gas(log_cad, noise_coeffs) + + def _make_inputs(self, synthetic_pool_coeffs, synthetic_x_obs): + n_obs = synthetic_x_obs.shape[0] + n_days = int(synthetic_pool_coeffs.values.shape[2]) + day_indices = jnp.array(np.arange(n_obs) % n_days) + y_obs = jnp.ones(n_obs) * 9.0 + return jnp.array(synthetic_x_obs), y_obs, day_indices + + def test_pinned_loss_value(self, synthetic_pool_coeffs, synthetic_x_obs): + """Loss at known params must match precomputed value.""" + from quantammsim.calibration.loss import pool_loss_fixed_gas + + params = self._make_params() + x_obs, y_obs, day_indices = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + fixed_log_gas = jnp.log(jnp.array(1.0)) + loss = pool_loss_fixed_gas( + params, fixed_log_gas, synthetic_pool_coeffs, x_obs, y_obs, day_indices + ) + # cadence=12, gas=1.0, noise=[8,0..0], y=9.0 → pinned + np.testing.assert_allclose(float(loss), 0.001727, atol=1e-4) + + def test_pinned_loss_at_multiple_gas_values( + self, synthetic_pool_coeffs, synthetic_x_obs + ): + """Pinned loss values at gas=0.01, 1.0, 5.0.""" + from quantammsim.calibration.loss import pool_loss_fixed_gas + + params = self._make_params() + x_obs, y_obs, day_indices = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + # Precomputed with _make_params defaults: log_cad=log(12), noise=[8,0..0] + expected = {0.01: 0.0763, 1.0: 0.00173, 5.0: 0.0313} + for gas_val, exp_loss in expected.items(): + lg = jnp.log(jnp.array(gas_val)) + loss = float(pool_loss_fixed_gas( + params, lg, synthetic_pool_coeffs, x_obs, y_obs, day_indices + )) + np.testing.assert_allclose(loss, exp_loss, atol=1e-3, + err_msg=f"gas={gas_val}") + + def test_zero_when_perfect(self, synthetic_pool_coeffs, synthetic_x_obs): + """Construct y_obs = log(V_arb + V_noise) exactly, verify loss ≈ 0.""" + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + from quantammsim.calibration.loss import ( + noise_volume, + pack_params_fixed_gas, + pool_loss_fixed_gas, + ) + + log_cad = jnp.log(jnp.array(12.0)) + fixed_log_gas = jnp.log(jnp.array(1.0)) + noise_coeffs = jnp.zeros(K_OBS).at[0].set(8.0) + + v_arb_all = interpolate_pool_daily( + synthetic_pool_coeffs, log_cad, jnp.exp(fixed_log_gas) + ) + n_obs = synthetic_x_obs.shape[0] + n_days = int(synthetic_pool_coeffs.values.shape[2]) + day_indices = jnp.array(np.arange(n_obs) % n_days) + v_arb = v_arb_all[day_indices] + v_noise = noise_volume(noise_coeffs, jnp.array(synthetic_x_obs)) + y_obs = jnp.log(jnp.maximum(v_arb + v_noise, 1e-6)) + + params = pack_params_fixed_gas(float(log_cad), noise_coeffs) + loss = pool_loss_fixed_gas( + params, fixed_log_gas, synthetic_pool_coeffs, + jnp.array(synthetic_x_obs), y_obs, day_indices, + ) + assert float(loss) < 1e-10 + + def test_matches_free_gas_at_same_value( + self, synthetic_pool_coeffs, synthetic_x_obs + ): + """Fixed-gas loss should equal free-gas loss when gas matches.""" + from quantammsim.calibration.loss import ( + pack_params, + pack_params_fixed_gas, + pool_loss, + pool_loss_fixed_gas, + ) + + log_cad = float(jnp.log(jnp.array(12.0))) + log_gas = float(jnp.log(jnp.array(1.0))) + noise_coeffs = jnp.zeros(K_OBS).at[0].set(8.0) + + x_obs, y_obs, day_indices = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + + free_params = pack_params(log_cad, log_gas, noise_coeffs) + fixed_params = pack_params_fixed_gas(log_cad, noise_coeffs) + + loss_free = pool_loss( + free_params, synthetic_pool_coeffs, x_obs, y_obs, day_indices + ) + loss_fixed = pool_loss_fixed_gas( + fixed_params, jnp.array(log_gas), synthetic_pool_coeffs, + x_obs, y_obs, day_indices, + ) + np.testing.assert_allclose(float(loss_free), float(loss_fixed), rtol=1e-6) + + def test_loss_varies_with_gas_within_grid(self, synthetic_pool_coeffs, synthetic_x_obs): + """Gas values within grid range [0, 5] should produce distinct losses.""" + from quantammsim.calibration.loss import pool_loss_fixed_gas + + params = self._make_params() + x_obs, y_obs, day_indices = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + # Stay within grid range (gas_costs=[0, 1, 5]) to avoid extrapolation plateau + losses = [] + for gas in [0.01, 0.1, 1.0, 3.0]: + lg = jnp.log(jnp.array(gas)) + loss = float(pool_loss_fixed_gas( + params, lg, synthetic_pool_coeffs, x_obs, y_obs, day_indices + )) + losses.append(loss) + # All 4 within-grid gas values should give distinct losses + assert len(set(f"{l:.8f}" for l in losses)) == 4 + + def test_day_indices_affect_loss(self, synthetic_pool_coeffs, synthetic_x_obs): + """Different day_indices must produce different loss — verifies per-day V_arb is used.""" + from quantammsim.calibration.loss import pool_loss_fixed_gas + + params = self._make_params() + n_obs = synthetic_x_obs.shape[0] + n_days = int(synthetic_pool_coeffs.values.shape[2]) + x_obs = jnp.array(synthetic_x_obs) + y_obs = jnp.ones(n_obs) * 9.0 + fixed_log_gas = jnp.log(jnp.array(1.0)) + + day_idx_all_zero = jnp.zeros(n_obs, dtype=jnp.int32) + day_idx_varying = jnp.array(np.arange(n_obs) % n_days) + + loss_same = pool_loss_fixed_gas( + params, fixed_log_gas, synthetic_pool_coeffs, x_obs, y_obs, day_idx_all_zero + ) + loss_vary = pool_loss_fixed_gas( + params, fixed_log_gas, synthetic_pool_coeffs, x_obs, y_obs, day_idx_varying + ) + assert float(loss_same) != float(loss_vary) + + def test_grad_wrt_params(self, synthetic_pool_coeffs, synthetic_x_obs): + """Gradient w.r.t. params_flat has correct shape and is finite.""" + from quantammsim.calibration.loss import pool_loss_fixed_gas + + params = self._make_params() + x_obs, y_obs, day_indices = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + fixed_log_gas = jnp.log(jnp.array(1.0)) + + grad = jax.grad(pool_loss_fixed_gas, argnums=0)( + params, fixed_log_gas, synthetic_pool_coeffs, x_obs, y_obs, day_indices, + ) + assert grad.shape == (1 + K_OBS,) + assert jnp.all(jnp.isfinite(grad)) + # Gradient should be nonzero (we're not at the optimum) + assert float(jnp.sum(jnp.abs(grad))) > 1e-10 + + def test_grad_changes_with_gas( + self, synthetic_pool_coeffs, synthetic_x_obs + ): + """Gradient w.r.t. params should differ at different fixed gas values. + + fixed_log_gas affects V_arb through grid interpolation, which shifts + the loss landscape and thus the gradient. + """ + from quantammsim.calibration.loss import pool_loss_fixed_gas + + params = self._make_params() + x_obs, y_obs, day_indices = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + + grad_fn = jax.grad(pool_loss_fixed_gas, argnums=0) + grad_low = grad_fn( + params, jnp.log(jnp.array(0.01)), + synthetic_pool_coeffs, x_obs, y_obs, day_indices + ) + grad_high = grad_fn( + params, jnp.log(jnp.array(10.0)), + synthetic_pool_coeffs, x_obs, y_obs, day_indices + ) + # Gradients should differ because V_arb differs + assert not jnp.allclose(grad_low, grad_high, atol=1e-6) + + def test_extreme_negative_noise_finite(self, synthetic_pool_coeffs, synthetic_x_obs): + """Loss should remain finite with very negative noise intercept.""" + from quantammsim.calibration.loss import pool_loss_fixed_gas, pack_params_fixed_gas + + # Very negative noise intercept → V_noise ≈ 0, but V_arb still positive + nc = jnp.zeros(K_OBS).at[0].set(-100.0) + params = pack_params_fixed_gas(float(jnp.log(jnp.array(12.0))), nc) + + n_obs = synthetic_x_obs.shape[0] + n_days = int(synthetic_pool_coeffs.values.shape[2]) + day_indices = jnp.array(np.arange(n_obs) % n_days) + x_obs = jnp.array(synthetic_x_obs) + y_obs = jnp.ones(n_obs) * 9.0 + fixed_log_gas = jnp.log(jnp.array(1.0)) + + loss = pool_loss_fixed_gas( + params, fixed_log_gas, synthetic_pool_coeffs, x_obs, y_obs, day_indices + ) + assert jnp.isfinite(loss) + # V_noise ≈ 0, so log(V_arb) ≈ 8.5 vs y=9.0 → nonzero loss + assert float(loss) > 0.01 + # Gradient should also be finite + grad = jax.grad(pool_loss_fixed_gas, argnums=0)( + params, fixed_log_gas, synthetic_pool_coeffs, x_obs, y_obs, day_indices + ) + assert jnp.all(jnp.isfinite(grad)) + + def test_boundary_clamp_cadence(self, synthetic_pool_coeffs, synthetic_x_obs): + """Cadence below grid min should clamp — loss at cad=0.5 equals cad=1.0.""" + from quantammsim.calibration.loss import pack_params_fixed_gas, pool_loss_fixed_gas + + x_obs, y_obs, day_indices = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + nc = jnp.zeros(K_OBS).at[0].set(8.0) + fixed_log_gas = jnp.log(jnp.array(1.0)) + + params_below = pack_params_fixed_gas(float(jnp.log(jnp.array(0.5))), nc) + params_at_min = pack_params_fixed_gas(float(jnp.log(jnp.array(1.0))), nc) + + loss_below = pool_loss_fixed_gas( + params_below, fixed_log_gas, synthetic_pool_coeffs, + x_obs, y_obs, day_indices, + ) + loss_at_min = pool_loss_fixed_gas( + params_at_min, fixed_log_gas, synthetic_pool_coeffs, + x_obs, y_obs, day_indices, + ) + np.testing.assert_allclose(float(loss_below), float(loss_at_min), rtol=1e-6) + + def test_boundary_clamp_gas(self, synthetic_pool_coeffs, synthetic_x_obs): + """Gas above grid max should clamp — loss at gas=10 equals gas=5.""" + from quantammsim.calibration.loss import pool_loss_fixed_gas + + params = self._make_params() + x_obs, y_obs, day_indices = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + + loss_above = pool_loss_fixed_gas( + params, jnp.log(jnp.array(10.0)), synthetic_pool_coeffs, + x_obs, y_obs, day_indices, + ) + loss_at_max = pool_loss_fixed_gas( + params, jnp.log(jnp.array(5.0)), synthetic_pool_coeffs, + x_obs, y_obs, day_indices, + ) + np.testing.assert_allclose(float(loss_above), float(loss_at_max), rtol=1e-6) + + def test_k_obs_matches_loss_module(self): + """K_OBS in conftest must match K_OBS in loss.py.""" + from quantammsim.calibration.loss import K_OBS as K_OBS_IMPL + assert K_OBS == K_OBS_IMPL diff --git a/tests/calibration/test_per_pool_fit_fixed_gas.py b/tests/calibration/test_per_pool_fit_fixed_gas.py new file mode 100644 index 0000000..df281c7 --- /dev/null +++ b/tests/calibration/test_per_pool_fit_fixed_gas.py @@ -0,0 +1,383 @@ +"""Tests for fixed-gas mode in quantammsim.calibration.per_pool_fit.""" + +import jax.numpy as jnp +import numpy as np +import pytest + +from tests.calibration.conftest import K_OBS, N_DAYS, POOL_IDS_FULL, POOL_PREFIXES + + +class TestInitialGuessFixedGas: + """Test make_initial_guess_fixed_gas.""" + + def test_shape(self, synthetic_x_obs): + from quantammsim.calibration.per_pool_fit import make_initial_guess_fixed_gas + + n_obs = synthetic_x_obs.shape[0] + y_obs = np.ones(n_obs) * 9.0 + init = make_initial_guess_fixed_gas(synthetic_x_obs, y_obs) + assert init.shape == (1 + K_OBS,) + + def test_one_shorter_than_free(self, synthetic_x_obs): + from quantammsim.calibration.per_pool_fit import ( + make_initial_guess, + make_initial_guess_fixed_gas, + ) + + n_obs = synthetic_x_obs.shape[0] + y_obs = np.ones(n_obs) * 9.0 + free = make_initial_guess(synthetic_x_obs, y_obs) + fixed = make_initial_guess_fixed_gas(synthetic_x_obs, y_obs) + assert free.shape[0] == fixed.shape[0] + 1 + + def test_log_cadence_default(self, synthetic_x_obs): + from quantammsim.calibration.per_pool_fit import make_initial_guess_fixed_gas + + n_obs = synthetic_x_obs.shape[0] + y_obs = np.ones(n_obs) * 9.0 + init = make_initial_guess_fixed_gas(synthetic_x_obs, y_obs) + np.testing.assert_allclose(init[0], np.log(12.0), atol=0.01) + + def test_pinned_ols_coefficients(self, synthetic_x_obs): + """OLS on constant y=9.0: intercept should dominate, others near zero.""" + from quantammsim.calibration.per_pool_fit import make_initial_guess_fixed_gas + + n_obs = synthetic_x_obs.shape[0] + y_obs = np.ones(n_obs) * 9.0 + init = make_initial_guess_fixed_gas(synthetic_x_obs, y_obs) + # Intercept should be close to 9.0 (y is constant) + np.testing.assert_allclose(init[1], 9.0, atol=0.01) + # Other noise coeffs should be near zero for constant y + assert np.all(np.abs(init[2:]) < 0.1) + + def test_noise_matches_free_gas_noise(self, synthetic_x_obs): + """OLS noise coeffs should be identical for free and fixed-gas init.""" + from quantammsim.calibration.per_pool_fit import ( + make_initial_guess, + make_initial_guess_fixed_gas, + ) + + n_obs = synthetic_x_obs.shape[0] + y_obs = np.ones(n_obs) * 9.0 + free = make_initial_guess(synthetic_x_obs, y_obs) + fixed = make_initial_guess_fixed_gas(synthetic_x_obs, y_obs) + np.testing.assert_allclose(free[2:], fixed[1:]) + + def test_ols_with_heterogeneous_y(self, synthetic_x_obs): + """OLS on heterogeneous y should produce different coeffs than constant y.""" + from quantammsim.calibration.per_pool_fit import make_initial_guess_fixed_gas + + # y correlated with TVL (column 1) + y_het = synthetic_x_obs[:, 1] * 0.5 + 5.0 + np.random.RandomState(42).randn( + synthetic_x_obs.shape[0]) * 0.01 + init_het = make_initial_guess_fixed_gas(synthetic_x_obs, y_het) + init_const = make_initial_guess_fixed_gas( + synthetic_x_obs, np.ones(synthetic_x_obs.shape[0]) * 9.0 + ) + # Noise coefficients should differ substantially + assert np.max(np.abs(init_het[1:] - init_const[1:])) > 0.1 + + +class TestFitSinglePoolFixedGas: + """Test fit_single_pool with fixed_gas_usd.""" + + def _make_inputs(self, synthetic_pool_coeffs, synthetic_x_obs): + n_obs = synthetic_x_obs.shape[0] + n_days = int(synthetic_pool_coeffs.values.shape[2]) + day_indices = np.arange(n_obs) % n_days + y_obs = np.ones(n_obs) * 9.0 + return synthetic_x_obs, y_obs, day_indices + + def _make_gt_inputs(self, synthetic_pool_coeffs, synthetic_x_obs): + """Make ground-truth y_obs from known params for recovery test.""" + from quantammsim.calibration.grid_interpolation import interpolate_pool_daily + from quantammsim.calibration.loss import noise_volume + + n_obs = synthetic_x_obs.shape[0] + n_days = int(synthetic_pool_coeffs.values.shape[2]) + day_indices = np.arange(n_obs) % n_days + + TRUE_LOG_CAD = float(np.log(10.0)) + TRUE_GAS = 0.5 + TRUE_NC = np.zeros(K_OBS) + TRUE_NC[0] = 7.0 + TRUE_NC[1] = 0.3 + + v_arb = np.array(interpolate_pool_daily( + synthetic_pool_coeffs, jnp.array(TRUE_LOG_CAD), jnp.array(TRUE_GAS) + ))[day_indices] + v_noise = np.array(noise_volume(jnp.array(TRUE_NC), jnp.array(synthetic_x_obs))) + y_obs = np.log(np.maximum(v_arb + v_noise, 1e-6)) + return synthetic_x_obs, y_obs, day_indices, TRUE_LOG_CAD, TRUE_GAS, TRUE_NC + + def test_returns_result_dict(self, synthetic_pool_coeffs, synthetic_x_obs): + from quantammsim.calibration.per_pool_fit import fit_single_pool + + x_obs, y_obs, day_idx = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + result = fit_single_pool( + synthetic_pool_coeffs, x_obs, y_obs, day_idx, fixed_gas_usd=1.0 + ) + for key in [ + "log_cadence", "log_gas", "noise_coeffs", "loss", + "converged", "gas_fixed", + ]: + assert key in result, f"Missing key: {key}" + + def test_gas_fixed_flag(self, synthetic_pool_coeffs, synthetic_x_obs): + from quantammsim.calibration.per_pool_fit import fit_single_pool + + x_obs, y_obs, day_idx = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + result = fit_single_pool( + synthetic_pool_coeffs, x_obs, y_obs, day_idx, fixed_gas_usd=1.0 + ) + assert result["gas_fixed"] is True + + def test_free_gas_flag(self, synthetic_pool_coeffs, synthetic_x_obs): + from quantammsim.calibration.per_pool_fit import fit_single_pool + + x_obs, y_obs, day_idx = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + result = fit_single_pool( + synthetic_pool_coeffs, x_obs, y_obs, day_idx + ) + assert result["gas_fixed"] is False + + def test_gas_usd_pinned(self, synthetic_pool_coeffs, synthetic_x_obs): + """gas_usd in result must exactly match the fixed value.""" + from quantammsim.calibration.per_pool_fit import fit_single_pool + + x_obs, y_obs, day_idx = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + for gas_val in [0.001, 0.01, 0.5, 1.0, 5.0]: + result = fit_single_pool( + synthetic_pool_coeffs, x_obs, y_obs, day_idx, + fixed_gas_usd=gas_val, + ) + assert result["gas_usd"] == gas_val + + def test_log_gas_pinned(self, synthetic_pool_coeffs, synthetic_x_obs): + """log_gas must equal log(fixed_gas_usd).""" + from quantammsim.calibration.per_pool_fit import fit_single_pool + + x_obs, y_obs, day_idx = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + result = fit_single_pool( + synthetic_pool_coeffs, x_obs, y_obs, day_idx, fixed_gas_usd=2.5, + ) + np.testing.assert_allclose( + result["log_gas"], np.log(2.5), rtol=1e-6, + ) + + def test_pinned_fit_on_constant_y(self, synthetic_pool_coeffs, synthetic_x_obs): + """Pinned fitted values for fixed_gas=1.0, y=9.0.""" + from quantammsim.calibration.per_pool_fit import fit_single_pool + + x_obs, y_obs, day_idx = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + result = fit_single_pool( + synthetic_pool_coeffs, x_obs, y_obs, day_idx, fixed_gas_usd=1.0, + ) + assert result["converged"] + np.testing.assert_allclose(result["loss"], 1.30e-5, atol=5e-5) + np.testing.assert_allclose( + result["noise_coeffs"][0], 8.989, atol=0.05, + ) + assert 5.0 <= result["cadence_minutes"] <= 15.0 + + def test_ground_truth_recovery_fixed_gas( + self, synthetic_pool_coeffs, synthetic_x_obs + ): + """Fit on ground-truth y with correct gas → near-zero loss, correct cadence.""" + from quantammsim.calibration.per_pool_fit import fit_single_pool + + x_obs, y_obs, day_idx, true_lc, true_gas, true_nc = self._make_gt_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + result = fit_single_pool( + synthetic_pool_coeffs, x_obs, y_obs, day_idx, fixed_gas_usd=true_gas, + ) + assert result["converged"] + assert result["loss"] < 1e-5 + + def test_ground_truth_recovery_free_gas( + self, synthetic_pool_coeffs, synthetic_x_obs + ): + """Fit on ground-truth y with free gas → near-zero loss.""" + from quantammsim.calibration.per_pool_fit import fit_single_pool + + x_obs, y_obs, day_idx, _, _, _ = self._make_gt_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + result = fit_single_pool( + synthetic_pool_coeffs, x_obs, y_obs, day_idx, + ) + assert result["converged"] + assert result["loss"] < 1e-5 + + def test_loss_decreases_from_init(self, synthetic_pool_coeffs, synthetic_x_obs): + from quantammsim.calibration.loss import pool_loss_fixed_gas + from quantammsim.calibration.per_pool_fit import ( + fit_single_pool, + make_initial_guess_fixed_gas, + ) + + x_obs, y_obs, day_idx = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + init = make_initial_guess_fixed_gas(x_obs, y_obs) + fixed_log_gas = jnp.float64(np.log(1.0)) + init_loss = float(pool_loss_fixed_gas( + jnp.array(init), fixed_log_gas, synthetic_pool_coeffs, + jnp.array(x_obs), jnp.array(y_obs), jnp.array(day_idx), + )) + + result = fit_single_pool( + synthetic_pool_coeffs, x_obs, y_obs, day_idx, fixed_gas_usd=1.0, + ) + assert result["loss"] < init_loss * 0.99 # at least 1% improvement + + def test_different_fixed_gas_different_cadence( + self, synthetic_pool_coeffs, synthetic_x_obs + ): + """Gas values spanning 5000x should produce substantially different cadences.""" + from quantammsim.calibration.per_pool_fit import fit_single_pool + + x_obs, y_obs, day_idx = self._make_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + r_low = fit_single_pool( + synthetic_pool_coeffs, x_obs, y_obs, day_idx, fixed_gas_usd=0.001, + ) + r_high = fit_single_pool( + synthetic_pool_coeffs, x_obs, y_obs, day_idx, fixed_gas_usd=5.0, + ) + # 5000x gas range → cadences should differ substantially + assert abs(r_low["log_cadence"] - r_high["log_cadence"]) > 0.1 + + def test_fixed_vs_free_gas_on_gt_data( + self, synthetic_pool_coeffs, synthetic_x_obs + ): + """Free gas should achieve loss ≤ fixed gas (more degrees of freedom).""" + from quantammsim.calibration.per_pool_fit import fit_single_pool + + x_obs, y_obs, day_idx, _, _, _ = self._make_gt_inputs( + synthetic_pool_coeffs, synthetic_x_obs + ) + r_free = fit_single_pool( + synthetic_pool_coeffs, x_obs, y_obs, day_idx, + ) + r_fixed = fit_single_pool( + synthetic_pool_coeffs, x_obs, y_obs, day_idx, fixed_gas_usd=1.0, + ) + # Free gas has strictly more freedom → should do at least as well + assert r_free["loss"] <= r_fixed["loss"] * 1.01 + + +class TestFitAllPoolsFixedGas: + """Test fit_all_pools with fix_gas_to_chain=True.""" + + def _make_matched(self, synthetic_daily_grid, synthetic_panel, tmp_path): + from quantammsim.calibration.pool_data import match_grids_to_panel + + grid_dir = tmp_path / "grids" + grid_dir.mkdir() + for prefix in POOL_PREFIXES: + synthetic_daily_grid.to_parquet( + grid_dir / f"{prefix}_daily.parquet", index=False + ) + return match_grids_to_panel(str(grid_dir), synthetic_panel) + + def test_all_gas_fixed( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.per_pool_fit import fit_all_pools + + matched = self._make_matched( + synthetic_daily_grid, synthetic_panel, tmp_path + ) + results = fit_all_pools(matched, fix_gas_to_chain=True) + for prefix, res in results.items(): + assert res["gas_fixed"] is True + + def test_gas_matches_chain( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + """Each pool's gas_usd should match CHAIN_GAS_USD[chain].""" + from quantammsim.calibration.loss import CHAIN_GAS_USD + from quantammsim.calibration.per_pool_fit import fit_all_pools + + matched = self._make_matched( + synthetic_daily_grid, synthetic_panel, tmp_path + ) + results = fit_all_pools(matched, fix_gas_to_chain=True) + for prefix, res in results.items(): + chain = res["chain"] + expected = CHAIN_GAS_USD.get(chain, 1.0) + assert res["gas_usd"] == expected, ( + f"{prefix} ({chain}): gas_usd={res['gas_usd']} != {expected}" + ) + + def test_free_gas_not_fixed( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.per_pool_fit import fit_all_pools + + matched = self._make_matched( + synthetic_daily_grid, synthetic_panel, tmp_path + ) + results = fit_all_pools(matched, fix_gas_to_chain=False) + for prefix, res in results.items(): + assert res["gas_fixed"] is False + + def test_both_pools_have_results( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.per_pool_fit import fit_all_pools + + matched = self._make_matched( + synthetic_daily_grid, synthetic_panel, tmp_path + ) + results = fit_all_pools(matched, fix_gas_to_chain=True) + assert len(results) == len(matched) + for prefix in matched: + assert prefix in results + assert results[prefix]["converged"] + + def test_metadata_preserved( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + """Each result should carry chain, fee, tokens from the matched data.""" + from quantammsim.calibration.per_pool_fit import fit_all_pools + + matched = self._make_matched( + synthetic_daily_grid, synthetic_panel, tmp_path + ) + results = fit_all_pools(matched, fix_gas_to_chain=True) + for prefix, res in results.items(): + assert res["chain"] == matched[prefix]["chain"] + assert res["tokens"] == matched[prefix]["tokens"] + assert np.isfinite(res["fee"]) + + def test_mainnet_gas_1_arbitrum_gas_001( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + """Pin: MAINNET pool gets gas=1.0, ARBITRUM pool gets gas=0.01.""" + from quantammsim.calibration.per_pool_fit import fit_all_pools + + matched = self._make_matched( + synthetic_daily_grid, synthetic_panel, tmp_path + ) + results = fit_all_pools(matched, fix_gas_to_chain=True) + for prefix, res in results.items(): + if res["chain"] == "MAINNET": + assert res["gas_usd"] == 1.0 + elif res["chain"] == "ARBITRUM": + assert res["gas_usd"] == 0.01 diff --git a/tests/calibration/test_pool_data.py b/tests/calibration/test_pool_data.py index b3a9cf7..4864c19 100644 --- a/tests/calibration/test_pool_data.py +++ b/tests/calibration/test_pool_data.py @@ -12,6 +12,151 @@ ) +class TestEncodeTokens: + """Test encode_tokens: token index, assignments, and covariate matrix.""" + + def _get_matched(self, synthetic_daily_grid, synthetic_panel, tmp_path): + from quantammsim.calibration.pool_data import match_grids_to_panel + + grid_dir = tmp_path / "grids" + grid_dir.mkdir() + for prefix in POOL_PREFIXES: + synthetic_daily_grid.to_parquet( + grid_dir / f"{prefix}_daily.parquet", index=False + ) + return match_grids_to_panel(str(grid_dir), synthetic_panel) + + def test_returns_expected_keys( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.pool_data import encode_tokens + + matched = self._get_matched( + synthetic_daily_grid, synthetic_panel, tmp_path + ) + result = encode_tokens(matched) + expected_keys = { + "token_index", "token_a_idx", "token_b_idx", + "x_token", "chain_idx", "chain_index", + "log_fees", "n_tokens", "n_chains", + } + assert expected_keys.issubset(result.keys()) + + def test_unique_tokens_discovered( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.pool_data import encode_tokens + + matched = self._get_matched( + synthetic_daily_grid, synthetic_panel, tmp_path + ) + result = encode_tokens(matched) + # Synthetic panel has tokens: BTC, ETH (pool 0) and AAVE, ETH (pool 1) + # Unique tokens: AAVE, BTC, ETH (sorted) + assert result["n_tokens"] == 3 + assert set(result["token_index"].keys()) == {"AAVE", "BTC", "ETH"} + # Indices should be contiguous 0..2 + assert set(result["token_index"].values()) == {0, 1, 2} + + def test_x_token_shape_and_intercept( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.pool_data import encode_tokens + + matched = self._get_matched( + synthetic_daily_grid, synthetic_panel, tmp_path + ) + result = encode_tokens(matched) + x_token = result["x_token"] + assert x_token.shape[0] == result["n_tokens"] # 3 tokens + assert x_token.shape[1] >= 4 # at least intercept + 3 binary flags + # Intercept column is all 1s + np.testing.assert_array_equal(x_token[:, 0], 1.0) + + def test_token_classifications( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.pool_data import encode_tokens + + matched = self._get_matched( + synthetic_daily_grid, synthetic_panel, tmp_path + ) + result = encode_tokens(matched) + ti = result["token_index"] + x_tok = result["x_token"] + # Column layout: [intercept, log_mcap, is_stable, is_eth_deriv, is_L1_native] + # ETH: is_eth_derivative=1, is_L1_native=1 + assert x_tok[ti["ETH"], 3] == 1.0 # is_eth_derivative + assert x_tok[ti["ETH"], 4] == 1.0 # is_L1_native + # AAVE: none of the binary flags + assert x_tok[ti["AAVE"], 2] == 0.0 # not stable + assert x_tok[ti["AAVE"], 3] == 0.0 # not eth_deriv + assert x_tok[ti["AAVE"], 4] == 0.0 # not L1_native + + def test_pool_token_mapping( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.pool_data import encode_tokens + + matched = self._get_matched( + synthetic_daily_grid, synthetic_panel, tmp_path + ) + result = encode_tokens(matched) + ti = result["token_index"] + + # Pool 0 (first sorted prefix): tokens = "BTC,ETH" + assert result["token_a_idx"][0] == ti["BTC"] + assert result["token_b_idx"][0] == ti["ETH"] + + # Pool 1 (second sorted prefix): tokens = "AAVE,ETH" + assert result["token_a_idx"][1] == ti["AAVE"] + assert result["token_b_idx"][1] == ti["ETH"] + + def test_chain_index( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.pool_data import encode_tokens + + matched = self._get_matched( + synthetic_daily_grid, synthetic_panel, tmp_path + ) + result = encode_tokens(matched) + assert result["n_chains"] == 2 + assert set(result["chain_index"].keys()) == {"ARBITRUM", "MAINNET"} + assert set(result["chain_index"].values()) == {0, 1} + + def test_chain_idx_mapping( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.pool_data import encode_tokens + + matched = self._get_matched( + synthetic_daily_grid, synthetic_panel, tmp_path + ) + result = encode_tokens(matched) + pool_ids = sorted(matched.keys()) + ci = result["chain_index"] + # Pool 0 is MAINNET, pool 1 is ARBITRUM + assert result["chain_idx"][0] == ci[matched[pool_ids[0]]["chain"]] + assert result["chain_idx"][1] == ci[matched[pool_ids[1]]["chain"]] + + def test_log_fees( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.pool_data import encode_tokens + + matched = self._get_matched( + synthetic_daily_grid, synthetic_panel, tmp_path + ) + result = encode_tokens(matched) + pool_ids = sorted(matched.keys()) + for i, pid in enumerate(pool_ids): + expected_fee = matched[pid]["fee"] + np.testing.assert_allclose( + result["log_fees"][i], np.log(expected_fee), rtol=1e-6 + ) + + class TestMatchGridsToPanel: """Test match_grids_to_panel: match grid parquets to panel rows.""" @@ -200,6 +345,62 @@ def test_x_obs_no_nans(self, synthetic_panel): assert not np.any(np.isnan(x)) +class TestBuildXObsReduced: + """Test build_x_obs with reduced=True: 4-column pruned covariates.""" + + def test_reduced_shape(self, synthetic_panel): + from quantammsim.calibration.pool_data import K_OBS_REDUCED, build_x_obs + + pool0 = synthetic_panel[synthetic_panel["pool_id"] == POOL_IDS_FULL[0]] + x = build_x_obs(pool0, reduced=True) + assert x.shape == (len(pool0), K_OBS_REDUCED) + assert K_OBS_REDUCED == 4 + + def test_reduced_columns(self, synthetic_panel): + from quantammsim.calibration.pool_data import build_x_obs + + pool0 = synthetic_panel[synthetic_panel["pool_id"] == POOL_IDS_FULL[0]] + x_full = build_x_obs(pool0) + x_red = build_x_obs(pool0, reduced=True) + + # col 0: intercept + np.testing.assert_array_equal(x_red[:, 0], 1.0) + # col 1: log_tvl_lag1 (same as full col 1) + np.testing.assert_allclose(x_red[:, 1], x_full[:, 1]) + # col 2: dow_sin (same as full col 6) + np.testing.assert_allclose(x_red[:, 2], x_full[:, 6]) + # col 3: dow_cos (same as full col 7) + np.testing.assert_allclose(x_red[:, 3], x_full[:, 7]) + + def test_reduced_no_sigma(self, synthetic_panel): + from quantammsim.calibration.pool_data import build_x_obs + + pool0 = synthetic_panel[synthetic_panel["pool_id"] == POOL_IDS_FULL[0]] + x_full = build_x_obs(pool0) + x_red = build_x_obs(pool0, reduced=True) + + # Sigma-dependent columns from full (2,3,5) should not appear + sigma_cols = x_full[:, [2, 3, 5]] + for col in range(x_red.shape[1]): + for scol in range(sigma_cols.shape[1]): + if not np.allclose(sigma_cols[:, scol], 0.0): + assert not np.allclose(x_red[:, col], sigma_cols[:, scol]) + + def test_default_unchanged(self, synthetic_panel): + from quantammsim.calibration.pool_data import build_x_obs + + pool0 = synthetic_panel[synthetic_panel["pool_id"] == POOL_IDS_FULL[0]] + x = build_x_obs(pool0) + assert x.shape == (len(pool0), K_OBS) + + def test_reduced_no_nans(self, synthetic_panel): + from quantammsim.calibration.pool_data import build_x_obs + + pool0 = synthetic_panel[synthetic_panel["pool_id"] == POOL_IDS_FULL[0]] + x = build_x_obs(pool0, reduced=True) + assert not np.any(np.isnan(x)) + + class TestBuildPoolAttributes: """Test build_pool_attributes: pool-level feature matrix.""" @@ -301,3 +502,306 @@ def test_attributes_returns_pool_order( assert isinstance(pool_ids, list) assert len(pool_ids) == len(matched) assert set(pool_ids) == set(matched.keys()) + + +class TestTokenCanonicalization: + """Test _CANON_MAP and canonicalization in encode_tokens.""" + + def test_canon_map_exists(self): + from quantammsim.calibration.pool_data import _CANON_MAP + assert isinstance(_CANON_MAP, dict) + + def test_canon_map_expected_mappings(self): + from quantammsim.calibration.pool_data import _CANON_MAP + assert _CANON_MAP["WETH"] == "ETH" + assert _CANON_MAP["waBasWETH"] == "ETH" + assert _CANON_MAP["waEthLidoWETH"] == "ETH" + assert _CANON_MAP["waEthLidowstETH"] == "wstETH" + assert _CANON_MAP["waGnowstETH"] == "wstETH" + assert _CANON_MAP["waBasUSDC"] == "USDC" + assert _CANON_MAP["scUSD"] == "USDC" + assert _CANON_MAP["sDAI"] == "DAI" + assert _CANON_MAP["WBTC"] == "BTC" + assert _CANON_MAP["waGnoGNO"] == "GNO" + assert _CANON_MAP["stS"] == "S" + + def test_canonicalize_passthrough(self): + from quantammsim.calibration.pool_data import _CANON_MAP + for tok in ["AAVE", "BTC", "ETH", "LINK", "ARB"]: + assert tok not in _CANON_MAP + + def test_canonicalize_function(self): + from quantammsim.calibration.pool_data import _canonicalize_token + assert _canonicalize_token("WETH") == "ETH" + assert _canonicalize_token("WBTC") == "BTC" + assert _canonicalize_token("ETH") == "ETH" + assert _canonicalize_token("AAVE") == "AAVE" + + def test_encode_tokens_canonicalize_false( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + """With canonicalize=False, same result as v1 (no merging).""" + from quantammsim.calibration.pool_data import encode_tokens, match_grids_to_panel + + grid_dir = tmp_path / "grids" + grid_dir.mkdir() + for prefix in POOL_PREFIXES: + synthetic_daily_grid.to_parquet( + grid_dir / f"{prefix}_daily.parquet", index=False + ) + matched = match_grids_to_panel(str(grid_dir), synthetic_panel) + result = encode_tokens(matched, canonicalize=False) + # Synthetic uses BTC, ETH, AAVE — none in canon map, same either way + assert result["n_tokens"] == 3 + assert set(result["token_index"].keys()) == {"AAVE", "BTC", "ETH"} + + def test_encode_tokens_canonicalize_default( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + """Default canonicalize=True. Synthetic data unaffected (BTC, ETH, AAVE not in map).""" + from quantammsim.calibration.pool_data import encode_tokens, match_grids_to_panel + + grid_dir = tmp_path / "grids" + grid_dir.mkdir() + for prefix in POOL_PREFIXES: + synthetic_daily_grid.to_parquet( + grid_dir / f"{prefix}_daily.parquet", index=False + ) + matched = match_grids_to_panel(str(grid_dir), synthetic_panel) + result = encode_tokens(matched) # canonicalize=True by default + assert result["n_tokens"] == 3 + assert set(result["token_index"].keys()) == {"AAVE", "BTC", "ETH"} + + def test_encode_tokens_merges_wrapped_tokens(self): + """Synthetic matched dict with WETH+USDC and waBasWETH+WBTC → ETH, BTC, USDC.""" + from quantammsim.calibration.pool_data import encode_tokens + + # Minimal matched dict — encode_tokens only uses 'tokens' and 'fee' keys + matched = { + "pool_a": { + "tokens": "WETH,USDC", + "fee": 0.003, + "chain": "MAINNET", + }, + "pool_b": { + "tokens": "waBasWETH,WBTC", + "fee": 0.003, + "chain": "BASE", + }, + } + result = encode_tokens(matched, canonicalize=True) + # WETH→ETH, waBasWETH→ETH, WBTC→BTC → unique: {BTC, ETH, USDC} + assert result["n_tokens"] == 3 + assert set(result["token_index"].keys()) == {"BTC", "ETH", "USDC"} + + # Both pools should have ETH as token A (canonicalized) + ti = result["token_index"] + assert result["token_a_idx"][0] == ti["ETH"] # pool_a: WETH→ETH + assert result["token_a_idx"][1] == ti["ETH"] # pool_b: waBasWETH→ETH + + def test_encode_tokens_canon_false_keeps_wrapped(self): + """canonicalize=False keeps WETH and waBasWETH as separate tokens.""" + from quantammsim.calibration.pool_data import encode_tokens + + matched = { + "pool_a": { + "tokens": "WETH,USDC", + "fee": 0.003, + "chain": "MAINNET", + }, + "pool_b": { + "tokens": "waBasWETH,WBTC", + "fee": 0.003, + "chain": "BASE", + }, + } + result = encode_tokens(matched, canonicalize=False) + # No merging: WETH, USDC, waBasWETH, WBTC → 4 unique tokens + assert result["n_tokens"] == 4 + assert "WETH" in result["token_index"] + assert "waBasWETH" in result["token_index"] + + +class TestCrossPoolFeatures: + """Test build_cross_pool_x_obs: cross-pool lagged volume features.""" + + @pytest.fixture + def three_pool_panel(self): + """3 pools sharing tokens, 10 days each.""" + np.random.seed(42) + dates = pd.date_range("2025-12-01", periods=10, freq="D") + rows = [] + pool_configs = [ + ("0xpool_a_full_id_padding_to_66_chars_aaaaaaaaaaaaaaaaaaaaaaaaa", + "MAINNET", "ETH,USDC", 0.003), + ("0xpool_b_full_id_padding_to_66_chars_bbbbbbbbbbbbbbbbbbbbbbbbb", + "MAINNET", "ETH,AAVE", 0.003), + ("0xpool_c_full_id_padding_to_66_chars_ccccccccccccccccccccccccc", + "ARBITRUM", "AAVE,USDC", 0.01), + ] + for full_id, chain, tokens, fee in pool_configs: + for di, date in enumerate(dates): + tvl = 12.0 + 0.05 * np.sin(2 * np.pi * di / 7) + vol = 9.0 + 0.3 * np.random.randn() + rows.append({ + "pool_id": full_id, + "chain": chain, + "date": date, + "log_volume": vol, + "log_tvl": tvl, + "log_tvl_lag1": tvl - 0.01, + "volatility": 0.4, + "log_fee": np.log(fee), + "swap_fee": fee, + "tokens": tokens, + }) + return pd.DataFrame(rows) + + @pytest.fixture + def three_pool_matched(self, three_pool_panel): + """Minimal matched dict for 3 pools (no grid needed for x_obs tests).""" + matched = {} + for full_id in three_pool_panel["pool_id"].unique(): + prefix = full_id[:16] + rows = three_pool_panel[three_pool_panel["pool_id"] == full_id].copy() + rows = rows.reset_index(drop=True) + matched[prefix] = { + "panel": rows, + "pool_id": full_id, + "chain": rows.iloc[0]["chain"], + "fee": float(np.exp(rows.iloc[0]["log_fee"])), + "tokens": rows.iloc[0]["tokens"], + "weights": [0.5, 0.5], + } + return matched + + def test_build_cross_pool_x_obs_shape(self, three_pool_matched): + from quantammsim.calibration.pool_data import ( + K_OBS_CROSS, build_cross_pool_x_obs, + ) + + pid = sorted(three_pool_matched.keys())[0] + entry = three_pool_matched[pid] + x = build_cross_pool_x_obs(entry["panel"], three_pool_matched, pid) + # Drops first day → n_obs - 1 rows, K_OBS_CROSS=7 columns + assert x.shape[1] == K_OBS_CROSS + assert x.shape[0] == len(entry["panel"]) - 1 + + def test_first_four_cols_match_reduced(self, three_pool_matched): + from quantammsim.calibration.pool_data import ( + build_cross_pool_x_obs, build_x_obs, + ) + + pid = sorted(three_pool_matched.keys())[0] + entry = three_pool_matched[pid] + x_cross = build_cross_pool_x_obs(entry["panel"], three_pool_matched, pid) + x_reduced = build_x_obs(entry["panel"], reduced=True) + # First 4 columns should match (after dropping first row) + np.testing.assert_allclose(x_cross[:, :4], x_reduced[1:, :4]) + + def test_cross_vol_token_a_excludes_self(self, three_pool_matched): + """Peer average for token A excludes pool i itself.""" + from quantammsim.calibration.pool_data import build_cross_pool_x_obs + + pool_ids = sorted(three_pool_matched.keys()) + pid_a = pool_ids[0] # ETH,USDC + pid_b = pool_ids[1] # ETH,AAVE — shares ETH with pool_a + + x_a = build_cross_pool_x_obs( + three_pool_matched[pid_a]["panel"], + three_pool_matched, pid_a, + ) + # Column 4 = cross_vol_token_a (ETH peers excl self) + # Pool b also has ETH, so pool_a's cross_vol_token_a should use pool_b's volume + panel_b = three_pool_matched[pid_b]["panel"] + log_vol_b_lagged = panel_b["log_volume"].values[:-1] # lag by 1 + np.testing.assert_allclose(x_a[:, 4], log_vol_b_lagged, rtol=1e-6) + + def test_cross_vol_is_lagged(self, three_pool_matched): + """Features at day t use log_volume at day t-1.""" + from quantammsim.calibration.pool_data import build_cross_pool_x_obs + + pool_ids = sorted(three_pool_matched.keys()) + pid = pool_ids[0] + x = build_cross_pool_x_obs( + three_pool_matched[pid]["panel"], + three_pool_matched, pid, + ) + # x has n_obs - 1 rows (first day dropped) + # Row 0 of x corresponds to day 1 and should use day 0 volume + assert x.shape[0] > 0 + + def test_cross_vol_nan_free_after_first_day(self, three_pool_matched): + from quantammsim.calibration.pool_data import build_cross_pool_x_obs + + for pid in three_pool_matched: + x = build_cross_pool_x_obs( + three_pool_matched[pid]["panel"], + three_pool_matched, pid, + ) + assert not np.any(np.isnan(x)), f"NaNs in cross-pool x_obs for {pid}" + + def test_exclude_pool_changes_features(self, three_pool_matched): + """exclude_pool removes that pool from peer averages.""" + from quantammsim.calibration.pool_data import build_cross_pool_x_obs + + pool_ids = sorted(three_pool_matched.keys()) + pid = pool_ids[0] # ETH,USDC + + x_normal = build_cross_pool_x_obs( + three_pool_matched[pid]["panel"], + three_pool_matched, pid, + ) + x_excluded = build_cross_pool_x_obs( + three_pool_matched[pid]["panel"], + three_pool_matched, pid, + exclude_pool=pool_ids[1], # exclude the ETH peer + ) + # Chain feature (col 6) may change too; token A col (4) definitely changes + # since pool_b is the only ETH peer + # With only peer excluded, cross_vol_token_a should be NaN→fallback + assert not np.allclose(x_normal[:, 4], x_excluded[:, 4]) + + def test_single_token_pool_fallback(self): + """When a token appears in only one pool, its cross_vol uses global mean.""" + from quantammsim.calibration.pool_data import build_cross_pool_x_obs + + np.random.seed(42) + dates = pd.date_range("2025-12-01", periods=5, freq="D") + rows = [] + # Pool A: LINK,USDC — LINK is unique + for di, date in enumerate(dates): + rows.append({ + "pool_id": "0xsolo_link_pool_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + "chain": "MAINNET", "date": date, + "log_volume": 9.0 + 0.1 * di, + "log_tvl_lag1": 12.0, "volatility": 0.4, + "log_fee": np.log(0.003), "tokens": "LINK,USDC", + }) + # Pool B: ETH,USDC + for di, date in enumerate(dates): + rows.append({ + "pool_id": "0xpeer_eth_pool__bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", + "chain": "MAINNET", "date": date, + "log_volume": 10.0 + 0.1 * di, + "log_tvl_lag1": 13.0, "volatility": 0.4, + "log_fee": np.log(0.003), "tokens": "ETH,USDC", + }) + panel = pd.DataFrame(rows) + + matched = {} + for full_id in panel["pool_id"].unique(): + prefix = full_id[:16] + sub = panel[panel["pool_id"] == full_id].reset_index(drop=True) + matched[prefix] = { + "panel": sub, "pool_id": full_id, + "chain": sub.iloc[0]["chain"], + "fee": float(np.exp(sub.iloc[0]["log_fee"])), + "tokens": sub.iloc[0]["tokens"], "weights": [0.5, 0.5], + } + + pid_link = [p for p in matched if matched[p]["tokens"] == "LINK,USDC"][0] + x = build_cross_pool_x_obs(panel[panel["pool_id"] == matched[pid_link]["pool_id"]].reset_index(drop=True), + matched, pid_link) + # LINK has no peers → col 4 should be a fallback (global mean), not NaN + assert not np.any(np.isnan(x[:, 4])) diff --git a/tests/calibration/test_pool_data_volatility.py b/tests/calibration/test_pool_data_volatility.py new file mode 100644 index 0000000..d6e2dd1 --- /dev/null +++ b/tests/calibration/test_pool_data_volatility.py @@ -0,0 +1,453 @@ +"""Tests for Binance volatility and TOKEN_MAP in quantammsim.calibration.pool_data.""" + +import numpy as np +import pandas as pd +import pytest + +from tests.calibration.conftest import POOL_IDS_FULL + + +class TestTokenMap: + """Test TOKEN_MAP resolves Balancer tokens to Binance symbols correctly.""" + + def test_wrapped_native(self): + from quantammsim.calibration.pool_data import _resolve_binance_symbol + + assert _resolve_binance_symbol("WBTC") == "BTC" + assert _resolve_binance_symbol("WETH") == "ETH" + assert _resolve_binance_symbol("cbBTC") == "BTC" + + def test_lst_to_underlying(self): + from quantammsim.calibration.pool_data import _resolve_binance_symbol + + assert _resolve_binance_symbol("wstETH") == "ETH" + assert _resolve_binance_symbol("stETH") == "ETH" + assert _resolve_binance_symbol("rETH") == "ETH" + assert _resolve_binance_symbol("cbETH") == "ETH" + + def test_vault_tokens(self): + from quantammsim.calibration.pool_data import _resolve_binance_symbol + + assert _resolve_binance_symbol("waEthLidoWETH") == "ETH" + assert _resolve_binance_symbol("waEthLidowstETH") == "ETH" + assert _resolve_binance_symbol("waBasWETH") == "ETH" + assert _resolve_binance_symbol("waGnowstETH") == "ETH" + assert _resolve_binance_symbol("waGnoGNO") == "GNO" + + def test_stablecoins_map_to_usdc(self): + from quantammsim.calibration.pool_data import _resolve_binance_symbol + + for stable in ["DAI", "WXDAI", "sDAI", "USDT", "DOLA", "scUSD", + "USDC.e", "USDbC", "waBasUSDC"]: + assert _resolve_binance_symbol(stable) == "USDC", ( + f"{stable} should map to USDC" + ) + + def test_matic_variants(self): + from quantammsim.calibration.pool_data import _resolve_binance_symbol + + assert _resolve_binance_symbol("wPOL") == "POL" + assert _resolve_binance_symbol("WMATIC") == "POL" + assert _resolve_binance_symbol("MATIC") == "POL" + + def test_sonic_variants(self): + from quantammsim.calibration.pool_data import _resolve_binance_symbol + + assert _resolve_binance_symbol("wS") == "S" + assert _resolve_binance_symbol("stS") == "S" + + def test_passthrough_unknown(self): + from quantammsim.calibration.pool_data import _resolve_binance_symbol + + assert _resolve_binance_symbol("AAVE") == "AAVE" + assert _resolve_binance_symbol("LINK") == "LINK" + assert _resolve_binance_symbol("SNX") == "SNX" + + def test_jitosol(self): + from quantammsim.calibration.pool_data import _resolve_binance_symbol + + assert _resolve_binance_symbol("JitoSOL") == "SOL" + + +class TestGetAssetType: + """Test _get_asset_type classification.""" + + def test_stablecoins(self): + from quantammsim.calibration.pool_data import _get_asset_type + + for tok in ["USDC", "USDT", "DAI", "WXDAI", "sDAI", "DOLA", "scUSD"]: + assert _get_asset_type(tok, {}) == 0, f"{tok} should be stable (0)" + + def test_native_lst(self): + from quantammsim.calibration.pool_data import _get_asset_type + + for tok in ["WETH", "ETH", "wstETH", "WBTC", "BTC", "GNO", "S", "wS"]: + assert _get_asset_type(tok, {}) == 1, f"{tok} should be native/LST (1)" + + def test_volatile(self): + from quantammsim.calibration.pool_data import _get_asset_type + + for tok in ["AAVE", "LINK", "SNX", "CRV", "COMP"]: + assert _get_asset_type(tok, {}) == 2, f"{tok} should be volatile (2)" + + def test_mcap_override(self): + from quantammsim.calibration.pool_data import _get_asset_type + + mcaps = {"AAVE": {"asset_type": "stable", "mcap_usd": 1e9}} + assert _get_asset_type("AAVE", mcaps) == 0 # overridden to stable + + +class TestComputeBinancePairVolatility: + """Test compute_binance_pair_volatility with synthetic Binance-like data.""" + + @pytest.fixture + def fake_binance_dir(self, tmp_path): + """Create fake Binance minute parquets for ETH and AAVE.""" + np.random.seed(42) + n_minutes = 24 * 60 * 7 # 7 days of minute data + base_ts = int(pd.Timestamp("2025-01-01").timestamp() * 1000) + unix = base_ts + np.arange(n_minutes) * 60_000 + + # ETH: geometric brownian motion starting at 3000 + eth_log_returns = np.random.normal(0, 0.0005, n_minutes) + eth_prices = 3000.0 * np.exp(np.cumsum(eth_log_returns)) + eth_df = pd.DataFrame({"unix": unix, "close": eth_prices}) + eth_df.to_parquet(tmp_path / "ETH_USD.parquet", index=False) + + # AAVE: correlated with ETH but with higher vol + aave_log_returns = 0.6 * eth_log_returns + 0.4 * np.random.normal( + 0, 0.001, n_minutes + ) + aave_prices = 200.0 * np.exp(np.cumsum(aave_log_returns)) + aave_df = pd.DataFrame({"unix": unix, "close": aave_prices}) + aave_df.to_parquet(tmp_path / "AAVE_USD.parquet", index=False) + + # USDC: constant at $1 (stablecoin proxy) + usdc_df = pd.DataFrame({ + "unix": unix, "close": np.ones(n_minutes), + }) + usdc_df.to_parquet(tmp_path / "USDC_USD.parquet", index=False) + + return str(tmp_path) + + def test_pinned_volatility_values(self, fake_binance_dir): + """Exact pinned values with seed(42).""" + from quantammsim.calibration.pool_data import compute_binance_pair_volatility + + vol = compute_binance_pair_volatility("WETH", "AAVE", fake_binance_dir) + expected = np.array([ + 0.27103676, 0.23068148, 0.43763073, 0.35174542, + 0.26827274, 0.35256874, 0.27833725, + ]) + np.testing.assert_allclose(vol.values, expected, rtol=1e-4) + + def test_exactly_seven_days(self, fake_binance_dir): + from quantammsim.calibration.pool_data import compute_binance_pair_volatility + + vol = compute_binance_pair_volatility("WETH", "AAVE", fake_binance_dir) + assert len(vol) == 7 + + def test_values_positive(self, fake_binance_dir): + from quantammsim.calibration.pool_data import compute_binance_pair_volatility + + vol = compute_binance_pair_volatility("WETH", "AAVE", fake_binance_dir) + assert (vol > 0).all() + + def test_pinned_median(self, fake_binance_dir): + from quantammsim.calibration.pool_data import compute_binance_pair_volatility + + vol = compute_binance_pair_volatility("WETH", "AAVE", fake_binance_dir) + np.testing.assert_allclose(vol.median(), 0.2783, atol=0.001) + + def test_token_order_invariance(self, fake_binance_dir): + """vol(A,B) should equal vol(B,A) — log returns of reciprocal have same std.""" + from quantammsim.calibration.pool_data import compute_binance_pair_volatility + + vol_ab = compute_binance_pair_volatility("WETH", "AAVE", fake_binance_dir) + vol_ba = compute_binance_pair_volatility("AAVE", "WETH", fake_binance_dir) + np.testing.assert_allclose(vol_ab.values, vol_ba.values, rtol=1e-5) + + def test_stable_vs_volatile_uses_single_asset(self, fake_binance_dir): + """ETH/USDC should use just ETH price — verify against hand-computed ETH vol.""" + import os + + from quantammsim.calibration.pool_data import compute_binance_pair_volatility + + vol_pair = compute_binance_pair_volatility("WETH", "USDC", fake_binance_dir) + + # Hand-compute ETH-only vol for ground-truth comparison + eth = pd.read_parquet(os.path.join(fake_binance_dir, "ETH_USD.parquet")) + eth_ts = pd.DataFrame( + {"ratio": eth["close"].values}, + index=pd.to_datetime(eth["unix"].values, unit="ms", utc=True), + ) + hourly = eth_ts.resample("1h").last().dropna() + hourly["log_return"] = np.log(hourly["ratio"] / hourly["ratio"].shift(1)) + hourly = hourly.dropna() + hourly["date"] = hourly.index.date + daily_std = hourly.groupby("date")["log_return"].std() + expected = (daily_std * np.sqrt(24 * 365)).dropna() + expected = expected[expected > 0] + + np.testing.assert_allclose(vol_pair.values, expected.values, rtol=1e-5) + + def test_stable_stable_returns_none(self, fake_binance_dir): + from quantammsim.calibration.pool_data import compute_binance_pair_volatility + + vol = compute_binance_pair_volatility("USDC", "DAI", fake_binance_dir) + assert vol is None + + def test_same_underlying_returns_none(self, fake_binance_dir): + from quantammsim.calibration.pool_data import compute_binance_pair_volatility + + # WETH and wstETH both map to ETH + vol = compute_binance_pair_volatility("WETH", "wstETH", fake_binance_dir) + assert vol is None + + def test_missing_data_returns_none(self, fake_binance_dir): + from quantammsim.calibration.pool_data import compute_binance_pair_volatility + + vol = compute_binance_pair_volatility("WETH", "MAGIC", fake_binance_dir) + assert vol is None + + def test_daily_index_type(self, fake_binance_dir): + """Index should be datetime.date objects.""" + from quantammsim.calibration.pool_data import compute_binance_pair_volatility + import datetime + + vol = compute_binance_pair_volatility("WETH", "AAVE", fake_binance_dir) + for d in vol.index: + assert isinstance(d, datetime.date) + + def test_stable_a_volatile_b_uses_reciprocal(self, fake_binance_dir): + """DAI/WETH should use 1/ETH, giving same vol as WETH/DAI.""" + from quantammsim.calibration.pool_data import compute_binance_pair_volatility + + vol_forward = compute_binance_pair_volatility("WETH", "USDC", fake_binance_dir) + vol_reverse = compute_binance_pair_volatility("DAI", "WETH", fake_binance_dir) + # Both should be ETH vol (log returns of X and 1/X have same std) + np.testing.assert_allclose( + vol_forward.values, vol_reverse.values, rtol=1e-5, + ) + + +class TestReplacePanelVolatility: + """Test replace_panel_volatility_with_binance.""" + + @pytest.fixture + def fake_binance_dir(self, tmp_path): + """Minimal fake Binance data for 3 days.""" + np.random.seed(42) + n_minutes = 24 * 60 * 3 + base_ts = int(pd.Timestamp("2025-12-01").timestamp() * 1000) + unix = base_ts + np.arange(n_minutes) * 60_000 + + eth_prices = 3000.0 + np.cumsum(np.random.normal(0, 1.0, n_minutes)) + eth_df = pd.DataFrame({"unix": unix, "close": eth_prices}) + eth_df.to_parquet(tmp_path / "ETH_USD.parquet", index=False) + + btc_prices = 60000.0 + np.cumsum(np.random.normal(0, 5.0, n_minutes)) + btc_df = pd.DataFrame({"unix": unix, "close": btc_prices}) + btc_df.to_parquet(tmp_path / "BTC_USD.parquet", index=False) + + aave_prices = 200.0 + np.cumsum(np.random.normal(0, 0.5, n_minutes)) + aave_df = pd.DataFrame({"unix": unix, "close": aave_prices}) + aave_df.to_parquet(tmp_path / "AAVE_USD.parquet", index=False) + + return str(tmp_path) + + def test_does_not_modify_input(self, synthetic_panel, fake_binance_dir): + from quantammsim.calibration.pool_data import replace_panel_volatility_with_binance + + original_vol = synthetic_panel["volatility"].copy() + replace_panel_volatility_with_binance( + synthetic_panel, fake_binance_dir + ) + pd.testing.assert_series_equal(synthetic_panel["volatility"], original_vol) + + def test_no_nans_introduced(self, synthetic_panel, fake_binance_dir): + """Pools without Binance data should keep original volatility, not NaN.""" + from quantammsim.calibration.pool_data import replace_panel_volatility_with_binance + + result = replace_panel_volatility_with_binance( + synthetic_panel, fake_binance_dir + ) + assert result["volatility"].notna().all() + + def test_volatility_actually_changes(self, synthetic_panel, fake_binance_dir): + """At least some volatility values should differ after replacement.""" + from quantammsim.calibration.pool_data import replace_panel_volatility_with_binance + + original_vol = synthetic_panel["volatility"].values.copy() + result = replace_panel_volatility_with_binance( + synthetic_panel, fake_binance_dir + ) + # At least one pool has BTC,ETH or AAVE,ETH — both have Binance data, + # and dates overlap (panel starts 2025-12-01, fake data starts 2025-12-01). + n_changed = (result["volatility"].values != original_vol).sum() + assert n_changed > 0, "No volatility values were replaced" + + def test_replaced_values_are_positive(self, synthetic_panel, fake_binance_dir): + """Replaced volatility values must be positive.""" + from quantammsim.calibration.pool_data import replace_panel_volatility_with_binance + + result = replace_panel_volatility_with_binance( + synthetic_panel, fake_binance_dir + ) + assert (result["volatility"] > 0).all() + + def test_all_columns_preserved(self, synthetic_panel, fake_binance_dir): + """Output should have all original columns.""" + from quantammsim.calibration.pool_data import replace_panel_volatility_with_binance + + result = replace_panel_volatility_with_binance( + synthetic_panel, fake_binance_dir + ) + for col in synthetic_panel.columns: + assert col in result.columns + + def test_row_count_preserved(self, synthetic_panel, fake_binance_dir): + """Output should have the same number of rows.""" + from quantammsim.calibration.pool_data import replace_panel_volatility_with_binance + + result = replace_panel_volatility_with_binance( + synthetic_panel, fake_binance_dir + ) + assert len(result) == len(synthetic_panel) + + def test_replaced_values_match_binance_computation( + self, synthetic_panel, fake_binance_dir + ): + """Replaced vol values must equal compute_binance_pair_volatility exactly.""" + from quantammsim.calibration.pool_data import ( + compute_binance_pair_volatility, + replace_panel_volatility_with_binance, + ) + + result = replace_panel_volatility_with_binance( + synthetic_panel, fake_binance_dir + ) + + # BTC/ETH pool — compute expected vol independently + vol_btc_eth = compute_binance_pair_volatility("BTC", "ETH", fake_binance_dir) + assert vol_btc_eth is not None, "BTC/ETH vol should be computable" + vol_dict = vol_btc_eth.to_dict() + + pool0 = result[result["tokens"] == "BTC,ETH"].copy() + pool0_dates = pd.to_datetime(pool0["date"]).dt.date + matched_mask = pool0_dates.isin(vol_dict.keys()).values + matched = pool0[matched_mask] + assert len(matched) > 0, "No date overlap between panel and Binance data" + + for _, row in matched.iterrows(): + d = pd.to_datetime(row["date"]).date() + np.testing.assert_allclose( + row["volatility"], vol_dict[d], rtol=1e-6, + err_msg=f"BTC/ETH vol mismatch on {d}", + ) + + +class TestBuildPoolAttributeValues: + """Test build_pool_attributes returns correct numerical values, not just names.""" + + def _make_matched(self, synthetic_daily_grid, synthetic_panel, tmp_path): + from quantammsim.calibration.pool_data import match_grids_to_panel + from tests.calibration.conftest import POOL_PREFIXES + + grid_dir = tmp_path / "grids" + grid_dir.mkdir() + for prefix in POOL_PREFIXES: + synthetic_daily_grid.to_parquet( + grid_dir / f"{prefix}_daily.parquet", index=False + ) + return match_grids_to_panel(str(grid_dir), synthetic_panel) + + def test_chain_dummy_values( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + """MAINNET pool has chain_MAINNET=1, ARBITRUM pool has chain_MAINNET=0.""" + from quantammsim.calibration.pool_data import build_pool_attributes + + matched = self._make_matched( + synthetic_daily_grid, synthetic_panel, tmp_path + ) + X_attr, attr_names, pool_ids = build_pool_attributes(matched) + + # ARBITRUM is reference (alphabetically first), MAINNET gets a dummy + chain_idx = attr_names.index("chain_MAINNET") + for i, pid in enumerate(pool_ids): + if matched[pid]["chain"] == "MAINNET": + assert X_attr[i, chain_idx] == 1.0 + else: + assert X_attr[i, chain_idx] == 0.0 + + def test_log_fee_values( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + """log_fee should match panel values.""" + from quantammsim.calibration.pool_data import build_pool_attributes + + matched = self._make_matched( + synthetic_daily_grid, synthetic_panel, tmp_path + ) + X_attr, attr_names, pool_ids = build_pool_attributes(matched) + + fee_idx = attr_names.index("log_fee") + for i, pid in enumerate(pool_ids): + expected = np.log(matched[pid]["fee"]) + np.testing.assert_allclose(X_attr[i, fee_idx], expected, rtol=1e-3) + + def test_same_asset_type_for_btc_eth( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + """BTC,ETH pool: both native/LST → same_asset_type=1.""" + from quantammsim.calibration.pool_data import build_pool_attributes + + matched = self._make_matched( + synthetic_daily_grid, synthetic_panel, tmp_path + ) + X_attr, attr_names, pool_ids = build_pool_attributes(matched) + + sat_idx = attr_names.index("same_asset_type") + for i, pid in enumerate(pool_ids): + if matched[pid]["tokens"] == "BTC,ETH": + assert X_attr[i, sat_idx] == 1.0 + elif matched[pid]["tokens"] == "AAVE,ETH": + # AAVE=volatile(2), ETH=native(1) → different + assert X_attr[i, sat_idx] == 0.0 + + def test_pinned_attribute_values( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + """Pinned X_attr values for the two synthetic pools.""" + from quantammsim.calibration.pool_data import build_pool_attributes + + matched = self._make_matched( + synthetic_daily_grid, synthetic_panel, tmp_path + ) + X_attr, attr_names, pool_ids = build_pool_attributes(matched) + + # Pool 0 (0xaaaa = MAINNET, BTC/ETH, fee=0.003) + p0_idx = pool_ids.index("0xaaaa11112222aa") + np.testing.assert_allclose(X_attr[p0_idx, 0], 1.0) # chain_MAINNET + np.testing.assert_allclose( + X_attr[p0_idx, attr_names.index("log_fee")], np.log(0.003), rtol=1e-3 + ) + + # Pool 1 (0xbbbb = ARBITRUM, AAVE/ETH, fee=0.01) + p1_idx = pool_ids.index("0xbbbb33334444bb") + np.testing.assert_allclose(X_attr[p1_idx, 0], 0.0) # chain_MAINNET=0 + np.testing.assert_allclose( + X_attr[p1_idx, attr_names.index("log_fee")], np.log(0.01), rtol=1e-3 + ) + + def test_no_nans( + self, synthetic_daily_grid, synthetic_panel, tmp_path + ): + from quantammsim.calibration.pool_data import build_pool_attributes + + matched = self._make_matched( + synthetic_daily_grid, synthetic_panel, tmp_path + ) + X_attr, _, _ = build_pool_attributes(matched) + assert not np.any(np.isnan(X_attr)) diff --git a/tests/calibration/test_regression_pins.py b/tests/calibration/test_regression_pins.py new file mode 100644 index 0000000..877c23f --- /dev/null +++ b/tests/calibration/test_regression_pins.py @@ -0,0 +1,449 @@ +"""Pinned numerical regression tests for the calibration pipeline. + +These tests pin exact numerical values computed from the synthetic fixtures. +They protect against silent computation errors during refactoring — a test +that checks only shapes/signs would still pass if e.g. an index is off by +one in unpack, or a sign is flipped in regularization. + +All pinned values were computed with: + - Python 3.9, JAX 0.4.30, numpy seed 42 + - Synthetic fixtures from conftest.py (N_DAYS=15, 2 pools) +""" + +import os +import tempfile + +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd +import pytest + +from tests.calibration.conftest import ( + CADENCES, + GAS_COSTS, + K_OBS, + N_DAYS, + POOL_IDS_FULL, + POOL_PREFIXES, +) + +from quantammsim.calibration.grid_interpolation import ( + interpolate_pool_daily, + precompute_pool_coeffs_daily, +) +from quantammsim.calibration.loss import noise_volume, pack_params, pool_loss +from quantammsim.calibration.per_pool_fit import fit_all_pools, fit_single_pool +from quantammsim.calibration.pool_data import build_x_obs, match_grids_to_panel + + +# ── Helpers ──────────────────────────────────────────────────────────────── + + +@pytest.fixture +def matched_data(synthetic_daily_grid, synthetic_panel): + """Build matched data dict by writing temp parquets for both pools.""" + tmpdir = tempfile.mkdtemp() + for prefix in POOL_PREFIXES: + path = os.path.join(tmpdir, f"{prefix}_daily.parquet") + synthetic_daily_grid.to_parquet(path) + matched = match_grids_to_panel(tmpdir, synthetic_panel) + yield matched + import shutil + shutil.rmtree(tmpdir) + + +@pytest.fixture +def pool0_inputs(matched_data): + """x_obs, y_obs, day_indices, coeffs for pool 0.""" + entry = matched_data[POOL_PREFIXES[0]] + panel = entry["panel"] + x_obs = build_x_obs(panel) + y_obs = panel["log_volume"].values.astype(float) + day_indices = np.array(entry["day_indices"]) + return entry["coeffs"], x_obs, y_obs, day_indices + + +def _known_params(): + """Standard test params: cadence=12, gas=$1, noise intercept=8.""" + noise_coeffs = np.zeros(K_OBS) + noise_coeffs[0] = 8.0 + return pack_params(np.log(12.0), np.log(1.0), jnp.array(noise_coeffs)) + + +# ── Grid interpolation pins ─────────────────────────────────────────────── + + +class TestInterpolationPins: + """Verify interpolation exactness at grid knot points.""" + + def test_interpolation_exact_at_all_knots(self, synthetic_pool_coeffs): + """Interpolation at grid knot points must exactly reproduce grid values.""" + coeffs = synthetic_pool_coeffs + for ci, cad in enumerate(CADENCES): + for gi, gas in enumerate(GAS_COSTS): + log_cad = jnp.log(cad) + v_arb = interpolate_pool_daily(coeffs, log_cad, jnp.array(gas)) + grid_vals = coeffs.values[ci, gi, :] + np.testing.assert_allclose( + v_arb, grid_vals, atol=1e-4, + err_msg=f"Mismatch at cad={cad}, gas={gas}", + ) + + def test_interpolation_midpoint_value(self, synthetic_pool_coeffs): + """Pin interpolated value at a known mid-grid point.""" + v_arb = interpolate_pool_daily( + synthetic_pool_coeffs, jnp.log(6.0), jnp.array(0.5) + ) + # Pinned from JAX 0.4.30, seed 42 + assert v_arb.shape == (N_DAYS,) + np.testing.assert_allclose(float(v_arb[0]), 6579.6309, rtol=1e-4) + np.testing.assert_allclose(float(jnp.mean(v_arb)), 6621.3186, rtol=1e-4) + + def test_interpolation_monotone_in_cadence(self, synthetic_pool_coeffs): + """V_arb should decrease as cadence increases (at fixed gas).""" + coeffs = synthetic_pool_coeffs + gas = jnp.array(1.0) + cads = [1.0, 6.0, 12.0, 30.0, 60.0] + means = [ + float(jnp.mean(interpolate_pool_daily(coeffs, jnp.log(c), gas))) + for c in cads + ] + for i in range(len(means) - 1): + assert means[i] > means[i + 1], ( + f"V_arb not decreasing: cad={cads[i]}->{cads[i+1]}, " + f"mean={means[i]:.1f}->{means[i+1]:.1f}" + ) + + def test_interpolation_monotone_in_gas(self, synthetic_pool_coeffs): + """V_arb should decrease as gas cost increases (at fixed cadence).""" + coeffs = synthetic_pool_coeffs + log_cad = jnp.log(12.0) + gases = [0.0, 0.5, 1.0, 3.0, 5.0] + means = [ + float(jnp.mean(interpolate_pool_daily(coeffs, log_cad, jnp.array(g)))) + for g in gases + ] + for i in range(len(means) - 1): + assert means[i] > means[i + 1], ( + f"V_arb not decreasing: gas={gases[i]}->{gases[i+1]}, " + f"mean={means[i]:.1f}->{means[i+1]:.1f}" + ) + + def test_interpolation_differentiable(self, synthetic_pool_coeffs): + """Gradient of interpolated V_arb w.r.t. log_cadence must be finite.""" + coeffs = synthetic_pool_coeffs + + def f(log_cad): + return jnp.sum(interpolate_pool_daily(coeffs, log_cad, jnp.array(1.0))) + + grad_val = jax.grad(f)(jnp.log(12.0)) + assert jnp.isfinite(grad_val), f"Non-finite gradient: {grad_val}" + # Gradient should be negative (more cadence → less arb) + assert float(grad_val) < 0, f"Expected negative gradient, got {grad_val}" + + +# ── Loss function pins ───────────────────────────────────────────────────── + + +class TestLossPins: + """Pin exact loss values and gradients at known parameter points.""" + + def test_loss_value_pinned(self, synthetic_pool_coeffs, pool0_inputs): + """Pin the exact loss value at known params on synthetic data.""" + coeffs, x_obs, _, day_indices = pool0_inputs + params = _known_params() + y_obs = jnp.ones(x_obs.shape[0]) * 9.0 + day_indices_j = jnp.arange(x_obs.shape[0]) % N_DAYS + + loss = pool_loss(params, coeffs, jnp.array(x_obs), y_obs, day_indices_j) + # Pinned: 0.001726984975292 (JAX 0.4.30, seed 42) + np.testing.assert_allclose(float(loss), 0.001727, rtol=1e-3) + + def test_gradient_pinned(self, synthetic_pool_coeffs, pool0_inputs): + """Pin gradient values at known params.""" + coeffs, x_obs, _, day_indices = pool0_inputs + params = _known_params() + y_obs = jnp.ones(x_obs.shape[0]) * 9.0 + day_indices_j = jnp.arange(x_obs.shape[0]) % N_DAYS + + grad_fn = jax.grad(pool_loss) + grad = grad_fn(params, coeffs, jnp.array(x_obs), y_obs, day_indices_j) + grad_np = np.array(grad) + + # All gradients must be finite + assert np.all(np.isfinite(grad_np)), f"Non-finite gradients: {grad_np}" + + # Pin signs of key gradient components + # grad[0] = d_loss/d_log_cadence (negative: increasing cadence decreases V_arb, + # pushing log(V_arb + V_noise) away from y_obs=9.0) + assert grad_np[0] < 0, f"Expected negative cadence grad, got {grad_np[0]}" + # grad[1] = d_loss/d_log_gas (negative: same effect via gas) + assert grad_np[1] < 0, f"Expected negative gas grad, got {grad_np[1]}" + + # Pin magnitudes (rtol=0.01 to allow platform variance) + expected_grad = np.array([ + -0.000223, -0.000362, -0.000264, -0.003138, + 0.001071, 0.012854, 0.018232, -0.006220, + 0.013421, -0.016786, + ]) + np.testing.assert_allclose(grad_np, expected_grad, rtol=0.05, atol=1e-5) + + def test_loss_increases_with_bad_params(self, synthetic_pool_coeffs, pool0_inputs): + """Loss with wildly wrong noise intercept >> loss with good params.""" + coeffs, x_obs, _, _ = pool0_inputs + y_obs = jnp.ones(x_obs.shape[0]) * 9.0 + day_indices_j = jnp.arange(x_obs.shape[0]) % N_DAYS + x_obs_j = jnp.array(x_obs) + + params_good = _known_params() + noise_bad = np.zeros(K_OBS) + noise_bad[0] = 20.0 + params_bad = pack_params(np.log(12.0), np.log(1.0), jnp.array(noise_bad)) + + loss_good = float(pool_loss(params_good, coeffs, x_obs_j, y_obs, day_indices_j)) + loss_bad = float(pool_loss(params_bad, coeffs, x_obs_j, y_obs, day_indices_j)) + + assert loss_bad > 100.0, f"Expected loss_bad > 100, got {loss_bad}" + assert loss_bad > loss_good * 1000, "Bad params should be >1000x worse" + + +# ── Noise volume pins ────────────────────────────────────────────────────── + + +class TestNoiseVolumePins: + def test_intercept_only_equals_exp(self, synthetic_x_obs): + """With intercept-only noise coeffs, V_noise = exp(intercept) exactly.""" + coeffs = np.zeros(K_OBS) + coeffs[0] = 8.0 + v_noise = noise_volume(jnp.array(coeffs), jnp.array(synthetic_x_obs)) + # x_obs column 0 is all 1.0 (intercept), so x_obs @ coeffs = 8.0 for all obs + np.testing.assert_allclose(v_noise, np.exp(8.0), rtol=1e-6) + + def test_tvl_coeff_creates_variation(self, synthetic_x_obs): + """With nonzero TVL coeff, V_noise varies across observations.""" + coeffs = np.zeros(K_OBS) + coeffs[0] = 5.0 + coeffs[1] = 1.0 # TVL coefficient + v_noise = noise_volume(jnp.array(coeffs), jnp.array(synthetic_x_obs)) + assert float(jnp.std(v_noise)) > 0, "Expected variation from TVL coeff" + + +# ── Per-pool fit pins ────────────────────────────────────────────────────── + + +class TestPerPoolFitPins: + """Pin per-pool optimizer convergence on synthetic data.""" + + def test_fit_single_pool_converges(self, pool0_inputs): + """fit_single_pool should converge on synthetic data.""" + coeffs, x_obs, y_obs, day_indices = pool0_inputs + result = fit_single_pool(coeffs, x_obs, y_obs, day_indices) + assert result["converged"], "fit_single_pool did not converge" + + def test_fit_single_pool_loss_pinned(self, pool0_inputs): + """Pin the converged loss value.""" + coeffs, x_obs, y_obs, day_indices = pool0_inputs + result = fit_single_pool(coeffs, x_obs, y_obs, day_indices) + # Pinned: 0.0723 (JAX 0.4.30, seed 42) + np.testing.assert_allclose(result["loss"], 0.0723, rtol=0.05) + + def test_fit_single_pool_cadence_pinned(self, pool0_inputs): + """Pin the converged cadence — should find ~1.27 min on synthetic data.""" + coeffs, x_obs, y_obs, day_indices = pool0_inputs + result = fit_single_pool(coeffs, x_obs, y_obs, day_indices) + # Pinned: 1.266 minutes + np.testing.assert_allclose(result["cadence_minutes"], 1.27, rtol=0.1) + # Cadence must be in valid range + assert 1.0 <= result["cadence_minutes"] <= 60.0 + + def test_fit_single_pool_loss_lower_than_init(self, pool0_inputs): + """Fitted loss must be lower than loss at initial guess.""" + from quantammsim.calibration.per_pool_fit import make_initial_guess + + coeffs, x_obs, y_obs, day_indices = pool0_inputs + init = make_initial_guess(x_obs, y_obs) + init_loss = float( + pool_loss( + jnp.array(init), + coeffs, + jnp.array(x_obs), + jnp.array(y_obs), + jnp.array(day_indices), + ) + ) + result = fit_single_pool(coeffs, x_obs, y_obs, day_indices) + assert result["loss"] < init_loss, ( + f"Fitted loss {result['loss']:.6f} >= init loss {init_loss:.6f}" + ) + + def test_fit_all_pools_returns_all(self, matched_data): + """fit_all_pools returns results for every matched pool.""" + results = fit_all_pools(matched_data) + assert set(results.keys()) == set(matched_data.keys()) + for pid, r in results.items(): + assert "loss" in r + assert "log_cadence" in r + assert "noise_coeffs" in r + assert len(r["noise_coeffs"]) == K_OBS + + +# ── Joint fit pins ───────────────────────────────────────────────────────── + + +class TestJointFitPins: + """Pin joint optimization behavior on synthetic data.""" + + def test_joint_ppn_loss_decreases(self, matched_data): + """Joint per_pool_noise loss must decrease from initialization.""" + from quantammsim.calibration.joint_fit import fit_joint + + result = fit_joint(matched_data, mode="per_pool_noise", maxiter=100) + assert result["loss"] < result["init_loss"], ( + f"Loss didn't decrease: {result['loss']:.6f} >= {result['init_loss']:.6f}" + ) + + def test_joint_ppn_loss_pinned(self, matched_data): + """Pin the joint per_pool_noise loss value.""" + from quantammsim.calibration.joint_fit import fit_joint + + result = fit_joint(matched_data, mode="per_pool_noise", maxiter=100) + # Pinned: 0.0406 (JAX 0.4.30, seed 42) + # Use wide tolerance since optimizer path may vary across platforms + assert result["loss"] < 0.10, f"Loss too high: {result['loss']}" + assert result["loss"] < result["init_loss"] + + def test_joint_shared_noise_loss_decreases(self, matched_data): + """Joint shared_noise loss must decrease from initialization.""" + from quantammsim.calibration.joint_fit import fit_joint + + result = fit_joint(matched_data, mode="shared_noise", maxiter=100) + assert result["loss"] < result["init_loss"], ( + f"Loss didn't decrease: {result['loss']:.6f} >= {result['init_loss']:.6f}" + ) + + def test_joint_predict_new_pool_at_zero_attrs(self, matched_data): + """Predict at zero attributes → output equals bias terms.""" + from quantammsim.calibration.joint_fit import fit_joint, predict_new_pool_joint + + result = fit_joint(matched_data, mode="per_pool_noise", maxiter=50) + x_attr = np.zeros(result["k_attr"]) + pred = predict_new_pool_joint(result, x_attr) + + # At zero attributes: log_cadence = bias_cad, log_gas = bias_gas + np.testing.assert_allclose( + pred["log_cadence"], result["bias_cad"], rtol=1e-10 + ) + np.testing.assert_allclose( + pred["log_gas"], result["bias_gas"], rtol=1e-10 + ) + assert pred["cadence_minutes"] > 0 + assert pred["gas_usd"] > 0 + + def test_joint_shared_noise_predict_includes_noise(self, matched_data): + """Shared noise mode prediction includes noise_coeffs.""" + from quantammsim.calibration.joint_fit import fit_joint, predict_new_pool_joint + + result = fit_joint(matched_data, mode="shared_noise", maxiter=50) + x_attr = np.zeros(result["k_attr"]) + pred = predict_new_pool_joint(result, x_attr) + + assert "noise_coeffs" in pred, "shared_noise predict should include noise_coeffs" + assert len(pred["noise_coeffs"]) == K_OBS + # At zero attributes: noise_coeffs = bias_noise + np.testing.assert_allclose( + pred["noise_coeffs"], result["bias_noise"], rtol=1e-10 + ) + + def test_joint_ppn_noise_shape(self, matched_data): + """Per-pool noise mode produces (n_pools, K_OBS) noise coefficients.""" + from quantammsim.calibration.joint_fit import fit_joint + + n_pools = len(matched_data) + result = fit_joint(matched_data, mode="per_pool_noise", maxiter=20) + assert result["noise_coeffs"].shape == (n_pools, K_OBS) + + def test_joint_warm_start_from_option_c(self, matched_data): + """Warm start from Option C should produce a viable starting point.""" + from quantammsim.calibration.joint_fit import fit_joint + + option_c = fit_all_pools(matched_data) + result = fit_joint( + matched_data, + mode="per_pool_noise", + maxiter=100, + init_from_option_c=option_c, + ) + # The warm start may have higher init_loss than cold start because + # the linear projection of per-pool params introduces approximation + # error. But the final loss should still decrease from init. + assert result["loss"] < result["init_loss"] + + +# ── Pack/unpack roundtrip pins ───────────────────────────────────────────── + + +class TestPackUnpackPins: + def test_per_pool_loss_pack_roundtrip_exact(self): + """pack → unpack must recover exact values.""" + from quantammsim.calibration.loss import unpack_params + + log_cad = 2.4849 + log_gas = -0.6932 + noise = jnp.array([8.1, -1.2, 3.4, -0.5, 0.7, -2.1, 0.3, 0.9]) + packed = pack_params(log_cad, log_gas, noise) + + lc, lg, nc = unpack_params(packed) + np.testing.assert_allclose(float(lc), log_cad, atol=1e-10) + np.testing.assert_allclose(float(lg), log_gas, atol=1e-10) + np.testing.assert_allclose(nc, noise, atol=1e-10) + + def test_joint_pack_roundtrip_ppn(self): + """Joint per_pool_noise pack → unpack roundtrip.""" + from quantammsim.calibration.joint_fit import ( + pack_joint_params, + unpack_joint_params, + ) + + k_attr = 5 + n_pools = 3 + bias_cad = 2.5 + bias_gas = -0.1 + W_cad = jnp.arange(k_attr, dtype=float) * 0.1 + W_gas = jnp.arange(k_attr, dtype=float) * -0.05 + noise = jnp.ones((n_pools, K_OBS)) * 0.3 + + packed = pack_joint_params(bias_cad, bias_gas, W_cad, W_gas, noise) + config = {"k_attr": k_attr, "n_pools": n_pools, "mode": "per_pool_noise"} + unpacked = unpack_joint_params(packed, config) + + np.testing.assert_allclose(float(unpacked["bias_cad"]), bias_cad, atol=1e-10) + np.testing.assert_allclose(float(unpacked["bias_gas"]), bias_gas, atol=1e-10) + np.testing.assert_allclose(unpacked["W_cad"], W_cad, atol=1e-10) + np.testing.assert_allclose(unpacked["W_gas"], W_gas, atol=1e-10) + np.testing.assert_allclose(unpacked["noise_coeffs"], noise, atol=1e-10) + + def test_joint_pack_roundtrip_shared(self): + """Joint shared_noise pack → unpack roundtrip.""" + from quantammsim.calibration.joint_fit import ( + pack_joint_params, + unpack_joint_params, + ) + + k_attr = 4 + bias_cad = 1.5 + bias_gas = 0.2 + W_cad = jnp.ones(k_attr) * 0.1 + W_gas = jnp.ones(k_attr) * -0.2 + # shared_noise: (1 + k_attr, K_OBS) where row 0 is bias_noise + noise = jnp.arange((1 + k_attr) * K_OBS, dtype=float).reshape( + 1 + k_attr, K_OBS + ) + + packed = pack_joint_params(bias_cad, bias_gas, W_cad, W_gas, noise) + config = {"k_attr": k_attr, "n_pools": 2, "mode": "shared_noise"} + unpacked = unpack_joint_params(packed, config) + + np.testing.assert_allclose(float(unpacked["bias_cad"]), bias_cad, atol=1e-10) + np.testing.assert_allclose(unpacked["bias_noise"], noise[0], atol=1e-10) + np.testing.assert_allclose(unpacked["W_noise"], noise[1:], atol=1e-10) diff --git a/tests/pools/G3M/__init__.py b/tests/pools/G3M/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/pools/G3M/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/pools/G3M/test_hypersurge_balancer.py b/tests/pools/G3M/test_hypersurge_balancer.py new file mode 100644 index 0000000..cefa68b --- /dev/null +++ b/tests/pools/G3M/test_hypersurge_balancer.py @@ -0,0 +1,219 @@ +import numpy as np +import numpy.testing as npt + +import jax.numpy as jnp + +from quantammsim.core_simulator.dynamic_inputs import ( + DynamicInputArrays, + empty_dynamic_input_arrays, +) +from quantammsim.pools.G3M.balancer.hypersurge_balancer import ( + HYPERSURGE_PARAM_KEYS, + HyperSurgeBalancerPool, +) +from quantammsim.pools.G3M.balancer.hypersurge_balancer_reserves import ( + _hypersurge_fee_for_trade, + _pair_deviation, +) +from quantammsim.pools.creator import create_pool +from quantammsim.runners.jax_runner_utils import NestedHashabledict + + +ALL_SIG_VARIATIONS_2 = tuple(map(tuple, [[1, -1], [-1, 1]])) + + +def _run_fingerprint(n_steps=4): + return NestedHashabledict( + { + "n_assets": 2, + "bout_length": n_steps + 1, + "initial_pool_value": 1_000_000.0, + "arb_frequency": 1, + "do_arb": True, + "do_trades": False, + "fees": 0.003, + "gas_cost": 0.0, + "arb_fees": 0.0, + "all_sig_variations": ALL_SIG_VARIATIONS_2, + "noise_model": "arb_only", + "noise_trader_ratio": 0.0, + "hypersurge_arb_max_fee": 0.02, + "hypersurge_arb_threshold": 0.10, + "hypersurge_arb_cap_deviation": 0.50, + "hypersurge_noise_max_fee": 0.10, + "hypersurge_noise_threshold": 0.10, + "hypersurge_noise_cap_deviation": 0.50, + } + ) + + +def _unbatch_params(params): + return { + key: value if key == "subsidary_params" else value[0] + for key, value in params.items() + } + + +def test_creator_registers_hypersurge_balancer_aliases(): + assert isinstance(create_pool("balancer_hypersurge"), HyperSurgeBalancerPool) + assert isinstance(create_pool("hypersurge_balancer"), HyperSurgeBalancerPool) + + +def test_hypersurge_params_are_trainable_by_default(): + pool = create_pool("balancer_hypersurge") + run_fingerprint = _run_fingerprint() + initial_values = pool.get_initial_values(run_fingerprint) + + params = pool.init_parameters( + initial_values, + run_fingerprint, + n_assets=2, + n_parameter_sets=3, + noise="gaussian", + ) + + assert pool.is_trainable() + for key in HYPERSURGE_PARAM_KEYS: + assert key in params + assert params[key].shape == (3, 1) + assert "initial_weights_logits" in params + + +def test_pair_deviation_is_zero_when_pool_matches_oracle(): + reserves = jnp.array([5000.0, 2500.0]) + weights = jnp.array([0.5, 0.5]) + oracle_prices = jnp.array([100.0, 200.0]) + + deviation = _pair_deviation( + reserves, + weights, + oracle_prices, + token_in=0, + token_out=1, + ) + + npt.assert_allclose(np.asarray(deviation), 0.0, atol=1e-12) + + +def test_fee_uses_noise_params_when_trade_worsens_deviation(): + reserves = jnp.array([5000.0, 2500.0]) + weights = jnp.array([0.5, 0.5]) + oracle_prices = jnp.array([100.0, 200.0]) + hypersurge_params = jnp.array([0.02, 0.10, 0.50, 0.10, 0.10, 0.50]) + + fee = _hypersurge_fee_for_trade( + reserves, + candidate_trade=jnp.array([1000.0, -500.0]), + weights=weights, + oracle_prices=oracle_prices, + token_in=0, + token_out=1, + base_fee=0.003, + hypersurge_params=hypersurge_params, + ) + + assert float(fee) > 0.02 + + +def test_fee_uses_arb_params_when_trade_improves_deviation(): + reserves = jnp.array([6000.0, 2000.0]) + weights = jnp.array([0.5, 0.5]) + oracle_prices = jnp.array([100.0, 200.0]) + hypersurge_params = jnp.array([0.02, 0.10, 0.50, 0.10, 0.10, 0.50]) + + fee = _hypersurge_fee_for_trade( + reserves, + candidate_trade=jnp.array([-1000.0, 1000.0]), + weights=weights, + oracle_prices=oracle_prices, + token_in=1, + token_out=0, + base_fee=0.003, + hypersurge_params=hypersurge_params, + ) + + npt.assert_allclose(np.asarray(fee), 0.02, rtol=1e-12) + + +def test_hypersurge_balancer_reserve_scan_returns_positive_reserves(): + pool = create_pool("balancer_hypersurge") + prices = jnp.array( + [ + [100.0, 200.0], + [105.0, 200.0], + [110.0, 200.0], + [115.0, 200.0], + ] + ) + run_fingerprint = _run_fingerprint(n_steps=prices.shape[0]) + params = _unbatch_params( + pool.init_parameters( + pool.get_initial_values(run_fingerprint), + run_fingerprint, + n_assets=2, + n_parameter_sets=1, + noise="gaussian", + ) + ) + + reserves = pool.calculate_reserves_with_fees( + params, + run_fingerprint, + prices, + jnp.array([0, 0]), + additional_oracle_input=prices, + ) + + assert reserves.shape == prices.shape + assert bool(jnp.all(jnp.isfinite(reserves))) + assert bool(jnp.all(reserves > 0.0)) + + +def test_hypersurge_balancer_dynamic_inputs_accept_oracle_prices(): + pool = create_pool("balancer_hypersurge") + prices = jnp.array( + [ + [100.0, 200.0], + [105.0, 200.0], + [110.0, 200.0], + [115.0, 200.0], + ] + ) + empty_inputs = empty_dynamic_input_arrays() + dynamic_inputs = DynamicInputArrays( + trades=None, + fees=jnp.full((prices.shape[0],), 0.003), + gas_cost=jnp.zeros((prices.shape[0],)), + arb_fees=jnp.zeros((prices.shape[0],)), + lp_supply=jnp.ones((prices.shape[0],)), + reclamm_price_ratio_updates=empty_inputs.reclamm_price_ratio_updates, + oracle_prices=prices, + ) + run_fingerprint = _run_fingerprint(n_steps=prices.shape[0]) + run_fingerprint = NestedHashabledict( + { + **run_fingerprint, + "dynamic_input_flags": { + "use_dynamic_inputs": True, + "has_trades": False, + "has_dynamic_fees": True, + "has_dynamic_gas_cost": True, + "has_dynamic_arb_fees": True, + "has_lp_supply": True, + "has_reclamm_price_ratio_updates": False, + "has_oracle_prices": True, + }, + } + ) + + reserves = pool.calculate_reserves_with_dynamic_inputs( + {"initial_weights": jnp.array([0.5, 0.5])}, + run_fingerprint, + prices, + jnp.array([0, 0]), + dynamic_inputs, + ) + + assert reserves.shape == prices.shape + assert bool(jnp.all(jnp.isfinite(reserves))) + assert bool(jnp.all(reserves > 0.0)) diff --git a/tests/pools/reCLAMM/test_hypersurge_reclamm.py b/tests/pools/reCLAMM/test_hypersurge_reclamm.py new file mode 100644 index 0000000..92d8622 --- /dev/null +++ b/tests/pools/reCLAMM/test_hypersurge_reclamm.py @@ -0,0 +1,94 @@ +import numpy.testing as npt + +import jax.numpy as jnp + +from quantammsim.pools.G3M.balancer.hypersurge_balancer_reserves import ( + _hypersurge_fee_for_trade, +) +from quantammsim.pools.creator import create_pool +from quantammsim.pools.hypersurge_utils import HYPERSURGE_PARAM_KEYS +from quantammsim.pools.reCLAMM.reclamm_hypersurge import ReClammHyperSurgePool +from quantammsim.runners.jax_runner_utils import NestedHashabledict + + +ALL_SIG_VARIATIONS_2 = tuple(map(tuple, [[1, -1], [-1, 1]])) + + +def _run_fingerprint(n_steps=4): + return NestedHashabledict( + { + "n_assets": 2, + "bout_length": n_steps + 1, + "initial_pool_value": 1_000_000.0, + "arb_frequency": 1, + "do_arb": True, + "fees": 0.003, + "gas_cost": 0.0, + "arb_fees": 0.0, + "all_sig_variations": ALL_SIG_VARIATIONS_2, + "noise_model": "arb_only", + "noise_trader_ratio": 0.0, + "hypersurge_arb_max_fee": 0.02, + "hypersurge_arb_threshold": 0.10, + "hypersurge_arb_cap_deviation": 0.50, + "hypersurge_noise_max_fee": 0.10, + "hypersurge_noise_threshold": 0.10, + "hypersurge_noise_cap_deviation": 0.50, + } + ) + + +def test_creator_registers_hypersurge_reclamm_aliases(): + assert isinstance(create_pool("reclamm_hypersurge"), ReClammHyperSurgePool) + assert isinstance(create_pool("hypersurge_reclamm"), ReClammHyperSurgePool) + + +def test_hypersurge_reclamm_params_are_trainable_by_default(): + pool = create_pool("reclamm_hypersurge") + run_fingerprint = _run_fingerprint() + initial_values = pool.get_initial_values(run_fingerprint) + + params = pool.init_parameters( + initial_values, + run_fingerprint, + n_assets=2, + n_parameter_sets=3, + noise="gaussian", + ) + + assert pool.is_trainable() + assert "price_ratio" in params + assert "centeredness_margin" in params + for key in HYPERSURGE_PARAM_KEYS: + assert key in params + assert params[key].shape == (3, 1) + + +def test_hypersurge_fee_falls_back_to_base_fee_when_oracle_invalid(): + reserves = jnp.array([5000.0, 2500.0]) + weights = jnp.array([0.5, 0.5]) + hypersurge_params = jnp.array([0.02, 0.10, 0.50, 0.10, 0.10, 0.50]) + + zero_oracle_fee = _hypersurge_fee_for_trade( + reserves, + candidate_trade=jnp.array([1000.0, -500.0]), + weights=weights, + oracle_prices=jnp.array([0.0, 200.0]), + token_in=0, + token_out=1, + base_fee=0.003, + hypersurge_params=hypersurge_params, + ) + nan_oracle_fee = _hypersurge_fee_for_trade( + reserves, + candidate_trade=jnp.array([1000.0, -500.0]), + weights=weights, + oracle_prices=jnp.array([jnp.nan, 200.0]), + token_in=0, + token_out=1, + base_fee=0.003, + hypersurge_params=hypersurge_params, + ) + + npt.assert_allclose(zero_oracle_fee, 0.003, rtol=1e-12) + npt.assert_allclose(nan_oracle_fee, 0.003, rtol=1e-12) diff --git a/tests/pools/reCLAMM/test_reclamm_reserves.py b/tests/pools/reCLAMM/test_reclamm_reserves.py index 58d9dae..b9664aa 100644 --- a/tests/pools/reCLAMM/test_reclamm_reserves.py +++ b/tests/pools/reCLAMM/test_reclamm_reserves.py @@ -1428,4 +1428,3 @@ def test_lp_supply_e2e_do_run_on_historic_data(self): assert lp_val > base_val, ( f"Doubled LP supply should increase final value: {lp_val} <= {base_val}" ) - diff --git a/tests/scripts/test_compare_reclamm_geometric_noise_runs.py b/tests/scripts/test_compare_reclamm_geometric_noise_runs.py new file mode 100644 index 0000000..c211e7e --- /dev/null +++ b/tests/scripts/test_compare_reclamm_geometric_noise_runs.py @@ -0,0 +1,189 @@ +"""Tests for adjacent-row sourcing in compare_reclamm_geometric_noise_runs.py.""" + +from __future__ import annotations + +import importlib.util +from pathlib import Path + +import numpy as np + + +SCRIPT_PATH = ( + Path(__file__).resolve().parents[2] + / "scripts" + / "reclamm" + / "compare_reclamm_geometric_noise_runs.py" +) + + +def load_script_module(): + spec = importlib.util.spec_from_file_location( + "test_compare_reclamm_geometric_noise_runs_module", + SCRIPT_PATH, + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def test_build_run_specs_from_adjacent_row_maps_csv_cells_to_two_specs(): + module = load_script_module() + row = { + "metric_key": "noise_vs_arb_geometric_improvement_pct", + "metric_unit": "pct", + "source_noise_profile": "market_linear", + "pair_slug": "price_ratio_vs_margin", + "slice_slug": "q2", + "adjacency_axis": "horizontal", + "heatmap_value_diff_abs": 54.223889391076895, + "1_price_ratio": 1.335, + "1_centeredness_margin": 0.3184210526, + "1_daily_price_shift_exponent": 0.1975, + "1_tvl_usd": 1_000_000.0, + "1_heatmap_value": -53.0543210862, + "2_price_ratio": 1.36, + "2_centeredness_margin": 0.3184210526, + "2_daily_price_shift_exponent": 0.1975, + "2_tvl_usd": 1_000_000.0, + "2_heatmap_value": 1.1695683049, + } + + description, run_specs = module.build_run_specs_from_adjacent_row( + row, + csv_path=Path("adjacent_pairs.csv"), + row_index=0, + ) + + assert "adjacent_pairs.csv row 0" in description + assert "price_ratio_vs_margin q2" in description + assert "horizontal" in description + assert "noise_profile=market_linear" in description + assert len(run_specs) == 2 + assert run_specs[0]["name"] == "Top diff row cell 1" + assert run_specs[0]["price_ratio"] == 1.335 + assert run_specs[0]["centeredness_margin"] == 0.3184210526 + assert run_specs[0]["daily_price_shift_exponent"] == 0.1975 + assert run_specs[0]["tvl_usd"] == 1_000_000.0 + assert run_specs[0]["color"] == "C0" + assert run_specs[0]["source_noise_profile"] == "market_linear" + assert "heatmap_value=-53.054321" in run_specs[0]["reason"] + assert run_specs[1]["name"] == "Top diff row cell 2" + assert run_specs[1]["price_ratio"] == 1.36 + assert run_specs[1]["color"] == "C1" + + +def test_default_output_file_for_adjacent_csv_uses_csv_stem_and_row_index(): + module = load_script_module() + output = module.default_output_file_for_adjacent_csv( + Path("scripts/results/reclamm_heatmap_adjacency/example.csv"), + row_index=3, + ) + + assert ( + output.as_posix() + == "scripts/results/reclamm_heatmap_adjacency/example_row_3_geometric_noise_compare.png" + ) + + +def test_build_run_config_rejects_legacy_calibrated_noise_profile(): + module = load_script_module() + base_config = { + "name": "base", + "price_ratio": 1.1, + "centeredness_margin": 0.6, + "daily_price_shift_exponent": 0.1, + "initial_pool_value": 1_000_000.0, + "noise_model": "market_linear", + "reclamm_noise_params": {"foo": 1.0}, + "noise_arrays_path": "path.npz", + } + spec = { + "name": "cell", + "price_ratio": 1.335, + "centeredness_margin": 0.3184210526, + "daily_price_shift_exponent": 0.1975, + "tvl_usd": 1_000_000.0, + "source_noise_profile": "legacy_calibrated", + } + + import pytest + + with pytest.raises(ValueError): + module.build_run_config(spec, base_config=base_config) + + +def test_build_run_variants_canonicalizes_noise_and_arb_only_configs(): + module = load_script_module() + fixed_path = str( + Path(__file__).resolve().parents[2] + / "results" + / "linear_market_noise" + / "_sim_arrays" + / "0x9d1fcf346ea1b0_2024-06-01_2026-03-01.npz" + ) + base_config = { + "name": "base", + "price_ratio": 1.1, + "centeredness_margin": 0.6, + "daily_price_shift_exponent": 0.1, + "initial_pool_value": 1_000_000.0, + "noise_model": "market_linear", + "noise_arrays_path": fixed_path, + } + spec = { + "name": "cell", + "price_ratio": 1.335, + "centeredness_margin": 0.3184210526, + "daily_price_shift_exponent": 0.1975, + "tvl_usd": 1_000_000.0, + "source_noise_profile": "market_linear", + } + + class FakeThermostatCompare: + @staticmethod + def make_noise_variant_cfg(cfg, enable_noise_model): + updated = dict(cfg) + updated["enable_noise_model"] = bool(enable_noise_model) + updated["noise_model"] = "market_linear" if enable_noise_model else "arb_only" + updated["noise_arrays_path"] = fixed_path + updated["reclamm_noise_params"] = {"tvl_mean": 1.0, "tvl_std": 2.0} + return updated + + variants = module.build_run_variants(spec, base_config, FakeThermostatCompare) + + assert variants["noise"]["noise_model"] == "market_linear" + assert variants["arb"]["noise_model"] == "arb_only" + assert variants["noise"]["noise_arrays_path"] == fixed_path + assert variants["arb"]["noise_arrays_path"] == fixed_path + assert variants["noise"]["reclamm_noise_params"] == {"tvl_mean": 1.0, "tvl_std": 2.0} + assert variants["arb"]["reclamm_noise_params"] == {"tvl_mean": 1.0, "tvl_std": 2.0} + + +def test_print_run_inputs_to_terminal_includes_fingerprint_and_update_params(capsys): + module = load_script_module() + cfg = { + "name": "cell", + "variant_label": "arb-only", + } + run_fingerprint = { + "tokens": ["AAVE", "ETH"], + "fees": np.float64(0.0025), + "arb_frequency": np.int64(14), + } + update_params = { + "price_ratio": np.array(1.335), + "centeredness_margin": np.array(0.3184210526), + "daily_price_shift_base": np.array(0.99999841596), + } + + module.print_run_inputs_to_terminal(cfg, run_fingerprint, update_params) + + captured = capsys.readouterr().out + assert "Run inputs for cell (arb-only):" in captured + assert '"run_fingerprint"' in captured + assert '"update_params"' in captured + assert '"tokens": [' in captured + assert '"AAVE"' in captured + assert '"arb_frequency": 14' in captured + assert '"price_ratio": 1.335' in captured diff --git a/tests/scripts/test_compare_reclamm_thermostats.py b/tests/scripts/test_compare_reclamm_thermostats.py new file mode 100644 index 0000000..caec6ae --- /dev/null +++ b/tests/scripts/test_compare_reclamm_thermostats.py @@ -0,0 +1,1022 @@ +"""Tests for heatmap skip logic in compare_reclamm_thermostats.py.""" + +import importlib.util +from pathlib import Path +import sys +import types + +import numpy as np +import pytest + + +SCRIPT_PATH = ( + Path(__file__).resolve().parents[2] + / "scripts" + / "compare_reclamm_thermostats.py" +) + + +def _load_script_module(): + injected_modules = {} + + def inject_module(name, module): + injected_modules[name] = sys.modules.get(name) + sys.modules[name] = module + + # Minimal stubs so the script can be imported without the full runtime + # stack present in this test environment. + jax_module = types.ModuleType("jax") + jax_module.numpy = np + inject_module("jax", jax_module) + inject_module("jax.numpy", np) + + pandas_module = types.ModuleType("pandas") + pandas_module.Timestamp = lambda value: value + pandas_module.DatetimeIndex = tuple + pandas_module.DataFrame = type("DataFrame", (), {}) + pandas_module.read_parquet = lambda *args, **kwargs: None + inject_module("pandas", pandas_module) + + matplotlib_module = types.ModuleType("matplotlib") + pyplot_module = types.ModuleType("matplotlib.pyplot") + pyplot_module.cm = types.SimpleNamespace(viridis=lambda values: values) + colors_module = types.ModuleType("matplotlib.colors") + colors_module.TwoSlopeNorm = object + colors_module.Normalize = object + colors_module.SymLogNorm = object + cm_module = types.ModuleType("matplotlib.cm") + cm_module.ScalarMappable = object + inject_module("matplotlib", matplotlib_module) + inject_module("matplotlib.pyplot", pyplot_module) + inject_module("matplotlib.colors", colors_module) + inject_module("matplotlib.cm", cm_module) + + quantammsim_module = types.ModuleType("quantammsim") + runners_module = types.ModuleType("quantammsim.runners") + jax_runners_module = types.ModuleType("quantammsim.runners.jax_runners") + jax_runners_module.do_run_on_historic_data = lambda **kwargs: { + "final_value": 0.0 + } + runners_module.jax_runners = jax_runners_module + + pools_module = types.ModuleType("quantammsim.pools") + reclamm_pkg_module = types.ModuleType("quantammsim.pools.reCLAMM") + reserves_module = types.ModuleType( + "quantammsim.pools.reCLAMM.reclamm_reserves" + ) + reserves_module.calibrate_arc_length_speed = lambda *args, **kwargs: 0.0 + reserves_module.compute_price_ratio = lambda *args, **kwargs: 1.0 + reserves_module.initialise_reclamm_reserves = ( + lambda *args, **kwargs: (np.array([1.0, 1.0]), 1.0, 1.0) + ) + reclamm_pkg_module.reclamm_reserves = reserves_module + pools_module.reCLAMM = reclamm_pkg_module + + utils_module = types.ModuleType("quantammsim.utils") + data_processing_module = types.ModuleType("quantammsim.utils.data_processing") + historic_utils_module = types.ModuleType( + "quantammsim.utils.data_processing.historic_data_utils" + ) + historic_utils_module.get_historic_parquet_data = lambda *args, **kwargs: None + data_processing_module.historic_data_utils = historic_utils_module + utils_module.data_processing = data_processing_module + + quantammsim_module.runners = runners_module + quantammsim_module.pools = pools_module + quantammsim_module.utils = utils_module + + inject_module("quantammsim", quantammsim_module) + inject_module("quantammsim.runners", runners_module) + inject_module("quantammsim.runners.jax_runners", jax_runners_module) + inject_module("quantammsim.pools", pools_module) + inject_module("quantammsim.pools.reCLAMM", reclamm_pkg_module) + inject_module("quantammsim.pools.reCLAMM.reclamm_reserves", reserves_module) + inject_module("quantammsim.utils", utils_module) + inject_module("quantammsim.utils.data_processing", data_processing_module) + inject_module( + "quantammsim.utils.data_processing.historic_data_utils", + historic_utils_module, + ) + + spec = importlib.util.spec_from_file_location( + "test_compare_reclamm_thermostats_module", + SCRIPT_PATH, + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + try: + spec.loader.exec_module(module) + return module + finally: + for name, original in injected_modules.items(): + if original is None: + sys.modules.pop(name, None) + else: + sys.modules[name] = original + + +@pytest.fixture +def script_module(): + return _load_script_module() + + +@pytest.fixture +def base_cfg(): + return { + "tokens": ["AAVE", "ETH"], + "start": "2024-06-01 00:00:00", + "end": "2025-06-01 00:00:00", + "fees": 0.0025, + "price_ratio": 1.10, + "centeredness_margin": 0.60, + "daily_price_shift_exponent": 0.1, + "initial_pool_value": 5_000_000.0, + } + + +@pytest.fixture +def launch_final_values(): + return { + "geometric": 1_000_000.0, + "constant_arc_length": 1_010_000.0, + } + + +def test_make_noise_variant_cfg_disables_noise_fields(script_module, base_cfg): + noisy_cfg = { + **base_cfg, + "enable_noise_model": True, + "noise_model": "market_linear", + "noise_artifact_dir": "results/linear_market_noise", + "noise_pool_id": "0x9d1fcf346ea1b0", + "gas_cost": 1.0, + "protocol_fee_split": 0.25, + "reclamm_noise_params": {"tvl_mean": 1.0, "tvl_std": 2.0}, + "noise_arrays_path": "results/linear_market_noise/_sim_arrays/aave_eth.npz", + "arb_frequency": 6, + } + + arb_only_cfg = script_module.make_noise_variant_cfg( + noisy_cfg, + enable_noise_model=False, + ) + resolved = script_module.resolve_reclamm_noise_settings(arb_only_cfg) + + assert arb_only_cfg["enable_noise_model"] is False + assert arb_only_cfg["noise_model"] == "arb_only" + assert arb_only_cfg["noise_reference_model"] == "market_linear" + assert arb_only_cfg["gas_cost"] == script_module.DEFAULT_GAS_COST + assert arb_only_cfg["protocol_fee_split"] == script_module.DEFAULT_PROTOCOL_FEE_SPLIT + assert arb_only_cfg["arb_frequency"] == script_module.FIXED_COMPARE_ARB_FREQUENCY + assert arb_only_cfg["noise_artifact_dir"] == script_module.DEFAULT_MARKET_LINEAR_ARTIFACT_DIR + assert arb_only_cfg["noise_pool_id"] == script_module.AAVE_WETH_POOL_ID + assert arb_only_cfg["noise_arrays_path"] == script_module.DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH + assert "reclamm_noise_params" not in arb_only_cfg + assert resolved["noise_model"] == "arb_only" + assert resolved["noise_arrays_path"] == script_module.DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH + assert set(resolved["reclamm_noise_params"]) == {"tvl_mean", "tvl_std"} + assert resolved["noise_summary"] == ( + f"arb_only (arb_frequency={script_module.FIXED_COMPARE_ARB_FREQUENCY})" + ) + + +def test_make_noise_variant_cfg_defaults_to_fixed_compare_arb_cadence( + script_module, + base_cfg, +): + noisy_cfg = { + **base_cfg, + "enable_noise_model": True, + "noise_model": "market_linear", + } + + arb_only_cfg = script_module.make_noise_variant_cfg( + noisy_cfg, + enable_noise_model=False, + ) + resolved_noise = script_module.resolve_reclamm_noise_settings(noisy_cfg) + + assert arb_only_cfg["arb_frequency"] == script_module.FIXED_COMPARE_ARB_FREQUENCY + assert arb_only_cfg["noise_arrays_path"] == script_module.DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH + assert resolved_noise["arb_frequency"] == script_module.FIXED_COMPARE_ARB_FREQUENCY + + +def test_make_fingerprint_ignores_non_axis_override_fields(script_module, base_cfg): + canonical_cfg = { + **base_cfg, + "enable_noise_model": True, + "noise_model": "market_linear", + } + noisy_override_cfg = { + **canonical_cfg, + "arb_frequency": 6, + "gas_cost": 7.0, + "protocol_fee_split": 0.9, + "arb_fees": 3.0, + "noise_artifact_dir": "custom/noise/dir", + "noise_pool_id": "override-pool", + "reclamm_noise_params": {"tvl_mean": 999.0}, + "noise_arrays_path": "custom/path.npz", + } + + canonical_fingerprint = script_module.make_fingerprint(canonical_cfg, "geometric") + overridden_fingerprint = script_module.make_fingerprint( + noisy_override_cfg, + "geometric", + ) + canonical_key = script_module._make_method_cache_key(canonical_cfg, "geometric") + overridden_key = script_module._make_method_cache_key( + noisy_override_cfg, + "geometric", + ) + + assert overridden_fingerprint == canonical_fingerprint + assert overridden_key == canonical_key + assert overridden_fingerprint["arb_frequency"] == script_module.FIXED_COMPARE_ARB_FREQUENCY + assert overridden_fingerprint["gas_cost"] == script_module.DEFAULT_GAS_COST + assert ( + overridden_fingerprint["noise_arrays_path"] + == script_module.DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH + ) + assert ( + overridden_fingerprint["protocol_fee_split"] + == script_module.DEFAULT_PROTOCOL_FEE_SPLIT + ) + + +def test_arb_only_fingerprint_only_changes_noise_model(script_module, base_cfg): + noisy_cfg = { + **base_cfg, + "enable_noise_model": True, + "noise_model": "market_linear", + } + + noise_fingerprint = script_module.make_fingerprint(noisy_cfg, "geometric") + arb_only_cfg = script_module.make_noise_variant_cfg(noisy_cfg, enable_noise_model=False) + arb_fingerprint = script_module.make_fingerprint(arb_only_cfg, "geometric") + + expected_arb = dict(noise_fingerprint) + expected_arb["noise_model"] = "arb_only" + + assert noise_fingerprint["noise_model"] == "market_linear" + assert noise_fingerprint["noise_arrays_path"] == script_module.DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH + assert arb_fingerprint == expected_arb + + +def test_make_fingerprint_keeps_path_fallback_when_arrays_not_preloaded( + script_module, + base_cfg, +): + noisy_cfg = { + **base_cfg, + "enable_noise_model": True, + "noise_model": "market_linear", + } + + fingerprint = script_module.make_fingerprint(noisy_cfg, "geometric") + + assert fingerprint["noise_arrays_path"] == script_module.DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH + assert "noise_base_array" not in fingerprint + assert "noise_tvl_coeff_array" not in fingerprint + + +def test_make_fingerprint_includes_preloaded_market_linear_arrays( + script_module, + base_cfg, +): + noisy_cfg = { + **base_cfg, + "enable_noise_model": True, + "noise_model": "market_linear", + } + shared_noise = { + "arrays_path": script_module.DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH, + "noise_base_array": np.array([1.0, 2.0]), + "noise_tvl_coeff_array": np.array([3.0, 4.0]), + "tvl_mean": 10.0, + "tvl_std": 5.0, + } + + fingerprint = script_module.make_fingerprint( + noisy_cfg, + "geometric", + market_linear_noise_data=shared_noise, + ) + + assert fingerprint["noise_arrays_path"] == script_module.DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH + assert np.array_equal(fingerprint["noise_base_array"], shared_noise["noise_base_array"]) + assert np.array_equal( + fingerprint["noise_tvl_coeff_array"], + shared_noise["noise_tvl_coeff_array"], + ) + + +@pytest.mark.parametrize( + ("field", "updated_value"), + [ + ("fees", 0.01), + ("start", "2024-07-01 00:00:00"), + ("end", "2025-07-01 00:00:00"), + ("tokens", ["WBTC", "ETH"]), + ], +) +def test_method_cache_key_includes_run_identity_fields( + script_module, + base_cfg, + field, + updated_value, +): + canonical_cfg = { + **base_cfg, + "enable_noise_model": True, + "noise_model": "market_linear", + } + updated_cfg = dict(canonical_cfg) + updated_cfg[field] = updated_value + + canonical_key = script_module._make_method_cache_key(canonical_cfg, "geometric") + updated_key = script_module._make_method_cache_key(updated_cfg, "geometric") + + assert updated_key != canonical_key + + +@pytest.mark.parametrize( + "cfg_overrides", + [ + {"enable_noise_model": True, "noise_model": "calibrated"}, + { + "enable_noise_model": False, + "noise_model": "arb_only", + "noise_reference_model": "calibrated", + }, + ], +) +def test_resolve_reclamm_noise_settings_rejects_legacy_modes( + script_module, + base_cfg, + cfg_overrides, +): + cfg = { + **base_cfg, + **cfg_overrides, + } + + with pytest.raises(ValueError, match="market_linear"): + script_module.resolve_reclamm_noise_settings(cfg) + + +def test_generate_heatmaps_skips_existing_pairs( + monkeypatch, + script_module, + base_cfg, + launch_final_values, +): + monkeypatch.setattr(script_module.os.path, "exists", lambda filename: True) + monkeypatch.setattr( + script_module, + "build_heatmap_matrices", + lambda **kwargs: pytest.fail("heatmap sweep should have been skipped"), + ) + monkeypatch.setattr( + script_module, + "plot_heatmap", + lambda **kwargs: pytest.fail("plotting should have been skipped"), + ) + + script_module.generate_heatmaps( + base_cfg, + price_data=None, + launch_final_values=launch_final_values, + cache={}, + ) + + +def test_generate_heatmaps_only_renders_missing_artifacts( + monkeypatch, + script_module, + base_cfg, + launch_final_values, +): + pair = script_module.get_pair_heatmap_specs(base_cfg)[0] + slice_variant = pair["fixed_slices"][0] + pair_suffix = script_module._pair_slice_suffix(pair, slice_variant) + missing_file = script_module.tvl_artifact_filename( + "reclamm_heatmap_geometric_vs_launch_geometric_symlog20", + base_cfg, + suffix=pair_suffix, + ) + + def fake_exists(filename): + if filename == missing_file: + return False + return filename.startswith("reclamm_heatmap_") + + build_calls = [] + plotted_files = [] + + def fake_build_heatmap_matrices(**kwargs): + build_calls.append(kwargs) + return { + "geometric_vs_launch_geometric_pct": np.zeros( + (len(kwargs["y_values"]), len(kwargs["x_values"])), + dtype=float, + ) + } + + def fake_plot_heatmap(**kwargs): + plotted_files.append(kwargs["filename"]) + + monkeypatch.setattr(script_module.os.path, "exists", fake_exists) + monkeypatch.setattr( + script_module, + "build_heatmap_matrices", + fake_build_heatmap_matrices, + ) + monkeypatch.setattr(script_module, "plot_heatmap", fake_plot_heatmap) + + script_module.generate_heatmaps( + base_cfg, + price_data=None, + launch_final_values=launch_final_values, + cache={}, + ) + + assert len(build_calls) == 1 + assert build_calls[0]["progress_label"] == pair_suffix + assert build_calls[0]["base_cfg"][pair["fixed_key"]] == pytest.approx( + slice_variant["value"] + ) + assert build_calls[0]["metric_keys"] == ["geometric_vs_launch_geometric_pct"] + assert plotted_files == [missing_file] + + +def test_generate_heatmaps_only_renders_missing_improvement_artifacts( + monkeypatch, + script_module, + base_cfg, + launch_final_values, +): + pair = script_module.get_pair_heatmap_specs(base_cfg)[0] + slice_variant = pair["fixed_slices"][0] + pair_suffix = script_module._pair_slice_suffix(pair, slice_variant) + missing_file = script_module.tvl_artifact_filename( + "reclamm_heatmap_noise_vs_arb_geometric_improvement_symlog20", + base_cfg, + suffix=pair_suffix, + ) + + def fake_exists(filename): + if filename == missing_file: + return False + return filename.startswith("reclamm_heatmap_") + + build_calls = [] + plotted_files = [] + + def fake_build_heatmap_matrices(**kwargs): + build_calls.append(kwargs) + return { + "noise_vs_arb_geometric_improvement_pct": np.zeros( + (len(kwargs["y_values"]), len(kwargs["x_values"])), + dtype=float, + ) + } + + def fake_plot_heatmap(**kwargs): + plotted_files.append(kwargs["filename"]) + + monkeypatch.setattr(script_module.os.path, "exists", fake_exists) + monkeypatch.setattr( + script_module, + "build_heatmap_matrices", + fake_build_heatmap_matrices, + ) + monkeypatch.setattr(script_module, "plot_heatmap", fake_plot_heatmap) + + script_module.generate_heatmaps( + base_cfg, + price_data=None, + launch_final_values=launch_final_values, + cache={}, + ) + + assert len(build_calls) == 1 + assert build_calls[0]["progress_label"] == pair_suffix + assert build_calls[0]["base_cfg"][pair["fixed_key"]] == pytest.approx( + slice_variant["value"] + ) + assert build_calls[0]["metric_keys"] == [ + "noise_vs_arb_geometric_improvement_pct" + ] + assert plotted_files == [missing_file] + + +def test_generate_three_variable_3d_heatmaps_only_renders_missing_slice( + monkeypatch, + script_module, + base_cfg, + launch_final_values, +): + missing_file = script_module.tvl_artifact_filename( + "reclamm_heatmap_3d_geometric_vs_launch_geometric_symlog20", + base_cfg, + suffix="slice_q1", + ) + + def fake_exists(filename): + if filename == missing_file: + return False + return filename.startswith("reclamm_heatmap_3d_") + + build_calls = [] + plotted_files = [] + + def fake_build_heatmap_matrices(**kwargs): + build_calls.append(kwargs) + return { + "geometric_vs_launch_geometric_pct": np.zeros( + (len(kwargs["y_values"]), len(kwargs["x_values"])), + dtype=float, + ) + } + + def fake_plot_three_variable_heatmap_3d(**kwargs): + plotted_files.append(kwargs["filename"]) + + monkeypatch.setattr(script_module.os.path, "exists", fake_exists) + monkeypatch.setattr( + script_module, + "build_heatmap_matrices", + fake_build_heatmap_matrices, + ) + monkeypatch.setattr( + script_module, + "plot_three_variable_heatmap_3d", + fake_plot_three_variable_heatmap_3d, + ) + + script_module.generate_three_variable_3d_heatmaps( + base_cfg, + price_data=None, + launch_final_values=launch_final_values, + cache={}, + ) + + assert len(build_calls) == 3 + assert {call["progress_label"] for call in build_calls} == { + "3d_price_ratio_vs_margin_shift_exp_q1", + "3d_shift_exp_vs_margin_price_ratio_q1", + "3d_price_ratio_vs_shift_exp_margin_q1", + } + assert plotted_files == [missing_file] + + +def test_run_method_final_value_cached_reuses_persisted_parquet_value( + monkeypatch, + script_module, + base_cfg, +): + cfg = dict(base_cfg) + cache_key = script_module._make_method_cache_key(cfg, "geometric") + cache_key_hash = script_module._make_method_cache_hash(cache_key) + + class FakeFrame: + empty = False + + def itertuples(self, index=False): + return [ + types.SimpleNamespace( + cache_key_hash=cache_key_hash, + final_value=1_234_567.0, + ) + ] + + monkeypatch.setattr(script_module.os.path, "exists", lambda filename: True) + monkeypatch.setattr(script_module.pd, "read_parquet", lambda *args, **kwargs: FakeFrame()) + monkeypatch.setattr( + script_module, + "do_run_on_historic_data", + lambda **kwargs: pytest.fail("persisted forward-value cache should be reused"), + ) + + cache = script_module.make_sweep_cache(price_data=None, cache_scope_cfg=cfg) + value = script_module._run_method_final_value_cached(cfg, "geometric", cache) + + assert value == pytest.approx(1_234_567.0) + + +def test_run_method_final_value_cached_passes_preloaded_market_linear_arrays( + monkeypatch, + script_module, + base_cfg, +): + captured = {} + shared_noise = { + "arrays_path": script_module.DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH, + "noise_base_array": np.array([11.0, 12.0]), + "noise_tvl_coeff_array": np.array([21.0, 22.0]), + "tvl_mean": 100.0, + "tvl_std": 25.0, + } + + monkeypatch.setattr( + script_module, + "_load_persistent_final_value_cache", + lambda cache: cache.update( + { + "_persistent_final_value_cache_loaded": True, + "_persistent_final_value_cache": {}, + "_persistent_final_value_next_batch_id": 0, + } + ), + ) + monkeypatch.setattr(script_module, "flush_sweep_cache", lambda *args, **kwargs: None) + + def fake_do_run_on_historic_data(**kwargs): + captured["run_fingerprint"] = kwargs["run_fingerprint"] + return {"final_value": 1_111_111.0} + + monkeypatch.setattr( + script_module, + "do_run_on_historic_data", + fake_do_run_on_historic_data, + ) + + cfg = { + **base_cfg, + "enable_noise_model": True, + "noise_model": "market_linear", + } + cache = script_module.make_sweep_cache( + price_data=None, + cache_scope_cfg=cfg, + market_linear_noise_data=shared_noise, + ) + + value = script_module._run_method_final_value_cached(cfg, "geometric", cache) + + assert value == pytest.approx(1_111_111.0) + assert np.array_equal( + captured["run_fingerprint"]["noise_base_array"], + shared_noise["noise_base_array"], + ) + assert np.array_equal( + captured["run_fingerprint"]["noise_tvl_coeff_array"], + shared_noise["noise_tvl_coeff_array"], + ) + assert ( + captured["run_fingerprint"]["noise_arrays_path"] + == script_module.DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH + ) + + +def test_arc_speed_artifacts_only_build_missing_line_output( + monkeypatch, + script_module, + base_cfg, + launch_final_values, +): + missing_line = script_module.tvl_artifact_filename("reclamm_line_efficiency", base_cfg, suffix="arc_speed_vs_price_ratio") + + def fake_exists(filename): + if filename == missing_line: + return False + return filename.startswith("reclamm_") + + build_calls = [] + curve_calls = [] + plotted_lines = [] + + def fake_build_heatmap_matrices(**kwargs): + build_calls.append(kwargs) + return { + "efficiency_pct": np.zeros( + (len(kwargs["y_values"]), len(kwargs["x_values"])), + dtype=float, + ) + } + + def fake_build_metric_curve(**kwargs): + curve_calls.append(kwargs) + return np.zeros(len(kwargs["x_values"]), dtype=float) + + monkeypatch.setattr(script_module.os.path, "exists", fake_exists) + monkeypatch.setattr(script_module, "RUN_CONSTANT_ARC_LENGTH", True) + monkeypatch.setattr( + script_module, + "compute_auto_calibrated_arc_length_speed", + lambda cfg, price_data: 1.23e-4, + ) + monkeypatch.setattr( + script_module, + "build_heatmap_matrices", + fake_build_heatmap_matrices, + ) + monkeypatch.setattr( + script_module, + "build_metric_curve", + fake_build_metric_curve, + ) + monkeypatch.setattr( + script_module, + "plot_heatmap", + lambda **kwargs: pytest.fail("existing heatmap should not be redrawn"), + ) + monkeypatch.setattr( + script_module, + "plot_arc_speed_line_chart", + lambda **kwargs: plotted_lines.append(kwargs["filename"]), + ) + + script_module.generate_arc_speed_efficiency_artifacts( + base_cfg=base_cfg, + launch_cfg=dict(base_cfg), + price_data=None, + launch_final_values=launch_final_values, + cache={}, + ) + + assert len(build_calls) == 1 + assert build_calls[0]["progress_label"] == "arc_speed_vs_price_ratio" + assert len(curve_calls) == 1 + assert curve_calls[0]["x_key"] == "arc_length_speed" + assert plotted_lines == [missing_line] + + +def test_compute_auto_calibrated_arc_length_speed_uses_nearest_datetime_row( + monkeypatch, + script_module, +): + real_pd = pytest.importorskip("pandas") + monkeypatch.setattr(script_module.pd, "Timestamp", real_pd.Timestamp) + monkeypatch.setattr(script_module.pd, "DatetimeIndex", real_pd.DatetimeIndex) + monkeypatch.setattr(script_module.pd, "DataFrame", real_pd.DataFrame) + monkeypatch.setattr( + script_module.pd, + "MultiIndex", + real_pd.MultiIndex, + raising=False, + ) + + price_data = real_pd.DataFrame( + { + "close_AAVE": [10.0, 30.0], + "close_ETH": [1.0, 1.0], + }, + index=real_pd.DatetimeIndex( + ["2024-06-01 00:01:00", "2024-06-01 00:03:00"] + ), + ) + cfg = { + "tokens": ["AAVE", "ETH"], + "start": "2024-06-01 00:02:10", + "price_ratio": 1.10, + "centeredness_margin": 0.60, + "daily_price_shift_exponent": 0.1, + "initial_pool_value": 1_000_000.0, + } + captured = {} + + monkeypatch.setattr( + script_module, + "initialise_reclamm_reserves", + lambda *args, **kwargs: (np.array([1.0, 1.0]), 1.0, 1.0), + ) + monkeypatch.setattr( + script_module, + "compute_price_ratio", + lambda *args, **kwargs: 1.0, + ) + + def fake_calibrate_arc_length_speed(*args, **kwargs): + captured["market_price_0"] = args[7] + return 123.0 + + monkeypatch.setattr( + script_module, + "calibrate_arc_length_speed", + fake_calibrate_arc_length_speed, + ) + + result = script_module.compute_auto_calibrated_arc_length_speed(cfg, price_data) + + assert result == pytest.approx(123.0) + assert captured["market_price_0"] == pytest.approx(30.0) + + +def test_flush_sweep_cache_writes_compact_scalar_parquet(script_module): + captured = {} + + class FakeFrame: + def __init__(self, payload): + captured["payload"] = payload + + def to_parquet(self, path, index=False, compression=None): + captured["path"] = path + captured["index"] = index + captured["compression"] = compression + + script_module.pd.DataFrame = FakeFrame + script_module.os.makedirs = lambda *args, **kwargs: captured.setdefault( + "makedirs", args[0] + ) + + cache = { + "_pending_persistent_final_values": { + "abc123": { + "cache_key_hash": "abc123", + "final_value": 123.45, + "method": "geometric", + "enable_noise_model": True, + "noise_model": "market_linear", + "price_ratio": 1.1, + "centeredness_margin": 0.6, + "daily_price_shift_exponent": 0.1, + "initial_pool_value": 5_000_000.0, + "arb_frequency": 15, + } + }, + "_persistent_final_value_cache": {}, + "_persistent_final_value_next_batch_id": 0, + "_persistent_final_value_cache_loaded": True, + "_persistent_final_value_cache_path": "results/reclamm_heatmap_forward_cache/test/forward_values_tvl_5m.parquet", + } + + script_module.flush_sweep_cache(cache, force=True) + + assert set(captured["payload"].keys()) == set( + script_module.PERSISTED_FORWARD_VALUE_COLUMNS + ) + assert captured["payload"]["cache_key_hash"] == ["abc123"] + assert captured["payload"]["method"] == ["geometric"] + assert captured["payload"]["arb_frequency"] == [15] + assert captured["index"] is False + assert captured["compression"] == "zstd" + assert captured["makedirs"].endswith("forward_values_tvl_5m.parquet") + assert captured["path"].endswith("forward_values_tvl_5m.parquet/batch_00000000.parquet") + assert cache["_pending_persistent_final_values"] == {} + assert cache["_persistent_final_value_cache"] == {"abc123": 123.45} + assert cache["_persistent_final_value_next_batch_id"] == 1 + + +def test_load_persistent_final_value_cache_reads_sharded_parquet_dir( + monkeypatch, + script_module, +): + frames = { + "results/reclamm_heatmap_forward_cache/test/forward_values_tvl_1m.parquet/batch_00000000.parquet": [ + types.SimpleNamespace( + cache_key_hash="first123", + final_value=111.0, + method="geometric", + enable_noise_model=True, + noise_model="market_linear", + price_ratio=1.1, + centeredness_margin=0.6, + daily_price_shift_exponent=0.1, + initial_pool_value=1_000_000.0, + arb_frequency=15, + ) + ], + "results/reclamm_heatmap_forward_cache/test/forward_values_tvl_1m.parquet/batch_00000001.parquet": [ + types.SimpleNamespace( + cache_key_hash="second456", + final_value=222.0, + method="geometric", + enable_noise_model=False, + noise_model="arb_only", + price_ratio=1.2, + centeredness_margin=0.7, + daily_price_shift_exponent=0.2, + initial_pool_value=1_000_000.0, + arb_frequency=15, + ) + ], + } + + class FakeFrame: + def __init__(self, rows): + self._rows = rows + self.empty = not rows + + def itertuples(self, index=False): + return list(self._rows) + + monkeypatch.setattr(script_module.os.path, "exists", lambda filename: True) + monkeypatch.setattr(script_module.os.path, "isdir", lambda filename: True) + monkeypatch.setattr( + script_module.os, + "listdir", + lambda path: ["batch_00000001.parquet", "batch_00000000.parquet"], + ) + monkeypatch.setattr( + script_module.pd, + "read_parquet", + lambda path, *args, **kwargs: FakeFrame(frames[path]), + ) + + cache = { + "_persistent_final_value_cache_loaded": False, + "_persistent_final_value_cache_path": "results/reclamm_heatmap_forward_cache/test/forward_values_tvl_1m.parquet", + } + + script_module._load_persistent_final_value_cache(cache) + + assert cache["_persistent_final_value_cache"] == { + "first123": 111.0, + "second456": 222.0, + } + assert cache["_persistent_final_value_next_batch_id"] == 2 + + +def test_load_persistent_final_value_cache_supports_legacy_two_column_parquet( + monkeypatch, + script_module, +): + class FakeFrame: + empty = False + + def itertuples(self, index=False): + return [ + types.SimpleNamespace( + cache_key_hash="legacy123", + final_value=999.0, + ) + ] + + monkeypatch.setattr(script_module.os.path, "exists", lambda filename: True) + monkeypatch.setattr(script_module.pd, "read_parquet", lambda *args, **kwargs: FakeFrame()) + + cache = { + "_persistent_final_value_cache_loaded": False, + "_persistent_final_value_cache_path": "results/reclamm_heatmap_forward_cache/test/forward_values_tvl_1m.parquet", + } + + script_module._load_persistent_final_value_cache(cache) + + assert cache["_persistent_final_value_cache"] == {"legacy123": 999.0} + assert cache["_persistent_final_value_next_batch_id"] == 0 + + +def test_make_sweep_cache_does_not_eagerly_load_persisted_parquet( + monkeypatch, + script_module, +): + calls = [] + + monkeypatch.setattr( + script_module, + "_load_persistent_final_value_cache", + lambda cache: calls.append(dict(cache)), + ) + + cache = script_module.make_sweep_cache(price_data=None, cache_scope_cfg=None) + + assert calls == [] + assert cache["_persistent_final_value_cache"] == {} + assert cache["_persistent_final_value_next_batch_id"] == 0 + assert cache["_persistent_final_value_cache_loaded"] is False + + +def test_run_comparison_cached_only_uses_geometric_runs_when_constant_arc_disabled( + monkeypatch, + script_module, + base_cfg, + launch_final_values, +): + calls = [] + + monkeypatch.setattr( + script_module, + "_make_comparison_cache_key", + lambda cfg, launch_final_values: ("cache", round(float(cfg["price_ratio"]), 6)), + ) + + def fake_run_method_final_value_cached(cfg, method, cache): + calls.append((cfg.get("enable_noise_model", False), method)) + return { + (True, "geometric"): 1_050_000.0, + (False, "geometric"): 1_000_000.0, + }[(cfg.get("enable_noise_model", False), method)] + + monkeypatch.setattr( + script_module, + "_run_method_final_value_cached", + fake_run_method_final_value_cached, + ) + + metrics = script_module.run_comparison_cached( + base_cfg, + cache={"_comparison_cache": {}, "_final_value_cache": {}, "_shared_price_data": None}, + launch_final_values=launch_final_values, + metric_keys=( + "geometric_vs_launch_geometric_pct", + "noise_vs_arb_geometric_improvement_pct", + ), + ) + + assert calls == [(True, "geometric"), (False, "geometric")] + assert metrics == { + "geometric_vs_launch_geometric_pct": pytest.approx(5.0), + "noise_vs_arb_geometric_improvement_pct": pytest.approx(5.0), + } diff --git a/tests/scripts/test_find_adjacent_heatmap_pairs.py b/tests/scripts/test_find_adjacent_heatmap_pairs.py new file mode 100644 index 0000000..6ad7a39 --- /dev/null +++ b/tests/scripts/test_find_adjacent_heatmap_pairs.py @@ -0,0 +1,307 @@ +"""Tests for cache-backed adjacent heatmap pair detection.""" + +from __future__ import annotations + +import importlib.util +from pathlib import Path + +import pytest + + +SCRIPT_PATH = ( + Path(__file__).resolve().parents[2] + / "scripts" + / "reclamm" + / "find_adjacent_heatmap_pairs.py" +) + + +def load_script_module(): + spec = importlib.util.spec_from_file_location( + "test_find_adjacent_heatmap_pairs_module", + SCRIPT_PATH, + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +def make_cell(x_index, y_index, heatmap_value, **overrides): + cell = { + "metric_key": "noise_vs_arb_geometric_improvement_pct", + "metric_unit": "pct", + "pair_slug": "price_ratio_vs_margin", + "slice_slug": "q2", + "slice_label": "Q2", + "fixed_key": "daily_price_shift_exponent", + "fixed_value": 0.1975, + "price_ratio": 1.01 + 0.1 * x_index, + "centeredness_margin": 0.05 + 0.1 * y_index, + "daily_price_shift_exponent": 0.1975, + "tvl_usd": 1_000_000.0, + "heatmap_value": float(heatmap_value), + "x_index": int(x_index), + "y_index": int(y_index), + } + cell.update(overrides) + return cell + + +def test_find_adjacent_rows_for_slice_filters_and_sorts_descending(): + module = load_script_module() + records_by_coord = { + (0, 0): make_cell(0, 0, 0.0), + (0, 1): make_cell(1, 0, 35.0), + (0, 2): make_cell(2, 0, -10.0), + (1, 0): make_cell(0, 1, 5.0), + (1, 1): make_cell(1, 1, -40.0), + (1, 2): make_cell(2, 1, -50.0), + } + + rows = module.find_adjacent_rows_for_slice( + metric_key="noise_vs_arb_geometric_improvement_pct", + metric_unit="pct", + records_by_coord=records_by_coord, + x_count=3, + y_count=2, + min_diff=30.0, + ) + + assert [row["heatmap_value_diff_abs"] for row in rows] == [75.0, 45.0, 45.0, 40.0, 35.0] + assert rows[0]["adjacency_axis"] == "vertical" + assert rows[0]["1_x_index"] == 1 + assert rows[0]["1_y_index"] == 0 + assert rows[0]["2_x_index"] == 1 + assert rows[0]["2_y_index"] == 1 + + horizontal_rows = module.find_adjacent_rows_for_slice( + metric_key="noise_vs_arb_geometric_improvement_pct", + metric_unit="pct", + records_by_coord=records_by_coord, + x_count=3, + y_count=2, + min_diff=30.0, + adjacency_axis="horizontal", + ) + assert {row["adjacency_axis"] for row in horizontal_rows} == {"horizontal"} + + vertical_rows = module.find_adjacent_rows_for_slice( + metric_key="noise_vs_arb_geometric_improvement_pct", + metric_unit="pct", + records_by_coord=records_by_coord, + x_count=3, + y_count=2, + min_diff=30.0, + adjacency_axis="vertical", + ) + assert {row["adjacency_axis"] for row in vertical_rows} == {"vertical"} + + +def test_build_slice_cell_grid_reconstructs_metric_values_from_cache_hashes(): + module = load_script_module() + + class FakeCompareModule: + @staticmethod + def make_noise_variant_cfg(cfg, enable_noise_model): + updated = dict(cfg) + updated["enable_noise_model"] = bool(enable_noise_model) + return updated + + @staticmethod + def _make_method_cache_key(cfg, method): + return ( + method, + bool(cfg["enable_noise_model"]), + round(float(cfg["price_ratio"]), 6), + round(float(cfg["centeredness_margin"]), 6), + round(float(cfg["daily_price_shift_exponent"]), 6), + round(float(cfg["initial_pool_value"]), 2), + ) + + @staticmethod + def _make_method_cache_hash(key): + return repr(key) + + @staticmethod + def get_initial_pool_value(cfg): + return float(cfg["initial_pool_value"]) + + base_cfg = { + "price_ratio": 1.1, + "centeredness_margin": 0.3, + "daily_price_shift_exponent": 0.2, + "initial_pool_value": 1_000_000.0, + } + pair_spec = { + "slug": "price_ratio_vs_margin", + "x_values": [1.1, 1.2], + "y_values": [0.3, 0.4], + "x_key": "price_ratio", + "y_key": "centeredness_margin", + "fixed_key": "daily_price_shift_exponent", + } + slice_variant = { + "slug": "q2", + "label": "Q2", + "value": 0.2, + } + + heatmap_targets = { + (0, 0): (130.0, 100.0), # +30% + (0, 1): (200.0, 100.0), # +100% + (1, 0): (70.0, 100.0), # -30% + (1, 1): (160.0, 100.0), # +60% + } + cache_lookup = {} + for (y_index, x_index), (noise_geo, arb_geo) in heatmap_targets.items(): + cfg = dict(base_cfg) + cfg["price_ratio"] = pair_spec["x_values"][x_index] + cfg["centeredness_margin"] = pair_spec["y_values"][y_index] + noise_cfg, noise_method = FakeCompareModule.make_noise_variant_cfg(cfg, True), "geometric" + arb_cfg, arb_method = FakeCompareModule.make_noise_variant_cfg(cfg, False), "geometric" + + noise_key = FakeCompareModule._make_method_cache_key(noise_cfg, noise_method) + arb_key = FakeCompareModule._make_method_cache_key(arb_cfg, arb_method) + cache_lookup[FakeCompareModule._make_method_cache_hash(noise_key)] = noise_geo + cache_lookup[FakeCompareModule._make_method_cache_hash(arb_key)] = arb_geo + + slice_scan = module.build_slice_cell_grid( + compare_module=FakeCompareModule, + base_cfg=base_cfg, + pair_spec=pair_spec, + slice_variant=slice_variant, + metric_key="noise_vs_arb_geometric_improvement_pct", + cache_lookup=cache_lookup, + ) + + assert slice_scan["resolved_cell_count"] == 4 + assert slice_scan["missing_hash_count"] == 0 + assert slice_scan["records_by_coord"][(0, 0)]["heatmap_value"] == pytest.approx(30.0) + assert slice_scan["records_by_coord"][(0, 1)]["heatmap_value"] == pytest.approx(100.0) + assert slice_scan["records_by_coord"][(1, 0)]["heatmap_value"] == pytest.approx(-30.0) + assert slice_scan["records_by_coord"][(1, 1)]["heatmap_value"] == pytest.approx(60.0) + + rows = module.find_adjacent_rows_for_slice( + metric_key="noise_vs_arb_geometric_improvement_pct", + metric_unit="pct", + records_by_coord=slice_scan["records_by_coord"], + x_count=2, + y_count=2, + min_diff=30.0, + ) + + assert [row["heatmap_value_diff_abs"] for row in rows] == pytest.approx([90.0, 70.0, 60.0, 40.0]) + assert rows[0]["1_heatmap_value"] == pytest.approx(-30.0) + assert rows[0]["2_heatmap_value"] == pytest.approx(60.0) + + +def test_run_top_row_geometric_comparison_dispatches_to_compare_module(monkeypatch): + module = load_script_module() + captured = {} + + class FakeCompareModule: + @staticmethod + def run_adjacent_csv_row_comparison(csv_path, row_index=0, output_file=None): + captured["csv_path"] = csv_path + captured["row_index"] = row_index + captured["output_file"] = output_file + return "fake-output.png" + + monkeypatch.setattr( + module, + "load_geometric_compare_module", + lambda module_path=None: FakeCompareModule, + ) + + output = module.run_top_row_geometric_comparison( + Path("tmp_adjacent.csv"), + output_file="custom.png", + row_index=0, + ) + + assert output == "fake-output.png" + assert captured == { + "csv_path": Path("tmp_adjacent.csv"), + "row_index": 0, + "output_file": "custom.png", + } + + +def test_autodetect_lightweight_noise_profile_keeps_market_linear(): + module = load_script_module() + compare_context = module._LightweightCompareContext() + base_cfg = compare_context.configs_for_tvl(compare_context.CONFIGS, 1_000_000.0)[1] + pair_spec = compare_context.get_pair_heatmap_specs(base_cfg)[0] + + compare_context.set_noise_profile("market_linear") + module.autodetect_lightweight_noise_profile( + compare_module=compare_context, + base_cfg=base_cfg, + pair_specs=[pair_spec], + metric_key="noise_vs_arb_geometric_improvement_pct", + slice_slug="q2", + cache_lookup={}, + ) + + assert compare_context.noise_profile == "market_linear" + + +def test_lightweight_context_rejects_legacy_noise_profile(): + module = load_script_module() + compare_context = module._LightweightCompareContext() + + with pytest.raises(ValueError): + compare_context.set_noise_profile("legacy_calibrated") + + +def test_source_variant_sets_explicit_market_and_arb_only_noise_models(): + module = load_script_module() + compare_context = module._LightweightCompareContext() + cfg = compare_context.configs_for_tvl(compare_context.CONFIGS, 1_000_000.0)[1] + + noise_cfg, noise_method = module._source_variant( + compare_context, + cfg, + "noise_geometric", + ) + arb_cfg, arb_method = module._source_variant( + compare_context, + cfg, + "arb_geometric", + ) + + assert noise_method == "geometric" + assert noise_cfg["enable_noise_model"] is True + assert noise_cfg["noise_model"] == "market_linear" + assert noise_cfg["noise_arrays_path"] == compare_context.DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH + + assert arb_method == "geometric" + assert arb_cfg["enable_noise_model"] is False + assert arb_cfg["noise_model"] == "arb_only" + assert arb_cfg["noise_arrays_path"] == compare_context.DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH + + resolved_arb = compare_context.resolve_reclamm_noise_settings(arb_cfg) + assert resolved_arb["noise_model"] == "arb_only" + assert resolved_arb["noise_arrays_path"] == compare_context.DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH + assert set(resolved_arb["reclamm_noise_params"]) == {"tvl_mean", "tvl_std"} + assert resolved_arb["noise_cache_key"][0] == "arb_only" + + +def test_lightweight_context_defaults_to_fixed_compare_arb_cadence(): + module = load_script_module() + compare_context = module._LightweightCompareContext() + cfg = compare_context.configs_for_tvl(compare_context.CONFIGS, 1_000_000.0)[1] + cfg["arb_frequency"] = 6 + cfg["gas_cost"] = 99.0 + cfg["protocol_fee_split"] = 0.9 + + arb_cfg = compare_context.make_noise_variant_cfg(cfg, False) + resolved_noise = compare_context.resolve_reclamm_noise_settings(cfg) + + assert arb_cfg["arb_frequency"] == compare_context.FIXED_COMPARE_ARB_FREQUENCY + assert arb_cfg["noise_model"] == "arb_only" + assert arb_cfg["noise_arrays_path"] == compare_context.DEFAULT_MARKET_LINEAR_NOISE_ARRAYS_PATH + assert resolved_noise["arb_frequency"] == compare_context.FIXED_COMPARE_ARB_FREQUENCY + assert arb_cfg["gas_cost"] == compare_context.DEFAULT_GAS_COST + assert arb_cfg["protocol_fee_split"] == compare_context.DEFAULT_PROTOCOL_FEE_SPLIT diff --git a/tests/unit/test_hypersurge_training_dynamic_inputs.py b/tests/unit/test_hypersurge_training_dynamic_inputs.py new file mode 100644 index 0000000..4a4896b --- /dev/null +++ b/tests/unit/test_hypersurge_training_dynamic_inputs.py @@ -0,0 +1,109 @@ +import numpy as np +import pandas as pd + +from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames +from quantammsim.pools.hypersurge_utils import HYPERSURGE_PARAM_KEYS +from quantammsim.runners.jax_runners import train_on_historic_data + + +def _synthetic_price_data(start="2023-01-01 00:00:00", periods=24 * 60 * 5): + date_index = pd.date_range(start=start, periods=periods, freq="min") + t = np.arange(periods, dtype=np.float64) + price_data = pd.DataFrame( + { + "close_BTC": 20_000.0 * np.exp(0.00008 * t + 0.015 * np.sin(t / 240.0)), + "close_ETH": 1_500.0 * np.exp(0.00005 * t + 0.02 * np.cos(t / 360.0)), + }, + index=(date_index.view("int64") // 10**6), + ) + oracle_prices = pd.DataFrame( + { + "unix": price_data.index.to_numpy(), + "BTC": price_data["close_BTC"].to_numpy(), + "ETH": price_data["close_ETH"].to_numpy(), + } + ) + return price_data, oracle_prices + + +def _training_fingerprint(rule): + return { + "tokens": ["BTC", "ETH"], + "rule": rule, + "startDateString": "2023-01-01 00:00:00", + "endDateString": "2023-01-04 00:00:00", + "endTestDateString": "2023-01-05 00:00:00", + "chunk_period": 60, + "weight_interpolation_period": 60, + "bout_offset": 120, + "initial_pool_value": 1_000_000.0, + "fees": 0.003, + "arb_fees": 0.0, + "gas_cost": 0.0, + "arb_frequency": 1, + "do_arb": True, + "return_val": "returns", + "use_fused_reserves": False, + "optimisation_settings": { + "method": "gradient_descent", + "base_lr": 0.01, + "optimiser": "adam", + "batch_size": 1, + "n_iterations": 1, + "n_parameter_sets": 1, + "training_data_kind": "historic", + "sample_method": "uniform", + "initial_random_key": 0, + "n_cycles": 1, + "val_fraction": 0.0, + "early_stopping": False, + "decay_lr_ratio": 0.8, + "decay_lr_plateau": 100, + "min_lr": 1e-6, + }, + } + + +def test_hypersurge_balancer_training_accepts_oracle_frames(): + price_data, oracle_prices = _synthetic_price_data() + params, metadata = train_on_historic_data( + _training_fingerprint("balancer_hypersurge"), + price_data=price_data, + dynamic_input_frames=DynamicInputFrames(oracle_prices=oracle_prices), + verbose=False, + force_init=True, + return_training_metadata=True, + iterations_per_print=999999, + ) + + assert np.isfinite(metadata["final_objective"]) + for key in HYPERSURGE_PARAM_KEYS: + assert key in params + + +def test_hypersurge_reclamm_training_accepts_oracle_frames(): + price_data, oracle_prices = _synthetic_price_data() + fingerprint = _training_fingerprint("reclamm_hypersurge") + fingerprint.update( + { + "initial_price_ratio": 2.0, + "initial_centeredness_margin": 0.25, + "initial_daily_price_shift_base": 1.0 - 1.0 / 124000.0, + "reclamm_interpolation_method": "geometric", + } + ) + + params, metadata = train_on_historic_data( + fingerprint, + price_data=price_data, + dynamic_input_frames=DynamicInputFrames(oracle_prices=oracle_prices), + verbose=False, + force_init=True, + return_training_metadata=True, + iterations_per_print=999999, + ) + + assert np.isfinite(metadata["final_objective"]) + assert "price_ratio" in params + for key in HYPERSURGE_PARAM_KEYS: + assert key in params diff --git a/tests/unit/test_jax_runner_utils.py b/tests/unit/test_jax_runner_utils.py index 2e014f5..82a08c2 100644 --- a/tests/unit/test_jax_runner_utils.py +++ b/tests/unit/test_jax_runner_utils.py @@ -335,20 +335,59 @@ def test_prepare_dynamic_inputs_preserves_fixed_hot_path_structure(self): assert flags["has_dynamic_arb_fees"] is True assert flags["has_lp_supply"] is True assert flags["has_reclamm_price_ratio_updates"] is False + assert flags["has_oracle_prices"] is False assert train_inputs.trades.shape == (2, 3) assert train_inputs.fees.shape == (2,) assert train_inputs.gas_cost.shape == (2,) assert train_inputs.arb_fees.shape == (2,) assert train_inputs.lp_supply.shape == (2,) assert train_inputs.reclamm_price_ratio_updates.shape == (1, 4) + assert train_inputs.oracle_prices.shape == (1, 1) assert test_inputs.trades.shape == (2, 3) assert test_inputs.fees.shape == (2,) assert test_inputs.gas_cost.shape == (2,) assert test_inputs.arb_fees.shape == (2,) assert test_inputs.lp_supply.shape == (2,) assert test_inputs.reclamm_price_ratio_updates.shape == (1, 4) + assert test_inputs.oracle_prices.shape == (1, 1) np.testing.assert_allclose(np.asarray(train_inputs.fees), np.array([0.003, 0.003])) + def test_prepare_dynamic_inputs_normalizes_oracle_prices(self): + """Oracle price frames should align to token order and ffill by minute.""" + from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames + from quantammsim.runners.jax_runner_utils import prepare_dynamic_inputs + + run_fingerprint = { + "tokens": ["ETH", "USDC"], + "startDateString": "2023-01-01 00:00:00", + "endDateString": "2023-01-01 00:03:00", + "endTestDateString": "2023-01-01 00:05:00", + } + + oracle_prices = pd.DataFrame( + { + "unix": [1672531200000, 1672531320000], + "ETH": [1200.0, 1210.0], + "USDC": [1.0, 1.0], + } + ) + + prepared = prepare_dynamic_inputs( + run_fingerprint, + dynamic_input_frames=DynamicInputFrames(oracle_prices=oracle_prices), + do_test_period=True, + ) + + assert prepared["dynamic_input_flags"]["has_oracle_prices"] is True + np.testing.assert_allclose( + np.asarray(prepared["train_dynamic_inputs"].oracle_prices), + np.array([[1200.0, 1.0], [1200.0, 1.0], [1210.0, 1.0]]), + ) + np.testing.assert_allclose( + np.asarray(prepared["test_dynamic_inputs"].oracle_prices), + np.array([[1210.0, 1.0], [1210.0, 1.0]]), + ) + def test_prepare_dynamic_inputs_normalizes_reclamm_price_ratio_updates(self): """Manual reCLAMM update schedules should map to per-step event rows.""" from quantammsim.core_simulator.dynamic_inputs import DynamicInputFrames