Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 69 additions & 1 deletion apps/base_rag_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import argparse
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any
from typing import Any, Callable

import dotenv
from leann.api import LeannBuilder, LeannChat
Expand Down Expand Up @@ -55,10 +55,12 @@ def __init__(
name: str,
description: str,
default_index_name: str,
example_queries: list[str] | None = None,
):
self.name = name
self.description = description
self.default_index_name = default_index_name
self.example_queries = example_queries or []
self.parser = self._create_parser()

def _create_parser(self) -> argparse.ArgumentParser:
Expand Down Expand Up @@ -282,6 +284,50 @@ def get_llm_config(self, args) -> dict[str, Any]:

return config

def _foreach_source(
self,
sources: list,
args,
load: Callable[[Any, int], list | None],
*,
source_label: str = "source",
start_total: int = 0,
) -> tuple[list, int]:
"""Process sources with max_items tracking and error handling.

Args:
sources: List of source paths/identifiers to iterate.
args: Parsed argparse namespace (must have ``max_items``).
load: Callable ``(source, max_count) -> list | None`` that loads
documents for a single source. Return None/empty to skip.
source_label: Label used in progress messages.
start_total: Starting count of already-processed documents.

Returns:
``(all_documents, total_processed)``.
"""
all_docs = []
total = start_total
for i, source in enumerate(sources):
print(f"\nProcessing {source_label} {i + 1}/{len(sources)}: {source}")
try:
max_count = -1
if args.max_items > 0:
remaining = args.max_items - total
if remaining <= 0:
print(f"Reached max_items limit ({args.max_items})")
break
max_count = remaining
docs = load(source, max_count)
if docs:
all_docs.extend(docs)
total += len(docs)
print(f"Processed {len(docs)} items from this {source_label}")
except Exception as e:
print(f"Error processing {source}: {e}")
continue
return all_docs, total

async def build_index(self, args, texts: list[dict[str, Any]]) -> str:
"""Build LEANN index from text chunks (dicts with 'text' and 'metadata' keys)."""
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
Expand Down Expand Up @@ -411,3 +457,25 @@ async def run(self):
await self.run_single_query(args, index_path, args.query)
else:
await self.run_interactive_chat(args, index_path)

def _print_header(self):
"""Print a header with name, example queries, and usage hint.

Override in subclasses to add platform warnings or extra help text.
"""
print(f"\n{self.name} RAG Example")
print("=" * 50)
if self.example_queries:
print("\nExample queries you can try:")
for q in self.example_queries:
print(f"- '{q}'")
print("\nOr run without --query for interactive mode\n")

@classmethod
def main(cls):
"""Standard __main__ entry point. Prints header, then runs the app."""
import asyncio

app = cls()
app._print_header()
asyncio.run(app.run())
62 changes: 19 additions & 43 deletions apps/browser_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,18 @@ def __init__(self):
name="Browser History",
description="Process and query Chrome browser history with LEANN",
default_index_name="google_history_index",
example_queries=[
"What websites did I visit about machine learning?",
"Find my search history about programming",
"What YouTube videos did I watch recently?",
"Show me websites about travel planning",
],
)

def _print_header(self):
super()._print_header()
print("Note: Make sure Chrome is closed before running\n")

def _add_specific_arguments(self, parser):
"""Add browser-specific arguments."""
browser_group = parser.add_argument_group("Browser Parameters")
Expand Down Expand Up @@ -111,35 +121,14 @@ async def load_data(self, args) -> list[dict[str, Any]]:
reader = ChromeHistoryReader()

# Process each profile
all_documents = []
total_processed = 0

for i, profile_dir in enumerate(profile_dirs):
print(f"\nProcessing profile {i + 1}/{len(profile_dirs)}: {profile_dir.name}")

try:
# Apply max_items limit per profile
max_per_profile = -1
if args.max_items > 0:
remaining = args.max_items - total_processed
if remaining <= 0:
break
max_per_profile = remaining

# Load history
documents = reader.load_data(
chrome_profile_path=str(profile_dir),
max_count=max_per_profile,
)

if documents:
all_documents.extend(documents)
total_processed += len(documents)
print(f"Processed {len(documents)} history entries from this profile")

except Exception as e:
print(f"Error processing {profile_dir}: {e}")
continue
all_documents, _ = self._foreach_source(
profile_dirs,
args,
load=lambda src, mc: reader.load_data(
chrome_profile_path=str(src), max_count=mc
),
source_label="profile",
)

