Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: tests

on:
push:
pull_request:

jobs:
pytest:
runs-on: ubuntu-latest

steps:
- name: Checkout
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest

- name: Run tests
run: |
pytest -q
4 changes: 4 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[pytest]
testpaths = tests
pythonpath = .
addopts = -q
37 changes: 37 additions & 0 deletions sample_data/eval_cases.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
[
{
"id": "E-001",
"title": "Uncapped liability + consequential",
"text": "The Vendor shall be liable for all damages including consequential, indirect, special damages. Liability is uncapped.",
"expected_risk_level": "High",
"expected_categories": ["Liability"]
},
{
"id": "E-002",
"title": "Termination for convenience no cure",
"text": "Either party may terminate for convenience at any time upon written notice. No cure period is provided.",
"expected_risk_level": "Medium",
"expected_categories": ["Termination"]
},
{
"id": "E-003",
"title": "Neutral baseline",
"text": "This agreement describes services and payment terms. Standard confidentiality applies. No special liability clauses are stated.",
"expected_risk_level": "Low",
"expected_categories": []
},
{
"id": "E-004",
"title": "Both liability + termination",
"text": "Either party may terminate for convenience at any time upon notice. Vendor is liable for all damages including consequential. Liability is uncapped.",
"expected_risk_level": "High",
"expected_categories": ["Liability", "Termination"]
},
{
"id": "E-005",
"title": "Cure period present (should reduce termination signal)",
"text": "Either party may terminate upon material breach if not cured within 30 days after notice. Termination for convenience is not permitted.",
"expected_risk_level": "Low",
"expected_categories": []
}
]
163 changes: 141 additions & 22 deletions src/evaluation_stub.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,152 @@
from __future__ import annotations

import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Set, Tuple

from src.model_adapter import analyze_contract


@dataclass
class EvalCase:
id: str
title: str
text: str
expected_risk_level: str
expected_categories: List[str]


def _normalize_level(level: str) -> str:
return str(level).strip().title()


def _predicted_categories(result_dict: Dict[str, Any]) -> Set[str]:
findings = result_dict.get("findings", [])
cats = {f.get("category", "").strip() for f in findings if f.get("category")}
return {c for c in cats if c}


def run_evaluation(cases: List[EvalCase]) -> Dict[str, Any]:
"""
Runs evaluation cases through the adapter, compares predicted vs expected.

Returns a dict:
- risk_level_accuracy
- category_precision/recall/f1 (micro)
- mismatches[]
"""
if not cases:
return {
"risk_level_accuracy": 0.0,
"category_precision": 0.0,
"category_recall": 0.0,
"category_f1": 0.0,
"mismatches": [],
"n": 0,
}

level_hits = 0
tp = fp = fn = 0
mismatches: List[Dict[str, Any]] = []

for c in cases:
result = analyze_contract(c.text, title=c.title, source_type="paste").model_dump()

from .model_adapter import analyze_contract
pred_level = _normalize_level(result["summary"]["risk_level"])
exp_level = _normalize_level(c.expected_risk_level)

if pred_level == exp_level:
level_hits += 1

SAMPLES_DIR = Path("sample_data")
pred_cats = _predicted_categories(result)
exp_cats = set(c.expected_categories)

# micro counts
tp += len(pred_cats & exp_cats)
fp += len(pred_cats - exp_cats)
fn += len(exp_cats - pred_cats)

if pred_level != exp_level or pred_cats != exp_cats:
mismatches.append(
{
"id": c.id,
"title": c.title,
"pred_risk_level": pred_level,
"exp_risk_level": exp_level,
"pred_categories": sorted(list(pred_cats)),
"exp_categories": sorted(list(exp_cats)),
"run_id": result.get("run_id"),
}
)

accuracy = level_hits / len(cases)

precision = tp / (tp + fp) if (tp + fp) else 0.0
recall = tp / (tp + fn) if (tp + fn) else 0.0
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0

