diff --git a/bin/Index/IRGen.cpp b/bin/Index/IRGen.cpp index 5edc88cee..0b85ce873 100644 --- a/bin/Index/IRGen.cpp +++ b/bin/Index/IRGen.cpp @@ -314,6 +314,7 @@ std::optional IRGenerator::Generate( uint32_t func_scope = PushStructure( mx::ir::StructureKind::FUNCTION_SCOPE, EntityIdOf(*body)); func_.body_scope_index = func_scope; + AssociateBlockWithStructure(frame); // FRAME was created before scope; wire it up now. AssociateBlockWithStructure(entry); // ENTER_SCOPE for the function body. @@ -375,6 +376,17 @@ std::optional IRGenerator::Generate( // Insert compensation blocks for gotos that cross scope boundaries. InsertGotoCompensationBlocks(); + // Safety net: any block still without a parent structure (e.g. a + // LABEL block forward-referenced by a goto whose label was never + // defined, or any future synthetic block that escapes the per-site + // associate calls) gets attached to the function body scope so + // `IRBlock::parent_function()` always resolves. + for (uint32_t bi = 0; bi < func_.blocks.size(); ++bi) { + if (func_.blocks[bi].parent_structure_index == UINT32_MAX) { + AssociateBlockWithStructure(bi, func_.body_scope_index); + } + } + // Patch empty blocks before computing dominators. // Empty blocks arise when all paths into a merge/exit block already // terminated (e.g., both if-branches return), or from empty switch cases. @@ -464,6 +476,7 @@ std::optional IRGenerator::GenerateGlobalInit( uint32_t func_scope = PushStructure( mx::ir::StructureKind::FUNCTION_SCOPE, EntityIdOf(var)); func_.body_scope_index = func_scope; + AssociateBlockWithStructure(frame); // FRAME was created before scope; wire it up now. AssociateBlockWithStructure(entry); // ENTER_SCOPE for the function body. @@ -504,6 +517,14 @@ std::optional IRGenerator::GenerateGlobalInit( // Pop FUNCTION_SCOPE. PopStructure(); + // Safety net: associate any block missing a parent structure with the + // function body scope. Mirrors the pass in `Generate()`. + for (uint32_t bi = 0; bi < func_.blocks.size(); ++bi) { + if (func_.blocks[bi].parent_structure_index == UINT32_MAX) { + AssociateBlockWithStructure(bi, func_.body_scope_index); + } + } + // Patch empty blocks. for (uint32_t bi = 0; bi < func_.blocks.size(); ++bi) { auto &block = func_.blocks[bi]; @@ -579,12 +600,17 @@ void IRGenerator::PopStructure() { } void IRGenerator::AssociateBlockWithStructure(uint32_t block_idx) { - if (current_structure_index_ == UINT32_MAX) return; - func_.blocks[block_idx].parent_structure_index = current_structure_index_; + AssociateBlockWithStructure(block_idx, current_structure_index_); +} + +void IRGenerator::AssociateBlockWithStructure(uint32_t block_idx, + uint32_t struct_idx) { + if (struct_idx == UINT32_MAX) return; + func_.blocks[block_idx].parent_structure_index = struct_idx; StructureIR::ChildRef ref; ref.index = block_idx; ref.is_structure = false; - func_.structures[current_structure_index_].children.push_back(ref); + func_.structures[struct_idx].children.push_back(ref); } void IRGenerator::AssociateObjectWithScope(uint32_t obj_idx) { @@ -631,7 +657,12 @@ uint32_t IRGenerator::NewBlock(mx::ir::BlockKind kind) { // The dead block gets an IMPLICIT_UNREACHABLE terminator immediately so that // CurrentBlockTerminated() returns true and dead-code skipping works. void IRGenerator::SwitchToDeadBlock() { - SwitchToBlock(NewBlock(mx::ir::BlockKind::UNREACHABLE)); + uint32_t dead = NewBlock(mx::ir::BlockKind::UNREACHABLE); + // Anchor the dead block to the current scope so its parent_function() + // resolves; otherwise downstream consumers walking up the structure + // chain hit a nullopt parent. + AssociateBlockWithStructure(dead); + SwitchToBlock(dead); InstructionIR term; term.opcode = mx::ir::OpCode::IMPLICIT_UNREACHABLE; EmitTopLevel(std::move(term)); @@ -1495,22 +1526,33 @@ void IRGenerator::EmitSwitchStmt(const pasta::Stmt &s) { } // If no explicit default, add an implicit default that branches to the - // switch exit block. Without this, the interpreter errors when no case - // matches (e.g., a switch with gaps in its case values). + // switch exit block. The default needs its own block (not exit_block + // directly) because IRSwitchCaseStructure::target_block() looks up the + // structure's first child block; if we used exit_block, it would belong + // to the outer SWITCH structure (line ~1703 below) and the default's + // SWITCH_CASE structure would have no children, making target_block() + // return {} and the interpreter's `decide_switch` skip past the default + // entirely on a non-matching selector. if (!has_default) { - // Create a structure for the implicit default so serialization succeeds. + uint32_t default_block = NewBlock(mx::ir::BlockKind::SWITCH_DEFAULT); + uint32_t impl_struct = PushStructure( mx::ir::StructureKind::SWITCH_CASE, EntityIdOf(s)); auto &sc_struct = func_.structures[impl_struct]; sc_struct.is_default = true; + AssociateBlockWithStructure(default_block); PopStructure(); InstructionIR::SwitchCaseIR implicit_default; implicit_default.is_default = true; - implicit_default.block_index = exit_block; + implicit_default.block_index = default_block; implicit_default.structure_index = impl_struct; term.switch_cases.push_back(implicit_default); - AddEdge(current_block_index_, exit_block); + AddEdge(current_block_index_, default_block); + // The empty default_block falls through to the switch exit; the + // post-processing pass at the top of Generate() gives empty blocks + // an IMPLICIT_GOTO terminator to their first successor. + AddEdge(default_block, exit_block); } uint32_t term_idx = EmitTopLevel(std::move(term)); @@ -1634,6 +1676,7 @@ void IRGenerator::EmitSwitchStmt(const pasta::Stmt &s) { // Emit the loop condition and back-edge. // After the last case in the body, branch to the condition block. uint32_t cond_block = NewBlock(mx::ir::BlockKind::LOOP_CONDITION); + AssociateBlockWithStructure(cond_block); EmitBranch(cond_block); SwitchToBlock(cond_block); @@ -1646,6 +1689,7 @@ void IRGenerator::EmitSwitchStmt(const pasta::Stmt &s) { ? cases[loop_top_ci].block_index : exit_block; uint32_t loop_exit = NewBlock(mx::ir::BlockKind::LOOP_EXIT); + AssociateBlockWithStructure(loop_exit); EmitCondBranch(cond_val, loop_body_block, loop_exit, EntityIdOf(stmt)); SwitchToBlock(loop_exit); } else { @@ -4595,8 +4639,16 @@ void IRGenerator::InsertGotoCompensationBlocks() { // Compensation block needed. - // Create a compensation block. + // Create a compensation block. This pass runs *after* all PushStructure + // / PopStructure activity, so `current_structure_index_` is no longer + // meaningful — anchor the new block explicitly to the common-ancestor + // scope (the structure the compensation logically lives at), or to the + // function body scope as a fallback. Without a parent, the block has no + // path back to its function via `IRBlock::parent_function()`. uint32_t comp_block = NewBlock(mx::ir::BlockKind::COMPENSATION); + uint32_t comp_parent = (common_ancestor != UINT32_MAX) ? common_ancestor + : func_.body_scope_index; + AssociateBlockWithStructure(comp_block, comp_parent); // Redirect the source instruction → comp_block instead of → target. auto &goto_inst = func_.instructions[pg.goto_inst_idx]; diff --git a/bin/Index/IRGen.h b/bin/Index/IRGen.h index 4aff237b3..f9c5386c3 100644 --- a/bin/Index/IRGen.h +++ b/bin/Index/IRGen.h @@ -249,6 +249,11 @@ class IRGenerator { mx::RawEntityId source_eid = mx::kInvalidEntityId); void PopStructure(); void AssociateBlockWithStructure(uint32_t block_idx); + // Explicit-target overload: used by post-emission passes (e.g. goto + // compensation) where `current_structure_index_` is no longer + // meaningful but the synthetic block still needs a parent so + // `IRBlock::parent_function()` resolves. + void AssociateBlockWithStructure(uint32_t block_idx, uint32_t struct_idx); void AssociateObjectWithScope(uint32_t obj_idx); // Emit EXIT_SCOPE for all enclosing scopes up to (but not including) diff --git a/bindings/Python/Interpreter.cpp b/bindings/Python/Interpreter.cpp index 7fb3edf3e..74abc12ba 100644 --- a/bindings/Python/Interpreter.cpp +++ b/bindings/Python/Interpreter.cpp @@ -536,9 +536,11 @@ static PyObject *py_init_state(PyObject *, PyObject *args) { PyObject *global_resolver = (nargs >= 7) ? PyTuple_GetItem(args, 6) : Py_None; PyObject *func_addr_resolver = (nargs >= 8) ? PyTuple_GetItem(args, 7) : Py_None; + PyObject *entity_by_addr_resolver = + (nargs >= 9) ? PyTuple_GetItem(args, 8) : Py_None; return SymbolicInitState(state_obj, second, py_policy, func_obj, args_list, func_resolver, global_resolver, - func_addr_resolver); + func_addr_resolver, entity_by_addr_resolver); } PyErr_SetString(PyExc_TypeError, @@ -657,9 +659,11 @@ static PyObject *py_step(PyObject *, PyObject *args) { PyObject *global_resolver = (nargs >= 6) ? PyTuple_GetItem(args, 5) : Py_None; PyObject *func_addr_resolver = (nargs >= 7) ? PyTuple_GetItem(args, 6) : Py_None; + PyObject *entity_by_addr_resolver = + (nargs >= 8) ? PyTuple_GetItem(args, 7) : Py_None; return SymbolicStep(state_obj, second, py_policy, max_steps, func_resolver, global_resolver, - func_addr_resolver); + func_addr_resolver, entity_by_addr_resolver); } PyErr_SetString(PyExc_TypeError, @@ -770,6 +774,45 @@ static PyObject *py_get_value_at(PyObject *, PyObject *args) { return obj; } +// Resume a state suspended on a symbolic SWITCH selector by entering +// the chosen target block. The driver picks one case (or default) per +// fork and calls this with the cloned snapshot + the IRBlock to enter. +// +// Path-condition constraints (selector ∈ [low, high], etc.) are added on +// the Python side; this entry-point only manipulates the substrate state. +// +// resume_switch_case(state, target_block_obj) +static PyObject *py_resume_switch_case(PyObject *, PyObject *args) { + PyObject *state_obj; + PyObject *block_obj; + if (!PyArg_ParseTuple(args, "OO", &state_obj, &block_obj)) { + return nullptr; + } + if (Py_TYPE(state_obj) != &InterpreterStateType) { + PyErr_SetString(PyExc_TypeError, "Expected InterpreterState"); + return nullptr; + } + auto block = from_python(block_obj); + if (!block) { + PyErr_SetString(PyExc_TypeError, + "resume_switch_case: second argument must be IRBlock"); + return nullptr; + } + auto *sw = reinterpret_cast(state_obj); + auto *symbolic = sw->symbolic_state + ? reinterpret_cast *>( + sw->symbolic_state)->data + : nullptr; + if (!symbolic || symbolic->call_stack.empty()) { + PyErr_SetString(PyExc_RuntimeError, + "resume_switch_case: symbolic state has no live call frame"); + return nullptr; + } + symbolic->work_stack.push_back( + {ir::interpret::WorkKind::ENTER_BLOCK, IRInstruction{}, *block}); + Py_RETURN_NONE; +} + static PyObject *py_clone_state(PyObject *, PyObject *args) { PyObject *state_obj; if (!PyArg_ParseTuple(args, "O", &state_obj)) return nullptr; @@ -832,6 +875,8 @@ static PyObject *py_init_state_frame(PyObject *, PyObject *args) { (nargs >= 8) ? PyTuple_GetItem(args, 7) : Py_None; PyObject *func_addr_resolver = (nargs >= 9) ? PyTuple_GetItem(args, 8) : Py_None; + PyObject *entity_by_addr_resolver = + (nargs >= 10) ? PyTuple_GetItem(args, 9) : Py_None; if (Py_TYPE(memory_obj) != &ConcreteMemoryType) { PyErr_SetString(PyExc_TypeError, @@ -842,7 +887,7 @@ static PyObject *py_init_state_frame(PyObject *, PyObject *args) { return SymbolicInitStateFrame(state_obj, memory_obj, py_policy, func_obj, param_addrs, return_addr, func_resolver, global_resolver, - func_addr_resolver); + func_addr_resolver, entity_by_addr_resolver); } // init_state_at: mid-block entry for under-constrained symbolic execution. @@ -878,6 +923,8 @@ static PyObject *py_init_state_at(PyObject *, PyObject *args) { (nargs >= 10) ? PyTuple_GetItem(args, 9) : Py_None; PyObject *func_addr_resolver = (nargs >= 11) ? PyTuple_GetItem(args, 10) : Py_None; + PyObject *entity_by_addr_resolver = + (nargs >= 12) ? PyTuple_GetItem(args, 11) : Py_None; if (Py_TYPE(memory_obj) != &ConcreteMemoryType) { PyErr_SetString(PyExc_TypeError, @@ -888,7 +935,7 @@ static PyObject *py_init_state_at(PyObject *, PyObject *args) { return SymbolicInitStateAt(state_obj, memory_obj, py_policy, func_obj, block_obj, param_addrs, return_addr, value_seed, func_resolver, global_resolver, - func_addr_resolver); + func_addr_resolver, entity_by_addr_resolver); } // Module methods. @@ -925,6 +972,11 @@ static PyMethodDef InterpreterMethods[] = { " Symbolic-address sibling of resume_addr: writes an arbitrary " "Python value (typically a z3 expression) into the suspended op's " "address-operand cache slot."}, + {"resume_switch_case", py_resume_switch_case, METH_VARARGS, + "resume_switch_case(state, target_block)\n" + " Resume a state suspended on a symbolic SWITCH selector by " + "entering the chosen target IRBlock. Path-condition constraints are " + "the driver's responsibility."}, {"get_value_at", py_get_value_at, METH_VARARGS, "get_value_at(state, eid) -> Python value\n" " Read the cached value at an operand entity-id from the live " diff --git a/bindings/Python/SymbolicInterpreter.cpp b/bindings/Python/SymbolicInterpreter.cpp index 540de4a67..02b3ddeb1 100644 --- a/bindings/Python/SymbolicInterpreter.cpp +++ b/bindings/Python/SymbolicInterpreter.cpp @@ -240,12 +240,14 @@ bool LoadSymbolicStateType(::PyObject *interp_module) { PythonPolicy::PythonPolicy(PyObject *py_policy, ConcreteMemory &memory, FunctionResolver func_resolver, GlobalResolver global_resolver, - FunctionAddressResolver func_addr_resolver) + FunctionAddressResolver func_addr_resolver, + EntityByAddressResolver entity_by_addr_resolver) : py_policy_(py_policy), memory_(memory), func_resolver_(std::move(func_resolver)), global_resolver_(std::move(global_resolver)), - func_addr_resolver_(std::move(func_addr_resolver)) {} + func_addr_resolver_(std::move(func_addr_resolver)), + entity_by_addr_resolver_(std::move(entity_by_addr_resolver)) {} PythonPolicy::~PythonPolicy() { Py_XDECREF(cached_make_const_); @@ -264,6 +266,7 @@ PythonPolicy::~PythonPolicy() { Py_XDECREF(cached_symbolic_load_); Py_XDECREF(cached_symbolic_store_); Py_XDECREF(cached_on_enter_block_); + Py_XDECREF(cached_on_global_initialized_); Py_XDECREF(cached_on_instruction_); Py_XDECREF(pending_exc_type_); Py_XDECREF(pending_exc_value_); @@ -284,6 +287,18 @@ PyObject *PythonPolicy::lookup_method(PyObject *&cache, const char *name) { return (cache != Py_None) ? cache : nullptr; } +// `Py_BuildValue`'s "O" format aborts with `SystemError: NULL object passed +// to Py_BuildValue` if the argument is a real C NULL — `PyObject_CallFunction` +// uses the same builder under the hood. A `SharedPyPtr::Get()` *should* +// always be a real PyObject (Py_None at worst, via `make_default()` / +// `default_value()`), but if a code path ever default-constructs a +// `SharedPyPtr` and stuffs it into the value cache we'd silently feed NULL +// to Py_BuildValue and crash deep in C. Substitute Py_None at the call +// boundary so the failure surfaces as a normal Python None instead. +static inline PyObject *or_none(PyObject *p) noexcept { + return p ? p : Py_None; +} + // =========================================================================== // 0. Value extraction / construction // =========================================================================== @@ -315,6 +330,20 @@ int64_t PythonPolicy::extract_int(const SharedPyPtr &val) { return 0; } +std::optional PythonPolicy::try_extract_int_impl( + const SharedPyPtr &val) { + PyObject *obj = val.Get(); + if (obj && PyLong_Check(obj) && !PyBool_Check(obj)) { + int64_t v = PyLong_AsLongLong(obj); + if (v == -1 && PyErr_Occurred()) { + PyErr_Clear(); + return std::nullopt; + } + return v; + } + return std::nullopt; +} + uint64_t PythonPolicy::extract_uint(const SharedPyPtr &val) { PyObject *obj = val.Get(); if (obj && PyLong_Check(obj)) return PyLong_AsUnsignedLongLong(obj); @@ -357,8 +386,8 @@ SharedPyPtr PythonPolicy::make_const(ConstOp op, int64_t signed_val, Py_DECREF(result); return v; } - Py_XDECREF(result); - PyErr_Clear(); + if (!result) { capture_exception(); return make_default(); } + Py_DECREF(result); // Py_NotImplemented } return value_to_shared(concrete_make_const(op, signed_val, unsigned_val)); } @@ -375,14 +404,14 @@ SharedPyPtr PythonPolicy::binary_op(OpCode op, const SharedPyPtr &lhs, const SharedPyPtr &rhs) { if (PyObject *method = lookup_method(cached_binary_op_, "binary_op")) { PyObject *result = PyObject_CallFunction( - method, "iOO", static_cast(op), lhs.Get(), rhs.Get()); + method, "iOO", static_cast(op), or_none(lhs.Get()), or_none(rhs.Get())); if (result && result != Py_NotImplemented) { SharedPyPtr v(result); Py_DECREF(result); return v; } - Py_XDECREF(result); - PyErr_Clear(); + if (!result) { capture_exception(); return make_default(); } + Py_DECREF(result); } bool is_float_arith = ir::IsFloatArithmetic(op); bool is_float_cmp = ir::IsFloatComparison(op); @@ -400,14 +429,14 @@ SharedPyPtr PythonPolicy::binary_op(OpCode op, const SharedPyPtr &lhs, SharedPyPtr PythonPolicy::unary_op(OpCode op, const SharedPyPtr &operand) { if (PyObject *method = lookup_method(cached_unary_op_, "unary_op")) { PyObject *result = PyObject_CallFunction( - method, "iO", static_cast(op), operand.Get()); + method, "iO", static_cast(op), or_none(operand.Get())); if (result && result != Py_NotImplemented) { SharedPyPtr v(result); Py_DECREF(result); return v; } - Py_XDECREF(result); - PyErr_Clear(); + if (!result) { capture_exception(); return make_default(); } + Py_DECREF(result); } bool is_float_arith = ir::IsFloatArithmetic(op); bool needs_f32 = is_float_arith && @@ -424,14 +453,14 @@ SharedPyPtr PythonPolicy::compare(OpCode op, const SharedPyPtr &lhs, const SharedPyPtr &rhs) { if (PyObject *method = lookup_method(cached_compare_, "compare")) { PyObject *result = PyObject_CallFunction( - method, "iOO", static_cast(op), lhs.Get(), rhs.Get()); + method, "iOO", static_cast(op), or_none(lhs.Get()), or_none(rhs.Get())); if (result && result != Py_NotImplemented) { SharedPyPtr v(result); Py_DECREF(result); return v; } - Py_XDECREF(result); - PyErr_Clear(); + if (!result) { capture_exception(); return make_default(); } + Py_DECREF(result); } bool needs_f32 = ir::IsFloatComparison(op) && (static_cast(op) % 2 == 1); @@ -443,14 +472,14 @@ SharedPyPtr PythonPolicy::compare(OpCode op, const SharedPyPtr &lhs, SharedPyPtr PythonPolicy::cast(CastOp op, const SharedPyPtr &operand) { if (PyObject *method = lookup_method(cached_cast_, "cast")) { PyObject *result = PyObject_CallFunction( - method, "iO", static_cast(op), operand.Get()); + method, "iO", static_cast(op), or_none(operand.Get())); if (result && result != Py_NotImplemented) { SharedPyPtr v(result); Py_DECREF(result); return v; } - Py_XDECREF(result); - PyErr_Clear(); + if (!result) { capture_exception(); return make_default(); } + Py_DECREF(result); } // Casts from f32 require the input to be in f32 form (low 32 bits). bool input_is_f32 = (op == CastOp::F32_TO_F64) || @@ -473,15 +502,15 @@ SharedPyPtr PythonPolicy::ptr_add(const SharedPyPtr &base, int64_t element_size) { if (PyObject *method = lookup_method(cached_ptr_add_, "ptr_add")) { PyObject *result = PyObject_CallFunction( - method, "OOL", base.Get(), index.Get(), + method, "OOL", or_none(base.Get()), or_none(index.Get()), static_cast(element_size)); if (result && result != Py_NotImplemented) { SharedPyPtr v(result); Py_DECREF(result); return v; } - Py_XDECREF(result); - PyErr_Clear(); + if (!result) { capture_exception(); return make_default(); } + Py_DECREF(result); } // ptr_add yields a pointer; preserve the ("ptr", N) tagging so // downstream `extract_address` can recover it. value_to_shared loses @@ -498,15 +527,15 @@ SharedPyPtr PythonPolicy::ptr_diff(const SharedPyPtr &lhs, int64_t element_size) { if (PyObject *method = lookup_method(cached_ptr_diff_, "ptr_diff")) { PyObject *result = PyObject_CallFunction( - method, "OOL", lhs.Get(), rhs.Get(), + method, "OOL", or_none(lhs.Get()), or_none(rhs.Get()), static_cast(element_size)); if (result && result != Py_NotImplemented) { SharedPyPtr v(result); Py_DECREF(result); return v; } - Py_XDECREF(result); - PyErr_Clear(); + if (!result) { capture_exception(); return make_default(); } + Py_DECREF(result); } return value_to_shared(concrete_ptr_diff( python_to_value(lhs.Get()), python_to_value(rhs.Get()), element_size)); @@ -516,15 +545,15 @@ SharedPyPtr PythonPolicy::ptr_offset(const SharedPyPtr &base, int64_t byte_offset) { if (PyObject *method = lookup_method(cached_ptr_offset_, "ptr_offset")) { PyObject *result = PyObject_CallFunction( - method, "OL", base.Get(), + method, "OL", or_none(base.Get()), static_cast(byte_offset)); if (result && result != Py_NotImplemented) { SharedPyPtr v(result); Py_DECREF(result); return v; } - Py_XDECREF(result); - PyErr_Clear(); + if (!result) { capture_exception(); return make_default(); } + Py_DECREF(result); } // GEP_FIELD's result is a pointer — see the ptr_add comment. Value v = concrete_ptr_offset(python_to_value(base.Get()), byte_offset); @@ -574,7 +603,7 @@ SharedPyPtr PythonPolicy::float_intrinsic( std::optional PythonPolicy::is_true(const SharedPyPtr &val) { if (PyObject *method = lookup_method(cached_is_true_, "is_true")) { - PyObject *result = PyObject_CallFunction(method, "O", val.Get()); + PyObject *result = PyObject_CallFunction(method, "O", or_none(val.Get())); if (result && result != Py_NotImplemented) { if (result == Py_None) { Py_DECREF(result); @@ -584,8 +613,8 @@ std::optional PythonPolicy::is_true(const SharedPyPtr &val) { Py_DECREF(result); return truth; } - Py_XDECREF(result); - PyErr_Clear(); + if (!result) { capture_exception(); return std::nullopt; } + Py_DECREF(result); } return concrete_is_true(python_to_value(val.Get())); } @@ -609,15 +638,15 @@ bool PythonPolicy::mem_read(PythonScheduler &, const SharedPyPtr &addr, const MemAccessHint &hint, SharedPyPtr &result) { if (PyObject *method = lookup_method(cached_mem_read_, "mem_read")) { PyObject *py_result = PyObject_CallFunction( - method, "OIi", addr.Get(), hint.size_bytes, + method, "OIi", or_none(addr.Get()), hint.size_bytes, static_cast(hint.is_float)); if (py_result && py_result != Py_NotImplemented) { result = SharedPyPtr(py_result); Py_DECREF(py_result); return true; } - Py_XDECREF(py_result); - PyErr_Clear(); + if (!py_result) { capture_exception(); result = make_default(); return true; } + Py_DECREF(py_result); // Py_NotImplemented } // Concrete fallback. @@ -626,7 +655,9 @@ bool PythonPolicy::mem_read(PythonScheduler &, const SharedPyPtr &addr, result = make_default(); return true; } - Value v = concrete_read_from_mem(memory_, *a, hint.size_bytes, hint.is_float); + Value v = IsBigEndian(hint.sub_op) + ? concrete_read_from_mem_be(memory_, *a, hint.size_bytes, hint.is_float) + : concrete_read_from_mem_le(memory_, *a, hint.size_bytes, hint.is_float); result = hint.is_float ? float_value_to_shared(v, hint.size_bytes) : value_to_shared(v); return true; @@ -637,21 +668,26 @@ bool PythonPolicy::mem_write(PythonScheduler &, const SharedPyPtr &addr, const MemAccessHint &hint) { if (PyObject *method = lookup_method(cached_mem_write_, "mem_write")) { PyObject *py_result = PyObject_CallFunction( - method, "OOIi", addr.Get(), val.Get(), hint.size_bytes, - static_cast(hint.is_float)); + method, "OOIi", or_none(addr.Get()), or_none(val.Get()), + hint.size_bytes, static_cast(hint.is_float)); if (py_result && py_result != Py_NotImplemented) { Py_DECREF(py_result); return true; } - Py_XDECREF(py_result); - PyErr_Clear(); + if (!py_result) { capture_exception(); return true; } + Py_DECREF(py_result); // Py_NotImplemented } // Concrete fallback. auto a = extract_address(addr); if (!a) return true; - concrete_write_to_mem(memory_, *a, python_to_value(val.Get()), - hint.size_bytes, hint.is_float); + if (IsBigEndian(hint.sub_op)) { + concrete_write_to_mem_be(memory_, *a, python_to_value(or_none(val.Get())), + hint.size_bytes, hint.is_float); + } else { + concrete_write_to_mem_le(memory_, *a, python_to_value(or_none(val.Get())), + hint.size_bytes, hint.is_float); + } return true; } @@ -667,15 +703,15 @@ bool PythonPolicy::exec_symbolic_load_impl(PythonScheduler &, PyObject *method = lookup_method(cached_symbolic_load_, "symbolic_load"); if (!method) return false; PyObject *py_result = PyObject_CallFunction( - method, "OIi", addr.Get(), hint.size_bytes, + method, "OIi", or_none(addr.Get()), hint.size_bytes, static_cast(hint.is_float)); if (py_result && py_result != Py_NotImplemented && py_result != Py_None) { result = SharedPyPtr(py_result); Py_DECREF(py_result); return true; } - Py_XDECREF(py_result); - PyErr_Clear(); + if (!py_result) { capture_exception(); return false; } + Py_DECREF(py_result); // Py_NotImplemented or Py_None return false; } @@ -686,14 +722,14 @@ bool PythonPolicy::exec_symbolic_store_impl(PythonScheduler &, PyObject *method = lookup_method(cached_symbolic_store_, "symbolic_store"); if (!method) return false; PyObject *py_result = PyObject_CallFunction( - method, "OOIi", addr.Get(), val.Get(), hint.size_bytes, - static_cast(hint.is_float)); + method, "OOIi", or_none(addr.Get()), or_none(val.Get()), + hint.size_bytes, static_cast(hint.is_float)); if (py_result && py_result != Py_NotImplemented) { Py_DECREF(py_result); return true; } - Py_XDECREF(py_result); - PyErr_Clear(); + if (!py_result) { capture_exception(); return false; } + Py_DECREF(py_result); // Py_NotImplemented return false; } @@ -735,8 +771,8 @@ void PythonPolicy::on_enter_block_impl(const IRBlock &block) { auto block_eid = EntityId(block.id()).Pack(); PyObject *result = PyObject_CallFunction(method, "K", static_cast(block_eid)); - Py_XDECREF(result); - if (PyErr_Occurred()) PyErr_Clear(); + if (!result) { capture_exception(); return; } + Py_DECREF(result); } // Per-instruction observer hook. @@ -744,10 +780,36 @@ void PythonPolicy::on_instruction_impl_inner(const IRInstruction &inst) { PyObject *method = lookup_method(cached_on_instruction_, "on_instruction"); if (!method) return; PyObject *inst_obj = ::mx::to_python(inst); - if (!inst_obj) { PyErr_Clear(); return; } + if (!inst_obj) { capture_exception(); return; } PyObject *result = PyObject_CallFunction(method, "N", inst_obj); - Py_XDECREF(result); - if (PyErr_Occurred()) PyErr_Clear(); + if (!result) { capture_exception(); return; } + Py_DECREF(result); +} + +// Global-initialized observer hook. +void PythonPolicy::on_global_initialized_impl_inner( + const IRFunction &init_func, const SharedPyPtr &addr) { + PyObject *method = lookup_method(cached_on_global_initialized_, + "on_global_initialized"); + if (!method) return; + // The init frame was always constructed with `params = {addr}` where + // `addr` is a concrete literal pointer (compute_global_ptr). If it + // ever fails to extract here, that's a substrate bug, not a runtime + // condition — surface it loudly instead of silently passing 0. + auto resolved = extract_address(addr); + if (!resolved) { + PyErr_SetString(PyExc_RuntimeError, + "on_global_initialized: initializer frame's address slot is not " + "concrete (substrate invariant broken)"); + capture_exception(); + return; + } + PyObject *func_obj = ::mx::to_python(init_func); + if (!func_obj) { capture_exception(); return; } + PyObject *result = PyObject_CallFunction(method, "NK", func_obj, + static_cast(*resolved)); + if (!result) { capture_exception(); return; } + Py_DECREF(result); } // =========================================================================== @@ -764,9 +826,9 @@ bool PythonPolicy::resolve_branch(PythonScheduler &, auto true_eid = EntityId(true_block.id()).Pack(); auto false_eid = EntityId(false_block.id()).Pack(); PyObject *inst_obj = ::mx::to_python(branch_inst); - if (!inst_obj) { PyErr_Clear(); inst_obj = Py_None; Py_INCREF(Py_None); } + if (!inst_obj) { capture_exception(); return false; } PyObject *result = PyObject_CallFunction( - method, "NOKK", inst_obj, condition.Get(), true_eid, false_eid); + method, "NOKK", inst_obj, or_none(condition.Get()), true_eid, false_eid); if (result && result != Py_NotImplemented) { if (result == Py_None) { Py_DECREF(result); @@ -781,8 +843,8 @@ bool PythonPolicy::resolve_branch(PythonScheduler &, chosen_block = false_block; return true; } - Py_XDECREF(result); - PyErr_Clear(); + if (!result) { capture_exception(); return false; } + Py_DECREF(result); // Py_NotImplemented } chosen_block = true_block; return true; @@ -792,6 +854,7 @@ bool PythonPolicy::resolve_call(PythonScheduler &, const IRInstruction &call_inst, RawEntityId target_eid, RawEntityId indirect_target_eid, + uint64_t target_addr, const std::vector &arguments, bool is_indirect, CallResolution &resolution) { @@ -800,15 +863,16 @@ bool PythonPolicy::resolve_call(PythonScheduler &, PyObject *args_list = PyList_New( static_cast(arguments.size())); for (size_t i = 0; i < arguments.size(); ++i) { - PyObject *arg = arguments[i].Get(); + PyObject *arg = or_none(arguments[i].Get()); Py_INCREF(arg); PyList_SET_ITEM(args_list, static_cast(i), arg); } PyObject *inst_obj = ::mx::to_python(call_inst); - if (!inst_obj) { PyErr_Clear(); inst_obj = Py_None; Py_INCREF(Py_None); } + if (!inst_obj) { capture_exception(); Py_DECREF(args_list); return false; } PyObject *result = PyObject_CallFunction( - method, "NKKOi", inst_obj, target_eid, indirect_target_eid, + method, "NKKKOi", inst_obj, target_eid, indirect_target_eid, + static_cast(target_addr), args_list, static_cast(is_indirect)); Py_DECREF(args_list); @@ -857,7 +921,12 @@ bool PythonPolicy::resolve_call(PythonScheduler &, if (func_resolver_) { for (auto eid : {target_eid, indirect_target_eid}) { if (eid != kInvalidEntityId) { - if (auto ir = func_resolver_(eid)) { + auto ir = func_resolver_(eid); + if (PyErr_Occurred()) { + capture_exception(); + return false; + } + if (ir) { resolution.action = CallAction::INLINE; resolution.return_value = make_default(); resolution.callee_ir = *std::move(ir); @@ -875,7 +944,12 @@ bool PythonPolicy::resolve_global(PythonScheduler &, RawEntityId entity_id, GlobalResolution &resolution) { if (global_resolver_) { - if (auto info = global_resolver_(entity_id)) { + auto info = global_resolver_(entity_id); + if (PyErr_Occurred()) { + capture_exception(); + return false; + } + if (info) { resolution.info = *std::move(info); return true; } @@ -895,8 +969,10 @@ FunctionResolver make_func_resolver(PyObject *obj) { SharedPyPtr fr(obj); return [fr](RawEntityId eid) -> std::optional { SharedPyPtr ret(PyObject_CallFunction(fr.Get(), "(K)", eid)); - if (!ret || ret.Get() == Py_None) { - PyErr_Clear(); + if (!ret) { + return std::nullopt; // exception left pending for caller to capture + } + if (ret.Get() == Py_None) { return std::nullopt; } return from_python(ret.Get()); @@ -908,8 +984,11 @@ GlobalResolver make_global_resolver(PyObject *obj) { SharedPyPtr gr(obj); return [gr](RawEntityId eid) -> std::optional { SharedPyPtr ret(PyObject_CallFunction(gr.Get(), "(K)", eid)); - if (!ret || ret.Get() == Py_None) { - PyErr_Clear(); + if (!ret) { + // Python exception pending — leave it; resolve_global will capture it. + return std::nullopt; + } + if (ret.Get() == Py_None) { return std::nullopt; } if (!PyTuple_Check(ret.Get()) || PyTuple_Size(ret.Get()) < 3) { @@ -950,19 +1029,41 @@ FunctionAddressResolver make_func_addr_resolver(PyObject *obj) { SharedPyPtr fr(obj); return [fr](RawEntityId eid) -> std::optional { SharedPyPtr ret(PyObject_CallFunction(fr.Get(), "(K)", eid)); - if (!ret || ret.Get() == Py_None) { - PyErr_Clear(); + if (!ret) { + return std::nullopt; // exception left pending for caller to capture + } + if (ret.Get() == Py_None) { return std::nullopt; } uint64_t addr = PyLong_AsUnsignedLongLong(ret.Get()); if (PyErr_Occurred()) { - PyErr_Clear(); + PyErr_Clear(); // type error from PyLong_AsUnsignedLongLong, not a user exception return std::nullopt; } return addr; }; } +EntityByAddressResolver make_entity_by_addr_resolver(PyObject *obj) { + if (!obj || obj == Py_None || !PyCallable_Check(obj)) return {}; + SharedPyPtr fn(obj); + return [fn](uint64_t addr) -> RawEntityId { + SharedPyPtr ret(PyObject_CallFunction(fn.Get(), "(K)", addr)); + if (!ret) { + return kInvalidEntityId; // exception left pending for caller to capture + } + if (ret.Get() == Py_None) { + return kInvalidEntityId; + } + uint64_t eid = PyLong_AsUnsignedLongLong(ret.Get()); + if (PyErr_Occurred()) { + PyErr_Clear(); + return kInvalidEntityId; + } + return static_cast(eid); + }; +} + } // namespace PyObject *SymbolicInitState(PyObject *state_obj, PyObject *memory_obj, @@ -970,7 +1071,8 @@ PyObject *SymbolicInitState(PyObject *state_obj, PyObject *memory_obj, PyObject *args_list, PyObject *func_resolver_obj, PyObject *global_resolver_obj, - PyObject *func_addr_resolver_obj) { + PyObject *func_addr_resolver_obj, + PyObject *entity_by_addr_resolver_obj) { auto *sw = reinterpret_cast(state_obj); auto *mw = reinterpret_cast(memory_obj); @@ -991,7 +1093,8 @@ PyObject *SymbolicInitState(PyObject *state_obj, PyObject *memory_obj, PythonPolicy policy(py_policy, *mw->memory, make_func_resolver(func_resolver_obj), make_global_resolver(global_resolver_obj), - make_func_addr_resolver(func_addr_resolver_obj)); + make_func_addr_resolver(func_addr_resolver_obj), + make_entity_by_addr_resolver(entity_by_addr_resolver_obj)); PythonScheduler sched; auto &symbolic = install_fresh_symbolic_state(sw); @@ -1007,7 +1110,8 @@ PyObject *SymbolicInitStateFrame(PyObject *state_obj, PyObject *memory_obj, PyObject *return_addr_obj, PyObject *func_resolver_obj, PyObject *global_resolver_obj, - PyObject *func_addr_resolver_obj) { + PyObject *func_addr_resolver_obj, + PyObject *entity_by_addr_resolver_obj) { auto *sw = reinterpret_cast(state_obj); auto *mw = reinterpret_cast(memory_obj); @@ -1038,7 +1142,8 @@ PyObject *SymbolicInitStateFrame(PyObject *state_obj, PyObject *memory_obj, PythonPolicy policy(py_policy, *mw->memory, make_func_resolver(func_resolver_obj), make_global_resolver(global_resolver_obj), - make_func_addr_resolver(func_addr_resolver_obj)); + make_func_addr_resolver(func_addr_resolver_obj), + make_entity_by_addr_resolver(entity_by_addr_resolver_obj)); PythonScheduler sched; auto &symbolic = install_fresh_symbolic_state(sw); @@ -1056,7 +1161,8 @@ PyObject *SymbolicInitStateAt(PyObject *state_obj, PyObject *memory_obj, PyObject *value_seed_dict, PyObject *func_resolver_obj, PyObject *global_resolver_obj, - PyObject *func_addr_resolver_obj) { + PyObject *func_addr_resolver_obj, + PyObject *entity_by_addr_resolver_obj) { auto *sw = reinterpret_cast(state_obj); auto *mw = reinterpret_cast(memory_obj); @@ -1106,7 +1212,8 @@ PyObject *SymbolicInitStateAt(PyObject *state_obj, PyObject *memory_obj, PythonPolicy policy(py_policy, *mw->memory, make_func_resolver(func_resolver_obj), make_global_resolver(global_resolver_obj), - make_func_addr_resolver(func_addr_resolver_obj)); + make_func_addr_resolver(func_addr_resolver_obj), + make_entity_by_addr_resolver(entity_by_addr_resolver_obj)); PythonScheduler sched; auto &symbolic = install_fresh_symbolic_state(sw); @@ -1121,7 +1228,8 @@ PyObject *SymbolicStep(PyObject *state_obj, PyObject *memory_obj, PyObject *py_policy, uint64_t max_steps, PyObject *func_resolver_obj, PyObject *global_resolver_obj, - PyObject *func_addr_resolver_obj) { + PyObject *func_addr_resolver_obj, + PyObject *entity_by_addr_resolver_obj) { auto *sw = reinterpret_cast(state_obj); auto *mw = reinterpret_cast(memory_obj); @@ -1135,7 +1243,8 @@ PyObject *SymbolicStep(PyObject *state_obj, PyObject *memory_obj, PythonPolicy policy(py_policy, *mw->memory, make_func_resolver(func_resolver_obj), make_global_resolver(global_resolver_obj), - make_func_addr_resolver(func_addr_resolver_obj)); + make_func_addr_resolver(func_addr_resolver_obj), + make_entity_by_addr_resolver(entity_by_addr_resolver_obj)); PythonScheduler sched; bool budget_hit = interp_step( @@ -1197,6 +1306,14 @@ PyObject *SymbolicStep(PyObject *state_obj, PyObject *memory_obj, PyBool_FromLong(mc->is_write() ? 1 : 0)); PyDict_SetItemString(result_dict, "result", result_tuple); Py_XDECREF(result_tuple); + } else if (auto *sc = dynamic_cast *>(first)) { + PyObject *sel_obj = sc->selector().Get(); + if (!sel_obj) sel_obj = Py_None; + PyObject *result_tuple = Py_BuildValue( + "(sOK)", "switch", sel_obj, + static_cast(sc->selector_eid())); + PyDict_SetItemString(result_dict, "result", result_tuple); + Py_XDECREF(result_tuple); } else { PyObject *desc = PyUnicode_FromString(first->describe().c_str()); PyObject *result_tuple = Py_BuildValue("(sO)", "suspended", desc); @@ -1238,6 +1355,70 @@ PyObject *SymbolicStep(PyObject *state_obj, PyObject *memory_obj, PyList_Append(forks_list, dict); Py_DECREF(dict); } + } else if (auto *sc = dynamic_cast *>(cont.get())) { + auto snap = sc->snapshot(); + if (!snap) continue; + // One fork entry per switch suspension; the driver walks + // `cases` + `default_block_eid` and clones the snapshot per + // child. The snapshot is stashed in `state` (already a fresh + // wrapper, not cloned — the driver clones via _interp.clone_state). + PyObject *state_obj = MakeSymbolicStateWrapper(snap->clone()); + PyObject *dict = PyDict_New(); + PyDict_SetItemString(dict, "state", state_obj); + Py_XDECREF(state_obj); + PyObject *kind_str = PyUnicode_FromString("switch"); + PyDict_SetItemString(dict, "kind", kind_str); + Py_DECREF(kind_str); + PyObject *sel_obj = sc->selector().Get(); + if (!sel_obj) sel_obj = Py_None; + Py_INCREF(sel_obj); + PyDict_SetItemString(dict, "selector", sel_obj); + Py_DECREF(sel_obj); + PyObject *eid_obj = PyLong_FromUnsignedLongLong(sc->selector_eid()); + PyDict_SetItemString(dict, "selector_eid", eid_obj); + Py_DECREF(eid_obj); + PyObject *cases_list = PyList_New(0); + for (const auto &c : sc->cases()) { + uint64_t target_eid = EntityId(c.target_block.id()).Pack(); + PyObject *block_obj = ::mx::to_python(c.target_block); + if (!block_obj) { + Py_INCREF(Py_None); + block_obj = Py_None; + } + // (low, high, target_block_eid, target_block_obj). Driver uses + // the eid for events / path-condition records and the IRBlock + // object for `_interp.resume_switch_case`. + PyObject *case_tuple = Py_BuildValue( + "(LLKN)", static_cast(c.low), + static_cast(c.high), + target_eid, block_obj); // 'N' steals block_obj + PyList_Append(cases_list, case_tuple); + Py_DECREF(case_tuple); + } + PyDict_SetItemString(dict, "cases", cases_list); + Py_DECREF(cases_list); + uint64_t default_eid = EntityId(sc->default_block().id()).Pack(); + if (default_eid != 0) { + PyObject *def_eid = PyLong_FromUnsignedLongLong(default_eid); + PyDict_SetItemString(dict, "default_block_eid", def_eid); + Py_DECREF(def_eid); + PyObject *def_block = ::mx::to_python(sc->default_block()); + if (!def_block) { + Py_INCREF(Py_None); + def_block = Py_None; + } + PyDict_SetItemString(dict, "default_block", def_block); + Py_DECREF(def_block); + } else { + Py_INCREF(Py_None); + PyDict_SetItemString(dict, "default_block_eid", Py_None); + Py_DECREF(Py_None); + Py_INCREF(Py_None); + PyDict_SetItemString(dict, "default_block", Py_None); + Py_DECREF(Py_None); + } + PyList_Append(forks_list, dict); + Py_DECREF(dict); } else if (auto *mc = dynamic_cast *>(cont.get())) { auto snap = mc->snapshot(); if (!snap) continue; diff --git a/bindings/Python/SymbolicInterpreter.h b/bindings/Python/SymbolicInterpreter.h index d60cbc718..a30689824 100644 --- a/bindings/Python/SymbolicInterpreter.h +++ b/bindings/Python/SymbolicInterpreter.h @@ -78,6 +78,16 @@ struct PythonScheduler : Scheduler { std::move(snapshot), std::move(cond), cond_eid, tb, fb, std::move(false_val), std::move(true_val))); } + + void on_switch(SharedPyPtr selector, RawEntityId sel_eid, + std::vector cases, + IRBlock default_block, + ref_t snapshot) { + outcome.continuations.emplace_back( + std::make_unique>( + std::move(snapshot), std::move(selector), sel_eid, + std::move(cases), default_block)); + } }; // =========================================================================== @@ -95,13 +105,19 @@ struct PythonScheduler : Scheduler { using FunctionAddressResolver = std::function(RawEntityId)>; +// Reverse-direction resolver: virtual address → entity id. +// Called by exec_call to map an indirect callee address back to a +// known declaration. Returns kInvalidEntityId on miss. +using EntityByAddressResolver = std::function; + class PythonPolicy : public Policy { public: PythonPolicy(PyObject *py_policy, ConcreteMemory &memory, FunctionResolver func_resolver = {}, GlobalResolver global_resolver = {}, - FunctionAddressResolver func_addr_resolver = {}); + FunctionAddressResolver func_addr_resolver = {}, + EntityByAddressResolver entity_by_addr_resolver = {}); ~PythonPolicy(); ConcreteMemory &memory(void) { return memory_; } @@ -111,6 +127,10 @@ class PythonPolicy std::optional extract_address(const SharedPyPtr &val); int64_t extract_int(const SharedPyPtr &val); uint64_t extract_uint(const SharedPyPtr &val); + // Concreteness-preserving variant: returns nullopt for non-PyLong + // values so callers (e.g. decide_switch) can fork on symbolic ints + // instead of silently treating them as 0. + std::optional try_extract_int_impl(const SharedPyPtr &val); SharedPyPtr make_literal_int(int64_t v, uint8_t width = 8); SharedPyPtr make_literal_ptr(uint64_t addr); SharedPyPtr make_default(); @@ -225,18 +245,36 @@ class PythonPolicy const IRInstruction &call_inst, RawEntityId target_eid, RawEntityId indirect_target_eid, + uint64_t target_addr, const std::vector &arguments, bool is_indirect, CallResolution &resolution); bool resolve_global(PythonScheduler &sched, RawEntityId entity_id, GlobalResolution &resolution); + // Reverse-direction resolver: address → entity id. + RawEntityId entity_for_address_impl(uint64_t addr) { + if (entity_by_addr_resolver_) { + RawEntityId result = entity_by_addr_resolver_(addr); + if (PyErr_Occurred()) { + capture_exception(); + return kInvalidEntityId; + } + return result; + } + return kInvalidEntityId; + } + // Phase 9: per-function address invention. Falls back to nullopt // (substrate auto-allocates) when no resolver is wired in. std::optional address_for_function_impl(PythonScheduler &, RawEntityId eid) { if (func_addr_resolver_) { - return func_addr_resolver_(eid); + auto addr = func_addr_resolver_(eid); + if (PyErr_Occurred()) { + capture_exception(); + } + return addr; } return std::nullopt; } @@ -275,12 +313,24 @@ class PythonPolicy } void on_instruction_impl_inner(const IRInstruction &inst); + // Fires when a GLOBAL_INITIALIZER / THREAD_LOCAL_INITIALIZER frame + // returns. Fans out to the Python policy's `on_global_initialized` + // method (delegated to InterceptorPolicy in dispatch.py). + template + void on_global_initialized_impl(SchedT &, const IRFunction &init_func, + const SharedPyPtr &addr) { + on_global_initialized_impl_inner(init_func, addr); + } + void on_global_initialized_impl_inner(const IRFunction &init_func, + const SharedPyPtr &addr); + private: SharedPyPtr py_policy_; ConcreteMemory &memory_; FunctionResolver func_resolver_; GlobalResolver global_resolver_; FunctionAddressResolver func_addr_resolver_; + EntityByAddressResolver entity_by_addr_resolver_; PyObject *cached_make_const_{nullptr}; PyObject *cached_binary_op_{nullptr}; @@ -299,6 +349,7 @@ class PythonPolicy PyObject *cached_symbolic_store_{nullptr}; PyObject *cached_on_enter_block_{nullptr}; PyObject *cached_on_instruction_{nullptr}; + PyObject *cached_on_global_initialized_{nullptr}; // Pending exception state. Captured when a Python hook raises so the // interpreter loop can exit cleanly and SymbolicStep can re-raise it. @@ -319,14 +370,16 @@ PyObject *SymbolicInitState(PyObject *state_obj, PyObject *memory_obj, PyObject *args_list, PyObject *func_resolver_obj, PyObject *global_resolver_obj, - PyObject *func_addr_resolver_obj); + PyObject *func_addr_resolver_obj, + PyObject *entity_by_addr_resolver_obj); PyObject *SymbolicInitStateFrame(PyObject *state_obj, PyObject *memory_obj, PyObject *py_policy, PyObject *func_obj, PyObject *param_addrs_list, PyObject *return_addr_obj, PyObject *func_resolver_obj, PyObject *global_resolver_obj, - PyObject *func_addr_resolver_obj); + PyObject *func_addr_resolver_obj, + PyObject *entity_by_addr_resolver_obj); // Mid-block entry: start at a chosen IRBlock with a caller-supplied seed // of live-in values (dict mapping eid -> Python value). Symex driver uses // this for under-constrained execution that begins partway through a @@ -339,12 +392,14 @@ PyObject *SymbolicInitStateAt(PyObject *state_obj, PyObject *memory_obj, PyObject *value_seed_dict, PyObject *func_resolver_obj, PyObject *global_resolver_obj, - PyObject *func_addr_resolver_obj); + PyObject *func_addr_resolver_obj, + PyObject *entity_by_addr_resolver_obj); PyObject *SymbolicStep(PyObject *state_obj, PyObject *memory_obj, PyObject *py_policy, uint64_t max_steps, PyObject *func_resolver_obj, PyObject *global_resolver_obj, - PyObject *func_addr_resolver_obj); + PyObject *func_addr_resolver_obj, + PyObject *entity_by_addr_resolver_obj); // Register the private `_SymbolicState` PyTypeObject into the // interpreter submodule. Called once during module init. diff --git a/bindings/Python/symex/__init__.py b/bindings/Python/symex/__init__.py index d5e6c0fe3..bb1645cf2 100644 --- a/bindings/Python/symex/__init__.py +++ b/bindings/Python/symex/__init__.py @@ -15,7 +15,7 @@ from .lens import MemView, ArgsView, LocalsView from .ctx import Ctx from .path import Path -from .events import EventLog +from .events import EventLog, StopNow from .until import ExploreUntil from .engine import SymExEngine, PathSet from .dispatch import InterceptorPolicy, SymExpr @@ -39,6 +39,7 @@ "Ctx", "Path", "EventLog", + "StopNow", "PathSet", "ExploreUntil", "SymExEngine", diff --git a/bindings/Python/symex/_types.py b/bindings/Python/symex/_types.py new file mode 100644 index 000000000..3f09d2a8f --- /dev/null +++ b/bindings/Python/symex/_types.py @@ -0,0 +1,27 @@ +# Copyright (c) 2026-present, Trail of Bits, Inc. +# +# This source code is licensed in accordance with the terms specified in +# the LICENSE file found in the root directory of this source tree. + +"""Shared type-check constants and leaf-level enums. + +Leaf module — no imports from the symex package so any file can import +these without triggering circular-import issues. +""" + +from enum import StrEnum + +# Pre-built type tuples for isinstance checks. +_BYTES_TYPES = (bytes, bytearray, memoryview) +_INT_TYPES = (int, bool) +_SEQ_TYPES = (list, tuple) + + +class Endian(StrEnum): + """Byte order for Python-side memory helpers (`to_bytes` / `from_bytes`). + + The string values match Python's `int.to_bytes(..., byteorder=...)` + argument so they can be passed through directly. + """ + LITTLE = "little" + BIG = "big" diff --git a/bindings/Python/symex/ctx.py b/bindings/Python/symex/ctx.py index 153d26b76..cdd7a70d4 100644 --- a/bindings/Python/symex/ctx.py +++ b/bindings/Python/symex/ctx.py @@ -17,7 +17,7 @@ `ctx.stop_path()` and then return `ctx.default()` (or any value). """ -from .events import Terminal +from .events import Terminal, StopNow class Ctx: @@ -56,3 +56,19 @@ def stop_path(self): """ if self.path is not None: self.path.terminal = Terminal.STOPPED + + def stop_now(self, terminal=Terminal.STOPPED): + """Halt the current interpreter slice immediately. + + Sets `path.terminal` to `terminal` and raises `StopNow`, which + the C++ interpreter catches to clear the work stack. The engine + driver intercepts `StopNow` after `_interp.step()` returns and + treats the path as stopped — no further slices are taken. + + Works from any hook type (intercept or observe). Unlike + `stop_path()`, execution does not continue to the end of the + current slice. + """ + if self.path is not None: + self.path.terminal = terminal + raise StopNow() diff --git a/bindings/Python/symex/dispatch.py b/bindings/Python/symex/dispatch.py index b0f135792..13d70e38b 100644 --- a/bindings/Python/symex/dispatch.py +++ b/bindings/Python/symex/dispatch.py @@ -36,7 +36,7 @@ from .events import ( MEMORY_READ, MEMORY_WRITE, SYMBOLIC_LOAD, SYMBOLIC_STORE, - GLOBAL_READ, GLOBAL_WRITE, + GLOBAL_READ, GLOBAL_WRITE, GLOBAL_INITIALIZED, CALL, INDIRECT_CALL, BRANCH, LOOP, CONCRETIZE, BLOCK_ENTER, INSTRUCTION, @@ -105,14 +105,29 @@ class _Selector: """ __slots__ = ("addr_range", "name", "eid", "func", "block", "region", - "_layout", "_resolved_range", "kind", "target_kind") + "_layout", "_resolved_range", "kind", "target_kind", "decl") def __init__(self, addr_range=None, name=None, eid=None, func=None, block=None, region=None, layout=None, - kind=None, target_kind=None): + kind=None, target_kind=None, decl=None): self.addr_range = addr_range self.name = name + # Allow `decl=` to substitute for both name and eid: + # we extract them up front so matches_name / matches_eid still + # do their cheap equality checks at dispatch time. + if decl is not None: + try: + if name is None: + self.name = str(decl.name) if decl.name else None + except (AttributeError, TypeError): + pass + if eid is None: + try: + eid = int(decl.id) + except (AttributeError, TypeError): + eid = None self.eid = eid + self.decl = decl self.func = func self.block = block self.region = region @@ -195,6 +210,7 @@ def make_selector(layout, **kwargs): addr_range=kwargs.get("addr_range"), name=kwargs.get("name"), eid=kwargs.get("eid"), + decl=kwargs.get("decl"), func=kwargs.get("func"), block=kwargs.get("block"), region=kwargs.get("region"), @@ -268,80 +284,163 @@ def _build_chain(handlers, default_fn): # ---- per-event default (chain bottom) functions ----------------------- -def _make_default_mem_read(is_float, shadow=None): +from ._types import _BYTES_TYPES, _INT_TYPES, _SEQ_TYPES + +# Pre-built zero-fill byte strings indexed by size (0-16). +_ZERO_BYTES = tuple(bytes(n) for n in range(17)) + + +def _shadow_write(shadow, addr, val, size): + """Decompose a z3 value into per-byte shadow entries (little-endian). + + Each ``shadow[addr + i]`` slot holds a ``BitVec(8)`` extract of ``val`` + (or, when ``val.size() == 8``, ``val`` itself — avoids accumulating + nested ``Extract(7, 0, ...)`` layers when a symbolic byte is + read-and-rewritten). + + Concrete writes erase covered shadow slots (``ConcreteMemory`` is the + source of truth for concrete bytes); the caller is responsible for + writing the actual bytes to ``ConcreteMemory``. Partial overlap is + handled naturally — only touched positions are updated. + """ + z3 = _z3_module() + if z3 is not None and _is_z3(val): + # Simplify before per-byte decomposition: when ``val`` was + # constructed by re-concatenating bytes earlier read from the + # shadow (the common round-trip), z3 collapses the + # ``Extract(i+7, i, Concat(b_n, …, b_0))`` patterns back to the + # original byte expressions. Without this, every load-modify-store + # cycle accumulates Concat/Extract layers around the same bytes. + val = z3.simplify(val) + i = 0 + while i < size: + shadow[addr + i] = _z3_byte_at(val, i) + i += 1 + else: + i = 0 + while i < size: + shadow.pop(addr + i, None) + i += 1 + + +def _shadow_read(shadow, addr, size, data, buf): + """Reconstruct a symbolic value from the byte-granular shadow. + + ``data`` is the already-read concrete bytes for ``[addr, addr+size)``. + Returns ``None`` immediately when no shadow entry overlaps the range + (zero dict lookups beyond the in-check on the fast path). + + For each covered position: shadow-dict hit → ``z3.BitVec(8)`` extract; + miss → ``BitVecVal`` of the concrete byte. ``buf`` is a pre-allocated + list (held on ``Path``) reused across calls. ``z3.simplify`` collapses + a same-width read of a single symbolic write back to the original + variable. + """ + # Fast path: check whether any byte in the range is shadowed. Sparse + # paths see a non-shadowed range and return immediately, so this is + # the no-symbolic-data hot path. + found = False + i = 0 + while i < size: + if (addr + i) in shadow: + found = True + break + i += 1 + if not found: + return None + + z3 = _z3_module() + + while len(buf) < size: + buf.append(None) + + i = 0 + while i < size: + entry = shadow.get(addr + i) + buf[i] = entry if entry is not None else z3.BitVecVal(data[i], 8) + i += 1 + + if size == 1: + return buf[0] + + # Concat is MSB-first; little-endian layout means buf[size-1] is the MSB. + result = buf[size - 1] + j = size - 2 + while j >= 0: + result = z3.Concat(result, buf[j]) + j -= 1 + return z3.simplify(result) + + +def _make_default_mem_read(is_float, shadow=None, buf=None, byte_order="little"): """Return the chain bottom for a memory_read event. - Reads concrete bytes via the lens; for `is_float`, unpacks IEEE - float32 / float64. Returns the loaded value. Falls back to - zero on a read failure (the address has no backing) — matching - the C++ `ConcreteMemory::read` zero-fill semantics — so the - surrounding event still fires and sinks (like `OOBSink`) get - a chance to surface the OOB. - - Phase 8c: when `shadow` (a `(addr, size) -> z3 expr` dict) holds - an exact-match entry for `(addr, size)`, the shadow value wins. - Partial-overlap and width-mismatch hits are out of scope; the - substrate's IR lowering keeps slot widths consistent. + Reads baseline bytes from ``ConcreteMemory`` first (always needed as + fallback), then defers to ``_shadow_read`` which checks the per-path + shadow dict. Shadow entries shadow the baseline; any range with no + shadow entry goes straight to the concrete fallback. Falls back to + zero-fill on a read failure so OOB sinks still fire. """ + _buf = buf if buf is not None else [] + def default(ctx, addr, size): - if shadow is not None: - cached = shadow.get((addr, size)) - if cached is not None and _is_z3(cached): - return cached try: data = ctx.mem.read_bytes(addr, size) except RuntimeError: - data = b"\x00" * size + data = _ZERO_BYTES[size] if size < len(_ZERO_BYTES) else bytes(size) + if shadow is not None: + sym = _shadow_read(shadow, addr, size, data, _buf) + if sym is not None: + return sym if is_float and size == 4: - return _struct.unpack("f" + return _struct.unpack(fmt, data)[0] if is_float and size == 8: - return _struct.unpack("d" + return _struct.unpack(fmt, data)[0] + return int.from_bytes(data, byte_order, signed=False) return default -def _make_default_mem_write(is_float, shadow=None): +def _make_default_mem_write(is_float, shadow=None, byte_order="little"): """Return the chain bottom for a memory_write event. - Writes concrete bytes via the lens. Handles ints, ("ptr", N) - pointer tuples, raw bytes, and IEEE floats. Returns None. - - Phase 8c: a z3 expression written to a concrete address goes - into `shadow[(addr, size)]` (when `shadow` is provided) so a - later exact-match read returns it. Without a shadow the value - is dropped, preserving pre-Phase-8c behavior for tests that - construct a policy without a Path. + Writes concrete bytes via the lens. Handles ints, ("ptr", N) pointer + tuples, raw bytes, and IEEE floats. A z3 write decomposes ``val`` into + per-byte extracts in the shadow dict; a concrete write clears any + covered shadow slots (concrete memory is the source of truth) and + writes the real bytes to ``ConcreteMemory``. """ def default(ctx, addr, val, size): if isinstance(val, bool): val = int(val) if isinstance(val, int): if shadow is not None: - shadow.pop((addr, size), None) + _shadow_write(shadow, addr, val, size) ctx.mem.write_bytes( - addr, val.to_bytes(size, "little", signed=(val < 0))) + addr, val.to_bytes(size, byte_order, signed=(val < 0))) return None if isinstance(val, tuple) and len(val) == 2 and val[0] == VALUE_TAG_PTR: if shadow is not None: - shadow.pop((addr, size), None) + _shadow_write(shadow, addr, val, size) ctx.mem.write_bytes( - addr, int(val[1]).to_bytes(size, "little", signed=False)) + addr, int(val[1]).to_bytes(size, byte_order, signed=False)) return None - if isinstance(val, (bytes, bytearray)): + if isinstance(val, _BYTES_TYPES): if shadow is not None: - shadow.pop((addr, size), None) - ctx.mem.write_bytes(addr, bytes(val)) + _shadow_write(shadow, addr, val, size) + ctx.mem.write_bytes(addr, val) return None if isinstance(val, float): if shadow is not None: - shadow.pop((addr, size), None) - fmt = "" + fmt = f"{prefix}f" if size == 4 else f"{prefix}d" ctx.mem.write_bytes(addr, _struct.pack(fmt, val)) return None if shadow is not None and _is_z3(val): - shadow[(addr, size)] = val + _shadow_write(shadow, addr, val, size) return None - # Symbolic or otherwise unknown without a shadow — drop. return None return default @@ -351,6 +450,17 @@ def _default_call(ctx): return _DEFER +def _default_indirect_call(ctx, target_addr): + """Chain bottom for an indirect_call event — defer to substrate inline. + + Indirect-call hooks receive `target_addr` as a positional argument + (per the design doc); the chain bottom accepts and ignores it so + handlers that forward via `next_hook(ctx, target_addr)` reach a + concrete sentinel. + """ + return _DEFER + + def _default_branch(ctx, condition): """Chain bottom for a branch event. @@ -359,7 +469,7 @@ def _default_branch(ctx, condition): control back to the substrate so it can enumerate edges via a BranchContinuation. """ - if isinstance(condition, (int, bool)): + if isinstance(condition, _INT_TYPES): return condition != 0 return _FORK @@ -427,6 +537,38 @@ def _z3_coerce(v, sample): return None +def _z3_resize(val, target_bits, signed=False): + """Return a z3 BitVec of exactly ``target_bits`` derived from ``val``. + + If ``val`` is already that wide, returned unchanged — z3 doesn't + auto-simplify ``Extract(N-1, 0, BitVec(N))`` or ``ZeroExt(0, …)`` / + ``SignExt(0, …)`` to identity, so wrapping a same-width value in those + ops just bloats every downstream expression. Avoiding the wrap keeps + expressions canonical across repeated read-modify-write cycles. + """ + cur = val.size() + if cur == target_bits: + return val + z3 = _z3_module() + if cur > target_bits: + return z3.Extract(target_bits - 1, 0, val) + return z3.SignExt(target_bits - cur, val) if signed \ + else z3.ZeroExt(target_bits - cur, val) + + +def _z3_byte_at(val, i): + """Extract the i-th byte (8 bits, little-endian) of ``val``. + + Skips the wrap when ``val`` is already exactly one byte wide and + we're asking for byte 0 — the common case for byte-granular shadow + writes of an already-extracted symbolic byte. + """ + if val.size() == 8 and i == 0: + return val + z3 = _z3_module() + return z3.Extract(8 * (i + 1) - 1, 8 * i, val) + + # Per-arity dispatch tables: each entry is `(opcode_range, builder)`. # Builders for ops that need a z3 module function take (z3, a[, b]) so # the cached `_z3_module()` reference flows through; the rest are pure @@ -476,11 +618,28 @@ def _dispatch_op(table, op): return None +def _dispatch_op_sized(table, op): + """Like ``_dispatch_op`` but also returns the operand width in bits. + + Assumes the entry has _8/_16/_32/_64 variants in that order (so a + range of size 4). Used by integer binary/unary dispatch so a + width-mismatched operand pair is reconciled before reaching z3. + """ + for lo_hi, builder in table: + if lo_hi[0] <= op <= lo_hi[1]: + return builder, 8 << (op - lo_hi[0]) + return None + + def _z3_compare(op, lhs, rhs): """Build a z3 boolean from a comparison opcode and two operands. - Returns None if the opcode isn't recognised — caller falls back to - `SymExpr` so analysts who care can still detect propagation. + For BitVec operands, both are resized to a matching width (the + larger of the two) before the op is built — z3 sort-checks + aggressively and an upstream cast may have left the pair mismatched. + FP operands fall through unchanged (already same-width by construction). + + Returns None if the opcode isn't recognised. Result is simplified. """ z3 = _z3_module() builder = _dispatch_op(_COMPARE_TABLE, op) @@ -491,31 +650,89 @@ def _z3_compare(op, lhs, rhs): b = _z3_coerce(rhs, sample) if a is None or b is None: return None - return builder(z3, a, b) + if isinstance(a, z3.BitVecRef) and isinstance(b, z3.BitVecRef) \ + and a.size() != b.size(): + target = max(a.size(), b.size()) + a = _z3_resize(a, target) + b = _z3_resize(b, target) + return z3.simplify(builder(z3, a, b)) def _z3_binary(op, lhs, rhs): - """Build a z3 BitVec from a binary opcode. None on unsupported.""" + """Build a z3 BitVec from a binary opcode. None on unsupported. + + Both operands are coerced to the opcode's declared width before the + op is built — without this, an upstream cast that left an operand + wider/narrower than its companion (or wider/narrower than the + opcode's width) crashes z3 with a sort-mismatch. + Result is simplified so each step in a chain of arithmetic ops stays + canonical. + """ z3 = _z3_module() - builder = _dispatch_op(_BINARY_TABLE, op) - if z3 is None or builder is None: + sized = _dispatch_op_sized(_BINARY_TABLE, op) + if z3 is None or sized is None: return None + builder, bits = sized sample = lhs if _is_z3(lhs) else rhs a = _z3_coerce(lhs, sample) b = _z3_coerce(rhs, sample) if a is None or b is None: return None - return builder(z3, a, b) + a = _z3_resize(a, bits) + b = _z3_resize(b, bits) + return z3.simplify(builder(z3, a, b)) def _z3_unary(op, operand): if not _is_z3(operand): return None z3 = _z3_module() - builder = _dispatch_op(_UNARY_TABLE, op) - if z3 is None or builder is None: + sized = _dispatch_op_sized(_UNARY_TABLE, op) + if z3 is None or sized is None: return None - return builder(z3, operand) + builder, bits = sized + return z3.simplify(builder(z3, _z3_resize(operand, bits))) + + +def _z3_cast(op, operand): + """Lower a CastOp to z3 ops so symbolic values flow through casts. + + Returns None for casts we don't model symbolically (float ↔ int, + F32_TO_F64, etc.); the caller falls back to a SymExpr placeholder + in that case, which keeps the value visible but blocks z3 reasoning + until those casts are wired up. + """ + z3 = _z3_module() + if z3 is None or not isinstance(operand, z3.ExprRef): + return None + name = mx.ir.CastOp(int(op)).name + + # Each cast result is z3.simplify'd — `_z3_resize` may emit a fresh + # Extract/SignExt/ZeroExt node that the simplifier can immediately + # collapse against an existing one (common when a SEXT feeds into a + # TRUNC on the same value). + if name.startswith("SEXT_I") or name.startswith("ZEXT_I"): + # SEXT_I{src}_I{tgt} or ZEXT_I{src}_I{tgt}. Use the operand's + # actual width (not src) — earlier casts may already have + # widened or narrowed it past what the opcode claims. + prefix = "SEXT_I" if name.startswith("SEXT_I") else "ZEXT_I" + parts = name[len(prefix):].split("_I") + _, tgt = int(parts[0]), int(parts[1]) + return z3.simplify(_z3_resize(operand, tgt, signed=(prefix == "SEXT_I"))) + if name.startswith("TRUNC_I"): + # TRUNC_I{src}_I{tgt}. + parts = name[len("TRUNC_I"):].split("_I") + _, tgt = int(parts[0]), int(parts[1]) + return z3.simplify(_z3_resize(operand, tgt)) + if name in ("BITCAST", "IDENTITY"): + return operand + if name in ("PTR_TO_I64", "I64_TO_PTR"): + return z3.simplify(_z3_resize(operand, 64)) + if name == "PTR_TO_I32": + return z3.simplify(_z3_resize(operand, 32)) + if name == "I32_TO_PTR": + return z3.simplify(_z3_resize(operand, 64)) + return None # float casts unhandled; caller produces a SymExpr # ----------------------------------------------------------------------- @@ -546,7 +763,8 @@ def __init__(self, engine, path, *, layout=None, memory=None): self._memory = memory if memory is not None else ( self._layout.memory if self._layout is not None else None) # Lazily built per-step MemView; ArgsView is per-call event. - self._mem_view = MemView(self._memory) if self._memory else None + self._mem_view = (MemView(self._memory, endian=engine.endian) + if self._memory else None) # Phase 8c: shadow map for symbolic values written to concrete # substrate-allocated addresses. Held as a shared reference to # the path's dict so writes and reads survive across steps @@ -556,6 +774,9 @@ def __init__(self, engine, path, *, layout=None, memory=None): self._shadow = getattr(path, "_symbolic_shadow", None) if self._shadow is None: self._shadow = {} + self._shadow_buf = getattr(path, "_shadow_buf", None) + if self._shadow_buf is None: + self._shadow_buf = [] # ------------------------------------------------------------------ # Hook entry points (lookup_method on PyPolicy fires these) @@ -578,17 +799,10 @@ def mem_read(self, addr, size, is_float): MEMORY_READ, lambda sel: sel.matches_addr(addr_int)) chain = _build_chain(handlers, _make_default_mem_read(bool(is_float), - self._shadow)) - try: - value = chain(ctx, addr_int, size_i) - except Exception as exc: # noqa: BLE001 - self._record_handler_error(ctx, MEMORY_READ, exc, - role="intercept") - self._fire_observers(MEMORY_READ, Phase.AFTER, ctx, - addr=addr_int, size=size_i, - is_float=bool(is_float), value=None, - handled=False, region=region_name) - return NotImplemented + self._shadow, + self._shadow_buf, + byte_order=str(self._engine.endian))) + value = chain(ctx, addr_int, size_i) self._fire_observers(MEMORY_READ, Phase.AFTER, ctx, addr=addr_int, size=size_i, @@ -624,12 +838,9 @@ def mem_write(self, addr, val, size, is_float): MEMORY_WRITE, lambda sel: sel.matches_addr(addr_int)) chain = _build_chain(handlers, _make_default_mem_write(bool(is_float), - self._shadow)) - try: - chain(ctx, addr_int, val, size_i) - except Exception as exc: # noqa: BLE001 - self._record_handler_error(ctx, MEMORY_WRITE, exc, - role="intercept") + self._shadow, + byte_order=str(self._engine.endian))) + chain(ctx, addr_int, val, size_i) # Phase 6 invariant: when the region's overlay has been # materialized (some prior symbolic access forced its @@ -714,12 +925,7 @@ def _default(c, a, sz): lambda sel: (sel.matches_region(region_name) and sel.matches_name(region_name))) chain = _build_chain(handlers, _default) - try: - result = chain(ctx, addr, size_i) - except Exception as exc: # noqa: BLE001 - self._record_handler_error(ctx, SYMBOLIC_LOAD, exc, - role="intercept") - return NotImplemented + result = chain(ctx, addr, size_i) if result is NotImplemented: return NotImplemented @@ -781,8 +987,7 @@ def _default(c, a, v, sz): if val_z is None: return NotImplemented for i in range(sz): - byte_i = z3.Extract(8 * i + 7, 8 * i, val_z) - region.store_byte(a + i, byte_i) + region.store_byte(a + i, _z3_byte_at(val_z, i)) return True handlers = self._matching_handlers( @@ -790,12 +995,7 @@ def _default(c, a, v, sz): lambda sel: (sel.matches_region(region_name) and sel.matches_name(region_name))) chain = _build_chain(handlers, _default) - try: - result = chain(ctx, addr, val, size_i) - except Exception as exc: # noqa: BLE001 - self._record_handler_error(ctx, SYMBOLIC_STORE, exc, - role="intercept") - return NotImplemented + result = chain(ctx, addr, val, size_i) if result is NotImplemented: return NotImplemented @@ -839,27 +1039,24 @@ def _coerce_store_value(self, val, size, z3): if isinstance(val, int): return z3.BitVecVal(val & ((1 << bits) - 1), bits) if isinstance(val, float): + byte_order = str(self._engine.endian) + prefix = "<" if byte_order == "little" else ">" if int(size) == 4: - packed = _struct.pack(" 64: - base_z = z3.Extract(63, 0, base_z) - if idx_z.size() < 64: - idx_z = z3.SignExt(64 - idx_z.size(), idx_z) - elif idx_z.size() > 64: - idx_z = z3.Extract(63, 0, idx_z) - return base_z + idx_z * z3.BitVecVal(int(element_size), 64) + base_z = _z3_resize(base_z, 64) + idx_z = _z3_resize(idx_z, 64, signed=True) + return z3.simplify(base_z + idx_z * z3.BitVecVal(int(element_size), 64)) def ptr_diff(self, lhs, rhs, element_size): if _is_concrete(lhs) and _is_concrete(rhs): @@ -1060,6 +1263,34 @@ def on_instruction(self, inst): ctx.inst = inst self._fire_observers(INSTRUCTION, Phase.AFTER, ctx, inst=inst) + def on_global_initialized(self, init_func, addr): + """Fired by the substrate when a GLOBAL_INITIALIZER / + THREAD_LOCAL_INITIALIZER frame returns — i.e. the global's IR + initializer has finished executing. Fans out to + `engine.observe.global_initialized` observers. + + `init_func` is the initializer IRFunction; `init_func.source_declaration` + is the VarDecl. `addr` is the global's virtual address (int).""" + if not self._engine._observers.lookup( + (GLOBAL_INITIALIZED, Phase.AFTER)): + return + decl = init_func.source_declaration if init_func is not None else None + name = None + eid = 0 + if decl is not None: + try: + name = str(decl.name) if decl.name else None + except Exception: + name = None + try: + eid = int(decl.id) + except (AttributeError, TypeError): + eid = 0 + ctx = self._make_ctx() + self._fire_observers(GLOBAL_INITIALIZED, Phase.AFTER, ctx, + init_func=init_func, decl=decl, + name=name, eid=eid, addr=int(addr)) + # ----- truth + branch resolution: fork on non-concrete ----- def is_true(self, val): @@ -1074,7 +1305,7 @@ def is_true(self, val): # handler is interested. if self._engine._intercepts.lookup(BRANCH): return None - if isinstance(val, (int, bool)): + if isinstance(val, _INT_TYPES): return val != 0 return None @@ -1105,18 +1336,14 @@ def _match(sel): return True handlers = self._matching_handlers(BRANCH, _match) - if not handlers and isinstance(condition, (int, bool)) and \ + if not handlers and isinstance(condition, _INT_TYPES) and \ not _is_z3(condition): return condition != 0 # Phase 2 fast path if not handlers: return None # symbolic, no handler — let substrate fork chain = _build_chain(handlers, _default_branch) - try: - chosen = chain(ctx, condition) - except Exception as exc: # noqa: BLE001 - self._record_handler_error(ctx, BRANCH, exc, role="intercept") - chosen = _FORK + chosen = chain(ctx, condition) if chosen is _FORK: return None @@ -1216,9 +1443,11 @@ def _mirror_concrete_write_to_overlay(self, addr, val, size, region_name): elif (isinstance(val, tuple) and len(val) == 2 and val[0] == VALUE_TAG_PTR): int_val = int(val[1]) - elif isinstance(val, (bytes, bytearray)): - for i, b in enumerate(val): + elif isinstance(val, _BYTES_TYPES): + i = 0 + for b in val: region.store_byte(addr + i, int(b) & 0xFF) + i += 1 return else: # Unhandled value shape (z3 expr written through concrete @@ -1238,12 +1467,7 @@ def _fire_observers(self, event, phase, ctx, **payload): for selector, handler in registry.lookup(key): if not _selector_matches_payload(selector, event, payload): continue - try: - handler(ctx, **payload) - except Exception as exc: # noqa: BLE001 — swallow + log - self._record_handler_error(ctx, event, exc, - role=f"observer.{phase}") - continue + handler(ctx, **payload) self._auto_record_event(ctx, event, phase, payload) def _auto_record_event(self, ctx, event, phase, payload): @@ -1262,19 +1486,6 @@ def _auto_record_event(self, ctx, event, phase, payload): entry[k] = v path.events.append(entry) - def _record_handler_error(self, ctx, event, exc, *, role): - path = ctx.path - if path is None: - return - kind = (EventKind.OBSERVER_ERROR if role.startswith("observer") - else EventKind.INTERCEPT_ERROR) - path.events.append({ - "kind": kind, - "event": event, - "role": role, - "error": repr(exc), - }) - def _lookup_name(self, eid): """Best-effort resolution of an entity id to a function name. @@ -1295,7 +1506,7 @@ def _is_concrete(value): policy help: ints, bools, None, and ("ptr", N) tuples.""" if value is None: return True - if isinstance(value, (int, bool)): + if isinstance(value, _INT_TYPES): return True if isinstance(value, tuple) and len(value) == 2 and \ value[0] == VALUE_TAG_PTR: @@ -1322,6 +1533,15 @@ def _selector_matches_payload(selector, event, payload): if not selector.matches_eid(payload.get("eid")): return False return True + if event == GLOBAL_INITIALIZED: + if not selector.matches_name(payload.get("name")): + return False + if not selector.matches_eid(payload.get("eid")): + return False + addr = payload.get("addr") + if addr is not None and not selector.matches_addr(addr): + return False + return True # Phase 9: address_for / address_resolved filter on kind= / name= / eid= if event in (ADDRESS_FOR, ADDRESS_RESOLVED): if not selector.matches_kind(payload.get("kind")): diff --git a/bindings/Python/symex/engine.py b/bindings/Python/symex/engine.py index 88d8a1f4a..07b556b75 100644 --- a/bindings/Python/symex/engine.py +++ b/bindings/Python/symex/engine.py @@ -26,16 +26,19 @@ from .layout import Layout from .lens import MemView, ArgsView from .path import Path, FindingsList +from ._types import Endian from .events import ( - EventLog, EventKind, BRANCH, BranchDirection, Terminal, + EventLog, EventKind, BRANCH, BranchDirection, Terminal, StopNow, StepResultKind, Strategy, _FilterableList, ADDRESS_FOR, ADDRESS_RESOLVED, INDIRECT_CALL_RESOLVED, + SWITCH_CASE, SWITCH_DEFAULT, ) from .until import ExploreUntil from .dispatch import ( InterceptorPolicy, _Registry, SymExpr, _is_z3, _z3_bool, _z3_module, make_selector, ) +from ._types import _SEQ_TYPES from .intercept import InterceptDispatcher from .observe import ObserveDispatcher from .concretize import ( @@ -60,7 +63,7 @@ def __init__(self, paths, total_steps): def _resolve_function(index, name): for fd in mx.ast.FunctionDecl.IN(index): if str(fd.name) == name: - ir = mx.ir.IRFunction.FROM(fd) + ir = _ir_for_decl(fd) if ir is not None: return ir return None @@ -77,6 +80,17 @@ def _func_decl_for(entity): return None +def _ir_for_decl(decl): + """Return the IRFunction for `decl`, falling back to the canonical decl.""" + ir = mx.ir.IRFunction.FROM(decl) + if ir is not None: + return ir + canonical = decl.canonical_declaration + if canonical is not None and canonical is not decl: + ir = mx.ir.IRFunction.FROM(canonical) + return ir + + def _var_decl_for(entity): """Pull a VarDecl out of either a direct decl or a DeclRefExpr.""" if isinstance(entity, mx.ast.VarDecl): @@ -93,7 +107,7 @@ def resolve(eid): fd = _func_decl_for(index.entity(eid)) if fd is None: return None - return mx.ir.IRFunction.FROM(fd) + return _ir_for_decl(fd) return resolve @@ -143,7 +157,7 @@ def resolve(eid): align = align_bits // 8 if align_bits is not None else 8 if align == 0: align = 8 - initializer = mx.ir.IRFunction.FROM(vd) + initializer = _ir_for_decl(vd) return (vd.id, size, align, initializer) return resolve @@ -152,26 +166,44 @@ def _make_global_resolver_with_hints(index, engine): """Like `_make_global_resolver` but returns 5-tuples with an optional address hint consulted before the substrate auto-allocates.""" def resolve(eid): - vd = _var_decl_for(index.entity(eid)) - if vd is None: - return None - ty = vd.type - bits = ty.size_in_bits - size = (bits + 7) // 8 if bits is not None else 0 - align_bits = ty.alignment - align = align_bits // 8 if align_bits is not None else 8 - if align == 0: - align = 8 - initializer = mx.ir.IRFunction.FROM(vd) - canonical_eid = int(vd.id) - is_tls = _detect_tls(vd) - kind = "thread_local" if is_tls else "global" try: - name = str(vd.name) if vd.name else None + entity = index.entity(eid) except Exception: + entity = None + vd = _var_decl_for(entity) + if vd is not None: + ty = vd.type + bits = ty.size_in_bits + size = (bits + 7) // 8 if bits is not None else 0 + align_bits = ty.alignment + align = align_bits // 8 if align_bits is not None else 8 + if align == 0: + align = 8 + initializer = _ir_for_decl(vd) + canonical_eid = int(vd.id) + is_tls = _detect_tls(vd) + kind = "thread_local" if is_tls else "global" + try: + name = str(vd.name) if vd.name else None + except Exception: + name = None + else: + # Entity didn't resolve to a VarDecl (may be a non-AST entity, + # unindexed, or truly anonymous). Still give the interceptor a + # chance to supply an address so that handlers registered with + # `intercept.address_for(kind="global")` without a name filter + # can map the entire global address space. + canonical_eid = int(eid) + size = 0 + align = 8 + initializer = None + kind = "global" name = None addr_hint = engine._resolve_address_for( canonical_eid, name, kind, size, align) + + if vd is None and addr_hint is None: + return None return (canonical_eid, size, align, initializer, addr_hint) return resolve @@ -295,7 +327,11 @@ def _path_match_one(path, key, target): class SymExEngine: def __init__(self, index): self._index = index - self.layout = None + # Target endianness used by every Python-side `to_bytes` / + # `from_bytes` helper. Defaults to little; set explicitly for + # big-endian targets BEFORE creating paths or placing globals. + self.endian: Endian = Endian.LITTLE + self._layout = None self._func_resolver = _make_func_resolver(index) # Phase 9: upgraded to 5-tuple with address hint. self._global_resolver = _make_global_resolver_with_hints(index, self) @@ -332,6 +368,11 @@ def __init__(self, index): # explore call so repeated references to the same entity reuse # the same address (memoized globally, not per-path). self._address_for_cache: dict[int, int] = {} + # Reverse map: virtual address → entity id. Kept in sync with + # _address_for_cache so exec_call can resolve indirect callee + # addresses back to declarations without reading synthetic data + # from the interpreter's flat memory. + self._addr_to_eid_cache: dict[int, int] = {} # Phase 9: value-origin side-table. Mint sites (address_for, # indirect_call, fresh_int) populate this for Phase 10 lineage. self.value_origins: dict[int, dict] = {} @@ -342,6 +383,18 @@ def __init__(self, index): self._current_path = None # Phase 9: func_addr_resolver closure passed to the substrate. self._func_addr_resolver = self._make_func_addr_resolver() + # Reverse resolver: virtual address → entity id. + self._entity_by_addr_resolver = self._make_entity_by_addr_resolver() + + @property + def layout(self): + return self._layout + + @layout.setter + def layout(self, value): + if value is not None: + value.byte_order = str(self.endian) + self._layout = value def _get_cfg(self, ir_func): """Return the cached CFGInfo for `ir_func`, computing it on @@ -371,6 +424,11 @@ def _get_cfg(self, ir_func): # Phase 9: address resolution machinery # ------------------------------------------------------------------ + def _cache_address(self, eid: int, addr: int): + """Record eid↔addr in both the forward and reverse caches.""" + self._address_for_cache[eid] = addr + self._addr_to_eid_cache[addr] = eid + def _make_func_addr_resolver(self): """Build the func_addr_resolver closure passed to the substrate. @@ -388,17 +446,31 @@ def resolve(eid): name = engine._func_name_resolver(eid_i) if name is not None and engine.layout is not None and name in engine.layout: addr = engine.layout[name] - engine._address_for_cache[eid_i] = addr + engine._cache_address(eid_i, addr) engine._fire_address_resolved( eid_i, name, "function", 0, 8, addr, "pre_placed", None) return addr addr = engine._dispatch_address_for(eid_i, name, "function", 0, 8) if addr is not None: - engine._address_for_cache[eid_i] = addr + engine._cache_address(eid_i, addr) return addr return resolve + def _make_entity_by_addr_resolver(self): + """Build the entity_by_addr_resolver closure passed to the substrate. + + Called by exec_call in the C++ interpreter loop to map an indirect + callee virtual address back to a known entity id. Returns the int + entity id or 0 (kInvalidEntityId) on miss. + """ + engine = self + + def resolve(addr): + return engine._addr_to_eid_cache.get(int(addr), 0) + + return resolve + def _resolve_address_for(self, eid: int, name, kind: str, size: int, align: int): """Resolve an address for `eid` by checking the cache, layout @@ -412,13 +484,13 @@ def _resolve_address_for(self, eid: int, name, kind: str, return cached if name is not None and self.layout is not None and name in self.layout: addr = self.layout[name] - self._address_for_cache[eid] = addr + self._cache_address(eid, addr) self._fire_address_resolved( eid, name, kind, size, align, addr, "pre_placed", None) return addr addr = self._dispatch_address_for(eid, name, kind, size, align) if addr is not None: - self._address_for_cache[eid] = addr + self._cache_address(eid, addr) return addr def _dispatch_address_for(self, eid: int, name, kind: str, @@ -456,7 +528,7 @@ def _default_address_for(ctx, eid_, name_, kind_, size_, align_): path = self._current_path layout = self.layout mem = layout.memory if layout is not None else None - mem_view = MemView(mem) if mem is not None else None + mem_view = MemView(mem, endian=self.endian) if mem is not None else None ctx = Ctx( path=path, mem=mem_view, @@ -464,6 +536,15 @@ def _default_address_for(ctx, eid_, name_, kind_, size_, align_): layout=layout, solver=getattr(path, "solver", None), ) + # Stash the decl on ctx so handlers can do `ctx.decl.type` etc. + # without enclosing over the index — Index.entity uses its own + # in-memory cache, so the lookup is cheap. + ctx.decl = None + if eid is not None and int(eid) != 0: + try: + ctx.decl = self._index.entity(int(eid)) + except Exception: + pass handler_name = None if handlers: @@ -474,10 +555,7 @@ def _default_address_for(ctx, eid_, name_, kind_, size_, align_): except Exception: pass - try: - result = chain(ctx, eid, name, kind, size, align) - except Exception: - return None + result = chain(ctx, eid, name, kind, size, align) if result is None: return None @@ -501,7 +579,18 @@ def _fire_address_resolved(self, eid: int, name, kind: str, path = self._current_path layout = self.layout mem = layout.memory if layout is not None else None - mem_view = MemView(mem) if mem is not None else None + mem_view = MemView(mem, endian=self.endian) if mem is not None else None + + # The Index has its own in-memory cache; entity() is cheap. The + # decl is what most analyst code actually wants to operate on, + # so put it directly on ctx alongside the kwargs payload. + decl = None + if eid is not None and int(eid) != 0: + try: + decl = self._index.entity(int(eid)) + except Exception: + decl = None + ctx = Ctx( path=path, mem=mem_view, @@ -509,10 +598,12 @@ def _fire_address_resolved(self, eid: int, name, kind: str, layout=layout, solver=getattr(path, "solver", None), ) + ctx.decl = decl payload = { "eid": eid, "name": name, + "decl": decl, "kind": kind, "addr": addr, "source": source, @@ -525,10 +616,7 @@ def _fire_address_resolved(self, eid: int, name, kind: str, (ADDRESS_RESOLVED, Phase.AFTER)): if not _selector_matches_payload(selector, ADDRESS_RESOLVED, payload): continue - try: - handler(ctx, **payload) - except Exception: - pass + handler(ctx, **payload) if path is not None: path.events.append(dict({"kind": ADDRESS_RESOLVED, @@ -602,7 +690,7 @@ def explore(self, start_func, *, start_block=None, args=None, seed=None, # address space — otherwise a freshly minted Layout() per # call would diverge from the one whose addresses are baked # into a previously-cloned state. - self.layout = Layout() + self.layout = Layout(endian=self.endian) layout = self.layout memory = layout.memory @@ -655,7 +743,7 @@ def explore_many(self, start_funcs, *, args=None, until=None, raise ValueError(f"unknown explore strategy {strategy!r}") if self.layout is None: - self.layout = Layout() + self.layout = Layout(endian=self.endian) layout = self.layout memory = layout.memory @@ -707,7 +795,7 @@ def _resolve_start_many(self, start_funcs): raise TypeError( "explore_many: start_funcs is a bare string; pass a " "list (e.g. [name]) or a regex / predicate") - if isinstance(start_funcs, (list, tuple)): + if isinstance(start_funcs, _SEQ_TYPES): return self._resolve_explicit_list(start_funcs) if callable(start_funcs): return self._resolve_by_predicate(start_funcs) @@ -750,7 +838,7 @@ def _resolve_by_predicate(self, pred): continue if not pred(name): continue - ir = mx.ir.IRFunction.FROM(fd) + ir = _ir_for_decl(fd) if ir is None: continue fid = int(ir.id) @@ -859,14 +947,14 @@ def resume_from(self, snapshot, *, modify=None, Returns the list of paths produced. Used by `Path.replay`. """ - layout = self.layout if self.layout is not None else Layout() + layout = self.layout if self.layout is not None else Layout(endian=self.endian) memory = layout.memory strategy_obj = (_coerce_strategy(concretize) if concretize is not None else self.address_strategy) until_pred = until if until is not None else ExploreUntil.never() fresh_state = _interp.clone_state(snapshot.state) - path = Path(fresh_state, memory, parent_id=parent_id) + path = Path(fresh_state, memory, parent_id=parent_id, endian=self.endian) path.events = EventLog(snapshot.events) path.tags = set(snapshot.tags) path.path_condition = list(snapshot.path_condition) @@ -888,6 +976,9 @@ def resume_from(self, snapshot, *, modify=None, path._tls_shadow = dict(getattr(snapshot, "tls_shadow", {})) # Phase 10 path._origin_by_name = dict(getattr(snapshot, "origin_by_name", {})) + # Phase 15 + path.vars = dict(getattr(snapshot, "vars", {})) + path.shared = dict(getattr(snapshot, "shared", {})) if modify is not None: modify(path) @@ -924,7 +1015,8 @@ def _init_path(self, ir_func, memory, policy, *, args, start_block, if start_block is None: _interp.init_state(state, memory, policy, ir_func, list(args), self._func_resolver, self._global_resolver, - self._func_addr_resolver) + self._func_addr_resolver, + self._entity_by_addr_resolver) else: block = self._resolve_block(ir_func, start_block) param_addrs = self._allocate_param_slots(ir_func, memory, args) @@ -933,11 +1025,12 @@ def _init_path(self, ir_func, memory, policy, *, args, start_block, state, memory, policy, ir_func, block, param_addrs, None, seed_dict, self._func_resolver, self._global_resolver, - self._func_addr_resolver) + self._func_addr_resolver, + self._entity_by_addr_resolver) finally: self._current_path = prev_path - path = Path(state, memory) + path = Path(state, memory, endian=self.endian) path._func_name = self._function_name(ir_func) path._layout = self.layout path.entry_func = ir_func @@ -987,7 +1080,8 @@ def _allocate_param_slots(self, ir_func, memory, args): fd = ir_func.declaration if fd is None: return addrs - for i, p in enumerate(fd.parameters): + i = 0 + for p in fd.parameters: ty = p.type bits = ty.size_in_bits size = max(1, (bits + 7) // 8) if bits is not None else 8 @@ -1001,8 +1095,9 @@ def _allocate_param_slots(self, ir_func, memory, args): val = int(val) if isinstance(val, int): memory.write_bytes( - addr, val.to_bytes(size, "little", + addr, val.to_bytes(size, str(self.endian), signed=(val < 0))) + i += 1 return addrs def _step_one(self, path, memory, policy, slice_steps, concretize): @@ -1010,7 +1105,13 @@ def _step_one(self, path, memory, policy, slice_steps, concretize): try: out = _interp.step(path.state, memory, policy, slice_steps, self._func_resolver, self._global_resolver, - self._func_addr_resolver) + self._func_addr_resolver, + self._entity_by_addr_resolver) + except StopNow: + self._current_path = None + if path.terminal is None: + path.terminal = Terminal.STOPPED + return [path] finally: self._current_path = None @@ -1037,6 +1138,8 @@ def _step_one(self, path, memory, policy, slice_steps, concretize): return [path] if kind == StepResultKind.BRANCH: return self._handle_branch_forks(path, result, forks) + if kind == StepResultKind.SWITCH: + return self._handle_switch_forks(path, result, forks) if kind == StepResultKind.SUSPENDED: sub_kind = forks[0].get("sub_kind") if forks else None if sub_kind == "call-addr": @@ -1077,6 +1180,104 @@ def _handle_branch_forks(self, path, result, forks): children.append(child) return children + def _handle_switch_forks(self, path, result, forks): + """Realize a symbolic-SWITCH suspension as one forked path per + case (and one for the default block). Each child gets a + path-condition constraint pinning the selector into that case's + range. Infeasible children are dropped.""" + if not forks: + path.terminal = Terminal.STUCK_BRANCH + return [path] + + fork = forks[0] + sel = fork.get("selector") + sel_eid = int(fork.get("selector_eid") or 0) + cases = fork.get("cases") or [] + default_block_obj = fork.get("default_block") + default_eid = fork.get("default_block_eid") + snapshot_state = fork["state"] + + sel_is_z3 = _is_z3(sel) + z3 = _z3_module() if sel_is_z3 else None + + children = [] + first_state_used = [False] + + def _take_state(): + if not first_state_used[0]: + first_state_used[0] = True + return snapshot_state + return _interp.clone_state(snapshot_state) + + def _bv_val(v): + return z3.BitVecVal(int(v), sel.size()) + + def _feasible(child): + if z3 is None: + return True + return child.solver.solver.check() == z3.sat + + for low, high, target_eid, target_block_obj in cases: + child_state = _take_state() + _interp.resume_switch_case(child_state, target_block_obj) + child = self._fork_child(path, child_state) + if sel_is_z3: + if low == high: + child.path_condition.append(sel == _bv_val(low)) + else: + child.path_condition.append( + z3.And(sel >= _bv_val(low), + sel <= _bv_val(high))) + child.solver.invalidate() + if not _feasible(child): + continue + child.events.append({ + "kind": SWITCH_CASE, + "selector_eid": sel_eid, + "low": low, + "high": high, + "target_block": target_eid, + "step": path.steps, + }) + children.append(child) + + if default_block_obj is not None and default_eid is not None: + child_state = _take_state() + _interp.resume_switch_case(child_state, default_block_obj) + child = self._fork_child(path, child_state) + if sel_is_z3 and cases: + disjuncts = [] + for low, high, _eid, _blk in cases: + if low == high: + disjuncts.append(sel == _bv_val(low)) + else: + disjuncts.append( + z3.And(sel >= _bv_val(low), + sel <= _bv_val(high))) + child.path_condition.append(z3.Not(z3.Or(*disjuncts))) + child.solver.invalidate() + if _feasible(child): + child.events.append({ + "kind": SWITCH_DEFAULT, + "selector_eid": sel_eid, + "target_block": default_eid, + "step": path.steps, + }) + children.append(child) + else: + child.events.append({ + "kind": SWITCH_DEFAULT, + "selector_eid": sel_eid, + "target_block": default_eid, + "step": path.steps, + }) + children.append(child) + + if not children: + path.terminal = Terminal.INFEASIBLE + return [path] + return children + def _handle_suspension(self, path, result, forks, strategy): if not forks: path.suspended = result @@ -1164,7 +1365,7 @@ def _handle_symbolic_indirect_call(self, path, result, forks, concretize): layout = self.layout mem = layout.memory if layout is not None else None - mem_view = MemView(mem) if mem is not None else None + mem_view = MemView(mem, endian=self.endian) if mem is not None else None ctx = Ctx( path=path, mem=mem_view, @@ -1197,18 +1398,14 @@ def _default_indirect(ctx_, target_expr_): except Exception: pass - try: - chosen = chain(ctx, addr_expr) - except Exception: - path.terminal = Terminal.UNRESOLVED_CALL - return [path] + chosen = chain(ctx, addr_expr) if chosen is None or chosen is _DEFER: path.terminal = Terminal.UNRESOLVED_CALL return [path] # Normalize to a list of candidates. - if not isinstance(chosen, (list, tuple)): + if not isinstance(chosen, _SEQ_TYPES): candidates = [chosen] else: candidates = list(chosen) @@ -1250,7 +1447,8 @@ def take_state(): return _interp.clone_state(fork_entry["state"]) z3 = _z3_module() - for fork_idx, addr in enumerate(resolved_addrs): + fork_idx = 0 + for addr in resolved_addrs: child_state = take_state() _interp.resume_addr(child_state, address_eid, addr) child = self._fork_child(path, child_state) @@ -1272,6 +1470,7 @@ def take_state(): "step": path.steps, }) children.append(child) + fork_idx += 1 return children @@ -1513,7 +1712,8 @@ def _fork_child(self, parent, child_state): """Build a fresh Path that inherits everything from `parent` except its interpreter state. Used by branch and suspension fork handling so the propagation rules stay in one place.""" - child = Path(child_state, parent.mem, parent_id=parent.id) + child = Path(child_state, parent.mem, parent_id=parent.id, + endian=parent._byte_order) child.events = EventLog(parent.events) child.tags = set(parent.tags) child.path_condition = list(parent.path_condition) @@ -1522,10 +1722,19 @@ def _fork_child(self, parent, child_state): child._layout = parent._layout child.entry_func = parent.entry_func child.solver.adopt_fresh_vars(parent.solver._fresh_vars) + # Phase 8c: inherit the byte-granular symbolic shadow. Without + # this the child sees parent's stamped 0xCD sentinel bytes in + # shared concrete memory but has no shadow entry to resolve them + # — every symbolic byte the parent wrote shows up as a literal + # 0xCD on the child, which is the wrong load value. + child._symbolic_shadow = dict(parent._symbolic_shadow) # Phase 9: inherit TLS base and shadow (isolation via per-path shadow). child.tls_base = parent.tls_base child._tls_shadow = dict(parent._tls_shadow) # Phase 10: inherit provenance table so forked paths know the # origins of all symbolic inputs minted before the fork. child._origin_by_name = dict(parent._origin_by_name) + # Phase 15: per-path copy of vars; shared reference for shared. + child.vars = dict(parent.vars) + child.shared = parent.shared return child diff --git a/bindings/Python/symex/events.py b/bindings/Python/symex/events.py index 7adab4bfa..1fc19aca9 100644 --- a/bindings/Python/symex/events.py +++ b/bindings/Python/symex/events.py @@ -22,17 +22,18 @@ class EventKind(StrEnum): SYMBOLIC_STORE = "symbolic_store" GLOBAL_READ = "global_read" GLOBAL_WRITE = "global_write" + GLOBAL_INITIALIZED = "global_initialized" CALL = "call" INDIRECT_CALL = "indirect_call" BRANCH = "branch" + SWITCH_CASE = "switch_case" + SWITCH_DEFAULT = "switch_default" LOOP = "loop" CONCRETIZE = "concretize" BINARY_OP = "binary_op" MEMADDR_CONCRETIZE = "memaddr_concretize" CONCRETIZATION_TRUNCATED = "concretization_truncated" CONCRETIZATION_INFEASIBLE = "concretization_infeasible" - OBSERVER_ERROR = "observer_error" - INTERCEPT_ERROR = "intercept_error" REGION_MATERIALIZED = "region_materialized" LAZY_BUDGET_EXHAUSTED = "lazy_budget_exhausted" CONSTRAIN_TO_CONCRETE_ADDR = "constrain_to_concrete_addr" @@ -53,9 +54,12 @@ class EventKind(StrEnum): SYMBOLIC_STORE = EventKind.SYMBOLIC_STORE GLOBAL_READ = EventKind.GLOBAL_READ GLOBAL_WRITE = EventKind.GLOBAL_WRITE +GLOBAL_INITIALIZED = EventKind.GLOBAL_INITIALIZED CALL = EventKind.CALL INDIRECT_CALL = EventKind.INDIRECT_CALL BRANCH = EventKind.BRANCH +SWITCH_CASE = EventKind.SWITCH_CASE +SWITCH_DEFAULT = EventKind.SWITCH_DEFAULT LOOP = EventKind.LOOP CONCRETIZE = EventKind.CONCRETIZE BLOCK_ENTER = EventKind.BLOCK_ENTER @@ -68,9 +72,9 @@ class EventKind(StrEnum): ALL_EVENTS = frozenset({ MEMORY_READ, MEMORY_WRITE, SYMBOLIC_LOAD, SYMBOLIC_STORE, - GLOBAL_READ, GLOBAL_WRITE, + GLOBAL_READ, GLOBAL_WRITE, GLOBAL_INITIALIZED, CALL, INDIRECT_CALL, - BRANCH, LOOP, CONCRETIZE, + BRANCH, SWITCH_CASE, SWITCH_DEFAULT, LOOP, CONCRETIZE, BLOCK_ENTER, INSTRUCTION, ADDRESS_FOR, ADDRESS_RESOLVED, INDIRECT_CALL_RESOLVED, @@ -88,6 +92,18 @@ class BranchDirection(StrEnum): UNKNOWN = "?" +class StopNow(BaseException): + """Raise from any hook to halt the current slice immediately. + + The C++ interpreter catches this via the existing exception-propagation + mechanism, clears the work stack, and surfaces it back to the Python + driver. `_step_one` intercepts it and treats the path as stopped. + + Prefer `ctx.stop_now(terminal=...)` over raising this directly so the + terminal value is recorded before the interpreter halts. + """ + + class Terminal(StrEnum): """Reasons a path stops stepping. Strings rather than ints so the legacy `path.terminal == "completed"` checks keep working.""" @@ -115,6 +131,7 @@ class StepResultKind(StrEnum): ERROR = "error" BUDGET = "budget" BRANCH = "branch" + SWITCH = "switch" SUSPENDED = "suspended" @@ -123,6 +140,12 @@ class Strategy(StrEnum): DFS = "dfs" +# Endian lives in `_types` (a leaf module) since it's a fundamental +# constant used by Layout, MemView, Path, and the engine. Re-exported +# here for back-compat with existing `from .events import Endian` users. +from ._types import Endian # noqa: E402,F401 + + class CallAction(StrEnum): """Substrate-facing tag for the second slot of resolve_call's return tuple: ("skip", value) replaces the call with `value`; diff --git a/bindings/Python/symex/layout.py b/bindings/Python/symex/layout.py index b65d5b686..a0c63109e 100644 --- a/bindings/Python/symex/layout.py +++ b/bindings/Python/symex/layout.py @@ -22,6 +22,7 @@ import multiplier as mx from .region import LazyRegion, Region, RegionTable +from ._types import _BYTES_TYPES, Endian _interp = mx.ir.interpret @@ -32,7 +33,7 @@ class Layout: # Function alloc: 0x4000_0000_0000_0000 upward (next_function_address) # Lazy regions: 0x7000_0000_0000_0000 upward (declare_lazy) - def __init__(self, memory=None): + def __init__(self, memory=None, endian: Endian = Endian.LITTLE): """Create a Layout. Parameters @@ -47,6 +48,7 @@ def __init__(self, memory=None): layout = Layout(mem) """ self._memory = memory if memory is not None else _interp.ConcreteMemory() + self._byte_order = str(endian) self._regions = RegionTable() # name -> addr fast lookup for __getitem__ / __contains__. self._by_name: dict[str, Region] = {} @@ -71,6 +73,14 @@ def __init__(self, memory=None): def memory(self): return self._memory + @property + def byte_order(self) -> str: + return self._byte_order + + @byte_order.setter + def byte_order(self, value): + self._byte_order = str(value) + def place_global(self, name, addr, size, init=None, align=8): if name in self._by_name: raise ValueError(f"layout name already in use: {name!r}") @@ -319,7 +329,7 @@ def place_string(self, name: str, value, *, if isinstance(value, str): data = value.encode(encoding) - elif isinstance(value, (bytes, bytearray)): + elif isinstance(value, _BYTES_TYPES): data = bytes(value) else: raise TypeError( @@ -350,9 +360,10 @@ def _write_init(self, name, addr, size, init): init = int(init) if isinstance(init, int): self._memory.write_bytes( - addr, init.to_bytes(size, "little", signed=(init < 0))) + addr, init.to_bytes(size, self._byte_order, + signed=(init < 0))) return - if isinstance(init, (bytes, bytearray)): + if isinstance(init, _BYTES_TYPES): data = bytes(init) if len(data) > size: raise ValueError( diff --git a/bindings/Python/symex/lens.py b/bindings/Python/symex/lens.py index f305f7d69..70a78438d 100644 --- a/bindings/Python/symex/lens.py +++ b/bindings/Python/symex/lens.py @@ -18,6 +18,8 @@ import struct as _struct import multiplier as mx +from ._types import _BYTES_TYPES, _INT_TYPES +from ._types import Endian def _coerce_addr(value): @@ -36,13 +38,18 @@ class MemView: only — it never copies the underlying memory. """ - def __init__(self, memory): + def __init__(self, memory, endian: Endian = Endian.LITTLE): self._memory = memory + self._byte_order = str(endian) @property def memory(self): return self._memory + @property + def byte_order(self) -> str: + return self._byte_order + # ---- read primitives ---------------------------------------------- def __getitem__(self, key): @@ -61,7 +68,7 @@ def read_bytes(self, addr, size): def read_int(self, addr, size, signed=False): data = self._memory.read_bytes(addr, size) - return int.from_bytes(data, "little", signed=signed) + return int.from_bytes(data, self._byte_order, signed=signed) def read_str(self, addr, max=4096, encoding="utf-8"): out = bytearray() @@ -75,8 +82,8 @@ def read_str(self, addr, max=4096, encoding="utf-8"): # ---- write primitives --------------------------------------------- def write(self, addr, value, size=None): - if isinstance(value, (bytes, bytearray)): - self._memory.write_bytes(addr, bytes(value)) + if isinstance(value, _BYTES_TYPES): + self._memory.write_bytes(addr, value) return if isinstance(value, bool): value = int(value) @@ -84,7 +91,8 @@ def write(self, addr, value, size=None): if size is None: raise ValueError("size required when writing an integer") self._memory.write_bytes( - addr, value.to_bytes(size, "little", signed=(value < 0))) + addr, value.to_bytes(size, self._byte_order, + signed=(value < 0))) return raise TypeError( f"MemView.write does not yet handle values of type {type(value)}") @@ -103,7 +111,7 @@ def write_bytes(self, addr, data): "constrain them via path.solver.") except ImportError: pass - self._memory.write_bytes(addr, bytes(data)) + self._memory.write_bytes(addr, data) # ---- struct lens (Phase 2) ---------------------------------------- @@ -130,7 +138,7 @@ def write_struct(self, addr, layout, **fields): value = int(value) self._memory.write_bytes( addr + offset, - int(value).to_bytes(size, "little", signed=signed)) + int(value).to_bytes(size, self._byte_order, signed=signed)) class ArgsView: @@ -241,6 +249,7 @@ def __init__(self, index: mx.Index, ir_func, layout, kinds: str = "all"): self._index = index self._layout = layout self._memory = layout.memory + self._byte_order = getattr(layout, "byte_order", str(Endian.LITTLE)) # name → (inst_id, size_bytes, align_bytes, addr) self._locals: dict[str, tuple[int, int, int, int]] = {} # Symbolic initial values deferred until install_hooks(). @@ -256,11 +265,9 @@ def _discover(self, ir_func, kinds: str): seen_ids: set[int] = set() for block in ir_func.blocks: for inst in block.all_instructions: - # AllocaInst.FROM returns None for non-ALLOCA instructions, - # so this doubles as both the isinstance check and the upcast. - alloca = mx.ir.AllocaInst.FROM(inst) - if alloca is None: + if not isinstance(inst, mx.ir.AllocaInst): continue + alloca = inst inst_id = int(alloca.id) if inst_id in seen_ids: @@ -305,12 +312,12 @@ def __setitem__(self, name: str, value): if _is_z3(value): self._symbolic_inits[name] = value - elif isinstance(value, (int, bool)): + elif isinstance(value, _INT_TYPES): val = int(value) - data = val.to_bytes(size, "little", signed=(val < 0)) + data = val.to_bytes(size, self._byte_order, signed=(val < 0)) self._memory.write_bytes(addr, data) - elif isinstance(value, (bytes, bytearray)): - self._memory.write_bytes(addr, bytes(value)[:size]) + elif isinstance(value, _BYTES_TYPES): + self._memory.write_bytes(addr, value[:size]) else: raise TypeError( f"value must be int, bytes, or z3 expression, got {type(value)}") @@ -365,7 +372,7 @@ def read(self, path, name: str): data = path.mem.read_bytes(addr, size) if data: - return int.from_bytes(data, "little") + return int.from_bytes(data, self._byte_order) return 0 def write(self, path, name: str, value): @@ -380,8 +387,9 @@ def write(self, path, name: str, value): inst_id, size, align, addr = self._locals[name] if _is_z3(value): - path._symbolic_shadow[(addr, size)] = value - elif isinstance(value, (int, bool)): + from .dispatch import _shadow_write + _shadow_write(path._symbolic_shadow, addr, value, size) + elif isinstance(value, _INT_TYPES): val = int(value) path.mem.write(addr, val, size) else: @@ -419,7 +427,8 @@ def dump(self, path=None): init_str = f"" else: data = self._memory.read_bytes(addr, size) - init_str = hex(int.from_bytes(data, "little")) if data else "0x0" + init_str = (hex(int.from_bytes(data, self._byte_order)) + if data else "0x0") print(f" {name:20s} addr=0x{addr:016x} " f"size={size} init={init_str}") diff --git a/bindings/Python/symex/observe.py b/bindings/Python/symex/observe.py index 39f68233f..f61d16594 100644 --- a/bindings/Python/symex/observe.py +++ b/bindings/Python/symex/observe.py @@ -11,10 +11,8 @@ see the resolved decision. Use `engine.observe.before.` for the narrower pre-dispatch window. -Observer dispatch is "fire all in registration order". An exception in -one observer does NOT propagate; the dispatcher records it as an -`observer_error` entry on `path.events` and continues with the next -observer (and the rest of execution). +Observer dispatch is "fire all in registration order". Exceptions +propagate immediately — a buggy observer aborts the step. """ from .dispatch import make_selector diff --git a/bindings/Python/symex/path.py b/bindings/Python/symex/path.py index c3b7e12fe..71ccc54c3 100644 --- a/bindings/Python/symex/path.py +++ b/bindings/Python/symex/path.py @@ -19,6 +19,7 @@ import multiplier as mx from .dispatch import _z3_module +from ._types import _BYTES_TYPES, _INT_TYPES, _SEQ_TYPES, Endian from .events import ( EventLog, BRANCH, BLOCK_ENTER, MEMORY_READ, MEMORY_WRITE, BranchDirection, Terminal, @@ -27,6 +28,10 @@ _interp = mx.ir.interpret +# 256-entry table of single-byte byte strings — avoids any +# allocation when writing one concrete byte at a time. +_BYTE_TABLE = tuple(bytes([i]) for i in range(256)) + _id_counter = [0] @@ -137,10 +142,12 @@ def adopt_fresh_vars(self, fresh_vars): class Path: - def __init__(self, state, mem, *, parent_id=None): + def __init__(self, state, mem, *, parent_id=None, + endian: Endian = Endian.LITTLE): self.id = _next_id() self._state = state self._mem = mem + self._byte_order = str(endian) self._parent_id = parent_id self.events = EventLog() self.tags = set() @@ -171,10 +178,17 @@ def __init__(self, state, mem, *, parent_id=None): # Phase 6: count of LazyRegion materializations charged to # this path so far (engine.lazy_region_budget caps). self._lazy_regions_used = 0 - # Phase 8c: shadow map for symbolic values written to concrete - # substrate-allocated addresses (return slot, ALLOCA/ARG, - # ALLOCA/LOCAL). Keyed on (addr, size) -> z3 expression. + # Phase 8c: byte-granular shadow for symbolic values written to + # concrete addresses. Keyed addr -> z3.BitVec(8) (one entry per + # byte). z3 writes are decomposed into per-byte Extract()s; + # reads reconstruct via Concat so memcpy-style byte-at-a-time + # accesses see the correct symbolic value. self._symbolic_shadow: dict = {} + # Pre-allocated working buffer for _shadow_read reconstruction. + # Grown in-place if a read exceeds its current length; reused + # across all reads within and between steps to avoid per-read + # list allocation. + self._shadow_buf: list = [] # Phase 9: TLS base address for this logical thread. Set by # the engine to layout.tls_base at init time. Inherited by # forks (cloned, then diverge via _tls_shadow). @@ -187,6 +201,12 @@ def __init__(self, state, mem, *, parent_id=None): # record dict. Populated by solver.fresh_int and any engine hook # that mints a named symbolic value (e.g. address_for). self._origin_by_name: dict = {} + # Phase 15: analyst variable bags. + # `vars` — copied on fork; each path gets its own dict. + # `shared` — same dict object across all forks; mutations are + # visible to every path that shares the reference. + self.vars: dict = {} + self.shared: dict = {} @property def state(self): @@ -200,9 +220,14 @@ def mem(self): def steps(self): return self._state.steps + @property + def byte_order(self) -> str: + return self._byte_order + def clone(self): cloned_state = _interp.clone_state(self._state) new_path = Path(cloned_state, self._mem, parent_id=self.id) + new_path._byte_order = self._byte_order new_path.events = EventLog(self.events) new_path.tags = set(self.tags) new_path.path_condition = list(self.path_condition) @@ -218,6 +243,8 @@ def clone(self): new_path._tls_shadow = dict(self._tls_shadow) new_path._origin_by_name = dict(self._origin_by_name) new_path.findings = FindingsList(self.findings) + new_path.vars = dict(self.vars) + new_path.shared = self.shared return new_path def snapshot(self): @@ -251,6 +278,9 @@ def snapshot(self): tls_shadow=dict(self._tls_shadow), # Phase 10 origin_by_name=dict(self._origin_by_name), + # Phase 15 + vars=dict(self.vars), + shared=dict(self.shared), ) def restore(self, snap): @@ -283,6 +313,9 @@ def restore(self, snap): # Phase 10 self._origin_by_name.clear() self._origin_by_name.update(snap.origin_by_name) + # Phase 15 + self.vars = dict(snap.vars) + self.shared = dict(snap.shared) def replay(self, *, modify, engine, slice_steps=1024, concretize=None, until=None): @@ -460,6 +493,97 @@ def is_tainted(self, expr) -> bool: """Return True if any `fresh_int` variable contributes to `expr`.""" return bool(self.taint_sources(expr)) + def _resolve_endian(self, endian, byte_order) -> Endian: + """Pick the effective byte order for a write call. + + Accepts both ``endian=`` and ``byte_order=`` for ergonomic flexibility + (the rest of the codebase mixes the two names — ``Layout(endian=...)`` + and ``Path.byte_order`` for the property). Falls back to the path's + own byte order when neither is given. + """ + if endian is not None and byte_order is not None: + raise TypeError( + "pass either endian= or byte_order=, not both") + choice = endian if endian is not None else byte_order + return choice if choice is not None else self._byte_order + + def write_symbolic(self, addr: int, value, size: int = None, + *, endian: Endian = None, + byte_order: Endian = None) -> None: + """Write a z3 expression into the symbolic shadow at ``addr``. + + ``size`` defaults to ``value.size() // 8`` (the BitVec's byte width). + ``endian`` (or its alias ``byte_order``) defaults to the path's byte + order (``path.byte_order``); pass ``Endian.BIG`` or ``Endian.LITTLE`` + to override per-call. + Raises ``TypeError`` if ``value`` is not a z3 expression. + """ + from .dispatch import _shadow_write, _is_z3, _z3_byte_at + if not _is_z3(value): + raise TypeError( + f"write_symbolic: value must be a z3 expression, " + f"got {type(value).__name__}") + if size is None: + size = value.size() // 8 + bo = self._resolve_endian(endian, byte_order) + if bo == Endian.BIG: + # Byte at addr is the MSB; LSB lands at addr + size - 1. + for i in range(size): + self._symbolic_shadow[addr + i] = _z3_byte_at(value, size - 1 - i) + else: + _shadow_write(self._symbolic_shadow, addr, value, size) + + def write_memory(self, addr: int, value, size: int = None, + *, endian: Endian = None, + byte_order: Endian = None) -> None: + """Write ``value`` to ``addr``, handling symbolic and concrete cases. + + - z3 expression → symbolic shadow + - int / bool → concrete memory; ``size`` is required + - bytes/bytearray → concrete memory; ``size`` is ignored + + ``endian`` (or its alias ``byte_order``) defaults to the path's byte + order (``path.byte_order``); pass ``Endian.BIG`` or ``Endian.LITTLE`` + to override for this write only. For ``bytes`` input the byte order + is moot — bytes are written as-is. + """ + from .dispatch import _shadow_write, _is_z3, _z3_byte_at + bo = self._resolve_endian(endian, byte_order) + if _is_z3(value): + if size is None: + size = value.size() // 8 + if bo == Endian.BIG: + for i in range(size): + self._symbolic_shadow[addr + i] = _z3_byte_at(value, size - 1 - i) + else: + _shadow_write(self._symbolic_shadow, addr, value, size) + elif isinstance(value, _BYTES_TYPES): + self._mem.write_bytes(addr, value) + elif isinstance(value, _SEQ_TYPES): + i = 0 + for elem in value: + if _is_z3(elem): + _shadow_write(self._symbolic_shadow, addr + i, elem, 1) + elif isinstance(elem, _INT_TYPES): + self._mem.write_bytes(addr + i, _BYTE_TABLE[int(elem) & 0xFF]) + else: + raise TypeError( + f"write_memory: element {i} has unsupported type " + f"{type(elem).__name__}; expected int or z3 expression") + i += 1 + elif isinstance(value, _INT_TYPES): + if size is None: + raise ValueError( + "write_memory: size is required when writing an integer") + val = int(value) + self._mem.write_bytes(addr, + val.to_bytes(size, str(bo), + signed=(val < 0))) + else: + raise TypeError( + f"write_memory: unsupported value type {type(value).__name__}; " + f"expected z3 expression, int, or bytes") + def summary(self): """Single human-readable summary of what happened on this path. @@ -588,13 +712,16 @@ class _Snapshot: # Phase 9 "tls_base", "tls_shadow", # Phase 10 - "origin_by_name") + "origin_by_name", + # Phase 15 + "vars", "shared") def __init__(self, *, state, events, tags, path_condition, terminal, return_value, error_kind, loop_iters, func_name, fresh_vars, symbolic_shadow, findings, region_at_suspension, lazy_regions_used, - entry_func, tls_base, tls_shadow, origin_by_name): + entry_func, tls_base, tls_shadow, origin_by_name, + vars, shared): self.state = state self.events = events self.tags = tags @@ -613,3 +740,5 @@ def __init__(self, *, state, events, tags, path_condition, terminal, self.tls_base = tls_base self.tls_shadow = tls_shadow self.origin_by_name = origin_by_name + self.vars = vars + self.shared = shared diff --git a/bindings/Python/symex/region.py b/bindings/Python/symex/region.py index 2b9a53b57..fb894ce13 100644 --- a/bindings/Python/symex/region.py +++ b/bindings/Python/symex/region.py @@ -112,6 +112,7 @@ def __init__(self): # `_bases[i]` is `_regions[i].base`; kept parallel for bisect. self._regions: list[Region] = [] self._bases: list[int] = [] + self._by_name: dict[str, Region] = {} def add(self, region: Region) -> None: """Insert `region`. Raises `ValueError` if it would overlap @@ -133,20 +134,20 @@ def add(self, region: Region) -> None: f"size={region.size}) overlaps {prev.name!r}") self._regions.insert(idx, region) self._bases.insert(idx, region.base) + self._by_name[region.name] = region def remove(self, name: str) -> None: - for i, r in enumerate(self._regions): - if r.name == name: - del self._regions[i] - del self._bases[i] - return - raise KeyError(name) + region = self._by_name.pop(name, None) + if region is None: + raise KeyError(name) + idx = bisect.bisect_left(self._bases, region.base) + while idx < len(self._regions) and self._regions[idx] is not region: + idx += 1 + del self._regions[idx] + del self._bases[idx] def get(self, name: str) -> Optional[Region]: - for r in self._regions: - if r.name == name: - return r - return None + return self._by_name.get(name) def __iter__(self): return iter(self._regions) diff --git a/docs/symex-hook-dispatch-as-work-item.md b/docs/symex-hook-dispatch-as-work-item.md new file mode 100644 index 000000000..f1bb2c800 --- /dev/null +++ b/docs/symex-hook-dispatch-as-work-item.md @@ -0,0 +1,55 @@ +# Future Direction: Hook Dispatch as a First-Class Work-Stack Item + +**Status: NOT IMPLEMENTED — design notes only** + +## Background + +The interpreter work stack is already a two-level structure: + +- **IR level** ("machine code"): `CALL`, `LOAD`, `COND_BRANCH` — the IR opcodes the analyst sees. +- **Work-stack level** ("microcode"): `EXEC_CALL`, `ANALYZE`, `DECIDE_COND_BRANCH` — how each IR op is realized across one or more work items popped by the dispatch loop. + +Hook dispatch (intercept/observe chains) currently lives *inside* `dispatch()` as an inline C++ call into Python. This conflation means there is no clean point at which to suspend the interpreter between "hooks decided what to do about this instruction" and "instruction executes." + +## The Problem + +When a Python hook raises `StopNow` (or any exception), the C++ abort path fires. The item that triggered the hook was already popped from the work stack before `dispatch` was called, so the work stack at abort time represents *everything after the current instruction*, not *the current instruction itself*. Preserving that work stack and re-running gives re-fire semantics: on next entry, the hook fires again for the same instruction. + +Re-fire is the correct and accepted semantic for user-driven resumption — the user makes an explicit choice to resume and can set `path.vars` state to change hook behavior before doing so. The current implementation is fine for this use case. + +The limitation appears when you want suspension *without* re-fire — i.e., to resume past the hook without re-invoking it. That requires hook dispatch to be a distinct, independently resumable work item. + +## Proposed Architecture + +Add a new work-item kind, tentatively `FIRE_HOOKS`, that encapsulates a single hook-dispatch event: + +``` +EXEC_CALL + └─► FIRE_CALL_HOOKS(inst, serialized_args, resolved_name) + └─► EXEC_CALL_INLINE(callee) ← only if hooks deferred +``` + +`FIRE_CALL_HOOKS` is popped, the Python hook chain runs, and: + +- If a hook intercepts (returns a value), `EXEC_CALL_INLINE` is removed and the call result is written. +- If the chain defers, `EXEC_CALL_INLINE` stays and executes next. +- If a hook suspends (raises `StopNow` or similar), the work stack is left with `FIRE_CALL_HOOKS` at the top. On resume, the hook chain runs again from scratch — but the item is the *same* item, so no special re-entry logic is needed. + +The same pattern applies to `FIRE_MEM_READ_HOOKS`, `FIRE_BRANCH_HOOKS`, etc. + +## Benefits + +- Suspension at a hook boundary leaves the work stack pointing exactly at the hook-dispatch item. Resume re-executes the hook chain cleanly with no "was this already fired?" bookkeeping. +- The current work-stack/clear distinction (`abort_requested` clears, user-pause doesn't) becomes unnecessary: suspension is just "don't pop the item." +- Every suspension point in the system has the same shape: a work item that is safe to re-pop. + +## Cost and Re-architecting Scope + +- Hook arguments (args list, resolved name, target eid, etc.) must be serializable into the `WorkItem` struct so they survive across a suspend/resume cycle. Today they are computed inline and discarded. +- `dispatch()` gains new `FIRE_*_HOOKS` cases; the inline hook calls in `exec_call`, `exec_load`, `exec_store`, etc. are replaced by work-item pushes. +- The Python `InterceptorPolicy` stays largely unchanged; `dispatch()` just calls it from a different control-flow site. +- Estimated scope: medium — touches `InterpreterLoop.h`, `SymbolicInterpreter.cpp/h`, and the work-item struct, but not the Python symex layer above. + +## Current Interim Behavior + +`ctx.stop_now()` / `StopNow` uses the existing exception-abort path, which clears the work stack. The path is marked terminal and is not stepped again. Re-fire on a hypothetical user resume is accepted as correct: the user controls when and whether to resume, and sets path state accordingly before doing so. diff --git a/docs/symex-symbolic-switch-plan.md b/docs/symex-symbolic-switch-plan.md new file mode 100644 index 000000000..7b3ae7cd8 --- /dev/null +++ b/docs/symex-symbolic-switch-plan.md @@ -0,0 +1,330 @@ +# Symbolic SWITCH plan + +## Goal + +When a SWITCH instruction's selector is a symbolic value (z3 BitVec), the +substrate currently calls `policy.extract_int(sel)` which returns `0` for +non-concrete values, picks whichever case covers `0` (or the default), and +proceeds down a single path. This loses the other branches. + +Target behavior: fork into one path per case (and one for the default), each +carrying a path-condition constraint that pins the selector into that case's +range. Each case is treated like an entry point on a state-machine — the +substrate hands the driver a list of (range, target_block) pairs and the +driver realizes them as forked paths. + +## Conceptual model + +A SWITCH with cases `[1 → A]`, `[2..5 → B]`, `[7 → C]`, `[default → D]` and +symbolic selector `S` becomes 4 forked paths: + +| path | target | path-condition addition | +|------|--------|---------------------------------| +| 1 | A | `S == 1` | +| 2 | B | `2 <= S && S <= 5` | +| 3 | C | `S == 7` | +| 4 | D | `!(S == 1 \|\| (2 <= S && S <= 5) \|\| S == 7)` | + +A case is treated as part of the instruction stack/state machine: the driver +clones the snapshot, enters the case's target block, then the path's solver +gets the constraint added. Infeasible paths (where the constraint conflicts +with the path's existing condition) are dropped. + +## Files to touch + +### 1. `include/multiplier/IR/Interpret/Continuation.h` + +Add a new continuation type alongside `BranchContinuation`: + +```cpp +struct SwitchCase { + int64_t low; + int64_t high; + IRBlock target_block; +}; + +template +class SwitchContinuation : public Continuation { + ValueT selector_; + RawEntityId sel_eid_; + std::vector cases_; // non-default cases, in source order + IRBlock default_block_; // default — `id().Pack() == 0` if none +public: + SwitchContinuation(ref_t> snap, + ValueT selector, RawEntityId sel_eid, + std::vector cases, + IRBlock default_block); + const ValueT &selector() const { return selector_; } + RawEntityId selector_eid() const { return sel_eid_; } + const std::vector &cases() const { return cases_; } + const IRBlock &default_block() const { return default_block_; } + // Accept(visitor) routing — match the existing pattern for branches. +}; +``` + +### 2. `include/multiplier/IR/Interpret/Policy.h` + +Add a precise concreteness predicate for ints (the existing `extract_int` +silently returns 0 for symbolic values, which is what bit us). Default is +"always concrete" so ConcretePolicy needs no override: + +```cpp +std::optional try_extract_int(const ValueT &val) { + return self().try_extract_int_impl(val); +} +std::optional try_extract_int_impl(const ValueT &val) { + return self().extract_int(val); // Concrete: always succeeds +} +``` + +PythonPolicy overrides to return `std::nullopt` for non-PyLong values +(the way `is_true` already does). + +### 3. `include/multiplier/IR/Interpret/InterpreterLoop.h::decide_switch` + +Replace the unconditional `extract_int` with a `try_extract_int`. On +`std::nullopt` (symbolic), gather all cases + default and emit a switch +continuation: + +```cpp +auto maybe_sel = policy.try_extract_int(sel); +if (maybe_sel) { + int64_t sel_val = *maybe_sel; + // ... existing concrete path: pick matching case or default ... + return; +} + +// Symbolic selector: collect case ranges and emit a switch continuation. +std::vector case_list; +IRBlock default_block{}; +for (auto sc : sw->cases()) { + if (sc.is_default()) { + default_block = sc.target_block(); + } else { + case_list.push_back({sc.low(), sc.high(), sc.target_block()}); + } +} +sched.on_switch(sel, sel_eid, std::move(case_list), + default_block, state.clone()); +state.work_stack.clear(); +``` + +### 4. `include/multiplier/IR/Interpret/Policy.h::Scheduler` + scheduler impls + +Add to the CRTP base: + +```cpp +void on_switch(ValueT selector, RawEntityId sel_eid, + std::vector cases, + IRBlock default_block, + auto &&state) { + self().on_switch(std::move(selector), sel_eid, + std::move(cases), default_block, + std::forward(state)); +} +``` + +`NoOpScheduler::on_switch` (in Policy.h) pushes a `SwitchContinuation` onto +`outcome.continuations`. + +### 5. `bindings/Python/SymbolicInterpreter.h::PythonScheduler` + +Add `on_switch` mirroring the other scheduler hooks — pushes a +`SwitchContinuation` onto `outcome.continuations`. + +### 6. `bindings/Python/SymbolicInterpreter.cpp::SymbolicStep` + +In the existing block that translates `outcome.continuations` into the +result dict's `forks` list, add a case for `SwitchContinuation`. Each fork +entry is a Python dict shaped like: + +```python +{ + "kind": "switch", + "selector": , + "selector_eid": , + "cases": [(low, high, target_block_eid), ...], + "default_block_eid": , + "snapshot": , +} +``` + +`target_block_eid` and `default_block_eid` are the packed entity ids of the +IRBlocks (matching how `branch` forks already report block ids). The snapshot +is opaque to Python — it's handed back to `_interp.resume_from_switch_case` +or similar (see step 7). + +### 7. New substrate entry point: resume at a switch case + +Add a Python-callable to the existing interpreter dispatch (`Interpreter.cpp`) +that takes a switch-fork dict, picks one case (by index, or `-1` for default), +and: + + * clones the snapshot, + * pushes an ENTER_BLOCK work item for the chosen target, + * returns the new state (or wraps it in the engine's existing state + container). + +Two sketches; pick whichever fits the existing API better: + +```python +# Option A — one call per case, like resume_addr does for memory: +_interp.resume_switch_case(snapshot, target_block_eid) + +# Option B — caller passes the case index from the fork dict: +_interp.resume_switch_case(fork_dict, case_index_or_neg1_for_default) +``` + +The constraint addition (`solver.add(low <= sel && sel <= high)`) happens +on the Python driver side, not in the substrate. + +### 8. Driver: `bindings/Python/symex/engine.py` + +In the BFS / DFS fork-handling section (where `fork["kind"] == "branch"` is +handled today), add a `"switch"` arm: + +```python +elif fork["kind"] == "switch": + sel = fork["selector"] + cases = fork["cases"] # list of (lo, hi, target) + default_eid = fork["default_block_eid"] # may be None + + # One child per case. + for low, high, target_eid in cases: + child = self._fork_child_for_switch_case( + path, fork["snapshot"], sel, low, high, target_eid) + if child is not None: # feasibility check passed + queue.append(child) + + # Default branch: constrain selector to be outside every case. + if default_eid is not None: + child = self._fork_child_for_switch_default( + path, fork["snapshot"], sel, cases, default_eid) + if child is not None: + queue.append(child) +``` + +Helpers (new in engine.py): + +```python +def _fork_child_for_switch_case(self, parent, snapshot, sel, low, high, target_eid): + child = self._fork_child(parent, _interp.clone_state(snapshot)) + self._enter_block_in_state(child._state, target_eid) + if _is_z3(sel): + if low == high: + child.solver.add(sel == low) + else: + child.solver.add(z3.And(sel >= low, sel <= high)) + if not child.solver.feasible(): + return None + child.path_condition.append({ + "kind": "switch_case", + "selector_eid": fork["selector_eid"], + "low": low, "high": high, + }) + return child + +def _fork_child_for_switch_default(self, parent, snapshot, sel, cases, default_eid): + child = self._fork_child(parent, _interp.clone_state(snapshot)) + self._enter_block_in_state(child._state, default_eid) + if _is_z3(sel): + for low, high, _ in cases: + if low == high: + child.solver.add(sel != low) + else: + child.solver.add(z3.Or(sel < low, sel > high)) + if not child.solver.feasible(): + return None + child.path_condition.append({ + "kind": "switch_default", + "selector_eid": fork["selector_eid"], + }) + return child +``` + +`_enter_block_in_state` is the small new helper that pushes an ENTER_BLOCK +work item targeting the resolved IRBlock — this lives next to the existing +`_init_path` block-resolving code. + +### 9. Path condition record shape + +`Path.path_condition` is a list of dicts (matching the existing branch +record). Switch-case records use: + +```python +{"kind": "switch_case", "selector_eid": , "low": , "high": } +{"kind": "switch_default", "selector_eid": } +``` + +`Path.condition_str()` should learn to render these as `S == 7` / +`2 <= S <= 5` / `S ∉ {1, 2..5, 7}` for human-readable summaries. + +### 10. Events + +`path.events` should record each switch fork the same way branch forks are +recorded today — one entry per case actually taken. Re-use the existing +`branch` event kind with a `direction` field, or introduce a `switch_case` +event kind. Recommend the latter for clarity: + +```python +{"kind": "switch_case", "selector_eid": , + "low": lo, "high": hi, "target_block": } +{"kind": "switch_default", "selector_eid": , "target_block": } +``` + +### 11. Tests + +`tests/symex/test_phaseN_symbolic_switch.py` (pick the next free phase +number): + + * `test_symbolic_switch_three_cases_plus_default` — symbolic 8-bit + selector, switch with 3 single-value cases + default, expect 4 paths. + * `test_symbolic_switch_range_case` — case `[2..5]` adds the + range constraint correctly; only one path per range. + * `test_symbolic_switch_infeasible_case` — pre-constrain the path's + solver so one case is unreachable; expect that path to be dropped. + * `test_symbolic_switch_no_default` — a switch with no default block + correctly emits only the case forks. + * `test_concrete_switch_unchanged` — concrete selector still picks + exactly one path (regression guard). + +Each test inspects the path count, terminal kinds, and per-path +`path_condition` to confirm the right constraints landed. + +## Order of work + +1. `Continuation.h` — add `SwitchCase` and `SwitchContinuation`. +2. `Policy.h` — add `try_extract_int` (CRTP default + Python override). +3. Scheduler `on_switch` in both base and concrete schedulers. +4. `decide_switch` change. +5. `PythonScheduler::on_switch` and `SymbolicStep` translation. +6. `_interp.resume_switch_case` (or equivalent) Python-callable substrate + entry. +7. Driver fork dispatch + helpers in `engine.py`. +8. `Path.path_condition` recording + `condition_str()` rendering. +9. Tests. + +Each step should leave `python3 -m pytest tests/symex/ -x -q` green; the +last step adds new tests. + +## Notes / pitfalls + +* The selector might be a SymExpr (our Python placeholder) rather than a + real z3 BitVec if some upstream cast hasn't been lowered to z3 yet. + In that case `_is_z3(sel)` is False and the solver constraints are + no-ops — the paths still fork structurally but feasibility checks + always pass. Acceptable as a stop-gap; can be tightened later. + +* `decide_switch` runs *after* the selector is computed (the work-stack + push of `ANALYZE` for the selector at `decide_switch:1518` ensures + this). The continuation captures the selector value, not the operand + — so the symbolic value flows through cleanly without re-evaluation. + +* Snapshot cloning is cheap (shared_ptr / PyObjectRC segments); the cost + per switch is N small clones, not N full state copies. + +* Path-condition feasibility: for paths with no prior symbolic + constraints, every case is feasible and we always fork into N + branches. That's expected — switches with O(256) ranges produce + O(256) paths. If this becomes a problem, add a per-explore cap akin + to `engine.lazy_region_budget`. diff --git a/include/multiplier/IR/Interpret/ConcreteOps.h b/include/multiplier/IR/Interpret/ConcreteOps.h index e3909de95..80ba0e4b7 100644 --- a/include/multiplier/IR/Interpret/ConcreteOps.h +++ b/include/multiplier/IR/Interpret/ConcreteOps.h @@ -115,11 +115,21 @@ MX_EXPORT bool concrete_has_address(const Value &val); class ConcreteMemory; -MX_EXPORT void concrete_write_to_mem(ConcreteMemory &memory, uint64_t address, - const Value &val, size_t size, - bool is_float = false); -MX_EXPORT Value concrete_read_from_mem(ConcreteMemory &memory, uint64_t address, - size_t size, bool is_float); +// Endian-explicit memory accessors. The `Value` type is host-agnostic; +// the byte order lives entirely in these encode/decode steps. Pick +// `_le` or `_be` based on the IR memory op's `IsBigEndian(sub_op)`. +MX_EXPORT void concrete_write_to_mem_le(ConcreteMemory &memory, + uint64_t address, const Value &val, + size_t size, bool is_float = false); +MX_EXPORT void concrete_write_to_mem_be(ConcreteMemory &memory, + uint64_t address, const Value &val, + size_t size, bool is_float = false); +MX_EXPORT Value concrete_read_from_mem_le(ConcreteMemory &memory, + uint64_t address, size_t size, + bool is_float); +MX_EXPORT Value concrete_read_from_mem_be(ConcreteMemory &memory, + uint64_t address, size_t size, + bool is_float); MX_EXPORT bool concrete_mem_bulk_op(ConcreteMemory &memory, MemOp sub, const std::vector &ops, const MemoryInst &mi, Value &result); diff --git a/include/multiplier/IR/Interpret/ConcretePolicy.h b/include/multiplier/IR/Interpret/ConcretePolicy.h index be9eeeaf1..c7ab2e1d5 100644 --- a/include/multiplier/IR/Interpret/ConcretePolicy.h +++ b/include/multiplier/IR/Interpret/ConcretePolicy.h @@ -104,8 +104,13 @@ class MX_EXPORT ConcretePolicy result = make_undef(); return true; } - result = concrete_read_from_mem(memory_, addr.u64, - hint.size_bytes, hint.is_float); + if (IsBigEndian(hint.sub_op)) { + result = concrete_read_from_mem_be(memory_, addr.u64, + hint.size_bytes, hint.is_float); + } else { + result = concrete_read_from_mem_le(memory_, addr.u64, + hint.size_bytes, hint.is_float); + } return true; } @@ -113,8 +118,13 @@ class MX_EXPORT ConcretePolicy bool mem_write(Sched &, const Value &addr, const Value &val, const MemAccessHint &hint) { if (addr.u64 == 0) return true; - concrete_write_to_mem(memory_, addr.u64, val, - hint.size_bytes, hint.is_float); + if (IsBigEndian(hint.sub_op)) { + concrete_write_to_mem_be(memory_, addr.u64, val, + hint.size_bytes, hint.is_float); + } else { + concrete_write_to_mem_le(memory_, addr.u64, val, + hint.size_bytes, hint.is_float); + } return true; } @@ -150,6 +160,7 @@ class MX_EXPORT ConcretePolicy bool resolve_call(Sched &, const IRInstruction &, RawEntityId target_eid, RawEntityId indirect_target_eid, + uint64_t /*target_addr*/, const std::vector &, bool, CallResolution &resolution) { if (func_resolver_) { diff --git a/include/multiplier/IR/Interpret/Continuation.h b/include/multiplier/IR/Interpret/Continuation.h index 5bb473e82..12d070dff 100644 --- a/include/multiplier/IR/Interpret/Continuation.h +++ b/include/multiplier/IR/Interpret/Continuation.h @@ -171,6 +171,58 @@ class BranchContinuation final uint8_t step_{0}; }; +// =========================================================================== +// SwitchContinuation — driver must pick a switch case (or default). +// +// Emitted when a SWITCH instruction's selector is non-extractable (e.g. a +// symbolic z3 BitVec). The driver realizes each case as a forked path, +// adding a path-condition constraint that pins the selector into that +// case's range; the default fork constrains the selector outside every +// case. +// =========================================================================== + +// Plain struct (named `SwitchCaseRange` to avoid colliding with +// `mx::SwitchCase` from the AST when both namespaces are pulled in via +// `using namespace ir::interpret;`). +struct SwitchCaseRange { + int64_t low; + int64_t high; + IRBlock target_block; +}; + +template +class SwitchContinuation final : public Continuation { + public: + using state_ref = typename Continuation::state_ref; + + SwitchContinuation(state_ref snap, ValueT selector, RawEntityId sel_eid, + std::vector cases, + IRBlock default_block) + : snapshot_(std::move(snap)), + selector_(std::move(selector)), + sel_eid_(sel_eid), + cases_(std::move(cases)), + default_block_(default_block) {} + + state_ref snapshot(void) const override { return snapshot_; } + + std::string describe(void) const override { return "switch"; } + + RawEntityId operand_eid(void) const override { return sel_eid_; } + + const ValueT &selector(void) const { return selector_; } + RawEntityId selector_eid(void) const { return sel_eid_; } + const std::vector &cases(void) const { return cases_; } + const IRBlock &default_block(void) const { return default_block_; } + + private: + state_ref snapshot_; + ValueT selector_; + RawEntityId sel_eid_{kInvalidEntityId}; + std::vector cases_; + IRBlock default_block_; +}; + // =========================================================================== // CallContinuation — driver must resolve an unresolved CALL. // diff --git a/include/multiplier/IR/Interpret/Interpreter.h b/include/multiplier/IR/Interpret/Interpreter.h index 57a659313..ccd8a0dc5 100644 --- a/include/multiplier/IR/Interpret/Interpreter.h +++ b/include/multiplier/IR/Interpret/Interpreter.h @@ -52,6 +52,7 @@ enum class WorkKind : uint8_t { COMPUTE_CONST, COMPUTE_ALLOCA, COMPUTE_BINARY, + COMPUTE_LOGICAL, COMPUTE_COMPARE, COMPUTE_UNARY, COMPUTE_CAST, diff --git a/include/multiplier/IR/Interpret/InterpreterLoop.h b/include/multiplier/IR/Interpret/InterpreterLoop.h index 6f4f36e21..fe145d121 100644 --- a/include/multiplier/IR/Interpret/InterpreterLoop.h +++ b/include/multiplier/IR/Interpret/InterpreterLoop.h @@ -162,8 +162,19 @@ inline void enter_block(auto &state, PolicyT &policy, const IRBlock &block) { // Phase 8d: notify the policy of the block entry so analysts can // observe every block visit (not only branch transitions). policy.on_enter_block(state, block); - // Clear transient values cache, then push roots. - state.call_stack.top().values.clear(); + // Clear the transient values cache. + auto &frame = state.call_stack.top(); + frame.values.clear(); + // Evict call_results entries for CALL instructions that live in this + // block. Cross-block entries (calls from other blocks whose results + // this block reads as operands) are preserved. Without this, a loop + // re-entering this block would reuse the first iteration's return + // values for every subsequent call, skipping resolve_call entirely. + for (auto inst : block.all_instructions()) { + if (inst.opcode() == OpCode::CALL) { + frame.call_results.erase(EntityId(inst.id()).Pack()); + } + } push_block_work_items(state, block); } @@ -287,7 +298,6 @@ inline void analyze(auto &state, case OpCode::FMUL_32: case OpCode::FMUL_64: case OpCode::FDIV_32: case OpCode::FDIV_64: case OpCode::FREM_32: case OpCode::FREM_64: - case OpCode::LOGICAL_AND: case OpCode::LOGICAL_OR: if (auto bin = BinaryInst::from(inst)) { push(WorkKind::COMPUTE_BINARY); push_operand(bin->rhs()); @@ -295,6 +305,13 @@ inline void analyze(auto &state, } break; + case OpCode::LOGICAL_AND: case OpCode::LOGICAL_OR: + if (auto bin = BinaryInst::from(inst)) { + push(WorkKind::COMPUTE_LOGICAL); + push_operand(bin->lhs()); + } + break; + // Comparisons (all widths). case OpCode::CMP_EQ_8: case OpCode::CMP_EQ_16: case OpCode::CMP_EQ_32: case OpCode::CMP_EQ_64: @@ -609,6 +626,37 @@ inline void compute_binary(CallFrame &frame, PolicyT &policy, val(frame, bin->rhs())); } +// Short-circuit evaluation for LOGICAL_AND / LOGICAL_OR. +// Called after LHS has been evaluated; pushes RHS + COMPUTE_BINARY if the +// result can't be determined from LHS alone, otherwise stores the result +// directly without evaluating RHS. +template +inline void compute_logical(auto &state, PolicyT &policy, SchedT &sched, + const IRInstruction &inst) { + auto &frame = state.call_stack.top(); + auto bin = BinaryInst::from(inst); + if (!bin) { + frame.values[eid(inst)] = ValueTraits::default_value(); + return; + } + const ValueT &lhs = val(frame, bin->lhs()); + auto truth = policy.is_true(lhs); + bool is_and = inst.opcode() == OpCode::LOGICAL_AND; + if (truth.has_value()) { + bool lhs_true = *truth; + if (is_and ? !lhs_true : lhs_true) { + // LOGICAL_AND + false LHS → 0; LOGICAL_OR + true LHS → 1. + frame.values[eid(inst)] = policy.make_const( + ConstOp::UINT8, lhs_true ? 1 : 0, lhs_true ? 1u : 0u); + return; + } + } + // Symbolic LHS or non-short-circuit concrete: evaluate RHS, then + // combine with binary_op (which handles the symbolic case correctly). + state.work_stack.push_back({WorkKind::COMPUTE_BINARY, inst, {}}); + state.work_stack.push_back({WorkKind::ANALYZE, bin->rhs(), {}}); +} + template inline void compute_compare(CallFrame &frame, PolicyT &policy, const IRInstruction &inst) { @@ -799,7 +847,7 @@ inline void compute_global_ptr(auto &state, PolicyT &policy, // the push may reallocate the segment's vector and invalidate `frame`. frame.values[id] = addr; - if (info.initializer && !placed_via_hint) { + if (info.initializer) { CallFrame init_frame; init_frame.func = *info.initializer; init_frame.params = {addr}; @@ -858,7 +906,6 @@ inline void compute_func_ptr(auto &state, PolicyT &policy, } slot_addr = *a; } - policy.memory().write(slot_addr, &src_eid, 8); state.function_addresses[src_eid] = slot_addr; } frame.locals[src_eid] = slot_addr; @@ -875,7 +922,8 @@ inline void exec_load(auto &state, PolicyT &policy, auto mi = MemoryInst::from(inst); if (!mi) return; auto sub = mi->sub_opcode(); - MemAccessHint hint{ir::AccessSize(sub), ir::IsFloatLoad(sub), false}; + MemAccessHint hint{ir::AccessSize(sub), ir::IsFloatLoad(sub), false, + false, false, sub}; auto &frame = state.call_stack.top(); ValueT addr = val(frame, mi->address()); auto inst_eid = eid(inst); @@ -909,12 +957,21 @@ inline void exec_store(auto &state, PolicyT &policy, auto mi = MemoryInst::from(inst); if (!mi) return; auto sub = mi->sub_opcode(); - MemAccessHint hint{ir::AccessSize(sub), ir::IsFloatStore(sub), true}; + MemAccessHint hint{ir::AccessSize(sub), ir::IsFloatStore(sub), true, + false, false, sub}; auto &frame = state.call_stack.top(); ValueT addr = val(frame, mi->address()); ValueT stored = val(frame, mi->stored_value()); + auto inst_eid = eid(inst); auto addr_eid = eid(mi->address()); + // The STORE's result slot holds the stored value, so an assignment + // expression like `(p = expr)` evaluates to `expr` — i.e. the IR + // operand referencing the STORE instruction yields the value that + // was just written. Set this BEFORE the side effect so it's also + // visible if `with_address` suspends (the snapshot already has it). + frame.values[inst_eid] = stored; + // Phase 8a: symmetric symbolic-STORE short-circuit (see exec_load). if (!policy.extract_address(addr) && addr_eid != kInvalidEntityId) { if (policy.exec_symbolic_store(sched, addr, stored, hint)) { @@ -1077,6 +1134,7 @@ inline void exec_call(auto &state, PolicyT &policy, std::optional callee_ir; auto target_decl = ci->target(); RawEntityId indirect_eid = kInvalidEntityId; + uint64_t target_addr = 0; if (ci->is_indirect()) { auto callee_op = inst.nth_operand(0); @@ -1091,10 +1149,14 @@ inline void exec_call(auto &state, PolicyT &policy, policy.with_address(callee_val, policy.memory(), hint, eid(callee_op), state, sched, [&](auto &p, ConcreteMemory & /*mem*/, uint64_t a) { - ValueT addr_val = p.make_literal_ptr(a); - ValueT eid_val; - p.mem_read(sched, addr_val, hint, eid_val); - indirect_eid = static_cast(p.extract_uint(eid_val)); + // `a` IS the function's virtual address (the value a function + // pointer carries after store/load roundtrip). Resolve it back + // to a declaration entity id via the policy's reverse resolver + // rather than reading synthetic data from memory; the address + // itself is also forwarded to resolve_call so analyst hooks + // can act on addresses that aren't (yet) in the reverse map. + target_addr = a; + indirect_eid = p.entity_for_address(a); }); } @@ -1106,7 +1168,7 @@ inline void exec_call(auto &state, PolicyT &policy, ? target_decl->id().Pack() : kInvalidEntityId; CallResolution resolution; bool alive = policy.resolve_call( - sched, inst, target_eid, indirect_eid, + sched, inst, target_eid, indirect_eid, target_addr, call_args, ci->is_indirect(), resolution); if (!alive) { frame.values[id] = policy.make_default(); @@ -1367,6 +1429,20 @@ inline void exec_ret(auto &state, PolicyT &policy, ValueT callee_result = read_return_value( state, policy, frame, ret_from_inst, sched); auto call_site = frame.call_site; + auto func_kind = frame.func.kind(); + bool is_global_init = + (func_kind == ir::FunctionKind::GLOBAL_INITIALIZER || + func_kind == ir::FunctionKind::THREAD_LOCAL_INITIALIZER); + IRFunction init_func; + ValueT init_addr; + if (is_global_init) { + init_func = frame.func; + if (!frame.params.empty()) { + init_addr = frame.params[0]; + } else { + init_addr = ValueTraits::default_value(); + } + } state.call_stack.pop(); if (call_site != kInvalidEntityId) { // Store in both caches: values (for within-block use) and @@ -1374,6 +1450,9 @@ inline void exec_ret(auto &state, PolicyT &policy, state.call_stack.top().values[call_site] = callee_result; state.call_stack.top().call_results[call_site] = callee_result; } + if (is_global_init) { + policy.on_global_initialized(sched, init_func, init_addr); + } return; } @@ -1436,8 +1515,8 @@ inline void decide_cond_branch(auto &state, PolicyT &policy, state.work_stack.clear(); } -template -inline void decide_switch(auto &state, PolicyT &policy, +template +inline void decide_switch(auto &state, PolicyT &policy, SchedT &sched, const IRInstruction &inst) { auto sw = SwitchInst::from(inst); if (!sw) return; @@ -1452,21 +1531,41 @@ inline void decide_switch(auto &state, PolicyT &policy, } ValueT sel = frame.values[sel_eid]; - int64_t sel_val = policy.extract_int(sel); + auto maybe_sel = policy.try_extract_int(sel); + if (maybe_sel) { + int64_t sel_val = *maybe_sel; + IRBlock default_block{}; + for (auto sc : sw->cases()) { + if (sc.is_default()) { + default_block = sc.target_block(); + continue; + } + if (sel_val >= sc.low() && sel_val <= sc.high()) { + enter_block(state, policy, sc.target_block()); + return; + } + } + if (EntityId(default_block.id()).Pack()) { + enter_block(state, policy, default_block); + } + return; + } + + // Symbolic selector: collect cases + default and emit a switch + // continuation. The driver clones the snapshot per case, enters the + // case's target block, and adds a path-condition constraint. + std::vector case_list; IRBlock default_block{}; for (auto sc : sw->cases()) { if (sc.is_default()) { default_block = sc.target_block(); - continue; - } - if (sel_val >= sc.low() && sel_val <= sc.high()) { - enter_block(state, policy, sc.target_block()); - return; + } else { + case_list.push_back({sc.low(), sc.high(), sc.target_block()}); } } - if (EntityId(default_block.id()).Pack()) { - enter_block(state, policy, default_block); - } + sched.on_switch(std::move(sel), sel_eid, std::move(case_list), + default_block, state.clone()); + state.work_stack.clear(); } // =========================================================================== @@ -1481,7 +1580,8 @@ inline void dispatch(auto &state, PolicyT &policy, // Fire the per-instruction observe hook for every work item that // represents a real instruction execution (not scheduling helpers). if (item.kind != WorkKind::ENTER_BLOCK && - item.kind != WorkKind::ANALYZE) { + item.kind != WorkKind::ANALYZE && + item.kind != WorkKind::COMPUTE_LOGICAL) { policy.on_instruction(state, sched, item.inst); if (policy.abort_requested()) { state.work_stack.clear(); @@ -1515,6 +1615,10 @@ inline void dispatch(auto &state, PolicyT &policy, ++state.steps; compute_binary(frame, policy, item.inst); break; + case WorkKind::COMPUTE_LOGICAL: + ++state.steps; + compute_logical(state, policy, sched, item.inst); + break; case WorkKind::COMPUTE_COMPARE: ++state.steps; compute_compare(frame, policy, item.inst); @@ -1629,7 +1733,7 @@ inline void dispatch(auto &state, PolicyT &policy, break; case WorkKind::DECIDE_SWITCH: ++state.steps; - decide_switch(state, policy, item.inst); + decide_switch(state, policy, sched, item.inst); break; case WorkKind::EXEC_RET: ++state.steps; diff --git a/include/multiplier/IR/Interpret/Policy.h b/include/multiplier/IR/Interpret/Policy.h index 08c337337..b78e3803a 100644 --- a/include/multiplier/IR/Interpret/Policy.h +++ b/include/multiplier/IR/Interpret/Policy.h @@ -81,6 +81,17 @@ struct Scheduler { std::move(false_val), std::move(true_val), std::forward(state)); } + + // Symbolic SWITCH: scheduler emits one continuation describing every + // case + default. The driver realizes them as forked paths. + void on_switch(ValueT selector, RawEntityId sel_eid, + std::vector cases, + IRBlock default_block, + auto &&state) { + self().on_switch(std::move(selector), sel_eid, std::move(cases), + default_block, + std::forward(state)); + } }; // Concrete execution: no forking, no error collection. @@ -115,6 +126,16 @@ struct NoOpScheduler : Scheduler { true_block, false_block, std::move(false_val), std::move(true_val))); } + + void on_switch(Value selector, RawEntityId sel_eid, + std::vector cases, + IRBlock default_block, + ref_t> state) { + outcome.continuations.emplace_back( + std::make_unique>( + std::move(state), std::move(selector), sel_eid, + std::move(cases), default_block)); + } }; // =========================================================================== @@ -149,6 +170,18 @@ struct Policy { return self().extract_int(val); } + // Concreteness-preserving variant of extract_int: returns nullopt for + // values the policy cannot represent as a concrete int (e.g. symbolic + // z3 BitVecs in PythonPolicy). Used by `decide_switch` to fork on + // symbolic selectors instead of silently picking case-zero. + std::optional try_extract_int(const ValueT &val) { + return self().try_extract_int_impl(val); + } + // Default: defer to extract_int — concrete policies always produce an int. + std::optional try_extract_int_impl(const ValueT &val) { + return self().extract_int(val); + } + uint64_t extract_uint(const ValueT &val) { return self().extract_uint(val); } @@ -361,12 +394,13 @@ struct Policy { const IRInstruction &call_inst, RawEntityId target_eid, RawEntityId indirect_target_eid, + uint64_t target_addr, const std::vector &arguments, bool is_indirect, CallResolution &resolution) { return self().resolve_call( sched, call_inst, target_eid, indirect_target_eid, - arguments, is_indirect, resolution); + target_addr, arguments, is_indirect, resolution); } bool resolve_global(auto &sched, RawEntityId entity_id, @@ -392,6 +426,16 @@ struct Policy { return std::nullopt; } + // Reverse-direction resolver: given a virtual address previously assigned + // to some entity (function, global), return the entity's RawEntityId. + // Used by exec_call to map an indirect callee address back to a known + // declaration without reading synthetic data from the interpreter's + // flat address space. Default returns kInvalidEntityId (no mapping). + RawEntityId entity_for_address(uint64_t addr) { + return self().entity_for_address_impl(addr); + } + RawEntityId entity_for_address_impl(uint64_t) { return kInvalidEntityId; } + // Phase 9: marks the next `with_address` suspension (when emitted from // an indirect-call callee load) as a call-target suspension. Policies // that care override `_impl`; the default is a no-op so concrete @@ -412,6 +456,20 @@ struct Policy { template void on_instruction_impl(StateT &, SchedT &, const IRInstruction &) {} + // Fires when a GLOBAL_INITIALIZER / THREAD_LOCAL_INITIALIZER frame + // returns — i.e. after the global's IR initializer has finished + // executing. `init_func` is the initializer IRFunction; its + // source_declaration is the VarDecl. `addr` is the global's + // virtual address (ValueT — concrete pointer). + template + void on_global_initialized(SchedT &sched, const IRFunction &init_func, + const ValueT &addr) { + self().on_global_initialized_impl(sched, init_func, addr); + } + template + void on_global_initialized_impl(SchedT &, const IRFunction &, + const ValueT &) {} + // Abort-request gate. PythonPolicy sets this when a Python hook raises // an exception so the loop can exit cleanly after the current item. bool abort_requested() const { diff --git a/lib/IR/Instruction.cpp b/lib/IR/Instruction.cpp index 9e6fe3609..244f39995 100644 --- a/lib/IR/Instruction.cpp +++ b/lib/IR/Instruction.cpp @@ -6,10 +6,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include @@ -267,15 +269,47 @@ void IRInstruction::format(std::ostream &os) const { os << "/" << ir::EnumeratorName(bi->sub_opcode()); } else if (auto fi = FloatOpInst::from(*this)) { os << "/" << ir::EnumeratorName(fi->sub_opcode()); + } else if (auto rmw = ReadModifyWriteInst::from(*this)) { + os << "/" << ir::EnumeratorName(rmw->underlying_op()); } - // Call target name. + // Named entity annotations. if (auto ci = CallInst::from(*this)) { if (auto fd = ci->target()) { os << " @" << fd->name(); } else if (ci->is_indirect()) { os << " @"; } + } else if (auto gp = GlobalPtrInst::from(*this)) { + if (auto vd = gp->variable()) { + os << " @" << vd->name(); + } + } else if (auto tp = ThreadLocalPtrInst::from(*this)) { + if (auto vd = tp->variable()) { + os << " @" << vd->name(); + } + } else if (auto fp = FuncPtrInst::from(*this)) { + if (auto fd = fp->function()) { + os << " @" << fd->name(); + } + } else if (auto gep = GEPFieldInst::from(*this)) { + os << " ." << gep->field().name() << "+" << gep->byte_offset(); + } else if (auto bi = BranchInst::from(*this)) { + os << " -> %" << bi->target_block().id().Pack(); + } else if (auto cb = CondBranchInst::from(*this)) { + os << " -> %" << cb->true_block().id().Pack() + << " / %" << cb->false_block().id().Pack(); + } else if (auto sw = SwitchInst::from(*this)) { + for (auto c : sw->cases()) { + if (c.is_default()) { + os << " [default -> %" << c.target_block().id().Pack() << "]"; + } else if (c.is_range()) { + os << " [" << c.low() << ".." << c.high() + << " -> %" << c.target_block().id().Pack() << "]"; + } else { + os << " [" << c.low() << " -> %" << c.target_block().id().Pack() << "]"; + } + } } // Operands. diff --git a/lib/IR/Interpret/ConcretePolicy.cpp b/lib/IR/Interpret/ConcretePolicy.cpp index 3fdc02755..3dace6d4b 100644 --- a/lib/IR/Interpret/ConcretePolicy.cpp +++ b/lib/IR/Interpret/ConcretePolicy.cpp @@ -40,39 +40,104 @@ bool concrete_has_address(const Value &val) { return val.u64 != 0; } -void concrete_write_to_mem(ConcreteMemory &memory_, uint64_t address, - const Value &val, size_t size, bool is_float) { +// Endian-explicit byte composition / decomposition. Both helpers +// operate on `Value::u64` directly so they are independent of the +// host's byte order and never alias the anonymous-struct fields. +// +// Sign extension for integer reads is done via integer arithmetic on +// `u64`, again so the result is host-agnostic. + +namespace { + +inline uint64_t narrow_float_bits(const Value &val, size_t size) { uint64_t bits = val.u64; - // If storing a float to a 4-byte slot and the value holds f64 bits - // (high 32 non-zero), narrow f64 → f32. - if (is_float && size <= 4 && (bits >> 32) != 0) { + if (size <= 4 && (bits >> 32) != 0) { float f = static_cast(val.f64); uint32_t fbits; std::memcpy(&fbits, &f, sizeof(fbits)); bits = fbits; } - memory_.write(address, &bits, - static_cast(std::min(size, sizeof(bits)))); + return bits; } -Value concrete_read_from_mem(ConcreteMemory &memory_, uint64_t address, - size_t size, bool is_float) { +inline Value sign_extend_int_value(uint64_t bits, size_t size) { Value result; - result.u64 = 0; - memory_.read(address, &result.u64, - static_cast(std::min(size, sizeof(result.u64)))); - if (!is_float) { - // Sign-extend integer reads. - switch (size) { - case 1: result.i64 = static_cast(result.i8.val); break; - case 2: result.i64 = static_cast(result.i16.val); break; - case 4: result.i64 = static_cast(result.i32.val); break; - default: break; - } + result.u64 = bits; + switch (size) { + case 1: + result.i64 = static_cast(static_cast(bits & 0xff)); + break; + case 2: + result.i64 = static_cast(static_cast(bits & 0xffff)); + break; + case 4: + result.i64 = static_cast( + static_cast(bits & 0xffffffff)); + break; + default: + break; } return result; } +} // namespace + +void concrete_write_to_mem_le(ConcreteMemory &memory_, uint64_t address, + const Value &val, size_t size, bool is_float) { + uint64_t bits = is_float ? narrow_float_bits(val, size) : val.u64; + uint8_t buf[8]; + size_t n = std::min(size, sizeof(buf)); + for (size_t i = 0; i < n; ++i) { + buf[i] = static_cast((bits >> (i * 8)) & 0xff); + } + memory_.write(address, buf, static_cast(n)); +} + +void concrete_write_to_mem_be(ConcreteMemory &memory_, uint64_t address, + const Value &val, size_t size, bool is_float) { + uint64_t bits = is_float ? narrow_float_bits(val, size) : val.u64; + uint8_t buf[8]; + size_t n = std::min(size, sizeof(buf)); + for (size_t i = 0; i < n; ++i) { + buf[n - 1 - i] = static_cast((bits >> (i * 8)) & 0xff); + } + memory_.write(address, buf, static_cast(n)); +} + +Value concrete_read_from_mem_le(ConcreteMemory &memory_, uint64_t address, + size_t size, bool is_float) { + uint8_t buf[8] = {0}; + size_t n = std::min(size, sizeof(buf)); + memory_.read(address, buf, static_cast(n)); + uint64_t bits = 0; + for (size_t i = 0; i < n; ++i) { + bits |= static_cast(buf[i]) << (i * 8); + } + if (is_float) { + Value result; + result.u64 = bits; + return result; + } + return sign_extend_int_value(bits, size); +} + +Value concrete_read_from_mem_be(ConcreteMemory &memory_, uint64_t address, + size_t size, bool is_float) { + uint8_t buf[8] = {0}; + size_t n = std::min(size, sizeof(buf)); + memory_.read(address, buf, static_cast(n)); + uint64_t bits = 0; + for (size_t i = 0; i < n; ++i) { + bits |= static_cast(buf[n - 1 - i]) << (i * 8); + } + if (is_float) { + Value result; + result.u64 = bits; + return result; + } + return sign_extend_int_value(bits, size); +} + // =========================================================================== // Free functions: stateless concrete value operations // =========================================================================== diff --git a/lib/Re2.cpp b/lib/Re2.cpp index 554130a8a..6ffd6cb29 100644 --- a/lib/Re2.cpp +++ b/lib/Re2.cpp @@ -21,7 +21,7 @@ static const std::string_view kEmptyStringView(""); // arguments per sub-match, hence the requirement for enclosing // `pattern_` in a match group with `(` and `)`. RegexQueryImpl::RegexQueryImpl(std::string pattern_) - : pattern("(" + pattern_ + ")"), + : pattern("(?m)(" + pattern_ + ")"), re(pattern) {} void RegexQueryImpl::ForEachMatch( diff --git a/lib/Re2Impl.cpp b/lib/Re2Impl.cpp index 2b4cb9fac..f9338e3d3 100644 --- a/lib/Re2Impl.cpp +++ b/lib/Re2Impl.cpp @@ -371,9 +371,9 @@ gap::generator RegexQueryResultImpl::Enumerate(void) & { if (auto result = GetNextMatchInFragment()) { co_yield *result; + } else { + ++index; } - - ++index; } } diff --git a/tests/symex/c/symex_integration.c b/tests/symex/c/symex_integration.c index 5857ad737..fa064cab6 100644 --- a/tests/symex/c/symex_integration.c +++ b/tests/symex/c/symex_integration.c @@ -122,3 +122,49 @@ int32_t si_rw_branch(int32_t *arr, int32_t n, int32_t val) { arr[0] = val; return arr[0]; } + +// ----------------------------------------------------------------------- +// Symbolic-switch corpus. +// ----------------------------------------------------------------------- + +// Three single-value cases + default. Symbolic selector forks 4 paths. +int32_t si_switch_three(int32_t sel) { + switch (sel) { + case 1: return 10; + case 2: return 20; + case 3: return 30; + default: return -1; + } +} + +// GNU range case + default — symbolic selector forks 2 paths. +int32_t si_switch_range(int32_t sel) { + switch (sel) { + case 2 ... 5: return 100; + default: return -1; + } +} + +// No default — symbolic selector forks exactly 2 paths (one per case). +// Falls through to an implicit return 0 for unmatched selectors. +int32_t si_switch_no_default(int32_t sel) { + int32_t r = 0; + switch (sel) { + case 1: r = 10; break; + case 2: r = 20; break; + } + return r; +} + +// Pre-constrained selector: an early branch narrows `sel` to {0, 1}. +// Cases 2 and the default block become infeasible and the engine should +// drop them rather than enqueueing dead paths. +int32_t si_switch_constrained(int32_t sel) { + if (sel > 1) return -2; + switch (sel) { + case 0: return 100; + case 1: return 200; + case 2: return 300; + default: return -1; + } +} diff --git a/tests/symex/test_phase15.py b/tests/symex/test_phase15.py new file mode 100644 index 000000000..df73f9014 --- /dev/null +++ b/tests/symex/test_phase15.py @@ -0,0 +1,232 @@ +# Copyright (c) 2026-present, Trail of Bits, Inc. +# +# This source code is licensed in accordance with the terms specified in +# the LICENSE file found in the root directory of this source tree. + +"""Phase 15 — symbolic SWITCH selector forking. + +When a SWITCH instruction's selector is a symbolic z3 BitVec, the +substrate emits a SwitchContinuation instead of falling through to +case-zero. The driver realizes one forked path per case (and one for +the default), each carrying a path-condition constraint that pins the +selector into the case's range. + +Tests: + P15.1 Symbolic 3-case + default switch produces 4 paths. + P15.2 Each path's path_condition contains the right z3 constraint. + P15.3 GNU range case [2..5] adds the range constraint. + P15.4 No-default switch produces exactly N case paths. + P15.5 Concrete selector still picks exactly one path (regression). + P15.6 Infeasible case (pre-constrained selector) is dropped. +""" + +import pytest + +z3 = pytest.importorskip("z3") + +from multiplier.symex.engine import SymExEngine +from multiplier.symex.layout import Layout +from multiplier.symex.events import ( + Terminal, EventKind, SWITCH_CASE, SWITCH_DEFAULT, +) + + +def _engine(symex_index): + e = SymExEngine(symex_index) + e.layout = Layout() + return e + + +def _completed(paths): + return [p for p in paths if p.terminal == Terminal.COMPLETED] + + +# --------------------------------------------------------------------------- +# P15.1 — three single-value cases + default → 4 paths +# --------------------------------------------------------------------------- + +def test_p15_1_symbolic_switch_three_cases_plus_default(symex_index): + e = _engine(symex_index) + sel = z3.BitVec("sel", 32) + paths = e.explore("si_switch_three", args=[sel]) + completed = _completed(paths) + # Three single-value cases (1, 2, 3) + default → exactly 4 completed paths. + assert len(completed) == 4, \ + f"expected 4 paths, got {len(completed)}: " \ + f"{[p.return_value for p in completed]}" + + +# --------------------------------------------------------------------------- +# P15.2 — path conditions identify the chosen case +# --------------------------------------------------------------------------- + +def test_p15_2_symbolic_switch_path_conditions(symex_index): + e = _engine(symex_index) + sel = z3.BitVec("sel", 32) + paths = e.explore("si_switch_three", args=[sel]) + completed = _completed(paths) + + # For each completed path, ask the solver what `sel` must be. + case_witnesses = {} + for p in completed: + s = z3.Solver() + for c in p.path_condition: + s.add(c) + if s.check() == z3.sat: + m = s.model() + sel_val = m.eval(sel, model_completion=True).as_long() + case_witnesses.setdefault(p.return_value, []).append(sel_val) + + # Every case path picks exactly its case value; default picks + # something that is not 1, 2, or 3. + assert any(p.return_value == 10 for p in completed), \ + "case 1 (sel == 1 → return 10) must produce a path" + assert any(p.return_value == 20 for p in completed) + assert any(p.return_value == 30 for p in completed) + assert any(p.return_value in (-1, 0xFFFFFFFF) for p in completed), \ + "default path (return -1) must exist" + + # Path condition must constrain `sel` to the matching case value. + for p in completed: + if p.return_value == 10: + assert p.must_be(sel, 1), \ + f"return=10 path must force sel==1, got conditions={p.condition_str()}" + elif p.return_value == 20: + assert p.must_be(sel, 2) + elif p.return_value == 30: + assert p.must_be(sel, 3) + elif p.return_value in (-1, 0xFFFFFFFF): + # Default path: sel is none of {1, 2, 3}. + for forbidden in (1, 2, 3): + assert not p.can_be(sel, forbidden), \ + f"default path must forbid sel=={forbidden}, " \ + f"conditions={p.condition_str()}" + + +# --------------------------------------------------------------------------- +# P15.3 — range case [2..5] adds the right constraint +# --------------------------------------------------------------------------- + +def test_p15_3_symbolic_switch_range_case(symex_index): + e = _engine(symex_index) + sel = z3.BitVec("sel", 32) + paths = e.explore("si_switch_range", args=[sel]) + completed = _completed(paths) + # One path for the [2..5] range, one for default → 2 paths. + assert len(completed) == 2, \ + f"expected 2 paths, got {len(completed)}" + + range_paths = [p for p in completed if p.return_value == 100] + default_paths = [p for p in completed if p.return_value in (-1, 0xFFFFFFFF)] + assert len(range_paths) == 1 + assert len(default_paths) == 1 + + p_range = range_paths[0] + # On the range path, sel can be any value in [2, 5] but nothing else. + for v in (2, 3, 4, 5): + assert p_range.can_be(sel, v), \ + f"range path must admit sel=={v}, " \ + f"conditions={p_range.condition_str()}" + for v in (1, 6): + assert not p_range.can_be(sel, v), \ + f"range path must forbid sel=={v}, " \ + f"conditions={p_range.condition_str()}" + + p_def = default_paths[0] + for v in (2, 3, 4, 5): + assert not p_def.can_be(sel, v), \ + f"default path must forbid sel=={v}, " \ + f"conditions={p_def.condition_str()}" + assert p_def.can_be(sel, 0) or p_def.can_be(sel, 6) + + +# --------------------------------------------------------------------------- +# P15.4 — switch without default produces only the case forks +# --------------------------------------------------------------------------- + +def test_p15_4_symbolic_switch_no_default(symex_index): + e = _engine(symex_index) + sel = z3.BitVec("sel", 32) + paths = e.explore("si_switch_no_default", args=[sel]) + completed = _completed(paths) + # Two case paths only. + assert len(completed) == 2, \ + f"expected 2 paths (no default), got {len(completed)}: " \ + f"{[p.return_value for p in completed]}" + return_values = {p.return_value for p in completed} + assert return_values == {10, 20}, \ + f"return values must be {{10, 20}}, got {return_values}" + + +# --------------------------------------------------------------------------- +# P15.5 — concrete selector still picks exactly one path (regression) +# --------------------------------------------------------------------------- + +def test_p15_5_concrete_switch_unchanged(symex_index): + e = _engine(symex_index) + paths = e.explore("si_switch_three", args=[2]) + completed = _completed(paths) + assert len(completed) == 1, \ + f"concrete selector must yield exactly 1 path, got {len(completed)}" + assert completed[0].return_value == 20, \ + f"concrete sel=2 must return 20, got {completed[0].return_value}" + + +# --------------------------------------------------------------------------- +# P15.6 — switch_case events recorded with selector_eid + range +# --------------------------------------------------------------------------- + +def test_p15_6_switch_case_events_recorded(symex_index): + e = _engine(symex_index) + sel = z3.BitVec("sel", 32) + paths = e.explore("si_switch_three", args=[sel]) + completed = _completed(paths) + + # Each non-default path has exactly one switch_case event with matching range. + case_events = [] + default_events = [] + for p in completed: + for ev in p.events: + if ev.get("kind") == SWITCH_CASE: + case_events.append(ev) + elif ev.get("kind") == SWITCH_DEFAULT: + default_events.append(ev) + + assert len(case_events) == 3, \ + f"expected 3 switch_case events across paths, got {len(case_events)}" + assert len(default_events) == 1, \ + f"expected 1 switch_default event, got {len(default_events)}" + + case_ranges = sorted(((e["low"], e["high"]) for e in case_events)) + assert case_ranges == [(1, 1), (2, 2), (3, 3)], \ + f"unexpected case ranges {case_ranges}" + + # Selector-eid is consistent across all events. + eids = {ev["selector_eid"] for ev in case_events + default_events} + assert len(eids) == 1 + + +# --------------------------------------------------------------------------- +# P15.7 — pre-constrained selector drops infeasible cases +# --------------------------------------------------------------------------- + +def test_p15_7_symbolic_switch_infeasible_dropped(symex_index): + """si_switch_constrained's early `if (sel > 1) return -2` narrows + `sel <= 1` on the surviving switch path. Case 2 (sel == 2) is + UNSAT under that constraint and must be dropped — without that + drop, every fork would reach exec_ret regardless of feasibility.""" + e = _engine(symex_index) + sel = z3.BitVec("sel", 32) + paths = e.explore("si_switch_constrained", args=[sel]) + completed = _completed(paths) + + return_values = sorted(p.return_value for p in completed) + # Cases 0 and 1 are feasible; case 2 is not (sel == 2 conflicts with + # the prior `sel <= 1`). The default path is feasible because + # `sel <= 1` admits negative values that aren't 0 or 1. + assert 100 in return_values, \ + f"case 0 path must survive, got {return_values}" + assert 200 in return_values, \ + f"case 1 path must survive, got {return_values}" + assert 300 not in return_values, \ + f"case 2 must be dropped as infeasible, got {return_values}" diff --git a/tests/symex/test_phase2.py b/tests/symex/test_phase2.py index cd78b694a..6fce6cb5d 100644 --- a/tests/symex/test_phase2.py +++ b/tests/symex/test_phase2.py @@ -193,8 +193,8 @@ def test_p2_6_intercept_indirect_call_resolution(index): fired = [] @engine.intercept.indirect_call - def hook(ctx, next_hook): - fired.append(True) + def hook(ctx, target_addr, next_hook): + fired.append(target_addr) return 4242 # sentinel return paths = engine.explore("test_function_calls") @@ -353,20 +353,16 @@ def watch(ctx, **payload): # --- P2.13 --------------------------------------------------------------- -def test_p2_13_observer_exception_does_not_corrupt_path(index): +def test_p2_13_observer_exception_propagates(index): engine = SymExEngine(index) - @engine.observe.memory_read - def boom(ctx, **payload): + @engine.observe.instruction + def boom(ctx, inst, **_): raise RuntimeError("test bug") - paths = engine.explore("symbolic_test_add_i32", args=[2, 3]) - # Path completes normally despite observer raising. - assert paths[0].return_value == 5 - errors = [e for e in paths[0].events if e.get("kind") == - "observer_error"] - assert errors, "observer_error not recorded on path.events" - assert "test bug" in errors[0]["error"] + import pytest + with pytest.raises(RuntimeError, match="test bug"): + engine.explore("symbolic_test_add_i32", args=[2, 3]) # --- P2.14 --------------------------------------------------------------- diff --git a/tests/symex/test_phase8c.py b/tests/symex/test_phase8c.py index 6cb281df1..60a72297f 100644 --- a/tests/symex/test_phase8c.py +++ b/tests/symex/test_phase8c.py @@ -70,5 +70,8 @@ def test_p8c_1_shadow_roundtrips_symbolic_write(index): assert _is_z3(got), \ f"expected z3 expression from shadow read; got {type(got).__name__}" - assert got is sym or got.eq(sym), \ + # The byte-granular shadow reconstructs via Concat(Extract(...)); z3.simplify + # collapses that back to the original variable, so structural eq holds. + import z3 as _z3 + assert _z3.simplify(got).eq(_z3.simplify(sym)), \ f"shadow returned a different expression: {got!r} vs {sym!r}" diff --git a/tests/symex/test_phase9.py b/tests/symex/test_phase9.py index 84106ca77..90a16df4c 100644 --- a/tests/symex/test_phase9.py +++ b/tests/symex/test_phase9.py @@ -296,9 +296,9 @@ def test_p9_8_indirect_call_target_kind_concrete(index): symbolic_count = [0] @engine.intercept.indirect_call(target_kind="concrete") - def on_concrete(ctx, next_hook): + def on_concrete(ctx, target_addr, next_hook): concrete_count[0] += 1 - return next_hook(ctx) + return next_hook(ctx, target_addr) @engine.intercept.indirect_call(target_kind="symbolic") def on_symbolic(ctx, next_hook):