diff --git a/src/passes/RemoveUnusedModuleElements.cpp b/src/passes/RemoveUnusedModuleElements.cpp index 9432af2ddc0..dad83e62008 100644 --- a/src/passes/RemoveUnusedModuleElements.cpp +++ b/src/passes/RemoveUnusedModuleElements.cpp @@ -275,27 +275,19 @@ struct Analyzer { std::unordered_map> unreadStructFieldExprMap; - // Cached segment data. Each time we see a new indirect call, we must scan all - // the segments of the table it refers to, find the functions in that segment, - // and check their types. If the number of segments is immense, we may end up - // doing a massive amount of function lookups (N * M where N = number of - // unique indirect call forms and M = size of the table's segments). To avoid - // that, precompute the function lookups in advance by "flattening" the data. - struct FlatElemInfo { - // The name of the element segment. - Name name; - - // The data in the element segment. - struct Item { - // The function the element segment's item refers to. - Name func; - // The type of function. - Type type; - }; - std::vector data; + // Cached table data. Each time we see a new call_indirect form (a table and a + // type), we must find all the functions that might be called, and their + // element segments, as those are now reachable. We parse element segments + // once at the start to build an efficient "flat" data structure for later + // queries. + struct FlatTableInfo { + // Maps each heap type that is in this table to the items it can call: the + // functions, and their segments. This takes into account subtyping, that + // is, typeItemMap[foo] includes data for subtypes of foo, so that we just + // need to read one place. + std::unordered_map> typeFuncs; + std::unordered_map> typeElems; }; - // Each table tracks all its elems. - using FlatTableInfo = std::vector; std::unordered_map flatTableInfoMap; Analyzer(Module* module, @@ -320,18 +312,20 @@ struct Analyzer { if (!elem->table) { continue; } - FlatElemInfo elemInfo; - elemInfo.name = elem->name; - auto& data = elemInfo.data; + auto& flatTableInfo = flatTableInfoMap[elem->table]; for (auto* item : elem->data) { if (auto* refFunc = item->dynCast()) { auto* func = module->getFunction(refFunc->func); - data.emplace_back(FlatElemInfo::Item{func->name, func->type}); + std::optional type = func->type.getHeapType(); + // Add this function and element to all relevant types: each function + // might be called by its type, or a supertype. + while (type) { + flatTableInfo.typeFuncs[*type].insert(func->name); + flatTableInfo.typeElems[*type].insert(elem->name); + type = type->getSuperType(); + } } } - if (!elemInfo.data.empty()) { - flatTableInfoMap[elem->table].push_back(std::move(elemInfo)); - } } } @@ -421,18 +415,12 @@ struct Analyzer { auto [table, type] = call; - // Any function in the table of that signature may be called. - for (auto& elemInfo : flatTableInfoMap[table]) { - auto elemReferenced = false; - for (auto& [func, funcType] : elemInfo.data) { - if (HeapType::isSubType(funcType.getHeapType(), type)) { - use({ModuleElementKind::Function, func}); - elemReferenced = true; - } - } - if (elemReferenced) { - reference({ModuleElementKind::ElementSegment, elemInfo.name}); - } + // Find callable functions and segments. + for (auto& func : flatTableInfoMap[table].typeFuncs[type]) { + use({ModuleElementKind::Function, func}); + } + for (auto& elem : flatTableInfoMap[table].typeElems[type]) { + reference({ModuleElementKind::ElementSegment, elem}); } }