if not all_documents:
print("No browser history found to process!")
Expand All @@ -156,17 +145,4 @@ async def load_data(self, args) -> list[dict[str, Any]]:


if __name__ == "__main__":
import asyncio

# Example queries for browser history RAG
print("\n🌐 Browser History RAG Example")
print("=" * 50)
print("\nExample queries you can try:")
print("- 'What websites did I visit about machine learning?'")
print("- 'Find my search history about programming'")
print("- 'What YouTube videos did I watch recently?'")
print("- 'Show me websites about travel planning'")
print("\nNote: Make sure Chrome is closed before running\n")

rag = BrowserRAG()
asyncio.run(rag.run())
BrowserRAG.main()
134 changes: 134 additions & 0 deletions apps/chat_export_rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
Shared base for chat export RAG apps (ChatGPT, Claude, etc.).
Unifies: find export files → load with reader → chunk → index.
"""
import sys
from pathlib import Path
from typing import Any, Callable

sys.path.insert(0, str(Path(__file__).parent))

from base_rag_example import BaseRAGExample
from chunking import create_text_chunks


class ChatExportRAG(BaseRAGExample):
"""Generic RAG app for chat export data (ChatGPT, Claude, etc.).

No method overrides needed — just provide constructor args.
"""

def __init__(
self,
name: str,
description: str,
default_index_name: str,
reader_factory: Callable[[bool], Any],
export_keyword: str,
file_extensions: list[str],
default_export_dir: str,
example_queries: list[str],
export_setup_instructions: list[str],
):
self._reader_factory = reader_factory
self._export_keyword = export_keyword
self._file_extensions = file_extensions
self._default_export_dir = default_export_dir
self._export_setup_instructions = export_setup_instructions

self.max_items_default = -1
self.embedding_model_default = "sentence-transformers/all-MiniLM-L6-v2"

super().__init__(
name=name,
description=description,
default_index_name=default_index_name,
example_queries=example_queries,
)

def _add_specific_arguments(self, parser):
group = parser.add_argument_group(f"{self.name} Parameters")
group.add_argument(
"--export-path",
type=str,
default=self._default_export_dir,
help=f"Path to {self.name} export file or directory (default: {self._default_export_dir})",
)
group.add_argument(
"--concatenate-conversations",
action="store_true",
default=True,
help="Concatenate messages within conversations for better context (default: True)",
)
group.add_argument(
"--separate-messages",
action="store_true",
help="Process each message as a separate document (overrides --concatenate-conversations)",
)
group.add_argument(
"--chunk-size", type=int, default=512, help="Text chunk size (default: 512)"
)
group.add_argument(
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
)

def _find_exports(self, export_path: Path) -> list[Path]:
export_files: list[Path] = []
if export_path.is_file():
if export_path.suffix.lower() in self._file_extensions:
export_files.append(export_path)
elif export_path.is_dir():
for ext in self._file_extensions:
export_files.extend(export_path.glob(f"*{ext}"))
return export_files

async def load_data(self, args) -> list[dict[str, Any]]:
export_path = Path(args.export_path)

if not export_path.exists():
print(f"{self.name} export path not found: {export_path}")
print("Please ensure you have exported your data and placed it in the correct location.")
for line in self._export_setup_instructions:
print(line)
return []

export_files = self._find_exports(export_path)

if not export_files:
exts = ", ".join(self._file_extensions)
print(f"No {self.name} export files ({exts}) found in: {export_path}")
return []

print(f"Found {len(export_files)} {self.name} export files")

concatenate = args.concatenate_conversations and not args.separate_messages
reader = self._reader_factory(concatenate)

all_documents, _ = self._foreach_source(
export_files,
args,
load=lambda src, mc: reader.load_data(
**{
f"{self._export_keyword}_export_path": str(src),
"max_count": mc,
"include_metadata": True,
}
),
source_label="export file",
)

if not all_documents:
print("No conversations found to process!")
print("\nTroubleshooting:")
print("- Ensure the export file is a valid export")
return []

print(f"\nTotal conversations processed: {len(all_documents)}")
print("Now starting to split into text chunks... this may take some time")

all_texts = create_text_chunks(
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
)

print(f"Created {len(all_texts)} text chunks from {len(all_documents)} conversations")
return all_texts
Loading
Loading