diff --git a/src/Solcore/Frontend/TypeInference/TcContract.hs b/src/Solcore/Frontend/TypeInference/TcContract.hs index 0a5e3972c..6b1431ad0 100644 --- a/src/Solcore/Frontend/TypeInference/TcContract.hs +++ b/src/Solcore/Frontend/TypeInference/TcContract.hs @@ -1,5 +1,8 @@ module Solcore.Frontend.TypeInference.TcContract where +import Algebra.Graph.AdjacencyMap +import Algebra.Graph.AdjacencyMap.Algorithm +import Algebra.Graph.NonEmpty.AdjacencyMap qualified as NAG import Control.Monad import Control.Monad.Except import Control.Monad.State @@ -8,6 +11,7 @@ import Data.List import Data.List.NonEmpty qualified as N import Data.Map qualified as Map import Data.Maybe +import Data.Set (Set) import Data.Set qualified as Set import Solcore.Frontend.Pretty.ShortName import Solcore.Frontend.Pretty.SolcorePretty @@ -94,6 +98,8 @@ tcTopDeclChecks topDeclChecks = checkSynonymCycles syns let st = buildSynTable syns csExpanded <- everywhereM (mkM (expandTyM st)) cs + checkRecursiveTypes (topLevelDts csExpanded) + mapM_ checkRecursiveTypes (perContractDts csExpanded) mapM_ checkTopDecl (filter isClass csExpanded) mapM_ checkTopDecl (filter (not . isClass) csExpanded) trustImportedDecls csExpanded @@ -120,6 +126,8 @@ tcTopDeclChecks topDeclChecks = TrustTopDeclBody -> withPartialDataTypesDisabled $ pure Nothing + topLevelDts cs' = [d | TDataDef d <- cs'] + perContractDts cs' = [[d | CDataDecl d <- cds] | TContr (Contract _ _ cds) <- cs'] tcTopDecl' d = timeItNamed (shortName d) $ do clearSubst tcTopDecl d @@ -168,6 +176,102 @@ recursiveSynonymError cyclePath = " " ++ intercalate " -> " (map pretty cyclePath) ] +-- check for recursive data types + +allDataTys :: [TopDecl Name] -> [DataTy] +allDataTys = concatMap collect + where + collect (TDataDef d) = [d] + collect (TContr (Contract _ _ cds)) = [d | CDataDecl d <- cds] + collect _ = [] + +tyVarNames :: Ty -> [Name] +tyVarNames (TyVar tv) = [tyvarName tv] +tyVarNames (TyCon _ ts) = concatMap tyVarNames ts +tyVarNames _ = [] + +-- Collect type variable names that appear in non-phantom argument positions. +-- Phantom positions (indices in the map for the head type constructor) are skipped. +nonPhantomVarNames :: Map.Map Name (Set Int) -> Ty -> [Name] +nonPhantomVarNames m (TyCon n args) = + let phantomIdxs = Map.findWithDefault Set.empty n m + in concatMap + ( \(i, arg) -> + if Set.member i phantomIdxs then [] else nonPhantomVarNames m arg + ) + (zip [0 ..] args) +nonPhantomVarNames _ (TyVar v) = [tyvarName v] +nonPhantomVarNames _ _ = [] + +-- Build the phantom-parameter map using fixpoint iteration so that +-- transitively-phantom positions are discovered. A parameter at index i of +-- type T is phantom when it never appears in a non-phantom position across all +-- constructor field types (using the current map to decide what counts as +-- non-phantom). Starting from the empty map and iterating monotonically to a +-- fixpoint ensures that every position that can be proved phantom eventually is. +buildPhantomMap :: [DataTy] -> Map.Map Name (Set Int) +buildPhantomMap dts = fixpoint initial + where + initial = Map.fromList [(dataName dt, Set.empty) | dt <- dts] + + fixpoint m = + let m' = Map.fromList (map (refineEntry m) dts) + in if m == m' then m else fixpoint m' + + refineEntry m (DataTy n params ctors) = + let allFieldTys = concatMap constrTy ctors + isPhantomParam p = + let pName = tyvarName p + in all (\ty -> pName `notElem` nonPhantomVarNames m ty) allFieldTys + phantomIdxs = Set.fromList [i | (i, p) <- zip [0 ..] params, isPhantomParam p] + in (n, phantomIdxs) + +nonPhantomTyNames :: Map.Map Name (Set Int) -> Ty -> [Name] +nonPhantomTyNames phantomMap (TyCon n args) = + n : concatMap processArg (zip [0 ..] args) + where + phantomIdxs = Map.findWithDefault Set.empty n phantomMap + processArg (i, arg) + | Set.member i phantomIdxs = [] + | otherwise = nonPhantomTyNames phantomMap arg +nonPhantomTyNames _ _ = [] + +buildTypeDepsGraph :: Set Name -> [DataTy] -> AdjacencyMap Name +buildTypeDepsGraph userTypes dts = + overlay isolated edged + where + phantomMap = buildPhantomMap dts + isolated = vertices (Set.toList userTypes) + edged = stars [(dataName dt, deps dt) | dt <- dts] + deps (DataTy _ _ ctors) = + nub + . filter (`Set.member` userTypes) + . concatMap (\(Constr _ tys) -> concatMap (nonPhantomTyNames phantomMap) tys) + $ ctors + +checkRecursiveTypes :: [DataTy] -> TcM () +checkRecursiveTypes dts = + case cyclicSccs of + [] -> pure () + (c : _) -> recursiveTypeError (NAG.vertexList1 c) + where + userTypes = Set.fromList (map dataName dts) + graph = buildTypeDepsGraph userTypes dts + cyclicSccs = filter (isCyclic graph) (vertexList (scc graph)) + isCyclic origGraph sccComp = + case N.toList (NAG.vertexList1 sccComp) of + [v] -> hasEdge v v origGraph -- singleton SCC: cyclic only if self-loop + _ -> True -- 2+ vertices: always a mutual cycle + +recursiveTypeError :: N.NonEmpty Name -> TcM a +recursiveTypeError cycleVerts = + throwError $ + unlines + [ "Recursive data type detected:", + " " ++ intercalate ", " (map pretty (N.toList cycleVerts)), + " (Data types must be non-recursive)" + ] + -- setting up pragmas for type checking setupPragmas :: [Pragma] -> TcM () @@ -265,7 +369,7 @@ checkTopDecl _ = pure () tcContract :: Contract Name -> TcM (Contract Id, [(Name, Scheme)]) tcContract c@(Contract n vs cdecls) = - withLocalEnv $ withContractName n $ do + withLocalContractEnv $ withContractName n $ do ctx' <- gets ctx initializeEnv c decls' <- mapM tcDecl' cdecls diff --git a/src/Solcore/Frontend/TypeInference/TcMonad.hs b/src/Solcore/Frontend/TypeInference/TcMonad.hs index e44176cbe..0b7bdee83 100644 --- a/src/Solcore/Frontend/TypeInference/TcMonad.hs +++ b/src/Solcore/Frontend/TypeInference/TcMonad.hs @@ -417,6 +417,18 @@ withLocalEnv ta = putEnv savedCtx pure a +-- Like withLocalEnv but also restores the typeTable, for contract scopes +-- where data type names must not leak between sibling contracts. +withLocalContractEnv :: TcM a -> TcM a +withLocalContractEnv ta = + do + savedCtx <- gets ctx + savedTypes <- gets typeTable + a <- ta + putEnv savedCtx + modify (\env -> env {typeTable = savedTypes}) + pure a + envList :: TcM [(Name, Scheme)] envList = gets (Map.toList . ctx) diff --git a/test/Cases.hs b/test/Cases.hs index 0198a73a1..58a46359b 100644 --- a/test/Cases.hs +++ b/test/Cases.hs @@ -243,6 +243,7 @@ cases = runTestForFile "const.solc" caseFolder, runTestExpectingFailure "const-array.solc" caseFolder, runTestForFile "constructor-weak-args.solc" caseFolder, + runTestForFile "constructors-contract.solc" caseFolder, runTestExpectingFailure "complexproxy.solc" caseFolder, runTestForFile "cyclical-defs.solc" caseFolder, runTestForFile "cyclical-defs-inferred.solc" caseFolder, @@ -263,12 +264,12 @@ cases = runTestExpectingFailure "duplicated-type-name.solc" caseFolder, runTestForFile "DuplicateFun.solc" caseFolder, runTestExpectingFailure "DupFun.solc" caseFolder, - runTestForFile "EitherModule.solc" caseFolder, + runTestExpectingFailure "EitherModule.solc" caseFolder, runTestForFile "empty-asm.solc" caseFolder, runTestExpectingFailure "Enum.solc" caseFolder, runTestExpectingFailure "Eq.solc" caseFolder, runTestForFile "EqQual.solc" caseFolder, - runTestForFile "EvenOdd.solc" caseFolder, + runTestExpectingFailure "EvenOdd.solc" caseFolder, runTestExpectingFailure "Filter.solc" caseFolder, runTestForFile "foo-class.solc" caseFolder, runTestForFile "Foo.solc" caseFolder, @@ -293,8 +294,8 @@ cases = runTestExpectingFailure "joinErr.solc" caseFolder, runTestExpectingFailure "KindTest.solc" caseFolder, runTestExpectingFailure "listeq.solc" caseFolder, - runTestForFile "ListModule.solc" caseFolder, - runTestForFile "listid.solc" caseFolder, + runTestExpectingFailure "ListModule.solc" caseFolder, + runTestExpectingFailure "listid.solc" caseFolder, runTestForFile "Logic.solc" caseFolder, runTestExpectingFailure "mainproxy.solc" caseFolder, runTestForFile "MatchCall.solc" caseFolder, @@ -309,6 +310,7 @@ cases = runTestForFile "monomorphic-require.solc" caseFolder, runTestForFile "morefun.solc" caseFolder, runTestForFile "Mutuals.solc" caseFolder, + runTestForFile "rec-memory.solc" caseFolder, runTestExpectingFailure "nano-desugared.solc" caseFolder, runTestForFile "NegPair.solc" caseFolder, runTestForFile "nid.solc" caseFolder, @@ -323,8 +325,8 @@ cases = runTestExpectingFailure "PairMatch2.solc" caseFolder, -- failing due to missing assign constraint runTestExpectingFailure "patterson-bug.solc" caseFolder, - runTestForFile "Peano.solc" caseFolder, - runTestForFile "PeanoMatch.solc" caseFolder, + runTestExpectingFailure "Peano.solc" caseFolder, + runTestExpectingFailure "PeanoMatch.solc" caseFolder, runTestForFile "polymatch-error.solc" caseFolder, runTestForFile "polymorphic-require.solc" caseFolder, runTestExpectingFailure "pragma_merge_fail_coverage.solc" caseFolder, @@ -336,6 +338,8 @@ cases = runTestForFile "proxy.solc" caseFolder, runTestExpectingFailure "proxy1.solc" caseFolder, runTestForFile "rec.solc" caseFolder, + runTestExpectingFailure "recursive-type-direct.solc" caseFolder, + runTestExpectingFailure "recursive-type-mutual.solc" caseFolder, runTestExpectingFailure "require-annotation-missing-param.solc" caseFolder, runTestExpectingFailure "require-annotation-missing-return.solc" caseFolder, runTestExpectingFailure "require-annotation-missing-both.solc" caseFolder, @@ -375,7 +379,7 @@ cases = runTestExpectingFailure "subject-index.solc" caseFolder, runTestExpectingFailure "subject-reduction.solc" caseFolder, runTestExpectingFailure "subsumption-test.solc" caseFolder, - runTestForFile "super-class.solc" caseFolder, + runTestExpectingFailure "super-class.solc" caseFolder, runTestForFile "super-class-num.solc" caseFolder, runTestForFile "tiamat.solc" caseFolder, runTestForFile "tuple-trick.solc" caseFolder, @@ -423,6 +427,7 @@ cases = runTestForFile "redundant-match.solc" caseFolder, runTestForFile "false-redundant-warning.solc" caseFolder, runTestForFile "proxy-desugar.solc" caseFolder, + runTestForFile "box.solc" caseFolder, runTestForFile "invokable-issue.solc" caseFolder, runTestForFile "td.solc" caseFolder, runTestForFile "bar.solc" caseFolder, @@ -450,7 +455,8 @@ cases = runTestForFile "ltimp.solc" caseFolder, runTestExpectingFailure "class-return-type-miss.solc" caseFolder, runTestExpectingFailure "catenable-err.solc" caseFolder, - runTestForFile "pars.solc" caseFolder + runTestForFile "pars.solc" caseFolder, + runTestExpectingFailure "synonym-example.solc" caseFolder ] where caseFolder = "./test/examples/cases" diff --git a/test/examples/cases/Ackermann.solc b/test/examples/cases/Ackermann.solc index c2181cc01..de305844a 100644 --- a/test/examples/cases/Ackermann.solc +++ b/test/examples/cases/Ackermann.solc @@ -1,4 +1,6 @@ -data Nat = Zero | Succ(Nat) ; +data memory(a) = memory(word); + +data Nat = Zero | Succ(memory(Nat)) ; function foo (x : Nat, y : Nat) -> word { match y, x { diff --git a/test/examples/cases/EitherModule.solc b/test/examples/cases/EitherModule.solc index 347af8baa..49a604e23 100644 --- a/test/examples/cases/EitherModule.solc +++ b/test/examples/cases/EitherModule.solc @@ -1,6 +1,7 @@ contract EitherModule { + data memory(a) = memory(word); data Either(a,b) = Left(a) | Right(b); - data List(a) = Nil | Cons(a,List(a)); + data List(a) = Nil | Cons(a, memory(List(a))); function lefts(xs : List(Either(word,word))) -> List(word) { match xs { diff --git a/test/examples/cases/box.solc b/test/examples/cases/box.solc new file mode 100644 index 000000000..060928a85 --- /dev/null +++ b/test/examples/cases/box.solc @@ -0,0 +1,3 @@ +data memory(a) = memory(word); +data Box(a) = Box(memory(a)); +data Rec = Rec(Box(Rec)); diff --git a/test/examples/cases/constructors-contract.solc b/test/examples/cases/constructors-contract.solc new file mode 100644 index 000000000..6ea5bd923 --- /dev/null +++ b/test/examples/cases/constructors-contract.solc @@ -0,0 +1,9 @@ +contract A { + data T = MkT(U); + data U = MkU(word); +} + +contract B { + data T = MkT2(word); + data U = MkU2(T); +} diff --git a/test/examples/cases/rec-memory.solc b/test/examples/cases/rec-memory.solc new file mode 100644 index 000000000..f3807c4f6 --- /dev/null +++ b/test/examples/cases/rec-memory.solc @@ -0,0 +1,3 @@ +data memory(a) = memory(word); + +data List(a) = Nil | Cons(a, memory(List(a))); diff --git a/test/examples/cases/recursive-type-direct.solc b/test/examples/cases/recursive-type-direct.solc new file mode 100644 index 000000000..ff180526c --- /dev/null +++ b/test/examples/cases/recursive-type-direct.solc @@ -0,0 +1,9 @@ +-- Direct recursion: Nat refers to itself in the Succ constructor. +-- Expected: type checker rejects with "Recursive data type detected". +data Nat = Zero | Succ(Nat); + +contract RecursiveTypeDirect { + function main() -> Nat { + Zero + } +} diff --git a/test/examples/cases/recursive-type-mutual.solc b/test/examples/cases/recursive-type-mutual.solc new file mode 100644 index 000000000..8b4452438 --- /dev/null +++ b/test/examples/cases/recursive-type-mutual.solc @@ -0,0 +1,10 @@ +-- Mutual recursion: Even refers to Odd and Odd refers to Even. +-- Expected: type checker rejects with "Recursive data type detected". +data Even = Zero | SuccE(Odd); +data Odd = SuccO(Even); + +contract RecursiveTypeMutual { + function main() -> Even { + Zero + } +} diff --git a/test/examples/cases/synonym-example.solc b/test/examples/cases/synonym-example.solc new file mode 100644 index 000000000..bae9ff259 --- /dev/null +++ b/test/examples/cases/synonym-example.solc @@ -0,0 +1,3 @@ + +type Ref = T; +data T = Mk(Ref); diff --git a/test/examples/pragmas/coverage.solc b/test/examples/pragmas/coverage.solc index c412dc914..62f2ba545 100644 --- a/test/examples/pragmas/coverage.solc +++ b/test/examples/pragmas/coverage.solc @@ -1,6 +1,8 @@ pragma no-coverage-condition ; -data List(a) = Nil | Cons(a,List(a)); +data memory(a) = memory(word); + +data List(a) = Nil | Cons(a,memory(List(a))); data Bool = True | False ; forall a b c . class a : C(b,c) {}