Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 63 additions & 14 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,72 @@ jobs:
with:
python-version: "3.11"

- name: Install Rust toolchain
uses: dtolnay/rust-toolchain@stable

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e ".[dev]"
pip install ".[dev]"

- name: Run tests
run: pytest -v

publish:
build-wheels:
needs: [check-version, test]
if: needs.check-version.outputs.should_release == 'true'
strategy:
matrix:
include:
- os: ubuntu-latest
target: x86_64
- os: ubuntu-latest
target: aarch64
- os: macos-latest
target: x86_64
- os: macos-latest
target: aarch64
- os: windows-latest
target: x86_64
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4

- name: Build wheels
uses: PyO3/maturin-action@v1
with:
target: ${{ matrix.target }}
args: --release --out dist
manylinux: auto

- name: Upload wheels
uses: actions/upload-artifact@v4
with:
name: wheels-${{ matrix.os }}-${{ matrix.target }}
path: dist

build-sdist:
needs: [check-version, test]
if: needs.check-version.outputs.should_release == 'true'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Build sdist
uses: PyO3/maturin-action@v1
with:
command: sdist
args: --out dist

- name: Upload sdist
uses: actions/upload-artifact@v4
with:
name: wheels-sdist
path: dist

publish:
needs: [check-version, build-wheels, build-sdist]
if: needs.check-version.outputs.should_release == 'true'
runs-on: ubuntu-latest
permissions:
id-token: write
Expand All @@ -59,24 +114,18 @@ jobs:
steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"

- name: Create version tag
run: |
VERSION=${{ needs.check-version.outputs.version }}
git tag "v$VERSION"
git push origin "v$VERSION"

- name: Install build dependencies
run: |
python -m pip install --upgrade pip
pip install build hatchling

- name: Build wheel and sdist
run: python -m build
- name: Download all artifacts
uses: actions/download-artifact@v4
with:
pattern: wheels-*
merge-multiple: true
path: dist

- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
Expand Down
5 changes: 4 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ jobs:
with:
python-version: ${{ matrix.python-version }}

- name: Install Rust toolchain
uses: dtolnay/rust-toolchain@stable

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e ".[dev]"
pip install ".[dev]"

- name: Run ruff (lint)
run: ruff check .
Expand Down
9 changes: 8 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,11 @@ examples/.cache/
.vscode/

