Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
315 changes: 119 additions & 196 deletions src/main/java/org/openrewrite/prethink/ExportContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.lang.reflect.Field;
import java.nio.file.Path;
import java.util.*;
import java.util.stream.Stream;

import static java.util.Collections.emptyList;
import static org.openrewrite.prethink.Prethink.CONTEXT_DIR;
Expand Down Expand Up @@ -79,14 +80,102 @@ public boolean causesAnotherCycle() {
return true;
}

@Value
public static class Accumulator {
Set<Path> existingContextPaths;
private final Set<Path> existingContextPaths = new HashSet<>();

public Set<Path> getExistingContextPaths() {
return existingContextPaths;
}

// Rendered once and reused; safe to cache because producers stop writing after cycle 1.
@Nullable
volatile Map<String, String> csvByFilename;
@Nullable
volatile String markdown;
}

@Override
public Accumulator getInitialValue(ExecutionContext ctx) {
return new Accumulator(new HashSet<>());
return new Accumulator();
}

/** Aggregate and render each context's tables once per run; later calls are no-ops. */
private void renderOnce(Accumulator acc, ExecutionContext ctx) {
if (acc.csvByFilename != null) {
return;
}
synchronized (acc) {
if (acc.csvByFilename != null) {
return;
}
DataTableStore store = DataTableExecutionContextView.view(ctx).getDataTableStore();

// Multiple recipes can write the same table type, so collect every instance to concatenate its rows.
Map<String, List<DataTable<?>>> instancesByFqn = new HashMap<>();
for (DataTable<?> dt : store.getDataTables()) {
String tableFqn = dt.getClass().getName();
if (dataTables.contains(tableFqn)) {
instancesByFqn.computeIfAbsent(tableFqn, k -> new ArrayList<>()).add(dt);
}
}

Map<String, String> rendered = new LinkedHashMap<>();
List<DataTableInfo> exportedTables = new ArrayList<>();
// Iterate in the declared dataTables order for deterministic output.
for (String tableFqn : dataTables) {
List<DataTable<?>> instances = instancesByFqn.get(tableFqn);
if (instances == null || instances.isEmpty()) {
continue;
}
DataTable<?> representative = instances.get(0);
rendered.put(tableToFilename(tableFqn), streamToCsv(store, representative, instances));
exportedTables.add(new DataTableInfo(
representative.getDisplayName(),
representative.getDescription(),
tableToFilename(tableFqn),
getColumnInfo(representative)
));
}
acc.markdown = exportedTables.isEmpty() ? null : generateMarkdown(exportedTables);
// Publish the map last so readers see it (and markdown) fully built — volatile happens-before.
acc.csvByFilename = rendered;
}
}

/** Stream each row straight to the writer so a full table is never held in memory. */
@SuppressWarnings("unchecked")
private String streamToCsv(DataTableStore store, DataTable<?> representative, List<DataTable<?>> instances) {
List<Field> columnFields = getColumnFields(representative.getType());
String[] headers = columnFields.stream()
.map(f -> f.getAnnotation(Column.class).displayName())
.toArray(String[]::new);

StringWriter stringWriter = new StringWriter();
CsvWriter writer = new CsvWriter(stringWriter, new CsvWriterSettings());
writer.writeHeaders(headers);

String[] values = new String[columnFields.size()];
for (DataTable<?> instance : instances) {
Class<? extends DataTable<Object>> dtClass = (Class<? extends DataTable<Object>>) instance.getClass();
try (Stream<Object> rows = store.getRows(dtClass, instance.getGroup())) {
rows.forEach(row -> {
for (int i = 0; i < columnFields.size(); i++) {
Field field = columnFields.get(i);
try {
field.setAccessible(true);
Object value = field.get(row);
values[i] = value == null ? "" : value.toString();
} catch (IllegalAccessException e) {
values[i] = "";
}
}
writer.writeRow((Object[]) values);
});
}
}

writer.close();
return stringWriter.toString();
}

@Override
Expand Down Expand Up @@ -114,63 +203,35 @@ public Collection<SourceFile> generate(Accumulator acc, ExecutionContext ctx) {
return emptyList();
}

List<SourceFile> contextFiles = new ArrayList<>();

// Access DataTableStore here - after preceding recipes have populated it
DataTableStore store = DataTableExecutionContextView.view(ctx).getDataTableStore();
// Aggregate + render exactly once; reused by getVisitor() and any later cycle.
renderOnce(acc, ctx);

// Aggregate rows from all instances of the same DataTable class,
// since multiple recipes may populate the same data table type
Map<String, DataTable<?>> tablesByFqn = new LinkedHashMap<>();
Map<String, List<Object>> rowsByFqn = new LinkedHashMap<>();
aggregateMatchingTables(store, tablesByFqn, rowsByFqn);

