Skip to content
22 changes: 12 additions & 10 deletions claasp/cipher_modules/models/cp/mzn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import itertools
import math
import os
import subprocess
import time
from copy import deepcopy
Expand Down Expand Up @@ -96,7 +95,7 @@ def __init__(self, cipher, sat_or_milp='sat'):
self.component_probability_var = {}

def initialise_model(self):
self._variables_list = []
self._variables_declarations = []
self._model_constraints = []
self.c = 0
if self._cipher.is_spn():
Expand All @@ -121,7 +120,8 @@ def initialise_model(self):

def current_model_parts(self):
return MiniZincModelParts(
variables=list(self._variables_list),
prefix=list(self._model_prefix),
variables=list(self._variables_declarations),
constraints=list(self._model_constraints),
outputs=list(self.mzn_output_directives),
carries_outputs=list(self.mzn_carries_output_directives),
Expand Down Expand Up @@ -209,7 +209,7 @@ def add_solution_to_components_values_internal(

def build_generic_cp_model_from_dictionary(self, component_and_model_types, fixed_variables=None):
variables = []
self._variables_list = []
self._variables_declarations = []
self._model_constraints = []
component_types = [CIPHER_OUTPUT, CONSTANT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, SBOX, WORD_OPERATION]
operation_types = ['AND', 'MODADD', 'MODSUB', 'NOT', 'OR', 'ROTATE', 'SHIFT', 'SHIFT_BY_VARIABLE_AMOUNT', 'XOR']
Expand Down Expand Up @@ -237,7 +237,7 @@ def build_generic_cp_model_from_dictionary(self, component_and_model_types, fixe
raise ValueError("Unexpected return value from component generator")

self._model_constraints.extend(constraints)
self._variables_list.extend(variables)
self._variables_declarations.extend(variables)

if metadata:
probability_var = metadata.get("probability_var")
Expand Down Expand Up @@ -999,8 +999,9 @@ def solve_for_ARX(
"""
mzn_model_string = self.assemble_model(
MiniZincModelParts(
prefix=list(self._model_prefix),
variables=list(self._model_constraints),
constraints=list(self._variables_list),
constraints=list(self._variables_declarations),
)
)
solver_name_mzn = Solver.lookup(solver_name)
Expand Down Expand Up @@ -1114,7 +1115,8 @@ def write_minizinc_model_to_file(self, file_path, prefix=""):
file.write(
self.assemble_model(
MiniZincModelParts(
variables=self.mzn_comments + list(self._variables_list),
prefix=list(self._model_prefix),
variables=self.mzn_comments + list(self._variables_declarations),
constraints=list(self._model_constraints),
outputs=list(self.mzn_output_directives),
carries_outputs=list(self.mzn_carries_output_directives),
Expand Down Expand Up @@ -1152,7 +1154,7 @@ def model_constraints(self):
sage: constraints = mzn.model_constraints
sage: len(constraints) > 0
True
sage: 'constraint rot_0_0[2] = plaintext[11];' == constraints[29]
sage: 'constraint rot_0_0[2] = plaintext[11];' in constraints
True
"""
if not self._model_constraints:
Expand Down Expand Up @@ -1180,6 +1182,6 @@ def model_variables(self):
sage: 'array[0..15] of var 0..1: pre_modadd_0_1_0;' == variables[0]
True
"""
if not self._variables_list:
if not self._variables_declarations:
raise ValueError("No model generated")
return self._variables_list
return self._variables_declarations
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def create_boomerang_model(self, fixed_variables_for_top_cipher, fixed_variables
self._model_constraints.extend([get_word_operations()])
self._model_constraints.extend([get_bct_operations()])

self._variables_list.extend(
self._variables_declarations.extend(
self.differential_model_top_cipher.get_variables() + self.differential_model_bottom_cipher.get_variables()
)
self._model_constraints.extend(
Expand Down Expand Up @@ -226,7 +226,7 @@ def write_minizinc_model_to_file(self, file_path, prefix=""):
+ "\n"
+ model_string_bottom
+ "\n"
+ "\n".join(self._variables_list)
+ "\n".join(self._variables_declarations)
+ "\n"
+ "\n".join(self._model_constraints)
)
Expand All @@ -249,7 +249,7 @@ def get_hex_from_sublists(sublists, bool_dict):
return hex_values

if result.status not in [Status.UNKNOWN, Status.UNSATISFIABLE, Status.ERROR]:
list_of_sublist_of_vars = group_strings_by_pattern(self._variables_list)
list_of_sublist_of_vars = group_strings_by_pattern(self._variables_declarations)
dict_of_component_value = get_hex_from_sublists(list_of_sublist_of_vars, solution.__dict__)

return {"component_values": dict_of_component_value}
Expand Down
16 changes: 6 additions & 10 deletions claasp/cipher_modules/models/cp/mzn_models/mzn_cipher_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class MznCipherModel(MznModel):
def __init__(self, cipher):
super().__init__(cipher)

def build_cipher_model(self, fixed_variables=[], second=False):
def build_cipher_model(self, fixed_variables=[]):
"""
Build the cipher model.

Expand All @@ -55,10 +55,9 @@ def build_cipher_model(self, fixed_variables=[], second=False):
sage: cp.build_cipher_model(fixed_variables)
"""
self.initialise_model()
self._model_prefix.extend(self.input_constraints())
self.sbox_mant = []
variables = []
self._variables_list = []
self._variables_declarations = self.input_declarations()
constraints = self.fix_variables_value_constraints(fixed_variables)
component_types = (CIPHER_OUTPUT, CONSTANT, INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, SBOX, WORD_OPERATION)
operation_types = ("AND", "MODADD", "MODSUB", "NOT", "OR", "ROTATE", "SHIFT", "SHIFT_BY_VARIABLE_AMOUNT", "XOR")
Expand All @@ -77,13 +76,10 @@ def build_cipher_model(self, fixed_variables=[], second=False):
variables, constraints = component.cp_constraints(self.sbox_mant)

self._model_constraints.extend(constraints)
self._variables_list.extend(variables)
self._variables_declarations.extend(variables)

self._model_constraints.extend(self.final_constraints())

if not second:
self._model_constraints = self._model_prefix + self._model_constraints

def find_missing_bits(self, fixed_values=[], solver_name=SOLVER_DEFAULT, solver_external=True):
self.build_cipher_model(fixed_variables=fixed_values)
solution = self.solve(CIPHER, solver_name=solver_name, solve_external=solver_external)
Expand Down Expand Up @@ -119,9 +115,9 @@ def final_constraints(self):

return cp_constraints

def input_constraints(self):
def input_declarations(self):
"""
Return a list of CP constraints for the inputs of the cipher.
Return a list of CP variable declarations for the inputs of the cipher.

INPUT:

Expand All @@ -133,7 +129,7 @@ def input_constraints(self):
sage: from claasp.cipher_modules.models.cp.mzn_models.mzn_cipher_model import MznCipherModel
sage: speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=4)
sage: cp = MznCipherModel(speck)
sage: cp.input_constraints()
sage: cp.input_declarations()
['array[0..31] of var 0..1: plaintext;',
...
'array[0..31] of var 0..1: cipher_output_3_12;']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def build_cipher_model(self, fixed_variables=[]):
sage: minizinc.build_cipher_model()
...
"""
self._variables_list = []
self._variables_declarations = []
variables = []
constraints = self.fix_variables_value_constraints_for_ARX(fixed_variables)
self._model_constraints = constraints
Expand All @@ -63,4 +63,4 @@ def build_cipher_model(self, fixed_variables=[]):
variables, constraints = component.minizinc_constraints(self)

self._model_constraints.extend(constraints)
self._variables_list.extend(variables)
self._variables_declarations.extend(variables)
Original file line number Diff line number Diff line change
Expand Up @@ -105,22 +105,22 @@ def build_deterministic_truncated_xor_differential_trail_model(
if number_of_rounds is None:
number_of_rounds = self._cipher.number_of_rounds

self._variables_list = []
self._variables_declarations = []
constraints = self.fix_variables_value_constraints(fixed_variables)
deterministic_truncated_xor_differential = constraints

for component in self._cipher.get_all_components():
if check_if_implemented_component(component):
variables, constraints = self.propagate_deterministically(component, wordwise)
self._variables_list.extend(variables)
self._variables_declarations.extend(variables)
deterministic_truncated_xor_differential.extend(constraints)

if not wordwise:
variables, constraints = self.input_deterministic_truncated_xor_differential_constraints()
else:
variables, constraints = self.input_wordwise_deterministic_truncated_xor_differential_constraints()
self._model_prefix.extend(variables)
self._variables_list.extend(constraints)
self._variables_declarations.extend(variables)
deterministic_truncated_xor_differential.extend(constraints)
if not wordwise:
deterministic_truncated_xor_differential.extend(
self.final_deterministic_truncated_xor_differential_constraints(minimize)
Expand All @@ -130,7 +130,7 @@ def build_deterministic_truncated_xor_differential_trail_model(
self.final_wordwise_deterministic_truncated_xor_differential_constraints(minimize)
)

self._model_constraints = self._model_prefix + deterministic_truncated_xor_differential
self._model_constraints = deterministic_truncated_xor_differential

def final_deterministic_truncated_xor_differential_constraints(self, minimize=False):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def build_deterministic_truncated_xor_differential_trail_model(self, fixed_varia
"""
variables = []
constraints = self.fix_variables_value_constraints_for_ARX(fixed_variables)
self._variables_list = []
self._variables_declarations = []
self._model_constraints = constraints

for component in self._cipher.get_all_components():
Expand All @@ -63,5 +63,5 @@ def build_deterministic_truncated_xor_differential_trail_model(self, fixed_varia
else:
print(f"{component.id} not yet implemented")

self._variables_list.extend(variables)
self._variables_declarations.extend(variables)
self._model_constraints.extend(constraints)
Original file line number Diff line number Diff line change
Expand Up @@ -68,22 +68,22 @@ def build_differential_linear_continuous_trail_model(self, fixed_values=[]):
self.init_input_declarations()

self._model_constraints.extend(self.connect_components())
self._variables_list.insert(0, get_continuous_operations())
self._variables_declarations.insert(0, get_continuous_operations())
self.add_linear_mask_variables()

def add_linear_mask_variables(self):
block_size = self._cipher.output_bit_size
output_mask = (
f"array[0..{block_size - 1}] of var 0..1: output_mask;"
)
self._variables_list.append(output_mask)
self._variables_declarations.append(output_mask)

def init_input_declarations(self):
input_declarations = [
f"array[0..{size - 1}] of var -1.0..1.0: {name};"
for name, size in zip(self._cipher.inputs, self._cipher.inputs_bit_size)
]
self._variables_list.extend(input_declarations)
self._variables_declarations.extend(input_declarations)

def connect_components(self):
constraints = []
Expand Down Expand Up @@ -126,11 +126,11 @@ def _build_linear_mask_correlation_constraints(self):
f"array[0..{block_size - 1}] of var lower..upper: active_bit_correlations = "
f"array1d(0..{block_size - 1}, [{active_bit_correlations_entries}]);"
)
self._variables_list.append(active_bit_correlations_decl)
self._variables_declarations.append(active_bit_correlations_decl)

def _build_difflin_corr_constraints(self):
self._variables_list.append("var lower..upper: differential_linear_correlation;")
self._variables_list.append("var float: correlation_log2_approximation;")
self._variables_declarations.append("var lower..upper: differential_linear_correlation;")
self._variables_declarations.append("var float: correlation_log2_approximation;")

self._model_constraints.append(
"constraint differential_linear_correlation = product(active_bit_correlations);"
Expand Down Expand Up @@ -274,7 +274,7 @@ def _format_continuous_value(self, val):

def solve_for_ARX(self, solver_name="scip", timeout_in_seconds_=30, processes_=4):
constraints = self._model_constraints
variables = self._variables_list
variables = self._variables_declarations
mzn_model_string = "\n".join(variables) + "\n".join(constraints)
solver_name_mzn = Solver.lookup(solver_name)
bit_mzn_model = Model()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def _is_probability_array_declaration(var):
normalized = " ".join(var.strip().split())
return normalized.startswith("array[") and "] of var " in normalized and normalized.endswith(": p;")

has_declaration = any(_is_probability_array_declaration(var) for var in self._variables_list)
has_declaration = any(_is_probability_array_declaration(var) for var in self._variables_declarations)
if has_declaration:
return None

Expand Down Expand Up @@ -584,20 +584,18 @@ def build_xor_differential_linear_model(self, weight=-1, fixed_variables=None):

weight_declarations, weight_constraints = self._build_weight_constraints(weight)

declarations = self._state_declarations() + self._variables_list
declarations = self._state_declarations() + self._variables_declarations
if probability_array_declaration is not None:
declarations.append(probability_array_declaration)
declarations.extend(weight_declarations)

self._variables_list = declarations
self._variables_declarations = declarations

self._model_constraints.extend(middle_bottom_constraints)
self._model_constraints.extend(branch_constraints)
self._model_constraints.extend(weight_constraints)
self._model_constraints.extend(self._build_output_block(weight))

self._model_constraints = self._model_prefix + self._model_constraints

def find_one_differential_linear_trail_with_fixed_weight(
self,
weight,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def build_hybrid_impossible_xor_differential_trail_model(
round_num, component_num = map(int, component.id.split("_")[-2:])
self.sboxes_component_number_list[round_num] += [component_num]

self._variables_list = []
self._variables_declarations = []
constraints = self.fix_variables_value_constraints(fixed_variables)
deterministic_truncated_xor_differential = constraints
self.middle_round = middle_round
Expand All @@ -138,20 +138,20 @@ def build_hybrid_impossible_xor_differential_trail_model(
backward_components, clean=False
)

self._variables_list.extend(direct_variables)
self._variables_declarations.extend(direct_variables)
deterministic_truncated_xor_differential.extend(direct_constraints)

inverse_variables, inverse_constraints = self.clean_inverse_impossible_variables_constraints(
backward_components, inverse_variables, inverse_constraints
)
self._variables_list.extend(inverse_variables)
self._variables_declarations.extend(inverse_variables)
deterministic_truncated_xor_differential.extend(inverse_constraints)

variables, constraints = self.input_constraints(
number_of_rounds=number_of_rounds, middle_round=middle_round, probabilistic=probabilistic
)
self._model_prefix.extend(variables)
self._variables_list.extend(constraints)
self._variables_declarations.extend(variables)
deterministic_truncated_xor_differential.extend(constraints)

deterministic_truncated_xor_differential.extend(
self.final_impossible_constraints(
Expand All @@ -165,7 +165,7 @@ def build_hybrid_impossible_xor_differential_trail_model(
)
set_of_constraints = deterministic_truncated_xor_differential

self._model_constraints = self._model_prefix + self.clean_constraints(
self._model_constraints = self.clean_constraints(
set_of_constraints, initial_round, middle_round, final_round
)

Expand Down Expand Up @@ -836,7 +836,7 @@ def solve(
if final_round is None:
final_round = self._cipher.number_of_rounds
command = self.get_command_for_solver_process(model_type, solver_name, processes_, timeout_in_seconds_)
model = "\n".join(self._variables_list + self._model_constraints) + "\n"
model = "\n".join(self._model_prefix + self._variables_declarations + self._model_constraints) + "\n"
solver_process = subprocess.run(command, input=model, capture_output=True, text=True)
if solver_process.returncode >= 0:
solutions = []
Expand Down
Loading
Loading