Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sol-core.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ library
Solcore.Backend.MastEval
Solcore.Backend.Specialise
Solcore.Desugarer.DecisionTreeCompiler
Solcore.Desugarer.DeriveGeneric
Solcore.Desugarer.FieldAccess
Solcore.Desugarer.IfDesugarer
Solcore.Desugarer.IndirectCall
Expand Down
199 changes: 199 additions & 0 deletions src/Solcore/Desugarer/DeriveGeneric.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
module Solcore.Desugarer.DeriveGeneric where

import Data.List (nub)
import Data.List.NonEmpty (toList)
import Solcore.Frontend.Syntax

-- Generate Generic instances for data types

deriveGenericTopDecls :: [DataTy] -> [TopDecl Name] -> Either String [TopDecl Name]
deriveGenericTopDecls localData allDecls
| not (genericClassVisible allDecls) = Right allDecls
| (n : _) <- conflicts = Left (conflictError n)
| otherwise = Right (allDecls ++ newInsts)
where
excluded = pragmaExcluded allDecls
hasInst = existingGenericTypes allDecls
conflicts =
[ dataName dt
| dt <- localData,
dataName dt `elem` hasInst,
dataName dt `notElem` excluded
]
newInsts =
[ TInstDef (buildInstance dt)
| dt <- localData,
not (null (dataConstrs dt)),
dataName dt `notElem` excluded,
dataName dt `notElem` hasInst
]
conflictError n =
"type '"
++ show n
++ "' has a manual Generic instance "
++ "but no 'pragma no-generic-instance-for "
++ show n
++ "'; "
++ "add the pragma to suppress auto-derivation"

genericClassVisible :: [TopDecl Name] -> Bool
genericClassVisible = any isGenericClass
where
isGenericClass (TClassDef cls) = className cls == Name "Generic"
isGenericClass _ = False

collectDataDefs :: [TopDecl Name] -> [DataTy]
collectDataDefs = concatMap go
where
go (TDataDef dt) = [dt]
go (TContr (Contract _ _ ds)) = [dt | CDataDecl dt <- ds]
go _ = []

existingGenericTypes :: [TopDecl Name] -> [Name]
existingGenericTypes = concatMap go
where
go (TInstDef inst)
| instName inst == Name "Generic" = [tyConName (mainTy inst)]
go _ = []
tyConName (TyCon n _) = n
tyConName _ = Name ""

pragmaExcluded :: [TopDecl Name] -> [Name]
pragmaExcluded = nub . concatMap go
where
go (TPragmaDecl (Pragma NoGenericInstanceFor (DisableFor names))) =
toList names
go _ = []

-- SOP representation type

unitTy :: Ty
unitTy = TyCon (Name "()") []

mkProdOf :: [Ty] -> Ty
mkProdOf [] = unitTy
mkProdOf [t] = t
mkProdOf (t : ts) = TyCon (Name "pair") [t, mkProdOf ts]

mkSumOf :: [Ty] -> Ty
mkSumOf [] = unitTy
mkSumOf [t] = t
mkSumOf (t : ts) = TyCon (Name "sum") [t, mkSumOf ts]

constrRep :: Constr -> Ty
constrRep (Constr _ []) = unitTy
constrRep (Constr _ [t]) = t
constrRep (Constr _ ts) = mkProdOf ts

sopRep :: DataTy -> Ty
sopRep dt = mkSumOf (map constrRep (dataConstrs dt))

-- Expression helpers

mkProdExp :: [Exp Name] -> Exp Name
mkProdExp [] = Con (Name "()") []
mkProdExp [e] = e
mkProdExp (e : es) = Con (Name "pair") [e, mkProdExp es]

applyInr :: Int -> Exp Name -> Exp Name
applyInr 0 e = e
applyInr n e = Con (Name "inr") [applyInr (n - 1) e]

wrapSumExp :: Int -> Int -> Exp Name -> Exp Name
wrapSumExp _ 1 inner = inner
wrapSumExp idx total inner
| idx == total - 1 = applyInr (total - 1) inner
| otherwise = applyInr idx (Con (Name "inl") [inner])

pairPat :: Pat Name -> Pat Name -> Pat Name
pairPat p1 p2 = PCon (Name "pair") [p1, p2]

mkProdPat :: [Name] -> Pat Name
mkProdPat [] = PCon (Name "()") []
mkProdPat [v] = PVar v
mkProdPat vs = foldr1 pairPat (map PVar vs)

applyPInr :: Int -> Pat Name -> Pat Name
applyPInr 0 p = p
applyPInr n p = PCon (Name "inr") [applyPInr (n - 1) p]

wrapSumPat :: Int -> Int -> Pat Name -> Pat Name
wrapSumPat _ 1 inner = inner
wrapSumPat idx total inner
| idx == total - 1 = applyPInr (total - 1) inner
| otherwise = applyPInr idx (PCon (Name "inl") [inner])

freshVarNames :: Int -> [Name]
freshVarNames n = [Name ("_gv" ++ show i) | i <- [0 .. n - 1]]

fromClause :: Int -> Int -> Constr -> Equation Name
fromClause idx total (Constr cname tys) =
let vars = freshVarNames (length tys)
pat = PCon cname (map PVar vars)
prodExp = mkProdExp (map Var vars)
sumExp = wrapSumExp idx total prodExp
in ([pat], [Return sumExp])

fromBody :: DataTy -> Body Name
fromBody dt =
let constrs = dataConstrs dt
total = length constrs
in [Match [Var (Name "_x")] (zipWith (\i c -> fromClause i total c) [0 ..] constrs)]

toClause :: Int -> Int -> Constr -> Equation Name
toClause idx total (Constr cname tys) =
let vars = freshVarNames (length tys)
prodPat = mkProdPat vars
sumPat = wrapSumPat idx total prodPat
conExp = Con cname (map Var vars)
in ([sumPat], [Return conExp])

toBody :: DataTy -> Body Name
toBody dt =
let constrs = dataConstrs dt
total = length constrs
in [Match [Var (Name "_r")] (zipWith (\i c -> toClause i total c) [0 ..] constrs)]

buildFrom :: DataTy -> FunDef Name
buildFrom dt = FunDef False sig (fromBody dt)
where
mainT = TyCon (dataName dt) (map TyVar (dataParams dt))
repT = sopRep dt
sig =
Signature
{ sigVars = [],
sigContext = [],
sigName = Name "from",
sigParams = [Typed False (Name "_x") mainT],
sigRetComptime = False,
sigReturn = Just repT,
sigPayable = False
}

buildTo :: DataTy -> FunDef Name
buildTo dt = FunDef False sig (toBody dt)
where
mainT = TyCon (dataName dt) (map TyVar (dataParams dt))
repT = sopRep dt
sig =
Signature
{ sigVars = [],
sigContext = [],
sigName = Name "to",
sigParams = [Typed False (Name "_r") repT],
sigRetComptime = False,
sigReturn = Just mainT,
sigPayable = False
}

buildInstance :: DataTy -> Instance Name
buildInstance dt =
Instance
{ instDefault = False,
instVars = dataParams dt,
instContext = [],
instName = Name "Generic",
paramsTy = [sopRep dt],
mainTy = TyCon (dataName dt) (map TyVar (dataParams dt)),
instFunctions = [buildFrom dt, buildTo dt]
}
16 changes: 12 additions & 4 deletions src/Solcore/Frontend/Parser/Decl.hs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ pragmaP :: Parser Pragma
pragmaP = do
keyword "pragma"
ty <- pragmaTypeP
st <- pragmaStatusP
st <- pragmaStatusForP ty
_ <- semicolon
return (Pragma ty st)

Expand All @@ -195,9 +195,17 @@ pragmaTypeP =
<$ keyword "no-patterson-condition"
<|> NoBoundVariableCondition
<$ keyword "no-bounded-variable-condition"

pragmaStatusP :: Parser PragmaStatus
pragmaStatusP = option DisableAll $ do
<|> NoGenericInstanceFor
<$ keyword "no-generic-instance-for"

-- | Parse the pragma status. For 'NoGenericInstanceFor' a non-empty list of
-- type names is mandatory; for all other pragma types the list is optional and
-- defaults to 'DisableAll'.
pragmaStatusForP :: PragmaType -> Parser PragmaStatus
pragmaStatusForP NoGenericInstanceFor = do
names <- (Name <$> identifier) `sepBy1` comma
return (DisableFor (NE.fromList names))
pragmaStatusForP _ = option DisableAll $ do
names <- (Name <$> identifier) `sepBy1` comma
return (DisableFor (NE.fromList names))

Expand Down
1 change: 1 addition & 0 deletions src/Solcore/Frontend/Pretty/SolcorePretty.hs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ instance Pretty PragmaType where
ppr NoBoundVariableCondition = text "no-bounded-variable-condition"
ppr NoCoverageCondition = text "no-coverage-condition"
ppr NoPattersonCondition = text "no-patterson-condition"
ppr NoGenericInstanceFor = text "no-generic-instance-for"

instance Pretty PragmaStatus where
ppr (DisableFor ns) =
Expand Down
1 change: 1 addition & 0 deletions src/Solcore/Frontend/Pretty/TreePretty.hs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ instance Pretty PragmaType where
ppr NoBoundVariableCondition = text "no-bounded-variable-condition"
ppr NoCoverageCondition = text "no-coverage-condition"
ppr NoPattersonCondition = text "no-patterson-condition"
ppr NoGenericInstanceFor = text "no-generic-instance-for"

instance Pretty PragmaStatus where
ppr (DisableFor ns) =
Expand Down
1 change: 1 addition & 0 deletions src/Solcore/Frontend/Syntax/Contract.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ data PragmaType
= NoCoverageCondition
| NoPattersonCondition
| NoBoundVariableCondition
| NoGenericInstanceFor
deriving (Eq, Ord, Show, Data, Typeable)

data PragmaStatus
Expand Down
2 changes: 2 additions & 0 deletions src/Solcore/Frontend/Syntax/NameResolution.hs
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@ instance Resolve S.PragmaType where
pure NoPattersonCondition
resolve S.NoBoundVariableCondition =
pure NoBoundVariableCondition
resolve S.NoGenericInstanceFor =
pure NoGenericInstanceFor

instance Resolve S.PragmaStatus where
type Result S.PragmaStatus = PragmaStatus
Expand Down
1 change: 1 addition & 0 deletions src/Solcore/Frontend/Syntax/SyntaxTree.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ data PragmaType
= NoCoverageCondition
| NoPattersonCondition
| NoBoundVariableCondition
| NoGenericInstanceFor
deriving (Eq, Ord, Show, Data, Typeable)

data PragmaStatus
Expand Down
14 changes: 13 additions & 1 deletion src/Solcore/Pipeline/SolcorePipeline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import Solcore.Backend.MastEval (defaultFuel, eliminateDeadCode, evalCompUnit)
import Solcore.Backend.Specialise (specialiseCompUnit)
import Solcore.Desugarer.ContractDispatch (contractDispatchTopDecls)
import Solcore.Desugarer.DecisionTreeCompiler (matchCompiler, showWarning)
import Solcore.Desugarer.DeriveGeneric (deriveGenericTopDecls)
import Solcore.Desugarer.FieldAccess (fieldDesugarTopDecls)
import Solcore.Desugarer.IfDesugarer (ifDesugarer)
import Solcore.Desugarer.IndirectCall (indirectCallTopDecls)
Expand Down Expand Up @@ -301,12 +302,23 @@ prepareInferenceDeclsForTypeInference opts emitOutput imps inferenceDecls = do
putStrLn "> Dispatch:"
putStrLn $ prettyInferenceDecls dispatched

-- Generic instance derivation (only for locally-defined data types)
let localData = [dt | ModuleInferenceDecl ModuleLocalDecl (TDataDef dt) <- dispatched]
derived <-
ExceptT $
runExceptT $
traverseModuleInferenceTopDecls (ExceptT . pure . deriveGenericTopDecls localData) dispatched

liftIO $ when verbose $ do
putStrLn "> Generic instance derivation:"
putStrLn $ prettyInferenceDecls derived

-- SCC analysis
connected <-
ExceptT $
timeItNamed "SCC " $
runExceptT $
traverseModuleInferenceTopDecls (ExceptT . sccAnalysisTopDecls) dispatched
traverseModuleInferenceTopDecls (ExceptT . sccAnalysisTopDecls) derived

liftIO $ when verbose $ do
putStrLn "> SCC Analysis:"
Expand Down
Loading
Loading