From 1893c54ee7a2873c75f797beebe4ec3491a559f8 Mon Sep 17 00:00:00 2001 From: Rony Date: Sat, 18 Apr 2026 10:59:50 -0500 Subject: [PATCH 1/5] FEATURE/Add: continuous diff-lin connecting constraints and Speck 9-round test - Translated connection constraints from middle to bottom to use a canonical meaning per border bit, preventing mathematical errors from fan-out logic. - Adapted the integrated model to properly reflect standalone continuous boundary semantics. - Updated the differential linear model test to use float_search and explicitly minimize the correlation log2 approximation. - Improved solver output parsing to iterate over all solutions and accurately fetch the optimal correlation, avoiding premature factible states. --- .../mzn_continuous_predicates.py | 38 +- ...zn_differential_linear_continuous_model.py | 35 +- .../mzn_differential_linear_model.py | 359 ++++++++++++-- .../mzn_differential_linear_model_test.py | 468 +++++++++++++++++- 4 files changed, 840 insertions(+), 60 deletions(-) diff --git a/claasp/cipher_modules/models/cp/minizinc_utils/mzn_continuous_predicates.py b/claasp/cipher_modules/models/cp/minizinc_utils/mzn_continuous_predicates.py index 267c54217..a810a6819 100644 --- a/claasp/cipher_modules/models/cp/minizinc_utils/mzn_continuous_predicates.py +++ b/claasp/cipher_modules/models/cp/minizinc_utils/mzn_continuous_predicates.py @@ -55,4 +55,40 @@ def get_continuous_operations(): int: n = length(int_var); } in array1d(0..n-1, [if int_var[i] = 0 then -1.0 else 1.0 endif|i in 0..n-1]); - """ \ No newline at end of file + """ + + +def active_bit_correlation_expression(mask_expr, correlation_expr): + """Return the masked active-bit expression used in continuous correlation products.""" + return ( + f"if {mask_expr} = 0 then 1.0 " + f"else {mask_expr} * abs({correlation_expr}) endif" + ) + + +def piecewise_log2_approximation_expression(correlation_expr, scale=1.0, else_value="0.0"): + """Return the shared piecewise linear approximation of -log2(correlation).""" + scale_prefix = "" if scale == 1.0 else f"{scale} * " + return ( + f"{scale_prefix}(\n" + f"if {correlation_expr} <= 0.001021453702391378 then\n" + f"-19931.57001201849*{correlation_expr}+29.89737278555626\n" + f"elseif {correlation_expr} <= 0.004151650554233785 /\\ {correlation_expr} > 0.001021453702391378 then\n" + f"-584.962260272084*{correlation_expr}+10.13570866882117\n" + f"elseif {correlation_expr} <= 0.01359667098324998 /\\ {correlation_expr} > 0.004151650554233785 then\n" + f"-192.6450521799878*{correlation_expr}+8.506944714410169\n" + f"elseif {correlation_expr} <= 0.05399137458004444 /\\ {correlation_expr} > 0.01359667098324998 then\n" + f"-50.62607129324977*{correlation_expr}+6.575959357916722\n" + f"elseif {correlation_expr} <= 0.1420480516058986 /\\ {correlation_expr} > 0.05399137458004444 then\n" + f"-11.87410019056137*{correlation_expr}+4.483687170396419\n" + f"elseif {correlation_expr} <= 0.2463455066216964 /\\ {correlation_expr} > 0.1420480516058986 then\n" + f"-8.613130253286352*{correlation_expr}+4.020472744461092\n" + f"elseif {correlation_expr} <= 0.595815289564374 /\\ {correlation_expr} > 0.2463455066216964 then\n" + f"-3.761918786389538*{correlation_expr}+2.825398597919413\n" + f"elseif {correlation_expr} <= 0.998000001 /\\ {correlation_expr} > 0.595815289564374 then\n" + f"-1.444862453710759*{correlation_expr}+1.44486100812744\n" + f"else\n" + f"{else_value}\n" + f"endif\n" + ")" + ) \ No newline at end of file diff --git a/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_continuous_model.py b/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_continuous_model.py index ad7b52fc2..5a94227d8 100644 --- a/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_continuous_model.py +++ b/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_continuous_model.py @@ -3,7 +3,11 @@ import time from minizinc import Instance, Model, Solver, Status from claasp.cipher_modules.models.cp.mzn_model import MznModel -from claasp.cipher_modules.models.cp.minizinc_utils.mzn_continuous_predicates import get_continuous_operations +from claasp.cipher_modules.models.cp.minizinc_utils.mzn_continuous_predicates import ( + active_bit_correlation_expression, + get_continuous_operations, + piecewise_log2_approximation_expression, +) from claasp.name_mappings import CONSTANT, INTERMEDIATE_OUTPUT, CIPHER_OUTPUT, WORD_OPERATION class MznDifferentialLinearContinuousModel(MznModel): @@ -117,8 +121,7 @@ def _build_linear_mask_correlation_constraints(self): cipher_output_id = self._get_cipher_output_id() active_bit_correlations_entries = ", ".join([ - f"if output_mask[{i}] = 0 then 1.0 " - f"else output_mask[{i}] * abs({cipher_output_id}[{i}]) endif" + active_bit_correlation_expression(f"output_mask[{i}]", f"{cipher_output_id}[{i}]") for i in range(block_size) ]) @@ -141,28 +144,10 @@ def _build_difflin_corr_constraints(self): self._model_constraints.append( "constraint sum(array1d(output_mask)) >= 1;" ) - self._model_constraints.append(""" - constraint correlation_log2_approximation = - if differential_linear_correlation <= 0.001021453702391378 then - -19931.57001201849*differential_linear_correlation+29.89737278555626 - elseif differential_linear_correlation <= 0.004151650554233785 /\\ differential_linear_correlation > 0.001021453702391378 then - -584.962260272084*differential_linear_correlation+10.13570866882117 - elseif differential_linear_correlation <= 0.01359667098324998 /\\ differential_linear_correlation > 0.004151650554233785 then - -192.6450521799878*differential_linear_correlation+8.506944714410169 - elseif differential_linear_correlation <= 0.05399137458004444 /\\ differential_linear_correlation > 0.01359667098324998 then - -50.62607129324977*differential_linear_correlation+6.575959357916722 - elseif differential_linear_correlation <= 0.1420480516058986 /\\ differential_linear_correlation > 0.05399137458004444 then - -11.87410019056137*differential_linear_correlation+4.483687170396419 - elseif differential_linear_correlation <= 0.2463455066216964 /\\ differential_linear_correlation > 0.1420480516058986 then - -8.613130253286352*differential_linear_correlation+4.020472744461092 - elseif differential_linear_correlation <= 0.595815289564374 /\\ differential_linear_correlation > 0.2463455066216964 then - -3.761918786389538*differential_linear_correlation+2.825398597919413 - elseif differential_linear_correlation <= 0.998000001 /\\ differential_linear_correlation > 0.595815289564374 then - -1.444862453710759*differential_linear_correlation+1.44486100812744 - else - 1=1 - endif; - """) + self._model_constraints.append( + "constraint correlation_log2_approximation = " + f"{piecewise_log2_approximation_expression('differential_linear_correlation')};" + ) def find_lowest_continuous_correlation(self, fixed_values=[], solver_name="scip"): self.build_differential_linear_continuous_trail_model(fixed_values=fixed_values) diff --git a/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model.py b/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model.py index aef235066..01f6376ac 100644 --- a/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model.py +++ b/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model.py @@ -19,6 +19,10 @@ import re import time as tm +from claasp.cipher_modules.models.cp.minizinc_utils.mzn_continuous_predicates import ( + active_bit_correlation_expression, + piecewise_log2_approximation_expression, +) from claasp.cipher_modules.models.cp.mzn_model import MznModel, SOLVE_SATISFY from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_linear_model import MznXorLinearModel from claasp.cipher_modules.models.cp.solvers import SOLVER_DEFAULT @@ -29,7 +33,6 @@ INPUT_KEY, INPUT_PLAINTEXT, INPUT_TWEAK, - INTERMEDIATE_OUTPUT, LINEAR_LAYER, MIX_COLUMN, SBOX, @@ -45,12 +48,17 @@ class MznDifferentialLinearModel(MznModel): - top: XOR differential - middle: deterministic/semi-deterministic truncated XOR differential - bottom: XOR linear + + If ``single_key`` is True, the linear-weight contribution only counts + bottom components that remain after removing the key schedule from the + cipher. """ _ALLOWED_MIDDLE_MODELS = { "cp_deterministic_truncated_xor_differential_constraints", "cp_semi_deterministic_truncated_xor_differential_constraints", "cp_deterministic_truncated_xor_differential_trail_constraints", + "cp_continuous_differential_propagation_constraints", } _ALLOWED_WORD_OPERATIONS = { @@ -71,6 +79,7 @@ def __init__( list_of_components, middle_part_model="cp_semi_deterministic_truncated_xor_differential_constraints", standard_differential_part=True, + single_key=False, ): super().__init__(cipher) self.standard_differential_part = standard_differential_part @@ -87,6 +96,8 @@ def __init__( } self.middle_part_model = middle_part_model + self.single_key = single_key + self._cached_weight_bottom_component_ids = None if self.middle_part_model not in self._ALLOWED_MIDDLE_MODELS: raise ValueError( f"middle_part_model should be one of {sorted(self._ALLOWED_MIDDLE_MODELS)}" @@ -108,6 +119,10 @@ def format_func(record): self.bit_bindings, self.bit_bindings_for_intermediate_output = get_bit_bindings(cipher, format_func) self.raw_bit_bindings, self.raw_bit_bindings_for_intermediate_output = get_bit_bindings(cipher) + def _is_continuous_middle(self): + """Return True when the middle part uses continuous correlation propagation.""" + return self.middle_part_model == "cp_continuous_differential_propagation_constraints" + def _validate_component_partitioning(self): allowed_overlapping_ids = set(self._get_truncated_xor_differential_components_in_border()) overlap = (self.middle_part_component_ids & self.bottom_part_component_ids) - allowed_overlapping_ids @@ -126,6 +141,23 @@ def _validate_arx_only_cipher(self): def _get_component_by_id(self, component_id): return self._cipher.get_component_from_id(component_id) + def _weight_bottom_component_ids(self): + if not self.single_key: + return self.bottom_part_component_ids + + if self._cached_weight_bottom_component_ids is not None: + return self._cached_weight_bottom_component_ids + + try: + cipher_without_key_schedule = self._cipher.remove_key_schedule() + no_key_schedule_ids = set(cipher_without_key_schedule.get_all_components_ids()) + effective_bottom_component_ids = self.bottom_part_component_ids & no_key_schedule_ids + except Exception: + effective_bottom_component_ids = set(self.bottom_part_component_ids) + + self._cached_weight_bottom_component_ids = effective_bottom_component_ids + return self._cached_weight_bottom_component_ids + def _parse_linear_bit_id(self, bit_id): match = re.match(r"^(.*)_([io])\[(\d+)\]$", bit_id) if not match: @@ -159,7 +191,16 @@ def _state_declarations(self): continue if component.id in self.middle_part_component_ids: - domain = "0..2" + if self._is_continuous_middle(): + # Continuous components self-declare their variables + # (x1_, x2_, ) in cp_continuous_differential_propagation_constraints + continue + else: + domain = "0..2" + elif component.id in self.bottom_part_component_ids: + # Linear components self-declare their variables + # (_i, _o) in cp_xor_linear_mask_propagation_constraints + continue else: domain = "0..1" declarations.append( @@ -209,14 +250,107 @@ def _get_truncated_xor_differential_components_in_border(self): return list(set(border_components)) + def _continuous_middle_input_expr(self, component_id, input_bit_index): + component = self._get_component_by_id(component_id) + accumulated = 0 + for input_idx, bit_positions in enumerate(component.input_bit_positions, start=1): + input_size = len(bit_positions) + if input_bit_index < accumulated + input_size: + local_index = input_bit_index - accumulated + return f"x{input_idx}_{component_id}[{local_index}]" + accumulated += input_size + + raise ValueError( + f"Invalid continuous input bit index {input_bit_index} for component {component_id}" + ) + + def _continuous_middle_connecting_constraints(self, include_middle_sources=True): + """ + Build wiring constraints for the continuous middle part. + + Continuous component generators declare x1_, x2_, ... input arrays, + but they do not connect them to predecessor components. This method links + all incoming arcs to those x* arrays. + """ + constraints = [] + + for output_bit_id, successor_bits in self.raw_bit_bindings.items(): + source_component_id, source_bit_index, source_side = output_bit_id + + is_input = source_component_id in self._cipher.inputs + if not is_input and source_side != "o": + continue + + if (not include_middle_sources) and source_component_id in self.middle_part_component_ids: + continue + + source_bit_expr = f"{source_component_id}[{int(source_bit_index)}]" + + for successor_bit in successor_bits: + successor_component_id, successor_bit_index, successor_side = successor_bit + successor_is_output = successor_component_id in (self._cipher.outputs if hasattr(self._cipher, "outputs") else []) + if ( + not (successor_side == "i" or successor_is_output) + or successor_component_id not in self.middle_part_component_ids + ): + continue + + successor_bit_expr = self._continuous_middle_input_expr( + successor_component_id, + int(successor_bit_index), + ) + + if source_component_id in self.middle_part_component_ids: + constraints.append(f"constraint {successor_bit_expr} = {source_bit_expr};") + else: + constraints.append( + f"constraint {successor_bit_expr} = if {source_bit_expr} = 1 then 1 else -1 endif;" + ) + + # Wire intermediate outputs since they are not in raw_bit_bindings sources + for comp_id, bit_dict in self.raw_bit_bindings_for_intermediate_output.items(): + if comp_id not in self.middle_part_component_ids: + continue + for inter_bit_tuple, pins in bit_dict.items(): + inter_id, inter_bit, _ = inter_bit_tuple + + source_pin = None + for pin in pins: + pin_id, pin_bit, pin_side = pin + if pin_side == "o" or pin_id in self._cipher.inputs: + source_pin = pin + break + + if not source_pin: + continue + + source_id, source_bit, _ = source_pin + if (not include_middle_sources) and source_id in self.middle_part_component_ids: + continue + + source_expr = f"{source_id}[{int(source_bit)}]" + successor_expr = self._continuous_middle_input_expr(inter_id, int(inter_bit)) + + if source_id in self.middle_part_component_ids: + constraints.append(f"constraint {successor_expr} = {source_expr};") + else: + constraints.append(f"constraint {successor_expr} = if {source_expr} = 1 then 1 else -1 endif;") + + return constraints + def _top_to_middle_connecting_constraints(self): + if self._is_continuous_middle(): + return self._continuous_middle_connecting_constraints(include_middle_sources=False) + constraints = [] border_components = set(self._get_regular_xor_differential_components_in_border()) border_components.update(set(self._cipher.inputs)) for output_bit_id, successor_bits in self.raw_bit_bindings.items(): source_component_id, source_bit_index, source_side = output_bit_id - if source_side != "o" or source_component_id not in border_components: + + is_input = source_component_id in self._cipher.inputs + if (not is_input and source_side != "o") or source_component_id not in border_components: continue source_bit_expr = f"{source_component_id}[{int(source_bit_index)}]" @@ -229,13 +363,23 @@ def _top_to_middle_connecting_constraints(self): continue successor_bit_expr = f"{successor_component_id}[{int(successor_bit_index)}]" - constraints.append( - f"constraint {successor_bit_expr} = if {source_bit_expr} = 1 then 1 else 0 endif;" - ) + if self._is_continuous_middle(): + # Cast differential int bit to continuous float: + # 0 (no difference) -> -1.0, 1 (difference) -> 1.0 + constraints.append( + f"constraint {successor_bit_expr} = if {source_bit_expr} = 1 then 1 else -1 endif;" + ) + else: + constraints.append( + f"constraint {successor_bit_expr} = if {source_bit_expr} = 1 then 1 else 0 endif;" + ) return constraints def _middle_to_bottom_connecting_constraints(self): + if self._is_continuous_middle(): + return self._continuous_middle_to_bottom_connecting_constraints() + constraints = [] truncated_border_components = set(self._get_truncated_xor_differential_components_in_border()) # ensure that at least one bit difference exist in the concatenation of the output of the truncated border components @@ -257,6 +401,106 @@ def _middle_to_bottom_connecting_constraints(self): return constraints + def _continuous_middle_to_bottom_connecting_constraints(self): + """ + Connection from differential-linear (continuous) part to linear part. + + Implements the semantics from [BGGMP2023]:: + + combined[i] = if linear_mask[i] == 0 then 1 + else linear_mask[i] * abs(continuous_correlation[i]) + + differential_linear_correlation = product(combined) + + When the linear mask bit is 0 (inactive), the contribution is 1 (neutral + in the product). When it is 1 (active), the contribution is the absolute + value of the continuous correlation at that position. + """ + constraints = [] + border_components = set(self._get_truncated_xor_differential_components_in_border()) + + # In single-key mode we only account for the data-path branch at the + # middle/bottom border, matching the single-key linear-weight semantics. + # For single_key=False we intentionally keep all border components, + # including the key-schedule branch. + if self.single_key: + last_middle_round = max( + self._cipher.get_round_from_component_id(cid) + for cid in self.middle_part_component_ids + ) + state_input_ids = set() + for comp in self._cipher.get_components_in_round(last_middle_round): + if comp.id in self.middle_part_component_ids and comp.type == "intermediate_output": + # The 32-bit output is the state; 16-bit is typically key schedule + if comp.output_bit_size > 16: + state_input_ids.update(comp.input_id_links) + if state_input_ids: + border_components = border_components & state_input_ids + + # Build a canonical border mask per continuous source bit. This mirrors + # the standalone continuous model semantics where one mask bit gates one + # correlation bit, instead of multiplying once per fan-out edge. + border_sources = {} + for output_bit_id, successor_bits in self.bit_bindings.items(): + source_component_id, _, _ = self._parse_linear_bit_id(output_bit_id) + if source_component_id not in border_components: + continue + + source_bit_expr = output_bit_id.replace("_o[", "[") + for successor_bit in successor_bits: + successor_component_id, _, _ = self._parse_linear_bit_id(successor_bit) + if successor_component_id in self.bottom_part_component_ids: + border_sources.setdefault(source_bit_expr, set()).add(successor_bit) + + if not border_sources: + return constraints + + def _sort_bit_expr(bit_expr): + component_id, _, bit_index = self._parse_linear_bit_id(bit_expr) + return component_id, bit_index + + ordered_border_sources = [] + for source_bit_expr in sorted(border_sources, key=_sort_bit_expr): + ordered_successors = sorted(border_sources[source_bit_expr], key=_sort_bit_expr) + ordered_border_sources.append((source_bit_expr, ordered_successors)) + + n = len(ordered_border_sources) + + self._variables_list.append( + f"array[0..{n - 1}] of var 0..1: linear_border_mask;" + ) + + for idx, (_, successors) in enumerate(ordered_border_sources): + if len(successors) == 1: + constraints.append( + f"constraint linear_border_mask[{idx}] = {successors[0]};" + ) + else: + constraints.append( + f"constraint linear_border_mask[{idx}] = ({' + '.join(successors)}) mod 2;" + ) + + # Declare the combined array + self._variables_list.append( + f"array[0..{n - 1}] of var -1.0..1.0: linear_mask_times_diff_lin_output;" + ) + + # Build: combined[i] = if mask==0 then 1 else mask*abs(corr) endif + for idx, (cont_bit, _) in enumerate(ordered_border_sources): + constraints.append( + f"constraint linear_mask_times_diff_lin_output[{idx}] = " + f"{active_bit_correlation_expression(f'linear_border_mask[{idx}]', cont_bit)};" + ) + + # Declare and constrain the differential-linear correlation + self._variables_list.append("var lower..upper: differential_linear_correlation;") + constraints.append( + "constraint differential_linear_correlation = product(linear_mask_times_diff_lin_output);" + ) + constraints.append("constraint differential_linear_correlation != 0.0;") + + return constraints + def _branch_xor_linear_constraints_for_bottom_part(self): constraints = [] @@ -273,7 +517,10 @@ def _branch_xor_linear_constraints_for_bottom_part(self): return constraints def _build_weight_constraints(self, weight): - declarations = ["var int: weight;"] + if self._is_continuous_middle(): + declarations = ["var float: weight;"] + else: + declarations = ["var int: weight;"] def _sum_component_probability(component_ids): terms = [] @@ -288,18 +535,34 @@ def _sum_component_probability(component_ids): # Keep model-assignment semantics consistent with _component_model_entries: # if a component appears in both middle and bottom, it is modeled as bottom. effective_middle_component_ids = self.middle_part_component_ids - self.bottom_part_component_ids + effective_bottom_component_ids = self._weight_bottom_component_ids() top_probability_sum = _sum_component_probability(self.top_part_component_ids) - middle_probability_sum = _sum_component_probability(effective_middle_component_ids) - bottom_probability_sum = _sum_component_probability(self.bottom_part_component_ids) - # Approximate DL objective in log-domain: - # weight ≈ p + r + 2q - # where p = top, r = middle, q = bottom. - # The factor 2*bottom reflects that the exact term is 2^(2q), - # and the middle term (2·2^r - 1) is approximated by r. - constraints = [ - f"constraint weight = {top_probability_sum} + {middle_probability_sum} + 2*{bottom_probability_sum};" - ] + bottom_probability_sum = _sum_component_probability(effective_bottom_component_ids) + + constraints = [] + + if self._is_continuous_middle(): + # Continuous middle correlation is converted via a piece-wise linear approximation + # of log2(abs(correlation)) scaled for integer weights. + declarations.append("var float: correlation_log2_approximation;") + constraints.append( + "constraint correlation_log2_approximation = " + f"{piecewise_log2_approximation_expression('differential_linear_correlation', scale=100.0)};" + ) + constraints.append( + f"constraint weight = {top_probability_sum} + correlation_log2_approximation + 2*{bottom_probability_sum};" + ) + else: + middle_probability_sum = _sum_component_probability(effective_middle_component_ids) + # Approximate DL objective in log-domain: + # weight ≈ p + r + 2q + # where p = top, r = middle, q = bottom. + # The factor 2*bottom reflects that the exact term is 2^(2q), + # and the middle term (2·2^r - 1) is approximated by r. + constraints = [ + f"constraint weight = {top_probability_sum} + {middle_probability_sum} + 2*{bottom_probability_sum};" + ] if weight != -1: constraints.append(f"constraint weight <= {100 * weight};") @@ -370,7 +633,13 @@ def _build_output_block(self, weight): output_probability = self._component_probability_output(component, probability_output) output = self._append_probability_output(output, output_probability) - output += '"Trail weight = " ++ show(weight)];' + output += '"Trail weight = " ++ show(weight)' + + # Include the DL correlation in the output for continuous models + if self._is_continuous_middle(): + output += ' ++ "\\n" ++ "differential_linear_correlation = " ++ show(differential_linear_correlation)' + + output += '];' constraints.append(output) constraints.extend(self.mzn_output_directives) @@ -470,6 +739,7 @@ def _collect_differential_linear_component_weights(self, components_values): q_weight = 0.0 seen_middle = set() seen_bottom = set() + effective_bottom_component_ids = self._weight_bottom_component_ids() for component_id, component_solution in components_values.items(): if not isinstance(component_solution, dict): @@ -486,27 +756,31 @@ def _collect_differential_linear_component_weights(self, components_values): if base_component_id not in seen_middle: middle_sum += weight seen_middle.add(base_component_id) - continue - - if base_component_id in self.bottom_part_component_ids and base_component_id not in seen_bottom: - q_weight += weight - seen_bottom.add(base_component_id) + elif base_component_id in effective_bottom_component_ids: + if base_component_id not in seen_bottom: + q_weight += weight + seen_bottom.add(base_component_id) return p_weight, middle_sum, q_weight, bool(seen_middle) - @staticmethod - def _middle_weight_term(middle_sum, has_middle_components): + def _middle_weight_term(self, middle_sum, has_middle_components): if not has_middle_components: return 0.0 - if middle_sum == 1: - raise ValueError("Unexpected probability weight 1 in middle part") - return abs(math.log(abs(2 * (2**(-1*middle_sum)) - 1), 2)) + if self._is_continuous_middle(): + import math + return math.log(2 * (2**middle_sum) - 1, 2) + else: + import math + if middle_sum == 1: + raise ValueError("Unexpected probability weight 1 in middle part") + return abs(math.log(abs(2 * (2**(-1*middle_sum)) - 1), 2)) def _differential_linear_total_weight_from_components(self, components_values): - p_weight, middle_sum, q_weight, _ = ( + p_weight, middle_sum, q_weight, has_middle_components = ( self._collect_differential_linear_component_weights(components_values) ) - return round(p_weight + middle_sum + (2 * q_weight), 10) + r_weight = self._middle_weight_term(middle_sum, has_middle_components) + return round(p_weight + r_weight + (2 * q_weight), 10) def _set_differential_linear_total_weight(self, solution): if not isinstance(solution, dict): @@ -521,10 +795,15 @@ def _set_differential_linear_total_weight(self, solution): def _parse_solver_output( self, output_to_parse, model_type, truncated=False, solve_external=False, solver_name=SOLVER_DEFAULT ): + continuous_xor_differential_linear_model = self._is_continuous_middle() and model_type in ( + XOR_DIFFERENTIAL_LINEAR_ONE_SOLUTION, + XOR_DIFFERENTIAL_LINEAR_OPTIMAL_SOLUTION, + ) + parsed = super()._parse_solver_output( output_to_parse, model_type, - truncated=truncated, + truncated=truncated or continuous_xor_differential_linear_model, solve_external=solve_external, solver_name=solver_name, ) @@ -545,7 +824,10 @@ def _parse_solver_output( return parsed if solve_external: - solver_time, memory, components_values, _ = parsed + if continuous_xor_differential_linear_model: + solver_time, memory, components_values = parsed + else: + solver_time, memory, components_values, _ = parsed total_weight = [] solution_keys = sorted( components_values.keys(), @@ -571,12 +853,24 @@ def build_xor_differential_linear_model(self, weight=-1, fixed_variables=None): fixed_variables = [] self.initialise_model() + + # Include continuous predicates (continuous_modadd, continuous_xor, etc.) + if self._is_continuous_middle(): + from claasp.cipher_modules.models.cp.minizinc_utils.mzn_continuous_predicates import ( + get_continuous_operations, + ) + self._model_prefix.append(get_continuous_operations()) + model_entries = self._component_model_entries() fixed_constraints = self._partition_fixed_value_constraints(fixed_variables) self.build_generic_cp_model_from_dictionary(model_entries) self._model_constraints = fixed_constraints + self._model_constraints + continuous_middle_constraints = [] + if self._is_continuous_middle(): + continuous_middle_constraints = self._continuous_middle_connecting_constraints() + probability_array_declaration = self._probability_array_declaration_from_component_map() middle_bottom_constraints = self._middle_to_bottom_connecting_constraints() @@ -591,6 +885,7 @@ def build_xor_differential_linear_model(self, weight=-1, fixed_variables=None): self._variables_list = declarations + self._model_constraints.extend(continuous_middle_constraints) self._model_constraints.extend(middle_bottom_constraints) self._model_constraints.extend(branch_constraints) self._model_constraints.extend(weight_constraints) diff --git a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model_test.py b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model_test.py index e9e0de40c..d8b5148d8 100644 --- a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model_test.py +++ b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model_test.py @@ -1,9 +1,15 @@ import itertools import math +import subprocess import pytest from claasp.cipher_modules.models.cp.mzn_models.mzn_differential_linear_model import MznDifferentialLinearModel -from claasp.cipher_modules.models.cp.solvers import CPSAT +from claasp.cipher_modules.models.cp.mzn_models.mzn_differential_linear_continuous_model import ( + MznDifferentialLinearContinuousModel, +) +from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_model import MznXorDifferentialModel +from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_linear_model import MznXorLinearModel +from claasp.cipher_modules.models.cp.solvers import CPSAT, SCIP from claasp.cipher_modules.models.utils import ( differential_linear_checker_for_block_cipher_single_key, integer_to_bit_list, @@ -14,9 +20,10 @@ from claasp.ciphers.block_ciphers.ballet_block_cipher import BalletBlockCipher from claasp.ciphers.block_ciphers.speck_block_cipher import SpeckBlockCipher from claasp.ciphers.permutations.chacha_permutation import ChachaPermutation, ROUND_MODE_HALF -from claasp.name_mappings import INPUT_PLAINTEXT, SATISFIABLE, INPUT_KEY, INPUT_MESSAGE +from claasp.name_mappings import INPUT_PLAINTEXT, SATISFIABLE, INPUT_KEY, INPUT_MESSAGE, XOR_DIFFERENTIAL_LINEAR_ONE_SOLUTION from claasp.ciphers.mac.siphash_mac import SiphashMAC + def _split_components(cipher, top_rounds_end, middle_rounds_end): top_part_components = [] middle_part_components = [] @@ -87,6 +94,29 @@ def test_parse_linear_bit_id_handles_valid_and_invalid_formats(): model._parse_linear_bit_id("invalid_bit_id") +def test_single_key_weight_bottom_component_ids_exclude_key_schedule_for_speck(): + speck = SpeckBlockCipher(number_of_rounds=9) + component_model_list = _split_components(speck, top_rounds_end=4, middle_rounds_end=6) + model = MznDifferentialLinearModel( + speck, + component_model_list, + middle_part_model="cp_continuous_differential_propagation_constraints", + single_key=True, + ) + + effective_bottom_ids = model._weight_bottom_component_ids() + + assert effective_bottom_ids.issubset(model.bottom_part_component_ids) + # Speck key-schedule branch components must not contribute to the single-key linear weight. + assert "modadd_6_2" not in effective_bottom_ids + assert "modadd_7_2" not in effective_bottom_ids + assert "modadd_8_2" not in effective_bottom_ids + # Data-path branch components must still contribute. + assert "modadd_6_7" in effective_bottom_ids + assert "modadd_7_7" in effective_bottom_ids + assert "modadd_8_7" in effective_bottom_ids + + def test_normalize_middle_part_components_values_hex_and_unknown_bits(): speck = SpeckBlockCipher(number_of_rounds=6) component_model_list = _split_components(speck, top_rounds_end=2, middle_rounds_end=3) @@ -813,3 +843,437 @@ def test_optimal_semi_deterministic_differential_linear_trail_siphash(): print("Status:", trail["status"]) print("Total weight:", trail["total_weight"]) assert trail["status"] == SATISFIABLE +# ────────────────────────────────────────────────────────────────────── +# Tests for continuous middle model (cp_continuous_differential_propagation_constraints) +# ────────────────────────────────────────────────────────────────────── + +def test_differential_linear_trail_continuous_middle_speck_build(): + """ + Speck32/64, 3 rounds: 1 top (diff) + 1 middle (continuous) + 1 bottom (linear). + + Input difference from Table 4 of [BGGMP2023]: (0x0010, 0x5000). + Verifies the model builds correctly and can be solved with SCIP. + """ + speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=3) + component_list = _split_components(speck, top_rounds_end=1, middle_rounds_end=2) + model = MznDifferentialLinearModel( + speck, + component_list, + middle_part_model="cp_continuous_differential_propagation_constraints", + ) + + plaintext = set_fixed_variables( + "plaintext", "equal", list(range(32)), + integer_to_bit_list(0x00105000, 32, "big"), + ) + key = set_fixed_variables( + "key", "equal", list(range(64)), + integer_to_bit_list(0, 64, "big"), + ) + + model.build_xor_differential_linear_model( + weight=-1, + fixed_variables=[plaintext, key], + ) + + # Verify model structure + full_model = "\n".join(model._variables_list + model._model_constraints) + + # Check continuous predicates are present + assert "continuous_modadd" in full_model + assert "continuous_xor" in full_model + + # Check DL-to-linear connection + assert "linear_mask_times_diff_lin_output" in full_model + assert "differential_linear_correlation" in full_model + assert "product(linear_mask_times_diff_lin_output)" in full_model + + # Check weight constraint does NOT include middle probability for continuous + assert "weight = " in full_model + + # Check state declarations don't have var 0..2 (no truncated middle) + declarations = model._state_declarations() + decl_text = "\n".join(declarations) + assert "var 0..2" not in decl_text + + +def test_differential_linear_trail_continuous_middle_9_rounds_speck_table4(): + """ + Speck32/64, Table 4 split into two checks: + + - 4-round differential prefix with weight 7. + - 3-round linear suffix with paper input mask ``0x10000020`` and weight 2. + + The 9-round combined model is still built below to verify that the + continuous predicates and border wiring are present. + """ + speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=9) + component_model_list = _split_components(speck, top_rounds_end=4, middle_rounds_end=6) + + model = MznDifferentialLinearModel( + speck, + component_model_list, + middle_part_model="cp_continuous_differential_propagation_constraints", + single_key=True, + ) + + def _component_bit_size(cipher, component_id): + if component_id in cipher.inputs: + idx = list(cipher.inputs).index(component_id) + return cipher.inputs_bit_size[idx] + return cipher.get_component_from_id(component_id).output_bit_size + + def _fixed_from_hex(cipher, component_id, hex_value): + bit_size = _component_bit_size(cipher, component_id) + return set_fixed_variables( + component_id=component_id, + constraint_type="equal", + bit_positions=list(range(bit_size)), + bit_values=integer_to_bit_list(int(hex_value, 16), bit_size, "big"), + ) + + combined_diff_fixed = { + "plaintext": "0xA8400010", + "key": "0x0000000000000000", + "intermediate_output_0_6": "0x81408100", + "intermediate_output_1_12": "0x00020400", + "intermediate_output_2_12": "0x00001000", + "intermediate_output_3_12": "0x10005000", + # Keep the top key-schedule branch at 0 for the single-key paper setting. + "xor_1_5": "0x0000", + "xor_2_5": "0x0000", + "xor_3_5": "0x0000", + } + + prefix_fixed = dict(combined_diff_fixed) + prefix_fixed["cipher_output_3_12"] = prefix_fixed.pop("intermediate_output_3_12") + + lin_fixed = { + "intermediate_output_6_12": "0x00000020", + "intermediate_output_7_12": "0x00800080", + "cipher_output_8_12": "0x02050204", + } + + fixed_values = [] + for component_id, hex_val in {**combined_diff_fixed, **lin_fixed}.items(): + fixed_values.append(_fixed_from_hex(speck, component_id, hex_val)) + + model.build_xor_differential_linear_model( + weight=-1, + fixed_variables=fixed_values, + ) + + full_model = "\n".join(model._variables_list + model._model_constraints) + + assert "continuous_modadd" in full_model + assert "continuous_xor" in full_model + assert "linear_mask_times_diff_lin_output" in full_model + assert "differential_linear_correlation" in full_model + assert "product(linear_mask_times_diff_lin_output)" in full_model + declarations = model._state_declarations() + assert "var 0..2" not in "\n".join(declarations) + + prefix_cipher = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=4) + prefix_model = MznXorDifferentialModel(prefix_cipher) + prefix_fixed_values = [] + for component_id, hex_val in prefix_fixed.items(): + prefix_fixed_values.append(_fixed_from_hex(prefix_cipher, component_id, hex_val)) + + prefix_trail = prefix_model.find_one_xor_differential_trail_with_fixed_weight( + fixed_weight=7, + fixed_values=prefix_fixed_values, + solver_name=CPSAT, + solve_external=True, + ) + + assert prefix_trail["status"] == SATISFIABLE + assert float(prefix_trail["total_weight"]) == 7.0 + assert prefix_trail["components_values"]["plaintext"]["value"] == "0xa8400010" + assert prefix_trail["components_values"]["intermediate_output_0_6"]["value"] == "0x81408100" + assert prefix_trail["components_values"]["intermediate_output_1_12"]["value"] == "0x00020400" + assert prefix_trail["components_values"]["intermediate_output_2_12"]["value"] == "0x00001000" + assert prefix_trail["components_values"]["cipher_output_3_12"]["value"] == "0x10005000" + + suffix_cipher = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=3) + suffix_model = MznXorLinearModel(suffix_cipher) + suffix_fixed_values = [ + _fixed_from_hex(suffix_cipher, "plaintext", "0x10000020"), + _fixed_from_hex(suffix_cipher, "intermediate_output_0_6", "0x00000020"), + _fixed_from_hex(suffix_cipher, "intermediate_output_1_12", "0x00800080"), + _fixed_from_hex(suffix_cipher, "cipher_output_2_12", "0x02050204"), + ] + + suffix_trail = suffix_model.find_lowest_weight_xor_linear_trail( + fixed_values=suffix_fixed_values, + solver_name=CPSAT, + solve_external=True, + ) + + assert suffix_trail["status"] == SATISFIABLE + assert float(suffix_trail["total_weight"]) == 2.0 + assert suffix_trail["components_values"]["plaintext"]["value"] == "0x10000020" + assert suffix_trail["components_values"]["intermediate_output_0_6_o"]["value"] == "0x00000020" + assert suffix_trail["components_values"]["intermediate_output_1_12_o"]["value"] == "0x00800080" + assert suffix_trail["components_values"]["cipher_output_2_12_o"]["value"] == "0x02050204" + + +def test_continuous_middle_model_signed_input_and_key_match_table4_values(): + """ + Validate the standalone continuous model using signed (-1..1) inputs. + + For the Table 4 middle part setting, plaintext and key are fixed as + signed arrays. The expected correlation corresponds to output mask + 0x10000020. + """ + + def _sign_bits_from_hex(hex_value, bit_size): + value = int(hex_value, 16) + return [1.0 if ((value >> (bit_size - 1 - i)) & 1) else -1.0 for i in range(bit_size)] + + def _mask_bits_from_hex(hex_value, bit_size): + value = int(hex_value, 16) + return [1 if ((value >> (bit_size - 1 - i)) & 1) else 0 for i in range(bit_size)] + + def _solve_for_mask(mask_hex): + cipher = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=2) + model = MznDifferentialLinearContinuousModel(cipher) + + fixed_inputs = [ + { + "component_id": "plaintext", + "bit_positions": list(range(32)), + "bit_values": _sign_bits_from_hex("0x10005000", 32), + }, + { + "component_id": "key", + "bit_positions": list(range(64)), + "bit_values": [-1.0] * 64, + }, + ] + + model.build_differential_linear_continuous_trail_model(fixed_values=fixed_inputs) + model._build_linear_mask_correlation_constraints() + model._build_difflin_corr_constraints() + + for index, bit in enumerate(_mask_bits_from_hex(mask_hex, 32)): + model._model_constraints.append(f"constraint output_mask[{index}] = {bit};") + + cipher_output_id = model._get_cipher_output_id() + model._model_constraints.append( + f"solve :: float_search({cipher_output_id}, 1e-12, smallest, indomain_min, complete) " + "minimize correlation_log2_approximation;" + ) + + try: + result = model.solve_for_ARX("scip", timeout_in_seconds_=120) + except Exception: + pytest.skip("Native MiniZinc SCIP solver is not available in this environment") + + parsed = model._parse_result(result, "scip") + assert parsed.get("status") in ("SATISFIED", "OPTIMAL_SOLUTION") + return parsed["differential_linear_correlation"] + + expected_corr = 0.745481409287476 + expected_total_probability = -11.42375571965992 + + corr_table4_mask = _solve_for_mask("0x10000020") + total_probability_table4_mask = math.log2(corr_table4_mask) - 7 - 2 * 2 + + assert abs(corr_table4_mask - expected_corr) < 1e-6 + assert abs(total_probability_table4_mask - expected_total_probability) < 1e-6 + + corr_integrated_bottom_mask = _solve_for_mask("0x00000020") + total_probability_integrated_bottom_mask = math.log2(corr_integrated_bottom_mask) - 7 - 2 * 2 + + assert abs(corr_integrated_bottom_mask - corr_table4_mask) > 1e-2 + assert abs(total_probability_integrated_bottom_mask - expected_total_probability) > 1e-2 + + +def test_differential_linear_trail_continuous_middle_9_rounds_speck_table4_full_model_float_search_scip(): + """ + Full integrated 9-round Speck32/64 differential-linear model. + + Uses native MiniZinc SCIP (not fscip), enforces paper split values + top=7 and bottom=2, and checks that the continuous differential-linear + correlation is non-zero under float_search. + """ + speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=9) + component_model_list = _split_components(speck, top_rounds_end=4, middle_rounds_end=6) + + model = MznDifferentialLinearModel( + speck, + component_model_list, + middle_part_model="cp_continuous_differential_propagation_constraints", + single_key=True, + ) + + def _component_bit_size(cipher, component_id): + if component_id in cipher.inputs: + idx = list(cipher.inputs).index(component_id) + return cipher.inputs_bit_size[idx] + return cipher.get_component_from_id(component_id).output_bit_size + + def _fixed_from_hex(cipher, component_id, hex_value): + bit_size = _component_bit_size(cipher, component_id) + return set_fixed_variables( + component_id=component_id, + constraint_type="equal", + bit_positions=list(range(bit_size)), + bit_values=integer_to_bit_list(int(hex_value, 16), bit_size, "big"), + ) + + def _sum_probability_expr(component_ids): + terms = [] + for component_id in sorted(component_ids): + probability_expr = model._component_probability_expression(component_id) + if probability_expr: + terms.append(f"({probability_expr})") + return "(" + " + ".join(terms) + ")" if terms else "0" + + def _base_component_id(component_id): + if component_id.endswith("_i") or component_id.endswith("_o"): + return component_id[:-2] + return component_id + + combined_diff_fixed = { + "plaintext": "0xA8400010", + "key": "0x0000000000000000", + "intermediate_output_0_6": "0x81408100", + "intermediate_output_1_12": "0x00020400", + "intermediate_output_2_12": "0x00001000", + "intermediate_output_3_12": "0x10005000", + "xor_1_5": "0x0000", + "xor_2_5": "0x0000", + "xor_3_5": "0x0000", + } + lin_fixed = { + "intermediate_output_6_12": "0x00000020", + "intermediate_output_7_12": "0x00800080", + "cipher_output_8_12": "0x02050204", + } + + fixed_values = [] + for component_id, hex_val in {**combined_diff_fixed, **lin_fixed}.items(): + fixed_values.append(_fixed_from_hex(speck, component_id, hex_val)) + + model.build_xor_differential_linear_model(weight=-1, fixed_variables=fixed_values) + + solve_idx = None + for idx, line in enumerate(model._model_constraints): + if line.strip().startswith("solve "): + solve_idx = idx + model._model_constraints[idx] = ( + "solve :: float_search(linear_mask_times_diff_lin_output, 1e-12, largest, indomain_split, complete) " + "minimize correlation_log2_approximation;" + ) + break + + assert solve_idx is not None + assert "minimize correlation_log2_approximation" in model._model_constraints[solve_idx] + + # top_expr = _sum_probability_expr(model.top_part_component_ids) + # bottom_expr = _sum_probability_expr(model._weight_bottom_component_ids()) + # model._model_constraints.insert(solve_idx, f"constraint {top_expr} = 700;") + # model._model_constraints.insert(solve_idx + 1, f"constraint {bottom_expr} = 200;") + + full_model = "\n".join(model._variables_list + model._model_constraints) + command = ["minizinc", "--input-from-stdin", "--solver-statistics", "--solver", "scip"] + solver_process = subprocess.run(command, input=full_model, capture_output=True, text=True, check=False) + + if solver_process.returncode != 0 and "Unknown solver" in solver_process.stderr: + pytest.skip("Native MiniZinc SCIP solver is not available in this environment") + + assert solver_process.returncode == 0 + + solver_output = solver_process.stdout.splitlines() + assert all("--solver fscip" not in line for line in solver_output) + assert "UNSATISFIABLE" not in solver_process.stdout + assert "----------" in solver_process.stdout + + _, _, components_values, _ = model._parse_solver_output( + solver_output, + XOR_DIFFERENTIAL_LINEAR_ONE_SOLUTION, + solve_external=True, + solver_name=SCIP, + ) + assert "solution1" in components_values + + solution = components_values["solution1"] + top_sum = 0.0 + bottom_sum = 0.0 + seen_top = set() + seen_bottom = set() + + for component_id, component_solution in solution.items(): + if not isinstance(component_solution, dict): + continue + weight = float(component_solution.get("weight", 0)) + base_component_id = _base_component_id(component_id) + if base_component_id in model.top_part_component_ids and base_component_id not in seen_top: + top_sum += weight + seen_top.add(base_component_id) + if base_component_id in model._weight_bottom_component_ids() and base_component_id not in seen_bottom: + bottom_sum += weight + seen_bottom.add(base_component_id) + if weight > 0: + print(f"BOTTOM WEIGHT: {base_component_id} = {weight}") + + assert top_sum == 7.0 + assert bottom_sum == 2.0 + + correlation_values = [ + float(line.split("=", 1)[1].strip()) + for line in solver_output + if line.startswith("differential_linear_correlation =") + ] + assert correlation_values + correlation_value = max(correlation_values) + assert 0.0 < correlation_value <= 1.0 + + expected_corr = 0.745481409287476 + expected_total_probability = -11.42375571965992 + total_probability = math.log2(correlation_value) - 7 - 2 * 2 + + assert min(abs(value - expected_corr) for value in correlation_values) < 1e-6 + assert abs(total_probability - expected_total_probability) < 1e-6 + + trail_weight_values = [ + float(line.split("=", 1)[1].strip()) + for line in solver_output + if line.startswith("Trail weight =") + ] + assert trail_weight_values + trail_weight_raw = min(trail_weight_values) + + def _piecewise_log2_approximation(correlation): + if correlation <= 0.001021453702391378: + return -19931.57001201849 * correlation + 29.89737278555626 + if correlation <= 0.004151650554233785: + return -584.962260272084 * correlation + 10.13570866882117 + if correlation <= 0.01359667098324998: + return -192.6450521799878 * correlation + 8.506944714410169 + if correlation <= 0.05399137458004444: + return -50.62607129324977 * correlation + 6.575959357916722 + if correlation <= 0.1420480516058986: + return -11.87410019056137 * correlation + 4.483687170396419 + if correlation <= 0.2463455066216964: + return -8.613130253286352 * correlation + 4.020472744461092 + if correlation <= 0.595815289564374: + return -3.761918786389538 * correlation + 2.825398597919413 + if correlation <= 0.998000001: + return -1.444862453710759 * correlation + 1.44486100812744 + return 0.0 + + correlation_log2_approximation = _piecewise_log2_approximation(correlation_value) + expected_weight_raw = 100 * top_sum + 100 * correlation_log2_approximation + 200 * bottom_sum + expected_weight_without_correlation_scaling = 100 * top_sum + correlation_log2_approximation + 200 * bottom_sum + + assert abs(trail_weight_raw - expected_weight_raw) < 1e-3 + assert abs(trail_weight_raw - expected_weight_raw) < abs( + trail_weight_raw - expected_weight_without_correlation_scaling + ) + assert abs(trail_weight_raw / 100 - (top_sum + correlation_log2_approximation + 2 * bottom_sum)) < 1e-5 + + +if __name__ == "__main__": + test_differential_linear_trail_continuous_middle_9_rounds_speck_table4_full_model_float_search_scip() From cdbe92d8532ff2542d3eff6bcb9797fcb02ad49f Mon Sep 17 00:00:00 2001 From: Rony Date: Wed, 22 Apr 2026 14:24:02 -0500 Subject: [PATCH 2/5] Refactor: simplify variable fixing and objective definition in DL model - Added get_fixed_variable_from_hex utility method to MznDifferentialLinearModel to standardize and simplify fixing component bits from hex strings. - Added optimization_objective parameter to build_xor_differential_linear_model to allow natively injecting custom solve statements (e.g., minimize correlation_log2_approximation) without manual constraint string manipulation. - Refactored mzn_differential_linear_model_test.py to use these new clean APIs, removing duplicated hex parsing helpers and hacky array replacements. --- .../mzn_differential_linear_model.py | 28 +- .../mzn_differential_linear_model_test.py | 362 +++++++----------- 2 files changed, 155 insertions(+), 235 deletions(-) diff --git a/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model.py b/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model.py index 01f6376ac..ff174dd07 100644 --- a/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model.py +++ b/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model.py @@ -26,7 +26,7 @@ from claasp.cipher_modules.models.cp.mzn_model import MznModel, SOLVE_SATISFY from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_linear_model import MznXorLinearModel from claasp.cipher_modules.models.cp.solvers import SOLVER_DEFAULT -from claasp.cipher_modules.models.utils import get_bit_bindings +from claasp.cipher_modules.models.utils import get_bit_bindings, integer_to_bit_list, set_fixed_variables from claasp.name_mappings import ( CIPHER_OUTPUT, CONSTANT, @@ -847,8 +847,26 @@ def _ensure_components_values(self, solution): return self._normalize_middle_part_components_values(solution) + def get_fixed_variable_from_hex(self, component_id, hex_value, endianness="big"): + if component_id in self._cipher.inputs: + idx = list(self._cipher.inputs).index(component_id) + bit_size = self._cipher.inputs_bit_size[idx] + else: + bit_size = self._cipher.get_component_from_id(component_id).output_bit_size + + if isinstance(hex_value, str) and hex_value.startswith("0x"): + int_val = int(hex_value, 16) + else: + int_val = int(hex_value) - def build_xor_differential_linear_model(self, weight=-1, fixed_variables=None): + return set_fixed_variables( + component_id=component_id, + constraint_type="equal", + bit_positions=list(range(bit_size)), + bit_values=integer_to_bit_list(int_val, bit_size, endianness), + ) + + def build_xor_differential_linear_model(self, weight=-1, fixed_variables=None, optimization_objective=None): if fixed_variables is None: fixed_variables = [] @@ -889,7 +907,11 @@ def build_xor_differential_linear_model(self, weight=-1, fixed_variables=None): self._model_constraints.extend(middle_bottom_constraints) self._model_constraints.extend(branch_constraints) self._model_constraints.extend(weight_constraints) - self._model_constraints.extend(self._build_output_block(weight)) + output_block = self._build_output_block(weight) + if optimization_objective: + output_block[0] = optimization_objective + + self._model_constraints.extend(output_block) self._model_constraints = self._model_prefix + self._model_constraints diff --git a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model_test.py b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model_test.py index d8b5148d8..d5f0e7e38 100644 --- a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model_test.py +++ b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model_test.py @@ -47,6 +47,18 @@ def _split_components(cipher, top_rounds_end, middle_rounds_end): "bottom_part_components": bottom_part_components, } +def _fixed_from_hex_test(cipher, component_id, hex_value): + if component_id in cipher.inputs: + idx = list(cipher.inputs).index(component_id) + bit_size = cipher.inputs_bit_size[idx] + else: + bit_size = cipher.get_component_from_id(component_id).output_bit_size + return set_fixed_variables( + component_id=component_id, + constraint_type="equal", + bit_positions=list(range(bit_size)), + bit_values=integer_to_bit_list(int(hex_value, 16), bit_size, "big"), + ) @pytest.mark.parametrize( "cipher_cls,cipher_kwargs,top_rounds_end,middle_rounds_end", @@ -896,127 +908,6 @@ def test_differential_linear_trail_continuous_middle_speck_build(): decl_text = "\n".join(declarations) assert "var 0..2" not in decl_text - -def test_differential_linear_trail_continuous_middle_9_rounds_speck_table4(): - """ - Speck32/64, Table 4 split into two checks: - - - 4-round differential prefix with weight 7. - - 3-round linear suffix with paper input mask ``0x10000020`` and weight 2. - - The 9-round combined model is still built below to verify that the - continuous predicates and border wiring are present. - """ - speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=9) - component_model_list = _split_components(speck, top_rounds_end=4, middle_rounds_end=6) - - model = MznDifferentialLinearModel( - speck, - component_model_list, - middle_part_model="cp_continuous_differential_propagation_constraints", - single_key=True, - ) - - def _component_bit_size(cipher, component_id): - if component_id in cipher.inputs: - idx = list(cipher.inputs).index(component_id) - return cipher.inputs_bit_size[idx] - return cipher.get_component_from_id(component_id).output_bit_size - - def _fixed_from_hex(cipher, component_id, hex_value): - bit_size = _component_bit_size(cipher, component_id) - return set_fixed_variables( - component_id=component_id, - constraint_type="equal", - bit_positions=list(range(bit_size)), - bit_values=integer_to_bit_list(int(hex_value, 16), bit_size, "big"), - ) - - combined_diff_fixed = { - "plaintext": "0xA8400010", - "key": "0x0000000000000000", - "intermediate_output_0_6": "0x81408100", - "intermediate_output_1_12": "0x00020400", - "intermediate_output_2_12": "0x00001000", - "intermediate_output_3_12": "0x10005000", - # Keep the top key-schedule branch at 0 for the single-key paper setting. - "xor_1_5": "0x0000", - "xor_2_5": "0x0000", - "xor_3_5": "0x0000", - } - - prefix_fixed = dict(combined_diff_fixed) - prefix_fixed["cipher_output_3_12"] = prefix_fixed.pop("intermediate_output_3_12") - - lin_fixed = { - "intermediate_output_6_12": "0x00000020", - "intermediate_output_7_12": "0x00800080", - "cipher_output_8_12": "0x02050204", - } - - fixed_values = [] - for component_id, hex_val in {**combined_diff_fixed, **lin_fixed}.items(): - fixed_values.append(_fixed_from_hex(speck, component_id, hex_val)) - - model.build_xor_differential_linear_model( - weight=-1, - fixed_variables=fixed_values, - ) - - full_model = "\n".join(model._variables_list + model._model_constraints) - - assert "continuous_modadd" in full_model - assert "continuous_xor" in full_model - assert "linear_mask_times_diff_lin_output" in full_model - assert "differential_linear_correlation" in full_model - assert "product(linear_mask_times_diff_lin_output)" in full_model - declarations = model._state_declarations() - assert "var 0..2" not in "\n".join(declarations) - - prefix_cipher = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=4) - prefix_model = MznXorDifferentialModel(prefix_cipher) - prefix_fixed_values = [] - for component_id, hex_val in prefix_fixed.items(): - prefix_fixed_values.append(_fixed_from_hex(prefix_cipher, component_id, hex_val)) - - prefix_trail = prefix_model.find_one_xor_differential_trail_with_fixed_weight( - fixed_weight=7, - fixed_values=prefix_fixed_values, - solver_name=CPSAT, - solve_external=True, - ) - - assert prefix_trail["status"] == SATISFIABLE - assert float(prefix_trail["total_weight"]) == 7.0 - assert prefix_trail["components_values"]["plaintext"]["value"] == "0xa8400010" - assert prefix_trail["components_values"]["intermediate_output_0_6"]["value"] == "0x81408100" - assert prefix_trail["components_values"]["intermediate_output_1_12"]["value"] == "0x00020400" - assert prefix_trail["components_values"]["intermediate_output_2_12"]["value"] == "0x00001000" - assert prefix_trail["components_values"]["cipher_output_3_12"]["value"] == "0x10005000" - - suffix_cipher = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=3) - suffix_model = MznXorLinearModel(suffix_cipher) - suffix_fixed_values = [ - _fixed_from_hex(suffix_cipher, "plaintext", "0x10000020"), - _fixed_from_hex(suffix_cipher, "intermediate_output_0_6", "0x00000020"), - _fixed_from_hex(suffix_cipher, "intermediate_output_1_12", "0x00800080"), - _fixed_from_hex(suffix_cipher, "cipher_output_2_12", "0x02050204"), - ] - - suffix_trail = suffix_model.find_lowest_weight_xor_linear_trail( - fixed_values=suffix_fixed_values, - solver_name=CPSAT, - solve_external=True, - ) - - assert suffix_trail["status"] == SATISFIABLE - assert float(suffix_trail["total_weight"]) == 2.0 - assert suffix_trail["components_values"]["plaintext"]["value"] == "0x10000020" - assert suffix_trail["components_values"]["intermediate_output_0_6_o"]["value"] == "0x00000020" - assert suffix_trail["components_values"]["intermediate_output_1_12_o"]["value"] == "0x00800080" - assert suffix_trail["components_values"]["cipher_output_2_12_o"]["value"] == "0x02050204" - - def test_continuous_middle_model_signed_input_and_key_match_table4_values(): """ Validate the standalone continuous model using signed (-1..1) inputs. @@ -1091,11 +982,7 @@ def _solve_for_mask(mask_hex): def test_differential_linear_trail_continuous_middle_9_rounds_speck_table4_full_model_float_search_scip(): """ - Full integrated 9-round Speck32/64 differential-linear model. - - Uses native MiniZinc SCIP (not fscip), enforces paper split values - top=7 and bottom=2, and checks that the continuous differential-linear - correlation is non-zero under float_search. + Full integrated 9-round Speck32/64 differential-linear model using native MiniZinc SCIP. """ speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=9) component_model_list = _split_components(speck, top_rounds_end=4, middle_rounds_end=6) @@ -1106,36 +993,7 @@ def test_differential_linear_trail_continuous_middle_9_rounds_speck_table4_full_ middle_part_model="cp_continuous_differential_propagation_constraints", single_key=True, ) - - def _component_bit_size(cipher, component_id): - if component_id in cipher.inputs: - idx = list(cipher.inputs).index(component_id) - return cipher.inputs_bit_size[idx] - return cipher.get_component_from_id(component_id).output_bit_size - - def _fixed_from_hex(cipher, component_id, hex_value): - bit_size = _component_bit_size(cipher, component_id) - return set_fixed_variables( - component_id=component_id, - constraint_type="equal", - bit_positions=list(range(bit_size)), - bit_values=integer_to_bit_list(int(hex_value, 16), bit_size, "big"), - ) - - def _sum_probability_expr(component_ids): - terms = [] - for component_id in sorted(component_ids): - probability_expr = model._component_probability_expression(component_id) - if probability_expr: - terms.append(f"({probability_expr})") - return "(" + " + ".join(terms) + ")" if terms else "0" - - def _base_component_id(component_id): - if component_id.endswith("_i") or component_id.endswith("_o"): - return component_id[:-2] - return component_id - - combined_diff_fixed = { + TABLE_4_COMBINED_DIFF_FIXED = { "plaintext": "0xA8400010", "key": "0x0000000000000000", "intermediate_output_0_6": "0x81408100", @@ -1146,37 +1004,28 @@ def _base_component_id(component_id): "xor_2_5": "0x0000", "xor_3_5": "0x0000", } - lin_fixed = { + TABLE_4_LIN_FIXED = { "intermediate_output_6_12": "0x00000020", "intermediate_output_7_12": "0x00800080", "cipher_output_8_12": "0x02050204", } + fixed_values = [ + model.get_fixed_variable_from_hex(comp_id, hex_val) + for comp_id, hex_val in {**TABLE_4_COMBINED_DIFF_FIXED, **TABLE_4_LIN_FIXED}.items() + ] - fixed_values = [] - for component_id, hex_val in {**combined_diff_fixed, **lin_fixed}.items(): - fixed_values.append(_fixed_from_hex(speck, component_id, hex_val)) - - model.build_xor_differential_linear_model(weight=-1, fixed_variables=fixed_values) - - solve_idx = None - for idx, line in enumerate(model._model_constraints): - if line.strip().startswith("solve "): - solve_idx = idx - model._model_constraints[idx] = ( - "solve :: float_search(linear_mask_times_diff_lin_output, 1e-12, largest, indomain_split, complete) " - "minimize correlation_log2_approximation;" - ) - break - - assert solve_idx is not None - assert "minimize correlation_log2_approximation" in model._model_constraints[solve_idx] - - # top_expr = _sum_probability_expr(model.top_part_component_ids) - # bottom_expr = _sum_probability_expr(model._weight_bottom_component_ids()) - # model._model_constraints.insert(solve_idx, f"constraint {top_expr} = 700;") - # model._model_constraints.insert(solve_idx + 1, f"constraint {bottom_expr} = 200;") + model.build_xor_differential_linear_model( + weight=-1, + fixed_variables=fixed_values, + optimization_objective=( + "solve :: float_search(linear_mask_times_diff_lin_output, 1e-12, largest, indomain_split, complete) " + "minimize correlation_log2_approximation;" + ) + ) full_model = "\n".join(model._variables_list + model._model_constraints) + assert "minimize correlation_log2_approximation;" in full_model + command = ["minizinc", "--input-from-stdin", "--solver-statistics", "--solver", "scip"] solver_process = subprocess.run(command, input=full_model, capture_output=True, text=True, check=False) @@ -1184,95 +1033,144 @@ def _base_component_id(component_id): pytest.skip("Native MiniZinc SCIP solver is not available in this environment") assert solver_process.returncode == 0 - solver_output = solver_process.stdout.splitlines() assert all("--solver fscip" not in line for line in solver_output) assert "UNSATISFIABLE" not in solver_process.stdout assert "----------" in solver_process.stdout _, _, components_values, _ = model._parse_solver_output( - solver_output, - XOR_DIFFERENTIAL_LINEAR_ONE_SOLUTION, - solve_external=True, - solver_name=SCIP, + solver_output, XOR_DIFFERENTIAL_LINEAR_ONE_SOLUTION, solve_external=True, solver_name=SCIP ) - assert "solution1" in components_values - + solution = components_values["solution1"] - top_sum = 0.0 - bottom_sum = 0.0 - seen_top = set() - seen_bottom = set() + top_sum, bottom_sum = 0.0, 0.0 + seen_top, seen_bottom = set(), set() for component_id, component_solution in solution.items(): if not isinstance(component_solution, dict): continue weight = float(component_solution.get("weight", 0)) - base_component_id = _base_component_id(component_id) + base_component_id = model._base_component_id(component_id) if base_component_id in model.top_part_component_ids and base_component_id not in seen_top: top_sum += weight seen_top.add(base_component_id) if base_component_id in model._weight_bottom_component_ids() and base_component_id not in seen_bottom: bottom_sum += weight seen_bottom.add(base_component_id) - if weight > 0: - print(f"BOTTOM WEIGHT: {base_component_id} = {weight}") assert top_sum == 7.0 assert bottom_sum == 2.0 correlation_values = [ float(line.split("=", 1)[1].strip()) - for line in solver_output - if line.startswith("differential_linear_correlation =") + for line in solver_output if line.startswith("differential_linear_correlation =") ] assert correlation_values correlation_value = max(correlation_values) - assert 0.0 < correlation_value <= 1.0 expected_corr = 0.745481409287476 expected_total_probability = -11.42375571965992 - total_probability = math.log2(correlation_value) - 7 - 2 * 2 + total_probability = math.log2(correlation_value) - top_sum - bottom_sum*2 assert min(abs(value - expected_corr) for value in correlation_values) < 1e-6 assert abs(total_probability - expected_total_probability) < 1e-6 - trail_weight_values = [ - float(line.split("=", 1)[1].strip()) - for line in solver_output - if line.startswith("Trail weight =") + +def test_differential_linear_trail_continuous_middle_9_rounds_speck_table6_full_model_float_search_scip(): + """ + Full integrated 9-round Speck32/64 differential-linear model for Table 6 using native MiniZinc SCIP. + """ + speck = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=9) + component_model_list = _split_components(speck, top_rounds_end=4, middle_rounds_end=6) + + model = MznDifferentialLinearModel( + speck, + component_model_list, + middle_part_model="cp_continuous_differential_propagation_constraints", + single_key=True, + ) + TABLE_6_COMBINED_DIFF_FIXED = { + "plaintext": "0x20540800", + "key": "0x0000000000000000", + "intermediate_output_0_6": "0xA0408040", + "intermediate_output_1_12": "0x01000002", + "intermediate_output_2_12": "0x00000008", + "intermediate_output_3_12": "0x00080028", + "xor_1_5": "0x0000", + "xor_2_5": "0x0000", + "xor_3_5": "0x0000", + } + TABLE_6_LIN_FIXED = { + "intermediate_output_6_12": "0x00200020", + "intermediate_output_7_12": "0x00814081", + "cipher_output_8_12": "0x0C000E01", + } + fixed_values = [ + model.get_fixed_variable_from_hex(comp_id, hex_val) + for comp_id, hex_val in {**TABLE_6_COMBINED_DIFF_FIXED, **TABLE_6_LIN_FIXED}.items() ] - assert trail_weight_values - trail_weight_raw = min(trail_weight_values) - - def _piecewise_log2_approximation(correlation): - if correlation <= 0.001021453702391378: - return -19931.57001201849 * correlation + 29.89737278555626 - if correlation <= 0.004151650554233785: - return -584.962260272084 * correlation + 10.13570866882117 - if correlation <= 0.01359667098324998: - return -192.6450521799878 * correlation + 8.506944714410169 - if correlation <= 0.05399137458004444: - return -50.62607129324977 * correlation + 6.575959357916722 - if correlation <= 0.1420480516058986: - return -11.87410019056137 * correlation + 4.483687170396419 - if correlation <= 0.2463455066216964: - return -8.613130253286352 * correlation + 4.020472744461092 - if correlation <= 0.595815289564374: - return -3.761918786389538 * correlation + 2.825398597919413 - if correlation <= 0.998000001: - return -1.444862453710759 * correlation + 1.44486100812744 - return 0.0 - - correlation_log2_approximation = _piecewise_log2_approximation(correlation_value) - expected_weight_raw = 100 * top_sum + 100 * correlation_log2_approximation + 200 * bottom_sum - expected_weight_without_correlation_scaling = 100 * top_sum + correlation_log2_approximation + 200 * bottom_sum - - assert abs(trail_weight_raw - expected_weight_raw) < 1e-3 - assert abs(trail_weight_raw - expected_weight_raw) < abs( - trail_weight_raw - expected_weight_without_correlation_scaling + + model.build_xor_differential_linear_model( + weight=-1, + fixed_variables=fixed_values, + optimization_objective=( + "solve :: float_search(linear_mask_times_diff_lin_output, 1e-12, largest, indomain_split, complete) " + "minimize correlation_log2_approximation;" + ) + ) + + full_model = "\n".join(model._variables_list + model._model_constraints) + assert "minimize correlation_log2_approximation;" in full_model + + command = ["minizinc", "--input-from-stdin", "--solver-statistics", "--solver", "scip"] + solver_process = subprocess.run(command, input=full_model, capture_output=True, text=True, check=False) + + if solver_process.returncode != 0 and "Unknown solver" in solver_process.stderr: + pytest.skip("Native MiniZinc SCIP solver is not available in this environment") + + assert solver_process.returncode == 0 + solver_output = solver_process.stdout.splitlines() + assert all("--solver fscip" not in line for line in solver_output) + assert "UNSATISFIABLE" not in solver_process.stdout + assert "----------" in solver_process.stdout + + _, _, components_values, _ = model._parse_solver_output( + solver_output, XOR_DIFFERENTIAL_LINEAR_ONE_SOLUTION, solve_external=True, solver_name=SCIP ) - assert abs(trail_weight_raw / 100 - (top_sum + correlation_log2_approximation + 2 * bottom_sum)) < 1e-5 + + solution = components_values["solution1"] + top_sum, bottom_sum = 0.0, 0.0 + seen_top, seen_bottom = set(), set() + + for component_id, component_solution in solution.items(): + if not isinstance(component_solution, dict): + continue + weight = float(component_solution.get("weight", 0)) + base_component_id = model._base_component_id(component_id) + if base_component_id in model.top_part_component_ids and base_component_id not in seen_top: + top_sum += weight + seen_top.add(base_component_id) + if base_component_id in model._weight_bottom_component_ids() and base_component_id not in seen_bottom: + bottom_sum += weight + seen_bottom.add(base_component_id) + + assert top_sum == 7.0 + assert bottom_sum == 3.0 + + correlation_values = [ + float(line.split("=", 1)[1].strip()) + for line in solver_output if line.startswith("differential_linear_correlation =") + ] + assert correlation_values + correlation_value = max(correlation_values) + + expected_corr = 0.7759094272693416 + expected_total_probability = -13.36603983996931 + total_probability = math.log2(correlation_value) - top_sum - bottom_sum*2 + + assert min(abs(value - expected_corr) for value in correlation_values) < 1e-6 + assert abs(total_probability - expected_total_probability) < 1e-6 + if __name__ == "__main__": From 03a9d5408c8dd605fd567ab9273f44cc54bd52f5 Mon Sep 17 00:00:00 2001 From: Rony Date: Thu, 23 Apr 2026 09:25:11 -0500 Subject: [PATCH 3/5] Refactor: reduce SonarCloud issues (complexity & float comparison) - Replace direct float equality checks with math.isclose() in continuous predicates - Extract complex loop logic into helper functions in mzn_differential_linear_model.py to reduce cognitive complexity (border constraints, output parsing) - Simplify intermediate output pinning iteration and remove unused 'pin_bit' --- .../mzn_continuous_predicates.py | 4 +- .../mzn_differential_linear_model.py | 209 ++++++++---------- 2 files changed, 93 insertions(+), 120 deletions(-) diff --git a/claasp/cipher_modules/models/cp/minizinc_utils/mzn_continuous_predicates.py b/claasp/cipher_modules/models/cp/minizinc_utils/mzn_continuous_predicates.py index a810a6819..bd025354d 100644 --- a/claasp/cipher_modules/models/cp/minizinc_utils/mzn_continuous_predicates.py +++ b/claasp/cipher_modules/models/cp/minizinc_utils/mzn_continuous_predicates.py @@ -1,3 +1,5 @@ +import math + def get_continuous_operations(): """ Returns the MiniZinc code required for continuous correlation propagation. @@ -68,7 +70,7 @@ def active_bit_correlation_expression(mask_expr, correlation_expr): def piecewise_log2_approximation_expression(correlation_expr, scale=1.0, else_value="0.0"): """Return the shared piecewise linear approximation of -log2(correlation).""" - scale_prefix = "" if scale == 1.0 else f"{scale} * " + scale_prefix = "" if math.isclose(scale, 1.0) else f"{scale} * " return ( f"{scale_prefix}(\n" f"if {correlation_expr} <= 0.001021453702391378 then\n" diff --git a/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model.py b/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model.py index ff174dd07..d4ef63e8f 100644 --- a/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model.py +++ b/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model.py @@ -273,69 +273,55 @@ def _continuous_middle_connecting_constraints(self, include_middle_sources=True) all incoming arcs to those x* arrays. """ constraints = [] + constraints.extend(self._continuous_middle_connecting_from_raw_bindings(include_middle_sources)) + constraints.extend(self._continuous_middle_connecting_for_intermediate_outputs(include_middle_sources)) + return constraints + def _continuous_middle_connecting_from_raw_bindings(self, include_middle_sources): + constraints = [] for output_bit_id, successor_bits in self.raw_bit_bindings.items(): source_component_id, source_bit_index, source_side = output_bit_id - is_input = source_component_id in self._cipher.inputs if not is_input and source_side != "o": continue - - if (not include_middle_sources) and source_component_id in self.middle_part_component_ids: + if not include_middle_sources and source_component_id in self.middle_part_component_ids: continue source_bit_expr = f"{source_component_id}[{int(source_bit_index)}]" - for successor_bit in successor_bits: successor_component_id, successor_bit_index, successor_side = successor_bit successor_is_output = successor_component_id in (self._cipher.outputs if hasattr(self._cipher, "outputs") else []) - if ( - not (successor_side == "i" or successor_is_output) - or successor_component_id not in self.middle_part_component_ids - ): + if not (successor_side == "i" or successor_is_output) or successor_component_id not in self.middle_part_component_ids: continue - successor_bit_expr = self._continuous_middle_input_expr( - successor_component_id, - int(successor_bit_index), - ) - + successor_bit_expr = self._continuous_middle_input_expr(successor_component_id, int(successor_bit_index)) if source_component_id in self.middle_part_component_ids: constraints.append(f"constraint {successor_bit_expr} = {source_bit_expr};") else: - constraints.append( - f"constraint {successor_bit_expr} = if {source_bit_expr} = 1 then 1 else -1 endif;" - ) + constraints.append(f"constraint {successor_bit_expr} = if {source_bit_expr} = 1 then 1 else -1 endif;") + return constraints - # Wire intermediate outputs since they are not in raw_bit_bindings sources + def _continuous_middle_connecting_for_intermediate_outputs(self, include_middle_sources): + constraints = [] for comp_id, bit_dict in self.raw_bit_bindings_for_intermediate_output.items(): if comp_id not in self.middle_part_component_ids: continue for inter_bit_tuple, pins in bit_dict.items(): inter_id, inter_bit, _ = inter_bit_tuple - - source_pin = None - for pin in pins: - pin_id, pin_bit, pin_side = pin - if pin_side == "o" or pin_id in self._cipher.inputs: - source_pin = pin - break - + source_pin = next((pin for pin in pins if pin[2] == "o" or pin[0] in self._cipher.inputs), None) if not source_pin: continue source_id, source_bit, _ = source_pin - if (not include_middle_sources) and source_id in self.middle_part_component_ids: + if not include_middle_sources and source_id in self.middle_part_component_ids: continue source_expr = f"{source_id}[{int(source_bit)}]" successor_expr = self._continuous_middle_input_expr(inter_id, int(inter_bit)) - if source_id in self.middle_part_component_ids: constraints.append(f"constraint {successor_expr} = {source_expr};") else: constraints.append(f"constraint {successor_expr} = if {source_expr} = 1 then 1 else -1 endif;") - return constraints def _top_to_middle_connecting_constraints(self): @@ -348,32 +334,27 @@ def _top_to_middle_connecting_constraints(self): for output_bit_id, successor_bits in self.raw_bit_bindings.items(): source_component_id, source_bit_index, source_side = output_bit_id - is_input = source_component_id in self._cipher.inputs if (not is_input and source_side != "o") or source_component_id not in border_components: continue source_bit_expr = f"{source_component_id}[{int(source_bit_index)}]" - for successor_bit in successor_bits: - successor_component_id, successor_bit_index, successor_side = successor_bit - if ( - successor_side != "i" - or successor_component_id not in self.middle_part_component_ids - ): - continue + constraints.extend(self._top_to_middle_successor_constraints(successor_bits, source_bit_expr)) - successor_bit_expr = f"{successor_component_id}[{int(successor_bit_index)}]" - if self._is_continuous_middle(): - # Cast differential int bit to continuous float: - # 0 (no difference) -> -1.0, 1 (difference) -> 1.0 - constraints.append( - f"constraint {successor_bit_expr} = if {source_bit_expr} = 1 then 1 else -1 endif;" - ) - else: - constraints.append( - f"constraint {successor_bit_expr} = if {source_bit_expr} = 1 then 1 else 0 endif;" - ) + return constraints + def _top_to_middle_successor_constraints(self, successor_bits, source_bit_expr): + constraints = [] + for successor_bit in successor_bits: + successor_component_id, successor_bit_index, successor_side = successor_bit + if successor_side != "i" or successor_component_id not in self.middle_part_component_ids: + continue + + successor_bit_expr = f"{successor_component_id}[{int(successor_bit_index)}]" + if self._is_continuous_middle(): + constraints.append(f"constraint {successor_bit_expr} = if {source_bit_expr} = 1 then 1 else -1 endif;") + else: + constraints.append(f"constraint {successor_bit_expr} = if {source_bit_expr} = 1 then 1 else 0 endif;") return constraints def _middle_to_bottom_connecting_constraints(self): @@ -418,28 +399,47 @@ def _continuous_middle_to_bottom_connecting_constraints(self): """ constraints = [] border_components = set(self._get_truncated_xor_differential_components_in_border()) + border_components = self._filter_border_components_for_single_key(border_components) + + border_sources = self._collect_border_sources(border_components) + if not border_sources: + return constraints + + ordered_border_sources = self._order_border_sources(border_sources) + n = len(ordered_border_sources) + + self._variables_list.append(f"array[0..{n - 1}] of var 0..1: linear_border_mask;") + self._add_linear_border_mask_constraints(constraints, ordered_border_sources) - # In single-key mode we only account for the data-path branch at the - # middle/bottom border, matching the single-key linear-weight semantics. - # For single_key=False we intentionally keep all border components, - # including the key-schedule branch. - if self.single_key: - last_middle_round = max( - self._cipher.get_round_from_component_id(cid) - for cid in self.middle_part_component_ids + self._variables_list.append(f"array[0..{n - 1}] of var -1.0..1.0: linear_mask_times_diff_lin_output;") + for idx, (cont_bit, _) in enumerate(ordered_border_sources): + constraints.append( + f"constraint linear_mask_times_diff_lin_output[{idx}] = " + f"{active_bit_correlation_expression(f'linear_border_mask[{idx}]', cont_bit)};" ) - state_input_ids = set() - for comp in self._cipher.get_components_in_round(last_middle_round): - if comp.id in self.middle_part_component_ids and comp.type == "intermediate_output": - # The 32-bit output is the state; 16-bit is typically key schedule - if comp.output_bit_size > 16: - state_input_ids.update(comp.input_id_links) - if state_input_ids: - border_components = border_components & state_input_ids - - # Build a canonical border mask per continuous source bit. This mirrors - # the standalone continuous model semantics where one mask bit gates one - # correlation bit, instead of multiplying once per fan-out edge. + + self._variables_list.append("var lower..upper: differential_linear_correlation;") + constraints.append("constraint differential_linear_correlation = product(linear_mask_times_diff_lin_output);") + constraints.append("constraint differential_linear_correlation != 0.0;") + + return constraints + + def _filter_border_components_for_single_key(self, border_components): + if not self.single_key: + return border_components + last_middle_round = max( + self._cipher.get_round_from_component_id(cid) for cid in self.middle_part_component_ids + ) + state_input_ids = set() + for comp in self._cipher.get_components_in_round(last_middle_round): + if comp.id in self.middle_part_component_ids and comp.type == "intermediate_output": + if comp.output_bit_size > 16: + state_input_ids.update(comp.input_id_links) + if state_input_ids: + return border_components & state_input_ids + return border_components + + def _collect_border_sources(self, border_components): border_sources = {} for output_bit_id, successor_bits in self.bit_bindings.items(): source_component_id, _, _ = self._parse_linear_bit_id(output_bit_id) @@ -451,10 +451,9 @@ def _continuous_middle_to_bottom_connecting_constraints(self): successor_component_id, _, _ = self._parse_linear_bit_id(successor_bit) if successor_component_id in self.bottom_part_component_ids: border_sources.setdefault(source_bit_expr, set()).add(successor_bit) + return border_sources - if not border_sources: - return constraints - + def _order_border_sources(self, border_sources): def _sort_bit_expr(bit_expr): component_id, _, bit_index = self._parse_linear_bit_id(bit_expr) return component_id, bit_index @@ -463,43 +462,14 @@ def _sort_bit_expr(bit_expr): for source_bit_expr in sorted(border_sources, key=_sort_bit_expr): ordered_successors = sorted(border_sources[source_bit_expr], key=_sort_bit_expr) ordered_border_sources.append((source_bit_expr, ordered_successors)) + return ordered_border_sources - n = len(ordered_border_sources) - - self._variables_list.append( - f"array[0..{n - 1}] of var 0..1: linear_border_mask;" - ) - + def _add_linear_border_mask_constraints(self, constraints, ordered_border_sources): for idx, (_, successors) in enumerate(ordered_border_sources): if len(successors) == 1: - constraints.append( - f"constraint linear_border_mask[{idx}] = {successors[0]};" - ) + constraints.append(f"constraint linear_border_mask[{idx}] = {successors[0]};") else: - constraints.append( - f"constraint linear_border_mask[{idx}] = ({' + '.join(successors)}) mod 2;" - ) - - # Declare the combined array - self._variables_list.append( - f"array[0..{n - 1}] of var -1.0..1.0: linear_mask_times_diff_lin_output;" - ) - - # Build: combined[i] = if mask==0 then 1 else mask*abs(corr) endif - for idx, (cont_bit, _) in enumerate(ordered_border_sources): - constraints.append( - f"constraint linear_mask_times_diff_lin_output[{idx}] = " - f"{active_bit_correlation_expression(f'linear_border_mask[{idx}]', cont_bit)};" - ) - - # Declare and constrain the differential-linear correlation - self._variables_list.append("var lower..upper: differential_linear_correlation;") - constraints.append( - "constraint differential_linear_correlation = product(linear_mask_times_diff_lin_output);" - ) - constraints.append("constraint differential_linear_correlation != 0.0;") - - return constraints + constraints.append(f"constraint linear_border_mask[{idx}] = ({' + '.join(successors)}) mod 2;") def _branch_xor_linear_constraints_for_bottom_part(self): constraints = [] @@ -814,6 +784,9 @@ def _parse_solver_output( ): return parsed + return self._process_differential_linear_parsed_output(parsed, solve_external, continuous_xor_differential_linear_model) + + def _process_differential_linear_parsed_output(self, parsed, solve_external, continuous_xor_differential_linear_model): if not solve_external: if isinstance(parsed, list): for solution in parsed: @@ -823,22 +796,20 @@ def _parse_solver_output( self._set_differential_linear_total_weight(parsed) return parsed - if solve_external: - if continuous_xor_differential_linear_model: - solver_time, memory, components_values = parsed - else: - solver_time, memory, components_values, _ = parsed - total_weight = [] - solution_keys = sorted( - components_values.keys(), - key=lambda key: int(key.replace("solution", "")) if key.startswith("solution") else 0, - ) - for solution_key in solution_keys: - solution_components_values = components_values.get(solution_key, {}) - total_weight.append(str(self._differential_linear_total_weight_from_components(solution_components_values))) - return solver_time, memory, components_values, total_weight - - return parsed + if continuous_xor_differential_linear_model: + solver_time, memory, components_values = parsed + else: + solver_time, memory, components_values, _ = parsed + + total_weight = [] + solution_keys = sorted( + components_values.keys(), + key=lambda key: int(key.replace("solution", "")) if key.startswith("solution") else 0, + ) + for solution_key in solution_keys: + solution_components_values = components_values.get(solution_key, {}) + total_weight.append(str(self._differential_linear_total_weight_from_components(solution_components_values))) + return solver_time, memory, components_values, total_weight def _ensure_components_values(self, solution): if not isinstance(solution, dict): From 30f48b895672823d91c419a95e7097e425bd705f Mon Sep 17 00:00:00 2001 From: Rony Date: Thu, 23 Apr 2026 11:25:47 -0500 Subject: [PATCH 4/5] Refactor: simplify weight calculation in MznDifferentialLinearModel and clean up test cases --- .../mzn_differential_linear_model.py | 3 +- .../mzn_differential_linear_model_test.py | 84 +------------------ 2 files changed, 2 insertions(+), 85 deletions(-) diff --git a/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model.py b/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model.py index d4ef63e8f..29c357c97 100644 --- a/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model.py +++ b/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model.py @@ -749,8 +749,7 @@ def _differential_linear_total_weight_from_components(self, components_values): p_weight, middle_sum, q_weight, has_middle_components = ( self._collect_differential_linear_component_weights(components_values) ) - r_weight = self._middle_weight_term(middle_sum, has_middle_components) - return round(p_weight + r_weight + (2 * q_weight), 10) + return round(p_weight + middle_sum + (2 * q_weight), 10) def _set_differential_linear_total_weight(self, solution): if not isinstance(solution, dict): diff --git a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model_test.py b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model_test.py index d5f0e7e38..bb6b2dc3e 100644 --- a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model_test.py +++ b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model_test.py @@ -4,11 +4,6 @@ import pytest from claasp.cipher_modules.models.cp.mzn_models.mzn_differential_linear_model import MznDifferentialLinearModel -from claasp.cipher_modules.models.cp.mzn_models.mzn_differential_linear_continuous_model import ( - MznDifferentialLinearContinuousModel, -) -from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_differential_model import MznXorDifferentialModel -from claasp.cipher_modules.models.cp.mzn_models.mzn_xor_linear_model import MznXorLinearModel from claasp.cipher_modules.models.cp.solvers import CPSAT, SCIP from claasp.cipher_modules.models.utils import ( differential_linear_checker_for_block_cipher_single_key, @@ -908,78 +903,6 @@ def test_differential_linear_trail_continuous_middle_speck_build(): decl_text = "\n".join(declarations) assert "var 0..2" not in decl_text -def test_continuous_middle_model_signed_input_and_key_match_table4_values(): - """ - Validate the standalone continuous model using signed (-1..1) inputs. - - For the Table 4 middle part setting, plaintext and key are fixed as - signed arrays. The expected correlation corresponds to output mask - 0x10000020. - """ - - def _sign_bits_from_hex(hex_value, bit_size): - value = int(hex_value, 16) - return [1.0 if ((value >> (bit_size - 1 - i)) & 1) else -1.0 for i in range(bit_size)] - - def _mask_bits_from_hex(hex_value, bit_size): - value = int(hex_value, 16) - return [1 if ((value >> (bit_size - 1 - i)) & 1) else 0 for i in range(bit_size)] - - def _solve_for_mask(mask_hex): - cipher = SpeckBlockCipher(block_bit_size=32, key_bit_size=64, number_of_rounds=2) - model = MznDifferentialLinearContinuousModel(cipher) - - fixed_inputs = [ - { - "component_id": "plaintext", - "bit_positions": list(range(32)), - "bit_values": _sign_bits_from_hex("0x10005000", 32), - }, - { - "component_id": "key", - "bit_positions": list(range(64)), - "bit_values": [-1.0] * 64, - }, - ] - - model.build_differential_linear_continuous_trail_model(fixed_values=fixed_inputs) - model._build_linear_mask_correlation_constraints() - model._build_difflin_corr_constraints() - - for index, bit in enumerate(_mask_bits_from_hex(mask_hex, 32)): - model._model_constraints.append(f"constraint output_mask[{index}] = {bit};") - - cipher_output_id = model._get_cipher_output_id() - model._model_constraints.append( - f"solve :: float_search({cipher_output_id}, 1e-12, smallest, indomain_min, complete) " - "minimize correlation_log2_approximation;" - ) - - try: - result = model.solve_for_ARX("scip", timeout_in_seconds_=120) - except Exception: - pytest.skip("Native MiniZinc SCIP solver is not available in this environment") - - parsed = model._parse_result(result, "scip") - assert parsed.get("status") in ("SATISFIED", "OPTIMAL_SOLUTION") - return parsed["differential_linear_correlation"] - - expected_corr = 0.745481409287476 - expected_total_probability = -11.42375571965992 - - corr_table4_mask = _solve_for_mask("0x10000020") - total_probability_table4_mask = math.log2(corr_table4_mask) - 7 - 2 * 2 - - assert abs(corr_table4_mask - expected_corr) < 1e-6 - assert abs(total_probability_table4_mask - expected_total_probability) < 1e-6 - - corr_integrated_bottom_mask = _solve_for_mask("0x00000020") - total_probability_integrated_bottom_mask = math.log2(corr_integrated_bottom_mask) - 7 - 2 * 2 - - assert abs(corr_integrated_bottom_mask - corr_table4_mask) > 1e-2 - assert abs(total_probability_integrated_bottom_mask - expected_total_probability) > 1e-2 - - def test_differential_linear_trail_continuous_middle_9_rounds_speck_table4_full_model_float_search_scip(): """ Full integrated 9-round Speck32/64 differential-linear model using native MiniZinc SCIP. @@ -1169,9 +1092,4 @@ def test_differential_linear_trail_continuous_middle_9_rounds_speck_table6_full_ total_probability = math.log2(correlation_value) - top_sum - bottom_sum*2 assert min(abs(value - expected_corr) for value in correlation_values) < 1e-6 - assert abs(total_probability - expected_total_probability) < 1e-6 - - - -if __name__ == "__main__": - test_differential_linear_trail_continuous_middle_9_rounds_speck_table4_full_model_float_search_scip() + assert abs(total_probability - expected_total_probability) < 1e-6 \ No newline at end of file From 47c5855884b5f9b6b60ba8614e9606ff66018bc8 Mon Sep 17 00:00:00 2001 From: Rony Date: Thu, 23 Apr 2026 23:49:14 -0500 Subject: [PATCH 5/5] FIX/Refactor: reduce cognitive complexity in continuous differential-linear wiring constraints Co-authored-by: Copilot --- .../mzn_differential_linear_model.py | 119 +++++++++++++----- .../mzn_differential_linear_model_test.py | 8 +- 2 files changed, 91 insertions(+), 36 deletions(-) diff --git a/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model.py b/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model.py index 29c357c97..53cae5e2f 100644 --- a/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model.py +++ b/claasp/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model.py @@ -277,28 +277,94 @@ def _continuous_middle_connecting_constraints(self, include_middle_sources=True) constraints.extend(self._continuous_middle_connecting_for_intermediate_outputs(include_middle_sources)) return constraints + def _is_valid_continuous_middle_source(self, source_component_id, source_side, include_middle_sources): + is_input = source_component_id in self._cipher.inputs + if not is_input and source_side != "o": + return False + if not include_middle_sources and source_component_id in self.middle_part_component_ids: + return False + return True + + def _is_valid_continuous_middle_successor(self, successor_component_id, successor_side): + successor_is_output = successor_component_id in ( + self._cipher.outputs if hasattr(self._cipher, "outputs") else [] + ) + return (successor_side == "i" or successor_is_output) and ( + successor_component_id in self.middle_part_component_ids + ) + + @staticmethod + def _continuous_middle_source_expr(source_component_id, source_bit_index): + return f"{source_component_id}[{int(source_bit_index)}]" + + def _continuous_middle_assignment_constraint(self, source_component_id, source_bit_expr, successor_bit_expr): + if source_component_id in self.middle_part_component_ids: + return f"constraint {successor_bit_expr} = {source_bit_expr};" + return f"constraint {successor_bit_expr} = if {source_bit_expr} = 1 then 1 else -1 endif;" + + def _continuous_middle_constraints_for_successors(self, source_component_id, source_bit_expr, successor_bits): + constraints = [] + for successor_component_id, successor_bit_index, successor_side in successor_bits: + if not self._is_valid_continuous_middle_successor(successor_component_id, successor_side): + continue + + successor_bit_expr = self._continuous_middle_input_expr( + successor_component_id, int(successor_bit_index) + ) + constraints.append( + self._continuous_middle_assignment_constraint( + source_component_id, source_bit_expr, successor_bit_expr + ) + ) + return constraints + def _continuous_middle_connecting_from_raw_bindings(self, include_middle_sources): constraints = [] for output_bit_id, successor_bits in self.raw_bit_bindings.items(): source_component_id, source_bit_index, source_side = output_bit_id - is_input = source_component_id in self._cipher.inputs - if not is_input and source_side != "o": - continue - if not include_middle_sources and source_component_id in self.middle_part_component_ids: + if not self._is_valid_continuous_middle_source( + source_component_id, source_side, include_middle_sources + ): continue - source_bit_expr = f"{source_component_id}[{int(source_bit_index)}]" - for successor_bit in successor_bits: - successor_component_id, successor_bit_index, successor_side = successor_bit - successor_is_output = successor_component_id in (self._cipher.outputs if hasattr(self._cipher, "outputs") else []) - if not (successor_side == "i" or successor_is_output) or successor_component_id not in self.middle_part_component_ids: - continue + source_bit_expr = self._continuous_middle_source_expr(source_component_id, source_bit_index) + constraints.extend( + self._continuous_middle_constraints_for_successors( + source_component_id, source_bit_expr, successor_bits + ) + ) + return constraints - successor_bit_expr = self._continuous_middle_input_expr(successor_component_id, int(successor_bit_index)) - if source_component_id in self.middle_part_component_ids: - constraints.append(f"constraint {successor_bit_expr} = {source_bit_expr};") - else: - constraints.append(f"constraint {successor_bit_expr} = if {source_bit_expr} = 1 then 1 else -1 endif;") + def _continuous_middle_intermediate_source_pin(self, pins): + return next((pin for pin in pins if pin[2] == "o" or pin[0] in self._cipher.inputs), None) + + def _should_skip_continuous_middle_intermediate_source(self, source_id, include_middle_sources): + return (not include_middle_sources) and (source_id in self.middle_part_component_ids) + + def _continuous_middle_constraint_from_intermediate_pin( + self, inter_id, inter_bit, source_pin, include_middle_sources + ): + if not source_pin: + return None + + source_id, source_bit, _ = source_pin + if self._should_skip_continuous_middle_intermediate_source(source_id, include_middle_sources): + return None + + source_expr = self._continuous_middle_source_expr(source_id, source_bit) + successor_expr = self._continuous_middle_input_expr(inter_id, int(inter_bit)) + return self._continuous_middle_assignment_constraint(source_id, source_expr, successor_expr) + + def _continuous_middle_constraints_for_intermediate_bit_dict(self, bit_dict, include_middle_sources): + constraints = [] + for inter_bit_tuple, pins in bit_dict.items(): + inter_id, inter_bit, _ = inter_bit_tuple + source_pin = self._continuous_middle_intermediate_source_pin(pins) + constraint = self._continuous_middle_constraint_from_intermediate_pin( + inter_id, inter_bit, source_pin, include_middle_sources + ) + if constraint: + constraints.append(constraint) return constraints def _continuous_middle_connecting_for_intermediate_outputs(self, include_middle_sources): @@ -306,22 +372,11 @@ def _continuous_middle_connecting_for_intermediate_outputs(self, include_middle_ for comp_id, bit_dict in self.raw_bit_bindings_for_intermediate_output.items(): if comp_id not in self.middle_part_component_ids: continue - for inter_bit_tuple, pins in bit_dict.items(): - inter_id, inter_bit, _ = inter_bit_tuple - source_pin = next((pin for pin in pins if pin[2] == "o" or pin[0] in self._cipher.inputs), None) - if not source_pin: - continue - - source_id, source_bit, _ = source_pin - if not include_middle_sources and source_id in self.middle_part_component_ids: - continue - - source_expr = f"{source_id}[{int(source_bit)}]" - successor_expr = self._continuous_middle_input_expr(inter_id, int(inter_bit)) - if source_id in self.middle_part_component_ids: - constraints.append(f"constraint {successor_expr} = {source_expr};") - else: - constraints.append(f"constraint {successor_expr} = if {source_expr} = 1 then 1 else -1 endif;") + constraints.extend( + self._continuous_middle_constraints_for_intermediate_bit_dict( + bit_dict, include_middle_sources + ) + ) return constraints def _top_to_middle_connecting_constraints(self): @@ -746,7 +801,7 @@ def _middle_weight_term(self, middle_sum, has_middle_components): return abs(math.log(abs(2 * (2**(-1*middle_sum)) - 1), 2)) def _differential_linear_total_weight_from_components(self, components_values): - p_weight, middle_sum, q_weight, has_middle_components = ( + p_weight, middle_sum, q_weight, _ = ( self._collect_differential_linear_component_weights(components_values) ) return round(p_weight + middle_sum + (2 * q_weight), 10) diff --git a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model_test.py b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model_test.py index bb6b2dc3e..f2fde52af 100644 --- a/tests/unit/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model_test.py +++ b/tests/unit/cipher_modules/models/cp/mzn_models/mzn_differential_linear_model_test.py @@ -981,8 +981,8 @@ def test_differential_linear_trail_continuous_middle_9_rounds_speck_table4_full_ bottom_sum += weight seen_bottom.add(base_component_id) - assert top_sum == 7.0 - assert bottom_sum == 2.0 + assert top_sum == 7 + assert bottom_sum == 2 correlation_values = [ float(line.split("=", 1)[1].strip()) @@ -1077,8 +1077,8 @@ def test_differential_linear_trail_continuous_middle_9_rounds_speck_table6_full_ bottom_sum += weight seen_bottom.add(base_component_id) - assert top_sum == 7.0 - assert bottom_sum == 3.0 + assert top_sum == 7 + assert bottom_sum == 3 correlation_values = [ float(line.split("=", 1)[1].strip())