Skip to content

Commit 9811d05

Browse files
authored
Add support for translating C structs (#10)
1 parent 9e0c43f commit 9811d05

10 files changed

Lines changed: 375 additions & 109 deletions

File tree

cpp2rust/converter/converter.cpp

Lines changed: 123 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -519,9 +519,11 @@ bool IsPointerType(clang::QualType qual_type) {
519519
->getCanonicalTypeInternal()));
520520
}
521521

522-
bool Converter::RecordDerivesDefault(const clang::CXXRecordDecl *decl) {
523-
if (GetUserDefinedDefaultConstructor(decl)) {
524-
return false;
522+
bool Converter::RecordDerivesDefault(const clang::RecordDecl *decl) {
523+
if (auto cxx_decl = clang::dyn_cast<clang::CXXRecordDecl>(decl)) {
524+
if (GetUserDefinedDefaultConstructor(cxx_decl)) {
525+
return false;
526+
}
525527
}
526528

527529
for (auto f : decl->fields()) {
@@ -546,7 +548,7 @@ bool Converter::RecordDerivesDefault(const clang::CXXRecordDecl *decl) {
546548
return true;
547549
}
548550

549-
static bool recordDerivesCopy(const clang::CXXRecordDecl *decl) {
551+
static bool recordDerivesCopy(const clang::RecordDecl *decl) {
550552
for (auto f : decl->fields()) {
551553
// Records that contain std::vector, std::array, std::string or anything
552554
// that is translated to Vec<>, do not derive Copy
@@ -569,8 +571,8 @@ static bool recordDerivesCopy(const clang::CXXRecordDecl *decl) {
569571
}
570572
}
571573

572-
// Look recursively into fields that are CXXRecordDecl
573-
if (auto field_record = f->getType()->getAsCXXRecordDecl()) {
574+
// Look recursively into fields that are RecordDecl
575+
if (auto field_record = f->getType()->getAsRecordDecl()) {
574576
if (!recordDerivesCopy(field_record)) {
575577
return false;
576578
}
@@ -580,6 +582,109 @@ static bool recordDerivesCopy(const clang::CXXRecordDecl *decl) {
580582
return true;
581583
}
582584

585+
bool Converter::VisitRecordDecl(clang::RecordDecl *decl) {
586+
decl->dumpColor();
587+
588+
// VisitCXXRecordDecl already visited the record
589+
if (clang::isa<clang::CXXRecordDecl>(decl)) {
590+
return true;
591+
}
592+
593+
if (!decl->isCompleteDefinition()) {
594+
return false;
595+
}
596+
597+
if (!record_decls_.insert(GetID(decl)).second) {
598+
return false;
599+
}
600+
601+
Mapper::AddRuleForUserDefinedType(decl);
602+
EmitRustStruct(decl);
603+
604+
return false;
605+
}
606+
607+
void Converter::EmitRustStruct(clang::RecordDecl *decl) {
608+
// Enums and static variables. In rust they live outside the record
609+
for (auto *d : decl->decls()) {
610+
if (auto *enum_decl = llvm::dyn_cast<clang::EnumDecl>(d)) {
611+
VisitEnumDecl(enum_decl);
612+
}
613+
if (auto *var_decl = clang::dyn_cast<clang::VarDecl>(d)) {
614+
VisitVarDecl(var_decl);
615+
}
616+
}
617+
618+
// Inner records. In rust they live outside the record
619+
for (auto *d : decl->decls()) {
620+
if (auto *nested = clang::dyn_cast<clang::RecordDecl>(d)) {
621+
if (!nested->isImplicit()) {
622+
inner_structs_[GetID(nested)] = GetRecordName(nested);
623+
if (auto *cxx = clang::dyn_cast<clang::CXXRecordDecl>(nested)) {
624+
VisitCXXRecordDecl(cxx);
625+
} else {
626+
VisitRecordDecl(nested);
627+
}
628+
}
629+
}
630+
}
631+
632+
// Derived traits
633+
StrCat("#[derive(");
634+
for (auto *attr : GetStructAttributes(decl)) {
635+
StrCat(attr, ",");
636+
}
637+
StrCat(")]");
638+
639+
// Fields
640+
auto access = clang::dyn_cast<clang::CXXRecordDecl>(decl)
641+
? AccessSpecifierAsString(decl->getAccess())
642+
: keyword::kPub;
643+
StrCat(access, keyword::kStruct, GetRecordName(decl),
644+
token::kOpenCurlyBracket);
645+
for (auto *field : decl->fields()) {
646+
VisitFieldDecl(field);
647+
}
648+
StrCat(token::kCloseCurlyBracket);
649+
650+
// C++ method decls
651+
if (auto *cxx = clang::dyn_cast<clang::CXXRecordDecl>(decl)) {
652+
auto struct_name = GetRecordName(cxx);
653+
654+
ConvertCXXMethodDecls(
655+
cxx, std::string(keyword::kImpl) + ' ' + struct_name,
656+
[](const auto *method) {
657+
return !method->isImplicit() &&
658+
!(method->getDefinition() &&
659+
method->getDefinition()->isDefaulted()) &&
660+
(method->isThisDeclarationADefinition() ||
661+
clang::isa<clang::CXXConstructorDecl>(method)) &&
662+
!method->isVirtual() &&
663+
!clang::isa<clang::CXXDestructorDecl>(method);
664+
});
665+
666+
if (cxx->bases_begin() != cxx->bases_end()) {
667+
ConvertCXXMethodDecls(
668+
cxx,
669+
std::format("{} impl {} for {}", keyword_unsafe_,
670+
GetUnsafeTypeAsString(cxx->bases_begin()->getType()),
671+
struct_name),
672+
[](const auto *method) {
673+
return !method->isImplicit() && method->isVirtual();
674+
});
675+
}
676+
}
677+
678+
// Traits
679+
if (auto *cxx = clang::dyn_cast<clang::CXXRecordDecl>(decl)) {
680+
AddOrdTrait(cxx);
681+
AddCloneTrait(cxx);
682+
AddDropTrait(cxx);
683+
AddDefaultTrait(cxx);
684+
}
685+
AddByteReprTrait(decl);
686+
}
687+
583688
bool Converter::VisitCXXRecordDecl(clang::CXXRecordDecl *decl) {
584689
if (clang::isa<clang::ClassTemplateSpecializationDecl>(decl)) {
585690
materializeTemplateSpecialization(decl);
@@ -623,74 +728,7 @@ bool Converter::VisitCXXRecordDecl(clang::CXXRecordDecl *decl) {
623728
}
624729
}
625730

626-
auto struct_name = GetRecordName(decl);
627-
628-
// First visit the nested enums
629-
for (auto d : decl->decls()) {
630-
if (auto enum_decl = llvm::dyn_cast<clang::EnumDecl>(d)) {
631-
VisitEnumDecl(enum_decl);
632-
}
633-
}
634-
635-
for (auto *decl : decl->decls()) {
636-
if (auto var_decl = clang::dyn_cast<clang::VarDecl>(decl)) {
637-
VisitVarDecl(var_decl);
638-
}
639-
}
640-
641-
auto nested = GetNestedStructs(decl);
642-
for (auto *record_decl : nested) {
643-
auto ID = GetID(record_decl);
644-
inner_structs_[ID] = GetRecordName(record_decl);
645-
VisitCXXRecordDecl(record_decl);
646-
}
647-
648-
StrCat(token::kHash, token::kOpenBracket, "derive", token::kOpenParen);
649-
bool derives_default = RecordDerivesDefault(decl);
650-
651-
for (auto *struct_attr : GetStructAttributes(decl, derives_default)) {
652-
StrCat(struct_attr, token::kComma);
653-
}
654-
StrCat(token::kCloseParen, token::kCloseBracket);
655-
656-
auto access_specifier = decl->getAccess();
657-
StrCat(AccessSpecifierAsString(access_specifier), keyword::kStruct,
658-
struct_name, token::kOpenCurlyBracket);
659-
for (auto *field : decl->fields()) {
660-
VisitFieldDecl(field);
661-
}
662-
StrCat(token::kCloseCurlyBracket);
663-
664-
ConvertCXXMethodDecls(
665-
decl, std::string(keyword::kImpl) + ' ' + struct_name,
666-
[](const auto *method) {
667-
return !method->isImplicit() &&
668-
!(method->getDefinition() &&
669-
method->getDefinition()->isDefaulted()) &&
670-
(method->isThisDeclarationADefinition() ||
671-
clang::isa<clang::CXXConstructorDecl>(method)) &&
672-
!method->isVirtual() &&
673-
!clang::isa<clang::CXXDestructorDecl>(method);
674-
});
675-
676-
AddOrdTrait(decl);
677-
AddCloneTrait(decl);
678-
AddDropTrait(decl);
679-
if (!derives_default) {
680-
AddDefaultTrait(decl);
681-
}
682-
AddByteReprTrait(decl);
683-
684-
if (decl->bases_begin() != decl->bases_end()) {
685-
ConvertCXXMethodDecls(
686-
decl,
687-
std::format("{} impl {} for {}", keyword_unsafe_,
688-
GetUnsafeTypeAsString(decl->bases_begin()->getType()),
689-
struct_name),
690-
[](const auto *method) {
691-
return !method->isImplicit() && method->isVirtual();
692-
});
693-
}
731+
EmitRustStruct(decl);
694732
} else {
695733
// FIXME: improve error handling
696734
assert(0 && "unsupported union");
@@ -2797,15 +2835,18 @@ std::string Converter::GetRecordName(const clang::NamedDecl *decl) const {
27972835
}
27982836

27992837
std::vector<const char *>
2800-
Converter::GetStructAttributes(const clang::CXXRecordDecl *decl,
2801-
bool &out_impl_default) {
2838+
Converter::GetStructAttributes(const clang::RecordDecl *decl) {
28022839
std::vector<const char *> struct_attrs = {};
28032840

28042841
if (recordDerivesCopy(decl)) {
28052842
struct_attrs.emplace_back("Copy");
28062843
}
28072844

2808-
if (!decl->defaultedCopyConstructorIsDeleted()) {
2845+
if (auto cxx_decl = clang::dyn_cast<clang::CXXRecordDecl>(decl)) {
2846+
if (!cxx_decl->defaultedCopyConstructorIsDeleted()) {
2847+
struct_attrs.emplace_back("Clone");
2848+
}
2849+
} else /* RecordDecl */ {
28092850
struct_attrs.emplace_back("Clone");
28102851
}
28112852

@@ -3106,11 +3147,14 @@ void Converter::AddCloneTrait(const clang::CXXRecordDecl *decl) {}
31063147
void Converter::AddDropTrait(const clang::CXXRecordDecl *decl) {}
31073148

31083149
void Converter::AddDefaultTrait(const clang::CXXRecordDecl *decl) {
3150+
if (RecordDerivesDefault(decl)) {
3151+
return;
3152+
}
31093153
auto struct_name = GetRecordName(decl);
31103154
StrCat(std::format("impl Default for {}", struct_name),
31113155
token::kOpenCurlyBracket, "fn default() -> Self",
31123156
token::kOpenCurlyBracket);
3113-
if (auto default_ctor = GetUserDefinedDefaultConstructor(decl)) {
3157+
if (auto *default_ctor = GetUserDefinedDefaultConstructor(decl)) {
31143158
StrCat(keyword_unsafe_, token::kOpenCurlyBracket);
31153159
Convert(clang::CXXConstructExpr::Create(
31163160
ctx_, ctx_.getCanonicalTagType(decl), clang::SourceLocation(),
@@ -3133,7 +3177,7 @@ void Converter::AddDefaultTrait(const clang::CXXRecordDecl *decl) {
31333177
StrCat(token::kCloseCurlyBracket, token::kCloseCurlyBracket);
31343178
}
31353179

3136-
void Converter::AddByteReprTrait(const clang::CXXRecordDecl *decl) {}
3180+
void Converter::AddByteReprTrait(const clang::RecordDecl *decl) {}
31373181

31383182
void Converter::ConvertUnsignedArithBinaryOperator(clang::BinaryOperator *op,
31393183
clang::Expr *expr) {

cpp2rust/converter/converter.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,12 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {
9595

9696
virtual bool ConvertLambdaVarDecl(clang::VarDecl *decl);
9797

98+
bool VisitRecordDecl(clang::RecordDecl *decl);
99+
98100
virtual bool VisitCXXRecordDecl(clang::CXXRecordDecl *decl);
99101

102+
void EmitRustStruct(clang::RecordDecl *decl);
103+
100104
virtual bool VisitCXXMethodDecl(clang::CXXMethodDecl *decl);
101105
virtual std::string GetSelfMaybeWithMut(const clang::CXXMethodDecl *decl);
102106

@@ -355,7 +359,7 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {
355359
virtual std::string GetRecordName(const clang::NamedDecl *decl) const;
356360

357361
virtual std::vector<const char *>
358-
GetStructAttributes(const clang::CXXRecordDecl *decl, bool &out_impl_default);
362+
GetStructAttributes(const clang::RecordDecl *decl);
359363

360364
virtual std::string GetUnsafeTypeAsString(clang::QualType qual_type) const;
361365

@@ -410,7 +414,7 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {
410414

411415
virtual void AddDefaultTrait(const clang::CXXRecordDecl *decl);
412416

413-
virtual void AddByteReprTrait(const clang::CXXRecordDecl *decl);
417+
virtual void AddByteReprTrait(const clang::RecordDecl *decl);
414418

415419
virtual void
416420
ConvertUnsignedArithBinaryOperator(clang::BinaryOperator *binary_operator,
@@ -453,7 +457,7 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {
453457

454458
virtual bool IsReferenceType(const clang::Expr *expr) const;
455459

456-
virtual bool RecordDerivesDefault(const clang::CXXRecordDecl *decl);
460+
virtual bool RecordDerivesDefault(const clang::RecordDecl *decl);
457461

458462
std::string *rs_code_;
459463
clang::ASTContext &ctx_;

cpp2rust/converter/models/converter_refcount.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ void ConverterRefCount::AddDropTrait(const clang::CXXRecordDecl *decl) {
494494
StrCat("}");
495495
}
496496

497-
void ConverterRefCount::AddByteReprTrait(const clang::CXXRecordDecl *decl) {
497+
void ConverterRefCount::AddByteReprTrait(const clang::RecordDecl *decl) {
498498
auto struct_name = GetRecordName(decl);
499499
StrCat(std::format("impl ByteRepr for {}", struct_name),
500500
token::kOpenCurlyBracket, token::kCloseCurlyBracket);
@@ -1604,11 +1604,10 @@ ConverterRefCount::ConvertVarDefaultInit(clang::QualType qual_type) {
16041604
}
16051605

16061606
std::vector<const char *>
1607-
ConverterRefCount::GetStructAttributes(const clang::CXXRecordDecl *decl,
1608-
bool &out_impl_default) {
1607+
ConverterRefCount::GetStructAttributes(const clang::RecordDecl *decl) {
16091608
std::vector<const char *> attrs = {};
16101609

1611-
if (out_impl_default) {
1610+
if (RecordDerivesDefault(decl)) {
16121611
attrs.emplace_back("Default");
16131612
}
16141613
return attrs;

cpp2rust/converter/models/converter_refcount.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class ConverterRefCount final : public Converter {
3535

3636
void AddDropTrait(const clang::CXXRecordDecl *decl) override;
3737

38-
void AddByteReprTrait(const clang::CXXRecordDecl *decl) override;
38+
void AddByteReprTrait(const clang::RecordDecl *decl) override;
3939

4040
void AddDefaultTrait(const clang::CXXRecordDecl *decl) override;
4141

@@ -121,8 +121,7 @@ class ConverterRefCount final : public Converter {
121121
std::string ConvertVarDefaultInit(clang::QualType qual_type) override;
122122

123123
std::vector<const char *>
124-
GetStructAttributes(const clang::CXXRecordDecl *decl,
125-
bool &out_impl_default) override;
124+
GetStructAttributes(const clang::RecordDecl *decl) override;
126125

127126
bool MayCauseBorrowMutError(const clang::Expr *lhs, const clang::Expr *rhs);
128127

0 commit comments

Comments
 (0)