Skip to content

A quantum-inspired inference framework for LLMs that uses Tensor Networks and Fidelity-Guided Beam Search (FGBS) to improve logical coherence and reasoning reliability.

Notifications You must be signed in to change notification settings

ksupasate/TNAD-TensorNetwork-Decoding

Repository files navigation

TNAD: Tensor Network-Augmented Decoding

Quantum-inspired inference framework for improving logical coherence in Large Language Model reasoning

Python 3.9+ PyTorch License: MIT


Overview

TNAD (Tensor Network-Augmented Decoding) is a novel inference framework that leverages quantum-inspired tensor network methods to enhance the logical coherence of Large Language Model (LLM) outputs. By maintaining a Matrix Product State (MPS) representation of token sequences during generation, TNAD can detect and penalize incoherent reasoning paths in real-time, leading to more reliable and logically consistent responses.

Key Innovation

Traditional LLM decoding methods (greedy, beam search, sampling) optimize for local token probability without considering global logical structure. TNAD introduces Fidelity-Guided Beam Search (FGBS), which balances:

  • Fluency (standard LLM probability): How natural the text sounds
  • Coherence (quantum-inspired fidelity score): How logically consistent the reasoning is

This is achieved through the Coherence Fidelity Score (CFS), computed from the Schmidt spectrum of the MPS representation, which quantifies structural integrity of the generated sequence.

Mathematical Foundation

Standard Beam Search:

Score(S) = log P(S)

Fidelity-Guided Beam Search:

Score(S) = α · log P(S) + (1-α) · log F(S)

where:

  • P(S): LLM probability (fluency)
  • F(S): Coherence Fidelity Score (structural integrity)
  • α ∈ [0,1]: Balance parameter

The CFS is derived from quantum purity measures:

Given Schmidt values λ = [λ₁, λ₂, ..., λ_χ]:
Purity: P = Σᵢ λᵢ⁴
CFS: F = 1 / P

High F → uniform spectrum → high entanglement → coherent state
Low F → peaked spectrum → low entanglement → decoherent state

Features

  • Quantum-Inspired Coherence Scoring: Real-time structural monitoring via Matrix Product States
  • Fidelity-Guided Beam Search: Novel decoding algorithm balancing fluency and coherence
  • Memory Efficient: Optimized MPS implementation with bond dimension control
  • GPU Accelerated: Full CUDA/MPS support with 8-bit and 4-bit quantization
  • Production Ready: Comprehensive test suite, type hints, and detailed documentation
  • Research Grade: Extensive experiment framework for benchmarking and analysis

Installation

Requirements

  • Python 3.9+
  • PyTorch 2.0+
  • CUDA 11.8+ (optional, for GPU acceleration)
  • 16GB+ RAM recommended (8GB minimum with quantization)

Quick Install

# Clone repository
git clone https://github.com/ksupasate/TNAD-TensorNetwork-Decoding.git
cd TNAD-TensorNetwork-Decoding

# Install dependencies
pip install -r requirements.txt

# Install package
pip install -e .

Dependencies Installation Script

For a guided installation with dependency checks:

bash install_dependencies.sh

Verify Installation

python test_setup.py

Quick Start

Basic Usage

from transformers import AutoModelForCausalLM, AutoTokenizer
from tnad import FidelityGuidedBeamSearcher

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")

# Initialize FGBS
searcher = FidelityGuidedBeamSearcher(
    model=model,
    tokenizer=tokenizer,
    beam_width=5,      # Number of parallel beams
    alpha=0.5,         # Balance: 0=pure coherence, 1=pure LLM
    bond_dim=16,       # MPS bond dimension (controls coherence tracking)
)

# Generate coherent text
prompt = "If all cats are animals, and some animals can fly, can all cats fly? Let's think step by step."
result = searcher.generate(prompt, max_length=100)

print("Generated text:", result['text'])
print("Coherence score:", result['log_cfs'])
print("LLM probability:", result['log_prob'])

Memory-Efficient Usage (8-bit Quantization)

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# 8-bit quantization config
quant_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_threshold=6.0,
)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    quantization_config=quant_config,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")

# Use smaller bond dimension for memory efficiency
searcher = FidelityGuidedBeamSearcher(
    model=model,
    tokenizer=tokenizer,
    beam_width=3,      # Reduced for memory
    alpha=0.5,
    bond_dim=8,        # Smaller bond dimension
)

Comparing with Baseline

