Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
dataset/sqlite_tables.db
*__pycache__
.env
venv
dist/

*.egg-info/
Expand Down
102 changes: 102 additions & 0 deletions llmsql/_cli/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import argparse
from typing import Any

from llmsql._cli.subparsers import SubCommand
from llmsql.config.config import DEFAULT_LLMSQL_VERSION, DEFAULT_WORKDIR_PATH
from llmsql.evaluation.evaluate import evaluate


class Evaluate(SubCommand):
"""Command for LLM evaluation"""

def __init__(
self,
subparsers:argparse._SubParsersAction,
*args:Any,
**kwargs:Any
)->None:
self._parser = subparsers.add_parser(
"evaluate",
help = "Evaluate predictions against the LLMSQL benchmark",
formatter_class=argparse.RawDescriptionHelpFormatter,
)

self._add_args()
self._parser.set_defaults(func = self._execute)


def _add_args(self)->None:
"""Add evaluation-specific arguments to the parser."""
self._parser.add_argument(
"--outputs",
type=str,
required=True,
help="Path to the .json file containing the model's generated queries.",
)

self._parser.add_argument(
"--version",
type = str,
default = DEFAULT_LLMSQL_VERSION,
choices=["1.0","2.0"],
help = f"LLMSQL benchmark version (default:{DEFAULT_LLMSQL_VERSION})"
)

self._parser.add_argument(
"--workdir-path",
default = DEFAULT_WORKDIR_PATH,
help = f"Directory for benchmark files (default: {DEFAULT_WORKDIR_PATH})",
)

self._parser.add_argument(
"--questions-path",
type = str,
default = None,
help = "Manual path to benchmark questions JSON file.",
)

self._parser.add_argument(
"--db-path",
type=str,
default = None,
help = "Path to SQLite benchmark database.",
)

self._parser.add_argument(
"--show-mismatches",
action="store_true",
default=True,
help="Print SQL mismatches during evaluation",
)

self._parser.add_argument(
"--max-mismatches",
type=int,
default=5,
help="Maximum mismatches to display",
)

self._parser.add_argument(
"--save-report",
type=str,
default=None,
help="Path to save evaluation report JSON.",
)

@staticmethod
def _execute(args: argparse.Namespace) -> None:
"""Execute the evaluate function with parsed arguments."""
try:
evaluate(
outputs=args.outputs,
version = args.version,
workdir_path=args.workdir_path,
questions_path=args.questions_path,
db_path=args.db_path,
save_report=args.save_report,
show_mismatches=args.show_mismatches,
max_mismatches=args.max_mismatches,
)
except Exception as e:
print(f"Error during evaluation: {e}")

3 changes: 2 additions & 1 deletion llmsql/_cli/llmsql_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import textwrap

from llmsql._cli.inference import Inference

from llmsql._cli.evaluate import Evaluate

class ParserCLI:
"""Main CLI parser that manages all subcommands."""
Expand Down Expand Up @@ -52,6 +52,7 @@ def __init__(self) -> None:
)

Inference(self._subparsers)
Evaluate(self._subparsers)

def parse_args(self) -> argparse.Namespace:
"""Parse CLI arguments."""
Expand Down
40 changes: 40 additions & 0 deletions tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,43 @@ async def test_help_shows_without_crashing(monkeypatch, capsys):

captured = capsys.readouterr()
assert "usage:" in captured.err.lower() or "usage:" in captured.out.lower()



@pytest.mark.asyncio
async def test_evaluate_command_called(monkeypatch):
"""
Ensure the evaluate command is correctly invoked with arguments.
"""

mock_evaluate = AsyncMock(return_value={})


monkeypatch.setattr(
"llmsql._cli.evaluate.evaluate",
mock_evaluate,
)

test_args = [
"llmsql",
"evaluate",
"--outputs",
"dummy_file.jsonl",
"--show-mismatches",
"--max-mismatches",
"10"
]

monkeypatch.setattr(sys, "argv", test_args)


cli = ParserCLI()
args = cli.parse_args()
cli.execute(args)

mock_evaluate.assert_called_once()

call_kwargs = mock_evaluate.call_args.kwargs
assert call_kwargs["outputs"] == "dummy_file.jsonl"
assert call_kwargs["show_mismatches"] is True
assert call_kwargs["max_mismatches"] == 10