Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions Compiler/ContractSpec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,15 @@ inductive Stmt
(resultVars : List String) -- local vars to bind return values to
(externalName : String) -- name of the external function declaration
(args : List Expr) -- call arguments
/-- High-level ABI-encoded external call with return value binding.
Compiles to: mstore(selector+args), call/staticcall, revert forwarding, returndatacopy+mload.
Covers the common `call(gas(), target, 0, 0, calldataSize, 0, 32)` + decode pattern. -/
| externalCallWithReturn
(resultVar : String) -- local variable to bind the returned uint256
(target : Expr) -- target contract address
(selector : Nat) -- 4-byte function selector (e.g., 0xa035b1fe)
(args : List Expr) -- ABI-encoded arguments (each occupies 32 bytes)
(isStatic : Bool := false) -- use staticcall instead of call
deriving Repr

structure FunctionSpec where
Expand Down Expand Up @@ -633,6 +642,8 @@ private partial def collectStmtNames : Stmt → List String
collectExprListNames topics ++ collectExprNames dataOffset ++ collectExprNames dataSize
| Stmt.externalCallBind resultVars externalName args =>
resultVars ++ externalName :: collectExprListNames args
| Stmt.externalCallWithReturn resultVar target _ args _ =>
resultVar :: collectExprNames target ++ collectExprListNames args

private partial def collectStmtListNames : List Stmt → List String
| [] => []
Expand Down Expand Up @@ -1098,6 +1109,8 @@ private partial def stmtContainsUnsafeLogicalCallLike : Stmt → Bool
exprContainsUnsafeLogicalCallLike dataSize
| Stmt.externalCallBind _ _ args =>
args.any exprContainsUnsafeLogicalCallLike
| Stmt.externalCallWithReturn _ target _ args _ =>
exprContainsUnsafeLogicalCallLike target || args.any exprContainsUnsafeLogicalCallLike

private partial def staticParamBindingNames (name : String) (ty : ParamType) : List String :=
match ty with
Expand Down Expand Up @@ -1327,6 +1340,14 @@ private partial def validateScopedStmtIdentifiers
| Stmt.externalCallBind resultVars _ args => do
args.forM (validateScopedExprIdentifiers context params paramScope dynamicParams localScope constructorArgCount)
pure (resultVars.reverse ++ localScope)
| Stmt.externalCallWithReturn resultVar target _ args _ => do
validateScopedExprIdentifiers context params paramScope dynamicParams localScope constructorArgCount target
args.forM (validateScopedExprIdentifiers context params paramScope dynamicParams localScope constructorArgCount)
if paramScope.contains resultVar then
throw s!"Compilation error: {context} uses Stmt.externalCallWithReturn with result variable '{resultVar}' that shadows a parameter"
if localScope.contains resultVar then
throw s!"Compilation error: {context} uses Stmt.externalCallWithReturn with result variable '{resultVar}' that redeclares an existing local variable"
pure (resultVar :: localScope)
Comment thread
cursor[bot] marked this conversation as resolved.
| _ => pure localScope

private partial def validateScopedStmtListIdentifiers
Expand Down Expand Up @@ -1563,6 +1584,8 @@ private partial def stmtWritesState : Stmt → Bool
args.any exprWritesState || true
| Stmt.externalCallBind _ _ args =>
args.any exprWritesState || true
| Stmt.externalCallWithReturn _ target _ args isStatic =>
exprWritesState target || args.any exprWritesState || !isStatic
where
exprWritesState (expr : Expr) : Bool :=
match expr with
Expand Down Expand Up @@ -1639,6 +1662,8 @@ private partial def stmtReadsStateOrEnv : Stmt → Bool
args.any exprReadsStateOrEnv || true
| Stmt.externalCallBind _ _ args =>
args.any exprReadsStateOrEnv || true
| Stmt.externalCallWithReturn _ target _ args _ =>
exprReadsStateOrEnv target || args.any exprReadsStateOrEnv || true

private def validateFunctionSpec (spec : FunctionSpec) : Except String Unit := do
if spec.isPayable && (spec.isView || spec.isPure) then
Expand Down Expand Up @@ -2233,6 +2258,9 @@ private partial def validateInteropStmt (context : String) : Stmt → Except Str
args.forM (validateInteropExpr context)
| Stmt.externalCallBind _ _ args =>
args.forM (validateInteropExpr context)
| Stmt.externalCallWithReturn _ target _ args _ => do
validateInteropExpr context target
args.forM (validateInteropExpr context)
| Stmt.returnValues values =>
values.forM (validateInteropExpr context)
| Stmt.rawLog topics dataOffset dataSize => do
Expand Down Expand Up @@ -2460,6 +2488,9 @@ private partial def validateInternalCallShapesInStmt
validateInternalCallShapesInExpr functions callerName dataSize
| Stmt.externalCallBind _resultVars _ args =>
args.forM (validateInternalCallShapesInExpr functions callerName)
| Stmt.externalCallWithReturn _ target _ args _ => do
validateInternalCallShapesInExpr functions callerName target
args.forM (validateInternalCallShapesInExpr functions callerName)
| _ =>
pure ()