if (tablesByFqn.isEmpty()) {
List<SourceFile> contextFiles = new ArrayList<>();
Map<String, String> csvByFilename = acc.csvByFilename;
if (csvByFilename == null || csvByFilename.isEmpty()) {
return contextFiles;
}

// Collect the data tables we're exporting for the markdown file
List<DataTableInfo> exportedTables = new ArrayList<>();

for (Map.Entry<String, DataTable<?>> tableEntry : tablesByFqn.entrySet()) {
String tableFqn = tableEntry.getKey();
DataTable<?> table = tableEntry.getValue();
List<?> rows = rowsByFqn.getOrDefault(tableFqn, emptyList());

String filename = tableToFilename(tableFqn);
String csvContent = exportToCsv(table, rows);
Path filePath = CONTEXT_DIR.resolve(filename);

// Collect table info for markdown
exportedTables.add(new DataTableInfo(
table.getDisplayName(),
table.getDescription(),
filename,
getColumnInfo(table, rows)
));

for (Map.Entry<String, String> entry : csvByFilename.entrySet()) {
Path filePath = CONTEXT_DIR.resolve(entry.getKey());
// Only generate if file doesn't already exist
if (!acc.getExistingContextPaths().contains(filePath)) {
PlainText csvFile = PlainText.builder()
.text(csvContent)
contextFiles.add(PlainText.builder()
.text(entry.getValue())
.sourcePath(filePath)
.build();
contextFiles.add(csvFile);
.build());
}
}

// Generate the markdown description file
if (!exportedTables.isEmpty()) {
String mdFilename = toKebabCase(displayName) + ".md";
Path mdPath = CONTEXT_DIR.resolve(mdFilename);

String markdown = acc.markdown;
if (markdown != null) {
Path mdPath = CONTEXT_DIR.resolve(toKebabCase(displayName) + ".md");
if (!acc.getExistingContextPaths().contains(mdPath)) {
String mdContent = generateMarkdown(exportedTables);
PlainText mdFile = PlainText.builder()
.text(mdContent)
contextFiles.add(PlainText.builder()
.text(markdown)
.sourcePath(mdPath)
.build();
contextFiles.add(mdFile);
.build());
}
}

Expand All @@ -194,42 +255,22 @@ public TreeVisitor<?, ExecutionContext> getVisitor(Accumulator acc) {
Path path = pt.getSourcePath();

if (path.startsWith(CONTEXT_DIR)) {
// Reuse the once-rendered output instead of re-aggregating per file.
renderOnce(acc, ctx);
String filename = path.getFileName().toString();

// Update CSV files
if (filename.endsWith(".csv")) {
String newContent = getCsvContentForFile(filename, ctx);
Map<String, String> csvByFilename = acc.csvByFilename;
String newContent = csvByFilename == null ? null : csvByFilename.get(filename);
if (newContent != null && !newContent.equals(pt.getText())) {
return pt.withText(newContent);
}
}

// Update markdown file
String expectedMdFilename = toKebabCase(displayName) + ".md";
if (filename.equals(expectedMdFilename)) {
DataTableStore store2 = DataTableExecutionContextView.view(ctx).getDataTableStore();
{
Map<String, DataTable<?>> tablesByFqn = new LinkedHashMap<>();
Map<String, List<Object>> rowsByFqn = new LinkedHashMap<>();
aggregateMatchingTables(store2, tablesByFqn, rowsByFqn);
List<DataTableInfo> exportedTables = new ArrayList<>();
for (Map.Entry<String, DataTable<?>> tableEntry : tablesByFqn.entrySet()) {
String tableFqn = tableEntry.getKey();
DataTable<?> table = tableEntry.getValue();
List<?> rows = rowsByFqn.getOrDefault(tableFqn, emptyList());
exportedTables.add(new DataTableInfo(
table.getDisplayName(),
table.getDescription(),
tableToFilename(tableFqn),
getColumnInfo(table, rows)
));
}
if (!exportedTables.isEmpty()) {
String newContent = generateMarkdown(exportedTables);
if (!newContent.equals(pt.getText())) {
return pt.withText(newContent);
}
}
} else if (filename.equals(toKebabCase(displayName) + ".md")) {
// Update markdown file
String markdown = acc.markdown;
if (markdown != null && !markdown.equals(pt.getText())) {
return pt.withText(markdown);
}
}
}
Expand Down Expand Up @@ -280,77 +321,17 @@ private String generateMarkdown(List<DataTableInfo> tables) {
return sb.toString();
}

private List<ColumnInfo> getColumnInfo(DataTable<?> table, List<?> rows) {
private List<ColumnInfo> getColumnInfo(DataTable<?> table) {
List<ColumnInfo> columns = new ArrayList<>();

Class<?> rowClass;
if (!rows.isEmpty()) {
rowClass = rows.get(0).getClass();
} else {
try {
rowClass = Class.forName(table.getClass().getName() + "$Row");
} catch (ClassNotFoundException e) {
return columns;
}
}

for (Field field : rowClass.getDeclaredFields()) {
for (Field field : table.getType().getDeclaredFields()) {
Column columnAnnotation = field.getAnnotation(Column.class);
if (columnAnnotation != null) {
columns.add(new ColumnInfo(columnAnnotation.displayName(), columnAnnotation.description()));
}
}

return columns;
}

private @Nullable String getCsvContentForFile(String filename, ExecutionContext ctx) {
DataTableStore store = DataTableExecutionContextView.view(ctx).getDataTableStore();

Map<String, DataTable<?>> tablesByFqn = new LinkedHashMap<>();
Map<String, List<Object>> rowsByFqn = new LinkedHashMap<>();
aggregateMatchingTables(store, tablesByFqn, rowsByFqn);

for (Map.Entry<String, DataTable<?>> entry : tablesByFqn.entrySet()) {
String expectedFilename = tableToFilename(entry.getKey());
if (expectedFilename.equals(filename)) {
return exportToCsv(entry.getValue(),
rowsByFqn.getOrDefault(entry.getKey(), emptyList()));
}
}
return null;
}

/**
* Aggregate rows from all DataTable instances of the same class into a single list per class.
* When multiple recipes produce the same DataTable type (e.g., FindNodeTestCoverage and
* FindTestCoverage both produce TestMapping), this ensures all rows are combined.
*/
@SuppressWarnings({"unchecked"})
private void aggregateMatchingTables(DataTableStore store,
Map<String, DataTable<?>> tablesByFqn,
Map<String, List<Object>> rowsByFqn) {
// First pass: collect all matching tables (order from store is non-deterministic)
Map<String, DataTable<?>> unordered = new HashMap<>();
Map<String, List<Object>> unorderedRows = new HashMap<>();
for (DataTable<?> dt : store.getDataTables()) {
String tableFqn = dt.getClass().getName();
if (dataTables.contains(tableFqn)) {
unordered.putIfAbsent(tableFqn, dt);
List<Object> rows = unorderedRows.computeIfAbsent(tableFqn, k -> new ArrayList<>());
Class<? extends DataTable<Object>> dtClass = (Class<? extends DataTable<Object>>) dt.getClass();
store.getRows(dtClass, dt.getGroup()).forEach(rows::add);
}
}
// Second pass: insert in dataTables list order for deterministic output
for (String fqn : dataTables) {
if (unordered.containsKey(fqn)) {
tablesByFqn.put(fqn, unordered.get(fqn));
rowsByFqn.put(fqn, unorderedRows.getOrDefault(fqn, new ArrayList<>()));
}
}
}

private String tableToFilename(String tableFqn) {
// org.openrewrite.prethink.table.MethodDescriptions -> method-descriptions.csv
String simpleName = tableFqn.substring(tableFqn.lastIndexOf('.') + 1);
Expand All @@ -377,64 +358,6 @@ private String toKebabCase(String input) {
return result.toString();
}

private String exportToCsv(DataTable<?> table, List<?> rows) {
if (rows.isEmpty()) {
return getHeadersFromTable(table);
}

Class<?> rowClass = rows.get(0).getClass();
List<Field> columnFields = getColumnFields(rowClass);

StringWriter stringWriter = new StringWriter();
CsvWriter writer = new CsvWriter(stringWriter, new CsvWriterSettings());

// Write headers
String[] headers = columnFields.stream()
.map(f -> f.getAnnotation(Column.class).displayName())
.toArray(String[]::new);
writer.writeHeaders(headers);

// Write rows
for (Object row : rows) {
String[] values = new String[columnFields.size()];
for (int i = 0; i < columnFields.size(); i++) {
Field field = columnFields.get(i);
try {
field.setAccessible(true);
Object value = field.get(row);
values[i] = value == null ? "" : value.toString();
} catch (IllegalAccessException e) {
values[i] = "";
}
}
writer.writeRow((Object[]) values);
}

writer.close();
return stringWriter.toString();
}

private String getHeadersFromTable(DataTable<?> table) {
// Get the Row class from the DataTable's generic type
try {
Class<?> rowClass = Class.forName(table.getClass().getName() + "$Row");
List<Field> columnFields = getColumnFields(rowClass);

StringWriter stringWriter = new StringWriter();
CsvWriter writer = new CsvWriter(stringWriter, new CsvWriterSettings());

String[] headers = columnFields.stream()
.map(f -> f.getAnnotation(Column.class).displayName())
.toArray(String[]::new);
writer.writeHeaders(headers);
writer.close();

return stringWriter.toString();
} catch (ClassNotFoundException e) {
return "";
}
}

private List<Field> getColumnFields(Class<?> rowClass) {
List<Field> columnFields = new ArrayList<>();
for (Field field : rowClass.getDeclaredFields()) {
Expand Down
Loading
Loading