# Generate with both FGBS and standard beam search
comparison = searcher.compare_with_baseline(
    prompt="Solve: If x + 2 = 5, then x = ?",
    max_length=100,
)

print("FGBS output:", comparison['fgbs']['text'])
print("Baseline output:", comparison['baseline']['text'])
print("CFS improvement:", comparison['cfs_comparison']['cfs_improvement'])

Configuration

TNAD uses YAML configuration files for reproducible experiments. See configs/default.yaml for the full configuration template.

Key Parameters

FGBS Algorithm:

  • beam_width (3-10): Number of parallel beams. Higher = better quality, slower
  • alpha (0.0-1.0): Fluency vs coherence balance
    • 1.0 = pure LLM (standard beam search)
    • 0.5 = balanced (recommended)
    • 0.3 = prioritize coherence (good for reasoning)
  • bond_dim (8-32): MPS bond dimension
    • Smaller: faster but limited logical tracking
    • Larger: slower but better coherence monitoring
  • top_k (20-50): Number of top tokens to consider per beam

Model Configuration:

  • load_in_8bit: Enable 8-bit quantization for memory efficiency
  • torch_dtype: Precision (float16/bfloat16 for efficiency)
  • device: Target device (auto/cuda/cpu/mps)

Generation:

  • max_length: Maximum generation length (tokens)
  • min_length: Minimum length before allowing EOS
  • temperature: Sampling temperature (1.0 = no scaling)

Example Configurations

Memory-Optimized (for 8GB GPU):

fgbs:
  beam_width: 3
  alpha: 0.5
  bond_dim: 8
  top_k: 30

model:
  load_in_8bit: true
  torch_dtype: "float16"

generation:
  max_length: 256

Quality-Optimized (for 24GB+ GPU):

fgbs:
  beam_width: 10
  alpha: 0.5
  bond_dim: 32
  top_k: 50

model:
  load_in_8bit: false
  torch_dtype: "bfloat16"

generation:
  max_length: 512

Experiments

Running GSM8K Benchmark

GSM8K is a dataset of grade school math word problems requiring multi-step reasoning.

# Run with default configuration
python experiments/run_gsm8k.py --config configs/default.yaml

# Override specific parameters
python experiments/run_gsm8k.py \
    --config configs/default.yaml \
    --alpha 0.5 \
    --bond_dim 16 \
    --beam_width 5 \
    --num_examples 100

# Use custom model
python experiments/run_gsm8k.py \
    --config configs/default.yaml \
    --model "meta-llama/Llama-3.1-8B-Instruct"

Running StrategyQA Benchmark

StrategyQA tests multi-hop reasoning with implicit decomposition.

python experiments/run_strategyqa.py --config configs/default.yaml

Ablation Studies

Run comprehensive ablation studies across hyperparameters:

python experiments/run_ablations.py --config configs/default.yaml

This will sweep over:

  • Alpha values: [0.0, 0.3, 0.5, 0.7, 1.0]
  • Bond dimensions: [4, 8, 16, 32]
  • Beam widths: [1, 3, 5, 10]

Baseline Comparisons

python experiments/baselines.py \
    --methods greedy beam_search self_consistency \
    --num_examples 100

Reproducing Paper Results

python experiments/reproduce_paper_results.py

Jupyter Notebooks

Interactive demonstrations and tutorials are available in the notebooks/ directory:

  • demo.ipynb: Quick introduction and basic usage
  • tutorial_comprehensive.ipynb: In-depth tutorial covering all features
  • performance_benchmark.ipynb: Performance analysis and profiling
  • tnad_colab.ipynb: Google Colab compatible notebook
# Launch Jupyter
jupyter notebook notebooks/

API Reference

Core Components

FidelityGuidedBeamSearcher

Main FGBS implementation for LLM generation.

searcher = FidelityGuidedBeamSearcher(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    beam_width: int = 5,
    alpha: float = 0.5,
    bond_dim: int = 16,
    top_k: int = 50,
    temperature: float = 1.0,
    device: Optional[str] = None,
    normalize_embeddings: bool = True,
)

Methods:

  • generate(prompt, max_length, min_length, return_details, show_progress): Generate text
  • compare_with_baseline(prompt, max_length): Compare FGBS vs standard beam search

MPSSequence

Matrix Product State representation of token sequences.

mps = MPSSequence(
    bond_dim: int,
    embedding_dim: int,
    device: Optional[str] = None,
    normalize_embeddings: bool = True,
)

