Skip to content

Commit 370e554

Browse files
authored
Handle comparisons against translated functions (#3)
This PR adds support for translating comparisons against functions that have translation rules `fn_ptr == translated_rule`, for example in `fn1 == fread` in fn_ptr_stdlib_compare.cpp: ```rs fn1 == FnPtr::new<fn(AnyPtr, u64, u64, Ptr<::std::fs::File>) -> u64>( libcc2rs::fread_refcount) ```
1 parent 7850660 commit 370e554

16 files changed

Lines changed: 589 additions & 367 deletions

cpp2rust/converter/converter.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,7 +1434,7 @@ void Converter::EmitFnPtrCall(clang::Expr *callee) {
14341434

14351435
void Converter::ConvertFunctionToFunctionPointer(
14361436
const clang::FunctionDecl *fn_decl) {
1437-
StrCat(std::format("Some({})", GetNamedDeclAsString(fn_decl)));
1437+
StrCat(std::format("Some({})", Mapper::MapFunctionName(fn_decl)));
14381438
}
14391439

14401440
void Converter::ConvertGenericCallExpr(clang::CallExpr *expr) {
@@ -2074,7 +2074,7 @@ std::string Converter::ConvertDeclRefExpr(clang::DeclRefExpr *expr) {
20742074
}
20752075

20762076
auto *decl = expr->getDecl();
2077-
if (Mapper::Contains(expr)) {
2077+
if (ShouldReplaceWithMappedBody(expr)) {
20782078
return GetMappedAsString(expr);
20792079
} else if (auto *function = decl->getAsFunction()) {
20802080
if (auto method = clang::dyn_cast<clang::CXXMethodDecl>(function)) {
@@ -3475,6 +3475,13 @@ bool Converter::isCallee() const {
34753475
return !curr_expr_kind_.empty() && curr_expr_kind_.back() == ExprKind::Callee;
34763476
}
34773477

3478+
bool Converter::ShouldReplaceWithMappedBody(clang::DeclRefExpr *expr) const {
3479+
if (clang::isa<clang::FunctionDecl>(expr->getDecl()) && isAddrOf()) {
3480+
return false;
3481+
}
3482+
return Mapper::Contains(expr);
3483+
}
3484+
34783485
void Converter::SetFresh() {
34793486
switch (computed_expr_type_) {
34803487
case ComputedExprType::Value:

cpp2rust/converter/converter.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,8 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {
459459

460460
virtual bool RecordDerivesDefault(const clang::RecordDecl *decl);
461461

462+
bool ShouldReplaceWithMappedBody(clang::DeclRefExpr *expr) const;
463+
462464
std::string *rs_code_;
463465
clang::ASTContext &ctx_;
464466
clang::FunctionDecl *curr_function_ = nullptr;

cpp2rust/converter/converter_lib.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
#include <clang/AST/ParentMapContext.h>
99
#include <clang/Basic/SourceManager.h>
1010

11+
#include <algorithm>
1112
#include <array>
13+
#include <cctype>
1214
#include <filesystem>
1315
#include <unordered_set>
1416

@@ -660,4 +662,22 @@ clang::Expr *CreateConversionToBool(clang::Expr *expr, clang::ASTContext &ctx) {
660662
/*BasePath=*/nullptr, clang::VK_PRValue, clang::FPOptionsOverride());
661663
}
662664

665+
static std::string_view Trim(std::string_view s) {
666+
auto is_space = [](unsigned char c) { return std::isspace(c); };
667+
auto b = std::find_if_not(s.begin(), s.end(), is_space);
668+
auto e = std::find_if_not(s.rbegin(), s.rend(), is_space).base();
669+
return {b, e};
670+
}
671+
672+
void Unwrap(std::string &s, std::string_view prefix, std::string_view suffix) {
673+
auto trimmed = Trim(s);
674+
if (trimmed.starts_with(prefix) && trimmed.ends_with(suffix)) {
675+
assert(trimmed.size() >= prefix.size() + suffix.size() &&
676+
"prefix and suffix overlap in s");
677+
trimmed.remove_prefix(prefix.size());
678+
trimmed.remove_suffix(suffix.size());
679+
s = std::string(trimmed);
680+
}
681+
}
682+
663683
} // namespace cpp2rust

cpp2rust/converter/converter_lib.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <optional>
1313
#include <regex>
1414
#include <string>
15+
#include <string_view>
1516
#include <vector>
1617

1718
namespace cpp2rust {
@@ -154,4 +155,6 @@ bool ContainsVAArgExpr(const clang::Stmt *stmt);
154155

155156
clang::Expr *CreateConversionToBool(clang::Expr *expr, clang::ASTContext &ctx);
156157

158+
void Unwrap(std::string &s, std::string_view prefix, std::string_view suffix);
159+
157160
} // namespace cpp2rust

cpp2rust/converter/mapper.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <llvm/Support/ThreadPool.h>
99

1010
#include <atomic>
11+
#include <format>
1112
#include <mutex>
1213
#include <regex>
1314
#include <utility>
@@ -581,6 +582,15 @@ const TranslationRule::ExprTgt *GetExprTgt(const clang::Expr *expr) {
581582
return nullptr;
582583
}
583584

585+
std::string MapFunctionName(const clang::FunctionDecl *decl) {
586+
assert(decl);
587+
if (exprs_.contains(ToString(decl))) {
588+
return std::format("libcc2rs::{}_{}", decl->getNameAsString(),
589+
model_ == Model::kRefCount ? "refcount" : "unsafe");
590+
}
591+
return GetNamedDeclAsString(decl->getCanonicalDecl());
592+
}
593+
584594
std::string InstantiateTemplate(const clang::Expr *expr,
585595
const std::string &text) {
586596
auto it = search(expr);

cpp2rust/converter/mapper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ bool Contains(const clang::Expr *expr);
3131

3232
std::string Map(clang::QualType qual_type);
3333
const TranslationRule::ExprTgt *GetExprTgt(const clang::Expr *expr);
34+
std::string MapFunctionName(const clang::FunctionDecl *decl);
3435
std::string InstantiateTemplate(const clang::Expr *expr,
3536
const std::string &text);
3637
bool ReturnsPointer(const clang::Expr *expr);

cpp2rust/converter/models/converter_refcount.cpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ std::string ConverterRefCount::BuildFnAdapter(
195195
closure += "{ ";
196196

197197
// Build adapter body: src_fn(convert(a0), convert(a1), ...)
198-
closure += GetNamedDeclAsString(src_fn) + "(";
198+
closure += Mapper::MapFunctionName(src_fn) + '(';
199199
for (unsigned i = 0; i < src_proto->getNumParams(); ++i) {
200200
auto src_pty = src_proto->getParamType(i);
201201
auto tgt_pty = target_proto->getParamType(i);
@@ -204,12 +204,12 @@ std::string ConverterRefCount::BuildFnAdapter(
204204
} else if (src_pty->isPointerType() && tgt_pty->isPointerType()) {
205205
if (tgt_pty->isVoidPointerType()) {
206206
closure += std::format("a{}.cast::<{}>().unwrap()", i,
207-
ToString(src_pty->getPointeeType()));
207+
ConvertPointeeType(src_pty));
208208
} else if (src_pty->isVoidPointerType()) {
209209
closure += std::format("a{}.to_any()", i);
210210
} else if (tgt_pty->getPointeeType()->isCharType()) {
211211
closure += std::format("a{}.reinterpret_cast::<{}>()", i,
212-
ToString(src_pty->getPointeeType()));
212+
ConvertPointeeType(src_pty));
213213
} else if (src_pty->getPointeeType()->isCharType()) {
214214
closure += std::format("a{}.reinterpret_cast::<u8>()", i);
215215
}
@@ -625,7 +625,7 @@ bool ConverterRefCount::VisitDeclRefExpr(clang::DeclRefExpr *expr) {
625625
}
626626
}
627627

628-
if (Mapper::Contains(expr)) {
628+
if (ShouldReplaceWithMappedBody(expr)) {
629629
StrCat(GetMappedAsString(expr));
630630
return false;
631631
}
@@ -1037,7 +1037,7 @@ void ConverterRefCount::ConvertFunctionToFunctionPointer(
10371037
StrCat(std::format("FnPtr::<{}>::new({})",
10381038
ConvertFunctionPointerType(
10391039
fn_decl->getType()->getAs<clang::FunctionProtoType>()),
1040-
GetNamedDeclAsString(fn_decl)));
1040+
Mapper::MapFunctionName(fn_decl)));
10411041
}
10421042

10431043
void ConverterRefCount::ConvertEqualsNullPtr(clang::Expr *expr) {
@@ -2155,4 +2155,18 @@ std::string ConverterRefCount::ConvertMappedMethodCall(
21552155
return std::format("{}.with_mut(|__v: {}| __v{})", ptr, param_type, body);
21562156
}
21572157

2158+
std::string ConverterRefCount::ConvertPointeeType(clang::QualType ptr_type) {
2159+
if (ptr_type->getPointeeType()->isIntegerType()) {
2160+
return ToString(ptr_type->getPointeeType());
2161+
}
2162+
2163+
// Pointee of a pointer to incomplete type is an incomplete type that does
2164+
// not have a translation rule. Hence ToString(ptr_type->getPointeeType()) is
2165+
// not enough
2166+
assert(ptr_type->isPointerType());
2167+
auto str = ToString(ptr_type);
2168+
Unwrap(str, "Ptr<", ">");
2169+
return str;
2170+
}
2171+
21582172
} // namespace cpp2rust

cpp2rust/converter/models/converter_refcount.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ class ConverterRefCount final : public Converter {
204204
std::string ConvertFreshPointer(clang::Expr *expr) override;
205205

206206
std::string ConvertPtrType(clang::QualType type);
207+
std::string ConvertPointeeType(clang::QualType ptr_type);
207208

208209
/// The kind of conversion that should be performed.
209210
enum class ConversionKind : uint8_t {

libcc2rs/src/io.rs

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// Copyright (c) 2022-present INESC-ID.
22
// Distributed under the MIT license that can be found in the LICENSE file.
33

4-
use crate::{AsPointer, Ptr, Value};
4+
use crate::{AnyPtr, AsPointer, Ptr, Value};
55
use std::cell::{RefCell, UnsafeCell};
66
use std::os::fd::{AsFd, FromRawFd, IntoRawFd};
77
use std::rc::Rc;
@@ -58,3 +58,83 @@ pub unsafe fn cout_unsafe() -> *mut std::fs::File {
5858
pub unsafe fn cerr_unsafe() -> *mut std::fs::File {
5959
UNSAFE_STDERR.with(UnsafeCell::get)
6060
}
61+
62+
pub fn fread_refcount(a0: AnyPtr, a1: u64, a2: u64, a3: Ptr<::std::fs::File>) -> u64 {
63+
let total = a1.saturating_mul(a2) as usize;
64+
let mut dst = a0
65+
.cast::<u8>()
66+
.expect("fread: only supporting u8 pointers")
67+
.clone();
68+
69+
let f = (*a3.upgrade().deref())
70+
.try_clone()
71+
.expect("try_clone failed");
72+
let mut reader = std::io::BufReader::with_capacity(64 * 1024, f);
73+
74+
let mut read_bytes: usize = 0;
75+
let mut buffer: [u8; 8192] = [0; 8192];
76+
77+
while read_bytes < total {
78+
let remaining = total - read_bytes;
79+
let to_read = std::cmp::min(buffer.len(), remaining);
80+
81+
let n = match std::io::Read::read(&mut reader, &mut buffer[..to_read]) {
82+
Ok(0) => break,
83+
Ok(n) => n,
84+
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
85+
Err(e) => panic!("Unhandled error in fread: {e}"),
86+
};
87+
88+
for &byte in &buffer[..n] {
89+
dst.write(byte);
90+
dst = dst.offset(1);
91+
}
92+
93+
read_bytes += n;
94+
}
95+
96+
(read_bytes / a1 as usize) as u64
97+
}
98+
99+
/// # Safety
100+
///
101+
/// `a0` must point to a writable buffer of at least `a1 * a2` bytes, and `a3`
102+
/// must point to a valid, open `std::fs::File`.
103+
pub unsafe fn fread_unsafe(
104+
a0: *mut ::std::ffi::c_void,
105+
a1: u64,
106+
a2: u64,
107+
a3: *mut ::std::fs::File,
108+
) -> u64 {
109+
let total = a1.saturating_mul(a2) as usize;
110+
let mut dst = a0 as *mut u8;
111+
112+
let f = unsafe { (*a3).try_clone().expect("try_clone failed") };
113+
let mut reader = std::io::BufReader::with_capacity(64 * 1024, f);
114+
115+
let mut read_bytes: usize = 0;
116+
let mut buffer: [u8; 8192] = [0; 8192];
117+
118+
while read_bytes < total {
119+
let remaining = total - read_bytes;
120+
let to_read = std::cmp::min(buffer.len(), remaining);
121+
122+
let n = match std::io::Read::read(&mut reader, &mut buffer[..to_read]) {
123+
Ok(0) => break,
124+
Ok(n) => n,
125+
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
126+
Err(e) => panic!("Unhandled error in fread: {e}"),
127+
};
128+
129+
for &byte in &buffer[..n] {
130+
unsafe {
131+
*dst = byte;
132+
dst = dst.offset(1);
133+
}
134+
}
135+
136+
read_bytes += n;
137+
}
138+
139+
(read_bytes / a1 as usize) as u64
140+
}

0 commit comments

Comments
 (0)