diff --git a/src/nagini_contracts/contracts.py b/src/nagini_contracts/contracts.py index 0befa5200..621a7e1df 100644 --- a/src/nagini_contracts/contracts.py +++ b/src/nagini_contracts/contracts.py @@ -27,7 +27,8 @@ 'Exists', 'Low', 'Acc', 'Rd', 'Fold', 'Unfold', 'Unfolding', 'Previous', 'RaisedException', 'Sequence', 'ToSeq', 'MaySet', 'MayCreate', 'getMethod', 'getArg', 'getOld', 'arg', 'Joinable', - 'MayStart',] + 'MayStart', 'CallSlot', 'CallSlotProof', + 'UniversallyQuantified', 'ClosureCall',] T = TypeVar('T') V = TypeVar('V') @@ -261,7 +262,7 @@ def ContractOnly(func: T) -> T: """ return func - + def GhostReturns(start_index: int) -> Callable[[T], T]: """ Decorator for functions which specifies which return values are ghost @@ -278,6 +279,45 @@ def wrap(func: T) -> T: return wrap +def CallSlot(call_slot: Callable[..., None]) -> Callable[..., Any]: + """ + Decorator to mark a method as a call slot declaration. + """ + + def call_slot_handler(*args, **kwargs) -> Any: + + def uq_handler(*args, **kwargs) -> None: + pass + + return uq_handler + + return call_slot_handler + + +def UniversallyQuantified(uq: Callable[..., None]) -> None: + """ + Decorator to mark a method as introducing universally quantified + variables inside a call slot. + """ + pass + + +def CallSlotProof(call_slot: Callable[..., Any]) -> Callable[[Callable[..., None]], None]: + """ + Decorator to mark a method as a proof for a call slot. + """ + pass + + +def ClosureCall(call: T, justification: Any) -> T: + """ + Justifies a closure call through either + * a CallSlot (justification == the callslot instance) + * proofing static dispatch (justification == the static method) + """ + pass + + def list_pred(l: object) -> bool: """ Special, predefined predicate that represents the permissions belonging @@ -335,6 +375,10 @@ def dict_pred(d: object) -> bool: 'set_pred', 'Sequence', 'ToSeq', + 'CallSlot', + 'UniversallyQuantified', + 'CallSlotProof', + 'ClosureCall', 'MaySet', 'MayCreate', ] diff --git a/src/nagini_translation/analyzer.py b/src/nagini_translation/analyzer.py index 7cf83d849..9dba7b045 100644 --- a/src/nagini_translation/analyzer.py +++ b/src/nagini_translation/analyzer.py @@ -57,6 +57,14 @@ ) from nagini_translation.lib.views import PythonModuleView from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from nagini_translation.call_slot_analyzers import ( + CallSlotAnalyzer, + is_call_slot, + CallSlotProofAnalyzer, + is_call_slot_proof, + is_closure_call, + check_closure_call +) logger = logging.getLogger('nagini_translation.analyzer') @@ -77,6 +85,7 @@ def __init__(self, types: TypeInfo, path: str, selected: Set[str]): self.global_module, None, sil_names=self.global_module.sil_names, file=file) self.current_class = None + self.outer_functions = [] # type: List[PythonMethod] self.current_function = None self.current_scopes = [] self.contract_only = False @@ -90,6 +99,8 @@ def __init__(self, types: TypeInfo, path: str, selected: Set[str]): self._aliases = {} # Dict[str, PythonBaseVar] self.current_loop_invariant = None self.selected = selected + self.call_slot_analyzer = CallSlotAnalyzer(self) + self.call_slot_proof_analyzer = CallSlotProofAnalyzer(self) self.deferred_tasks = [] @property @@ -125,7 +136,9 @@ def define_new(self, container: Union[PythonModule, PythonClass], name in container.methods or name in container.predicates or (isinstance(container, PythonClass) and - name in container.static_methods)): + name in container.static_methods) or + (isinstance(container, PythonModule) and + name in container.call_slots)): raise InvalidProgramException(node, 'multiple.definitions') def collect_imports(self, abs_path: str) -> None: @@ -519,6 +532,12 @@ def visit_ImportFrom(self, node: ast.ImportFrom): self.analyze_import(node.module) def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + if is_call_slot(node): + self.call_slot_analyzer.analyze(node) + return + if is_call_slot_proof(node): + self.call_slot_proof_analyzer.analyze(node) + return if self.current_function: raise InvalidProgramException(node, 'nested.function.declaration') name = node.name @@ -584,6 +603,8 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None: if not is_property_setter: func.type = self.convert_type(functype) + func.callable_type = self.convert_type( + self.module.get_func_type(func.scope_prefix, callable=True)) for child in node.body: if is_io_existential(child): @@ -778,6 +799,8 @@ def visit_Call(self, node: ast.Call) -> None: Collects preconditions, postconditions, raised exceptions and invariants. """ + if is_closure_call(node): + check_closure_call(node) if (isinstance(node.func, ast.Name) and node.func.id in CONTRACT_WRAPPER_FUNCS): if node.func.id == 'Requires': @@ -888,14 +911,18 @@ def visit_Name(self, node: ast.Name) -> None: if isinstance(assign, ast.Assign) and len(assign.targets) == 1: var.value = assign.value self.module.global_vars[node.id] = var + current_module = self.module + # track access later def todo(): var = self.get_target(node, current_module) if isinstance(var, PythonGlobalVar): self.track_access(node, var) + self.deferred_tasks.append(todo) return + else: # Node is a static field. if isinstance(node.ctx, ast.Load): @@ -1036,6 +1063,8 @@ def convert_type(self, mypy_type, node=None) -> PythonType: return self._convert_type_type(mypy_type, node) elif self.types.is_callable_type(mypy_type): return self._convert_callable_type(mypy_type, node) + elif self.types.is_any_type(mypy_type): + raise UnsupportedException(node, 'Found Any type. Type annotation missing?') else: raise UnsupportedException(mypy_type) return result @@ -1054,7 +1083,13 @@ def _convert_normal_type(self, mypy_type) -> PythonType: return result def _convert_callable_type(self, mypy_type, node) -> PythonType: - return self.find_or_create_class(CALLABLE_TYPE, module=self.module.global_module) + return GenericType( + self.find_or_create_class(CALLABLE_TYPE, module=self.module.global_module), + ( + [self.convert_type(arg_type, node) for arg_type in mypy_type.arg_types] + + [self.convert_type(mypy_type.ret_type, node)] + ) + ) def _convert_union_type(self, mypy_type, node) -> PythonType: args = [self.convert_type(arg_type, node) @@ -1103,6 +1138,7 @@ def get_alt_types(self, node: ast.AST) -> Dict[int, PythonType]: context = [] if self.current_class is not None: context.append(self.current_class.name) + context.extend(map(lambda method: method.name, self.outer_functions)) if self.current_function is not None: context.append(self.current_function.name) name = node.id if isinstance(node, ast.Name) else node.arg @@ -1127,6 +1163,7 @@ def typeof(self, node: ast.AST) -> PythonType: context = [] if self.current_class is not None: context.append(self.current_class.name) + context.extend(map(lambda method: method.name, self.outer_functions)) if self.current_function is not None: context.append(self.current_function.name) context.extend(self.current_scopes) @@ -1165,6 +1202,7 @@ def typeof(self, node: ast.AST) -> PythonType: context = [] if self.current_class is not None: context.append(self.current_class.name) + context.extend(map(lambda method: method.name, self.outer_functions)) context.append(self.current_function.name) context.extend(self.current_scopes) type, _ = self.module.get_type(context, node.arg) @@ -1259,7 +1297,13 @@ def visit_Try(self, node: ast.Try) -> None: def _incompatible_decorators(self, decorators) -> bool: return ((('Predicate' in decorators) and ('Pure' in decorators)) or (('IOOperation' in decorators) and (len(decorators) != 1)) or - (('property' in decorators) and (len(decorators) != 1))) + (('property' in decorators) and (len(decorators) != 1)) or + ( + (('CallSlot' in decorators) and (len(decorators) != 1)) and + (('Pure' in decorators) and ('CallSlot' in decorators) and (len(decorators) != 2)) + ) or + (('UniversallyQuantified' in decorators) and (len(decorators) != 1)) or + (('CallSlotProof' in decorators) and (len(decorators) != 1))) def is_declared_contract_only(self, func: ast.FunctionDef) -> bool: """ diff --git a/src/nagini_translation/call_slot_analyzers.py b/src/nagini_translation/call_slot_analyzers.py new file mode 100644 index 000000000..f6ca50a0a --- /dev/null +++ b/src/nagini_translation/call_slot_analyzers.py @@ -0,0 +1,652 @@ +import ast +from typing import Union, List, Set +from nagini_contracts.contracts import ( + CONTRACT_WRAPPER_FUNCS, + CONTRACT_FUNCS +) +from nagini_translation.lib.program_nodes import ( + PythonModule +) +from nagini_translation.lib.util import ( + UnsupportedException, + InvalidProgramException, +) + + +class _CallSlotBaseAnalyzer: + + # TODO: pure call slots + + __ILLEGAL_VARIABLE_NAMES = set(CONTRACT_FUNCS + CONTRACT_WRAPPER_FUNCS) + + def __init__(self, analyzer: 'Analyzer') -> None: + self.analyzer = analyzer + self.call_slot = None # type: CallSlotBase + + def analyze(self, node: ast.FunctionDef) -> None: + """ + Preprocess the call slot `node'. + """ + + old_current_function = self.analyzer.current_function + if old_current_function is not None: + self.analyzer.outer_functions.append(old_current_function) + self._pre_process(node) + + self.analyzer.current_function = self.call_slot + self._check_method_declaration(self.call_slot) + + body_node = node + + has_uq_vars = _is_uq_vars(node.body) + if has_uq_vars: + body_node = node.body[0] + mock_call_slot = self.analyzer.node_factory.create_call_slot( + body_node.name, + body_node, + self.call_slot, + self.analyzer.node_factory + ) + + self.analyzer.outer_functions.append(self.call_slot) + self.analyzer.current_function = mock_call_slot + + self._check_method_declaration(mock_call_slot) + self.call_slot.uq_variables = mock_call_slot.args + + body_node._parent = node + + self.call_slot.body = body_node.body + for child in body_node.body: + self.analyzer.visit(child, body_node) + for decorator in node.decorator_list: + self.analyzer.visit(decorator, node) + + self._check_body(body_node.body) + + self._check_variables() + + # cleanup + if has_uq_vars: + self.call_slot._locals = mock_call_slot._locals + self.call_slot.precondition = mock_call_slot.precondition + self.call_slot.postcondition = mock_call_slot.postcondition + self.analyzer.outer_functions.pop() + + if self.call_slot.return_variables: + self.call_slot.type = self.call_slot.locals[self.call_slot.return_variables[0].id].type + self.analyzer.current_function = old_current_function + if old_current_function is not None: + self.analyzer.outer_functions.pop() + + def _check_method_declaration(self, call_slot: 'CallSlotBase') -> None: + """ + Checks whether `node' is a method declaration valid for a call slot or + universally quantified variables. If not raises an appropriate + exception. Expects analyzer.{current_function, outer_functions} to be + set correctly. + + * No magic name ('__...__') + * Return type = None + * No *args + * No **kwargs + """ + + analyzer = self.analyzer + + if analyzer._is_illegal_magic_method_name(call_slot.node.name): + raise InvalidProgramException(call_slot.node, 'illegal.magic.method') + + _type = analyzer.convert_type( + analyzer.module.get_func_type(call_slot.scope_prefix)) + + if _type is not None: + raise InvalidProgramException( + call_slot.node, + 'call_slots.return.not_none', + "Method '%s' doesn't return 'None'" % call_slot.node.name + ) + + if 0 < len(call_slot.node.args.defaults): + raise InvalidProgramException( + call_slot.node.args.defaults[0], + 'call_slots.parameters.default', + "Method '%s' has a default parameter" % call_slot.node.name + ) + + analyzer.visit(call_slot.node.args, call_slot.node) + + if call_slot.var_arg is not None: + raise InvalidProgramException( + call_slot.node, + 'call_slots.parameters.var_args', + ("Method '%s' contains illegal variadic parameters" + % call_slot.node.name) + ) + + if call_slot.kw_arg is not None: + raise InvalidProgramException( + call_slot.node, + 'call_slots.parameters.kw_args', + ("Method '%s' contains illegal keyword parameters" + % call_slot.node.name) + ) + + def _pre_process(self, node: ast.FunctionDef) -> None: + """ + Abstract method for pre processing. + Has to initialize self.call_slot + """ + raise NotImplementedError() + + def _check_body(self, body: List[ast.stmt]) -> None: + """ + Abstract method to check whether the body is valid. + """ + raise NotImplementedError() + + def _check_variables(self) -> None: + + # argument variables + argv = self.call_slot.args + # universally quantified variables + uqv = self.call_slot.uq_variables + # return variables + rtv = self.call_slot.return_variables + + shadowed_variables = argv.keys() & uqv.keys() + + if 0 < len(shadowed_variables): + shadowed_variable_name = next(iter(shadowed_variables)) + raise InvalidProgramException( + uqv[shadowed_variable_name].node, + "call_slots.parameters.illegal_shadowing", + "UQ variable '%s' illegally shadows an outer variable" % shadowed_variable_name + ) + + all_variable_names = argv.keys() | uqv.keys() + + assert rtv is not None + if 0 < len(rtv): + + assert len(rtv) == 1 + return_variable = rtv[0] + assert isinstance(return_variable, ast.Name) + + if return_variable.id in all_variable_names: + raise InvalidProgramException( + return_variable, + "call_slots.parameters.illegal_shadowing", + "return variable '%s' illegally shadows an outer variable" % return_variable.id + ) + + all_variable_names.add(return_variable.id) + + invalid_variable_names = all_variable_names & _CallSlotBaseAnalyzer.__ILLEGAL_VARIABLE_NAMES + + if 0 < len(invalid_variable_names): + illegal_variable_name = next(iter(invalid_variable_names)) + + if illegal_variable_name in argv: + illegal_variable = argv[illegal_variable_name].node + elif illegal_variable_name in uqv: + illegal_variable = uqv[illegal_variable_name].node + else: + illegal_variable = rtv[0] + + raise InvalidProgramException( + illegal_variable, + "call_slots.parameters.illegal_name", + "Variable '%s' has an illegal name" % illegal_variable_name + ) + + +class CallSlotAnalyzer(_CallSlotBaseAnalyzer): + + def _pre_process(self, node: ast.FunctionDef) -> None: + + assert is_call_slot(node) + scope = self.analyzer.module + assert isinstance(scope, PythonModule) + + if self.analyzer.current_function is not None: + raise InvalidProgramException( + node, + 'call_slots.nested.declaration', + "Callslot '%s' occurs inside a method" % node.name + ) + if self.analyzer.current_class is not None: + raise InvalidProgramException( + node, + 'call_slots.nested.declaration', + "Callslot '%s' occurs inside a class" % node.name + ) + + self.analyzer.define_new(scope, node.name, node) + + self.call_slot = self.analyzer.node_factory.create_call_slot( + node.name, + node, + scope, + self.analyzer.node_factory + ) + + + if _has_double_decorator_name(node, 'Pure', 'CallSlot'): + self.call_slot.pure = True + + scope.call_slots[node.name] = self.call_slot + + def _check_body(self, body: List[ast.stmt]) -> None: + for child in body: + + if isinstance(child, ast.Expr) and isinstance(child.value, ast.Call): + if is_precondition(child.value) or is_postcondition(child.value): + continue + self._check_call_declaration(child.value) + + elif isinstance(child, ast.Assign): + self._check_call_declaration(child) + + else: + raise InvalidProgramException( + child, + 'call_slots.body.invalid_stmt', + 'Callslot declarations must only consist of contracts and a single call' + ) + + if self.call_slot.return_variables is None: + raise InvalidProgramException( + self.call_slot.node, + 'call_slots.no.call', + "Callslot '%s' doesn't declare a call" % self.call_slot.node.name + ) + + def _check_call_declaration(self, node: Union[ast.Call, ast.Assign]) -> None: + + if self.call_slot.return_variables is not None: + raise InvalidProgramException( + node, + 'call_slots.multiple.calls', + "Callslot '%s' declares more than one call" % self.call_slot.node.name + ) + + if isinstance(node, ast.Assign): + + if not isinstance(node.value, ast.Call): + raise InvalidProgramException( + node, + 'call_slots.body.invalid_stmt', + 'Callslot declarations must only consist of contracts and a single call' + ) + call = node.value + + if len(node.targets) > 1: + raise UnsupportedException( + node, + "Callslot's call can't have more than one return target" + ) + + assert len(node.targets) == 1 + + if isinstance(node.targets[0], ast.Name): + self.call_slot.return_variables = [node.targets[0]] + else: + raise UnsupportedException( + node, + "Callslot's call supports only a single variable as return target" + ) + else: + self.call_slot.return_variables = [] + call = node + + self._check_call(call) + + def _check_call(self, call: ast.Call) -> None: + self.call_slot.call = call + + if not isinstance(call.func, ast.Name): + raise InvalidProgramException( + call.func, + 'call_slots.call_declaration.invalid_target', + "Callslot '%s' has an invalid call target" % self.call_slot.node.name + ) + if call.func.id not in self.call_slot.args: + raise InvalidProgramException( + call.func, + 'call_slots.call_declaration.invalid_target', + ("Callslot '%s' has an invalid call target (target must be a normal variable)" % + self.call_slot.node.name) + ) + + +class CallSlotProofAnalyzer(_CallSlotBaseAnalyzer): + + def _pre_process(self, node: ast.FunctionDef) -> None: + assert is_call_slot_proof(node) + + if self.analyzer.current_function is None: + raise InvalidProgramException( + node, + 'call_slots.proof.outside_method', + "Callslotproof '%s' occurs outside a method" % node.name + ) + + assert len(node.decorator_list) == 1 + assert isinstance(node.decorator_list[0], ast.Call) + assert isinstance(node.decorator_list[0].func, ast.Name) + assert node.decorator_list[0].func.id == 'CallSlotProof' + + proof_annotation = node.decorator_list[0] # type: ast.Call + + assert len(proof_annotation.args) == 1 + + call_slot_instantiation = proof_annotation.args[0] + if not isinstance(call_slot_instantiation, ast.Call): + raise InvalidProgramException( + proof_annotation, + 'call_slots.proof_annotation.invalid_arg', + "Callslot proof '%s' doesn't have a valid call slot instantiation" + ) + + if not isinstance(call_slot_instantiation.func, ast.Name): + raise InvalidProgramException( + proof_annotation, + 'call_slots.proof_annotation.invalid_arg', + "Callslot proof '%s' doesn't have a valid call slot instantiation" + ) + + if len(call_slot_instantiation.args) != len(node.args.args): + raise InvalidProgramException( + proof_annotation, + 'call_slots.proof_annotation.invalid_arg', + ) + + self.call_slot = self.analyzer.node_factory.create_call_slot_proof( + node.name, + node, + self.analyzer.current_function, + self.analyzer.node_factory, + call_slot_instantiation, + self.analyzer.current_function.get_fresh_name('__proof_old_label') + ) + + self.analyzer.current_function.call_slot_proofs[node] = self.call_slot + + def _check_body(self, body: List[ast.stmt]) -> None: + + # Possible extensions: + # - local variables with assignments + # - while loops + # - new statements (for local variables) + # - restricted method calls (only for 'local state') + + for child in body: + + if isinstance(child, ast.Expr) and isinstance(child.value, ast.Call): + if is_precondition(child.value) or is_postcondition(child.value): + continue + if is_fold(child.value) or is_unfold(child.value): + continue + if is_assume(child.value) or is_assert(child.value): + continue + self._check_call_declaration(child.value) + + elif isinstance(child, ast.Assign): + self._check_call_declaration(child) + + elif isinstance(child, ast.Assert): + continue + + elif isinstance(child, ast.FunctionDef): + if not is_call_slot_proof(child): + # NOTE: dead code, Analyzer will throw before we can reach this + raise InvalidProgramException( + child, + 'call_slots.proof_body.invalid_stmt', + "Illegal statement in call slot proof '%s'" % self.call_slot.node.name + ) + # other call slot proof checks are done elsewhere + + elif isinstance(child, ast.If): + # check purity of condition later + self._check_body(child.body) + self._check_body(child.orelse) + + else: + raise InvalidProgramException( + child, + 'call_slots.proof_body.invalid_stmt', + "Illegal statement in call slot proof '%s'" % self.call_slot.node.name + ) + + def _check_call_declaration(self, node: Union[ast.Call, ast.Assign]) -> None: + + if isinstance(node, ast.Assign): + + if not isinstance(node.value, ast.Call): + raise InvalidProgramException( + node, + 'call_slots.proof_body.invalid_stmt', + "Callslot proof '%s' contains an illegal assignment" % self.call_slot.node.name + ) + call = node.value + + if len(node.targets) > 1: + raise UnsupportedException( + node, + "Callslot proof's call can't have more than one return target" + ) + + assert len(node.targets) == 1 + + if isinstance(node.targets[0], ast.Name): + if self.call_slot.return_variables is None: + self.call_slot.return_variables = [node.targets[0]] + elif ( + len(self.call_slot.return_variables) != 1 or + self.call_slot.return_variables[0].id != node.targets[0].id + ): + raise InvalidProgramException( + node, + 'call_slots.proof_body.different_return_variables', + "Callslot proof '%s' uses different return variables" % self.call_slot.node.name + ) + else: + raise UnsupportedException( + node, + "Callslot proof's call supports only a single variable as return target" + ) + else: + if self.call_slot.return_variables is None: + self.call_slot.return_variables = [] + elif len(self.call_slot.return_variables) != 0: + raise InvalidProgramException( + node, + 'call_slots.proof_body.different_return_variables', + "Callslot proof '%s' uses different return variables" % self.call_slot.node.name + ) + call = node + + self._check_closure_call(call) + + def _check_closure_call(self, closureCall: ast.Call) -> None: + + if not is_closure_call(closureCall): + raise InvalidProgramException( + closureCall, + 'call_slots.proof_call.not_closure_call', + "Callslot proof '%s' has a call which is not a ClosureCall" % self.call_slot.node.name + ) + + assert len(closureCall.args) == 2 # guaranteed by type checker + assert isinstance(closureCall.args[0], ast.Call) + + self._check_call(closureCall.args[0]) + + def _check_call(self, call: ast.Call) -> None: + + if not isinstance(call.func, ast.Name): + raise InvalidProgramException( + call.func, + 'call_slots.proof_call.invalid_target', + "Callslot proof '%s' has an invalid call target" % self.call_slot.node.name + ) + if call.func.id not in self.call_slot.args: + raise InvalidProgramException( + call.func, + 'call_slots.proof_call.invalid_target', + "Callslot '%s' has an invalid call target" % self.call_slot.node.name + ) + + +def check_closure_call(closureCall: ast.Call) -> None: + assert is_closure_call(closureCall) + assert len(closureCall.args) == 2 # guaranteed by type checker + + if not isinstance(closureCall.args[0], ast.Call): + raise InvalidProgramException( + closureCall.args[0], + 'call_slots.closure_call.invalid_call', + "ClosureCall's first argument has to be a call of a closure" + ) + + justification = closureCall.args[1] + + if not isinstance(justification, (ast.Name, ast.Call)): + raise InvalidProgramException( + justification, + 'call_slots.closure_call.invalid_justification', + "ClosureCall's justification has to be either a call slot or static dispatch" + ) + + if isinstance(justification, ast.Call): + if not isinstance(justification.func, ast.Call): + raise InvalidProgramException( + justification, + 'call_slots.closure_call.invalid_justification', + "ClosureCall's justification has to instatiate uq variables if it's a call slot" + ) + + if not isinstance(justification.func.func, ast.Name): + raise InvalidProgramException( + justification, + 'call_slots.closure_call.invalid_justification', + "ClosureCall's justification has to be a named if it's a call slot" + ) + + +def _is_uq_vars(body) -> bool: + return ( + len(body) == 1 and + isinstance(body[0], ast.FunctionDef) and + is_universally_quantified(body[0]) + ) + + +def is_call_slot(node: ast.FunctionDef) -> bool: + """ + Whether node is a call slot declaration. + """ + return ( + _has_single_decorator_name(node, 'CallSlot') or + _has_double_decorator_name(node, 'Pure', 'CallSlot') + ) + + +def is_universally_quantified(node: ast.FunctionDef) -> bool: + """ + Whether a function introduces universally quantified variables + """ + return _has_single_decorator_name(node, 'UniversallyQuantified') + + +def is_call_slot_proof(node: ast.FunctionDef) -> bool: + """ + Whether a function introduces universally quantified variables + """ + return _has_single_decorator_call(node, 'CallSlotProof') + + +def _has_single_decorator_name(node: ast.FunctionDef, decorator_name: str) -> bool: + """ + Whether `node' has only one decorator that equals to `decorator' + """ + if len(node.decorator_list) != 1: + return False + + decorator = node.decorator_list[0] + if not isinstance(decorator, ast.Name): + return False + + return decorator.id == decorator_name + + +def _has_double_decorator_name( + node: ast.FunctionDef, + decorator_name1: str, + decorator_name2: str +) -> bool: + """ + Whether `node' has only one decorator that equals to `decorator' + """ + if len(node.decorator_list) != 2: + return False + + decorator1, decorator2 = node.decorator_list + if not isinstance(decorator1, ast.Name): + return False + + if not isinstance(decorator2, ast.Name): + return False + + if decorator1 == decorator_name2: + decorator1, decorator2 = decorator2, decorator1 + + return decorator1.id == decorator_name1 and decorator2.id == decorator_name2 + + +def _has_single_decorator_call(node: ast.FunctionDef, decorator_name: str) -> bool: + """ + Whether `node' has only one decorator that equals to `decorator' + """ + if len(node.decorator_list) != 1: + return False + decorator = node.decorator_list[0] + + if not isinstance(decorator, ast.Call): + return False + + return isinstance(decorator.func, ast.Name) and decorator.func.id == decorator_name + + +def is_closure_call(call: ast.Call) -> bool: + return is_named_call(call, 'ClosureCall') + + +def is_precondition(call: ast.Call) -> bool: + return is_named_call(call, 'Requires') + + +def is_postcondition(call: ast.Call) -> bool: + return is_named_call(call, 'Ensures') + + +def is_fold(call: ast.Call) -> bool: + return is_named_call(call, 'Unfold') + + +def is_unfold(call: ast.Call) -> bool: + return is_named_call(call, 'Fold') + + +def is_assume(call: ast.Call) -> bool: + return is_named_call(call, 'Assume') + + +def is_assert(call: ast.Call) -> bool: + return is_named_call(call, 'Assert') + + +def is_named_call(call: ast.Call, name: str) -> bool: + return isinstance(call.func, ast.Name) and call.func.id == name diff --git a/src/nagini_translation/conftest.py b/src/nagini_translation/conftest.py index 17554a566..7ef1ddb01 100644 --- a/src/nagini_translation/conftest.py +++ b/src/nagini_translation/conftest.py @@ -26,6 +26,7 @@ _SIF_TESTS_DIR = 'tests/sif/' _IO_TESTS_DIR = 'tests/io/' _OBLIGATIONS_TESTS_DIR = 'tests/obligations/' +_CLOSURES_TESTS_DIR = 'tests/closures/' class PyTestConfig: @@ -55,6 +56,8 @@ def add_test(self, test: str): self._add_test_dir(_IO_TESTS_DIR) elif test == 'obligations': self._add_test_dir(_OBLIGATIONS_TESTS_DIR) + elif test == 'closures': + self._add_test_dir(_CLOSURES_TESTS_DIR) else: print('Unrecognized test set.') @@ -114,6 +117,7 @@ def pytest_addoption(parser: 'pytest.config.Parser'): parser.addoption('--sif', dest='sif', action='store_true') parser.addoption('--io', dest='io', action='store_true') parser.addoption('--obligations', dest='obligations', action='store_true') + parser.addoption('--closures', dest='closures', action='store_true') parser.addoption('--all-verifiers', dest='all_verifiers', action='store_true') parser.addoption('--silicon', dest='silicon', action='store_true') @@ -125,7 +129,7 @@ def pytest_configure(config: 'pytest.config.Config'): # Setup tests. tests = [] if config.option.all_tests: - tests = ['functional', 'sif', 'io', 'obligations'] + tests = ['functional', 'sif', 'io', 'obligations', 'closures'] else: if config.option.functional: tests.append('functional') @@ -135,6 +139,8 @@ def pytest_configure(config: 'pytest.config.Config'): tests.append('io') if config.option.obligations: tests.append('obligations') + if config.option.closures: + tests.append('closures') if tests: # Overwrite config file options. _pytest_config.clear_tests() diff --git a/src/nagini_translation/lib/config.py b/src/nagini_translation/lib/config.py index bb28cb0cc..b7c04ee86 100644 --- a/src/nagini_translation/lib/config.py +++ b/src/nagini_translation/lib/config.py @@ -125,7 +125,7 @@ def __init__(self, config) -> None: tests_value = self._info.get('tests') if not tests_value: - self.tests = ['functional', 'sif', 'io', 'obligations'] + self.tests = ['functional', 'sif', 'io', 'obligations', 'closures'] else: self.tests = tests_value.strip().split() diff --git a/src/nagini_translation/lib/constants.py b/src/nagini_translation/lib/constants.py index 6cdc4468d..37d2402d2 100644 --- a/src/nagini_translation/lib/constants.py +++ b/src/nagini_translation/lib/constants.py @@ -55,8 +55,6 @@ THREAD_DOMAIN = 'Thread' -METHOD_ID_DOMAIN = 'ThreadingID' - GET_ARG_FUNC = 'getArg' GET_OLD_FUNC = 'getOld' @@ -101,7 +99,6 @@ JOINABLE_FUNC, THREAD_POST_PRED, THREAD_START_PRED, - METHOD_ID_DOMAIN, GET_ARG_FUNC, GET_OLD_FUNC, GET_METHOD_FUNC, @@ -249,7 +246,7 @@ PRIMITIVES = {PRIMITIVE_INT_TYPE, PRIMITIVE_BOOL_TYPE, PRIMITIVE_SEQ_TYPE, CALLABLE_TYPE} -BOXED_PRIMITIVES = {INT_TYPE, BOOL_TYPE} +BOXED_PRIMITIVES = {INT_TYPE, BOOL_TYPE, CALLABLE_TYPE} NAME_VAR = '__name__' diff --git a/src/nagini_translation/lib/context.py b/src/nagini_translation/lib/context.py index 2b6453550..7578f44c0 100644 --- a/src/nagini_translation/lib/context.py +++ b/src/nagini_translation/lib/context.py @@ -13,6 +13,7 @@ PythonType, PythonVar, PythonVarBase, + CallSlotProof, ) from nagini_translation.lib.typedefs import Expr from typing import Dict, List @@ -43,6 +44,7 @@ def __init__(self) -> None: self._current_alias_context = [] self.bound_type_vars = {} self._global_counter = 0 + self.current_call_slot_proof = None # type: CallSlotProof self.perm_factor = None # If this is set, all translated permission amounts # are multiplied by this factor. self._old_aliases = {} # Keys are pretty-printed Python expressions, diff --git a/src/nagini_translation/lib/jvmaccess.py b/src/nagini_translation/lib/jvmaccess.py index bbe6937ab..372e34a96 100644 --- a/src/nagini_translation/lib/jvmaccess.py +++ b/src/nagini_translation/lib/jvmaccess.py @@ -13,8 +13,9 @@ class JVM: """ def __init__(self, classpath: str): - jpype.startJVM(jpype.getDefaultJVMPath(), - '-Djava.class.path=' + classpath, '-Xss128m') + if not jpype.isJVMStarted(): + jpype.startJVM(jpype.getDefaultJVMPath(), + '-Djava.class.path=' + classpath, '-Xss128m') self.java = jpype.JPackage('java') self.scala = jpype.JPackage('scala') self.viper = jpype.JPackage('viper') diff --git a/src/nagini_translation/lib/program_nodes.py b/src/nagini_translation/lib/program_nodes.py index e162f3554..7075c170a 100644 --- a/src/nagini_translation/lib/program_nodes.py +++ b/src/nagini_translation/lib/program_nodes.py @@ -146,6 +146,7 @@ def __init__(self, types: TypeInfo, self.classes = OrderedDict() self.functions = OrderedDict() self.methods = OrderedDict() + self.call_slots = OrderedDict() # type: Dict[str, CallSlot] self.predicates = OrderedDict() self.io_operations = OrderedDict() self.global_vars = OrderedDict() @@ -189,6 +190,8 @@ def process(self, translator: 'Translator') -> None: function.process(self.get_fresh_name(name), translator) for name, method in self.methods.items(): method.process(self.get_fresh_name(name), translator) + for name, call_slot in self.call_slots.items(): + call_slot.process(self.get_fresh_name(name), translator) for name, predicate in self.predicates.items(): predicate.process(self.get_fresh_name(name), translator) for name, var in self.global_vars.items(): @@ -230,7 +233,7 @@ def get_type(self, prefixes: List[str], return module_result return None, None - def get_func_type(self, path: List[str]): + def get_func_type(self, path: List[str], callable=False): """ Returns the type of the function identified by the given path in the current module (including imported other modules). It is assumed that @@ -239,11 +242,11 @@ def get_func_type(self, path: List[str]): """ actual_prefix = self.type_prefix.split('.') if self.type_prefix else [] actual_prefix.extend(path) - local_result = self.types.get_func_type(actual_prefix) + local_result = self.types.get_func_type(actual_prefix, callable) if local_result is not None: return local_result for module in self.from_imports: - module_result = module.get_func_type(prefix) + module_result = module.get_func_type(prefix, callable) if module_result is not None: return module_result return None @@ -270,7 +273,8 @@ def get_contents(self, only_top: bool) -> Dict: elements that can be accessed without a receiver. """ dicts = [self.classes, self.functions, self.global_vars, self.methods, - self.predicates, self.io_operations, self.namespaces] + self.predicates, self.io_operations, self.namespaces, + self.call_slots] return CombinedDict([], dicts) @@ -908,6 +912,7 @@ def __init__(self, name: str, node: ast.AST, cls: PythonClass, self.var_arg = None # direct self.kw_arg = None # direct self.type = None # infer + self.callable_type = None # type: PythonType self.generic_type = -1 self.result = None # infer self.error_var = None # infer @@ -924,7 +929,7 @@ def __init__(self, name: str, node: ast.AST, cls: PythonClass, self.type_vars = OrderedDict() self.setter = None self.func_constant = None - self.threading_id = None + self.call_slot_proofs = {} # type: Dict[ast.FunctionDef, CallSlotProof] self.definition_deps = set() self.call_deps = set() @@ -935,7 +940,6 @@ def add_all_call_deps(self, res: Set[Tuple[ast.AST, PythonNode, PythonModule]], """ add_all_call_deps(self.call_deps, res, prefix) - def process(self, sil_name: str, translator: 'Translator') -> None: """ Creates fresh Silver names for all parameters and initializes them, @@ -943,9 +947,7 @@ def process(self, sil_name: str, translator: 'Translator') -> None: checks if this method overrides one from a superclass, """ self.sil_name = sil_name - self.threading_id = self.superscope.get_fresh_name(self.name + "_threading") - if self.pure: - self.func_constant = self.superscope.get_fresh_name(self.name) + self.func_constant = self.superscope.get_fresh_name(self.name) for name, arg in self.args.items(): arg.process(self.get_fresh_name(name), translator) if self.var_arg: @@ -1726,6 +1728,38 @@ def create_python_method( container_factory, interface, interface_dict, method_type) + def create_call_slot( + self, + name: str, + node: ast.FunctionDef, + superscope: PythonScope, + container_factory: 'ProgramNodeFactory' + ) -> 'CallSlot': + return CallSlot( + name, + node, + superscope, + container_factory + ) + + def create_call_slot_proof( + self, + name: str, + node: ast.FunctionDef, + superscope: PythonScope, + container_factory: 'ProgramNodeFactory', + call_slot_instantiation: ast.Call, + old_label: str + ) -> 'CallSlotProof': + return CallSlotProof( + name, + node, + superscope, + container_factory, + call_slot_instantiation, + old_label + ) + def create_python_io_operation(self, name: str, node: ast.AST, superscope: PythonScope, container_factory: 'ProgramNodeFactory', @@ -1744,6 +1778,131 @@ def create_python_class(self, name: str, superscope: PythonScope, return PythonClass(name, superscope, node_factory, node, superclass, interface) + +class CallSlotBase(PythonMethod): + + def __init__( + self, + name: str, + node: ast.FunctionDef, + superscope: PythonScope, + node_factory: 'ProgramNodeFactory', + ) -> None: + + PythonMethod.__init__( + self, + name, + node, + None, # cls: PythonClass + superscope, + False, # pure: bool + False, # contract_only: bool + node_factory # node_factory: 'ProgramNodeFactory' + ) + + # universally quantified variables + self.uq_variables = OrderedDict() # type: Dict[str, PythonVar] + # NOTE: currently we only support one single return variable + self.return_variables = None # type: List[ast.Name] + self.body = None # type: List[ast.stmt] + + +class CallSlot(CallSlotBase): + + def __init__( + self, + name: str, + node: ast.FunctionDef, + superscope: PythonScope, + node_factory: 'ProgramNodeFactory', + ) -> None: + + CallSlotBase.__init__( + self, + name, + node, + superscope, + node_factory, + ) + + self.call = None # type: ast.Call + self.sil_application_name = None # type: str + + def process(self, sil_name: str, translator: 'Translator') -> None: + """ + Creates fresh Silver names for all parameters and initializes them, + same for local variables. Also sets the method type and + checks if this method overrides one from a superclass, + """ + self.sil_name = sil_name + self.sil_application_name = self.get_fresh_name(self.name + '_apply') + for name, arg in self.args.items(): + arg.process(self.get_fresh_name(name), translator) + for name, uq_var in self.uq_variables.items(): + uq_var.process(self.get_fresh_name(name), translator) + for local in self.locals: + self.locals[local].process(self.get_fresh_name(local), translator) + self.obligation_info = translator.create_obligation_info(self) + if self.type is not None and self.return_variables: + self.result = self.locals[self.return_variables[0].id] + + def get_contents(self, only_top: bool) -> Dict: + """ + Returns the elements that can be accessed from this container (to be + used by get_target). If 'only_top' is true, returns only top level + elements that can be accessed without a receiver. + """ + dicts = [super().get_contents(only_top), self.uq_variables] + return CombinedDict([], dicts) + + def get_variable(self, name: str) -> Optional['PythonVar']: + """ + Returns the variable (local variable or method parameter) with the + given name. + """ + + if name in self.locals: + return self.locals[name] + elif name in self.args: + return self.args[name] + elif name in self.uq_variables: + return self.uq_variables[name] + elif name in self.special_vars: + return self.special_vars[name] + elif name in self.io_existential_vars: + return self.io_existential_vars[name] + elif self.var_arg and self.var_arg.name == name: + return self.var_arg + elif self.kw_arg and self.kw_arg.name == name: + return self.kw_arg + else: + return self.module.global_vars.get(name) + + +class CallSlotProof(CallSlotBase): + + def __init__( + self, + name: str, + node: ast.FunctionDef, + superscope: PythonScope, + node_factory: 'ProgramNodeFactory', + call_slot_instantiation: ast.Call, + old_label: str + ) -> None: + + CallSlotBase.__init__( + self, + name, + node, + superscope, + node_factory, + ) + + self.call_slot_instantiation = call_slot_instantiation + self.old_label = old_label # type: str + + def toposort_classes(class_set: Set[PythonClass]) -> List[PythonClass]: """ Topological sorting of classes in a set, ensuring that derived classes @@ -1782,4 +1941,4 @@ def chain_cond_exp(guarded_expr: List[Tuple[Expr, Expr]], _, else_exp = guarded_expr[1] else: else_exp = chain_cond_exp(guarded_expr[1:], viper, position, info, ctx) - return viper.CondExp(guard, then_exp, else_exp, position, info) \ No newline at end of file + return viper.CondExp(guard, then_exp, else_exp, position, info) diff --git a/src/nagini_translation/lib/resolver.py b/src/nagini_translation/lib/resolver.py index 3be3f2ddf..a6bd98ea3 100644 --- a/src/nagini_translation/lib/resolver.py +++ b/src/nagini_translation/lib/resolver.py @@ -197,6 +197,8 @@ def _do_get_type(node: ast.AST, containers: List[ContainerInterface], rectype = get_type(node.func.value, containers, container) if target.generic_type != -1: return rectype.type_args[target.generic_type] + if isinstance(node, ast.Name) and not isinstance(node._parent, ast.Call): + return target.callable_type return target.type if isinstance(target, PythonField): result = target.type @@ -378,6 +380,8 @@ def _get_call_type(node: ast.Call, module: PythonModule, arg_type = get_type(node.args[0], containers, container) list_class = module.global_module.classes[LIST_TYPE] return GenericType(list_class, [arg_type]) + elif node.func.id == 'ClosureCall': + return get_type(node.args[0], containers, container).type_args[-1] elif node.func.id in ('getArg', 'getOld', 'getMethod'): object_class = module.global_module.classes[OBJECT_TYPE] return object_class @@ -485,4 +489,4 @@ def pairwise_supertype(t1: PythonType, t2: PythonType) -> Optional[PythonType]: return None if not t1.superclass: return pairwise_supertype(t2.superclass, t1) - return pairwise_supertype(t2, t1.superclass) \ No newline at end of file + return pairwise_supertype(t2, t1.superclass) diff --git a/src/nagini_translation/lib/typeinfo.py b/src/nagini_translation/lib/typeinfo.py index e42dba82a..e80a38707 100644 --- a/src/nagini_translation/lib/typeinfo.py +++ b/src/nagini_translation/lib/typeinfo.py @@ -104,9 +104,8 @@ def visit_name_expr(self, node: mypy.nodes.NameExpr): break if (node.name not in LITERALS and not is_alias): name_type = self.type_of(node) - if not isinstance(name_type, mypy.types.CallableType): - self.set_type(self.prefix + [node.name], name_type, - node.line, col(node)) + self.set_type(self.prefix + [node.name], name_type, + node.line, col(node)) def visit_star_expr(self, node: mypy.nodes.StarExpr): node.expr.accept(self) @@ -115,7 +114,7 @@ def visit_func_def(self, node: mypy.nodes.FuncDef): oldprefix = self.prefix self.prefix = self.prefix + [node.name()] functype = self.type_of(node) - self.set_type(self.prefix, functype, node.line, col(node), True) + self.set_type(self.prefix, functype, node.line, col(node)) for arg in node.arguments: self.set_type(self.prefix + [arg.variable.name()], arg.variable.type, arg.line, col(arg)) @@ -138,9 +137,7 @@ def visit_class_def(self, node: mypy.nodes.ClassDef): super().visit_class_def(node) self.prefix = oldprefix - def set_type(self, fqn, type, line, col, return_type=False): - if return_type and isinstance(type, mypy.types.CallableType): - type = type.ret_type + def set_type(self, fqn, type, line, col): if not type or isinstance(type, mypy.types.AnyType): if line in self.ignored_lines: return @@ -307,7 +304,7 @@ def get_type(self, prefix: List[str], name: str): else: return result, alts - def get_func_type(self, prefix: List[str]): + def get_func_type(self, prefix: List[str], callable=False): """ Looks up the type of the function which creates the given context """ @@ -316,9 +313,9 @@ def get_func_type(self, prefix: List[str]): if len(prefix) == 0: return None else: - return self.get_func_type(prefix[:len(prefix) - 1]) + return self.get_func_type(prefix[:len(prefix) - 1], callable) else: - if isinstance(result, mypy.types.FunctionLike): + if not callable and isinstance(result, mypy.types.FunctionLike): result = result.ret_type return result @@ -340,6 +337,9 @@ def is_union_type(self, type: mypy.types.Type) -> bool: def is_callable_type(self, type: mypy.types.Type) -> bool: return isinstance(type, mypy.types.CallableType) + def is_any_type(self, type: mypy.types.Type) -> bool: + return isinstance(type, mypy.types.AnyType) + def is_type_type(self, type: mypy.types.Type) -> bool: return isinstance(type, mypy.types.TypeType) diff --git a/src/nagini_translation/lib/viper_ast.py b/src/nagini_translation/lib/viper_ast.py index 6a62da2af..412182041 100644 --- a/src/nagini_translation/lib/viper_ast.py +++ b/src/nagini_translation/lib/viper_ast.py @@ -148,6 +148,12 @@ def Method(self, name, args, returns, pres, posts, locals, body, position, body_with_locals, position, info, self.NoTrafos) + def from_option(self, option): + if option == self.none: + return None + else: + return option.get() + def Field(self, name, type, position, info): return self.ast.Field(name, type, position, info, self.NoTrafos) @@ -254,6 +260,9 @@ def FieldAccessPredicate(self, fieldacc, perm, position, info): def Old(self, expr, position, info): return self.ast.Old(expr, position, info, self.NoTrafos) + def LabelledOld(self, expr, label, position, info): + return self.ast.LabelledOld(expr, label, position, info, self.NoTrafos) + def Inhale(self, expr, position, info): return self.ast.Inhale(expr, position, info, self.NoTrafos) @@ -468,12 +477,6 @@ def While(self, cond, invariants, locals, body, position, info): def Let(self, variable, exp, body, position, info): return self.ast.Let(variable, exp, body, position, info, self.NoTrafos) - def from_option(self, option): - if option == self.none: - return None - else: - return option.get() - def to_function0(self, func): func0 = Function0() func0.apply = types.MethodType(func, func0) diff --git a/src/nagini_translation/translator.py b/src/nagini_translation/translator.py index e068b6e21..b22983b81 100644 --- a/src/nagini_translation/translator.py +++ b/src/nagini_translation/translator.py @@ -22,6 +22,7 @@ TranslatorConfig, ) from nagini_translation.translators.call import CallTranslator +from nagini_translation.translators.call_slots.call_slot_translator import CallSlotTranslator from nagini_translation.translators.contract import ContractTranslator from nagini_translation.translators.expression import ExpressionTranslator from nagini_translation.translators.io_operation import ( @@ -59,6 +60,8 @@ def __init__(self, jvm: JVM, source_file: str, type_info: TypeInfo, type_info, viper_ast) config.call_translator = CallTranslator(config, jvm, source_file, type_info, viper_ast) + config.call_slot_translator = CallSlotTranslator(config, jvm, source_file, + type_info, viper_ast) config.contract_translator = ContractTranslator(config, jvm, source_file, type_info, viper_ast) diff --git a/src/nagini_translation/translators/abstract.py b/src/nagini_translation/translators/abstract.py index 5436b26f7..a694aad12 100644 --- a/src/nagini_translation/translators/abstract.py +++ b/src/nagini_translation/translators/abstract.py @@ -51,6 +51,7 @@ def __init__(self, translator: 'Translator'): self.expr_translator = None self.stmt_translator = None self.call_translator = None + self.call_slot_translator = None self.contract_translator = None self.perm_translator = None self.pure_translator = None @@ -367,3 +368,20 @@ def translate_args(self, target: PythonMethod, arg_nodes: List, keywords, node, ctx, implicit_receiver) + def translate_call_slot_check(self, target: PythonMethod, args: List[Expr], + formal_args: List[Expr], arg_stmts: List[Stmt], + position: 'silver.ast.Position', node: ast.AST, + ctx: Context) -> StmtsAndExpr: + return self.config.call_slot_translator.translate_call_slot_check( + target, args, formal_args, arg_stmts, position, node, ctx + ) + + def translate_call_slot_application(self, call: ast.Call, ctx: Context) -> StmtsAndExpr: + return self.config.call_slot_translator.translate_call_slot_application(call, ctx) + + def translate_call_slot_proof(self, proof: ast.FunctionDef, ctx: Context) -> List[Stmt]: + return self.config.call_slot_translator.translate_call_slot_proof(proof, ctx) + + def translate_call_slot(self, call_slot: 'CallSlot', ctx: Context + ) -> Tuple['silver.ast.Function', 'silver.ast.Method']: + return self.config.call_slot_translator.translate_call_slot(call_slot, ctx) diff --git a/src/nagini_translation/translators/call.py b/src/nagini_translation/translators/call.py index b1619cbcd..00449aff5 100644 --- a/src/nagini_translation/translators/call.py +++ b/src/nagini_translation/translators/call.py @@ -21,12 +21,12 @@ DICT_TYPE, END_LABEL, ERROR_NAME, + FUNCTION_DOMAIN_NAME, GET_ARG_FUNC, GET_METHOD_FUNC, GET_OLD_FUNC, JOINABLE_FUNC, LIST_TYPE, - METHOD_ID_DOMAIN, OBJECT_TYPE, RANGE_TYPE, RESULT_NAME, @@ -39,6 +39,7 @@ ) from nagini_translation.lib.errors import rules from nagini_translation.lib.program_nodes import ( + CallSlot, GenericType, MethodType, PythonClass, @@ -924,6 +925,9 @@ def translate_normal_call(self, target: PythonMethod, arg_stmts: List[Stmt], raise InvalidProgramException(node, 'invalid.contract.position') return arg_stmts, self.create_predicate_access(target_name, args, perm, node, ctx) + elif isinstance(target, CallSlot): + return self.translate_call_slot_check(target, args, formal_args, + arg_stmts, position, node, ctx) elif target.pure: return self._translate_function_call(target, args, formal_args, arg_stmts, position, node, ctx) @@ -956,6 +960,8 @@ def translate_Call(self, node: ast.Call, ctx: Context, impure=False) -> StmtsAnd if is_name: if func_name in CONTRACT_WRAPPER_FUNCS: raise InvalidProgramException(node, 'invalid.contract.position') + elif func_name == 'ClosureCall': + return self.translate_call_slot_application(node, ctx) elif func_name in CONTRACT_FUNCS: return self.translate_contractfunc_call(node, ctx, impure) elif func_name in IO_CONTRACT_FUNCS: @@ -1095,11 +1101,11 @@ def _translate_thread_creation(self, node: ast.Call, arg_type_check = self.type_check(arg_val, method_args[i].type, pos, ctx) arg_type_checks = self.viper.EqCmp(arg_type_checks, arg_type_check, pos, info) - method_id_type = self.viper.DomainType(METHOD_ID_DOMAIN, {}, []) + method_id_type = self.viper.DomainType(FUNCTION_DOMAIN_NAME, {}, []) thread_method = self.viper.DomainFuncApp(GET_METHOD_FUNC, [thread], method_id_type, pos, info, THREAD_DOMAIN) - actual_method = self.viper.DomainFuncApp(target.threading_id, [], method_id_type, - pos, info, METHOD_ID_DOMAIN) + actual_method = self.viper.DomainFuncApp(target.func_constant, [], method_id_type, + pos, info, FUNCTION_DOMAIN_NAME) inhale_method = self.viper.Inhale(self.viper.EqCmp(thread_method, actual_method, pos, info), pos, info) @@ -1143,14 +1149,14 @@ def _translate_thread_start(self, node: ast.Call, options_pos = self.to_position(node, ctx, rules=rules.THREAD_START_METHOD_UNLISTED) correct_method = self.viper.FalseLit(options_pos, info) - method_id_type = self.viper.DomainType(METHOD_ID_DOMAIN, {}, []) + method_id_type = self.viper.DomainType(FUNCTION_DOMAIN_NAME, {}, []) actual_method = self.viper.DomainFuncApp(GET_METHOD_FUNC, [thread], method_id_type, options_pos, info, THREAD_DOMAIN) for method in method_options: - this_method = self.viper.DomainFuncApp(method.threading_id, [], + this_method = self.viper.DomainFuncApp(method.func_constant, [], method_id_type, options_pos, info, - METHOD_ID_DOMAIN) + FUNCTION_DOMAIN_NAME) this_option = self.viper.EqCmp(actual_method, this_method, options_pos, info) correct_method = self.viper.Or(correct_method, this_option, options_pos, info) stmts.append(self.viper.Assert(correct_method, options_pos, info)) @@ -1234,7 +1240,7 @@ def _inhale_possible_thread_post(self, method: PythonMethod, thread: Expr, method_stmts = [] info = self.no_info(ctx) else_block = self.translate_block([], pos, info) - method_id_type = self.viper.DomainType(METHOD_ID_DOMAIN, {}, []) + method_id_type = self.viper.DomainType(FUNCTION_DOMAIN_NAME, {}, []) actual_method = self.viper.DomainFuncApp(GET_METHOD_FUNC, [thread], method_id_type, pos, info, THREAD_DOMAIN) @@ -1295,9 +1301,9 @@ def _inhale_possible_thread_post(self, method: PythonMethod, thread: Expr, # is the actual method. method_stmts.append(self.viper.Inhale(post_assertion, pos, info)) then_block = self.translate_block(method_stmts, pos, info) - this_method = self.viper.DomainFuncApp(method.threading_id, [], + this_method = self.viper.DomainFuncApp(method.func_constant, [], method_id_type, pos, info, - METHOD_ID_DOMAIN) + FUNCTION_DOMAIN_NAME) correct_method = self.viper.EqCmp(actual_method, this_method, pos, info) cond = self.viper.And(any_perm, correct_method, pos, info) return self.viper.If(cond, then_block, else_block, pos, info) diff --git a/src/nagini_translation/translators/call_slots/call_slot_translator.py b/src/nagini_translation/translators/call_slots/call_slot_translator.py new file mode 100644 index 000000000..99ce279a2 --- /dev/null +++ b/src/nagini_translation/translators/call_slots/call_slot_translator.py @@ -0,0 +1,716 @@ +import ast +from copy import deepcopy +from functools import reduce +from itertools import chain +from typing import Dict, List, Tuple, Union +from nagini_translation.lib.constants import ERROR_NAME, PRIMITIVES +from nagini_translation.lib.context import Context +from nagini_translation.lib.program_nodes import ( + CallSlot, + CallSlotProof, + PythonMethod, + PythonVar, + TypeInfo +) +from nagini_translation.lib.typedefs import ( + Expr, + Stmt, + StmtsAndExpr, + Function, + Method, + Position +) +from nagini_translation.lib.util import InvalidProgramException +from nagini_translation.translators.common import CommonTranslator +from nagini_translation.call_slot_analyzers import ( + is_call_slot_proof, + is_closure_call, + is_precondition, + is_postcondition, + is_fold, + is_unfold, + is_assume, + is_assert, +) + + +class CallSlotTranslator(CommonTranslator): + + def __init__(self, config: 'TranslatorConfig', jvm: 'JVM', source_file: str, + type_info: TypeInfo, viper_ast: 'ViperAST') -> None: + super().__init__(config, jvm, source_file, type_info, viper_ast) + self._var_replacer = _VarReplacer() + + def translate_call_slot(self, call_slot: CallSlot, ctx: Context) -> Tuple[Function, Union[Method, Function]]: + + old_function = ctx.current_function + ctx.current_function = call_slot + + position = self.to_position(call_slot.node, ctx) + info = self.no_info(ctx) + + call_slot_holds = self.viper.Function( + call_slot.sil_name, + [arg.decl for arg in call_slot.get_args()], + self.viper.Bool, + [], + [], + None, + position, + info + ) + + args = [arg.decl for arg in chain(call_slot.get_args(), call_slot.uq_variables.values())] + + if call_slot.pure: + + pres = [] + for arg in chain(call_slot.get_args(), call_slot.uq_variables.values()): + if arg.type.name in PRIMITIVES: + continue + pres.append(self.type_check(arg.ref(), arg.type, position, ctx)) + for pre, _ in call_slot.precondition: + pres.append(self._translate_pure_expr( + pre, ctx, target_type=self.viper.Bool, impure=True)) + + if call_slot.pure and call_slot.type is not None: + old_posts = call_slot.postcondition + call_slot.postcondition = [ + ( + self._var_replacer.replace(deepcopy(post), { + call_slot.return_variables[0].id: + ast.Call(ast.Name('Result', ast.Load), [], []) + }), + aliases + ) + for post, aliases in old_posts + ] + posts = [] + + if call_slot.type is not None and call_slot.type.name not in PRIMITIVES: + viper_type = self.translate_type(call_slot.type, ctx) + posts.append(self.type_check( + self.viper.Result(viper_type, position, info), + call_slot.type, position, ctx + )) + + for post, _ in call_slot.postcondition: + posts.append(self._translate_pure_expr( + post, ctx, target_type=self.viper.Bool, impure=True)) + + if call_slot.pure and call_slot.type is not None: + call_slot.postcondition = old_posts + + _type = self.translate_type(call_slot.type, ctx) + call_slot_apply = self.viper.Function( + call_slot.sil_application_name, args, _type, pres, posts, + None, position, info + ) + + else: + pres, posts = self.extract_contract(call_slot, ERROR_NAME, False, ctx) + call_slot_apply = self.create_method_node( + ctx, + call_slot.sil_application_name, + args, + [res.decl for res in call_slot.get_results()], + pres, + posts, + [], + [self.viper.Inhale(self.viper.FalseLit(position, info), position, info)], + position, + info, + method=call_slot + ) + + ctx.current_function = old_function + + return call_slot_holds, call_slot_apply + + def translate_call_slot_check( + self, target: CallSlot, args: List[Expr], + formal_args: List[Expr], arg_stmts: List[Stmt], + position: 'silver.ast.Position', node: ast.AST, + ctx: Context + ) -> StmtsAndExpr: + check = self.viper.FuncApp( + target.sil_name, + args, + position, + self.no_info(ctx), + self.viper.Bool, + formal_args + ) + return arg_stmts, check + + def translate_call_slot_application(self, closureCall: ast.Call, ctx: Context) -> StmtsAndExpr: + call, justification = closureCall.args + + assert isinstance(call, ast.Call) + assert isinstance(justification, (ast.Call, ast.Name)) + if isinstance(justification, ast.Call): + return self._application_call_slot(call, justification, ctx) + else: + return self._application_static_dispatch(call, justification, ctx) + + def _application_call_slot(self, call: ast.Call, justification: ast.Call, ctx: Context) -> StmtsAndExpr: + assert isinstance(call.func, ast.Name) + assert isinstance(justification.func, ast.Call) + assert isinstance(justification.func.func, ast.Name) + + call_slot = ctx.module.call_slots[justification.func.func.id] + + stmts = [self._application_call_slot_justification(call_slot, justification, ctx)] + + assert len(call_slot.get_args()) == len(justification.func.args) + arg_map = { + py_var.name: arg + for py_var, arg in zip( + chain(call_slot.get_args(), call_slot.uq_variables.values()), + chain(justification.func.args, justification.args) + ) + } + + assert len(call.args) == len(call_slot.call.args) + for arg_call, arg_slot in zip(call.args, call_slot.call.args): + stmts.append(self._application_call_slot_arg_match( + arg_call, arg_slot, arg_map, ctx)) + + stmts.append(self._application_call_slot_arg_match( + call.func, call_slot.call.func, arg_map, ctx)) + + formal_args = [] + if call_slot.pure: + for arg in call_slot.get_args(): + formal_args.append(arg.decl) + for arg in call_slot.uq_variables.values(): + formal_args.append(arg.decl) + + call_stmts, call_expr = self._translate_normal_call( + call_slot.sil_application_name, call_slot, justification.func.args + justification.args, + self.to_position(call, ctx), ctx, pure_call=call_slot.pure, formal_args=formal_args + ) + + return stmts + call_stmts, call_expr + + def _application_call_slot_justification( + self, + call_slot: CallSlot, + justification: ast.Call, + ctx: Context + ) -> Stmt: + + assert isinstance(justification.func, ast.Call) # uq vars + justification = deepcopy(justification) + + justification.func.id = call_slot.sil_name + expr = self._translate_pure_expr(justification.func, ctx, target_type=self.viper.Bool) + return self.viper.Assert(expr, self.to_position(justification.func, ctx), self.no_info(ctx)) + + def _application_call_slot_arg_match( + self, + call_arg: ast.expr, + slot_arg: ast.expr, + arg_map: Dict[str, ast.expr], + ctx: Context + ) -> Stmt: + + slot_arg = deepcopy(slot_arg) + slot_arg = self._var_replacer.replace(slot_arg, arg_map) + + viper_call_arg = self._translate_pure_expr(call_arg, ctx) + + viper_slot_arg = self._translate_pure_expr(slot_arg, ctx) + + return self.viper.Assert( + self.viper.EqCmp( + viper_call_arg, + viper_slot_arg, + self.to_position(call_arg, ctx), + self.no_info(ctx) + ), + self.to_position(call_arg, ctx), + self.no_info(ctx) + ) + + def _translate_normal_call( + self, + name: str, + target: PythonMethod, + args: List[ast.expr], + position: Position, + ctx: Context, + pure_call: bool = False, + formal_args: List[Expr] = [] + ) -> StmtsAndExpr: + + result_var = None + if target.type is not None and not pure_call: + result_var = ctx.current_function.create_variable( + target.name + '_res', target.type, self.translator).ref() + + stmts = [] # type: List[Stmt] + arg_exprs = [] # type: List[Expr] + + for arg in args: + expr = self._translate_pure_expr(arg, ctx) + arg_exprs.append(expr) + + if pure_call: + _type = self.translate_type(target.type, ctx) + expr = self.viper.FuncApp( + name, arg_exprs, position, self.no_info(ctx), _type, formal_args + ) + else: + stmts = self.create_method_call_node( + ctx, name, arg_exprs, [result_var] if result_var else [], + position, self.no_info(ctx), target_method=target + ) + expr = result_var + + return stmts, expr + + def _application_static_dispatch(self, call: ast.Call, justification: ast.Name, ctx: Context) -> StmtsAndExpr: + + stmts = [] # type: List[Stmt] + position = self.to_position(call, ctx) + info = self.no_info(ctx) + target = self.get_target(justification, ctx) + assert isinstance(target, PythonMethod) + + closure_expr = self._translate_pure_expr(call.func, ctx) + + justification_expr = self._translate_pure_expr(justification, ctx) + + stmts.append(self.viper.Assert( + self.viper.EqCmp( + closure_expr, + justification_expr, + position, + info + ), + position, + info + )) + + method = ctx.module.get_func_or_method(justification.id) + + formal_args = [] + if target.pure: + for arg in target.get_args(): + formal_args.append(arg.decl) + + call_stmts, call_expr = self._translate_normal_call( + method.sil_name, method, call.args, self.to_position(call, ctx), + ctx, pure_call=target.pure, formal_args=formal_args + ) + stmts.extend(call_stmts) + + return stmts, call_expr + + def translate_call_slot_proof(self, proof_node: ast.FunctionDef, ctx: Context) -> List[Stmt]: + assert is_call_slot_proof(proof_node) + + proof = ctx.current_function.call_slot_proofs[proof_node] + old_proof = ctx.current_call_slot_proof + ctx.current_call_slot_proof = proof + + call_slot = self._get_call_slot(proof, ctx) + cl_map = self._get_cl_map(proof, call_slot) + + with ctx.aliases_context(): + vars_stmts = self._proof_extract_vars(proof, ctx) + + body_stmts = self._proof_translate_body(proof, call_slot, cl_map, ctx) + + while_loop = self._proof_create_non_deterministic_while_loop( + proof, body_stmts, ctx + ) + + instantiation = deepcopy(proof.call_slot_instantiation) + instantiation.func.id = call_slot.sil_name + + instantiation_expr = self._translate_pure_expr( + instantiation, ctx, target_type=self.viper.Bool) + + instantiation_stmt = self.viper.Inhale( + instantiation_expr, self.to_position(instantiation, ctx), self.no_info(ctx) + ) + + ctx.current_call_slot_proof = old_proof + + return vars_stmts + [while_loop] + [instantiation_stmt] + + def _proof_extract_vars(self, proof: CallSlotProof, ctx: Context) -> List[Stmt]: + + vars = proof.get_args() + values = proof.call_slot_instantiation.args + + stmts = [] # type: List[Stmt] + + assert len(vars) == len(values) + for var, value in zip(vars, values): + stmts.extend(self._proof_extract_var(var, value, ctx)) + + return stmts + + def _get_call_slot(self, proof: CallSlotProof, ctx: Context) -> CallSlot: + call_slot_name = proof.call_slot_instantiation.func.id + if call_slot_name not in ctx.module.call_slots: + raise InvalidProgramException( + proof.node, + 'call_slots.proof_annotation.invalid_call_slot' + ) + + return ctx.module.call_slots[call_slot_name] + + def _get_cl_map(self, proof: CallSlotProof, call_slot: CallSlot) -> Dict[str, ast.expr]: + + proof_nv = proof.args.values() + cl_nv = call_slot.args.values() + if len(proof_nv) != len(cl_nv): + raise InvalidProgramException( + proof.node, + 'call_slots.proof_annotation.invalid_call_slot' + ) + + proof_uqv = proof.uq_variables.values() + cl_uqv = call_slot.uq_variables.values() + if len(proof_uqv) != len(cl_uqv): + raise InvalidProgramException( + proof.node, + 'call_slots.proof_annotation.invalid_call_slot' + ) + + proof_rv = [proof.locals[rv.id] for rv in proof.return_variables] + cl_rv = [call_slot.locals[rv.id] for rv in call_slot.return_variables] + if len(proof_rv) != len(cl_rv): + raise InvalidProgramException( + proof.node, + 'call_slots.proof_annotation.invalid_call_slot' + ) + + return { + cl_var.name: ast.Name( + proof_var.name, + ast.Load, + lineno=proof_var.node.lineno, + col_offset=proof_var.node.col_offset, + ) + for cl_var, proof_var in zip( + chain(cl_nv, cl_uqv, cl_rv), + chain(proof_nv, proof_uqv, proof_rv) + ) + } + + def _proof_extract_var(self, var: PythonVar, val: ast.expr, ctx: Context) -> Stmt: + viper_val = self._translate_pure_expr(val, ctx) + + proof_var = ctx.current_function.create_variable( + '__proof_' + var.name, var.type, self.config.translator + ) + + stmts = [] # type: List[Stmt] + position = self.to_position(val, ctx) + info = self.no_info(ctx) + stmts.append(self.set_var_defined( + proof_var, position, info)) + + if proof_var.type.name not in PRIMITIVES: + stmts.append(self.viper.Inhale( + self.var_type_check( + proof_var.sil_name, proof_var.type, position, ctx + ), + position, + info + )) + + ctx.set_alias(var.name, proof_var, var) + + stmts.append(self.viper.LocalVarAssign( + proof_var.ref(proof_var.node, ctx), + viper_val, + self.to_position(proof_var.node, ctx), + self.no_info(ctx) + )) + + return stmts + + def _proof_create_non_deterministic_while_loop( + self, + proof: CallSlotProof, + body: List[Stmt], + ctx: Context + ) -> Stmt: + + non_deterministic_bool = ctx.current_function.create_variable( + '__proof_non_deterministic_choice', ctx.module.global_module.classes['bool'].try_unbox(), self.translator + ) + + position = self.to_position(proof.node, ctx) + info = self.no_info(ctx) + + return self.viper.While( + non_deterministic_bool.ref(), [], [], + self.translate_block(body, position, info), + position, info + ) + + def _proof_translate_body( + self, + proof: CallSlotProof, + call_slot: CallSlot, + cl_map: Dict[str, ast.expr], + ctx: Context + ) -> List[Stmt]: + + stmts = [] # type: List[Stmt] + position = self.to_position(proof.node, ctx) + info = self.no_info(ctx) + + stmts.extend(self._proof_introduce_uq_ret_vars(proof, ctx)) + + stmts.append(self.viper.Inhale( + self._proof_translate_contract(proof, call_slot.precondition, cl_map, ctx), + position, + info + )) + + stmts.append(self.viper.Label(proof.old_label, position, info)) + + call_counter = ctx.current_function.create_variable( + '__proof_call_counter', ctx.module.global_module.classes['int'].try_unbox(), self.translator + ) + + stmts.append(self.viper.LocalVarAssign( + call_counter.ref(), + self.viper.IntLit(0, position, info), + position, + info + )) + + stmts.extend(self._proof_translate_body_only( + proof.body, + proof, + call_slot, + call_counter, + cl_map, + ctx + )) + + stmts.append(self.viper.Exhale( + self._proof_translate_contract(proof, call_slot.postcondition, cl_map, ctx), + position, + info + )) + + stmts.append(self.viper.Assert( + self.viper.EqCmp( + call_counter.ref(), + self.viper.IntLit(1, position, info), + position, + info + ), + position, + info + )) + + return stmts + + def _proof_introduce_uq_ret_vars(self, proof: CallSlotProof, ctx: Context) -> List[Stmt]: + stmts = [] # type: List[Stmt] + + for var in proof.uq_variables.values(): + proof_var = ctx.current_function.create_variable( + '__proof_' + var.name, var.type, self.translator + ) + + position = self.to_position(proof_var.node, ctx) + info = self.no_info(ctx) + + stmts.append(self.set_var_defined(proof_var, position, info)) + + if proof_var.type.name not in PRIMITIVES: + stmts.append(self.viper.Inhale( + self.var_type_check( + proof_var.sil_name, proof_var.type, position, ctx + ), + position, + info + )) + + ctx.set_alias(var.name, proof_var, var) + + if proof.return_variables: + ret_var = proof.locals[proof.return_variables[0].id] + proof_var = ctx.current_function.create_variable( + '__proof_' + ret_var.name, ret_var.type, self.translator + ) + ctx.set_alias(ret_var.name, proof_var, ret_var) + + return stmts + + def _proof_translate_contract( + self, + proof: CallSlotProof, + contract: List[Tuple[ast.expr, Dict]], + cl_map: Dict[str, ast.expr], + ctx: Context + ) -> Expr: + + position = self.to_position(proof.node, ctx) + info = self.no_info(ctx) + + contract = [ + self._translate_pure_expr(self._var_replacer.replace(deepcopy(pre[0]), cl_map), + ctx, impure=True, target_type=self.viper.Bool) + for pre in contract + ] + + return reduce( + lambda left, right: self.viper.And(left, right, position, info), + contract, self.viper.TrueLit(position, info) + ) + + def _proof_translate_body_only( + self, + body: List[ast.stmt], + proof: CallSlotProof, + call_slot: CallSlot, + call_counter: PythonVar, + cl_map: Dict[str, ast.expr], + ctx: Context + ) -> List[Stmt]: + + viper_stmts = [] # type: List[Stmt] + + for stmt in body: + + if isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.Call): + if is_precondition(stmt.value) or is_postcondition(stmt.value): + continue # ignore + if is_fold(stmt.value) or is_unfold(stmt.value): + viper_stmts.extend(self.translate_stmt(stmt, ctx)) + continue + if is_assume(stmt.value) or is_assert(stmt.value): + viper_stmts.extend(self.translate_stmt(stmt, ctx)) + continue + + assert is_closure_call(stmt.value) + + viper_stmts.extend(self._proof_translate_application( + stmt, call_slot, call_counter, cl_map, ctx + )) + + elif isinstance(stmt, ast.Assign): + + assert is_closure_call(stmt.value) + + viper_stmts.extend(self._proof_translate_application( + stmt, call_slot, call_counter, cl_map, ctx + )) + + elif isinstance(stmt, ast.Assert): + viper_stmts.extend(self.translate_stmt(stmt, ctx)) + + elif isinstance(stmt, ast.FunctionDef): + assert is_call_slot_proof(stmt) + viper_stmts.extend(self.translate_stmt(stmt, ctx)) + + elif isinstance(stmt, ast.If): + position = self.to_position(stmt, ctx) + info = self.no_info(ctx) + + cond = self._translate_pure_expr( + stmt.test, ctx, target_type=self.viper.Bool + ) + + then_body = self._proof_translate_body_only( + stmt.body, proof, call_slot, call_counter, cl_map, ctx + ) + then_block = self.translate_block(then_body, position, info) + + else_body = self._proof_translate_body_only( + stmt.orelse, proof, call_slot, call_counter, cl_map, ctx + ) + else_block = self.translate_block(else_body, position, info) + + viper_stmts.append(self.viper.If( + cond, then_block, else_block, position, info + )) + + else: + assert False + + return viper_stmts + + def _proof_translate_application( + self, + stmt: Union[ast.Assign, ast.Expr], + call_slot: CallSlot, + call_counter: PythonVar, + cl_map: Dict[str, ast.expr], + ctx: Context + ) -> List[Stmt]: + + stmts = [] # type: List[Stmt] + + for arg_call, arg_slot in zip(stmt.value.args[0].args, call_slot.call.args): + stmts.append(self._application_call_slot_arg_match( + arg_call, arg_slot, cl_map, ctx + )) + stmts.append(self._application_call_slot_arg_match( + stmt.value.args[0].func, call_slot.call.func, cl_map, ctx + )) + + if call_slot.pure: + if isinstance(stmt.value.args[1], ast.Name): + target = self.get_target(stmt.value.args[1], ctx) + else: + target = self.get_target(stmt.value.args[1].func.func, ctx) + + if not target.pure: + raise InvalidProgramException( + stmt.value, + 'call_slots.impure_closure_call.inside_pure_proof' + ) + + stmts.extend(self.translate_stmt(stmt, ctx)) + + position = self.to_position(stmt, ctx) + info = self.no_info(ctx) + stmts.append(self.viper.LocalVarAssign( + call_counter.ref(), + self.viper.Add( + call_counter.ref(), + self.viper.IntLit(1, position, info), + position, + info + ), + position, + info + )) + + return stmts + + def _translate_pure_expr( + self, + node: ast.expr, + ctx: Context, + target_type: object = None, + impure: bool = False + ) -> Expr: + stmts, expr = self.translate_expr(node, ctx, target_type=target_type, impure=impure) + if stmts: + raise InvalidProgramException(node, 'purity.violated') + return expr + + +class _VarReplacer(ast.NodeTransformer): + + def replace(self, node: ast.expr, arg_map: Dict[str, ast.expr]) -> ast.expr: + self.arg_map = arg_map + return self.visit(node) + + def visit_Name(self, name: ast.Name) -> ast.expr: + return deepcopy(self.arg_map[name.id]) if name.id in self.arg_map else name diff --git a/src/nagini_translation/translators/common.py b/src/nagini_translation/translators/common.py index b731fc5a8..a72cf5776 100644 --- a/src/nagini_translation/translators/common.py +++ b/src/nagini_translation/translators/common.py @@ -32,6 +32,7 @@ PythonNode, PythonType, PythonVar, + CallSlot, ) from nagini_translation.lib.resolver import get_target as do_get_target from nagini_translation.lib.typedefs import ( @@ -240,6 +241,12 @@ def is_local_variable(self, var: PythonVar, ctx: Context) -> bool: return False if var.name in ctx.actual_function.args: return False + if ( + isinstance(ctx.actual_function, CallSlot) and + ctx.actual_function.result and + var.name == ctx.actual_function.result.name + ): + return False return var in ctx.actual_function.locals.values() def get_may_set_predicate(self, rec: Expr, field: PythonField, ctx: Context, diff --git a/src/nagini_translation/translators/contract.py b/src/nagini_translation/translators/contract.py index 2bf97b48f..1b62eb548 100644 --- a/src/nagini_translation/translators/contract.py +++ b/src/nagini_translation/translators/contract.py @@ -11,13 +11,13 @@ from nagini_translation.lib.constants import ( BOOL_TYPE, BUILTIN_PREDICATES, + FUNCTION_DOMAIN_NAME, GET_ARG_FUNC, GET_METHOD_FUNC, GET_OLD_FUNC, GLOBAL_VAR_FIELD, INT_TYPE, JOINABLE_FUNC, - METHOD_ID_DOMAIN, PRIMITIVES, RANGE_TYPE, SEQ_TYPE, @@ -377,8 +377,15 @@ def translate_old(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: return [], ctx.old_expr_aliases[key] stmt, exp = self.translate_expr(node.args[0], ctx) - res = self.viper.Old(exp, self.to_position(node, ctx), - self.no_info(ctx)) + + if not ctx.current_call_slot_proof: + res = self.viper.Old(exp, self.to_position(node, ctx), + self.no_info(ctx)) + else: + res = self.viper.LabelledOld( + exp, ctx.current_call_slot_proof.old_label, + self.to_position(node, ctx), self.no_info(ctx) + ) return stmt, res def translate_fold(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: @@ -586,7 +593,7 @@ def translate_get_arg(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: target_type = self.viper.Int) pos = self.to_position(node, ctx) info = self.no_info(ctx) - method_id_type = self.viper.DomainType(METHOD_ID_DOMAIN, {}, []) + method_id_type = self.viper.DomainType(FUNCTION_DOMAIN_NAME, {}, []) func = self.viper.DomainFuncApp(GET_ARG_FUNC, [thread, index], self.viper.Ref, pos, info, THREAD_DOMAIN) return thread_stmt + index_stmt, func diff --git a/src/nagini_translation/translators/expression.py b/src/nagini_translation/translators/expression.py index 90b0e8cd8..7f06018a0 100644 --- a/src/nagini_translation/translators/expression.py +++ b/src/nagini_translation/translators/expression.py @@ -19,8 +19,8 @@ GET_METHOD_FUNC, GLOBAL_VAR_FIELD, INT_TYPE, + CALLABLE_TYPE, LIST_TYPE, - METHOD_ID_DOMAIN, OPERATOR_FUNCTIONS, PRIMITIVE_INT_TYPE, SET_TYPE, @@ -148,7 +148,7 @@ def _translate_only(self, node: ast.AST, ctx: Context, impure=False): def translate_Return(self, node: ast.Return, ctx: Context, impure=False) -> StmtsAndExpr: - return self.translate_expr(node.value, ctx, impure=impure) + return self.translate_expr(node.value, ctx, impure=impure, target_type=self._target_type) def translate_ListComp(self, node: ast.ListComp, ctx: Context) -> StmtsAndExpr: if len(node.generators) != 1: @@ -630,7 +630,8 @@ def translate_Name(self, node: ast.Name, ctx: Context) -> StmtsAndExpr: if (isinstance(ctx.actual_function, PythonMethod) and not (ctx.actual_function.pure or ctx.actual_function.predicate) and not isinstance(node.ctx, ast.Store) and - self.is_local_variable(var, ctx)): + self.is_local_variable(var, ctx) and + var.type.name != 'Callable'): result = self.wrap_definedness_check(var.ref(node, ctx), var, node, ctx) else: result = var.ref(node, ctx) @@ -831,7 +832,7 @@ def _is_primitive_operation(self, op: ast.operator, left_type: PythonType, translated as a native silver binary operation. True iff both types are identical and primitives. """ - if op not in self._primitive_operations: + if type(op) not in self._primitive_operations: return False left_type_boxed = left_type.python_class.try_box() right_type_boxed = right_type.python_class.try_box() @@ -848,6 +849,8 @@ def _translate_primitive_operation(self, left: Expr, right: Expr, op = self._primitive_operations[type(op)] if op_type.python_class.try_box().name == INT_TYPE: wrap = self.to_int + elif op_type.python_class.try_box().name == CALLABLE_TYPE: + wrap = lambda node, ctx: node else: wrap = self.to_bool result = op(wrap(left, ctx), wrap(right, ctx), pos, self.no_info(ctx)) @@ -889,6 +892,13 @@ def is_thread_method_definition(self, node: ast.Compare, ctx: Context) -> bool: return True return False + def is_callable_equality(self, node: ast.Compare, ctx: Context) -> bool: + if len(node.ops) != 1 or len(node.comparators) != 1: + return False + if not isinstance(node.ops[0], (ast.Eq, ast.Is, ast.NotEq, ast.IsNot)): + return False + # TODO + def is_type_equality(self, node: ast.Compare, ctx: Context) -> bool: """ Checks if a comparison checks the equality of the type of an object with @@ -946,12 +956,12 @@ def translate_thread_method_definition(self, node: ast.Compare, thread_stmt, thread = self.translate_expr(get_call.args[0], ctx) pos = self.to_position(node, ctx) info = self.no_info(ctx) - method_id_type = self.viper.DomainType(METHOD_ID_DOMAIN, {}, []) + method_id_type = self.viper.DomainType(FUNCTION_DOMAIN_NAME, {}, []) thread_method = self.viper.DomainFuncApp(GET_METHOD_FUNC, [thread], method_id_type, pos, info, THREAD_DOMAIN) - method_literal = self.viper.DomainFuncApp(target_method.threading_id, [], + method_literal = self.viper.DomainFuncApp(target_method.func_constant, [], method_id_type, - pos, info, METHOD_ID_DOMAIN) + pos, info, FUNCTION_DOMAIN_NAME) comparison = self.viper.EqCmp(thread_method, method_literal, pos, info) return thread_stmt, comparison @@ -967,6 +977,8 @@ def translate_Compare(self, node: ast.Compare, return self.translate_thread_method_definition(node, ctx) if self.is_type_equality(node, ctx): return self.translate_type_equality(node, ctx) + # if self.is_callable_equality(node, ctx): + # return self.translate_callable_equality(node, ctx) if len(node.ops) != 1 or len(node.comparators) != 1: raise UnsupportedException(node) left_stmt, left = self.translate_expr(node.left, ctx) diff --git a/src/nagini_translation/translators/method.py b/src/nagini_translation/translators/method.py index 1e47a2cb7..7b78a64ee 100644 --- a/src/nagini_translation/translators/method.py +++ b/src/nagini_translation/translators/method.py @@ -19,6 +19,7 @@ STRING_TYPE, ) from nagini_translation.lib.program_nodes import ( + CallSlot, GenericType, MethodType, PythonExceptionHandler, @@ -173,7 +174,11 @@ def _create_typeof_pres(self, func: PythonMethod, is_constructor: bool, """ Creates 'typeof' preconditions for function arguments. """ - args = func.get_args() + if isinstance(func, CallSlot): + args = func.get_args() + args.extend(func.uq_variables.values()) + else: + args = func.get_args() pres = [] for i, arg in enumerate(args): if not (arg.type.name in PRIMITIVES): diff --git a/src/nagini_translation/translators/obligation/fork.py b/src/nagini_translation/translators/obligation/fork.py index 796f272b0..c5b5babd7 100644 --- a/src/nagini_translation/translators/obligation/fork.py +++ b/src/nagini_translation/translators/obligation/fork.py @@ -6,11 +6,11 @@ from nagini_translation.lib import silver_nodes as sil from nagini_translation.lib.constants import ( + FUNCTION_DOMAIN_NAME, GET_ARG_FUNC, GET_METHOD_FUNC, GET_OLD_FUNC, JOINABLE_FUNC, - METHOD_ID_DOMAIN, OBJECT_TYPE, THREAD_DOMAIN, THREAD_POST_PRED, @@ -138,7 +138,7 @@ def _add_precondition_exhales(self) -> None: thread targets, inhale join permissions, remember values of old expressions in method postconditions. """ - method_id_type = self.viper.DomainType(METHOD_ID_DOMAIN, {}, []) + method_id_type = self.viper.DomainType(FUNCTION_DOMAIN_NAME, {}, []) actual_method = self.viper.DomainFuncApp(GET_METHOD_FUNC, [self._thread], method_id_type, self._position, self._info, THREAD_DOMAIN) @@ -188,9 +188,9 @@ def _add_precondition_exhales(self) -> None: # Do all this under the condition that the current method is the thread's # actual target method. - this_method = self.viper.DomainFuncApp(method.threading_id, [], + this_method = self.viper.DomainFuncApp(method.func_constant, [], method_id_type, self._position, - self._info, METHOD_ID_DOMAIN) + self._info, FUNCTION_DOMAIN_NAME) method_cond = self.viper.EqCmp(actual_method, this_method, self._position, self._info) then_block = self._translator.translate_block(stmts, self._position, diff --git a/src/nagini_translation/translators/program.py b/src/nagini_translation/translators/program.py index b59a59868..50c1408fa 100644 --- a/src/nagini_translation/translators/program.py +++ b/src/nagini_translation/translators/program.py @@ -22,7 +22,6 @@ IS_DEFINED_FUNC, JOINABLE_FUNC, MAY_SET_PRED, - METHOD_ID_DOMAIN, PRIMITIVES, RESULT_NAME, STRING_TYPE, @@ -592,20 +591,9 @@ def create_thread_predicates(self, ctx: Context) -> Function: self.no_position(ctx), self.no_info(ctx)) return [start_pred, post_pred] - def create_method_id_domain(self, constants: List['silver.ast.DomainFunc'], - ctx: Context) -> 'silver.ast.Domain': - return self.viper.Domain(METHOD_ID_DOMAIN, constants, [], [], - self.no_position(ctx), self.no_info(ctx)) - - def translate_method_id_to_constant(self, method, ctx) -> 'silver.ast.DomainFunc': - func_type = self.viper.DomainType(METHOD_ID_DOMAIN, {}, []) - return self.viper.DomainFunc(method.threading_id,[],func_type, True, - self.to_position(method.node,ctx), self.no_info(ctx), - METHOD_ID_DOMAIN) - def create_thread_domain(self, ctx: Context) -> 'silver.ast.Domain': pos, info = self.no_position(ctx), self.no_info(ctx) - method_id_type = self.viper.DomainType(METHOD_ID_DOMAIN, {}, []) + method_id_type = self.viper.DomainType(FUNCTION_DOMAIN_NAME, {}, []) thread_param = self.viper.LocalVarDecl('t', self.viper.Ref, pos, info) index_param = self.viper.LocalVarDecl('i', self.viper.Int, pos, info) get_method = self.viper.DomainFunc(GET_METHOD_FUNC, [thread_param], @@ -698,7 +686,6 @@ def translate_program(self, modules: List[PythonModule], static_fields = OrderedDict() func_constants = [] - threading_ids_constants = [] # Silver names of the set of nodes which have been selected by the user # to be verified (if any). selected_names = [] @@ -760,8 +747,7 @@ def translate_program(self, modules: List[PythonModule], functions.append(self.translate_function(function, ctx)) func_constants.append(self.translate_function_constant(function, ctx)) for method in module.methods.values(): - id_constant = self.translate_method_id_to_constant(method, ctx) - threading_ids_constants.append(id_constant) + func_constants.append(self.translate_function_constant(method, ctx)) if method.interface: continue self.track_dependencies(selected_names, selected, method, ctx) @@ -769,6 +755,13 @@ def translate_program(self, modules: List[PythonModule], for pred in module.predicates.values(): self.track_dependencies(selected_names, selected, pred, ctx) predicates.append(self.translate_predicate(pred, ctx)) + for call_slot in module.call_slots.values(): + call_slot_holds, call_slot_apply = self.translate_call_slot(call_slot, ctx) + functions.append(call_slot_holds) + if call_slot.pure: + functions.append(call_slot_apply) + else: + methods.append(call_slot_apply) for class_name, cls in module.classes.items(): if class_name in PRIMITIVES or class_name != cls.name: # Skip primitives and type variable entries. @@ -794,8 +787,7 @@ def translate_program(self, modules: List[PythonModule], 'invalid.override') for method_name in cls.methods: method = cls.methods[method_name] - threading_ids_constants.append( - self.translate_method_id_to_constant(method, ctx)) + func_constants.append(self.translate_function_constant(method, ctx)) if method.interface: continue self.track_dependencies(selected_names, selected, method, ctx) @@ -807,8 +799,7 @@ def translate_program(self, modules: List[PythonModule], methods.append(self.create_override_check(method, ctx)) for method_name in cls.static_methods: method = cls.static_methods[method_name] - threading_ids_constants.append( - self.translate_method_id_to_constant(method, ctx)) + func_constants.append(self.translate_function_constant(method, ctx)) self.track_dependencies(selected_names, selected, method, ctx) methods.append(self.translate_method(method, ctx)) if method.overrides: @@ -881,7 +872,6 @@ def translate_program(self, modules: List[PythonModule], type_axioms, ctx)) domains.append(self.create_thread_domain(ctx)) domains.append(self.create_functions_domain(func_constants, ctx)) - domains.append(self.create_method_id_domain(threading_ids_constants, ctx)) converted_sil_progs = self._convert_silver_elements(sil_progs, all_used_names, ctx) diff --git a/src/nagini_translation/translators/statement.py b/src/nagini_translation/translators/statement.py index 98acaa6ac..42acede4b 100644 --- a/src/nagini_translation/translators/statement.py +++ b/src/nagini_translation/translators/statement.py @@ -58,6 +58,7 @@ ) from nagini_translation.translators.abstract import Context from nagini_translation.translators.common import CommonTranslator +from nagini_translation.call_slot_analyzers import is_call_slot_proof from typing import List, Optional, Tuple, Union @@ -277,6 +278,9 @@ def translate_stmt_FunctionDef(self, node: ast.FunctionDef, dependencies of the declaration are defined, and subsequently an assignment that sets the function name to be defined. """ + if is_call_slot_proof(node) and not self.is_main_method(ctx): + return self.translate_call_slot_proof(node, ctx) + assert self.is_main_method(ctx) if ctx.current_class: method = ctx.current_class.get_func_or_method(node.name) diff --git a/tests/closures/translation/call_slot_declaration/test_arg_contract_shadowing.py b/tests/closures/translation/call_slot_declaration/test_arg_contract_shadowing.py new file mode 100644 index 000000000..76f2f04a0 --- /dev/null +++ b/tests/closures/translation/call_slot_declaration/test_arg_contract_shadowing.py @@ -0,0 +1,8 @@ +from typing import Callable +from nagini_contracts.contracts import CallSlot + + +@CallSlot +#:: ExpectedOutput(invalid.program:call_slots.parameters.illegal_name) +def call_slot(Acc: Callable[[int], None]) -> None: + Acc(2) diff --git a/tests/closures/translation/call_slot_declaration/test_double.py b/tests/closures/translation/call_slot_declaration/test_double.py new file mode 100644 index 000000000..2c45df7b5 --- /dev/null +++ b/tests/closures/translation/call_slot_declaration/test_double.py @@ -0,0 +1,26 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + UniversallyQuantified, + Requires, + Ensures +) + + +@CallSlot +def call_slot1(f: Callable[[int, int], int], x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + z = f(x, y) + Ensures(z == x + y) + +@CallSlot +def call_slot2(f: Callable[[int, int], int], x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + z = f(x, y) + Ensures(z == x + y) diff --git a/tests/closures/translation/call_slot_declaration/test_illegal_assign.py b/tests/closures/translation/call_slot_declaration/test_illegal_assign.py new file mode 100644 index 000000000..c935fbb0d --- /dev/null +++ b/tests/closures/translation/call_slot_declaration/test_illegal_assign.py @@ -0,0 +1,19 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + UniversallyQuantified, + Requires, + Ensures +) + + +@CallSlot +def call_slot(f: Callable[[int, int], int], x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + z = f(x, y) + #:: ExpectedOutput(invalid.program:call_slots.body.invalid_stmt) + x += 2 + Ensures(z == x + y) diff --git a/tests/closures/translation/call_slot_declaration/test_illegal_return_name.py b/tests/closures/translation/call_slot_declaration/test_illegal_return_name.py new file mode 100644 index 000000000..553ab03f8 --- /dev/null +++ b/tests/closures/translation/call_slot_declaration/test_illegal_return_name.py @@ -0,0 +1,18 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + UniversallyQuantified, + Requires, + Ensures +) + + +@CallSlot +def call_slot(f: Callable[[int, int], int], x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + #:: ExpectedOutput(invalid.program:call_slots.parameters.illegal_name) + Acc = f(x, y) + Ensures(Acc == x + y) diff --git a/tests/closures/translation/call_slot_declaration/test_magic_name.py b/tests/closures/translation/call_slot_declaration/test_magic_name.py new file mode 100644 index 000000000..bcde7834f --- /dev/null +++ b/tests/closures/translation/call_slot_declaration/test_magic_name.py @@ -0,0 +1,8 @@ +from typing import Callable +from nagini_contracts.contracts import CallSlot + + +#:: ExpectedOutput(invalid.program:illegal.magic.method) +@CallSlot +def __call_slot__(f: Callable[[int], None]) -> None: + f(2) diff --git a/tests/closures/translation/call_slot_declaration/test_method_in_contract.py b/tests/closures/translation/call_slot_declaration/test_method_in_contract.py new file mode 100644 index 000000000..d618fcc8f --- /dev/null +++ b/tests/closures/translation/call_slot_declaration/test_method_in_contract.py @@ -0,0 +1,31 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + UniversallyQuantified, + Requires, + Ensures, + Acc, + Fold +) + + +@CallSlot +def call_slot(f: Callable[[int, int], int], argm: 'Arg') -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + #:: ExpectedOutput(invalid.program:purity.violated) + Requires(is_arg(argm) and y > 0) + z = f(argm.val, y) + Ensures(z >= y) + + +class Arg: + + def __init__(self) -> None: + Ensures(Acc(self.val)) + self.val = 1 # type: int + + +def is_arg(arg: Arg) -> bool: + return True diff --git a/tests/closures/translation/call_slot_declaration/test_method_name_collision.py b/tests/closures/translation/call_slot_declaration/test_method_name_collision.py new file mode 100644 index 000000000..dd4edd4d6 --- /dev/null +++ b/tests/closures/translation/call_slot_declaration/test_method_name_collision.py @@ -0,0 +1,11 @@ +from typing import Callable +from nagini_contracts.contracts import CallSlot + + +def some_name() -> None: + pass + +#:: ExpectedOutput(type.error:Name 'some_name' already defined) +@CallSlot +def some_name(f: Callable[[int], None]) -> None: + f(2) diff --git a/tests/closures/translation/call_slot_declaration/test_nested_class.py b/tests/closures/translation/call_slot_declaration/test_nested_class.py new file mode 100644 index 000000000..6e925ef92 --- /dev/null +++ b/tests/closures/translation/call_slot_declaration/test_nested_class.py @@ -0,0 +1,10 @@ +from typing import Callable +from nagini_contracts.contracts import CallSlot + + +class class_with_call_slot_inside: + + #:: ExpectedOutput(invalid.program:call_slots.nested.declaration) + @CallSlot + def call_slot(self, f: Callable[[int], None]) -> None: + f(2) diff --git a/tests/closures/translation/call_slot_declaration/test_nested_method.py b/tests/closures/translation/call_slot_declaration/test_nested_method.py new file mode 100644 index 000000000..a3212fe89 --- /dev/null +++ b/tests/closures/translation/call_slot_declaration/test_nested_method.py @@ -0,0 +1,10 @@ +from typing import Callable +from nagini_contracts.contracts import CallSlot + + +def method_with_call_slot_inside() -> None: + + #:: ExpectedOutput(invalid.program:call_slots.nested.declaration) + @CallSlot + def call_slot(f: Callable[[int], None]) -> None: + f(2) diff --git a/tests/closures/translation/call_slot_declaration/test_predicate_in_contract.py b/tests/closures/translation/call_slot_declaration/test_predicate_in_contract.py new file mode 100644 index 000000000..190c63a5d --- /dev/null +++ b/tests/closures/translation/call_slot_declaration/test_predicate_in_contract.py @@ -0,0 +1,33 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + UniversallyQuantified, + Requires, + Ensures, + Acc, + Predicate, + Fold +) + + +@CallSlot +def call_slot(f: Callable[[int, int], int], argm: 'Arg') -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(is_arg(argm) and y > 0) + z = f(argm.val, y) + Ensures(z >= y) + + +class Arg: + + def __init__(self) -> None: + Ensures(is_arg(self)) + self.val = 1 # type: int + Fold(is_arg(self)) + + +@Predicate +def is_arg(argm: Arg) -> bool: + return Acc(argm.val) diff --git a/tests/closures/translation/call_slot_declaration/test_return_shadowing.py b/tests/closures/translation/call_slot_declaration/test_return_shadowing.py new file mode 100644 index 000000000..8766882c2 --- /dev/null +++ b/tests/closures/translation/call_slot_declaration/test_return_shadowing.py @@ -0,0 +1,19 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + UniversallyQuantified, + Requires, + Ensures, + Old +) + + +@CallSlot +def call_slot(f: Callable[[int, int], int], x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + #:: ExpectedOutput(invalid.program:call_slots.parameters.illegal_shadowing) + y = f(x, y) + Ensures(y == x + Old(y)) diff --git a/tests/closures/translation/call_slot_declaration/test_simple.py b/tests/closures/translation/call_slot_declaration/test_simple.py new file mode 100644 index 000000000..c0c639837 --- /dev/null +++ b/tests/closures/translation/call_slot_declaration/test_simple.py @@ -0,0 +1,17 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + UniversallyQuantified, + Requires, + Ensures +) + + +@CallSlot +def call_slot(f: Callable[[int, int], int], x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + z = f(x, y) + Ensures(z == x + y) diff --git a/tests/closures/translation/call_slot_declaration/test_simple_no_uq.py b/tests/closures/translation/call_slot_declaration/test_simple_no_uq.py new file mode 100644 index 000000000..a50236ff4 --- /dev/null +++ b/tests/closures/translation/call_slot_declaration/test_simple_no_uq.py @@ -0,0 +1,13 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + Requires, + Ensures +) + + +@CallSlot +def call_slot(f: Callable[[int], int], x: int) -> None: + Requires(x >= 0) + z = f(x) + Ensures(z >= x) diff --git a/tests/closures/translation/call_slot_declaration/test_uq_shadowing.py b/tests/closures/translation/call_slot_declaration/test_uq_shadowing.py new file mode 100644 index 000000000..081724d29 --- /dev/null +++ b/tests/closures/translation/call_slot_declaration/test_uq_shadowing.py @@ -0,0 +1,15 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + UniversallyQuantified +) + + +@CallSlot +def some_slot(f: Callable[[int], None], x: int) -> None: + + @UniversallyQuantified + #:: ExpectedOutput(invalid.program:call_slots.parameters.illegal_shadowing) + def uq(x: int) -> None: + f(x) + diff --git a/tests/closures/translation/call_slot_proof/test_arg_contract_shadowing.py b/tests/closures/translation/call_slot_proof/test_arg_contract_shadowing.py new file mode 100644 index 000000000..35b6238d1 --- /dev/null +++ b/tests/closures/translation/call_slot_proof/test_arg_contract_shadowing.py @@ -0,0 +1,42 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + CallSlotProof, + ClosureCall, + Requires, + Ensures +) + + +def add(x: int, y: int) -> int: + return x + y + + +def mul(x: int, y: int) -> int: + return x * y + + +F_Type = Callable[[int], int] + + +@CallSlot +def call_slot(f: F_Type, x: int) -> None: + + Requires(x >= 0) + z = f(x) + Ensures(z >= x) + + +def method() -> None: + + x = 1 + f = add + + @CallSlotProof(call_slot(f, x)) + #:: ExpectedOutput(invalid.program:call_slots.parameters.illegal_name) + def call_slot_proof(Acc: F_Type, x: int) -> None: + Requires(x >= 0) + + # justified because f == add + z = ClosureCall(Acc(x), add) # type: int + Ensures(z >= x) diff --git a/tests/closures/translation/call_slot_proof/test_double.py b/tests/closures/translation/call_slot_proof/test_double.py new file mode 100644 index 000000000..9ee2e7bb4 --- /dev/null +++ b/tests/closures/translation/call_slot_proof/test_double.py @@ -0,0 +1,55 @@ +from typing import Callable, Optional +from nagini_contracts.contracts import ( + CallSlot, + CallSlotProof, + UniversallyQuantified, + ClosureCall, + Requires, + Ensures +) + + +def add(x: int, y: int) -> int: + return x + y + + +def mul(x: int, y: int) -> Optional[int]: + return x * y + + +F_Type = Callable[[int, int], Optional[int]] + + +@CallSlot +def call_slot(f: F_Type, x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + f(x, y) + + +def method() -> None: + + x = 1 + f = add + + @CallSlotProof(call_slot(f, x)) + def call_slot_proof1(f: F_Type, x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + + # justified because f == add + ClosureCall(f(x, y), add) # type: int + + @CallSlotProof(call_slot(f, x)) + def call_slot_proof2(f: F_Type, x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + + # justified because f == add + ClosureCall(f(x, y), add) # type: int diff --git a/tests/closures/translation/call_slot_proof/test_empty_instantiation.py b/tests/closures/translation/call_slot_proof/test_empty_instantiation.py new file mode 100644 index 000000000..0382d2f67 --- /dev/null +++ b/tests/closures/translation/call_slot_proof/test_empty_instantiation.py @@ -0,0 +1,49 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + CallSlotProof, + UniversallyQuantified, + ClosureCall, + Requires, + Ensures +) + + +def add(x: int, y: int) -> int: + return x + y + + +def mul(x: int, y: int) -> int: + return x * y + + +F_Type = Callable[[int, int], int] + + +@CallSlot +def call_slot(f: F_Type, x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + z = f(x, y) + Ensures(z == x + y) + + +def method() -> None: + + x = 1 + f = add + + #:: ExpectedOutput(type.error:Too few arguments for "CallSlotProof") + @CallSlotProof() + def call_slot_proof(f: F_Type, x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + + # justified because f == add + z = ClosureCall(f(x, y), add) # type: int + + Ensures(z == x + y) diff --git a/tests/closures/translation/call_slot_proof/test_illegal_assign.py b/tests/closures/translation/call_slot_proof/test_illegal_assign.py new file mode 100644 index 000000000..5a3d2f8f0 --- /dev/null +++ b/tests/closures/translation/call_slot_proof/test_illegal_assign.py @@ -0,0 +1,50 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + CallSlotProof, + UniversallyQuantified, + ClosureCall, + Requires, + Ensures +) + + +def add(x: int, y: int) -> int: + return x + y + + +def mul(x: int, y: int) -> int: + return x * y + + +F_Type = Callable[[int, int], int] + +@CallSlot +def call_slot(f: F_Type, x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + z = f(x, y) + Ensures(z == x + y) + + +def method() -> None: + + x = 1 + f = add + + @CallSlotProof(call_slot(f, x)) + def call_slot_proof(f: F_Type, x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + + # justified because f == add + z = ClosureCall(f(x, y), add) # type: int + + #:: ExpectedOutput(invalid.program:call_slots.proof_body.invalid_stmt) + x = 2 + + Ensures(z == x + y) diff --git a/tests/closures/translation/call_slot_proof/test_invalid_instantiation.py b/tests/closures/translation/call_slot_proof/test_invalid_instantiation.py new file mode 100644 index 000000000..691d3bf53 --- /dev/null +++ b/tests/closures/translation/call_slot_proof/test_invalid_instantiation.py @@ -0,0 +1,49 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + CallSlotProof, + UniversallyQuantified, + ClosureCall, + Requires, + Ensures +) + + +def add(x: int, y: int) -> int: + return x + y + + +def mul(x: int, y: int) -> int: + return x * y + + +F_Type = Callable[[int, int], int] + + +@CallSlot +def call_slot(f: F_Type, x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + z = f(x, y) + Ensures(z == x + y) + + +def method() -> None: + + x = 1 + f = add + + #:: ExpectedOutput(invalid.program:call_slots.proof_annotation.invalid_arg) + @CallSlotProof(call_slot) + def call_slot_proof(f: F_Type, x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + + # justified because f == add + z = ClosureCall(f(x, y), add) # type: int + + Ensures(z == x + y) diff --git a/tests/closures/translation/call_slot_proof/test_magic_name.py b/tests/closures/translation/call_slot_proof/test_magic_name.py new file mode 100644 index 000000000..ecfdd637c --- /dev/null +++ b/tests/closures/translation/call_slot_proof/test_magic_name.py @@ -0,0 +1,49 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + CallSlotProof, + UniversallyQuantified, + ClosureCall, + Requires, + Ensures +) + + +def add(x: int, y: int) -> int: + return x + y + + +def mul(x: int, y: int) -> int: + return x * y + + +F_Type = Callable[[int, int], int] + + +@CallSlot +def call_slot(f: F_Type, x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + z = f(x, y) + Ensures(z == x + y) + + +def method() -> None: + + x = 1 + f = add + + #:: ExpectedOutput(invalid.program:illegal.magic.method) + @CallSlotProof(call_slot(f, x)) + def __call_slot_proof__(f: F_Type, x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + + # justified because f == add + z = ClosureCall(f(x, y), add) # type: int + + Ensures(z == x + y) diff --git a/tests/closures/translation/call_slot_proof/test_method_name_collision.py b/tests/closures/translation/call_slot_proof/test_method_name_collision.py new file mode 100644 index 000000000..c2f5b77f0 --- /dev/null +++ b/tests/closures/translation/call_slot_proof/test_method_name_collision.py @@ -0,0 +1,52 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + CallSlotProof, + UniversallyQuantified, + ClosureCall, + Requires, + Ensures +) + + +def add(x: int, y: int) -> int: + return x + y + + +def mul(x: int, y: int) -> int: + return x * y + + +F_Type = Callable[[int, int], int] + + +@CallSlot +def call_slot(f: F_Type, x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + z = f(x, y) + Ensures(z == x + y) + + +def method() -> None: + + x = 1 + f = add + + def some_name() -> None: + pass + + #:: ExpectedOutput(type.error:Name 'some_name' already defined) + @CallSlotProof(call_slot(f, x)) + def some_name(f: F_Type, x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + + # justified because f == add + z = ClosureCall(f(x, y), add) # type: int + + Ensures(z == x + y) diff --git a/tests/closures/translation/call_slot_proof/test_proof_outside_function.py b/tests/closures/translation/call_slot_proof/test_proof_outside_function.py new file mode 100644 index 000000000..0399fb97a --- /dev/null +++ b/tests/closures/translation/call_slot_proof/test_proof_outside_function.py @@ -0,0 +1,48 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + CallSlotProof, + UniversallyQuantified, + ClosureCall, + Requires, + Ensures +) + + +def add(x: int, y: int) -> int: + return x + y + + +def mul(x: int, y: int) -> int: + return x * y + + +F_Type = Callable[[int, int], int] + + +@CallSlot +def call_slot(f: F_Type, x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + z = f(x, y) + Ensures(z == x + y) + + +x = 1 +f = add + + +#:: ExpectedOutput(invalid.program:call_slots.proof.outside_method) +@CallSlotProof(call_slot(f, x)) +def call_slot_proof(f: F_Type, x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + + # justified because f == add + z = ClosureCall(f(x, y), add) # type: int + + Ensures(z == x + y) diff --git a/tests/closures/translation/call_slot_proof/test_return_shadowing.py b/tests/closures/translation/call_slot_proof/test_return_shadowing.py new file mode 100644 index 000000000..3cb3d9ea6 --- /dev/null +++ b/tests/closures/translation/call_slot_proof/test_return_shadowing.py @@ -0,0 +1,49 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + CallSlotProof, + UniversallyQuantified, + ClosureCall, + Requires, + Ensures +) + + +def add(x: int, y: int) -> int: + return x + y + + +def mul(x: int, y: int) -> int: + return x * y + + +F_Type = Callable[[int, int], int] + + +@CallSlot +def call_slot(f: F_Type, x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + z = f(x, y) + Ensures(z == y) + + +def method() -> None: + + x = 1 + f = add + + @CallSlotProof(call_slot(f, x)) + def call_slot_proof(f: F_Type, x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + + # justified because f == add + #:: ExpectedOutput(invalid.program:call_slots.parameters.illegal_shadowing) + x = ClosureCall(f(x, y), add) # type: int + + Ensures(x == y) diff --git a/tests/closures/translation/call_slot_proof/test_simple.py b/tests/closures/translation/call_slot_proof/test_simple.py new file mode 100644 index 000000000..ecf5197a2 --- /dev/null +++ b/tests/closures/translation/call_slot_proof/test_simple.py @@ -0,0 +1,48 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + CallSlotProof, + UniversallyQuantified, + ClosureCall, + Requires, + Ensures +) + + +def add(x: int, y: int) -> int: + return x + y + + +def mul(x: int, y: int) -> int: + return x * y + + +F_Type = Callable[[int, int], int] + + +@CallSlot +def call_slot(f: F_Type, x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + z = f(x, y) + Ensures(z == x + y) + + +def method() -> None: + + x = 1 + f = add + + @CallSlotProof(call_slot(f, x)) + def call_slot_proof(f: F_Type, x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + + # justified because f == add + z = ClosureCall(f(x, y), add) # type: int + + Ensures(z == x + y) diff --git a/tests/closures/translation/call_slot_proof/test_simple_no_uq.py b/tests/closures/translation/call_slot_proof/test_simple_no_uq.py new file mode 100644 index 000000000..86b01d8c3 --- /dev/null +++ b/tests/closures/translation/call_slot_proof/test_simple_no_uq.py @@ -0,0 +1,43 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + CallSlotProof, + ClosureCall, + Requires, + Ensures +) + + +def inc(x: int) -> int: + return x + 1 + + +def twice(x: int) -> int: + return 2 * x + + +F_Type = Callable[[int], int] + + +@CallSlot +def call_slot(f: F_Type, x: int) -> None: + + Requires(x >= 0) + z = f(x) + Ensures(z >= x) + + +def method() -> None: + + x = 1 + f = inc + + @CallSlotProof(call_slot(f, x)) + def call_slot_proof(f: F_Type, x: int) -> None: + + Requires(x >= 0) + + # justified because f == add + z = ClosureCall(f(x), inc) # type: int + + Ensures(z >= x) diff --git a/tests/closures/translation/call_slot_proof/test_uq_shadowing.py b/tests/closures/translation/call_slot_proof/test_uq_shadowing.py new file mode 100644 index 000000000..59155e9ed --- /dev/null +++ b/tests/closures/translation/call_slot_proof/test_uq_shadowing.py @@ -0,0 +1,49 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + CallSlotProof, + UniversallyQuantified, + ClosureCall, + Requires, + Ensures +) + + +def add(x: int, y: int) -> int: + return x + y + + +def mul(x: int, y: int) -> int: + return x * y + + +F_Type = Callable[[int, int], int] + + +@CallSlot +def call_slot(f: F_Type, x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + z = f(x, y) + Ensures(z == x + y) + + +def method() -> None: + + x = 1 + f = add + + @CallSlotProof(call_slot(f, x)) + def call_slot_proof(f: F_Type, x: int) -> None: + + @UniversallyQuantified + #:: ExpectedOutput(invalid.program:call_slots.parameters.illegal_shadowing) + def uq(x: int) -> None: + Requires(x >= 0 and x > x) + + # justified because f == add + z = ClosureCall(f(x, x), add) # type: int + + Ensures(z == x + x) diff --git a/tests/closures/translation/closure_call/test_bad_justification.py b/tests/closures/translation/closure_call/test_bad_justification.py new file mode 100644 index 000000000..df6a87330 --- /dev/null +++ b/tests/closures/translation/closure_call/test_bad_justification.py @@ -0,0 +1,39 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + CallSlotProof, + UniversallyQuantified, + ClosureCall, + Requires, + Ensures +) + + +def add(x: int, y: int) -> int: + return x + y + + +def mul(x: int, y: int) -> int: + return x * y + + +F_Type = Callable[[int, int], int] + + +@CallSlot +def call_slot(f: F_Type, x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + z = f(x, y) + Ensures(z == x + y) + + +def method(f: F_Type, x: int) -> None: + Requires(call_slot(f, x)) + + y = 1 + + #:: ExpectedOutput(invalid.program:call_slots.closure_call.invalid_justification) + z = ClosureCall(f(x, y), x + y) # type: int diff --git a/tests/closures/translation/closure_call/test_no_call.py b/tests/closures/translation/closure_call/test_no_call.py new file mode 100644 index 000000000..194ef894b --- /dev/null +++ b/tests/closures/translation/closure_call/test_no_call.py @@ -0,0 +1,39 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + CallSlotProof, + UniversallyQuantified, + ClosureCall, + Requires, + Ensures +) + + +def add(x: int, y: int) -> int: + return x + y + + +def mul(x: int, y: int) -> int: + return x * y + + +F_Type = Callable[[int, int], int] + + +@CallSlot +def call_slot(f: F_Type, x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + z = f(x, y) + Ensures(z == x + y) + + +def method(f: F_Type, x: int) -> None: + Requires(call_slot(f, x)) + + y = 1 + + #:: ExpectedOutput(invalid.program:call_slots.closure_call.invalid_call) + z = ClosureCall(x, call_slot(f, x)(y)) # type: int diff --git a/tests/closures/translation/closure_call/test_no_uq_vars.py b/tests/closures/translation/closure_call/test_no_uq_vars.py new file mode 100644 index 000000000..cbfd34862 --- /dev/null +++ b/tests/closures/translation/closure_call/test_no_uq_vars.py @@ -0,0 +1,39 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + CallSlotProof, + UniversallyQuantified, + ClosureCall, + Requires, + Ensures +) + + +def add(x: int, y: int) -> int: + return x + y + + +def mul(x: int, y: int) -> int: + return x * y + + +F_Type = Callable[[int, int], int] + + +@CallSlot +def call_slot(f: F_Type, x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + z = f(x, y) + Ensures(z == x + y) + + +def method(f: F_Type, x: int) -> None: + Requires(call_slot(f, x)) + + y = 1 + + #:: ExpectedOutput(invalid.program:call_slots.closure_call.invalid_justification) + z = ClosureCall(f(x, y), call_slot(f, x)) # type: int diff --git a/tests/closures/translation/closure_call/test_not_named_call_slot.py b/tests/closures/translation/closure_call/test_not_named_call_slot.py new file mode 100644 index 000000000..43243478f --- /dev/null +++ b/tests/closures/translation/closure_call/test_not_named_call_slot.py @@ -0,0 +1,48 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + CallSlotProof, + UniversallyQuantified, + ClosureCall, + Requires, + Ensures, + Acc +) + + +def add(x: int, y: int) -> int: + return x + y + + +def mul(x: int, y: int) -> int: + return x * y + + +F_Type = Callable[[int, int], int] + + +@CallSlot +def call_slot(f: F_Type, x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + z = f(x, y) + Ensures(z == x + y) + + +class A: + + def __init__(self) -> None: + self.call_slot = call_slot # type: Callable[[F_Type, int], Callable[[int], int]] + Ensures(Acc(self.call_slot)) + + +def method(f: F_Type, x: int) -> None: + Requires(call_slot(f, x)) + + y = 1 + a = A() + + #:: ExpectedOutput(invalid.program:call_slots.closure_call.invalid_justification) + z = ClosureCall(f(x, y), (a.call_slot)(f, x)(y)) # type: int diff --git a/tests/closures/translation/closure_call/test_simple.py b/tests/closures/translation/closure_call/test_simple.py new file mode 100644 index 000000000..cbb3bc5a2 --- /dev/null +++ b/tests/closures/translation/closure_call/test_simple.py @@ -0,0 +1,38 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + CallSlotProof, + UniversallyQuantified, + ClosureCall, + Requires, + Ensures +) + + +def add(x: int, y: int) -> int: + return x + y + + +def mul(x: int, y: int) -> int: + return x * y + + +F_Type = Callable[[int, int], int] + + +@CallSlot +def call_slot(f: F_Type, x: int) -> None: + + @UniversallyQuantified + def uq(y: int) -> None: + Requires(x >= 0 and y > x) + z = f(x, y) + Ensures(z == x + y) + + +def method(f: F_Type, x: int) -> None: + Requires(call_slot(f, x)) + + y = 1 + + z = ClosureCall(f(x, y), call_slot(f, x)(y)) # type: int diff --git a/tests/closures/verification/examples/hof_forwarding.py b/tests/closures/verification/examples/hof_forwarding.py new file mode 100644 index 000000000..3a41024aa --- /dev/null +++ b/tests/closures/verification/examples/hof_forwarding.py @@ -0,0 +1,172 @@ +from typing import Callable +from nagini_contracts.contracts import ( + Implies, + Requires, + Ensures, + Result, + Old, + Invariant, + Predicate, + Fold, + Unfold, + Unfolding, + Acc, + CallSlot, + CallSlotProof, + UniversallyQuantified, + ClosureCall, +) + + +def inc(x: int) -> int: + Ensures(Result() == x + 1) + return x + 1 + + +def mul(x: int) -> int: + Requires(x > 0) + Ensures(Result() == x * 2) + return x * 2 + + +@Predicate +def pre(x: int, pre_token: int) -> bool: + return ( + True if pre_token == 1 else + x > 0 if pre_token == 2 else + True + ) + + +@Predicate +def post(x: int, ret: int, post_token: int) -> bool: + return ( + ret == x + 1 if post_token == 1 else + ret == x * 2 if post_token == 2 else + True + ) + + +f0_type = Callable[[int], int] + + +@CallSlot +def f1_slot(f0: f0_type, x: int, pre_token: int, post_token: int) -> None: + Requires(pre(x, pre_token)) + ret = f0(x) + Ensures(post(x, ret, post_token)) + + +def f1(f0: f0_type, x: int, pre_token: int, post_token: int) -> int: + Requires(f1_slot(f0, x, pre_token, post_token)) + Requires(pre(x, pre_token)) + Ensures(post(x, Result(), post_token)) + return ClosureCall(f0(x), f1_slot(f0, x, pre_token, post_token)()) + + +f1_type = Callable[[f0_type, int, int, int], int] + + +@CallSlot +def f2_slot(f1: f1_type, f0: f0_type, x: int, pre_token: int, post_token: int) -> None: + Requires(pre(x, pre_token)) + ret = f1(f0, x, pre_token, post_token) + Ensures(post(x, ret, post_token)) + + +def f2(f1: f1_type, f0: f0_type, x: int, pre_token: int, post_token: int) -> int: + Requires(f2_slot(f1, f0, x, pre_token, post_token)) + Requires(pre(x, pre_token)) + Ensures(post(x, Result(), post_token)) + return ClosureCall( + f1(f0, x, pre_token, post_token), + f2_slot(f1, f0, x, pre_token, post_token)() + ) + + +f2_type = Callable[[f1_type, f0_type, int, int, int], int] + + +@CallSlot +def f3_slot(f2: f2_type, f1: f1_type, f0: f0_type, x: int, pre_token: int, post_token: int) -> None: + Requires(pre(x, pre_token)) + ret = f2(f1, f0, x, pre_token, post_token) + Ensures(post(x, ret, post_token)) + + +def f3(f2: f2_type, f1: f1_type, f0: f0_type, x: int, pre_token: int, post_token: int) -> int: + Requires(f3_slot(f2, f1, f0, x, pre_token, post_token)) + Requires(pre(x, pre_token)) + Ensures(post(x, Result(), post_token)) + return ClosureCall( + f2(f1, f0, x, pre_token, post_token), + f3_slot(f2, f1, f0, x, pre_token, post_token)() + ) + + +def client() -> None: + + _inc = inc + _mul = mul + + _f1 = f1 + _f2 = f2 + _f3 = f3 + + @CallSlotProof(f1_slot(_inc, 5, 1, 1)) + def f1_slot_inc(f0: f0_type, x: int, pre_token: int, post_token: int) -> None: + Requires(pre(x, pre_token)) + Unfold(pre(x, pre_token)) + ret = ClosureCall(f0(x), inc) # type: int + Fold(post(x, ret, post_token)) + Ensures(post(x, ret, post_token)) + + @CallSlotProof(f2_slot(_f1, _inc, 5, 1, 1)) + def f2_slot_inc(_f1: f1_type, _f0: f0_type, x: int, pre_token: int, post_token: int) -> None: + Requires(pre(x, pre_token)) + + ret = ClosureCall(_f1(_f0, x, pre_token, post_token), f1) # type: int + + Ensures(post(x, ret, post_token)) + + @CallSlotProof(f3_slot(_f2, _f1, _inc, 5, 1, 1)) + def f3_slot_inc(_f2: f2_type, _f1: f1_type, _f0: f0_type, x: int, pre_token: int, post_token: int) -> None: + Requires(pre(x, pre_token)) + + ret = ClosureCall(_f2(_f1, _f0, x, pre_token, post_token), f2) # type: int + + Ensures(post(x, ret, post_token)) + + Fold(pre(5, 1)) + y1 = ClosureCall(_f3(_f2, _f1, _inc, 5, 1, 1), f3) # type: int + Unfold(post(5, y1, 1)) + assert y1 == 6 + + @CallSlotProof(f1_slot(_mul, 5, 2, 2)) + def f1_slot_mul(f0: f0_type, x: int, pre_token: int, post_token: int) -> None: + Requires(pre(x, pre_token)) + Unfold(pre(x, pre_token)) + ret = ClosureCall(f0(x), mul) # type: int + Fold(post(x, ret, post_token)) + Ensures(post(x, ret, post_token)) + + @CallSlotProof(f2_slot(_f1, _mul, 5, 2, 2)) + def f2_slot_mul(_f1: f1_type, _f0: f0_type, x: int, pre_token: int, post_token: int) -> None: + Requires(pre(x, pre_token)) + + ret = ClosureCall(_f1(_f0, x, pre_token, post_token), f1) # type: int + + Ensures(post(x, ret, post_token)) + + @CallSlotProof(f3_slot(_f2, _f1, _mul, 5, 2, 2)) + def f3_slot_mul(_f2: f2_type, _f1: f1_type, _f0: f0_type, x: int, pre_token: int, post_token: int) -> None: + Requires(pre(x, pre_token)) + + ret = ClosureCall(_f2(_f1, _f0, x, pre_token, post_token), f2) # type: int + + Ensures(post(x, ret, post_token)) + + Fold(pre(5, 2)) + y2 = ClosureCall(_f3(_f2, _f1, _mul, 5, 2, 2), f3) # type: int + Unfold(post(5, y2, 2)) + assert y2 == 10 diff --git a/tests/closures/verification/examples/n_times.py b/tests/closures/verification/examples/n_times.py new file mode 100644 index 000000000..5d17f4eb8 --- /dev/null +++ b/tests/closures/verification/examples/n_times.py @@ -0,0 +1,132 @@ +from typing import Callable, Optional +from nagini_contracts.contracts import ( + Implies, + Requires, + Ensures, + Result, + Old, + Invariant, + Predicate, + Fold, + Unfold, + Unfolding, + Pure, + Acc, + CallSlot, + CallSlotProof, + UniversallyQuantified, + ClosureCall, + Assert, +) + + +class State: + + def __init__( + self, + counter: int, + value: int, + offset: int, + next: Optional['State'] + ) -> None: + + self.counter = counter + self.value = value + self.offset = offset + self.next = next + Ensures(Acc(self.counter) and Acc(self.value)) + Ensures(Acc(self.offset) and Acc(self.next)) + Ensures(self.counter == counter and self.value == value) + Ensures(self.offset == offset and self.next == next) + + +f_type = Callable[[State], Optional[object]] + + +@Predicate +def n_inv(s: State, i: int, n_inv_token: int) -> bool: + return ( + Acc(s.counter) and Acc(s.value) and + s.counter == i and s.value == i * (i + 1) // 2 + if n_inv_token == 1 else + True + ) + + +@CallSlot +def n_times_slot(f: f_type, n_inv_token: int) -> None: + + @UniversallyQuantified + def uq(s: State, i: int) -> None: + Requires(n_inv(s, i, n_inv_token)) + + f(s) + + Ensures(n_inv(s, i + 1, n_inv_token)) + + +def n_times(f: f_type, n: int, s: State, n_inv_token: int) -> None: + Requires(0 <= n) + Requires(n_inv(s, 0, n_inv_token)) + Requires(n_times_slot(f, n_inv_token)) + Ensures(n_inv(s, n, n_inv_token)) + + i = 0 + while i < n: + Invariant(0 <= i and i <= n) + Invariant(n_inv(s, i, n_inv_token)) + + ClosureCall(f(s), n_times_slot(f, n_inv_token)(s, i)) + i += 1 + + # FIXME: this shouldn't be necessary + # (would be impossible with proper 'parametric assertions') + # However while i == n, theyr values as Ref types seem to be uneqal + Unfold(n_inv(s, i, n_inv_token)) + Fold(n_inv(s, n, n_inv_token)) + + +def sum_range(s: State) -> Optional[object]: + Requires(Acc(s.counter) and Acc(s.value)) + Ensures(Acc(s.counter) and Acc(s.value)) + Ensures(s.counter == Old(s.counter) + 1) + Ensures(s.value == Old(s.value + s.counter + 1)) + + s.counter += 1 + s.value += s.counter + + return None + + +def n_times_client() -> None: + + s = State(0, 0, 0, None) + + f = sum_range + + @CallSlotProof(n_times_slot(f, 1)) + def n_times_slot(f: f_type, n_inv_token: int) -> None: + + @UniversallyQuantified + def uq(s: State, i: int) -> None: + Requires(n_inv(s, i, n_inv_token)) + Unfold(n_inv(s, i, n_inv_token)) + + assert i == s.counter + assert s.value == i * (i + 1) // 2 + ClosureCall(f(s), sum_range) + + assert s.counter == i + 1 + assert s.value == i * (i + 1) // 2 + i + 1 + assert s.value == i * (i + 1) // 2 + i + 1 + assert s.value == i * (i + 1) // 2 + 2 * (i + 1) // 2 + assert s.value == (i * (i + 1) + 2 * (i + 1)) // 2 + + Fold(n_inv(s, i + 1, n_inv_token)) + Ensures(n_inv(s, i + 1, n_inv_token)) + + Fold(n_inv(s, 0, 1)) + n_times(f, 23, s, 1) + Unfold(n_inv(s, 23, 1)) + assert s.value == 276 + assert s.counter == 23 diff --git a/tests/closures/verification/examples/pure.py b/tests/closures/verification/examples/pure.py new file mode 100644 index 000000000..e631086b8 --- /dev/null +++ b/tests/closures/verification/examples/pure.py @@ -0,0 +1,144 @@ +from typing import Callable +from nagini_contracts.contracts import ( + CallSlot, + CallSlotProof, + UniversallyQuantified, + ClosureCall, + Pure, + Result, + Requires, + Ensures, + Acc +) + + +def choice() -> bool: + return True + + +class Argument: + + def __init__(self, value_a: int, value_b: int) -> None: + self.value_a = value_a + self.value_b = value_b + Ensures(Acc(self.value_a) and Acc(self.value_b)) + Ensures(self.value_a == value_a and self.value_b == value_b) + + +f_type = Callable[[Argument, int], int] + + +@Pure +def add(argm: Argument, x: int) -> int: + Requires(Acc(argm.value_a)) + Ensures(Result() == argm.value_a + x) + return argm.value_a + x + + +@Pure +def mul(argm: Argument, x: int) -> int: + Requires(Acc(argm.value_b)) + Ensures(Result() == argm.value_b * x) + return argm.value_b * x + + +@Pure +@CallSlot +def add_or_mul(f: f_type) -> None: + + @UniversallyQuantified + def uq(argm: Argument, x: int) -> None: + Requires(Acc(argm.value_a) and Acc(argm.value_b)) + + y = f(argm, x) + + Ensures(y == argm.value_a + x or y == argm.value_b * x) + + +@CallSlot +def hof_slot(f: f_type) -> None: + + @UniversallyQuantified + def uq(argm: Argument, x: int) -> None: + Requires(Acc(argm.value_a, 1 / 2) and Acc(argm.value_b, 1 / 3)) + + y = f(argm, x) + + Ensures(Acc(argm.value_a, 1 / 2) and Acc(argm.value_b, 1 / 3)) + Ensures(y <= argm.value_a + x or y >= argm.value_b * x) + + +def hof(f: f_type, argm: Argument) -> int: + Requires(Acc(argm.value_a, 1 / 2) and Acc(argm.value_b, 1 / 3)) + Requires(hof_slot(f)) + + Ensures(Acc(argm.value_a, 1 / 2) and Acc(argm.value_b, 1 / 3)) + Ensures(Result() <= argm.value_a + 5 or Result() >= argm.value_b * 5) + + return ClosureCall(f(argm, 5), hof_slot(f)(argm, 5)) + + +def client() -> None: + + argm = Argument(1, 2) + assert argm.value_a == 1 + assert argm.value_b == 2 + + f = add + y = ClosureCall(f(argm, 3), add) # type: int + assert y == 4 + assert argm.value_a == 1 + assert argm.value_b == 2 + + f = mul + y = ClosureCall(f(argm, 3), mul) + assert y == 6 + assert argm.value_a == 1 + assert argm.value_b == 2 + + if choice(): + f = add + else: + f = mul + + @CallSlotProof(add_or_mul(f)) + def add_or_mul_proof(f: f_type) -> None: + + @UniversallyQuantified + def uq(argm: Argument, x: int) -> None: + Requires(Acc(argm.value_a) and Acc(argm.value_b)) + + if f == add: + y = ClosureCall(f(argm, x), add) # type: int + else: + y = ClosureCall(f(argm, x), mul) + + Ensures(y == argm.value_a + x or y == argm.value_b * x) + + y1 = ClosureCall(f(argm, 3), add_or_mul(f)(argm, 3)) # type: int + y2 = ClosureCall(f(argm, 3), add_or_mul(f)(argm, 3)) # type: int + assert y1 == 4 or y2 == 6 + assert y1 == y2 + assert argm.value_a == 1 + assert argm.value_b == 2 + assert y1 == ClosureCall(f(argm, 3), add_or_mul(f)(argm, 3)) + + @CallSlotProof(hof_slot(f)) + def hof_slot_proof(f: f_type) -> None: + + @UniversallyQuantified + def uq(argm: Argument, x: int) -> None: + Requires(Acc(argm.value_a) and Acc(argm.value_b)) + + if f == add: + y = ClosureCall(f(argm, x), add) # type: int + else: + y = ClosureCall(f(argm, x), mul) + + Ensures(y <= argm.value_a + x or y >= argm.value_b * x) + + h = hof + y = ClosureCall(h(f, argm), hof) + assert y <= 6 or y >= 10 + assert argm.value_a == 1 + assert argm.value_b == 2 diff --git a/tests/closures/verification/examples/setup_compute.py b/tests/closures/verification/examples/setup_compute.py new file mode 100644 index 000000000..4dcc8dc40 --- /dev/null +++ b/tests/closures/verification/examples/setup_compute.py @@ -0,0 +1,155 @@ +from typing import Callable, Optional +from nagini_contracts.contracts import ( + Requires, + Ensures, + Acc, + Predicate, + Fold, + Unfold, + CallSlot, + CallSlotProof, + UniversallyQuantified, + ClosureCall, +) + + +class Argument: + + def __init__(self, a: int, b: int) -> None: + self.a = a # type: int + self.b = b # type: int + + Ensures(Acc(self.a) and Acc(self.b)) + Ensures(self.a == a and self.b == b) + + +F_Type = Callable[[Argument, int], Optional[object]] + + +@CallSlot +def f_setup(setup: F_Type, c: int, d: int, before_token: int, between_token: int) -> None: + + @UniversallyQuantified + def uq(argm: Argument) -> None: + Requires(before(argm, c, d, before_token)) + + setup(argm, c) + + Ensures(between(argm, c, d, between_token)) + + +@CallSlot +def f_compute(compute: F_Type, c: int, d: int, between_token: int, after_token: int) -> None: + + @UniversallyQuantified + def uq(argm: Argument) -> None: + Requires(between(argm, c, d, between_token)) + + compute(argm, d) + + Ensures(after(argm, c, d, after_token)) + + +def f( + setup: F_Type, + compute: F_Type, + argm: Argument, + c: int, + d: int, + before_token: int, + between_token: int, + after_token: int +) -> None: + + Requires(f_setup(setup, c, d, before_token, between_token)) + Requires(f_compute(compute, c, d, between_token, after_token)) + Requires(before(argm, c, d, before_token)) + + ClosureCall(setup(argm, c), f_setup(setup, c, d, before_token, between_token)(argm)) + + # assert between(arg, c, d, between_token) + + ClosureCall(compute(argm, d), f_compute(compute, c, d, between_token, after_token)(argm)) + + Ensures(after(argm, c, d, after_token)) + + +def setup(argm: Argument, c: int) -> Optional[object]: + Requires(Acc(argm.a)) + Requires(c > 1) + Ensures(Acc(argm.a)) + Ensures(argm.a == 3 * c) + + argm.a = 3 * c + + +def compute(argm: Argument, d: int) -> Optional[object]: + Requires(Acc(argm.a, 1 / 2) and Acc(argm.b)) + Ensures(Acc(argm.a, 1 / 2) and Acc(argm.b)) + Ensures(argm.b == argm.a + d) + + argm.b = argm.a + d + + +@Predicate +def before(argm: Argument, c: int, d: int, token: int) -> bool: + return ( + Acc(argm.a) and Acc(argm.b) and c > 1 if token == 1 else + True + ) + + +@Predicate +def between(argm: Argument, c: int, d: int, token: int) -> bool: + return ( + Acc(argm.b) and Acc(argm.a) and argm.a == 3 * c if token == 1 else + True + ) + + +@Predicate +def after(argm: Argument, c: int, d: int, token: int) -> bool: + return ( + Acc(argm.a) and Acc(argm.b) and argm.a == 3 * c and argm.b == argm.a + d if token == 1 else + True + ) + + +def client() -> None: + + argm = Argument(2, 2) + c, d = 2, 3 + before_token = between_token = after_token = 1 + + @CallSlotProof(f_setup(setup, c, d, before_token, between_token)) + def f_setup_proof(f: F_Type, c: int, d: int, A: int, B: int) -> None: + + @UniversallyQuantified + def uq(argm: Argument) -> None: + Requires(before(argm, c, d, before_token)) + Unfold(before(argm, c, d, between_token)) + + ClosureCall(f(argm, c), setup) + + Fold(between(argm, c, d, between_token)) + Ensures(between(argm, c, d, between_token)) + + @CallSlotProof(f_compute(compute, c, d, between_token, after_token)) + def f_compute_proof(f: F_Type, c: int, d: int, B: int, C: int) -> None: + + @UniversallyQuantified + def uq(argm: Argument) -> None: + Requires(between(argm, c, d, between_token)) + Unfold(between(argm, c, d, between_token)) + + ClosureCall(f(argm, d), compute) + + Fold(after(argm, c, d, after_token)) + Ensures(after(argm, c, d, after_token)) + + Fold(before(argm, c, d, between_token)) + f(setup, compute, argm, c, d, before_token, between_token, after_token) + Unfold(after(argm, c, d, after_token)) + + assert argm.a == 6 + assert argm.b == 9 diff --git a/tests/closures/verification/examples/static_code.py b/tests/closures/verification/examples/static_code.py new file mode 100644 index 000000000..58a092ed6 --- /dev/null +++ b/tests/closures/verification/examples/static_code.py @@ -0,0 +1,145 @@ +from typing import Callable, Tuple +from nagini_contracts.contracts import ( + Requires, + Ensures, + Invariant, + Acc, + Result, + Pure, + CallSlot, + CallSlotProof, + UniversallyQuantified, + ClosureCall, +) + + +@Pure +def idiv(x: int, y: int) -> int: + Requires(y != 0) + + +@Pure +def sqrt(x: int) -> int: + Requires(x >= 0) + + +@Pure +def log(x: int) -> int: + Requires(x > 0) + + +def choice() -> bool: + return True + + +class Argument: + + def __init__(self, parameter: int, result: int) -> None: + self.parameter = parameter # type: int + self.result = result # type: int + + Ensures(Acc(self.parameter) and Acc(self.result)) + Ensures(self.parameter == parameter and self.result == result) + + +G_Type = Callable[[Argument, int, int], Tuple[int, int]] + + +@CallSlot +def func_call_slot(g: G_Type, b: int, c: int) -> None: + + @UniversallyQuantified + def uq(a: Argument) -> None: + + Requires(Acc(a.parameter, 1 / 2) and Acc(a.result)) + Requires(a.parameter >= b + c) + + ret = g(a, b, c) + + Ensures(Acc(a.parameter, 1 / 2) and Acc(a.result)) + Ensures(a.result != 0 and ret[0] >= 0 and ret[1] > 0) + + +def func( + g: G_Type, + a: Argument, + b: int, + c: int +) -> Tuple[int, int, int]: + + Requires(func_call_slot(g, b, c)) + Requires(Acc(a.parameter)) + Requires(Acc(a.result)) + Ensures(Acc(a.parameter)) + Ensures(Acc(a.result)) + + while a.parameter < b + c: + Invariant(Acc(a.parameter)) + a.parameter = a.parameter * a.parameter + + # g reads a.parameter and writes to a.result + + # closure call justified because the call slot holds: + # func_call_slot(g, b, c) + d, e = ClosureCall(g(a, b, c), func_call_slot(g, b, c)(a)) # type: Tuple[int, int] + + return idiv(1, a.result), sqrt(d), log(e) + + +def concrete_g_1(a: Argument, b: int, c: int) -> Tuple[int, int]: + Requires(Acc(a.parameter, 1 / 2) and Acc(a.result)) + Requires(a.parameter >= b + c) + Ensures(Acc(a.parameter, 1 / 2) and Acc(a.result)) + Ensures(a.result == -1 and Result()[0] == 0 and Result()[1] >= 1) + + a.result = -1 + return 0, a.parameter - b - c + 1 + + +def concrete_g_2(a: Argument, b: int, c: int) -> Tuple[int, int]: + Requires(Acc(a.parameter, 1 / 2) and Acc(a.result)) + Requires(a.parameter >= b + c) + Ensures(Acc(a.parameter, 1 / 2) and Acc(a.result)) + Ensures(a.result == 1 and Result()[0] >= 0 and Result()[1] == 1) + + a.result = 1 + return a.parameter - b - c, 1 + + +def client() -> None: + + if choice(): # non-deterministic choice + concrete_g = concrete_g_1 + else: + concrete_g = concrete_g_2 + + a = Argument(2, 2) + b, c = 2, 3 + + @CallSlotProof(func_call_slot(concrete_g, b, c)) + def func_call_slot_proof(concrete_g: G_Type, b: int, c: int) -> None: + + @UniversallyQuantified + def uq(a: Argument) -> None: + + Requires(Acc(a.parameter, 1 / 2) and Acc(a.result)) + Requires(a.parameter >= b + c) + + if concrete_g == concrete_g_1: + # closure call justified, because we can prove static dispatch: + # concrete_g == concrete_g_1 + # and concrete_g_1 is a method whose contracts we can look up + # statically/in nagini + ret = ClosureCall(concrete_g(a, b, c), concrete_g_1) # type: Tuple[int, int] + else: + assert concrete_g == concrete_g_2 + # closure call justified, because we can prove static dispatch: + # concrete_g == concrete_g_2 + # and concrete_g_2 is a method whose contracts we can look up + # statically/in nagini + ret = ClosureCall(concrete_g(a, b, c), concrete_g_2) + + Ensures(Acc(a.parameter, 1 / 2) and Acc(a.result)) + Ensures(a.result != 0 and ret[0] >= 0 and ret[1] > 0) + + func(concrete_g, a, b, c) diff --git a/tests/closures/verification/examples/while_loop.py b/tests/closures/verification/examples/while_loop.py new file mode 100644 index 000000000..ceaadaca5 --- /dev/null +++ b/tests/closures/verification/examples/while_loop.py @@ -0,0 +1,193 @@ +from typing import Callable, Optional +from nagini_contracts.contracts import ( + Implies, + Requires, + Ensures, + Result, + Old, + Invariant, + Predicate, + Fold, + Unfold, + Unfolding, + Pure, + Acc, + CallSlot, + CallSlotProof, + UniversallyQuantified, + ClosureCall, + Assert, +) + + +class State: + + def __init__( + self, + counter: int, + value: int, + offset: int, + next: Optional['State'] + ) -> None: + + self.counter = counter + self.value = value + self.offset = offset + self.next = next + Ensures(Acc(self.counter) and Acc(self.value)) + Ensures(Acc(self.offset) and Acc(self.next)) + Ensures(self.counter == counter and self.value == value) + Ensures(self.offset == offset and self.next == next) + + +cond_t = Callable[[State], bool] +body_t = Callable[[State], Optional[object]] + + +@Predicate +def inv(s: State, inv_token: int) -> bool: + return ( + ( + Acc(s.counter) and Acc(s.value, 1 / 2) and + s.counter <= s.value + ) if inv_token == 1 else + True + ) + + +@Pure +def cond_expr(s: State, inv_token: int) -> bool: + Requires(inv(s, inv_token)) + return Unfolding( + inv(s, inv_token), + s.counter < s.value if inv_token == 1 else + True + ) + + +@Pure +@CallSlot +def cond_slot( + cond: cond_t, + inv_token: int +) -> None: + + @UniversallyQuantified + def uq(s: State) -> None: + Requires(inv(s, inv_token)) + + b = cond(s) + + Ensures(b == cond_expr(s, inv_token)) + + +@CallSlot +def body_slot(body: body_t, inv_token: int) -> None: + + @UniversallyQuantified + def uq(s: State) -> None: + Requires(inv(s, inv_token)) + Requires(cond_expr(s, inv_token)) + + body(s) + + Ensures(inv(s, inv_token)) + + +def while_loop( + cond: cond_t, + body: body_t, + s: State, + inv_token: int +) -> None: + + Requires(inv(s, inv_token)) + Requires(cond_slot(cond, inv_token)) + Requires(body_slot(body, inv_token)) + + Ensures(inv(s, inv_token)) + Ensures(not cond_expr(s, inv_token)) + + b = ClosureCall( + cond(s), + cond_slot(cond, inv_token)(s) + ) # type: bool + + assert b == cond_expr(s, inv_token) + + while b: + Invariant(inv(s, inv_token)) + Invariant(b == cond_expr(s, inv_token)) + + ClosureCall( + body(s), + body_slot(body, inv_token)(s) + ) + + b = ClosureCall( + cond(s), + cond_slot(cond, inv_token)(s) + ) + + +@Pure +def count_to_cond(s: State) -> bool: + Requires(Acc(s.counter) and Acc(s.value)) + Ensures(Result() == (s.counter < s.value)) + return s.counter < s.value + + +def count_to_body(s: State) -> Optional[object]: + Requires(Acc(s.counter) and Acc(s.value, 1 / 2)) + Requires(s.counter < s.value) + + Ensures(Acc(s.counter) and Acc(s.value, 1 / 2)) + Ensures(s.counter <= s.value) + Ensures(s.counter == Old(s.counter) + 1) + + s.counter += 1 + return None + + +def while_loop_client() -> None: + + cond_f = count_to_cond + body_f = count_to_body + + s = State(0, 20, 0, None) + + @CallSlotProof(cond_slot(cond_f, 1)) + def cond_slot( + cond: cond_t, + inv_token: int + ) -> None: + + @UniversallyQuantified + def uq(s: State) -> None: + Requires(inv(s, inv_token)) + + Unfold(inv(s, inv_token)) + b = ClosureCall(cond(s), count_to_cond) # type: bool + Fold(inv(s, inv_token)) + + Ensures(b == cond_expr(s, inv_token)) + + @CallSlotProof(body_slot(body_f, 1)) + def body_slot(body: body_t, inv_token: int) -> None: + + @UniversallyQuantified + def uq(s: State) -> None: + Requires(inv(s, inv_token)) + Requires(Unfolding(inv(s, inv_token), cond_expr(s, inv_token))) + + Unfold(inv(s, inv_token)) + ClosureCall(body(s), count_to_body) + Fold(inv(s, inv_token)) + + Ensures(inv(s, inv_token)) + + Fold(inv(s, 1)) + while_loop(cond_f, body_f, s, 1) + Unfold(inv(s, 1)) + + assert s.counter == 20 and s.value == 20 diff --git a/tests/closures/verification/test_old.py b/tests/closures/verification/test_old.py new file mode 100644 index 000000000..41a702eb1 --- /dev/null +++ b/tests/closures/verification/test_old.py @@ -0,0 +1,71 @@ +from typing import Callable, Optional +from nagini_contracts.contracts import ( + Requires, + Ensures, + Predicate, + Acc, + Old, + CallSlot, + CallSlotProof, + ClosureCall, +) + + +class Argument: + + def __init__(self, parameter: int, result: int) -> None: + self.parameter = parameter # type: int + self.result = result # type: int + + Ensures(Acc(self.parameter) and Acc(self.result)) + Ensures(self.parameter == parameter and self.result == result) + + +inc_type = Callable[[Argument], Optional[object]] + + +def inc(argm: Argument) -> Optional[object]: + Requires(Acc(argm.parameter) and Acc(argm.result)) + + Ensures(Acc(argm.parameter) and Acc(argm.result)) + Ensures(argm.result == Old(argm.result) + argm.parameter) + Ensures(argm.parameter == Old(argm.parameter)) + + argm.result = argm.result + argm.parameter + + return None + + +@CallSlot +def inc_call_slot(f: inc_type, argm: Argument) -> None: + Requires(Acc(argm.parameter) and Acc(argm.result)) + + f(argm) + + Ensures(Acc(argm.parameter) and Acc(argm.result)) + Ensures(argm.result >= Old(argm.result) + argm.parameter) + Ensures(argm.parameter == Old(argm.parameter)) + + +def test() -> None: + + argm = Argument(1, 2) + + argm.result = 20 + argm.parameter = 50 + + f = inc + + @CallSlotProof(inc_call_slot(inc, argm)) + def inc_proof(f: inc_type, argm: Argument) -> None: + Requires(Acc(argm.parameter) and Acc(argm.result)) + + ClosureCall(f(argm), inc) + + Ensures(Acc(argm.parameter) and Acc(argm.result)) + Ensures(argm.result >= Old(argm.result) + argm.parameter) + Ensures(argm.parameter == Old(argm.parameter)) + + ClosureCall(f(argm), inc_call_slot(f, argm)()) + + assert argm.result >= 70 and argm.parameter == 50 diff --git a/tests/closures/verification/test_pure.py b/tests/closures/verification/test_pure.py new file mode 100644 index 000000000..b4719c4fa --- /dev/null +++ b/tests/closures/verification/test_pure.py @@ -0,0 +1,97 @@ +from typing import Callable +from nagini_contracts.contracts import ( + Requires, + Ensures, + Acc, + Pure, + Result, + Old, + CallSlot, + CallSlotProof, + UniversallyQuantified, + ClosureCall +) + + +def choice() -> bool: + return True + + +class Argument: + + def __init__(self, parameter: int, result: int) -> None: + self.parameter = parameter # type: int + self.result = result # type: int + + Ensures(Acc(self.parameter) and Acc(self.result)) + Ensures(self.parameter == parameter and self.result == result) + + +F_Type = Callable[[Argument, int], int] + + +@Pure +def add(argm: Argument, x: int) -> int: + Requires(Acc(argm.parameter)) + Ensures(Result() == x + argm.parameter) + return x + argm.parameter + + +@Pure +def mul(argm: Argument, x: int) -> int: + Requires(Acc(argm.parameter)) + Ensures(Result() == x * argm.parameter) + return x * argm.parameter + + +@Pure +@CallSlot +def pure_call_slot(f: F_Type, argm: Argument) -> None: + + @UniversallyQuantified + def uq(x: int) -> None: + Requires(Acc(argm.parameter) and argm.parameter > 0 and x > 1) + + y = f(argm, x) + + Ensures(y > argm.parameter) + + +def client(f: F_Type, argm: Argument) -> None: + Requires(Acc(argm.parameter) and Acc(argm.result)) + Requires(argm.parameter > 0) + Requires(pure_call_slot(f, argm)) + Ensures(Acc(argm.parameter) and Acc(argm.result)) + Ensures(argm.parameter == Old(argm.parameter)) + Ensures(argm.result > argm.parameter) + + argm.result = ClosureCall(f(argm, 20), pure_call_slot(f, argm)(20)) + + +def method() -> None: + + if choice(): + f = add + else: + f = mul + + argm = Argument(10, 5) + + @CallSlotProof(pure_call_slot(f, argm)) + def pure_call_slot(f: F_Type, argm: Argument) -> None: + + @UniversallyQuantified + def uq(x: int) -> None: + Requires(Acc(argm.parameter) and argm.parameter > 0 and x > 1) + + if f == add: + y = ClosureCall(f(argm, x), add) # type: int + else: + y = ClosureCall(f(argm, x), mul) + + Ensures(y > argm.parameter) + + client(f, argm) + + assert argm.parameter == 10 + assert argm.result > argm.parameter