Methods:

  • add_token(token_embedding): Add token to MPS chain
  • get_schmidt_values(cut_position): Extract Schmidt spectrum
  • copy(): Create deep copy for beam branching

compute_cfs()

Compute Coherence Fidelity Score from Schmidt values.

from tnad import compute_cfs

cfs = compute_cfs(
    schmidt_values: Union[np.ndarray, torch.Tensor],
    normalize: bool = True,
    eps: float = 1e-10,
    return_purity: bool = False,
)

Utility Functions

from tnad.utils import (
    log_normalize,      # Log-space normalization
    safe_divide,        # Numerically stable division
    normalize_schmidt_values,  # Schmidt value normalization
    compute_purity,     # Quantum purity calculation
    get_device,         # Auto device selection
)

Project Structure

TNAD-TensorNetwork-Decoding/
├── tnad/                       # Core package
│   ├── __init__.py            # Package exports
│   ├── fgbs_searcher.py       # FGBS implementation
│   ├── mps_manager.py         # MPS representation
│   ├── coherence_score.py     # CFS computation
│   └── utils.py               # Utility functions
├── experiments/               # Experiment scripts
│   ├── run_gsm8k.py          # GSM8K benchmark
│   ├── run_strategyqa.py     # StrategyQA benchmark
│   ├── run_ablations.py      # Ablation studies
│   ├── baselines.py          # Baseline methods
│   ├── aggregate_results.py  # Results aggregation
│   └── visualize_results.py  # Plotting utilities
├── tests/                     # Test suite
│   ├── test_fgbs_integration.py
│   ├── test_coherence_score.py
│   └── test_mps_manager.py
├── notebooks/                 # Jupyter tutorials
│   ├── demo.ipynb
│   ├── tutorial_comprehensive.ipynb
│   └── performance_benchmark.ipynb
├── configs/                   # Configuration files
│   ├── default.yaml
│   ├── memory_optimized.yaml
│   └── full_publication.yaml
├── data/                      # Sample datasets
├── results/                   # Experiment outputs
├── requirements.txt           # Python dependencies
├── setup.py                   # Package setup
└── README.md                  # This file

Performance Optimization

Memory Efficiency

TNAD includes several optimizations for memory-constrained environments:

  1. Quantization: 8-bit and 4-bit model loading
  2. Optimized MPS: Reduced memory allocations, efficient copying
  3. Garbage Collection: Aggressive cleanup during beam search
  4. Cache Management: LRU cache for Schmidt values
# Memory monitoring
from experiments.check_gpu_memory import monitor_memory

with monitor_memory():
    result = searcher.generate(prompt, max_length=100)

Speed Optimization

  1. Batch Embeddings: Pre-compute embeddings for top-k tokens
  2. Efficient Matrix Operations: Optimized @ operator usage
  3. Caching: Schmidt value caching with configurable size
  4. PyTorch Optimizations: Gradient-free inference, mixed precision
# For maximum speed (sacrifices some coherence tracking)
searcher = FidelityGuidedBeamSearcher(
    model=model,
    tokenizer=tokenizer,
    beam_width=3,      # Smaller beam
    bond_dim=8,        # Smaller bond dimension
    top_k=20,          # Fewer candidates
)

Testing

Run the full test suite:

# All tests with coverage
pytest tests/ --cov=tnad --cov-report=html

# Specific test files
pytest tests/test_fgbs_integration.py -v
pytest tests/test_coherence_score.py -v
pytest tests/test_mps_manager.py -v

Test Coverage

The test suite includes:

  • Unit tests for all core components
  • Integration tests for FGBS pipeline
  • Numerical stability tests
  • Edge case handling
  • Memory leak detection

Troubleshooting

Common Issues

Out of Memory (OOM):

# Solution 1: Enable 8-bit quantization
load_in_8bit: true

# Solution 2: Reduce beam width and bond dimension
beam_width: 3
bond_dim: 8

# Solution 3: Reduce generation length
max_length: 128

Slow Generation:

# Reduce top_k to consider fewer tokens
top_k: 20

# Reduce beam width
beam_width: 3

# Use smaller model
model: "microsoft/phi-2"

MPS (Apple Silicon) Errors:

# Some operations may not be supported on MPS
# The code automatically falls back to CPU for unsupported ops

# Force CPU for stability:
device: "cpu"

Quantization Not Working:

# Ensure bitsandbytes is installed
pip install bitsandbytes>=0.41.0

# Check GPU compatibility (CUDA 11.1+)
python -c "import torch; print(torch.cuda.is_available())"

Debug Mode

Enable detailed logging:

from tnad.utils import setup_logger

setup_logger(log_level="DEBUG", log_file="debug.log")

Citation

If you use TNAD in your research, please cite:

@software{tnad2024,
  title={TNAD: Tensor Network-Augmented Decoding for Coherent LLM Reasoning},
  author={Supasate Vorathammathorn},
  year={2025},
  url={https://github.com/ksupasate/TNAD-TensorNetwork-Decoding}
}

License

This project is licensed under the MIT License - see the LICENSE file for details.


Acknowledgments

  • Inspired by tensor network methods from quantum many-body physics
  • Built on HuggingFace Transformers and PyTorch
  • Uses Matrix Product State (MPS) formalism from quantum information theory

Contributing

Contributions are welcome! Please feel free to submit a Pull Request. For major changes, please open an issue first to discuss what you would like to change.

Development Setup

# Install development dependencies
pip install -e ".[dev]"

# Run tests before committing
pytest tests/ --cov=tnad

# Format code
black tnad/ tests/ experiments/

# Type checking
mypy tnad/

Contact

For questions and feedback:


Advanced Features (v1.0)

Encoder-Decoder Models

Full support for T5, BART, mBART, and mT5 architectures:

from transformers import T5ForConditionalGeneration, T5Tokenizer
from tnad import EncoderDecoderFGBS

# Load T5 model
model = T5ForConditionalGeneration.from_pretrained("t5-base")
tokenizer = T5Tokenizer.from_pretrained("t5-base")

# Create encoder-decoder FGBS
searcher = EncoderDecoderFGBS(
    model=model,
    tokenizer=tokenizer,
    beam_width=5,
    alpha=0.5,
    bond_dim=16,
)

# Summarization
result = searcher.generate(
    "summarize: Large language models have shown remarkable capabilities...",
    max_length=100
)
print(result['text'])

Multi-GPU Distributed Beam Search

Scale FGBS across multiple GPUs for higher throughput:

from tnad import DistributedFGBS, setup_distributed, cleanup_distributed

# Setup distributed environment
rank, world_size = setup_distributed()

# Load model on each GPU
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    device_map=f"cuda:{rank}"
)

# Create distributed searcher
searcher = DistributedFGBS(
    model=model,
    tokenizer=tokenizer,
    beam_width=16,  # Total beams across all GPUs
    alpha=0.5,
    bond_dim=16,
    rank=rank,
    world_size=world_size,
)

# Generate (same prompt on all GPUs)
result = searcher.generate("Solve: x + 2 = 5", max_length=100)

# Only rank 0 gets final result
if rank == 0:
    print(result['text'])

cleanup_distributed()

Performance: Near-linear scaling up to 8 GPUs.

vLLM Integration

Production-ready deployment with vLLM for 5-10x higher throughput:

from tnad import vLLMFGBS

# Initialize vLLM with FGBS
fgbs = vLLMFGBS(
    model_name="meta-llama/Llama-3.1-8B-Instruct",
    alpha=0.5,
    bond_dim=16,
    tensor_parallel_size=2,  # Use 2 GPUs
    gpu_memory_utilization=0.9,
)

# Single generation
result = fgbs.generate("What is quantum computing?", max_tokens=100)
print(result['text'])
print(f"CFS: {result['cfs']:.4f}")

# Batch generation (high throughput)
prompts = ["Question 1", "Question 2", "Question 3"]
results = fgbs.generate_batch(prompts, max_tokens=100)

Note: vLLM requires Linux and CUDA. Install with pip install vllm.

Streaming Generation

Real-time token-by-token generation for interactive applications:

from tnad import StreamingFGBS

searcher = StreamingFGBS(
    model=model,
    tokenizer=tokenizer,
    beam_width=5,
    alpha=0.5,
    bond_dim=16,
)

# Synchronous streaming
for token_info in searcher.generate_stream("Once upon a time", max_length=100):
    print(token_info.token, end='', flush=True)
    print(f" [CFS: {math.exp(token_info.log_cfs):.2f}]", end='')

# Async streaming (for web servers)
async for token in searcher.generate_stream_async(prompt, max_length=100):
    await websocket.send_text(token.token)

