Skip to content

Commit f35afcf

Browse files
committed
Add support for function pointers in PtrKind
Previously function pointers were modeled as Option<fn>, but this is problematic for function pointer casts. Option<fn> works well in unsafe with std::mem::transmute, however there is no safe way to achieve the same operation in refcount. This is solved using the new Ptr<fn>. To allow casting between different function types, use type erased Rc<dyn Any> inside the new PtrKind::Fn. Equality of function pointers is achieved through implementing the OriginalAlloc::address method. The C standard allows converting function pointers between incompatible function types. UB is triggered only when the incompatible pointer is called. For this reason the new FnState implements 2 new concepts: 1. casting adaptors (to allow argument casting between ABI compatible types) 2. provenance stack (to allow round-trip function pointer casts) For 1., consider the following cast: int fn_taking_int_ptr(int *p); int (*fn_taking_void_ptr)(void*) = (int (*)(void*))fn_taking_int_ptr; Calling fn_taking_int_ptr with an int* argument works because both int* and void* have the same size. To support this in Rust we need to create an int* -> void* adapter when casting from fn_taking_int_ptr to fn_taking_void_ptr: fn_taking_int_ptr.cast_fn::<fn(AnyPtr) -> i32>(Some( (|a0: AnyPtr| -> i32 { fn_taking_int_ptr(a0.cast::<i32>().unwrap()) }) as fn(AnyPtr) -> i32 )) The job of the adapter is to convert from AnyPtr to Ptr<i32>. Ptr::cast_fn is a new function that takes as type argument the type of the target function pointer and an optional adaptor. If cast_fn receives None, then there is no valid adaptor from source to target, matching the UB semantics of calling a function through an incompatible function pointer: int add(int a, int b) { return a + b; } void (*wrong)(void) = (void (*)(void))add; wrong() For 2., the provenance stack contains all casts performed on the pointer in the past. Compared to PtrKind::Reinterpreted, PtrKind::Fn has no backing byte storage through OriginalAlloc, so each cast must know its history in order to allow round-trip casts, such as: int (*)(int, int) -> void (*)(void) -> int (*)(int, int) (1) (2) For this specific case, where both (1) and (2) create non-compatible adaptors (because of non-compatible arguments), we cannot recover a call to the original function after (1) is performed. For this to work, save a stack of provenance, and when (2) is perfomed, cast_fn recovers the original function pointer. See test_roundtrip in fn_ptr_cast.cpp. A current limitation of this approach is that it only allows function pointer casts where the source is a direct declaration of a function. Accessing a function pointer through a member field for example, would create a capturing adapter which does not coerce in a fn inside Ptr<fn>.
1 parent 3b52da7 commit f35afcf

22 files changed

Lines changed: 518 additions & 163 deletions

cpp2rust/converter/converter.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1375,6 +1375,12 @@ bool Converter::VisitCallExpr(clang::CallExpr *expr) {
13751375
return false;
13761376
}
13771377

