-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathRouter.py
More file actions
132 lines (100 loc) · 4.86 KB
/
Router.py
File metadata and controls
132 lines (100 loc) · 4.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import json
import sys
from enum import Enum
from pydantic import BaseModel
from openai import OpenAI
from dotenv import load_dotenv
load_dotenv()
# ── Constants ────────────────────────────────────────────────────────────────
ROUTER_MODEL = "gpt-4o-mini" # cheapest; only does classification
OUTPUT_FILE = "router_output.json" # consumed by next script in pipeline
TASK_TYPES = ["math_qa", "email_gen"] # extend this list as you add task types
TASK_TYPES = ["math_qa", "email_gen", "email_eval"]
# ── Pydantic schema (enforces structured output) ─────────────────────────────
class TaskType(str, Enum):
math_qa = "math_qa"
email_gen = "email_gen"
email_eval = "email_eval"
class RouterOutput(BaseModel):
task_type : TaskType
confidence : float # 0.0 – 1.0
reasoning : str # one-line explanation (useful for debugging)
clean_query: str # stripped / normalized version of user input
# ── System prompt ─────────────────────────────────────────────────────────────
SYSTEM_PROMPT = """
You are a task classifier for a prompt-routing system.
Classify the user's input into exactly one of these task types:
- math_qa : any mathematical problem, calculation, word problem, algebra, arithmetic
- email_gen : requests to write, draft, compose, or improve an email of any kind
- email_eval : evaluate or review an email
Respond ONLY with a valid JSON object matching this schema exactly:
{
"task_type" : "<math_qa | email_gen>",
"confidence" : <float between 0.0 and 1.0>,
"reasoning" : "<one sentence explaining why>",
"clean_query": "<user query, lightly cleaned — fix typos, remove filler words>"
}
Do not include any text outside the JSON object.
""".strip()
# ── Core classifier function ──────────────────────────────────────────────────
def classify(user_query: str) -> RouterOutput:
"""
Send user_query to gpt-4o-mini and return a validated RouterOutput.
"""
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
response = client.chat.completions.create(
model = ROUTER_MODEL,
temperature = 0.0, # deterministic — classification, not generation
max_tokens = 200,
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_query},
],
)
raw_json = response.choices[0].message.content.strip()
try:
parsed = RouterOutput.model_validate_json(raw_json)
except Exception as e:
raise ValueError(
f"Router returned invalid JSON.\nRaw response:\n{raw_json}\nError: {e}"
)
return parsed
# ── Output writer (for downstream scripts) ────────────────────────────────────
def save_output(result: RouterOutput, query: str) -> None:
"""
Write structured output to router_output.json.
This file is the handoff contract between router.py and the next script.
"""
payload = {
"original_query": query,
**result.model_dump(),
}
with open(OUTPUT_FILE, "w") as f:
json.dump(payload, f, indent=2)
print(f"\n[router] Output saved → {OUTPUT_FILE}")
# ── CLI entry point ───────────────────────────────────────────────────────────
def main():
if len(sys.argv) > 1:
# Accept query as CLI argument: python router.py "solve 2x + 3 = 9"
user_query = " ".join(sys.argv[1:])
else:
# Interactive mode
print("=== Task Router ===")
user_query = input("Enter your query: ").strip()
if not user_query:
print("[error] Empty query. Exiting.")
sys.exit(1)
print(f"\n[router] Classifying: '{user_query}'")
result = classify(user_query)
save_output(result, user_query)
# ── Pretty print for dev visibility ──
print("\n── Classification Result ──────────────────")
print(f" Task Type : {result.task_type.value}")
print(f" Confidence : {result.confidence:.0%}")
print(f" Reasoning : {result.reasoning}")
print(f" Clean Query: {result.clean_query}")
print("───────────────────────────────────────────\n")
return result
if __name__ == "__main__":
main()