Skip to content

Commit c062e91

Browse files
committed
super-optimizer: allow dumpting .smt2 files, tracking construction time vs. solver time
1 parent 78af7ce commit c062e91

5 files changed

Lines changed: 201 additions & 40 deletions

File tree

vxsort/smallsort/codegen/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ dependencies = [
1010
"pytest-cov>=7.0.0",
1111
"tabulate>=0.9.0",
1212
"rich>=14.3.1",
13+
"zstandard>=0.25.0",
1314
]
1415

1516
[dependency-groups]

vxsort/smallsort/codegen/src/bitonic_compiler.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python3
22
from __future__ import annotations
33
import argparse
4+
import tempfile
45

56
# Handle both relative and absolute imports
67
try:
@@ -105,6 +106,7 @@ def generate_bitonic_sorter(
105106
top_k: int | None = None,
106107
output_format: str = "json",
107108
gadget_depth: int = 3,
109+
smt2_dump_dir: str | None = None,
108110
):
109111
"""
110112
Generate bitonic sorter with super-optimized permutation sequences.
@@ -117,6 +119,7 @@ def generate_bitonic_sorter(
117119
top_k: Number of best solutions to keep. If None, all solutions are kept.
118120
output_format: Output format ("json" or "asm")
119121
gadget_depth: Maximum instruction depth per gadget (1-3, default 3)
122+
smt2_dump_dir: Directory to dump SMT2 files if requested
120123
121124
Returns:
122125
List of SolutionNode trees representing different optimized solutions
@@ -128,7 +131,7 @@ def generate_bitonic_sorter(
128131
)
129132

130133
# Create super-vectorizer
131-
super_opt = BitonicSuperVectorizer(num_vecs, type, vm)
134+
super_opt = BitonicSuperVectorizer(num_vecs, type, vm, smt2_dump_dir=smt2_dump_dir)
132135

133136
# Synthesize all stages to build solution tree
134137
print("Synthesizing permutation gadgets...")
@@ -215,13 +218,23 @@ def generate_bitonic_sorter(
215218
choices=[1, 2, 3],
216219
help="Maximum instruction depth per gadget (1-3, default: 3)",
217220
)
221+
parser.add_argument(
222+
"--dump-smt2",
223+
action="store_true",
224+
help="Dump SMT2 files from Z3 into compressed tar files in /tmp",
225+
)
218226

219227
args = parser.parse_args()
220228

221229
# Convert string arguments to Enum members
222230
vm = vector_machine[args.vector_machine]
223231
dtype = primitive_type[args.datatype]
224232

233+
smt2_dump_dir = None
234+
if args.dump_smt2:
235+
smt2_dump_dir = tempfile.mkdtemp(prefix="vxsort_smt2_", dir="/tmp")
236+
print(f"SMT2 dump directory: {smt2_dump_dir}")
237+
225238
generate_bitonic_sorter(
226239
args.num_vecs,
227240
dtype,
@@ -230,4 +243,5 @@ def generate_bitonic_sorter(
230243
top_k=args.top_k,
231244
output_format=args.output_format,
232245
gadget_depth=args.gadget_depth + 1, # +1 because range is exclusive
246+
smt2_dump_dir=smt2_dump_dir,
233247
)

vxsort/smallsort/codegen/src/bitonic_super_optimizer.py

Lines changed: 136 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,23 @@
11
from __future__ import annotations
2+
import time
23
import copy
4+
import os
5+
import tarfile
6+
import io
7+
import zstandard as zstd
38
from dataclasses import dataclass
49
from tabulate import tabulate
510
from multiprocessing import Pool
6-
from z3 import Solver, Context, main_ctx, Extract, BitVecVal, sat, BitVec, Distinct
11+
from z3 import (
12+
Solver,
13+
Context,
14+
main_ctx,
15+
Extract,
16+
BitVecVal,
17+
sat,
18+
BitVec,
19+
Distinct,
20+
)
721

822
try:
923
from .success_progress import SuccessProgress
@@ -299,7 +313,8 @@ def synthesize_gadget_with_symbolic(
299313
input_state: VectorState,
300314
target_pairs: list[tuple[int, int]],
301315
max_solutions: int | None = None,
302-
) -> list[tuple[PermutationGadget, VectorState]]:
316+
solver_callback: callable | None = None,
317+
) -> tuple[list[tuple[PermutationGadget, VectorState]], float, float]:
303318
"""
304319
Synthesize gadgets using symbolic immediates in Z3.
305320
@@ -309,7 +324,8 @@ def synthesize_gadget_with_symbolic(
309324
The symbolic values are represented using SymbolicPlaceholder for pickling.
310325
The pickling is required for multiprocessing.
311326
312-
Returns list of (gadget, output_state) tuples. The output state is computed
327+
Returns (results, construction_time, solver_time) where results is
328+
list of (gadget, output_state) tuples. The output state is computed
313329
directly from the satisfying model, avoiding a redundant Z3 solve. The output
314330
state is in canonical form: for each pair, the lower element index goes to
315331
the top vector and the higher index goes to bottom, reflecting the
@@ -318,7 +334,9 @@ def synthesize_gadget_with_symbolic(
318334
Args:
319335
max_solutions: Optional cap on the number of solutions returned.
320336
When ``None`` (the default) all solutions are enumerated.
337+
solver_callback: Optional callback receiving the Solver instance.
321338
"""
339+
start_construction = time.perf_counter()
322340
ctx = main_ctx()
323341
solver = Solver(ctx=ctx)
324342

@@ -400,14 +418,19 @@ def maybe_resolve_symbolic_vars(instructions: list[InstructionSpec]):
400418
# have the same pair_id, losing elements in the process.
401419
solver.add(Distinct(*output_lanes))
402420

421+
if solver_callback:
422+
solver_callback(solver)
423+
403424
# Collect symbolic variable terms for enumeration
404425
terms = list(symbolic_vars.values())
405-
426+
construction_time = time.perf_counter() - start_construction
427+
solver_start = time.perf_counter()
406428
# If no symbolic variables, single check suffices
407429
if not terms:
408430
result = solver.check()
409431
if result != sat:
410-
return []
432+
solver_time = time.perf_counter() - solver_start
433+
return [], construction_time, solver_time
411434
model = solver.model()
412435
gadget, output_state = self._extract_solution_from_model(
413436
model,
@@ -418,7 +441,8 @@ def maybe_resolve_symbolic_vars(instructions: list[InstructionSpec]):
418441
top_instructions_template,
419442
bottom_instructions_template,
420443
)
421-
return [(gadget, output_state)]
444+
solver_time = time.perf_counter() - solver_start
445+
return [(gadget, output_state)], construction_time, solver_time
422446

423447
# Enumerate all solutions over symbolic variables
424448
results = []
@@ -433,7 +457,8 @@ def maybe_resolve_symbolic_vars(instructions: list[InstructionSpec]):
433457
bottom_instructions_template,
434458
)
435459
results.append((gadget, output_state))
436-
return results
460+
solver_time = time.perf_counter() - solver_start
461+
return results, construction_time, solver_time
437462

438463
def _extract_solution_from_model(
439464
self,
@@ -827,22 +852,54 @@ def _validate_gadgets(
827852
)
828853

829854
try:
830-
with Pool() as pool:
831-
# Use imap_unordered for streaming results and progress updates
832-
for (
833-
gadget_results,
834-
job_input_state,
835-
job_metadata,
836-
) in pool.imap_unordered(_validate_gadget_worker, jobs):
837-
success_inc = 0
838-
for gadget, output_state in gadget_results:
839-
validated_gadgets.append(
840-
(gadget, job_input_state, output_state, job_metadata)
841-
)
842-
success_inc = 1
843-
844-
if progress:
845-
progress.update(task_id, advance=1, success=success_inc)
855+
pool = Pool()
856+
total_construct_time = 0.0
857+
total_solve_time = 0.0
858+
859+
# Use imap_unordered for streaming results and progress updates
860+
for (
861+
gadget_results,
862+
job_input_state,
863+
job_metadata,
864+
construct_time,
865+
solver_time,
866+
) in pool.imap_unordered(_validate_gadget_worker, jobs):
867+
total_construct_time += construct_time
868+
total_solve_time += solver_time
869+
870+
success_inc = 0
871+
for gadget, output_state in gadget_results:
872+
validated_gadgets.append(
873+
(gadget, job_input_state, output_state, job_metadata)
874+
)
875+
success_inc = 1
876+
877+
if progress:
878+
progress.update(task_id, advance=1, success=success_inc)
879+
880+
pool.close()
881+
pool.join()
882+
883+
if jobs:
884+
print(
885+
f"TOTAL construction time: {total_construct_time:.2f}s, TOTAL solver time: {total_solve_time:.2f}s"
886+
)
887+
888+
# Phase 2.5: Compress tar files if smt2_dump_dir is set
889+
if jobs and "smt2_dump_dir" in jobs[0][6]:
890+
smt2_dump_dir = jobs[0][6]["smt2_dump_dir"]
891+
stage_idx = jobs[0][6]["stage_idx"]
892+
for filename in os.listdir(smt2_dump_dir):
893+
if filename.startswith(f"stage{stage_idx}_") and filename.endswith(
894+
".tar"
895+
):
896+
tar_path = os.path.join(smt2_dump_dir, filename)
897+
zst_path = tar_path + ".zst"
898+
cctx = zstd.ZstdCompressor()
899+
with open(tar_path, "rb") as f_in:
900+
with open(zst_path, "wb") as f_out:
901+
cctx.copy_stream(f_in, f_out)
902+
os.remove(tar_path)
846903
finally:
847904
if progress:
848905
progress.stop()
@@ -979,10 +1036,17 @@ def _enumerate_dual_input_instructions(
9791036
class BitonicSuperVectorizer:
9801037
"""Super-optimizer for bitonic sorting networks using Z3-based gadget synthesis."""
9811038

982-
def __init__(self, num_vecs: int, prim_type: primitive_type, vm: vector_machine):
1039+
def __init__(
1040+
self,
1041+
num_vecs: int,
1042+
prim_type: primitive_type,
1043+
vm: vector_machine,
1044+
smt2_dump_dir: str | None = None,
1045+
):
9831046
self.num_vecs = num_vecs
9841047
self.prim_type = prim_type
9851048
self.vm = vm
1049+
self.smt2_dump_dir = smt2_dump_dir
9861050

9871051
# Calculate total elements and elements per vector
9881052
self.elements_per_vector = width_dict[vm] // int(prim_type.value[0])
@@ -1104,7 +1168,11 @@ def _build_tree_recursive(
11041168
metadata = {
11051169
"input_state": input_state,
11061170
"parent_path": parent_path,
1171+
"stage_idx": stage_idx,
11071172
}
1173+
if self.smt2_dump_dir:
1174+
metadata["smt2_dump_dir"] = self.smt2_dump_dir
1175+
11081176
if max_solutions_per_gadget is not None:
11091177
metadata["max_solutions"] = max_solutions_per_gadget
11101178

@@ -1296,13 +1364,17 @@ def node_to_dict(node: SolutionNode) -> dict:
12961364
print(f"Exported {len(roots)} solution trees to {output_path}")
12971365

12981366

1367+
_worker_tar = None
1368+
_worker_job_count = 0
1369+
1370+
12991371
def _validate_gadget_worker(job):
13001372
"""Worker function for parallel gadget validation.
13011373
1302-
Returns (gadget_results, input_state, metadata) where gadget_results is
1303-
a list of (gadget, output_state) tuples. The output_state is computed
1304-
directly during validation.
1374+
Returns (gadget_results, input_state, metadata, construction_time, solver_time)
1375+
where gadget_results is a list of (gadget, output_state) tuples.
13051376
"""
1377+
global _worker_tar, _worker_job_count
13061378
top_seq, bottom_seq, input_state, target_pairs, vm, prim_type, metadata = job
13071379

13081380
# Create clones of sequences to avoid modifying the ones in the main process
@@ -1314,12 +1386,42 @@ def _validate_gadget_worker(job):
13141386
# Default to 1 for backward compatibility; callers opt in to more via
13151387
# build_solution_tree(max_solutions_per_gadget=N).
13161388
max_solutions = metadata.get("max_solutions", 1)
1317-
gadget_results = synthesizer.synthesize_gadget_with_symbolic(
1318-
top_seq_clone,
1319-
bottom_seq_clone,
1320-
input_state,
1321-
target_pairs,
1322-
max_solutions=max_solutions,
1389+
1390+
smt2_dump_dir = metadata.get("smt2_dump_dir")
1391+
stage_idx = metadata.get("stage_idx")
1392+
1393+
def dump_smt2_to_tar(solver):
1394+
global _worker_tar, _worker_job_count
1395+
if smt2_dump_dir is None:
1396+
return
1397+
1398+
if _worker_tar is None:
1399+
pid = os.getpid()
1400+
# Write to an uncompressed tar file first; we'll compress it in the main process
1401+
tar_filename = f"stage{stage_idx}_pid_{pid}.tar"
1402+
tar_path = os.path.join(smt2_dump_dir, tar_filename)
1403+
_worker_tar = tarfile.open(tar_path, mode="a")
1404+
1405+
_worker_job_count += 1
1406+
smt2_text = "(reset)\n" + solver.sexpr() + "\n(check-sat)\n"
1407+
smt2_bytes = smt2_text.encode("utf-8")
1408+
1409+
tar_info = tarfile.TarInfo(name=f"job_{_worker_job_count}.smt2")
1410+
tar_info.size = len(smt2_bytes)
1411+
_worker_tar.addfile(tar_info, io.BytesIO(smt2_bytes))
1412+
# Ensure it's written to disk
1413+
_worker_tar.fileobj.flush()
1414+
1415+
gadget_results, construction_time, solver_time = (
1416+
synthesizer.synthesize_gadget_with_symbolic(
1417+
top_seq_clone,
1418+
bottom_seq_clone,
1419+
input_state,
1420+
target_pairs,
1421+
max_solutions=max_solutions,
1422+
solver_callback=dump_smt2_to_tar,
1423+
)
13231424
)
13241425

1325-
return gadget_results, input_state, metadata
1426+
metadata["worker_pid"] = os.getpid()
1427+
return gadget_results, input_state, metadata, construction_time, solver_time

vxsort/smallsort/codegen/tests/test_symbolic_synthesis.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@ def test_symbolic_synthesis():
3636
print(f"Target pairs: {target_pairs}")
3737

3838
# Try with no instructions (should succeed)
39-
# Returns list of (gadget, output_state) tuples
40-
results = synthesizer.synthesize_gadget_with_symbolic(
39+
# Returns (results, construction_time, solver_time)
40+
# where results is list of (gadget, output_state) tuples
41+
results, _, _ = synthesizer.synthesize_gadget_with_symbolic(
4142
[], [], input_state, target_pairs
4243
)
4344

@@ -82,8 +83,9 @@ def test_symbolic_synthesis():
8283
print(f"Target pairs: {target_pairs2}")
8384
print(f"Instruction template: {inst_template.intrinsic_name}")
8485

85-
# Returns list of (gadget, output_state) tuples
86-
results2 = synthesizer.synthesize_gadget_with_symbolic(
86+
# Returns (results, construction_time, solver_time)
87+
# where results is list of (gadget, output_state) tuples
88+
results2, _, _ = synthesizer.synthesize_gadget_with_symbolic(
8789
[inst_template], [], input_state2, target_pairs2
8890
)
8991

@@ -200,7 +202,7 @@ def test_multi_solution_enumeration():
200202

201203
# Top: apply blend(top, bottom, symbolic_imm8)
202204
# Bottom: identity (no instructions) — stays as bottom
203-
results = synthesizer.synthesize_gadget_with_symbolic(
205+
results, _, _ = synthesizer.synthesize_gadget_with_symbolic(
204206
[blend_template], [], input_state, target_pairs
205207
)
206208

0 commit comments

Comments
 (0)