# With callbacks
def on_token(token):
    print(f"Generated: {token.token} (CFS={math.exp(token.log_cfs):.2f})")

result = searcher.generate_with_callbacks(
    prompt="Explain quantum computing",
    on_token=on_token,
    coherence_threshold=0.5,  # Early stopping if CFS drops
)

Fine-Tuning with Coherence Rewards

Train models to generate intrinsically more coherent text:

from tnad import CoherenceFilteredDataset, CoherenceRewardTrainer
from transformers import TrainingArguments

# Filter dataset by coherence
filtered_dataset = CoherenceFilteredDataset(
    data=train_data,
    tokenizer=tokenizer,
    embedding_layer=model.get_input_embeddings(),
    min_cfs=0.8,  # Keep top 80% coherent examples
    bond_dim=16,
)

# Training arguments
training_args = TrainingArguments(
    output_dir="./coherence_finetuned",
    num_train_epochs=3,
    per_device_train_batch_size=4,
)

# Create trainer with coherence weighting
trainer = CoherenceRewardTrainer(
    model=model,
    train_dataset=filtered_dataset,
    args=training_args,
    coherence_weight=0.5,  # 50% weight on coherence
)

trainer.train()

Reinforcement Learning:

from tnad import CoherenceRLTrainer

# RL trainer with PPO
rl_trainer = CoherenceRLTrainer(
    model=model,
    ref_model=ref_model,
    tokenizer=tokenizer,
    bond_dim=16,
    learning_rate=1e-5,
)

# Train with coherence rewards
rl_trainer.train(prompts=training_prompts, num_epochs=3)

Web Demo (Gradio)

Interactive web interface for testing FGBS:

# Launch demo
python -m tnad.web_demo --model meta-llama/Llama-3.1-8B-Instruct --share

# With authentication
python -m tnad.web_demo --model gpt2 --auth username:password

# Custom port
python -m tnad.web_demo --model gpt2 --server-port 7860

Features:

  • Real-time parameter tuning (α, χ, beam width)
  • Coherence trajectory visualization
  • Baseline comparison
  • Export results to JSON/CSV

API Server (FastAPI)

Production REST API with WebSocket streaming:

# Start server
python -m tnad.api_server --model meta-llama/Llama-3.1-8B-Instruct --port 8000

# With 8-bit quantization
python -m tnad.api_server --model meta-llama/Llama-3.1-8B-Instruct --load-in-8bit

# Multiple workers
python -m tnad.api_server --model gpt2 --workers 4

API Endpoints:

# Single generation
curl -X POST "http://localhost:8000/generate" \
  -H "Content-Type: application/json" \
  -d '{
    "prompt": "Solve: x + 2 = 5",
    "max_length": 100,
    "alpha": 0.5,
    "return_details": true
  }'

# Batch generation
curl -X POST "http://localhost:8000/generate/batch" \
  -H "Content-Type: application/json" \
  -d '{
    "prompts": ["Question 1", "Question 2"],
    "max_length": 100
  }'

# Health check
curl "http://localhost:8000/health"

# Metrics
curl "http://localhost:8000/metrics"

Server-Sent Events (SSE):

const eventSource = new EventSource(
  'http://localhost:8000/stream?prompt=Hello&max_length=50'
);
eventSource.onmessage = (event) => {
  const data = JSON.parse(event.data);
  console.log(data.token);
};

WebSocket:

const ws = new WebSocket('ws://localhost:8000/ws/generate');
ws.onopen = () => {
  ws.send(JSON.stringify({prompt: 'Hello', max_length: 50}));
};
ws.onmessage = (event) => {
  const data = JSON.parse(event.data);
  console.log(data.token);
};

Complete Examples

See examples/advanced_features_demo.py for comprehensive examples of all features:

# Run all examples
python examples/advanced_features_demo.py --example all

# Run specific example
python examples/advanced_features_demo.py --example streaming
python examples/advanced_features_demo.py --example encoder-decoder
python examples/advanced_features_demo.py --example vllm

Roadmap (v2.0)

  • Support for vision-language models (CLIP, LLaVA)
  • Beam search with constrained decoding
  • Multi-modal coherence scoring
  • Federated learning support
  • Model compression techniques

Built with ❤️ for more reliable and coherent AI reasoning

About

A quantum-inspired inference framework for LLMs that uses Tensor Networks and Fidelity-Guided Beam Search (FGBS) to improve logical coherence and reasoning reliability.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •