Skip to content

Commit 81ca488

Browse files
authored
Translate unions in unsafe (#32)
C/C++ unions are translated as Rust unions in unsafe. They manually implement default with `std::mem::zeroed` since unions cannot derive Default Each struct and union is decorated with the `repr(C)` attribute to make sure that the C layout is preserved. We need this because Rust is free to reorder fields for optimal struct layout
1 parent 64aa98b commit 81ca488

85 files changed

Lines changed: 828 additions & 466 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

cpp2rust/compat/platform_flags.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ static inline std::vector<std::string> getPlatformClangFlags() {
1010
std::vector<std::string> flags = {
1111
"-resource-dir=" CLANG_RESOURCE_DIR,
1212
"-I" COMPAT_INCLUDE_DIR,
13+
"-D_FORTIFY_SOURCE=0",
1314
};
1415
#ifdef MACOS_SDK_PATH
1516
flags.push_back("-isysroot" MACOS_SDK_PATH);

cpp2rust/converter/converter.cpp

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -600,12 +600,12 @@ bool Converter::VisitRecordDecl(clang::RecordDecl *decl) {
600600
}
601601

602602
Mapper::AddRuleForUserDefinedType(decl);
603-
EmitRustStruct(decl);
603+
EmitRustStructOrUnion(decl);
604604

605605
return false;
606606
}
607607

608-
void Converter::EmitRustStruct(clang::RecordDecl *decl) {
608+
void Converter::EmitRustStructOrUnion(clang::RecordDecl *decl) {
609609
// Enums and static variables. In rust they live outside the record
610610
for (auto *d : decl->decls()) {
611611
if (auto *enum_decl = llvm::dyn_cast<clang::EnumDecl>(d)) {
@@ -631,6 +631,9 @@ void Converter::EmitRustStruct(clang::RecordDecl *decl) {
631631
}
632632

633633
// Derived traits
634+
if (EmitsReprCForRecords()) {
635+
StrCat("#[repr(C)]");
636+
}
634637
StrCat("#[derive(");
635638
for (auto *attr : GetStructAttributes(decl)) {
636639
StrCat(attr, ",");
@@ -641,7 +644,8 @@ void Converter::EmitRustStruct(clang::RecordDecl *decl) {
641644
auto access = clang::dyn_cast<clang::CXXRecordDecl>(decl)
642645
? AccessSpecifierAsString(decl->getAccess())
643646
: keyword::kPub;
644-
StrCat(access, keyword::kStruct, GetRecordName(decl));
647+
StrCat(access, decl->isUnion() ? keyword::kUnion : keyword::kStruct,
648+
GetRecordName(decl));
645649
{
646650
PushBrace brace(*this);
647651
for (auto *field : decl->fields()) {
@@ -682,8 +686,8 @@ void Converter::EmitRustStruct(clang::RecordDecl *decl) {
682686
AddOrdTrait(cxx);
683687
AddCloneTrait(cxx);
684688
AddDropTrait(cxx);
685-
AddDefaultTrait(cxx);
686689
}
690+
AddDefaultTrait(decl);
687691
AddByteReprTrait(decl);
688692
}
689693

@@ -729,10 +733,10 @@ bool Converter::VisitCXXRecordDecl(clang::CXXRecordDecl *decl) {
729733
}
730734
}
731735

732-
EmitRustStruct(decl);
736+
EmitRustStructOrUnion(decl);
733737
} else {
734738
// FIXME: improve error handling
735-
assert(0 && "unsupported union");
739+
assert(0 && "unsupported record kind");
736740
}
737741

738742
return false;
@@ -2889,6 +2893,10 @@ std::string Converter::GetRecordName(const clang::NamedDecl *decl) const {
28892893

28902894
std::vector<const char *>
28912895
Converter::GetStructAttributes(const clang::RecordDecl *decl) {
2896+
if (decl->isUnion()) {
2897+
return {"Copy", "Clone"};
2898+
}
2899+
28922900
std::vector<const char *> struct_attrs = {};
28932901

28942902
if (recordDerivesCopy(decl)) {
@@ -3208,7 +3216,21 @@ void Converter::AddCloneTrait(const clang::CXXRecordDecl *decl) {}
32083216

32093217
void Converter::AddDropTrait(const clang::CXXRecordDecl *decl) {}
32103218

3211-
void Converter::AddDefaultTrait(const clang::CXXRecordDecl *decl) {
3219+
void Converter::AddDefaultTraitForUnion(const clang::RecordDecl *decl) {
3220+
StrCat(std::format("impl Default for {}", GetRecordName(decl)));
3221+
PushBrace impl_brace(*this);
3222+
StrCat("fn default() -> Self");
3223+
PushBrace fn_brace(*this);
3224+
StrCat("unsafe");
3225+
PushBrace unsafe_brace(*this);
3226+
StrCat("std::mem::zeroed()");
3227+
}
3228+
3229+
void Converter::AddDefaultTrait(const clang::RecordDecl *decl) {
3230+
if (decl->isUnion()) {
3231+
AddDefaultTraitForUnion(decl);
3232+
return;
3233+
}
32123234
if (RecordDerivesDefault(decl)) {
32133235
return;
32143236
}
@@ -3217,20 +3239,26 @@ void Converter::AddDefaultTrait(const clang::CXXRecordDecl *decl) {
32173239
PushBrace impl_brace(*this);
32183240
StrCat("fn default() -> Self");
32193241
PushBrace fn_brace(*this);
3220-
if (auto *default_ctor = GetUserDefinedDefaultConstructor(decl)) {
3221-
StrCat(keyword_unsafe_);
3222-
PushBrace unsafe_brace(*this);
3223-
Convert(clang::CXXConstructExpr::Create(
3224-
ctx_, ctx_.getCanonicalTagType(decl), clang::SourceLocation(),
3225-
default_ctor,
3226-
/*Elidable=*/false, llvm::ArrayRef<clang::Expr *>(),
3227-
/*HadMultipleCandidates=*/false,
3228-
/*ListInitialization=*/false,
3229-
/*StdInitListInitialization=*/false,
3230-
/*ZeroInitialization=*/false, clang::CXXConstructionKind::Complete,
3231-
clang::SourceRange()));
3232-
} else {
3233-
StrCat(struct_name);
3242+
3243+
if (auto *cxx = clang::dyn_cast<clang::CXXRecordDecl>(decl)) {
3244+
if (auto *default_ctor = GetUserDefinedDefaultConstructor(cxx)) {
3245+
StrCat(keyword_unsafe_);
3246+
PushBrace unsafe_brace(*this);
3247+
Convert(clang::CXXConstructExpr::Create(
3248+
ctx_, ctx_.getCanonicalTagType(decl), clang::SourceLocation(),
3249+
default_ctor,
3250+
/*Elidable=*/false, llvm::ArrayRef<clang::Expr *>(),
3251+
/*HadMultipleCandidates=*/false,
3252+
/*ListInitialization=*/false,
3253+
/*StdInitListInitialization=*/false,
3254+
/*ZeroInitialization=*/false, clang::CXXConstructionKind::Complete,
3255+
clang::SourceRange()));
3256+
return;
3257+
}
3258+
}
3259+
3260+
StrCat(struct_name);
3261+
{
32343262
PushBrace struct_brace(*this);
32353263
for (auto *field : decl->fields()) {
32363264
StrCat(GetNamedDeclAsString(field), token::kColon,

cpp2rust/converter/converter.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {
9999

100100
virtual bool VisitCXXRecordDecl(clang::CXXRecordDecl *decl);
101101

102-
void EmitRustStruct(clang::RecordDecl *decl);
102+
virtual void EmitRustStructOrUnion(clang::RecordDecl *decl);
103+
104+
virtual bool EmitsReprCForRecords() const { return true; }
103105

104106
virtual bool VisitCXXMethodDecl(clang::CXXMethodDecl *decl);
105107
virtual std::string GetSelfMaybeWithMut(const clang::CXXMethodDecl *decl);
@@ -441,7 +443,9 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {
441443

442444
virtual void AddDropTrait(const clang::CXXRecordDecl *decl);
443445

444-
virtual void AddDefaultTrait(const clang::CXXRecordDecl *decl);
446+
virtual void AddDefaultTrait(const clang::RecordDecl *decl);
447+
448+
virtual void AddDefaultTraitForUnion(const clang::RecordDecl *decl);
445449

446450
virtual void AddByteReprTrait(const clang::RecordDecl *decl);
447451

cpp2rust/converter/lex.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ inline constexpr const char *kReturn = "return";
5151
inline constexpr const char *kSelfValue = "self";
5252
inline constexpr const char *kStatic = "static";
5353
inline constexpr const char *kStruct = "struct";
54+
inline constexpr const char *kUnion = "union";
5455
inline constexpr const char *kTrue = "true";
5556
inline constexpr const char *kWhile = "while";
5657
inline constexpr const char *kFor = "for";

cpp2rust/converter/models/converter_refcount.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,11 +459,14 @@ void ConverterRefCount::AddCloneTrait(const clang::CXXRecordDecl *decl) {
459459
StrCat("}");
460460
}
461461

462-
void ConverterRefCount::AddDefaultTrait(const clang::CXXRecordDecl *decl) {
462+
void ConverterRefCount::AddDefaultTrait(const clang::RecordDecl *decl) {
463463
PushConversionKind push(*this, ConversionKind::FullRefCount);
464464
Converter::AddDefaultTrait(decl);
465465
}
466466

467+
void ConverterRefCount::AddDefaultTraitForUnion(const clang::RecordDecl *decl) {
468+
}
469+
467470
void ConverterRefCount::AddDropTrait(const clang::CXXRecordDecl *decl) {
468471
if (!decl->hasUserDeclaredDestructor()) {
469472
return;

cpp2rust/converter/models/converter_refcount.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ class ConverterRefCount final : public Converter {
2828

2929
bool VisitCXXRecordDecl(clang::CXXRecordDecl *decl) override;
3030

31+
bool EmitsReprCForRecords() const override { return false; }
32+
3133
void ConvertOrdAndPartialOrdTraits(const clang::CXXRecordDecl *decl,
3234
const clang::FunctionDecl *op) override;
3335

@@ -37,7 +39,9 @@ class ConverterRefCount final : public Converter {
3739

3840
void AddByteReprTrait(const clang::RecordDecl *decl) override;
3941

40-
void AddDefaultTrait(const clang::CXXRecordDecl *decl) override;
42+
void AddDefaultTrait(const clang::RecordDecl *decl) override;
43+
44+
void AddDefaultTraitForUnion(const clang::RecordDecl *decl) override;
4145

4246
std::string GetSelfMaybeWithMut(const clang::CXXMethodDecl *decl) override;
4347

tests/benchmarks/out/unsafe/bfs.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::collections::BTreeMap;
66
use std::io::{Read, Seek, Write};
77
use std::os::fd::{AsFd, FromRawFd, IntoRawFd};
88
use std::rc::Rc;
9+
#[repr(C)]
910
#[derive(Copy, Clone, Default)]
1011
pub struct Queue {
1112
pub elems: *mut u32,
@@ -35,11 +36,13 @@ impl Queue {
3536
return ((self.back) == (0_u64));
3637
}
3738
}
39+
#[repr(C)]
3840
#[derive(Copy, Clone, Default)]
3941
pub struct GraphNode {
4042
pub vertex: u32,
4143
pub next: *mut GraphNode,
4244
}
45+
#[repr(C)]
4346
#[derive(Copy, Clone, Default)]
4447
pub struct Graph {
4548
pub V: u32,

tests/benchmarks/out/unsafe/bst.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::collections::BTreeMap;
66
use std::io::{Read, Seek, Write};
77
use std::os::fd::{AsFd, FromRawFd, IntoRawFd};
88
use std::rc::Rc;
9+
#[repr(C)]
910
#[derive(Copy, Clone, Default)]
1011
pub struct node_t {
1112
pub left: *mut node_t,

tests/ub/out/unsafe/ub6.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::collections::BTreeMap;
66
use std::io::{Read, Seek, Write};
77
use std::os::fd::{AsFd, FromRawFd, IntoRawFd};
88
use std::rc::Rc;
9+
#[repr(C)]
910
#[derive(Copy, Clone, Default)]
1011
pub struct Pair {
1112
pub x1: *mut i32,
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
extern crate libcc2rs;
2+
use libcc2rs::*;
3+
use std::cell::RefCell;
4+
use std::collections::BTreeMap;
5+
use std::io::prelude::*;
6+
use std::io::{Read, Seek, Write};
7+
use std::os::fd::AsFd;
8+
use std::rc::{Rc, Weak};
9+
#[repr(C)]
10+
#[derive()]
11+
pub struct record {
12+
pub code: Value<u16>,
13+
pub lo: Value<u16>,
14+
pub hi: Value<u32>,
15+
pub pad: Value<Box<[u8]>>,
16+
}
17+
impl ByteRepr for record {}
18+
#[repr(C)]
19+
#[derive(Copy, Clone)]
20+
pub union Container_anon_15_3 {
21+
pub h: Value<record>,
22+
pub raw_: Value<Box<[u8]>>,
23+
}
24+
impl Default for Container_anon_15_3 {
25+
fn default() -> Self {
26+
unsafe { std::mem::zeroed() }
27+
}
28+
}
29+
#[repr(C)]
30+
#[derive(Default)]
31+
pub struct Container {
32+
pub view: Value<Container_anon_15_3>,
33+
}
34+
impl ByteRepr for Container {}
35+
pub fn fill_0(out: AnyPtr, cap: u64) {
36+
let out: Value<AnyPtr> = Rc::new(RefCell::new(out));
37+
let cap: Value<u64> = Rc::new(RefCell::new(cap));
38+
let src: Value<Box<[u8]>> = Rc::new(RefCell::new(Box::new([
39+
0_u8,
40+
<u8>::default(),
41+
<u8>::default(),
42+
<u8>::default(),
43+
<u8>::default(),
44+
<u8>::default(),
45+
<u8>::default(),
46+
<u8>::default(),
47+
<u8>::default(),
48+
<u8>::default(),
49+
<u8>::default(),
50+
<u8>::default(),
51+
<u8>::default(),
52+
<u8>::default(),
53+
<u8>::default(),
54+
<u8>::default(),
55+
])));
56+
(*src.borrow_mut())[(0) as usize] = 2_u8;
57+
(*src.borrow_mut())[(1) as usize] = 0_u8;
58+
(*src.borrow_mut())[(2) as usize] = 0_u8;
59+
(*src.borrow_mut())[(3) as usize] = 80_u8;
60+
(*src.borrow_mut())[(4) as usize] = 127_u8;
61+
(*src.borrow_mut())[(5) as usize] = 0_u8;
62+
(*src.borrow_mut())[(6) as usize] = 0_u8;
63+
(*src.borrow_mut())[(7) as usize] = 1_u8;
64+
let n: Value<u64> = Rc::new(RefCell::new(
65+
if ((::std::mem::size_of::<[u8; 16]>() as u64) < (*cap.borrow())) {
66+
::std::mem::size_of::<[u8; 16]>() as u64
67+
} else {
68+
(*cap.borrow())
69+
},
70+
));
71+
{
72+
(*out.borrow()).memcpy(
73+
&((src.as_pointer() as Ptr<u8>) as Ptr<u8>).to_any(),
74+
(*n.borrow()) as usize,
75+
);
76+
(*out.borrow()).clone()
77+
};
78+
}
79+
pub fn main() {
80+
std::process::exit(main_0());
81+
}
82+
fn main_0() -> i32 {
83+
let c: Value<Container> = <Value<Container>>::default();
84+
{
85+
((c.as_pointer()) as Ptr<Container>).to_any().memset(
86+
(0) as u8,
87+
::std::mem::size_of::<Container>() as u64 as usize,
88+
);
89+
((c.as_pointer()) as Ptr<Container>).to_any().clone()
90+
};
91+
({
92+
let _out: AnyPtr = (((*c.borrow()).view.as_pointer()).to_strong().as_pointer() as AnyPtr);
93+
let _cap: u64 = ::std::mem::size_of::<Container_anon_15_3>() as u64;
94+
fill_0(_out, _cap)
95+
});
96+
assert!((((*(*(*(*c.borrow()).view.borrow()).h.borrow()).code.borrow()) as i32) == (2)));
97+
assert!(
98+
((((((*(*(*c.borrow()).view.borrow()).h.borrow()).lo.as_pointer())
99+
.to_strong()
100+
.as_pointer() as Ptr::<u8>)
101+
.offset((0) as isize)
102+
.read()) as i32)
103+
== (0))
104+
);
105+
assert!(
106+
((((((*(*(*c.borrow()).view.borrow()).h.borrow()).lo.as_pointer())
107+
.to_strong()
108+
.as_pointer() as Ptr::<u8>)
109+
.offset((1) as isize)
110+
.read()) as i32)
111+
== (80))
112+
);
113+
assert!((((*(*(*c.borrow()).view.borrow()).raw_.borrow())[(0) as usize] as i32) == (2)));
114+
assert!(
115+
((((*(*(*c.borrow()).view.borrow()).raw_.borrow())[(3) as usize] as u8) as i32) == (80))
116+
);
117+
return 0;
118+
}

0 commit comments

Comments
 (0)