|
20 | 20 | ArrayDecl, |
21 | 21 | Cast, |
22 | 22 | Typename, |
| 23 | + Typedef, |
23 | 24 | ) |
24 | 25 | import os |
25 | 26 | import sys |
@@ -72,6 +73,24 @@ def __init__(self, ast, function_name, hexvalues): |
72 | 73 | self.new_main = None |
73 | 74 | self.arguments = [] |
74 | 75 | self.hexvalues = hexvalues |
| 76 | + self.typedef_map = self._build_typedef_map(ast) |
| 77 | + |
| 78 | + def _build_typedef_map(self, ast): |
| 79 | + """Build a map from typedef names to their resolved type nodes.""" |
| 80 | + typedef_map = {} |
| 81 | + for node in ast.ext: |
| 82 | + if isinstance(node, Typedef): |
| 83 | + typedef_map[node.name] = node.type |
| 84 | + return typedef_map |
| 85 | + |
| 86 | + def _resolve_typedef(self, type_node): |
| 87 | + """If type_node is a typedef'd IdentifierType, resolve it to the underlying type.""" |
| 88 | + if (isinstance(type_node, TypeDecl) and |
| 89 | + isinstance(type_node.type, IdentifierType)): |
| 90 | + type_name = " ".join(type_node.type.names) |
| 91 | + if type_name in self.typedef_map: |
| 92 | + return self.typedef_map[type_name] |
| 93 | + return type_node |
75 | 94 |
|
76 | 95 | def visit_func(self, node): |
77 | 96 | """ |
@@ -403,6 +422,35 @@ def generate_array_declaration(self, name, array_type_node, values): |
403 | 422 | values_str = ", ".join(values_chunk) |
404 | 423 | return f"{element_type_str} {name}[] = {{ {values_str} }};" |
405 | 424 |
|
| 425 | + def generate_typedef_array_declaration(self, name, orig_type_node, resolved_array_node, value): |
| 426 | + """ |
| 427 | + Generates the variable declaration for a typedef'd array type. |
| 428 | +
|
| 429 | + Uses the typedef name for the declaration but generates a flat initializer |
| 430 | + list based on the resolved array element type and size. |
| 431 | + """ |
| 432 | + # Get the typedef name from the original type node |
| 433 | + typedef_name = " ".join(orig_type_node.type.names) |
| 434 | + # Get the base element type from the resolved array |
| 435 | + element_type_str = self.get_element_type_str(resolved_array_node) |
| 436 | + elem_size = TYPE_SIZES.get(element_type_str, 4) |
| 437 | + hex_digits_per_elem = elem_size * 2 |
| 438 | + |
| 439 | + # Strip 0x prefix and trailing whitespace |
| 440 | + values_hex = value.strip() |
| 441 | + if values_hex.startswith("0x"): |
| 442 | + values_hex = values_hex[2:] |
| 443 | + |
| 444 | + # Chunk by element size and reverse bytes within each element |
| 445 | + values_chunk = [] |
| 446 | + for i in range(0, len(values_hex), hex_digits_per_elem): |
| 447 | + chunk = values_hex[i : i + hex_digits_per_elem] |
| 448 | + if len(chunk) == hex_digits_per_elem: |
| 449 | + values_chunk.append("0x" + _mem_bytes_to_literal_hex(chunk)) |
| 450 | + |
| 451 | + values_str = ", ".join(values_chunk) |
| 452 | + return f"{typedef_name} {name} = {{{values_str}}};" |
| 453 | + |
406 | 454 | def get_element_type_str(self, array_type_node): |
407 | 455 | """ |
408 | 456 | Get the base type of array type. |
@@ -441,7 +489,12 @@ def gen_arguments(self, arg_types, arg_names, hex_values): |
441 | 489 | declarations = [] |
442 | 490 |
|
443 | 491 | for type_node, name, value in zip(arg_types, arg_names, hex_values): |
444 | | - if self.is_primitive(type_node): |
| 492 | + resolved = self._resolve_typedef(type_node) |
| 493 | + if self.is_array(resolved): |
| 494 | + decl = self.generate_typedef_array_declaration( |
| 495 | + name, type_node, resolved, value |
| 496 | + ) |
| 497 | + elif self.is_primitive(type_node): |
445 | 498 | decl = self.generate_primitive_declaration(name, type_node, value) |
446 | 499 | elif self.is_struct(type_node): |
447 | 500 | decl = self.generate_struct_declaration(name, type_node, value) |
@@ -511,11 +564,31 @@ def generate_executable( |
511 | 564 | with open(input_file, "r") as file: |
512 | 565 | original_c_content = file.read() |
513 | 566 |
|
| 567 | + # Remove any existing main function definition to avoid conflicts |
| 568 | + import re as _re |
| 569 | + main_def_pat = r'(?:^[^\n\S]*(?:int|void)\s+main\s*\([^)]*\)\s*\{)' |
| 570 | + main_match = _re.search(main_def_pat, original_c_content, flags=_re.MULTILINE) |
| 571 | + if main_match: |
| 572 | + brace_start = original_c_content.index('{', main_match.start()) |
| 573 | + depth = 0 |
| 574 | + i = brace_start |
| 575 | + while i < len(original_c_content): |
| 576 | + if original_c_content[i] == '{': |
| 577 | + depth += 1 |
| 578 | + elif original_c_content[i] == '}': |
| 579 | + depth -= 1 |
| 580 | + if depth == 0: |
| 581 | + original_c_content = original_c_content[:main_match.start()] + original_c_content[i+1:] |
| 582 | + break |
| 583 | + i += 1 |
| 584 | + # Also remove forward declarations of main |
| 585 | + original_c_content = _re.sub(r'^[^\n\S]*int\s+main\s*\([^)]*\)\s*;\s*\n?', '', original_c_content, flags=_re.MULTILINE) |
| 586 | + |
514 | 587 | # This must be included in we want to run flexpret backend (for printf) |
515 | 588 | if include_flexpret: |
516 | 589 | original_c_content = "#include <flexpret/flexpret.h> \n" + original_c_content |
517 | 590 | else: |
518 | | - original_c_content = "#include <time.h> \n" + original_c_content |
| 591 | + original_c_content = "#include <stdio.h>\n#include <stdint.h>\n#include <time.h>\n" + original_c_content |
519 | 592 |
|
520 | 593 | # TODO: generate global variables, add the global timing function |
521 | 594 | original_c_content += timing_function_body |
|
0 commit comments