Expand Down Expand Up @@ -2606,6 +2637,9 @@ private partial def validateExternalCallTargetsInStmt
else
checkDuplicateVars (name :: seen) rest
checkDuplicateVars [] resultVars
| Stmt.externalCallWithReturn _ target _ args _ => do
validateExternalCallTargetsInExpr externals context target
args.forM (validateExternalCallTargetsInExpr externals context)
| Stmt.returnValues values =>
values.forM (validateExternalCallTargetsInExpr externals context)
| Stmt.rawLog topics dataOffset dataSize => do
Expand Down Expand Up @@ -3751,6 +3785,52 @@ def compileStmt (fields : List Field) (events : List EventDef := [])
| Stmt.externalCallBind resultVars externalName args => do
let argExprs ← compileExprList fields dynamicSource args
pure [YulStmt.letMany resultVars (YulExpr.call externalName argExprs)]
| Stmt.externalCallWithReturn resultVar target selector args isStatic => do
let targetExpr ← compileExpr fields dynamicSource target
let argCompiledExprs ← compileExprList fields dynamicSource args
-- Step 1: store selector (left-shifted 224 bits) at memory offset 0
let selectorExpr := YulExpr.call "shl" [YulExpr.lit 224, YulExpr.hex selector]
let storeSelector := YulStmt.expr (YulExpr.call "mstore" [YulExpr.lit 0, selectorExpr])
-- Step 2: store each arg at offsets 4, 36, 68, ...
let storeArgs := argCompiledExprs.zipIdx.map fun (argExpr, i) =>
YulStmt.expr (YulExpr.call "mstore" [YulExpr.lit (4 + i * 32), argExpr])
-- Step 3: perform call/staticcall
let calldataSize := 4 + args.length * 32
let callExpr :=
if isStatic then
YulExpr.call "staticcall" [
YulExpr.call "gas" [],
targetExpr,
YulExpr.lit 0, YulExpr.lit calldataSize,
YulExpr.lit 0, YulExpr.lit 32
]
else
YulExpr.call "call" [
YulExpr.call "gas" [],
targetExpr,
YulExpr.lit 0,
YulExpr.lit 0, YulExpr.lit calldataSize,
YulExpr.lit 0, YulExpr.lit 32
]
let letSuccess := YulStmt.let_ "__ecwr_success" callExpr
-- Step 4: revert forwarding on failure
let revertBlock := YulStmt.if_ (YulExpr.call "iszero" [YulExpr.ident "__ecwr_success"]) [
YulStmt.let_ "__ecwr_rds" (YulExpr.call "returndatasize" []),
YulStmt.expr (YulExpr.call "returndatacopy" [YulExpr.lit 0, YulExpr.lit 0, YulExpr.ident "__ecwr_rds"]),
YulStmt.expr (YulExpr.call "revert" [YulExpr.lit 0, YulExpr.ident "__ecwr_rds"])
]
-- Step 5: validate return data size ≥ 32
let sizeCheck := YulStmt.if_ (YulExpr.call "lt" [YulExpr.call "returndatasize" [], YulExpr.lit 32]) [
YulStmt.expr (YulExpr.call "revert" [YulExpr.lit 0, YulExpr.lit 0])
]
-- Wrap call + checks in a block so __ecwr_success is block-scoped.
-- This avoids duplicate let declarations when multiple externalCallWithReturn
-- statements appear in the same function body.
let callBlock := YulStmt.block ([storeSelector] ++ storeArgs ++ [letSuccess, revertBlock, sizeCheck])
-- Step 6: extract return value outside the block (call already copied returndata to memory[0..32])
-- resultVar is flat-scoped so subsequent statements can reference it.
let bindResult := YulStmt.let_ resultVar (YulExpr.call "mload" [YulExpr.lit 0])
pure [callBlock, bindResult]
| Stmt.returnValues values => do
if isInternal then
if values.length != internalRetNames.length then
Expand Down Expand Up @@ -4019,6 +4099,7 @@ private partial def collectStmtBindNames : Stmt → List String
varName :: collectStmtListBindNames body
| Stmt.internalCallAssign names _ _ => names
| Stmt.externalCallBind resultVars _ _ => resultVars
| Stmt.externalCallWithReturn resultVar _ _ _ _ => [resultVar]
| _ => []

private partial def collectStmtListBindNames : List Stmt → List String
Expand Down Expand Up @@ -4091,6 +4172,8 @@ private partial def stmtUsesArrayElement : Stmt → Bool
topics.any exprUsesArrayElement || exprUsesArrayElement dataOffset || exprUsesArrayElement dataSize
| Stmt.externalCallBind _ _ args =>
args.any exprUsesArrayElement
| Stmt.externalCallWithReturn _ target _ args _ =>
exprUsesArrayElement target || args.any exprUsesArrayElement
| _ => false

private def functionUsesArrayElement (fn : FunctionSpec) : Bool :=
Expand Down
213 changes: 213 additions & 0 deletions Compiler/ContractSpecFeatureTest.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4887,6 +4887,51 @@ private def externalCallBindSpec : ContractSpec := {
]
}

-- Stmt.externalCallWithReturn: ABI-encoded external call with return (#926)

-- Test: externalCallWithReturn compiles to mstore+call+returndatacopy pattern
private def externalCallWithReturnSpec : ContractSpec := {
name := "ExternalCallWithReturn"
fields := []
constructor := none
functions := [
-- Test 1: simple staticcall with no args (oracle price feed pattern)
{ name := "getPrice"
params := [{ name := "oracle", ty := ParamType.address }]
returnType := some FieldType.uint256
body := [
Stmt.externalCallWithReturn "price" (Expr.param "oracle") 0xa035b1fe [] (isStatic := true),
Stmt.return (Expr.localVar "price")
]
},
-- Test 2: call with args (ERC20 balanceOf pattern)
{ name := "getBalance"
params := [
{ name := "token", ty := ParamType.address },
{ name := "account", ty := ParamType.address }
]
returnType := some FieldType.uint256
body := [
Stmt.externalCallWithReturn "bal" (Expr.param "token") 0x70a08231 [Expr.param "account"],
Stmt.return (Expr.localVar "bal")
]
},
-- Test 3: call with multiple args (IRM borrowRate pattern)
{ name := "getBorrowRate"
params := [
{ name := "irm", ty := ParamType.address },
{ name := "a", ty := ParamType.uint256 },
{ name := "b", ty := ParamType.uint256 }
]
returnType := some FieldType.uint256
body := [
Stmt.externalCallWithReturn "rate" (Expr.param "irm") 0x9451fed4 [Expr.param "a", Expr.param "b"],
Stmt.return (Expr.localVar "rate")
]
}
]
}

-- Test: single return externalCallBind compiles correctly
#eval! do
match compile externalCallBindSpec [1, 2, 3] with
Expand Down Expand Up @@ -5012,4 +5057,172 @@ private def externalCallBindSpec : ContractSpec := {
| .ok _ =>
throw (IO.userError "✗ externalCallBind should have rejected duplicate vars")

-- Test: staticcall with no args (oracle pattern)
#eval! do
match compile externalCallWithReturnSpec [1, 2, 3] with
| .error err =>
throw (IO.userError s!"✗ externalCallWithReturn spec compile failed: {err}")
| .ok ir =>
let rendered := Yul.render (emitYul ir)
-- Should use shl(224, selector) for selector encoding
assertContains "externalCallWithReturn staticcall selector" rendered
["shl(224, 0xa035b1fe)"]
-- Should use staticcall (not call) since isStatic=true
assertContains "externalCallWithReturn uses staticcall" rendered
["staticcall(gas(),"]
-- Should have revert forwarding
assertContains "externalCallWithReturn revert forwarding" rendered
["iszero(__ecwr_success)", "returndatacopy(0, 0, __ecwr_rds)", "revert(0, __ecwr_rds)"]
-- Should validate returndata size
assertContains "externalCallWithReturn size check" rendered
["lt(returndatasize(), 32)"]
-- Should extract return value (no redundant returndatacopy — call already copied to memory)
assertContains "externalCallWithReturn return extraction" rendered
["let price := mload(0)"]
-- Should NOT have a redundant returndatacopy(0, 0, 32) outside the revert block
IO.println s!"✓ externalCallWithReturn staticcall compiles correctly"

-- Test: call with one arg (balanceOf pattern)
#eval! do
match compile externalCallWithReturnSpec [1, 2, 3] with
| .error err =>
throw (IO.userError s!"✗ externalCallWithReturn spec compile failed: {err}")
| .ok ir =>
let rendered := Yul.render (emitYul ir)
-- Should use shl(224, selector) for balanceOf selector
assertContains "externalCallWithReturn call selector" rendered
["shl(224, 0x70a08231)"]
-- Should store arg at offset 4
assertContains "externalCallWithReturn arg encoding" rendered
["mstore(4,"]
-- Should use call (not staticcall) since isStatic=false
assertContains "externalCallWithReturn uses call" rendered
["call(gas(),"]
-- Should extract result to bal
assertContains "externalCallWithReturn bal binding" rendered
["let bal := mload(0)"]
IO.println s!"✓ externalCallWithReturn call with args compiles correctly"

-- Test: call with multiple args (IRM pattern)
#eval! do
match compile externalCallWithReturnSpec [1, 2, 3] with
| .error err =>
throw (IO.userError s!"✗ externalCallWithReturn spec compile failed: {err}")
| .ok ir =>
let rendered := Yul.render (emitYul ir)
-- Should store two args at offsets 4 and 36
assertContains "externalCallWithReturn multi-arg encoding" rendered
["shl(224, 0x9451fed4)", "mstore(4,", "mstore(36,"]
-- Calldata size should be 4 + 2*32 = 68
assertContains "externalCallWithReturn calldata size" rendered
["call(gas(),"]
-- Should extract result to rate
assertContains "externalCallWithReturn rate binding" rendered
["let rate := mload(0)"]
IO.println s!"✓ externalCallWithReturn multi-arg call compiles correctly"

-- Test: externalCallWithReturn rejects result variable shadowing a parameter
#eval! do
let shadowSpec : ContractSpec := {
name := "ShadowParam"
fields := []
constructor := none
functions := [
{ name := "bad"
params := [{ name := "oracle", ty := ParamType.address }]
returnType := some FieldType.uint256
body := [
Stmt.externalCallWithReturn "oracle" (Expr.param "oracle") 0xa035b1fe [] (isStatic := true),
Stmt.return (Expr.localVar "oracle")
]
}
]
}
match compile shadowSpec [1] with
| .error err =>
if contains err "shadows a parameter" then
IO.println s!"✓ externalCallWithReturn rejects parameter shadow: {err}"
else
throw (IO.userError s!"✗ externalCallWithReturn wrong error: {err}")
| .ok _ =>
throw (IO.userError "✗ externalCallWithReturn should have rejected parameter shadow")

-- Test: externalCallWithReturn rejects redeclaring existing local variable
#eval! do
let redeclareSpec : ContractSpec := {
name := "RedeclareLocal"
fields := []
constructor := none
functions := [
{ name := "bad"
params := [{ name := "oracle", ty := ParamType.address }]
returnType := some FieldType.uint256
body := [
Stmt.letVar "price" (Expr.literal 0),
Stmt.externalCallWithReturn "price" (Expr.param "oracle") 0xa035b1fe [] (isStatic := true),
Stmt.return (Expr.localVar "price")
]
}
]
}
match compile redeclareSpec [1] with
| .error err =>
if contains err "redeclares an existing local variable" then
IO.println s!"✓ externalCallWithReturn rejects local redeclaration: {err}"
else
throw (IO.userError s!"✗ externalCallWithReturn wrong error: {err}")
| .ok _ =>
throw (IO.userError "✗ externalCallWithReturn should have rejected redeclaration")

-- Test: staticcall external call allows view mutability
#eval! do
let viewSpec : ContractSpec := {
name := "ViewStaticCall"
fields := []
constructor := none
functions := [
{ name := "getPrice"
params := [{ name := "oracle", ty := ParamType.address }]
returnType := some FieldType.uint256
isView := true
body := [
Stmt.externalCallWithReturn "price" (Expr.param "oracle") 0xa035b1fe [] (isStatic := true),
Stmt.return (Expr.localVar "price")
]
}
]
}
match compile viewSpec [1] with
| .error err =>
throw (IO.userError s!"✗ view staticcall should compile: {err}")
| .ok _ =>
IO.println "✓ externalCallWithReturn staticcall accepted for view function"

-- Test: multiple externalCallWithReturn in same function (no duplicate let collision)
#eval! do
let multiCallSpec : ContractSpec := {
name := "MultiExternalCall"
fields := []
constructor := none
functions := [
{ name := "getPrices"
params := [{ name := "oracle1", ty := ParamType.address }, { name := "oracle2", ty := ParamType.address }]
returnType := none
body := [
Stmt.externalCallWithReturn "price1" (Expr.param "oracle1") 0xa035b1fe [] (isStatic := true),
Stmt.externalCallWithReturn "price2" (Expr.param "oracle2") 0xa035b1fe [] (isStatic := true),
Stmt.stop
]
}
]
}
match compile multiCallSpec [1] with
| .error err =>
throw (IO.userError s!"✗ multiple externalCallWithReturn should compile: {err}")
| .ok ir =>
let rendered := Yul.render (emitYul ir)
assertContains "multi externalCallWithReturn both bindings" rendered
["let price1 := mload(0)", "let price2 := mload(0)"]
IO.println "✓ multiple externalCallWithReturn in same function compiles without collision"

end Compiler.ContractSpecFeatureTest
Loading
Loading