Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 45 additions & 24 deletions pkg/tools/builtin/sql_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,23 @@ import (
// Error is non-empty when the query failed (DB error, DML rejection, or
// validation failure before execution).
//
// Columns / Rows / Truncated are populated on successful execution so adopters
// can flow result sets out of the sub-agent (e.g. into a side-by-side data
// grid) without re-running the query against their own DB handle. On failure
// these fields are zero-valued; partial rows from a mid-iteration error are
// still surfaced, mirroring the SQLResult returned to the model.
// Columns / Rows / RowCount / ExecutionMs / Truncated are populated on
// successful execution so adopters can flow result sets out of the sub-agent
// (e.g. into a side-by-side data grid, or a "INSERT executed · N rows" banner)
// without re-running the query against their own DB handle. RowCount carries
// rows returned for reads and rows affected for DML/DDL, mirroring
// SQLResult.RowCount. On early validation failure these fields are zero-valued;
// partial rows from a mid-iteration error are still surfaced, mirroring the
// SQLResult returned to the model.
type SQLQueryEvent struct {
SessionKey string
Query string
Error string // empty on success
Columns []string // populated on success; nil on early validation failure
Rows []map[string]any // populated on success; nil on early validation failure
Truncated bool // true when the MaxRows safety cap clipped output
SessionKey string
Query string
Error string // empty on success
Columns []string // populated on success; nil on early validation failure
Rows []map[string]any // populated on success; nil on early validation failure
RowCount int // rows returned (read) or affected (DML/DDL); mirrors SQLResult.RowCount
ExecutionMs int64 // wall-clock execution time; zero on early validation failure
Truncated bool // true when the MaxRows safety cap clipped output
}

// SQLResult is the structured envelope returned from execute_sql back to the
Expand Down Expand Up @@ -299,7 +304,7 @@ func (t *CallSQLAgentTool) WithAllowSelectStar(allow bool) *CallSQLAgentTool {
}

const callSQLAgentDefaultName = "call_sql_agent"
const callSQLAgentDescription = "Translate natural language business questions into SQL and query the Database directly. It automatically determines tables, runs queries, and returns structured data."
const callSQLAgentDescription = "Answer a natural-language question about the database. A SQL sub-agent determines the tables, runs the query, and returns a concise written answer summarizing the results — not raw table rows. Use it to look things up and reason about the data, not to dump or export full result sets verbatim."

type sqlAgentArgs struct {
Query string `json:"query" description:"The natural language query or task for the SQL database Sub-Agent to perform."`
Expand Down Expand Up @@ -342,17 +347,31 @@ func (t *CallSQLAgentTool) Execute(ctx context.Context, argsJSON string) (tools.
}

if t.selfConsistency > 1 {
out, err := t.runCandidates(ctx, args.Query, t.selfConsistency)
out, structured, err := t.runCandidates(ctx, args.Query, t.selfConsistency)
if err != nil {
return tools.Result{}, err
}
return tools.Text(out), nil
return sqlAgentResult(out, structured), nil
}
cand := t.runOnce(ctx, args.Query, 0)
if cand.finalResp == "" {
return tools.Text("Process finished but no verbal response was given. Check logs."), nil
}
return tools.Text(cand.finalResp), nil
return sqlAgentResult(cand.finalResp, cand.lastResult), nil
}

// sqlAgentResult packages the sub-agent's natural-language answer as the
// model-visible Text and attaches the last successful SQLResult on
// Structured for host integrations (parity with the SQLQueryEvent hook).
// The model only ever reads Text; Structured reaches hosts via OnToolResult.
// A nil result is left off Structured rather than stored as a typed-nil
// *SQLResult, so adopters' `if res.Structured != nil` checks stay honest.
func sqlAgentResult(text string, structured *SQLResult) tools.Result {
res := tools.Text(text)
if structured != nil {
res.Structured = structured
}
return res
}

// runOnce executes a single sub-agent run and returns its candidate record.
Expand Down Expand Up @@ -444,7 +463,7 @@ func hitlBlockedReport(hitlEvent agent.StreamEvent, innerSummary string) string

// runCandidates fans out n parallel sub-agent runs, clusters their execution
// results, and returns the answer from the winning cluster.
func (t *CallSQLAgentTool) runCandidates(ctx context.Context, query string, n int) (string, error) {
func (t *CallSQLAgentTool) runCandidates(ctx context.Context, query string, n int) (string, *SQLResult, error) {
cands := make([]sqlCandidate, n)
var wg sync.WaitGroup
for i := range n {
Expand All @@ -458,9 +477,9 @@ func (t *CallSQLAgentTool) runCandidates(ctx context.Context, query string, n in

winner := pickByMajority(cands)
if winner == nil || winner.finalResp == "" {
return "Process finished but no verbal response was given. Check logs.", nil
return "Process finished but no verbal response was given. Check logs.", nil, nil
}
return winner.finalResp, nil
return winner.finalResp, winner.lastResult, nil
}

// roleLine returns the first sentence of the sub-agent system prompt,
Expand Down Expand Up @@ -685,12 +704,14 @@ func (t *executeSQLTool) makeEmitFunc(ctx context.Context) func(SQLResult) (tool
return func(res SQLResult) (tools.Result, error) {
if t.onSQL != nil {
t.onSQL(ctx, SQLQueryEvent{
SessionKey: t.sessionKey,
Query: res.SQL,
Error: res.Error,
Columns: res.Columns,
Rows: res.Rows,
Truncated: res.Truncated,
SessionKey: t.sessionKey,
Query: res.SQL,
Error: res.Error,
Columns: res.Columns,
Rows: res.Rows,
RowCount: res.RowCount,
ExecutionMs: res.ExecutionMs,
Truncated: res.Truncated,
})
}
b, err := json.Marshal(res)
Expand Down
29 changes: 29 additions & 0 deletions pkg/tools/builtin/sql_agent_builders_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,35 @@ func TestExecuteSQLTool_RejectsMutationsWhenDisabled(t *testing.T) {
}
}

func TestExecuteSQLTool_EmitSurfacesRowCountAndExecutionMs(t *testing.T) {
var got SQLQueryEvent
exec := &executeSQLTool{
sessionKey: "sess_1",
onSQL: func(_ context.Context, ev SQLQueryEvent) { got = ev },
}
emit := exec.makeEmitFunc(context.Background())

res := SQLResult{
SQL: "UPDATE customers SET active = false WHERE id = 1",
RowCount: 3,
ExecutionMs: 42,
}
out, err := emit(res)
if err != nil {
t.Fatalf("emit returned error: %v", err)
}
if got.RowCount != 3 {
t.Fatalf("SQLQueryEvent.RowCount: got %d, want 3", got.RowCount)
}
if got.ExecutionMs != 42 {
t.Fatalf("SQLQueryEvent.ExecutionMs: got %d, want 42", got.ExecutionMs)
}
// The model-facing payload must keep carrying the same numbers.
if !strings.Contains(out.Text, `"row_count":3`) || !strings.Contains(out.Text, `"execution_ms":42`) {
t.Fatalf("marshalled result should retain row_count/execution_ms, got %s", out.Text)
}
}

func TestExecuteSQLTool_DescriptionReflectsMutationFlag(t *testing.T) {
off := (&executeSQLTool{allowMutations: false}).Descriptor().Description
on := (&executeSQLTool{allowMutations: true}).Descriptor().Description
Expand Down
50 changes: 50 additions & 0 deletions pkg/tools/builtin/sql_agent_result_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package builtin

import (
"strings"
"testing"
)

// TestSQLAgentResult_AttachesStructuredWhenPresent verifies the sub-agent's
// natural-language answer is the model-visible Text and the captured
// SQLResult rides on Structured for host integrations.
func TestSQLAgentResult_AttachesStructuredWhenPresent(t *testing.T) {
sr := &SQLResult{Columns: []string{"id"}, Rows: []map[string]any{{"id": 1}}, RowCount: 1}
res := sqlAgentResult("found 1 row", sr)

if res.Text != "found 1 row" {
t.Fatalf("Text: got %q, want %q", res.Text, "found 1 row")
}
got, ok := res.Structured.(*SQLResult)
if !ok {
t.Fatalf("Structured: got %T, want *SQLResult", res.Structured)
}
if got != sr {
t.Fatalf("Structured: got %p, want %p", got, sr)
}
}

// TestSQLAgentResult_NilResultLeavesStructuredNil guards the typed-nil trap:
// a nil *SQLResult must NOT be boxed into the `any` field (which would make
// `res.Structured != nil` true for adopters and break their guards).
func TestSQLAgentResult_NilResultLeavesStructuredNil(t *testing.T) {
res := sqlAgentResult("no query succeeded", nil)

if res.Structured != nil {
t.Fatalf("Structured: got %#v (type %T), want untyped nil", res.Structured, res.Structured)
}
}

// TestCallSQLAgentTool_DescriptionDoesNotClaimRawRows locks in the honest
// contract: the tool returns a written answer, not raw/exportable rows. A
// misleading "returns structured data" claim is what drove a parent agent to
// loop trying to coax verbatim rows out of it for a CSV export.
func TestCallSQLAgentTool_DescriptionDoesNotClaimRawRows(t *testing.T) {
desc := NewCallSQLAgentTool(nil, "", nil, nil).Descriptor().Description
if strings.Contains(desc, "returns structured data") {
t.Fatalf("description still makes the misleading structured-data claim: %q", desc)
}
if !strings.Contains(desc, "not raw table rows") {
t.Fatalf("description should state it does not return raw rows: %q", desc)
}
}
Loading