1378+
void Converter::EmitFnPtrCall(clang::Expr *callee) {
1379+
StrCat(token::kOpenParen);
1380+
Convert(callee);
1381+
StrCat(").unwrap()");
1382+
}
1383+
13781384
void Converter::ConvertGenericCallExpr(clang::CallExpr *expr) {
13791385
clang::Expr *callee = expr->getCallee();
13801386
auto convert_param_ty = [&](clang::QualType param_type, clang::Expr *expr) {
@@ -1447,9 +1453,7 @@ void Converter::ConvertGenericCallExpr(clang::CallExpr *expr) {
14471453
}
14481454

14491455
if (proto && !function) {
1450-
StrCat(token::kOpenParen);
1451-
Convert(callee);
1452-
StrCat(").unwrap()");
1456+
EmitFnPtrCall(callee);
14531457
} else {
14541458
PushExprKind push(*this, ExprKind::RValue);
14551459
Convert(StripFunctionPointerDecay(callee));

cpp2rust/converter/converter.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {
199199

200200
void ConvertGenericCallExpr(clang::CallExpr *expr);
201201

202+
virtual void EmitFnPtrCall(clang::Expr *callee);
203+
202204
virtual void ConvertPrintf(clang::CallExpr *expr);
203205

204206
void ConvertVAArgCall(clang::CallExpr *expr);
@@ -332,7 +334,8 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {
332334
virtual bool Convert(clang::Stmt *stmt);
333335
virtual bool Convert(clang::Expr *expr);
334336

335-
std::string GetFunctionPointerDefaultAsString(clang::QualType qual_type);
337+
virtual std::string
338+
GetFunctionPointerDefaultAsString(clang::QualType qual_type);
336339

337340
virtual std::string GetDefaultAsString(clang::QualType qual_type);
338341

cpp2rust/converter/models/converter_refcount.cpp

Lines changed: 146 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,72 @@ bool ConverterRefCount::VisitLValueReferenceType(
165165
return false;
166166
}
167167

168+
std::string ConverterRefCount::BuildFnAdapter(
169+
const clang::FunctionDecl *src_fn,
170+
const clang::FunctionProtoType *src_proto,
171+
const clang::FunctionProtoType *target_proto) {
172+
173+
// UB: Incompatible arity
174+
if (src_proto->getNumParams() != target_proto->getNumParams()) {
175+
return "None";
176+
}
177+
178+
PushConversionKind push(*this, ConversionKind::Unboxed);
179+
180+
// Build adapter signature: |a0: T0, a1: T1, ...| -> Tr
181+
std::string closure = "(|";
182+
for (unsigned i = 0; i < target_proto->getNumParams(); ++i) {
183+
if (i > 0)
184+
closure += ", ";
185+
closure +=
186+
std::format("a{}: {}", i, ToString(target_proto->getParamType(i)));
187+
}
188+
closure += "|";
189+
if (!target_proto->getReturnType()->isVoidType()) {
190+
closure += std::format(" -> {} ", ToString(target_proto->getReturnType()));
191+
}
192+
closure += "{ ";
193+
194+
// Build adapter body: src_fn(convert(a0), convert(a1), ...)
195+
closure += GetNamedDeclAsString(src_fn->getCanonicalDecl()) + "(";
196+
for (unsigned i = 0; i < src_proto->getNumParams(); ++i) {
197+
auto src_pty = src_proto->getParamType(i);
198+
auto tgt_pty = target_proto->getParamType(i);
199+
if (ToString(src_pty) == ToString(tgt_pty)) {
200+
closure += std::format("a{}", i);
201+
} else if (src_pty->isPointerType() && tgt_pty->isVoidPointerType()) {
202+
closure += std::format("a{}.cast::<{}>().unwrap()", i,
203+
ToString(src_pty->getPointeeType()));
204+
} else if (src_pty->isVoidPointerType() && tgt_pty->isPointerType()) {
205+
closure += std::format("a{}.to_any()", i);
206+
} else {
207+
// UB: Incompatible types
208+
return "None";
209+
}
210+
closure += ", ";
211+
}
212+
closure += ") })";
213+
214+
return std::format("Some({} as {})", closure, GetFnTypeString(target_proto));
215+
}
216+
217+
std::string
218+
ConverterRefCount::GetFnTypeString(const clang::FunctionProtoType *proto) {
219+
PushConversionKind push(*this, ConversionKind::Unboxed);
220+
std::string result = "fn(";
221+
for (auto p_ty : proto->param_types()) {
222+
result += ToString(p_ty) + ",";
223+
}
224+
result += ")";
225+
if (!proto->getReturnType()->isVoidType()) {
226+
result += std::format(" -> {}", ToString(proto->getReturnType()));
227+
}
228+
return result;
229+
}
230+
168231
bool ConverterRefCount::VisitPointerType(clang::PointerType *type) {
169-
if (type->getPointeeType()->getAs<clang::FunctionProtoType>()) {
170-
PushConversionKind push(*this, ConversionKind::Unboxed);
171-
ConvertFunctionPointerType(type);
232+
if (auto proto = type->getPointeeType()->getAs<clang::FunctionProtoType>()) {
233+
StrCat(std::format("Ptr<{}>", GetFnTypeString(proto)));
172234
return false;
173235
}
174236

@@ -570,7 +632,9 @@ bool ConverterRefCount::VisitDeclRefExpr(clang::DeclRefExpr *expr) {
570632

571633
if (clang::isa<clang::FunctionDecl>(decl)) {
572634
if (isAddrOf()) {
573-
StrCat(std::format("Some({} as _)", str));
635+
auto proto = decl->getType()->getAs<clang::FunctionProtoType>();
636+
auto fn_type = GetFnTypeString(proto);
637+
StrCat(std::format("fn_ptr!({}, {})", str, fn_type));
574638
} else {
575639
StrCat(str);
576640
}
@@ -951,9 +1015,79 @@ bool ConverterRefCount::VisitImplicitCastExpr(clang::ImplicitCastExpr *expr) {
9511015
}
9521016
}
9531017

1018+
if (expr->getCastKind() == clang::CastKind::CK_NullToPointer &&
1019+
expr->getType()->isFunctionPointerType()) {
1020+
StrCat("Ptr::null()");
1021+
computed_expr_type_ = ComputedExprType::FreshPointer;
1022+
return false;
1023+
}
1024+
9541025
return Converter::VisitImplicitCastExpr(expr);
9551026
}
9561027

1028+
void ConverterRefCount::EmitFnPtrCall(clang::Expr *callee) {
1029+
Convert(callee);
1030+
StrCat(".call_fn()");
1031+
}
1032+
1033+
std::string ConverterRefCount::GetFunctionPointerDefaultAsString(
1034+
clang::QualType qual_type) {
1035+
return "Ptr::null()";
1036+
}
1037+
1038+
void ConverterRefCount::ConvertEqualsNullPtr(clang::Expr *expr) {
1039+
StrCat("(");
1040+
Convert(expr);
1041+
StrCat(").is_null()");
1042+
}
1043+
1044+
bool ConverterRefCount::VisitFunctionPointerCast(
1045+
clang::ExplicitCastExpr *expr) {
1046+
if (expr->getType()->isFunctionPointerType() ||
1047+
expr->getSubExpr()->getType()->isFunctionPointerType()) {
1048+
if (expr->getSubExpr()->getType()->isFunctionPointerType() &&
1049+
expr->getType()->isFunctionPointerType()) {
1050+
auto target_proto =
1051+
expr->getType()->getPointeeType()->getAs<clang::FunctionProtoType>();
1052+
auto src_proto = expr->getSubExpr()
1053+
->getType()
1054+
->getPointeeType()
1055+
->getAs<clang::FunctionProtoType>();
1056+
auto fn_type = GetFnTypeString(target_proto);
1057+
1058+
std::string adapter = "None";
1059+
// Only accept direct references to the casted function. Otherwise the
1060+
// closure would be capturing and would not coerce into a fn pointer.
1061+
if (auto *decl_ref = clang::dyn_cast<clang::DeclRefExpr>(
1062+
expr->getSubExpr()->IgnoreImplicit())) {
1063+
if (auto *fn_decl =
1064+
clang::dyn_cast<clang::FunctionDecl>(decl_ref->getDecl())) {
1065+
adapter = BuildFnAdapter(fn_decl, src_proto, target_proto);
1066+
}
1067+
}
1068+
1069+
StrCat(std::format("{}.cast_fn::<{}>({})", ToString(expr->getSubExpr()),
1070+
fn_type, adapter));
1071+
} else if (expr->getSubExpr()->getType()->isFunctionPointerType() ||
1072+
expr->getType()->isVoidPointerType()) {
1073+
Convert(expr->getSubExpr());
1074+
StrCat(".to_any()");
1075+
} else if (expr->getSubExpr()->getType()->isVoidPointerType() ||
1076+
expr->getType()->isFunctionPointerType()) {
1077+
auto target_proto =
1078+
expr->getType()->getPointeeType()->getAs<clang::FunctionProtoType>();
1079+
auto fn_type = GetFnTypeString(target_proto);
1080+
StrCat(std::format("{}.cast::<{}>().expect(\"ub:wrong fn type\")",
1081+
ToString(expr->getSubExpr()), fn_type));
1082+
} else {
1083+
assert(0 && "Unhandled function pointer cast");
1084+
}
1085+
return false;
1086+
}
1087+
1088+
return true;
1089+
}
1090+
9571091
bool ConverterRefCount::VisitExplicitCastExpr(clang::ExplicitCastExpr *expr) {
9581092
if (expr->getTypeAsWritten()->isVoidType()) {
9591093
return false;
@@ -968,7 +1102,9 @@ bool ConverterRefCount::VisitExplicitCastExpr(clang::ExplicitCastExpr *expr) {
9681102
return false;
9691103
case clang::Stmt::CStyleCastExprClass:
9701104
case clang::Stmt::CXXStaticCastExprClass:
971-
if (expr->getSubExpr()->getType()->isVoidPointerType()) {
1105+
if (!VisitFunctionPointerCast(expr)) {
1106+
return false;
1107+
} else if (expr->getSubExpr()->getType()->isVoidPointerType()) {
9721108
Convert(expr->getSubExpr());
9731109
PushConversionKind push(*this, ConversionKind::Unboxed);
9741110
StrCat(std::format(".cast::<{}>().expect(\"ub:wrong type\")",
@@ -1478,8 +1614,12 @@ void ConverterRefCount::ConvertVarInit(clang::QualType qual_type,
14781614
Buffer buf(*this);
14791615
PushConversionKind push(*this, ConversionKind::Unboxed);
14801616
if (qual_type->isFunctionPointerType() && lambda->capture_size() == 0) {
1481-
PushExprKind addr_of(*this, ExprKind::AddrOf);
1617+
StrCat(std::format("Ptr::from_fn(("));
14821618
VisitLambdaExpr(lambda);
1619+
StrCat(std::format(
1620+
") as {}, 0)",
1621+
GetFnTypeString(qual_type->getPointeeType()
1622+
->getAs<clang::FunctionProtoType>())));
14831623
} else {
14841624
VisitLambdaExpr(lambda);
14851625
}

cpp2rust/converter/models/converter_refcount.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,16 @@ class ConverterRefCount final : public Converter {
5959

6060
void ConvertPrintf(clang::CallExpr *expr) override;
6161

62+
void EmitFnPtrCall(clang::Expr *callee) override;
63+
6264
bool VisitCallExpr(clang::CallExpr *expr) override;
6365

6466
bool VisitStringLiteral(clang::StringLiteral *expr) override;
6567

6668
bool VisitImplicitCastExpr(clang::ImplicitCastExpr *expr) override;
6769

70+
bool VisitFunctionPointerCast(clang::ExplicitCastExpr *expr);
71+
6872
bool VisitExplicitCastExpr(clang::ExplicitCastExpr *expr) override;
6973

7074
bool VisitBinaryOperator(clang::BinaryOperator *expr) override;
@@ -103,6 +107,11 @@ class ConverterRefCount final : public Converter {
103107

104108
std::string GetDefaultAsString(clang::QualType qual_type) override;
105109

110+
std::string
111+
GetFunctionPointerDefaultAsString(clang::QualType qual_type) override;
112+
113+
void ConvertEqualsNullPtr(clang::Expr *expr) override;
114+
106115
std::string GetDefaultAsStringFallback(clang::QualType qual_type) override;
107116

108117
std::string ConvertVarDefaultInit(clang::QualType qual_type) override;
@@ -171,6 +180,11 @@ class ConverterRefCount final : public Converter {
171180
const char *GetPointerDerefSuffix(clang::QualType pointee_type);
172181
const char *GetPointerDerefPrefix(clang::QualType pointee_type) override;
173182

183+
std::string GetFnTypeString(const clang::FunctionProtoType *proto);
184+
std::string BuildFnAdapter(const clang::FunctionDecl *src_fn,
185+
const clang::FunctionProtoType *src_proto,
186+
const clang::FunctionProtoType *target_proto);
187+
174188
void EmitSetOrAssign(clang::Expr *lhs, std::string_view rhs);
175189

176190
// Wraps a pointer expression with deref prefix/suffix: e.g.

libcc2rs/src/lib.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@
44
mod reinterpret;
55
pub use reinterpret::ByteRepr;
66

7+
#[macro_export]
8+
macro_rules! fn_ptr {
9+
($f:expr, $ty:ty) => {
10+
$crate::Ptr::from_fn($f as $ty, $f as *const () as usize)
11+
};
12+
}
13+
714
mod rc;
815
pub use rc::*;
916

0 commit comments

Comments
 (0)