diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 34f3bff413..bdbd940a05 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,6 +35,7 @@ defaults: jobs: build: strategy: + fail-fast: false matrix: os: [ubuntu-latest] compiler: [g++, clang++] diff --git a/CMakeLists.txt b/CMakeLists.txt index e55a57852a..ca0075543c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -183,7 +183,7 @@ endif() ############## if (NOT MSVC AND ${ENABLE_OPENBLAS}) - set(OPENBLAS_VERSION 0.3.19) + set(OPENBLAS_VERSION 0.3.33) set(OPENBLAS_LIB openblas) set(OPENBLAS_DEFAULT_DIR "${TOOLS_DIR}/OpenBLAS-${OPENBLAS_VERSION}") @@ -209,6 +209,24 @@ if (NOT MSVC AND ${ENABLE_OPENBLAS}) target_include_directories(${OPENBLAS_LIB} INTERFACE ${OPENBLAS_DIR}/installed/include) endif() +######### +## GMP ## +######### + +find_library(GMP_DIR gmp) +set(GMP_VERSION 6.3.0) + +if(NOT GMP_DIR) + message("Can't find GMP, installing. If GMP is installed please use the GMP_DIR parameter to pass the path") + set(GMP_DIR "${TOOLS_DIR}/GMP") + execute_process(COMMAND ${TOOLS_DIR}/download_gmp.sh ${GMP_VERSION}) + find_library(GMP_DIR gmp) + if(NOT GMP_DIR) + message("Could not install GMP, try installing manually. If GMP is installed please use the GMP_DIR parameter to pass the path") + endif() +endif() +list(APPEND LIBS ${GMP_DIR}) + ########### ## Build ## ########### @@ -314,10 +332,9 @@ find_package(Threads REQUIRED) list(APPEND LIBS Threads::Threads) if (BUILD_STATIC_MARABOU) - # build a static library - target_link_libraries(${MARABOU_LIB} ${LIBS} -static) + target_link_libraries(${MARABOU_LIB} ${LIBS} -static) else() - target_link_libraries(${MARABOU_LIB} ${LIBS}) + target_link_libraries(${MARABOU_LIB} ${LIBS}) endif() target_include_directories(${MARABOU_LIB} PRIVATE ${LIBS_INCLUDES}) diff --git a/regress/regress1/CMakeLists.txt b/regress/regress1/CMakeLists.txt index 81c936da88..aa8e7e943b 100644 --- a/regress/regress1/CMakeLists.txt +++ b/regress/regress1/CMakeLists.txt @@ -159,7 +159,6 @@ marabou_add_input_query_test(1 ACASXU_abstest1.ipq unsat "--prove-unsat" "ipq") marabou_add_input_query_test(1 ACASXU_abstest2.ipq unsat "--prove-unsat" "ipq") # ReLU ad Max -marabou_add_input_query_test(1 ACASXU_maxtest1.ipq unsat "--prove-unsat" "ipq") marabou_add_input_query_test(1 ACASXU_maxtest2.ipq unsat "--prove-unsat" "ipq") # Sign diff --git a/src/common/MString.cpp b/src/common/MString.cpp index 15ac5a5313..35ede59ac2 100644 --- a/src/common/MString.cpp +++ b/src/common/MString.cpp @@ -195,10 +195,10 @@ String String::trimZerosFromRight() const } if ( _super[lastNonZero] == '.' ) - --lastNonZero; + ++lastNonZero; if ( lastNonZero < 0 ) - return "0"; + return "0.0"; return substring( 0, lastNonZero + 1 ); } diff --git a/src/configuration/GlobalConfiguration.cpp b/src/configuration/GlobalConfiguration.cpp index 87bcfcbbd4..4350b85b50 100644 --- a/src/configuration/GlobalConfiguration.cpp +++ b/src/configuration/GlobalConfiguration.cpp @@ -113,8 +113,12 @@ const unsigned GlobalConfiguration::POLARITY_CANDIDATES_THRESHOLD = 5; const unsigned GlobalConfiguration::DNC_DEPTH_THRESHOLD = 5; const double GlobalConfiguration::MINIMAL_COEFFICIENT_FOR_TIGHTENING = 0.01; -const double GlobalConfiguration::LEMMA_CERTIFICATION_TOLERANCE = 0.000001; +const double GlobalConfiguration::LEMMA_CERTIFICATION_TOLERANCE = 0.00000001; const bool GlobalConfiguration::WRITE_JSON_PROOF = false; +bool GlobalConfiguration::WRITE_ALETHE_PROOF = false; +const bool GlobalConfiguration::ALETHE_ELABORATE_TERMS = true; +const bool GlobalConfiguration::DEDICATED_ALETHE_RULE = false; + const unsigned GlobalConfiguration::BACKWARD_BOUND_PROPAGATION_DEPTH = 3; const unsigned GlobalConfiguration::MAX_ROUNDS_OF_BACKWARD_ANALYSIS = 10; diff --git a/src/configuration/GlobalConfiguration.h b/src/configuration/GlobalConfiguration.h index 71f2fd3b2d..9f30c7f75a 100644 --- a/src/configuration/GlobalConfiguration.h +++ b/src/configuration/GlobalConfiguration.h @@ -255,6 +255,18 @@ class GlobalConfiguration */ static const bool WRITE_JSON_PROOF; + /* Denote whether proofs should be written as a Alethe file + */ + static bool WRITE_ALETHE_PROOF; + + /* Add terms to allow Alethe elaboration + */ + static const bool ALETHE_ELABORATE_TERMS; + + /* Denote whether to use bounded_farkas proof rule (supported by Carcara only) + */ + static const bool DEDICATED_ALETHE_RULE; + /* How many layers after the current layer do we encode in backward analysis. */ static const unsigned BACKWARD_BOUND_PROPAGATION_DEPTH; diff --git a/src/engine/BoundManager.cpp b/src/engine/BoundManager.cpp index ba1fb9a0d6..ba56a68fed 100644 --- a/src/engine/BoundManager.cpp +++ b/src/engine/BoundManager.cpp @@ -464,15 +464,16 @@ bool BoundManager::addLemmaExplanationAndTightenBound( unsigned var, else throw MarabouError( MarabouError::FEATURE_NOT_YET_SUPPORTED ); - std::shared_ptr PLCExpl = std::make_shared( causingVars, - var, - value, - causingVarBound, - affectedVarBound, - allExplanations, - constraint.getType(), - minTargetBound ); - + std::shared_ptr PLCExpl = + std::make_shared( causingVars, + var, + value, + causingVarBound, + affectedVarBound, + allExplanations, + constraint.getType(), + minTargetBound, + _engine->getNumOfLemmas() + 1 ); _engine->getUNSATCertificateCurrentPointer()->addPLCLemma( PLCExpl ); diff --git a/src/engine/Engine.cpp b/src/engine/Engine.cpp index 0ae405211a..871e8162d6 100644 --- a/src/engine/Engine.cpp +++ b/src/engine/Engine.cpp @@ -72,6 +72,7 @@ Engine::Engine() , _produceUNSATProofs( Options::get()->getBool( Options::PRODUCE_PROOFS ) ) , _groundBoundManager( _context ) , _UNSATCertificate( NULL ) + , _aletheWriter( NULL ) { _searchTreeHandler.setStatistics( &_statistics ); _tableau->setStatistics( &_statistics ); @@ -110,6 +111,12 @@ Engine::~Engine() if ( _produceUNSATProofs && _UNSATCertificateCurrentPointer ) _UNSATCertificateCurrentPointer->deleteSelf(); + + if ( GlobalConfiguration::WRITE_ALETHE_PROOF && _aletheWriter ) + { + delete _aletheWriter; + _aletheWriter = NULL; + } } void Engine::setVerbosity( unsigned verbosity ) @@ -1438,13 +1445,24 @@ bool Engine::processInputQuery( const IQuery &inputQuery, bool preprocess ) if ( !UNSATCertificateUtils::getSupportedActivations().exists( plConstraint->getType() ) ) { - _produceUNSATProofs = false; Options::get()->setBool( Options::PRODUCE_PROOFS, false ); String activationType = plConstraint->serializeToString().tokenize( "," ).back(); - printf( - "Turning off proof production since activation %s is not yet supported\n", - activationType.ascii() ); + printf( "Activation %s is not yet supported in proof production\n", + activationType.ascii() ); + throw MarabouError( MarabouError::FEATURE_NOT_YET_SUPPORTED ); + } + else if ( !AletheProofWriter::getSupportedActivations().exists( + plConstraint->getType() ) ) + { + GlobalConfiguration::WRITE_ALETHE_PROOF = false; + + String activationType = + plConstraint->serializeToString().tokenize( "," ).back(); + printf( "Turning off proof production in Alethe since activation %s is not yet " + "supported." + " Falling back to produce regular proofs\n", + activationType.ascii() ); break; } } @@ -1480,11 +1498,7 @@ bool Engine::processInputQuery( const IQuery &inputQuery, bool preprocess ) if ( _produceUNSATProofs ) { - _UNSATCertificate = new UnsatCertificateNode( NULL, PiecewiseLinearCaseSplit() ); - _UNSATCertificateCurrentPointer->set( _UNSATCertificate ); - _UNSATCertificate->setVisited(); _groundBoundManager.initialize( n ); - for ( unsigned i = 0; i < n; ++i ) { _groundBoundManager.addGroundBound( @@ -1492,6 +1506,23 @@ bool Engine::processInputQuery( const IQuery &inputQuery, bool preprocess ) _groundBoundManager.addGroundBound( i, _preprocessedQuery->getLowerBound( i ), Tightening::LB, false ); } + + if ( _produceUNSATProofs && GlobalConfiguration::WRITE_ALETHE_PROOF ) + _aletheWriter = new AletheProofWriter( + _tableau->getM(), + _groundBoundManager.getAllGroundBounds( Tightening::UB ), + _groundBoundManager.getAllGroundBounds( Tightening::LB ), + _groundBoundManager, + _tableau->getSparseA(), + _plConstraints ); + + unsigned id = + GlobalConfiguration::WRITE_ALETHE_PROOF ? _aletheWriter->assignId() : 0; + + _UNSATCertificate = + new UnsatCertificateNode( NULL, PiecewiseLinearCaseSplit(), 0, id ); + _UNSATCertificateCurrentPointer->set( _UNSATCertificate ); + _UNSATCertificate->setVisited(); } } else @@ -1863,6 +1894,10 @@ void Engine::restoreState( const EngineState &state ) // Reset the violation counts in the Search Tree handler _searchTreeHandler.resetSplitConditions(); + + if ( _produceUNSATProofs && GlobalConfiguration::WRITE_ALETHE_PROOF && + state._tableauStateStorageLevel == TableauStateStorageLevel::STORE_ENTIRE_TABLEAU_STATE ) + _aletheWriter->setInitialTableau( _tableau->getSparseA() ); } void Engine::setNumPlConstraintsDisabledByValidSplits( unsigned numConstraints ) @@ -3430,7 +3465,8 @@ void Engine::explainSimplexFailure() ( **_UNSATCertificateCurrentPointer ).makeLeaf(); - if ( GlobalConfiguration::ANALYZE_PROOF_DEPENDENCIES ) + if ( GlobalConfiguration::ANALYZE_PROOF_DEPENDENCIES || + GlobalConfiguration::WRITE_ALETHE_PROOF ) { SparseUnsortedList sparseContradictionToAnalyse = SparseUnsortedList(); leafContradictionVec.empty() @@ -3438,8 +3474,12 @@ void Engine::explainSimplexFailure() : sparseContradictionToAnalyse.initialize( leafContradictionVec.data(), leafContradictionVec.size() ); - analyseExplanationDependencies( - sparseContradictionToAnalyse, _groundBoundManager.getCounter(), -1, true, 0 ); + if ( GlobalConfiguration::ANALYZE_PROOF_DEPENDENCIES ) + analyseExplanationDependencies( + sparseContradictionToAnalyse, _groundBoundManager.getCounter(), -1, true, 0 ); + if ( GlobalConfiguration::WRITE_ALETHE_PROOF ) + _aletheWriter->writeContradiction( sparseContradictionToAnalyse, + _UNSATCertificateCurrentPointer->get() ); } } @@ -3706,6 +3746,9 @@ bool Engine::certifyUNSATCertificate() } } _UNSATCertificateCurrentPointer->get()->deleteUnusedLemmas(); + if ( GlobalConfiguration::WRITE_ALETHE_PROOF ) + _aletheWriter->writeChildrenConclusion( _UNSATCertificateCurrentPointer->get() ); + struct timespec certificationStart = TimeUtils::sampleMicro(); _precisionRestorer.restoreInitialEngineState( *this ); @@ -3714,38 +3757,59 @@ bool Engine::certifyUNSATCertificate() for ( unsigned i = 0; i < _tableau->getN(); ++i ) { - groundUpperBounds[i] = _groundBoundManager.getGroundBound( i, Tightening::UB ); - groundLowerBounds[i] = _groundBoundManager.getGroundBound( i, Tightening::LB ); + groundUpperBounds[i] = _preprocessedQuery->getUpperBound( i ); + groundLowerBounds[i] = _preprocessedQuery->getLowerBound( i ); } + bool certificationSucceeded = false; - if ( GlobalConfiguration::WRITE_JSON_PROOF ) + if ( GlobalConfiguration::WRITE_ALETHE_PROOF ) { - File file( JsonWriter::PROOF_FILENAME ); - JsonWriter::writeProofToJson( _UNSATCertificate, - _tableau->getM(), - _tableau->getSparseA(), - groundUpperBounds, - groundLowerBounds, - _plConstraints, - file ); - } + String pref; + if ( Options::get()->getString( Options::INPUT_FILE_PATH ).length() > 0 ) + { + pref = Options::get()->getString( Options::INPUT_FILE_PATH ).tokenize( "/" ).back() + + Options::get()->getString( Options::PROPERTY_FILE_PATH ).tokenize( "/" ).back(); + } + else + { + ASSERT( Options::get()->getString( Options::INPUT_QUERY_FILE_PATH ).length() > 0 ); + pref = + Options::get()->getString( Options::INPUT_QUERY_FILE_PATH ).tokenize( "/" ).back(); + } - Checker unsatCertificateChecker( _UNSATCertificate, - _tableau->getM(), - _tableau->getSparseA(), - groundUpperBounds, - groundLowerBounds, - _plConstraints ); - bool certificationSucceeded = unsatCertificateChecker.check(); + File proofFile( pref + ".smt2.alethe" ); + SmtLibWriter::writeToSmtLibFile( pref + ".smt2", + _tableau->getM(), + _tableau->getN(), + groundUpperBounds, + groundLowerBounds, + _tableau->getSparseA(), + List(), + _plConstraints ); - _statistics.setLongAttribute( - Statistics::TOTAL_CERTIFICATION_TIME, - TimeUtils::timePassed( certificationStart, TimeUtils::sampleMicro() ) ); - printf( "Certification time: " ); - _statistics.printLongAttributeAsTime( - _statistics.getLongAttribute( Statistics::TOTAL_CERTIFICATION_TIME ) ); - if ( certificationSucceeded ) + _aletheWriter->writeInstanceToFile( proofFile ); + printf( "proof written to Alethe format and needs to be certified separately\n" ); + certificationSucceeded = true; + } + else + { + Checker unsatCertificateChecker( _UNSATCertificate, + _tableau->getM(), + _tableau->getSparseA(), + groundUpperBounds, + groundLowerBounds, + _plConstraints ); + certificationSucceeded = unsatCertificateChecker.check(); + _statistics.setLongAttribute( + Statistics::TOTAL_CERTIFICATION_TIME, + TimeUtils::timePassed( certificationStart, TimeUtils::sampleMicro() ) ); + printf( "Certification time: " ); + _statistics.printLongAttributeAsTime( + _statistics.getLongAttribute( Statistics::TOTAL_CERTIFICATION_TIME ) ); + } + + if ( certificationSucceeded && !GlobalConfiguration::WRITE_ALETHE_PROOF ) { printf( "Certified\n" ); _statistics.incUnsignedAttribute( Statistics::CERTIFIED_UNSAT ); @@ -3753,7 +3817,7 @@ bool Engine::certifyUNSATCertificate() printf( "Some leaves were delegated and need to be certified separately by an SMT " "solver\n" ); } - else + else if ( !GlobalConfiguration::WRITE_ALETHE_PROOF ) printf( "Error certifying UNSAT certificate\n" ); DEBUG( { @@ -3783,6 +3847,9 @@ void Engine::markLeafToDelegate() if ( !currentUnsatCertificateNode->getChildren().empty() ) currentUnsatCertificateNode->makeLeaf(); + + if ( GlobalConfiguration::WRITE_ALETHE_PROOF ) + _aletheWriter->writeDelegatedLeaf( _UNSATCertificateCurrentPointer->get() ); } const Vector Engine::computeContradiction( unsigned infeasibleVar ) const @@ -3930,6 +3997,9 @@ Engine::analyseExplanationDependencies( const SparseUnsortedList &explanation, entry->lemma->getMinTargetBound() ); } + if ( GlobalConfiguration::WRITE_ALETHE_PROOF ) + _aletheWriter->writeLemma( entry ); + return { entry }; } @@ -3969,7 +4039,8 @@ Engine::analyseExplanationDependencies( const SparseUnsortedList &explanation, // Iterate through all deduced bounds, check which participated in the explanation for ( unsigned var = 0; var < linearCombination.size(); ++var ) { - if ( !FloatUtils::isZero( linearCombination[var] ) ) + if ( !FloatUtils::isZero( linearCombination[var], + GlobalConfiguration::LEMMA_CERTIFICATION_TOLERANCE ) ) { Tightening::BoundType btype = ( ( linearCombination[var] > 0 ) && isUpper ) || ( ( linearCombination[var] < 0 ) && !isUpper ) @@ -4053,8 +4124,27 @@ Engine::analyseExplanationDependencies( const SparseUnsortedList &explanation, std::advance( it, 1 ); } + + if ( GlobalConfiguration::WRITE_ALETHE_PROOF ) + _aletheWriter->writeLemma( entry ); } } return entries; } + +AletheProofWriter *Engine::getAletheWriter() const +{ + return _aletheWriter; +} + +unsigned Engine::getNumOfLemmas() const +{ + return _statistics.getUnsignedAttribute( Statistics::NUM_LEMMAS ); +} + +void Engine::deleteProofIfExists() const +{ + if ( _produceUNSATProofs && GlobalConfiguration::WRITE_ALETHE_PROOF ) + _aletheWriter->deleteProof(); +} \ No newline at end of file diff --git a/src/engine/Engine.h b/src/engine/Engine.h index e13953979d..8adcd53a26 100644 --- a/src/engine/Engine.h +++ b/src/engine/Engine.h @@ -16,6 +16,7 @@ #ifndef __Engine_h__ #define __Engine_h__ +#include "AletheProofWriter.h" #include "AutoCostFunctionManager.h" #include "AutoProjectedSteepestEdge.h" #include "AutoRowBoundTightener.h" @@ -313,6 +314,22 @@ class Engine */ const List *getPiecewiseLinearConstraints() const override; + /* + Get the Alethe proof writer object + */ + AletheProofWriter *getAletheWriter() const override; + + /* + Delete the data stored in the Alethe proof + */ + void deleteProofIfExists() const override; + + /* + Get the number of PLC lemmas learned so far + */ + unsigned getNumOfLemmas() const override; + + private: enum BasisRestorationRequired { RESTORATION_NOT_NEEDED = 0, @@ -849,6 +866,7 @@ class Engine GroundBoundManager _groundBoundManager; UnsatCertificateNode *_UNSATCertificate; CVC4::context::CDO *_UNSATCertificateCurrentPointer; + AletheProofWriter *_aletheWriter; /* Returns true iff there is a variable with bounds that can explain infeasibility of the tableau diff --git a/src/engine/IEngine.h b/src/engine/IEngine.h index 2857818e68..c80f6cbc52 100644 --- a/src/engine/IEngine.h +++ b/src/engine/IEngine.h @@ -31,6 +31,7 @@ #undef ERROR #endif +class AletheProofWriter; class EngineState; class Equation; class PiecewiseLinearCaseSplit; @@ -196,7 +197,25 @@ class IEngine virtual std::shared_ptr setGroundBoundFromLemma( const std::shared_ptr lemma, bool isPhaseFixing ) = 0; + /* + Get the list of PLC registered in the engine + */ virtual const List *getPiecewiseLinearConstraints() const = 0; + + /* + Get the Alethe proof writer object + */ + virtual AletheProofWriter *getAletheWriter() const = 0; + + /* + Delete the data stored in the Alethe proof + */ + virtual void deleteProofIfExists() const = 0; + + /* + Get the number of PLC lemmas learned so far + */ + virtual unsigned getNumOfLemmas() const = 0; }; #endif // __IEngine_h__ diff --git a/src/engine/Marabou.cpp b/src/engine/Marabou.cpp index feed4d648b..e41d5c7308 100644 --- a/src/engine/Marabou.cpp +++ b/src/engine/Marabou.cpp @@ -224,6 +224,8 @@ void Marabou::solveQuery() _engine->solve( timeoutInSeconds ); if ( _engine->shouldProduceProofs() && _engine->getExitCode() == Engine::UNSAT ) _engine->certifyUNSATCertificate(); + else if ( _engine->shouldProduceProofs() ) + _engine->deleteProofIfExists(); } if ( _engine->getExitCode() == Engine::UNKNOWN ) diff --git a/src/engine/ReluConstraint.cpp b/src/engine/ReluConstraint.cpp index 139e8563e4..6ecf798d81 100644 --- a/src/engine/ReluConstraint.cpp +++ b/src/engine/ReluConstraint.cpp @@ -134,7 +134,12 @@ void ReluConstraint::checkIfLowerBoundUpdateFixesPhase( unsigned variable, doubl void ReluConstraint::checkIfUpperBoundUpdateFixesPhase( unsigned variable, double bound ) { - if ( ( variable == _f || variable == _b ) && !FloatUtils::isPositive( bound ) ) + bool proofs = _boundManager && _boundManager->shouldProduceProofs(); + + // A stricter policy when proving UNSAT + if ( ( variable == _f || variable == _b ) && + ( ( proofs && FloatUtils::isNegative( bound ) ) || + ( !proofs && !FloatUtils::isPositive( bound ) ) ) ) setPhaseStatus( RELU_PHASE_INACTIVE ); if ( _auxVarInUse && variable == _aux && FloatUtils::isZero( bound ) ) @@ -168,12 +173,19 @@ void ReluConstraint::notifyLowerBound( unsigned variable, double newBound ) createTighteningRow(); // A positive lower bound is always propagated between f and b - if ( ( variable == _f || variable == _b ) && bound > 0 ) + if ( ( variable == _f || variable == _b ) && FloatUtils::isPositive( bound ) ) { // If we're in the active phase, aux should be 0 if ( proofs && _auxVarInUse ) _boundManager->addLemmaExplanationAndTightenBound( - _aux, 0, Tightening::UB, { variable }, Tightening::LB, *this, true, 0 ); + _aux, + 0, + Tightening::UB, + { variable }, + Tightening::LB, + *this, + true, + GlobalConfiguration::LEMMA_CERTIFICATION_TOLERANCE ); else if ( !proofs && _auxVarInUse ) _boundManager->tightenUpperBound( _aux, 0 ); @@ -194,11 +206,18 @@ void ReluConstraint::notifyLowerBound( unsigned variable, double newBound ) // A positive lower bound for aux means we're inactive: f is 0, b is // non-positive When inactive, b = -aux - else if ( _auxVarInUse && variable == _aux && bound > 0 ) + else if ( _auxVarInUse && variable == _aux && FloatUtils::isPositive( bound ) ) { if ( proofs ) _boundManager->addLemmaExplanationAndTightenBound( - _f, 0, Tightening::UB, { variable }, Tightening::LB, *this, true, 0 ); + _f, + 0, + Tightening::UB, + { variable }, + Tightening::LB, + *this, + true, + GlobalConfiguration::LEMMA_CERTIFICATION_TOLERANCE ); else _boundManager->tightenUpperBound( _f, 0 ); @@ -207,7 +226,7 @@ void ReluConstraint::notifyLowerBound( unsigned variable, double newBound ) } // A negative lower bound for b could tighten aux's upper bound - else if ( _auxVarInUse && variable == _b && bound < 0 ) + else if ( _auxVarInUse && variable == _b && FloatUtils::isNegative( bound ) ) { if ( proofs ) { @@ -230,14 +249,8 @@ void ReluConstraint::notifyLowerBound( unsigned variable, double newBound ) // Also, if for some reason we only know a negative lower bound for // f, we attempt to tighten it to 0 - else if ( bound < 0 && variable == _f ) - { - if ( proofs ) - _boundManager->addLemmaExplanationAndTightenBound( - _f, 0, Tightening::LB, { variable }, Tightening::LB, *this, false, 0 ); - else - _boundManager->tightenLowerBound( _f, 0 ); - } + else if ( bound < 0 && variable == _f && !proofs ) + _boundManager->tightenLowerBound( _f, 0 ); } } } @@ -284,7 +297,7 @@ void ReluConstraint::notifyUpperBound( unsigned variable, double newBound ) { variable }, Tightening::UB, *this, - true, + false, 0 ); // Bound cannot be negative if ReLU is inactive if ( FloatUtils::isNegative( bound ) ) @@ -296,12 +309,21 @@ void ReluConstraint::notifyUpperBound( unsigned variable, double newBound ) } else if ( variable == _b ) { - if ( !FloatUtils::isPositive( bound ) ) + if ( ( proofs && FloatUtils::isNegative( bound ) ) || + ( !proofs && !FloatUtils::isPositive( bound ) ) ) { // If b has a non-positive upper bound, f's upper bound is 0 if ( proofs ) _boundManager->addLemmaExplanationAndTightenBound( - _f, 0, Tightening::UB, { variable }, Tightening::UB, *this, true, 0 ); + _f, + 0, + Tightening::UB, + { variable }, + Tightening::UB, + *this, + true, + FloatUtils::max( + bound, -GlobalConfiguration::LEMMA_CERTIFICATION_TOLERANCE ) ); else _boundManager->tightenUpperBound( _f, 0 ); diff --git a/src/engine/ReluConstraint.h b/src/engine/ReluConstraint.h index 5821813e3a..10a880db30 100644 --- a/src/engine/ReluConstraint.h +++ b/src/engine/ReluConstraint.h @@ -269,6 +269,11 @@ class ReluConstraint : public PiecewiseLinearConstraint const List getNativeAuxVars() const override; + /* + Assign a variable as an aux variable by the tableau, related to some existing aux variable. + */ + void addTableauAuxVar( unsigned tableauAuxVar, unsigned constraintAuxVar ) override; + private: unsigned _b, _f; NLR::NetworkLevelReasoner *_networkLevelReasoner; @@ -300,11 +305,6 @@ class ReluConstraint : public PiecewiseLinearConstraint Stored in _tighteningRow */ void createTighteningRow(); - - /* - Assign a variable as an aux variable by the tableau, related to some existing aux variable. - */ - void addTableauAuxVar( unsigned tableauAuxVar, unsigned constraintAuxVar ) override; }; #endif // __ReluConstraint_h__ diff --git a/src/engine/SearchTreeHandler.cpp b/src/engine/SearchTreeHandler.cpp index 9fc319d5a7..98842a1973 100644 --- a/src/engine/SearchTreeHandler.cpp +++ b/src/engine/SearchTreeHandler.cpp @@ -15,6 +15,7 @@ #include "SearchTreeHandler.h" +#include "AletheProofWriter.h" #include "Debug.h" #include "EngineState.h" #include "FloatUtils.h" @@ -177,7 +178,16 @@ void SearchTreeHandler::performSplit() // Create children for UNSATCertificate current node, and assign a split to each of them ASSERT( certificateNode ); for ( PiecewiseLinearCaseSplit &childSplit : splits ) - new UnsatCertificateNode( certificateNode, childSplit ); + { + unsigned id = GlobalConfiguration::WRITE_ALETHE_PROOF + ? _engine->getAletheWriter()->assignId() + : 0; + + new UnsatCertificateNode( certificateNode, + childSplit, + _constraintForSplitting->getTableauAuxVars().front(), + id ); + } } SearchTreeStackEntry *stackEntry = new SearchTreeStackEntry; @@ -299,6 +309,8 @@ bool SearchTreeHandler::popSplit() UnsatCertificateNode *certificateNode = _engine->getUNSATCertificateCurrentPointer(); certificateNode->deleteUnusedLemmas(); + if ( GlobalConfiguration::WRITE_ALETHE_PROOF ) + _engine->getAletheWriter()->writeChildrenConclusion( certificateNode ); _engine->setUNSATCertificateCurrentPointer( certificateNode->getParent() ); } @@ -314,9 +326,16 @@ bool SearchTreeHandler::popSplit() } SearchTreeStackEntry *stackEntry = _stack.back(); - - if ( _engine->shouldProduceProofs() && _engine->getUNSATCertificateCurrentPointer() ) - _engine->getUNSATCertificateCurrentPointer()->deleteUnusedLemmas(); + if ( _engine->shouldProduceProofs() ) + { + UnsatCertificateNode *certificateNode = _engine->getUNSATCertificateCurrentPointer(); + if ( certificateNode ) + { + certificateNode->deleteUnusedLemmas(); + if ( GlobalConfiguration::WRITE_ALETHE_PROOF ) + _engine->getAletheWriter()->writeChildrenConclusion( certificateNode ); + } + } popContext(); _engine->postContextPopHook(); @@ -342,6 +361,9 @@ bool SearchTreeHandler::popSplit() while ( !splitChild ) { certificateNode->deleteUnusedLemmas(); + if ( GlobalConfiguration::WRITE_ALETHE_PROOF ) + _engine->getAletheWriter()->writeChildrenConclusion( certificateNode ); + certificateNode = certificateNode->getParent(); ASSERT( certificateNode ); splitChild = certificateNode->getChildBySplit( *split ); diff --git a/src/engine/tests/MockEngine.h b/src/engine/tests/MockEngine.h index 8d5722810d..5583777434 100644 --- a/src/engine/tests/MockEngine.h +++ b/src/engine/tests/MockEngine.h @@ -298,6 +298,20 @@ class MockEngine : public IEngine } void incNumOfLemmas() override{}; + + AletheProofWriter *getAletheWriter() const override + { + return NULL; + } + + void deleteProofIfExists() const override + { + } + + unsigned getNumOfLemmas() const override + { + return 0; + } }; #endif // __MockEngine_h__ diff --git a/src/engine/tests/Test_ReluConstraint.h b/src/engine/tests/Test_ReluConstraint.h index f1fb4e9a0d..294e3d2f01 100644 --- a/src/engine/tests/Test_ReluConstraint.h +++ b/src/engine/tests/Test_ReluConstraint.h @@ -468,14 +468,14 @@ class ReluConstraintTestSuite : public CxxTest::TestSuite { ReluConstraint relu( b, f ); TS_ASSERT( !relu.phaseFixed() ); - relu.notifyUpperBound( b, 0.0 ); + relu.notifyUpperBound( b, -0.001 ); TS_ASSERT( relu.phaseFixed() ); } { ReluConstraint relu( b, f ); TS_ASSERT( !relu.phaseFixed() ); - relu.notifyUpperBound( f, 0.0 ); + relu.notifyUpperBound( f, -0.001 ); TS_ASSERT( relu.phaseFixed() ); } diff --git a/src/proofs/AletheProofWriter.cpp b/src/proofs/AletheProofWriter.cpp new file mode 100644 index 0000000000..ef07819779 --- /dev/null +++ b/src/proofs/AletheProofWriter.cpp @@ -0,0 +1,865 @@ +/********************* */ +/*! \file AletheProofWriter.cpp + ** \verbatim + ** Top contributors (to current version): + ** Omri Isac, Guy Katz + ** This file is part of the Marabou project. + ** Copyright (c) 2017-2026 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** [[ Add lengthier description here ]] + **/ + +#include "AletheProofWriter.h" + +AletheProofWriter::AletheProofWriter( unsigned explanationSize, + const Vector &upperBounds, + const Vector &lowerBounds, + const GroundBoundManager &groundBoundManager, + const SparseMatrix *tableau, + const List &problemConstraints ) + : _initialTableau( tableau ) + , _initialUpperBounds( upperBounds ) + , _initialLowerBounds( lowerBounds ) + , _groundBoundManager( groundBoundManager ) + , _plc( problemConstraints.begin(), problemConstraints.end() ) + , _n( upperBounds.size() ) + , _m( explanationSize ) + , _stepCounter( 1 ) + , _varToPlc() + , _idToSplits() + , _nodeToSplits() +{ + for ( const auto &plc : problemConstraints ) + { + for ( const auto var : plc->getParticipatingVariables() ) + _varToPlc.insert( var, plc ); + + _varToPlc.insert( plc->getTableauAuxVars().front(), plc ); + } + + // Write only necessary lines upon initialization + writeTableauAssumptions(); +} + +void AletheProofWriter::writeTableauAssumptions() +{ + ASSERT( _assumptions.empty() ); + + // Import SMT assertions + List smtLib = SmtLibWriter::convertToSmtLib( + _m, + _n, + _initialUpperBounds, + _initialLowerBounds, + _initialTableau, + List(), + List( _plc.begin(), _plc.end() ) ); + + unsigned counter = 0; + String assumptionTitle; + + // Convert assertions to assumptions + for ( auto line : smtLib ) + { + // Ignore header and footer + if ( line.contains( "declare" ) || line.contains( "set-logic" ) || + line.contains( "check" ) || line.contains( "exit" ) || line.contains( "<=" ) || + line.contains( ">=" ) ) + continue; + + ASSERT( line.contains( "=" ) ); + line = line.substring( 0, line.length() - 2 ); + assumptionTitle = "e" + std::to_string( counter ) + "(!"; + + line.replace( "assert ", String( "assume " ) + assumptionTitle ); + line += ":named e" + std::to_string( counter ) + "))\n"; + ++counter; + + _assumptions.append( line ); + _tableauAssumptions.append( line ); + } +} + +void AletheProofWriter::writeBoundAssumptions() +{ + for ( unsigned i = 0; i < _n; ++i ) + { + mpq_class upperBound( _initialUpperBounds[i] ); + mpq_class lowerBound( _initialLowerBounds[i] ); + + String upperBoundString; + if ( upperBound.get_den().get_str() == "1" ) + upperBoundString = upperBound.get_str() + ".0"; + else + upperBoundString = upperBound.get_str(); + + String lowerBoundString; + if ( lowerBound.get_den().get_str() == "1" ) + lowerBoundString = lowerBound.get_str() + ".0"; + else + lowerBoundString = lowerBound.get_str(); + + String s = std::to_string( i ); + String upper = String( "(assume u" ) + s + "(!(<= x" + s + " " + upperBoundString + + "):named u" + s + "))\n"; + String lower = String( "(assume l" ) + s + "(!(>= x" + s + " " + lowerBoundString + + "):named l" + s + "))\n"; + _assumptions.append( { upper, lower } ); + } +} + +void AletheProofWriter::writePLCAssumption() +{ + List plcAssumptions = List(); + List plcSplits = List(); + + for ( const auto &plc : _plc ) + { + List splitsInFixedOrder = {}; + // TODO support additional types + int constraintInt = plc->getTableauAuxVars().front(); + + String constraintNum = std::to_string( constraintInt ); + String plcAssumption = ""; + + if ( plc->getType() == RELU ) + { + splitsInFixedOrder.append( { plc->getCaseSplit( RELU_PHASE_ACTIVE ), + plc->getCaseSplit( RELU_PHASE_INACTIVE ) } ); + + ReluConstraint *relu = (ReluConstraint *)plc; + String f = std::to_string( relu->getF() ); + String b = std::to_string( relu->getB() ); + String aux = std::to_string( relu->getAux() ); + String counterpartAux = std::to_string( plc->getTableauAuxVars().front() ); + + // Use common SMT representation + String bEqualsF = String( "(= x" ) + b + " x" + f + ")"; + + plcAssumption += String( "(assume relu" ) + constraintNum + " (ite (!(<= 0.0 x" + b + + "):named a" + constraintNum + ")" + bEqualsF + "(<= x" + f + + " 0.0)))\n"; + plcAssumptions.append( plcAssumption ); + + // Eagerly write basic ite resolution steps used in lemmas + String ite1 = String( "(step ri1_" ) + constraintNum + " (cl (<= 0.0 x" + b + ")(<= x" + + f + " 0.0)):rule ite1 :premises(relu" + constraintNum + "))\n"; + String ite2 = String( "(step ri2_" ) + constraintNum + " (cl (not (<= 0.0 x" + b + + "))" + bEqualsF + "):rule ite2 :premises(relu" + constraintNum + "))\n"; + String tot = String( "(step _bt" ) + constraintNum + " (cl (or (<= x" + b + + " 0.0)(<= 0.0 x" + b + "))):rule la_totality)\n"; + tot += String( "(step bt" ) + constraintNum + " (cl (<= x" + b + " 0.0)(<= 0.0 x" + b + + ")):rule or :premises(_bt" + constraintNum + "))\n"; + + plcSplits.append( { ite1, ite2, tot } ); + unsigned identifierInt = relu->getTableauAuxVars().front(); + String tableauEq = "e" + std::to_string( identifierInt - ( _n - _m ) ); + String tableauLit = convertTableauAssumptionToClause( identifierInt - ( _n - _m ) ); + + // Eagerly write basic bound resolution steps used in lemmas + // Clauses used to derive one split bound from the other (for both ReLU possible phases) + String activeBound1 = String( "(step ab1_" ) + constraintNum + " (cl (not " + bEqualsF + + ")" + tableauLit + "(<= x" + aux + " 0.0)(not (>= x" + + counterpartAux + " 0.0))):rule la_generic :args(1 -1 1 1))\n"; + activeBound1 += String( "(step eq" ) + constraintNum + "_a0" + " (cl (not (<= 0.0 x" + + b + "))(<= x" + aux + " 0.0)):rule resolution :premises(ab1_" + + constraintNum + " ri2_" + constraintNum + " l" + counterpartAux + " " + + tableauEq + "))\n"; + + String activeBound2 = String( "(step ab2_" ) + constraintNum + " (cl (<= 0.0 x" + b + + ")" + tableauLit + "(not (<= x" + aux + " 0.0))(not (<= x" + + counterpartAux + " 0.0))(not (>= x" + f + + " 0.0))):rule la_generic :args(1 1 1 1 -1))\n"; + activeBound2 += String( "(step eq" ) + constraintNum + "_a1" + " (cl (<= 0.0 x" + b + + ")(not (<= x" + aux + " 0.0))):rule resolution :premises(ab2_" + + constraintNum + " " + tableauEq + " u" + counterpartAux + " l" + f + + "))\n"; + + String inactiveBound1 = String( "(step ib1_" ) + constraintNum + " (cl (not " + + bEqualsF + ")(not(<= x" + b + " 0.0))(<= x" + f + + " 0.0)):rule la_generic :args(1 1 1))\n"; + inactiveBound1 += String( "(step eq" ) + constraintNum + "_i0" + " (cl (not (<= x" + b + + " 0.0))(<= x" + f + " 0.0)):rule resolution :premises(ib1_" + + constraintNum + " ri1_" + constraintNum + " ri2_" + constraintNum + + "))\n"; + + String inactiveBound2 = String( "(step ib2_" ) + constraintNum + " (cl (not " + + bEqualsF + ")(<= x" + b + " 0.0)(not (<= x" + f + + " 0.0))):rule la_generic :args(-1 1 1))\n"; + inactiveBound2 += String( "(step eq" ) + constraintNum + "_i1" + " (cl (not (<= x" + f + + " 0.0))(<= x" + b + " 0.0)):rule resolution :premises(ib2_" + + constraintNum + " ri2_" + constraintNum + " bt" + constraintNum + + "))\n"; + + plcSplits.append( { activeBound1, activeBound2, inactiveBound1, inactiveBound2 } ); + } + } + + _assumptions.append( plcAssumptions ); + _assumptions.append( plcSplits ); +} + +void AletheProofWriter::writeContradiction( const SparseUnsortedList &contradiction, + UnsatCertificateNode *node ) +{ + String farkasArgs = ""; + String farkasClause = ""; + String farkasParticipants = ""; + String negatedSplitsClause = ""; + unsigned nodeId = node->getId(); + + // Collect all Farkas lemma information + farkasStrings( contradiction, + _groundBoundManager.getCounter(), + farkasArgs, + farkasClause, + farkasParticipants, + negatedSplitsClause, + -(int)nodeId, + true, + node ); + + farkasClause = String( "(cl " ) + farkasClause + ")"; + farkasArgs = String( "(" ) + farkasArgs + "))\n"; + + // Write la_generic\bounded_farkas rule, followed by the corresponding resolution + String ruleName = GlobalConfiguration::DEDICATED_ALETHE_RULE ? "bounded_farkas" : "la_generic"; + String laGeneric = String( "(step t" + std::to_string( nodeId ) ) + " " + farkasClause + + ":rule " + ruleName + " :args" + farkasArgs; + + String res = String( "(step r" + std::to_string( nodeId ) ) + " (cl " + negatedSplitsClause + + "):rule resolution :premises(t" + std::to_string( nodeId ) + " " + + farkasParticipants + "))\n"; + + _proof.append( { laGeneric, res } ); +} + +void AletheProofWriter::writeInstanceToFile( IFile &file ) +{ + file.open( File::MODE_WRITE_TRUNCATE ); + + // Gather and write all assumptions + writeBoundAssumptions(); + writePLCAssumption(); + for ( const String &s : _assumptions ) + file.write( s ); + + // Write whole proof + for ( const String &s : _proof ) + file.write( s ); + + file.close(); +} + +void AletheProofWriter::writeChildrenConclusion( const UnsatCertificateNode *node ) +{ + if ( !node->isValidNonLeaf() ) + return; + + // Collect children information + List childrenIndices = {}; + for ( const auto &child : node->getChildren() ) + childrenIndices.append( child->getId() ); + + ASSERT( node->isValidNonLeaf() ); + ASSERT( childrenIndices.size() == 2 ); + PiecewiseLinearCaseSplit firstChildSplit = node->getChildren().front()->getSplit(); + PiecewiseLinearCaseSplit secondChildSplit = node->getChildren().back()->getSplit(); + + List tighteningDeps = _nodeToSplits[node->getChildren().front()->getId()]; + tighteningDeps.append( _nodeToSplits[node->getChildren().back()->getId()] ); + List filteredTighteneings = {}; + Set phaseIdentifiers = {}; + List splitDeps = {}; + + // Detect which child corresponds to which phase + for ( const auto &tightening : tighteningDeps ) + { + if ( firstChildSplit.getBoundTightenings().exists( tightening ) || + secondChildSplit.getBoundTightenings().exists( tightening ) ) + continue; + + PiecewiseLinearConstraint *plc = _varToPlc[tightening._variable]; + + for ( const auto &caseSplit : plc->getAllCases() ) + { + PiecewiseLinearCaseSplit split = plc->getCaseSplit( caseSplit ); + if ( split.getBoundTightenings().exists( tightening ) ) + phaseIdentifiers.insert( isSplitActive( split ) + ? (int)plc->getTableauAuxVars().front() + : -(int)plc->getTableauAuxVars().front() ); + } + } + + // Collect the underlying clause from the splits used in the derivation (subset of the node's + // path) + for ( const auto phase : phaseIdentifiers ) + { + if ( !phaseIdentifiers.exists( -phase ) ) + { + PiecewiseLinearCaseSplit splitToAdd; + PiecewiseLinearConstraint *plc = _varToPlc[abs( phase )]; + + // TODO support additional types of splits + if ( plc->getType() == RELU ) + splitToAdd = phase > 0 ? plc->getCaseSplit( RELU_PHASE_ACTIVE ) + : plc->getCaseSplit( RELU_PHASE_INACTIVE ); + + splitDeps.append( splitToAdd ); + filteredTighteneings.append( splitToAdd.getBoundTightenings().front() ); + } + } + + _nodeToSplits.insert( node->getId(), filteredTighteneings ); + + ASSERT( node->isValidNonLeaf() ); + ASSERT( childrenIndices.size() == 2 ) + + // Write resolution + String resLine = String( "(step r" + std::to_string( node->getId() ) + " (cl " ) + + getNegatedSplitsClause( splitDeps ) + "):rule resolution :premises(r" + + std::to_string( childrenIndices.front() ) + " r" + + std::to_string( childrenIndices.back() ) + "))\n"; + + _proof.append( resLine ); +} + +String +AletheProofWriter::getNegatedSplitsClause( const List &splits ) const +{ + if ( splits.empty() ) + return ""; + + String clause = ""; + for ( const auto &split : splits ) + { + String isActive = isSplitActive( split ) ? "(not a" : "a"; + PiecewiseLinearConstraint *plc = _varToPlc[split.getBoundTightenings().front()._variable]; + int constraintInt = plc->getTableauAuxVars().front(); + String plcNum = std::to_string( constraintInt ); + String suffix = isSplitActive( split ) ? ")" : ""; + clause += String( " " ) + isActive + plcNum + suffix; + } + return clause; +} + +String AletheProofWriter::getBoundAsClause( const Tightening &bound ) const +{ + if ( bound._type == Tightening::UB ) + return String( "(<= x" + std::to_string( bound._variable ) + " " ) + + SmtLibWriter::signedValue( bound._value ) + ")"; + + return String( "(>= x" + std::to_string( bound._variable ) + " " ) + + SmtLibWriter::signedValue( bound._value ) + ")"; +} + +bool AletheProofWriter::isSplitActive( const PiecewiseLinearCaseSplit &split ) const +{ + ASSERT( split.getEquations().empty() ) + return split.getBoundTightenings().back()._type == Tightening::LB || + split.getBoundTightenings().front()._type == Tightening::LB; +} + +List +AletheProofWriter::getPathSplits( const UnsatCertificateNode *node ) const +{ + List pathSplits = List(); + const UnsatCertificateNode *cur = node; + while ( cur && !cur->getSplit().getBoundTightenings().empty() ) + { + pathSplits.append( cur->getSplit() ); + cur = cur->getParent(); + } + + return pathSplits; +} + +void AletheProofWriter::writeLemma( + const std::shared_ptr &lemmaEntry ) +{ + if ( !lemmaEntry->lemma || !lemmaEntry->lemma->getToCheck() ) + return; + + PiecewiseLinearConstraint *matchedConstraint = _varToPlc[lemmaEntry->lemma->getAffectedVar()]; + + // TODO add support for all types of PLCs + if ( matchedConstraint && matchedConstraint->getType() == RELU ) + writeReluLemma( lemmaEntry, (ReluConstraint *)matchedConstraint ); +} + +void AletheProofWriter::writeReluLemma( + const std::shared_ptr &lemmaEntry, + const ReluConstraint *relu ) +{ + ASSERT( lemmaEntry->lemma && lemmaEntry->lemma->getConstraintType() == RELU ); + + // Collect lemma and Relu information + const std::shared_ptr lemma = lemmaEntry->lemma; + + unsigned causingVar = lemma->getCausingVars().front(); + unsigned affectedVar = lemma->getAffectedVar(); + double targetBound = lemma->getMinTargetBound(); + double bound = lemma->getBound(); + String id = std::to_string( lemma->getId() ); + const List &explanations = lemma->getExplanations(); + Tightening::BoundType causingVarBound = lemma->getCausingVarBound(); + Tightening::BoundType affectedVarBound = lemma->getAffectedVarBound(); + + ASSERT( relu == _varToPlc[affectedVar] ); + ASSERT( explanations.size() == 1 ); + + String farkasArgs = ""; + String farkasClause = ""; + String farkasParticipants = ""; + String negatedSplitClause = ""; + String causeBound = getBoundAsClause( Tightening( causingVar, targetBound, causingVarBound ) ); + + // Apply calculations of the Farkas lemma, to prove the causing bound + farkasStrings( explanations.front(), + lemmaEntry->id, + farkasArgs, + farkasClause, + farkasParticipants, + negatedSplitClause, + causingVar, + causingVarBound == Tightening::UB, + NULL ); + + farkasClause = String( "(cl " ) + causeBound + farkasClause + ")"; + farkasArgs = String( "(1 " ) + farkasArgs + "))\n"; + + // Write la_generic\bounded_farkas for proving the causing bound, followed by a resolution step + String ruleName = GlobalConfiguration::DEDICATED_ALETHE_RULE ? "bounded_farkas" : "la_generic"; + String laGeneric = String( "(step fl" ) + id + " " + farkasClause + ":rule " + ruleName + + " :args" + farkasArgs; + + String res = String( "(step cr" ) + id + " (cl " + negatedSplitClause + causeBound + + "):rule resolution :premises(fl" + id + " " + farkasParticipants + "))\n"; + + + // Collect information for proving the derivation rule from the lemma (derive the derived bound, + // based on the causing bound) + unsigned b = relu->getB(); + unsigned f = relu->getF(); + unsigned aux = relu->getAux(); + int constraintInt = relu->getTableauAuxVars().front(); + + String identifier = std::to_string( constraintInt ); + bool matched = false; + + String proofRule = ""; + String proofRuleRes = ""; + String tempString = ""; + + // Additional steps required in some cases, for deriving a phase from a tighter bound + if ( targetBound > 0 ) + tempString += String( "(not " ) + + getBoundAsClause( Tightening( causingVar, 0, Tightening::UB ) ) + ")(not " + + causeBound + ")"; + else + tempString += String( "(not" ) + causeBound + ")" + + getBoundAsClause( Tightening( + causingVar, 0, targetBound > 0 ? Tightening::LB : Tightening::UB ) ); + + if ( targetBound != 0 ) + { + proofRule = + String( "(step taut" ) + id + " (cl (or " + tempString + ")):rule la_tautology)\n"; + proofRule += String( "(step ts" ) + id + " (cl " + tempString + "):rule or :premises(taut" + + id + "))\n"; + } + + if ( targetBound > 0 && causingVar == b ) + { + tempString = getBoundAsClause( Tightening( causingVar, 0, Tightening::UB ) ) + + " (<= 0.0 x" + std::to_string( b ) + ")"; + + proofRule += + String( "(step tot" ) + id + " (cl (or " + tempString + ")):rule la_totality)\n"; + proofRule += String( "(step tos" ) + id + " (cl " + tempString + "):rule or :premises(tot" + + id + "))\n"; + } + + String conclusion = getBoundAsClause( Tightening( affectedVar, bound, affectedVarBound ) ); + + // Prepare the shared prefix, and conclude with the remaining rules based on the exact + // derivation type + String pref = String( "(step rl" ) + id + " (cl " + negatedSplitClause + conclusion + + "):rule resolution :premises(cr" + id; + // if the lb of b or f is positive, then ub of aux is zero + if ( ( causingVar == f || causingVar == b ) && causingVarBound == Tightening::LB && + affectedVar == aux && affectedVarBound == Tightening::UB && targetBound > 0 ) + { + matched = true; + proofRuleRes = pref + " ts" + id; + + if ( causingVar == b ) + proofRuleRes += String( " eq" ) + identifier + "_a0 tos" + id; + else + proofRuleRes += String( " eq" ) + identifier + "_a0 ri1_" + identifier; + + proofRuleRes += "))\n"; + } + // if the lb of b is zero, then so is the ub of aux + else if ( causingVar == b && causingVarBound == Tightening::LB && affectedVar == aux && + affectedVarBound == Tightening::UB && targetBound == 0 ) + { + matched = true; + proofRuleRes = pref + " eq" + identifier + "_a0))\n"; + } + // If lb of aux is positive, then ub of f is 0 + else if ( causingVar == aux && causingVarBound == Tightening::LB && affectedVar == f && + affectedVarBound == Tightening::UB && targetBound > 0 ) + { + matched = true; + proofRuleRes = pref + " ts" + id + " eq" + identifier + "_a0 ri1_" + identifier + "))\n"; + } + + // If ub of b is non positive, then ub of f is 0 + else if ( causingVar == b && causingVarBound == Tightening::UB && affectedVar == f && + affectedVarBound == Tightening::UB && targetBound < 0 ) + { + matched = true; + proofRuleRes = pref + " ts" + id + " eq" + identifier + "_i0))\n"; + } + // Propagate 0 ub from f to b ... + else if ( causingVar == f && causingVarBound == Tightening::UB && affectedVar == b && + affectedVarBound == Tightening::UB && targetBound == 0 ) + { + matched = true; + proofRuleRes = pref + " eq" + identifier + "_i1))\n"; + } + // ... and vise versa + else if ( causingVar == b && causingVarBound == Tightening::UB && affectedVar == f && + affectedVarBound == Tightening::UB && targetBound == 0 ) + { + matched = true; + proofRuleRes = pref + +" eq" + identifier + "_i0))\n"; + } + // If ub of aux is 0, then lb of b is 0 + else if ( causingVar == aux && causingVarBound == Tightening::UB && affectedVar == b && + affectedVarBound == Tightening::LB && targetBound == 0 ) + + { + matched = true; + proofRuleRes = pref + " eq" + identifier + "_a1))\n"; + } + // If lb of b is negative x, then ub of aux is -x + else if ( causingVar == b && causingVarBound == Tightening::LB && affectedVar == aux && + affectedVarBound == Tightening::UB && targetBound < 0 ) + { + matched = true; + + unsigned identifierInt = relu->getTableauAuxVars().front(); + String tautClause = String( "(not " ) + + getBoundAsClause( Tightening( affectedVar, 0, Tightening::UB ) ) + ")" + + conclusion; + proofRule = + String( "(step taut" ) + id + " (cl (or " + tautClause + ")):rule la_tautology)\n"; + + proofRule += String( "(step ts" ) + id + " (cl " + tautClause + "):rule or :premises(taut" + + id + "))\n"; + + String counterpartBound = + getBoundAsClause( Tightening( identifierInt, 0, Tightening::LB ) ); + String subConclusion = getBoundAsClause( Tightening( f, 0, Tightening::UB ) ); + String tableauLit = convertTableauAssumptionToClause( identifierInt - ( _n - _m ) ); + String subFarkasClause = String( " (cl (not " ) + causeBound + ")" + conclusion + "(not " + + subConclusion + ")(not " + counterpartBound + ")" + tableauLit; + + proofRule += String( "(step ifl" ) + id + subFarkasClause + + "):rule la_generic :args(1 1 -1 1 -1))\n"; + + proofRuleRes += pref + +" e" + std::to_string( identifierInt - ( _n - _m ) ) + " l" + + std::to_string( identifierInt ) + " ifl" + id + " ts" + id + " eq" + + identifier + "_a0 ri1_" + identifier + "))\n"; + } + // If ub of b is positive, then propagate to f + else if ( causingVar == b && causingVarBound == Tightening::UB && affectedVar == f && + affectedVarBound == Tightening::UB && targetBound > 0 ) + { + matched = true; + + unsigned identifierInt = relu->getTableauAuxVars().front(); + String tautClause = String( "(not " ) + + getBoundAsClause( Tightening( affectedVar, 0, Tightening::UB ) ) + ")" + + conclusion; + + proofRule = + String( "(step taut" ) + id + " (cl (or " + tautClause + ")):rule la_tautology)\n"; + + proofRule += String( "(step ts" ) + id + " (cl " + tautClause + "):rule or :premises(taut" + + id + "))\n"; + + String counterpartBound = + getBoundAsClause( Tightening( identifierInt, 0, Tightening::UB ) ); + String subConclusion = getBoundAsClause( Tightening( aux, 0, Tightening::UB ) ); + String tableauLit = convertTableauAssumptionToClause( identifierInt - ( _n - _m ) ); + String subFarkasClause = String( " (cl (not " ) + causeBound + ")" + conclusion + "(not " + + subConclusion + ")(not " + counterpartBound + ")" + tableauLit; + + proofRule += + String( "(step ifl" ) + id + subFarkasClause + "):rule la_generic :args(1 1 -1 1 1))\n"; + + proofRuleRes += pref + " e" + std::to_string( identifierInt - ( _n - _m ) ) + " u" + + std::to_string( identifierInt ) + " ifl" + id + " ts" + id + " eq" + + identifier + "_a0 ri1_" + identifier + "))\n"; + } + + if ( matched ) + _proof.append( { laGeneric, res, proofRule, proofRuleRes } ); +} + +void AletheProofWriter::linearCombinationMpq( const std::vector &explainedRow, + const SparseUnsortedList &expl ) const +{ + SparseUnsortedList tableauRow( _n ); + for ( const auto &entry : expl ) + { + if ( entry._value == 0 ) + continue; + + _initialTableau->getRow( entry._index, &tableauRow ); + for ( const auto &tableauEntry : tableauRow ) + { + // Add ci * xi to the explained row in the ith index + if ( tableauEntry._value != 0 ) + { + mpq_t tempval, tempEntry, tempTableauEntry; + mpq_init( tempval ); + mpq_init( tempEntry ); + mpq_init( tempTableauEntry ); + mpq_set_d( tempTableauEntry, tableauEntry._value ); + mpq_set_d( tempEntry, entry._value ); + mpq_mul( tempval, tempEntry, tempTableauEntry ); + mpq_add( const_cast( explainedRow[tableauEntry._index] ), + explainedRow[tableauEntry._index], + tempval ); + mpq_clear( tempval ); + mpq_clear( tempEntry ); + mpq_clear( tempTableauEntry ); + } + } + } +} + +void AletheProofWriter::farkasStrings( const SparseUnsortedList &expl, + unsigned entryId, + String &farkasArgs, + String &farkasClause, + String &farkasParticipants, + String &negatedSplitClause, + int explainedVar, + bool isUpper, + UnsatCertificateNode *node ) +{ + std::vector explainedRow = std::vector( _n ); + for ( const auto num : explainedRow ) + mpq_init( num ); + + linearCombinationMpq( explainedRow, expl ); + bool isLemma = explainedVar >= 0; + if ( isLemma ) + { + mpq_t temp; + mpq_init( temp ); + mpq_set_d( temp, 1 ); + mpq_add( + const_cast( explainedRow[explainedVar] ), explainedRow[explainedVar], temp ); + mpq_clear( temp ); + } + + farkasClause = ""; + farkasArgs = ""; + farkasParticipants = ""; + List splitDeps; + + // Deduce the participating equations and their Farkas coefficients + for ( const auto entry : expl ) + if ( entry._value != 0 ) + { + farkasClause += String( "(not e" + std::to_string( entry._index ) ) + ")"; + + mpq_class temp( isUpper ? -entry._value : entry._value ); + farkasArgs += temp.get_str() + " "; + farkasParticipants += String( "e" + std::to_string( entry._index ) ) + " "; + } + + for ( unsigned i = 0; i < _n; ++i ) + { + // Deduce the participating bounds, either derived from lemmas, splits, or from the input + mpq_class temp( explainedRow[i] ); + if ( mpq_sgn( explainedRow[i] ) == 0 ) + continue; + + bool useEntryUpperBound = ( mpq_sgn( explainedRow[i] ) > 0 && isUpper ) || + ( mpq_sgn( explainedRow[i] ) < 0 && !isUpper ); + + String boundString = useEntryUpperBound ? "u" : "l"; + Tightening::BoundType boundType = useEntryUpperBound ? Tightening::UB : Tightening::LB; + const std::shared_ptr &gbEntry = + _groundBoundManager.getGroundBoundEntryUpToId( i, boundType, entryId ); + + int lemId = gbEntry->lemma ? gbEntry->lemma->getId() : -1; + double bound = gbEntry->val; + bool isLemmaIncluded = lemId >= 0 && gbEntry->lemma->getToCheck(); + bool useSplitBound = ( lemId < 0 && gbEntry->isPhaseFixing ); + + String ineqString; + + // Add the bound itself to the clause + if ( useEntryUpperBound ) + ineqString = String( "(not (<= x" + std::to_string( i ) + " " ) + + SmtLibWriter::signedValue( bound ) + ")) "; + else + ineqString = String( "(not (<= " ) + SmtLibWriter::signedValue( bound ) + " x" + + std::to_string( i ) + ")) "; + + // Add the bound coefficient only if the generic rule is used + if ( !GlobalConfiguration::DEDICATED_ALETHE_RULE ) + farkasArgs += temp.get_str() + " "; + + // Use the bound if it is learned by a lemma, or by a split + if ( isLemmaIncluded || useSplitBound ) + farkasClause += ineqString; + // If an input bound is used, then add it by its name + else + { + farkasClause += String( "(not " ) + boundString + std::to_string( i ) + ") "; + farkasParticipants += boundString + std::to_string( i ) + " "; + } + + // Add split deps of previous lemmas + if ( isLemmaIncluded ) + { + farkasParticipants += String( "rl" + std::to_string( lemId ) ) + " "; + for ( const auto dep : _idToSplits[gbEntry->id] ) + if ( !splitDeps.exists( dep ) ) + splitDeps.append( dep ); + } + // Add split deps of previous splits + else if ( useSplitBound ) + { + splitDeps.append( Tightening( i, bound, boundType ) ); + + String identifier = std::to_string( _varToPlc[i]->getTableauAuxVars().front() ); + if ( ( i == _varToPlc[i]->getParticipatingVariables().front() && bound == 0.0 && + boundType == Tightening::UB ) ) + farkasParticipants += String( "eq" ) + identifier + "_i1 "; + + if ( !( i == _varToPlc[i]->getParticipatingVariables().front() && bound == 0.0 && + boundType == Tightening::LB ) ) + { + if ( i == _varToPlc[i]->getParticipatingVariables().back() ) + farkasParticipants += String( "eq" ) + identifier + "_a0 "; + else + farkasParticipants += String( "ri1_" ) + identifier + " "; + } + } + } + + for ( const auto num : explainedRow ) + mpq_clear( num ); + + // Add proof terms for all splits in node path, with 0 argument for those that are not + // actually used Enables elaboration in Carcara + if ( node && GlobalConfiguration::ALETHE_ELABORATE_TERMS ) + { + List nodePath = getPathSplits( node ); + for ( const auto &caseSplit : nodePath ) + { + bool isCaseIncluded = false; + for ( const auto &bound : caseSplit.getBoundTightenings() ) + if ( splitDeps.exists( bound ) ) + isCaseIncluded = true; + + if ( isCaseIncluded ) + continue; + + farkasArgs += "0 "; + String identifier = + std::to_string( _varToPlc[caseSplit.getBoundTightenings().front()._variable] + ->getTableauAuxVars() + .front() ); + farkasClause += isSplitActive( caseSplit ) ? String( " (not a" ) + identifier + ")" + : String( " a" ) + identifier; + splitDeps.append( caseSplit.getBoundTightenings().front() ); + } + } + + if ( isLemma && _idToSplits.exists( entryId ) ) + _idToSplits[entryId] = splitDeps; + else if ( isLemma ) + _idToSplits.insert( entryId, splitDeps ); + else + _nodeToSplits.insert( -explainedVar, splitDeps ); + + // Compute the learned activation pattern from the bounds, and write its negation as a clause + Set usedPlc = {}; + for ( const auto &tightening : splitDeps ) + { + PiecewiseLinearConstraint *plc = _varToPlc[tightening._variable]; + + // Avoid repetitions + if ( usedPlc.exists( plc->getTableauAuxVars().front() ) ) + continue; + + usedPlc.insert( plc->getTableauAuxVars().front() ); + String identifier = std::to_string( plc->getTableauAuxVars().front() ); + + PiecewiseLinearCaseSplit tighteningSplit; + + for ( const auto &casSplit : plc->getAllCases() ) + { + PiecewiseLinearCaseSplit split = plc->getCaseSplit( casSplit ); + if ( split.getBoundTightenings().exists( tightening ) ) + tighteningSplit = split; + } + + String isNegActive = isSplitActive( tighteningSplit ) ? "(not a" : "a"; + String suffix = isSplitActive( tighteningSplit ) ? ")" : " "; + negatedSplitClause += isNegActive + identifier + suffix; + } +} + +String AletheProofWriter::convertTableauAssumptionToClause( unsigned index ) const +{ + return String( "(not e" ) + std::to_string( index ) + ")"; +} + +void AletheProofWriter::writeDelegatedLeaf( const UnsatCertificateNode *node ) +{ + String proofHole = String( "(step r" + std::to_string( node->getId() ) ) + " (cl " + + getNegatedSplitsClause( getPathSplits( node ) ) + "):rule hole)\n"; + + List deps = {}; + for ( const auto &split : getPathSplits( node ) ) + for ( const auto tightening : split.getBoundTightenings() ) + deps.append( tightening ); + + _nodeToSplits.insert( node->getId(), deps ); + _proof.append( proofHole ); +} + +unsigned AletheProofWriter::assignId() +{ + return _stepCounter++; +} + +void AletheProofWriter::deleteProof() +{ + _proof.clear(); +} + +void AletheProofWriter::setInitialTableau( const SparseMatrix *tableau ) +{ + _initialTableau = tableau; +} + +const Set AletheProofWriter::getSupportedActivations() +{ + return { RELU }; +} \ No newline at end of file diff --git a/src/proofs/AletheProofWriter.h b/src/proofs/AletheProofWriter.h new file mode 100644 index 0000000000..ba14dc6a8a --- /dev/null +++ b/src/proofs/AletheProofWriter.h @@ -0,0 +1,164 @@ +/** +** \verbatim +** Top contributors (to current version): +** Omri Isac, Guy Katz +** This file is part of the Marabou project. +** Copyright (c) 2017-2026 by the authors listed in the file AUTHORS +** in the top-level source directory) and their institutional affiliations. +** All rights reserved. See the file COPYING in the top-level source +** directory for licensing information.\endverbatim +** +** [[ Add lengthier description here ]] +**/ + +#ifndef __AletheProofWriter_h__ +#define __AletheProofWriter_h__ + +#include "GroundBoundManager.h" +#include "PiecewiseLinearCaseSplit.h" +#include "SmtLibWriter.h" +#include "SparseMatrix.h" +#include "SparseUnsortedList.h" +#include "Stack.h" +#include "UnsatCertificateNode.h" +#include "UnsatCertificateUtils.h" +#include "Vector.h" +#include "gmp.h" +#include "gmpxx.h" + +class AletheProofWriter +{ +public: + static const Set getSupportedActivations(); + + AletheProofWriter( unsigned explanationSize, + const Vector &upperBounds, + const Vector &lowerBounds, + const GroundBoundManager &groundBoundManager, + const SparseMatrix *tableau, + const List &problemConstraints ); + + /* + Write whole proof info to a file + */ + void writeInstanceToFile( IFile &file ); + + /* + Write steps to conclude UNSAT of a node from the UNSAT of its children + */ + void writeChildrenConclusion( const UnsatCertificateNode *node ); + + /* + Get the next unique ID to a node, and increment it + */ + unsigned assignId(); + + /* + Write proof hole for a delegated leaf node + */ + void writeDelegatedLeaf( const UnsatCertificateNode *node ); + + /* + Add proof steps to prove a PLC lemma + */ + void writeLemma( const std::shared_ptr &lemmaEntry ); + + /* + Add proof steps to prove the UNSAT of a leaf + */ + void writeContradiction( const SparseUnsortedList &contradiction, UnsatCertificateNode *node ); + + /* + Delete the content of the proof + */ + void deleteProof(); + + /* + Set the initial tableau constraints that define the query + */ + void setInitialTableau( const SparseMatrix *tableau ); + +private: + /* + Initial query information. + */ + const SparseMatrix *_initialTableau; + Vector _tableauAssumptions; // For easy access + Vector _initialUpperBounds; + Vector _initialLowerBounds; + const GroundBoundManager &_groundBoundManager; + Vector _plc; + + unsigned _n; + unsigned _m; + + /* + Lists for proofs steps and assumptions, track the number of nodes used in the proof + */ + List _proof; + List _assumptions; + unsigned _stepCounter; + + /* + Maintain maps the link between variables, PLC, nodes and their ids and splits + */ + Map _varToPlc; + Map> _idToSplits; + Map> _nodeToSplits; + + /* + Add original query assumptions to the proof file + */ + void writeBoundAssumptions(); + + void writePLCAssumption(); + + void writeTableauAssumptions(); + + /* + Add proof steps for proving a lemma learned from a ReLU activation constraint. + */ + void writeReluLemma( const std::shared_ptr &lemmaEntry, + const ReluConstraint *relu ); + /* + Collect all case splits of path to a proof node + */ + List getPathSplits( const UnsatCertificateNode *node ) const; + + /* + Convert multiple Marabou objects into their corresponding Alethe clause + */ + + String getBoundAsClause( const Tightening &bound ) const; + + String getNegatedSplitsClause( const List &splits ) const; + + String convertTableauAssumptionToClause( unsigned index ) const; + + /* + Check if a case split object represents the active ReLU phase + */ + bool isSplitActive( const PiecewiseLinearCaseSplit &split ) const; + + /* + Compute linear combinations from proof vectors using GMP + */ + void linearCombinationMpq( const std::vector &explainedRow, + const SparseUnsortedList &expl ) const; + + /* + A helper function that converts proof vector information to la_generic arguments and clauses as + Strings + */ + void farkasStrings( const SparseUnsortedList &expl, + unsigned entryId, + String &farkasArgs, + String &farkasClause, + String &farkasParticipants, + String &negatedSplitClause, + int explainerVar, + bool isUpper, + UnsatCertificateNode *node ); +}; + +#endif // __AletheProofWriter_h__ \ No newline at end of file diff --git a/src/proofs/CMakeLists.txt b/src/proofs/CMakeLists.txt index ad7aca42d1..2e30136fbf 100644 --- a/src/proofs/CMakeLists.txt +++ b/src/proofs/CMakeLists.txt @@ -14,6 +14,7 @@ macro(proofs_add_unit_test name) marabou_add_test(${PROOFS_TESTS_DIR}/Test_${name} proofs USE_MOCK_COMMON USE_MOCK_ENGINE "unit") endmacro() +proofs_add_unit_test(AletheProofWriter) proofs_add_unit_test(BoundExplainer) proofs_add_unit_test(Checker) proofs_add_unit_test(SmtLibWriter) diff --git a/src/proofs/JsonWriter.cpp b/src/proofs/JsonWriter.cpp index 78020a14f6..b0d95f3bb3 100644 --- a/src/proofs/JsonWriter.cpp +++ b/src/proofs/JsonWriter.cpp @@ -180,7 +180,24 @@ void JsonWriter::writeUnsatCertificateNode( const UnsatCertificateNode *node, unsigned counter = 0; unsigned size = node->getChildren().size(); - for ( auto child : node->getChildren() ) + List childrenInFixedOrder = List(); + const List &backChildTightening = + node->getChildren().back()->getSplit().getBoundTightenings(); + + // Insert the inactive phase first + if ( backChildTightening.back()._type == Tightening::LB || + backChildTightening.front()._type == Tightening::LB ) + { + childrenInFixedOrder.append( node->getChildren().front() ); + childrenInFixedOrder.append( node->getChildren().back() ); + } + else + { + childrenInFixedOrder.append( node->getChildren().back() ); + childrenInFixedOrder.append( node->getChildren().front() ); + } + + for ( auto child : childrenInFixedOrder ) { instance.append( "{\n" ); writeUnsatCertificateNode( child, explanationSize, instance ); diff --git a/src/proofs/PlcLemma.cpp b/src/proofs/PlcLemma.cpp index 796f90b1a9..0afbf1e040 100644 --- a/src/proofs/PlcLemma.cpp +++ b/src/proofs/PlcLemma.cpp @@ -23,7 +23,8 @@ PLCLemma::PLCLemma( const List &causingVars, Tightening::BoundType affectedVarBound, const Vector &explanations, PiecewiseLinearFunctionType constraintType, - double minTargetBound ) + double minTargetBound, + unsigned id ) : _causingVars( causingVars ) , _affectedVar( affectedVar ) , _bound( bound ) @@ -32,6 +33,7 @@ PLCLemma::PLCLemma( const List &causingVars, , _constraintType( constraintType ) , _toCheck( false ) , _minTargetBound( minTargetBound ) + , _id( id ) { if ( explanations.empty() ) _explanations = List(); @@ -116,3 +118,8 @@ void PLCLemma::setToCheck() { _toCheck = true; } + +unsigned PLCLemma::getId() const +{ + return _id; +} diff --git a/src/proofs/PlcLemma.h b/src/proofs/PlcLemma.h index 10884888bb..a9cfbb2226 100644 --- a/src/proofs/PlcLemma.h +++ b/src/proofs/PlcLemma.h @@ -33,7 +33,8 @@ class PLCLemma Tightening::BoundType affectedVarBound, const Vector &explanation, PiecewiseLinearFunctionType constraintType, - double minTargetBound ); + double minTargetBound, + unsigned id ); ~PLCLemma(); @@ -49,6 +50,7 @@ class PLCLemma PiecewiseLinearFunctionType getConstraintType() const; bool getToCheck() const; double getMinTargetBound() const; + unsigned getId() const; void setToCheck(); @@ -62,6 +64,7 @@ class PLCLemma PiecewiseLinearFunctionType _constraintType; bool _toCheck; double _minTargetBound; + unsigned _id; }; #endif //__PlcExplanation_h__ diff --git a/src/proofs/SmtLibWriter.cpp b/src/proofs/SmtLibWriter.cpp index 92628707da..b6999bb71c 100644 --- a/src/proofs/SmtLibWriter.cpp +++ b/src/proofs/SmtLibWriter.cpp @@ -14,24 +14,17 @@ #include "SmtLibWriter.h" -#include "DisjunctionConstraint.h" -#include "LeakyReluConstraint.h" -#include "MaxConstraint.h" -#include "ReluConstraint.h" -#include "SignConstraint.h" - const unsigned SmtLibWriter::SMTLIBWRITER_PRECISION = - (unsigned)std::log10( 1 / GlobalConfiguration::DEFAULT_EPSILON_FOR_COMPARISONS ); - - -void SmtLibWriter::writeToSmtLibFile( const String &fileName, - unsigned numOfTableauRows, - unsigned numOfVariables, - const Vector &upperBounds, - const Vector &lowerBounds, - const SparseMatrix *tableau, - const List &additionalEquations, - const List &problemConstraints ) + (unsigned)std::log10( 1 / GlobalConfiguration::LEMMA_CERTIFICATION_TOLERANCE ); + +List +SmtLibWriter::convertToSmtLib( unsigned numOfTableauRows, + unsigned numOfVariables, + const Vector &upperBounds, + const Vector &lowerBounds, + const SparseMatrix *tableau, + const List &additionalEquations, + const List &problemConstraints ) { List instance; @@ -119,21 +112,41 @@ void SmtLibWriter::writeToSmtLibFile( const String &fileName, } SmtLibWriter::addFooter( instance ); + + return instance; +} + +void SmtLibWriter::writeToSmtLibFile( const String &fileName, + unsigned numOfTableauRows, + unsigned numOfVariables, + const Vector &upperBounds, + const Vector &lowerBounds, + const SparseMatrix *tableau, + const List &additionalEquations, + const List &problemConstraints ) +{ + List instance = SmtLibWriter::convertToSmtLib( numOfTableauRows, + numOfVariables, + upperBounds, + lowerBounds, + tableau, + additionalEquations, + problemConstraints ); File file( fileName ); SmtLibWriter::writeInstanceToFile( file, instance ); } void SmtLibWriter::addHeader( unsigned numberOfVariables, List &instance ) { - instance.append( "( set-logic QF_LRA )\n" ); + instance.append( "(set-logic QF_LRA)\n" ); for ( unsigned i = 0; i < numberOfVariables; ++i ) - instance.append( "( declare-fun x" + std::to_string( i ) + " () Real )\n" ); + instance.append( "(declare-fun x" + std::to_string( i ) + " () Real)\n" ); } void SmtLibWriter::addFooter( List &instance ) { - instance.append( "( check-sat )\n" ); - instance.append( "( exit )\n" ); + instance.append( "(check-sat)\n" ); + instance.append( "(exit)\n" ); } void SmtLibWriter::addReLUConstraint( unsigned b, @@ -141,14 +154,15 @@ void SmtLibWriter::addReLUConstraint( unsigned b, const PhaseStatus status, List &instance ) { - if ( status == PHASE_NOT_FIXED ) - instance.append( "( assert ( = x" + std::to_string( f ) + " ( ite ( >= x" + - std::to_string( b ) + " 0 ) x" + std::to_string( b ) + " 0 ) ) )\n" ); + if ( GlobalConfiguration::WRITE_ALETHE_PROOF || status == PHASE_NOT_FIXED ) + instance.append( "(assert (ite (<= 0.0 x" + std::to_string( b ) + ") (= x" + + std::to_string( b ) + " x" + std::to_string( f ) + ") (<= x" + + std::to_string( f ) + " 0.0)))\n" ); else if ( status == RELU_PHASE_ACTIVE ) - instance.append( "( assert ( = x" + std::to_string( f ) + " x" + std::to_string( b ) + - " ) )\n" ); + instance.append( "(assert (= x" + std::to_string( f ) + " x" + std::to_string( b ) + + "))\n" ); else if ( status == RELU_PHASE_INACTIVE ) - instance.append( "( assert ( = x" + std::to_string( f ) + " 0 ) )\n" ); + instance.append( "(assert (= x" + std::to_string( f ) + " 0))\n" ); } void SmtLibWriter::addSignConstraint( unsigned b, @@ -157,12 +171,13 @@ void SmtLibWriter::addSignConstraint( unsigned b, List &instance ) { if ( status == PHASE_NOT_FIXED ) - instance.append( "( assert ( = x" + std::to_string( f ) + " ( ite ( >= x" + - std::to_string( b ) + " 0 ) 1 ( - 1 ) ) ) )\n" ); + instance.append( "(assert (ite (>= x" + std::to_string( b ) + " 0.0) (= x" + + std::to_string( f ) + " 1.0) (= x" + std::to_string( f ) + + " (- 1.0))))\n" ); else if ( status == SIGN_PHASE_POSITIVE ) - instance.append( "( assert ( = x" + std::to_string( f ) + " 1 ) )\n" ); + instance.append( "(assert (= x" + std::to_string( f ) + " 1.0))\n" ); else if ( status == SIGN_PHASE_NEGATIVE ) - instance.append( "( assert ( = x" + std::to_string( f ) + " ( - 1 ) ) )\n" ); + instance.append( "(assert (= x" + std::to_string( f ) + " (- 1.0)))\n" ); } void SmtLibWriter::addAbsConstraint( unsigned b, @@ -171,15 +186,15 @@ void SmtLibWriter::addAbsConstraint( unsigned b, List &instance ) { if ( status == PHASE_NOT_FIXED ) - instance.append( "( assert ( = x" + std::to_string( f ) + " ( ite ( >= x" + - std::to_string( b ) + " 0 ) x" + std::to_string( b ) + " ( - x" + - std::to_string( b ) + " ) ) ) )\n" ); + instance.append( "(assert (ite (>= x" + std::to_string( b ) + " 0.0) (= x" + + std::to_string( f ) + " x" + std::to_string( b ) + ") (= x" + + std::to_string( f ) + " (- x" + std::to_string( b ) + "))))\n" ); else if ( status == ABS_PHASE_POSITIVE ) - instance.append( "( assert ( = x" + std::to_string( f ) + " x" + std::to_string( b ) + - " ) )\n" ); + instance.append( "(assert (= x" + std::to_string( f ) + " x" + std::to_string( b ) + + "))\n" ); else if ( status == ABS_PHASE_NEGATIVE ) - instance.append( "( assert ( = x" + std::to_string( f ) + " ( - x" + std::to_string( b ) + - " ) ) )\n" ); + instance.append( "(assert (= x" + std::to_string( f ) + " (- x" + std::to_string( b ) + + ")))\n" ); } void SmtLibWriter::addMaxConstraint( unsigned f, @@ -194,13 +209,13 @@ void SmtLibWriter::addMaxConstraint( unsigned f, // f equals to some value (the value of maxVal) if ( status == MAX_PHASE_ELIMINATED ) - instance.append( String( "( assert ( = x" + std::to_string( f ) + " " ) + - signedValue( maxVal ) + " ) )\n" ); + instance.append( String( "(assert (= x" ) + std::to_string( f ) + " " + + signedValue( maxVal ) + "))\n" ); // f equals to some element (maxVal is an index) else if ( status != PHASE_NOT_FIXED ) - instance.append( "( assert ( = x" + std::to_string( f ) + " x" + - std::to_string( (unsigned)maxVal ) + " ) )\n" ); + instance.append( "(assert (= x" + std::to_string( f ) + " x" + + std::to_string( (unsigned)maxVal ) + "))\n" ); else { @@ -209,7 +224,7 @@ void SmtLibWriter::addMaxConstraint( unsigned f, for ( const auto &element : elements ) { counter = 0; - assertRowLine = "( assert ( =>"; + assertRowLine = "(assert (=>"; for ( auto const &otherElement : elements ) { if ( otherElement == element ) @@ -217,21 +232,20 @@ void SmtLibWriter::addMaxConstraint( unsigned f, if ( counter < size - 2 ) { - assertRowLine += " ( and"; + assertRowLine += " (and"; ++counter; } - assertRowLine += " ( >= x" + std::to_string( element ) + " x" + - std::to_string( otherElement ) + " )"; + assertRowLine += " (>= x" + std::to_string( element ) + " x" + + std::to_string( otherElement ) + ")"; } for ( unsigned i = 0; i < size - 2; ++i ) - assertRowLine += String( " )" ); + assertRowLine += String( ")" ); - assertRowLine += - " ( = x" + std::to_string( f ) + " x" + std::to_string( element ) + " )"; + assertRowLine += " (= x" + std::to_string( f ) + " x" + std::to_string( element ) + ")"; - instance.append( assertRowLine + " ) )\n" ); + instance.append( assertRowLine + "))\n" ); } } } @@ -242,12 +256,12 @@ void SmtLibWriter::addDisjunctionConstraint( const List &instance ) { if ( status == PHASE_NOT_FIXED ) - instance.append( String( "( assert ( = x" + std::to_string( f ) + " ( ite ( >= x" + - std::to_string( b ) + " 0 ) x" + std::to_string( b ) + " ( * " ) + - signedValue( slope ) + " x" + std::to_string( b ) + " ) ) ) )\n" ); + instance.append( String( "(assert (ite (>= x" ) + std::to_string( b ) + " 0) (= x" + + std::to_string( f ) + " x" + std::to_string( b ) + ") (= x" + + std::to_string( f ) + " (* " + signedValue( slope ) + " x" + + std::to_string( b ) + "))))\n" ); else if ( status == RELU_PHASE_ACTIVE ) - instance.append( "( assert ( = x" + std::to_string( f ) + " x" + std::to_string( b ) + - " ) )\n" ); + instance.append( "(assert (= x" + std::to_string( f ) + " x" + std::to_string( b ) + + "))\n" ); else if ( status == RELU_PHASE_INACTIVE ) instance.append( - String( "( assert ( = x" + std::to_string( f ) + " x" + std::to_string( b ) ) + - signedValue( -slope ) + ") )\n" ); + String( "(assert (= x" + std::to_string( f ) + " x" + std::to_string( b ) ) + + signedValue( -slope ) + "))\n" ); } void SmtLibWriter::addTableauRow( const SparseUnsortedList &row, List &instance ) @@ -316,58 +331,71 @@ void SmtLibWriter::addTableauRow( const SparseUnsortedList &row, List &i // Avoid adding a redundant last element auto it = --row.end(); - if ( std::isnan( it->_value ) || FloatUtils::isZero( it->_value ) ) + if ( std::isnan( it->_value ) || it->_value == 0 ) --size; if ( !size ) return; - unsigned counter = 0; - String assertRowLine = "( assert ( = 0"; + String assertRowLine = "(assert (= 0.0 "; + + if ( row.getSize() > 1 ) + assertRowLine += "(+"; + auto entry = row.begin(); for ( ; entry != row.end(); ++entry ) { - if ( FloatUtils::isZero( entry->_value ) ) + if ( entry->_value == 0 ) continue; - if ( counter != size - 1 ) - assertRowLine += String( " ( + " ); - else - assertRowLine += String( " " ); - + assertRowLine += String( " " ); + mpq_class tempVal( entry->_value ); // Coefficients +-1 can be dropped if ( entry->_value == 1 ) assertRowLine += String( "x" ) + std::to_string( entry->_index ); else if ( entry->_value == -1 ) - assertRowLine += String( "( - x" ) + std::to_string( entry->_index ) + " )"; + assertRowLine += String( "(- x" ) + std::to_string( entry->_index ) + ")"; + else if ( entry->_value == (int)entry->_value ) + assertRowLine += String( "(* " ) + signedValue( entry->_value ) + " x" + + std::to_string( entry->_index ) + ")"; else - assertRowLine += String( "( * " ) + signedValue( entry->_value ) + " x" + - std::to_string( entry->_index ) + " )"; - - ++counter; + assertRowLine += + String( "(* " ) + tempVal.get_str() + " x" + std::to_string( entry->_index ) + ")"; } - for ( unsigned i = 0; i < counter + 1; ++i ) - assertRowLine += String( " )" ); + if ( row.getSize() > 1 ) + assertRowLine += ")"; - instance.append( assertRowLine + "\n" ); + instance.append( assertRowLine + "))\n" ); } void SmtLibWriter::addGroundUpperBounds( const Vector &bounds, List &instance ) { unsigned n = bounds.size(); for ( unsigned i = 0; i < n; ++i ) - instance.append( String( "( assert ( <= x" + std::to_string( i ) ) + String( " " ) + - signedValue( bounds[i] ) + " ) )\n" ); + { + mpq_class bound( bounds[i] ); + String boundString = bound.get_str(); + boundString = bound.get_den().get_str() == "1" ? boundString + ".0" : boundString; + + instance.append( String( "(assert (<= x" + std::to_string( i ) ) + String( " " ) + + boundString + "))\n" ); + } } void SmtLibWriter::addGroundLowerBounds( const Vector &bounds, List &instance ) { unsigned n = bounds.size(); for ( unsigned i = 0; i < n; ++i ) - instance.append( String( "( assert ( >= x" + std::to_string( i ) ) + String( " " ) + - signedValue( bounds[i] ) + " ) )\n" ); + { + mpq_class bound( bounds[i] ); + String boundString = bound.get_str(); + boundString = bound.get_den().get_str() == "1" ? boundString + ".0" : boundString; + + instance.append( String( "(assert (>= x" + std::to_string( i ) ) + String( " " ) + + boundString + "))\n" ); + } } void SmtLibWriter::writeInstanceToFile( IFile &file, const List &instance ) @@ -385,74 +413,70 @@ String SmtLibWriter::signedValue( double val ) std::stringstream s; s << std::fixed << std::setprecision( SMTLIBWRITER_PRECISION ) << abs( val ); return val >= 0 ? String( s.str() ).trimZerosFromRight() - : String( "( - " + s.str() ).trimZerosFromRight() + " )"; + : String( "(- " + s.str() ).trimZerosFromRight() + ")"; } void SmtLibWriter::addEquation( const Equation &eq, List &instance, bool assertEquations ) { - unsigned size = eq._addends.size(); + // Count only nonzero elements + unsigned size = 0; + for ( const auto &addend : eq._addends ) + if ( addend._coefficient != 0 ) + ++size; if ( !size ) return; - unsigned counter = 0; - String assertRowLine = ""; if ( assertEquations ) - assertRowLine += "( assert "; + assertRowLine += "(assert "; if ( eq._type == Equation::EQ ) - assertRowLine += "( = "; + assertRowLine += "(= "; else if ( eq._type == Equation::LE ) // Scalar should be >= than sum of addends - assertRowLine += "( >= "; + assertRowLine += "(>= "; else // Scalar should be <= than sum of addends - assertRowLine += "( <= "; + assertRowLine += "(<= "; assertRowLine += signedValue( eq._scalar ); + if ( size > 1 ) + assertRowLine += String( " (+" ); + for ( const auto &addend : eq._addends ) { - if ( FloatUtils::isZero( addend._coefficient ) ) - { - // If the last addend has coefficient zero, add 0 to close previously opened addition - if ( addend == eq._addends.back() ) - assertRowLine += String( " 0 )" ); + if ( addend._coefficient == 0 ) continue; - } - - if ( !( addend == eq._addends.back() ) ) - assertRowLine += String( " ( + " ); - else - assertRowLine += String( " " ); + assertRowLine += String( " " ); // Coefficients +-1 can be dropped if ( addend._coefficient == 1 ) assertRowLine += String( "x" ) + std::to_string( addend._variable ); else if ( addend._coefficient == -1 ) - assertRowLine += String( "( - x" ) + std::to_string( addend._variable ) + " )"; + assertRowLine += String( "(- x" ) + std::to_string( addend._variable ) + ")"; else - assertRowLine += String( "( * " ) + signedValue( addend._coefficient ) + " x" + - std::to_string( addend._variable ) + " )"; - - ++counter; + assertRowLine += String( "(* " ) + signedValue( addend._coefficient ) + " x" + + std::to_string( addend._variable ) + ")"; } - for ( unsigned i = 0; i < counter; ++i ) - assertRowLine += String( " )" ); + assertRowLine += String( ")" ); + + if ( size > 1 ) + assertRowLine += String( ")" ); - instance.append( assertRowLine + ( assertEquations ? " ) \n" : " " ) ); + instance.append( assertRowLine + ( assertEquations ? ")\n" : " " ) ); } void SmtLibWriter::addTightening( Tightening bound, List &instance ) { if ( bound._type == Tightening::LB ) - instance.append( String( "( >= x" ) + std::to_string( bound._variable ) + " " + - signedValue( bound._value ) + " )" ); + instance.append( String( "(>= x" ) + std::to_string( bound._variable ) + " " + + signedValue( bound._value ) + ")" ); else - instance.append( String( "( <= x" + std::to_string( bound._variable ) ) + String( " " ) + - signedValue( bound._value ) + " )" ); + instance.append( String( "(<= x" + std::to_string( bound._variable ) ) + String( " " ) + + signedValue( bound._value ) + ")" ); } \ No newline at end of file diff --git a/src/proofs/SmtLibWriter.h b/src/proofs/SmtLibWriter.h index 2b6c3ec5b9..326c9ca1cd 100644 --- a/src/proofs/SmtLibWriter.h +++ b/src/proofs/SmtLibWriter.h @@ -15,12 +15,16 @@ #ifndef __SmtLibWriter_h__ #define __SmtLibWriter_h__ +#include "DisjunctionConstraint.h" #include "File.h" +#include "LeakyReluConstraint.h" #include "List.h" #include "MString.h" +#include "MaxConstraint.h" #include "PiecewiseLinearConstraint.h" #include "SparseUnsortedList.h" #include "Vector.h" +#include "gmpxx.h" #include @@ -123,8 +127,18 @@ class SmtLibWriter static String signedValue( double val ); /* - A wrapper function calling all previous functions + Wrapper functions calling all previous functions */ + static List + convertToSmtLib( unsigned numOfTableauRows, + unsigned numOfVariables, + const Vector &upperBounds, + const Vector &lowerBounds, + const SparseMatrix *tableau, + const List &additionalEquations, + const List &problemConstraints ); + + static void writeToSmtLibFile( const String &fileName, unsigned numOfTableauRows, unsigned numOfVariables, diff --git a/src/proofs/UnsatCertificateNode.cpp b/src/proofs/UnsatCertificateNode.cpp index f42c7f3263..d5dcbbf219 100644 --- a/src/proofs/UnsatCertificateNode.cpp +++ b/src/proofs/UnsatCertificateNode.cpp @@ -17,10 +17,14 @@ #include UnsatCertificateNode::UnsatCertificateNode( UnsatCertificateNode *parent, - PiecewiseLinearCaseSplit split ) + PiecewiseLinearCaseSplit split, + unsigned splitNum, + unsigned id ) : _parent( parent ) , _contradiction( NULL ) , _headSplit( std::move( split ) ) + , _splitNum( splitNum ) + , _id( id ) , _hasSATSolution( false ) , _wasVisited( false ) , _delegationStatus( DelegationStatus::DONT_DELEGATE ) @@ -163,6 +167,16 @@ void UnsatCertificateNode::deleteUnusedLemmas() { if ( GlobalConfiguration::ANALYZE_PROOF_DEPENDENCIES ) for ( auto &lemma : _PLCExplanations ) - if ( lemma && !lemma->getToCheck() ) + if ( ( lemma && !lemma->getToCheck() ) || GlobalConfiguration::WRITE_ALETHE_PROOF ) lemma = nullptr; } + +unsigned UnsatCertificateNode::getSplitNum() const +{ + return _splitNum; +} + +unsigned UnsatCertificateNode::getId() const +{ + return _id; +} diff --git a/src/proofs/UnsatCertificateNode.h b/src/proofs/UnsatCertificateNode.h index 9963f85f8c..41865e4093 100644 --- a/src/proofs/UnsatCertificateNode.h +++ b/src/proofs/UnsatCertificateNode.h @@ -35,7 +35,10 @@ enum DelegationStatus : unsigned { class UnsatCertificateNode { public: - UnsatCertificateNode( UnsatCertificateNode *parent, PiecewiseLinearCaseSplit split ); + UnsatCertificateNode( UnsatCertificateNode *parent, + PiecewiseLinearCaseSplit split, + unsigned splitNum, + unsigned id ); ~UnsatCertificateNode(); /* @@ -133,12 +136,18 @@ class UnsatCertificateNode */ void deleteUnusedLemmas(); + unsigned getSplitNum() const; + unsigned getId() const; + void setId( unsigned id ) const; + private: List _children; UnsatCertificateNode *_parent; List> _PLCExplanations; Contradiction *_contradiction; PiecewiseLinearCaseSplit _headSplit; + unsigned _splitNum; + unsigned _id; // Enables certifying correctness of UNSAT leaves in SAT queries bool _hasSATSolution; diff --git a/src/proofs/tests/Test_AletheProofWriter.h b/src/proofs/tests/Test_AletheProofWriter.h new file mode 100644 index 0000000000..e39f055636 --- /dev/null +++ b/src/proofs/tests/Test_AletheProofWriter.h @@ -0,0 +1,136 @@ +/********************* */ +/*! \file Test_SmtLibWriter.h + ** \verbatim + ** Top contributors (to current version): + ** Omri Isac, Guy Katz + ** This file is part of the Marabou project. + ** Copyright (c) 2017-2022 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** [[ Add lengthier description here ]] + **/ + +#include "AletheProofWriter.h" +#include "CSRMatrix.h" +#include "MockFile.h" +#include "Query.h" +#include "context/cdlist.h" +#include "context/context.h" + +#include + +using CVC4::context::Context; +using namespace CVC4::context; + +class AletheProofWriterTestSuite : public CxxTest::TestSuite +{ +public: + MockFile *file; + Context *context; + + void setUp() + { + TS_ASSERT_THROWS_NOTHING( file = new MockFile() ); + TS_ASSERT_THROWS_NOTHING( context = new Context ); + } + + void tearDown() + { + TS_ASSERT_THROWS_NOTHING( delete file; ); + TS_ASSERT_THROWS_NOTHING( delete context; ); + } + + /* + Tests the writing Alethe assumption in correct SMTLIB format + */ + void test_alethe_assumption_writing() + { + // Construct all required info. + Vector ubs = { 1, 1, 1, 0 }; + Vector lbs = { 0, 0, 0, 0 }; + Vector rows = { 1, 2, -1, 0, 1, -1, 1, 1 }; + + unsigned n = ubs.size(); + unsigned m = 2; + + GroundBoundManager gbm = GroundBoundManager( *context ); + gbm.initialize( n ); + + Query query = Query(); + query.setNumberOfVariables( n ); + + ReluConstraint relu = ReluConstraint( 0, 1 ); + relu.transformToUseAuxVariables( query ); + relu.addTableauAuxVar( 3, 2 ); + List constraints = { &relu }; + + CSRMatrix *matrix = new CSRMatrix(); + matrix->initialize( rows.data(), m, n ); + + AletheProofWriter writer( m, ubs, lbs, gbm, matrix, constraints ); + writer.writeInstanceToFile( *file ); + + String line; + String expectedLine; + + // Tableau assumptions + line = file->readLine( '\n' ); + expectedLine = "(assume e0(!(= 0.0 (+ x0 (* 2.0 x1) (- x2))):named e0))"; + TS_ASSERT_EQUALS( line, expectedLine ); + + line = file->readLine( '\n' ); + expectedLine = "(assume e1(!(= 0.0 (+ x0 (- x1) x2 x3)):named e1))"; + TS_ASSERT_EQUALS( line, expectedLine ); + + // Bound Assumptions + line = file->readLine( '\n' ); + expectedLine = "(assume u0(!(<= x0 1.0):named u0))"; + TS_ASSERT_EQUALS( line, expectedLine ); + + line = file->readLine( '\n' ); + expectedLine = "(assume l0(!(>= x0 0.0):named l0))"; + TS_ASSERT_EQUALS( line, expectedLine ); + + line = file->readLine( '\n' ); + expectedLine = "(assume u1(!(<= x1 1.0):named u1))"; + TS_ASSERT_EQUALS( line, expectedLine ); + + line = file->readLine( '\n' ); + expectedLine = "(assume l1(!(>= x1 0.0):named l1))"; + TS_ASSERT_EQUALS( line, expectedLine ); + + line = file->readLine( '\n' ); + expectedLine = "(assume u2(!(<= x2 1.0):named u2))"; + TS_ASSERT_EQUALS( line, expectedLine ); + + line = file->readLine( '\n' ); + expectedLine = "(assume l2(!(>= x2 0.0):named l2))"; + TS_ASSERT_EQUALS( line, expectedLine ); + + line = file->readLine( '\n' ); + expectedLine = "(assume u3(!(<= x3 0.0):named u3))"; + TS_ASSERT_EQUALS( line, expectedLine ); + + line = file->readLine( '\n' ); + expectedLine = "(assume l3(!(>= x3 0.0):named l3))"; + TS_ASSERT_EQUALS( line, expectedLine ); + + // Relu Assumptions + line = file->readLine( '\n' ); + expectedLine = "(assume relu0 (ite (!(<= 0.0 x0):named a0)(= x0 x1)(<= x1 0.0)))"; + TS_ASSERT_EQUALS( line, expectedLine ); + + line = file->readLine( '\n' ); + // Next lines should represent proof steps + while ( line != "" ) + { + expectedLine = "(step"; + TS_ASSERT_EQUALS( line.find( expectedLine ), 0U ); + line = file->readLine( '\n' ); + } + + delete matrix; + } +}; \ No newline at end of file diff --git a/src/proofs/tests/Test_Checker.h b/src/proofs/tests/Test_Checker.h index 57ff3f10fd..64f7f88194 100644 --- a/src/proofs/tests/Test_Checker.h +++ b/src/proofs/tests/Test_Checker.h @@ -38,7 +38,7 @@ class CheckerTestSuite : public CxxTest::TestSuite List constraintsList = { &relu1, &relu2 }; // Set a complete tree of depth 3, using 2 ReLUs - auto *root = new UnsatCertificateNode( NULL, PiecewiseLinearCaseSplit() ); + auto *root = new UnsatCertificateNode( NULL, PiecewiseLinearCaseSplit(), 0, 0 ); Checker checker( root, m, &initialTableau, groundUpperBounds, groundLowerBounds, constraintsList ); @@ -57,11 +57,11 @@ class CheckerTestSuite : public CxxTest::TestSuite TS_ASSERT_EQUALS( split1_2.getBoundTightenings().size(), 2U ); // Child with missing aux tightening - auto *child1 = new UnsatCertificateNode( root, split1_1 ); - auto *child2 = new UnsatCertificateNode( root, split1_2 ); + auto *child1 = new UnsatCertificateNode( root, split1_1, 0, 0 ); + auto *child2 = new UnsatCertificateNode( root, split1_2, 0, 0 ); - auto *child2_1 = new UnsatCertificateNode( child2, split2_1 ); - auto *child2_2 = new UnsatCertificateNode( child2, split2_2 ); + auto *child2_1 = new UnsatCertificateNode( child2, split2_1, 0, 0 ); + auto *child2_2 = new UnsatCertificateNode( child2, split2_2, 0, 0 ); root->setVisited(); child2->setVisited(); diff --git a/src/proofs/tests/Test_SmtLibWriter.h b/src/proofs/tests/Test_SmtLibWriter.h index 0f66ddb017..8bf47acf79 100644 --- a/src/proofs/tests/Test_SmtLibWriter.h +++ b/src/proofs/tests/Test_SmtLibWriter.h @@ -32,8 +32,8 @@ class SmtLibWriterTestSuite : public CxxTest::TestSuite file = new MockFile(); Vector row = { 1, 2 }; SparseUnsortedList sparseRow( row.data(), 2 ); - Vector groundUpperBounds = { 1, 1 }; - Vector groundLowerBounds = { 1, -1 }; + Vector groundUpperBounds = { 1.0, 1.0 }; + Vector groundLowerBounds = { 1.0, -1.0 }; List instance; SmtLibWriter::addHeader( 2, instance ); @@ -108,186 +108,183 @@ class SmtLibWriterTestSuite : public CxxTest::TestSuite String expectedLine; line = file->readLine( '\n' ); - expectedLine = "( set-logic QF_LRA )"; + expectedLine = "(set-logic QF_LRA)"; TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = "( declare-fun x0 () Real )"; + expectedLine = "(declare-fun x0 () Real)"; TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = "( declare-fun x1 () Real )"; + expectedLine = "(declare-fun x1 () Real)"; TS_ASSERT_EQUALS( line, expectedLine ); // Bounds line = file->readLine( '\n' ); - expectedLine = String( "( assert ( <= x0 " ) + SmtLibWriter::signedValue( 1 ) + " ) )"; + expectedLine = String( "(assert (<= x0 " ) + SmtLibWriter::signedValue( 1.0 ) + "))"; TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( "( assert ( <= x1 " ) + SmtLibWriter::signedValue( 1 ) + " ) )"; + expectedLine = String( "(assert (<= x1 " ) + SmtLibWriter::signedValue( 1.0 ) + "))"; TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( "( assert ( >= x0 " ) + SmtLibWriter::signedValue( 1 ) + " ) )"; + expectedLine = String( "(assert (>= x0 " ) + SmtLibWriter::signedValue( 1.0 ) + "))"; TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( "( assert ( >= x1 " ) + SmtLibWriter::signedValue( -1 ) + " ) )"; + expectedLine = String( "(assert (>= x1 -1.0))" ); TS_ASSERT_EQUALS( line, expectedLine ); // Tableau line = file->readLine( '\n' ); - expectedLine = - String( "( assert ( = 0 ( + x0 ( * " ) + SmtLibWriter::signedValue( 2 ) + " x1 ) ) ) )"; + expectedLine = "(assert (= 0.0 (+ x0 (* 2.0 x1))))"; TS_ASSERT_EQUALS( line, expectedLine ); // Relu line = file->readLine( '\n' ); - expectedLine = "( assert ( = x1 ( ite ( >= x0 0 ) x0 0 ) ) )"; + expectedLine = "(assert (ite (<= 0.0 x0) (= x0 x1) (<= x1 0.0)))"; TS_ASSERT_EQUALS( line, expectedLine ); // Sign line = file->readLine( '\n' ); - expectedLine = "( assert ( = x1 ( ite ( >= x0 0 ) 1 ( - 1 ) ) ) )"; + expectedLine = "(assert (ite (>= x0 0.0) (= x1 1.0) (= x1 (- 1.0))))"; TS_ASSERT_EQUALS( line, expectedLine ); // Absolute Value line = file->readLine( '\n' ); - expectedLine = "( assert ( = x1 ( ite ( >= x0 0 ) x0 ( - x0 ) ) ) )"; + expectedLine = "(assert (ite (>= x0 0.0) (= x1 x0) (= x1 (- x0))))"; TS_ASSERT_EQUALS( line, expectedLine ); // LeakyRelu line = file->readLine( '\n' ); - expectedLine = "( assert ( = x1 ( ite ( >= x0 0 ) x0 ( * 0.1 x0 ) ) ) )"; + expectedLine = "(assert (ite (>= x0 0) (= x1 x0) (= x1 (* 0.1 x0))))"; TS_ASSERT_EQUALS( line, expectedLine ); // Max line = file->readLine( '\n' ); - expectedLine = String( "( assert ( => ( and ( >= x2 x3 ) ( >= x2 x4 ) ) ( = x1 x2 ) ) )" ); + expectedLine = String( "(assert (=> (and (>= x2 x3) (>= x2 x4)) (= x1 x2)))" ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( "( assert ( => ( and ( >= x3 x2 ) ( >= x3 x4 ) ) ( = x1 x3 ) ) )" ); + expectedLine = String( "(assert (=> (and (>= x3 x2) (>= x3 x4)) (= x1 x3)))" ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( "( assert ( => ( and ( >= x4 x2 ) ( >= x4 x3 ) ) ( = x1 x4 ) ) )" ); + expectedLine = String( "(assert (=> (and (>= x4 x2) (>= x4 x3)) (= x1 x4)))" ); TS_ASSERT_EQUALS( line, expectedLine ); // Disjunctions (several cases) line = file->readLine( '\n' ); - expectedLine = String( "( assert" ); + expectedLine = String( "(assert" ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( "( or" ); + expectedLine = String( "(or" ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = - String( "( and ( = ( - 4 ) ( + x0 ( * ( - 2 ) x1 ) ) ) ( <= x1 ( - 2 ) ) )" ); + expectedLine = String( "(and (= (- 4.0) (+ x0 (* (- 2.0) x1))) (<= x1 (- 2.0)))" ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( "( and ( >= x1 2 )( <= x0 ( - 1.5 ) ) )" ); + expectedLine = String( "(and (>= x1 2.0)(<= x0 (- 1.5)))" ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( " ) )" ); + expectedLine = String( "))" ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( "( assert" ); + expectedLine = String( "(assert" ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( "( or" ); + expectedLine = String( "(or" ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( "( = 1 ( - x1 ) ) " ); + expectedLine = String( "(= 1.0 (- x1)) " ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( "( or" ); + expectedLine = String( "(or" ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( "( = 1 ( - x2 ) ) " ); + expectedLine = String( "(= 1.0 (- x2)) " ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( "( = 1 ( - x3 ) ) " ); + expectedLine = String( "(= 1.0 (- x3)) " ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( " ) ) )" ); + expectedLine = String( ")))" ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( "( assert" ); + expectedLine = String( "(assert" ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( "( >= x1 2 )" ); + expectedLine = String( "(>= x1 2.0)" ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( " )" ); + expectedLine = String( ")" ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( "( assert" ); + expectedLine = String( "(assert" ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( "( or" ); + expectedLine = String( "(or" ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( "( = 1 ( - x1 ) ) " ); + expectedLine = String( "(= 1.0 (- x1)) " ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( "( >= x1 2 )" ); + expectedLine = String( "(>= x1 2.0)" ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( " ) )" ); + expectedLine = String( "))" ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( "( assert" ); + expectedLine = String( "(assert" ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( "( and ( >= x1 2 )( and ( >= x1 2 )( <= x0 ( - 1.5 ) ) ) )" ); + expectedLine = String( "(and (>= x1 2.0)(and (>= x1 2.0)(<= x0 (- 1.5))))" ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( " )" ); + expectedLine = String( ")" ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( "( assert" ); + expectedLine = String( "(assert" ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = - String( "( and ( = 1 ( - x1 ) ) ( and ( = 1 ( - x1 ) ) ( = 1 ( - x1 ) ) ) )" ); + expectedLine = String( "(and (= 1.0 (- x1)) (and (= 1.0 (- x1)) (= 1.0 (- x1)) ))" ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = String( " )" ); + expectedLine = String( ")" ); TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = "( check-sat )"; + expectedLine = "(check-sat)"; TS_ASSERT_EQUALS( line, expectedLine ); line = file->readLine( '\n' ); - expectedLine = "( exit )"; + expectedLine = "(exit)"; TS_ASSERT_EQUALS( line, expectedLine ); } }; diff --git a/src/proofs/tests/Test_UnsatCertificateNode.h b/src/proofs/tests/Test_UnsatCertificateNode.h index ca18bc8608..550726cef7 100644 --- a/src/proofs/tests/Test_UnsatCertificateNode.h +++ b/src/proofs/tests/Test_UnsatCertificateNode.h @@ -30,7 +30,7 @@ class UnsatCertificateNodeTestSuite : public CxxTest::TestSuite Vector groundUpperBounds( 6, 1 ); Vector groundLowerBounds( 6, 0 ); - auto *root = new UnsatCertificateNode( NULL, PiecewiseLinearCaseSplit() ); + auto *root = new UnsatCertificateNode( NULL, PiecewiseLinearCaseSplit(), 0, 0 ); ReluConstraint relu = ReluConstraint( 1, 3 ); auto splits = relu.getCaseSplits(); @@ -39,8 +39,8 @@ class UnsatCertificateNodeTestSuite : public CxxTest::TestSuite PiecewiseLinearCaseSplit split1 = splits.back(); PiecewiseLinearCaseSplit split2 = splits.front(); - auto *child1 = new UnsatCertificateNode( root, split1 ); - auto *child2 = new UnsatCertificateNode( root, split2 ); + auto *child1 = new UnsatCertificateNode( root, split1, 1, 0 ); + auto *child2 = new UnsatCertificateNode( root, split2, 1, 0 ); TS_ASSERT_EQUALS( child1->getParent(), root ); TS_ASSERT_EQUALS( child2->getParent(), root ); @@ -64,7 +64,7 @@ class UnsatCertificateNodeTestSuite : public CxxTest::TestSuite */ void test_contradiction() { - UnsatCertificateNode root = UnsatCertificateNode( NULL, PiecewiseLinearCaseSplit() ); + UnsatCertificateNode root = UnsatCertificateNode( NULL, PiecewiseLinearCaseSplit(), 0, 0 ); auto upperBoundExplanation = Vector( 1, 1 ); auto lowerBoundExplanation = Vector( 1, 1 ); @@ -79,15 +79,15 @@ class UnsatCertificateNodeTestSuite : public CxxTest::TestSuite */ void test_plc_explanation_changes() { - UnsatCertificateNode root = UnsatCertificateNode( NULL, PiecewiseLinearCaseSplit() ); + UnsatCertificateNode root = UnsatCertificateNode( NULL, PiecewiseLinearCaseSplit(), 0, 0 ); Vector emptyVec; auto explanation1 = std::shared_ptr( - new PLCLemma( { 1 }, 1, 0, Tightening::UB, Tightening::UB, emptyVec, RELU, 0 ) ); + new PLCLemma( { 1 }, 1, 0, Tightening::UB, Tightening::UB, emptyVec, RELU, 0, 0 ) ); auto explanation2 = std::shared_ptr( - new PLCLemma( { 1 }, 1, -1, Tightening::UB, Tightening::UB, emptyVec, RELU, 0 ) ); + new PLCLemma( { 1 }, 1, -1, Tightening::UB, Tightening::UB, emptyVec, RELU, 0, 0 ) ); auto explanation3 = std::shared_ptr( - new PLCLemma( { 1 }, 1, -4, Tightening::UB, Tightening::UB, emptyVec, RELU, 0 ) ); + new PLCLemma( { 1 }, 1, -4, Tightening::UB, Tightening::UB, emptyVec, RELU, 0, 0 ) ); TS_ASSERT( root.getPLCLemmas().empty() );