Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 105 additions & 1 deletion src/Solcore/Frontend/TypeInference/TcContract.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
module Solcore.Frontend.TypeInference.TcContract where

import Algebra.Graph.AdjacencyMap
import Algebra.Graph.AdjacencyMap.Algorithm
import Algebra.Graph.NonEmpty.AdjacencyMap qualified as NAG
import Control.Monad
import Control.Monad.Except
import Control.Monad.State
Expand All @@ -8,6 +11,7 @@ import Data.List
import Data.List.NonEmpty qualified as N
import Data.Map qualified as Map
import Data.Maybe
import Data.Set (Set)
import Data.Set qualified as Set
import Solcore.Frontend.Pretty.ShortName
import Solcore.Frontend.Pretty.SolcorePretty
Expand Down Expand Up @@ -94,6 +98,8 @@ tcTopDeclChecks topDeclChecks =
checkSynonymCycles syns
let st = buildSynTable syns
csExpanded <- everywhereM (mkM (expandTyM st)) cs
checkRecursiveTypes (topLevelDts csExpanded)
mapM_ checkRecursiveTypes (perContractDts csExpanded)
mapM_ checkTopDecl (filter isClass csExpanded)
mapM_ checkTopDecl (filter (not . isClass) csExpanded)
trustImportedDecls csExpanded
Expand All @@ -120,6 +126,8 @@ tcTopDeclChecks topDeclChecks =
TrustTopDeclBody ->
withPartialDataTypesDisabled $
pure Nothing
topLevelDts cs' = [d | TDataDef d <- cs']
perContractDts cs' = [[d | CDataDecl d <- cds] | TContr (Contract _ _ cds) <- cs']
tcTopDecl' d = timeItNamed (shortName d) $ do
clearSubst
tcTopDecl d
Expand Down Expand Up @@ -168,6 +176,102 @@ recursiveSynonymError cyclePath =
" " ++ intercalate " -> " (map pretty cyclePath)
]

-- check for recursive data types

allDataTys :: [TopDecl Name] -> [DataTy]
allDataTys = concatMap collect
where
collect (TDataDef d) = [d]
collect (TContr (Contract _ _ cds)) = [d | CDataDecl d <- cds]
collect _ = []

tyVarNames :: Ty -> [Name]
tyVarNames (TyVar tv) = [tyvarName tv]
tyVarNames (TyCon _ ts) = concatMap tyVarNames ts
tyVarNames _ = []

-- Collect type variable names that appear in non-phantom argument positions.
-- Phantom positions (indices in the map for the head type constructor) are skipped.
nonPhantomVarNames :: Map.Map Name (Set Int) -> Ty -> [Name]
nonPhantomVarNames m (TyCon n args) =
let phantomIdxs = Map.findWithDefault Set.empty n m
in concatMap
( \(i, arg) ->
if Set.member i phantomIdxs then [] else nonPhantomVarNames m arg
)
(zip [0 ..] args)
nonPhantomVarNames _ (TyVar v) = [tyvarName v]
nonPhantomVarNames _ _ = []

-- Build the phantom-parameter map using fixpoint iteration so that
-- transitively-phantom positions are discovered. A parameter at index i of
-- type T is phantom when it never appears in a non-phantom position across all
-- constructor field types (using the current map to decide what counts as
-- non-phantom). Starting from the empty map and iterating monotonically to a
-- fixpoint ensures that every position that can be proved phantom eventually is.
buildPhantomMap :: [DataTy] -> Map.Map Name (Set Int)
buildPhantomMap dts = fixpoint initial
where
initial = Map.fromList [(dataName dt, Set.empty) | dt <- dts]

fixpoint m =
let m' = Map.fromList (map (refineEntry m) dts)
in if m == m' then m else fixpoint m'

