Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 53 additions & 48 deletions cpp2rust/converter/converter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2997,12 +2997,64 @@ bool Converter::VisitCXXStdInitializerListExpr(
return false;
}

std::string Converter::GetArrayDefaultAsString(clang::QualType qual_type) {
if (auto *array_type = clang::dyn_cast<clang::ConstantArrayType>(qual_type)) {
auto size_as_string = GetNumAsString(array_type->getSize());
auto element_type = array_type->getElementType();
auto element_type_as_string = GetDefaultAsString(element_type);
return std::format("[{}; {}]", element_type_as_string,
size_as_string.c_str());
}
if (auto *array_type =
clang::dyn_cast<clang::IncompleteArrayType>(qual_type)) {
return GetDefaultAsString(array_type->getElementType());
}
if (Mapper::ToString(qual_type).contains("std::array")) {
assert(GetTemplateArgs(qual_type).has_value());
auto template_args = *GetTemplateArgs(qual_type);
assert(template_args.size() == 2);
auto array_size = template_args[1];
unsigned size = 0;
switch (array_size.getKind()) {
case clang::TemplateArgument::Expression: {
auto array_size_expr = array_size.getAsExpr();
assert(array_size_expr && !array_size_expr->isValueDependent());
clang::Expr::EvalResult result;
ENSURE(array_size_expr->EvaluateAsInt(result, ctx_));
size = result.Val.getInt().getZExtValue();
break;
}
case clang::TemplateArgument::Integral: {
size = array_size.getAsIntegral().getZExtValue();
break;
}
default:
assert(0 && "Unsupported array size kind");
break;
}
return std::format(
"std::array::from_fn::<_, {}, _>(|_| Default::default()).to_vec()",
size);
}
return {};
}

