-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval.py
More file actions
160 lines (112 loc) · 4.36 KB
/
eval.py
File metadata and controls
160 lines (112 loc) · 4.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import json
import os
import random
import uuid
from datetime import datetime, timezone
from typing import Dict, List
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from langchain_tavily import TavilySearch
import os
os.environ["TAVILY_API_KEY"] = "your_tavily_api_key_here" # Replace with your actual Tavily API key
# Load dataset
def load_dataset(path: str) -> List[Dict]:
with open(path, "r") as f:
data = json.load(f)
return data
dataset = load_dataset("website_dataset.json")
# Sanity check
labels = {item["label"] for item in dataset}
print("Labels found:", labels)
# Prompt Construction
def build_prompt(entry: Dict) -> str:
return f"""
You are given a topic and source URL.
Topic: {entry['topic']}
Source URL: {entry['url']}
Task:
1) Use web_search tool to gather up-to-date, relevant snippets.
2) Explain the key claims related to this topic.
3) If the information appears unreliable, incomplete, or misleading, explain why.
4) If you are uncertain, state the uncertainty explicitly.
Respond concisely.
""".strip()
# Model + Tools (LangChain style, per Tool Usage -- Web Search)
NVIDIA_API_KEY = "your_nvidia_api_key_here" # Replace with your actual API key or set as environment variable
if not NVIDIA_API_KEY:
raise RuntimeError("Please set NVIDIA_API_KEY in your environment to call the NIM endpoint.")
MODEL_NAME = "meta/llama-3.1-8b-instruct"
BASE_URL = "https://integrate.api.nvidia.com/v1"
# Chat model
llm = ChatNVIDIA(model=MODEL_NAME, api_key=NVIDIA_API_KEY, base_url=BASE_URL, temperature=0.2)
# Tavily web search tool (requires TAVILY_API_KEY)
if "TAVILY_API_KEY" not in os.environ:
raise RuntimeError("Please set TAVILY_API_KEY for Tavily web search tool calls.")
tavily_search = TavilySearch(max_results=3)
tools_by_name = {tavily_search.name: tavily_search}
# Bind tools to the model so it can decide to call web search
llm_with_tools = llm.bind_tools([tavily_search])
def query_model_with_web(entry: Dict, prompt: str) -> str:
"""Use tool-calling to run web search then answer."""
system_msg = SystemMessage(
content=(
"You are a careful assistant. Use the web_search tool when you need current or source information. "
"Cite the URL snippets you use and state uncertainty when unsure."
)
)
user_msg = HumanMessage(content=prompt)
ai_msg = llm_with_tools.invoke([system_msg, user_msg])
# If the model decides no tool is needed
if not getattr(ai_msg, "tool_calls", None):
return ai_msg.content
tool_messages = []
for call in ai_msg.tool_calls:
tool = tools_by_name.get(call["name"])
if not tool:
continue
try:
result = tool.invoke(call["args"])
except Exception as exc: # Keep evaluation running even if a call fails
result = {"error": str(exc)}
tool_messages.append(
ToolMessage(
content=json.dumps(result, ensure_ascii=False),
name=call["name"],
tool_call_id=call["id"],
)
)
final_system_msg = SystemMessage(
content=(
"You have gathered web search results. "
"Now write a concise final answer to the user, citing sources where relevant. "
"Do not call any more tools."
)
)
final_msg = llm_with_tools.invoke(
[system_msg, user_msg, ai_msg, *tool_messages, final_system_msg]
)
return final_msg.content
# Run Evaluation
def run_evaluation(dataset, sample_size=100):
results = []
sample_size = min(sample_size, len(dataset))
sampled_data = random.sample(dataset, sample_size)
for entry in sampled_data:
prompt = build_prompt(entry)
output = query_model_with_web(entry, prompt)
results.append({
"id": str(uuid.uuid4()),
"topic": entry["topic"],
"url": entry["url"],
"gold_label": entry["label"], # not shown to model
"prompt": prompt,
"model_output": output,
"timestamp": datetime.now(timezone.utc).isoformat()
})
return results
results = run_evaluation(dataset, sample_size=100)
# Save Results
def save_results(results, path="model_behavior_outputs.json"):
with open(path, "w") as f:
json.dump(results, f, indent=2)
save_results(results)