# Benchmark results (generated, not committed)
benchmarks/results/*.json
benchmarks/results/*.json

target/

# Rust/maturin build artifacts
*.so
*.dylib
*.pyd
71 changes: 59 additions & 12 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ pytest # Tests (must pass)

### Setup
```bash
pip install -e . # Install package in editable mode
pip install -e ".[dev]" # Install with dev dependencies
pip install maturin # Required for building Rust extension
maturin develop --release # Build Rust extension (needs Rust toolchain)
pip install -e ".[dev]" # Install with dev dependencies
```

### Testing
Expand Down Expand Up @@ -88,25 +89,34 @@ python examples/colbert_nanobeir.py

**Local build (for testing):**
```bash
pip install build
python -m build
twine check dist/*
maturin build --release # Build wheel with Rust extension
```

## Architecture

### Core Components

**`muvera/muvera.py`** - Main `Muvera` class implementing Fixed Dimensional Encoding (FDE)
- Three encoding paths: single document, uniform batch, variable-length batch
- Two encoding paths: single document, variable-length batch
- Document encoding uses AVERAGE aggregation within partitions
- Query encoding uses SUM aggregation within partitions
- Optional final dimensionality reduction via Count Sketch
- Hot-path methods (`_aggregate_single`, `_scatter_add`, `_fill_empty_batch`) delegate to Rust kernels when available

**`muvera/helper.py`** - Low-level utilities (not public API)
- Gray code manipulation for partition indexing
- Random projection matrices (SimHash, AMS Sketch, Count Sketch)
- Vectorized batch partition indexing
- `partition_index_gray` and `partition_indices_gray_batch` delegate to Rust when available

**`src/`** - Rust extension module (`muvera._rust_kernels`) via PyO3/maturin
- `gray_code.rs` — Gray code append and binary conversion
- `partition.rs` — Single and batch Gray-code partition indexing
- `scatter.rs` — Scatter-add kernel for batch aggregation
- `fill_empty.rs` — Single-point-cloud aggregation + batch empty partition filling
- `lib.rs` — PyO3 module definition exposing 5 functions

**`muvera/_rust_kernels.pyi`** - Type stubs for the Rust extension module

### Algorithm Flow

Expand All @@ -120,24 +130,58 @@ twine check dist/*
6. **Repetitions**: Repeat steps 1-5 with different random seeds, concatenating results
7. **Final Projection** (optional): Apply Count Sketch to reduce final dimension

### Rust Acceleration

Performance-critical inner loops are implemented in Rust via PyO3, with automatic fallback to pure Python:

```python
# muvera/__init__.py
try:
import muvera._rust_kernels
_RUST_AVAILABLE = True
except ImportError:
_RUST_AVAILABLE = False
```

**Accelerated functions:**
| Rust function | Python fallback | Speedup |
|---|---|---|
| `aggregate_single` | `Muvera._aggregate_single_python` | 8-17x (single doc) |
| `scatter_add_partitions` | `Muvera._scatter_add` (np.add.at loop) | 1-2.5x (batch) |
| `fill_empty_partitions_batch` | `Muvera._fill_empty_batch` (Python loop) | 1-2.5x (batch) |
| `partition_index_gray` | `helper._partition_index_gray_python` | part of aggregate |
| `partition_indices_gray_batch` | `helper._partition_indices_gray_batch_python` | part of batch |

**What is NOT in Rust** (intentionally kept in NumPy for seed compatibility):
- `simhash_matrix_from_seed`, `ams_projection_matrix_from_seed` — depend on `np.random.default_rng`
- `count_sketch_vector_from_seed` — same reason
- `Muvera.__init__`, public API signatures — 100% unchanged

### Batch Processing

The library supports three input formats:
The library supports two input formats:
- **Single**: `(num_vectors, dimension)` - processes one point cloud
- **Uniform batch**: `(batch_size, num_vectors, dimension)` - all point clouds have same length
- **Variable-length batch**: `list[np.ndarray]` - each point cloud has different length (recommended for real-world data)

Variable-length batch processing flattens all point clouds, processes them together, then aggregates per-document using `np.add.at()` for efficient scatter-add operations.
Variable-length batch processing flattens all point clouds, processes them together, then aggregates per-document using Rust `scatter_add_partitions` (or `np.add.at()` fallback).

## Code Conventions

### Python
- NumPy-style docstrings (configured in pyproject.toml)
- Type hints required (Python 3.9+ syntax with `|` for unions)
- Line length: 100 characters
- Use `np.float32` for all embeddings (memory efficiency)
- Use `np.uint32` for partition indices
- Random number generation via `np.random.default_rng(seed)` for reproducibility

### Rust
- Edition 2021
- Dependencies: `pyo3` 0.23, `numpy` 0.23 (Rust crate, not Python package), `ndarray` 0.16
- All three crates are version-locked together (upgrade all at once)
- Use `f32` for all floating-point data, `u32` for partition indices, `i32` for counts, `i64` for boundaries
- PyO3 functions accept `PyReadonlyArray*` for input arrays and `&Bound<PyArray*>` for in-place mutation

## Testing

### Test Organization
Expand All @@ -146,6 +190,7 @@ Variable-length batch processing flattens all point clouds, processes them toget
- **`test_muvera.py`**: Core Muvera class tests (shapes, validation, reproducibility)
- **`test_reference.py`**: Validation against reference implementation (sionic-ai/muvera-py)
- **`test_real_colbert.py`**: Real-world ColBERT embedding tests using NanoBEIR fixtures
- **`test_rust_equivalence.py`**: Numerical equivalence tests between Rust kernels and Python fallbacks (skipped if Rust extension is unavailable)

### Real Data Testing

Expand Down Expand Up @@ -173,15 +218,17 @@ Output dimension: `num_repetitions * 2^num_simhash_projections * projection_dime
**`.github/workflows/test.yml`** - Continuous Integration
- Triggers: Push to main, all pull requests
- Tests across Python 3.9-3.13
- Installs Rust toolchain via `dtolnay/rust-toolchain@stable`
- Builds Rust extension via `pip install ".[dev]"` (maturin build backend)
- Runs ruff (lint + format check), mypy (type checking), pytest
- Tests example scripts

**`.github/workflows/publish.yml`** - PyPI Publishing
- Triggers: Push to main
- Checks if `v{version}` tag already exists; skips release if it does
- Runs full test suite
- Creates git tag, builds wheel/sdist, publishes to PyPI via OIDC
- Creates GitHub Release with release notes
- Builds cross-platform wheels via `PyO3/maturin-action@v1` (Linux x86_64/aarch64, macOS x86_64/aarch64, Windows x86_64)
- Builds sdist separately
- Creates git tag, publishes to PyPI via OIDC, creates GitHub Release

### Deployment Policy

Expand Down
Loading