Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/LLM_Workflows/neo4j_graph_rag/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,4 +202,4 @@ neo4j_graph_rag/
│ └── rag_dag.png
└── data/
└── README.md Dataset download and conversion instructions
```
```
2 changes: 1 addition & 1 deletion examples/LLM_Workflows/neo4j_graph_rag/data/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ with open("tmdb_5000_credits.json", "w") as f:
json.dump(credits.to_dict(orient="records"), f)
```

Run this script once from inside the `data/` folder, then proceed with `python run.py --mode ingest`.
Run this script once from inside the `data/` folder, then proceed with `python run.py --mode ingest`.
12 changes: 7 additions & 5 deletions examples/LLM_Workflows/neo4j_graph_rag/data/data_refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
# under the License.


import pandas as pd, json

import json

import pandas as pd

movies = pd.read_csv("examples/LLM_Workflows/neo4j_graph_rag/data/tmdb_5000_movies.csv")
credits = pd.read_csv("examples/LLM_Workflows/neo4j_graph_rag/data/tmdb_5000_credits.csv")

with open("examples/LLM_Workflows/neo4j_graph_rag/data/tmdb_5000_movies.json", "w") as f:
json.dump(movies.to_dict(orient="records"), f)

with open("examples/LLM_Workflows/neo4j_graph_rag/data/tmdb_5000_credits.json", "w") as f:
json.dump(credits.to_dict(orient="records"), f)
json.dump(credits.to_dict(orient="records"), f)
2 changes: 1 addition & 1 deletion examples/LLM_Workflows/neo4j_graph_rag/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ services:

volumes:
neo4j_data:
neo4j_logs:
neo4j_logs:
4 changes: 2 additions & 2 deletions examples/LLM_Workflows/neo4j_graph_rag/embed_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def movie_embeddings(
model=EMBEDDING_MODEL,
input=[item["text"] for item in batch],
)
for item, emb_obj in zip(batch, response.data):
for item, emb_obj in zip(batch, response.data, strict=True):
results.append({"id": item["id"], "embedding": emb_obj.embedding})

logger.info("Embedded batch %d-%d of %d", i, min(i + BATCH_SIZE, total), total)
Expand Down Expand Up @@ -165,4 +165,4 @@ def embedding_summary(
"dimensions": EMBEDDING_DIMENSIONS,
}
logger.info("Embedding complete: %s", summary)
return summary
return summary
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,4 @@ def answer(prompt_messages: list[dict], openai_api_key: str) -> str:
)
result = response.choices[0].message.content
logger.info("Generated answer (%d chars)", len(result))
return result
return result
2 changes: 1 addition & 1 deletion examples/LLM_Workflows/neo4j_graph_rag/graph_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,4 @@ def schema_to_prompt() -> str:
prop_str = f" with properties: {', '.join(props)}" if props else ""
lines.append(f" (:{src})-[:{rel}]->(:{dest}){prop_str}")

return "\n".join(lines)
return "\n".join(lines)
2 changes: 1 addition & 1 deletion examples/LLM_Workflows/neo4j_graph_rag/ingest_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,4 +323,4 @@ def ingestion_summary(
"person_edges": write_person_nodes_and_edges,
}
logger.info("Ingestion complete: %s", summary)
return summary
return summary
2 changes: 1 addition & 1 deletion examples/LLM_Workflows/neo4j_graph_rag/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
# specific language governing permissions and limitations
# under the License.

sf-hamilton>=1.73.0
neo4j>=5.18.0
openai>=1.30.0
pandas>=2.0.0
python-dotenv>=1.0.0
sf-hamilton>=1.73.0
tqdm>=4.66.0
69 changes: 34 additions & 35 deletions examples/LLM_Workflows/neo4j_graph_rag/retrieval_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,9 @@
import logging

import openai
from neo4j import Driver

from embed_module import EMBEDDING_MODEL, VECTOR_INDEX_NAME
from graph_schema import schema_to_prompt
from neo4j import Driver

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -127,6 +126,7 @@
# 1. Classify query intent
# ---------------------------------------------------------------------------


def query_intent(user_query: str, openai_api_key: str) -> str:
"""
Classify the user query into one of four retrieval strategies:
Expand Down Expand Up @@ -175,6 +175,7 @@ def query_intent(user_query: str, openai_api_key: str) -> str:
# 2. Entity extraction
# ---------------------------------------------------------------------------


def entity_extraction(
user_query: str,
openai_api_key: str,
Expand Down Expand Up @@ -248,6 +249,7 @@ def entity_extraction(
# 3. Entity resolution — look up canonical forms in Neo4j
# ---------------------------------------------------------------------------


def _resolve_persons(names: list[str], session) -> dict[str, str]:
"""Fuzzy-match person names against the graph, return {input: canonical}."""
resolved = {}
Expand Down Expand Up @@ -404,24 +406,16 @@ def entity_resolution(

with neo4j_driver.session() as session:
if entity_extraction.get("persons"):
resolved["persons"] = _resolve_persons(
entity_extraction["persons"], session
)
resolved["persons"] = _resolve_persons(entity_extraction["persons"], session)

if entity_extraction.get("movies"):
resolved["movies"] = _resolve_movies(
entity_extraction["movies"], session
)
resolved["movies"] = _resolve_movies(entity_extraction["movies"], session)

if entity_extraction.get("genres"):
resolved["genres"] = _resolve_genres(
entity_extraction["genres"], session
)
resolved["genres"] = _resolve_genres(entity_extraction["genres"], session)

if entity_extraction.get("companies"):
resolved["companies"] = _resolve_companies(
entity_extraction["companies"], session
)
resolved["companies"] = _resolve_companies(entity_extraction["companies"], session)

# Pass through numeric/date filters unchanged
for key in ("year_after", "year_before", "rating_above", "rating_below"):
Expand All @@ -436,6 +430,7 @@ def entity_resolution(
# 4. Vector path
# ---------------------------------------------------------------------------


def query_embedding(
user_query: str,
openai_api_key: str,
Expand Down Expand Up @@ -499,6 +494,7 @@ def vector_results(
# 5. Cypher generation using resolved entities
# ---------------------------------------------------------------------------


def _build_entity_context(resolved: dict) -> str:
"""
Build a plain-English summary of resolved entities for the Cypher
Expand All @@ -511,32 +507,32 @@ def _build_entity_context(resolved: dict) -> str:

persons = resolved.get("persons", {})
if persons:
for original, canonical in persons.items():
for _original, canonical in persons.items():
lines.append(f' Person: "{canonical}"')

movies = resolved.get("movies", {})
if movies:
for original, canonical in movies.items():
for _original, canonical in movies.items():
lines.append(f' Movie title: "{canonical}"')

genres = resolved.get("genres", {})
if genres:
for original, canonical in genres.items():
for _original, canonical in genres.items():
lines.append(f' Genre: "{canonical}"')

companies = resolved.get("companies", {})
if companies:
for original, canonical in companies.items():
for _original, canonical in companies.items():
lines.append(f' ProductionCompany: "{canonical}"')

if "year_after" in resolved:
lines.append(f' Date filter: m.release_date > \'{resolved["year_after"]}-01-01\'')
lines.append(f" Date filter: m.release_date > '{resolved['year_after']}-01-01'")
if "year_before" in resolved:
lines.append(f' Date filter: m.release_date < \'{resolved["year_before"]}-12-31\'')
lines.append(f" Date filter: m.release_date < '{resolved['year_before']}-12-31'")
if "rating_above" in resolved:
lines.append(f' Rating filter: m.vote_average > {resolved["rating_above"]}')
lines.append(f" Rating filter: m.vote_average > {resolved['rating_above']}")
if "rating_below" in resolved:
lines.append(f' Rating filter: m.vote_average < {resolved["rating_below"]}')
lines.append(f" Rating filter: m.vote_average < {resolved['rating_below']}")

return "\n".join(lines)

Expand Down Expand Up @@ -656,6 +652,7 @@ def cypher_results(
# 6. Enrich vector results with graph traversal
# ---------------------------------------------------------------------------


def _enrich_movie(movie_id: int, driver: Driver) -> dict | None:
"""Pull directors, cast, genres, companies for a movie node."""
cypher = """
Expand Down Expand Up @@ -684,6 +681,7 @@ def _enrich_movie(movie_id: int, driver: Driver) -> dict | None:
# 7. Merge results
# ---------------------------------------------------------------------------


def merged_results(
vector_results: list[dict],
cypher_results: list[dict],
Expand Down Expand Up @@ -722,6 +720,7 @@ def merged_results(
# 8. Format context
# ---------------------------------------------------------------------------


def retrieved_context(merged_results: list[dict], query_intent: str) -> str:
"""
Format merged results into plain-text context for the generation DAG.
Expand All @@ -734,26 +733,26 @@ def retrieved_context(merged_results: list[dict], query_intent: str) -> str:
return "No relevant information found in the knowledge graph for this query."

FIELD_LABELS = {
"movie": "Movie",
"director": "Director",
"actor": "Actor",
"genre": "Genre",
"company": "Production company",
"film_count": "Films",
"movie_count": "Count",
"movie": "Movie",
"director": "Director",
"actor": "Actor",
"genre": "Genre",
"company": "Production company",
"film_count": "Films",
"movie_count": "Count",
"action_movie_count": "Action movies",
"avg_rating": "Avg rating",
"average_rating": "Avg rating",
"vote_average": "Rating",
"release_date": "Released",
"avg_rating": "Avg rating",
"average_rating": "Avg rating",
"vote_average": "Rating",
"release_date": "Released",
}

lines = []
i = 0

for row in merged_results:
i += 1
source = row.get("_source", "unknown")
_source = row.get("_source", "unknown")

if "directors" in row:
# Enriched movie record from vector path
Expand Down Expand Up @@ -786,4 +785,4 @@ def retrieved_context(merged_results: list[dict], query_intent: str) -> str:

context = "\n".join(lines)
logger.info("Formatted context: %d chars from %d results", len(context), len(merged_results))
return context
return context
28 changes: 14 additions & 14 deletions examples/LLM_Workflows/neo4j_graph_rag/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@
import os
from pathlib import Path

from dotenv import load_dotenv
from hamilton import driver
from neo4j import GraphDatabase

import embed_module
import generation_module
import ingest_module
import retrieval_module
from dotenv import load_dotenv
from graph_schema import CONSTRAINTS
from neo4j import GraphDatabase

from hamilton import driver

load_dotenv()
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
Expand Down Expand Up @@ -124,9 +124,11 @@ def run_ingest(visualise: bool = False):
)
s = result["ingestion_summary"]
logger.info(
"Ingestion complete — movies: %d, genre edges: %d, "
"company edges: %d, person edges: %d",
s["movies"], s["genre_edges"], s["company_edges"], s["person_edges"],
"Ingestion complete — movies: %d, genre edges: %d, company edges: %d, person edges: %d",
s["movies"],
s["genre_edges"],
s["company_edges"],
s["person_edges"],
)
drv.close()

Expand Down Expand Up @@ -158,7 +160,9 @@ def run_embed(visualise: bool = False):
s = result["embedding_summary"]
logger.info(
"Embedding complete — %d embeddings written, index: %s, model: %s",
s["embeddings_written"], s["vector_index"], s["model"],
s["embeddings_written"],
s["vector_index"],
s["model"],
)
drv.close()

Expand All @@ -172,11 +176,7 @@ def run_query(question: str, visualise: bool = False):
drv = make_neo4j_driver()
openai_api_key = get_env("OPENAI_API_KEY")

rag_driver = (
driver.Builder()
.with_modules(retrieval_module, generation_module)
.build()
)
rag_driver = driver.Builder().with_modules(retrieval_module, generation_module).build()

if visualise:
rag_driver.display_all_functions("rag_dag.png")
Expand Down Expand Up @@ -239,4 +239,4 @@ def main():


if __name__ == "__main__":
main()
main()
11 changes: 6 additions & 5 deletions hamilton/plugins/h_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def run_before_graph_execution(
# log config and inputs as `param` which creates columns in the UI to filter runs
# `log_param()` accepts `value: Any` and will stringify complex objects
for value_sets in [self.config, inputs]:
if value_sets is None:
continue
for node_name, value in value_sets.items():
self.client.log_param(self.run_id, key=node_name, value=value)

Expand Down Expand Up @@ -325,11 +327,10 @@ def run_after_node_execution(
def run_after_graph_execution(self, success: bool, *args, **kwargs):
"""End the MLFlow run"""
# `status` is an enum value of mlflow.entities.RunStatus
if success:
self.client.set_terminated(self.run_id, status="FINISHED")
else:
self.client.set_terminated(self.run_id, status="FAILED")
mlflow.end_run()
status = "FINISHED" if success else "FAILED"

self.client.set_terminated(self.run_id, status=status)
mlflow.end_run(status=status)

def run_before_node_execution(self, *args, **kwargs):
"""Placeholder required to subclass NodeExecutionHook"""
Loading
Loading