return {
"risk_level_accuracy": round(accuracy, 3),
"category_precision": round(precision, 3),
"category_recall": round(recall, 3),
"category_f1": round(f1, 3),
"mismatches": mismatches,
"n": len(cases),
}


def load_cases(path: str | Path) -> List[EvalCase]:
p = Path(path)
payload = json.loads(p.read_text(encoding="utf-8"))

cases: List[EvalCase] = []
for row in payload:
cases.append(
EvalCase(
id=row["id"],
title=row.get("title", row["id"]),
text=row["text"],
expected_risk_level=row["expected_risk_level"],
expected_categories=row.get("expected_categories", []),
)
)
return cases


def main() -> None:
samples = sorted(SAMPLES_DIR.glob("*.txt"))
if not samples:
print("No sample contracts found in sample_data/. Add .txt files to run evaluation.")
return

scores = []
for fp in samples:
contract_text = fp.read_text(encoding="utf-8", errors="ignore")
result = analyze_contract(contract_text=contract_text, title=fp.stem, source_type="sample_data")

score = result.summary.overall_risk_score
level = result.summary.risk_level
findings_count = len(result.findings)

scores.append(score)
print(f"{fp.name}: score={score}, level={level}, findings={findings_count}")

if scores:
print("\nSummary")
print(f"count={len(scores)} min={min(scores)} max={max(scores)} avg={sum(scores)/len(scores):.2f}")
import argparse

parser = argparse.ArgumentParser(description="Offline evaluation harness (stub).")
parser.add_argument(
"path",
nargs="?",
default="sample_data/eval_cases.json",
help="Path to evaluation cases JSON (default: sample_data/eval_cases.json)",
)
args = parser.parse_args()

cases = load_cases(args.path)
metrics = run_evaluation(cases)

print("\n=== Evaluation Summary ===")
print(f"Cases: {metrics['n']}")
print(f"Risk-level accuracy: {metrics['risk_level_accuracy']}")
print(
"Category micro P/R/F1: "
f"{metrics['category_precision']} / {metrics['category_recall']} / {metrics['category_f1']}"
)

if metrics["mismatches"]:
print("\n=== Mismatches ===")
for m in metrics["mismatches"]:
print(
f"- {m['id']} | {m['title']} | "
f"level {m['pred_risk_level']} vs {m['exp_risk_level']} | "
f"cats {m['pred_categories']} vs {m['exp_categories']}"
)
else:
print("\nNo mismatches 🎉")


if __name__ == "__main__":
Expand Down
7 changes: 7 additions & 0 deletions tests/test_evaluation_harness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from src.evaluation_stub import load_cases, run_evaluation

def test_eval_harness_loads_and_runs():
cases = load_cases("sample_data/eval_cases.json")
metrics = run_evaluation(cases)
assert "risk_level_accuracy" in metrics
assert metrics["n"] > 0
13 changes: 13 additions & 0 deletions tests/test_evidence_rule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from src.model_adapter import analyze_contract

def test_all_findings_have_evidence():
text = "Vendor is liable for all damages including consequential. Liability is uncapped."
result = analyze_contract(text, title="Test", source_type="paste").model_dump()
findings = result.get("findings", [])
assert len(findings) > 0

for f in findings:
ev = f.get("evidence", [])
assert isinstance(ev, list)
assert len(ev) >= 1
assert ev[0].get("snippet")
32 changes: 32 additions & 0 deletions tests/test_scoring_monotonic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from src.schemas import Finding, Evidence
from src.scoring import compute_score

def test_high_severity_scores_higher_than_low():
low = [
Finding(
finding_id="T-LOW",
category="General",
risk_statement="Low risk",
severity="Low",
confidence=1.0,
evidence=[Evidence(clause_ref="X", snippet="...")],
recommendation="N/A",
)
]

high = [
Finding(
finding_id="T-HIGH",
category="Liability",
risk_statement="High risk",
severity="High",
confidence=1.0,
evidence=[Evidence(clause_ref="Y", snippet="...")],
recommendation="N/A",
)
]

s_low, _ = compute_score(low)
s_high, _ = compute_score(high)

assert s_high > s_low