std::string Converter::GetDefaultAsString(clang::QualType qual_type) {
if (IsVaListType(qual_type)) {
computed_expr_type_ = ComputedExprType::FreshValue;
return "VaList::default()";
}

if (auto arr = GetArrayDefaultAsString(qual_type); !arr.empty()) {
computed_expr_type_ = ComputedExprType::FreshValue;
return arr;
}

if (auto init = Mapper::MapInitializer(qual_type); !init.empty()) {
computed_expr_type_ = ComputedExprType::FreshValue;
return init;
}

if (qual_type->isPointerType()) {
auto pointee = qual_type->getPointeeType();
if (pointee->isFunctionType()) {
Expand All @@ -3014,54 +3066,7 @@ std::string Converter::GetDefaultAsString(clang::QualType qual_type) {
}

computed_expr_type_ = ComputedExprType::FreshValue;

if (auto *array_type = clang::dyn_cast<clang::ConstantArrayType>(qual_type)) {
auto size_as_string = GetNumAsString(array_type->getSize());
auto element_type = array_type->getElementType();
auto element_type_as_string = GetDefaultAsString(element_type);
return std::format("[{}; {}]", element_type_as_string,
size_as_string.c_str());
} else if (auto *array_type =
clang::dyn_cast<clang::IncompleteArrayType>(qual_type)) {
return GetDefaultAsString(array_type->getElementType());
} else {
auto qual_type_str = Mapper::ToString(qual_type);
if (qual_type_str == "struct std::pair") {
auto template_args = *GetTemplateArgs(qual_type);
auto first_type = template_args[0].getAsType();
auto second_type = template_args[1].getAsType();
return std::format("({}, {})", GetDefaultAsString(first_type),
GetDefaultAsString(second_type));
} else if (qual_type_str.contains("std::array")) {
assert(GetTemplateArgs(qual_type).has_value());
auto template_args = *GetTemplateArgs(qual_type);
assert(template_args.size() == 2);
auto array_size = template_args[1];
unsigned size = 0;
switch (array_size.getKind()) {
case clang::TemplateArgument::Expression: {
auto array_size_expr = array_size.getAsExpr();
assert(array_size_expr && !array_size_expr->isValueDependent());
clang::Expr::EvalResult result;
ENSURE(array_size_expr->EvaluateAsInt(result, ctx_));
size = result.Val.getInt().getZExtValue();
break;
}
case clang::TemplateArgument::Integral: {
size = array_size.getAsIntegral().getZExtValue();
break;
}
default:
assert(0 && "Unsupported array size kind");
break;
}
return std::format(
"std::array::from_fn::<_, {}, _>(|_| Default::default()).to_vec()",
size);
} else {
return GetDefaultAsStringFallback(qual_type);
}
}
return GetDefaultAsStringFallback(qual_type);
}

std::string Converter::GetDefaultAsStringFallback(clang::QualType qual_type) {
Expand Down
2 changes: 2 additions & 0 deletions cpp2rust/converter/converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,8 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {

virtual std::string GetDefaultAsString(clang::QualType qual_type);

virtual std::string GetArrayDefaultAsString(clang::QualType qual_type);

virtual std::string GetDefaultAsStringFallback(clang::QualType qual_type);

virtual std::string ConvertVarDefaultInit(clang::QualType qual_type);
Expand Down
12 changes: 12 additions & 0 deletions cpp2rust/converter/mapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,18 @@ std::string Map(clang::QualType qual_type) {
return {};
}

std::string MapInitializer(clang::QualType qual_type) {
auto type_str = ToString(qual_type);
auto [rule, subs] = search(types_, type_str, GetTypeMapKey(type_str));
if (rule && !rule->initializer.empty()) {
for (auto &ty : subs) {
ty = mapTypeStringRecursive(ty);
}
return instantiateTgt(subs, rule->initializer);
}
return {};
}

bool MapsToPointer(clang::QualType qual_type) {
auto rule = search(qual_type);
return rule && rule->type_info.is_pointer();
Expand Down
1 change: 1 addition & 0 deletions cpp2rust/converter/mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ bool Contains(clang::QualType qual_type);
bool Contains(const clang::Expr *expr);

std::string Map(clang::QualType qual_type);
std::string MapInitializer(clang::QualType qual_type);
const TranslationRule::ExprRule *GetExprRule(const clang::Expr *expr);
std::string MapFunctionName(const clang::FunctionDecl *decl);
std::string InstantiateTemplate(const clang::Expr *expr, unsigned n);
Expand Down
45 changes: 25 additions & 20 deletions cpp2rust/converter/models/converter_refcount.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1623,11 +1623,36 @@ bool ConverterRefCount::VisitCXXDefaultArgExpr(clang::CXXDefaultArgExpr *expr) {
return Converter::VisitCXXDefaultArgExpr(expr);
}

std::string
ConverterRefCount::GetArrayDefaultAsString(clang::QualType qual_type) {
if (auto *array_type = clang::dyn_cast<clang::ConstantArrayType>(qual_type)) {
const auto &size = array_type->getSize();
auto size_as_string = GetNumAsString(size);
auto element_type = array_type->getElementType();
PushConversionKind push(*this, ConversionKind::Unboxed);
auto element_type_as_string = ToString(element_type);
auto default_as_string = GetDefaultAsString(element_type);
return std::format("(0..{}).map(|_| {}).collect::<Box<[{}]>>()",
size_as_string.c_str(), default_as_string,
element_type_as_string);
}
return Converter::GetArrayDefaultAsString(qual_type);
}

std::string ConverterRefCount::GetDefaultAsString(clang::QualType qual_type) {
if (IsVaListType(qual_type)) {
return BoxValue("VaList::default()");
}

if (auto arr = GetArrayDefaultAsString(qual_type); !arr.empty()) {
return BoxValue(std::move(arr));
}

if (auto init = Mapper::MapInitializer(qual_type); !init.empty()) {
computed_expr_type_ = ComputedExprType::FreshValue;
return BoxValue(std::move(init));
}

std::string ret;
if (qual_type->isPointerType()) {
auto pointee_type = qual_type->getPointeeType();
Expand All @@ -1641,26 +1666,6 @@ std::string ConverterRefCount::GetDefaultAsString(clang::QualType qual_type) {
ret = std::format("Ptr::<{}>::null()", ConvertPointeeType(qual_type));
}
}
} else if (auto *array_type =
clang::dyn_cast<clang::ConstantArrayType>(qual_type)) {
const auto &size = array_type->getSize();
auto size_as_string = GetNumAsString(size);
auto element_type = array_type->getElementType();
PushConversionKind push(*this, ConversionKind::Unboxed);
auto element_type_as_string = ToString(element_type);
auto default_as_string = GetDefaultAsString(element_type);
ret = std::format("(0..{}).map(|_| {}).collect::<Box<[{}]>>()",
size_as_string.c_str(), default_as_string,
element_type_as_string);
} else if (Mapper::ToString(qual_type) == "struct std::pair") {
auto template_args = *GetTemplateArgs(qual_type);
auto first_type = template_args[0].getAsType();
auto second_type = template_args[1].getAsType();
ret = std::format("(Rc::new(RefCell::new({})), Rc::new(RefCell::new({})))",
GetDefaultAsString(first_type),
GetDefaultAsString(second_type));
} else if (Mapper::ToString(qual_type).contains("std::array")) {
ret = Converter::GetDefaultAsString(qual_type);
} else {
return Converter::GetDefaultAsString(qual_type);
}
Expand Down
2 changes: 2 additions & 0 deletions cpp2rust/converter/models/converter_refcount.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ class ConverterRefCount final : public Converter {

std::string GetDefaultAsString(clang::QualType qual_type) override;

std::string GetArrayDefaultAsString(clang::QualType qual_type) override;

void ConvertEqualsNullPtr(clang::Expr *expr) override;

std::string GetDefaultAsStringFallback(clang::QualType qual_type) override;
Expand Down
2 changes: 1 addition & 1 deletion rules/pair/ir_unsafe.json
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@
}
},
"t1": {
"init": "(T1::default(), T2::default())",
"init": "<(T1, T2)>::default()",
"type": "(T1, T2)"
}
}
2 changes: 1 addition & 1 deletion rules/pair/tgt_unsafe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ struct T1;
struct T2;

fn types() {
let t1: (T1, T2) = (T1::default(), T2::default());
let t1: (T1, T2) = <(T1, T2)>::default();
}

unsafe fn f1<T1, T2>(a0: (T1, T2)) -> T2 {
Expand Down
2 changes: 1 addition & 1 deletion rules/stdio/ir_unsafe.json
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@
}
},
"t1": {
"init": "Default::default()",
"init": "std::ptr::null_mut()",
"type": "*mut ::std::fs::File",
"is_unsafe_pointer": true
}
Expand Down
2 changes: 1 addition & 1 deletion rules/stdio/tgt_unsafe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use libcc2rs::*;
use std::io::prelude::*;

fn types() -> Result<(), Box<dyn std::error::Error>> {
let t1: *mut ::std::fs::File = Default::default();
let t1: *mut ::std::fs::File = std::ptr::null_mut();
Ok(())
}

Expand Down
3 changes: 1 addition & 2 deletions tests/unit/out/refcount/fflush_null.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ pub fn main() {
std::process::exit(main_0());
}
fn main_0() -> i32 {
let file_ptr: Value<Ptr<::std::fs::File>> =
Rc::new(RefCell::new(Ptr::<::std::fs::File>::null()));
let file_ptr: Value<Ptr<::std::fs::File>> = Rc::new(RefCell::new(Ptr::null()));
return if !(*file_ptr.borrow()).is_null() {
match (*file_ptr.borrow()).with_mut(|v| v.sync_all()) {
Ok(_) => 0,
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/out/refcount/fn_ptr_stdlib_compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ fn main_0() -> i32 {
let _arg0: AnyPtr = AnyPtr::default();
let _arg1: u64 = 0_u64;
let _arg2: u64 = 0_u64;
let _arg3: Ptr<::std::fs::File> = Ptr::<::std::fs::File>::null();
let _arg3: Ptr<::std::fs::File> = Ptr::null();
(*(*f3.borrow()))(_arg0, _arg1, _arg2, _arg3)
}) == 22_u64)
);
Expand Down Expand Up @@ -234,7 +234,7 @@ fn main_0() -> i32 {
let _arg0: AnyPtr = AnyPtr::default();
let _arg1: u64 = 0_u64;
let _arg2: u64 = 0_u64;
let _arg3: Ptr<::std::fs::File> = Ptr::<::std::fs::File>::null();
let _arg3: Ptr<::std::fs::File> = Ptr::null();
(*(*g3.borrow()))(_arg0, _arg1, _arg2, _arg3)
}) == 33_u64)
);
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/out/refcount/global_without_initializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ thread_local!(
pub static s: Value<Ptr<S>> = Rc::new(RefCell::new(Ptr::<S>::null()));
);
thread_local!(
pub static file: Value<Ptr<::std::fs::File>> =
Rc::new(RefCell::new(Ptr::<::std::fs::File>::null()));
pub static file: Value<Ptr<::std::fs::File>> = Rc::new(RefCell::new(Ptr::null()));
);
thread_local!(
pub static size: Value<u64> = <Value<u64>>::default();
Expand Down
Loading