Skip to content

Commit f931963

Browse files
committed
Merge branch 'prat' of https://github.com/prateekiiest/interwhen into prat
2 parents 9802544 + a0fbce4 commit f931963

5 files changed

Lines changed: 286489 additions & 19 deletions

File tree

examples/TTSwithVerification/bestofk_baseline.py

Lines changed: 182 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
import argparse
22
import asyncio
3+
from datetime import datetime
34
import json
45
import logging
56
import os
67
import re
78
import sys
9+
import shutil
10+
import subprocess
11+
from multiprocessing.pool import ThreadPool
812
from contextlib import contextmanager
913
from dataclasses import dataclass
10-
from typing import Dict, List, Optional
14+
from pathlib import Path
15+
from typing import Dict, List, Optional, Tuple
1116

1217
import numpy as np
1318
import pandas as pd
@@ -16,17 +21,27 @@
1621
from tqdm import tqdm
1722
from transformers import AutoTokenizer
1823

24+
from interwhen.utils.zebralogic_helper import SYSTEM_PROMPT_VANILLA, USER_PROMPT_TEMPLATE, get_zebralogic_dataset, extract_last_json, zebra_correctness
25+
1926
from interwhen import stream_completion
27+
from verina_utils import *
2028

2129
# ============== MODEL CONFIGURATION ==============
2230
MAIN_MODEL = "Qwen/Qwen3-30B-A3B-Thinking-2507"
2331
# Multi-process vLLM configuration
2432
VLLM_PORTS = [8000, 8001, 8002] # 3 instances with tensor-parallel-size 2 each
2533
REQUEST_COUNTER = {"main": 0, "critic": 0} # Track request count for round-robin load balancing
26-
34+
# Verina paths
35+
_SCRIPT_DIR = Path(__file__).parent.resolve()
36+
VERINA_ROOT = (_SCRIPT_DIR / "../../../verina").resolve()
37+
VERINA_DATASETS_PATH = VERINA_ROOT / "datasets" / "verina"
38+
LEAN_PLAYGROUND_DIR = VERINA_ROOT / "lean-playground"
2739

2840
logger = logging.getLogger(__name__)
2941

42+
# Save the real stderr so tqdm always works even if suppress_output is active
43+
_real_stderr = sys.stderr
44+
3045

3146
@contextmanager
3247
def suppress_output():
@@ -306,13 +321,56 @@ def evaluate_mcq_answer(answer, options, ground_truth):
306321
return False, sol, f"Incorrect: expected '{gt_sol}', got '{opt_value}' (option {opt_letter})"
307322
return False, sol, f"Solution '{sol}' not found in options or ground truth"
308323

324+
# --------------------- ZebraLogic helpers ---------------------
325+
326+
def evaluate_zebralogic_answer(answer, example):
327+
"""Evaluate a zebralogic answer against ground truth using zebra_correctness."""
328+
candidate = extract_last_json(answer)
329+
if not candidate:
330+
return False, None, "No valid JSON solution found"
331+
correct, skipped, missing, total = zebra_correctness(example, candidate)
332+
is_correct = correct == total
333+
msg = f"Correct={correct}/{total}, skipped={skipped}, missing={missing}"
334+
return is_correct, candidate, msg
335+
336+
337+
def build_zebralogic_prompt(example):
338+
system_prompt = SYSTEM_PROMPT_VANILLA
339+
user_prompt = USER_PROMPT_TEMPLATE.format(problem_text=example['puzzle_clean'])
340+
return system_prompt, user_prompt
341+
342+
# verina helpers
343+
def evaluate_verina_answer(output: str, data: BenchmarkData, task_idx: int) -> Tuple[bool, str, str]:
344+
"""Evaluate Verina code generation output - wrapper for best-of-k interface"""
345+
generated_code = extract_code_from_response(output)
346+
347+
if not generated_code.strip():
348+
return False, "", "No code extracted from response"
349+
350+
compiles, all_tests_pass, compile_output, test_results = evaluate_generated_code(data, generated_code, task_idx)
351+
352+
num_tests = len(data.tests) if data.tests else 0
353+
num_passed = sum(1 for v in test_results.values() if v == "pass")
354+
355+
if compiles and all_tests_pass:
356+
return True, generated_code, f"Code compiles and all {num_tests} tests pass"
357+
elif compiles:
358+
return False, generated_code, f"Compilation succeeded but {num_tests - num_passed}/{num_tests} tests failed"
359+
else:
360+
error_preview = compile_output[:300] if compile_output else "Unknown error"
361+
return False, generated_code, f"Compilation failed: {error_preview}"
362+
309363

310364
def build_full_prompt(task, example, nums=None):
311365
if task == "game24":
312366
prompt = build_game24_prompt(nums)
313367
return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
314368
if task == "maze":
315369
system_prompt, user_prompt = build_maze_prompt(example)
370+
elif task == 'zebralogic':
371+
system_prompt, user_prompt = build_zebralogic_prompt(example)
372+
elif task == "verina":
373+
return build_verina_prompt(example)
316374
else:
317375
system_prompt, user_prompt = build_spatialmap_prompt(example)
318376
return (
@@ -329,6 +387,10 @@ def load_dataset_for_task(task):
329387
return load_dataset("microsoft/VISION_LANGUAGE", "maze_text_only", split="val")
330388
if task == "spatialmap":
331389
return load_dataset("microsoft/VISION_LANGUAGE", "spatial_map_text_only", split="val")
390+
if task == "zebralogic":
391+
return get_zebralogic_dataset()
392+
if task == "verina":
393+
return load_verina_dataset()
332394
raise ValueError(f"Unsupported task: {task}")
333395

334396

@@ -451,6 +513,30 @@ def build_game24_critic_prompt(nums, reasoning_output):
451513
"""
452514

453515

516+
def build_zebralogic_critic_prompt(task_description, reasoning_output):
517+
"""Build critic prompt to evaluate ZebraLogic solution and provide reasoning."""
518+
return f"""You are an expert logic puzzle verifier. Evaluate the following ZebraLogic solution.
519+
520+
Task:
521+
{task_description}
522+
523+
Student's reasoning and answer:
524+
{reasoning_output}
525+
526+
Verify:
527+
1. Does the solution assign exactly one value per feature per house?
528+
2. Are all constraints/clues satisfied?
529+
3. Is the JSON output well-formed and complete?
530+
531+
Respond in the following format:
532+
VERDICT: CORRECT or INCORRECT
533+
REASONING: Your detailed explanation
534+
535+
If CORRECT, briefly explain why.
536+
If INCORRECT, explain what went wrong and suggest corrections.
537+
"""
538+
539+
454540
def build_mcq_critic_prompt(task, task_description, reasoning_output):
455541
"""Build critic prompt to evaluate MCQ solution and provide reasoning."""
456542
task_name = "Maze" if task == "maze" else "Spatial Reasoning"
@@ -472,6 +558,52 @@ def build_mcq_critic_prompt(task, task_description, reasoning_output):
472558
If INCORRECT, explain what went wrong and suggest the correct approach.
473559
"""
474560

561+
def build_verina_critic_prompt(data: BenchmarkData, reasoning_output: str) -> str:
562+
"""Build critic prompt to evaluate Verina Lean code generation and provide reasoning."""
563+
signature = data.signature
564+
func_name = signature.get("name", "solution")
565+
return_type = signature.get("return_type", "Bool")
566+
param_list = render_param_list(signature)
567+
568+
precond = data.lean_data.get("precond", "True").strip()
569+
postcond = data.lean_data.get("postcond", "").strip()
570+
571+
return f"""You are an expert Lean 4 code verifier. Evaluate the following code generation attempt.
572+
573+
## Task Description
574+
{data.description}
575+
576+
## Function Signature
577+
```lean4
578+
def {func_name} {param_list} (h_precond : {func_name}_precond ...) : {return_type}
579+
```
580+
581+
## Precondition
582+
```lean4
583+
{precond}
584+
```
585+
586+
## Postcondition
587+
```lean4
588+
{postcond}
589+
```
590+
591+
## Student's Reasoning and Generated Code
592+
{reasoning_output}
593+
594+
Verify:
595+
1. Is the generated code syntactically valid Lean 4?
596+
2. Does it match the expected function signature and return type ({return_type})?
597+
3. Does the logic appear to satisfy the postcondition given the precondition?
598+
4. Are there any obvious bugs, infinite loops, or incorrect base cases?
599+
600+
Respond in the following format:
601+
VERDICT: CORRECT or INCORRECT
602+
REASONING: Your detailed explanation
603+
604+
If CORRECT, briefly explain why
605+
If INCORRECT, explain what went wrong and suggest how to fix it.
606+
"""
475607

476608
def batch_evaluate_with_critic(outputs_df, task, example, critic_llm_server, tokenizer, nums=None, quiet=True):
477609
"""Batch evaluate outputs using vLLM API across multiple instances. Outputs_df should have columns: 'output', 'seed_idx'"""
@@ -518,6 +650,11 @@ async def _run_parallel():
518650
output_text = row["output"]
519651
if task == "game24":
520652
critic_prompt = build_game24_critic_prompt(nums, output_text)
653+
elif task == "zebralogic":
654+
_, task_desc = build_zebralogic_prompt(example)
655+
critic_prompt = build_zebralogic_critic_prompt(task_desc, output_text)
656+
elif task == "verina":
657+
critic_prompt = build_verina_critic_prompt(example, output_text)
521658
else:
522659
if task == "maze":
523660
_, task_desc = build_maze_prompt(example)
@@ -649,7 +786,7 @@ def run_k_samples_with_critic(
649786

650787
if __name__ == "__main__":
651788
parser = argparse.ArgumentParser(description="Best-of-K baseline (standard CoT) for TTSwithVerification datasets")
652-
parser.add_argument("--task", type=str, required=True, choices=["game24", "maze", "spatialmap"],
789+
parser.add_argument("--task", type=str, required=True, choices=["game24", "maze", "spatialmap", "zebralogic","verina"],
653790
help="Task to run")
654791
parser.add_argument("--k", type=int, default=1, help="Number of samples per example")
655792
parser.add_argument("--num_examples", "-n", type=int, default=100,
@@ -670,6 +807,7 @@ def run_k_samples_with_critic(
670807
parser.add_argument("--seed", type=int, default=42, help="Base random seed")
671808
parser.add_argument("--max_tokens", type=int, default=32768, help="Max tokens for generation")
672809
parser.add_argument("--temperature", type=float, default=0.6, help="Sampling temperature")
810+
parser.add_argument("--processes", "-p", type=int, default=1, help="Number of examples to process in parallel (default: 1, sequential)")
673811
parser.add_argument("--debug", "-d", action="store_true", help="Enable debug logging")
674812
args = parser.parse_args()
675813

@@ -724,13 +862,27 @@ def run_k_samples_with_critic(
724862
total_tokens_all_samples = 0
725863
results = []
726864

727-
for idx in tqdm(indices, desc="Processing examples", unit="example"):
865+
def process_example(idx):
866+
"""Process a single example: generate k samples, evaluate, return result dict."""
728867
example = dataset[int(idx)]
729868
if args.task == "game24":
730869
nums = example["numbers"]
731870
prompt = build_full_prompt(args.task, example, nums=nums)
732871
eval_fn = lambda output: evaluate_game24_answer(output, nums)
733872
options = None
873+
874+
elif args.task == "zebralogic":
875+
prompt = build_full_prompt(args.task, example)
876+
eval_fn = lambda output, ex=example: evaluate_zebralogic_answer(output, ex)
877+
options = None
878+
elif args.task == "verina":
879+
# For verina, example is a BenchmarkData object
880+
prompt = build_full_prompt(args.task, example)
881+
current_idx = int(idx)
882+
current_data = example
883+
eval_fn = lambda output, data=current_data, task_idx=current_idx: evaluate_verina_answer(output, data, task_idx)
884+
options = None
885+
734886
else:
735887
prompt = build_full_prompt(args.task, example)
736888
gt = str(example.get("ground_truth", "")).strip()
@@ -781,17 +933,7 @@ def run_k_samples_with_critic(
781933

782934
save_outputs(idx, sample_results, best_idx, output_dirs["reasoning"])
783935

784-
total_examples += 1
785-
if any_correct:
786-
total_correct += 1
787-
total_correct_samples += correct_samples
788-
total_samples += len(sample_results)
789-
critic_correct_samples += critic_correct_samples_example
790-
critic_total_samples += len(sample_results)
791-
total_tokens += best_result.tokens
792-
total_tokens_all_samples += sum(r.tokens for r in sample_results)
793-
794-
results.append({
936+
return {
795937
"idx": int(idx),
796938
"best_idx": best_idx,
797939
"any_correct": any_correct,
@@ -806,10 +948,31 @@ def run_k_samples_with_critic(
806948
"all_critic_correct": [r.critic_correct for r in sample_results],
807949
"all_critic_feedback": [r.critic_feedback for r in sample_results],
808950
"options": options,
809-
})
810-
811-
#logger.info(f"Best sample: {best_idx} | Correct in K: {any_correct}")
812-
#logger.info(f"Best message: {best_result.message}")
951+
"_any_correct": any_correct,
952+
"_correct_samples": correct_samples,
953+
"_critic_correct_samples": critic_correct_samples_example,
954+
"_n_samples": len(sample_results),
955+
"_best_tokens": best_result.tokens,
956+
"_all_tokens_sum": sum(r.tokens for r in sample_results),
957+
}
958+
959+
with ThreadPool(processes=args.processes) as pool:
960+
for result in tqdm(pool.imap_unordered(process_example, indices), total=len(indices), desc="Processing examples", unit="example", file=_real_stderr):
961+
total_examples += 1
962+
if result["_any_correct"]:
963+
total_correct += 1
964+
total_correct_samples += result["_correct_samples"]
965+
total_samples += result["_n_samples"]
966+
critic_correct_samples += result["_critic_correct_samples"]
967+
critic_total_samples += result["_n_samples"]
968+
total_tokens += result["_best_tokens"]
969+
total_tokens_all_samples += result["_all_tokens_sum"]
970+
971+
# Remove internal keys before appending
972+
for k in list(result.keys()):
973+
if k.startswith("_"):
974+
del result[k]
975+
results.append(result)
813976

814977
accuracy = total_correct / total_examples if total_examples else 0
815978
avg_best_tokens = total_tokens / total_examples if total_examples else 0

0 commit comments

Comments
 (0)