refineEntry m (DataTy n params ctors) =
let allFieldTys = concatMap constrTy ctors
isPhantomParam p =
let pName = tyvarName p
in all (\ty -> pName `notElem` nonPhantomVarNames m ty) allFieldTys
phantomIdxs = Set.fromList [i | (i, p) <- zip [0 ..] params, isPhantomParam p]
in (n, phantomIdxs)
Comment thread
rodrigogribeiro marked this conversation as resolved.

nonPhantomTyNames :: Map.Map Name (Set Int) -> Ty -> [Name]
nonPhantomTyNames phantomMap (TyCon n args) =
n : concatMap processArg (zip [0 ..] args)
where
phantomIdxs = Map.findWithDefault Set.empty n phantomMap
processArg (i, arg)
| Set.member i phantomIdxs = []
| otherwise = nonPhantomTyNames phantomMap arg
nonPhantomTyNames _ _ = []

buildTypeDepsGraph :: Set Name -> [DataTy] -> AdjacencyMap Name
buildTypeDepsGraph userTypes dts =
overlay isolated edged
where
phantomMap = buildPhantomMap dts
isolated = vertices (Set.toList userTypes)
edged = stars [(dataName dt, deps dt) | dt <- dts]
deps (DataTy _ _ ctors) =
nub
. filter (`Set.member` userTypes)
. concatMap (\(Constr _ tys) -> concatMap (nonPhantomTyNames phantomMap) tys)
$ ctors
Comment thread
rodrigogribeiro marked this conversation as resolved.

checkRecursiveTypes :: [DataTy] -> TcM ()
checkRecursiveTypes dts =
case cyclicSccs of
[] -> pure ()
(c : _) -> recursiveTypeError (NAG.vertexList1 c)
where
userTypes = Set.fromList (map dataName dts)
Comment thread
rodrigogribeiro marked this conversation as resolved.
graph = buildTypeDepsGraph userTypes dts
cyclicSccs = filter (isCyclic graph) (vertexList (scc graph))
isCyclic origGraph sccComp =
case N.toList (NAG.vertexList1 sccComp) of
[v] -> hasEdge v v origGraph -- singleton SCC: cyclic only if self-loop
_ -> True -- 2+ vertices: always a mutual cycle

recursiveTypeError :: N.NonEmpty Name -> TcM a
recursiveTypeError cycleVerts =
throwError $
unlines
[ "Recursive data type detected:",
" " ++ intercalate ", " (map pretty (N.toList cycleVerts)),
" (Data types must be non-recursive)"
]

-- setting up pragmas for type checking

setupPragmas :: [Pragma] -> TcM ()
Expand Down Expand Up @@ -265,7 +369,7 @@ checkTopDecl _ = pure ()

tcContract :: Contract Name -> TcM (Contract Id, [(Name, Scheme)])
tcContract c@(Contract n vs cdecls) =
withLocalEnv $ withContractName n $ do
withLocalContractEnv $ withContractName n $ do
ctx' <- gets ctx
initializeEnv c
decls' <- mapM tcDecl' cdecls
Expand Down
12 changes: 12 additions & 0 deletions src/Solcore/Frontend/TypeInference/TcMonad.hs
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,18 @@ withLocalEnv ta =
putEnv savedCtx
pure a

-- Like withLocalEnv but also restores the typeTable, for contract scopes
-- where data type names must not leak between sibling contracts.
withLocalContractEnv :: TcM a -> TcM a
withLocalContractEnv ta =
do
savedCtx <- gets ctx
savedTypes <- gets typeTable
a <- ta
putEnv savedCtx
modify (\env -> env {typeTable = savedTypes})
pure a

envList :: TcM [(Name, Scheme)]
envList = gets (Map.toList . ctx)

