11from __future__ import annotations
2+ import time
23import copy
4+ import os
5+ import tarfile
6+ import io
7+ import zstandard as zstd
38from dataclasses import dataclass
49from tabulate import tabulate
510from multiprocessing import Pool
6- from z3 import Solver , Context , main_ctx , Extract , BitVecVal , sat , BitVec , Distinct
11+ from z3 import (
12+ Solver ,
13+ Context ,
14+ main_ctx ,
15+ Extract ,
16+ BitVecVal ,
17+ sat ,
18+ BitVec ,
19+ Distinct ,
20+ )
721
822try :
923 from .success_progress import SuccessProgress
@@ -299,7 +313,8 @@ def synthesize_gadget_with_symbolic(
299313 input_state : VectorState ,
300314 target_pairs : list [tuple [int , int ]],
301315 max_solutions : int | None = None ,
302- ) -> list [tuple [PermutationGadget , VectorState ]]:
316+ solver_callback : callable | None = None ,
317+ ) -> tuple [list [tuple [PermutationGadget , VectorState ]], float , float ]:
303318 """
304319 Synthesize gadgets using symbolic immediates in Z3.
305320
@@ -309,7 +324,8 @@ def synthesize_gadget_with_symbolic(
309324 The symbolic values are represented using SymbolicPlaceholder for pickling.
310325 The pickling is required for multiprocessing.
311326
312- Returns list of (gadget, output_state) tuples. The output state is computed
327+ Returns (results, construction_time, solver_time) where results is
328+ list of (gadget, output_state) tuples. The output state is computed
313329 directly from the satisfying model, avoiding a redundant Z3 solve. The output
314330 state is in canonical form: for each pair, the lower element index goes to
315331 the top vector and the higher index goes to bottom, reflecting the
@@ -318,7 +334,9 @@ def synthesize_gadget_with_symbolic(
318334 Args:
319335 max_solutions: Optional cap on the number of solutions returned.
320336 When ``None`` (the default) all solutions are enumerated.
337+ solver_callback: Optional callback receiving the Solver instance.
321338 """
339+ start_construction = time .perf_counter ()
322340 ctx = main_ctx ()
323341 solver = Solver (ctx = ctx )
324342
@@ -400,14 +418,19 @@ def maybe_resolve_symbolic_vars(instructions: list[InstructionSpec]):
400418 # have the same pair_id, losing elements in the process.
401419 solver .add (Distinct (* output_lanes ))
402420
421+ if solver_callback :
422+ solver_callback (solver )
423+
403424 # Collect symbolic variable terms for enumeration
404425 terms = list (symbolic_vars .values ())
405-
426+ construction_time = time .perf_counter () - start_construction
427+ solver_start = time .perf_counter ()
406428 # If no symbolic variables, single check suffices
407429 if not terms :
408430 result = solver .check ()
409431 if result != sat :
410- return []
432+ solver_time = time .perf_counter () - solver_start
433+ return [], construction_time , solver_time
411434 model = solver .model ()
412435 gadget , output_state = self ._extract_solution_from_model (
413436 model ,
@@ -418,7 +441,8 @@ def maybe_resolve_symbolic_vars(instructions: list[InstructionSpec]):
418441 top_instructions_template ,
419442 bottom_instructions_template ,
420443 )
421- return [(gadget , output_state )]
444+ solver_time = time .perf_counter () - solver_start
445+ return [(gadget , output_state )], construction_time , solver_time
422446
423447 # Enumerate all solutions over symbolic variables
424448 results = []
@@ -433,7 +457,8 @@ def maybe_resolve_symbolic_vars(instructions: list[InstructionSpec]):
433457 bottom_instructions_template ,
434458 )
435459 results .append ((gadget , output_state ))
436- return results
460+ solver_time = time .perf_counter () - solver_start
461+ return results , construction_time , solver_time
437462
438463 def _extract_solution_from_model (
439464 self ,
@@ -827,22 +852,54 @@ def _validate_gadgets(
827852 )
828853
829854 try :
830- with Pool () as pool :
831- # Use imap_unordered for streaming results and progress updates
832- for (
833- gadget_results ,
834- job_input_state ,
835- job_metadata ,
836- ) in pool .imap_unordered (_validate_gadget_worker , jobs ):
837- success_inc = 0
838- for gadget , output_state in gadget_results :
839- validated_gadgets .append (
840- (gadget , job_input_state , output_state , job_metadata )
841- )
842- success_inc = 1
843-
844- if progress :
845- progress .update (task_id , advance = 1 , success = success_inc )
855+ pool = Pool ()
856+ total_construct_time = 0.0
857+ total_solve_time = 0.0
858+
859+ # Use imap_unordered for streaming results and progress updates
860+ for (
861+ gadget_results ,
862+ job_input_state ,
863+ job_metadata ,
864+ construct_time ,
865+ solver_time ,
866+ ) in pool .imap_unordered (_validate_gadget_worker , jobs ):
867+ total_construct_time += construct_time
868+ total_solve_time += solver_time
869+
870+ success_inc = 0
871+ for gadget , output_state in gadget_results :
872+ validated_gadgets .append (
873+ (gadget , job_input_state , output_state , job_metadata )
874+ )
875+ success_inc = 1
876+
877+ if progress :
878+ progress .update (task_id , advance = 1 , success = success_inc )
879+
880+ pool .close ()
881+ pool .join ()
882+
883+ if jobs :
884+ print (
885+ f"TOTAL construction time: { total_construct_time :.2f} s, TOTAL solver time: { total_solve_time :.2f} s"
886+ )
887+
888+ # Phase 2.5: Compress tar files if smt2_dump_dir is set
889+ if jobs and "smt2_dump_dir" in jobs [0 ][6 ]:
890+ smt2_dump_dir = jobs [0 ][6 ]["smt2_dump_dir" ]
891+ stage_idx = jobs [0 ][6 ]["stage_idx" ]
892+ for filename in os .listdir (smt2_dump_dir ):
893+ if filename .startswith (f"stage{ stage_idx } _" ) and filename .endswith (
894+ ".tar"
895+ ):
896+ tar_path = os .path .join (smt2_dump_dir , filename )
897+ zst_path = tar_path + ".zst"
898+ cctx = zstd .ZstdCompressor ()
899+ with open (tar_path , "rb" ) as f_in :
900+ with open (zst_path , "wb" ) as f_out :
901+ cctx .copy_stream (f_in , f_out )
902+ os .remove (tar_path )
846903 finally :
847904 if progress :
848905 progress .stop ()
@@ -979,10 +1036,17 @@ def _enumerate_dual_input_instructions(
9791036class BitonicSuperVectorizer :
9801037 """Super-optimizer for bitonic sorting networks using Z3-based gadget synthesis."""
9811038
982- def __init__ (self , num_vecs : int , prim_type : primitive_type , vm : vector_machine ):
1039+ def __init__ (
1040+ self ,
1041+ num_vecs : int ,
1042+ prim_type : primitive_type ,
1043+ vm : vector_machine ,
1044+ smt2_dump_dir : str | None = None ,
1045+ ):
9831046 self .num_vecs = num_vecs
9841047 self .prim_type = prim_type
9851048 self .vm = vm
1049+ self .smt2_dump_dir = smt2_dump_dir
9861050
9871051 # Calculate total elements and elements per vector
9881052 self .elements_per_vector = width_dict [vm ] // int (prim_type .value [0 ])
@@ -1104,7 +1168,11 @@ def _build_tree_recursive(
11041168 metadata = {
11051169 "input_state" : input_state ,
11061170 "parent_path" : parent_path ,
1171+ "stage_idx" : stage_idx ,
11071172 }
1173+ if self .smt2_dump_dir :
1174+ metadata ["smt2_dump_dir" ] = self .smt2_dump_dir
1175+
11081176 if max_solutions_per_gadget is not None :
11091177 metadata ["max_solutions" ] = max_solutions_per_gadget
11101178
@@ -1296,13 +1364,17 @@ def node_to_dict(node: SolutionNode) -> dict:
12961364 print (f"Exported { len (roots )} solution trees to { output_path } " )
12971365
12981366
1367+ _worker_tar = None
1368+ _worker_job_count = 0
1369+
1370+
12991371def _validate_gadget_worker (job ):
13001372 """Worker function for parallel gadget validation.
13011373
1302- Returns (gadget_results, input_state, metadata) where gadget_results is
1303- a list of (gadget, output_state) tuples. The output_state is computed
1304- directly during validation.
1374+ Returns (gadget_results, input_state, metadata, construction_time, solver_time)
1375+ where gadget_results is a list of (gadget, output_state) tuples.
13051376 """
1377+ global _worker_tar , _worker_job_count
13061378 top_seq , bottom_seq , input_state , target_pairs , vm , prim_type , metadata = job
13071379
13081380 # Create clones of sequences to avoid modifying the ones in the main process
@@ -1314,12 +1386,42 @@ def _validate_gadget_worker(job):
13141386 # Default to 1 for backward compatibility; callers opt in to more via
13151387 # build_solution_tree(max_solutions_per_gadget=N).
13161388 max_solutions = metadata .get ("max_solutions" , 1 )
1317- gadget_results = synthesizer .synthesize_gadget_with_symbolic (
1318- top_seq_clone ,
1319- bottom_seq_clone ,
1320- input_state ,
1321- target_pairs ,
1322- max_solutions = max_solutions ,
1389+
1390+ smt2_dump_dir = metadata .get ("smt2_dump_dir" )
1391+ stage_idx = metadata .get ("stage_idx" )
1392+
1393+ def dump_smt2_to_tar (solver ):
1394+ global _worker_tar , _worker_job_count
1395+ if smt2_dump_dir is None :
1396+ return
1397+
1398+ if _worker_tar is None :
1399+ pid = os .getpid ()
1400+ # Write to an uncompressed tar file first; we'll compress it in the main process
1401+ tar_filename = f"stage{ stage_idx } _pid_{ pid } .tar"
1402+ tar_path = os .path .join (smt2_dump_dir , tar_filename )
1403+ _worker_tar = tarfile .open (tar_path , mode = "a" )
1404+
1405+ _worker_job_count += 1
1406+ smt2_text = "(reset)\n " + solver .sexpr () + "\n (check-sat)\n "
1407+ smt2_bytes = smt2_text .encode ("utf-8" )
1408+
1409+ tar_info = tarfile .TarInfo (name = f"job_{ _worker_job_count } .smt2" )
1410+ tar_info .size = len (smt2_bytes )
1411+ _worker_tar .addfile (tar_info , io .BytesIO (smt2_bytes ))
1412+ # Ensure it's written to disk
1413+ _worker_tar .fileobj .flush ()
1414+
1415+ gadget_results , construction_time , solver_time = (
1416+ synthesizer .synthesize_gadget_with_symbolic (
1417+ top_seq_clone ,
1418+ bottom_seq_clone ,
1419+ input_state ,
1420+ target_pairs ,
1421+ max_solutions = max_solutions ,
1422+ solver_callback = dump_smt2_to_tar ,
1423+ )
13231424 )
13241425
1325- return gadget_results , input_state , metadata
1426+ metadata ["worker_pid" ] = os .getpid ()
1427+ return gadget_results , input_state , metadata , construction_time , solver_time
0 commit comments