-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathformatter.go
More file actions
442 lines (367 loc) · 11.3 KB
/
formatter.go
File metadata and controls
442 lines (367 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
package sqlformatter
import (
"fmt"
"regexp"
"strings"
)
// Formatter SQL formatter configuration
type Formatter struct {
IndentSize int
KeywordUpper bool
}
// NewFormatter creates a new formatter instance
func NewFormatter() *Formatter {
return &Formatter{
IndentSize: 2,
KeywordUpper: true,
}
}
// Format formats the given SQL statement
func (f *Formatter) Format(sql string) (string, error) {
if strings.TrimSpace(sql) == "" {
return "", fmt.Errorf("SQL statement cannot be empty")
}
// 清理并标准化SQL
cleaned := f.cleanSQL(sql)
// 格式化SQL
formatted := f.formatSQL(cleaned)
return formatted, nil
}
// cleanSQL 清理SQL语句
func (f *Formatter) cleanSQL(sql string) string {
// 移除多余的空白字符
sql = regexp.MustCompile(`\s+`).ReplaceAllString(sql, " ")
// 移除首尾空白
sql = strings.TrimSpace(sql)
return sql
}
// formatSQL 格式化SQL语句
func (f *Formatter) formatSQL(sql string) string {
// SQL关键字模式
keywords := []string{
"SELECT", "FROM", "WHERE", "GROUP BY", "HAVING", "ORDER BY", "LIMIT",
"INSERT", "INTO", "VALUES", "UPDATE", "SET", "DELETE",
"JOIN", "INNER JOIN", "LEFT JOIN", "RIGHT JOIN", "FULL JOIN",
"UNION", "UNION ALL", "CASE", "WHEN", "THEN", "ELSE", "END",
}
// 对关键字进行处理
for _, keyword := range keywords {
pattern := `(?i)\b` + regexp.QuoteMeta(keyword) + `\b`
re := regexp.MustCompile(pattern)
replacement := f.keyword(keyword)
sql = re.ReplaceAllString(sql, replacement)
}
// 检测SQL类型并格式化
sqlUpper := strings.ToUpper(strings.TrimSpace(sql))
if strings.HasPrefix(sqlUpper, "SELECT") {
return f.formatSelectStatement(sql)
} else if strings.HasPrefix(sqlUpper, "INSERT") {
return f.formatInsertStatement(sql)
} else if strings.HasPrefix(sqlUpper, "UPDATE") {
return f.formatUpdateStatement(sql)
} else if strings.HasPrefix(sqlUpper, "DELETE") {
return f.formatDeleteStatement(sql)
}
return sql
}
// formatSelectStatement 格式化SELECT语句
func (f *Formatter) formatSelectStatement(sql string) string {
// 使用正则表达式分割SQL的各个部分
parts := f.splitSelectSQL(sql)
var result strings.Builder
indent := f.getIndent(1)
// SELECT部分
if selectPart := parts["SELECT"]; selectPart != "" {
result.WriteString(f.keyword("SELECT"))
result.WriteString("\n")
selectColumns := f.formatSelectColumns(selectPart)
result.WriteString(indent + selectColumns)
}
// FROM部分
if fromPart := parts["FROM"]; fromPart != "" {
result.WriteString("\n" + f.keyword("FROM"))
result.WriteString("\n")
fromClause := f.formatFromClause(fromPart)
result.WriteString(indent + fromClause)
}
// WHERE部分
if wherePart := parts["WHERE"]; wherePart != "" {
result.WriteString("\n" + f.keyword("WHERE"))
result.WriteString("\n")
result.WriteString(indent + wherePart)
}
// GROUP BY部分
if groupByPart := parts["GROUP BY"]; groupByPart != "" {
result.WriteString("\n" + f.keyword("GROUP BY"))
result.WriteString("\n")
result.WriteString(indent + groupByPart)
}
// HAVING部分
if havingPart := parts["HAVING"]; havingPart != "" {
result.WriteString("\n" + f.keyword("HAVING"))
result.WriteString("\n")
result.WriteString(indent + havingPart)
}
// ORDER BY部分
if orderByPart := parts["ORDER BY"]; orderByPart != "" {
result.WriteString("\n" + f.keyword("ORDER BY"))
result.WriteString("\n")
result.WriteString(indent + orderByPart)
}
// LIMIT部分
if limitPart := parts["LIMIT"]; limitPart != "" {
result.WriteString("\n" + f.keyword("LIMIT"))
result.WriteString("\n")
result.WriteString(indent + limitPart)
}
return result.String()
}
// splitSelectSQL 分割SELECT SQL的各个部分
func (f *Formatter) splitSelectSQL(sql string) map[string]string {
parts := make(map[string]string)
// 定义关键字的正则表达式模式
patterns := map[string]string{
"SELECT": `(?i)\bSELECT\s+(.*?)(?:\s+FROM|\s*$)`,
"FROM": `(?i)\bFROM\s+(.*?)(?:\s+WHERE|\s+GROUP\s+BY|\s+HAVING|\s+ORDER\s+BY|\s+LIMIT|\s*$)`,
"WHERE": `(?i)\bWHERE\s+(.*?)(?:\s+GROUP\s+BY|\s+HAVING|\s+ORDER\s+BY|\s+LIMIT|\s*$)`,
"GROUP BY": `(?i)\bGROUP\s+BY\s+(.*?)(?:\s+HAVING|\s+ORDER\s+BY|\s+LIMIT|\s*$)`,
"HAVING": `(?i)\bHAVING\s+(.*?)(?:\s+ORDER\s+BY|\s+LIMIT|\s*$)`,
"ORDER BY": `(?i)\bORDER\s+BY\s+(.*?)(?:\s+LIMIT|\s*$)`,
"LIMIT": `(?i)\bLIMIT\s+(.*?)(?:\s*$)`,
}
for keyword, pattern := range patterns {
re := regexp.MustCompile(pattern)
matches := re.FindStringSubmatch(sql)
if len(matches) > 1 {
parts[keyword] = strings.TrimSpace(matches[1])
}
}
return parts
}
// formatSelectColumns 格式化SELECT列
func (f *Formatter) formatSelectColumns(selectPart string) string {
// 分割列名
columns := f.splitColumns(selectPart)
if len(columns) == 0 {
return selectPart
}
var result strings.Builder
for i, col := range columns {
if i > 0 {
result.WriteString(",\n" + f.getIndent(1))
}
result.WriteString(strings.TrimSpace(col))
}
return result.String()
}
// formatFromClause 格式化FROM子句
func (f *Formatter) formatFromClause(fromPart string) string {
// 处理JOIN
joinPattern := `(?i)\b(INNER\s+JOIN|LEFT\s+JOIN|RIGHT\s+JOIN|FULL\s+JOIN|JOIN)\b`
re := regexp.MustCompile(joinPattern)
// 分割JOIN部分
parts := re.Split(fromPart, -1)
joins := re.FindAllString(fromPart, -1)
var result strings.Builder
result.WriteString(strings.TrimSpace(parts[0])) // 主表
for i, join := range joins {
if i+1 < len(parts) {
result.WriteString("\n" + f.getIndent(1))
result.WriteString(f.keyword(join) + " " + strings.TrimSpace(parts[i+1]))
}
}
return result.String()
}
// splitColumns 分割列名(考虑函数调用中的逗号)
func (f *Formatter) splitColumns(columnsStr string) []string {
var columns []string
var current strings.Builder
var parenCount int
for _, char := range columnsStr {
switch char {
case '(':
parenCount++
current.WriteRune(char)
case ')':
parenCount--
current.WriteRune(char)
case ',':
if parenCount == 0 {
columns = append(columns, current.String())
current.Reset()
} else {
current.WriteRune(char)
}
default:
current.WriteRune(char)
}
}
if current.Len() > 0 {
columns = append(columns, current.String())
}
return columns
}
// formatInsertStatement 格式化INSERT语句
func (f *Formatter) formatInsertStatement(sql string) string {
// INSERT INTO table (col1, col2) VALUES (val1, val2)
insertPattern := `(?i)\bINSERT\s+INTO\s+(\S+)\s*\(([^)]+)\)\s+VALUES\s*\(([^)]+)\)`
re := regexp.MustCompile(insertPattern)
matches := re.FindStringSubmatch(sql)
if len(matches) >= 4 {
tableName := strings.TrimSpace(matches[1])
columns := strings.TrimSpace(matches[2])
values := strings.TrimSpace(matches[3])
var result strings.Builder
indent := f.getIndent(1)
result.WriteString(f.keyword("INSERT INTO") + " " + tableName)
result.WriteString("\n" + indent + "(" + f.formatColumnList(columns) + ")")
result.WriteString("\n" + f.keyword("VALUES"))
result.WriteString("\n" + indent + "(" + f.formatValueList(values) + ")")
return result.String()
}
// 如果不匹配标准格式,返回格式化的关键字版本
return f.formatKeywords(sql)
}
// formatUpdateStatement 格式化UPDATE语句
func (f *Formatter) formatUpdateStatement(sql string) string {
// 分割UPDATE语句的各个部分
parts := f.splitUpdateSQL(sql)
var result strings.Builder
indent := f.getIndent(1)
// UPDATE部分
if updatePart := parts["UPDATE"]; updatePart != "" {
result.WriteString(f.keyword("UPDATE") + " " + updatePart)
}
// SET部分
if setPart := parts["SET"]; setPart != "" {
result.WriteString("\n" + f.keyword("SET"))
result.WriteString("\n" + indent + f.formatSetClause(setPart))
}
// WHERE部分
if wherePart := parts["WHERE"]; wherePart != "" {
result.WriteString("\n" + f.keyword("WHERE"))
result.WriteString("\n" + indent + wherePart)
}
return result.String()
}
// formatDeleteStatement 格式化DELETE语句
func (f *Formatter) formatDeleteStatement(sql string) string {
// 分割DELETE语句的各个部分
parts := f.splitDeleteSQL(sql)
var result strings.Builder
indent := f.getIndent(1)
// DELETE FROM部分
if fromPart := parts["FROM"]; fromPart != "" {
result.WriteString(f.keyword("DELETE FROM") + " " + fromPart)
}
// WHERE部分
if wherePart := parts["WHERE"]; wherePart != "" {
result.WriteString("\n" + f.keyword("WHERE"))
result.WriteString("\n" + indent + wherePart)
}
return result.String()
}
// formatColumnList 格式化列列表
func (f *Formatter) formatColumnList(columns string) string {
cols := f.splitColumns(columns)
if len(cols) <= 1 {
return columns
}
var result strings.Builder
for i, col := range cols {
if i > 0 {
result.WriteString(", ")
}
result.WriteString(strings.TrimSpace(col))
}
return result.String()
}
// formatValueList 格式化值列表
func (f *Formatter) formatValueList(values string) string {
vals := f.splitColumns(values) // 复用splitColumns逻辑
if len(vals) <= 1 {
return values
}
var result strings.Builder
for i, val := range vals {
if i > 0 {
result.WriteString(", ")
}
result.WriteString(strings.TrimSpace(val))
}
return result.String()
}
// formatSetClause 格式化SET子句
func (f *Formatter) formatSetClause(setPart string) string {
// 分割SET子句中的赋值语句
assignments := f.splitColumns(setPart)
if len(assignments) <= 1 {
return setPart
}
var result strings.Builder
indent := f.getIndent(1)
for i, assignment := range assignments {
if i > 0 {
result.WriteString(",\n" + indent)
}
result.WriteString(strings.TrimSpace(assignment))
}
return result.String()
}
// splitUpdateSQL 分割UPDATE SQL的各个部分
func (f *Formatter) splitUpdateSQL(sql string) map[string]string {
parts := make(map[string]string)
patterns := map[string]string{
"UPDATE": `(?i)\bUPDATE\s+(.*?)(?:\s+SET|\s*$)`,
"SET": `(?i)\bSET\s+(.*?)(?:\s+WHERE|\s*$)`,
"WHERE": `(?i)\bWHERE\s+(.*?)(?:\s*$)`,
}
for keyword, pattern := range patterns {
re := regexp.MustCompile(pattern)
matches := re.FindStringSubmatch(sql)
if len(matches) > 1 {
parts[keyword] = strings.TrimSpace(matches[1])
}
}
return parts
}
// splitDeleteSQL 分割DELETE SQL的各个部分
func (f *Formatter) splitDeleteSQL(sql string) map[string]string {
parts := make(map[string]string)
patterns := map[string]string{
"FROM": `(?i)\bDELETE\s+FROM\s+(.*?)(?:\s+WHERE|\s*$)`,
"WHERE": `(?i)\bWHERE\s+(.*?)(?:\s*$)`,
}
for keyword, pattern := range patterns {
re := regexp.MustCompile(pattern)
matches := re.FindStringSubmatch(sql)
if len(matches) > 1 {
parts[keyword] = strings.TrimSpace(matches[1])
}
}
return parts
}
// formatKeywords 格式化关键字(备用方法)
func (f *Formatter) formatKeywords(sql string) string {
keywords := []string{
"INSERT INTO", "VALUES", "UPDATE", "SET", "DELETE FROM", "WHERE",
}
result := sql
for _, keyword := range keywords {
pattern := `(?i)\b` + regexp.QuoteMeta(keyword) + `\b`
re := regexp.MustCompile(pattern)
result = re.ReplaceAllString(result, f.keyword(keyword))
}
return result
}
// keyword 处理关键字大小写
func (f *Formatter) keyword(word string) string {
if f.KeywordUpper {
return strings.ToUpper(word)
}
return strings.ToLower(word)
}
// getIndent 获取缩进字符串
func (f *Formatter) getIndent(level int) string {
return strings.Repeat(" ", level*f.IndentSize)
}