11import argparse
22import asyncio
3+ from datetime import datetime
34import json
45import logging
56import os
67import re
78import sys
9+ import shutil
10+ import subprocess
11+ from multiprocessing .pool import ThreadPool
812from contextlib import contextmanager
913from dataclasses import dataclass
10- from typing import Dict , List , Optional
14+ from pathlib import Path
15+ from typing import Dict , List , Optional , Tuple
1116
1217import numpy as np
1318import pandas as pd
1621from tqdm import tqdm
1722from transformers import AutoTokenizer
1823
24+ from interwhen .utils .zebralogic_helper import SYSTEM_PROMPT_VANILLA , USER_PROMPT_TEMPLATE , get_zebralogic_dataset , extract_last_json , zebra_correctness
25+
1926from interwhen import stream_completion
27+ from verina_utils import *
2028
2129# ============== MODEL CONFIGURATION ==============
2230MAIN_MODEL = "Qwen/Qwen3-30B-A3B-Thinking-2507"
2331# Multi-process vLLM configuration
2432VLLM_PORTS = [8000 , 8001 , 8002 ] # 3 instances with tensor-parallel-size 2 each
2533REQUEST_COUNTER = {"main" : 0 , "critic" : 0 } # Track request count for round-robin load balancing
26-
34+ # Verina paths
35+ _SCRIPT_DIR = Path (__file__ ).parent .resolve ()
36+ VERINA_ROOT = (_SCRIPT_DIR / "../../../verina" ).resolve ()
37+ VERINA_DATASETS_PATH = VERINA_ROOT / "datasets" / "verina"
38+ LEAN_PLAYGROUND_DIR = VERINA_ROOT / "lean-playground"
2739
2840logger = logging .getLogger (__name__ )
2941
42+ # Save the real stderr so tqdm always works even if suppress_output is active
43+ _real_stderr = sys .stderr
44+
3045
3146@contextmanager
3247def suppress_output ():
@@ -306,13 +321,56 @@ def evaluate_mcq_answer(answer, options, ground_truth):
306321 return False , sol , f"Incorrect: expected '{ gt_sol } ', got '{ opt_value } ' (option { opt_letter } )"
307322 return False , sol , f"Solution '{ sol } ' not found in options or ground truth"
308323
324+ # --------------------- ZebraLogic helpers ---------------------
325+
326+ def evaluate_zebralogic_answer (answer , example ):
327+ """Evaluate a zebralogic answer against ground truth using zebra_correctness."""
328+ candidate = extract_last_json (answer )
329+ if not candidate :
330+ return False , None , "No valid JSON solution found"
331+ correct , skipped , missing , total = zebra_correctness (example , candidate )
332+ is_correct = correct == total
333+ msg = f"Correct={ correct } /{ total } , skipped={ skipped } , missing={ missing } "
334+ return is_correct , candidate , msg
335+
336+
337+ def build_zebralogic_prompt (example ):
338+ system_prompt = SYSTEM_PROMPT_VANILLA
339+ user_prompt = USER_PROMPT_TEMPLATE .format (problem_text = example ['puzzle_clean' ])
340+ return system_prompt , user_prompt
341+
342+ # verina helpers
343+ def evaluate_verina_answer (output : str , data : BenchmarkData , task_idx : int ) -> Tuple [bool , str , str ]:
344+ """Evaluate Verina code generation output - wrapper for best-of-k interface"""
345+ generated_code = extract_code_from_response (output )
346+
347+ if not generated_code .strip ():
348+ return False , "" , "No code extracted from response"
349+
350+ compiles , all_tests_pass , compile_output , test_results = evaluate_generated_code (data , generated_code , task_idx )
351+
352+ num_tests = len (data .tests ) if data .tests else 0
353+ num_passed = sum (1 for v in test_results .values () if v == "pass" )
354+
355+ if compiles and all_tests_pass :
356+ return True , generated_code , f"Code compiles and all { num_tests } tests pass"
357+ elif compiles :
358+ return False , generated_code , f"Compilation succeeded but { num_tests - num_passed } /{ num_tests } tests failed"
359+ else :
360+ error_preview = compile_output [:300 ] if compile_output else "Unknown error"
361+ return False , generated_code , f"Compilation failed: { error_preview } "
362+
309363
310364def build_full_prompt (task , example , nums = None ):
311365 if task == "game24" :
312366 prompt = build_game24_prompt (nums )
313367 return f"<|im_start|>user\n { prompt } <|im_end|>\n <|im_start|>assistant\n "
314368 if task == "maze" :
315369 system_prompt , user_prompt = build_maze_prompt (example )
370+ elif task == 'zebralogic' :
371+ system_prompt , user_prompt = build_zebralogic_prompt (example )
372+ elif task == "verina" :
373+ return build_verina_prompt (example )
316374 else :
317375 system_prompt , user_prompt = build_spatialmap_prompt (example )
318376 return (
@@ -329,6 +387,10 @@ def load_dataset_for_task(task):
329387 return load_dataset ("microsoft/VISION_LANGUAGE" , "maze_text_only" , split = "val" )
330388 if task == "spatialmap" :
331389 return load_dataset ("microsoft/VISION_LANGUAGE" , "spatial_map_text_only" , split = "val" )
390+ if task == "zebralogic" :
391+ return get_zebralogic_dataset ()
392+ if task == "verina" :
393+ return load_verina_dataset ()
332394 raise ValueError (f"Unsupported task: { task } " )
333395
334396
@@ -451,6 +513,30 @@ def build_game24_critic_prompt(nums, reasoning_output):
451513"""
452514
453515
516+ def build_zebralogic_critic_prompt (task_description , reasoning_output ):
517+ """Build critic prompt to evaluate ZebraLogic solution and provide reasoning."""
518+ return f"""You are an expert logic puzzle verifier. Evaluate the following ZebraLogic solution.
519+
520+ Task:
521+ { task_description }
522+
523+ Student's reasoning and answer:
524+ { reasoning_output }
525+
526+ Verify:
527+ 1. Does the solution assign exactly one value per feature per house?
528+ 2. Are all constraints/clues satisfied?
529+ 3. Is the JSON output well-formed and complete?
530+
531+ Respond in the following format:
532+ VERDICT: CORRECT or INCORRECT
533+ REASONING: Your detailed explanation
534+
535+ If CORRECT, briefly explain why.
536+ If INCORRECT, explain what went wrong and suggest corrections.
537+ """
538+
539+
454540def build_mcq_critic_prompt (task , task_description , reasoning_output ):
455541 """Build critic prompt to evaluate MCQ solution and provide reasoning."""
456542 task_name = "Maze" if task == "maze" else "Spatial Reasoning"
@@ -472,6 +558,52 @@ def build_mcq_critic_prompt(task, task_description, reasoning_output):
472558If INCORRECT, explain what went wrong and suggest the correct approach.
473559"""
474560
561+ def build_verina_critic_prompt (data : BenchmarkData , reasoning_output : str ) -> str :
562+ """Build critic prompt to evaluate Verina Lean code generation and provide reasoning."""
563+ signature = data .signature
564+ func_name = signature .get ("name" , "solution" )
565+ return_type = signature .get ("return_type" , "Bool" )
566+ param_list = render_param_list (signature )
567+
568+ precond = data .lean_data .get ("precond" , "True" ).strip ()
569+ postcond = data .lean_data .get ("postcond" , "" ).strip ()
570+
571+ return f"""You are an expert Lean 4 code verifier. Evaluate the following code generation attempt.
572+
573+ ## Task Description
574+ { data .description }
575+
576+ ## Function Signature
577+ ```lean4
578+ def { func_name } { param_list } (h_precond : { func_name } _precond ...) : { return_type }
579+ ```
580+
581+ ## Precondition
582+ ```lean4
583+ { precond }
584+ ```
585+
586+ ## Postcondition
587+ ```lean4
588+ { postcond }
589+ ```
590+
591+ ## Student's Reasoning and Generated Code
592+ { reasoning_output }
593+
594+ Verify:
595+ 1. Is the generated code syntactically valid Lean 4?
596+ 2. Does it match the expected function signature and return type ({ return_type } )?
597+ 3. Does the logic appear to satisfy the postcondition given the precondition?
598+ 4. Are there any obvious bugs, infinite loops, or incorrect base cases?
599+
600+ Respond in the following format:
601+ VERDICT: CORRECT or INCORRECT
602+ REASONING: Your detailed explanation
603+
604+ If CORRECT, briefly explain why
605+ If INCORRECT, explain what went wrong and suggest how to fix it.
606+ """
475607
476608def batch_evaluate_with_critic (outputs_df , task , example , critic_llm_server , tokenizer , nums = None , quiet = True ):
477609 """Batch evaluate outputs using vLLM API across multiple instances. Outputs_df should have columns: 'output', 'seed_idx'"""
@@ -518,6 +650,11 @@ async def _run_parallel():
518650 output_text = row ["output" ]
519651 if task == "game24" :
520652 critic_prompt = build_game24_critic_prompt (nums , output_text )
653+ elif task == "zebralogic" :
654+ _ , task_desc = build_zebralogic_prompt (example )
655+ critic_prompt = build_zebralogic_critic_prompt (task_desc , output_text )
656+ elif task == "verina" :
657+ critic_prompt = build_verina_critic_prompt (example , output_text )
521658 else :
522659 if task == "maze" :
523660 _ , task_desc = build_maze_prompt (example )
@@ -649,7 +786,7 @@ def run_k_samples_with_critic(
649786
650787if __name__ == "__main__" :
651788 parser = argparse .ArgumentParser (description = "Best-of-K baseline (standard CoT) for TTSwithVerification datasets" )
652- parser .add_argument ("--task" , type = str , required = True , choices = ["game24" , "maze" , "spatialmap" ],
789+ parser .add_argument ("--task" , type = str , required = True , choices = ["game24" , "maze" , "spatialmap" , "zebralogic" , "verina" ],
653790 help = "Task to run" )
654791 parser .add_argument ("--k" , type = int , default = 1 , help = "Number of samples per example" )
655792 parser .add_argument ("--num_examples" , "-n" , type = int , default = 100 ,
@@ -670,6 +807,7 @@ def run_k_samples_with_critic(
670807 parser .add_argument ("--seed" , type = int , default = 42 , help = "Base random seed" )
671808 parser .add_argument ("--max_tokens" , type = int , default = 32768 , help = "Max tokens for generation" )
672809 parser .add_argument ("--temperature" , type = float , default = 0.6 , help = "Sampling temperature" )
810+ parser .add_argument ("--processes" , "-p" , type = int , default = 1 , help = "Number of examples to process in parallel (default: 1, sequential)" )
673811 parser .add_argument ("--debug" , "-d" , action = "store_true" , help = "Enable debug logging" )
674812 args = parser .parse_args ()
675813
@@ -724,13 +862,27 @@ def run_k_samples_with_critic(
724862 total_tokens_all_samples = 0
725863 results = []
726864
727- for idx in tqdm (indices , desc = "Processing examples" , unit = "example" ):
865+ def process_example (idx ):
866+ """Process a single example: generate k samples, evaluate, return result dict."""
728867 example = dataset [int (idx )]
729868 if args .task == "game24" :
730869 nums = example ["numbers" ]
731870 prompt = build_full_prompt (args .task , example , nums = nums )
732871 eval_fn = lambda output : evaluate_game24_answer (output , nums )
733872 options = None
873+
874+ elif args .task == "zebralogic" :
875+ prompt = build_full_prompt (args .task , example )
876+ eval_fn = lambda output , ex = example : evaluate_zebralogic_answer (output , ex )
877+ options = None
878+ elif args .task == "verina" :
879+ # For verina, example is a BenchmarkData object
880+ prompt = build_full_prompt (args .task , example )
881+ current_idx = int (idx )
882+ current_data = example
883+ eval_fn = lambda output , data = current_data , task_idx = current_idx : evaluate_verina_answer (output , data , task_idx )
884+ options = None
885+
734886 else :
735887 prompt = build_full_prompt (args .task , example )
736888 gt = str (example .get ("ground_truth" , "" )).strip ()
@@ -781,17 +933,7 @@ def run_k_samples_with_critic(
781933
782934 save_outputs (idx , sample_results , best_idx , output_dirs ["reasoning" ])
783935
784- total_examples += 1
785- if any_correct :
786- total_correct += 1
787- total_correct_samples += correct_samples
788- total_samples += len (sample_results )
789- critic_correct_samples += critic_correct_samples_example
790- critic_total_samples += len (sample_results )
791- total_tokens += best_result .tokens
792- total_tokens_all_samples += sum (r .tokens for r in sample_results )
793-
794- results .append ({
936+ return {
795937 "idx" : int (idx ),
796938 "best_idx" : best_idx ,
797939 "any_correct" : any_correct ,
@@ -806,10 +948,31 @@ def run_k_samples_with_critic(
806948 "all_critic_correct" : [r .critic_correct for r in sample_results ],
807949 "all_critic_feedback" : [r .critic_feedback for r in sample_results ],
808950 "options" : options ,
809- })
810-
811- #logger.info(f"Best sample: {best_idx} | Correct in K: {any_correct}")
812- #logger.info(f"Best message: {best_result.message}")
951+ "_any_correct" : any_correct ,
952+ "_correct_samples" : correct_samples ,
953+ "_critic_correct_samples" : critic_correct_samples_example ,
954+ "_n_samples" : len (sample_results ),
955+ "_best_tokens" : best_result .tokens ,
956+ "_all_tokens_sum" : sum (r .tokens for r in sample_results ),
957+ }
958+
959+ with ThreadPool (processes = args .processes ) as pool :
960+ for result in tqdm (pool .imap_unordered (process_example , indices ), total = len (indices ), desc = "Processing examples" , unit = "example" , file = _real_stderr ):
961+ total_examples += 1
962+ if result ["_any_correct" ]:
963+ total_correct += 1
964+ total_correct_samples += result ["_correct_samples" ]
965+ total_samples += result ["_n_samples" ]
966+ critic_correct_samples += result ["_critic_correct_samples" ]
967+ critic_total_samples += result ["_n_samples" ]
968+ total_tokens += result ["_best_tokens" ]
969+ total_tokens_all_samples += result ["_all_tokens_sum" ]
970+
971+ # Remove internal keys before appending
972+ for k in list (result .keys ()):
973+ if k .startswith ("_" ):
974+ del result [k ]
975+ results .append (result )
813976
814977 accuracy = total_correct / total_examples if total_examples else 0
815978 avg_best_tokens = total_tokens / total_examples if total_examples else 0
0 commit comments