diff --git a/claasp/cipher_modules/models/cp/minizinc_utils/mzn_continuous_predicates.py b/claasp/cipher_modules/models/cp/minizinc_utils/mzn_continuous_predicates.py index 267c54217..bd025354d 100644 --- a/claasp/cipher_modules/models/cp/minizinc_utils/mzn_continuous_predicates.py +++ b/claasp/cipher_modules/models/cp/minizinc_utils/mzn_continuous_predicates.py @@ -1,3 +1,5 @@ +import math + def get_continuous_operations(): """ Returns the MiniZinc code required for continuous correlation propagation. @@ -55,4 +57,40 @@ def get_continuous_operations(): int: n = length(int_var); } in array1d(0..n-1, [if int_var[i] = 0 then -1.0 else 1.0 endif|i in 0..n-1]); - """ \ No newline at end of file + """ + + +def active_bit_correlation_expression(mask_expr, correlation_expr): + """Return the masked active-bit expression used in continuous correlation products.""" + return ( + f"if {mask_expr} = 0 then 1.0 " + f"else {mask_expr} * abs({correlation_expr}) endif" + ) + + +def piecewise_log2_approximation_expression(correlation_expr, scale=1.0, else_value="0.0"): + """Return the shared piecewise linear approximation of -log2(correlation).""" + scale_prefix = "" if math.isclose(scale, 1.0) else f"{scale} * " + return ( + f"{scale_prefix}(\n" + f"if {correlation_expr} <= 0.001021453702391378 then\n" + f"-19931.57001201849*{correlation_expr}+29.89737278555626\n" + f"elseif {correlation_expr} <= 0.004151650554233785 /\\ {correlation_expr} > 0.001021453702391378 then\n" + f"-584.962260272084*{correlation_expr}+10.13570866882117\n" + f"elseif {correlation_expr} <= 0.01359667098324998 /\\ {correlation_expr} > 0.004151650554233785 then\n" + f"-192.6450521799878*{correlation_expr}+8.506944714410169\n" + f"elseif {correlation_expr} <= 0.05399137458004444 /\\ {correlation_expr} > 0.01359667098324998 then\n" + f"-50.62607129324977*{correlation_expr}+6.575959357916722\n" + f"elseif {correlation_expr} <= 0.1420480516058986 /\\ {correlation_expr} > 0.05399137458004444 then\n" + f"-11.87410019056137*{correlation_expr}+4.483687170396419\n" + f"elseif {correlation_expr} <= 0.2463455066216964 /\\ {correlation_expr} > 0.1420480516058986 then\n" + f"-8.613130253286352*{correlation_expr}+4.020472744461092\n" + f"elseif {correlation_expr} <= 0.595815289564374 /\\ {correlation_expr} > 0.2463455066216964 then\n" + f"-3.761918786389538*{correlation_expr}+2.825398597919413\n" + f"elseif {correlation_expr} <= 0.998000001 /\\ {correlation_expr} > 0.595815289564374 then\n" + f"-1.444862453710759*{correlation_expr}+1.44486100812744\n" + f"else\n" + f"{else_value}\n" + f"endif\n" + ")" + ) \ No newline at end of file diff --git a/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_continuous_model.py b/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_continuous_model.py index b6854372b..b58757e38 100644 --- a/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_continuous_model.py +++ b/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_continuous_model.py @@ -3,7 +3,11 @@ import time from minizinc import Instance, Model, Solver, Status from claasp.cipher_modules.models.cp.mzn_model import MznModel -from claasp.cipher_modules.models.cp.minizinc_utils.mzn_continuous_predicates import get_continuous_operations +from claasp.cipher_modules.models.cp.minizinc_utils.mzn_continuous_predicates import ( + active_bit_correlation_expression, + get_continuous_operations, + piecewise_log2_approximation_expression, +) from claasp.name_mappings import CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, PERMUTATION_COMPONENT, WORD_OPERATION class MznDifferentialLinearContinuousModel(MznModel): @@ -117,8 +121,7 @@ def _build_linear_mask_correlation_constraints(self): cipher_output_id = self._get_cipher_output_id() active_bit_correlations_entries = ", ".join([ - f"if output_mask[{i}] = 0 then 1.0 " - f"else output_mask[{i}] * abs({cipher_output_id}[{i}]) endif" + active_bit_correlation_expression(f"output_mask[{i}]", f"{cipher_output_id}[{i}]") for i in range(block_size) ]) @@ -141,28 +144,10 @@ def _build_difflin_corr_constraints(self): self._model_constraints.append( "constraint sum(array1d(output_mask)) >= 1;" ) - self._model_constraints.append(""" - constraint correlation_log2_approximation = - if differential_linear_correlation <= 0.001021453702391378 then - -19931.57001201849*differential_linear_correlation+29.89737278555626 - elseif differential_linear_correlation <= 0.004151650554233785 /\\ differential_linear_correlation > 0.001021453702391378 then - -584.962260272084*differential_linear_correlation+10.13570866882117 - elseif differential_linear_correlation <= 0.01359667098324998 /\\ differential_linear_correlation > 0.004151650554233785 then - -192.6450521799878*differential_linear_correlation+8.506944714410169 - elseif differential_linear_correlation <= 0.05399137458004444 /\\ differential_linear_correlation > 0.01359667098324998 then - -50.62607129324977*differential_linear_correlation+6.575959357916722 - elseif differential_linear_correlation <= 0.1420480516058986 /\\ differential_linear_correlation > 0.05399137458004444 then - -11.87410019056137*differential_linear_correlation+4.483687170396419 - elseif differential_linear_correlation <= 0.2463455066216964 /\\ differential_linear_correlation > 0.1420480516058986 then - -8.613130253286352*differential_linear_correlation+4.020472744461092 - elseif differential_linear_correlation <= 0.595815289564374 /\\ differential_linear_correlation > 0.2463455066216964 then - -3.761918786389538*differential_linear_correlation+2.825398597919413 - elseif differential_linear_correlation <= 0.998000001 /\\ differential_linear_correlation > 0.595815289564374 then - -1.444862453710759*differential_linear_correlation+1.44486100812744 - else - 1=1 - endif; - """) + self._model_constraints.append( + "constraint correlation_log2_approximation = " + f"{piecewise_log2_approximation_expression('differential_linear_correlation')};" + ) def find_lowest_continuous_correlation(self, fixed_values=[], solver_name="scip"): self.build_differential_linear_continuous_trail_model(fixed_values=fixed_values) diff --git a/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model.py b/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model.py index 677a7e3c1..c4bfcc653 100644 --- a/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model.py +++ b/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model.py @@ -19,17 +19,20 @@ import re import time as tm +from claasp.cipher_modules.models.cp.minizinc_utils.mzn_continuous_predicates import ( + active_bit_correlation_expression, + piecewise_log2_approximation_expression, +) from claasp.cipher_modules.models.cp.mzn_model import MznModel, SOLVE_SATISFY from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_linear_model import MznXorLinearModel from claasp.cipher_modules.models.cp.solvers import SOLVER_DEFAULT -from claasp.cipher_modules.models.utils import get_bit_bindings +from claasp.cipher_modules.models.utils import get_bit_bindings, integer_to_bit_list, set_fixed_variables from claasp.name_mappings import ( CIPHER_OUTPUT, CONSTANT, INPUT_KEY, INPUT_PLAINTEXT, INPUT_TWEAK, - INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, SBOX, @@ -45,12 +48,17 @@ class MznDifferentialLinearModel(MznModel): - top: XOR differential - middle: deterministic/semi-deterministic truncated XOR differential - bottom: XOR linear + + If ``single_key`` is True, the linear-weight contribution only counts + bottom components that remain after removing the key schedule from the + cipher. """ _ALLOWED_MIDDLE_MODELS = { "cp_deterministic_truncated_xor_differential_constraints", "cp_semi_deterministic_truncated_xor_differential_constraints", "cp_deterministic_truncated_xor_differential_trail_constraints", + "cp_continuous_differential_propagation_constraints", } _ALLOWED_WORD_OPERATIONS = { @@ -71,6 +79,7 @@ def __init__( list_of_components, middle_part_model="cp_semi_deterministic_truncated_xor_differential_constraints", standard_differential_part=True, + single_key=False, ): super().__init__(cipher) self.standard_differential_part = standard_differential_part @@ -87,6 +96,8 @@ def __init__( } self.middle_part_model = middle_part_model + self.single_key = single_key + self._cached_weight_bottom_component_ids = None if self.middle_part_model not in self._ALLOWED_MIDDLE_MODELS: raise ValueError( f"middle_part_model should be one of {sorted(self._ALLOWED_MIDDLE_MODELS)}" @@ -108,6 +119,10 @@ def format_func(record): self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(cipher, format_func) self.raw_bit_bindings, self.raw_bit_bindings_for_intermediate_output = get_bit_bindings(cipher) + def _is_continuous_middle(self): + """Return True when the middle part uses continuous correlation propagation.""" + return self.middle_part_model == "cp_continuous_differential_propagation_constraints" + def _validate_component_partitioning(self): allowed_overlapping_ids = set(self._get_truncated_xor_differential_components_in_border()) overlap = (self.middle_part_component_ids & self.bottom_part_component_ids) - allowed_overlapping_ids @@ -126,6 +141,23 @@ def _validate_arx_only_cipher(self): def _get_component_by_id(self, component_id): return self._cipher.component_from_id(component_id) + def _weight_bottom_component_ids(self): + if not self.single_key: + return self.bottom_part_component_ids + + if self._cached_weight_bottom_component_ids is not None: + return self._cached_weight_bottom_component_ids + + try: + cipher_without_key_schedule = self._cipher.remove_key_schedule() + no_key_schedule_ids = set(cipher_without_key_schedule.get_all_components_ids()) + effective_bottom_component_ids = self.bottom_part_component_ids & no_key_schedule_ids + except Exception: + effective_bottom_component_ids = set(self.bottom_part_component_ids) + + self._cached_weight_bottom_component_ids = effective_bottom_component_ids + return self._cached_weight_bottom_component_ids + def _parse_linear_bit_id(self, bit_id): match = re.match(r"^(.*)_([io])\[(\d+)\]$", bit_id) if not match: @@ -159,7 +191,16 @@ def _state_declarations(self): continue if component.id in self.middle_part_component_ids: - domain = "0..2" + if self._is_continuous_middle(): + # Continuous components self-declare their variables + # (x1_, x2_, ) in cp_continuous_differential_propagation_constraints + continue + else: + domain = "0..2" + elif component.id in self.bottom_part_component_ids: + # Linear components self-declare their variables + # (_i, _o) in cp_xor_linear_mask_propagation_constraints + continue else: domain = "0..1" declarations.append( @@ -209,33 +250,172 @@ def _get_truncated_xor_differential_components_in_border(self): return list(set(border_components)) + def _continuous_middle_input_expr(self, component_id, input_bit_index): + component = self._get_component_by_id(component_id) + accumulated = 0 + for input_idx, bit_positions in enumerate(component.input_bit_positions, start=1): + input_size = len(bit_positions) + if input_bit_index < accumulated + input_size: + local_index = input_bit_index - accumulated + return f"x{input_idx}_{component_id}[{local_index}]" + accumulated += input_size + + raise ValueError( + f"Invalid continuous input bit index {input_bit_index} for component {component_id}" + ) + + def _continuous_middle_connecting_constraints(self, include_middle_sources=True): + """ + Build wiring constraints for the continuous middle part. + + Continuous component generators declare x1_, x2_, ... input arrays, + but they do not connect them to predecessor components. This method links + all incoming arcs to those x* arrays. + """ + constraints = [] + constraints.extend(self._continuous_middle_connecting_from_raw_bindings(include_middle_sources)) + constraints.extend(self._continuous_middle_connecting_for_intermediate_outputs(include_middle_sources)) + return constraints + + def _is_valid_continuous_middle_source(self, source_component_id, source_side, include_middle_sources): + is_input = source_component_id in self._cipher.inputs + if not is_input and source_side != "o": + return False + if not include_middle_sources and source_component_id in self.middle_part_component_ids: + return False + return True + + def _is_valid_continuous_middle_successor(self, successor_component_id, successor_side): + successor_is_output = successor_component_id in ( + self._cipher.outputs if hasattr(self._cipher, "outputs") else [] + ) + return (successor_side == "i" or successor_is_output) and ( + successor_component_id in self.middle_part_component_ids + ) + + @staticmethod + def _continuous_middle_source_expr(source_component_id, source_bit_index): + return f"{source_component_id}[{int(source_bit_index)}]" + + def _continuous_middle_assignment_constraint(self, source_component_id, source_bit_expr, successor_bit_expr): + if source_component_id in self.middle_part_component_ids: + return f"constraint {successor_bit_expr} = {source_bit_expr};" + return f"constraint {successor_bit_expr} = if {source_bit_expr} = 1 then 1 else -1 endif;" + + def _continuous_middle_constraints_for_successors(self, source_component_id, source_bit_expr, successor_bits): + constraints = [] + for successor_component_id, successor_bit_index, successor_side in successor_bits: + if not self._is_valid_continuous_middle_successor(successor_component_id, successor_side): + continue + + successor_bit_expr = self._continuous_middle_input_expr( + successor_component_id, int(successor_bit_index) + ) + constraints.append( + self._continuous_middle_assignment_constraint( + source_component_id, source_bit_expr, successor_bit_expr + ) + ) + return constraints + + def _continuous_middle_connecting_from_raw_bindings(self, include_middle_sources): + constraints = [] + for output_bit_id, successor_bits in self.raw_bit_bindings.items(): + source_component_id, source_bit_index, source_side = output_bit_id + if not self._is_valid_continuous_middle_source( + source_component_id, source_side, include_middle_sources + ): + continue + + source_bit_expr = self._continuous_middle_source_expr(source_component_id, source_bit_index) + constraints.extend( + self._continuous_middle_constraints_for_successors( + source_component_id, source_bit_expr, successor_bits + ) + ) + return constraints + + def _continuous_middle_intermediate_source_pin(self, pins): + return next((pin for pin in pins if pin[2] == "o" or pin[0] in self._cipher.inputs), None) + + def _should_skip_continuous_middle_intermediate_source(self, source_id, include_middle_sources): + return (not include_middle_sources) and (source_id in self.middle_part_component_ids) + + def _continuous_middle_constraint_from_intermediate_pin( + self, inter_id, inter_bit, source_pin, include_middle_sources + ): + if not source_pin: + return None + + source_id, source_bit, _ = source_pin + if self._should_skip_continuous_middle_intermediate_source(source_id, include_middle_sources): + return None + + source_expr = self._continuous_middle_source_expr(source_id, source_bit) + successor_expr = self._continuous_middle_input_expr(inter_id, int(inter_bit)) + return self._continuous_middle_assignment_constraint(source_id, source_expr, successor_expr) + + def _continuous_middle_constraints_for_intermediate_bit_dict(self, bit_dict, include_middle_sources): + constraints = [] + for inter_bit_tuple, pins in bit_dict.items(): + inter_id, inter_bit, _ = inter_bit_tuple + source_pin = self._continuous_middle_intermediate_source_pin(pins) + constraint = self._continuous_middle_constraint_from_intermediate_pin( + inter_id, inter_bit, source_pin, include_middle_sources + ) + if constraint: + constraints.append(constraint) + return constraints + + def _continuous_middle_connecting_for_intermediate_outputs(self, include_middle_sources): + constraints = [] + for comp_id, bit_dict in self.raw_bit_bindings_for_intermediate_output.items(): + if comp_id not in self.middle_part_component_ids: + continue + constraints.extend( + self._continuous_middle_constraints_for_intermediate_bit_dict( + bit_dict, include_middle_sources + ) + ) + return constraints + def _top_to_middle_connecting_constraints(self): + if self._is_continuous_middle(): + return self._continuous_middle_connecting_constraints(include_middle_sources=False) + constraints = [] border_components = set(self._get_regular_xor_differential_components_in_border()) border_components.update(set(self._cipher.inputs)) for output_bit_id, successor_bits in self.raw_bit_bindings.items(): source_component_id, source_bit_index, source_side = output_bit_id - if source_side != "o" or source_component_id not in border_components: + is_input = source_component_id in self._cipher.inputs + if (not is_input and source_side != "o") or source_component_id not in border_components: continue source_bit_expr = f"{source_component_id}[{int(source_bit_index)}]" - for successor_bit in successor_bits: - successor_component_id, successor_bit_index, successor_side = successor_bit - if ( - successor_side != "i" - or successor_component_id not in self.middle_part_component_ids - ): - continue + constraints.extend(self._top_to_middle_successor_constraints(successor_bits, source_bit_expr)) - successor_bit_expr = f"{successor_component_id}[{int(successor_bit_index)}]" - constraints.append( - f"constraint {successor_bit_expr} = if {source_bit_expr} = 1 then 1 else 0 endif;" - ) + return constraints + def _top_to_middle_successor_constraints(self, successor_bits, source_bit_expr): + constraints = [] + for successor_bit in successor_bits: + successor_component_id, successor_bit_index, successor_side = successor_bit + if successor_side != "i" or successor_component_id not in self.middle_part_component_ids: + continue + + successor_bit_expr = f"{successor_component_id}[{int(successor_bit_index)}]" + if self._is_continuous_middle(): + constraints.append(f"constraint {successor_bit_expr} = if {source_bit_expr} = 1 then 1 else -1 endif;") + else: + constraints.append(f"constraint {successor_bit_expr} = if {source_bit_expr} = 1 then 1 else 0 endif;") return constraints def _middle_to_bottom_connecting_constraints(self): + if self._is_continuous_middle(): + return self._continuous_middle_to_bottom_connecting_constraints() + constraints = [] truncated_border_components = set(self._get_truncated_xor_differential_components_in_border()) # ensure that at least one bit difference exist in the concatenation of the output of the truncated border components @@ -257,6 +437,95 @@ def _middle_to_bottom_connecting_constraints(self): return constraints + def _continuous_middle_to_bottom_connecting_constraints(self): + """ + Connection from differential-linear (continuous) part to linear part. + + Implements the semantics from [BGGMP2023]:: + + combined[i] = if linear_mask[i] == 0 then 1 + else linear_mask[i] * abs(continuous_correlation[i]) + + differential_linear_correlation = product(combined) + + When the linear mask bit is 0 (inactive), the contribution is 1 (neutral + in the product). When it is 1 (active), the contribution is the absolute + value of the continuous correlation at that position. + """ + constraints = [] + border_components = set(self._get_truncated_xor_differential_components_in_border()) + border_components = self._filter_border_components_for_single_key(border_components) + + border_sources = self._collect_border_sources(border_components) + if not border_sources: + return constraints + + ordered_border_sources = self._order_border_sources(border_sources) + n = len(ordered_border_sources) + + self._variables_list.append(f"array[0..{n - 1}] of var 0..1: linear_border_mask;") + self._add_linear_border_mask_constraints(constraints, ordered_border_sources) + + self._variables_list.append(f"array[0..{n - 1}] of var -1.0..1.0: linear_mask_times_diff_lin_output;") + for idx, (cont_bit, _) in enumerate(ordered_border_sources): + constraints.append( + f"constraint linear_mask_times_diff_lin_output[{idx}] = " + f"{active_bit_correlation_expression(f'linear_border_mask[{idx}]', cont_bit)};" + ) + + self._variables_list.append("var lower..upper: differential_linear_correlation;") + constraints.append("constraint differential_linear_correlation = product(linear_mask_times_diff_lin_output);") + constraints.append("constraint differential_linear_correlation != 0.0;") + + return constraints + + def _filter_border_components_for_single_key(self, border_components): + if not self.single_key: + return border_components + last_middle_round = max( + self._cipher.get_round_from_component_id(cid) for cid in self.middle_part_component_ids + ) + state_input_ids = set() + for comp in self._cipher.get_components_in_round(last_middle_round): + if comp.id in self.middle_part_component_ids and comp.type == "intermediate_output": + if comp.output_bit_size > 16: + state_input_ids.update(comp.input_id_links) + if state_input_ids: + return border_components & state_input_ids + return border_components + + def _collect_border_sources(self, border_components): + border_sources = {} + for output_bit_id, successor_bits in self.bit_bindings.items(): + source_component_id, _, _ = self._parse_linear_bit_id(output_bit_id) + if source_component_id not in border_components: + continue + + source_bit_expr = output_bit_id.replace("_o[", "[") + for successor_bit in successor_bits: + successor_component_id, _, _ = self._parse_linear_bit_id(successor_bit) + if successor_component_id in self.bottom_part_component_ids: + border_sources.setdefault(source_bit_expr, set()).add(successor_bit) + return border_sources + + def _order_border_sources(self, border_sources): + def _sort_bit_expr(bit_expr): + component_id, _, bit_index = self._parse_linear_bit_id(bit_expr) + return component_id, bit_index + + ordered_border_sources = [] + for source_bit_expr in sorted(border_sources, key=_sort_bit_expr): + ordered_successors = sorted(border_sources[source_bit_expr], key=_sort_bit_expr) + ordered_border_sources.append((source_bit_expr, ordered_successors)) + return ordered_border_sources + + def _add_linear_border_mask_constraints(self, constraints, ordered_border_sources): + for idx, (_, successors) in enumerate(ordered_border_sources): + if len(successors) == 1: + constraints.append(f"constraint linear_border_mask[{idx}] = {successors[0]};") + else: + constraints.append(f"constraint linear_border_mask[{idx}] = ({' + '.join(successors)}) mod 2;") + def _branch_xor_linear_constraints_for_bottom_part(self): constraints = [] @@ -273,7 +542,10 @@ def _branch_xor_linear_constraints_for_bottom_part(self): return constraints def _build_weight_constraints(self, weight): - declarations = ["var int: weight;"] + if self._is_continuous_middle(): + declarations = ["var float: weight;"] + else: + declarations = ["var int: weight;"] def _sum_component_probability(component_ids): terms = [] @@ -288,18 +560,34 @@ def _sum_component_probability(component_ids): # Keep model-assignment semantics consistent with _component_model_entries: # if a component appears in both middle and bottom, it is modeled as bottom. effective_middle_component_ids = self.middle_part_component_ids - self.bottom_part_component_ids + effective_bottom_component_ids = self._weight_bottom_component_ids() top_probability_sum = _sum_component_probability(self.top_part_component_ids) - middle_probability_sum = _sum_component_probability(effective_middle_component_ids) - bottom_probability_sum = _sum_component_probability(self.bottom_part_component_ids) - # Approximate DL objective in log-domain: - # weight ≈ p + r + 2q - # where p = top, r = middle, q = bottom. - # The factor 2*bottom reflects that the exact term is 2^(2q), - # and the middle term (2·2^r - 1) is approximated by r. - constraints = [ - f"constraint weight = {top_probability_sum} + {middle_probability_sum} + 2*{bottom_probability_sum};" - ] + bottom_probability_sum = _sum_component_probability(effective_bottom_component_ids) + + constraints = [] + + if self._is_continuous_middle(): + # Continuous middle correlation is converted via a piece-wise linear approximation + # of log2(abs(correlation)) scaled for integer weights. + declarations.append("var float: correlation_log2_approximation;") + constraints.append( + "constraint correlation_log2_approximation = " + f"{piecewise_log2_approximation_expression('differential_linear_correlation', scale=100.0)};" + ) + constraints.append( + f"constraint weight = {top_probability_sum} + correlation_log2_approximation + 2*{bottom_probability_sum};" + ) + else: + middle_probability_sum = _sum_component_probability(effective_middle_component_ids) + # Approximate DL objective in log-domain: + # weight ≈ p + r + 2q + # where p = top, r = middle, q = bottom. + # The factor 2*bottom reflects that the exact term is 2^(2q), + # and the middle term (2·2^r - 1) is approximated by r. + constraints = [ + f"constraint weight = {top_probability_sum} + {middle_probability_sum} + 2*{bottom_probability_sum};" + ] if weight != -1: constraints.append(f"constraint weight <= {100 * weight};") @@ -370,7 +658,13 @@ def _build_output_block(self, weight): output_probability = self._component_probability_output(component, probability_output) output = self._append_probability_output(output, output_probability) - output += '"Trail weight = " ++ show(weight)];' + output += '"Trail weight = " ++ show(weight)' + + # Include the DL correlation in the output for continuous models + if self._is_continuous_middle(): + output += ' ++ "\\n" ++ "differential_linear_correlation = " ++ show(differential_linear_correlation)' + + output += '];' constraints.append(output) constraints.extend(self.mzn_output_directives) @@ -470,6 +764,7 @@ def _collect_differential_linear_component_weights(self, components_values): q_weight = 0.0 seen_middle = set() seen_bottom = set() + effective_bottom_component_ids = self._weight_bottom_component_ids() for component_id, component_solution in components_values.items(): if not isinstance(component_solution, dict): @@ -486,21 +781,24 @@ def _collect_differential_linear_component_weights(self, components_values): if base_component_id not in seen_middle: middle_sum += weight seen_middle.add(base_component_id) - continue - - if base_component_id in self.bottom_part_component_ids and base_component_id not in seen_bottom: - q_weight += weight - seen_bottom.add(base_component_id) + elif base_component_id in effective_bottom_component_ids: + if base_component_id not in seen_bottom: + q_weight += weight + seen_bottom.add(base_component_id) return p_weight, middle_sum, q_weight, bool(seen_middle) - @staticmethod - def _middle_weight_term(middle_sum, has_middle_components): + def _middle_weight_term(self, middle_sum, has_middle_components): if not has_middle_components: return 0.0 - if middle_sum == 1: - raise ValueError("Unexpected probability weight 1 in middle part") - return abs(math.log(abs(2 * (2**(-1*middle_sum)) - 1), 2)) + if self._is_continuous_middle(): + import math + return math.log(2 * (2**middle_sum) - 1, 2) + else: + import math + if middle_sum == 1: + raise ValueError("Unexpected probability weight 1 in middle part") + return abs(math.log(abs(2 * (2**(-1*middle_sum)) - 1), 2)) def _differential_linear_total_weight_from_components(self, components_values): p_weight, middle_sum, q_weight, _ = ( @@ -521,10 +819,15 @@ def _set_differential_linear_total_weight(self, solution): def _parse_solver_output( self, output_to_parse, model_type, truncated=False, solve_external=False, solver_name=SOLVER_DEFAULT ): + continuous_xor_differential_linear_model = self._is_continuous_middle() and model_type in ( + XOR_DIFFERENTIAL_LINEAR_ONE_SOLUTION, + XOR_DIFFERENTIAL_LINEAR_OPTIMAL_SOLUTION, + ) + parsed = super()._parse_solver_output( output_to_parse, model_type, - truncated=truncated, + truncated=truncated or continuous_xor_differential_linear_model, solve_external=solve_external, solver_name=solver_name, ) @@ -535,6 +838,9 @@ def _parse_solver_output( ): return parsed + return self._process_differential_linear_parsed_output(parsed, solve_external, continuous_xor_differential_linear_model) + + def _process_differential_linear_parsed_output(self, parsed, solve_external, continuous_xor_differential_linear_model): if not solve_external: if isinstance(parsed, list): for solution in parsed: @@ -544,19 +850,20 @@ def _parse_solver_output( self._set_differential_linear_total_weight(parsed) return parsed - if solve_external: + if continuous_xor_differential_linear_model: + solver_time, memory, components_values = parsed + else: solver_time, memory, components_values, _ = parsed - total_weight = [] - solution_keys = sorted( - components_values.keys(), - key=lambda key: int(key.replace("solution", "")) if key.startswith("solution") else 0, - ) - for solution_key in solution_keys: - solution_components_values = components_values.get(solution_key, {}) - total_weight.append(str(self._differential_linear_total_weight_from_components(solution_components_values))) - return solver_time, memory, components_values, total_weight - - return parsed + + total_weight = [] + solution_keys = sorted( + components_values.keys(), + key=lambda key: int(key.replace("solution", "")) if key.startswith("solution") else 0, + ) + for solution_key in solution_keys: + solution_components_values = components_values.get(solution_key, {}) + total_weight.append(str(self._differential_linear_total_weight_from_components(solution_components_values))) + return solver_time, memory, components_values, total_weight def _ensure_components_values(self, solution): if not isinstance(solution, dict): @@ -565,18 +872,48 @@ def _ensure_components_values(self, solution): return self._normalize_middle_part_components_values(solution) + def get_fixed_variable_from_hex(self, component_id, hex_value, endianness="big"): + if component_id in self._cipher.inputs: + idx = list(self._cipher.inputs).index(component_id) + bit_size = self._cipher.inputs_bit_size[idx] + else: + bit_size = self._cipher.get_component_from_id(component_id).output_bit_size + + if isinstance(hex_value, str) and hex_value.startswith("0x"): + int_val = int(hex_value, 16) + else: + int_val = int(hex_value) + + return set_fixed_variables( + component_id=component_id, + constraint_type="equal", + bit_positions=list(range(bit_size)), + bit_values=integer_to_bit_list(int_val, bit_size, endianness), + ) - def build_xor_differential_linear_model(self, weight=-1, fixed_variables=None): + def build_xor_differential_linear_model(self, weight=-1, fixed_variables=None, optimization_objective=None): if fixed_variables is None: fixed_variables = [] self.initialise_model() + + # Include continuous predicates (continuous_modadd, continuous_xor, etc.) + if self._is_continuous_middle(): + from claasp.cipher_modules.models.cp.minizinc_utils.mzn_continuous_predicates import ( + get_continuous_operations, + ) + self._model_prefix.append(get_continuous_operations()) + model_entries = self._component_model_entries() fixed_constraints = self._partition_fixed_value_constraints(fixed_variables) self.build_generic_cp_model_from_dictionary(model_entries) self._model_constraints = fixed_constraints + self._model_constraints + continuous_middle_constraints = [] + if self._is_continuous_middle(): + continuous_middle_constraints = self._continuous_middle_connecting_constraints() + probability_array_declaration = self._probability_array_declaration_from_component_map() middle_bottom_constraints = self._middle_to_bottom_connecting_constraints() @@ -591,10 +928,15 @@ def build_xor_differential_linear_model(self, weight=-1, fixed_variables=None): self._variables_declarations = declarations + self._model_constraints.extend(continuous_middle_constraints) self._model_constraints.extend(middle_bottom_constraints) self._model_constraints.extend(branch_constraints) self._model_constraints.extend(weight_constraints) - self._model_constraints.extend(self._build_output_block(weight)) + output_block = self._build_output_block(weight) + if optimization_objective: + output_block[0] = optimization_objective + + self._model_constraints.extend(output_block) def find_one_differential_linear_trail_with_fixed_weight( self, diff --git a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model_test.py b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model_test.py index 352bae5c3..103a4d15e 100644 --- a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model_test.py +++ b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model_test.py @@ -1,9 +1,10 @@ import itertools import math +import subprocess import pytest from claasp.cipher_modules.models.cp.mzn_models.mzn_differential_linear_model import MznDifferentialLinearModel -from claasp.cipher_modules.models.cp.solvers import CPSAT +from claasp.cipher_modules.models.cp.solvers import CPSAT, SCIP from claasp.cipher_modules.models.utils import ( differential_linear_checker_for_block_cipher_single_key, integer_to_bit_list, @@ -14,9 +15,10 @@ from claasp.ciphers.block_ciphers.ballet_block_cipher import BalletBlockCipher from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher from claasp.ciphers.permutations.chacha_permutation import ChachaPermutation, ROUND_MODE_HALF -from claasp.name_mappings import INPUT_PLAINTEXT, SATISFIABLE, INPUT_KEY, INPUT_MESSAGE +from claasp.name_mappings import INPUT_PLAINTEXT, SATISFIABLE, INPUT_KEY, INPUT_MESSAGE, XOR_DIFFERENTIAL_LINEAR_ONE_SOLUTION from claasp.ciphers.mac.siphash_mac import SiphashMAC + def _split_components(cipher, top_rounds_end, middle_rounds_end): top_part_components = [] middle_part_components = [] @@ -40,6 +42,18 @@ def _split_components(cipher, top_rounds_end, middle_rounds_end): "bottom_part_components": bottom_part_components, } +def _fixed_from_hex_test(cipher, component_id, hex_value): + if component_id in cipher.inputs: + idx = list(cipher.inputs).index(component_id) + bit_size = cipher.inputs_bit_size[idx] + else: + bit_size = cipher.get_component_from_id(component_id).output_bit_size + return set_fixed_variables( + component_id=component_id, + constraint_type="equal", + bit_positions=list(range(bit_size)), + bit_values=integer_to_bit_list(int(hex_value, 16), bit_size, "big"), + ) @pytest.mark.parametrize( "cipher_cls,cipher_kwargs,top_rounds_end,middle_rounds_end", @@ -87,6 +101,29 @@ def test_parse_linear_bit_id_handles_valid_and_invalid_formats(): model._parse_linear_bit_id("invalid_bit_id") +def test_single_key_weight_bottom_component_ids_exclude_key_schedule_for_speck(): + speck = SpeckBlockCipher(number_of_rounds=9) + component_model_list = _split_components(speck, top_rounds_end=4, middle_rounds_end=6) + model = MznDifferentialLinearModel( + speck, + component_model_list, + middle_part_model="cp_continuous_differential_propagation_constraints", + single_key=True, + ) + + effective_bottom_ids = model._weight_bottom_component_ids() + + assert effective_bottom_ids.issubset(model.bottom_part_component_ids) + # Speck key-schedule branch components must not contribute to the single-key linear weight. + assert "modadd_6_2" not in effective_bottom_ids + assert "modadd_7_2" not in effective_bottom_ids + assert "modadd_8_2" not in effective_bottom_ids + # Data-path branch components must still contribute. + assert "modadd_6_7" in effective_bottom_ids + assert "modadd_7_7" in effective_bottom_ids + assert "modadd_8_7" in effective_bottom_ids + + def test_normalize_middle_part_components_values_hex_and_unknown_bits(): speck = SpeckBlockCipher(number_of_rounds=6) component_model_list = _split_components(speck, top_rounds_end=2, middle_rounds_end=3) @@ -813,3 +850,246 @@ def test_optimal_semi_deterministic_differential_linear_trail_siphash(): print("Status:", trail["status"]) print("Total weight:", trail["total_weight"]) assert trail["status"] == SATISFIABLE +# ────────────────────────────────────────────────────────────────────── +# Tests for continuous middle model (cp_continuous_differential_propagation_constraints) +# ────────────────────────────────────────────────────────────────────── + +def test_differential_linear_trail_continuous_middle_speck_build(): + """ + Speck32/64, 3 rounds: 1 top (diff) + 1 middle (continuous) + 1 bottom (linear). + + Input difference from Table 4 of [BGGMP2023]: (0x0010, 0x5000). + Verifies the model builds correctly and can be solved with SCIP. + """ + speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=3) + component_list = _split_components(speck, top_rounds_end=1, middle_rounds_end=2) + model = MznDifferentialLinearModel( + speck, + component_list, + middle_part_model="cp_continuous_differential_propagation_constraints", + ) + + plaintext = set_fixed_variables( + "plaintext", "equal", list(range(32)), + integer_to_bit_list(0x00105000, 32, "big"), + ) + key = set_fixed_variables( + "key", "equal", list(range(64)), + integer_to_bit_list(0, 64, "big"), + ) + + model.build_xor_differential_linear_model( + weight=-1, + fixed_variables=[plaintext, key], + ) + + # Verify model structure + full_model = "\n".join(model._variables_list + model._model_constraints) + + # Check continuous predicates are present + assert "continuous_modadd" in full_model + assert "continuous_xor" in full_model + + # Check DL-to-linear connection + assert "linear_mask_times_diff_lin_output" in full_model + assert "differential_linear_correlation" in full_model + assert "product(linear_mask_times_diff_lin_output)" in full_model + + # Check weight constraint does NOT include middle probability for continuous + assert "weight = " in full_model + + # Check state declarations don't have var 0..2 (no truncated middle) + declarations = model._state_declarations() + decl_text = "\n".join(declarations) + assert "var 0..2" not in decl_text + +def test_differential_linear_trail_continuous_middle_9_rounds_speck_table4_full_model_float_search_scip(): + """ + Full integrated 9-round Speck32/64 differential-linear model using native MiniZinc SCIP. + """ + speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=9) + component_model_list = _split_components(speck, top_rounds_end=4, middle_rounds_end=6) + + model = MznDifferentialLinearModel( + speck, + component_model_list, + middle_part_model="cp_continuous_differential_propagation_constraints", + single_key=True, + ) + TABLE_4_COMBINED_DIFF_FIXED = { + "plaintext": "0xA8400010", + "key": "0x0000000000000000", + "intermediate_output_0_6": "0x81408100", + "intermediate_output_1_12": "0x00020400", + "intermediate_output_2_12": "0x00001000", + "intermediate_output_3_12": "0x10005000", + "xor_1_5": "0x0000", + "xor_2_5": "0x0000", + "xor_3_5": "0x0000", + } + TABLE_4_LIN_FIXED = { + "intermediate_output_6_12": "0x00000020", + "intermediate_output_7_12": "0x00800080", + "cipher_output_8_12": "0x02050204", + } + fixed_values = [ + model.get_fixed_variable_from_hex(comp_id, hex_val) + for comp_id, hex_val in {**TABLE_4_COMBINED_DIFF_FIXED, **TABLE_4_LIN_FIXED}.items() + ] + + model.build_xor_differential_linear_model( + weight=-1, + fixed_variables=fixed_values, + optimization_objective=( + "solve :: float_search(linear_mask_times_diff_lin_output, 1e-12, largest, indomain_split, complete) " + "minimize correlation_log2_approximation;" + ) + ) + + full_model = "\n".join(model._variables_list + model._model_constraints) + assert "minimize correlation_log2_approximation;" in full_model + + command = ["minizinc", "--input-from-stdin", "--solver-statistics", "--solver", "scip"] + solver_process = subprocess.run(command, input=full_model, capture_output=True, text=True, check=False) + + if solver_process.returncode != 0 and "Unknown solver" in solver_process.stderr: + pytest.skip("Native MiniZinc SCIP solver is not available in this environment") + + assert solver_process.returncode == 0 + solver_output = solver_process.stdout.splitlines() + assert all("--solver fscip" not in line for line in solver_output) + assert "UNSATISFIABLE" not in solver_process.stdout + assert "----------" in solver_process.stdout + + _, _, components_values, _ = model._parse_solver_output( + solver_output, XOR_DIFFERENTIAL_LINEAR_ONE_SOLUTION, solve_external=True, solver_name=SCIP + ) + + solution = components_values["solution1"] + top_sum, bottom_sum = 0.0, 0.0 + seen_top, seen_bottom = set(), set() + + for component_id, component_solution in solution.items(): + if not isinstance(component_solution, dict): + continue + weight = float(component_solution.get("weight", 0)) + base_component_id = model._base_component_id(component_id) + if base_component_id in model.top_part_component_ids and base_component_id not in seen_top: + top_sum += weight + seen_top.add(base_component_id) + if base_component_id in model._weight_bottom_component_ids() and base_component_id not in seen_bottom: + bottom_sum += weight + seen_bottom.add(base_component_id) + + assert top_sum == 7 + assert bottom_sum == 2 + + correlation_values = [ + float(line.split("=", 1)[1].strip()) + for line in solver_output if line.startswith("differential_linear_correlation =") + ] + assert correlation_values + correlation_value = max(correlation_values) + + expected_corr = 0.745481409287476 + expected_total_probability = -11.42375571965992 + total_probability = math.log2(correlation_value) - top_sum - bottom_sum*2 + + assert min(abs(value - expected_corr) for value in correlation_values) < 1e-6 + assert abs(total_probability - expected_total_probability) < 1e-6 + + +def test_differential_linear_trail_continuous_middle_9_rounds_speck_table6_full_model_float_search_scip(): + """ + Full integrated 9-round Speck32/64 differential-linear model for Table 6 using native MiniZinc SCIP. + """ + speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=9) + component_model_list = _split_components(speck, top_rounds_end=4, middle_rounds_end=6) + + model = MznDifferentialLinearModel( + speck, + component_model_list, + middle_part_model="cp_continuous_differential_propagation_constraints", + single_key=True, + ) + TABLE_6_COMBINED_DIFF_FIXED = { + "plaintext": "0x20540800", + "key": "0x0000000000000000", + "intermediate_output_0_6": "0xA0408040", + "intermediate_output_1_12": "0x01000002", + "intermediate_output_2_12": "0x00000008", + "intermediate_output_3_12": "0x00080028", + "xor_1_5": "0x0000", + "xor_2_5": "0x0000", + "xor_3_5": "0x0000", + } + TABLE_6_LIN_FIXED = { + "intermediate_output_6_12": "0x00200020", + "intermediate_output_7_12": "0x00814081", + "cipher_output_8_12": "0x0C000E01", + } + fixed_values = [ + model.get_fixed_variable_from_hex(comp_id, hex_val) + for comp_id, hex_val in {**TABLE_6_COMBINED_DIFF_FIXED, **TABLE_6_LIN_FIXED}.items() + ] + + model.build_xor_differential_linear_model( + weight=-1, + fixed_variables=fixed_values, + optimization_objective=( + "solve :: float_search(linear_mask_times_diff_lin_output, 1e-12, largest, indomain_split, complete) " + "minimize correlation_log2_approximation;" + ) + ) + + full_model = "\n".join(model._variables_list + model._model_constraints) + assert "minimize correlation_log2_approximation;" in full_model + + command = ["minizinc", "--input-from-stdin", "--solver-statistics", "--solver", "scip"] + solver_process = subprocess.run(command, input=full_model, capture_output=True, text=True, check=False) + + if solver_process.returncode != 0 and "Unknown solver" in solver_process.stderr: + pytest.skip("Native MiniZinc SCIP solver is not available in this environment") + + assert solver_process.returncode == 0 + solver_output = solver_process.stdout.splitlines() + assert all("--solver fscip" not in line for line in solver_output) + assert "UNSATISFIABLE" not in solver_process.stdout + assert "----------" in solver_process.stdout + + _, _, components_values, _ = model._parse_solver_output( + solver_output, XOR_DIFFERENTIAL_LINEAR_ONE_SOLUTION, solve_external=True, solver_name=SCIP + ) + + solution = components_values["solution1"] + top_sum, bottom_sum = 0.0, 0.0 + seen_top, seen_bottom = set(), set() + + for component_id, component_solution in solution.items(): + if not isinstance(component_solution, dict): + continue + weight = float(component_solution.get("weight", 0)) + base_component_id = model._base_component_id(component_id) + if base_component_id in model.top_part_component_ids and base_component_id not in seen_top: + top_sum += weight + seen_top.add(base_component_id) + if base_component_id in model._weight_bottom_component_ids() and base_component_id not in seen_bottom: + bottom_sum += weight + seen_bottom.add(base_component_id) + + assert top_sum == 7 + assert bottom_sum == 3 + + correlation_values = [ + float(line.split("=", 1)[1].strip()) + for line in solver_output if line.startswith("differential_linear_correlation =") + ] + assert correlation_values + correlation_value = max(correlation_values) + + expected_corr = 0.7759094272693416 + expected_total_probability = -13.36603983996931 + total_probability = math.log2(correlation_value) - top_sum - bottom_sum*2 + + assert min(abs(value - expected_corr) for value in correlation_values) < 1e-6 + assert abs(total_probability - expected_total_probability) < 1e-6 \ No newline at end of file