Skip to content

Commit 4cf3cd7

Browse files
authored
Translate function pointers and lambda (#7)
Translate function pointers and non-capturing lambdas as Option<fn> in unsafe and FnPtr in refcount. * Address is computed inside using the `FnAddr` trait * Record original function pointer + current cast * In unsafe function pointers + casts are translated using: `Option<fn>` + `std::mem::transmute` * In refcount function pointers + casts are translated using: `FnPtr<fn>` + `FnPtr<fn>::cast`
1 parent 2bf5df6 commit 4cf3cd7

66 files changed

Lines changed: 2978 additions & 120 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

cpp2rust/converter/converter.cpp

Lines changed: 107 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,15 @@ bool Converter::VisitRecordType(clang::RecordType *type) {
136136
auto *decl = type->getDecl();
137137
if (auto lambda = clang::dyn_cast<clang::CXXRecordDecl>(decl)) {
138138
if (lambda->isLambda()) {
139-
auto call_op = lambda->getLambdaCallOperator();
140-
StrCat("Rc<dyn Fn(");
141-
for (auto p : call_op->parameters()) {
142-
StrCat(std::format("{},", ToStringBase(p->getType())));
143-
}
144-
StrCat(")");
145-
if (!call_op->getReturnType()->isVoidType()) {
146-
StrCat("->");
147-
StrCat(ToStringBase(call_op->getReturnType()));
139+
if (in_function_formals_) {
140+
StrCat(
141+
ConvertFunctionPointerType(lambda->getLambdaCallOperator()
142+
->getType()
143+
->getAs<clang::FunctionProtoType>(),
144+
FnProtoType::LambdaCallOperator));
145+
} else {
146+
StrCat("_");
148147
}
149-
StrCat(">");
150148
return false;
151149
}
152150
}
@@ -228,24 +226,25 @@ bool Converter::VisitLValueReferenceType(clang::LValueReferenceType *type) {
228226
return Convert(pointee_type);
229227
}
230228

231-
void Converter::ConvertFunctionPointerType(clang::PointerType *type) {
232-
auto proto = type->getPointeeType()->getAs<clang::FunctionProtoType>();
233-
assert(proto && "Type should be a function prototype");
234-
235-
StrCat("Rc<dyn Fn(");
229+
std::string
230+
Converter::ConvertFunctionPointerType(const clang::FunctionProtoType *proto,
231+
FnProtoType kind) {
232+
std::string result =
233+
(kind == FnProtoType::LambdaCallOperator ? "impl Fn(" : "fn(");
236234
for (auto p_ty : proto->param_types()) {
237-
StrCat(std::format("{},", ToString(p_ty)));
235+
result += ToString(p_ty) + ",";
238236
}
239-
StrCat(")");
237+
result += ")";
240238
if (!proto->getReturnType()->isVoidType()) {
241-
StrCat(std::format("-> {}", ToString(proto->getReturnType())));
239+
result += std::format(" -> {}", ToString(proto->getReturnType()));
242240
}
243-
StrCat(">");
241+
return result;
244242
}
245243

246244
bool Converter::VisitPointerType(clang::PointerType *type) {
247-
if (type->getPointeeType()->getAs<clang::FunctionProtoType>()) {
248-
ConvertFunctionPointerType(type);
245+
if (auto proto = type->getPointeeType()->getAs<clang::FunctionProtoType>()) {
246+
StrCat(std::format("Option<{} {}>", keyword_unsafe_,
247+
ConvertFunctionPointerType(proto)));
249248
return false;
250249
}
251250

@@ -429,6 +428,9 @@ bool Converter::ConvertVarDeclSkipInit(clang::VarDecl *decl) {
429428
}
430429

431430
bool Converter::ConvertLambdaVarDecl(clang::VarDecl *decl) {
431+
if (decl->getType()->isFunctionPointerType()) {
432+
return false;
433+
}
432434
if (decl->hasInit()) {
433435
if (clang::isa<clang::LambdaExpr>(
434436
decl->getInit()->IgnoreUnlessSpelledInSource())) {
@@ -1383,6 +1385,17 @@ bool Converter::VisitCallExpr(clang::CallExpr *expr) {
13831385
return false;
13841386
}
13851387

1388+
void Converter::EmitFnPtrCall(clang::Expr *callee) {
1389+
StrCat(token::kOpenParen);
1390+
Convert(callee);
1391+
StrCat(").unwrap()");
1392+
}
1393+
1394+
void Converter::ConvertFunctionToFunctionPointer(
1395+
const clang::FunctionDecl *fn_decl) {
1396+
StrCat(std::format("Some({})", GetNamedDeclAsString(fn_decl)));
1397+
}
1398+
13861399
void Converter::ConvertGenericCallExpr(clang::CallExpr *expr) {
13871400
clang::Expr *callee = expr->getCallee();
13881401
auto convert_param_ty = [&](clang::QualType param_type, clang::Expr *expr) {
@@ -1405,7 +1418,8 @@ void Converter::ConvertGenericCallExpr(clang::CallExpr *expr) {
14051418
StrCat(token::kOpenParen);
14061419
StrCat(keyword_unsafe_);
14071420
StrCat(token::kOpenCurlyBracket);
1408-
const auto *function = expr->getCalleeDecl()->getAsFunction();
1421+
const auto *function =
1422+
expr->getCalleeDecl() ? expr->getCalleeDecl()->getAsFunction() : nullptr;
14091423
const clang::FunctionProtoType *proto = nullptr;
14101424

14111425
if (!function) {
@@ -1453,7 +1467,12 @@ void Converter::ConvertGenericCallExpr(clang::CallExpr *expr) {
14531467
}
14541468
}
14551469

1456-
Convert(callee);
1470+
if (proto && !function) {
1471+
EmitFnPtrCall(callee);
1472+
} else {
1473+
PushExprKind push(*this, ExprKind::Callee);
1474+
Convert(callee);
1475+
}
14571476
StrCat(token::kOpenParen);
14581477
for (unsigned i = 0; i < num_named_params && i < num_args; ++i) {
14591478
auto *arg = expr->getArg(i + arg_begin);
@@ -1674,7 +1693,15 @@ bool Converter::VisitImplicitCastExpr(clang::ImplicitCastExpr *expr) {
16741693
break;
16751694
}
16761695
case clang::CastKind::CK_FunctionToPointerDecay:
1677-
case clang::CastKind::CK_BuiltinFnToFnPtr:
1696+
case clang::CastKind::CK_BuiltinFnToFnPtr: {
1697+
if (isCallee()) {
1698+
Convert(sub_expr);
1699+
} else {
1700+
PushExprKind push(*this, ExprKind::AddrOf);
1701+
Convert(sub_expr);
1702+
}
1703+
break;
1704+
}
16781705
case clang::CastKind::CK_ConstructorConversion:
16791706
case clang::CastKind::CK_DerivedToBase:
16801707
Convert(sub_expr);
@@ -1698,7 +1725,11 @@ bool Converter::VisitImplicitCastExpr(clang::ImplicitCastExpr *expr) {
16981725
ConvertEqualsNullPtr(sub_expr);
16991726
break;
17001727
case clang::CastKind::CK_NullToPointer:
1701-
StrCat(keyword_default_);
1728+
if (type->isFunctionPointerType()) {
1729+
StrCat("None");
1730+
} else {
1731+
StrCat(keyword_default_);
1732+
}
17021733
computed_expr_type_ = ComputedExprType::FreshPointer;
17031734
break;
17041735
default:
@@ -1743,6 +1774,17 @@ bool Converter::VisitExplicitCastExpr(clang::ExplicitCastExpr *expr) {
17431774
if (expr->getType() == sub_expr->getType()) {
17441775
return Convert(sub_expr);
17451776
}
1777+
if (type->isFunctionPointerType() ||
1778+
sub_expr->getType()->isFunctionPointerType()) {
1779+
StrCat("std::mem::transmute::<");
1780+
Convert(sub_expr->getType());
1781+
StrCat(",");
1782+
Convert(type);
1783+
StrCat(">(");
1784+
Convert(sub_expr);
1785+
StrCat(")");
1786+
return false;
1787+
}
17461788
StrCat(token::kOpenParen);
17471789
Convert(sub_expr);
17481790
if (auto *unary_oper = clang::dyn_cast<clang::UnaryOperator>(sub_expr);
@@ -1969,12 +2011,12 @@ bool Converter::VisitConditionalOperator(clang::ConditionalOperator *expr) {
19692011
StrCat(keyword::kIf);
19702012
Convert(expr->getCond());
19712013
StrCat(token::kOpenCurlyBracket);
1972-
if (expr->isLValue() && !isRValue()) {
2014+
if (expr->isLValue() && !isRValue() && !expr->getType()->isFunctionType()) {
19732015
StrCat(token::kRef, keyword_mut_);
19742016
}
19752017
Convert(expr->getTrueExpr());
19762018
StrCat(token::kCloseCurlyBracket, keyword::kElse, token::kOpenCurlyBracket);
1977-
if (expr->isLValue() && !isRValue()) {
2019+
if (expr->isLValue() && !isRValue() && !expr->getType()->isFunctionType()) {
19782020
StrCat(token::kRef, keyword_mut_);
19792021
}
19802022
Convert(expr->getFalseExpr());
@@ -2028,35 +2070,23 @@ bool Converter::VisitDeclRefExpr(clang::DeclRefExpr *expr) {
20282070
return false;
20292071
}
20302072

2031-
if (auto function = clang::dyn_cast<clang::FunctionDecl>(decl)) {
2073+
if (auto *fn_decl = clang::dyn_cast<clang::FunctionDecl>(decl)) {
20322074
if (isAddrOf()) {
2033-
// Wrap unsafe function in safe closure because the Fn trait only accepts
2034-
// safe functions
2035-
std::string arguments;
2036-
for (unsigned i = 0; i < function->getNumParams(); ++i) {
2037-
arguments += (i ? ", a" : "a") + std::to_string(i);
2038-
}
2039-
StrCat("Rc::new", token::kOpenParen);
2040-
StrCat(std::format("|{}|", arguments));
2041-
StrCat(keyword_unsafe_, token::kOpenCurlyBracket);
2042-
StrCat(str);
2043-
StrCat(token::kOpenParen);
2044-
StrCat(arguments);
2045-
StrCat(token::kCloseParen);
2046-
StrCat(token::kCloseCurlyBracket);
2047-
StrCat(token::kCloseParen);
2075+
ConvertFunctionToFunctionPointer(fn_decl);
20482076
return false;
20492077
}
20502078
}
20512079

20522080
if (auto var_decl = clang::dyn_cast<clang::VarDecl>(decl)) {
2053-
if (auto init = var_decl->getInit()) {
2054-
if (auto lambda = clang::dyn_cast<clang::LambdaExpr>(
2055-
init->IgnoreUnlessSpelledInSource())) {
2056-
StrCat(token::kOpenParen);
2057-
VisitLambdaExpr(lambda);
2058-
StrCat(token::kCloseParen);
2059-
return false;
2081+
if (!var_decl->getType()->isFunctionPointerType()) {
2082+
if (auto init = var_decl->getInit()) {
2083+
if (auto lambda = clang::dyn_cast<clang::LambdaExpr>(
2084+
init->IgnoreUnlessSpelledInSource())) {
2085+
StrCat(token::kOpenParen);
2086+
VisitLambdaExpr(lambda);
2087+
StrCat(token::kCloseParen);
2088+
return false;
2089+
}
20602090
}
20612091
}
20622092
}
@@ -2515,6 +2545,9 @@ bool Converter::VisitCXXDefaultArgExpr(clang::CXXDefaultArgExpr *expr) {
25152545
}
25162546

25172547
bool Converter::VisitLambdaExpr(clang::LambdaExpr *expr) {
2548+
if (isAddrOf() && expr->capture_size() == 0) {
2549+
StrCat("Some");
2550+
}
25182551
StrCat(token::kOpenParen);
25192552
StrCat("|");
25202553
for (auto p : expr->getLambdaClass()->getLambdaCallOperator()->parameters()) {
@@ -2642,19 +2675,6 @@ bool Converter::VisitCXXStdInitializerListExpr(
26422675
return false;
26432676
}
26442677

2645-
std::string
2646-
Converter::GetFunctionPointerDefaultAsString(clang::QualType qual_type) {
2647-
std::string ret;
2648-
auto proto = qual_type->getPointeeType()->getAs<clang::FunctionProtoType>();
2649-
assert(proto);
2650-
ret = "Rc::new(|";
2651-
for (unsigned i = 0; i < proto->getNumParams(); ++i) {
2652-
ret += "_,";
2653-
}
2654-
ret += R"(| { panic!("ub: uninit function pointer") }))";
2655-
return ret;
2656-
}
2657-
26582678
std::string Converter::GetDefaultAsString(clang::QualType qual_type) {
26592679
if (IsVaListType(qual_type)) {
26602680
computed_expr_type_ = ComputedExprType::FreshValue;
@@ -2663,7 +2683,7 @@ std::string Converter::GetDefaultAsString(clang::QualType qual_type) {
26632683

26642684
if (qual_type->isPointerType()) {
26652685
if (qual_type->getPointeeType()->isFunctionType()) {
2666-
return GetFunctionPointerDefaultAsString(qual_type);
2686+
return "None";
26672687
} else {
26682688
computed_expr_type_ = ComputedExprType::FreshPointer;
26692689
return keyword_default_;
@@ -2811,6 +2831,16 @@ void Converter::ConvertVarInit(clang::QualType qual_type, clang::Expr *expr) {
28112831
StrCat(keyword_mut_);
28122832
}
28132833
}
2834+
if (qual_type->isFunctionPointerType()) {
2835+
if (auto *lambda = clang::dyn_cast<clang::LambdaExpr>(
2836+
expr->IgnoreUnlessSpelledInSource())) {
2837+
PushExprKind push(*this, ExprKind::AddrOf);
2838+
curr_init_type_.push(qual_type);
2839+
VisitLambdaExpr(lambda);
2840+
curr_init_type_.pop();
2841+
return;
2842+
}
2843+
}
28142844
auto *ignore_casts = expr->IgnoreCasts();
28152845
// FIXME: this looks very complicated
28162846
if (auto *ctor = clang::dyn_cast<clang::CXXConstructExpr>(ignore_casts);
@@ -2856,7 +2886,8 @@ void Converter::ConvertUnsignedArithOperand(clang::Expr *expr,
28562886
void Converter::ConvertEqualsNullPtr(clang::Expr *expr) {
28572887
StrCat("(");
28582888
Convert(expr);
2859-
if (IsUniquePtr(expr->getType())) {
2889+
if (IsUniquePtr(expr->getType()) ||
2890+
expr->getType()->isFunctionPointerType()) {
28602891
StrCat(").is_none()");
28612892
} else {
28622893
StrCat(").is_null()");
@@ -3240,6 +3271,13 @@ void Converter::PlaceholderCtx::dump() const {
32403271

32413272
std::string Converter::ConvertPlaceholder(clang::Expr *expr, clang::Expr *arg,
32423273
const PlaceholderCtx &ph_ctx) {
3274+
if (arg->getType()->isFunctionPointerType()) {
3275+
PushExprKind push(*this, ExprKind::Callee);
3276+
Buffer buf(*this);
3277+
Convert(arg);
3278+
return std::move(buf).str();
3279+
}
3280+
32433281
if (ph_ctx.needs_materialization()) {
32443282
auto materialized = ph_ctx.materialize_ctx->GetOrMaterialize(
32453283
static_cast<unsigned>(ph_ctx.materialize_idx),
@@ -3387,6 +3425,10 @@ bool Converter::isVoid() const {
33873425
return curr_expr_kind_.empty() || curr_expr_kind_.back() == ExprKind::Void;
33883426
}
33893427

3428+
bool Converter::isCallee() const {
3429+
return !curr_expr_kind_.empty() && curr_expr_kind_.back() == ExprKind::Callee;
3430+
}
3431+
33903432
void Converter::SetFresh() {
33913433
switch (computed_expr_type_) {
33923434
case ComputedExprType::Value:

cpp2rust/converter/converter.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,11 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {
6161

6262
virtual bool VisitPointerType(clang::PointerType *type);
6363

64-
void ConvertFunctionPointerType(clang::PointerType *type);
64+
enum class FnProtoType { LambdaCallOperator, FnPtr };
65+
66+
virtual std::string
67+
ConvertFunctionPointerType(const clang::FunctionProtoType *proto,
68+
FnProtoType kind = FnProtoType::FnPtr);
6569

6670
virtual bool VisitDecayedType(clang::DecayedType *type);
6771

@@ -201,6 +205,11 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {
201205

202206
void ConvertGenericCallExpr(clang::CallExpr *expr);
203207

208+
virtual void EmitFnPtrCall(clang::Expr *callee);
209+
210+
virtual void
211+
ConvertFunctionToFunctionPointer(const clang::FunctionDecl *fn_decl);
212+
204213
virtual void ConvertPrintf(clang::CallExpr *expr);
205214

206215
void ConvertVAArgCall(clang::CallExpr *expr);
@@ -334,8 +343,6 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {
334343
virtual bool Convert(clang::Stmt *stmt);
335344
virtual bool Convert(clang::Expr *expr);
336345

337-
std::string GetFunctionPointerDefaultAsString(clang::QualType qual_type);
338-
339346
virtual std::string GetDefaultAsString(clang::QualType qual_type);
340347

341348
virtual std::string GetDefaultAsStringFallback(clang::QualType qual_type);
@@ -472,6 +479,7 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {
472479
static std::unordered_set<std::string> abstract_structs_;
473480

474481
enum class ExprKind : uint8_t {
482+
Callee,
475483
LValue,
476484
RValue,
477485
XValue,
@@ -482,6 +490,8 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {
482490

483491
static const char *expr_kind_to_string(ExprKind kind) {
484492
switch (kind) {
493+
case ExprKind::Callee:
494+
return "Callee";
485495
case ExprKind::LValue:
486496
return "LValue";
487497
case ExprKind::RValue:
@@ -505,6 +515,7 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {
505515
bool isAddrOf() const;
506516
bool isObject() const;
507517
bool isVoid() const;
518+
bool isCallee() const;
508519

509520
void dump_expr_kinds();
510521

0 commit comments

Comments
 (0)