Expand Down
22 changes: 14 additions & 8 deletions test/Cases.hs
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ cases =
runTestForFile "const.solc" caseFolder,
runTestExpectingFailure "const-array.solc" caseFolder,
runTestForFile "constructor-weak-args.solc" caseFolder,
runTestForFile "constructors-contract.solc" caseFolder,
runTestExpectingFailure "complexproxy.solc" caseFolder,
runTestForFile "cyclical-defs.solc" caseFolder,
runTestForFile "cyclical-defs-inferred.solc" caseFolder,
Expand All @@ -263,12 +264,12 @@ cases =
runTestExpectingFailure "duplicated-type-name.solc" caseFolder,
runTestForFile "DuplicateFun.solc" caseFolder,
runTestExpectingFailure "DupFun.solc" caseFolder,
runTestForFile "EitherModule.solc" caseFolder,
runTestExpectingFailure "EitherModule.solc" caseFolder,
runTestForFile "empty-asm.solc" caseFolder,
runTestExpectingFailure "Enum.solc" caseFolder,
runTestExpectingFailure "Eq.solc" caseFolder,
runTestForFile "EqQual.solc" caseFolder,
runTestForFile "EvenOdd.solc" caseFolder,
runTestExpectingFailure "EvenOdd.solc" caseFolder,
runTestExpectingFailure "Filter.solc" caseFolder,
runTestForFile "foo-class.solc" caseFolder,
runTestForFile "Foo.solc" caseFolder,
Expand All @@ -293,8 +294,8 @@ cases =
runTestExpectingFailure "joinErr.solc" caseFolder,
runTestExpectingFailure "KindTest.solc" caseFolder,
runTestExpectingFailure "listeq.solc" caseFolder,
runTestForFile "ListModule.solc" caseFolder,
runTestForFile "listid.solc" caseFolder,
runTestExpectingFailure "ListModule.solc" caseFolder,
runTestExpectingFailure "listid.solc" caseFolder,
runTestForFile "Logic.solc" caseFolder,
runTestExpectingFailure "mainproxy.solc" caseFolder,
runTestForFile "MatchCall.solc" caseFolder,
Expand All @@ -309,6 +310,7 @@ cases =
runTestForFile "monomorphic-require.solc" caseFolder,
runTestForFile "morefun.solc" caseFolder,
runTestForFile "Mutuals.solc" caseFolder,
runTestForFile "rec-memory.solc" caseFolder,
runTestExpectingFailure "nano-desugared.solc" caseFolder,
runTestForFile "NegPair.solc" caseFolder,
runTestForFile "nid.solc" caseFolder,
Expand All @@ -323,8 +325,8 @@ cases =
runTestExpectingFailure "PairMatch2.solc" caseFolder,
-- failing due to missing assign constraint
runTestExpectingFailure "patterson-bug.solc" caseFolder,
runTestForFile "Peano.solc" caseFolder,
runTestForFile "PeanoMatch.solc" caseFolder,
runTestExpectingFailure "Peano.solc" caseFolder,
runTestExpectingFailure "PeanoMatch.solc" caseFolder,
runTestForFile "polymatch-error.solc" caseFolder,
runTestForFile "polymorphic-require.solc" caseFolder,
runTestExpectingFailure "pragma_merge_fail_coverage.solc" caseFolder,
Expand All @@ -336,6 +338,8 @@ cases =
runTestForFile "proxy.solc" caseFolder,
runTestExpectingFailure "proxy1.solc" caseFolder,
runTestForFile "rec.solc" caseFolder,
runTestExpectingFailure "recursive-type-direct.solc" caseFolder,
runTestExpectingFailure "recursive-type-mutual.solc" caseFolder,
runTestExpectingFailure "require-annotation-missing-param.solc" caseFolder,
runTestExpectingFailure "require-annotation-missing-return.solc" caseFolder,
runTestExpectingFailure "require-annotation-missing-both.solc" caseFolder,
Expand Down Expand Up @@ -375,7 +379,7 @@ cases =
runTestExpectingFailure "subject-index.solc" caseFolder,
runTestExpectingFailure "subject-reduction.solc" caseFolder,
runTestExpectingFailure "subsumption-test.solc" caseFolder,
runTestForFile "super-class.solc" caseFolder,
runTestExpectingFailure "super-class.solc" caseFolder,
runTestForFile "super-class-num.solc" caseFolder,
runTestForFile "tiamat.solc" caseFolder,
runTestForFile "tuple-trick.solc" caseFolder,
Expand Down Expand Up @@ -423,6 +427,7 @@ cases =
runTestForFile "redundant-match.solc" caseFolder,
runTestForFile "false-redundant-warning.solc" caseFolder,
runTestForFile "proxy-desugar.solc" caseFolder,
runTestForFile "box.solc" caseFolder,
runTestForFile "invokable-issue.solc" caseFolder,
runTestForFile "td.solc" caseFolder,
runTestForFile "bar.solc" caseFolder,
Expand Down Expand Up @@ -450,7 +455,8 @@ cases =
runTestForFile "ltimp.solc" caseFolder,
runTestExpectingFailure "class-return-type-miss.solc" caseFolder,
runTestExpectingFailure "catenable-err.solc" caseFolder,
runTestForFile "pars.solc" caseFolder
runTestForFile "pars.solc" caseFolder,
runTestExpectingFailure "synonym-example.solc" caseFolder
]
where
caseFolder = "./test/examples/cases"
Expand Down
4 changes: 3 additions & 1 deletion test/examples/cases/Ackermann.solc
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
data Nat = Zero | Succ(Nat) ;
data memory(a) = memory(word);

