diff --git a/.gitignore b/.gitignore index b74d3e5..8e4c246 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ dataset/sqlite_tables.db *__pycache__ .env +venv dist/ *.egg-info/ diff --git a/llmsql/_cli/evaluate.py b/llmsql/_cli/evaluate.py index e69de29..942e4fb 100644 --- a/llmsql/_cli/evaluate.py +++ b/llmsql/_cli/evaluate.py @@ -0,0 +1,102 @@ +import argparse +from typing import Any + +from llmsql._cli.subparsers import SubCommand +from llmsql.config.config import DEFAULT_LLMSQL_VERSION, DEFAULT_WORKDIR_PATH +from llmsql.evaluation.evaluate import evaluate + + +class Evaluate(SubCommand): + """Command for LLM evaluation""" + + def __init__( + self, + subparsers:argparse._SubParsersAction, + *args:Any, + **kwargs:Any + )->None: + self._parser = subparsers.add_parser( + "evaluate", + help = "Evaluate predictions against the LLMSQL benchmark", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + self._add_args() + self._parser.set_defaults(func = self._execute) + + + def _add_args(self)->None: + """Add evaluation-specific arguments to the parser.""" + self._parser.add_argument( + "--outputs", + type=str, + required=True, + help="Path to the .json file containing the model's generated queries.", + ) + + self._parser.add_argument( + "--version", + type = str, + default = DEFAULT_LLMSQL_VERSION, + choices=["1.0","2.0"], + help = f"LLMSQL benchmark version (default:{DEFAULT_LLMSQL_VERSION})" + ) + + self._parser.add_argument( + "--workdir-path", + default = DEFAULT_WORKDIR_PATH, + help = f"Directory for benchmark files (default: {DEFAULT_WORKDIR_PATH})", + ) + + self._parser.add_argument( + "--questions-path", + type = str, + default = None, + help = "Manual path to benchmark questions JSON file.", + ) + + self._parser.add_argument( + "--db-path", + type=str, + default = None, + help = "Path to SQLite benchmark database.", + ) + + self._parser.add_argument( + "--show-mismatches", + action="store_true", + default=True, + help="Print SQL mismatches during evaluation", + ) + + self._parser.add_argument( + "--max-mismatches", + type=int, + default=5, + help="Maximum mismatches to display", + ) + + self._parser.add_argument( + "--save-report", + type=str, + default=None, + help="Path to save evaluation report JSON.", + ) + + @staticmethod + def _execute(args: argparse.Namespace) -> None: + """Execute the evaluate function with parsed arguments.""" + try: + evaluate( + outputs=args.outputs, + version = args.version, + workdir_path=args.workdir_path, + questions_path=args.questions_path, + db_path=args.db_path, + save_report=args.save_report, + show_mismatches=args.show_mismatches, + max_mismatches=args.max_mismatches, + ) + except Exception as e: + print(f"Error during evaluation: {e}") + diff --git a/llmsql/_cli/llmsql_cli.py b/llmsql/_cli/llmsql_cli.py index 156caee..8e2409c 100644 --- a/llmsql/_cli/llmsql_cli.py +++ b/llmsql/_cli/llmsql_cli.py @@ -2,7 +2,7 @@ import textwrap from llmsql._cli.inference import Inference - +from llmsql._cli.evaluate import Evaluate class ParserCLI: """Main CLI parser that manages all subcommands.""" @@ -52,6 +52,7 @@ def __init__(self) -> None: ) Inference(self._subparsers) + Evaluate(self._subparsers) def parse_args(self) -> argparse.Namespace: """Parse CLI arguments.""" diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 4f2d543..0415a44 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -138,3 +138,43 @@ async def test_help_shows_without_crashing(monkeypatch, capsys): captured = capsys.readouterr() assert "usage:" in captured.err.lower() or "usage:" in captured.out.lower() + + + +@pytest.mark.asyncio +async def test_evaluate_command_called(monkeypatch): + """ + Ensure the evaluate command is correctly invoked with arguments. + """ + + mock_evaluate = AsyncMock(return_value={}) + + + monkeypatch.setattr( + "llmsql._cli.evaluate.evaluate", + mock_evaluate, + ) + + test_args = [ + "llmsql", + "evaluate", + "--outputs", + "dummy_file.jsonl", + "--show-mismatches", + "--max-mismatches", + "10" + ] + + monkeypatch.setattr(sys, "argv", test_args) + + + cli = ParserCLI() + args = cli.parse_args() + cli.execute(args) + + mock_evaluate.assert_called_once() + + call_kwargs = mock_evaluate.call_args.kwargs + assert call_kwargs["outputs"] == "dummy_file.jsonl" + assert call_kwargs["show_mismatches"] is True + assert call_kwargs["max_mismatches"] == 10 \ No newline at end of file