diff --git a/.gitignore b/.gitignore index cca6cf503..7e57a365e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ syntax: glob env bin +build deps/JPype1 __pycache__ .virtualenv @@ -14,4 +15,4 @@ tmp docs/build .idea *.pyc -viper_out +viper_out \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json index 37cdaf977..6a224085d 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -61,23 +61,20 @@ }, { "name": "Debug Nagini verifying Python", - "type": "python", + "type": "debugpy", "request": "launch", - "stopOnEntry": false, - "pythonPath": "${workspaceRoot}/nagini/env/bin/python3", - "program": "${workspaceRoot}/nagini/env/bin/nagini", + "program": "${workspaceFolder}/src/nagini_translation/main.py", "args": [ + "--base-dir=${fileDirname}/..", "${file}" ], - "cwd": "${workspaceRoot}", - "env": {}, - "envFile": "${workspaceRoot}/.env", + "cwd": "${workspaceFolder}", "console": "integratedTerminal", - "debugOptions": [ - "WaitOnAbnormalExit", - "WaitOnNormalExit", - "RedirectOutput" - ] + "justMyCode": false, + "python": "${command:python.interpreterPath}", + "env": { + "PYTHONPATH": "${workspaceFolder}/src" + } }, { "name": "Translate Python to Viper", diff --git a/setup.py b/setup.py index 78ddf6c20..22f10d487 100644 --- a/setup.py +++ b/setup.py @@ -29,8 +29,9 @@ 'toposort==1.5', 'jpype1==1.5.0', 'astunparse==1.6.2', - 'pytest==7.0.0', - 'z3-solver==4.8.7.0' + 'pytest', + 'z3-solver==4.8.7.0', + 'setuptools==68.2.0' ], entry_points={ 'console_scripts': [ diff --git a/src/nagini_contracts/contracts.py b/src/nagini_contracts/contracts.py index 6ff201183..553f769a6 100644 --- a/src/nagini_contracts/contracts.py +++ b/src/nagini_contracts/contracts.py @@ -30,7 +30,7 @@ CONTRACT_FUNCS = ['Assume', 'Assert', 'Old', 'Result', 'ResultT', 'Implies', 'Forall', 'IOForall', 'Forall2', 'Forall3', 'Forall6', 'Exists', 'Low', 'LowVal', 'LowEvent', 'Declassify', 'TerminatesSif', 'Acc', 'Rd', 'Wildcard', 'Fold', 'Unfold', 'Unfolding', 'Previous', - 'RaisedException', 'PSeq', 'PSet', 'ToSeq', 'ToMS', 'MaySet', 'MayCreate', + 'RaisedException', 'PSeq', 'PByteSeq', 'PSet', 'ToSeq', 'ToByteSeq', 'ToMS', 'MaySet', 'MayCreate', 'getMethod', 'getArg', 'getOld', 'arg', 'Joinable', 'MayStart', 'Let', 'PMultiset', 'LowExit', 'Refute', 'isNaN', 'Reveal'] @@ -205,6 +205,10 @@ def TerminatesSif(cond: bool, rank: int) -> bool: """ pass +PBool = bool + +PInt = int + class PSeq(Generic[T], Sized, Iterable[T]): """ A PSeq[T] represents a pure sequence of instances of subtypes of T, and @@ -265,6 +269,66 @@ def __iter__(self) -> Iterator[T]: can be used as arguments for Forall. """ +class PByteSeq(Sized, Iterable[int]): + """ + A PByteSeq represents a pure sequence of instances of int, and + is translated to native Viper sequences. + """ + + def __init__(self, *args: int) -> None: + """ + ``PByteSeq(a, b, c)`` creates a PByteSeq instance containing the objects + a, b and c in that order. + """ + + def __contains__(self, item: object) -> bool: + """ + True iff this PByteSeq contains the given object (not taking ``__eq__`` + into account). + """ + + def __getitem__(self, item: int) -> int: + """ + Returns the item at the given position. + """ + + def __len__(self) -> int: + """ + Returns the length of this PByteSeq. + """ + + def __add__(self, other: 'PByteSeq') -> 'PByteSeq': + """ + Concatenates two PByteSeqs to get a new PByteSeq. + """ + + def take(self, until: int) -> 'PByteSeq': + """ + Returns a new PByteSeq containing all elements starting + from the beginning until the given index. ``PByteSeq(3,2,5,6).take(3)`` + is equal to ``PByteSeq(3,2,5)``. + """ + + def drop(self, until: int) -> 'PByteSeq': + """ + Returns a new PByteSeq containing all elements starting + from the given index (i.e., drops all elements until that index). + ``PByteSeq(2,3,5,6).drop(2)`` is equal to ``PByteSeq(5,6)``. + """ + + def update(self, index: int, new_val: int) -> 'PByteSeq': + """ + Returns a new PByteSeq, containing the same elements + except for the element at index ``index``, which is replaced by + ``new_val``. + """ + + def __iter__(self) -> Iterator[int]: + """ + PByteSeqs can be quantified over; this is only here so thatPByteSeqs + can be used as arguments for Forall. + """ + def Previous(it: T) -> PSeq[T]: """ Within the body of a loop 'for x in xs', Previous(x) represents the list of @@ -356,6 +420,12 @@ def ToSeq(l: Iterable[T]) -> PSeq[T]: Converts the given iterable of a built-in type (list, set, dict, range) to a pure PSeq. """ + +def ToByteSeq(l: Iterable[int]) -> PByteSeq: + """ + Converts the given iterable of a compatible built-in type (bytearray) to + a pure PByteSeq. + """ def ToMS(s: PSeq[T]) -> PMultiset[T]: @@ -542,6 +612,13 @@ def dict_pred(d: object) -> bool: be folded or unfolded. """ +def bytearray_pred(d: object) -> bool: + """ + Special, predefined predicate that represents the permissions belonging + to a bytearray. To be used like normal predicates, except it does not need to + be folded or unfolded. + """ + def isNaN(f: float) -> bool: pass @@ -595,10 +672,15 @@ def isNaN(f: float) -> bool: 'list_pred', 'dict_pred', 'set_pred', + 'bytearray_pred', + 'PBool', + 'PInt', 'PSeq', + 'PByteSeq', 'PSet', 'PMultiset', 'ToSeq', + 'ToByteSeq', 'ToMS', 'MaySet', 'MayCreate', diff --git a/src/nagini_translation/analyzer.py b/src/nagini_translation/analyzer.py index 570a807ed..06b7c5563 100644 --- a/src/nagini_translation/analyzer.py +++ b/src/nagini_translation/analyzer.py @@ -10,7 +10,6 @@ import os import nagini_contracts.io_builtins import nagini_contracts.lock -import tokenize from collections import OrderedDict from nagini_contracts.contracts import CONTRACT_FUNCS, CONTRACT_WRAPPER_FUNCS @@ -27,6 +26,7 @@ EXTENDABLE_BUILTINS, IGNORED_IMPORTS, INT_TYPE, + PRIMITIVE_INT_TYPE, LEGAL_MAGIC_METHODS, LITERALS, MYPY_SUPERCLASSES, @@ -64,6 +64,7 @@ InvalidProgramException, is_io_existential, isStr, + read_source_file, UnsupportedException, ) from nagini_translation.lib.views import PythonModuleView @@ -103,6 +104,7 @@ def __init__(self, types: TypeInfo, path: str, selected: Set[str]): self.deferred_tasks = [] self.has_all_low = False self.enable_obligations = False + self.comment_pattern = "#@nagini" def initialize_io_analyzer(self) -> None: self.io_operation_analyzer = IOOperationAnalyzer( @@ -154,8 +156,8 @@ def collect_imports(self, abs_path: str) -> None: # This is a module that corresponds to a directory, so it has no # contents of its own. return - with tokenize.open(abs_path) as file: - text = file.read() + + text = read_source_file(abs_path) parse_result = ast.parse(text) try: mark_text_ranges(parse_result, text) @@ -581,18 +583,91 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None: cls.superclass = self.find_or_create_class(OBJECT_TYPE) if cls.python_class not in cls.superclass.python_class.direct_subclasses: cls.superclass.python_class.direct_subclasses.append(cls.python_class) - + if cls.superclass.python_class.enum: + raise InvalidProgramException(node, 'Cannot extend enumeration') + if cls.superclass.name == "IntEnum": + cls.enum = True + cls.enum_type = INT_TYPE + if self.is_dataclass(node): + cls.dataclass = True + if self.is_frozen_dataclass(node): + cls.frozen = True for kw in node.keywords: if kw.arg == 'metaclass' and isinstance(kw.value, ast.Name) and kw.value.id == 'ABCMeta': continue if kw.arg == 'metaclass': raise UnsupportedException(kw, "Unsupported metaclass") raise UnsupportedException(kw, "Unsupported keyword argument") - - for member in node.body: + + for member in node.body.copy(): self.visit(member, node) + if cls.dataclass and "__init__" not in cls.methods.keys(): + self._add_dataclass_init_method(node) + self.current_class = None + def _add_dataclass_init_method(self, node: ast.ClassDef) -> None: + """Adds the implicit __init__ method for dataclasses""" + assert self.current_class != None + + args: list[ast.arg] = [] + defaults: list[ast.expr] = [] + postconditions: list[ast.stmt] = [] + + # Parse fields, add implicit args and post conditions + args.append(self._create_arg_ast(node, 'self', None)) + for name, field in self.current_class.fields.items(): + args.append(self._create_arg_ast(node, name, field.type.name)) + self_attr = ast.Attribute(self._create_name_ast('self', node), name, ast.Load(), lineno=node.lineno, col_offset=0) + if not self.current_class.frozen: + postconditions.append(self._create_acc_postcondition(node, self_attr)) + postconditions.append(self._create_comp_postcondition(node, + ast.Attribute(self._create_name_ast('self', node), name, ast.Load(), lineno=node.lineno, col_offset=0), + self._create_name_ast(name, node), ast.Is())) + if getattr(field, 'result', None) is not None: + defaults.append(field.result) + field.result = None + + ast_arguments = ast.arguments([], args, None, [], [], None, defaults) + + # Could add implicit field assignments for non-frozen dataclass + + # Add decorators + decorator_list: list[ast.expr] = [self._create_name_ast('ContractOnly', node)] + + stmts = postconditions + function_def = ast.FunctionDef('__init__', ast_arguments, stmts, decorator_list, returns=None, lineno=node.lineno, col_offset=0) + self.visit(function_def, node) + + # Propagate default_factory info to the method args + method = self.current_class.methods['__init__'] + for name, field in self.current_class.fields.items(): + if getattr(field, 'default_factory', None): + method.args[name].default_factory = field.default_factory + + node.body.append(function_def) + self.current_class.implicit_init = True + return + + def _create_arg_ast(self, node, arg: str, type_name: Optional[str] = None) -> ast.arg: + name_node = None + if type_name != None: + name_node = self._create_name_ast(type_name, node) + return ast.arg(arg, name_node, lineno=node.lineno, col_offset=0) + + def _create_comp_postcondition(self, node, left: ast.expr, right: ast.expr, op: ast.cmpop) -> ast.stmt: + compare = ast.Compare(left, ops=[op], comparators=[right], + lineno=node.lineno, col_offset=0) + return ast.Expr(ast.Call(self._create_name_ast('Ensures', node), [compare], [], lineno=node.lineno, col_offset=0)) + + def _create_acc_postcondition(self, node, attr: ast.expr) -> ast.stmt: + acc_call = ast.Call(self._create_name_ast('Acc', node), [attr], [], + lineno=node.lineno, col_offset=0) + return ast.Expr(ast.Call(self._create_name_ast('Ensures', node), [acc_call], [], lineno=node.lineno, col_offset=0)) + + def _create_name_ast(self, id: str, node) -> ast.Name: + return ast.Name(id, ast.Load(), lineno=node.lineno, col_offset=0) + def _is_illegal_magic_method_name(self, name: str) -> bool: """ Anything that could potentially be a magic method, i.e. anything that @@ -889,6 +964,14 @@ def visit_Lambda(self, node: ast.Lambda) -> None: def visit_arg(self, node: ast.arg) -> None: assert self.current_function is not None node_type = self.typeof(node) + if isinstance(node.annotation, ast.Name) and node.annotation.id in ('PInt', 'PBool'): + if node.annotation.id == 'PInt': + assert node_type.name == 'int' + node_type = node_type.module.classes['__prim__int'] + elif node.annotation.id == 'PBool': + assert node_type.name == 'bool' + node_type = node_type.module.classes['__prim__bool'] + self.current_function.args[node.arg] = \ self.node_factory.create_python_var(node.arg, node, node_type) # If we just introduced new type variables, create the expression that @@ -1132,30 +1215,108 @@ def todo(): self.track_access(node, var) self.deferred_tasks.append(todo) return + elif self.current_class.dataclass: + # Node is a field of a dataclass + if isinstance(node.ctx, ast.Load): + return + + assign = node._parent + if isinstance(assign, ast.Assign): + if not len(assign.targets) == 1: + raise UnsupportedException(assign, + 'only simple assignments allowed for dataclass fields') + if (isinstance(assign.value, ast.Call) + and isinstance(assign.value.func, ast.Name) + and assign.value.func.id == 'field'): + raise UnsupportedException(assign, + 'field() requires a type annotation') + # Infer type from value + annotation = self._create_name_ast(self.typeof(node).name, node) + elif isinstance(assign, ast.AnnAssign) and assign.simple == 1: + annotation = assign.annotation + else: + msg = ('only simple assignments and reads allowed for ' + 'dataclass fields') + raise UnsupportedException(assign, msg) + + # Add type info for self in this context, can retrieve the correct type from __init__.self + prefix = self.module.type_prefix.split('.') if self.module.type_prefix else [] + prefix.extend([self.current_class.name, node.id, 'self']) + context = tuple(prefix) + self_type, _ = self.module.get_type([self.current_class.name, '__init__'], 'self') + self.module.types.all_types[context] = self_type + + # Create a property for frozen fields, or a normal field for non-frozen + if self.current_class.frozen: + ast_arguments = ast.arguments([], [self._create_arg_ast(node, 'self', None)], None, [], [], None, []) + stmts = [ast.Expr(ast.Call(self._create_name_ast('Decreases', node), [ast.Constant(None)], []))] + decorator_list: list[ast.expr] = [ast.Name('property'), ast.Name('ContractOnly')] + function_def = ast.FunctionDef(node.id, ast_arguments, stmts, decorator_list, returns=annotation, lineno=node.lineno, col_offset=0) + self.visit(function_def, self.current_class.node) + else: + self.current_class.add_field(node.id, node, self.typeof(node)) + + # Adjust the class body + self.current_class.node.body.remove(assign) + if self.current_class.frozen: + self.current_class.node.body.append(function_def) + + if assign.value != None: + field_obj = self.current_class.fields[node.id] + if (isinstance(assign.value, ast.Call) + and isinstance(assign.value.func, ast.Name) + and assign.value.func.id == 'field'): + # Handle dataclasses.field(default_factory=...) + factory = None + for kw in assign.value.keywords: + if kw.arg == 'default_factory': + factory = kw.value + else: + raise UnsupportedException(assign, 'unsupported keyword') + if factory is None: + raise UnsupportedException(assign, + 'field() without default_factory not supported') + # Use None as sentinel default + field_obj.result = ast.Constant(None, + lineno=node.lineno, col_offset=0) + field_obj.default_factory = factory.id + elif not isinstance(assign.value, (ast.Constant, ast.Attribute)): + raise UnsupportedException(assign, 'Illegal default value for datafield creation') + else: + # Temporarily set value, because it will be used as default + field_obj.result = assign.value + return + elif self.current_class.superclass.name == "IntEnum": + # Node is an enum member. Basically a static field that returns an instance of the enum instead + if isinstance(node.ctx, ast.Load): + return + + node_type = self.typeof(node) + if node_type.name != INT_TYPE: + raise InvalidProgramException(node, 'invalid literal for int() with base 10') + + assign = node._parent + if (not isinstance(assign, ast.Assign) + or len(assign.targets) != 1): + msg = ('only simple assignments and reads allowed for ' + 'enum members') + raise UnsupportedException(assign, msg) + self.create_static_field(node, self.current_class, assign.value) + return else: # Node is a static field. if isinstance(node.ctx, ast.Load): return - cls = self.typeof(node) - self.define_new(self.current_class, node.id, node) - var = self.node_factory.create_static_field(node.id, node, cls, - self.module, - self.current_class) + assign = node._parent if (not isinstance(assign, ast.Assign) or len(assign.targets) != 1): msg = ('only simple assignments and reads allowed for ' - 'static fields') + 'static fields') raise UnsupportedException(assign, msg) - var.value = assign.value - self.current_class.static_fields[node.id] = var - if node.id in self.current_class.fields: - # It's possible that we encountered a read of this field - # before seeing the definition, assumed it's a normal - # (non-static) field, and created the field. We remove it - # again now that we now it's actually static. - del self.current_class.fields[node.id] - self.track_access(node, var) + + cls = self.typeof(node) + self.create_static_field(node, cls, assign.value) return # We're in a function if isinstance(node.ctx, ast.Store): @@ -1207,6 +1368,23 @@ def todo(): self.track_access(node, var) + def create_static_field(self, node: ast.Name, type_: PythonType, val: ast.expr) -> None: + assert self.current_class != None + + self.define_new(self.current_class, node.id, node) + var = self.node_factory.create_static_field(node.id, node, type_, + self.module, + self.current_class) + var.value = val + self.current_class.static_fields[node.id] = var + if node.id in self.current_class.fields: + # It's possible that we encountered a read of this field + # before seeing the definition, assumed it's a normal + # (non-static) field, and created the field. We remove it + # again now that we now it's actually static. + del self.current_class.fields[node.id] + self.track_access(node, var) + def visit_Attribute(self, node: ast.Attribute) -> None: """ Tracks field accesses to find out which fields exist. @@ -1255,6 +1433,8 @@ def convert_type(self, mypy_type, node, bound_type_vars: Dict[str, PythonType] = """ Converts an internal mypy type to a PythonType. """ + if (self.types.is_literal_type(mypy_type)): + mypy_type = mypy_type.fallback if (self.types.is_void_type(mypy_type) or self.types.is_none_type(mypy_type)): result = None @@ -1313,7 +1493,18 @@ def convert_type(self, mypy_type, node, bound_type_vars: Dict[str, PythonType] = msg = f'Type could not be fully inferred (this usually means that a type argument is unknown)' raise InvalidProgramException(node, 'partial.type', message=msg) else: - msg = 'Unsupported type: {}'.format(mypy_type.__class__.__name__) + name = "" + if hasattr(node, 'id'): + name = node.id + elif hasattr(node, 'name'): + name = node.name + elif isinstance(node, ast.Attribute): + name = node.attr + if hasattr(node.value, 'id'): + name = node.value.id + "." + name + elif isinstance(node, ast.arg): + name = node.arg + msg = 'Unsupported type: {} for node {} of type {}'.format(mypy_type.__class__.__name__, name, type(node)) raise UnsupportedException(node, desc=msg) return result @@ -1577,13 +1768,18 @@ def visit_Try(self, node: ast.Try) -> None: self.stmt_container.labels.append(finally_name) self.visit_default(node) - def _incompatible_decorators(self, decorators) -> bool: + def _class_incompatible_decorators(self, decorators: set[str]) -> bool: + return ((('dataclass' in decorators) and (len(decorators) != 1)) or + (('dataclass' not in decorators) and (len(decorators) > 0)) + ) + + def _function_incompatible_decorators(self, decorators) -> bool: return ((('Predicate' in decorators) and ('Pure' in decorators)) or (('Opaque' in decorators) and ('Pure' not in decorators)) or (('Predicate' in decorators) and ('Inline' in decorators)) or (('Inline' in decorators) and ('Pure' in decorators)) or (('IOOperation' in decorators) and (len(decorators) != 1)) or - (('property' in decorators) and (len(decorators) != 1)) or + (('property' in decorators) and not(len(decorators) == 1 or (len(decorators) == 2 and 'ContractOnly' in decorators))) or (('AllLow' in decorators) and ('PreservesLow' in decorators)) or ((('AllLow' in decorators) or ('PreservesLow' in decorators)) and ( ('Predicate' in decorators) or ('Pure' in decorators))) @@ -1595,7 +1791,7 @@ def is_declared_contract_only(self, func: ast.FunctionDef) -> bool: respective decorator. """ decorators = {d.id for d in func.decorator_list if isinstance(d, ast.Name)} - if self._incompatible_decorators(decorators): + if self._function_incompatible_decorators(decorators): raise InvalidProgramException(func, "decorators.incompatible") result = 'ContractOnly' in decorators or 'abstractmethod' in decorators return result @@ -1622,35 +1818,79 @@ def is_contract_only(self, func: ast.FunctionDef) -> bool: result = result or (not selected) return result - def has_decorator(self, func: ast.FunctionDef, decorator: str) -> bool: + def __resolve_decorator(self, decorator: ast.expr) -> Tuple[bool, str]: + if isinstance(decorator, ast.Name): + return (True, decorator.id) + elif isinstance(decorator, ast.Call): + return self.__resolve_decorator(decorator.func) + return (False, "") + + def __get_decorators(self, decorator_list: list[ast.expr]) -> set[str]: + return {res[1] for d in decorator_list if (res := self.__resolve_decorator(d))[0]} + + def __decorator_has_keyword_value(self, decorator_list: list[ast.expr], decorator: str, keyword: str, value) -> bool: + for d in decorator_list: + if isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == decorator: + for k in d.keywords: + if k.arg == keyword and isinstance(k.value, ast.Constant): + return k.value.value == value + return False + + def class_has_decorator(self, cls: ast.ClassDef, decorator: str) -> bool: + decorators = self.__get_decorators(cls.decorator_list) + if self._class_incompatible_decorators(decorators): + raise InvalidProgramException(cls, "decorators.incompatible") + return decorator in decorators + + def is_dataclass(self, cls: ast.ClassDef) -> bool: + is_dataclass = self.class_has_decorator(cls, 'dataclass') + if is_dataclass: + self._dataclass_check_unsupported_keywords(cls) + return is_dataclass + + def _dataclass_check_unsupported_keywords(self, cls: ast.ClassDef) -> None: + decorator = [d for d in cls.decorator_list if self.__resolve_decorator(d)[1] == 'dataclass'][0] + if isinstance(decorator, ast.Name): + return + assert isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Name) + supported_keywords = ["frozen"] + + for k in decorator.keywords: + if not k.arg in supported_keywords: + raise UnsupportedException(decorator, "keyword unsupported") + + def is_frozen_dataclass(self, cls: ast.ClassDef) -> bool: + return self.__decorator_has_keyword_value(cls.decorator_list, 'dataclass', 'frozen', True) + + def function_has_decorator(self, func: ast.FunctionDef, decorator: str) -> bool: decorators = {d.id for d in func.decorator_list if isinstance(d, ast.Name)} - if self._incompatible_decorators(decorators): + if self._function_incompatible_decorators(decorators): raise InvalidProgramException(func, "decorators.incompatible") return decorator in decorators def is_pure(self, func: ast.FunctionDef) -> bool: - return self.has_decorator(func, 'Pure') + return self.function_has_decorator(func, 'Pure') def is_opaque(self, func: ast.FunctionDef) -> bool: - return self.has_decorator(func, 'Opaque') + return self.function_has_decorator(func, 'Opaque') def is_predicate(self, func: ast.FunctionDef) -> bool: - return self.has_decorator(func, 'Predicate') + return self.function_has_decorator(func, 'Predicate') def is_inline_method(self, func: ast.FunctionDef) -> bool: - return self.has_decorator(func, 'Inline') + return self.function_has_decorator(func, 'Inline') def is_static_method(self, func: ast.FunctionDef) -> bool: - return self.has_decorator(func, 'staticmethod') + return self.function_has_decorator(func, 'staticmethod') def is_class_method(self, func: ast.FunctionDef) -> bool: - return self.has_decorator(func, 'classmethod') - + return self.function_has_decorator(func, 'classmethod') + def is_io_operation(self, func: ast.FunctionDef) -> bool: - return self.has_decorator(func, 'IOOperation') + return self.function_has_decorator(func, 'IOOperation') def is_property_getter(self, func: ast.FunctionDef) -> bool: - return self.has_decorator(func, 'property') + return self.function_has_decorator(func, 'property') def is_property_setter(self, func: ast.FunctionDef) -> bool: setter_decorator = [d for d in func.decorator_list @@ -1662,7 +1902,7 @@ def is_property_setter(self, func: ast.FunctionDef) -> bool: return self.current_class.fields[setter_decorator[0].value.id] def is_all_low(self, func: ast.FunctionDef) -> bool: - return self.has_decorator(func, 'AllLow') + return self.function_has_decorator(func, 'AllLow') def preserves_low(self, func: ast.FunctionDef) -> bool: - return self.has_decorator(func, 'PreservesLow') + return self.function_has_decorator(func, 'PreservesLow') diff --git a/src/nagini_translation/lib/config.py b/src/nagini_translation/lib/config.py index c9a4a8677..fd8ae273f 100644 --- a/src/nagini_translation/lib/config.py +++ b/src/nagini_translation/lib/config.py @@ -324,6 +324,16 @@ def set_verifier(v: str): Test configuration. """ +enable_preprocessing = False +""" +Enable Preprocessing of files. +""" + +comment_pattern = "" +""" +Comment pattern to preprocess. +""" + __all__ = ( 'classpath', @@ -333,4 +343,6 @@ def set_verifier(v: str): 'mypy_dir', 'obligation_config', 'set_verifier', + 'enable_preprocessing', + 'comment_pattern' ) diff --git a/src/nagini_translation/lib/constants.py b/src/nagini_translation/lib/constants.py index 6110f8e81..9df4d076e 100644 --- a/src/nagini_translation/lib/constants.py +++ b/src/nagini_translation/lib/constants.py @@ -25,18 +25,20 @@ 'range', 'type', 'list', - 'enumerate'] + 'enumerate', + 'bytearray'] EXTENDABLE_BUILTINS = [ 'object', 'Exception', 'Lock', - 'int' + 'int', + 'IntEnum' ] THREADING = ['Thread'] -BUILTIN_PREDICATES = ['list_pred', 'set_pred', 'dict_pred', 'MayStart', 'ThreadPost'] +BUILTIN_PREDICATES = ['list_pred', 'set_pred', 'dict_pred', 'bytearray_pred', 'MayStart', 'ThreadPost'] FUNCTION_DOMAIN_NAME = 'Function' @@ -258,11 +260,15 @@ '__ror__', '__init__', + '__post_init__', '__enter__', '__exit__', '__str__', + '__repr__', '__len__', '__bool__', + '__format__', + '__hash__', '__getitem__', '__setitem__', @@ -294,6 +300,8 @@ PSEQ_TYPE = 'PSeq' +PBYTESEQ_TYPE = 'PByteSeq' + PSET_TYPE = 'PSet' PMSET_TYPE = 'PMultiset' @@ -310,6 +318,8 @@ BYTES_TYPE = 'bytes' +BYTEARRAY_TYPE = 'bytearray' + INT_TYPE = 'int' PERM_TYPE = 'perm' @@ -361,6 +371,8 @@ IGNORED_IMPORTS = {'_importlib_modulespec', 'abc', 'builtins', + 'dataclasses', + 'enum', 'nagini_contracts', 'nagini_contracts.adt', 'nagini_contracts.contracts', @@ -377,6 +389,8 @@ '_importlib_modulespec': [], 'abc': [], 'builtins': [], + 'dataclasses': [], + 'enum': [], 'nagini_contracts': [], 'nagini_contracts.contracts': [], 'nagini_contracts.io_contracts': [], diff --git a/src/nagini_translation/lib/program_nodes.py b/src/nagini_translation/lib/program_nodes.py index a503ebf1b..cd238c439 100644 --- a/src/nagini_translation/lib/program_nodes.py +++ b/src/nagini_translation/lib/program_nodes.py @@ -26,6 +26,7 @@ PRIMITIVE_SET_TYPE, PRIMITIVES, PSEQ_TYPE, + PBYTESEQ_TYPE, PSET_TYPE, RESULT_NAME, STRING_TYPE, @@ -410,6 +411,11 @@ def __init__(self, name: str, superscope: PythonScope, self.static_fields = OrderedDict() self.type = None # infer, domain type self.interface = interface + self.dataclass = False + self.implicit_init = False + self.frozen = False + self.enum = False + self.enum_type = None self.defined = False self._has_classmethod = False self.type_vars = OrderedDict() @@ -750,6 +756,8 @@ def try_box(self) -> 'PythonClass': boxed_name = PMSET_TYPE if boxed_name == 'Seq': boxed_name = PSEQ_TYPE + if boxed_name == 'PByteSeq': + boxed_name = PBYTESEQ_TYPE return self.module.classes[boxed_name] return self @@ -1593,6 +1601,7 @@ def __init__(self, name: str, node: ast.AST, type: PythonClass): self.alt_types = {} self.default = None self.default_expr = None + self.default_factory = None self.show_in_ce = True def process(self, sil_name: str, translator: 'Translator') -> None: diff --git a/src/nagini_translation/lib/resolver.py b/src/nagini_translation/lib/resolver.py index 48961e23d..aac560979 100644 --- a/src/nagini_translation/lib/resolver.py +++ b/src/nagini_translation/lib/resolver.py @@ -12,6 +12,7 @@ BOOL_TYPE, BUILTINS, BYTES_TYPE, + BYTEARRAY_TYPE, DICT_TYPE, ELLIPSIS_TYPE, FLOAT_TYPE, @@ -22,6 +23,7 @@ RIGHT_OPERATOR_FUNCTIONS, PMSET_TYPE, PSEQ_TYPE, + PBYTESEQ_TYPE, PSET_TYPE, RANGE_TYPE, SET_TYPE, @@ -297,6 +299,8 @@ def _do_get_type(node: ast.AST, containers: List[ContainerInterface], return module.global_module.classes[ELLIPSIS_TYPE] else: raise UnsupportedException(node, f"Unsupported constant value type {type(node.value)}") + if isinstance(node, (ast.JoinedStr, ast.FormattedValue)): + return module.global_module.classes[STRING_TYPE] if isNum(node): if isinstance(node.value, int): return module.global_module.classes[INT_TYPE] @@ -447,6 +451,9 @@ def _get_call_type(node: ast.Call, module: PythonModule, return module.global_module.classes[INT_TYPE] if func_name in ('token', 'ctoken', 'MustTerminate', 'MustRelease'): return module.global_module.classes[BOOL_TYPE] + # if func_name == BYTEARRAY_TYPE: + # return _get_collection_literal_type(node, ['args'], BYTEARRAY_TYPE, module, + # containers, container) if func_name == PSEQ_TYPE: return _get_collection_literal_type(node, ['args'], PSEQ_TYPE, module, containers, container) @@ -494,6 +501,8 @@ def _get_call_type(node: ast.Call, module: PythonModule, seq_class = module.global_module.classes[PSEQ_TYPE] content_type = _get_iteration_type(arg_type, module, node) return GenericType(seq_class, [content_type]) + elif node.func.id == 'ToByteSeq': + return module.global_module.classes[PBYTESEQ_TYPE] elif node.func.id == 'ToMS': arg_type = get_type(node.args[0], containers, container) ms_class = module.global_module.classes[PMSET_TYPE] @@ -601,7 +610,7 @@ def _get_subscript_type(value_type: PythonType, module: PythonModule, # FIXME: This is very unfortunate, but right now we cannot handle this # generically, so we have to hard code these two cases for the moment. return value_type.type_args[1] - elif value_type.name in (RANGE_TYPE, BYTES_TYPE): + elif value_type.name in (RANGE_TYPE, BYTES_TYPE, BYTEARRAY_TYPE, PBYTESEQ_TYPE): return module.global_module.classes[INT_TYPE] elif value_type.name == PSEQ_TYPE: return value_type.type_args[0] diff --git a/src/nagini_translation/lib/silver_nodes/types.py b/src/nagini_translation/lib/silver_nodes/types.py index 025863975..92c69d2cc 100644 --- a/src/nagini_translation/lib/silver_nodes/types.py +++ b/src/nagini_translation/lib/silver_nodes/types.py @@ -164,3 +164,19 @@ def translate(self, translator: 'AbstractTranslator', ctx: 'Context', elements = [element.translate(translator, ctx, position, info) for element in self._elements] return translator.viper.ExplicitSeq(elements, position, info) + +class PByteSeq: + """A helper class for generating Silver sequences.""" + + def __init__(self, elements: List['Expression']) -> None: + self._elements = elements + + def translate(self, translator: 'AbstractTranslator', ctx: 'Context', + position: Position, info: Info) -> Expr: + """Translate to Silver sequence.""" + if not self._elements: + return translator.viper.EmptySeq(self.viper.Int, position, info) + else: + elements = [element.translate(translator, ctx, position, info) + for element in self._elements] + return translator.viper.ExplicitSeq(elements, position, info) diff --git a/src/nagini_translation/lib/typeinfo.py b/src/nagini_translation/lib/typeinfo.py index d3c853836..dc5acdbe3 100644 --- a/src/nagini_translation/lib/typeinfo.py +++ b/src/nagini_translation/lib/typeinfo.py @@ -12,6 +12,7 @@ import os from mypy.build import BuildSource +from mypy.fscache import FileSystemCache from nagini_translation.lib import config from nagini_translation.lib.constants import IGNORED_IMPORTS, LITERALS from nagini_translation.mypy_patches.visitor import TraverserVisitor @@ -19,6 +20,7 @@ from nagini_translation.lib.util import ( construct_lambda_prefix, + read_source_file, ) from typing import List, Optional @@ -267,6 +269,37 @@ def visit_comparison_expr(self, o: mypy.nodes.ComparisonExpr): if 'is' not in o.operators and 'is not' not in o.operators: super().visit_comparison_expr(o) +class PreprocessingFileSystemCache(FileSystemCache): + """ + Slightly adjusted FileSystemCache that invokes the custom read_source_file + to read file data. + """ + def read(self, path: str) -> bytes: + if path in self.read_cache: + return self.read_cache[path] + if path in self.read_error_cache: + raise self.read_error_cache[path] + + # Need to stat first so that the contents of file are from no + # earlier instant than the mtime reported by self.stat(). + self.stat(path) + + dirname, basename = os.path.split(path) + dirname = os.path.normpath(dirname) + # Check the fake cache. + if basename == '__init__.py' and dirname in self.fake_package_cache: + data = b'' + else: + try: + text = read_source_file(path) + data = text.encode() + except OSError as err: + self.read_error_cache[path] = err + raise + + self.read_cache[path] = data + self.hash_cache[path] = mypy.util.hash_digest(data) + return data class TypeInfo: """ @@ -303,10 +336,10 @@ def _create_options(self, strict_optional: bool): result.cache_dir = '.mypy_cache_strict' if strict_optional else '.mypy_cache_nonstrict' return result - def check(self, filename: str, base_dir: str = None) -> bool: + def check(self, filename: str, base_dir: str = None, text: Optional[str] = None) -> bool: """ Typechecks the given file and collects all type information needed for - the translation to Viper + the translation to Viper. Optionally pass preprocessed text content. """ def report_errors(errors: List[str]) -> None: @@ -365,18 +398,17 @@ def my_find_cache_meta(id, path, mgr): # In Python 3.9 or newer, we use the incremental mode, and we have to monkey-patch mypy. mypy.build.find_cache_meta = my_find_cache_meta - sources = [BuildSource(filename, module_name, None, base_dir=base_dir)] + sources = [BuildSource(filename, module_name, base_dir=base_dir)] + # fscache = PreprocessingFileSystemCache() + # Ignoring preprocessing for now res_strict = mypy.build.build(sources, options_strict) if res_strict.errors: # Run mypy a second time with strict optional checking disabled, # s.t. we don't get overapproximated none-related errors. options_non_strict = self._create_options(False) - res_non_strict = mypy.build.build( - [BuildSource(filename, module_name, None, base_dir=base_dir)], - options_non_strict - ) + res_non_strict = mypy.build.build(sources, options_non_strict) if res_non_strict.errors: report_errors(res_non_strict.errors) relevant_files = [next(iter(res_strict.graph))] @@ -463,6 +495,9 @@ def is_normal_type(self, type: mypy.types.Type) -> bool: def is_instance_type(self, type: mypy.types.Type) -> bool: return isinstance(type, mypy.types.Instance) + def is_literal_type(self, type: mypy.types.Type) -> bool: + return isinstance(type, mypy.types.LiteralType) + def is_tuple_type(self, type: mypy.types.Type) -> bool: return isinstance(type, mypy.types.TupleType) diff --git a/src/nagini_translation/lib/util.py b/src/nagini_translation/lib/util.py index f11f2fb45..1790eba22 100644 --- a/src/nagini_translation/lib/util.py +++ b/src/nagini_translation/lib/util.py @@ -7,6 +7,9 @@ import ast import astunparse +import re +import tokenize +from nagini_translation.lib import config from typing import ( Any, @@ -23,6 +26,32 @@ V = TypeVar('V') +def preprocess_text(text: str, comment_prefix: str) -> str: + """ + Preprocesses the file text by transforming special comments into code. + Comments starting with the specified prefix will be converted to regular code. + """ + # Pattern: (whitespace)(comment_prefix)(optional space)(rest of line) + escaped_prefix = re.escape(comment_prefix) + pattern = re.compile(r'^(\s*)' + escaped_prefix + r' ?(.*)') + + def process_line(line: str) -> str: + match = pattern.match(line) + if match: + # Return indentation + code + return match.group(1) + match.group(2) + return line + + return '\n'.join(process_line(line) for line in text.split('\n')) + +def read_source_file(path: str) -> str: + with tokenize.open(path) as file: + text = file.read() + + if config.enable_preprocessing: + text = preprocess_text(text, config.comment_pattern) + return text + def flatten(lists: List[List[T]]) -> List[T]: """ Flattens a list of lists into a flat list diff --git a/src/nagini_translation/main.py b/src/nagini_translation/main.py index 9bad89d54..b3b11508d 100755 --- a/src/nagini_translation/main.py +++ b/src/nagini_translation/main.py @@ -12,6 +12,7 @@ import logging import os import sys +import threading import re import time import traceback @@ -51,7 +52,7 @@ ViperVerifier ) from nagini_translation import verifier -from typing import List, Set, Tuple +from typing import List, Set, Tuple, Union TYPE_ERROR_PATTERN = r"^(?P.*):(?P\d+): error: (?P.*)$" @@ -65,8 +66,8 @@ def parse_sil_file(sil_path: str, bv_path: str, bv_size: int, jvm, float_option: with open(sil_path, 'r') as file: text = file.read() with open(bv_path, 'r') as file: - int_min = -(2 ** (bv_size - 1)) - int_max = 2 ** (bv_size - 1) - 1 + int_min = -(2 ** (bv_size)) + int_max = 2 ** (bv_size) - 1 text += "\n" + file.read().replace("NBITS", str(bv_size)).replace("INT_MIN_VAL", str(int_min)).replace("INT_MAX_VAL", str(int_max)) if float_option == "real": text = text.replace("float.sil", "float_real.sil") @@ -120,12 +121,14 @@ def translate(path: str, jvm: JVM, bv_size: int, selected: Set[str] = set(), bas raise Exception('Viper not found on classpath.') if sif and not viper_ast.is_extension_available(): raise Exception('Viper AST SIF extension not found on classpath.') + types = TypeInfo() type_correct = types.check(path, base_dir) if not type_correct: return None analyzer = Analyzer(types, path, selected) + analyzer.comment_pattern = config.comment_pattern main_module = analyzer.module with open(os.path.join(builtins_index_path, 'builtins.json'), 'r') as file: analyzer.add_native_silver_builtins(json.loads(file.read())) @@ -193,6 +196,11 @@ def collect_modules(analyzer: Analyzer, path: str) -> None: for task in analyzer.deferred_tasks: task() +def get_verifier(path: str, jvm: JVM, viper_args: List[str], backend=ViperVerifier.silicon, counterexample=False) -> Union[Silicon, Carbon]: + if backend == ViperVerifier.silicon: + return Silicon(jvm, path, viper_args, counterexample) + elif backend == ViperVerifier.carbon: + return Carbon(jvm, path, viper_args) def verify(modules, prog: 'viper.silver.ast.Program', path: str, jvm: JVM, viper_args: List[str], backend=ViperVerifier.silicon, arp=False, counterexample=False, sif=False) -> VerificationResult: @@ -200,10 +208,7 @@ def verify(modules, prog: 'viper.silver.ast.Program', path: str, jvm: JVM, viper Verifies the given Viper program """ try: - if backend == ViperVerifier.silicon: - verifier = Silicon(jvm, path, viper_args, counterexample) - elif backend == ViperVerifier.carbon: - verifier = Carbon(jvm, path, viper_args) + verifier = get_verifier(path, jvm, viper_args, backend, counterexample) vresult = verifier.verify(modules, prog, arp=arp, sif=sif) return vresult except JException as je: @@ -295,6 +300,11 @@ def main() -> None: help=('run verification the given number of times to benchmark ' 'performance'), default=-1) + parser.add_argument( + '--benchmark-timeout', + type=int, + help='timeout in seconds for each benchmark run', + default=-1) parser.add_argument( '--ide-mode', action='store_true', @@ -350,6 +360,17 @@ def main() -> None: type=int, default=8 ) + parser.add_argument( + '--comment-pattern', + help='Preprocess comments with pattern', + type=str, + default="#@nagini" + ) + parser.add_argument( + '--preprocess', + action='store_true', + help='Enable preprocessing', + ) args = parser.parse_args() config.classpath = args.viper_jar_path @@ -357,6 +378,8 @@ def main() -> None: config.z3_path = args.z3 config.mypy_path = args.mypy_path config.set_verifier(args.verifier) + config.enable_preprocessing = args.preprocess + config.comment_pattern = args.comment_pattern if args.ignore_obligations: if args.force_obligations: parser.error('incompatible arguments: --ignore-obligations and --force-obligations') @@ -431,15 +454,54 @@ def translate_and_verify(python_file, jvm, args, print=print, arp=False, base_di raise ValueError('Unknown verifier specified: ' + args.verifier) viper_args = [] if args.viper_arg is None else args.viper_arg.split(",") if args.benchmark >= 1: - print("Run, Total, Start, End, Time".format()) + print("Run, Total, Start, End, Time, Result") + n_success = 0 + n_failure = 0 + n_timeout = 0 for i in range(args.benchmark): start = time.time() - modules, prog = translate(python_file, jvm, args.int_bitops_size, selected=selected, sif=args.sif, arp=arp, base_dir=base_dir, - ignore_global=args.ignore_global, float_encoding=args.float_encoding) - vresult = verify(modules, prog, python_file, jvm, viper_args, backend=backend, arp=arp) + timed_out = False + + verifier_ref = [None] + result_ref = [None] + def _run_iteration(vholder=verifier_ref, rholder=result_ref): + try: + modules_local, prog_local = translate(python_file, jvm, args.int_bitops_size, selected=selected, sif=args.sif, + arp=arp, base_dir=base_dir, ignore_global=args.ignore_global, float_encoding=args.float_encoding) + ver = get_verifier(python_file, jvm, viper_args, backend) + vholder[0] = ver + rholder[0] = ver.verify(modules_local, prog_local, arp=arp) + except Exception: + pass + + thread = threading.Thread(target=_run_iteration, daemon=True) + thread.start() + timeout = args.benchmark_timeout if args.benchmark_timeout > 0 else None + thread.join(timeout=timeout) + + if thread.is_alive(): + timed_out = True + ver = verifier_ref[0] + if ver is not None: + ver.stop() + thread.join(timeout=10) end = time.time() - print("{}, {}, {}, {}, {}".format( - i, args.benchmark, start, end, end - start)) + + if timed_out: + n_timeout += 1 + print("{}, {}, {}, {}, TIMEOUT, TIMEOUT".format( + i, args.benchmark, start, end)) + else: + vresult = result_ref[0] + result_str = "SUCCESS" if isinstance(vresult, verifier.Success) else "FAILURE" + if isinstance(vresult, verifier.Success): + n_success += 1 + else: + n_failure += 1 + print("{}, {}, {}, {}, {}, {}".format( + i, args.benchmark, start, end, end - start, result_str)) + print("Results: {} success, {} failure, {} timeout out of {} runs".format( + n_success, n_failure, n_timeout, args.benchmark)) else: submitter = None if args.submit_for_evaluation: @@ -452,12 +514,13 @@ def translate_and_verify(python_file, jvm, args, print=print, arp=False, base_di if submitter is not None: submitter.setSuccess(vresult.__bool__()) submitter.submit() - if args.verbose: - print("Verification completed.") - print(vresult.to_string(args.ide_mode, args.show_viper_errors)) - duration = '{:.2f}'.format(time.time() - start) - print('Verification took ' + duration + ' seconds.') - return isinstance(vresult, verifier.Success) + if args.verbose: + print("Verification completed.") + print(vresult.to_string(args.ide_mode, args.show_viper_errors)) + duration = '{:.2f}'.format(time.time() - start) + print('Verification took ' + duration + ' seconds.') + return isinstance(vresult, verifier.Success) + return True except (TypeException, InvalidProgramException, UnsupportedException) as e: print("Translation failed") if isinstance(e, (InvalidProgramException, UnsupportedException)): @@ -471,11 +534,13 @@ def translate_and_verify(python_file, jvm, args, print=print, arp=False, base_di issue = 'Not supported: ' if e.args[0]: issue += e.args[0] - else: + elif e.node != None: issue += astunparse.unparse(e.node) - line = str(e.node.lineno) - col = str(e.node.col_offset) - print(issue + ' (' + python_file + '@' + line + '.' + col + ')') + if e.node != None: + line = str(e.node.lineno) + col = str(e.node.col_offset) + print(issue + ' (' + python_file + '@' + line + '.' + col + ')') + traceback.print_exc() if isinstance(e, TypeException): for msg in e.messages: parts = TYPE_ERROR_MATCHER.match(msg) @@ -487,6 +552,7 @@ def translate_and_verify(python_file, jvm, args, print=print, arp=False, base_di msg = parts['msg'] line = parts['line'] print('Type error: ' + msg + ' (' + file + '@' + line + '.0)') + traceback.print_exc() else: print(msg) return False diff --git a/src/nagini_translation/models/converter.py b/src/nagini_translation/models/converter.py index cb5701bb9..0b0219d3e 100644 --- a/src/nagini_translation/models/converter.py +++ b/src/nagini_translation/models/converter.py @@ -20,6 +20,7 @@ UNBOX_INT = 'int___unbox__%limited' UNBOX_BOOL = 'bool___unbox__%limited' UNBOX_PSEQ = 'PSeq___sil_seq__%limited' +UNBOX_PBYTESEQ = 'PByteSeq___val__%limited' TYPEOF = 'typeof' SNAP_TO = '$SortWrappers.' SET_CARD = 'Set_card' @@ -279,6 +280,8 @@ def convert_python_field(self, recv, field, value, heap_contents, target, target receiver_type = global_module.classes['list'] elif field == 'set_acc': receiver_type = global_module.classes['set'] + elif field == 'bytearray_acc': + receiver_type = global_module.classes['bytearray'] elif field == '_val': # This is a global variable. var_sil_name = str(recv.applicable().id()) @@ -513,6 +516,8 @@ def convert_value(self, val, t: PythonType, name: str = None): return self.convert_bool_value(val) elif t.python_class.name == 'PSeq': return self.convert_pseq_value(val, t, name) + elif t.python_class.name == 'PByteSeq': + return self.convert_PByteSeq_value(val, name) elif t.python_class.is_adt: return self.convert_adt_value(val, t) elif isinstance(t, GenericType) and t.python_class.name == 'tuple': @@ -625,6 +630,12 @@ def convert_pseq_value(self, val, t: PythonType, name): sequence_info = self.convert_sequence_value(sequence, t.type_args[0], name) return 'Sequence: {{ {} }}'.format(', '.join(['{} -> {}'.format(k, v) for k, v in sequence_info.items()])) + def convert_PByteSeq_value(self, val, name): + sequence = self.get_func_value(UNBOX_PBYTESEQ, (UNIT, val)) + int_type = self.modules[0].global_module.classes['int'] + sequence_info = self.convert_sequence_value(sequence, int_type, name) + return 'Sequence: {{ {} }}'.format(', '.join(['{} -> {}'.format(k, v) for k, v in sequence_info.items()])) + def convert_int_value(self, val): if self.ref_has_type(val, 'bool'): return self.convert_bool_value(val) diff --git a/src/nagini_translation/models/extractor.py b/src/nagini_translation/models/extractor.py index 77152666c..dfe3d541e 100644 --- a/src/nagini_translation/models/extractor.py +++ b/src/nagini_translation/models/extractor.py @@ -81,7 +81,7 @@ def extract_field_chunk(self, chunk, jvm, modules, model, target): value = None if field_name in ('__iter_index', '__previous', '__container'): return - if field_name in ('list_acc', 'set_acc', 'dict_acc', '_val', 'MustReleaseBounded', 'MustReleaseUnbounded'): + if field_name in ('list_acc', 'set_acc', 'dict_acc', 'bytearray_acc', '_val', 'MustReleaseBounded', 'MustReleaseUnbounded'): # Special handling, pyfield = field_name else: diff --git a/src/nagini_translation/resources/all.sil b/src/nagini_translation/resources/all.sil index 2043eb14f..a1e7fc00e 100644 --- a/src/nagini_translation/resources/all.sil +++ b/src/nagini_translation/resources/all.sil @@ -23,6 +23,7 @@ domain SIFDomain[T] { import "bool.sil" import "float.sil" import "references.sil" +import "bytearray.sil" import "bytes.sil" import "iterator.sil" import "list.sil" @@ -31,6 +32,7 @@ import "measures.sil" import "pytype.sil" import "range.sil" import "seq.sil" +import "byteseq.sil" import "pset.sil" import "set_dict.sil" import "slice.sil" diff --git a/src/nagini_translation/resources/bool.sil b/src/nagini_translation/resources/bool.sil index 3718a3d7d..0be94ec48 100644 --- a/src/nagini_translation/resources/bool.sil +++ b/src/nagini_translation/resources/bool.sil @@ -13,7 +13,8 @@ function object___bool__(self: Ref) : Bool ensures self != null ==> (let t == (typeof(self)) in ((!issubtype(t, float()) && !issubtype(t, list(list_arg(t, 0))) && !issubtype(t, set(set_arg(t, 0))) && - !issubtype(t, dict(dict_arg(t, 0), dict_arg(t, 1)))) ==> result)) + !issubtype(t, dict(dict_arg(t, 0), dict_arg(t, 1))) && + !issubtype(t, bytearray())) ==> result)) function NoneType___bool__(self: Ref) : Bool decreases _ @@ -118,12 +119,182 @@ function int___rxor__(self: Ref, other: Ref): Ref int___xor__(self, other) } +function __shift_factor8(amount: Int): Int + decreases _ + requires 0 <= amount && amount <= 8 +{ + amount == 0 ? 1 : + amount == 1 ? 2 : + amount == 2 ? 4 : + amount == 3 ? 8 : + amount == 4 ? 16 : + amount == 5 ? 32 : + amount == 6 ? 64 : + amount == 7 ? 128 : + 256 +} + +function __shift_factor32(amount: Int): Int + decreases _ + requires 0 <= amount && amount <= 32 +{ + amount <= 8 ? __shift_factor8(amount) : + amount <= 16 ? __shift_factor8(amount - 8) * 256 : + amount <= 24 ? __shift_factor8(amount - 16) * 256 * 256 : + __shift_factor8(amount - 24) * 256 * 256 * 256 +} + +function __shift_factor64(amount: Int): Int + decreases _ + requires 0 <= amount && amount <= 64 +{ + amount <= 32 ? __shift_factor32(amount) : + __shift_factor32(amount - 32) * __shift_factor32(32) +} + +function __prim__int___lshift__(self: Int, other: Int): Int + decreases _ + requires @error("Bitwise operations on ints can only be performed in the range set by the --int-bitops-size setting (default: 8 bits).")(other <= 64 || self >= _INT_MIN) + requires @error("Bitwise operations on ints can only be performed in the range set by the --int-bitops-size setting (default: 8 bits).")(other <= 64 || self <= _INT_MAX) + requires @error("Negative shift count.")(other >= 0) + requires @error("Bitwise operations on ints can only be performed in the range set by the --int-bitops-size setting (default: 8 bits).")(other <= _INT_MAX) +{ + 0 <= other <= 64 ? self * __shift_factor64(other) : + self >= 0 ? fromBVInt(shlBVInt(toBVInt(self), toBVInt(other))) : + -fromBVInt(shlBVInt(toBVInt(-self), toBVInt(other))) +} + +function __prim__int___rlshift__(self: Int, other: Int): Int + decreases _ + requires @error("Bitwise operations on ints can only be performed in the range set by the --int-bitops-size setting (default: 8 bits).")(other <= 64 || self >= _INT_MIN) + requires @error("Bitwise operations on ints can only be performed in the range set by the --int-bitops-size setting (default: 8 bits).")(other <= 64 || self <= _INT_MAX) + requires @error("Negative shift count.")(other >= 0) + requires @error("Bitwise operations on ints can only be performed in the range set by the --int-bitops-size setting (default: 8 bits).")(other <= _INT_MAX) +{ + __prim__int___lshift__(self, other) +} + +function int___lshift__(self: Ref, other: Ref): Ref + decreases _ + requires issubtype(typeof(self), int()) + requires issubtype(typeof(other), int()) + requires @error("Bitwise operations on ints can only be performed in the range set by the --int-bitops-size setting (default: 8 bits).")(!issubtype(typeof(self), bool()) ==> int___unbox__(other) <= 64 || int___unbox__(self) >= _INT_MIN) + requires @error("Bitwise operations on ints can only be performed in the range set by the --int-bitops-size setting (default: 8 bits).")(!issubtype(typeof(self), bool()) ==> int___unbox__(other) <= 64 || int___unbox__(self) <= _INT_MAX) + requires @error("Negative shift count.")(!issubtype(typeof(other), bool()) ==> int___unbox__(other) >= 0) + requires @error("Bitwise operations on ints can only be performed in the range set by the --int-bitops-size setting (default: 8 bits).")(!issubtype(typeof(other), bool()) ==> int___unbox__(other) <= _INT_MAX) +{ + __prim__int___box__(__prim__int___lshift__(int___unbox__(self), int___unbox__(other))) +} + + +function int___rlshift__(self: Ref, other: Ref): Ref + decreases _ + requires issubtype(typeof(self), int()) + requires issubtype(typeof(other), int()) + requires @error("Bitwise operations on ints can only be performed in the range set by the --int-bitops-size setting (default: 8 bits).")(!issubtype(typeof(self), bool()) ==> int___unbox__(other) <= 64 || int___unbox__(self) >= _INT_MIN) + requires @error("Bitwise operations on ints can only be performed in the range set by the --int-bitops-size setting (default: 8 bits).")(!issubtype(typeof(self), bool()) ==> int___unbox__(other) <= 64 || int___unbox__(self) <= _INT_MAX) + requires @error("Negative shift count.")(!issubtype(typeof(other), bool()) ==> int___unbox__(other) >= 0) + requires @error("Bitwise operations on ints can only be performed in the range set by the --int-bitops-size setting (default: 8 bits).")(!issubtype(typeof(other), bool()) ==> int___unbox__(other) <= _INT_MAX) +{ + int___lshift__(self, other) +} + +function __prim__int___rshift__(self: Int, other: Int): Int + decreases _ + requires @error("Bitwise operations on ints can only be performed in the range set by the --int-bitops-size setting (default: 8 bits).")(other <= 64 || self >= _INT_MIN) + requires @error("Bitwise operations on ints can only be performed in the range set by the --int-bitops-size setting (default: 8 bits).")(other <= 64 || self <= _INT_MAX) + requires @error("Negative shift count.")(other >= 0) + requires @error("Bitwise operations on ints can only be performed in the range set by the --int-bitops-size setting (default: 8 bits).")(other <= _INT_MAX) +{ + + 0 <= other <= 64 ? self / __shift_factor64(other) : + self >= 0 ? fromBVInt(shrBVInt(toBVInt(self), toBVInt(other))) : + -fromBVInt(shrBVInt(toBVInt(-self), toBVInt(other))) +} + +function __prim__int___rrshift__(self: Int, other: Int): Int + decreases _ + requires @error("Bitwise operations on ints can only be performed in the range set by the --int-bitops-size setting (default: 8 bits).")(other <= 64 || self >= _INT_MIN) + requires @error("Bitwise operations on ints can only be performed in the range set by the --int-bitops-size setting (default: 8 bits).")(other <= 64 || self <= _INT_MAX) + requires @error("Negative shift count.")(other >= 0) + requires @error("Bitwise operations on ints can only be performed in the range set by the --int-bitops-size setting (default: 8 bits).")(other <= _INT_MAX) +{ + __prim__int___rshift__(self, other) +} + +function int___rshift__(self: Ref, other: Ref): Ref + decreases _ + requires issubtype(typeof(self), int()) + requires issubtype(typeof(other), int()) + requires @error("Bitwise operations on ints can only be performed in the range set by the --int-bitops-size setting (default: 8 bits).")(!issubtype(typeof(self), bool()) ==> int___unbox__(other) <= 64 || int___unbox__(self) >= _INT_MIN) + requires @error("Bitwise operations on ints can only be performed in the range set by the --int-bitops-size setting (default: 8 bits).")(!issubtype(typeof(self), bool()) ==> int___unbox__(other) <= 64 || int___unbox__(self) <= _INT_MAX) + requires @error("Negative shift count.")(!issubtype(typeof(other), bool()) ==> int___unbox__(other) >= 0) + requires @error("Bitwise operations on ints can only be performed in the range set by the --int-bitops-size setting (default: 8 bits).")(!issubtype(typeof(other), bool()) ==> int___unbox__(other) <= _INT_MAX) +{ + __prim__int___box__(__prim__int___rshift__(int___unbox__(self), int___unbox__(other))) +} + + +function int___rrshift__(self: Ref, other: Ref): Ref + decreases _ + requires issubtype(typeof(self), int()) + requires issubtype(typeof(other), int()) + requires @error("Bitwise operations on ints can only be performed in the range set by the --int-bitops-size setting (default: 8 bits).")(!issubtype(typeof(self), bool()) ==> int___unbox__(other) <= 64 || int___unbox__(self) >= _INT_MIN) + requires @error("Bitwise operations on ints can only be performed in the range set by the --int-bitops-size setting (default: 8 bits).")(!issubtype(typeof(self), bool()) ==> int___unbox__(other) <= 64 || int___unbox__(self) <= _INT_MAX) + requires @error("Negative shift count.")(!issubtype(typeof(other), bool()) ==> int___unbox__(other) >= 0) + requires @error("Bitwise operations on ints can only be performed in the range set by the --int-bitops-size setting (default: 8 bits).")(!issubtype(typeof(other), bool()) ==> int___unbox__(other) <= _INT_MAX) +{ + int___rshift__(self, other) +} + function int___bool__(self: Ref) : Bool decreases _ requires self != null ==> issubtype(typeof(self), int()) ensures self == null ==> !result ensures self != null ==> result == (int___unbox__(self) != 0) +function __int_bit_length8(val: Int): Int + decreases _ + requires 0 <= val && val < 256 +{ + val < 1 ? 0 : + val < 2 ? 1 : + val < 4 ? 2 : + val < 8 ? 3 : + val < 16 ? 4 : + val < 32 ? 5 : + val < 64 ? 6 : + val < 128 ? 7 : + 8 +} + +function __int_bit_length32(val: Int): Int + decreases _ + requires 0 <= val && val < __shift_factor32(32) +{ + val < __shift_factor32(8) ? __int_bit_length8(val) : + val < __shift_factor32(16) ? __int_bit_length8(val / __shift_factor32(8)) + 8 : + val < __shift_factor32(24) ? __int_bit_length8(val / __shift_factor32(16)) + 16 : + __int_bit_length8(val / __shift_factor32(24)) + 24 +} + +function __int_bit_length64(val: Int): Int + decreases _ + requires 0 <= val && val < __shift_factor64(64) +{ + val < __shift_factor32(32) ? __int_bit_length32(val) : + __int_bit_length32(val / __shift_factor32(32)) + 32 +} + +function int_bit_length(self: Ref): Int + decreases _ + requires issubtype(typeof(self), int()) + requires @error("bit_length only supported up to 64 bits")(!issubtype(typeof(self), bool()) ==> int___unbox__(self) < __shift_factor64(64)) + requires @error("bit_length only supported up to 64 bits")(!issubtype(typeof(self), bool()) ==> int___unbox__(self) > -__shift_factor64(64)) +{ + __int_bit_length64(abs(int___unbox__(self))) +} + function __prim__int___box__(prim: Int): Ref decreases _ ensures typeof(result) == int() @@ -289,6 +460,12 @@ function int___int__(self: Ref): Ref requires issubtype(typeof(self), int()) ensures result == self +function int___byte_bounds__(value: Int): Bool + decreases _ +{ + 0 <= value < 256 +} + domain __ObjectEquality { function object___eq__(Ref, Ref): Bool @@ -358,6 +535,8 @@ method sorted(r: Ref) returns (rs: Ref) ensures list___len__(r) > 1 ==> forall i: Int :: { r.list_acc[i] } i >= 0 && i < list___len__(r) ==> int___unbox__(list___getitem__(r, __prim__int___box__(i))) >= int___unbox__(list___getitem__(rs, __prim__int___box__(0))) ensures list___len__(r) > 1 ==> forall i: Int :: { r.list_acc[i] } i >= 0 && i < list___len__(r) ==> int___unbox__(list___getitem__(r, __prim__int___box__(i))) <= int___unbox__(list___getitem__(rs, __prim__int___box__(list___len__(r) - 1))) +method hash(r: Ref) returns (h: Ref) + ensures issubtype(typeof(h), int()) function sum(r: Ref): Int requires issubtype(typeof(r), list(int())) @@ -370,6 +549,8 @@ function __seq_ref_to_seq_int(sr: Seq[Ref]): Seq[Int] ensures sr == Seq() ==> result == Seq() ensures forall r: Ref :: {__seq_ref_to_seq_int(Seq(r))} issubtype(typeof(r), int()) ==> __seq_ref_to_seq_int(Seq(r)) == Seq(int___unbox__(r)) ensures forall sr1: Seq[Ref], sr2: Seq[Ref] :: {__seq_ref_to_seq_int(sr1 ++ sr2)} __seq_ref_to_seq_int(sr1 ++ sr2) == __seq_ref_to_seq_int(sr1) ++ __seq_ref_to_seq_int(sr2) + ensures forall i: Int :: {(i in result)} (i in result) ==> (exists j: Int :: 0 <= j < |sr| && typeof(sr[j]) == int() && i == int___unbox__(sr[j]) && sr[j] == __prim__int___box__(i) && (sr[j] in sr)) + ensures forall i: Int :: {(i in result)} (i in result) ==> (exists j: Int :: 0 <= j && j < |sr| && issubtype(typeof(sr[j]), int()) && i == int___unbox__(sr[j]) && (sr[j] in sr)) decreases _ diff --git a/src/nagini_translation/resources/builtins.json b/src/nagini_translation/resources/builtins.json index fdd6e2b24..a5d17d4d4 100644 --- a/src/nagini_translation/resources/builtins.json +++ b/src/nagini_translation/resources/builtins.json @@ -18,6 +18,10 @@ "__str__": { "args": ["object"], "type": "str" + }, + "__repr__": { + "args": ["object"], + "type": "str" } } }, @@ -99,6 +103,10 @@ "__sil_seq__": { "args": ["list"], "type": "__prim__Seq" + }, + "__eq__": { + "args": ["list", "object"], + "type": "__prim__bool" } }, "type_vars": 1, @@ -313,6 +321,30 @@ "__rxor__": { "args": ["int", "int"], "type": "int" + }, + "__lshift__": { + "args": ["int", "int"], + "type": "int" + }, + "__rlshift__": { + "args": ["int", "int"], + "type": "int" + }, + "__rshift__": { + "args": ["int", "int"], + "type": "int" + }, + "__rrshift__": { + "args": ["int", "int"], + "type": "int" + }, + "bit_length": { + "args": ["int"], + "type": "__prim__int" + }, + "__byte_bounds__": { + "args": ["__prim__int"], + "type": "__prim__bool" } }, "extends": "float" @@ -505,6 +537,16 @@ "args": ["str"], "type": "list", "MustTerminate": true + }, + "format": { + "args": ["str", "object", "object"], + "type": "str", + "MustTerminate": true + }, + "__format__": { + "args": ["str", "str"], + "type": "str", + "MustTerminate": true } }, "extends": "object" @@ -559,6 +601,92 @@ }, "extends": "object" }, +"bytearray": { + "methods": { + "__init__": { + "args": [], + "type": null, + "MustTerminate": true + }, + "__initFromInt__": { + "args": ["int"], + "type": null, + "MustTerminate": true + }, + "__initFromList__": { + "args": ["list"], + "type": null, + "MustTerminate": true + }, + "__initFromBytearray__": { + "args": ["bytearray"], + "type": null, + "MustTerminate": true + }, + "append": { + "args": ["bytearray", "int"], + "type": null, + "MustTerminate": true + }, + "extend": { + "args": ["bytearray", "bytearray"], + "type": null, + "MustTerminate": true + }, + "reverse": { + "args": ["bytearray"], + "type": null, + "MustTerminate": true + }, + "__setitem__": { + "args": ["bytearray", "__prim__int", "int"], + "type": null, + "MustTerminate": true + }, + "__iter__": { + "args": ["bytearray"], + "type": "Iterator", + "MustTerminate": true + }, + "__getitem_slice__": { + "args": ["bytearray", "slice"], + "type": "bytearray", + "display_name": "__getitem__", + "MustTerminate": true + } + }, + "functions": { + "hex": { + "args": ["bytearray"], + "type": "str" + }, + "__len__": { + "args": ["bytearray"], + "type": "__prim__int" + }, + "__getitem__": { + "args": ["bytearray", "int"], + "type": "__prim__int" + }, + "__contains__": { + "args": ["bytearray", "int"], + "type": "__prim__bool" + }, + "__bool__": { + "args": ["bytearray"], + "type": "__prim__bool" + }, + "__eq__": { + "args": ["bytearray", "object"], + "type": "__prim__bool" + }, + "__sil_seq__": { + "args": ["bytearray"], + "type": "__prim__Seq" + } + }, + "extends": "object" +}, "tuple": { "functions": { "__create0__": { @@ -643,6 +771,22 @@ "__box__": { "args": ["__prim__int"], "type": "int" + }, + "__lshift__": { + "args": ["__prim__int", "__prim__int"], + "type": "__prim__int" + }, + "__rlshift__": { + "args": ["__prim__int", "__prim__int"], + "type": "__prim__int" + }, + "__rshift__": { + "args": ["__prim__int", "__prim__int"], + "type": "__prim__int" + }, + "__rrshift__": { + "args": ["__prim__int", "__prim__int"], + "type": "__prim__int" } } }, @@ -717,6 +861,67 @@ "type_vars": 1, "extends": "object" }, +"PByteSeq": { + "functions": { + "__create__": { + "args": ["__prim__Seq"], + "type": "PByteSeq" + }, + "__from_bytes__": { + "args": ["__prim__Seq"], + "type": "PByteSeq" + }, + "__unbox__": { + "args": ["PByteSeq"], + "type": "__prim__Seq" + }, + "__contains__": { + "args": ["PByteSeq", "__prim__int"], + "type": "__prim__bool" + }, + "__getitem__": { + "args": ["PByteSeq", "int"], + "type": "__prim__int" + }, + "__sil_seq__": { + "args": ["PByteSeq"], + "type": "__prim__Seq" + }, + "__seq_ref_to_seq_int__": { + "args": ["__prim__Seq"], + "type": "__prim__Seq" + }, + "__val__": { + "args": ["PByteSeq"], + "type": "__prim__Seq" + }, + "__len__": { + "args": ["PByteSeq"], + "type": "__prim__int" + }, + "take": { + "args": ["PByteSeq", "__prim__int"], + "type": "PByteSeq" + }, + "drop": { + "args": ["PByteSeq", "__prim__int"], + "type": "PByteSeq" + }, + "update": { + "args": ["PByteSeq", "__prim__int", "__prim__int"], + "type": "PByteSeq" + }, + "__add__": { + "args": ["PByteSeq", "PByteSeq"], + "type": "PByteSeq" + }, + "__eq__": { + "args": ["PByteSeq", "PByteSeq"], + "type": "__prim__bool" + } + }, + "extends": "object" +}, "PSet": { "functions": { "__create__": { @@ -911,6 +1116,9 @@ }, "extends": "object" }, +"IntEnum": { + "extends": "int" +}, "global": { "functions": { "max": { @@ -940,6 +1148,11 @@ "args": ["list"], "type": "list", "MustTerminate": true + }, + "hash": { + "args": ["object"], + "type": "int", + "MustTerminate": true } } } diff --git a/src/nagini_translation/resources/bytearray.sil b/src/nagini_translation/resources/bytearray.sil new file mode 100644 index 000000000..98ed1aec4 --- /dev/null +++ b/src/nagini_translation/resources/bytearray.sil @@ -0,0 +1,162 @@ +/* + * Copyright (c) 2025 ETH Zurich + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +// Bytearray only accepts values in the range [0, 255] +field bytearray_acc : Seq[Int] + +method bytearray___init__() returns (res: Ref) + ensures acc(res.bytearray_acc) + ensures res.bytearray_acc == Seq[Int]() + ensures typeof(res) == bytearray() + ensures Low(res) + +method bytearray___initFromInt__(length: Ref) returns (res: Ref) + requires issubtype(typeof(length), int()) + ensures typeof(res) == bytearray() + ensures acc(res.bytearray_acc) + ensures |res.bytearray_acc| == int___unbox__(length) + ensures (forall i: Int :: { res.bytearray_acc[i] } 0 <= i < int___unbox__(length) ==> res.bytearray_acc[i] == 0) + ensures Low(length) ==> Low(res) + +method bytearray___initFromBytearray__(other: Ref) returns (res: Ref) + requires issubtype(typeof(other), bytearray()) + requires acc(other.bytearray_acc, 1/1000) + ensures acc(other.bytearray_acc, 1/1000) + ensures acc(res.bytearray_acc) + ensures res.bytearray_acc == other.bytearray_acc + ensures typeof(res) == bytearray() + ensures Low(other) ==> Low(res) + +method bytearray___initFromList__(values: Ref) returns (res: Ref) + requires issubtype(typeof(values), list(int())) + requires acc(values.list_acc, 1/1000) + requires forall i: Int :: {values.list_acc[i]} ((0 <= i < list___len__(values)) ==> int___byte_bounds__(int___unbox__(list___getitem__(values, __prim__int___box__(i))))) + ensures acc(values.list_acc, 1/1000) + ensures acc(res.bytearray_acc) + ensures typeof(res) == bytearray() + ensures res.bytearray_acc == __seq_ref_to_seq_int(values.list_acc) + ensures Low(values) ==> Low(res) + +function bytearray_hex(self: Ref): Ref + decreases _ + requires issubtype(typeof(self), bytearray()) + requires acc(self.bytearray_acc, wildcard) + ensures typeof(result) == str() + +function bytearray___len__(self: Ref) : Int + decreases _ + requires issubtype(typeof(self), bytearray()) + requires acc(self.bytearray_acc, wildcard) + ensures result >= 0 +{ + |self.bytearray_acc| +} + +function bytearray___contains__(self: Ref, key: Ref): Bool + decreases _ + requires issubtype(typeof(self), bytearray()) + requires acc(self.bytearray_acc, wildcard) + requires issubtype(typeof(key), int()) + ensures result == (int___unbox__(key) in self.bytearray_acc) + +function bytearray___bool__(self: Ref) : Bool + decreases _ + requires self != null ==> issubtype(typeof(self), bytearray()) + requires self != null ==> acc(self.bytearray_acc, wildcard) + ensures self == null ==> !result + ensures self != null ==> result == (|self.bytearray_acc| != 0) + +// Currently only supports comparing to another bytearray +function bytearray___eq__(self: Ref, other: Ref): Bool + decreases _ + requires issubtype(typeof(self), bytearray()) + requires acc(self.bytearray_acc, wildcard) + requires issubtype(typeof(other), bytearray()) + requires acc(other.bytearray_acc, wildcard) + ensures result <==> self.bytearray_acc == other.bytearray_acc + +// function bytearray___hex__(self: Ref): Str +// decreases _ +// requires issubtype(typeof(self), bytearray()) +// requires acc(self.bytearray_acc, wildcard) +// ensures result == + +function bytearray___getitem__(self: Ref, key: Ref): Int + decreases _ + requires issubtype(typeof(self), bytearray()) + requires acc(self.bytearray_acc, wildcard) + requires issubtype(typeof(key), int()) + requires @error("Bytearray index may be out of bounds.")(let ln == (bytearray___len__(self)) in (int___unbox__(key) < 0 ==> int___unbox__(key) >= -ln)) + requires @error("Bytearray index may be out of bounds.")(let ln == (bytearray___len__(self)) in (int___unbox__(key) >= 0 ==> int___unbox__(key) < ln)) + ensures result == (int___unbox__(key) >= 0 ? self.bytearray_acc[int___unbox__(key)] : self.bytearray_acc[bytearray___len__(self) + int___unbox__(key)]) + ensures int___byte_bounds__(result) + +method bytearray___getitem_slice__(self: Ref, key: Ref) returns (_res: Ref) + requires issubtype(typeof(self), bytearray()) + requires issubtype(typeof(key), slice()) + requires acc(self.bytearray_acc, 1/1000) + ensures acc(self.bytearray_acc, 1/1000) + ensures acc(_res.bytearray_acc) + ensures typeof(_res) == bytearray() + ensures _res.bytearray_acc == self.bytearray_acc[slice___start__(key, bytearray___len__(self))..slice___stop__(key, bytearray___len__(self))] + +method bytearray___setitem__(self: Ref, key: Int, value: Ref) returns () + requires issubtype(typeof(self), bytearray()) + requires acc(self.bytearray_acc) + requires @error("Bytearray index may be negative.")(key >= 0) + requires @error("Bytearray index may be out of bounds.")(key < bytearray___len__(self)) + requires issubtype(typeof(value), int()) + requires @error("Provided value may be out of bounds.")int___byte_bounds__(int___unbox__(value)) + ensures acc(self.bytearray_acc) + ensures self.bytearray_acc == old(self.bytearray_acc)[key := int___unbox__(value)] + ensures (Low(key) && Low(value)) ==> (forall i: Ref :: {bytearray___getitem__(self, i)} ((issubtype(typeof(i), int()) && int___unbox__(i) >= 0 && int___unbox__(i) < bytearray___len__(self) && Low(old(bytearray___getitem__(self, i)))) ==> Low(bytearray___getitem__(self, i)))) + +method bytearray_append(self: Ref, item: Ref) returns () + requires issubtype(typeof(self), bytearray()) + requires acc(self.bytearray_acc) + requires issubtype(typeof(item), int()) + requires @error("Provided item may be out of bounds.")int___byte_bounds__(int___unbox__(item)) + ensures acc(self.bytearray_acc) + ensures self.bytearray_acc == old(self.bytearray_acc) ++ Seq(int___unbox__(item)) + +// Actual type of other is Iterable[SupportsIndex] +method bytearray_extend(self: Ref, other: Ref) returns () + requires issubtype(typeof(self), bytearray()) + requires issubtype(typeof(other), bytearray()) + requires acc(self.bytearray_acc) + requires acc(other.bytearray_acc, 1/100) + ensures acc(self.bytearray_acc) + ensures acc(other.bytearray_acc, 1/100) + ensures self.bytearray_acc == old(self.bytearray_acc) ++ other.bytearray_acc + +method bytearray_reverse(self: Ref) returns () + requires issubtype(typeof(self), bytearray()) + requires acc(self.bytearray_acc) + ensures acc(self.bytearray_acc) + ensures old(bytearray___len__(self)) == bytearray___len__(self) + ensures forall i: Int :: {self.bytearray_acc[i]} ((i >= 0 && i < bytearray___len__(self)) ==> (self.bytearray_acc[i] == old(self.bytearray_acc[bytearray___len__(self) - 1 - i]))) + +method bytearray___iter__(self: Ref) returns (_res: Ref) + requires issubtype(typeof(self), bytearray()) + requires acc(self.bytearray_acc, 1/ 20) + ensures _res != self + ensures acc(_res.list_acc, 1 / 20) + ensures acc(self.bytearray_acc, 1 / 20) + ensures _res.list_acc == bytearray___sil_seq__(self) + ensures acc(_res.__container, write) && (_res.__container == self) + ensures acc(_res.__iter_index, write) && (_res.__iter_index == 0) + ensures acc(_res.__previous, write) && _res.__previous == Seq[Ref]() + ensures issubtype(typeof(_res), Iterator(int())) + +function bytearray___sil_seq__(self: Ref): Seq[Ref] + decreases _ + requires issubtype(typeof(self), bytearray()) + requires acc(self.bytearray_acc, wildcard) + ensures |result| == bytearray___len__(self) + ensures (forall i: Int :: { result[i] } 0 <= i < bytearray___len__(self) ==> result[i] == __prim__int___box__(self.bytearray_acc[i])) + ensures (forall i: Ref :: { (i in result) } (i in result) == (typeof(i) == int() && (int___unbox__(i) in self.bytearray_acc))) + ensures (forall i: Ref :: { (i in result) } (i in result) ==> int___byte_bounds__(int___unbox__(i))) \ No newline at end of file diff --git a/src/nagini_translation/resources/byteseq.sil b/src/nagini_translation/resources/byteseq.sil new file mode 100644 index 000000000..0dd656810 --- /dev/null +++ b/src/nagini_translation/resources/byteseq.sil @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2019 ETH Zurich + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + + +function PByteSeq___val__(self: Ref): Seq[Int] + decreases _ + requires issubtype(typeof(self), PByteSeq()) + ensures forall i: Int :: {i in result} (i in result) ==> int___byte_bounds__(i) + +function PByteSeq___len__(self: Ref): Int + decreases _ + requires issubtype(typeof(self), PByteSeq()) + ensures result == |PByteSeq___val__(self)| + +function PByteSeq___create__(values: Seq[Int]): Ref + decreases _ + requires forall i: Int :: {i in values} (i in values) ==> int___byte_bounds__(i) + ensures typeof(result) == PByteSeq() + ensures PByteSeq___val__(result) == values + +function PByteSeq___from_bytes__(values: Seq[Int]): Ref + decreases _ + ensures typeof(result) == PByteSeq() + ensures PByteSeq___val__(result) == values + +function PByteSeq___contains__(self: Ref, item: Int): Bool + decreases _ + requires issubtype(typeof(self), PByteSeq()) + ensures result == (item in PByteSeq___val__(self)) + +function PByteSeq___getitem__(self: Ref, index: Ref): Int + decreases _ + requires issubtype(typeof(self), PByteSeq()) + requires issubtype(typeof(index), int()) + requires @error("Index may be out of bounds.")(let ln == (PByteSeq___len__(self)) in + @error("Index may be out of bounds.")((int___unbox__(index) < 0 ==> int___unbox__(index) >= -ln) && (int___unbox__(index) >= 0 ==> int___unbox__(index) < ln))) + ensures result == (int___unbox__(index) >= 0 ? PByteSeq___val__(self)[int___unbox__(index)] : PByteSeq___val__(self)[PByteSeq___len__(self) + int___unbox__(index)]) + ensures int___byte_bounds__(result) + +function PByteSeq_take(self: Ref, no: Int): Ref + decreases _ + requires issubtype(typeof(self), PByteSeq()) + ensures result == PByteSeq___create__(PByteSeq___val__(self)[..no]) + ensures no == PByteSeq___len__(self) ==> result == self + +function PByteSeq_drop(self: Ref, no: Int): Ref + decreases _ + requires issubtype(typeof(self), PByteSeq()) + ensures result == PByteSeq___create__(PByteSeq___val__(self)[no..]) + ensures no == 0 ==> result == self + +function PByteSeq_update(self: Ref, index: Int, val: Int): Ref + decreases _ + requires issubtype(typeof(self), PByteSeq()) + requires index >= 0 && index < PByteSeq___len__(self) + ensures int___byte_bounds__(val) + ensures result == PByteSeq___create__(PByteSeq___val__(self)[index := val]) + +function PByteSeq___add__(self: Ref, other: Ref): Ref + decreases _ + requires issubtype(typeof(self), PByteSeq()) + requires issubtype(typeof(other), PByteSeq()) + ensures result == PByteSeq___create__(PByteSeq___val__(self) ++ PByteSeq___val__(other)) + +function PByteSeq___eq__(self: Ref, other: Ref): Bool + decreases _ + requires issubtype(typeof(self), PByteSeq()) + requires issubtype(typeof(other), PByteSeq()) + ensures result == (PByteSeq___val__(self) == PByteSeq___val__(other)) + ensures result ==> self == other // extensionality + ensures result == object___eq__(self, other) + + +function PByteSeq___sil_seq__(self: Ref): Seq[Ref] + decreases _ + requires issubtype(typeof(self), PByteSeq()) + ensures PByteSeq___val__(self) == __seq_ref_to_seq_int(result) + + +// Helper function to wrap generic __sil_seq__ calls for conversion to PByteSeq +function PByteSeq___seq_ref_to_seq_int__(sr: Seq[Ref]): Seq[Int] + decreases _ +{ + __seq_ref_to_seq_int(sr) +} \ No newline at end of file diff --git a/src/nagini_translation/resources/intbv.sil b/src/nagini_translation/resources/intbv.sil index e425c76d8..bcf626bbf 100644 --- a/src/nagini_translation/resources/intbv.sil +++ b/src/nagini_translation/resources/intbv.sil @@ -7,6 +7,7 @@ // File template: NBITS is to be replaced by the number of bits, INT_MAX_VAL and INT_MIN_VAL by the actual values. +define _BITOPS_SIZE (NBITS) define _INT_MAX (INT_MAX_VAL) define _INT_MIN (INT_MIN_VAL) @@ -16,4 +17,7 @@ domain ___intbv interpretation (SMTLIB: "(_ BitVec NBITS)", Boogie: "bvNBITS") { function andBVInt(___intbv, ___intbv): ___intbv interpretation "bvand" function orBVInt(___intbv, ___intbv): ___intbv interpretation "bvor" function xorBVInt(___intbv, ___intbv): ___intbv interpretation "bvxor" + function shlBVInt(___intbv, ___intbv): ___intbv interpretation "bvshl" + function shrBVInt(___intbv, ___intbv): ___intbv interpretation "bvlshr" + function notBVInt(___intbv): ___intbv interpretation "bvnot" } \ No newline at end of file diff --git a/src/nagini_translation/resources/list.sil b/src/nagini_translation/resources/list.sil index a335abb69..cee93c3ae 100644 --- a/src/nagini_translation/resources/list.sil +++ b/src/nagini_translation/resources/list.sil @@ -138,6 +138,23 @@ method list___iter__(self: Ref) returns (_res: Ref) inhale false } +function list___eq__(self: Ref, other: Ref): Bool + decreases _ + requires issubtype(typeof(self), list(list_arg(typeof(self), 0))) + requires issubtype(typeof(other), list(list_arg(typeof(other), 0))) + requires acc(self.list_acc, wildcard) + requires acc(other.list_acc, wildcard) + ensures result <==> + (list___len__(self) == list___len__(other) && + (forall i: Int :: {self.list_acc[i]} {other.list_acc[i]} + i >= 0 && i < list___len__(self) + ==> object___eq__(self.list_acc[i], other.list_acc[i]))) + ensures result <==> + (list___len__(self) == list___len__(other) && + (forall i: Ref :: {list___getitem__(self, i)} {list___getitem__(other, i)} + issubtype(typeof(i), int()) && int___unbox__(i) >= 0 && int___unbox__(i) < list___len__(self) + ==> object___eq__(list___getitem__(self, i), list___getitem__(other, i)))) + function list___sil_seq__(self: Ref): Seq[Ref] decreases _ requires acc(self.list_acc, wildcard) diff --git a/src/nagini_translation/resources/pytype.sil b/src/nagini_translation/resources/pytype.sil index 47973ec01..c44a0aaab 100644 --- a/src/nagini_translation/resources/pytype.sil +++ b/src/nagini_translation/resources/pytype.sil @@ -35,6 +35,7 @@ domain PyType { unique function bool(): PyType unique function bytes(): PyType unique function bytearray(): PyType + unique function PByteSeq(): PyType unique function range_0(): PyType unique function slice(): PyType unique function str(): PyType diff --git a/src/nagini_translation/resources/references.sil b/src/nagini_translation/resources/references.sil index 67acd5901..0315917a0 100644 --- a/src/nagini_translation/resources/references.sil +++ b/src/nagini_translation/resources/references.sil @@ -11,4 +11,11 @@ method object___str__(self: Ref) returns (res: Ref) ensures str___val__(res) == object___str_val__(self) function object___str_val__(self: Ref): Seq[Int] + decreases _ + +method object___repr__(self: Ref) returns (res: Ref) + ensures issubtype(typeof(res), str()) + ensures str___val__(res) == object___repr_val__(self) + +function object___repr_val__(self: Ref): Seq[Int] decreases _ \ No newline at end of file diff --git a/src/nagini_translation/resources/seq.sil b/src/nagini_translation/resources/seq.sil index 38e1fd88d..f494849bf 100644 --- a/src/nagini_translation/resources/seq.sil +++ b/src/nagini_translation/resources/seq.sil @@ -38,11 +38,13 @@ function PSeq_take(self: Ref, no: Int): Ref decreases _ requires issubtype(typeof(self), PSeq(PSeq_arg(typeof(self), 0))) ensures result == PSeq___create__(PSeq___sil_seq__(self)[..no], PSeq_arg(typeof(self), 0)) + ensures no == PSeq___len__(self) ==> result == self function PSeq_drop(self: Ref, no: Int): Ref decreases _ requires issubtype(typeof(self), PSeq(PSeq_arg(typeof(self), 0))) ensures result == PSeq___create__(PSeq___sil_seq__(self)[no..], PSeq_arg(typeof(self), 0)) + ensures no == 0 ==> result == self function PSeq_update(self: Ref, index: Int, val: Ref): Ref decreases _ diff --git a/src/nagini_translation/resources/str.sil b/src/nagini_translation/resources/str.sil index 32d477206..e5a3ea406 100644 --- a/src/nagini_translation/resources/str.sil +++ b/src/nagini_translation/resources/str.sil @@ -48,6 +48,15 @@ function str___mod__(self: Ref, other: Ref): Ref requires issubtype(typeof(self), str()) ensures issubtype(typeof(result), str()) +method str_format(self: Ref, args: Ref, kwargs: Ref) returns (_res: Ref) + requires issubtype(typeof(self), str()) + ensures issubtype(typeof(_res), str()) + +method str___format__(self: Ref, other: Ref) returns (_res: Ref) + requires issubtype(typeof(self), str()) + requires issubtype(typeof(other), str()) + ensures issubtype(typeof(_res), str()) + method str_split(self: Ref) returns (res: Ref) decreases _ requires issubtype(typeof(self), str()) diff --git a/src/nagini_translation/translators/call.py b/src/nagini_translation/translators/call.py index 635a52c98..9852e4d1a 100644 --- a/src/nagini_translation/translators/call.py +++ b/src/nagini_translation/translators/call.py @@ -42,6 +42,7 @@ THREAD_POST_PRED, THREAD_START_PRED, TUPLE_TYPE, + BYTEARRAY_TYPE, ) from nagini_translation.lib.errors import rules from nagini_translation.lib.program_nodes import ( @@ -208,6 +209,21 @@ def translate_adt_cons(self, cons: PythonClass, args: List[FuncApp], return box_func + def translate_enum_cons(self, enum: PythonClass, args: List[FuncApp], + pos: Position, ctx: Context) -> Expr: + """ + Cosntruct Enums via a sequence of constructor calls and + boxing/unboxing calls. + """ + assert len(args) == 1 + + info = self.no_info(ctx) + args[0] = self.to_type(args[0], self.viper.Int, ctx) + box_func_name = enum.sil_name + '__box__' + box_func = self.viper.FuncApp(box_func_name, args, pos, info, self.viper.Ref) + return box_func + + def _is_lock_subtype(self, cls: PythonClass) -> bool: if cls is None: return False @@ -229,6 +245,9 @@ def translate_constructor_call(self, target_class: PythonClass, if target_class.python_class.is_adt: return arg_stmts, self.translate_adt_cons(target_class, args, pos, ctx) + if target_class.python_class.enum: + return arg_stmts, self.translate_enum_cons(target_class, args, pos, ctx) + res_var = ctx.current_function.create_variable(target_class.name + '_res', target_class, @@ -317,22 +336,35 @@ def translate_constructor_call(self, target_class: PythonClass, if target: - target_class = target.cls - targets = [] - if target.declared_exceptions: - error_var = self.get_error_var(node, ctx) - targets.append(error_var) - method_name = target_class.get_method('__init__').sil_name - init = self.create_method_call_node( - ctx, method_name, args, targets, self.to_position(node, ctx), - self.no_info(ctx), target_method=target, target_node=node) - stmts.extend(init) - if target.declared_exceptions: - catchers = self.create_exception_catchers(error_var, - ctx.actual_function.try_blocks, node, ctx) - stmts = stmts + catchers + init_stmts = self._translate_init_call(target, args, node, ctx) + stmts.extend(init_stmts) + + # If the init method was created implicitly, we have to check for __post_init__ + if target_class.python_class.dataclass and target_class.python_class.implicit_init: + target = target_class.get_method('__post_init__') + if target: + post_init_stmts = self._translate_init_call(target, [res_var.ref()], node, ctx, '__post_init__') + stmts.extend(post_init_stmts) + return arg_stmts + defined_check + stmts, res_var.ref() + def _translate_init_call(self, target: PythonMethod, args: list, node: ast.Call, ctx: Context, name = '__init__') -> list: + target_class = target.cls + targets = [] + + if target.declared_exceptions: + error_var = self.get_error_var(node, ctx) + targets.append(error_var) + method_name = target_class.get_method(name).sil_name + stmts = self.create_method_call_node( + ctx, method_name, args, targets, self.to_position(node, ctx), + self.no_info(ctx), target_method=target, target_node=node) + if target.declared_exceptions: + catchers = self.create_exception_catchers(error_var, + ctx.actual_function.try_blocks, node, ctx) + stmts = stmts + catchers + return stmts + def _translate_list(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: contents = None stmts = [] @@ -377,6 +409,27 @@ def _translate_list(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: stmts.append(self.viper.Inhale(seq_equal, position, info)) return stmts, result_var + def _translate_default_factory(self, arg, node: ast.AST, + ctx: Context) -> Tuple[List, Expr, PythonType]: + """Translates a default_factory for a dataclass field argument.""" + if arg.default_factory == 'list': + list_class = ctx.module.global_module.classes[LIST_TYPE] + res_var = ctx.current_function.create_variable('list', + list_class, + self.translator) + targets = [res_var.ref()] + constr_call = self.get_method_call(list_class, '__init__', [], + [], targets, node, ctx) + stmts = list(constr_call) + position = self.to_position(node, ctx) + result = res_var.ref(node, ctx) + stmts.append(self.viper.Inhale( + self.type_check(result, arg.type, position, ctx), + position, self.no_info(ctx))) + return stmts, result, arg.type + raise UnsupportedException(node, + 'Unsupported default_factory: ' + str(arg.default_factory)) + def _translate_set(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: contents = None stmts = [] @@ -509,6 +562,43 @@ def _translate_enumerate(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: return arg_stmt + [new_stmt, type_inhale, contents_inhale], new_list.ref(node, ctx) + def _translate_bytearray(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: + """ + Translates a call to bytearray() + """ + bytearray_class = ctx.module.global_module.classes[BYTEARRAY_TYPE] + res_var = ctx.current_function.create_variable('bytearray', bytearray_class, self.translator) + targets = [res_var.ref()] + result_var = res_var.ref(node, ctx) + method_name = None + + # This could potentially be merged using the "display_name" field + # by extending the general code for selecting a specific __init__ call + if len(node.args) == 0: + call = self.get_method_call(bytearray_class, '__init__', [], [], targets, node, ctx) + return call, result_var + + elif len(node.args) == 1: + arg_type = self.get_type(node.args[0], ctx) + + if arg_type.name == BYTEARRAY_TYPE: + method_name = '__initFromBytearray__' + + if arg_type.name == LIST_TYPE: + method_name = '__initFromList__' + + if arg_type.name == INT_TYPE: + method_name = '__initFromInt__' + + if method_name: + target_method = bytearray_class.get_method(method_name) + arg_stmts, arg_vals, arg_types = self.translate_args(target_method, node.args, node.keywords, node, ctx) + constr_call = self.get_method_call(bytearray_class, method_name, arg_vals, arg_types, targets, node, ctx) + return arg_stmts + constr_call, res_var.ref(node, ctx) + + raise UnsupportedException(node, 'Unsupported variant of bytearray().') + + def _translate_builtin_func(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: """ @@ -541,6 +631,8 @@ def _translate_builtin_func(self, node: ast.Call, return self._translate_type_func(node, ctx) elif func_name == 'cast': return self._translate_cast_func(node, ctx) + elif func_name == 'bytearray': + return self._translate_bytearray(node, ctx) else: raise UnsupportedException(node) @@ -800,11 +892,11 @@ def translate_args(self, target: PythonMethod, arg_nodes: List, # are just set to null. if keywords: raise UnsupportedException(node, desc='Keyword arguments in call to ' - 'builtin function.') + 'builtin function: ' + target.name) diff = target.nargs - len(unpacked_args) if diff < 0: raise UnsupportedException(node, 'Unsupported version of builtin ' - 'function.') + 'function: ' + target.name) if diff > 0: null = self.viper.NullLit(self.no_position(ctx), self.no_info(ctx)) unpacked_args += [null] * diff @@ -845,9 +937,16 @@ def translate_args(self, target: PythonMethod, arg_nodes: List, for index, (arg, key) in enumerate(zip(args, keys)): if arg is False: # Not set yet, need default - args[index] = target.args[key].default_expr - assert args[index], '{} arg={}'.format(target.name, key) - arg_types[index] = self.get_type(target.args[key].default, ctx) + if target.args[key].default_factory: + factory_stmts, factory_expr, factory_type = \ + self._translate_default_factory(target.args[key], node, ctx) + arg_stmts += factory_stmts + args[index] = factory_expr + arg_types[index] = factory_type + else: + args[index] = target.args[key].default_expr + assert args[index], '{} arg={}'.format(target.name, key) + arg_types[index] = self.get_type(target.args[key].default, ctx) if target.var_arg: var_arg_list = self.create_tuple(var_args, var_arg_types, node, ctx) diff --git a/src/nagini_translation/translators/common.py b/src/nagini_translation/translators/common.py index 7eff612d7..ebd9dee15 100644 --- a/src/nagini_translation/translators/common.py +++ b/src/nagini_translation/translators/common.py @@ -17,6 +17,7 @@ INT_TYPE, IS_DEFINED_FUNC, LIST_TYPE, + BYTEARRAY_TYPE, MAIN_METHOD_NAME, MAY_SET_PRED, NAME_DOMAIN, @@ -25,6 +26,7 @@ PRIMITIVE_INT_TYPE, RANGE_TYPE, PSEQ_TYPE, + PBYTESEQ_TYPE, PSET_TYPE, SET_TYPE, SINGLE_NAME, @@ -186,7 +188,8 @@ def to_bool(self, e: Expr, ctx: Context, node: ast.AST = None) -> Expr: position=e.pos()) return result - def to_int(self, e: Expr, ctx: Context) -> Expr: + def to_int(self, e: Expr, ctx: Context, + python_type: 'PythonType' = None) -> Expr: """ Converts the given expression to an expression of the Silver type Int if it isn't already, either by unboxing a reference or undoing a @@ -204,10 +207,16 @@ def to_int(self, e: Expr, ctx: Context) -> Expr: e.funcname() == '__prim__int___box__'): return e.args().head() result = e - int_type = ctx.module.global_module.classes[INT_TYPE] - result = self.get_function_call(int_type, '__unbox__', - [result], [None], None, ctx, - position=e.pos()) + if python_type and python_type.python_class.enum and python_type.python_class.enum_type == INT_TYPE: + unbox_name = python_type.python_class.functions['__int__'].sil_name + result = self.viper.FuncApp(unbox_name, [result], + e.pos(), self.no_info(ctx), + self.viper.Int) + else: + int_type = ctx.module.global_module.classes[INT_TYPE] + result = self.get_function_call(int_type, '__unbox__', + [result], [None], None, ctx, + position=e.pos()) return result def unwrap(self, e: Expr) -> Expr: @@ -632,6 +641,40 @@ def get_sequence(self, receiver: PythonType, arg: Expr, arg_type: PythonType, res_trigger = res return res_trigger, res + def get_int_sequence(self, receiver: PythonType, arg: Expr, + node: ast.AST, ctx: Context, + position: Position = None) -> Expr: + """ + Returns a sequence (Viper type Seq[Int]) representing the contents of arg. + Defaults to type___sil_seq__, but used simpler expressions for known types + to improve performance/triggering. + """ + position = position if position else self.to_position(node, ctx) + info = self.no_info(ctx) + int_type = INT_TYPE + if not isinstance(receiver, UnionType) or isinstance(receiver, OptionalType): + if receiver.name == BYTEARRAY_TYPE: + seq_int = self.viper.SeqType(self.viper.Int) + field = self.viper.Field('bytearray_acc', seq_int, position, info) + res = self.viper.FieldAccess(arg, field, position, info) + return res + if receiver.name == PBYTESEQ_TYPE: + if (isinstance(arg, self.viper.ast.FuncApp) and + arg.funcname() == 'PByteSeq___create__'): + args = self.viper.to_list(arg.args()) + return args[0] + int_seq_op = getattr(receiver.cls, '__sil_int_seq__', None) + if callable(int_seq_op): + self.get_function_call(receiver, '__sil_int_seq__', [arg], [None], + node, ctx, position) + + # Fallback to getting a Seq[Ref] and then converting to Seq[Int] + PByteSeq_class = ctx.module.global_module.classes[PBYTESEQ_TYPE] + seq_ref_exp = self.get_function_call(receiver, '__sil_seq__', [arg], [None], + node, ctx, position) + return self.get_function_call(PByteSeq_class, '__seq_ref_to_seq_int__', [seq_ref_exp], [None], + node, ctx, position) + def _get_function_call(self, receiver: PythonType, func_name: str, args: List[Expr], arg_types: List[PythonType], node: ast.AST, diff --git a/src/nagini_translation/translators/contract.py b/src/nagini_translation/translators/contract.py index 5ecf51bf0..32f92bfcf 100644 --- a/src/nagini_translation/translators/contract.py +++ b/src/nagini_translation/translators/contract.py @@ -22,8 +22,11 @@ PMSET_TYPE, PRIMITIVES, PSEQ_TYPE, + PBYTESEQ_TYPE, PSET_TYPE, RANGE_TYPE, + BYTEARRAY_TYPE, + BYTES_TYPE, THREAD_DOMAIN, THREAD_POST_PRED, THREAD_START_PRED, @@ -134,6 +137,7 @@ def _get_perm(self, node: ast.Call, ctx: Context) -> Expr: def translate_builtin_predicate(self, node: ast.Call, perm: Expr, args: List[Expr], ctx: Context) -> Expr: name = node.func.id + seq_int = self.viper.SeqType(self.viper.Int) seq_ref = self.viper.SeqType(self.viper.Ref) set_ref = self.viper.SetType(self.viper.Ref) map_ref_ref = self.viper.MapType(self.viper.Ref, self.viper.Ref) @@ -146,6 +150,8 @@ def translate_builtin_predicate(self, node: ast.Call, perm: Expr, return self._get_field_perm('set_acc', set_ref, perm, args[0], pos, ctx) elif name == 'dict_pred': return self._get_field_perm('dict_acc', map_ref_ref, perm, args[0], pos, ctx) + elif name == 'bytearray_pred': + return self._get_field_perm('bytearray_acc', seq_int, perm, args[0], pos, ctx) elif name == 'MayStart': return self.translate_may_start(node, args, perm, ctx) elif name == 'ThreadPost': @@ -711,7 +717,7 @@ def translate_to_sequence(self, node: ast.Call, # iterable (which gives no information about order for unordered types). seq_call, _ = self.get_sequence(coll_type, arg, None, node, ctx) seq_class = ctx.module.global_module.classes[PSEQ_TYPE] - if coll_type.name == RANGE_TYPE: + if coll_type.name == RANGE_TYPE or coll_type.name == BYTEARRAY_TYPE: type_arg = ctx.module.global_module.classes[INT_TYPE] else: type_arg = coll_type.type_args[0] @@ -722,6 +728,22 @@ def translate_to_sequence(self, node: ast.Call, [seq_call, type_lit], [None, None], node, ctx) return stmt, result + + def translate_to_int_sequence(self, node: ast.Call, + ctx: Context) -> StmtsAndExpr: + coll_type = self.get_type(node.args[0], ctx) + stmt, arg = self.translate_expr(node.args[0], ctx) + + seq_call = self.get_int_sequence(coll_type, arg, node, ctx) + seq_class = ctx.module.global_module.classes[PBYTESEQ_TYPE] + if coll_type.name == BYTEARRAY_TYPE: + call_name = '__from_bytes__' + else: + call_name = '__create__' + result = self.get_function_call(seq_class, call_name, + [seq_call], [None], + node, ctx) + return stmt, result def translate_sequence(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: @@ -750,6 +772,31 @@ def translate_sequence(self, node: ast.Call, [result, type_lit], [None, None], node, ctx) return val_stmts, result + + def translate_int_sequence(self, node: ast.Call, + ctx: Context) -> StmtsAndExpr: + intseq_class = ctx.module.global_module.classes[PBYTESEQ_TYPE] + viper_type = self.viper.Int + val_stmts = [] + if node.args: + vals = [] + for arg in node.args: + arg_stmt, arg_val = self.translate_expr(arg, ctx, + target_type=viper_type) + val_stmts += arg_stmt + vals.append(arg_val) + result = self.viper.ExplicitSeq(vals, self.to_position(node, + ctx), + self.no_info(ctx)) + else: + result = self.viper.EmptySeq(viper_type, + self.to_position(node, ctx), + self.no_info(ctx)) + + result = self.get_function_call(intseq_class, '__create__', + [result], [None], node, + ctx) + return val_stmts, result def translate_pset(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: set_type = self.get_type(node, ctx) @@ -850,7 +897,7 @@ def translate_let(self, node: ast.Call, ctx: Context, arg = lambda_.args.args[0] var = ctx.actual_function.get_variable(lambda_prefix + arg.arg) - exp_stmt, exp_val = self.translate_expr(node.args[0], ctx) + exp_stmt, exp_val = self.translate_expr(node.args[0], ctx, target_type=var.decl.typ()) ctx.set_alias(arg.arg, var, None) @@ -1151,12 +1198,16 @@ def translate_contractfunc_call(self, node: ast.Call, ctx: Context, return self.translate_let(node, ctx, impure) elif func_name == PSEQ_TYPE: return self.translate_sequence(node, ctx) + elif func_name == PBYTESEQ_TYPE: + return self.translate_int_sequence(node, ctx) elif func_name == PSET_TYPE: return self.translate_pset(node, ctx) elif func_name == PMSET_TYPE: return self.translate_mset(node, ctx) elif func_name == 'ToSeq': return self.translate_to_sequence(node, ctx) + elif func_name == 'ToByteSeq': + return self.translate_to_int_sequence(node, ctx) elif func_name == 'ToMS': return self.translate_to_multiset(node, ctx) elif func_name == 'Joinable': @@ -1172,4 +1223,4 @@ def translate_contractfunc_call(self, node: ast.Call, ctx: Context, elif func_name == 'arg': raise InvalidProgramException(node, 'invalid.arg.use') else: - raise UnsupportedException(node) + raise UnsupportedException(node, func_name) diff --git a/src/nagini_translation/translators/expression.py b/src/nagini_translation/translators/expression.py index 6a7451d71..c3fba692a 100644 --- a/src/nagini_translation/translators/expression.py +++ b/src/nagini_translation/translators/expression.py @@ -408,13 +408,40 @@ def translate_string(self, s: str, node: ast.AST, ctx: Context) -> Expr: node, ctx) return call + def translate_JoinedStr(self, node: ast.JoinedStr, ctx: Context) -> StmtsAndExpr: + """ + Dummy implementation for ast.JoinedStr, only translates contained expressions. + Provides no guarantees about resulting value. + """ + stmts = [] + exps = [] + for val in node.values: + val_stmt, val_exp = self.translate_expr(val, ctx) + stmts += val_stmt + exps.append(val_exp) + + str_class = ctx.module.global_module.classes[STRING_TYPE] + res_var = ctx.current_function.create_variable('joined_str', str_class, self.translator) + result_var = res_var.ref(node, ctx) + position = self.to_position(node, ctx) + stmts.append(self.viper.Inhale(self.type_check(result_var, str_class, position, ctx), + position, self.no_info(ctx))) + return stmts, result_var + + def translate_FormattedValue(self, node: ast.FormattedValue, ctx: Context) -> StmtsAndExpr: + """ + Dummy implementation for ast.FormattedValue, only translates contained expression. + Does not apply given formatting rules + """ + stmt, exp = self.translate_expr(node.value, ctx) + return stmt, exp + def translate_Ellipsis(self, node: ast.Ellipsis, ctx: Context) -> StmtsAndExpr: ellipsis_class = ctx.module.global_module.classes[ELLIPSIS_TYPE] func_name = '__create__' call = self.get_function_call(ellipsis_class, func_name, [], [], node, ctx) return [], call - def translate_Bytes(self, node: ast.Constant, ctx: Context) -> StmtsAndExpr: elems = [] for c in node.s: @@ -1150,16 +1177,24 @@ def translate_Compare(self, node: ast.Compare, position = self.to_position(node, ctx) info = self.no_info(ctx) + if isinstance(node.ops[0], ast.Is): + return (stmts, self.viper.EqCmp(left, right, position, info)) + elif isinstance(node.ops[0], ast.IsNot): + return (stmts, self.viper.NeCmp(left, right, position, info)) + + # Unbox IntEnum to int + if left_type.python_class.enum and left_type.python_class.enum_type == INT_TYPE: + left = self.to_int(left, ctx, left_type) + left_type = ctx.module.global_module.classes[INT_TYPE] + if right_type.python_class.enum and right_type.python_class.enum_type == INT_TYPE: + right = self.to_int(right, ctx, right_type) + right_type = ctx.module.global_module.classes[INT_TYPE] + if self._is_primitive_operation(node.ops[0], left_type, right_type): result = self._translate_primitive_operation(left, right, left_type, node.ops[0], position, ctx) return stmts, result - - if isinstance(node.ops[0], ast.Is): - return (stmts, self.viper.EqCmp(left, right, position, info)) - elif isinstance(node.ops[0], ast.IsNot): - return (stmts, self.viper.NeCmp(left, right, position, info)) elif isinstance(node.ops[0], (ast.In, ast.NotIn)): contains_stmts, contains_expr = self._translate_contains( left, right, left_type, right_type, node, ctx) diff --git a/src/nagini_translation/translators/method.py b/src/nagini_translation/translators/method.py index 702753b92..5bf40b6d8 100644 --- a/src/nagini_translation/translators/method.py +++ b/src/nagini_translation/translators/method.py @@ -329,16 +329,18 @@ def translate_function(self, func: PythonMethod, check = self.type_check(result, func.type, res_type_pos, ctx) posts = [check] + posts - statements = func.node.body - start, end = get_body_indices(statements) - # Translate body - actual_body = statements[start:end] - if (func.contract_only or - (len(actual_body) == 1 and isinstance(actual_body[0], ast.Expr) and - isEllipsis(actual_body[0].value))): + if func.contract_only: body = None else: - body = self.translate_exprs(actual_body, func, ctx) + statements = func.node.body + start, end = get_body_indices(statements) + # Translate body + actual_body = statements[start:end] + if ((len(actual_body) == 1 and isinstance(actual_body[0], ast.Expr) and + isEllipsis(actual_body[0].value))): + body = None + else: + body = self.translate_exprs(actual_body, func, ctx) ctx.current_function = old_function name = func.sil_name diff --git a/src/nagini_translation/translators/program.py b/src/nagini_translation/translators/program.py index e5ac823c4..b6325cdfb 100644 --- a/src/nagini_translation/translators/program.py +++ b/src/nagini_translation/translators/program.py @@ -24,6 +24,7 @@ MAY_SET_PRED, METHOD_ID_DOMAIN, NAME_DOMAIN, + PRIMITIVE_INT_TYPE, PRIMITIVES, RESULT_NAME, THREAD_DOMAIN, @@ -146,7 +147,11 @@ def create_static_field_function(self, root: PythonVar, ctx.current_class = field.cls ctx.module = field.cls.module # Compute the field value - stmt, value = self.translate_expr(field.value, ctx) + if cls.enum: + stmt, value = self.translate_expr(field.value, ctx, self.viper.Int) + value = self.viper.FuncApp(cls.name + '__box__', [value], position, info, self.viper.Ref) + else: + stmt, value = self.translate_expr(field.value, ctx) if stmt: raise InvalidProgramException('purity.violated', field.node) field_position = self.to_position(field.node, ctx) @@ -442,7 +447,7 @@ def translate_default_args(self, method: PythonMethod, if type and not type.python_class.interface and not type.contains_type_var(): definition_deps.add((arg.node.annotation, type.python_class, method.module)) - if arg.default: + if arg.default and not arg.default_factory: stmt, expr = self.translate_expr(arg.default, ctx) if not stmt and expr: arg.default_expr = expr @@ -485,6 +490,10 @@ def _create_predefined_fields(self, self.viper.MapType(self.viper.Ref, self.viper.Ref), self.no_position(ctx), self.no_info(ctx))) + fields.append(self.viper.Field('bytearray_acc', + self.viper.SeqType(self.viper.Int), + self.no_position(ctx), + self.no_info(ctx))) fields.append(self.viper.Field('Measure$acc', self.viper.SeqType(self.viper.Ref), self.no_position(ctx), @@ -1162,6 +1171,98 @@ def create_adts_domains_and_functions(self, adts: List[PythonClass], return domains, functions + def _register_enum_int_function(self, enum: PythonClass, ctx: Context, suffix: str) -> PythonMethod: + """Register __int__ as an interface function on the enum class.""" + int_func = enum.node_factory.create_python_method( + suffix, None, enum, enum, True, False, enum.node_factory, + interface=True, interface_dict={'args': [enum.name], 'type': PRIMITIVE_INT_TYPE}) + arg = enum.node_factory.create_python_var('self', None, enum) + int_func.add_arg('self', arg) + int_func.type = ctx.module.global_module.classes[PRIMITIVE_INT_TYPE] + sil_name = enum.get_fresh_name(enum.name + suffix) + int_func.process(sil_name, self.translator) + enum.functions[suffix] = int_func + return int_func + + def _create_enum_func_box_and_unbox(self, enum: PythonClass, ctx: Context) -> list[Function]: + """Create __box__ and __int__ functions for IntEnum. Other enum types currently not supported.""" + + pos = self.to_position(enum, ctx) + info = self.no_info(ctx) + terminates_wildcard = self.viper.DecreasesWildcard(None, pos, info) + + box_func_suffix = '__box__' + box_func_name = enum.sil_name + box_func_suffix + unbox_func_suffix = '__int__' + int_func = self._register_enum_int_function(enum, ctx, unbox_func_suffix) + unbox_func_name = int_func.sil_name + + # Create box function (Int -> Ref) + int_val_use = self.viper.LocalVar('value', self.viper.Int, pos, info) + int_val_decl = self.viper.LocalVarDecl('value', self.viper.Int, pos, info) + result = self.viper.Result(self.viper.Ref, pos, info) + preconds = [terminates_wildcard] + + # Add precondition for allowed values + enum_value_precond = self.viper.FalseLit(pos, info) + for field_name in enum.static_fields: + field = enum.get_static_field(field_name) + if field and field.value: + _, enum_value_expr = self.translate_expr(field.value, ctx, self.viper.Int) + value_check = self.viper.EqCmp(int_val_use, enum_value_expr, pos, info) + enum_value_precond = self.viper.Or(enum_value_precond, value_check, pos, info) + preconds.append(enum_value_precond) + + postconds = [] + postconds.append(self.type_factory.type_check(result, enum, pos, ctx, True)) + + unbox_func = self.viper.FuncApp(unbox_func_name, [result], pos, info, self.viper.Int) + postconds.append(self.viper.EqCmp(unbox_func, int_val_use, pos, info)) + int_unbox_func = self.viper.FuncApp('int___unbox__', [result], pos, info, self.viper.Int) + postconds.append(self.viper.EqCmp(int_unbox_func, int_val_use, pos, info)) + + yield self.viper.Function(box_func_name, + [int_val_decl], self.viper.Ref, preconds, postconds, + None, pos, info) + + # Create unbox function (Ref -> Int) + ref_use = self.viper.LocalVar('box', self.viper.Ref, pos, info) + ref_decl = self.viper.LocalVarDecl('box', self.viper.Ref, pos, info) + result = self.viper.Result(self.viper.Int, pos, info) + preconds = [terminates_wildcard] + postconds = [] + + preconds.append(self.type_factory.type_check(ref_use, enum, pos, ctx, True)) + + # Add postcondition constraining the result to valid enum values + enum_value_postcond = self.viper.FalseLit(pos, info) + for field_name in enum.static_fields: + field = enum.get_static_field(field_name) + if field and field.value: + _, enum_value_expr = self.translate_expr(field.value, ctx, self.viper.Int) + value_check = self.viper.EqCmp(result, enum_value_expr, pos, info) + enum_value_postcond = self.viper.Or(enum_value_postcond, value_check, pos, info) + postconds.append(enum_value_postcond) + + # Add forall postcondition + i_var_use = self.viper.LocalVar('i', self.viper.Ref, pos, info) + i_var_decl = self.viper.LocalVarDecl('i', self.viper.Ref, pos, info) + obj_eq_check = self.viper.DomainFuncApp('object___eq__', [ref_use, i_var_use], self.viper.Bool, pos, info, '__ObjectEquality') + type_check = self.type_factory.type_check(i_var_use, enum, pos, ctx, True) + condition = self.viper.And(obj_eq_check, type_check, pos, info) + unbox_apply = self.viper.FuncApp(unbox_func_name, [i_var_use], pos, info, self.viper.Int) + unbox_eq = self.viper.EqCmp(unbox_apply, result, pos, info) + implication = self.viper.Implies(condition, unbox_eq, pos, info) + + trigger = self.viper.Trigger([obj_eq_check, unbox_apply], pos, info) + forall_postcond = self.viper.Forall([i_var_decl], [trigger], implication, pos, info) + postconds.append(forall_postcond) + + yield self.viper.Function(unbox_func_name, + [ref_decl], self.viper.Int, preconds, postconds, + None, pos, info) + + def translate_program(self, modules: List[PythonModule], sil_progs: Program, ctx: Context, selected: Set[str] = None, ignore_global: bool = False) -> Program: @@ -1229,6 +1330,12 @@ def translate_program(self, modules: List[PythonModule], sil_progs: Program, while current_field.overrides: current_field = current_field.overrides static_fields.setdefault(current_field, []).append(cls) + if cls.enum: + enum_functions = list(self._create_enum_func_box_and_unbox(cls, ctx)) + functions.extend(enum_functions) + if module is not module.global_module: + for function in enum_functions: + all_names.append(function.name()) ctx.current_class = None # Translate default args @@ -1292,7 +1399,7 @@ def translate_program(self, modules: List[PythonModule], sil_progs: Program, self.track_dependencies(selected_names, selected, func, ctx) functions.append(self.translate_function(func, ctx)) func_constants.append(self.translate_function_constant(func, ctx)) - if func.overrides and not ((func_name in ('__str__', '__bool__') and + if func.overrides and not ((func_name in ('__str__', '__bool__', '__eq__') and func.overrides.cls.name == 'object') or (func_name in ('__getitem__',) and func.overrides.cls.name == 'dict')): # We allow overriding certain methods, since the basic versions @@ -1350,6 +1457,7 @@ def translate_program(self, modules: List[PythonModule], sil_progs: Program, predicate_families[cpred].append(pred) else: predicate_families[cpred] = [pred] + ctx.current_class = old_class if not ignore_global: diff --git a/src/nagini_translation/translators/statement.py b/src/nagini_translation/translators/statement.py index 24954f9f6..a970f86c6 100644 --- a/src/nagini_translation/translators/statement.py +++ b/src/nagini_translation/translators/statement.py @@ -11,6 +11,7 @@ BYTES_TYPE, COMBINED_NAME_ACCESSOR, DICT_TYPE, + BYTEARRAY_TYPE, END_LABEL, IGNORED_IMPORTS, IGNORED_MODULE_NAMES, @@ -455,6 +456,7 @@ def _create_for_loop_invariant(self, iter_var: PythonVar, seq_temp_var: PythonVa """ pos = self.to_position(node, ctx) info = self.no_info(ctx) + seq_int = self.viper.SeqType(self.viper.Int) seq_ref = self.viper.SeqType(self.viper.Ref) set_ref = self.viper.SetType(self.viper.Ref) map_ref_ref = self.viper.MapType(self.viper.Ref, self.viper.Ref) @@ -484,6 +486,13 @@ def _create_for_loop_invariant(self, iter_var: PythonVar, seq_temp_var: PythonVa frac_perm_120, pos, info) invariant.append(field_pred) + elif iterable_type.name == BYTEARRAY_TYPE: + acc_field = self.viper.Field('bytearray_acc', seq_int, pos, info) + field_acc = self.viper.FieldAccess(iterable, acc_field, pos, info) + field_pred = self.viper.FieldAccessPredicate(field_acc, + frac_perm_120, pos, + info) + invariant.append(field_pred) elif iterable_type.name == RANGE_TYPE: pass else: @@ -714,7 +723,7 @@ def translate_stmt_For(self, node: ast.For, ctx: Context) -> List[Stmt]: # Find type of the collection content we're iterating over. if iterable_type.name in (LIST_TYPE, DICT_TYPE, SET_TYPE): target_type = iterable_type.type_args[0] - elif iterable_type.name in (RANGE_TYPE, BYTES_TYPE): + elif iterable_type.name in (RANGE_TYPE, BYTES_TYPE, BYTEARRAY_TYPE): target_type = ctx.module.global_module.classes[INT_TYPE] else: raise UnsupportedException(node, 'unknown.iterable') diff --git a/src/nagini_translation/translators/type.py b/src/nagini_translation/translators/type.py index 431c2712c..4552593bf 100644 --- a/src/nagini_translation/translators/type.py +++ b/src/nagini_translation/translators/type.py @@ -44,6 +44,7 @@ def builtins(self): return {'builtins.int': self.viper.Int, 'builtins.bool': self.viper.Bool, 'builtins.PSeq': self.viper.SeqType(self.viper.Ref), + 'builtins.PSeqInt': self.viper.SeqType(self.viper.Int), 'builtins.PSet': self.viper.SetType(self.viper.Ref), 'builtins.PMultiset': self.viper.MultisetType(self.viper.Ref), } diff --git a/src/nagini_translation/translators/type_domain_factory.py b/src/nagini_translation/translators/type_domain_factory.py index 6082137a5..1f0fc8c97 100644 --- a/src/nagini_translation/translators/type_domain_factory.py +++ b/src/nagini_translation/translators/type_domain_factory.py @@ -829,6 +829,9 @@ def type_check(self, lhs: 'Expr', type: 'PythonType', just_part = self.subtype_check(type_func, type.cls, position, ctx, concrete=concrete) return self.viper.Or(none_part, just_part, position, info) + # Enums cannot be subclassed, so we can use exact type equality + if type.python_class.enum: + concrete = True return self.subtype_check(type_func, type, position, ctx, concrete=concrete) diff --git a/src/nagini_translation/verifier.py b/src/nagini_translation/verifier.py index bfc995239..6580a8ebd 100644 --- a/src/nagini_translation/verifier.py +++ b/src/nagini_translation/verifier.py @@ -146,9 +146,15 @@ def verify(self, modules, prog: 'silver.ast.Program', arp=False, sif=False) -> V else: return Success() - def __del__(self): + def stop(self): if hasattr(self, 'silicon') and self.silicon: - self.silicon.stop() + try: + self.silicon.stop() + except Exception: + pass + + def __del__(self): + self.stop() class Carbon: @@ -193,4 +199,11 @@ def verify(self, modules, prog: 'silver.ast.Program', arp=False, sif=False) -> V errors += [it.next()] return Failure(errors, self.jvm, modules, sif) else: - return Success() \ No newline at end of file + return Success() + + def stop(self): + if hasattr(self, 'carbon') and self.carbon: + try: + self.carbon.stop() + except Exception: + pass \ No newline at end of file diff --git a/tests/functional/translation/test_dataclass.py b/tests/functional/translation/test_dataclass.py new file mode 100644 index 000000000..48cf3650e --- /dev/null +++ b/tests/functional/translation/test_dataclass.py @@ -0,0 +1,34 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from typing import List +from nagini_contracts.contracts import * +from dataclasses import dataclass, field + +@dataclass(frozen=True) +class foo: + num: int + name: str + obj: list[int] + +@dataclass(frozen=True) +class A: + data: foo + +@dataclass(frozen=True) +class B: + num: int = 2 + direct = 3 + +@dataclass +class FactoryClass: + arr: List[int] = field(default_factory=list) + +@dataclass(frozen=False) +class NonFrozenKeyword: + data: int + +def test_cons() -> None: + f1 = foo(1, "hello", []) + + f2 = foo(num=2, name="hello", obj=[]) \ No newline at end of file diff --git a/tests/functional/translation/test_dataclass_defaults.py b/tests/functional/translation/test_dataclass_defaults.py new file mode 100644 index 000000000..a689c653e --- /dev/null +++ b/tests/functional/translation/test_dataclass_defaults.py @@ -0,0 +1,20 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from dataclasses import dataclass +from nagini_contracts.contracts import * +from enum import IntEnum + +class Color_Enum(IntEnum): + red = 0 + green = 1 + blue = 2 + yellow = 3 + +@dataclass +class MyClass(): + color: Color_Enum = Color_Enum.red + +@dataclass(frozen=True) +class MyClassF(): + color: Color_Enum = Color_Enum.red \ No newline at end of file diff --git a/tests/functional/translation/test_dataclass_field.py b/tests/functional/translation/test_dataclass_field.py new file mode 100644 index 000000000..d66d76254 --- /dev/null +++ b/tests/functional/translation/test_dataclass_field.py @@ -0,0 +1,10 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from nagini_contracts.contracts import * +from dataclasses import dataclass, field + +@dataclass(frozen=True) +class MissingType: + #:: ExpectedOutput(unsupported:field() requires a type annotation) + arr = field(default_factory=list) \ No newline at end of file diff --git a/tests/functional/translation/test_dataclass_no_init.py b/tests/functional/translation/test_dataclass_no_init.py new file mode 100644 index 000000000..877f48bdf --- /dev/null +++ b/tests/functional/translation/test_dataclass_no_init.py @@ -0,0 +1,10 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from nagini_contracts.contracts import * +from dataclasses import dataclass + +#:: ExpectedOutput(unsupported:keyword unsupported) +@dataclass(frozen=True, init=False) +class NonInit: + data: int \ No newline at end of file diff --git a/tests/functional/translation/test_enum.py b/tests/functional/translation/test_enum.py new file mode 100644 index 000000000..91e68332f --- /dev/null +++ b/tests/functional/translation/test_enum.py @@ -0,0 +1,13 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from nagini_contracts.contracts import * +from enum import IntEnum + +class flag(IntEnum): + success = 0 + failure = 1 + +#:: ExpectedOutput(invalid.program:Cannot extend enumeration) +class sub_flag(flag): + unknown = 3 \ No newline at end of file diff --git a/tests/functional/verification/test_bitwise_op.py b/tests/functional/verification/test_bitwise_op.py index 6b65f3caf..3337f0d73 100644 --- a/tests/functional/verification/test_bitwise_op.py +++ b/tests/functional/verification/test_bitwise_op.py @@ -97,7 +97,7 @@ def and_3a(a: int, b: bool, c: int) -> None: def and_4(a: int, b: bool, c: int) -> None: Requires(a > -100 and a < 100) - Requires(c > -130 and c < 127) + Requires(c > -260 and c < 127) #:: ExpectedOutput(application.precondition:assertion.false) intint = a & c @@ -133,7 +133,7 @@ def or_3a(a: int, b: bool, c: int) -> None: def or_4(a: int, b: bool, c: int) -> None: Requires(a > -100 and a < 100) - Requires(c > -130 and c < 127) + Requires(c > -260 and c < 127) #:: ExpectedOutput(application.precondition:assertion.false) intint = a | c @@ -169,6 +169,121 @@ def xor_3a(a: int, b: bool, c: int) -> None: def xor_4(a: int, b: bool, c: int) -> None: Requires(a > -100 and a < 100) - Requires(c >= -128 and c < 129) + Requires(c >= -128 and c < 257) #:: ExpectedOutput(application.precondition:assertion.false) - intint = a ^ c \ No newline at end of file + intint = a ^ c + +def lshift_general(a: int, b: int) -> None: + Requires(a > -128 and a <= 127) + Requires(b >=0 and b <= 127) + + shift = a << b + + if b <= 8: + lshift_unlimited(a, b) + + if b == 9: + assert shift == a * 512 + + #:: ExpectedOutput(assert.failed:assertion.false) + assert shift == 1 + +def lshift_unlimited(a: int, b: int) -> None: + Requires(b >= 0 and b <= 64) + + shift = a << b + + # Unfortunately we cannot prove the equivalence shift == a * (2**b) + if b == 0: + assert shift == a + if b == 1: + assert shift == a * 2 + if b == 2: + assert shift == a * 4 + if b == 3: + assert shift == a * 8 + if b == 4: + assert shift == a * 16 + if b == 5: + assert shift == a * 32 + if b == 6: + assert shift == a * 64 + if b == 7: + assert shift == a * 128 + if b == 8: + assert shift == a * 256 + if b == 32: + assert shift == a * 4_294_967_296 + if b == 33: + assert shift == a * 8_589_934_592 + + #:: ExpectedOutput(assert.failed:assertion.false) + assert shift == 1 + +def lshift_neg(a: int, b: int) -> None: + Requires(a > -128 and a <= 127) + Requires(b >= -128 and b <= 127) + + #:: ExpectedOutput(application.precondition:assertion.false) + shift = a << b + +def rshift_general(a: int, b: int) -> None: + Requires(a >= -128 and a <= 127) + Requires(b >=0 and b <= 127) + + shift = a >> b + + if b <= 8: + rshift_unlimited(a, b) + + if b == 9: + assert shift == a // 512 + + #:: ExpectedOutput(assert.failed:assertion.false) + assert shift == 1 + +def rshift_unlimited(a: int, b: int) -> None: + Requires(b >= 0 and b <= 8) + + shift = a >> b + + # Unfortunately we cannot prove the equivalence shift == a // (2 ** b) + if b == 0: + assert shift == a + if b == 1: + assert shift == a // 2 + if b == 2: + assert shift == a // 4 + if b == 3: + assert shift == a // 8 + if b == 4: + assert shift == a // 16 + if b == 5: + assert shift == a // 32 + if b == 6: + assert shift == a // 64 + if b == 7: + assert shift == a // 128 + if b == 8: + assert shift == a // 256 + + #:: ExpectedOutput(assert.failed:assertion.false) + assert shift == 1 + +def int_bit_length() -> None: + + assert (0).bit_length() == 0 + assert (1).bit_length() == 1 + assert (3).bit_length() == 2 + assert (7).bit_length() == 3 + assert (15).bit_length() == 4 + + #:: ExpectedOutput(assert.failed:assertion.false) + assert (3245).bit_length() == 6 + +def int_bit_length_general(a: int) -> None: + Requires(a >= 0 and a < (1 << 64)) + + if a > 2: + assert a.bit_length() == (a >> 1).bit_length() + 1 + \ No newline at end of file diff --git a/tests/functional/verification/test_bytearray.py b/tests/functional/verification/test_bytearray.py new file mode 100644 index 000000000..e79fac036 --- /dev/null +++ b/tests/functional/verification/test_bytearray.py @@ -0,0 +1,209 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from nagini_contracts.contracts import * + + +def test_bytearray_constr() -> None: + a = bytearray() + assert len(a) == 0 + assert 6 not in a + #:: ExpectedOutput(assert.failed:assertion.false) + assert 2 in a + +def test_bytearray_constr_int() -> None: + a = bytearray(7) + assert len(a) == 7 + assert a[3] == 0 + assert 6 not in a + #:: ExpectedOutput(assert.failed:assertion.false) + assert 2 in a + +def test_bytearray_constr_list() -> None: + a = bytearray([2,3,4]) + assert len(a) == 3 + assert a[0] == 2 + assert a[1] == 3 + assert a[2] == 4 + + #:: ExpectedOutput(assert.failed:assertion.false) + assert 5 in a + +def test_bytearray_constr_list_bounds_low() -> None: + #:: ExpectedOutput(call.precondition:assertion.false) + a = bytearray([-1,3,4]) + +def test_bytearray_constr_list_bounds_high() -> None: + #:: ExpectedOutput(call.precondition:assertion.false) + a = bytearray([0,3,256]) + +def test_bytearray_constr_bytearray() -> None: + a = bytearray([2,3,4]) + b = bytearray(a) + + assert len(b) == 3 + assert b[0] == 2 + assert b[1] == 3 + assert b[2] == 4 + + #:: ExpectedOutput(assert.failed:assertion.false) + assert 5 in b + +def test_bytearray_bool() -> None: + a = bytearray([0]) + b = bytearray(3) + c = bytearray() + + assert a + assert b + + #:: ExpectedOutput(assert.failed:assertion.false) + assert c + +def test_bytearray_eq1() -> None: + a = bytearray([1,2,3]) + b = bytearray([1,2,3]) + + assert a == b + +def test_bytearray_eq2() -> None: + a = bytearray([1,2,3]) + b = bytearray([2,2,3]) + + #:: ExpectedOutput(assert.failed:assertion.false) + assert a == b + +def test_bytearray_eq_client1(b1: bytearray, b2: bytearray) -> None: + Requires(bytearray_pred(b1)) + Requires(bytearray_pred(b2)) + Requires(b1 == b2) + Requires(len(b1) > 0) + + assert b1[0] == b2[0] + +def test_bytearray_eq_client2(b1: bytearray, b2: bytearray) -> None: + Requires(bytearray_pred(b1)) + Requires(bytearray_pred(b2)) + Requires(b1 == b2) + Requires(len(b1) > 0) + + b1[0] = 42 + + #:: ExpectedOutput(assert.failed:assertion.false) + assert b1[0] == b2[0] + +def test_byterray_append() -> None: + a = bytearray([2,3,4]) + a.append(5) + + assert len(a) == 4 + assert a[0] == 2 + assert a[1] == 3 + assert a[2] == 4 + assert a[3] == 5 + + #:: ExpectedOutput(assert.failed:assertion.false) + assert a[2] == 8 + +def test_bytearray_append_bounds_low() -> None: + a = bytearray() + #:: ExpectedOutput(call.precondition:assertion.false) + a.append(-10) + +def test_bytearray_append_bounds_high() -> None: + a = bytearray() + #:: ExpectedOutput(call.precondition:assertion.false) + a.append(256) + +def test_bytearray_setitem() -> None: + a = bytearray([2,3,4]) + + a[0] = 10 + a[1] = 0 + a[2] = 255 + assert a[0] == 10 + assert a[1] == 0 + + #:: ExpectedOutput(assert.failed:assertion.false) + assert a[2] == 254 + +def test_bytearray_setitem_bounds_low() -> None: + a = bytearray([0,128,255]) + #:: ExpectedOutput(call.precondition:assertion.false) + a[0] = -1 + +def test_bytearray_setitem_bounds_high() -> None: + a = bytearray([0,128,255]) + #:: ExpectedOutput(call.precondition:assertion.false) + a[0] = 256 + +def test_byterray_extend() -> None: + a = bytearray([2,3,4]) + b = bytearray([5,6,7]) + a.extend(b) + + assert len(a) == 6 + assert a[4] == 6 + assert a[5] == 7 + + #:: ExpectedOutput(assert.failed:assertion.false) + assert a[3] == 8 + +def test_bytearray_reverse() -> None: + a = bytearray([2,3,4]) + a.reverse() + + assert len(a) == 3 + assert a[0] == 4 + assert a[1] == 3 + + #:: ExpectedOutput(assert.failed:assertion.false) + assert a[2] == 4 + +def test_bytearray_getitem_slice() -> None: + a = bytearray([2,3,4]) + b = a[1:] + + assert len(b) == 2 + assert b[0] == 3 + assert b[1] == 4 + + c = a[:-1] + assert len(c) == 2 + #:: ExpectedOutput(assert.failed:assertion.false) + assert c[0] == 3 + + +def test_bytearray_perm(b: bytearray) -> None: + #:: ExpectedOutput(call.precondition:insufficient.permission) + b.append(6) + +def test_bytearray_bounds_low(b: bytearray) -> None: + Requires(bytearray_pred(b)) + Requires(len(b) > 1) + + assert b[0] >= 0 + #:: ExpectedOutput(assert.failed:assertion.false) + assert b[1] < 0 + +def test_bytearray_bounds_high(b: bytearray) -> None: + Requires(bytearray_pred(b)) + Requires(len(b) > 1) + + assert b[0] <= 255 + #:: ExpectedOutput(assert.failed:assertion.false) + assert b[1] > 255 + +def test_bytearray_iter_bounds(b: bytearray) -> None: + Requires(bytearray_pred(b)) + + for byte in b: + assert 0 <= byte and byte < 256 + +def test_bytearray_hex(b: bytearray) -> None: + Requires(bytearray_pred(b)) + + value = b.hex() + + #:: ExpectedOutput(assert.failed:assertion.false) + assert value == "" \ No newline at end of file diff --git a/tests/functional/verification/test_dataclass.py b/tests/functional/verification/test_dataclass.py new file mode 100644 index 000000000..f404a3a5e --- /dev/null +++ b/tests/functional/verification/test_dataclass.py @@ -0,0 +1,121 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from typing import cast + +from nagini_contracts.contracts import * +from dataclasses import dataclass, field + +@dataclass +class A: + data: int + + @Pure + def __eq__(self, other: object) -> bool: + Requires(Rd(self.data)) + Requires(Implies(isinstance(other, A), Rd(cast(A, other).data))) + if not isinstance(other, A): + return False + + return self.data == other.data + +@dataclass +class C: + fields: list[A] + +@dataclass +class D: + value: int + length: int + text: str + +@dataclass +class ListClass: + arr: list[int] = field(default_factory=list) + +def test_1(val: int) -> None: + a = A(val) + + assert a.data == val + + a.data = 3 + #:: ExpectedOutput(assert.failed:assertion.false) + assert a.data == val + +def test_2() -> None: + a1 = A(0) + a2 = A(3) + a3 = A(42) + c = C([a1, a2, a3]) + + assert len(c.fields) == 3 + assert c.fields[0].data == 0 + + c.fields.append(A(20)) + assert len(c.fields) == 4 + assert c.fields[3].data == 20 + + #:: ExpectedOutput(assert.failed:assertion.false) + assert c.fields[1].data == c.fields[2].data + +def test_3() -> None: + c = C([A(0)]) + + assert c.fields[0].data == 0 + + c.fields = [] + assert len(c.fields) == 0 + +def test_named_param(val: int, length: int) -> None: + d = D(length=length, value=val, text="") + + assert d.value == val + assert d.text == "" + + #:: ExpectedOutput(assert.failed:assertion.false) + assert d.length == 2 + +def test_eq_1(val: int) -> None: + a1 = A(val) + a2 = A(val) + a3 = A(0) + + assert a1 == a2 + + #:: ExpectedOutput(assert.failed:assertion.false) + assert a1 == a3 + +def test_list_ref() -> None: + l = [1,2,3] + f = ListClass(l) + + l.append(4) + assert len(f.arr) == 4 + assert ToSeq(f.arr) == PSeq(1,2,3,4) + #:: ExpectedOutput(assert.failed:assertion.false) + assert f.arr[0] == 5 + +def test_list_conditions(l: list[int]) -> None: + Requires(list_pred(l)) + Requires(Forall(l, lambda i: 0 <= i and i < 10)) + + f = ListClass(l) + assert Forall(f.arr, lambda i: 0 <= i and i < 10) + +def test_list_eq(left: ListClass, right: ListClass) -> None: + Requires(Acc(left.arr) and list_pred(left.arr)) + Requires(Acc(right.arr) and list_pred(right.arr)) + Requires(len(left.arr) == len(right.arr)) + + #:: ExpectedOutput(assert.failed:assertion.false) + assert left.arr == right.arr + +def test_list_eq_elements(left: ListClass, right: ListClass) -> None: + Requires(Acc(left.arr) and list_pred(left.arr)) + Requires(Acc(right.arr) and list_pred(right.arr)) + Requires(len(left.arr) == len(right.arr)) + Requires(Forall(int, lambda i: Implies(0 <= i and i < len(left.arr), left.arr[i] == right.arr[i]))) + + assert left.arr == right.arr + #:: ExpectedOutput(assert.failed:assertion.false) + assert left.arr is right.arr \ No newline at end of file diff --git a/tests/functional/verification/test_dataclass_defaults.py b/tests/functional/verification/test_dataclass_defaults.py new file mode 100644 index 000000000..8733d97e8 --- /dev/null +++ b/tests/functional/verification/test_dataclass_defaults.py @@ -0,0 +1,113 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from enum import IntEnum +from typing import List +from nagini_contracts.contracts import * +from dataclasses import dataclass, field + +@dataclass +class A: + num: int = 2 + num2: int = 10 + +@dataclass(frozen=True) +class B: + num: int + my_field: int = 5 + +@dataclass(frozen=True) +class FieldClass: + arr: List[int] = field(default_factory=list) + +@dataclass(frozen=True) +class ComplexClass: + num1: int + num2: int = 3 + arr: List[int] = field(default_factory=list) + arr2: List[int] = field(default_factory=list) + +class Color_Enum(IntEnum): + red = 0 + green = 1 + blue = 2 + yellow = 3 + +@dataclass +class C: + color: Color_Enum = Color_Enum.green + +def test_default_vals1() -> None: + a = A() + + assert a.num == 2 + + #:: ExpectedOutput(assert.failed:assertion.false) + assert a.num == 3 + +def test_default_vals2(val: int) -> None: + b = B(val) + + assert b.num == val + assert b.my_field == 5 + + b2 = B(val, val) + assert b2.num == b2.my_field + + #:: ExpectedOutput(assert.failed:assertion.false) + assert b2.my_field == 5 + +def test_default_factory_list1() -> None: + a = FieldClass() + b = FieldClass() + + a.arr.append(1) + assert len(a.arr) == 1 + assert len(b.arr) == 0 + + #:: ExpectedOutput(assert.failed:assertion.false) + assert a.arr == b.arr + +def test_default_factory_list2() -> None: + l = [1,2,3] + a = FieldClass(l) + b = FieldClass(l) + + a.arr.append(1) + assert len(a.arr) == 4 + assert len(b.arr) == 4 + + assert a.arr is b.arr + +def test_default_factory_list3() -> None: + a = ComplexClass(7) + b = ComplexClass(5, arr=[1]) + + assert a.num1 == 7 + assert b.num1 == 5 + + assert a.num2 == 3 + assert b.num2 == 3 + + a.arr.append(1) + assert len(a.arr) == 1 + assert len(b.arr) == 1 + + assert a.arr[0] == b.arr[0] + + assert len(a.arr2) == 0 + assert len(b.arr2) == 0 + + #:: ExpectedOutput(assert.failed:assertion.false) + assert a.arr is b.arr + +def test_default_val_enum() -> None: + c = C() + + assert c.color == Color_Enum.green + + c2 = C(Color_Enum.yellow) + assert c2.color == Color_Enum.yellow + + #:: ExpectedOutput(assert.failed:assertion.false) + assert c.color == c2.color \ No newline at end of file diff --git a/tests/functional/verification/test_dataclass_frozen.py b/tests/functional/verification/test_dataclass_frozen.py new file mode 100644 index 000000000..f79e796bf --- /dev/null +++ b/tests/functional/verification/test_dataclass_frozen.py @@ -0,0 +1,132 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from nagini_contracts.contracts import * +from dataclasses import dataclass, field + +@dataclass(frozen=True) +class A: + data: int + + @Pure + def __eq__(self, other: object) -> bool: + if not isinstance(other, A): + return False + + return self.data == other.data + +@dataclass(frozen=True) +class B: + field: A + + @Pure + def __eq__(self, other: object) -> bool: + if not isinstance(other, B): + return False + + return self.field == other.field + +@dataclass(frozen=True) +class C: + fields: list[A] + +@dataclass(frozen=True) +class D: + value: int + length: int + text: str + +@dataclass(frozen=True) +class ListClass: + arr: list[int] = field(default_factory=list) + +def test_1(val: int) -> None: + a = A(val) + + assert a.data == val + + #:: ExpectedOutput(assert.failed:assertion.false) + assert a.data == 2 + +def test_2() -> None: + a1 = A(0) + a2 = A(3) + a3 = A(42) + c = C([a1, a2, a3]) + + assert len(c.fields) == 3 + assert c.fields[0].data == 0 + + c.fields.append(A(20)) + assert len(c.fields) == 4 + assert c.fields[3].data == 20 + + #:: ExpectedOutput(assert.failed:assertion.false) + assert c.fields[1].data == c.fields[2].data + +def test_named_param(val: int, length: int) -> None: + d = D(length=length, value=val, text="") + + assert d.value == val + assert d.text == "" + + #:: ExpectedOutput(assert.failed:assertion.false) + assert d.length == 2 + +def test_eq_1(val: int) -> None: + a1 = A(val) + a2 = A(val) + a3 = A(0) + + assert a1 == a2 + + #:: ExpectedOutput(assert.failed:assertion.false) + assert a1 == a3 + +def test_eq_2(a1: A, a2: A) -> None: + b1 = B(a1) + b2 = B(a1) + b3 = B(a2) + + assert b1 == b2 + + if a1 == a2: + assert b1 == b3 + else: + #:: ExpectedOutput(assert.failed:assertion.false) + assert b1 == b3 + +def test_list_ref() -> None: + l = [1,2,3] + f = ListClass(l) + + l.append(4) + assert len(f.arr) == 4 + assert ToSeq(f.arr) == PSeq(1,2,3,4) + #:: ExpectedOutput(assert.failed:assertion.false) + assert f.arr[0] == 5 + +def test_list_conditions(l: list[int]) -> None: + Requires(list_pred(l)) + Requires(Forall(l, lambda i: 0 <= i and i < 10)) + + f = ListClass(l) + assert Forall(f.arr, lambda i: 0 <= i and i < 10) + +def test_list_eq(left: ListClass, right: ListClass) -> None: + Requires(list_pred(left.arr)) + Requires(list_pred(right.arr)) + Requires(len(left.arr) == len(right.arr)) + + #:: ExpectedOutput(assert.failed:assertion.false) + assert left.arr == right.arr + +def test_list_eq_elements(left: ListClass, right: ListClass) -> None: + Requires(list_pred(left.arr)) + Requires(list_pred(right.arr)) + Requires(len(left.arr) == len(right.arr)) + Requires(Forall(int, lambda i: Implies(0 <= i and i < len(left.arr), left.arr[i] == right.arr[i]))) + + assert left.arr == right.arr + #:: ExpectedOutput(assert.failed:assertion.false) + assert left.arr is right.arr \ No newline at end of file diff --git a/tests/functional/verification/test_dataclass_override.py b/tests/functional/verification/test_dataclass_override.py new file mode 100644 index 000000000..d7a4751da --- /dev/null +++ b/tests/functional/verification/test_dataclass_override.py @@ -0,0 +1,68 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from nagini_contracts.contracts import * +from typing import Optional, TypeVar, Generic +from dataclasses import dataclass + +T = TypeVar('T') + +@dataclass(frozen=True) +class ResultDataclass(Generic[T]): + success: bool + error_code: int = 0 + data: Optional[T] = None + + @Pure + def __bool__(self) -> bool: + return self.success + + def __post_init__(self) -> None: + Requires(Implies(self.success, self.data != None)) + Requires((self.success and self.error_code <= 0) or (not self.success and self.error_code > 0)) + + if not self.success and self.error_code <= 0: + raise Exception() + + if self.success and self.error_code > 0: + raise Exception() + +def test_bool() -> None: + res = ResultDataclass(True, 0, 'data') + + assert res.success + assert res + + #:: ExpectedOutput(assert.failed:assertion.false) + assert res.error_code == 1 + +def test_bool2(res: ResultDataclass[int]) -> None: + + if res.success: + assert res + + #:: ExpectedOutput(assert.failed:assertion.false) + assert res + +def test_init_False() -> None: + res = ResultDataclass(False, 1, '') + + #:: ExpectedOutput(call.precondition:assertion.false) + res = ResultDataclass(False, 0, None) + +def test_init_code() -> None: + res = ResultDataclass(True, 0, 'data') + + #:: ExpectedOutput(call.precondition:assertion.false) + res = ResultDataclass(True, 1, 'data') + +def test_init_None() -> None: + res = ResultDataclass(True, 0, 'data') + + #:: ExpectedOutput(call.precondition:assertion.false) + res = ResultDataclass(True, 0, None) + +def test_data(res: ResultDataclass[str]) -> None: + if res.success: + #:: ExpectedOutput(assert.failed:assertion.false) + assert isinstance(res.data, str) \ No newline at end of file diff --git a/tests/functional/verification/test_enum_int.py b/tests/functional/verification/test_enum_int.py new file mode 100644 index 000000000..9625e2dbf --- /dev/null +++ b/tests/functional/verification/test_enum_int.py @@ -0,0 +1,71 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from nagini_contracts.contracts import * +from enum import IntEnum + +class flag(IntEnum): + success = 0 + failure = 1 + +class flag2(IntEnum): + success = 0 + failure = 2 + +def test_value() -> None: + f = flag(1) + + assert f == flag(1) + assert f == 1 + assert f == flag.failure + + assert flag.success == 0 + assert flag.success == flag(0) + assert flag.success == False + assert flag2(0) == flag(0) + + #:: ExpectedOutput(assert.failed:assertion.false) + assert flag.success == flag.failure + +def test_comparison() -> None: + f0 = flag(0) + f1 = flag(0) + f2 = flag2(0) + + assert f0 == f1 + assert f0 is f1 + + assert f1 == f2 + #:: ExpectedOutput(assert.failed:assertion.false) + assert f1 is f2 + +def test_value3() -> None: + assert flag.success == flag2.success + + #:: ExpectedOutput(assert.failed:assertion.false) + assert flag.failure == flag2.failure + +def test_contraints(f: flag) -> None: + assert 0 <= f + assert f <= 1 + + assert 0 <= int(f) + assert int(f) <= 1 + +def test_contraints2(f: flag2) -> None: + assert f == 0 or f == 2 + + #:: ExpectedOutput(assert.failed:assertion.false) + assert f == 0 + +def test_contraints3(f: flag2) -> None: + assert f != 1 + + #:: ExpectedOutput(assert.failed:assertion.false) + assert f == 3 + +def test_precond(f: flag) -> None: + Requires(f == flag.success) + + #:: ExpectedOutput(assert.failed:assertion.false) + assert f == 1 \ No newline at end of file diff --git a/tests/functional/verification/test_lists.py b/tests/functional/verification/test_lists.py index 3b054bba7..13fedfddd 100644 --- a/tests/functional/verification/test_lists.py +++ b/tests/functional/verification/test_lists.py @@ -151,4 +151,33 @@ def test_mul() -> None: assert newlist[2] is super1 assert newlist[5] is super2 #:: ExpectedOutput(assert.failed:assertion.false) - assert mylist[1] is super1 \ No newline at end of file + assert mylist[1] is super1 + +def test_eq1() -> None: + l1: list[int] = [1,2,3] + l2: list[int] = [1,2,3] + + assert l1 == l2 + #:: ExpectedOutput(assert.failed:assertion.false) + assert l1 is l2 + +def test_eq2() -> None: + l1: list[int] = [1,2] + l2: list[int] = [1,2,3] + + assert l1 != l2 + l1.append(3) + assert l1 == l2 + l2.append(4) + #:: ExpectedOutput(assert.failed:assertion.false) + assert l1 == l2 + +# TODO +# def test_index_to_elem(l: list[int]) -> None: +# Requires(list_pred(l)) +# Requires(Forall(int, lambda j: (Implies(0 <= j and j < len(l), 0 <= l[j] and l[j] < 256), [[l[j]]]))) + +# assert Forall(l, lambda el: 0 <= el and el < 256) + +# #:: ExpectedOutput(assert.failed:assertion.false) +# assert Forall(l, lambda el: 0 <= el and el < 255) \ No newline at end of file diff --git a/tests/functional/verification/test_pbyteseq.py b/tests/functional/verification/test_pbyteseq.py new file mode 100644 index 000000000..1a040b390 --- /dev/null +++ b/tests/functional/verification/test_pbyteseq.py @@ -0,0 +1,101 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from typing import List +from nagini_contracts.contracts import * + + +def test_seq() -> None: + no_ints = PByteSeq() + assert len(no_ints) == 0 + ints = PByteSeq(1, 2, 3) + + assert 3 in ints and 1 in ints + assert 4 not in ints + assert ints[1] == 2 + assert len(ints) == 3 + ints2 = ints + ints + assert len(ints2) == 6 + assert ints2[3] == 1 + ints3 = ints2.take(4) + assert len(ints3) == 4 + assert ints3[1] == ints2[1] + ints4 = ints.update(0, 3) + assert 1 not in ints4 + assert ints4[0] == 3 + ints5 = ints.drop(2) + assert len(ints5) == 1 + assert ints5[0] == 3 + #:: ExpectedOutput(assert.failed:assertion.false) + assert False + +def test_byteseq_bounds_low(b: PByteSeq) -> None: + Requires(len(b) > 1) + + assert b[0] >= 0 + #:: ExpectedOutput(assert.failed:assertion.false) + assert b[1] < 0 + +def test_byteseq_bounds_high(b: PByteSeq) -> None: + Requires(len(b) > 1) + + assert b[0] <= 255 + #:: ExpectedOutput(assert.failed:assertion.false) + assert b[1] > 255 + +def test_range() -> None: + ints = PByteSeq(1,3,5,6,8) + r = ints.drop(1).take(2) + + assert len(ints) == 5 + assert len(r) == 2 + assert 5 in r + assert r[0] == 3 + assert 1 not in r + assert 8 not in r + + #:: ExpectedOutput(assert.failed:assertion.false) + assert r[1] == 6 + +def test_list_ToByteSeq() -> None: + a = [1,2,3] + assert ToByteSeq(a) == PByteSeq(1,2,3) + #:: ExpectedOutput(assert.failed:assertion.false) + assert False + + +def test_bytearray_ToByteSeq() -> None: + a = bytearray([1,2,3]) + assert ToByteSeq(a) == PByteSeq(1,2,3) + #:: ExpectedOutput(assert.failed:assertion.false) + assert False + +def test_bytearray_bounds(b_array: bytearray) -> None: + Requires(bytearray_pred(b_array)) + Requires(len(b_array) > 2) + seq = ToByteSeq(b_array) + + assert 0 <= seq[0] and seq[0] <= 0xFF + + #:: ExpectedOutput(assert.failed:assertion.false) + assert seq[1] >= 256 + +def test_list_interop(b_array: bytearray) -> None: + Requires(bytearray_pred(b_array)) + l = list(b_array) + + byteseq_direct = ToByteSeq(b_array) + byteseq = ToByteSeq(l) + + assert byteseq_direct == byteseq + +def test_list_interop2(b_list: List[int]) -> None: + Requires(list_pred(b_list)) + Requires(Forall(b_list, lambda el: 0 <= el and el < 256)) + + byteseq_direct = ToByteSeq(b_list) + + b_array = bytearray(b_list) + byteseq = ToByteSeq(b_array) + + assert byteseq_direct == byteseq diff --git a/tests/functional/verification/test_primitives.py b/tests/functional/verification/test_primitives.py new file mode 100644 index 000000000..ab0defb27 --- /dev/null +++ b/tests/functional/verification/test_primitives.py @@ -0,0 +1,44 @@ +from nagini_contracts.contracts import * + +@Opaque +@Pure +def positive1(i1: int) -> bool: + return i1 > 0 + +@Opaque +@Pure +def positive2(i1: PInt) -> bool: + return i1 > 0 + +def client1(i1: int, i2: int) -> None: + if i1 == i2: + if positive1(i1): + #:: ExpectedOutput(assert.failed:assertion.false) + Assert(positive1(i2)) + +def client2(i1: int, i2: int) -> None: + if i1 == i2: + if positive2(i1): + Assert(positive2(i2)) + + +@Opaque +@Pure +def true1(i1: bool) -> bool: + return i1 + +@Opaque +@Pure +def true2(i1: PBool) -> bool: + return i1 + + +def bclient1(i1: bool, i2: bool) -> None: + if i1 == i2: + if true1(i1): + Assert(true1(i2)) + +def bclient2(i1: bool, i2: bool) -> None: + if i1 == i2: + if true2(i1): + Assert(true2(i2)) \ No newline at end of file diff --git a/tests/functional/verification/test_property_inherited.py b/tests/functional/verification/test_property_inherited.py new file mode 100644 index 000000000..214699b2d --- /dev/null +++ b/tests/functional/verification/test_property_inherited.py @@ -0,0 +1,32 @@ +# Any copyright is dedicated to the Public Domain. +# http://creativecommons.org/publicdomain/zero/1.0/ + +from nagini_contracts.contracts import * + +class A: + def __init__(self, val: int) -> None: + Ensures(Acc(self._field)) # type: ignore + Ensures(self.field == val) + self._field = val + + @property + def field(self) -> int: + Requires(Acc(self._field)) + return self._field + +class B(A): + + def __init__(self, val: int) -> None: + Ensures(Acc(self._field)) + Ensures(self.field == val) + super().__init__(val) + + @Pure + def non_zero(self) -> bool: + Requires(Acc(self._field)) + return self.field != 0 + +def test() -> None: + b = B(5) + + assert b.non_zero() \ No newline at end of file diff --git a/tests/functional/verification/test_pseq.py b/tests/functional/verification/test_pseq.py index 9a257c105..7962329b8 100644 --- a/tests/functional/verification/test_pseq.py +++ b/tests/functional/verification/test_pseq.py @@ -35,7 +35,6 @@ def test_seq() -> None: #:: ExpectedOutput(assert.failed:assertion.false) assert False - def test_list_ToSeq() -> None: a = [1,2,3] assert ToSeq(a) == PSeq(1,2,3) diff --git a/tests/functional/verification/test_string.py b/tests/functional/verification/test_string.py index de23598e0..64d102595 100644 --- a/tests/functional/verification/test_string.py +++ b/tests/functional/verification/test_string.py @@ -29,4 +29,20 @@ def main() -> None: my_string2 = "a" Assert(my_string2 == "a") #:: ExpectedOutput(assert.failed:assertion.false) - Assert(my_string2 == "b") \ No newline at end of file + Assert(my_string2 == "b") + + +def test_str_format_wrong1() -> None: + a = "".format() + #:: ExpectedOutput(assert.failed:assertion.false) + Assert(a == " ") + +def test_str_format_wrong2() -> None: + a = "{0}".format(2) + #:: ExpectedOutput(assert.failed:assertion.false) + Assert(a == "3") + +def test_fstr_wrong() -> None: + a = f"{8}" + #:: ExpectedOutput(assert.failed:assertion.false) + Assert(a == "3") \ No newline at end of file