data Nat = Zero | Succ(memory(Nat)) ;

function foo (x : Nat, y : Nat) -> word {
match y, x {
Expand Down
3 changes: 2 additions & 1 deletion test/examples/cases/EitherModule.solc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
contract EitherModule {
data memory(a) = memory(word);
data Either(a,b) = Left(a) | Right(b);
data List(a) = Nil | Cons(a,List(a));
data List(a) = Nil | Cons(a, memory(List(a)));

function lefts(xs : List(Either(word,word))) -> List(word) {
match xs {
Expand Down
3 changes: 3 additions & 0 deletions test/examples/cases/box.solc
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
data memory(a) = memory(word);
data Box(a) = Box(memory(a));
data Rec = Rec(Box(Rec));
9 changes: 9 additions & 0 deletions test/examples/cases/constructors-contract.solc
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
contract A {
data T = MkT(U);
data U = MkU(word);
}

contract B {
data T = MkT2(word);
data U = MkU2(T);
}
3 changes: 3 additions & 0 deletions test/examples/cases/rec-memory.solc
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
data memory(a) = memory(word);

data List(a) = Nil | Cons(a, memory(List(a)));
9 changes: 9 additions & 0 deletions test/examples/cases/recursive-type-direct.solc
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
-- Direct recursion: Nat refers to itself in the Succ constructor.
-- Expected: type checker rejects with "Recursive data type detected".
data Nat = Zero | Succ(Nat);

contract RecursiveTypeDirect {
function main() -> Nat {
Zero
}
}
10 changes: 10 additions & 0 deletions test/examples/cases/recursive-type-mutual.solc
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
-- Mutual recursion: Even refers to Odd and Odd refers to Even.
-- Expected: type checker rejects with "Recursive data type detected".
data Even = Zero | SuccE(Odd);
data Odd = SuccO(Even);

contract RecursiveTypeMutual {
function main() -> Even {
Zero
}
}
3 changes: 3 additions & 0 deletions test/examples/cases/synonym-example.solc
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

type Ref = T;
data T = Mk(Ref);
4 changes: 3 additions & 1 deletion test/examples/pragmas/coverage.solc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
pragma no-coverage-condition ;

data List(a) = Nil | Cons(a,List(a));
data memory(a) = memory(word);

data List(a) = Nil | Cons(a,memory(List(a)));
data Bool = True | False ;

forall a b c . class a : C(b,c) {}
Expand Down
Loading