Quantum-inspired inference framework for improving logical coherence in Large Language Model reasoning
TNAD (Tensor Network-Augmented Decoding) is a novel inference framework that leverages quantum-inspired tensor network methods to enhance the logical coherence of Large Language Model (LLM) outputs. By maintaining a Matrix Product State (MPS) representation of token sequences during generation, TNAD can detect and penalize incoherent reasoning paths in real-time, leading to more reliable and logically consistent responses.
Traditional LLM decoding methods (greedy, beam search, sampling) optimize for local token probability without considering global logical structure. TNAD introduces Fidelity-Guided Beam Search (FGBS), which balances:
- Fluency (standard LLM probability): How natural the text sounds
- Coherence (quantum-inspired fidelity score): How logically consistent the reasoning is
This is achieved through the Coherence Fidelity Score (CFS), computed from the Schmidt spectrum of the MPS representation, which quantifies structural integrity of the generated sequence.
Standard Beam Search:
Score(S) = log P(S)
Fidelity-Guided Beam Search:
Score(S) = α · log P(S) + (1-α) · log F(S)
where:
P(S): LLM probability (fluency)F(S): Coherence Fidelity Score (structural integrity)α ∈ [0,1]: Balance parameter
The CFS is derived from quantum purity measures:
Given Schmidt values λ = [λ₁, λ₂, ..., λ_χ]:
Purity: P = Σᵢ λᵢ⁴
CFS: F = 1 / P
High F → uniform spectrum → high entanglement → coherent state
Low F → peaked spectrum → low entanglement → decoherent state
- Quantum-Inspired Coherence Scoring: Real-time structural monitoring via Matrix Product States
- Fidelity-Guided Beam Search: Novel decoding algorithm balancing fluency and coherence
- Memory Efficient: Optimized MPS implementation with bond dimension control
- GPU Accelerated: Full CUDA/MPS support with 8-bit and 4-bit quantization
- Production Ready: Comprehensive test suite, type hints, and detailed documentation
- Research Grade: Extensive experiment framework for benchmarking and analysis
- Python 3.9+
- PyTorch 2.0+
- CUDA 11.8+ (optional, for GPU acceleration)
- 16GB+ RAM recommended (8GB minimum with quantization)
# Clone repository
git clone https://github.com/ksupasate/TNAD-TensorNetwork-Decoding.git
cd TNAD-TensorNetwork-Decoding
# Install dependencies
pip install -r requirements.txt
# Install package
pip install -e .For a guided installation with dependency checks:
bash install_dependencies.shpython test_setup.pyfrom transformers import AutoModelForCausalLM, AutoTokenizer
from tnad import FidelityGuidedBeamSearcher
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
# Initialize FGBS
searcher = FidelityGuidedBeamSearcher(
model=model,
tokenizer=tokenizer,
beam_width=5, # Number of parallel beams
alpha=0.5, # Balance: 0=pure coherence, 1=pure LLM
bond_dim=16, # MPS bond dimension (controls coherence tracking)
)
# Generate coherent text
prompt = "If all cats are animals, and some animals can fly, can all cats fly? Let's think step by step."
result = searcher.generate(prompt, max_length=100)
print("Generated text:", result['text'])
print("Coherence score:", result['log_cfs'])
print("LLM probability:", result['log_prob'])from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# 8-bit quantization config
quant_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct",
quantization_config=quant_config,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
# Use smaller bond dimension for memory efficiency
searcher = FidelityGuidedBeamSearcher(
model=model,
tokenizer=tokenizer,
beam_width=3, # Reduced for memory
alpha=0.5,
bond_dim=8, # Smaller bond dimension
)# Generate with both FGBS and standard beam search
comparison = searcher.compare_with_baseline(
prompt="Solve: If x + 2 = 5, then x = ?",
max_length=100,
)
print("FGBS output:", comparison['fgbs']['text'])
print("Baseline output:", comparison['baseline']['text'])
print("CFS improvement:", comparison['cfs_comparison']['cfs_improvement'])TNAD uses YAML configuration files for reproducible experiments. See configs/default.yaml for the full configuration template.
FGBS Algorithm:
beam_width(3-10): Number of parallel beams. Higher = better quality, sloweralpha(0.0-1.0): Fluency vs coherence balance- 1.0 = pure LLM (standard beam search)
- 0.5 = balanced (recommended)
- 0.3 = prioritize coherence (good for reasoning)
bond_dim(8-32): MPS bond dimension- Smaller: faster but limited logical tracking
- Larger: slower but better coherence monitoring
top_k(20-50): Number of top tokens to consider per beam
Model Configuration:
load_in_8bit: Enable 8-bit quantization for memory efficiencytorch_dtype: Precision (float16/bfloat16 for efficiency)device: Target device (auto/cuda/cpu/mps)
Generation:
max_length: Maximum generation length (tokens)min_length: Minimum length before allowing EOStemperature: Sampling temperature (1.0 = no scaling)
Memory-Optimized (for 8GB GPU):
fgbs:
beam_width: 3
alpha: 0.5
bond_dim: 8
top_k: 30
model:
load_in_8bit: true
torch_dtype: "float16"
generation:
max_length: 256Quality-Optimized (for 24GB+ GPU):
fgbs:
beam_width: 10
alpha: 0.5
bond_dim: 32
top_k: 50
model:
load_in_8bit: false
torch_dtype: "bfloat16"
generation:
max_length: 512GSM8K is a dataset of grade school math word problems requiring multi-step reasoning.
# Run with default configuration
python experiments/run_gsm8k.py --config configs/default.yaml
# Override specific parameters
python experiments/run_gsm8k.py \
--config configs/default.yaml \
--alpha 0.5 \
--bond_dim 16 \
--beam_width 5 \
--num_examples 100
# Use custom model
python experiments/run_gsm8k.py \
--config configs/default.yaml \
--model "meta-llama/Llama-3.1-8B-Instruct"StrategyQA tests multi-hop reasoning with implicit decomposition.
python experiments/run_strategyqa.py --config configs/default.yamlRun comprehensive ablation studies across hyperparameters:
python experiments/run_ablations.py --config configs/default.yamlThis will sweep over:
- Alpha values: [0.0, 0.3, 0.5, 0.7, 1.0]
- Bond dimensions: [4, 8, 16, 32]
- Beam widths: [1, 3, 5, 10]
python experiments/baselines.py \
--methods greedy beam_search self_consistency \
--num_examples 100python experiments/reproduce_paper_results.pyInteractive demonstrations and tutorials are available in the notebooks/ directory:
demo.ipynb: Quick introduction and basic usagetutorial_comprehensive.ipynb: In-depth tutorial covering all featuresperformance_benchmark.ipynb: Performance analysis and profilingtnad_colab.ipynb: Google Colab compatible notebook
# Launch Jupyter
jupyter notebook notebooks/Main FGBS implementation for LLM generation.
searcher = FidelityGuidedBeamSearcher(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
beam_width: int = 5,
alpha: float = 0.5,
bond_dim: int = 16,
top_k: int = 50,
temperature: float = 1.0,
device: Optional[str] = None,
normalize_embeddings: bool = True,
)Methods:
generate(prompt, max_length, min_length, return_details, show_progress): Generate textcompare_with_baseline(prompt, max_length): Compare FGBS vs standard beam search
Matrix Product State representation of token sequences.
mps = MPSSequence(
bond_dim: int,
embedding_dim: int,
device: Optional[str] = None,
normalize_embeddings: bool = True,
)Methods:
add_token(token_embedding): Add token to MPS chainget_schmidt_values(cut_position): Extract Schmidt spectrumcopy(): Create deep copy for beam branching
Compute Coherence Fidelity Score from Schmidt values.
from tnad import compute_cfs
cfs = compute_cfs(
schmidt_values: Union[np.ndarray, torch.Tensor],
normalize: bool = True,
eps: float = 1e-10,
return_purity: bool = False,
)from tnad.utils import (
log_normalize, # Log-space normalization
safe_divide, # Numerically stable division
normalize_schmidt_values, # Schmidt value normalization
compute_purity, # Quantum purity calculation
get_device, # Auto device selection
)TNAD-TensorNetwork-Decoding/
├── tnad/ # Core package
│ ├── __init__.py # Package exports
│ ├── fgbs_searcher.py # FGBS implementation
│ ├── mps_manager.py # MPS representation
│ ├── coherence_score.py # CFS computation
│ └── utils.py # Utility functions
├── experiments/ # Experiment scripts
│ ├── run_gsm8k.py # GSM8K benchmark
│ ├── run_strategyqa.py # StrategyQA benchmark
│ ├── run_ablations.py # Ablation studies
│ ├── baselines.py # Baseline methods
│ ├── aggregate_results.py # Results aggregation
│ └── visualize_results.py # Plotting utilities
├── tests/ # Test suite
│ ├── test_fgbs_integration.py
│ ├── test_coherence_score.py
│ └── test_mps_manager.py
├── notebooks/ # Jupyter tutorials
│ ├── demo.ipynb
│ ├── tutorial_comprehensive.ipynb
│ └── performance_benchmark.ipynb
├── configs/ # Configuration files
│ ├── default.yaml
│ ├── memory_optimized.yaml
│ └── full_publication.yaml
├── data/ # Sample datasets
├── results/ # Experiment outputs
├── requirements.txt # Python dependencies
├── setup.py # Package setup
└── README.md # This file
TNAD includes several optimizations for memory-constrained environments:
- Quantization: 8-bit and 4-bit model loading
- Optimized MPS: Reduced memory allocations, efficient copying
- Garbage Collection: Aggressive cleanup during beam search
- Cache Management: LRU cache for Schmidt values
# Memory monitoring
from experiments.check_gpu_memory import monitor_memory
with monitor_memory():
result = searcher.generate(prompt, max_length=100)- Batch Embeddings: Pre-compute embeddings for top-k tokens
- Efficient Matrix Operations: Optimized @ operator usage
- Caching: Schmidt value caching with configurable size
- PyTorch Optimizations: Gradient-free inference, mixed precision
# For maximum speed (sacrifices some coherence tracking)
searcher = FidelityGuidedBeamSearcher(
model=model,
tokenizer=tokenizer,
beam_width=3, # Smaller beam
bond_dim=8, # Smaller bond dimension
top_k=20, # Fewer candidates
)Run the full test suite:
# All tests with coverage
pytest tests/ --cov=tnad --cov-report=html
# Specific test files
pytest tests/test_fgbs_integration.py -v
pytest tests/test_coherence_score.py -v
pytest tests/test_mps_manager.py -vThe test suite includes:
- Unit tests for all core components
- Integration tests for FGBS pipeline
- Numerical stability tests
- Edge case handling
- Memory leak detection
Out of Memory (OOM):
# Solution 1: Enable 8-bit quantization
load_in_8bit: true
# Solution 2: Reduce beam width and bond dimension
beam_width: 3
bond_dim: 8
# Solution 3: Reduce generation length
max_length: 128Slow Generation:
# Reduce top_k to consider fewer tokens
top_k: 20
# Reduce beam width
beam_width: 3
# Use smaller model
model: "microsoft/phi-2"MPS (Apple Silicon) Errors:
# Some operations may not be supported on MPS
# The code automatically falls back to CPU for unsupported ops
# Force CPU for stability:
device: "cpu"Quantization Not Working:
# Ensure bitsandbytes is installed
pip install bitsandbytes>=0.41.0
# Check GPU compatibility (CUDA 11.1+)
python -c "import torch; print(torch.cuda.is_available())"Enable detailed logging:
from tnad.utils import setup_logger
setup_logger(log_level="DEBUG", log_file="debug.log")If you use TNAD in your research, please cite:
@software{tnad2024,
title={TNAD: Tensor Network-Augmented Decoding for Coherent LLM Reasoning},
author={Supasate Vorathammathorn},
year={2025},
url={https://github.com/ksupasate/TNAD-TensorNetwork-Decoding}
}This project is licensed under the MIT License - see the LICENSE file for details.
- Inspired by tensor network methods from quantum many-body physics
- Built on HuggingFace Transformers and PyTorch
- Uses Matrix Product State (MPS) formalism from quantum information theory
Contributions are welcome! Please feel free to submit a Pull Request. For major changes, please open an issue first to discuss what you would like to change.
# Install development dependencies
pip install -e ".[dev]"
# Run tests before committing
pytest tests/ --cov=tnad
# Format code
black tnad/ tests/ experiments/
# Type checking
mypy tnad/For questions and feedback:
- Open an issue on GitHub
- Email: ksupasate@gmail.com
Full support for T5, BART, mBART, and mT5 architectures:
from transformers import T5ForConditionalGeneration, T5Tokenizer
from tnad import EncoderDecoderFGBS
# Load T5 model
model = T5ForConditionalGeneration.from_pretrained("t5-base")
tokenizer = T5Tokenizer.from_pretrained("t5-base")
# Create encoder-decoder FGBS
searcher = EncoderDecoderFGBS(
model=model,
tokenizer=tokenizer,
beam_width=5,
alpha=0.5,
bond_dim=16,
)
# Summarization
result = searcher.generate(
"summarize: Large language models have shown remarkable capabilities...",
max_length=100
)
print(result['text'])Scale FGBS across multiple GPUs for higher throughput:
from tnad import DistributedFGBS, setup_distributed, cleanup_distributed
# Setup distributed environment
rank, world_size = setup_distributed()
# Load model on each GPU
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct",
device_map=f"cuda:{rank}"
)
# Create distributed searcher
searcher = DistributedFGBS(
model=model,
tokenizer=tokenizer,
beam_width=16, # Total beams across all GPUs
alpha=0.5,
bond_dim=16,
rank=rank,
world_size=world_size,
)
# Generate (same prompt on all GPUs)
result = searcher.generate("Solve: x + 2 = 5", max_length=100)
# Only rank 0 gets final result
if rank == 0:
print(result['text'])
cleanup_distributed()Performance: Near-linear scaling up to 8 GPUs.
Production-ready deployment with vLLM for 5-10x higher throughput:
from tnad import vLLMFGBS
# Initialize vLLM with FGBS
fgbs = vLLMFGBS(
model_name="meta-llama/Llama-3.1-8B-Instruct",
alpha=0.5,
bond_dim=16,
tensor_parallel_size=2, # Use 2 GPUs
gpu_memory_utilization=0.9,
)
# Single generation
result = fgbs.generate("What is quantum computing?", max_tokens=100)
print(result['text'])
print(f"CFS: {result['cfs']:.4f}")
# Batch generation (high throughput)
prompts = ["Question 1", "Question 2", "Question 3"]
results = fgbs.generate_batch(prompts, max_tokens=100)Note: vLLM requires Linux and CUDA. Install with pip install vllm.
Real-time token-by-token generation for interactive applications:
from tnad import StreamingFGBS
searcher = StreamingFGBS(
model=model,
tokenizer=tokenizer,
beam_width=5,
alpha=0.5,
bond_dim=16,
)
# Synchronous streaming
for token_info in searcher.generate_stream("Once upon a time", max_length=100):
print(token_info.token, end='', flush=True)
print(f" [CFS: {math.exp(token_info.log_cfs):.2f}]", end='')
# Async streaming (for web servers)
async for token in searcher.generate_stream_async(prompt, max_length=100):
await websocket.send_text(token.token)
# With callbacks
def on_token(token):
print(f"Generated: {token.token} (CFS={math.exp(token.log_cfs):.2f})")
result = searcher.generate_with_callbacks(
prompt="Explain quantum computing",
on_token=on_token,
coherence_threshold=0.5, # Early stopping if CFS drops
)Train models to generate intrinsically more coherent text:
from tnad import CoherenceFilteredDataset, CoherenceRewardTrainer
from transformers import TrainingArguments
# Filter dataset by coherence
filtered_dataset = CoherenceFilteredDataset(
data=train_data,
tokenizer=tokenizer,
embedding_layer=model.get_input_embeddings(),
min_cfs=0.8, # Keep top 80% coherent examples
bond_dim=16,
)
# Training arguments
training_args = TrainingArguments(
output_dir="./coherence_finetuned",
num_train_epochs=3,
per_device_train_batch_size=4,
)
# Create trainer with coherence weighting
trainer = CoherenceRewardTrainer(
model=model,
train_dataset=filtered_dataset,
args=training_args,
coherence_weight=0.5, # 50% weight on coherence
)
trainer.train()Reinforcement Learning:
from tnad import CoherenceRLTrainer
# RL trainer with PPO
rl_trainer = CoherenceRLTrainer(
model=model,
ref_model=ref_model,
tokenizer=tokenizer,
bond_dim=16,
learning_rate=1e-5,
)
# Train with coherence rewards
rl_trainer.train(prompts=training_prompts, num_epochs=3)Interactive web interface for testing FGBS:
# Launch demo
python -m tnad.web_demo --model meta-llama/Llama-3.1-8B-Instruct --share
# With authentication
python -m tnad.web_demo --model gpt2 --auth username:password
# Custom port
python -m tnad.web_demo --model gpt2 --server-port 7860Features:
- Real-time parameter tuning (α, χ, beam width)
- Coherence trajectory visualization
- Baseline comparison
- Export results to JSON/CSV
Production REST API with WebSocket streaming:
# Start server
python -m tnad.api_server --model meta-llama/Llama-3.1-8B-Instruct --port 8000
# With 8-bit quantization
python -m tnad.api_server --model meta-llama/Llama-3.1-8B-Instruct --load-in-8bit
# Multiple workers
python -m tnad.api_server --model gpt2 --workers 4API Endpoints:
# Single generation
curl -X POST "http://localhost:8000/generate" \
-H "Content-Type: application/json" \
-d '{
"prompt": "Solve: x + 2 = 5",
"max_length": 100,
"alpha": 0.5,
"return_details": true
}'
# Batch generation
curl -X POST "http://localhost:8000/generate/batch" \
-H "Content-Type: application/json" \
-d '{
"prompts": ["Question 1", "Question 2"],
"max_length": 100
}'
# Health check
curl "http://localhost:8000/health"
# Metrics
curl "http://localhost:8000/metrics"Server-Sent Events (SSE):
const eventSource = new EventSource(
'http://localhost:8000/stream?prompt=Hello&max_length=50'
);
eventSource.onmessage = (event) => {
const data = JSON.parse(event.data);
console.log(data.token);
};WebSocket:
const ws = new WebSocket('ws://localhost:8000/ws/generate');
ws.onopen = () => {
ws.send(JSON.stringify({prompt: 'Hello', max_length: 50}));
};
ws.onmessage = (event) => {
const data = JSON.parse(event.data);
console.log(data.token);
};See examples/advanced_features_demo.py for comprehensive examples of all features:
# Run all examples
python examples/advanced_features_demo.py --example all
# Run specific example
python examples/advanced_features_demo.py --example streaming
python examples/advanced_features_demo.py --example encoder-decoder
python examples/advanced_features_demo.py --example vllm- Support for vision-language models (CLIP, LLaVA)
- Beam search with constrained decoding
- Multi-modal coherence scoring
- Federated learning support
- Model compression techniques
Built with ❤️ for more reliable and coherent AI reasoning