diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 27332ab..5c1df7d 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -1,47 +1,124 @@ -name: Publish to PyPI +name: Build and Publish on: release: types: [published] jobs: - build: + # Build wheels on Linux + build-linux: + name: Build Linux wheels runs-on: ubuntu-latest + strategy: + matrix: + target: [x86_64, aarch64] steps: - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5 + - name: Build wheels + uses: PyO3/maturin-action@v1 with: - python-version: "3.11" + target: ${{ matrix.target }} + args: --release --out dist + manylinux: auto - - name: Install build dependencies - run: | - python -m pip install --upgrade pip - pip install build + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-linux-${{ matrix.target }} + path: dist/*.whl + + # Build wheels on macOS (x86_64) + build-macos-x86: + name: Build macOS x86_64 wheels + runs-on: macos-latest + steps: + - uses: actions/checkout@v4 + + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: x86_64-apple-darwin + args: --release --out dist + + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-macos-x86_64 + path: dist/*.whl + + # Build wheels on macOS (ARM64) + build-macos-arm: + name: Build macOS ARM64 wheels + runs-on: macos-latest + steps: + - uses: actions/checkout@v4 + + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: aarch64-apple-darwin + args: --release --out dist - - name: Build package - run: python -m build + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-macos-arm64 + path: dist/*.whl + + # Build wheels on Windows + build-windows: + name: Build Windows wheels + runs-on: windows-latest + steps: + - uses: actions/checkout@v4 + + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: x64 + args: --release --out dist + + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-windows + path: dist/*.whl + + # Build source distribution + build-sdist: + name: Build source distribution + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Build sdist + uses: PyO3/maturin-action@v1 + with: + command: sdist + args: --out dist - - name: Upload build artifacts + - name: Upload sdist uses: actions/upload-artifact@v4 with: - name: dist - path: dist/ + name: sdist + path: dist/*.tar.gz + # Publish all artifacts to PyPI publish: - needs: build + name: Publish to PyPI + needs: [build-linux, build-macos-x86, build-macos-arm, build-windows, build-sdist] runs-on: ubuntu-latest environment: pypi permissions: id-token: write # Required for trusted publishing steps: - - name: Download build artifacts + - name: Download all artifacts uses: actions/download-artifact@v4 with: - name: dist - path: dist/ + path: dist + merge-multiple: true - name: Publish to PyPI uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/rust-test.yml b/.github/workflows/rust-test.yml new file mode 100644 index 0000000..477a5ee --- /dev/null +++ b/.github/workflows/rust-test.yml @@ -0,0 +1,128 @@ +name: Rust Backend Tests + +on: + push: + branches: [main] + paths: + - 'rust/**' + - 'diff_diff/**' + - 'tests/**' + - 'pyproject.toml' + - '.github/workflows/rust-test.yml' + pull_request: + branches: [main] + paths: + - 'rust/**' + - 'diff_diff/**' + - 'tests/**' + - 'pyproject.toml' + - '.github/workflows/rust-test.yml' + +env: + CARGO_TERM_COLOR: always + +jobs: + # Run Rust unit tests + rust-tests: + name: Rust Unit Tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Install OpenBLAS + run: sudo apt-get update && sudo apt-get install -y libopenblas-dev + + - name: Run Rust tests + working-directory: rust + run: cargo test --verbose + + # Build and test with Python on multiple platforms + python-tests: + name: Python Tests (${{ matrix.os }}) + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + # Windows excluded due to Intel MKL build complexity + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install OpenBLAS (Ubuntu) + if: matrix.os == 'ubuntu-latest' + run: sudo apt-get update && sudo apt-get install -y libopenblas-dev + + - name: Install OpenBLAS (macOS) + if: matrix.os == 'macos-latest' + run: brew install openblas + + - name: Set OpenBLAS paths (macOS) + if: matrix.os == 'macos-latest' + run: | + echo "OPENBLAS_DIR=$(brew --prefix openblas)" >> $GITHUB_ENV + echo "PKG_CONFIG_PATH=$(brew --prefix openblas)/lib/pkgconfig" >> $GITHUB_ENV + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Install test dependencies + run: pip install pytest numpy pandas scipy + + - name: Build and install with maturin + run: | + pip install maturin + maturin build --release -o dist + echo "=== Built wheels ===" + ls -la dist/ + # --no-index ensures we install from local wheel, not PyPI + pip install --no-index --find-links=dist diff-diff + + - name: Verify Rust backend is available + # Run from /tmp to avoid source directory shadowing installed package + working-directory: /tmp + run: | + python -c "import diff_diff; print('Location:', diff_diff.__file__)" + python -c "from diff_diff import HAS_RUST_BACKEND; print('HAS_RUST_BACKEND:', HAS_RUST_BACKEND); assert HAS_RUST_BACKEND, 'Rust backend not available'" + + - name: Copy tests to isolated location + run: cp -r tests /tmp/tests + + - name: Run Rust backend tests + working-directory: /tmp + run: pytest tests/test_rust_backend.py -v + + - name: Run tests with Rust backend + working-directory: /tmp + run: DIFF_DIFF_BACKEND=rust pytest tests/ -x -q + + # Test pure Python fallback (without Rust extension) + python-fallback: + name: Pure Python Fallback + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install dependencies + run: pip install numpy pandas scipy pytest + + - name: Verify pure Python mode + run: | + # Use PYTHONPATH to import directly (skips maturin build) + PYTHONPATH=. python -c "from diff_diff import HAS_RUST_BACKEND; print(f'HAS_RUST_BACKEND: {HAS_RUST_BACKEND}'); assert not HAS_RUST_BACKEND" + + - name: Run tests in pure Python mode + run: PYTHONPATH=. DIFF_DIFF_BACKEND=python pytest tests/ -x -q --ignore=tests/test_rust_backend.py diff --git a/.gitignore b/.gitignore index 032d0df..272a10c 100644 --- a/.gitignore +++ b/.gitignore @@ -56,3 +56,13 @@ Thumbs.db # Benchmarks - generated data and results (can be regenerated) benchmarks/data/synthetic/*.csv benchmarks/results/ + +# Rust build artifacts +rust/target/ +Cargo.lock +*.so +*.pyd +*.dylib + +# Maturin build artifacts +target/ diff --git a/CLAUDE.md b/CLAUDE.md index 504ee71..40ddad6 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -31,6 +31,28 @@ ruff check diff_diff tests mypy diff_diff ``` +### Rust Backend Commands + +```bash +# Build Rust backend for development (requires Rust toolchain) +maturin develop + +# Build with release optimizations +maturin develop --release + +# Run Rust unit tests +cd rust && cargo test + +# Force pure Python mode (disable Rust backend) +DIFF_DIFF_BACKEND=python pytest + +# Force Rust mode (fail if Rust not available) +DIFF_DIFF_BACKEND=rust pytest + +# Run Rust backend equivalence tests +pytest tests/test_rust_backend.py -v +``` + ## Architecture ### Module Structure @@ -81,6 +103,20 @@ mypy diff_diff - Single optimization point for all estimators (reduces code duplication) - Cluster-robust SEs use pandas groupby instead of O(n × clusters) loop +- **`diff_diff/_backend.py`** - Backend detection and configuration (v2.0.0): + - Detects optional Rust backend availability + - Handles `DIFF_DIFF_BACKEND` environment variable ('auto', 'python', 'rust') + - Exports `HAS_RUST_BACKEND` flag and Rust function references + - Other modules import from here to avoid circular imports with `__init__.py` + +- **`rust/`** - Optional Rust backend for accelerated computation (v2.0.0): + - **`rust/src/lib.rs`** - PyO3 module definition, exports Python bindings + - **`rust/src/bootstrap.rs`** - Parallel bootstrap weight generation (Rademacher, Mammen, Webb) + - **`rust/src/linalg.rs`** - OLS solver and cluster-robust variance estimation + - **`rust/src/weights.rs`** - Synthetic control weights and simplex projection + - Uses ndarray-linalg with OpenBLAS (Linux/macOS) or Intel MKL (Windows) + - Provides 4-8x speedup for SyntheticDiD, minimal benefit for other estimators + - **`diff_diff/results.py`** - Dataclass containers for estimation results: - `DiDResults`, `MultiPeriodDiDResults`, `SyntheticDiDResults`, `PeriodEffect` - Each provides `summary()`, `to_dict()`, `to_dataframe()` methods diff --git a/TODO.md b/TODO.md index c80ddfd..64bdd6a 100644 --- a/TODO.md +++ b/TODO.md @@ -102,9 +102,19 @@ From code review (PR #32): --- +## Rust Backend Optimizations + +Deferred from PR #58 code review (can be done post-merge): + +- [ ] **Matrix inversion efficiency** (`rust/src/linalg.rs:180-194`): Use Cholesky factorization for symmetric positive-definite matrices instead of column-by-column solve +- [ ] **Reduce bootstrap allocations** (`rust/src/bootstrap.rs`): Currently uses `Vec>` → flatten → `Array2` which allocates twice. Should allocate directly into ndarray. +- [ ] **Consider static BLAS linking** (`rust/Cargo.toml`): Currently requires system BLAS libraries. Consider `openblas-static` or `intel-mkl-static` features for easier distribution. + +--- + ## Performance Optimizations -No major performance issues identified. Potential future optimizations: +Potential future optimizations: - [ ] JIT compilation for bootstrap loops (numba) - [ ] Parallel bootstrap iterations diff --git a/benchmarks/compare_results.py b/benchmarks/compare_results.py index 8f754f5..5411cd9 100644 --- a/benchmarks/compare_results.py +++ b/benchmarks/compare_results.py @@ -36,6 +36,11 @@ class ComparisonResult: r_time_std: float = 0.0 n_replications: int = 1 scale: str = "small" + # Optional three-way comparison fields + python_pure_time: Optional[float] = None + python_rust_time: Optional[float] = None + python_pure_time_std: float = 0.0 + python_rust_time_std: float = 0.0 def __str__(self) -> str: status = "PASS" if self.passed else "FAIL" @@ -68,6 +73,8 @@ def compare_estimates( atol: float = 1e-4, se_rtol: float = 0.10, scale: str = "small", + python_pure_results: Optional[Dict[str, Any]] = None, + python_rust_results: Optional[Dict[str, Any]] = None, ) -> ComparisonResult: """ Compare Python and R estimates for numerical equivalence. @@ -146,6 +153,24 @@ def compare_estimates( elif not se_ok and ci_overlap: notes.append(f"SE differs ({se_rel_diff:.1%}) but CI overlap - methodological difference") + # Extract three-way timing data if provided + python_pure_time = None + python_rust_time = None + python_pure_time_std = 0.0 + python_rust_time_std = 0.0 + + if python_pure_results: + pure_timing = python_pure_results.get("timing", {}) + pure_stats = pure_timing.get("stats", {}) + python_pure_time = pure_stats.get("mean", pure_timing.get("total_seconds", 0)) + python_pure_time_std = pure_stats.get("std", 0) + + if python_rust_results: + rust_timing = python_rust_results.get("timing", {}) + rust_stats = rust_timing.get("stats", {}) + python_rust_time = rust_stats.get("mean", rust_timing.get("total_seconds", 0)) + python_rust_time_std = rust_stats.get("std", 0) + return ComparisonResult( estimator=estimator, python_att=py_att, @@ -165,6 +190,10 @@ def compare_estimates( r_time_std=r_time_std, n_replications=n_reps, scale=scale, + python_pure_time=python_pure_time, + python_rust_time=python_rust_time, + python_pure_time_std=python_pure_time_std, + python_rust_time_std=python_rust_time_std, ) @@ -288,10 +317,41 @@ def generate_comparison_report( lines.append("=" * 70) lines.append("") - # Check if we have multi-replication data + # Check if we have three-way comparison data + has_three_way = any(comp.python_pure_time is not None for comp in comparisons) has_std = any(comp.n_replications > 1 for comp in comparisons) - if has_std: + if has_three_way: + # Three-way comparison table: R vs Python (pure) vs Python (rust) + lines.append("Three-Way Performance Comparison") + lines.append("") + lines.append(f"{'Estimator':<18} {'Scale':<6} {'R (s)':<10} {'Py-Pure (s)':<12} {'Py-Rust (s)':<12} {'Rust/R':<10} {'Rust/Pure':<10}") + lines.append("-" * 90) + for comp in comparisons: + r_time = comp.r_time + pure_time = comp.python_pure_time if comp.python_pure_time else "-" + rust_time = comp.python_rust_time if comp.python_rust_time else comp.python_time + + # Format times + r_str = f"{r_time:.3f}" if r_time else "-" + pure_str = f"{pure_time:.3f}" if isinstance(pure_time, (int, float)) else pure_time + rust_str = f"{rust_time:.3f}" if rust_time else "-" + + # Calculate speedups + if rust_time and r_time and r_time > 0: + rust_vs_r = f"{r_time/rust_time:.1f}x" + else: + rust_vs_r = "-" + + if rust_time and comp.python_pure_time and comp.python_pure_time > 0: + rust_vs_pure = f"{comp.python_pure_time/rust_time:.1f}x" + else: + rust_vs_pure = "-" + + lines.append( + f"{comp.estimator:<18} {comp.scale:<6} {r_str:<10} {pure_str:<12} {rust_str:<12} {rust_vs_r:<10} {rust_vs_pure:<10}" + ) + elif has_std: lines.append(f"{'Estimator':<20} {'Scale':<8} {'Python (s)':<18} {'R (s)':<18} {'Speedup':<10}") lines.append("-" * 80) for comp in comparisons: diff --git a/benchmarks/python/benchmark_basic.py b/benchmarks/python/benchmark_basic.py index c577408..e173d6f 100644 --- a/benchmarks/python/benchmark_basic.py +++ b/benchmarks/python/benchmark_basic.py @@ -8,16 +8,31 @@ import argparse import json +import os import sys from pathlib import Path +# IMPORTANT: Parse --backend and set environment variable BEFORE importing diff_diff +# This ensures the backend configuration is respected by all modules +def _get_backend_from_args(): + """Parse --backend argument without importing diff_diff.""" + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--backend", default="auto", choices=["auto", "python", "rust"]) + args, _ = parser.parse_known_args() + return args.backend + +_requested_backend = _get_backend_from_args() +if _requested_backend in ("python", "rust"): + os.environ["DIFF_DIFF_BACKEND"] = _requested_backend + +# NOW import diff_diff and other dependencies (will see the env var) import numpy as np import pandas as pd # Add parent to path for imports sys.path.insert(0, str(Path(__file__).parent.parent.parent)) -from diff_diff import DifferenceInDifferences +from diff_diff import DifferenceInDifferences, HAS_RUST_BACKEND from benchmarks.python.utils import Timer @@ -32,12 +47,25 @@ def parse_args(): "--type", default="twfe", choices=["basic", "twfe"], help="Estimator type (basic or twfe, default: twfe)" ) + parser.add_argument( + "--backend", default="auto", choices=["auto", "python", "rust"], + help="Backend to use: auto (default), python (pure Python), rust (Rust backend)" + ) return parser.parse_args() +def get_actual_backend() -> str: + """Return the actual backend being used based on HAS_RUST_BACKEND.""" + return "rust" if HAS_RUST_BACKEND else "python" + + def main(): args = parse_args() + # Get actual backend (already configured via env var before imports) + actual_backend = get_actual_backend() + print(f"Using backend: {actual_backend}") + # Load data print(f"Loading data from: {args.data}") data = pd.read_csv(args.data) @@ -64,6 +92,7 @@ def main(): # Build output output = { "estimator": "diff_diff.DifferenceInDifferences", + "backend": actual_backend, "cluster": args.cluster, # Treatment effect "att": float(att), diff --git a/benchmarks/python/benchmark_callaway.py b/benchmarks/python/benchmark_callaway.py index 02b9824..ba99908 100644 --- a/benchmarks/python/benchmark_callaway.py +++ b/benchmarks/python/benchmark_callaway.py @@ -8,16 +8,31 @@ import argparse import json +import os import sys from pathlib import Path +# IMPORTANT: Parse --backend and set environment variable BEFORE importing diff_diff +# This ensures the backend configuration is respected by all modules +def _get_backend_from_args(): + """Parse --backend argument without importing diff_diff.""" + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--backend", default="auto", choices=["auto", "python", "rust"]) + args, _ = parser.parse_known_args() + return args.backend + +_requested_backend = _get_backend_from_args() +if _requested_backend in ("python", "rust"): + os.environ["DIFF_DIFF_BACKEND"] = _requested_backend + +# NOW import diff_diff and other dependencies (will see the env var) import numpy as np import pandas as pd # Add parent to path for imports sys.path.insert(0, str(Path(__file__).parent.parent.parent)) -from diff_diff import CallawaySantAnna +from diff_diff import CallawaySantAnna, HAS_RUST_BACKEND from benchmarks.python.utils import BenchmarkResult, Timer @@ -39,12 +54,25 @@ def parse_args(): choices=["never_treated", "not_yet_treated"], help="Control group definition", ) + parser.add_argument( + "--backend", default="auto", choices=["auto", "python", "rust"], + help="Backend to use: auto (default), python (pure Python), rust (Rust backend)" + ) return parser.parse_args() +def get_actual_backend() -> str: + """Return the actual backend being used based on HAS_RUST_BACKEND.""" + return "rust" if HAS_RUST_BACKEND else "python" + + def main(): args = parse_args() + # Get actual backend (already configured via env var before imports) + actual_backend = get_actual_backend() + print(f"Using backend: {actual_backend}") + # Load data print(f"Loading data from: {args.data}") df = pd.read_csv(args.data) @@ -113,6 +141,7 @@ def main(): # Build output output = { "estimator": "diff_diff.CallawaySantAnna", + "backend": actual_backend, "method": args.method, "control_group": args.control_group, # Overall ATT diff --git a/benchmarks/python/benchmark_synthdid.py b/benchmarks/python/benchmark_synthdid.py index 414fc6b..4c1323b 100644 --- a/benchmarks/python/benchmark_synthdid.py +++ b/benchmarks/python/benchmark_synthdid.py @@ -8,16 +8,31 @@ import argparse import json +import os import sys from pathlib import Path +# IMPORTANT: Parse --backend and set environment variable BEFORE importing diff_diff +# This ensures the backend configuration is respected by all modules +def _get_backend_from_args(): + """Parse --backend argument without importing diff_diff.""" + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--backend", default="auto", choices=["auto", "python", "rust"]) + args, _ = parser.parse_known_args() + return args.backend + +_requested_backend = _get_backend_from_args() +if _requested_backend in ("python", "rust"): + os.environ["DIFF_DIFF_BACKEND"] = _requested_backend + +# NOW import diff_diff and other dependencies (will see the env var) import numpy as np import pandas as pd # Add parent to path for imports sys.path.insert(0, str(Path(__file__).parent.parent.parent)) -from diff_diff import SyntheticDiD +from diff_diff import SyntheticDiD, HAS_RUST_BACKEND from benchmarks.python.utils import Timer @@ -33,12 +48,25 @@ def parse_args(): choices=["bootstrap", "placebo"], help="Variance estimation method (default: placebo to match R)" ) + parser.add_argument( + "--backend", default="auto", choices=["auto", "python", "rust"], + help="Backend to use: auto (default), python (pure Python), rust (Rust backend)" + ) return parser.parse_args() +def get_actual_backend() -> str: + """Return the actual backend being used based on HAS_RUST_BACKEND.""" + return "rust" if HAS_RUST_BACKEND else "python" + + def main(): args = parse_args() + # Get actual backend (already configured via env var before imports) + actual_backend = get_actual_backend() + print(f"Using backend: {actual_backend}") + # Load data print(f"Loading data from: {args.data}") data = pd.read_csv(args.data) @@ -74,6 +102,7 @@ def main(): # Build output output = { "estimator": "diff_diff.SyntheticDiD", + "backend": actual_backend, # Point estimate and SE "att": float(results.att), "se": float(results.se), diff --git a/benchmarks/run_benchmarks.py b/benchmarks/run_benchmarks.py index 7ee1a5a..284d2f6 100644 --- a/benchmarks/run_benchmarks.py +++ b/benchmarks/run_benchmarks.py @@ -64,6 +64,11 @@ "basic": {"n_units": 10000, "n_periods": 10}, "sdid": {"n_control": 8000, "n_treated": 2000, "n_pre": 30, "n_post": 20}, }, + "20k": { + "staggered": {"n_units": 20000, "n_periods": 18, "n_cohorts": 7}, + "basic": {"n_units": 20000, "n_periods": 12}, + "sdid": {"n_control": 16000, "n_treated": 4000, "n_pre": 35, "n_post": 25}, + }, } # Timeout configurations (seconds) by scale @@ -72,6 +77,7 @@ "1k": {"python": 300, "r": 1800}, "5k": {"python": 600, "r": 3600}, "10k": {"python": 1200, "r": 7200}, + "20k": {"python": 2400, "r": 14400}, } @@ -150,6 +156,7 @@ def run_python_benchmark( output_path: Path, extra_args: Optional[List[str]] = None, timeout: Optional[int] = None, + backend: str = "auto", ) -> Dict[str, Any]: """ Execute Python benchmark script and return results. @@ -166,6 +173,8 @@ def run_python_benchmark( Additional command line arguments. timeout : int, optional Timeout in seconds. + backend : str + Backend to use: 'auto', 'python', or 'rust'. Returns ------- @@ -179,6 +188,7 @@ def run_python_benchmark( str(py_script), "--data", str(data_path), "--output", str(output_path), + "--backend", backend, ] if extra_args: cmd.extend(extra_args) @@ -290,48 +300,63 @@ def run_callaway_benchmark( name: str = "callaway", scale: str = "small", n_replications: int = 1, + backends: Optional[List[str]] = None, ) -> Dict[str, Any]: """Run Callaway-Sant'Anna benchmarks (Python and R) with replications.""" print(f"\n{'='*60}") print(f"CALLAWAY-SANT'ANNA BENCHMARK ({scale})") print(f"{'='*60}") + if backends is None: + backends = ["python", "rust"] + timeouts = TIMEOUT_CONFIGS.get(scale, TIMEOUT_CONFIGS["small"]) results = { "name": name, "scale": scale, "n_replications": n_replications, - "python": None, + "python_pure": None, + "python_rust": None, "r": None, "comparison": None, } - # Python benchmark with replications - print(f"\nRunning Python (diff_diff.CallawaySantAnna) - {n_replications} replications...") - py_output = RESULTS_DIR / "accuracy" / f"python_{name}_{scale}.json" - py_output.parent.mkdir(parents=True, exist_ok=True) - - py_timings = [] - py_result = None - for rep in range(n_replications): - try: - py_result = run_python_benchmark( - "benchmark_callaway.py", data_path, py_output, - timeout=timeouts["python"] - ) - py_timings.append(py_result["timing"]["total_seconds"]) - if rep == 0: - print(f" ATT: {py_result['overall_att']:.4f}") - print(f" SE: {py_result['overall_se']:.4f}") - print(f" Rep {rep+1}/{n_replications}: {py_timings[-1]:.3f}s") - except Exception as e: - print(f" Rep {rep+1} failed: {e}") - - if py_result and py_timings: - timing_stats = compute_timing_stats(py_timings) - py_result["timing"] = timing_stats - results["python"] = py_result - print(f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s") + # Run Python benchmark for each backend + for backend in backends: + # Map backend name to label (python -> pure, rust -> rust) + backend_label = f"python_{'pure' if backend == 'python' else backend}" + print(f"\nRunning Python (diff_diff.CallawaySantAnna, backend={backend}) - {n_replications} replications...") + py_output = RESULTS_DIR / "accuracy" / f"{backend_label}_{name}_{scale}.json" + py_output.parent.mkdir(parents=True, exist_ok=True) + + py_timings = [] + py_result = None + for rep in range(n_replications): + try: + py_result = run_python_benchmark( + "benchmark_callaway.py", data_path, py_output, + timeout=timeouts["python"], + backend=backend, + ) + py_timings.append(py_result["timing"]["total_seconds"]) + if rep == 0: + print(f" ATT: {py_result['overall_att']:.4f}") + print(f" SE: {py_result['overall_se']:.4f}") + print(f" Rep {rep+1}/{n_replications}: {py_timings[-1]:.3f}s") + except Exception as e: + print(f" Rep {rep+1} failed: {e}") + + if py_result and py_timings: + timing_stats = compute_timing_stats(py_timings) + py_result["timing"] = timing_stats + results[backend_label] = py_result + print(f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s") + + # For backward compatibility, also store as "python" (use rust if available) + if results.get("python_rust"): + results["python"] = results["python_rust"] + elif results.get("python_pure"): + results["python"] = results["python_pure"] # R benchmark with replications print(f"\nRunning R (did::att_gt) - {n_replications} replications...") @@ -361,20 +386,35 @@ def run_callaway_benchmark( # Compare results if results["python"] and results["r"]: - print("\nComparison:") + print("\nComparison (Python vs R):") comparison = compare_estimates( - results["python"], results["r"], "CallawaySantAnna", scale=scale + results["python"], results["r"], "CallawaySantAnna", scale=scale, + python_pure_results=results.get("python_pure"), + python_rust_results=results.get("python_rust"), ) results["comparison"] = comparison print(f" ATT diff: {comparison.att_diff:.2e}") print(f" SE rel diff: {comparison.se_rel_diff:.1%}") print(f" Status: {'PASS' if comparison.passed else 'FAIL'}") - # Compute speedup from timing stats - py_mean = results["python"]["timing"]["stats"]["mean"] - r_mean = results["r"]["timing"]["stats"]["mean"] - speedup = r_mean / py_mean if py_mean > 0 else float('inf') - print(f" Speed: Python is {speedup:.1f}x faster") + # Print timing comparison table + print("\nTiming Comparison:") + print(f" {'Backend':<15} {'Time (s)':<12} {'vs R':<12} {'vs Pure Python':<15}") + print(f" {'-'*54}") + + r_mean = results["r"]["timing"]["stats"]["mean"] if results["r"] else None + pure_mean = results["python_pure"]["timing"]["stats"]["mean"] if results.get("python_pure") else None + rust_mean = results["python_rust"]["timing"]["stats"]["mean"] if results.get("python_rust") else None + + if r_mean: + print(f" {'R':<15} {r_mean:<12.3f} {'1.00x':<12} {'-':<15}") + if pure_mean: + r_speedup = f"{r_mean/pure_mean:.2f}x" if r_mean else "-" + print(f" {'Python (pure)':<15} {pure_mean:<12.3f} {r_speedup:<12} {'1.00x':<15}") + if rust_mean: + r_speedup = f"{r_mean/rust_mean:.2f}x" if r_mean else "-" + pure_speedup = f"{pure_mean/rust_mean:.2f}x" if pure_mean else "-" + print(f" {'Python (rust)':<15} {rust_mean:<12.3f} {r_speedup:<12} {pure_speedup:<15}") return results @@ -384,51 +424,66 @@ def run_synthdid_benchmark( name: str = "synthdid", scale: str = "small", n_replications: int = 1, + backends: Optional[List[str]] = None, ) -> Dict[str, Any]: """Run Synthetic DiD benchmarks (Python and R) with replications.""" print(f"\n{'='*60}") print(f"SYNTHETIC DID BENCHMARK ({scale})") print(f"{'='*60}") + if backends is None: + backends = ["python", "rust"] + timeouts = TIMEOUT_CONFIGS.get(scale, TIMEOUT_CONFIGS["small"]) results = { "name": name, "scale": scale, "n_replications": n_replications, - "python": None, + "python_pure": None, + "python_rust": None, "r": None, "comparison": None, } - # Python benchmark with replications - print(f"\nRunning Python (diff_diff.SyntheticDiD) - {n_replications} replications...") - py_output = RESULTS_DIR / "accuracy" / f"python_{name}_{scale}.json" - py_output.parent.mkdir(parents=True, exist_ok=True) - - py_timings = [] - py_result = None - for rep in range(n_replications): - try: - py_result = run_python_benchmark( - "benchmark_synthdid.py", - data_path, - py_output, - extra_args=["--n-bootstrap", "50"], - timeout=timeouts["python"] - ) - py_timings.append(py_result["timing"]["total_seconds"]) - if rep == 0: - print(f" ATT: {py_result['att']:.4f}") - print(f" SE: {py_result['se']:.4f}") - print(f" Rep {rep+1}/{n_replications}: {py_timings[-1]:.3f}s") - except Exception as e: - print(f" Rep {rep+1} failed: {e}") - - if py_result and py_timings: - timing_stats = compute_timing_stats(py_timings) - py_result["timing"] = timing_stats - results["python"] = py_result - print(f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s") + # Run Python benchmark for each backend + for backend in backends: + # Map backend name to label (python -> pure, rust -> rust) + backend_label = f"python_{'pure' if backend == 'python' else backend}" + print(f"\nRunning Python (diff_diff.SyntheticDiD, backend={backend}) - {n_replications} replications...") + py_output = RESULTS_DIR / "accuracy" / f"{backend_label}_{name}_{scale}.json" + py_output.parent.mkdir(parents=True, exist_ok=True) + + py_timings = [] + py_result = None + for rep in range(n_replications): + try: + py_result = run_python_benchmark( + "benchmark_synthdid.py", + data_path, + py_output, + extra_args=["--n-bootstrap", "50"], + timeout=timeouts["python"], + backend=backend, + ) + py_timings.append(py_result["timing"]["total_seconds"]) + if rep == 0: + print(f" ATT: {py_result['att']:.4f}") + print(f" SE: {py_result['se']:.4f}") + print(f" Rep {rep+1}/{n_replications}: {py_timings[-1]:.3f}s") + except Exception as e: + print(f" Rep {rep+1} failed: {e}") + + if py_result and py_timings: + timing_stats = compute_timing_stats(py_timings) + py_result["timing"] = timing_stats + results[backend_label] = py_result + print(f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s") + + # For backward compatibility, also store as "python" (use rust if available) + if results.get("python_rust"): + results["python"] = results["python_rust"] + elif results.get("python_pure"): + results["python"] = results["python_pure"] # R benchmark with replications print(f"\nRunning R (synthdid::synthdid_estimate) - {n_replications} replications...") @@ -458,18 +513,35 @@ def run_synthdid_benchmark( # Compare results if results["python"] and results["r"]: - print("\nComparison:") - comparison = compare_estimates(results["python"], results["r"], "SyntheticDiD", scale=scale) + print("\nComparison (Python vs R):") + comparison = compare_estimates( + results["python"], results["r"], "SyntheticDiD", scale=scale, + python_pure_results=results.get("python_pure"), + python_rust_results=results.get("python_rust"), + ) results["comparison"] = comparison print(f" ATT diff: {comparison.att_diff:.2e}") print(f" SE rel diff: {comparison.se_rel_diff:.1%}") print(f" Status: {'PASS' if comparison.passed else 'FAIL'}") - # Compute speedup from timing stats - py_mean = results["python"]["timing"]["stats"]["mean"] - r_mean = results["r"]["timing"]["stats"]["mean"] - speedup = r_mean / py_mean if py_mean > 0 else float('inf') - print(f" Speed: Python is {speedup:.1f}x faster") + # Print timing comparison table + print("\nTiming Comparison:") + print(f" {'Backend':<15} {'Time (s)':<12} {'vs R':<12} {'vs Pure Python':<15}") + print(f" {'-'*54}") + + r_mean = results["r"]["timing"]["stats"]["mean"] if results["r"] else None + pure_mean = results["python_pure"]["timing"]["stats"]["mean"] if results.get("python_pure") else None + rust_mean = results["python_rust"]["timing"]["stats"]["mean"] if results.get("python_rust") else None + + if r_mean: + print(f" {'R':<15} {r_mean:<12.3f} {'1.00x':<12} {'-':<15}") + if pure_mean: + r_speedup = f"{r_mean/pure_mean:.2f}x" if r_mean else "-" + print(f" {'Python (pure)':<15} {pure_mean:<12.3f} {r_speedup:<12} {'1.00x':<15}") + if rust_mean: + r_speedup = f"{r_mean/rust_mean:.2f}x" if r_mean else "-" + pure_speedup = f"{pure_mean/rust_mean:.2f}x" if pure_mean else "-" + print(f" {'Python (rust)':<15} {rust_mean:<12.3f} {r_speedup:<12} {pure_speedup:<15}") return results @@ -479,49 +551,64 @@ def run_basic_did_benchmark( name: str = "basic", scale: str = "small", n_replications: int = 1, + backends: Optional[List[str]] = None, ) -> Dict[str, Any]: """Run basic DiD / TWFE benchmarks (Python and R) with replications.""" print(f"\n{'='*60}") print(f"BASIC DID / TWFE BENCHMARK ({scale})") print(f"{'='*60}") + if backends is None: + backends = ["python", "rust"] + timeouts = TIMEOUT_CONFIGS.get(scale, TIMEOUT_CONFIGS["small"]) results = { "name": name, "scale": scale, "n_replications": n_replications, - "python": None, + "python_pure": None, + "python_rust": None, "r": None, "comparison": None, } - # Python benchmark with replications - print(f"\nRunning Python (diff_diff.TwoWayFixedEffects) - {n_replications} replications...") - py_output = RESULTS_DIR / "accuracy" / f"python_{name}_{scale}.json" - py_output.parent.mkdir(parents=True, exist_ok=True) - - py_timings = [] - py_result = None - for rep in range(n_replications): - try: - py_result = run_python_benchmark( - "benchmark_basic.py", data_path, py_output, - extra_args=["--type", "twfe"], - timeout=timeouts["python"] - ) - py_timings.append(py_result["timing"]["total_seconds"]) - if rep == 0: - print(f" ATT: {py_result['att']:.4f}") - print(f" SE: {py_result['se']:.4f}") - print(f" Rep {rep+1}/{n_replications}: {py_timings[-1]:.3f}s") - except Exception as e: - print(f" Rep {rep+1} failed: {e}") - - if py_result and py_timings: - timing_stats = compute_timing_stats(py_timings) - py_result["timing"] = timing_stats - results["python"] = py_result - print(f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s") + # Run Python benchmark for each backend + for backend in backends: + # Map backend name to label (python -> pure, rust -> rust) + backend_label = f"python_{'pure' if backend == 'python' else backend}" + print(f"\nRunning Python (diff_diff.DifferenceInDifferences, backend={backend}) - {n_replications} replications...") + py_output = RESULTS_DIR / "accuracy" / f"{backend_label}_{name}_{scale}.json" + py_output.parent.mkdir(parents=True, exist_ok=True) + + py_timings = [] + py_result = None + for rep in range(n_replications): + try: + py_result = run_python_benchmark( + "benchmark_basic.py", data_path, py_output, + extra_args=["--type", "twfe"], + timeout=timeouts["python"], + backend=backend, + ) + py_timings.append(py_result["timing"]["total_seconds"]) + if rep == 0: + print(f" ATT: {py_result['att']:.4f}") + print(f" SE: {py_result['se']:.4f}") + print(f" Rep {rep+1}/{n_replications}: {py_timings[-1]:.3f}s") + except Exception as e: + print(f" Rep {rep+1} failed: {e}") + + if py_result and py_timings: + timing_stats = compute_timing_stats(py_timings) + py_result["timing"] = timing_stats + results[backend_label] = py_result + print(f" Mean time: {timing_stats['stats']['mean']:.3f}s ± {timing_stats['stats']['std']:.3f}s") + + # For backward compatibility, also store as "python" (use rust if available) + if results.get("python_rust"): + results["python"] = results["python_rust"] + elif results.get("python_pure"): + results["python"] = results["python_pure"] # R benchmark with replications print(f"\nRunning R (fixest::feols) - {n_replications} replications...") @@ -552,18 +639,35 @@ def run_basic_did_benchmark( # Compare results if results["python"] and results["r"]: - print("\nComparison:") - comparison = compare_estimates(results["python"], results["r"], "BasicDiD/TWFE", scale=scale) + print("\nComparison (Python vs R):") + comparison = compare_estimates( + results["python"], results["r"], "BasicDiD/TWFE", scale=scale, + python_pure_results=results.get("python_pure"), + python_rust_results=results.get("python_rust"), + ) results["comparison"] = comparison print(f" ATT diff: {comparison.att_diff:.2e}") print(f" SE rel diff: {comparison.se_rel_diff:.1%}") print(f" Status: {'PASS' if comparison.passed else 'FAIL'}") - # Compute speedup from timing stats - py_mean = results["python"]["timing"]["stats"]["mean"] - r_mean = results["r"]["timing"]["stats"]["mean"] - speedup = r_mean / py_mean if py_mean > 0 else float('inf') - print(f" Speed: Python is {speedup:.1f}x faster") + # Print timing comparison table + print("\nTiming Comparison:") + print(f" {'Backend':<15} {'Time (s)':<12} {'vs R':<12} {'vs Pure Python':<15}") + print(f" {'-'*54}") + + r_mean = results["r"]["timing"]["stats"]["mean"] if results["r"] else None + pure_mean = results["python_pure"]["timing"]["stats"]["mean"] if results.get("python_pure") else None + rust_mean = results["python_rust"]["timing"]["stats"]["mean"] if results.get("python_rust") else None + + if r_mean: + print(f" {'R':<15} {r_mean:<12.3f} {'1.00x':<12} {'-':<15}") + if pure_mean: + r_speedup = f"{r_mean/pure_mean:.2f}x" if r_mean else "-" + print(f" {'Python (pure)':<15} {pure_mean:<12.3f} {r_speedup:<12} {'1.00x':<15}") + if rust_mean: + r_speedup = f"{r_mean/rust_mean:.2f}x" if r_mean else "-" + pure_speedup = f"{pure_mean/rust_mean:.2f}x" if pure_mean else "-" + print(f" {'Python (rust)':<15} {rust_mean:<12.3f} {r_speedup:<12} {pure_speedup:<15}") return results @@ -601,7 +705,7 @@ def main(): ) parser.add_argument( "--scale", - choices=["small", "1k", "5k", "10k", "all"], + choices=["small", "1k", "5k", "10k", "20k", "all"], default="small", help="Dataset scale to use (default: small). Use 'all' for all scales.", ) @@ -609,7 +713,7 @@ def main(): # Determine which scales to run if args.scale == "all": - scales = ["small", "1k", "5k", "10k"] + scales = ["small", "1k", "5k", "10k", "20k"] else: scales = [args.scale] diff --git a/diff_diff/__init__.py b/diff_diff/__init__.py index 507a71a..ffb0f07 100644 --- a/diff_diff/__init__.py +++ b/diff_diff/__init__.py @@ -5,6 +5,16 @@ using the difference-in-differences methodology. """ +# Import backend detection from dedicated module (avoids circular imports) +from diff_diff._backend import ( + HAS_RUST_BACKEND, + _rust_bootstrap_weights, + _rust_compute_robust_vcov, + _rust_project_simplex, + _rust_solve_ols, + _rust_synthetic_weights, +) + from diff_diff.bacon import ( BaconDecomposition, BaconDecompositionResults, @@ -103,7 +113,7 @@ plot_sensitivity, ) -__version__ = "1.4.0" +__version__ = "2.0.0" __all__ = [ # Estimators "DifferenceInDifferences", @@ -187,4 +197,6 @@ "compute_pretrends_power", "compute_mdv", "plot_pretrends_power", + # Rust backend + "HAS_RUST_BACKEND", ] diff --git a/diff_diff/_backend.py b/diff_diff/_backend.py new file mode 100644 index 0000000..302b611 --- /dev/null +++ b/diff_diff/_backend.py @@ -0,0 +1,64 @@ +""" +Backend detection and configuration for diff-diff. + +This module handles: +1. Detection of optional Rust backend +2. Environment variable configuration (DIFF_DIFF_BACKEND) +3. Exports HAS_RUST_BACKEND and Rust function references + +Other modules should import from here to avoid circular imports with __init__.py. +""" + +import os + +# Check for backend override via environment variable +# DIFF_DIFF_BACKEND can be: 'auto' (default), 'python', or 'rust' +_backend_env = os.environ.get('DIFF_DIFF_BACKEND', 'auto').lower() + +# Try to import Rust backend for accelerated operations +try: + from diff_diff._rust_backend import ( + generate_bootstrap_weights_batch as _rust_bootstrap_weights, + compute_synthetic_weights as _rust_synthetic_weights, + project_simplex as _rust_project_simplex, + solve_ols as _rust_solve_ols, + compute_robust_vcov as _rust_compute_robust_vcov, + ) + _rust_available = True +except ImportError: + _rust_available = False + _rust_bootstrap_weights = None + _rust_synthetic_weights = None + _rust_project_simplex = None + _rust_solve_ols = None + _rust_compute_robust_vcov = None + +# Determine final backend based on environment variable and availability +if _backend_env == 'python': + # Force pure Python mode - disable Rust even if available + HAS_RUST_BACKEND = False + _rust_bootstrap_weights = None + _rust_synthetic_weights = None + _rust_project_simplex = None + _rust_solve_ols = None + _rust_compute_robust_vcov = None +elif _backend_env == 'rust': + # Force Rust mode - fail if not available + if not _rust_available: + raise ImportError( + "DIFF_DIFF_BACKEND=rust but Rust backend is not available. " + "Install with: pip install diff-diff[rust]" + ) + HAS_RUST_BACKEND = True +else: + # Auto mode - use Rust if available + HAS_RUST_BACKEND = _rust_available + +__all__ = [ + 'HAS_RUST_BACKEND', + '_rust_bootstrap_weights', + '_rust_synthetic_weights', + '_rust_project_simplex', + '_rust_solve_ols', + '_rust_compute_robust_vcov', +] diff --git a/diff_diff/linalg.py b/diff_diff/linalg.py index 3c7175d..51fa232 100644 --- a/diff_diff/linalg.py +++ b/diff_diff/linalg.py @@ -1,15 +1,17 @@ """ Unified linear algebra backend for diff-diff. -This module provides optimized OLS and variance estimation that can be -swapped to a compiled backend (Rust/C++) for maximum performance. +This module provides optimized OLS and variance estimation with an optional +Rust backend for maximum performance. The key optimizations are: 1. scipy.linalg.lstsq with 'gelsy' driver (QR-based, faster than SVD) 2. Vectorized cluster-robust SE via groupby (eliminates O(n*clusters) loop) 3. Single interface for all estimators (reduces code duplication) +4. Optional Rust backend for additional speedup (when available) -Future: This module can be extended with a Rust backend for additional speedup. +The Rust backend is automatically used when available, with transparent +fallback to NumPy/SciPy implementations. """ from typing import Optional, Tuple, Union @@ -18,6 +20,13 @@ import pandas as pd from scipy.linalg import lstsq as scipy_lstsq +# Import Rust backend if available (from _backend to avoid circular imports) +from diff_diff._backend import ( + HAS_RUST_BACKEND, + _rust_compute_robust_vcov, + _rust_solve_ols, +) + def solve_ols( X: np.ndarray, @@ -119,6 +128,87 @@ def solve_ols( "Clean your data or set check_finite=False to skip this check." ) + # Use Rust backend if available + # Note: Fall back to NumPy if check_finite=False since Rust's LAPACK + # doesn't support non-finite values + if HAS_RUST_BACKEND and check_finite: + # Ensure contiguous arrays for Rust + X = np.ascontiguousarray(X, dtype=np.float64) + y = np.ascontiguousarray(y, dtype=np.float64) + + # Convert cluster_ids to int64 for Rust (if provided) + cluster_ids_int = None + if cluster_ids is not None: + cluster_ids_int = pd.factorize(cluster_ids)[0].astype(np.int64) + + try: + coefficients, residuals, vcov = _rust_solve_ols( + X, y, cluster_ids_int, return_vcov + ) + except ValueError as e: + # Translate Rust LAPACK errors to consistent Python error messages + error_msg = str(e) + if "Matrix inversion failed" in error_msg or "Least squares failed" in error_msg: + raise ValueError( + "Design matrix is rank-deficient (singular X'X matrix). " + "This indicates perfect multicollinearity. Check your fixed effects " + "and covariates for linear dependencies." + ) from e + raise + + if return_fitted: + fitted = X @ coefficients + return coefficients, residuals, fitted, vcov + else: + return coefficients, residuals, vcov + + # Fallback to NumPy/SciPy implementation + return _solve_ols_numpy( + X, y, cluster_ids=cluster_ids, return_vcov=return_vcov, return_fitted=return_fitted + ) + + +def _solve_ols_numpy( + X: np.ndarray, + y: np.ndarray, + *, + cluster_ids: Optional[np.ndarray] = None, + return_vcov: bool = True, + return_fitted: bool = False, +) -> Union[ + Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]], + Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]], +]: + """ + NumPy/SciPy fallback implementation of solve_ols. + + Uses scipy.linalg.lstsq with 'gelsy' driver (QR with column pivoting) + for fast and stable least squares solving. + + Parameters + ---------- + X : np.ndarray + Design matrix of shape (n, k). + y : np.ndarray + Response vector of shape (n,). + cluster_ids : np.ndarray, optional + Cluster identifiers for cluster-robust SEs. + return_vcov : bool + Whether to compute variance-covariance matrix. + return_fitted : bool + Whether to return fitted values. + + Returns + ------- + coefficients : np.ndarray + OLS coefficients of shape (k,). + residuals : np.ndarray + Residuals of shape (n,). + fitted : np.ndarray, optional + Fitted values if return_fitted=True. + vcov : np.ndarray, optional + Variance-covariance matrix if return_vcov=True. + """ # Solve OLS using scipy's optimized solver # 'gelsy' uses QR with column pivoting, faster than default 'gelsd' (SVD) # Note: gelsy doesn't reliably report rank, so we don't check for deficiency @@ -131,7 +221,7 @@ def solve_ols( # Compute variance-covariance matrix if requested vcov = None if return_vcov: - vcov = compute_robust_vcov(X, residuals, cluster_ids) + vcov = _compute_robust_vcov_numpy(X, residuals, cluster_ids) if return_fitted: return coefficients, residuals, fitted, vcov @@ -176,6 +266,63 @@ def compute_robust_vcov( The cluster-robust computation is vectorized using pandas groupby, which is much faster than a Python loop over clusters. """ + # Use Rust backend if available + if HAS_RUST_BACKEND: + X = np.ascontiguousarray(X, dtype=np.float64) + residuals = np.ascontiguousarray(residuals, dtype=np.float64) + + cluster_ids_int = None + if cluster_ids is not None: + cluster_ids_int = pd.factorize(cluster_ids)[0].astype(np.int64) + + try: + return _rust_compute_robust_vcov(X, residuals, cluster_ids_int) + except ValueError as e: + # Translate Rust LAPACK errors to consistent Python error messages + error_msg = str(e) + if "Matrix inversion failed" in error_msg: + raise ValueError( + "Design matrix is rank-deficient (singular X'X matrix). " + "This indicates perfect multicollinearity. Check your fixed effects " + "and covariates for linear dependencies." + ) from e + raise + + # Fallback to NumPy implementation + return _compute_robust_vcov_numpy(X, residuals, cluster_ids) + + +def _compute_robust_vcov_numpy( + X: np.ndarray, + residuals: np.ndarray, + cluster_ids: Optional[np.ndarray] = None, +) -> np.ndarray: + """ + NumPy fallback implementation of compute_robust_vcov. + + Computes HC1 (heteroskedasticity-robust) or cluster-robust variance-covariance + matrix using the sandwich estimator. + + Parameters + ---------- + X : np.ndarray + Design matrix of shape (n, k). + residuals : np.ndarray + OLS residuals of shape (n,). + cluster_ids : np.ndarray, optional + Cluster identifiers. If None, uses HC1. If provided, uses + cluster-robust with G/(G-1) small-sample adjustment. + + Returns + ------- + vcov : np.ndarray + Variance-covariance matrix of shape (k, k). + + Notes + ----- + Uses vectorized groupby aggregation for cluster-robust SEs to avoid + the O(n * G) loop that would be required with explicit iteration. + """ n, k = X.shape XtX = X.T @ X diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index a690280..733c74e 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -20,6 +20,9 @@ compute_p_value, ) +# Import Rust backend if available (from _backend to avoid circular imports) +from diff_diff._backend import HAS_RUST_BACKEND, _rust_bootstrap_weights + # Type alias for pre-computed structures PrecomputedData = Dict[str, Any] @@ -99,6 +102,45 @@ def _generate_bootstrap_weights_batch( rng : np.random.Generator Random number generator. + Returns + ------- + np.ndarray + Array of bootstrap weights with shape (n_bootstrap, n_units). + """ + # Use Rust backend if available (parallel + fast RNG) + if HAS_RUST_BACKEND: + # Get seed from the NumPy RNG for reproducibility + seed = rng.integers(0, 2**63 - 1) + return _rust_bootstrap_weights(n_bootstrap, n_units, weight_type, seed) + + # Fallback to NumPy implementation + return _generate_bootstrap_weights_batch_numpy(n_bootstrap, n_units, weight_type, rng) + + +def _generate_bootstrap_weights_batch_numpy( + n_bootstrap: int, + n_units: int, + weight_type: str, + rng: np.random.Generator, +) -> np.ndarray: + """ + NumPy fallback implementation of _generate_bootstrap_weights_batch. + + Generates multiplier bootstrap weights for wild cluster bootstrap. + All weight distributions satisfy E[w] = 0, E[w^2] = 1. + + Parameters + ---------- + n_bootstrap : int + Number of bootstrap iterations. + n_units : int + Number of units (clusters) to generate weights for. + weight_type : str + Type of weights: "rademacher" (+-1), "mammen" (2-point), + or "webb" (6-point). + rng : np.random.Generator + Random number generator for reproducibility. + Returns ------- np.ndarray diff --git a/diff_diff/utils.py b/diff_diff/utils.py index a012d43..600d9c2 100644 --- a/diff_diff/utils.py +++ b/diff_diff/utils.py @@ -13,6 +13,13 @@ from diff_diff.linalg import compute_robust_vcov as _compute_robust_vcov_linalg from diff_diff.linalg import solve_ols as _solve_ols_linalg +# Import Rust backend if available (from _backend to avoid circular imports) +from diff_diff._backend import ( + HAS_RUST_BACKEND, + _rust_project_simplex, + _rust_synthetic_weights, +) + # Numerical constants for optimization algorithms _OPTIMIZATION_MAX_ITER = 1000 # Maximum iterations for weight optimization _OPTIMIZATION_TOL = 1e-8 # Convergence tolerance for optimization @@ -1033,6 +1040,37 @@ def compute_synthetic_weights( if n_control == 1: return np.asarray([1.0]) + # Use Rust backend if available + if HAS_RUST_BACKEND: + Y_control = np.ascontiguousarray(Y_control, dtype=np.float64) + Y_treated = np.ascontiguousarray(Y_treated, dtype=np.float64) + weights = _rust_synthetic_weights( + Y_control, Y_treated, lambda_reg, + _OPTIMIZATION_MAX_ITER, _OPTIMIZATION_TOL + ) + else: + # Fallback to NumPy implementation + weights = _compute_synthetic_weights_numpy(Y_control, Y_treated, lambda_reg) + + # Set small weights to zero for interpretability + weights[weights < min_weight] = 0 + if np.sum(weights) > 0: + weights = weights / np.sum(weights) + else: + # Fallback to uniform if all weights are zeroed + weights = np.ones(n_control) / n_control + + return np.asarray(weights) + + +def _compute_synthetic_weights_numpy( + Y_control: np.ndarray, + Y_treated: np.ndarray, + lambda_reg: float = 0.0, +) -> np.ndarray: + """NumPy fallback implementation of compute_synthetic_weights.""" + n_pre, n_control = Y_control.shape + # Initialize with uniform weights weights = np.ones(n_control) / n_control @@ -1065,15 +1103,7 @@ def compute_synthetic_weights( if np.linalg.norm(weights - weights_old) < _OPTIMIZATION_TOL: break - # Set small weights to zero for interpretability - weights[weights < min_weight] = 0 - if np.sum(weights) > 0: - weights = weights / np.sum(weights) - else: - # Fallback to uniform if all weights are zeroed - weights = np.ones(n_control) / n_control - - return np.asarray(weights) + return weights def _project_simplex(v: np.ndarray) -> np.ndarray: diff --git a/docs/benchmarks.rst b/docs/benchmarks.rst index cf9f7d0..8880210 100644 --- a/docs/benchmarks.rst +++ b/docs/benchmarks.rst @@ -2,7 +2,8 @@ Benchmarks: Validation Against R Packages ========================================= This document presents validation benchmarks comparing diff-diff against -established R packages for difference-in-differences analysis. +established R packages for difference-in-differences analysis. As of v2.0.0, +diff-diff includes an optional Rust backend for accelerated computation. .. contents:: Table of Contents :local: @@ -41,9 +42,10 @@ Validation Approach 2. **Identical Inputs**: Both Python and R estimators receive the same CSV data 3. **JSON Interchange**: R scripts output JSON for comparison 4. **Automated Comparison**: Python script validates numerical equivalence -5. **Multiple Scales**: Test at small (200-400 obs), 1K, 5K, and 10K unit scales -6. **Replicated Timing**: 10 replications per benchmark to report mean ± std -7. **Reproducible Seed**: Benchmarks use seed 20260111 for data generation +5. **Multiple Scales**: Test at small (200-400 obs), 1K, 5K, 10K, and 20K unit scales +6. **Replicated Timing**: 3 replications per benchmark to report mean ± std +7. **Reproducible Seed**: Benchmarks use seed 42 for data generation +8. **Three-Way Comparison**: Compare R, Python (pure NumPy/SciPy), and Python (Rust backend) Tolerance Thresholds ~~~~~~~~~~~~~~~~~~~~ @@ -92,23 +94,27 @@ Basic DiD Results :header-rows: 1 * - Metric - - diff-diff + - diff-diff (Pure) + - diff-diff (Rust) - R fixest - Difference * - ATT + - 5.112 - 5.112 - 5.112 - < 1e-10 * - SE + - 0.183 - 0.183 - 0.183 - 0.0% * - Time (s) - 0.002 - - 0.035 - - **17.9x faster** + - 0.002 + - 0.041 + - **22x faster** -**Validation**: PASS - Results are numerically identical. +**Validation**: PASS - Results are numerically identical across all implementations. Synthetic DiD Results ~~~~~~~~~~~~~~~~~~~~~ @@ -155,21 +161,25 @@ Callaway-Sant'Anna Results :header-rows: 1 * - Metric - - diff-diff + - diff-diff (Pure) + - diff-diff (Rust) - R did - Difference * - ATT + - 2.519 - 2.519 - 2.519 - < 1e-10 * - SE + - 0.062 - 0.062 - 0.063 - 2.3% * - Time (s) - 0.005 - - 0.070 - - **14.0x faster** + - 0.005 + - 0.071 + - **14x faster** **Validation**: PASS - Both point estimates and standard errors match R closely. @@ -186,120 +196,233 @@ Callaway-Sant'Anna Results Performance Comparison ---------------------- -We benchmarked performance across multiple dataset scales with 10 replications -each to provide mean ± std timing statistics. +We benchmarked performance across multiple dataset scales with 3 replications +each to provide mean ± std timing statistics. As of v2.0.0, we compare three +implementations: + +- **R**: Reference implementation (fixest, did packages) +- **Python (Pure)**: diff-diff with NumPy/SciPy only (no Rust backend) +- **Python (Rust)**: diff-diff with optional Rust backend enabled .. note:: - **v1.4.0 Performance Improvements**: diff-diff v1.4.0 introduced major - performance optimizations including a unified linear algebra backend - (``diff_diff/linalg.py``) with scipy's optimized gelsy LAPACK driver, - vectorized cluster-robust standard errors, and optimized CallawaySantAnna - bootstrap using matrix operations. These improvements make diff-diff - **faster than R at all scales**. + **v2.0.0 Rust Backend**: diff-diff v2.0.0 introduces an optional Rust backend + for accelerated computation. The Rust backend provides significant speedups + for **SyntheticDiD** (4-8x faster than pure Python), which uses custom Rust + implementations for synthetic weight computation and simplex projection. + For **BasicDiD** and **CallawaySantAnna**, the Rust backend provides minimal + additional speedup since these estimators primarily use OLS and variance + computations that are already highly optimized in NumPy/SciPy via BLAS/LAPACK. -Summary by Scale -~~~~~~~~~~~~~~~~ +Three-Way Performance Summary +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -**Small Scale** (400-1,600 observations): +**BasicDiD/TWFE Results:** .. list-table:: :header-rows: 1 - :widths: 30 25 25 20 + :widths: 12 15 18 18 12 12 - * - Estimator - - Python (s) + * - Scale - R (s) - - Speedup - * - BasicDiD/TWFE - - 0.002 ± 0.000 - - 0.035 ± 0.001 - - **17.9x** - * - CallawaySantAnna - - 0.005 ± 0.000 - - 0.070 ± 0.001 - - **14.0x** + - Python Pure (s) + - Python Rust (s) + - Rust/R + - Rust/Pure + * - small + - 0.035 + - 0.002 + - 0.002 + - **18x** + - 1.1x + * - 1k + - 0.037 + - 0.003 + - 0.003 + - **14x** + - 1.1x + * - 5k + - 0.038 + - 0.008 + - 0.006 + - **7x** + - 1.4x + * - 10k + - 0.041 + - 0.010 + - 0.011 + - **4x** + - 0.9x + * - 20k + - 0.050 + - 0.026 + - 0.025 + - **2x** + - 1.1x -**1K Scale** (6,000-10,000 observations): +**CallawaySantAnna Results:** .. list-table:: :header-rows: 1 - :widths: 30 25 25 20 + :widths: 12 15 18 18 12 12 - * - Estimator - - Python (s) + * - Scale - R (s) - - Speedup - * - BasicDiD/TWFE - - 0.003 ± 0.001 - - 0.035 ± 0.001 - - **12.5x** - * - CallawaySantAnna - - 0.012 ± 0.000 - - 0.113 ± 0.002 - - **9.6x** + - Python Pure (s) + - Python Rust (s) + - Rust/R + - Rust/Pure + * - small + - 0.071 + - 0.005 + - 0.005 + - **14.1x** + - 1.0x + * - 1k + - 0.114 + - 0.012 + - 0.012 + - **9.4x** + - 1.0x + * - 5k + - 0.341 + - 0.055 + - 0.056 + - **6.1x** + - 1.0x + * - 10k + - 0.726 + - 0.156 + - 0.155 + - **4.7x** + - 1.0x + * - 20k + - 1.464 + - 0.404 + - 0.411 + - **3.6x** + - 1.0x -**5K Scale** (40,000-60,000 observations): +**SyntheticDiD Results:** .. list-table:: :header-rows: 1 - :widths: 30 25 25 20 + :widths: 12 15 18 18 12 12 - * - Estimator - - Python (s) + * - Scale - R (s) - - Speedup - * - BasicDiD/TWFE - - 0.006 ± 0.003 - - 0.038 ± 0.002 - - **6.1x** - * - CallawaySantAnna - - 0.055 ± 0.001 - - 0.339 ± 0.002 - - **6.2x** + - Python Pure (s) + - Python Rust (s) + - Rust/R + - Rust/Pure + * - small + - 8.18 + - 0.015 + - 0.004 + - **2234x** + - **4.0x** + * - 1k + - 110.4 + - 0.068 + - 0.100 + - **1104x** + - 0.7x + * - 5k + - 511.1 + - 3.017 + - 0.688 + - **743x** + - **4.4x** + * - 10k + - 1462.7 + - 19.56 + - 2.59 + - **565x** + - **7.6x** + +.. note:: -**10K Scale** (100,000-150,000 observations): + **SyntheticDiD Performance**: diff-diff achieves **565x to 2234x speedup** over + R's synthdid package. At 10k scale, R takes ~24 minutes while Python Rust + completes in 2.6 seconds. The Rust backend provides **4-8x additional speedup** + over pure Python for SyntheticDiD due to optimized simplex projection and + synthetic weight computation. ATT estimates differ slightly due to different + weight optimization algorithms (projected gradient descent vs Frank-Wolfe), + but confidence intervals overlap. + +Dataset Sizes +~~~~~~~~~~~~~ .. list-table:: :header-rows: 1 - :widths: 30 25 25 20 - - * - Estimator - - Python (s) - - R (s) - - Speedup - * - BasicDiD/TWFE - - 0.010 ± 0.000 - - 0.041 ± 0.001 - - **4.1x** - * - CallawaySantAnna - - 0.155 ± 0.002 - - 0.730 ± 0.004 - - **4.7x** + :widths: 12 22 22 22 22 + + * - Scale + - BasicDiD + - CallawaySantAnna + - SyntheticDiD + - Observations + * - small + - 100 × 4 + - 200 × 8 + - 50 × 20 + - 400 - 1,600 + * - 1k + - 1,000 × 6 + - 1,000 × 10 + - 1,000 × 30 + - 6,000 - 30,000 + * - 5k + - 5,000 × 8 + - 5,000 × 12 + - 5,000 × 40 + - 40,000 - 200,000 + * - 10k + - 10,000 × 10 + - 10,000 × 15 + - 10,000 × 50 + - 100,000 - 500,000 + * - 20k + - 20,000 × 12 + - 20,000 × 18 + - N/A + - 240,000 - 360,000 Key Observations ~~~~~~~~~~~~~~~~ -1. **diff-diff is faster than R at all scales**: Following v1.4.0 optimizations, - diff-diff now outperforms R packages across all dataset sizes for BasicDiD/TWFE - and CallawaySantAnna estimators. +1. **diff-diff is dramatically faster than R**: -2. **BasicDiD/TWFE**: diff-diff is 4-18x faster than R's ``fixest::feols``. - The speedup is greatest at small scales (17.9x) and remains substantial - at large scales (4.1x at 10K observations). + - **BasicDiD/TWFE**: 2-18x faster than R + - **CallawaySantAnna**: 4-14x faster than R + - **SyntheticDiD**: 565-2234x faster than R (R takes 24 minutes at 10k scale!) -3. **CallawaySantAnna**: diff-diff is 5-14x faster than R's ``did::att_gt`` - using analytical SEs. At small scales (14x speedup), pure Python overhead - is minimal; at larger scales the gap narrows but remains substantial (4.7x). +2. **Rust backend benefit depends on the estimator**: -4. **Scaling behavior**: Both estimators show sub-linear scaling in diff-diff. - At 10K scale (150K observations for CallawaySantAnna), estimation completes - in ~150ms with analytical SEs. + - **SyntheticDiD**: Rust provides **4-8x speedup** over pure Python due to + optimized simplex projection and synthetic weight computation + - **BasicDiD/CallawaySantAnna**: Rust provides minimal benefit (~1x) since + these estimators use OLS/variance computations already optimized in NumPy/SciPy + +3. **When to use Rust backend**: + + - **SyntheticDiD**: Recommended - provides significant speedup (4-8x) + - **Bootstrap inference**: May help with parallelized iterations + - **BasicDiD/CallawaySantAnna**: Optional - pure Python is equally fast + +4. **Scaling behavior**: Both Python implementations show excellent scaling. + At 10K scale (500K observations for SyntheticDiD), Rust completes in + ~2.6 seconds vs ~20 seconds for pure Python vs ~24 minutes for R. + +5. **No Rust required for most use cases**: Users without Rust/maturin can + install diff-diff and get full functionality with excellent performance. + For BasicDiD and CallawaySantAnna, pure Python achieves the same speed as Rust. + Only SyntheticDiD benefits significantly from the Rust backend. Performance Optimization Details ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -The v1.4.0 performance improvements came from: +The performance improvements come from: 1. **Unified ``linalg.py`` backend**: Single optimized OLS/SE implementation using scipy's gelsy LAPACK driver (QR-based, faster than SVD) @@ -313,6 +436,9 @@ The v1.4.0 performance improvements came from: 4. **Vectorized bootstrap** (CallawaySantAnna): Matrix operations instead of nested loops, batch weight generation +5. **Optional Rust backend** (v2.0.0): PyO3-based Rust extension for compute-intensive + operations (OLS, robust variance, bootstrap weights, simplex projection) + Why is diff-diff Fast? ~~~~~~~~~~~~~~~~~~~~~~ @@ -320,6 +446,7 @@ Why is diff-diff Fast? 2. **Vectorized operations**: NumPy/pandas for matrix operations and aggregations 3. **Efficient memory access**: Pre-computed structures avoid repeated data reshaping 4. **Pure Python overhead minimized**: Hot paths use compiled NumPy/scipy routines +5. **Optional Rust acceleration**: Native code for bootstrap and optimization algorithms Real-World Data Validation -------------------------- @@ -412,22 +539,22 @@ Running Benchmarks # Run all benchmarks at small scale python benchmarks/run_benchmarks.py --all - # Run all benchmarks at all scales with 10 replications - python benchmarks/run_benchmarks.py --all --scale all --replications 10 + # Run all benchmarks at all scales with 3 replications + python benchmarks/run_benchmarks.py --all --scale all --replications 3 # Run specific estimator at specific scale - python benchmarks/run_benchmarks.py --estimator callaway --scale 1k --replications 10 - python benchmarks/run_benchmarks.py --estimator synthdid --scale small --replications 5 - python benchmarks/run_benchmarks.py --estimator basic --scale 5k --replications 10 + python benchmarks/run_benchmarks.py --estimator callaway --scale 1k --replications 3 + python benchmarks/run_benchmarks.py --estimator synthdid --scale small --replications 3 + python benchmarks/run_benchmarks.py --estimator basic --scale 20k --replications 3 - # Available scales: small, 1k, 5k, 10k, all + # Available scales: small, 1k, 5k, 10k, 20k, all # Default: small (backward compatible) - # Generate synthetic data only (use seed for reproducibility) - python benchmarks/run_benchmarks.py --generate-data-only --scale all --seed 20260111 + # Generate synthetic data only + python benchmarks/run_benchmarks.py --generate-data-only --scale all -The benchmarks in this documentation were run with seed 20260111 (date-based: -2026-01-11) for reproducibility. +The benchmarks run both pure Python and Rust backends automatically, producing +a three-way comparison table (R vs Python Pure vs Python Rust). Output ~~~~~~ diff --git a/pyproject.toml b/pyproject.toml index a32a686..2dbb55c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,10 @@ [build-system] -requires = ["setuptools>=61.0", "wheel"] -build-backend = "setuptools.build_meta" +requires = ["maturin>=1.4,<2.0"] +build-backend = "maturin" [project] name = "diff-diff" -version = "1.4.0" +version = "2.0.0" description = "A library for Difference-in-Differences causal inference analysis" readme = "README.md" license = "MIT" @@ -55,8 +55,17 @@ Documentation = "https://diff-diff.readthedocs.io" Repository = "https://github.com/igerber/diff-diff" Issues = "https://github.com/igerber/diff-diff/issues" -[tool.setuptools.packages.find] -include = ["diff_diff*"] +[tool.maturin] +# Build the Rust extension module +features = ["extension-module"] +# Python source is in the root directory +python-source = "." +# Module name for the compiled extension +module-name = "diff_diff._rust_backend" +# Path to Rust Cargo.toml +manifest-path = "rust/Cargo.toml" +# Include Python packages +python-packages = ["diff_diff"] [tool.pytest.ini_options] testpaths = ["tests"] diff --git a/rust/Cargo.toml b/rust/Cargo.toml new file mode 100644 index 0000000..b0bab7f --- /dev/null +++ b/rust/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "diff_diff_rust" +version = "2.0.0" +edition = "2021" +description = "Rust backend for diff-diff DiD library" +license = "MIT" + +[lib] +name = "diff_diff_rust" +# cdylib for Python extension, rlib for running tests +crate-type = ["cdylib", "rlib"] + +[features] +default = [] +# extension-module is only needed for cdylib builds, not for cargo test +extension-module = ["pyo3/extension-module"] + +[dependencies] +pyo3 = "0.20" +numpy = "0.20" +ndarray = { version = "0.15", features = ["rayon"] } +rand = "0.8" +rand_xoshiro = "0.6" +rayon = "1.8" + +# Platform-specific BLAS backends for linear algebra +[target.'cfg(not(target_os = "windows"))'.dependencies] +ndarray-linalg = { version = "0.16", features = ["openblas-system"] } + +[target.'cfg(target_os = "windows")'.dependencies] +ndarray-linalg = { version = "0.16", features = ["intel-mkl-system"] } + +[profile.release] +lto = true +codegen-units = 1 +opt-level = 3 diff --git a/rust/src/bootstrap.rs b/rust/src/bootstrap.rs new file mode 100644 index 0000000..21acca6 --- /dev/null +++ b/rust/src/bootstrap.rs @@ -0,0 +1,223 @@ +//! Bootstrap weight generation for multiplier bootstrap inference. +//! +//! This module provides efficient generation of bootstrap weights +//! using various distributions (Rademacher, Mammen, Webb). + +use ndarray::Array2; +use numpy::{IntoPyArray, PyArray2}; +use pyo3::prelude::*; +use rand::prelude::*; +use rand_xoshiro::Xoshiro256PlusPlus; +use rayon::prelude::*; + +/// Generate a batch of bootstrap weights. +/// +/// Generates (n_bootstrap, n_units) matrix of bootstrap weights +/// for multiplier bootstrap inference. +/// +/// # Arguments +/// * `n_bootstrap` - Number of bootstrap iterations +/// * `n_units` - Number of units (clusters) +/// * `weight_type` - Type of weights: "rademacher", "mammen", or "webb" +/// * `seed` - Random seed for reproducibility +/// +/// # Returns +/// (n_bootstrap, n_units) array of bootstrap weights +#[pyfunction] +#[pyo3(signature = (n_bootstrap, n_units, weight_type, seed))] +pub fn generate_bootstrap_weights_batch<'py>( + py: Python<'py>, + n_bootstrap: usize, + n_units: usize, + weight_type: &str, + seed: u64, +) -> PyResult<&'py PyArray2> { + let weights = match weight_type.to_lowercase().as_str() { + "rademacher" => generate_rademacher_batch(n_bootstrap, n_units, seed), + "mammen" => generate_mammen_batch(n_bootstrap, n_units, seed), + "webb" => generate_webb_batch(n_bootstrap, n_units, seed), + _ => { + return Err(PyErr::new::(format!( + "Unknown weight type: {}. Expected 'rademacher', 'mammen', or 'webb'", + weight_type + ))) + } + }; + + Ok(weights.into_pyarray(py)) +} + +/// Generate Rademacher weights: ±1 with equal probability. +/// +/// E[w] = 0, Var[w] = 1 +fn generate_rademacher_batch(n_bootstrap: usize, n_units: usize, seed: u64) -> Array2 { + // Generate weights in parallel using rayon + let rows: Vec> = (0..n_bootstrap) + .into_par_iter() + .map(|i| { + let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed.wrapping_add(i as u64)); + (0..n_units) + .map(|_| if rng.gen::() { 1.0 } else { -1.0 }) + .collect() + }) + .collect(); + + // Convert to ndarray + let flat: Vec = rows.into_iter().flatten().collect(); + Array2::from_shape_vec((n_bootstrap, n_units), flat).unwrap() +} + +/// Generate Mammen weights with two-point distribution. +/// +/// w = -(√5 - 1)/2 with probability (√5 + 1)/(2√5) +/// w = (√5 + 1)/2 with probability (√5 - 1)/(2√5) +/// +/// E[w] = 0, E[w²] = 1, E[w³] = 1 +fn generate_mammen_batch(n_bootstrap: usize, n_units: usize, seed: u64) -> Array2 { + let sqrt5 = 5.0_f64.sqrt(); + + // Two-point distribution values + let val_neg = -(sqrt5 - 1.0) / 2.0; // ≈ -0.618 + let val_pos = (sqrt5 + 1.0) / 2.0; // ≈ 1.618 + + // Probability of negative value + let prob_neg = (sqrt5 + 1.0) / (2.0 * sqrt5); // ≈ 0.724 + + let rows: Vec> = (0..n_bootstrap) + .into_par_iter() + .map(|i| { + let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed.wrapping_add(i as u64)); + (0..n_units) + .map(|_| { + if rng.gen::() < prob_neg { + val_neg + } else { + val_pos + } + }) + .collect() + }) + .collect(); + + let flat: Vec = rows.into_iter().flatten().collect(); + Array2::from_shape_vec((n_bootstrap, n_units), flat).unwrap() +} + +/// Generate Webb 6-point distribution weights. +/// +/// Six-point distribution that matches additional moments: +/// E[w] = 0, E[w²] = 1, E[w³] = 0, E[w⁴] = 1 +/// +/// Values: ±√(3/2), ±√(1/2), ±√(1/6) with specific probabilities +fn generate_webb_batch(n_bootstrap: usize, n_units: usize, seed: u64) -> Array2 { + // Webb 6-point values and cumulative probabilities + let val1 = (3.0_f64 / 2.0).sqrt(); // √(3/2) ≈ 1.225 + let val2 = (1.0_f64 / 2.0).sqrt(); // √(1/2) ≈ 0.707 + let val3 = (1.0_f64 / 6.0).sqrt(); // √(1/6) ≈ 0.408 + + // Equal probability for each of 6 values: 1/6 each + let prob = 1.0 / 6.0; + + let rows: Vec> = (0..n_bootstrap) + .into_par_iter() + .map(|i| { + let mut rng = Xoshiro256PlusPlus::seed_from_u64(seed.wrapping_add(i as u64)); + (0..n_units) + .map(|_| { + let u = rng.gen::(); + if u < prob { + -val1 + } else if u < 2.0 * prob { + -val2 + } else if u < 3.0 * prob { + -val3 + } else if u < 4.0 * prob { + val3 + } else if u < 5.0 * prob { + val2 + } else { + val1 + } + }) + .collect() + }) + .collect(); + + let flat: Vec = rows.into_iter().flatten().collect(); + Array2::from_shape_vec((n_bootstrap, n_units), flat).unwrap() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_rademacher_shape() { + let weights = generate_rademacher_batch(100, 50, 42); + assert_eq!(weights.shape(), &[100, 50]); + } + + #[test] + fn test_rademacher_values() { + let weights = generate_rademacher_batch(10, 100, 42); + + for w in weights.iter() { + assert!(*w == 1.0 || *w == -1.0, "Rademacher weight should be ±1"); + } + } + + #[test] + fn test_rademacher_mean_approx_zero() { + let weights = generate_rademacher_batch(1000, 1, 42); + let mean: f64 = weights.iter().sum::() / weights.len() as f64; + + // With 1000 samples, mean should be close to 0 + assert!( + mean.abs() < 0.1, + "Rademacher mean should be close to 0, got {}", + mean + ); + } + + #[test] + fn test_mammen_shape() { + let weights = generate_mammen_batch(100, 50, 42); + assert_eq!(weights.shape(), &[100, 50]); + } + + #[test] + fn test_mammen_mean_approx_zero() { + let weights = generate_mammen_batch(1000, 1, 42); + let mean: f64 = weights.iter().sum::() / weights.len() as f64; + + assert!( + mean.abs() < 0.1, + "Mammen mean should be close to 0, got {}", + mean + ); + } + + #[test] + fn test_webb_shape() { + let weights = generate_webb_batch(100, 50, 42); + assert_eq!(weights.shape(), &[100, 50]); + } + + #[test] + fn test_reproducibility() { + let weights1 = generate_rademacher_batch(100, 50, 42); + let weights2 = generate_rademacher_batch(100, 50, 42); + + // Same seed should produce same results + assert_eq!(weights1, weights2); + } + + #[test] + fn test_different_seeds() { + let weights1 = generate_rademacher_batch(100, 50, 42); + let weights2 = generate_rademacher_batch(100, 50, 43); + + // Different seeds should produce different results + assert_ne!(weights1, weights2); + } +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs new file mode 100644 index 0000000..3ce7afb --- /dev/null +++ b/rust/src/lib.rs @@ -0,0 +1,33 @@ +//! Rust backend for diff-diff DiD library. +//! +//! This module provides optimized implementations of computationally +//! intensive operations used in difference-in-differences analysis. + +use pyo3::prelude::*; + +mod bootstrap; +mod linalg; +mod weights; + +/// A Python module implemented in Rust for diff-diff acceleration. +#[pymodule] +fn _rust_backend(_py: Python, m: &PyModule) -> PyResult<()> { + // Bootstrap weight generation + m.add_function(wrap_pyfunction!( + bootstrap::generate_bootstrap_weights_batch, + m + )?)?; + + // Synthetic control weights + m.add_function(wrap_pyfunction!(weights::compute_synthetic_weights, m)?)?; + m.add_function(wrap_pyfunction!(weights::project_simplex, m)?)?; + + // Linear algebra operations + m.add_function(wrap_pyfunction!(linalg::solve_ols, m)?)?; + m.add_function(wrap_pyfunction!(linalg::compute_robust_vcov, m)?)?; + + // Version info + m.add("__version__", env!("CARGO_PKG_VERSION"))?; + + Ok(()) +} diff --git a/rust/src/linalg.rs b/rust/src/linalg.rs new file mode 100644 index 0000000..08c7b37 --- /dev/null +++ b/rust/src/linalg.rs @@ -0,0 +1,229 @@ +//! Linear algebra operations for OLS estimation and robust variance computation. +//! +//! This module provides optimized implementations of: +//! - OLS solving using LAPACK +//! - HC1 (heteroskedasticity-consistent) variance-covariance estimation +//! - Cluster-robust variance-covariance estimation + +use ndarray::{Array1, Array2, ArrayView1, ArrayView2}; +use ndarray_linalg::{LeastSquaresSvd, Solve}; +use numpy::{IntoPyArray, PyArray1, PyArray2, PyReadonlyArray1, PyReadonlyArray2}; +use pyo3::prelude::*; +use std::collections::HashMap; + +/// Solve OLS regression: β = (X'X)^{-1} X'y +/// +/// # Arguments +/// * `x` - Design matrix (n, k) +/// * `y` - Response vector (n,) +/// * `cluster_ids` - Optional cluster identifiers (n,) as integers +/// * `return_vcov` - Whether to compute and return variance-covariance matrix +/// +/// # Returns +/// Tuple of (coefficients, residuals, vcov) where vcov is None if return_vcov=False +#[pyfunction] +#[pyo3(signature = (x, y, cluster_ids=None, return_vcov=true))] +pub fn solve_ols<'py>( + py: Python<'py>, + x: PyReadonlyArray2<'py, f64>, + y: PyReadonlyArray1<'py, f64>, + cluster_ids: Option>, + return_vcov: bool, +) -> PyResult<( + &'py PyArray1, + &'py PyArray1, + Option<&'py PyArray2>, +)> { + let x_arr = x.as_array(); + let y_arr = y.as_array(); + + // Solve least squares using SVD (more stable than normal equations) + let x_owned = x_arr.to_owned(); + let y_owned = y_arr.to_owned(); + + let result = x_owned + .least_squares(&y_owned) + .map_err(|e| PyErr::new::(format!("Least squares failed: {}", e)))?; + + let coefficients = result.solution; + + // Compute fitted values and residuals + let fitted = x_arr.dot(&coefficients); + let residuals = &y_arr - &fitted; + + // Compute variance-covariance if requested + let vcov = if return_vcov { + let cluster_arr = cluster_ids.as_ref().map(|c| c.as_array().to_owned()); + let vcov_arr = compute_robust_vcov_internal(&x_arr, &residuals.view(), cluster_arr.as_ref())?; + Some(vcov_arr.into_pyarray(py)) + } else { + None + }; + + Ok(( + coefficients.into_pyarray(py), + residuals.into_pyarray(py), + vcov, + )) +} + +/// Compute HC1 or cluster-robust variance-covariance matrix. +/// +/// # Arguments +/// * `x` - Design matrix (n, k) +/// * `residuals` - OLS residuals (n,) +/// * `cluster_ids` - Optional cluster identifiers (n,) as integers +/// +/// # Returns +/// Variance-covariance matrix (k, k) +#[pyfunction] +#[pyo3(signature = (x, residuals, cluster_ids=None))] +pub fn compute_robust_vcov<'py>( + py: Python<'py>, + x: PyReadonlyArray2<'py, f64>, + residuals: PyReadonlyArray1<'py, f64>, + cluster_ids: Option>, +) -> PyResult<&'py PyArray2> { + let x_arr = x.as_array(); + let residuals_arr = residuals.as_array(); + let cluster_arr = cluster_ids.as_ref().map(|c| c.as_array().to_owned()); + + let vcov = compute_robust_vcov_internal(&x_arr, &residuals_arr, cluster_arr.as_ref())?; + Ok(vcov.into_pyarray(py)) +} + +/// Internal implementation of robust variance-covariance computation. +fn compute_robust_vcov_internal( + x: &ArrayView2, + residuals: &ArrayView1, + cluster_ids: Option<&Array1>, +) -> PyResult> { + let n = x.nrows(); + let k = x.ncols(); + + // Compute X'X + let xtx = x.t().dot(x); + + // Compute (X'X)^{-1} using Cholesky decomposition + let xtx_inv = invert_symmetric(&xtx)?; + + match cluster_ids { + None => { + // HC1 variance: (X'X)^{-1} X' diag(e²) X (X'X)^{-1} × n/(n-k) + let u_squared: Array1 = residuals.mapv(|r| r * r); + + // Compute X' diag(e²) X efficiently + // meat = Σᵢ eᵢ² xᵢ xᵢ' + let mut meat = Array2::::zeros((k, k)); + for i in 0..n { + let xi = x.row(i); + let e2 = u_squared[i]; + for j in 0..k { + for l in 0..k { + meat[[j, l]] += e2 * xi[j] * xi[l]; + } + } + } + + // HC1 adjustment factor + let adjustment = n as f64 / (n - k) as f64; + + // Sandwich: (X'X)^{-1} meat (X'X)^{-1} + let temp = xtx_inv.dot(&meat); + let vcov = temp.dot(&xtx_inv) * adjustment; + + Ok(vcov) + } + Some(clusters) => { + // Cluster-robust variance + // Group observations by cluster and sum scores within clusters + let n_obs = n; + + // Compute scores: X * e (element-wise, each row multiplied by residual) + let mut scores = Array2::::zeros((n, k)); + for i in 0..n { + let e = residuals[i]; + for j in 0..k { + scores[[i, j]] = x[[i, j]] * e; + } + } + + // Aggregate scores by cluster using HashMap + let mut cluster_sums: HashMap> = HashMap::new(); + for i in 0..n_obs { + let cluster = clusters[i]; + let row = scores.row(i).to_owned(); + cluster_sums + .entry(cluster) + .and_modify(|sum| *sum = &*sum + &row) + .or_insert(row); + } + + let n_clusters = cluster_sums.len(); + + if n_clusters < 2 { + return Err(PyErr::new::( + format!("Need at least 2 clusters for cluster-robust SEs, got {}", n_clusters) + )); + } + + // Build cluster scores matrix (G, k) + let mut cluster_scores = Array2::::zeros((n_clusters, k)); + for (idx, (_cluster_id, sum)) in cluster_sums.iter().enumerate() { + cluster_scores.row_mut(idx).assign(sum); + } + + // Compute meat: Σ_g (X_g' e_g)(X_g' e_g)' + let meat = cluster_scores.t().dot(&cluster_scores); + + // Adjustment factors + // G/(G-1) * (n-1)/(n-k) - matches NumPy implementation + let g = n_clusters as f64; + let adjustment = (g / (g - 1.0)) * ((n_obs - 1) as f64 / (n_obs - k) as f64); + + // Sandwich estimator + let temp = xtx_inv.dot(&meat); + let vcov = temp.dot(&xtx_inv) * adjustment; + + Ok(vcov) + } + } +} + +/// Invert a symmetric positive-definite matrix. +fn invert_symmetric(a: &Array2) -> PyResult> { + let n = a.nrows(); + let mut result = Array2::::zeros((n, n)); + + // Solve A * x_i = e_i for each column of the identity matrix + for i in 0..n { + let mut e_i = Array1::::zeros(n); + e_i[i] = 1.0; + + let col = a.solve(&e_i) + .map_err(|e| PyErr::new::(format!("Matrix inversion failed: {}", e)))?; + + result.column_mut(i).assign(&col); + } + + Ok(result) +} + +#[cfg(test)] +mod tests { + use super::*; + use ndarray::array; + + #[test] + fn test_invert_symmetric() { + let a = array![[4.0, 2.0], [2.0, 3.0]]; + let a_inv = invert_symmetric(&a).unwrap(); + + // A * A^{-1} should be identity + let identity = a.dot(&a_inv); + assert!((identity[[0, 0]] - 1.0).abs() < 1e-10); + assert!((identity[[1, 1]] - 1.0).abs() < 1e-10); + assert!((identity[[0, 1]]).abs() < 1e-10); + assert!((identity[[1, 0]]).abs() < 1e-10); + } +} diff --git a/rust/src/weights.rs b/rust/src/weights.rs new file mode 100644 index 0000000..62c1256 --- /dev/null +++ b/rust/src/weights.rs @@ -0,0 +1,220 @@ +//! Synthetic control weight computation via projected gradient descent. +//! +//! This module provides optimized implementations of: +//! - Synthetic control weight optimization +//! - Simplex projection + +use ndarray::{Array1, ArrayView1, ArrayView2}; +use numpy::{IntoPyArray, PyArray1, PyReadonlyArray1, PyReadonlyArray2}; +use pyo3::prelude::*; + +/// Maximum number of optimization iterations. +const MAX_ITER: usize = 1000; + +/// Default convergence tolerance (matches Python's _OPTIMIZATION_TOL). +const DEFAULT_TOL: f64 = 1e-8; + +/// Default step size for gradient descent. +const DEFAULT_STEP_SIZE: f64 = 0.1; + +/// Compute synthetic control weights via projected gradient descent. +/// +/// Solves: min_w ||Y_treated - Y_control @ w||² + lambda * ||w||² +/// subject to: w >= 0, sum(w) = 1 +/// +/// # Arguments +/// * `y_control` - Control unit outcomes matrix (n_pre, n_control) +/// * `y_treated` - Treated unit outcomes (n_pre,) +/// * `lambda_reg` - L2 regularization parameter +/// * `max_iter` - Maximum number of iterations (default: 1000) +/// * `tol` - Convergence tolerance (default: 1e-6) +/// +/// # Returns +/// Optimal weights (n_control,) that sum to 1 +#[pyfunction] +#[pyo3(signature = (y_control, y_treated, lambda_reg=0.0, max_iter=None, tol=None))] +pub fn compute_synthetic_weights<'py>( + py: Python<'py>, + y_control: PyReadonlyArray2<'py, f64>, + y_treated: PyReadonlyArray1<'py, f64>, + lambda_reg: f64, + max_iter: Option, + tol: Option, +) -> PyResult<&'py PyArray1> { + let y_control_arr = y_control.as_array(); + let y_treated_arr = y_treated.as_array(); + + let weights = + compute_synthetic_weights_internal(&y_control_arr, &y_treated_arr, lambda_reg, max_iter, tol)?; + + Ok(weights.into_pyarray(py)) +} + +/// Internal implementation of synthetic weight computation. +fn compute_synthetic_weights_internal( + y_control: &ArrayView2, + y_treated: &ArrayView1, + lambda_reg: f64, + max_iter: Option, + tol: Option, +) -> PyResult> { + let n_control = y_control.ncols(); + let max_iter = max_iter.unwrap_or(MAX_ITER); + let tol = tol.unwrap_or(DEFAULT_TOL); + + // Precompute Hessian: H = Y_control' @ Y_control + lambda * I + let h = { + let ytc = y_control.t().dot(y_control); + let mut h = ytc; + // Add regularization to diagonal + for i in 0..n_control { + h[[i, i]] += lambda_reg; + } + h + }; + + // Precompute linear term: f = Y_control' @ Y_treated + let f = y_control.t().dot(y_treated); + + // Initialize with uniform weights + let mut weights = Array1::from_elem(n_control, 1.0 / n_control as f64); + + // Projected gradient descent + let step_size = DEFAULT_STEP_SIZE; + let mut prev_weights = weights.clone(); + + for _ in 0..max_iter { + // Gradient: grad = H @ weights - f + let grad = h.dot(&weights) - &f; + + // Gradient step + weights = &weights - step_size * &grad; + + // Project onto simplex + weights = project_simplex_internal(&weights.view()); + + // Check convergence + let diff: f64 = weights + .iter() + .zip(prev_weights.iter()) + .map(|(a, b)| (a - b).powi(2)) + .sum(); + if diff.sqrt() < tol { + break; + } + + prev_weights.assign(&weights); + } + + Ok(weights) +} + +/// Project a vector onto the probability simplex. +/// +/// Implements the O(n log n) algorithm from: +/// Duchi et al. "Efficient Projections onto the ℓ1-Ball for Learning in High Dimensions" +/// +/// # Arguments +/// * `v` - Input vector (n,) +/// +/// # Returns +/// Projected vector (n,) satisfying: w >= 0, sum(w) = 1 +#[pyfunction] +pub fn project_simplex<'py>( + py: Python<'py>, + v: PyReadonlyArray1<'py, f64>, +) -> PyResult<&'py PyArray1> { + let v_arr = v.as_array(); + let result = project_simplex_internal(&v_arr); + Ok(result.into_pyarray(py)) +} + +/// Internal implementation of simplex projection. +/// +/// Algorithm: +/// 1. Sort v in descending order +/// 2. Find the largest k such that u_k + (1 - sum_{j=1}^k u_j) / k > 0 +/// 3. Set theta = (sum_{j=1}^k u_j - 1) / k +/// 4. Return max(v - theta, 0) +fn project_simplex_internal(v: &ArrayView1) -> Array1 { + let n = v.len(); + + // Sort in descending order + let mut u: Vec = v.iter().cloned().collect(); + u.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal)); + + // Find rho: largest index where u[rho] + (1 - cumsum[rho]) / (rho + 1) > 0 + let mut cumsum = 0.0; + let mut rho = 0; + for i in 0..n { + cumsum += u[i]; + if u[i] + (1.0 - cumsum) / (i + 1) as f64 > 0.0 { + rho = i; + } + } + + // Compute threshold + let cumsum_rho: f64 = u.iter().take(rho + 1).sum(); + let theta = (cumsum_rho - 1.0) / (rho + 1) as f64; + + // Project: max(v - theta, 0) + v.mapv(|x| (x - theta).max(0.0)) +} + +#[cfg(test)] +mod tests { + use super::*; + use ndarray::array; + + #[test] + fn test_project_simplex_already_on_simplex() { + let v = array![0.3, 0.5, 0.2]; + let result = project_simplex_internal(&v.view()); + + // Already on simplex, should be unchanged + let sum: f64 = result.sum(); + assert!((sum - 1.0).abs() < 1e-10); + assert!(result.iter().all(|&x| x >= 0.0)); + } + + #[test] + fn test_project_simplex_uniform() { + let v = array![1.0, 1.0, 1.0, 1.0]; + let result = project_simplex_internal(&v.view()); + + // Should project to uniform distribution + let sum: f64 = result.sum(); + assert!((sum - 1.0).abs() < 1e-10); + for &x in result.iter() { + assert!((x - 0.25).abs() < 1e-10); + } + } + + #[test] + fn test_project_simplex_negative() { + let v = array![-1.0, 2.0, 0.5]; + let result = project_simplex_internal(&v.view()); + + // Should be on simplex + let sum: f64 = result.sum(); + assert!((sum - 1.0).abs() < 1e-10); + assert!(result.iter().all(|&x| x >= -1e-10)); + } + + #[test] + fn test_compute_weights_sum_to_one() { + let y_control = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]; + let y_treated = array![2.0, 5.0, 8.0]; + + let weights = + compute_synthetic_weights_internal(&y_control.view(), &y_treated.view(), 0.0, None, None) + .unwrap(); + + let sum: f64 = weights.sum(); + assert!((sum - 1.0).abs() < 1e-6, "Weights should sum to 1, got {}", sum); + assert!( + weights.iter().all(|&w| w >= -1e-10), + "Weights should be non-negative" + ); + } +} diff --git a/tests/test_rust_backend.py b/tests/test_rust_backend.py new file mode 100644 index 0000000..2592336 --- /dev/null +++ b/tests/test_rust_backend.py @@ -0,0 +1,514 @@ +""" +Tests for the Rust backend. + +These tests verify that: +1. The Rust backend produces results matching the NumPy implementations +2. Basic functionality works correctly +3. Edge cases are handled properly + +Tests are skipped if the Rust backend is not available. +""" + +import numpy as np +import pytest + +from diff_diff import HAS_RUST_BACKEND + + +@pytest.mark.skipif(not HAS_RUST_BACKEND, reason="Rust backend not available") +class TestRustBackend: + """Test suite for Rust backend functions.""" + + def test_rust_backend_available(self): + """Verify Rust backend is available when this test runs.""" + assert HAS_RUST_BACKEND + + # ========================================================================= + # Bootstrap Weight Tests + # ========================================================================= + + def test_bootstrap_weights_shape(self): + """Test bootstrap weights have correct shape.""" + from diff_diff._rust_backend import generate_bootstrap_weights_batch + + n_bootstrap, n_units = 100, 50 + weights = generate_bootstrap_weights_batch(n_bootstrap, n_units, "rademacher", 42) + assert weights.shape == (n_bootstrap, n_units) + + def test_rademacher_weights_values(self): + """Test Rademacher weights are +-1.""" + from diff_diff._rust_backend import generate_bootstrap_weights_batch + + weights = generate_bootstrap_weights_batch(100, 50, "rademacher", 42) + unique_vals = np.unique(weights) + assert len(unique_vals) == 2 + assert set(unique_vals) == {-1.0, 1.0} + + def test_rademacher_weights_mean_zero(self): + """Test Rademacher weights have approximately zero mean.""" + from diff_diff._rust_backend import generate_bootstrap_weights_batch + + weights = generate_bootstrap_weights_batch(10000, 1, "rademacher", 42) + mean = weights.mean() + assert abs(mean) < 0.05, f"Rademacher mean should be ~0, got {mean}" + + def test_mammen_weights_mean_zero(self): + """Test Mammen weights have approximately zero mean.""" + from diff_diff._rust_backend import generate_bootstrap_weights_batch + + weights = generate_bootstrap_weights_batch(10000, 1, "mammen", 42) + mean = weights.mean() + assert abs(mean) < 0.05, f"Mammen mean should be ~0, got {mean}" + + def test_webb_weights_mean_zero(self): + """Test Webb weights have approximately zero mean.""" + from diff_diff._rust_backend import generate_bootstrap_weights_batch + + weights = generate_bootstrap_weights_batch(10000, 1, "webb", 42) + mean = weights.mean() + assert abs(mean) < 0.1, f"Webb mean should be ~0, got {mean}" + + def test_bootstrap_reproducibility(self): + """Test bootstrap weights are reproducible with same seed.""" + from diff_diff._rust_backend import generate_bootstrap_weights_batch + + weights1 = generate_bootstrap_weights_batch(100, 50, "rademacher", 42) + weights2 = generate_bootstrap_weights_batch(100, 50, "rademacher", 42) + np.testing.assert_array_equal(weights1, weights2) + + def test_bootstrap_different_seeds(self): + """Test different seeds produce different weights.""" + from diff_diff._rust_backend import generate_bootstrap_weights_batch + + weights1 = generate_bootstrap_weights_batch(100, 50, "rademacher", 42) + weights2 = generate_bootstrap_weights_batch(100, 50, "rademacher", 43) + assert not np.array_equal(weights1, weights2) + + # ========================================================================= + # Synthetic Weight Tests + # ========================================================================= + + def test_synthetic_weights_sum_to_one(self): + """Test synthetic weights sum to 1.""" + from diff_diff._rust_backend import compute_synthetic_weights + + np.random.seed(42) + Y_control = np.random.randn(10, 5) + Y_treated = np.random.randn(10) + + weights = compute_synthetic_weights(Y_control, Y_treated, 0.0, 1000, 1e-8) + assert abs(weights.sum() - 1.0) < 1e-6, f"Weights should sum to 1, got {weights.sum()}" + + def test_synthetic_weights_non_negative(self): + """Test synthetic weights are non-negative.""" + from diff_diff._rust_backend import compute_synthetic_weights + + np.random.seed(42) + Y_control = np.random.randn(10, 5) + Y_treated = np.random.randn(10) + + weights = compute_synthetic_weights(Y_control, Y_treated, 0.0, 1000, 1e-8) + assert np.all(weights >= -1e-10), "Weights should be non-negative" + + def test_synthetic_weights_shape(self): + """Test synthetic weights have correct shape.""" + from diff_diff._rust_backend import compute_synthetic_weights + + np.random.seed(42) + n_control = 8 + Y_control = np.random.randn(10, n_control) + Y_treated = np.random.randn(10) + + weights = compute_synthetic_weights(Y_control, Y_treated, 0.0, 1000, 1e-8) + assert weights.shape == (n_control,) + + # ========================================================================= + # Simplex Projection Tests + # ========================================================================= + + def test_project_simplex_sum(self): + """Test projected vector sums to 1.""" + from diff_diff._rust_backend import project_simplex + + v = np.array([0.5, 0.3, 0.2, 0.4]) + projected = project_simplex(v) + assert abs(projected.sum() - 1.0) < 1e-10 + + def test_project_simplex_non_negative(self): + """Test projected vector is non-negative.""" + from diff_diff._rust_backend import project_simplex + + v = np.array([-0.5, 0.3, 1.2, 0.4]) + projected = project_simplex(v) + assert np.all(projected >= -1e-10) + + def test_project_simplex_already_on_simplex(self): + """Test projecting a vector already on simplex.""" + from diff_diff._rust_backend import project_simplex + + v = np.array([0.3, 0.5, 0.2]) + projected = project_simplex(v) + np.testing.assert_array_almost_equal(projected, v) + + # ========================================================================= + # OLS Tests + # ========================================================================= + + def test_solve_ols_shape(self): + """Test OLS returns correct shapes.""" + from diff_diff._rust_backend import solve_ols + + np.random.seed(42) + n, k = 100, 5 + X = np.random.randn(n, k) + y = np.random.randn(n) + + coeffs, residuals, vcov = solve_ols(X, y, None, True) + + assert coeffs.shape == (k,) + assert residuals.shape == (n,) + assert vcov.shape == (k, k) + + def test_solve_ols_coefficients(self): + """Test OLS coefficients match scipy.""" + from diff_diff._rust_backend import solve_ols + from scipy.linalg import lstsq + + np.random.seed(42) + n, k = 100, 5 + X = np.random.randn(n, k) + y = np.random.randn(n) + + coeffs_rust, _, _ = solve_ols(X, y, None, True) + coeffs_scipy = lstsq(X, y)[0] + + np.testing.assert_array_almost_equal(coeffs_rust, coeffs_scipy, decimal=10) + + def test_solve_ols_residuals(self): + """Test OLS residuals are correct.""" + from diff_diff._rust_backend import solve_ols + + np.random.seed(42) + n, k = 100, 5 + X = np.random.randn(n, k) + y = np.random.randn(n) + + coeffs, residuals, _ = solve_ols(X, y, None, True) + expected_residuals = y - X @ coeffs + + np.testing.assert_array_almost_equal(residuals, expected_residuals, decimal=10) + + # ========================================================================= + # Robust VCoV Tests + # ========================================================================= + + def test_robust_vcov_shape(self): + """Test robust VCoV has correct shape.""" + from diff_diff._rust_backend import compute_robust_vcov + + np.random.seed(42) + n, k = 100, 5 + X = np.random.randn(n, k) + residuals = np.random.randn(n) + + vcov = compute_robust_vcov(X, residuals, None) + assert vcov.shape == (k, k) + + def test_robust_vcov_symmetric(self): + """Test robust VCoV is symmetric.""" + from diff_diff._rust_backend import compute_robust_vcov + + np.random.seed(42) + n, k = 100, 5 + X = np.random.randn(n, k) + residuals = np.random.randn(n) + + vcov = compute_robust_vcov(X, residuals, None) + np.testing.assert_array_almost_equal(vcov, vcov.T) + + def test_robust_vcov_positive_diagonal(self): + """Test robust VCoV has positive diagonal.""" + from diff_diff._rust_backend import compute_robust_vcov + + np.random.seed(42) + n, k = 100, 5 + X = np.random.randn(n, k) + residuals = np.random.randn(n) + + vcov = compute_robust_vcov(X, residuals, None) + assert np.all(np.diag(vcov) > 0), "Diagonal should be positive" + + def test_cluster_robust_vcov(self): + """Test cluster-robust VCoV.""" + from diff_diff._rust_backend import compute_robust_vcov + + np.random.seed(42) + n, k = 100, 5 + n_clusters = 10 + X = np.random.randn(n, k) + residuals = np.random.randn(n) + cluster_ids = np.repeat(np.arange(n_clusters), n // n_clusters) + + vcov = compute_robust_vcov(X, residuals, cluster_ids) + assert vcov.shape == (k, k) + assert np.all(np.diag(vcov) > 0) + + +@pytest.mark.skipif(not HAS_RUST_BACKEND, reason="Rust backend not available") +class TestRustVsNumpy: + """Tests comparing Rust and NumPy implementations for numerical equivalence.""" + + # ========================================================================= + # OLS Solver Equivalence + # ========================================================================= + + def test_solve_ols_coefficients_match(self): + """Test Rust and NumPy OLS coefficients match.""" + from diff_diff._rust_backend import solve_ols as rust_fn + from diff_diff.linalg import _solve_ols_numpy as numpy_fn + + np.random.seed(42) + n, k = 100, 5 + X = np.random.randn(n, k) + y = np.random.randn(n) + + rust_coeffs, rust_resid, rust_vcov = rust_fn(X, y, None, True) + numpy_coeffs, numpy_resid, numpy_vcov = numpy_fn(X, y, cluster_ids=None) + + np.testing.assert_array_almost_equal( + rust_coeffs, numpy_coeffs, decimal=8, + err_msg="OLS coefficients should match" + ) + np.testing.assert_array_almost_equal( + rust_resid, numpy_resid, decimal=8, + err_msg="OLS residuals should match" + ) + + def test_solve_ols_with_clusters_match(self): + """Test Rust and NumPy OLS with cluster SEs match.""" + from diff_diff._rust_backend import solve_ols as rust_fn + from diff_diff.linalg import _solve_ols_numpy as numpy_fn + + np.random.seed(42) + n, k = 100, 5 + n_clusters = 10 + X = np.random.randn(n, k) + y = np.random.randn(n) + cluster_ids = np.repeat(np.arange(n_clusters), n // n_clusters) + + rust_coeffs, _, rust_vcov = rust_fn(X, y, cluster_ids, True) + numpy_coeffs, _, numpy_vcov = numpy_fn(X, y, cluster_ids=cluster_ids) + + np.testing.assert_array_almost_equal( + rust_coeffs, numpy_coeffs, decimal=8, + err_msg="Clustered OLS coefficients should match" + ) + # VCoV may differ slightly due to implementation details + np.testing.assert_array_almost_equal( + rust_vcov, numpy_vcov, decimal=5, + err_msg="Clustered OLS VCoV should match" + ) + + # ========================================================================= + # Robust VCoV Equivalence + # ========================================================================= + + def test_robust_vcov_hc1_match(self): + """Test Rust and NumPy HC1 robust VCoV match.""" + from diff_diff._rust_backend import compute_robust_vcov as rust_fn + from diff_diff.linalg import _compute_robust_vcov_numpy as numpy_fn + + np.random.seed(42) + n, k = 100, 5 + X = np.random.randn(n, k) + residuals = np.random.randn(n) + + rust_vcov = rust_fn(X, residuals, None) + numpy_vcov = numpy_fn(X, residuals, None) + + np.testing.assert_array_almost_equal( + rust_vcov, numpy_vcov, decimal=8, + err_msg="HC1 robust VCoV should match" + ) + + def test_robust_vcov_clustered_match(self): + """Test Rust and NumPy cluster-robust VCoV match.""" + from diff_diff._rust_backend import compute_robust_vcov as rust_fn + from diff_diff.linalg import _compute_robust_vcov_numpy as numpy_fn + + np.random.seed(42) + n, k = 100, 5 + n_clusters = 10 + X = np.random.randn(n, k) + residuals = np.random.randn(n) + cluster_ids = np.repeat(np.arange(n_clusters), n // n_clusters) + + rust_vcov = rust_fn(X, residuals, cluster_ids) + numpy_vcov = numpy_fn(X, residuals, cluster_ids) + + np.testing.assert_array_almost_equal( + rust_vcov, numpy_vcov, decimal=6, + err_msg="Cluster-robust VCoV should match" + ) + + # ========================================================================= + # Bootstrap Weights Equivalence (Statistical Properties) + # ========================================================================= + + def test_bootstrap_weights_rademacher_properties(self): + """Test Rust Rademacher weights have correct statistical properties.""" + from diff_diff._rust_backend import generate_bootstrap_weights_batch as rust_fn + + # Generate large sample for statistical tests + n_bootstrap, n_units = 10000, 100 + weights = rust_fn(n_bootstrap, n_units, "rademacher", 42) + + # Rademacher: values are +-1, mean ~0, variance ~1 + unique_vals = np.unique(weights) + assert set(unique_vals) == {-1.0, 1.0}, "Rademacher weights should be +-1" + + mean = weights.mean() + assert abs(mean) < 0.02, f"Rademacher mean should be ~0, got {mean}" + + var = weights.var() + assert abs(var - 1.0) < 0.02, f"Rademacher variance should be ~1, got {var}" + + def test_bootstrap_weights_mammen_properties(self): + """Test Rust Mammen weights have correct statistical properties.""" + from diff_diff._rust_backend import generate_bootstrap_weights_batch as rust_fn + + n_bootstrap, n_units = 10000, 100 + weights = rust_fn(n_bootstrap, n_units, "mammen", 42) + + # Mammen: E[w] = 0, E[w^2] = 1, E[w^3] = 1 + mean = weights.mean() + assert abs(mean) < 0.02, f"Mammen mean should be ~0, got {mean}" + + second_moment = (weights ** 2).mean() + assert abs(second_moment - 1.0) < 0.02, f"Mammen E[w^2] should be ~1, got {second_moment}" + + third_moment = (weights ** 3).mean() + assert abs(third_moment - 1.0) < 0.1, f"Mammen E[w^3] should be ~1, got {third_moment}" + + def test_bootstrap_weights_webb_properties(self): + """Test Rust Webb weights have correct statistical properties.""" + from diff_diff._rust_backend import generate_bootstrap_weights_batch as rust_fn + + n_bootstrap, n_units = 10000, 100 + weights = rust_fn(n_bootstrap, n_units, "webb", 42) + + # Webb: 6-point distribution with E[w] = 0 + mean = weights.mean() + assert abs(mean) < 0.1, f"Webb mean should be ~0, got {mean}" + + # Should have 6 unique values + unique_vals = np.unique(weights.flatten()) + assert len(unique_vals) == 6, f"Webb should have 6 unique values, got {len(unique_vals)}" + + # ========================================================================= + # Synthetic Weights Equivalence + # ========================================================================= + + def test_synthetic_weights_match(self): + """Test Rust and NumPy synthetic weights produce similar results.""" + from diff_diff._rust_backend import compute_synthetic_weights as rust_fn + from diff_diff.utils import _compute_synthetic_weights_numpy as numpy_fn + + np.random.seed(42) + Y_control = np.random.randn(10, 5) + Y_treated = np.random.randn(10) + + rust_weights = rust_fn(Y_control, Y_treated, 0.0, 1000, 1e-8) + numpy_weights = numpy_fn(Y_control, Y_treated, 0.0) + + # Both should be valid simplex weights + assert abs(rust_weights.sum() - 1.0) < 1e-6, "Rust weights should sum to 1" + assert abs(numpy_weights.sum() - 1.0) < 1e-6, "NumPy weights should sum to 1" + assert np.all(rust_weights >= -1e-6), "Rust weights should be non-negative" + assert np.all(numpy_weights >= -1e-6), "NumPy weights should be non-negative" + + # Reconstruction error should be similar + rust_error = np.linalg.norm(Y_treated - Y_control @ rust_weights) + numpy_error = np.linalg.norm(Y_treated - Y_control @ numpy_weights) + assert abs(rust_error - numpy_error) < 0.5, \ + f"Reconstruction errors should be similar: rust={rust_error:.4f}, numpy={numpy_error:.4f}" + + def test_synthetic_weights_with_regularization(self): + """Test Rust synthetic weights with L2 regularization.""" + from diff_diff._rust_backend import compute_synthetic_weights as rust_fn + from diff_diff.utils import _compute_synthetic_weights_numpy as numpy_fn + + np.random.seed(42) + Y_control = np.random.randn(15, 8) + Y_treated = np.random.randn(15) + lambda_reg = 0.1 + + rust_weights = rust_fn(Y_control, Y_treated, lambda_reg, 1000, 1e-8) + numpy_weights = numpy_fn(Y_control, Y_treated, lambda_reg) + + # Both should be valid simplex weights + assert abs(rust_weights.sum() - 1.0) < 1e-6 + assert abs(numpy_weights.sum() - 1.0) < 1e-6 + + # With regularization, weights should be more spread out (higher entropy) + rust_entropy = -np.sum(rust_weights * np.log(rust_weights + 1e-10)) + numpy_entropy = -np.sum(numpy_weights * np.log(numpy_weights + 1e-10)) + assert rust_entropy > 0.5, "Regularized weights should have positive entropy" + assert numpy_entropy > 0.5, "Regularized weights should have positive entropy" + + def test_simplex_projection_match(self): + """Test Rust and NumPy simplex projection match exactly.""" + from diff_diff._rust_backend import project_simplex as rust_fn + from diff_diff.utils import _project_simplex as numpy_fn + + # Test various input vectors + test_vectors = [ + np.array([0.5, -0.3, 1.2, 0.4, -0.1]), + np.array([1.0, 1.0, 1.0, 1.0]), # uniform + np.array([0.25, 0.25, 0.25, 0.25]), # already on simplex + np.array([-1.0, -2.0, 5.0]), # one dominant + np.array([0.1, 0.2, 0.3, 0.4]), # near simplex + ] + + for v in test_vectors: + rust_proj = rust_fn(v) + numpy_proj = numpy_fn(v) + + np.testing.assert_array_almost_equal( + rust_proj, numpy_proj, decimal=10, + err_msg=f"Simplex projection mismatch for input {v}" + ) + + +class TestFallbackWhenNoRust: + """Test that pure Python fallback works when Rust is unavailable.""" + + def test_has_rust_backend_is_bool(self): + """HAS_RUST_BACKEND should be a boolean.""" + assert isinstance(HAS_RUST_BACKEND, bool) + + def test_imports_work_without_rust(self): + """Core imports should work regardless of Rust availability.""" + from diff_diff import ( + CallawaySantAnna, + DifferenceInDifferences, + SyntheticDiD, + ) + + assert CallawaySantAnna is not None + assert DifferenceInDifferences is not None + assert SyntheticDiD is not None + + def test_linalg_works_without_rust(self): + """linalg functions should work with NumPy fallback.""" + from diff_diff.linalg import compute_robust_vcov, solve_ols + + np.random.seed(42) + n, k = 50, 3 + X = np.random.randn(n, k) + y = np.random.randn(n) + + coeffs, residuals, vcov = solve_ols(X, y) + assert coeffs.shape == (k,) + assert residuals.shape == (n,) + assert vcov.shape == (k, k)