forked from Yuyang-Song/QUITE
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_module.py
More file actions
321 lines (265 loc) · 12.9 KB
/
test_module.py
File metadata and controls
321 lines (265 loc) · 12.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
"""
Test Module
Simple unit test for Knowledge Base Tool and DBMS Explain Tool components.
For better DBMS tool testing, verify that self.input_sql is a valid SQL query and can correspond to the database specified in the .env file.
"""
import os
import sys
import unittest
import asyncio
from pathlib import Path
# Setup project paths first
_current_file = Path(__file__).resolve()
_project_root = _current_file.parent
if str(_project_root) not in sys.path:
sys.path.insert(0, str(_project_root))
from src.utils.path_config import PROJECT_ROOT, setup_python_path, load_project_env
setup_python_path()
load_project_env()
from src.Rewrite_Middleware.middleware import Knowledge_Base_Tool, DBMS_EXPLAIN_Tool, DBMS_Syntax_Tool, Equivalence_Check_Tool, DBMS
from src.utils.data_distribution import get_statistics_list, get_available_databases
from src.utils.get_data_statistics import get_data_statistics
from src.utils.llm_client import GPT
print(f"Project root: {PROJECT_ROOT}")
class TestRewriteMiddleware(unittest.TestCase):
"""
####################################################################
# Rewrite Middleware Tests #
####################################################################
"""
def setUp(self):
"""Setup test data"""
self.test_sql = ""
self.suggestion_sql = """
WITH min_supply AS (
SELECT ps.ps_partkey, MIN(ps.ps_supplycost) AS min_supplycost
FROM partsupp ps
JOIN supplier s ON ps.ps_suppkey = s.s_suppkey
JOIN nation n ON s.s_nationkey = n.n_nationkey
JOIN region r ON n.n_regionkey = r.r_regionkey
WHERE r.r_name = 'EUROPE'
GROUP BY ps.ps_partkey
)
SELECT s.s_acctbal, s.s_name, n.n_name, p.p_partkey, p.p_mfgr, s.s_address, s.s_phone, s.s_comment
FROM part p
JOIN partsupp ps ON p.p_partkey = ps.ps_partkey
JOIN supplier s ON ps.ps_suppkey = s.s_suppkey
JOIN nation n ON s.s_nationkey = n.n_nationkey
JOIN min_supply ms ON p.p_partkey = ms.ps_partkey AND ps.ps_supplycost = ms.min_supplycost
WHERE p.p_size = 6
AND p.p_type LIKE '%NICKEL'
ORDER BY s.s_acctbal DESC, n.n_name, s.s_name, p.p_partkey
LIMIT 100;
"""
self.origin_suggestion_list = [
{
"group": "Subquery_optimization",
"origin_suggestion": "Ensure the CTE precomputes the minimum supply cost for parts from European suppliers, thus avoiding repetitive execution."
},
{
"group": "Join_optimization",
"origin_suggestion": "Simplify the main query by leveraging this CTE and reducing redundant joins, particularly eliminating unnecessary joins to the `region` table in the main query."
},
{
"group": "Predication_simplification",
"origin_suggestion": "Replace the correlated subquery that finds the minimum `ps_supplycost` with a more efficient Common Table Expression (CTE)."
}
]
# Load test SQL from file
test_sql_path = PROJECT_ROOT / "dataset" / "queries" / "test_sql.sql"
try:
self.test_sql = open(test_sql_path, 'r', encoding='utf-8').read().strip()
print(f"Loaded test SQL from: {test_sql_path}")
except:
# self.test_sql = "select p_brand, p_type, p_size, count(distinct ps_suppkey) as supplier_cnt from partsupp, part where p_partkey = ps_partkey and p_brand <> 'Brand#43' and p_type not like 'PROMO PLATED%' and p_size in (18, 8, 33, 17, 27, 6, 1, 50) and ps_suppkey not in ( select s_suppkey from supplier where s_comment like '%Customer%Complaints%' ) group by p_brand, p_type, p_size order by supplier_cnt desc, p_brand, p_type, p_size;"
print(f"Using fallback SQL, {test_sql_path} not found")
# Test SQL Pairs for Euiqvalence_Check_Tool
# Initialize DBMS instance
self.dbms = DBMS()
def test_dbms_connection(self):
"""
############################################################
# Test: DBMS Connection #
############################################################
"""
print("\n" + "="*50)
print("🧪 Testing DBMS Connection")
print("="*50)
try:
self.dbms.connect()
print("✅ DBMS Connection Test PASSED!")
except Exception as e:
self.fail(f"❌ DBMS Connection Test FAILED: {str(e)}")
def test_llm_connection(self):
"""
############################################################
# Test: LLM Connection #
############################################################
"""
print("\n" + "="*50)
print("🧪 Testing LLM Connection")
print("="*50)
test_prompt = "What is the capital of France?"
print(f"The reasoning agent configuration is: {os.getenv('REASONING_MODEL')}, base URL: {os.getenv('REASONING_MODEL_URL')}")
print(f"The decision agent configuration is: {os.getenv('DECISION_MODEL')}, base URL: {os.getenv('DECISION_MODEL_URL')}")
print(f"The assistant agent configuration is: {os.getenv('ASSISTANT_MODEL')}, base URL: {os.getenv('ASSISTANT_MODEL_URL')}")
Reasoning_Agent = GPT(
api_key=os.getenv("REASONING_MODEL_API_KEY"),
model=os.getenv("REASONING_MODEL"),
base_url=os.getenv("REASONING_MODEL_URL")
)
Decision_Agent = GPT(
api_key=os.getenv("DECISION_MODEL_API_KEY"),
model=os.getenv("DECISION_MODEL"),
base_url=os.getenv("DECISION_MODEL_URL")
)
Assistant_Agent = GPT(
api_key=os.getenv("ASSISTANT_MODEL_API_KEY"),
model=os.getenv("ASSISTANT_MODEL"),
base_url=os.getenv("ASSISTANT_MODEL_URL")
)
# Test LLM connection
try:
response = Reasoning_Agent.get_LLM_response(test_prompt)
self.assertIsNotNone(response, "LLM response is None")
print("✅ Reasoning Agent Connection Test PASSED!")
except Exception as e:
self.fail(f"❌ Reasoning Agent Connection Test FAILED: {str(e)}")
try:
response = Decision_Agent.get_LLM_response(test_prompt)
self.assertIsNotNone(response, "LLM response is None")
print("✅ Decision Agent Connection Test PASSED!")
except Exception as e:
self.fail(f"❌ Decision Agent Connection Test FAILED: {str(e)}")
try:
response = Assistant_Agent.get_LLM_response(test_prompt)
self.assertIsNotNone(response, "LLM response is None")
print("✅ Assistant Agent Connection Test PASSED!")
except Exception as e:
self.fail(f"❌ Assistant Agent Connection Test FAILED: {str(e)}")
print("🧪 All LLM Connection Tests PASSED!")
def test_dbms_data_distribution(self):
"""
############################################################
# Test: DBMS Data Distribution #
############################################################
"""
print("\n" + "="*50)
print("🧪 Testing DBMS Data Distribution")
print("="*50)
# obtain database name and statistics
DB_NAME = self.dbms.db_name
data_statistics = None
if DB_NAME in get_available_databases():
print(f"Database {DB_NAME} found, retrieving statistics...")
data_statistics = get_statistics_list(DB_NAME)
else:
print(f"Database {DB_NAME} not found, retrieving statistics from default database...")
data_statistics = get_data_statistics()
print(f"Current database: {DB_NAME}")
print(f"Data statistics: {data_statistics}")
def test_dbms_explain_tool(self):
"""
############################################################
# Test: DBMS Explain Tool #
############################################################
"""
print("\n" + "="*50)
print("🧪 Testing DBMS Explain Tool")
print("="*50)
try:
result = asyncio.run(DBMS_EXPLAIN_Tool(self.dbms, self.test_sql))
# Basic assertions
self.assertIsNotNone(result)
self.assertIsInstance(result, list)
print("✅ Test PASSED!")
print(f"📋 Query plans generated: {len(result)} items")
except Exception as e:
self.fail(f"❌ Test FAILED: {str(e)}")
def test_knowledge_base_tool(self):
"""
############################################################
# Test: Knowledge Base Tool #
############################################################
"""
print("\n" + "="*50)
print("🧪 Testing Knowledge Base Tool")
print("="*50)
try:
output = asyncio.run(Knowledge_Base_Tool(self.suggestion_sql, self.origin_suggestion_list))
# Basic assertions
self.assertIsNotNone(output)
print("✅ Test PASSED!")
# print(f"📋 Output: {output}")
except Exception as e:
self.fail(f"❌ Test FAILED: {str(e)}")
def test_dbms_syntax_tool(self):
"""
############################################################
# Test: DBMS Syntax Tool #
############################################################
"""
print("\n" + "="*50)
print("🧪 Testing DBMS Syntax Tool")
print("="*50)
try:
# Assuming DBMS has a method to check syntax
result = asyncio.run(DBMS_Syntax_Tool(self.dbms, self.test_sql))
# Basic assertions
self.assertTrue(result, "Syntax check failed")
print("✅ Test PASSED!")
except Exception as e:
self.fail(f"❌ Test FAILED: {str(e)}")
def test_equivalence_check_tool(self):
"""
############################################################
# Test: Equivalence Check Tool #
############################################################
"""
print("\n" + "="*50)
print("🧪 Testing Equivalence Check Tool")
print("="*50)
os.environ['LD_LIBRARY_PATH'] = str(PROJECT_ROOT / "src" / "Rewrite_Middleware" / "Hybrid_SQL_Corrector" )
SCHEMA_PATH = PROJECT_ROOT / "dataset" / "schemas" / "calcite_schemas.sql"
with open(SCHEMA_PATH, 'r') as f:
schema_content = f.read()
if not schema_content.strip():
raise ValueError(f"Schema file {SCHEMA_PATH} is empty or not found.")
print(f"Schema content loaded from {SCHEMA_PATH}")
try:
# Example SQL pairs for equivalence check
# EQ case
original_sql = "SELECT * FROM (VALUES (1,2)) WHERE FALSE"
rewritten_sql = "SELECT * FROM (SELECT NULL AS EXPR$0, NULL AS EXPR$1) AS t WHERE 1 = 0"
result = []
result.append(asyncio.run(Equivalence_Check_Tool(original_sql, rewritten_sql, SCHEMA_PATH)))
# Basic assertions
self.assertIsNotNone(result)
# NEQ case
original_sql = "SELECT * FROM (VALUES (1,2)) WHERE FALSE"
rewritten_sql = "SELECT * FROM (VALUES (1,2)) WHERE TRUE"
result.append(asyncio.run(Equivalence_Check_Tool(original_sql, rewritten_sql, SCHEMA_PATH)))
# Basic assertions
self.assertIsNotNone(result)
# Unknown case
original_sql = "SELECT 2, EMP.DEPTNO, EMP.JOB FROM EMP AS EMP UNION ALL SELECT 1, EMP0.DEPTNO, EMP0.JOB FROM EMP AS EMP0"
rewritten_sql = "SELECT 2, EMP1.DEPTNO, EMP1.JOB FROM EMP AS EMP1 UNION ALL SELECT 1, EMP2.DEPTNO, EMP2.JOB FROM EMP AS EMP2"
result.append(asyncio.run(Equivalence_Check_Tool(original_sql, rewritten_sql, SCHEMA_PATH)))
# Basic assertions
self.assertIsNotNone(result)
print("✅ Test PASSED!")
print(f"📋 Equivalence results: {result}")
except Exception as e:
self.fail(f"❌ Test FAILED: {str(e)}")
def main():
"""
####################################################################
# Main Runner #
####################################################################
"""
print("="*60)
print("🚀 Starting SQL Optimization Tests")
print("="*60)
unittest.main(verbosity=1)
if __name__ == "__main__":
main()