Skip to content

Incorrect SAT witness in disjunction query #873

@Zinoex

Description

@Zinoex

The following MWE with a simple ReLU-activated FCNN (I can send the .onnx file) results in 'sat' with the output values (vals[outputVars[0]], vals[outputVars[1]]) = (-0.9808001445708098, -0.3497275237786019). The query is a disjunction, i.e. the output should be outside the upper/lower bounds along some dimension. However, the returned example is clearly within the bounds, and thus is not a witness of 'sat'; that is, the assertion on the last line fails.

onnx_path = 'simple_nn.onnx'
network = Marabou.read_onnx(onnx_path)  # Updated to read ONNX models

outputVars = network.outputVars[0].flatten()
inputVars = network.inputVars[0].flatten()
options = Marabou.createOptions(verbosity=1)

sample = np.array([-0.45, -0.99])
delta = 0.01
ub = np.array([-0.9600000381469727, -0.3095250427722931])
lb = np.array([-1.0199999809265137, -0.3695250451564789])

# Set the input variables to the sampled point
for i, inputVar in enumerate(inputVars):
    network.setLowerBound(inputVar, sample[i] - delta)
    network.setUpperBound(inputVar, sample[i] + delta)

# Create disjunctive constraint for all output dimensions
disjuncts = []

for i, outputVar in enumerate(outputVars):
    # nn_output >= ub
    equation_GE = MarabouUtils.Equation(MarabouCore.Equation.GE)
    equation_GE.addAddend(1, outputVar)
    equation_GE.setScalar(ub[i])

    # nn_output <= lb
    equation_LE = MarabouUtils.Equation(MarabouCore.Equation.LE)
    equation_LE.addAddend(1, outputVar)
    equation_LE.setScalar(lb[i])

    # For this dimension, either GE or LE must be true
    disjuncts.extend([[equation_GE], [equation_LE]])

network.addDisjunctionConstraint(disjuncts)

# Solve
res, vals, _ = network.solve(verbose=True, options=options)
if res == 'sat':
    for i, inputVar in enumerate(inputVars):
        assert vals[inputVar] >= sample[i] - delta
        assert vals[inputVar] <= sample[i] + delta
    example_found = False
    for i, outputVar in enumerate(outputVars):
        if vals[outputVar] >= ub[i] or vals[outputVar] <= lb[i]:
            example_found = True
            break
    assert example_found

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions