Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 118 additions & 25 deletions cc_bindings_from_rs/generate_bindings/generate_struct_and_union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use itertools::Itertools;
use proc_macro2::{Ident, Literal, TokenStream};
use query_compiler::post_analysis_typing_env;
use quote::{format_ident, quote};
use rustc_abi::{FieldsShape, VariantIdx, Variants};
use rustc_abi::{Endian, FieldsShape, VariantIdx, Variants};
use rustc_middle::mir::interpret::Scalar;
use rustc_middle::mir::ConstValue;
use rustc_middle::ty::{self, Ty, TyCtxt, TyKind, TypeFlags};
Expand Down Expand Up @@ -597,7 +597,7 @@ pub fn generate_adt<'tcx>(
let relocating_ctor_snippets = generate_relocating_ctor(db, &core.cc_short_name);

let mut member_function_names = HashSet::<String>::new();
let adt_based_ctors = generate_adt_based_ctors(db, core.clone());
let adt_based_ctors = generate_adt_based_ctors(db, core.clone(), &mut member_function_names);

let impl_items_snippets = tcx
.inherent_impls(core.def_id)
Expand Down Expand Up @@ -894,6 +894,7 @@ fn anonymous_field_ident(index: usize) -> Ident {
fn generate_adt_based_ctors<'tcx>(
db: &BindingsGenerator<'tcx>,
core: Rc<AdtCoreBindings<'tcx>>,
member_function_names: &mut HashSet<String>,
) -> ApiSnippets<'tcx> {
let TyKind::Adt(adt_def, _) = core.self_ty.kind() else {
panic!("Attempted to generate constructor for a non-ADT type: {:?}", core.self_ty)
Expand All @@ -907,22 +908,43 @@ fn generate_adt_based_ctors<'tcx>(

adt_def
.variants()
.iter()
.map(|variant| {
generate_variant_ctor(db, core.clone(), variant).unwrap_or_else(|err| {
if should_suppress_errors {
Default::default()
} else {
generate_unsupported_def(db, variant.def_id, err).into_main_api()
}
})
.iter_enumerated()
.map(|(variant_index, variant)| {
generate_variant_ctor(db, core.clone(), member_function_names, variant_index, variant)
.unwrap_or_else(|err| {
if should_suppress_errors {
Default::default()
} else {
generate_unsupported_def(db, variant.def_id, err).into_main_api()
}
})
})
.collect()
}

#[derive(Debug, PartialEq, Eq)]
enum EnumKind {
/// Enum with `#[repr(C)]` where bindings get `tag` and "payload" fields.
ReprC,
/// Enums (e.g. `#[repr(Rust)]` and `#[repr(u32)]`) that are represented as a blob of bytes
/// (i.e. bindings only have a single, private `__opaque_blob_of_bytes` field).
OpaqueBlobOfBytes,
}
fn get_enum_kind<'tcx>(db: &BindingsGenerator<'tcx>, ty: Ty<'tcx>) -> Option<EnumKind> {
let enum_adt_def = ty.ty_adt_def().filter(|adt_def| adt_def.is_enum())?;
let repr_attrs = db.repr_attrs(enum_adt_def.did());
if repr_attrs.contains(&rustc_hir::attrs::ReprC) {
Some(EnumKind::ReprC)
} else {
Some(EnumKind::OpaqueBlobOfBytes)
}
}

fn generate_variant_ctor<'tcx>(
db: &BindingsGenerator<'tcx>,
core: Rc<AdtCoreBindings<'tcx>>,
member_function_names: &mut HashSet<String>,
variant_index: VariantIdx,
variant: &'tcx ty::VariantDef,
) -> Result<ApiSnippets<'tcx>> {
let tcx = db.tcx();
Expand Down Expand Up @@ -1017,8 +1039,64 @@ fn generate_variant_ctor<'tcx>(
main_api_params.is_empty(),
"Constructing enum variants with payload is unsupported: b/487356976, b/487357254",
);
ensure!(
get_enum_kind(db, core.self_ty) == Some(EnumKind::OpaqueBlobOfBytes),
"Constructors of #[repr(C)] enums don't work (see b/487399481 and cl/877428937)",
);

bail!("Constructing enum variants with no payload is not supported yet: b/487357254")
let tag_offset = {
let layout = get_layout(tcx, core.self_ty).expect("Should verify layout earlier");
use rustc_abi::Variants::*;
let tag_field = match layout.variants() {
Empty => unreachable!("Uninhabited types should be rejected earlier"),
Single { .. } => unreachable!("Single+NoPayload=ZST get rejected earlier"),
Multiple { tag_field, .. } => tag_field,
};
layout.fields().offset(tag_field.as_usize()).bytes() as usize
};
let adt_size = core.size_in_bytes as usize;
let (tag_value, tag_size): (u128, usize) = {
let ty::util::Discr { val, ty } = core
.self_ty
.discriminant_for_variant(tcx, variant_index)
.expect("Invalid VariantIdx");
let (size, _signed) = ty.int_size_and_signed(tcx);
let size = size.bytes() as usize;
let size = size.min(adt_size - tag_offset);
(val, size)
};
let tag_bytes = match tcx.sess.target.endian {
Endian::Little => &tag_value.to_le_bytes()[..tag_size],
Endian::Big => &tag_value.to_be_bytes()[std::mem::size_of::<u128>() - tag_size..],
};
let bytes = {
let mut bytes = vec![0; adt_size];
bytes[tag_offset..tag_offset + tag_bytes.len()].copy_from_slice(tag_bytes);
bytes.into_iter().map(Literal::u8_unsuffixed).collect_vec()
};
let method_name = format_cc_ident(db, &format!("Make{}", variant.name.as_str()))?;
let was_inserted = member_function_names.insert(method_name.to_string());
assert!(was_inserted, "No conflicts expected for 'constructor' names: {method_name}");
let doc_comment = generate_doc_comment(db, variant.def_id);
Ok(ApiSnippets {
main_api: CcSnippet {
prereqs,
tokens: quote! {
__NEWLINE__ #doc_comment
static #adt_cc_name #method_name();
__NEWLINE__
},
},
cc_details: CcSnippet::new(quote! {
__NEWLINE__
__COMMENT__ "`static` constructor"
inline #adt_cc_name #adt_cc_name::#method_name() {
return #adt_cc_name(PrivateBytesTag{}, { #( #bytes ),* });
}
__NEWLINE__
}),
..Default::default()
})
}
ty::AdtKind::Union => bail!("Crubit doesn't provide bindings for constructing unions"),
}
Expand Down Expand Up @@ -1092,10 +1170,11 @@ pub(crate) fn generate_fields<'tcx>(
};

// Used for generating enum bindings.
let is_supported_enum = adt_def.is_enum() && repr_attrs.contains(&rustc_hir::attrs::ReprC);

let tag_size_with_padding =
if is_supported_enum { get_tag_size_with_padding(layout) } else { 0 };
let enum_kind = get_enum_kind(db, self_ty);
let tag_size_with_padding = match enum_kind {
Some(EnumKind::ReprC) => get_tag_size_with_padding(layout),
None | Some(EnumKind::OpaqueBlobOfBytes) => 0,
};

let variant_sizes = match layout_variants {
Variants::Multiple { tag: _, tag_encoding: _, tag_field: _, variants } => {
Expand All @@ -1117,7 +1196,7 @@ pub(crate) fn generate_fields<'tcx>(
};
let variants_fields: Vec<Vec<Field<'tcx>>> = match adt_def.adt_kind() {
// Handle cases of unsupported ADTs.
ty::AdtKind::Enum if !is_supported_enum => {
ty::AdtKind::Enum if enum_kind != Some(EnumKind::ReprC) => {
vec![err_fields(anyhow!("No support for bindings of individual non-repr(C) `enum`s"))]
}

Expand Down Expand Up @@ -1224,7 +1303,7 @@ pub(crate) fn generate_fields<'tcx>(
// `def_span`.
variants_fields[variant_index][index].offset = offset.bytes();

if is_supported_enum {
if enum_kind == Some(EnumKind::ReprC) {
// Find the offset for the variant, and take it into
// account.
variants_fields[variant_index][index].offset -=
Expand Down Expand Up @@ -1283,9 +1362,10 @@ pub(crate) fn generate_fields<'tcx>(
})
.collect()
}
ty::AdtKind::Enum => {
// Check if each variant has the tag (and appropriate padding) in the front.
if !is_supported_enum {
// Check if each variant has the tag (and appropriate padding) in the front.
ty::AdtKind::Enum => match enum_kind {
None => unreachable!("ty::AdtKind::Enum with no enum_kind is impossible"),
Some(EnumKind::OpaqueBlobOfBytes) => {
variants_fields
.iter()
.flatten()
Expand All @@ -1296,7 +1376,8 @@ pub(crate) fn generate_fields<'tcx>(
quote! { static_assert(#offset == offsetof(#adt_cc_name, #cc_name)); }
})
.collect()
} else {
}
Some(EnumKind::ReprC) => {
let variant_offset_assertions: TokenStream = adt_def.variants().iter_enumerated().map(|(variant_index, variant_def)| {
let cc_variant_struct_name = format_cc_ident(db, variant_def.ident(tcx).as_str())
.unwrap_or_else(|_err| format_ident!("err_field"));
Expand Down Expand Up @@ -1330,7 +1411,7 @@ pub(crate) fn generate_fields<'tcx>(
}).collect();
quote! {#variant_offset_assertions #variant_field_assertions }
}
}
},
};

CcSnippet::with_include(
Expand All @@ -1343,7 +1424,7 @@ pub(crate) fn generate_fields<'tcx>(
)
};

let rs_details: RsSnippet = if is_supported_enum {
let rs_details: RsSnippet = if enum_kind == Some(EnumKind::ReprC) {
// Offsets for enums is an experimental feature.
// TODO(b/355642210): Add these assertions once they're not
// experiemtnal. let adt_rs_name =
Expand Down Expand Up @@ -1562,6 +1643,7 @@ pub(crate) fn generate_fields<'tcx>(

// For structs and unions, we can just flatten the fields variant. For enums, we
// need to handle each variant separately.
let adt_size = Literal::u64_unsuffixed(layout.size().bytes());
let fields = match adt_def.adt_kind() {
ty::AdtKind::Struct | ty::AdtKind::Union => {
let mut current_visibility = CcFieldVisState::public();
Expand All @@ -1571,7 +1653,7 @@ pub(crate) fn generate_fields<'tcx>(
.map(|field| get_field_tokens(field, &mut prereqs, &mut current_visibility))
.collect()
}
ty::AdtKind::Enum if !is_supported_enum => variants_fields
ty::AdtKind::Enum if enum_kind != Some(EnumKind::ReprC) => variants_fields
.into_iter()
.flatten()
.map(|field| get_field_tokens(field, &mut prereqs, &mut Default::default()))
Expand Down Expand Up @@ -1743,10 +1825,21 @@ pub(crate) fn generate_fields<'tcx>(
}
}
};
let enum_opaque_bytes_ctor = match enum_kind {
Some(EnumKind::OpaqueBlobOfBytes) => quote! {
private:
struct PrivateBytesTag {};
constexpr #cc_short_name(PrivateBytesTag,
std::array<unsigned char, #adt_size> bytes)
: __opaque_blob_of_bytes(bytes) {}
},
_ => quote! {},
};
CcSnippet {
prereqs,
tokens: quote! {
#fields
#enum_opaque_bytes_ctor
#assertions_method_decl
},
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,24 @@ fn test_format_item_struct_with_tuple_fields() {
});
}

#[test]
fn test_format_item_enum_with_one_byte_size() {
let test_src = r#"
#[derive(Clone, Copy)]
pub enum StringType {
Literal,
NotLiteral,
}
"#;
// The input above used to crash (tag size is 8 bytes - size of an `usize`, but
// the ADT size is only 1 byte). This test just makes sure that this input no
// longer triggers a crash (i.e. there is intentionally no other
// verification).
test_format_item(test_src, "StringType", |result| {
result.unwrap().unwrap();
});
}

#[test]
fn test_format_item_unsupported_struct_with_name_that_is_reserved_keyword() {
let test_src = r#"
Expand Down
49 changes: 49 additions & 0 deletions cc_bindings_from_rs/test/enums/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,52 @@ pub mod repr_c_clone_active_variant {
matches!(e, CloneActiveVariant::C(_))
}
}

pub mod repr_rust {
/// Doc comment of RustReprEnumWithNoPayload.
pub enum RustReprEnumWithNoPayload {
/// Doc comment of Variant1.
Variant1,
Variant2,
Variant3,
}

impl RustReprEnumWithNoPayload {
pub fn get_variant_number(&self) -> i32 {
match self {
Self::Variant1 => 1,
Self::Variant2 => 2,
Self::Variant3 => 3,
}
}
}

pub enum RustReprWithSingleNoPayloadVariant {
SingleVariant,
}

pub enum RustReprWithSingleTuplePayloadVariant {
SingleVariant(i32),
}
}

pub mod repr_int {
/// Two `NoPayloadX` variants to test that the tag is correctly set
/// (`NoPayload1` should have a tag of 0 and therefore `NoPayload2` is a
/// slightly better test for things like encoding the tag value with the
/// proper endianness, especially given that the tag is 4 bytes wide).
#[repr(u32)]
pub enum IntReprEnumWithNoPayload {
NoPayload1,
NoPayload2 = 1234,
}

impl IntReprEnumWithNoPayload {
pub fn is_no_payload1(&self) -> bool {
matches!(self, Self::NoPayload1)
}
pub fn is_no_payload2(&self) -> bool {
matches!(self, Self::NoPayload2)
}
}
}
Loading