From 29dde5466ef82957521ad7abcda87c3476953d73 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Sun, 6 Nov 2022 11:27:46 +0100 Subject: [PATCH 1/8] Move `query-multiplier` to main repo --- CMakeLists.txt | 1 + bin/CMakeLists.txt | 1 + bin/Query/CMakeLists.txt | 34 ++ bin/Query/PredicateExample.cpp | 177 +++++++ bin/Query/SyntexQuery.cpp | 151 ++++++ include/multiplier/Syntex.h | 159 ++++++ lib/CMakeLists.txt | 1 + lib/Query/AST.cpp | 204 ++++++++ lib/Query/AST.h | 126 +++++ lib/Query/CMakeLists.txt | 52 ++ lib/Query/Grammar.cpp | 319 ++++++++++++ lib/Query/Grammar.h | 95 ++++ lib/Query/NodeKind.h | 133 +++++ lib/Query/Query.cpp | 918 +++++++++++++++++++++++++++++++++ lib/Query/Query.h | 251 +++++++++ 15 files changed, 2622 insertions(+) create mode 100644 bin/Query/CMakeLists.txt create mode 100644 bin/Query/PredicateExample.cpp create mode 100644 bin/Query/SyntexQuery.cpp create mode 100644 include/multiplier/Syntex.h create mode 100644 lib/Query/AST.cpp create mode 100644 lib/Query/AST.h create mode 100644 lib/Query/CMakeLists.txt create mode 100644 lib/Query/Grammar.cpp create mode 100644 lib/Query/Grammar.h create mode 100644 lib/Query/NodeKind.h create mode 100644 lib/Query/Query.cpp create mode 100644 lib/Query/Query.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 53308c3fb..e40a18465 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -59,6 +59,7 @@ find_package(SQLite3 3.35 REQUIRED) find_package(reproc++ REQUIRED) find_package(pasta CONFIG REQUIRED) find_package(Python3 COMPONENTS Interpreter REQUIRED) +find_package(absl CONFIG REQUIRED) if(PLATFORM_MACOS) set(CMAKE_INSTALL_RPATH "@executable_path/../${CMAKE_INSTALL_LIBDIR}") diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 68b175c27..3b06ede72 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -9,3 +9,4 @@ add_subdirectory("Examples") add_subdirectory("GenJSON") add_subdirectory("Index") +add_subdirectory("Query") diff --git a/bin/Query/CMakeLists.txt b/bin/Query/CMakeLists.txt new file mode 100644 index 000000000..c495e8176 --- /dev/null +++ b/bin/Query/CMakeLists.txt @@ -0,0 +1,34 @@ +# +# Copyright (c) 2022-present, Trail of Bits, Inc. +# All rights reserved. +# +# This source code is licensed in accordance with the terms specified in +# the LICENSE file found in the root directory of this source tree. +# + +add_executable("syntex-query" "SyntexQuery.cpp") + +target_link_libraries("syntex-query" + PRIVATE + "mx-api" + "mx-syntex" +) + +install( + TARGETS + "syntex-query" + EXPORT + "${PROJECT_NAME}Targets" + RUNTIME + DESTINATION + "${CMAKE_INSTALL_BINDIR}" +) + +add_executable("predicate-example" "PredicateExample.cpp") + +target_link_libraries("predicate-example" + PRIVATE + gflags + "mx-api" + "mx-syntex" +) diff --git a/bin/Query/PredicateExample.cpp b/bin/Query/PredicateExample.cpp new file mode 100644 index 000000000..91a0549d3 --- /dev/null +++ b/bin/Query/PredicateExample.cpp @@ -0,0 +1,177 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. +// All rights reserved. +// +// This source code is licensed in accordance with the terms specified in +// the LICENSE file found in the root directory of this source tree. + +// +// Example utility that uses syntex predicates to locate float to integer casts +// + +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#define ANSI_RED "\x1b[1;31m" +#define ANSI_RESET "\x1b[1;0m" + +DECLARE_bool(help); +DEFINE_string(db, "", "Path to Multiplier database."); +DEFINE_bool(long_is_32_bits, false, "Is 'long' a 32-bit type?"); + +static std::optional IntegralTypeWidth(const mx::Type &type) { + auto builtin_type = mx::BuiltinType::from(type); + if (!builtin_type) { + return std::nullopt; + } + + switch (builtin_type->builtin_kind()) { + case mx::BuiltinTypeKind::S_CHAR: + case mx::BuiltinTypeKind::CHARACTER_U: + case mx::BuiltinTypeKind::CHARACTER_S: + case mx::BuiltinTypeKind::BOOLEAN: + case mx::BuiltinTypeKind::CHAR8: + case mx::BuiltinTypeKind::U_CHAR: + return 8u; + case mx::BuiltinTypeKind::W_CHAR_S: + case mx::BuiltinTypeKind::W_CHAR_U: + case mx::BuiltinTypeKind::CHAR16: + case mx::BuiltinTypeKind::SHORT: + case mx::BuiltinTypeKind::U_SHORT: + return 16u; + case mx::BuiltinTypeKind::CHAR32: + case mx::BuiltinTypeKind::INT: + case mx::BuiltinTypeKind::U_INT: + return 32u; + case mx::BuiltinTypeKind::U_LONG: + case mx::BuiltinTypeKind::LONG: + return FLAGS_long_is_32_bits ? 32u : 64u; + case mx::BuiltinTypeKind::U_LONG_LONG: + case mx::BuiltinTypeKind::LONG_LONG: + return 64u; + case mx::BuiltinTypeKind::U_INT128: + case mx::BuiltinTypeKind::INT128: + return 128u; + default: + return std::nullopt; + } +} + +static std::optional IntegralTypeWidth(const mx::ValueDecl &decl) { + return IntegralTypeWidth(decl.type()); +} + +static void HighlightMatch(std::ostream &os, mx::syntex::Match m) { + auto ref = mx::DeclRefExpr::from(std::get(m.MetavarMatch(0).Entity())); + if (!ref) { + return; + } + + auto var = mx::VarDecl::from(ref->referenced_declaration()); + if (!var) { + return; + } + + auto type_size = IntegralTypeWidth(var.value()); + if (!type_size) { + return; + } + + auto lit = mx::IntegerLiteral::from( + std::get(m.MetavarMatch(1).Entity())); + if (!lit) { + return; + } + + auto lit_data = lit->token().data(); + + // Strip off the suffixes of things like `0ull`. + while (lit_data.ends_with('u') || lit_data.ends_with('U') || + lit_data.ends_with('l') || lit_data.ends_with('L') || + lit_data.ends_with(' ')) { + lit_data = lit_data.substr(0, lit_data.size() - 1u); + } + + int64_t lit_val{-1}; + + std::stringstream ss; + ss << lit_data; + ss >> lit_val; + + if (0 >= lit_val) { + return; + } + + if (type_size.value() > static_cast(lit_val)) { + return; + } + + auto builtin_type = mx::BuiltinType::from(var->type()); + os << "File ID: " << mx::File::containing(m.Fragment()).id() << '\n' + << "Fragment ID: " << m.Fragment().id().Pack() << '\n' + << "Token ID: " << m.FirstTokenId() << '\n' + << "Literal value: " << lit_val << '\n' + << "Type size: " << type_size.value() << '\n' + << "Type kind: " << mx::EnumeratorName(builtin_type->builtin_kind()) + << "\nExpression:"; + + for (mx::Token tok : m.TokenRange()) { + os << ' ' << tok.data(); + } + + os << "\n\n"; +} + +extern "C" int main(int argc, char *argv[]) { + std::stringstream ss; + ss + << "Usage: " << argv[0] + << " [--db DATABASE]\n"; + + google::SetUsageMessage(ss.str()); + google::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + if (FLAGS_help) { + std::cerr << google::ProgramUsage() << std::endl; + return EXIT_FAILURE; + } + + if (FLAGS_db.empty()) { + std::cerr << "Need to specify a database using --db" << std::endl; + return EXIT_FAILURE; + } + + // Setup index and grammar + + mx::Index index = mx::EntityProvider::in_memory_cache( + mx::EntityProvider::from_database(FLAGS_db)); + mx::syntex::Grammar grammar(index, FLAGS_db); + + // Setup query + + mx::syntex::ParsedQuery parsed_query(grammar, "$var:DECL_REF_EXPR << $num:INTEGER_LITERAL"); + if (!parsed_query.IsValid()) { + return EXIT_FAILURE; + } + + // Match fragments + + parsed_query.ForEachMatch([] (mx::syntex::Match match) { + HighlightMatch(std::cout, std::move(match)); + return true; + }); + + return EXIT_SUCCESS; +} diff --git a/bin/Query/SyntexQuery.cpp b/bin/Query/SyntexQuery.cpp new file mode 100644 index 000000000..ff2dc1851 --- /dev/null +++ b/bin/Query/SyntexQuery.cpp @@ -0,0 +1,151 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. +// All rights reserved. +// +// This source code is licensed in accordance with the terms specified in +// the LICENSE file found in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include + +#define ANSI_RED "\x1b[1;31m" +#define ANSI_RESET "\x1b[1;0m" + +DECLARE_bool(help); +DEFINE_string(db, "", "Path to Multiplier database."); +DEFINE_string(query, "", "Use argument value as query"); +DEFINE_uint64(threads, 0, "Use this number of threads"); +DEFINE_bool(suppress_output, false, "Don't print matches to stdout"); + +static std::mutex gMatchPrintingMutex; + +static void PrintMatch(const mx::syntex::Match &match) +{ + if (FLAGS_suppress_output) { + return; + } + + { + std::lock_guard guard(gMatchPrintingMutex); + + // Print matching fragment ID + std::cout << "Match in " << match.Fragment().id() << ":\n"; + + for (auto token : match.Fragment().parsed_tokens()) { + if (token.id() == match.FirstTokenId()) { + // Switch to ANSI red for the first matching token + std::cout << ANSI_RED; + } + + std::cout << token.data() << " "; + + if (token.id() == match.LastTokenId()) { + // Reset color after last matching token + std::cout << ANSI_RESET; + } + } + + std::cout << "\n"; + + for (auto &metavar : match.MetavarMatches()) { + std::cout << "Matching metavar " << metavar.Name() << "\n"; + } + } +} + +static void ProcessFragmentRange(const mx::syntex::ParsedQuery &parsed_query, + const mx::RawEntityId *begin, + const mx::RawEntityId *end) +{ + for (; begin < end; ++begin) { + std::vector matches = + parsed_query.FindInFragment(*begin); + + for (const mx::syntex::Match &match : matches) { + PrintMatch(match); + } + } +} + +extern "C" int main(int argc, char *argv[]) { + std::stringstream ss; + ss + << "Usage: " << argv[0] + << " [--db DATABASE]\n"; + + google::SetUsageMessage(ss.str()); + google::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + if (FLAGS_help) { + std::cerr << google::ProgramUsage() << std::endl; + return EXIT_FAILURE; + } + + if (FLAGS_db.empty()) { + std::cerr << "Need to specify a database using --db" << std::endl; + return EXIT_FAILURE; + } + + // Setup index and grammar + + mx::Index index = mx::EntityProvider::from_database(FLAGS_db); + mx::syntex::Grammar grammar(index, FLAGS_db); + + // Parse query + + mx::syntex::ParsedQuery parsed_query(grammar, FLAGS_query); + + if (!parsed_query.IsValid()) { + std::cerr << "Query `" << FLAGS_query << "` has no valid parses\n"; + return EXIT_FAILURE; + } + + // Choose number of threads + + size_t threads = FLAGS_threads ?: std::thread::hardware_concurrency(); + std::cout << "starting matcher with " << threads << " threads\n"; + + // Collect all fragments to process + + std::vector fragment_ids; + + for (const mx::File &file : mx::File::in(index)) { + for (mx::RawEntityId fragment_id : file.fragment_ids()) { + fragment_ids.push_back(fragment_id); + } + } + + // Find the ideal number of fragments per thread + + size_t fragments_per_thread = fragment_ids.size() / threads; + + + // Create workers + + std::vector thread_pool; + + auto cur = &fragment_ids.front(); + auto last = &fragment_ids.back(); + + while (cur < last) { + auto end = cur + fragments_per_thread; + if (end > last) { + end = last; + } + thread_pool.emplace_back(ProcessFragmentRange, parsed_query, cur, end); + cur = end; + } + + // Wait for all workers to finish + + for (auto &thread : thread_pool) { + thread.join(); + } + + return EXIT_SUCCESS; +} diff --git a/include/multiplier/Syntex.h b/include/multiplier/Syntex.h new file mode 100644 index 000000000..d95f634d0 --- /dev/null +++ b/include/multiplier/Syntex.h @@ -0,0 +1,159 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. +// All rights reserved. +// +// This source code is licensed in accordance with the terms specified in +// the LICENSE file found in the root directory of this source tree. + +#pragma once + +#include +#include +#include +#include +#include "Entities/Attr.h" +#include "Entities/Decl.h" +#include "Entities/Designator.h" +#include "Entities/Stmt.h" +#include "Entities/Type.h" +#include "File.h" +#include "Index.h" +#include "Token.h" + +namespace mx { +namespace syntex { + +class Grammar; +class GrammarImpl; +class ParsedQuery; +class ParsedQueryImpl; +class Match; + +// +// Handle to a persistent grammar +// + +class Grammar { +private: + friend class ParsedQuery; + + std::shared_ptr impl; + Grammar() = delete; + +public: + explicit Grammar(const mx::Index &index, std::filesystem::path grammar_dir); +}; + +// +// Chunk of a fragment (potentially) matching a metavariable +// + +class MetavarMatch { +private: + std::string_view name; + mx::VariantEntity entity; + mx::TokenRange token_range; + +public: + MetavarMatch(std::string_view name_, mx::VariantEntity entity_, + mx::TokenRange token_range_) + : name(std::move(name_)), + entity(std::move(entity_)), + token_range(std::move(token_range_)) {} + + const std::string_view &Name(void) const { + return name; + } + + const mx::VariantEntity &Entity(void) const { + return entity; + } + + const mx::TokenRange &TokenRange(void) const { + return token_range; + } +}; + + +// +// Result of parsing a query +// + +class ParsedQuery { + private: + std::shared_ptr impl; + ParsedQuery(void) = delete; + + public: + explicit ParsedQuery(const Grammar &grammar, std::string_view query); + + bool IsValid() const; + + bool AddMetavarPredicate(const std::string_view &name, + std::function predicate); + + void ForEachMatch(const mx::Fragment &frag, + std::function pred) const; + void ForEachMatch(const mx::File &file, + std::function pred) const; + void ForEachMatch(std::function pred) const; + + std::vector Find(const mx::Fragment &frag) const; + std::vector Find(const mx::File &file) const; + std::vector Find(void) const; + + std::vector FindInFragment(mx::RawEntityId fragment_id) const; +}; + +// +// Chunk of a ParsedQuery that matched against a part of a fragment +// + +class Match { +private: + friend class ParsedQuery; + + mx::Fragment fragment; + mx::VariantEntity entity; + mx::TokenRange token_range; + + std::vector metavars; + +public: + Match(mx::Fragment fragment_, mx::VariantEntity entity_, + mx::TokenRange token_range_, std::vector matevars_) + : fragment(std::move(fragment_)), + entity(std::move(entity_)), + token_range(std::move(token_range_)), + metavars(std::move(matevars_)) {} + + const mx::Fragment &Fragment(void) const { + return fragment; + } + + const mx::VariantEntity &Entity(void) const { + return entity; + } + + const mx::TokenRange &TokenRange(void) const { + return token_range; + } + + mx::RawEntityId FirstTokenId(void) const { + return TokenRange().front().id(); + } + + mx::RawEntityId LastTokenId(void) const { + return TokenRange().back().id(); + } + + const std::vector &MetavarMatches(void) const { + return metavars; + } + + const MetavarMatch &MetavarMatch(size_t i) const { + return metavars[i]; + } +}; + +} // namespace syntex +} // namespace mx \ No newline at end of file diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index c377655aa..5ba4576ce 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -9,3 +9,4 @@ add_subdirectory(Common) add_subdirectory(API) add_subdirectory(Util) +add_subdirectory(Query) diff --git a/lib/Query/AST.cpp b/lib/Query/AST.cpp new file mode 100644 index 000000000..75db80703 --- /dev/null +++ b/lib/Query/AST.cpp @@ -0,0 +1,204 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. +// All rights reserved. +// +// This source code is licensed in accordance with the terms specified in +// the LICENSE file found in the root directory of this source tree. + +#include + +#include "AST.h" + +#include + +namespace mx { +namespace syntex { + +ASTNode::ASTNode(NodeKind kind_, + mx::VariantEntity entity_, + mx::TokenRange token_range_) + : kind(kind_), + entity(std::move(entity_)), + token_range(std::move(token_range_)), + child_vector() + {} + +ASTNode::ASTNode(mx::TokenKind kind_, + mx::VariantEntity entity_, + mx::TokenRange token_range_, + std::string spelling) + : kind(kind_), + entity(std::move(entity_)), + token_range(std::move(token_range_)), + spelling(std::move(spelling)) + {} + +ASTNode::~ASTNode() { + // Destruct correct union variant + if (kind.IsToken()) { + spelling.std::string::~string(); + } else { + child_vector.std::vector::~vector(); + } +} + +AST AST::Build(const mx::Fragment &fragment) { + AST self; + absl::flat_hash_map ctx_to_node; + + for (mx::Token tok : mx::Token::in(fragment)) { + // Skip whitespaces + switch (tok.kind()) { + case mx::TokenKind::UNKNOWN: + case mx::TokenKind::WHITESPACE: + case mx::TokenKind::COMMENT: + continue; + default: + if (tok.data().empty()) { + continue; + } + break; + } + + + // Start with the token node + ASTNode *node = &self.nodes.emplace_back( + tok.kind(), tok, tok, std::string(tok.data().data(), tok.data().size())); + node->prev = self.index[node->Kind().Serialize()]; + self.index[node->Kind().Serialize()] = node; + + for (auto ctx = mx::TokenContext::of(tok); ctx; ctx = ctx->parent()) { + auto it = ctx_to_node.find(ctx->id()); + + // Add to parent node's children if it already exists + + if (it != ctx_to_node.end()) { + it->second->child_vector.push_back(node); + node = nullptr; + break; + } + + // Otherwise we need to create a new parent node + + if (auto decl = mx::Decl::from(*ctx)) { + ASTNode *parent = &self.nodes.emplace_back(decl->kind(), *decl, decl->tokens()); + // Add it to the index + parent->prev = self.index[parent->Kind().Serialize()]; + self.index[parent->Kind().Serialize()] = parent; + ctx_to_node[ctx->id()] = parent; + parent->child_vector.push_back(node); + node = parent; + continue; + } + + if (auto stmt = mx::Stmt::from(*ctx)) { + ASTNode *parent = &self.nodes.emplace_back(stmt->kind(), *stmt, stmt->tokens()); + parent->prev = self.index[parent->Kind().Serialize()]; + self.index[parent->Kind().Serialize()] = parent; + ctx_to_node[ctx->id()] = parent; + parent->child_vector.push_back(node); + node = parent; + continue; + } + } + + // If we didn't add the token to a pre-existing parent, add it to the root + + if (node != nullptr) { + self.root.push_back(node); + } + } + + return self; +} + +#ifndef NDEBUG + +namespace { + +static std::string Data(const std::string &data) { + std::stringstream ss; + for (auto ch : data) { + switch (ch) { + // To keep xdot happy + case '[': ss << " ["; break; + case ']': ss << "]"; break; + // HTML escapes + case '<': ss << "<"; break; + case '>': ss << ">"; break; + case '"': ss << """; break; + case '\'': ss << "'"; break; + case '\n': ss << "
"; break; + case '&': ss << "&"; break; + case '\t': ss << "  "; break; + case '\r': break; + default: ss << ch; break; + } + } + return ss.str(); +} + +} // namespace + +void AST::PrintDOT(std::ostream &os) const { + os << "digraph {\n" + << "node [shape=none margin=0 nojustify=false labeljust=l font=courier];\n"; + + // Root node + os << "root [label=<
>];\n"; + for (const ASTNode *child : root) { + os << "root -> x" << std::hex << reinterpret_cast(child) + << std::dec << ";\n"; + } + + for (const ASTNode &node : nodes) { + os << "x" << std::hex << reinterpret_cast(&node) << std::dec + << " [label=<(&node) + << " -> x" << std::hex << reinterpret_cast(child) + << std::dec << ";\n"; + } + }; + + node.kind.Visit(Visitor { + [&] (mx::DeclKind kind) { + os + << " bgcolor=\"aquamarine\">" + << mx::EnumeratorName(kind) + << "
>];\n"; + PrintChildren(); + }, + [&] (mx::StmtKind kind) { + os + << " bgcolor=\"darkolivegreen3\">" + << mx::EnumeratorName(kind) + << ">];\n"; + PrintChildren(); + }, + [&] (mx::TokenKind kind) { + os + << " bgcolor=\"cornsilk2\">" + << mx::EnumeratorName(kind) + << "" + << Data(node.spelling) + << ">];\n"; + }, + [&] () { + assert(false); + abort(); + }, + }); + } + + os << "}\n"; +} + +#endif + +} // namespace syntex +} // namespace mx \ No newline at end of file diff --git a/lib/Query/AST.h b/lib/Query/AST.h new file mode 100644 index 000000000..935fb1ea6 --- /dev/null +++ b/lib/Query/AST.h @@ -0,0 +1,126 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. +// All rights reserved. +// +// This source code is licensed in accordance with the terms specified in +// the LICENSE file found in the root directory of this source tree. + +#pragma once + +#include "NodeKind.h" + +#include +#include +#include +#include + +// +// AST: In-memory tree representation of a multiplier fragment +// + +namespace pasta { +class TokenRange; +} + +namespace mx { +namespace syntex { + +class ASTNode { +public: + friend class AST; + + mutable const ASTNode *prev {nullptr}; + + ASTNode(NodeKind kind, + mx::VariantEntity entity, + mx::TokenRange token_range); + + ASTNode(mx::TokenKind kind, + mx::VariantEntity entity, + mx::TokenRange token_range, + std::string spelling); + + ~ASTNode(); + + NodeKind Kind() const { + return kind; + } + + const std::vector &ChildVector() const { + assert(!kind.IsToken()); + return child_vector; + }; + + const mx::VariantEntity &Entity() const { + return entity; + } + + const mx::TokenRange &TokenRange() const { + return token_range; + } + + const std::string &Spelling() const { + assert(kind.IsToken()); + return spelling; + } + +private: + NodeKind kind; + mx::VariantEntity entity; + mx::TokenRange token_range; + + union { + mutable std::vector child_vector; + std::string spelling; + }; +}; + + +// An AST. +class AST { +private: + friend class ASTNode; + + // Allocation arena for AST nodes + std::deque nodes; + + // Nodes at the root of the AST + std::vector root; + + // Nodes of the same kind are linked together in a chain + // This is the root of the chain for each kind + std::vector index; + + AST() { + index.resize(NodeKind::UpperLimit()); + } + +public: + // All nodes + const std::deque &AllNodes() const { + return nodes; + } + + // Nodes at the root of this AST + const std::vector &RootNodes(void) const { + return root; + } + + // Get indexed node of kind + const ASTNode *WithKind(NodeKind kind) const { + return index[kind.Serialize()]; + } + + // Build an AST from a multiplier fragment + static AST Build(const mx::Fragment &fragment); + + // NOTE: this actually lives in the PASTA grammar builder's cpp file + // do not call from anything else + static AST Build(const pasta::TokenRange &tokens); + +#ifndef NDEBUG + void PrintDOT(std::ostream &os) const; +#endif +}; + +} // namespace syntex +} // namespace mx \ No newline at end of file diff --git a/lib/Query/CMakeLists.txt b/lib/Query/CMakeLists.txt new file mode 100644 index 000000000..6886a1732 --- /dev/null +++ b/lib/Query/CMakeLists.txt @@ -0,0 +1,52 @@ +# +# Copyright (c) 2022-present, Trail of Bits, Inc. +# All rights reserved. +# +# This source code is licensed in accordance with the terms specified in +# the LICENSE file found in the root directory of this source tree. +# + +string(TOLOWER "${PROJECT_NAME}" lower_project_name) + +add_library("mx-syntex" + "AST.h" + "AST.cpp" + "Grammar.h" + "Grammar.cpp" + "NodeKind.h" + "Query.h" + "Query.cpp" +) + +target_include_directories("mx-syntex" + PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}" +) + +target_link_libraries("mx-syntex" + PRIVATE + "absl::hash" + "absl::raw_hash_set" + PUBLIC + "mx-api" + "mx-util" +) + +if(MX_ENABLE_INSTALL) + install( + TARGETS + "mx-syntex" + EXPORT "${PROJECT_NAME}Targets" + RUNTIME + DESTINATION + "${CMAKE_INSTALL_BINDIR}" + LIBRARY + DESTINATION + "${CMAKE_INSTALL_LIBDIR}" + ARCHIVE + DESTINATION + "${CMAKE_INSTALL_LIBDIR}" + PUBLIC_HEADER + DESTINATION + "${CMAKE_INSTALL_INCLUDEDIR}/${lower_project_name}" + ) +endif(MX_ENABLE_INSTALL) diff --git a/lib/Query/Grammar.cpp b/lib/Query/Grammar.cpp new file mode 100644 index 000000000..ef97fe0ed --- /dev/null +++ b/lib/Query/Grammar.cpp @@ -0,0 +1,319 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. +// All rights reserved. +// +// This source code is licensed in accordance with the terms specified in +// the LICENSE file found in the root directory of this source tree. + +#include "AST.h" +#include "Grammar.h" +#include +#include +#include + +namespace mx { +namespace syntex { + +GrammarImpl::GrammarImpl(const mx::Index &index_, std::filesystem::path db_path_) + : index(index_), db_path(db_path_) +{ + { + sqlite::Connection db(db_path); + DeserializeRules(db); + DeserializeTokens(db); + } + + for (auto file : mx::File::in(index)) { + for (auto fragment_id : file.fragment_ids()) { + Import(fragment_id); + } + } +} + +GrammarImpl::~GrammarImpl() +{ + sqlite::Connection db(db_path); + SerializeRules(db); + SerializeTokens(db); +} + +// Import a fragment into the grammar. +void GrammarImpl::Import(mx::RawEntityId fragment_id) +{ + auto fragment = index.fragment(fragment_id).value(); + auto ast = AST::Build(fragment); + +/* + + // Debug graphs + std::stringstream name; + name << "dot/ast_" << fragment.id() << ".dot"; + std::fstream fs(name.str(), std::fstream::out | std::fstream::trunc); + ast.PrintDOT(fs); + fs.close(); + +*/ + Import(ast); + +} + +void GrammarImpl::Import(const AST &ast) +{ + std::vector nodes(ast.RootNodes()); + + // Make a production rule for every node and its children. + while (!nodes.empty()) { + const ASTNode *node = nodes.back(); + nodes.pop_back(); + + if (node->Kind().IsToken()) { + // This is a token kind node, and represents a terminal. We want to map + // the contents of the token to the actual kind of the token. + + tokens.insert({ node->Spelling(), node->Kind().AsToken() }); + } else { + // This is an internal or root node. E.g. given the following: + // + // A + // / | \ + // B C D + // + // We want to make a rule of the form `B C D A`, i.e. if you match `B C D` + // then you have matched an `A`. This "backward" syntax enables us to prefix + // scan for left corners (`B` in this case) and find all rules starting with + // `B`. + + auto &child_vector = node->ChildVector(); + assert(child_vector.size() >= 1); + + // FIXME: do something else with long grammar rules. PHP has + // some generated initializer lists with 100s of elements that + // blows up our stack when serializing a grammar. + if (child_vector.size() > 100) { + continue; + } + + // Add the child nodes to the work list. + nodes.insert(nodes.end(), child_vector.begin(), child_vector.end()); + + // Walk the trie + GrammarLeaves *leaves = &root; + for (const ASTNode *child : child_vector) { + leaves = &leaves->operator[](child->Kind()).leaves; + } + // Save pointer to rule head + GrammarNode *head = &leaves->operator[](node->Kind()); + + // Avoid creating cyclic CFGs + bool allow_production = true; + + if (child_vector.size() == 1) { + std::vector queue = { node->Kind() }; + while (!queue.empty()) { + auto nt = queue.back(); + queue.pop_back(); + + // Check if we can reach our own left corner + if (nt == child_vector[0]->Kind()) { + allow_production = false; + break; + } + + // Queue result of matching trivial productions + auto it = root.find(nt); + if (it != root.end()) { + for (auto &[left, rest] : it->second.leaves) { + if (rest.is_production) { + queue.push_back(left); + } + } + } + } + } + + // Mark the head as a production if appropriate + head->is_production = allow_production; + } + } +} + +template +static void IterateRulesRecursive(const GrammarLeaves &leaves, + std::vector &stack, + F cb) +{ + for (const auto &[left, rest] : leaves) { + if (rest.is_production) { + cb(stack, left); + } + stack.push_back(left); + IterateRulesRecursive(rest.leaves, stack, cb); + stack.pop_back(); + } +} + +void GrammarImpl::DebugRules(std::ostream &os) +{ + std::vector stack; + IterateRulesRecursive(root, stack, [&] (const std::vector &body, NodeKind head) { + for (NodeKind kind : body) { + os << kind << " "; + } + os << "-> " << head << "\n"; + }); +} + +// NOTE: this is a simplistic serialization format + +inline void verify(bool condition) { + if (!condition) { + assert(false); + abort(); + } +} + +static constexpr const char *grammar_root_schema = + "CREATE TABLE IF NOT EXISTS " + "'mx::syntex::GrammarRoot'(kind, node, PRIMARY KEY(kind))"; + +static constexpr const char *grammar_nodes_schema = + "CREATE TABLE IF NOT EXISTS " + "'mx::syntex::GrammarNodes'(id, is_production, PRIMARY KEY(id))"; + +static constexpr const char *grammar_children_schema = + "CREATE TABLE IF NOT EXISTS " + "'mx::syntex::GrammarChildren'(parent, kind, child, PRIMARY KEY(parent, kind))"; + +void GrammarImpl::SerializeRules(sqlite::Connection& db) +{ + sqlite::Transaction tx(db); + std::scoped_lock lock(tx); + + db.Execute(grammar_root_schema); + db.Execute("DELETE FROM 'mx::syntex::GrammarRoot'"); + db.Execute(grammar_root_schema); + db.Execute("DELETE FROM 'mx::syntex::GrammarNodes'"); + db.Execute(grammar_children_schema); + db.Execute("DELETE FROM 'mx::syntex::GrammarChildren'"); + + auto root_stmt = db.Prepare( + "INSERT OR REPLACE INTO " + "'mx::syntex::GrammarRoot'(kind, node) VALUES (?1, ?2)"); + auto node_stmt = db.Prepare( + "INSERT OR REPLACE INTO " + "'mx::syntex::GrammarNodes'(id, is_production) VALUES (?1, ?2)"); + auto child_stmt = db.Prepare( + "INSERT OR REPLACE INTO " + "'mx::syntex::GrammarChildren'(parent, kind, child) VALUES (?1, ?2, ?3)"); + std::vector to_insert; + + auto GetId = [](const GrammarNode* node) { + return static_cast(reinterpret_cast(node)); + }; + + for(auto &[kind, node] : root) { + auto kind_value = kind.Serialize(); + root_stmt->BindValues(kind_value, GetId(&node)); + root_stmt->Execute(); + to_insert.push_back(&node); + } + + while(!to_insert.empty()) { + auto node = to_insert.back(); + to_insert.pop_back(); + + for(auto &[kind, child] : node->leaves) { + to_insert.push_back(&child); + child_stmt->BindValues(GetId(node), kind.Serialize(), GetId(&child)); + child_stmt->Execute(); + } + + node_stmt->BindValues(GetId(node), int{node->is_production}); + node_stmt->Execute(); + } +} + +void GrammarImpl::DeserializeRules(sqlite::Connection& db) +{ + db.Execute(grammar_root_schema); + db.Execute(grammar_nodes_schema); + db.Execute(grammar_children_schema); + auto root_stmt = db.Prepare( + "SELECT node, kind, is_production FROM 'mx::syntex::GrammarRoot' " + "JOIN 'mx::syntex::GrammarNodes' ON id = node"); + auto children_stmt = db.Prepare( + "SELECT child, kind, is_production FROM 'mx::syntex::GrammarChildren' " + "JOIN 'mx::syntex::GrammarNodes' ON id = child " + "WHERE parent = ?1"); + std::vector> to_load; + while(root_stmt->ExecuteStep()) { + std::uint64_t id; + unsigned short kind; + int is_production; + auto res = root_stmt->GetResult(); + res.Columns(id, kind, is_production); + auto &node = root[NodeKind::Deserialize(kind)]; + node.is_production = is_production; + to_load.emplace_back(id, &node); + } + + while(!to_load.empty()) { + auto pair = to_load.back(); + to_load.pop_back(); + auto id = std::get<0>(pair); + auto &node = *std::get<1>(pair); + + children_stmt->BindValues(id); + while(children_stmt->ExecuteStep()) { + std::uint64_t child_id; + unsigned short kind; + int is_production; + auto res = children_stmt->GetResult(); + res.Columns(child_id, kind, is_production); + auto &child_node = node.leaves[NodeKind::Deserialize(kind)]; + child_node.is_production = is_production; + to_load.emplace_back(child_id, &child_node); + } + } +} + +static constexpr const char* tokens_schema = + "CREATE TABLE IF NOT EXISTS 'mx::syntex::Tokens'(spelling, kind, PRIMARY KEY(spelling))"; + +void GrammarImpl::SerializeTokens(sqlite::Connection& db) +{ + db.Execute(tokens_schema); + auto stmt = db.Prepare( + "INSERT OR IGNORE INTO 'mx::syntex::Tokens'(spelling, kind) VALUES (?1, ?2)"); + for (auto &[spelling, kind] : tokens) { + stmt->BindValues(spelling, static_cast(kind)); + stmt->Execute(); + } +} + +void GrammarImpl::DeserializeTokens(sqlite::Connection& db) +{ + db.Execute(tokens_schema); + auto stmt = db.Prepare("SELECT spelling, kind FROM 'mx::syntex::Tokens'"); + while(stmt->ExecuteStep()) { + std::string spelling; + unsigned short kind; + auto res = stmt->GetResult(); + res.Columns(spelling, kind); + tokens[spelling] = static_cast(kind); + } +} + +// Determine the kind of an identifier based on its spelling +std::optional GrammarImpl::TokenKindOf(std::string_view spelling) const { + auto it = tokens.find(std::string(spelling)); + if (it != tokens.end()) { + return it->second; + } + return std::nullopt; +} + +Grammar::Grammar(const mx::Index &index, std::filesystem::path grammar_dir) + : impl(std::make_shared(index, grammar_dir)) {} + +} // namespace syntex +} // namespace mx \ No newline at end of file diff --git a/lib/Query/Grammar.h b/lib/Query/Grammar.h new file mode 100644 index 000000000..3dbbe503a --- /dev/null +++ b/lib/Query/Grammar.h @@ -0,0 +1,95 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. +// All rights reserved. +// +// This source code is licensed in accordance with the terms specified in +// the LICENSE file found in the root directory of this source tree. + +#pragma once + +#include "NodeKind.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace mx { + +class Index; + +namespace syntex { + +class AST; + +struct GrammarNode; + +// +// One set of grammar leaves +// FIXME(frabert): Deserialization crashes if this is turned into +// an `absl::flat_hash_map` +// +using GrammarLeaves = std::unordered_map; + +// +// Node in the grammar tree +// + +struct GrammarNode { + // Does this node correspond to the head of a production + bool is_production; + // Further leaves + GrammarLeaves leaves; +}; + +// +// Persistent CFG synthesized from a set of multiplier fragments +// + +class GrammarImpl { +private: + friend class Item; + friend class ParsedQuery; + friend class ParsedQueryImpl; + + // Multiplier index corresponding to this grammar + const mx::Index &index; + + // Grammar storage directory + std::filesystem::path db_path; + + // Mapping of spellings to token kinds + absl::flat_hash_map tokens; + + // Root of the grammar tree + GrammarLeaves root; + +public: + GrammarImpl(const mx::Index &index, std::filesystem::path db_path); + + ~GrammarImpl(void); + + // Import a fragment into the grammar. This extends the persisted grammar with + // the features from this fragment. + void Import(mx::RawEntityId fragment_id); + + void Import(const AST &ast); + + // Determine the kind of an identifier based on its spelling + std::optional TokenKindOf(std::string_view spelling) const; + + // Pretty print rules for debugging + void DebugRules(std::ostream &os); + + // Database grammar serialization + void SerializeRules(sqlite::Connection& db); + void DeserializeRules(sqlite::Connection& db); + + void SerializeTokens(sqlite::Connection& db); + void DeserializeTokens(sqlite::Connection& db); +}; + +} // namespace syntex +} // namespace mx \ No newline at end of file diff --git a/lib/Query/NodeKind.h b/lib/Query/NodeKind.h new file mode 100644 index 000000000..7e4ef7d04 --- /dev/null +++ b/lib/Query/NodeKind.h @@ -0,0 +1,133 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. +// All rights reserved. +// +// This source code is licensed in accordance with the terms specified in +// the LICENSE file found in the root directory of this source tree. + +#pragma once + +#include +#include +#include + +namespace mx { +namespace syntex { + +// +// NodeKind: Core class of Syntex, represents the following things: +// - An entry in a grammar rule +// - Kind of node in a multiplier AST +// - Kind of node in a query AST +// + +class NodeKind { +private: + unsigned short val; + + NodeKind(unsigned short val_) : val(val_) {} + +public: + static NodeKind Any() { + return NodeKind(UpperLimit()); + } + + NodeKind(mx::DeclKind kind) + : val(static_cast(kind)) {} + + NodeKind(mx::StmtKind kind) + : val(static_cast(kind) + + mx::NumEnumerators(mx::DeclKind{})) {} + + NodeKind(mx::TokenKind kind) + : val(static_cast(kind) + + mx::NumEnumerators(mx::DeclKind{}) + + mx::NumEnumerators(mx::StmtKind{})) {} + + template + auto Visit(T visitor) const { + if (val < mx::NumEnumerators(mx::DeclKind{})) { + return visitor(static_cast(val)); + } else if (val < mx::NumEnumerators(mx::DeclKind{}) + + mx::NumEnumerators(mx::StmtKind{})) { + return visitor(static_cast(val + - mx::NumEnumerators(mx::DeclKind{}))); + } else if (val < UpperLimit()) { + return visitor(static_cast(val + - mx::NumEnumerators(mx::DeclKind{}) + - mx::NumEnumerators(mx::StmtKind{}))); + } else { + return visitor(); + } + } + + bool IsToken() const { + return val >= mx::NumEnumerators(mx::DeclKind{}) + + mx::NumEnumerators(mx::StmtKind{}); + } + + mx::TokenKind AsToken() const { + assert(IsToken()); + return static_cast(val + - mx::NumEnumerators(mx::DeclKind{}) + - mx::NumEnumerators(mx::StmtKind{})); + } + + bool operator==(const NodeKind &other) const { + return val == other.val; + } + + static NodeKind Deserialize(unsigned short val) { + return val; + } + + unsigned short Serialize() const { + return val; + } + + static constexpr unsigned short UpperLimit() { + return mx::NumEnumerators(mx::DeclKind{}) + + mx::NumEnumerators(mx::StmtKind{}) + + mx::NumEnumerators(mx::TokenKind{}); + } +}; + +// +// This template (and deduction guide) allows for the easy generation of nice +// looking visitors from a set of lambdas. See the operator<< implementation +// below as an example usecase. +// + +template +struct Visitor : F ... { + using F::operator() ...; +}; + +template Visitor(F...) -> Visitor; + +// +// Pretty print a NodeKind to an output stream +// + +inline std::ostream& operator<<(std::ostream &os, const NodeKind &kind) { + kind.Visit(Visitor { + [&] (mx::DeclKind kind) { os << "DeclKind::" << EnumeratorName(kind); }, + [&] (mx::StmtKind kind) { os << "StmtKind::" << EnumeratorName(kind); }, + [&] (mx::TokenKind kind) { os << "TokenKind::" << EnumeratorName(kind); }, + [&] () { os << "NodeKind::Any"; }, + }); + return os; +} + +} // namespace syntex +} // namespace mx + +namespace std { + +template<> +struct hash { + size_t operator()(const mx::syntex::NodeKind &kind) const { + return kind.Serialize(); + } +}; + +} // namespace std \ No newline at end of file diff --git a/lib/Query/Query.cpp b/lib/Query/Query.cpp new file mode 100644 index 000000000..4294eb32f --- /dev/null +++ b/lib/Query/Query.cpp @@ -0,0 +1,918 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. +// All rights reserved. +// +// This source code is licensed in accordance with the terms specified in +// the LICENSE file found in the root directory of this source tree. + +#include "AST.h" +#include "Query.h" + +#include +#include +#include + +namespace mx { +namespace syntex { + +template +void Tokenize(TokenCallback token_callback, MetavarCallback metavar_callback, + VarargCallback vararg_callback, std::string_view input, size_t index) { + size_t end = index; + + auto Look = [&] (size_t i) -> int { + if (end + i < input.size()) + return input[end + i]; + else + return -1; + }; + + auto Eat = [&] (size_t cnt) { + end += cnt; + }; + + auto Get = [&] () { + int ch = Look(0); + if (ch != -1) + Eat(1); + return ch; + }; + + auto Match = [&] (char ch) { + if (Look(0) == ch) { + Eat(1); + return true; + } + return false; + }; + + auto MatchSpace = [&] () { + switch (Look(0)) { + case ' ': + case '\f': + case '\n': + case '\r': + case '\t': + case '\v': + Eat(1); + return true; + default: + return false; + } + }; + + auto MatchIdent = [&] () { + switch (Look(0)) { + case '_': + case 'a' ... 'z': + case 'A' ... 'Z': + case '0' ... '9': + Eat(1); + return true; + default: + return false; + } + }; + + auto MatchDecimal = [&] () { + switch (Look(0)) { + case '0' ... '9': + Eat(1); + return true; + default: + return false; + } + }; + + auto MatchHex = [&] () { + switch (Look(0)) { + case '0' ... '9': + case 'a' ... 'f': + case 'A' ... 'F': + Eat(1); + return true; + default: + return false; + } + }; + + auto MatchOct = [&] () { + switch (Look(0)) { + case '0' ... '7': + Eat(1); + return true; + default: + return false; + } + }; + + auto MatchDecimalExponent = [&] () { + if (Match('e') || Match('E')) { + Match('+') || Match('-'); + return true; + } + return false; + }; + + auto MatchHexExponent = [&] () { + if (Match('p') || Match('P')) { + Match('+') || Match('-'); + return true; + } + return false; + }; + + auto MatchIntegerSuffix = [&] () { + if (Match('l') || Match('L')) { + if (Match('l') || Match('L')) { + if (Match('u') || Match('U')) { // llu + // unsigned long long + } else { // ll + // long long + } + } else if (Match('u') || Match('U')) { // lu + // unsigned long + } else { // l + // long + } + } else if (Match('u') || Match('U')) { + if (Match('l') || Match('L')) { + if (Match('l') || Match('L')) { // ull + // unsigned long long + } else { // ul + // unsigned long + } + } else { // u + // unsigned int + } + } else { // + // int + } + }; + + auto MatchFloatingSuffix = [&] () { + if (Match('f') || Match('F')) { // f + // float + } else if (Match('l') || Match('L')) { // l + // long double + } else { // + // double + } + }; + + // Skip all whitespaces that might preceed the token + while (MatchSpace()) + ; + + // The spelling starts after skipping preceding whitespaces + size_t begin = end; + + // Add a token to the output + auto AddToken = [&] (size_t len, mx::TokenKind kind) { + size_t next = begin + len; + for (;;) + switch (Look(next - end)) { + case ' ': + case '\f': + case '\n': + case '\r': + case '\t': + case '\v': + ++next; + break; + default: + token_callback(kind, input.substr(begin, len), next); + return; + } + }; + + auto AddMetavar = [&] (std::string_view name, NodeKind filter) { + size_t next = end; + for (;;) + switch (Look(next - end)) { + case ' ': + case '\f': + case '\n': + case '\r': + case '\t': + case '\v': + ++next; + break; + default: + metavar_callback(name, filter, next); + return; + } + }; + + auto AddVararg = [&] () { + size_t next = end; + for (;;) + switch (Look(next - end)) { + case ' ': + case '\f': + case '\n': + case '\r': + case '\t': + case '\v': + ++next; + break; + default: + vararg_callback(next); + return; + } + }; + + switch (Get()) { + // End of input + case -1: + break; + + // + // For identifiers and constants, the longest match is always consumed + // + + // Metavariable + case '$': + { + // Check for variable argument + if (Look(0) == '.' && Look(1) == '.' && Look(2) == '.') { + Eat(3); + AddVararg(); + break; + } + + // Skip over the name + while (MatchIdent()) + ; + + auto name = input.substr(begin + 1, end - begin - 1); + NodeKind filter = NodeKind::Any(); + + // Skip over filter if present + if (Match(':')) { + size_t filter_begin = end; + while (MatchIdent()) + ; + + auto filter_str = input.substr(filter_begin, end - filter_begin); + + // Try to parse it as a DeclKind + // FIXME: this should probably be done with some kind of LUT over this + // slow mess + for (int i = 0; i < NumEnumerators(mx::DeclKind::TYPE); ++i) { + auto kind = static_cast(i); + if (EnumeratorName(kind) == filter_str) { + filter = NodeKind(kind); + goto done_filters; + } + } + for (int i = 0; i < NumEnumerators(mx::StmtKind::NULL_STMT); ++i) { + auto kind = static_cast(i); + if (EnumeratorName(kind) == filter_str) { + filter = NodeKind(kind); + goto done_filters; + } + } + for (int i = 0; i < NumEnumerators(mx::TokenKind::UNKNOWN); ++i) { + auto kind = static_cast(i); + if (EnumeratorName(kind) == filter_str) { + filter = NodeKind(kind); + goto done_filters; + } + } + + assert("FIXME: return proper error for invalid filter" && false); + +done_filters:; + } + + AddMetavar(name, filter); + break; + } + + // Identifiers + case '_': + case 'a' ... 'z': + case 'A' ... 'Z': + while (MatchIdent()) + ; + + AddToken(end - begin, mx::TokenKind::IDENTIFIER); + break; + + // Numeric constants + case '0': + if (Match('.')) { + while (MatchDecimal()) + ; + + if (MatchDecimalExponent()) { + while (MatchDecimal()) + ; + } + + MatchFloatingSuffix(); + } else if (Match('x') || Match('X')) { + while (MatchHex()) + ; + + if (Match('.')) { + while (MatchHex()) + ; + + if (MatchHexExponent()) { + while(MatchHex()) + ; + } + MatchFloatingSuffix(); + } else if (MatchHexExponent()) { + while (MatchHex()) + ; + + MatchFloatingSuffix(); + } else { + MatchIntegerSuffix(); + } + } else { + while (MatchOct()) + ; + + MatchIntegerSuffix(); + } + + AddToken(end - begin, mx::TokenKind::NUMERIC_CONSTANT); + break; + + case '1' ... '9': + while (MatchDecimal()) + ; + + if (Match('.')) { +FractionalConstant: + while (MatchDecimal()) + ; + + if (MatchDecimalExponent()) { + while (MatchDecimal()) + ; + } + + MatchFloatingSuffix(); + } else if (MatchDecimalExponent()) { + while (MatchDecimal()) + ; + MatchFloatingSuffix(); + } else { + MatchIntegerSuffix(); + } + + AddToken(end - begin, mx::TokenKind::NUMERIC_CONSTANT); + break; + + // Character constants + case '\'': + for (;;) { + auto ch = Get(); + if (ch == '\\') + Get(); + else if (ch == -1 || ch == '\'') + break; + } + + AddToken(end - begin, mx::TokenKind::CHARACTER_CONSTANT); + break; + + // String literals + case '"': + for (;;) { + auto ch = Get(); + if (ch == '\\') + Get(); + else if (ch == -1 || ch == '"') + break; + } + AddToken(end - begin, mx::TokenKind::STRING_LITERAL); + break; + + // + // For punctuators only the first character is consumed, and all possible + // matches at the current position are added + // + + case '[': + AddToken(1, mx::TokenKind::L_SQUARE); + break; + case ']': + AddToken(1, mx::TokenKind::R_SQUARE); + break; + case '(': + AddToken(1, mx::TokenKind::L_PARENTHESIS); + break; + case ')': + AddToken(1, mx::TokenKind::R_PARENTHESIS); + break; + case '{': + AddToken(1, mx::TokenKind::L_BRACE_TOKEN); + break; + case '}': + AddToken(1, mx::TokenKind::R_BRACE_TOKEN); + break; + case '~': + AddToken(1, mx::TokenKind::TILDE); + break; + case '?': + AddToken(1, mx::TokenKind::QUESTION); + break; + case ':': + AddToken(1, mx::TokenKind::COLON); + break; + case ';': + AddToken(1, mx::TokenKind::SEMI); + break; + case ',': + AddToken(1, mx::TokenKind::COMMA); + break; + case '.': + if (MatchDecimal()) { + goto FractionalConstant; + } else if (Look(0) == '.' && Look(1) == '.') { + AddToken(3, mx::TokenKind::ELLIPSIS); + } else { + AddToken(1, mx::TokenKind::PERIOD); + } + break; + case '-': + AddToken(1, mx::TokenKind::MINUS); + if (Look(0) == '>') + AddToken(2, mx::TokenKind::ARROW); + else if (Look(0) == '-') + AddToken(2, mx::TokenKind::MINUS_MINUS); + else if (Look(0) == '=') + AddToken(2, mx::TokenKind::MINUS_EQUAL); + break; + case '+': + AddToken(1, mx::TokenKind::PLUS); + if (Look(0) == '+') + AddToken(2, mx::TokenKind::PLUS_PLUS); + if (Look(0) == '=') + AddToken(2, mx::TokenKind::PLUS_EQUAL); + break; + case '&': + AddToken(1, mx::TokenKind::AMP); + if (Look(0) == '&') + AddToken(2, mx::TokenKind::AMP_AMP); + else if (Look(0) == '=') + AddToken(2, mx::TokenKind::AMP_EQUAL); + break; + case '*': + AddToken(1, mx::TokenKind::STAR); + if (Look(0) == '=') + AddToken(2, mx::TokenKind::STAR_EQUAL); + break; + case '!': + AddToken(1, mx::TokenKind::EXCLAIM); + if (Look(0) == '=') + AddToken(2, mx::TokenKind::EXCLAIM_EQUAL); + break; + case '/': + AddToken(1, mx::TokenKind::SLASH); + if (Look(0) == '=') + AddToken(2, mx::TokenKind::SLASH_EQUAL); + break; + case '%': + AddToken(1, mx::TokenKind::PERCENT); + if (Look(0) == '=') + AddToken(2, mx::TokenKind::PERCENT_EQUAL); + break; + case '<': + AddToken(1, mx::TokenKind::LESS); + if (Look(0) == '<') { + AddToken(2, mx::TokenKind::LESS_LESS); + if (Look(1) == '=') + AddToken(3, mx::TokenKind::LESS_LESS_EQUAL); + } else if (Look(0) == '=') { + AddToken(1, mx::TokenKind::LESS_EQUAL); + } + break; + case '>': + AddToken(1, mx::TokenKind::GREATER); + if (Look(0) == '>') { + AddToken(2, mx::TokenKind::GREATER_GREATER); + if (Look(1) == '=') + AddToken(3, mx::TokenKind::GREATER_GREATER_EQUAL); + } else if (Look(0) == '=') { + AddToken(2, mx::TokenKind::GREATER_EQUAL); + } + break; + case '=': + AddToken(1, mx::TokenKind::EQUAL); + if (Look(0) == '=') + AddToken(2, mx::TokenKind::EQUAL_EQUAL); + break; + case '^': + AddToken(1, mx::TokenKind::CARET); + if (Look(0) == '=') + AddToken(2, mx::TokenKind::CARET_EQUAL); + break; + case '|': + AddToken(1, mx::TokenKind::PIPE); + if (Look(0) == '|') + AddToken(2, mx::TokenKind::PIPE_PIPE); + else if (Look(0) == '=') + AddToken(2, mx::TokenKind::PIPE_EQUAL); + break; + case '#': + AddToken(1, mx::TokenKind::HASH); + if (Look(0) == '#') + AddToken(2, mx::TokenKind::HASH_HASH); + break; + default: + AddToken(1, mx::TokenKind::UNKNOWN); + } +} + +ParsedQueryImpl::ParsedQueryImpl(const GrammarImpl &grammar, std::string_view input) + : m_grammar(grammar), m_input(input) {} + +void ParsedQueryImpl::MatchGlob(TableEntry &result, + const std::unordered_set &follow, + Item &item, + size_t next) { + + for (auto &[left, rest] : *item.m_leaves) { + // If: + // a) we reach the end of a production and have an empty follow set + // b) or the next non-terminal is contained in the follow set + // we should continue parsing normally. + if ((rest.is_production && follow.empty()) || follow.contains(left)) { + MatchRule(result, item, next); + } + + // If the next entry is a terminator, we don't need to glob further + // NOTE: for most usecases of $... this makes sense and improves performance + // but if some weird case doesn't match it might be necessary to remove this + if (left == mx::TokenKind::R_PARENTHESIS || + left == mx::TokenKind::R_BRACE_TOKEN) { + continue; + } + + // Otherwise the rest of the grammar rule is a candiate for more globbing + if (rest.leaves.size() > 0) { + const GrammarLeaves *old_leaves = item.m_leaves; + item.m_leaves = &rest.leaves; + item.m_children.emplace_back(NodeKind::Any(), next, Glob::YES); + MatchGlob(result, follow, item, next); + item.m_leaves = old_leaves; + item.m_children.pop_back(); + } + } +} + +void ParsedQueryImpl::MatchRule(TableEntry &result, Item &item, size_t next) { + // Iterate shifts + for (auto &[key, _] : ParsesAtIndex(next)) { + NodeKind kind = key.first; + size_t next = key.second; + item.IterateShifts(kind, next, Glob::NO, [&] (Item &item) { + MatchRule(result, item, next); + }); + } + + // Iterate glob shifts + if (auto it = m_globs.find(next); it != m_globs.end()) { + // Compute set of node kinds that can follow $... + std::unordered_set follow; + for (auto &[key, _] : ParsesAtIndex(it->second)) { + follow.insert(key.first); + } + MatchGlob(result, follow, item, it->second); + } + + // Iterate reductions + item.IterateReductions([&] (NodeKind kind, const auto &children) { + result[{kind, next}].emplace(children); + MatchPrefix(result, kind, next); + }); +} + +void ParsedQueryImpl::MatchPrefix(TableEntry &result, NodeKind kind, size_t next) { + Item(&m_grammar.root).IterateShifts(kind, next, Glob::NO, [&] (Item &item) { + MatchRule(result, item, next); + }); +} + +const ParsedQueryImpl::TableEntry &ParsedQueryImpl::ParsesAtIndex(size_t index) { + // Lookup memoized parses at this index + auto it = m_parses.find(index); + if (it != m_parses.end()) { + return it->second; + } + + // And only do computation if the lookup found nothing + auto &result = m_parses[index]; + + auto TokenCallback = [&] (mx::TokenKind lex_kind, std::string_view spelling, size_t next) { + if (auto grm_kind = m_grammar.TokenKindOf(spelling)) { + result[{*grm_kind, next}].emplace(spelling); + MatchPrefix(result, *grm_kind, next); + } else { + std::cerr << "Warning: token `" << spelling << "` not present in grammar\n"; + } + }; + + auto MetavarCallback = [&] (std::string_view name, NodeKind filter, size_t next) { + if (name == "") { + result[{filter, next}].emplace(nullptr); + } else { + auto [it, added] = m_metavars.emplace(name, Metavar(name, {})); + if (!added) { + std::cerr << "Error: duplicate metavariable name `" << name << "`\n"; + abort(); + } + result[{filter, next}].emplace(&it->second); + } + MatchPrefix(result, filter, next); + }; + + auto VarargCallback = [&] (size_t next) { + m_globs[index] = next; + }; + + Tokenize(TokenCallback, MetavarCallback, VarargCallback, m_input, index); + + return result; +} + +std::pair> ParsedQueryImpl::MatchMarker( + const TableEntry &entry, const ParseMarker &marker, const ASTNode *node) { + + std::vector metavar_matches; + + switch (marker.m_kind) { + case ParseMarker::METAVAR: + if (marker.m_metavar) { + MetavarMatch mv_match(marker.m_metavar->m_name, + node->Entity(), node->TokenRange()); + if (auto &predicate = marker.m_metavar->m_predicate) { + if (!(*predicate)(mv_match)) { + return {false, {}}; + } + } + metavar_matches.push_back(std::move(mv_match)); + } + return {true, metavar_matches}; + case ParseMarker::TERMINAL: + return {node->Kind().IsToken() && node->Spelling() == marker.m_spelling, + {}}; + case ParseMarker::NONTERMINAL: + if (node->Kind().IsToken() || + node->ChildVector().size() != marker.m_children.size()) { + return {false, {}}; + } + + auto child_entry = &entry; + auto child_it = marker.m_children.begin(); + + for (const ASTNode *child_node : node->ChildVector()) { + auto &[kind, next, glob] = *child_it; + + if (kind != NodeKind::Any() && kind != child_node->Kind()) { + return {false, {}}; + } + + if (glob == Glob::NO) { + auto markers = child_entry->find({ kind, next }); + if (markers == child_entry->end()) { + return {false, metavar_matches}; + } + for (auto &marker : markers->second) { + auto [child_ok, child_metavar_matches] = + MatchMarker(*child_entry, marker, child_node); + if (child_ok) { + metavar_matches.insert( + metavar_matches.end(), + child_metavar_matches.begin(), + child_metavar_matches.end()); + goto ok; + } + } + return {false, {}}; + } + ok: + child_entry = &ParsesAtIndex(std::get<1>(*child_it)); + ++child_it; + } + return {true, metavar_matches}; + } +} + +void ParsedQueryImpl::DebugParseTable(std::ostream &os) { + // Make sure the DP table was actually filled in + ParsesAtIndex(0); + + // Find all possible indices then sort them + std::vector indices; + for (auto &[index, _] : m_parses) { + indices.push_back(index); + } + std::sort(indices.begin(), indices.end()); + + // Then print all the parses at every index in the table + for (size_t index : indices) { + os << "Parses at " << index << ":\n"; + for (auto &[key, markers] : m_parses.at(index)) { + for (auto &marker : markers) { + // Print head + std::stringstream ss; + ss << " (" << key.first << ", " << key.second << ")"; + + std::cout << std::left << std::setw(40) << std::setfill(' ') + << ss.str() << " <- "; + + // Print body + switch (marker.m_kind) { + case ParseMarker::METAVAR: + std::cout << "$" << (marker.m_metavar ? marker.m_metavar->m_name : ""); + break; + case ParseMarker::TERMINAL: + std::cout << "`" << marker.m_spelling << "`"; + break; + case ParseMarker::NONTERMINAL: + for (auto &[kind, next, glob] : marker.m_children) { + if (glob == Glob::YES) { + std::cout << "(" << kind << ", " << next << ", ..." << ") "; + } else { + std::cout << "(" << kind << ", " << next << ") "; + } + } + break; + } + + std::cout << "\n"; + } + } + } +} + +ParsedQuery::ParsedQuery(const Grammar &grammar, std::string_view query) + : impl(std::make_shared(*grammar.impl, query)) {} + +bool ParsedQuery::IsValid(void) const { + for (auto &[key, markers] : impl->ParsesAtIndex(0)) { + if (key.second == impl->m_input.size()) { + return true; + } + } + return false; +} + +bool ParsedQuery::AddMetavarPredicate( + const std::string_view &name, + std::function predicate) { + + // Find metavariable name + auto it = impl->m_metavars.find(name); + if (it == impl->m_metavars.end()) { + return false; + } + + // Overwrite predicate + if (it->second.m_predicate) { + it->second.m_predicate = + [old_pred = std::move(it->second.m_predicate.value()), + new_pred = std::move(predicate)] (const MetavarMatch &mvm) -> bool { + return old_pred(mvm) && new_pred(mvm); + }; + + } else { + it->second.m_predicate = std::move(predicate); + } + + return true; +} + +std::vector ParsedQuery::FindInFragment( + mx::RawEntityId fragment_id) const { + if (auto frag = impl->m_grammar.index.fragment(fragment_id)) { + return Find(frag.value()); + } else { + return {}; + } +} + +void ParsedQuery::ForEachMatch(const mx::File &file, + std::function pred) const { + bool done = false; + auto real_pred = [sub_pred = std::move(pred), &done] (Match m) -> bool { + if (sub_pred(std::move(m))) { + return true; + } else { + done = true; + return false; + } + }; + + for (mx::Fragment frag : mx::Fragment::in(file)) { + ForEachMatch(frag, real_pred); + if (done) { + break; + } + } +} + +void ParsedQuery::ForEachMatch(std::function pred) const { + bool done = false; + auto real_pred = [sub_pred = std::move(pred), &done] (Match m) -> bool { + if (sub_pred(std::move(m))) { + return true; + } else { + done = true; + return false; + } + }; + + for (mx::File file : mx::File::in(impl->m_grammar.index)) { + for (mx::Fragment frag : mx::Fragment::in(file)) { + ForEachMatch(frag, real_pred); + if (done) { + break; + } + } + if (done) { + break; + } + } +} + +std::vector ParsedQuery::Find(const mx::Fragment &frag) const { + std::vector ret; + ForEachMatch(frag, [&ret] (Match m) -> bool { + ret.emplace_back(std::move(m)); + return true; + }); + return ret; +} + +std::vector ParsedQuery::Find(const mx::File &file) const { + std::vector ret; + ForEachMatch(file, [&ret] (Match m) -> bool { + ret.emplace_back(std::move(m)); + return true; + }); + return ret; +} + +std::vector ParsedQuery::Find(void) const { + std::vector ret; + for (mx::File file : mx::File::in(impl->m_grammar.index)) { + for (mx::Fragment frag : mx::Fragment::in(file)) { + ForEachMatch(frag, [&ret] (Match m) -> bool { + ret.emplace_back(std::move(m)); + return true; + }); + } + } + return ret; +} + +void ParsedQuery::ForEachMatch(const mx::Fragment &frag, + std::function pred) const { + + // Create AST for fragment + auto frag_ast = AST::Build(frag); + + // Find matching AST node + auto &entry = impl->ParsesAtIndex(0); + for (auto &[key, markers] : entry) { + if (key.second != impl->m_input.size()) { + continue; + } + if (key.first == NodeKind::Any()) { + for (auto &ast_node : frag_ast.AllNodes()) { + for (auto &marker : markers) { + auto [ok, metavar_matches] = impl->MatchMarker( + entry, marker, &ast_node); + if (ok && !pred(Match(frag, ast_node.Entity(), + ast_node.TokenRange(), + metavar_matches))) { + return; + } + } + } + } else { + for (auto *ast_node = frag_ast.WithKind(key.first); + ast_node; ast_node = ast_node->prev) { + for (auto &marker : markers) { + auto [ok, metavar_matches] = impl->MatchMarker(entry, marker, ast_node); + if (ok && !pred(Match(frag, ast_node->Entity(), + ast_node->TokenRange(), + metavar_matches))) { + return; + } + } + } + } + } +} + +} // namespace syntex +} // namespace mx diff --git a/lib/Query/Query.h b/lib/Query/Query.h new file mode 100644 index 000000000..30ffc2e24 --- /dev/null +++ b/lib/Query/Query.h @@ -0,0 +1,251 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. +// All rights reserved. +// +// This source code is licensed in accordance with the terms specified in +// the LICENSE file found in the root directory of this source tree. + +#pragma once + +#include "AST.h" +#include "NodeKind.h" +#include "Grammar.h" +#include +#include + +template +inline void hash_combine(size_t &h, const T& v) +{ + std::hash hasher; + h ^= hasher(v) + 0x9e3779b9 + (h << 6) + (h >> 2); +} + +template<> +struct std::hash> { + size_t operator()(const std::pair &self) const { + size_t hash = 0; + hash_combine(hash, self.first); + hash_combine(hash, self.second); + return hash; + } +}; + +namespace mx { +namespace syntex { + +struct Metavar { + std::string_view m_name; + std::optional> m_predicate; + + explicit Metavar(std::string_view name, + std::optional> predicate) + : m_name(name), m_predicate(std::move(predicate)) {} +}; + +enum class Glob { + NO, + YES +}; + +struct ParseMarker { + + // Node category + enum { + METAVAR, + TERMINAL, + NONTERMINAL, + } m_kind; + + // Associated data + union { + Metavar *m_metavar; + std::string_view m_spelling; + std::vector> m_children; + }; + + explicit ParseMarker(Metavar *metavar) + : m_kind(METAVAR), m_metavar(metavar) {} + + explicit ParseMarker(std::string_view spelling) + : m_kind(TERMINAL), m_spelling(spelling) {} + + explicit ParseMarker(const std::vector> &children) + : m_kind(NONTERMINAL), m_children(children) {} + + ParseMarker(ParseMarker &&other) + : m_kind(other.m_kind) + { + switch (m_kind) { + case METAVAR: + m_metavar = other.m_metavar; + break; + case TERMINAL: + new (&m_spelling) std::string_view(other.m_spelling); + break; + case NONTERMINAL: + new (&m_children) std::vector>(std::move(other.m_children)); + break; + } + } + + ~ParseMarker() { + switch (m_kind) { + case METAVAR: + break; + case TERMINAL: + m_spelling.std::string_view::~string_view(); + break; + case NONTERMINAL: + m_children.std::vector>::~vector(); + break; + } + } + + bool operator==(const ParseMarker &other) const { + if (m_kind != other.m_kind) { + return false; + } + + switch (m_kind) { + case METAVAR: + // NOTE: it's impossible to have two metavariables + // at the same input location, thus this is never called + assert(false); + abort(); + case TERMINAL: + return m_spelling == other.m_spelling; + case NONTERMINAL: + return m_children == other.m_children; + } + } +}; + +} // namespace syntex +} // namespace mx + +template<> +struct std::hash { + size_t operator()(const mx::syntex::ParseMarker &self) const { + size_t hash = 0; + hash_combine(hash, self.m_kind); + switch (self.m_kind) { + case mx::syntex::ParseMarker::METAVAR: + break; + case mx::syntex::ParseMarker::TERMINAL: + hash_combine(hash, self.m_spelling); + break; + case mx::syntex::ParseMarker::NONTERMINAL: + for (auto &[kind, next, glob] : self.m_children) { + hash_combine(hash, kind); + hash_combine(hash, next); + hash_combine(hash, glob); + } + break; + } + return hash; + } +}; + +namespace mx { +namespace syntex { + +// +// Parser state (e.g. a pointer into the grammar trie) +// + +struct Item { + const GrammarLeaves *m_leaves; + std::vector> m_children; + + explicit Item(const GrammarLeaves *leaves) + : m_leaves(leaves) + {} + + template + void IterateShifts(NodeKind kind, size_t next, Glob glob, F cb) { + if (kind == NodeKind::Any()) { + const GrammarLeaves *old_leaves = m_leaves; + m_children.emplace_back(kind, next, glob); + + for (auto &[kind, rest] : *m_leaves) { + if (rest.leaves.empty()) { + continue; + } + + m_leaves = &rest.leaves; + cb(*this); + } + + m_leaves = old_leaves; + m_children.pop_back(); + } else { + auto it = m_leaves->find(kind); + if (it == m_leaves->end() || it->second.leaves.empty()) { + return; + } + + // Morph ourselves into the shifted state + const GrammarLeaves *old_leaves = m_leaves; + m_leaves = &it->second.leaves; + m_children.emplace_back(kind, next, glob); + + // Fire callback with morphed item + cb(*this); + + // Restore item + m_leaves = old_leaves; + m_children.pop_back(); + } + } + + template + void IterateReductions(F cb) const { + for (auto &[left, rest] : *m_leaves) { + if (rest.is_production) { + cb(left, m_children); + } + } + } +}; + +// +// Wrapper around parsing functions +// + +struct ParsedQueryImpl { + // GrammarImpl to be processed + const GrammarImpl &m_grammar; + + // Input string + std::string_view m_input; + + // Main DP parse table + using TableEntry = std::unordered_map, + std::unordered_set>; + + std::unordered_map m_parses; + + // Metavariables + std::unordered_map m_metavars; + + // Globs + std::unordered_map m_globs; + + void MatchGlob(TableEntry &result, const std::unordered_set &follow, + Item &item, size_t next); + + void MatchRule(TableEntry &result, Item &item, size_t next); + + void MatchPrefix(TableEntry &result, NodeKind kind, size_t next); + + const TableEntry &ParsesAtIndex(size_t index); + + explicit ParsedQueryImpl(const GrammarImpl &grammar, std::string_view input); + + void DebugParseTable(std::ostream &os); + + std::pair> MatchMarker( + const TableEntry &entry, const ParseMarker &marker, const ASTNode *node); +}; + +} // namespace syntex +} // namespace mx \ No newline at end of file From 296e21804f29d84bfbe214562a4e39ba72303182 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Mon, 7 Nov 2022 12:57:27 +0100 Subject: [PATCH 2/8] Move query AST to database --- bin/Query/CMakeLists.txt | 5 +- bin/Query/PredicateExample.cpp | 37 ++-- bin/Query/SyntexQuery.cpp | 105 +++------- include/multiplier/Index.h | 27 +++ include/multiplier/IndexStorage.h | 67 +++++- include/multiplier/PersistentMap.h | 184 ++++++++++++++++- include/multiplier/SQLiteStore.h | 61 ++++-- include/multiplier/Syntex.h | 91 +------- lib/API/CMakeLists.txt | 4 + lib/API/CachingEntityProvider.cpp | 86 ++++++++ lib/API/CachingEntityProvider.h | 36 ++++ lib/API/Grammar.h | 37 ++++ lib/API/Index.cpp | 17 ++ lib/API/InvalidEntityProvider.cpp | 31 +++ lib/API/InvalidEntityProvider.h | 11 + lib/{Query => API}/NodeKind.h | 0 lib/{Query => API}/Query.cpp | 128 ++++-------- lib/{Query => API}/Query.h | 40 +++- lib/API/SQLiteEntityProvider.cpp | 62 ++++++ lib/API/SQLiteEntityProvider.h | 11 + lib/CMakeLists.txt | 1 - lib/Common/IndexStorage.cpp | 121 +++++++++++ lib/Common/SQLiteStore.cpp | 8 +- lib/Query/AST.cpp | 204 ------------------ lib/Query/AST.h | 126 ------------ lib/Query/CMakeLists.txt | 52 ----- lib/Query/Grammar.cpp | 319 ----------------------------- lib/Query/Grammar.h | 95 --------- 28 files changed, 872 insertions(+), 1094 deletions(-) create mode 100644 lib/API/Grammar.h rename lib/{Query => API}/NodeKind.h (100%) rename lib/{Query => API}/Query.cpp (86%) rename lib/{Query => API}/Query.h (84%) delete mode 100644 lib/Query/AST.cpp delete mode 100644 lib/Query/AST.h delete mode 100644 lib/Query/CMakeLists.txt delete mode 100644 lib/Query/Grammar.cpp delete mode 100644 lib/Query/Grammar.h diff --git a/bin/Query/CMakeLists.txt b/bin/Query/CMakeLists.txt index c495e8176..353e157c6 100644 --- a/bin/Query/CMakeLists.txt +++ b/bin/Query/CMakeLists.txt @@ -10,8 +10,9 @@ add_executable("syntex-query" "SyntexQuery.cpp") target_link_libraries("syntex-query" PRIVATE + gflags + glog::glog "mx-api" - "mx-syntex" ) install( @@ -29,6 +30,6 @@ add_executable("predicate-example" "PredicateExample.cpp") target_link_libraries("predicate-example" PRIVATE gflags + glog::glog "mx-api" - "mx-syntex" ) diff --git a/bin/Query/PredicateExample.cpp b/bin/Query/PredicateExample.cpp index 91a0549d3..d7e72a7a8 100644 --- a/bin/Query/PredicateExample.cpp +++ b/bin/Query/PredicateExample.cpp @@ -72,8 +72,9 @@ static std::optional IntegralTypeWidth(const mx::ValueDecl &decl) { return IntegralTypeWidth(decl.type()); } -static void HighlightMatch(std::ostream &os, mx::syntex::Match m) { - auto ref = mx::DeclRefExpr::from(std::get(m.MetavarMatch(0).Entity())); +static void HighlightMatch(std::ostream &os, mx::Index index, mx::syntex::Match m) { + auto stmt = std::get(index.entity(m.MetavarMatch(0).Entity())); + auto ref = mx::DeclRefExpr::from(stmt); if (!ref) { return; } @@ -89,7 +90,7 @@ static void HighlightMatch(std::ostream &os, mx::syntex::Match m) { } auto lit = mx::IntegerLiteral::from( - std::get(m.MetavarMatch(1).Entity())); + std::get(index.entity(m.MetavarMatch(1).Entity()))); if (!lit) { return; } @@ -117,16 +118,25 @@ static void HighlightMatch(std::ostream &os, mx::syntex::Match m) { return; } + auto entity = index.entity(m.Entity()); + auto fragment = index.fragment_containing(m.Entity()); auto builtin_type = mx::BuiltinType::from(var->type()); - os << "File ID: " << mx::File::containing(m.Fragment()).id() << '\n' - << "Fragment ID: " << m.Fragment().id().Pack() << '\n' - << "Token ID: " << m.FirstTokenId() << '\n' + mx::TokenRange tok_range; + if(std::holds_alternative(entity)) { + tok_range = std::get(entity); + } else if(std::holds_alternative(entity)) { + tok_range = std::get(entity).tokens(); + } else if(std::holds_alternative(entity)) { + tok_range = std::get(entity).tokens(); + } + os << "File ID: " << mx::File::containing(*fragment).id() << '\n' + << "Fragment ID: " << fragment->id().Pack() << '\n' << "Literal value: " << lit_val << '\n' << "Type size: " << type_size.value() << '\n' << "Type kind: " << mx::EnumeratorName(builtin_type->builtin_kind()) << "\nExpression:"; - for (mx::Token tok : m.TokenRange()) { + for (mx::Token tok : tok_range) { os << ' ' << tok.data(); } @@ -157,21 +167,18 @@ extern "C" int main(int argc, char *argv[]) { mx::Index index = mx::EntityProvider::in_memory_cache( mx::EntityProvider::from_database(FLAGS_db)); - mx::syntex::Grammar grammar(index, FLAGS_db); // Setup query - mx::syntex::ParsedQuery parsed_query(grammar, "$var:DECL_REF_EXPR << $num:INTEGER_LITERAL"); - if (!parsed_query.IsValid()) { + auto res = index.query_syntex("$var:DECL_REF_EXPR << $num:INTEGER_LITERAL"); + if (!res.has_value()) { return EXIT_FAILURE; } // Match fragments - - parsed_query.ForEachMatch([] (mx::syntex::Match match) { - HighlightMatch(std::cout, std::move(match)); - return true; - }); + for(auto match : res.value()) { + HighlightMatch(std::cout, index, std::move(match)); + } return EXIT_SUCCESS; } diff --git a/bin/Query/SyntexQuery.cpp b/bin/Query/SyntexQuery.cpp index ff2dc1851..ebf6463e8 100644 --- a/bin/Query/SyntexQuery.cpp +++ b/bin/Query/SyntexQuery.cpp @@ -21,53 +21,43 @@ DEFINE_string(query, "", "Use argument value as query"); DEFINE_uint64(threads, 0, "Use this number of threads"); DEFINE_bool(suppress_output, false, "Don't print matches to stdout"); -static std::mutex gMatchPrintingMutex; - -static void PrintMatch(const mx::syntex::Match &match) +static void PrintMatch(mx::Index index, const mx::syntex::Match &match) { if (FLAGS_suppress_output) { return; } - { - std::lock_guard guard(gMatchPrintingMutex); - - // Print matching fragment ID - std::cout << "Match in " << match.Fragment().id() << ":\n"; - - for (auto token : match.Fragment().parsed_tokens()) { - if (token.id() == match.FirstTokenId()) { - // Switch to ANSI red for the first matching token - std::cout << ANSI_RED; - } - - std::cout << token.data() << " "; + auto entity = index.entity(match.Entity()); + auto fragment = *index.fragment_containing(match.Entity()); + mx::TokenRange tok_range; + if(std::holds_alternative(entity)) { + tok_range = std::get(entity); + } else if(std::holds_alternative(entity)) { + tok_range = std::get(entity).tokens(); + } else if(std::holds_alternative(entity)) { + tok_range = std::get(entity).tokens(); + } - if (token.id() == match.LastTokenId()) { - // Reset color after last matching token - std::cout << ANSI_RESET; - } + // Print matching fragment ID + std::cout << "Match in " << fragment.id() << ":\n"; + for (auto token : fragment.parsed_tokens()) { + if (token.id() == tok_range.front().id()) { + // Switch to ANSI red for the first matching token + std::cout << ANSI_RED; } - std::cout << "\n"; + std::cout << token.data() << " "; - for (auto &metavar : match.MetavarMatches()) { - std::cout << "Matching metavar " << metavar.Name() << "\n"; + if (token.id() == tok_range.back().id()) { + // Reset color after last matching token + std::cout << ANSI_RESET; } } -} -static void ProcessFragmentRange(const mx::syntex::ParsedQuery &parsed_query, - const mx::RawEntityId *begin, - const mx::RawEntityId *end) -{ - for (; begin < end; ++begin) { - std::vector matches = - parsed_query.FindInFragment(*begin); + std::cout << "\n"; - for (const mx::syntex::Match &match : matches) { - PrintMatch(match); - } + for (auto &metavar : match.MetavarMatches()) { + std::cout << "Matching metavar " << metavar.Name() << "\n"; } } @@ -94,57 +84,18 @@ extern "C" int main(int argc, char *argv[]) { // Setup index and grammar mx::Index index = mx::EntityProvider::from_database(FLAGS_db); - mx::syntex::Grammar grammar(index, FLAGS_db); // Parse query - mx::syntex::ParsedQuery parsed_query(grammar, FLAGS_query); + auto res = index.query_syntex(FLAGS_query); - if (!parsed_query.IsValid()) { + if (!res.has_value()) { std::cerr << "Query `" << FLAGS_query << "` has no valid parses\n"; return EXIT_FAILURE; } - // Choose number of threads - - size_t threads = FLAGS_threads ?: std::thread::hardware_concurrency(); - std::cout << "starting matcher with " << threads << " threads\n"; - - // Collect all fragments to process - - std::vector fragment_ids; - - for (const mx::File &file : mx::File::in(index)) { - for (mx::RawEntityId fragment_id : file.fragment_ids()) { - fragment_ids.push_back(fragment_id); - } - } - - // Find the ideal number of fragments per thread - - size_t fragments_per_thread = fragment_ids.size() / threads; - - - // Create workers - - std::vector thread_pool; - - auto cur = &fragment_ids.front(); - auto last = &fragment_ids.back(); - - while (cur < last) { - auto end = cur + fragments_per_thread; - if (end > last) { - end = last; - } - thread_pool.emplace_back(ProcessFragmentRange, parsed_query, cur, end); - cur = end; - } - - // Wait for all workers to finish - - for (auto &thread : thread_pool) { - thread.join(); + for(auto match : *res) { + PrintMatch(index, match); } return EXIT_SUCCESS; diff --git a/include/multiplier/Index.h b/include/multiplier/Index.h index b046a56b8..329aa7788 100644 --- a/include/multiplier/Index.h +++ b/include/multiplier/Index.h @@ -43,6 +43,17 @@ class WeggliQueryMatch; class WeggliQueryResultIterator; class WeggliQueryResult; class WeggliQueryResultImpl; +struct ASTNode; + +namespace syntex { +class Match; +class NodeKind; +class GrammarNode; +class ParsedQuery; +class ParsedQueryImpl; + +using GrammarLeaves = std::unordered_map; +} using DeclUse = Use; using StmtUse = Use; @@ -102,6 +113,8 @@ class EntityProvider { friend class UseIteratorImpl; friend class WeggliQueryResultImpl; friend class WeggliQueryResultIterator; + friend class syntex::ParsedQuery; + friend class syntex::ParsedQueryImpl; protected: @@ -172,6 +185,17 @@ class EntityProvider { virtual void FindSymbol(const Ptr &, std::string name, mx::DeclCategory category, std::vector &ids_out) = 0; + + virtual std::optional + TokenKindOf(std::string_view spelling) = 0; + + virtual void LoadGrammarRoot(syntex::GrammarLeaves &root) = 0; + + virtual std::vector GetFragmentsInAST(void) = 0; + virtual ASTNode GetASTNode(std::uint64_t id) = 0; + virtual std::vector GetASTNodeChildren(std::uint64_t id) = 0; + virtual std::vector GetASTNodesInFragment(RawEntityId frag) = 0; + virtual std::optional GetASTNodeWithKind(RawEntityId frag, unsigned short kind) = 0; }; using VariantEntity = std::variant> query_syntex(std::string_view query) const; + std::optional> query_syntex(FragmentId frag, std::string_view query) const; }; } // namespace mx diff --git a/include/multiplier/IndexStorage.h b/include/multiplier/IndexStorage.h index b32ab0ba3..10c520e72 100644 --- a/include/multiplier/IndexStorage.h +++ b/include/multiplier/IndexStorage.h @@ -28,7 +28,11 @@ enum : char { kEntityIdToMangledName, kMangledNameToEntityId, kEntityIdUseToFragmentId, - kEntityIdReference + kEntityIdReference, + kSpellingToTokenKind, + kGrammarRoot, + kGrammarNodes, + kGrammarChildren, }; enum MetadataName : char { @@ -58,6 +62,52 @@ enum MetadataName : char { kIndexingVersion, }; +struct ASTNode { + std::optional prev; + unsigned short kind; + RawEntityId entity; + std::optional spelling; +}; + +class PersistentAST final { + sqlite::Connection &db; + std::shared_ptr get_root_stmt; + std::shared_ptr create_node_stmt; + std::shared_ptr add_root_stmt; + std::shared_ptr get_node_stmt; + std::shared_ptr get_index_stmt; + std::shared_ptr set_index_stmt; + std::shared_ptr get_fragments_stmt; + std::shared_ptr get_children_stmt; + std::shared_ptr add_child_stmt; + + public: + PersistentAST(sqlite::Connection &db); + + std::vector Root(RawEntityId fragment); + + std::uint64_t AddNode(const ASTNode& node); + + void AddNodeToRoot(RawEntityId fragment, std::uint64_t node_id); + + ASTNode GetNode(std::uint64_t node_id); + + std::optional GetNodeInIndex( + RawEntityId fragment, + unsigned short kind); + + void SetNodeInIndex( + RawEntityId fragment, + unsigned short kind, + std::uint64_t node_id); + + std::vector GetFragments(); + + std::vector GetChildren(std::uint64_t parent); + + void AddChild(std::uint64_t parent, std::uint64_t child); +}; + class IndexStorage final { sqlite::Connection &db; @@ -160,6 +210,21 @@ class IndexStorage final { mx::PersistentSet entity_id_reference; + mx::PersistentMap + spelling_to_token_kind; + + mx::PersistentMap + grammar_root; + + mx::PersistentMap + grammar_nodes; + + mx::PersistentMap2 + grammar_children; + + PersistentAST ast; + // SQLite database. Used for things like symbol searches. SymbolDatabase database; diff --git a/include/multiplier/PersistentMap.h b/include/multiplier/PersistentMap.h index e50c87293..99ec88677 100644 --- a/include/multiplier/PersistentMap.h +++ b/include/multiplier/PersistentMap.h @@ -34,6 +34,10 @@ static constexpr const char* table_names[] = { "'mx::MangledNameToEntityId'", "'mx::EntityIdUseToFragmentId'", "'mx::EntityIdReference'", + "'mx::syntex::Tokens'", + "'mx::syntex::GrammarRoot'", + "'mx::syntex::GrammarNodes'", + "'mx::syntex::GrammarChildren'", }; template @@ -106,7 +110,8 @@ class PersistentSet { db.Execute(ss.str()); ss = {}; - ss << "INSERT OR IGNORE INTO " << table_names[kId] << '(' << table_desc.str() << ") VALUES("; + ss << "INSERT OR IGNORE INTO " << table_names[kId] + << '(' << table_desc.str() << ") VALUES("; for(size_t i = 0; i < sizeof...(Keys); ++i) { ss << "?" << (i + 1); if(i != sizeof...(Keys) - 1) { @@ -139,7 +144,8 @@ class PersistentSet { for(size_t i = 0; i < sizeof...(Keys); ++i) { ss = {}; - ss << "SELECT " << table_desc.str() << " FROM " << table_names[kId] << " WHERE "; + ss << "SELECT " << table_desc.str() + << " FROM " << table_names[kId] << " WHERE "; for(size_t j = 0; j <= i; j++) { ss << "key" << j << " = ?" << (j + 1); if(j != i) { @@ -188,31 +194,87 @@ class PersistentSet { } }; +template +class Iterator { + private: + std::shared_ptr stmt; + std::tuple value; + + template + void Read(std::index_sequence) { + auto res = stmt->GetResult(); + res.Columns(std::get(value)...); + } + + public: + Iterator(std::shared_ptr stmt) + : stmt(std::move(stmt)) { + this->operator++(); + } + + bool operator==(const Iterator& b) const { + return stmt == b.stmt; + } + + bool operator!=(const Iterator& b) const { + return stmt != b.stmt; + } + + Iterator& operator++(void) { + if(!stmt->ExecuteStep()) { + stmt = nullptr; + return *this; + } + + Read(std::make_index_sequence()); + return *this; + } + + const std::tuple &operator*(void) const { + return value; + } + + const std::tuple *operator->(void) const { + return &value; + } +}; + // Persistent mapping from keys to values. template class PersistentMap { private: sqlite::Connection &db; - std::shared_ptr set_stmt, get_stmt, get_or_set_stmt; + std::shared_ptr set_stmt; + std::shared_ptr get_stmt; + std::shared_ptr get_or_set_stmt; + std::shared_ptr enum_stmt; public: PersistentMap(sqlite::Connection &db) : db(db) { std::stringstream ss; - ss << "CREATE TABLE IF NOT EXISTS " << table_names[kId] << "(key, value, PRIMARY KEY(key))"; + ss << "CREATE TABLE IF NOT EXISTS " << table_names[kId] + << "(key, value, PRIMARY KEY(key))"; db.Execute(ss.str()); ss = {}; - ss << "INSERT OR REPLACE INTO " << table_names[kId] << "(key, value) VALUES (?1, ?2)"; + ss << "INSERT OR REPLACE INTO " << table_names[kId] + << "(key, value) VALUES (?1, ?2)"; set_stmt = db.Prepare(ss.str()); ss = {}; - ss << "SELECT key, value FROM " << table_names[kId] << " WHERE key = ?1"; + ss << "SELECT key, value FROM " << table_names[kId] + << " WHERE key = ?1"; get_stmt = db.Prepare(ss.str()); ss = {}; ss << "INSERT INTO " << table_names[kId] - << "(key, value) VALUES(?1, ?2) ON CONFLICT DO UPDATE SET value=value RETURNING key, value"; + << "(key, value) VALUES(?1, ?2) " + << "ON CONFLICT DO UPDATE SET value=value RETURNING key, value"; get_or_set_stmt = db.Prepare(ss.str()); + + ss = {}; + ss << "SELECT key, value FROM " << table_names; + enum_stmt = db.Prepare(ss.str()); } V GetOrSet(K key, V val) const { @@ -244,6 +306,114 @@ class PersistentMap { return std::nullopt; } + + Iterator begin() { + return Iterator(enum_stmt); + } + + Iterator end() { + return Iterator(nullptr); + } +}; + +template +class PersistentMap2 { + private: + sqlite::Connection &db; + std::shared_ptr set_stmt; + std::shared_ptr get_stmt; + std::shared_ptr get_or_set_stmt; + std::shared_ptr enum_stmt; + std::shared_ptr enum_k1_stmt; + std::shared_ptr enum_k2_stmt; + + public: + PersistentMap2(sqlite::Connection &db) : db(db) { + std::stringstream ss; + ss << "CREATE TABLE IF NOT EXISTS " + << table_names[kId] << "(key1, key2, value, PRIMARY KEY(key1, key2))"; + db.Execute(ss.str()); + + ss = {}; + ss << "INSERT OR REPLACE INTO " + << table_names[kId] << "(key1, key2, value) VALUES (?1, ?2, ?3)"; + set_stmt = db.Prepare(ss.str()); + + ss = {}; + ss << "SELECT key1, key2, value FROM " + << table_names[kId] << " WHERE key1 = ?1 AND key2 = ?2"; + get_stmt = db.Prepare(ss.str()); + + ss = {}; + ss << "INSERT INTO " << table_names[kId] + << "(key1, key2, value) VALUES(?1, ?2, ?3) " + << "ON CONFLICT DO UPDATE SET value=value RETURNING key1, key2, value"; + get_or_set_stmt = db.Prepare(ss.str()); + + ss = {}; + ss << "SELECT key1, key2, value FROM " << table_names; + enum_stmt = db.Prepare(ss.str()); + + ss = {}; + ss << "SELECT key1, key2, value FROM " << table_names[kId] + << " WHERE key1 = ?1"; + enum_k1_stmt = db.Prepare(ss.str()); + + ss = {}; + ss << "SELECT key1, key2, value FROM " << table_names[kId] + << " WHERE key2 = ?1"; + enum_k1_stmt = db.Prepare(ss.str()); + } + + V GetOrSet(K1 key1, K2 key2, V val) const { + get_or_set_stmt->BindValues(key1, key2, val); + get_or_set_stmt->ExecuteStep(); + auto res = get_or_set_stmt->GetResult(); + K1 stored_key1; + K2 stored_key2; + V stored_value; + res.Columns(stored_key1, stored_key2, stored_value); + get_or_set_stmt->ExecuteStep(); + return stored_value; + } + + void Set(K1 key1, K2 key2, V val) const { + set_stmt->BindValues(key1, key2, val); + set_stmt->Execute(); + } + + std::optional TryGet(K1 key1, K2 key2) const { + get_stmt->BindValues(key1, key2); + if(get_stmt->ExecuteStep()) { + K1 stored_key1; + K2 stored_key2; + V stored_value; + auto res = get_stmt->GetResult(); + res.Columns(stored_key1, stored_key2, stored_value); + get_stmt->ExecuteStep(); + return stored_value; + } + + return std::nullopt; + } + + Iterator begin() { + return Iterator(enum_stmt); + } + + Iterator key1_equals(K1 key) { + enum_k1_stmt->BindValues(key); + return Iterator(enum_k1_stmt); + } + + Iterator key2_equals(K2 key) { + enum_k2_stmt->BindValues(key); + return Iterator(enum_k1_stmt); + } + + Iterator end() { + return Iterator(nullptr); + } }; } // namespace mx diff --git a/include/multiplier/SQLiteStore.h b/include/multiplier/SQLiteStore.h index 27655a1e7..7b460b1d0 100644 --- a/include/multiplier/SQLiteStore.h +++ b/include/multiplier/SQLiteStore.h @@ -34,6 +34,39 @@ class Error : public std::runtime_error { }; class QueryResult { + private: + void column_dispatcher(int& idx, std::string& arg) { + arg = getText(idx); + idx++; + } + + void column_dispatcher(int& idx, std::string_view& arg) { + arg = getBlob(idx); + idx++; + } + + void column_dispatcher(int& idx, std::nullopt_t& arg) { + idx++; + } + + template || std::is_enum_v>> + void column_dispatcher(int& idx, T& arg) { + arg = static_cast(getInt64(idx)); + idx++; + } + + template + void column_dispatcher(int& idx, std::optional& arg) { + if(isNull(idx)) { + arg = {}; + idx++; + return; + } + T value; + column_dispatcher(idx, value); + arg = value; + } + public: ~QueryResult() = default; @@ -50,23 +83,7 @@ class QueryResult { } int idx = 0; - auto column_dispatcher = [this, &idx] (auto &&arg) { - using arg_t = std::decay_t; - if constexpr (std::is_integral_v) { - arg = static_cast(getInt64(idx)); - } else if (std::is_same_v) { - arg = getText(idx); - } else if (std::is_same_v) { - arg = getBlob(idx); - } else if constexpr (std::is_same_v) { - ; - } else { - throw Error("Can't read column data; Type not supported!"); - } - idx++; - }; - - (column_dispatcher(std::forward(args)), ...); + (column_dispatcher(idx, std::forward(args)), ...); } private: @@ -81,6 +98,7 @@ class QueryResult { std::string getText(int32_t idx); std::string_view getBlob(int32_t idx); + bool isNull(int32_t idx); std::shared_ptr stmt; }; @@ -145,6 +163,15 @@ class Statement : public std::enable_shared_from_this { void bind(const size_t i, const std::string_view &value); + template + void bind(const size_t i, const std::optional &value) { + if(value.has_value()) { + bind(i, value.value()); + } else { + bind(i, nullptr); + } + } + void reset(); template diff --git a/include/multiplier/Syntex.h b/include/multiplier/Syntex.h index d95f634d0..7c8cabbc9 100644 --- a/include/multiplier/Syntex.h +++ b/include/multiplier/Syntex.h @@ -26,22 +26,6 @@ class Grammar; class GrammarImpl; class ParsedQuery; class ParsedQueryImpl; -class Match; - -// -// Handle to a persistent grammar -// - -class Grammar { -private: - friend class ParsedQuery; - - std::shared_ptr impl; - Grammar() = delete; - -public: - explicit Grammar(const mx::Index &index, std::filesystem::path grammar_dir); -}; // // Chunk of a fragment (potentially) matching a metavariable @@ -50,58 +34,20 @@ class Grammar { class MetavarMatch { private: std::string_view name; - mx::VariantEntity entity; - mx::TokenRange token_range; + mx::EntityId entity; public: - MetavarMatch(std::string_view name_, mx::VariantEntity entity_, - mx::TokenRange token_range_) + MetavarMatch(std::string_view name_, mx::EntityId entity_) : name(std::move(name_)), - entity(std::move(entity_)), - token_range(std::move(token_range_)) {} + entity(std::move(entity_)) {} const std::string_view &Name(void) const { return name; } - const mx::VariantEntity &Entity(void) const { + mx::EntityId Entity(void) const { return entity; } - - const mx::TokenRange &TokenRange(void) const { - return token_range; - } -}; - - -// -// Result of parsing a query -// - -class ParsedQuery { - private: - std::shared_ptr impl; - ParsedQuery(void) = delete; - - public: - explicit ParsedQuery(const Grammar &grammar, std::string_view query); - - bool IsValid() const; - - bool AddMetavarPredicate(const std::string_view &name, - std::function predicate); - - void ForEachMatch(const mx::Fragment &frag, - std::function pred) const; - void ForEachMatch(const mx::File &file, - std::function pred) const; - void ForEachMatch(std::function pred) const; - - std::vector Find(const mx::Fragment &frag) const; - std::vector Find(const mx::File &file) const; - std::vector Find(void) const; - - std::vector FindInFragment(mx::RawEntityId fragment_id) const; }; // @@ -112,40 +58,19 @@ class Match { private: friend class ParsedQuery; - mx::Fragment fragment; - mx::VariantEntity entity; - mx::TokenRange token_range; + mx::EntityId entity; std::vector metavars; public: - Match(mx::Fragment fragment_, mx::VariantEntity entity_, - mx::TokenRange token_range_, std::vector matevars_) - : fragment(std::move(fragment_)), - entity(std::move(entity_)), - token_range(std::move(token_range_)), + Match(mx::EntityId entity_, std::vector matevars_) + : entity(std::move(entity_)), metavars(std::move(matevars_)) {} - const mx::Fragment &Fragment(void) const { - return fragment; - } - - const mx::VariantEntity &Entity(void) const { + const mx::EntityId &Entity(void) const { return entity; } - const mx::TokenRange &TokenRange(void) const { - return token_range; - } - - mx::RawEntityId FirstTokenId(void) const { - return TokenRange().front().id(); - } - - mx::RawEntityId LastTokenId(void) const { - return TokenRange().back().id(); - } - const std::vector &MetavarMatches(void) const { return metavars; } diff --git a/lib/API/CMakeLists.txt b/lib/API/CMakeLists.txt index b0dfd7566..be68e7e99 100644 --- a/lib/API/CMakeLists.txt +++ b/lib/API/CMakeLists.txt @@ -47,12 +47,16 @@ add_library("mx-api" "Fragment.cpp" "FragmentImpl.cpp" "Fragment.h" + "Grammar.h" "Index.cpp" "InvalidEntityProvider.cpp" "InvalidEntityProvider.h" + "NodeKind.h" "PackedFileImpl.cpp" "PackedFragmentImpl.cpp" "PackedReaderState.cpp" + "Query.cpp" + "Query.h" "Re2.cpp" "Re2.h" "SQLiteEntityProvider.cpp" diff --git a/lib/API/CachingEntityProvider.cpp b/lib/API/CachingEntityProvider.cpp index 47865f110..a4e4ae505 100644 --- a/lib/API/CachingEntityProvider.cpp +++ b/lib/API/CachingEntityProvider.cpp @@ -4,8 +4,12 @@ // This source code is licensed in accordance with the terms specified in // the LICENSE file found in the root directory of this source tree. +#include "NodeKind.h" +#include "Grammar.h" #include "CachingEntityProvider.h" +#include + #include #include #include @@ -40,6 +44,13 @@ void CachingEntityProvider::ClearCacheLocked(unsigned new_version_number) { references.clear(); has_file_list = false; version_number = new_version_number; + spelling_to_token_kind.clear(); + grammar_root.clear(); + fragments_in_ast.clear(); + node_contents.clear(); + node_children.clear(); + fragment_nodes.clear(); + node_index.clear(); next->VersionNumberChanged(new_version_number); } @@ -260,4 +271,79 @@ EntityProvider::Ptr EntityProvider::in_memory_cache( return ret; } +std::optional +CachingEntityProvider::TokenKindOf(std::string_view spelling) { + std::string str{spelling.data(), spelling.size()}; + auto it = spelling_to_token_kind.find(str); + if(it == spelling_to_token_kind.end()) { + auto kind = next->TokenKindOf(spelling); + if(kind) { + spelling_to_token_kind[str] = *kind; + } + return kind; + } + return it->second; +} + +void CachingEntityProvider::LoadGrammarRoot(syntex::GrammarLeaves& root) { + if(grammar_root.empty()) { + next->LoadGrammarRoot(grammar_root); + } + root = grammar_root; +} + +std::vector CachingEntityProvider::GetFragmentsInAST(void) { + if(fragments_in_ast.empty()) { + fragments_in_ast = next->GetFragmentsInAST(); + } + return fragments_in_ast; +} + +ASTNode CachingEntityProvider::GetASTNode(std::uint64_t id) { + auto it = node_contents.find(id); + if(it == node_contents.end()) { + node_contents[id] = next->GetASTNode(id); + } + return node_contents[id]; +} + +std::vector CachingEntityProvider::GetASTNodeChildren(std::uint64_t id) { + if(node_children.find(id) == node_children.end()) { + for(auto child : next->GetASTNodeChildren(id)) { + node_children.insert({id, child}); + } + } + std::vector res; + for(auto [it, end] = node_children.equal_range(id); it != end; ++it) { + res.push_back(it->second); + } + return res; +} + +std::vector CachingEntityProvider::GetASTNodesInFragment(RawEntityId frag) { + if(fragment_nodes.find(frag) == fragment_nodes.end()) { + for(auto child : next->GetASTNodesInFragment(frag)) { + fragment_nodes.insert({frag, child}); + } + } + std::vector res; + for(auto [it, end] = fragment_nodes.equal_range(frag); it != end; ++it) { + res.push_back(it->second); + } + return res; +} + +std::optional CachingEntityProvider::GetASTNodeWithKind(RawEntityId frag, unsigned short kind) { + auto it = node_index.find({frag, kind}); + if(it == node_index.end()) { + auto value = next->GetASTNodeWithKind(frag, kind); + if(value.has_value()) { + node_index[{frag, kind}] = value.value(); + return value; + } + return {}; + } + return it->second; +} + } // namespace mx diff --git a/lib/API/CachingEntityProvider.h b/lib/API/CachingEntityProvider.h index 59439eb2a..3b4c719f5 100644 --- a/lib/API/CachingEntityProvider.h +++ b/lib/API/CachingEntityProvider.h @@ -14,6 +14,23 @@ #include #include +template +inline void hash_combine(size_t &h, const T& v) +{ + std::hash hasher; + h ^= hasher(v) + 0x9e3779b9 + (h << 6) + (h >> 2); +} + +template<> +struct std::hash> { + size_t operator()(const std::pair &self) const { + size_t hash = 0; + hash_combine(hash, self.first); + hash_combine(hash, self.second); + return hash; + } +}; + namespace mx { class CachingEntityProvider final : public EntityProvider { @@ -50,6 +67,15 @@ class CachingEntityProvider final : public EntityProvider { std::unordered_map>> references; + std::unordered_map spelling_to_token_kind; + syntex::GrammarLeaves grammar_root; + + std::vector fragments_in_ast; + std::unordered_map node_contents; + std::unordered_multimap node_children; + std::unordered_multimap fragment_nodes; + std::unordered_map, std::uint64_t> node_index; + void ClearCacheLocked(unsigned new_version_number); inline CachingEntityProvider(Ptr next_) @@ -98,6 +124,16 @@ class CachingEntityProvider final : public EntityProvider { mx::DeclCategory category, std::vector &ids_out) final; + std::optional + TokenKindOf(std::string_view spelling) final; + + void LoadGrammarRoot(syntex::GrammarLeaves &root) final; + + std::vector GetFragmentsInAST(void) final; + ASTNode GetASTNode(std::uint64_t id) final; + std::vector GetASTNodeChildren(std::uint64_t id) final; + std::vector GetASTNodesInFragment(RawEntityId frag) final; + std::optional GetASTNodeWithKind(RawEntityId frag, unsigned short kind) final; }; } // namespace mx diff --git a/lib/API/Grammar.h b/lib/API/Grammar.h new file mode 100644 index 000000000..3b4c4f9cb --- /dev/null +++ b/lib/API/Grammar.h @@ -0,0 +1,37 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. +// All rights reserved. +// +// This source code is licensed in accordance with the terms specified in +// the LICENSE file found in the root directory of this source tree. + +#pragma once + +#include "NodeKind.h" + +#include + +namespace mx { +namespace syntex { + +struct GrammarNode; + +// +// One set of grammar leaves +// FIXME(frabert): Deserialization crashes if this is turned into +// an `absl::flat_hash_map` +// +using GrammarLeaves = std::unordered_map; + +// +// Node in the grammar tree +// + +struct GrammarNode { + // Does this node correspond to the head of a production + bool is_production; + // Further leaves + GrammarLeaves leaves; +}; + +} // namespace syntex +} // namespace mx \ No newline at end of file diff --git a/lib/API/Index.cpp b/lib/API/Index.cpp index a2160ca73..7a912bf1e 100644 --- a/lib/API/Index.cpp +++ b/lib/API/Index.cpp @@ -6,6 +6,7 @@ #include "File.h" #include "Fragment.h" +#include "Query.h" #include #include #include @@ -329,4 +330,20 @@ NamedDeclList Index::query_entities( return decls; } +std::optional> Index::query_syntex(std::string_view query) const { + syntex::ParsedQuery parsed_query(impl, query); + if(!parsed_query.IsValid()) { + return std::nullopt; + } + return parsed_query.Find(); +} + +std::optional> Index::query_syntex(FragmentId frag, std::string_view query) const { + syntex::ParsedQuery parsed_query(impl, query); + if(!parsed_query.IsValid()) { + return std::nullopt; + } + return parsed_query.Find(frag.fragment_id); +} + } // namespace mx diff --git a/lib/API/InvalidEntityProvider.cpp b/lib/API/InvalidEntityProvider.cpp index 6764432b8..1071be68c 100644 --- a/lib/API/InvalidEntityProvider.cpp +++ b/lib/API/InvalidEntityProvider.cpp @@ -6,6 +6,10 @@ #include "InvalidEntityProvider.h" +#include "NodeKind.h" +#include "Grammar.h" +#include + namespace mx { InvalidEntityProvider::~InvalidEntityProvider(void) noexcept {} @@ -76,6 +80,33 @@ void InvalidEntityProvider::FindSymbol( ids_out.clear(); } +std::optional +InvalidEntityProvider::TokenKindOf(std::string_view spelling) { + return {}; +} + +void InvalidEntityProvider::LoadGrammarRoot(syntex::GrammarLeaves &) {} + +std::vector InvalidEntityProvider::GetFragmentsInAST(void) { + return {}; +} + +ASTNode InvalidEntityProvider::GetASTNode(std::uint64_t id) { + return {}; +} + +std::vector InvalidEntityProvider::GetASTNodeChildren(std::uint64_t id) { + return {}; +} + +std::vector InvalidEntityProvider::GetASTNodesInFragment(RawEntityId frag) { + return {}; +} + +std::optional InvalidEntityProvider::GetASTNodeWithKind(RawEntityId frag, unsigned short kind) { + return {}; +} + Index::Index(void) : impl(std::make_shared()) {} diff --git a/lib/API/InvalidEntityProvider.h b/lib/API/InvalidEntityProvider.h index 37a3087d8..5bfc4fbff 100644 --- a/lib/API/InvalidEntityProvider.h +++ b/lib/API/InvalidEntityProvider.h @@ -57,6 +57,17 @@ class InvalidEntityProvider final : public EntityProvider { void FindSymbol(const Ptr &, std::string name, mx::DeclCategory category, std::vector &ids_out) final; + + std::optional + TokenKindOf(std::string_view spelling) final; + + void LoadGrammarRoot(syntex::GrammarLeaves &root) final; + + std::vector GetFragmentsInAST(void) final; + ASTNode GetASTNode(std::uint64_t id) final; + std::vector GetASTNodeChildren(std::uint64_t id) final; + std::vector GetASTNodesInFragment(RawEntityId frag) final; + std::optional GetASTNodeWithKind(RawEntityId frag, unsigned short kind) final; }; } // namespace mx diff --git a/lib/Query/NodeKind.h b/lib/API/NodeKind.h similarity index 100% rename from lib/Query/NodeKind.h rename to lib/API/NodeKind.h diff --git a/lib/Query/Query.cpp b/lib/API/Query.cpp similarity index 86% rename from lib/Query/Query.cpp rename to lib/API/Query.cpp index 4294eb32f..bfc34ad14 100644 --- a/lib/Query/Query.cpp +++ b/lib/API/Query.cpp @@ -4,12 +4,12 @@ // This source code is licensed in accordance with the terms specified in // the LICENSE file found in the root directory of this source tree. -#include "AST.h" #include "Query.h" #include #include #include +#include namespace mx { namespace syntex { @@ -530,8 +530,10 @@ done_filters:; } } -ParsedQueryImpl::ParsedQueryImpl(const GrammarImpl &grammar, std::string_view input) - : m_grammar(grammar), m_input(input) {} +ParsedQueryImpl::ParsedQueryImpl(std::shared_ptr ep, std::string_view input) + : m_ep(std::move(ep)), m_input(input) { + ep->LoadGrammarRoot(grammar_root); +} void ParsedQueryImpl::MatchGlob(TableEntry &result, const std::unordered_set &follow, @@ -595,7 +597,7 @@ void ParsedQueryImpl::MatchRule(TableEntry &result, Item &item, size_t next) { } void ParsedQueryImpl::MatchPrefix(TableEntry &result, NodeKind kind, size_t next) { - Item(&m_grammar.root).IterateShifts(kind, next, Glob::NO, [&] (Item &item) { + Item(&grammar_root).IterateShifts(kind, next, Glob::NO, [&] (Item &item) { MatchRule(result, item, next); }); } @@ -611,7 +613,7 @@ const ParsedQueryImpl::TableEntry &ParsedQueryImpl::ParsesAtIndex(size_t index) auto &result = m_parses[index]; auto TokenCallback = [&] (mx::TokenKind lex_kind, std::string_view spelling, size_t next) { - if (auto grm_kind = m_grammar.TokenKindOf(spelling)) { + if (auto grm_kind = m_ep->TokenKindOf(spelling)) { result[{*grm_kind, next}].emplace(spelling); MatchPrefix(result, *grm_kind, next); } else { @@ -643,15 +645,17 @@ const ParsedQueryImpl::TableEntry &ParsedQueryImpl::ParsesAtIndex(size_t index) } std::pair> ParsedQueryImpl::MatchMarker( - const TableEntry &entry, const ParseMarker &marker, const ASTNode *node) { + const TableEntry &entry, const ParseMarker &marker, std::uint64_t node_id) { std::vector metavar_matches; + auto node = m_ep->GetASTNode(node_id); + auto kind = NodeKind::Deserialize(node.kind); + auto children = m_ep->GetASTNodeChildren(node_id); switch (marker.m_kind) { case ParseMarker::METAVAR: if (marker.m_metavar) { - MetavarMatch mv_match(marker.m_metavar->m_name, - node->Entity(), node->TokenRange()); + MetavarMatch mv_match(marker.m_metavar->m_name, node.entity); if (auto &predicate = marker.m_metavar->m_predicate) { if (!(*predicate)(mv_match)) { return {false, {}}; @@ -661,21 +665,22 @@ std::pair> ParsedQueryImpl::MatchMarker( } return {true, metavar_matches}; case ParseMarker::TERMINAL: - return {node->Kind().IsToken() && node->Spelling() == marker.m_spelling, - {}}; + return {kind.IsToken() && node.spelling == marker.m_spelling, {}}; case ParseMarker::NONTERMINAL: - if (node->Kind().IsToken() || - node->ChildVector().size() != marker.m_children.size()) { + if (kind.IsToken() || + children.size() != marker.m_children.size()) { return {false, {}}; } auto child_entry = &entry; auto child_it = marker.m_children.begin(); - for (const ASTNode *child_node : node->ChildVector()) { + for (std::uint64_t child_node_id : children) { auto &[kind, next, glob] = *child_it; + auto child_node = m_ep->GetASTNode(child_node_id); + auto child_node_kind = NodeKind::Deserialize(child_node.kind); - if (kind != NodeKind::Any() && kind != child_node->Kind()) { + if (kind != NodeKind::Any() && kind != child_node_kind) { return {false, {}}; } @@ -686,7 +691,7 @@ std::pair> ParsedQueryImpl::MatchMarker( } for (auto &marker : markers->second) { auto [child_ok, child_metavar_matches] = - MatchMarker(*child_entry, marker, child_node); + MatchMarker(*child_entry, marker, child_node_id); if (child_ok) { metavar_matches.insert( metavar_matches.end(), @@ -753,8 +758,8 @@ void ParsedQueryImpl::DebugParseTable(std::ostream &os) { } } -ParsedQuery::ParsedQuery(const Grammar &grammar, std::string_view query) - : impl(std::make_shared(*grammar.impl, query)) {} +ParsedQuery::ParsedQuery(std::shared_ptr ep, std::string_view query) + : impl(std::make_shared(std::move(ep), query)) {} bool ParsedQuery::IsValid(void) const { for (auto &[key, markers] : impl->ParsesAtIndex(0)) { @@ -790,35 +795,6 @@ bool ParsedQuery::AddMetavarPredicate( return true; } -std::vector ParsedQuery::FindInFragment( - mx::RawEntityId fragment_id) const { - if (auto frag = impl->m_grammar.index.fragment(fragment_id)) { - return Find(frag.value()); - } else { - return {}; - } -} - -void ParsedQuery::ForEachMatch(const mx::File &file, - std::function pred) const { - bool done = false; - auto real_pred = [sub_pred = std::move(pred), &done] (Match m) -> bool { - if (sub_pred(std::move(m))) { - return true; - } else { - done = true; - return false; - } - }; - - for (mx::Fragment frag : mx::Fragment::in(file)) { - ForEachMatch(frag, real_pred); - if (done) { - break; - } - } -} - void ParsedQuery::ForEachMatch(std::function pred) const { bool done = false; auto real_pred = [sub_pred = std::move(pred), &done] (Match m) -> bool { @@ -829,21 +805,15 @@ void ParsedQuery::ForEachMatch(std::function pred) const { return false; } }; - - for (mx::File file : mx::File::in(impl->m_grammar.index)) { - for (mx::Fragment frag : mx::Fragment::in(file)) { - ForEachMatch(frag, real_pred); - if (done) { - break; - } - } + for(auto frag_id : impl->m_ep->GetFragmentsInAST()) { + ForEachMatch(frag_id, real_pred); if (done) { break; } } } -std::vector ParsedQuery::Find(const mx::Fragment &frag) const { +std::vector ParsedQuery::Find(mx::RawEntityId frag) const { std::vector ret; ForEachMatch(frag, [&ret] (Match m) -> bool { ret.emplace_back(std::move(m)); @@ -852,33 +822,20 @@ std::vector ParsedQuery::Find(const mx::Fragment &frag) const { return ret; } -std::vector ParsedQuery::Find(const mx::File &file) const { - std::vector ret; - ForEachMatch(file, [&ret] (Match m) -> bool { - ret.emplace_back(std::move(m)); - return true; - }); - return ret; -} - std::vector ParsedQuery::Find(void) const { std::vector ret; - for (mx::File file : mx::File::in(impl->m_grammar.index)) { - for (mx::Fragment frag : mx::Fragment::in(file)) { - ForEachMatch(frag, [&ret] (Match m) -> bool { - ret.emplace_back(std::move(m)); - return true; - }); - } + for (auto frag_id : impl->m_ep->GetFragmentsInAST()) { + ForEachMatch(frag_id, [&ret] (Match m) -> bool { + ret.emplace_back(std::move(m)); + return true; + }); } return ret; } -void ParsedQuery::ForEachMatch(const mx::Fragment &frag, +void ParsedQuery::ForEachMatch(mx::RawEntityId frag_id, std::function pred) const { - - // Create AST for fragment - auto frag_ast = AST::Build(frag); + auto frag = impl->m_ep->FragmentFor(impl->m_ep, frag_id); // Find matching AST node auto &entry = impl->ParsesAtIndex(0); @@ -887,28 +844,27 @@ void ParsedQuery::ForEachMatch(const mx::Fragment &frag, continue; } if (key.first == NodeKind::Any()) { - for (auto &ast_node : frag_ast.AllNodes()) { + for (auto ast_node_id : impl->m_ep->GetASTNodesInFragment(frag_id)) { + auto ast_node = impl->m_ep->GetASTNode(ast_node_id); for (auto &marker : markers) { auto [ok, metavar_matches] = impl->MatchMarker( - entry, marker, &ast_node); - if (ok && !pred(Match(frag, ast_node.Entity(), - ast_node.TokenRange(), - metavar_matches))) { + entry, marker, ast_node_id); + if (ok && !pred(Match(ast_node.entity, metavar_matches))) { return; } } } } else { - for (auto *ast_node = frag_ast.WithKind(key.first); - ast_node; ast_node = ast_node->prev) { + auto ast_node_id = impl->m_ep->GetASTNodeWithKind(frag_id, key.first.Serialize()); + while (ast_node_id.has_value()) { + auto ast_node = impl->m_ep->GetASTNode(*ast_node_id); for (auto &marker : markers) { - auto [ok, metavar_matches] = impl->MatchMarker(entry, marker, ast_node); - if (ok && !pred(Match(frag, ast_node->Entity(), - ast_node->TokenRange(), - metavar_matches))) { + auto [ok, metavar_matches] = impl->MatchMarker(entry, marker, *ast_node_id); + if (ok && !pred(Match(ast_node.entity, metavar_matches))) { return; } } + ast_node_id = ast_node.prev; } } } diff --git a/lib/Query/Query.h b/lib/API/Query.h similarity index 84% rename from lib/Query/Query.h rename to lib/API/Query.h index 30ffc2e24..5e1b47037 100644 --- a/lib/Query/Query.h +++ b/lib/API/Query.h @@ -6,17 +6,18 @@ #pragma once -#include "AST.h" #include "NodeKind.h" #include "Grammar.h" #include #include +#include +#include template inline void hash_combine(size_t &h, const T& v) { - std::hash hasher; - h ^= hasher(v) + 0x9e3779b9 + (h << 6) + (h >> 2); + std::hash hasher; + h ^= hasher(v) + 0x9e3779b9 + (h << 6) + (h >> 2); } template<> @@ -31,6 +32,30 @@ struct std::hash> { namespace mx { namespace syntex { +// +// Result of parsing a query +// + +class ParsedQuery { + private: + std::shared_ptr impl; + ParsedQuery(void) = delete; + + public: + explicit ParsedQuery(std::shared_ptr ep, std::string_view query); + + bool IsValid() const; + + bool AddMetavarPredicate(const std::string_view &name, + std::function predicate); + + void ForEachMatch(mx::RawEntityId frag_id, + std::function pred) const; + void ForEachMatch(std::function pred) const; + + std::vector Find(mx::RawEntityId frag_id) const; + std::vector Find(void) const; +}; struct Metavar { std::string_view m_name; @@ -212,12 +237,13 @@ struct Item { // struct ParsedQueryImpl { - // GrammarImpl to be processed - const GrammarImpl &m_grammar; + std::shared_ptr m_ep; // Input string std::string_view m_input; + GrammarLeaves grammar_root; + // Main DP parse table using TableEntry = std::unordered_map, std::unordered_set>; @@ -239,12 +265,12 @@ struct ParsedQueryImpl { const TableEntry &ParsesAtIndex(size_t index); - explicit ParsedQueryImpl(const GrammarImpl &grammar, std::string_view input); + explicit ParsedQueryImpl(std::shared_ptr ep, std::string_view input); void DebugParseTable(std::ostream &os); std::pair> MatchMarker( - const TableEntry &entry, const ParseMarker &marker, const ASTNode *node); + const TableEntry &entry, const ParseMarker &marker, std::uint64_t node_id); }; } // namespace syntex diff --git a/lib/API/SQLiteEntityProvider.cpp b/lib/API/SQLiteEntityProvider.cpp index 88db18ee4..fac3b5ac1 100644 --- a/lib/API/SQLiteEntityProvider.cpp +++ b/lib/API/SQLiteEntityProvider.cpp @@ -5,9 +5,11 @@ // the LICENSE file found in the root directory of this source tree. #include "SQLiteEntityProvider.h" +#include "NodeKind.h" #include "API.h" #include "Compress.h" #include "Re2.h" +#include "Grammar.h" #include #include #include @@ -334,6 +336,66 @@ void SQLiteEntityProvider::FindSymbol(const Ptr &, std::string symbol, }); } +std::optional +SQLiteEntityProvider::TokenKindOf(std::string_view spelling) { + auto &storage = d->GetStorage(); + return storage.spelling_to_token_kind.TryGet(spelling); +} + +void SQLiteEntityProvider::LoadGrammarRoot(syntex::GrammarLeaves &root) { + auto &storage = d->GetStorage(); + std::vector> to_load; + + for(auto [id, kind] : storage.grammar_root) { + auto is_production = storage.grammar_nodes.TryGet(id).value_or(0); + auto &node = root[syntex::NodeKind::Deserialize(kind)]; + node.is_production = is_production; + to_load.emplace_back(id, &node); + } + + while(!to_load.empty()) { + auto pair = to_load.back(); + to_load.pop_back(); + auto id = std::get<0>(pair); + auto &node = *std::get<1>(pair); + + for(auto it = storage.grammar_children.key1_equals(id); + it != storage.grammar_children.end(); + ++it) { + auto [parent_id, kind, child_id] = *it; + auto is_production = storage.grammar_nodes.TryGet(child_id).value_or(0); + auto &child_node = node.leaves[syntex::NodeKind::Deserialize(kind)]; + child_node.is_production = is_production; + to_load.emplace_back(child_id, &child_node); + } + } +} + +std::vector SQLiteEntityProvider::GetFragmentsInAST(void) { + auto &storage = d->GetStorage(); + return storage.ast.GetFragments(); +} + +ASTNode SQLiteEntityProvider::GetASTNode(std::uint64_t id) { + auto &storage = d->GetStorage(); + return storage.ast.GetNode(id); +} + +std::vector SQLiteEntityProvider::GetASTNodeChildren(std::uint64_t id) { + auto &storage = d->GetStorage(); + return storage.ast.GetChildren(id); +} + +std::vector SQLiteEntityProvider::GetASTNodesInFragment(RawEntityId frag) { + auto &storage = d->GetStorage(); + return storage.ast.Root(frag); +} + +std::optional SQLiteEntityProvider::GetASTNodeWithKind(RawEntityId frag, unsigned short kind) { + auto &storage = d->GetStorage(); + return storage.ast.GetNodeInIndex(frag, kind); +} + EntityProvider::Ptr EntityProvider::from_database(std::filesystem::path path) { return std::make_shared(path); } diff --git a/lib/API/SQLiteEntityProvider.h b/lib/API/SQLiteEntityProvider.h index 89cd5fbb2..38b9d56b3 100644 --- a/lib/API/SQLiteEntityProvider.h +++ b/lib/API/SQLiteEntityProvider.h @@ -64,6 +64,17 @@ class SQLiteEntityProvider final : public EntityProvider { void FindSymbol(const Ptr &, std::string name, mx::DeclCategory category, std::vector &ids_out) final; + + std::optional + TokenKindOf(std::string_view spelling) final; + + void LoadGrammarRoot(syntex::GrammarLeaves &root) final; + + std::vector GetFragmentsInAST(void) final; + ASTNode GetASTNode(std::uint64_t id) final; + std::vector GetASTNodeChildren(std::uint64_t id) final; + std::vector GetASTNodesInFragment(RawEntityId frag) final; + std::optional GetASTNodeWithKind(RawEntityId frag, unsigned short kind) final; }; } // namespace mx diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 5ba4576ce..c377655aa 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -9,4 +9,3 @@ add_subdirectory(Common) add_subdirectory(API) add_subdirectory(Util) -add_subdirectory(Query) diff --git a/lib/Common/IndexStorage.cpp b/lib/Common/IndexStorage.cpp index 20e1cdce3..34be08ba8 100644 --- a/lib/Common/IndexStorage.cpp +++ b/lib/Common/IndexStorage.cpp @@ -12,6 +12,122 @@ #include namespace mx { +PersistentAST::PersistentAST(sqlite::Connection &db) : db(db) { + db.Execute( + "CREATE TABLE IF NOT EXISTS " + "'mx::syntex::ASTNode'(prev, kind, entity, spelling)"); + db.Execute( + "CREATE TABLE IF NOT EXISTS " + "'mx::syntex::ASTChildren'(parent, child, PRIMARY KEY(parent, child))"); + db.Execute( + "CREATE TABLE IF NOT EXISTS " + "'mx::syntex::ASTIndex'(fragment, kind, node, PRIMARY KEY(fragment, kind))"); + db.Execute( + "CREATE TABLE IF NOT EXISTS 'mx::syntex::ASTRoot'(fragment, node)"); + + get_root_stmt = db.Prepare( + "SELECT node FROM 'mx::syntex::ASTRoot' WHERE fragment = ?1"); + create_node_stmt = db.Prepare( + "INSERT INTO 'mx::syntex::ASTNode'(prev, kind, entity, spelling) " + "VALUES (?1, ?2, ?3, ?4) RETURNING rowid"); + add_root_stmt = db.Prepare( + "INSERT INTO 'mx::syntex::ASTRoot'(fragment, node) VALUES (?1, ?2)"); + get_node_stmt = db.Prepare( + "SELECT prev, kind, entity, spelling " + "FROM 'mx::syntex::ASTNode' WHERE rowid = ?1"); + get_index_stmt = db.Prepare( + "SELECT node FROM 'mx::syntex::ASTIndex' " + "WHERE fragment = ?1 AND kind = ?2"); + set_index_stmt = db.Prepare( + "INSERT OR REPLACE INTO 'mx::syntex::ASTIndex'(fragment, kind, node) " + "VALUES(?1, ?2, ?3)" + ); + get_fragments_stmt = db.Prepare( + "SELECT DISTINCT fragment FROM 'mx::syntex::ASTRoot'" + ); + get_children_stmt = db.Prepare( + "SELECT child FROM 'mx::syntex::ASTChildren' WHERE parent = ?1" + ); + add_child_stmt = db.Prepare( + "INSERT INTO 'mx::syntex::ASTChildren'(parent, child) VALUES (?1, ?2)" + ); +} + +std::vector PersistentAST::Root(RawEntityId fragment) { + std::vector results; + get_root_stmt->BindValues(fragment); + while(get_root_stmt->ExecuteStep()) { + get_root_stmt->GetResult().Columns(results.emplace_back()); + } + return results; +} + +std::uint64_t PersistentAST::AddNode(const ASTNode& node) { + create_node_stmt->BindValues(node.prev, node.kind, + node.entity, node.spelling); + create_node_stmt->ExecuteStep(); + std::uint64_t rowid; + create_node_stmt->GetResult().Columns(rowid); + return rowid; +} + +void PersistentAST::AddNodeToRoot(RawEntityId fragment, std::uint64_t node_id) { + add_root_stmt->BindValues(fragment, node_id); + add_root_stmt->Execute(); +} + +ASTNode PersistentAST::GetNode(std::uint64_t node_id) { + ASTNode node; + get_node_stmt->BindValues(node_id); + get_node_stmt->ExecuteStep(); + get_node_stmt->GetResult().Columns(node.prev, node.kind, + node.entity, node.spelling); + get_node_stmt->ExecuteStep(); + return node; +} + +std::optional PersistentAST::GetNodeInIndex( + RawEntityId fragment, + unsigned short kind) { + get_index_stmt->BindValues(fragment, kind); + if(get_index_stmt->ExecuteStep()) { + std::uint64_t rowid; + get_index_stmt->GetResult().Columns(rowid); + return rowid; + } + return {}; +} + +void PersistentAST::SetNodeInIndex( + RawEntityId fragment, + unsigned short kind, + std::uint64_t node_id) { + set_index_stmt->BindValues(fragment, kind, node_id); + set_index_stmt->Execute(); +} + +std::vector PersistentAST::GetFragments() { + std::vector fragments; + while(get_fragments_stmt->ExecuteStep()) { + get_fragments_stmt->GetResult().Columns(fragments.emplace_back()); + } + return fragments; +} + +std::vector PersistentAST::GetChildren(std::uint64_t parent) { + std::vector children; + get_children_stmt->BindValues(parent); + while(get_children_stmt->ExecuteStep()) { + get_children_stmt->GetResult().Columns(children.emplace_back()); + } + return children; +} + +void PersistentAST::AddChild(std::uint64_t parent, std::uint64_t child) { + add_child_stmt->BindValues(parent, child); + add_child_stmt->Execute(); +} + IndexStorage::IndexStorage(sqlite::Connection& db) : db(db) , version_number(db) @@ -30,6 +146,11 @@ IndexStorage::IndexStorage(sqlite::Connection& db) , mangled_name_to_entity_id(db) , entity_id_use_to_fragment_id(db) , entity_id_reference(db) + , spelling_to_token_kind(db) + , grammar_root(db) + , grammar_nodes(db) + , grammar_children(db) + , ast(db) , database(db) {} IndexStorage::~IndexStorage() {} diff --git a/lib/Common/SQLiteStore.cpp b/lib/Common/SQLiteStore.cpp index ca7d638ce..5b4cb7de6 100644 --- a/lib/Common/SQLiteStore.cpp +++ b/lib/Common/SQLiteStore.cpp @@ -78,6 +78,11 @@ std::string_view QueryResult::getBlob(int32_t idx) { return std::string_view(ptr, len); } +bool QueryResult::isNull(int32_t idx) { + auto prepared_stmt = stmt->prepareStatement(); + return sqlite3_column_type(prepared_stmt, idx) == SQLITE_NULL; +} + Statement::Statement(Connection &conn, const std::string &stmt) : db(conn), query(stmt) { @@ -90,8 +95,7 @@ Statement::Statement(Connection &conn, const std::string &stmt) static_cast(query.size()), &stmt, const_cast(&tail)); if (SQLITE_OK != ret) { - assert(0); - throw Error("Failed to prepare statement"); + throw Error("Failed to prepare statement", db.GetHandler()); } return std::shared_ptr( diff --git a/lib/Query/AST.cpp b/lib/Query/AST.cpp deleted file mode 100644 index 75db80703..000000000 --- a/lib/Query/AST.cpp +++ /dev/null @@ -1,204 +0,0 @@ -// Copyright (c) 2022-present, Trail of Bits, Inc. -// All rights reserved. -// -// This source code is licensed in accordance with the terms specified in -// the LICENSE file found in the root directory of this source tree. - -#include - -#include "AST.h" - -#include - -namespace mx { -namespace syntex { - -ASTNode::ASTNode(NodeKind kind_, - mx::VariantEntity entity_, - mx::TokenRange token_range_) - : kind(kind_), - entity(std::move(entity_)), - token_range(std::move(token_range_)), - child_vector() - {} - -ASTNode::ASTNode(mx::TokenKind kind_, - mx::VariantEntity entity_, - mx::TokenRange token_range_, - std::string spelling) - : kind(kind_), - entity(std::move(entity_)), - token_range(std::move(token_range_)), - spelling(std::move(spelling)) - {} - -ASTNode::~ASTNode() { - // Destruct correct union variant - if (kind.IsToken()) { - spelling.std::string::~string(); - } else { - child_vector.std::vector::~vector(); - } -} - -AST AST::Build(const mx::Fragment &fragment) { - AST self; - absl::flat_hash_map ctx_to_node; - - for (mx::Token tok : mx::Token::in(fragment)) { - // Skip whitespaces - switch (tok.kind()) { - case mx::TokenKind::UNKNOWN: - case mx::TokenKind::WHITESPACE: - case mx::TokenKind::COMMENT: - continue; - default: - if (tok.data().empty()) { - continue; - } - break; - } - - - // Start with the token node - ASTNode *node = &self.nodes.emplace_back( - tok.kind(), tok, tok, std::string(tok.data().data(), tok.data().size())); - node->prev = self.index[node->Kind().Serialize()]; - self.index[node->Kind().Serialize()] = node; - - for (auto ctx = mx::TokenContext::of(tok); ctx; ctx = ctx->parent()) { - auto it = ctx_to_node.find(ctx->id()); - - // Add to parent node's children if it already exists - - if (it != ctx_to_node.end()) { - it->second->child_vector.push_back(node); - node = nullptr; - break; - } - - // Otherwise we need to create a new parent node - - if (auto decl = mx::Decl::from(*ctx)) { - ASTNode *parent = &self.nodes.emplace_back(decl->kind(), *decl, decl->tokens()); - // Add it to the index - parent->prev = self.index[parent->Kind().Serialize()]; - self.index[parent->Kind().Serialize()] = parent; - ctx_to_node[ctx->id()] = parent; - parent->child_vector.push_back(node); - node = parent; - continue; - } - - if (auto stmt = mx::Stmt::from(*ctx)) { - ASTNode *parent = &self.nodes.emplace_back(stmt->kind(), *stmt, stmt->tokens()); - parent->prev = self.index[parent->Kind().Serialize()]; - self.index[parent->Kind().Serialize()] = parent; - ctx_to_node[ctx->id()] = parent; - parent->child_vector.push_back(node); - node = parent; - continue; - } - } - - // If we didn't add the token to a pre-existing parent, add it to the root - - if (node != nullptr) { - self.root.push_back(node); - } - } - - return self; -} - -#ifndef NDEBUG - -namespace { - -static std::string Data(const std::string &data) { - std::stringstream ss; - for (auto ch : data) { - switch (ch) { - // To keep xdot happy - case '[': ss << " ["; break; - case ']': ss << "]"; break; - // HTML escapes - case '<': ss << "<"; break; - case '>': ss << ">"; break; - case '"': ss << """; break; - case '\'': ss << "'"; break; - case '\n': ss << "
"; break; - case '&': ss << "&"; break; - case '\t': ss << "  "; break; - case '\r': break; - default: ss << ch; break; - } - } - return ss.str(); -} - -} // namespace - -void AST::PrintDOT(std::ostream &os) const { - os << "digraph {\n" - << "node [shape=none margin=0 nojustify=false labeljust=l font=courier];\n"; - - // Root node - os << "root [label=<
>];\n"; - for (const ASTNode *child : root) { - os << "root -> x" << std::hex << reinterpret_cast(child) - << std::dec << ";\n"; - } - - for (const ASTNode &node : nodes) { - os << "x" << std::hex << reinterpret_cast(&node) << std::dec - << " [label=<(&node) - << " -> x" << std::hex << reinterpret_cast(child) - << std::dec << ";\n"; - } - }; - - node.kind.Visit(Visitor { - [&] (mx::DeclKind kind) { - os - << " bgcolor=\"aquamarine\">" - << mx::EnumeratorName(kind) - << "
>];\n"; - PrintChildren(); - }, - [&] (mx::StmtKind kind) { - os - << " bgcolor=\"darkolivegreen3\">" - << mx::EnumeratorName(kind) - << ">];\n"; - PrintChildren(); - }, - [&] (mx::TokenKind kind) { - os - << " bgcolor=\"cornsilk2\">" - << mx::EnumeratorName(kind) - << "" - << Data(node.spelling) - << ">];\n"; - }, - [&] () { - assert(false); - abort(); - }, - }); - } - - os << "}\n"; -} - -#endif - -} // namespace syntex -} // namespace mx \ No newline at end of file diff --git a/lib/Query/AST.h b/lib/Query/AST.h deleted file mode 100644 index 935fb1ea6..000000000 --- a/lib/Query/AST.h +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright (c) 2022-present, Trail of Bits, Inc. -// All rights reserved. -// -// This source code is licensed in accordance with the terms specified in -// the LICENSE file found in the root directory of this source tree. - -#pragma once - -#include "NodeKind.h" - -#include -#include -#include -#include - -// -// AST: In-memory tree representation of a multiplier fragment -// - -namespace pasta { -class TokenRange; -} - -namespace mx { -namespace syntex { - -class ASTNode { -public: - friend class AST; - - mutable const ASTNode *prev {nullptr}; - - ASTNode(NodeKind kind, - mx::VariantEntity entity, - mx::TokenRange token_range); - - ASTNode(mx::TokenKind kind, - mx::VariantEntity entity, - mx::TokenRange token_range, - std::string spelling); - - ~ASTNode(); - - NodeKind Kind() const { - return kind; - } - - const std::vector &ChildVector() const { - assert(!kind.IsToken()); - return child_vector; - }; - - const mx::VariantEntity &Entity() const { - return entity; - } - - const mx::TokenRange &TokenRange() const { - return token_range; - } - - const std::string &Spelling() const { - assert(kind.IsToken()); - return spelling; - } - -private: - NodeKind kind; - mx::VariantEntity entity; - mx::TokenRange token_range; - - union { - mutable std::vector child_vector; - std::string spelling; - }; -}; - - -// An AST. -class AST { -private: - friend class ASTNode; - - // Allocation arena for AST nodes - std::deque nodes; - - // Nodes at the root of the AST - std::vector root; - - // Nodes of the same kind are linked together in a chain - // This is the root of the chain for each kind - std::vector index; - - AST() { - index.resize(NodeKind::UpperLimit()); - } - -public: - // All nodes - const std::deque &AllNodes() const { - return nodes; - } - - // Nodes at the root of this AST - const std::vector &RootNodes(void) const { - return root; - } - - // Get indexed node of kind - const ASTNode *WithKind(NodeKind kind) const { - return index[kind.Serialize()]; - } - - // Build an AST from a multiplier fragment - static AST Build(const mx::Fragment &fragment); - - // NOTE: this actually lives in the PASTA grammar builder's cpp file - // do not call from anything else - static AST Build(const pasta::TokenRange &tokens); - -#ifndef NDEBUG - void PrintDOT(std::ostream &os) const; -#endif -}; - -} // namespace syntex -} // namespace mx \ No newline at end of file diff --git a/lib/Query/CMakeLists.txt b/lib/Query/CMakeLists.txt deleted file mode 100644 index 6886a1732..000000000 --- a/lib/Query/CMakeLists.txt +++ /dev/null @@ -1,52 +0,0 @@ -# -# Copyright (c) 2022-present, Trail of Bits, Inc. -# All rights reserved. -# -# This source code is licensed in accordance with the terms specified in -# the LICENSE file found in the root directory of this source tree. -# - -string(TOLOWER "${PROJECT_NAME}" lower_project_name) - -add_library("mx-syntex" - "AST.h" - "AST.cpp" - "Grammar.h" - "Grammar.cpp" - "NodeKind.h" - "Query.h" - "Query.cpp" -) - -target_include_directories("mx-syntex" - PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}" -) - -target_link_libraries("mx-syntex" - PRIVATE - "absl::hash" - "absl::raw_hash_set" - PUBLIC - "mx-api" - "mx-util" -) - -if(MX_ENABLE_INSTALL) - install( - TARGETS - "mx-syntex" - EXPORT "${PROJECT_NAME}Targets" - RUNTIME - DESTINATION - "${CMAKE_INSTALL_BINDIR}" - LIBRARY - DESTINATION - "${CMAKE_INSTALL_LIBDIR}" - ARCHIVE - DESTINATION - "${CMAKE_INSTALL_LIBDIR}" - PUBLIC_HEADER - DESTINATION - "${CMAKE_INSTALL_INCLUDEDIR}/${lower_project_name}" - ) -endif(MX_ENABLE_INSTALL) diff --git a/lib/Query/Grammar.cpp b/lib/Query/Grammar.cpp deleted file mode 100644 index ef97fe0ed..000000000 --- a/lib/Query/Grammar.cpp +++ /dev/null @@ -1,319 +0,0 @@ -// Copyright (c) 2022-present, Trail of Bits, Inc. -// All rights reserved. -// -// This source code is licensed in accordance with the terms specified in -// the LICENSE file found in the root directory of this source tree. - -#include "AST.h" -#include "Grammar.h" -#include -#include -#include - -namespace mx { -namespace syntex { - -GrammarImpl::GrammarImpl(const mx::Index &index_, std::filesystem::path db_path_) - : index(index_), db_path(db_path_) -{ - { - sqlite::Connection db(db_path); - DeserializeRules(db); - DeserializeTokens(db); - } - - for (auto file : mx::File::in(index)) { - for (auto fragment_id : file.fragment_ids()) { - Import(fragment_id); - } - } -} - -GrammarImpl::~GrammarImpl() -{ - sqlite::Connection db(db_path); - SerializeRules(db); - SerializeTokens(db); -} - -// Import a fragment into the grammar. -void GrammarImpl::Import(mx::RawEntityId fragment_id) -{ - auto fragment = index.fragment(fragment_id).value(); - auto ast = AST::Build(fragment); - -/* - - // Debug graphs - std::stringstream name; - name << "dot/ast_" << fragment.id() << ".dot"; - std::fstream fs(name.str(), std::fstream::out | std::fstream::trunc); - ast.PrintDOT(fs); - fs.close(); - -*/ - Import(ast); - -} - -void GrammarImpl::Import(const AST &ast) -{ - std::vector nodes(ast.RootNodes()); - - // Make a production rule for every node and its children. - while (!nodes.empty()) { - const ASTNode *node = nodes.back(); - nodes.pop_back(); - - if (node->Kind().IsToken()) { - // This is a token kind node, and represents a terminal. We want to map - // the contents of the token to the actual kind of the token. - - tokens.insert({ node->Spelling(), node->Kind().AsToken() }); - } else { - // This is an internal or root node. E.g. given the following: - // - // A - // / | \ - // B C D - // - // We want to make a rule of the form `B C D A`, i.e. if you match `B C D` - // then you have matched an `A`. This "backward" syntax enables us to prefix - // scan for left corners (`B` in this case) and find all rules starting with - // `B`. - - auto &child_vector = node->ChildVector(); - assert(child_vector.size() >= 1); - - // FIXME: do something else with long grammar rules. PHP has - // some generated initializer lists with 100s of elements that - // blows up our stack when serializing a grammar. - if (child_vector.size() > 100) { - continue; - } - - // Add the child nodes to the work list. - nodes.insert(nodes.end(), child_vector.begin(), child_vector.end()); - - // Walk the trie - GrammarLeaves *leaves = &root; - for (const ASTNode *child : child_vector) { - leaves = &leaves->operator[](child->Kind()).leaves; - } - // Save pointer to rule head - GrammarNode *head = &leaves->operator[](node->Kind()); - - // Avoid creating cyclic CFGs - bool allow_production = true; - - if (child_vector.size() == 1) { - std::vector queue = { node->Kind() }; - while (!queue.empty()) { - auto nt = queue.back(); - queue.pop_back(); - - // Check if we can reach our own left corner - if (nt == child_vector[0]->Kind()) { - allow_production = false; - break; - } - - // Queue result of matching trivial productions - auto it = root.find(nt); - if (it != root.end()) { - for (auto &[left, rest] : it->second.leaves) { - if (rest.is_production) { - queue.push_back(left); - } - } - } - } - } - - // Mark the head as a production if appropriate - head->is_production = allow_production; - } - } -} - -template -static void IterateRulesRecursive(const GrammarLeaves &leaves, - std::vector &stack, - F cb) -{ - for (const auto &[left, rest] : leaves) { - if (rest.is_production) { - cb(stack, left); - } - stack.push_back(left); - IterateRulesRecursive(rest.leaves, stack, cb); - stack.pop_back(); - } -} - -void GrammarImpl::DebugRules(std::ostream &os) -{ - std::vector stack; - IterateRulesRecursive(root, stack, [&] (const std::vector &body, NodeKind head) { - for (NodeKind kind : body) { - os << kind << " "; - } - os << "-> " << head << "\n"; - }); -} - -// NOTE: this is a simplistic serialization format - -inline void verify(bool condition) { - if (!condition) { - assert(false); - abort(); - } -} - -static constexpr const char *grammar_root_schema = - "CREATE TABLE IF NOT EXISTS " - "'mx::syntex::GrammarRoot'(kind, node, PRIMARY KEY(kind))"; - -static constexpr const char *grammar_nodes_schema = - "CREATE TABLE IF NOT EXISTS " - "'mx::syntex::GrammarNodes'(id, is_production, PRIMARY KEY(id))"; - -static constexpr const char *grammar_children_schema = - "CREATE TABLE IF NOT EXISTS " - "'mx::syntex::GrammarChildren'(parent, kind, child, PRIMARY KEY(parent, kind))"; - -void GrammarImpl::SerializeRules(sqlite::Connection& db) -{ - sqlite::Transaction tx(db); - std::scoped_lock lock(tx); - - db.Execute(grammar_root_schema); - db.Execute("DELETE FROM 'mx::syntex::GrammarRoot'"); - db.Execute(grammar_root_schema); - db.Execute("DELETE FROM 'mx::syntex::GrammarNodes'"); - db.Execute(grammar_children_schema); - db.Execute("DELETE FROM 'mx::syntex::GrammarChildren'"); - - auto root_stmt = db.Prepare( - "INSERT OR REPLACE INTO " - "'mx::syntex::GrammarRoot'(kind, node) VALUES (?1, ?2)"); - auto node_stmt = db.Prepare( - "INSERT OR REPLACE INTO " - "'mx::syntex::GrammarNodes'(id, is_production) VALUES (?1, ?2)"); - auto child_stmt = db.Prepare( - "INSERT OR REPLACE INTO " - "'mx::syntex::GrammarChildren'(parent, kind, child) VALUES (?1, ?2, ?3)"); - std::vector to_insert; - - auto GetId = [](const GrammarNode* node) { - return static_cast(reinterpret_cast(node)); - }; - - for(auto &[kind, node] : root) { - auto kind_value = kind.Serialize(); - root_stmt->BindValues(kind_value, GetId(&node)); - root_stmt->Execute(); - to_insert.push_back(&node); - } - - while(!to_insert.empty()) { - auto node = to_insert.back(); - to_insert.pop_back(); - - for(auto &[kind, child] : node->leaves) { - to_insert.push_back(&child); - child_stmt->BindValues(GetId(node), kind.Serialize(), GetId(&child)); - child_stmt->Execute(); - } - - node_stmt->BindValues(GetId(node), int{node->is_production}); - node_stmt->Execute(); - } -} - -void GrammarImpl::DeserializeRules(sqlite::Connection& db) -{ - db.Execute(grammar_root_schema); - db.Execute(grammar_nodes_schema); - db.Execute(grammar_children_schema); - auto root_stmt = db.Prepare( - "SELECT node, kind, is_production FROM 'mx::syntex::GrammarRoot' " - "JOIN 'mx::syntex::GrammarNodes' ON id = node"); - auto children_stmt = db.Prepare( - "SELECT child, kind, is_production FROM 'mx::syntex::GrammarChildren' " - "JOIN 'mx::syntex::GrammarNodes' ON id = child " - "WHERE parent = ?1"); - std::vector> to_load; - while(root_stmt->ExecuteStep()) { - std::uint64_t id; - unsigned short kind; - int is_production; - auto res = root_stmt->GetResult(); - res.Columns(id, kind, is_production); - auto &node = root[NodeKind::Deserialize(kind)]; - node.is_production = is_production; - to_load.emplace_back(id, &node); - } - - while(!to_load.empty()) { - auto pair = to_load.back(); - to_load.pop_back(); - auto id = std::get<0>(pair); - auto &node = *std::get<1>(pair); - - children_stmt->BindValues(id); - while(children_stmt->ExecuteStep()) { - std::uint64_t child_id; - unsigned short kind; - int is_production; - auto res = children_stmt->GetResult(); - res.Columns(child_id, kind, is_production); - auto &child_node = node.leaves[NodeKind::Deserialize(kind)]; - child_node.is_production = is_production; - to_load.emplace_back(child_id, &child_node); - } - } -} - -static constexpr const char* tokens_schema = - "CREATE TABLE IF NOT EXISTS 'mx::syntex::Tokens'(spelling, kind, PRIMARY KEY(spelling))"; - -void GrammarImpl::SerializeTokens(sqlite::Connection& db) -{ - db.Execute(tokens_schema); - auto stmt = db.Prepare( - "INSERT OR IGNORE INTO 'mx::syntex::Tokens'(spelling, kind) VALUES (?1, ?2)"); - for (auto &[spelling, kind] : tokens) { - stmt->BindValues(spelling, static_cast(kind)); - stmt->Execute(); - } -} - -void GrammarImpl::DeserializeTokens(sqlite::Connection& db) -{ - db.Execute(tokens_schema); - auto stmt = db.Prepare("SELECT spelling, kind FROM 'mx::syntex::Tokens'"); - while(stmt->ExecuteStep()) { - std::string spelling; - unsigned short kind; - auto res = stmt->GetResult(); - res.Columns(spelling, kind); - tokens[spelling] = static_cast(kind); - } -} - -// Determine the kind of an identifier based on its spelling -std::optional GrammarImpl::TokenKindOf(std::string_view spelling) const { - auto it = tokens.find(std::string(spelling)); - if (it != tokens.end()) { - return it->second; - } - return std::nullopt; -} - -Grammar::Grammar(const mx::Index &index, std::filesystem::path grammar_dir) - : impl(std::make_shared(index, grammar_dir)) {} - -} // namespace syntex -} // namespace mx \ No newline at end of file diff --git a/lib/Query/Grammar.h b/lib/Query/Grammar.h deleted file mode 100644 index 3dbbe503a..000000000 --- a/lib/Query/Grammar.h +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright (c) 2022-present, Trail of Bits, Inc. -// All rights reserved. -// -// This source code is licensed in accordance with the terms specified in -// the LICENSE file found in the root directory of this source tree. - -#pragma once - -#include "NodeKind.h" - -#include -#include -#include -#include -#include -#include -#include - -namespace mx { - -class Index; - -namespace syntex { - -class AST; - -struct GrammarNode; - -// -// One set of grammar leaves -// FIXME(frabert): Deserialization crashes if this is turned into -// an `absl::flat_hash_map` -// -using GrammarLeaves = std::unordered_map; - -// -// Node in the grammar tree -// - -struct GrammarNode { - // Does this node correspond to the head of a production - bool is_production; - // Further leaves - GrammarLeaves leaves; -}; - -// -// Persistent CFG synthesized from a set of multiplier fragments -// - -class GrammarImpl { -private: - friend class Item; - friend class ParsedQuery; - friend class ParsedQueryImpl; - - // Multiplier index corresponding to this grammar - const mx::Index &index; - - // Grammar storage directory - std::filesystem::path db_path; - - // Mapping of spellings to token kinds - absl::flat_hash_map tokens; - - // Root of the grammar tree - GrammarLeaves root; - -public: - GrammarImpl(const mx::Index &index, std::filesystem::path db_path); - - ~GrammarImpl(void); - - // Import a fragment into the grammar. This extends the persisted grammar with - // the features from this fragment. - void Import(mx::RawEntityId fragment_id); - - void Import(const AST &ast); - - // Determine the kind of an identifier based on its spelling - std::optional TokenKindOf(std::string_view spelling) const; - - // Pretty print rules for debugging - void DebugRules(std::ostream &os); - - // Database grammar serialization - void SerializeRules(sqlite::Connection& db); - void DeserializeRules(sqlite::Connection& db); - - void SerializeTokens(sqlite::Connection& db); - void DeserializeTokens(sqlite::Connection& db); -}; - -} // namespace syntex -} // namespace mx \ No newline at end of file From 0667df9cebfdb09611fd722284bb9c2b64aaa9ff Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Wed, 9 Nov 2022 11:35:42 +0100 Subject: [PATCH 3/8] Serialize AST and grammar after indexing --- bin/Index/BuildAST.cpp | 188 +++++++++++++++++++++ bin/Index/BuildAST.h | 17 ++ bin/Index/CMakeLists.txt | 3 + bin/Index/Main.cpp | 4 + include/multiplier/IndexStorage.h | 49 +++--- {lib/API => include/multiplier}/NodeKind.h | 0 include/multiplier/PersistentMap.h | 5 +- include/multiplier/SQLiteStore.h | 5 + lib/API/CMakeLists.txt | 1 - lib/API/CachingEntityProvider.cpp | 2 +- lib/API/Grammar.h | 2 +- lib/API/InvalidEntityProvider.cpp | 2 +- lib/API/Query.cpp | 4 +- lib/API/Query.h | 2 +- lib/API/SQLiteEntityProvider.cpp | 18 +- lib/Common/IndexStorage.cpp | 180 +++++++++++++++++--- 16 files changed, 410 insertions(+), 72 deletions(-) create mode 100644 bin/Index/BuildAST.cpp create mode 100644 bin/Index/BuildAST.h rename {lib/API => include/multiplier}/NodeKind.h (100%) diff --git a/bin/Index/BuildAST.cpp b/bin/Index/BuildAST.cpp new file mode 100644 index 000000000..12ade8c50 --- /dev/null +++ b/bin/Index/BuildAST.cpp @@ -0,0 +1,188 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. +// All rights reserved. +// +// This source code is licensed in accordance with the terms specified in +// the LICENSE file found in the root directory of this source tree. + +#include "BuildAST.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace indexer { + +static void SerializeAST(mx::Fragment fragment, ServerContext &ctx) { + auto &ast = ctx->ast; + std::unordered_map ctx_to_node_id; + + for (mx::Token tok : mx::Token::in(fragment)) { + // Skip whitespaces + switch (tok.kind()) { + case mx::TokenKind::UNKNOWN: + case mx::TokenKind::WHITESPACE: + case mx::TokenKind::COMMENT: + continue; + default: + if (tok.data().empty()) { + continue; + } + break; + } + ctx->spelling_to_token_kind.Set(tok.data(), tok.kind()); + + // Start with the token node + mx::ASTNode node{}; + node.kind = mx::syntex::NodeKind{tok.kind()}.Serialize(); + node.entity = tok.id(); + node.spelling = std::string(tok.data().data(), tok.data().size()); + node.prev = ast.GetNodeInIndex(fragment.id(), node.kind); + std::optional node_id = ast.AddNode(node); + ast.SetNodeInIndex(fragment.id(), node.kind, *node_id); + + for (auto ctx = mx::TokenContext::of(tok); ctx; ctx = ctx->parent()) { + auto it = ctx_to_node_id.find(ctx->id()); + + // Add to parent node's children if it already exists + + if (it != ctx_to_node_id.end()) { + ast.AddChild(it->second, *node_id); + node_id = std::nullopt; + break; + } + + // Otherwise we need to create a new parent node + + if (auto decl = mx::Decl::from(*ctx)) { + mx::ASTNode parent{}; + parent.kind = mx::syntex::NodeKind{decl->kind()}.Serialize(); + parent.entity = decl->id(); + parent.prev = ast.GetNodeInIndex(fragment.id(), parent.kind); + auto parent_id = ast.AddNode(parent); + // Add it to the index + ast.SetNodeInIndex(fragment.id(), parent.kind, parent_id); + ctx_to_node_id[ctx->id()] = parent_id; + ast.AddChild(parent_id, *node_id); + node_id = parent_id; + continue; + } + + if (auto stmt = mx::Stmt::from(*ctx)) { + mx::ASTNode parent{}; + parent.kind = mx::syntex::NodeKind{stmt->kind()}.Serialize(); + parent.entity = stmt->id(); + parent.prev = ast.GetNodeInIndex(fragment.id(), parent.kind); + auto parent_id = ast.AddNode(parent); + // Add it to the index + ast.SetNodeInIndex(fragment.id(), parent.kind, parent_id); + ctx_to_node_id[ctx->id()] = parent_id; + ast.AddChild(parent_id, *node_id); + node_id = parent_id; + continue; + } + } + + // If we didn't add the token to a pre-existing parent, add it to the root + + if (node_id.has_value()) { + ast.AddNodeToRoot(fragment.id(), *node_id); + } + } +} + +static void ImportGrammar(mx::Fragment fragment, ServerContext& ctx) { + auto &ast = ctx->ast; + auto &grammar = ctx->grammar; + auto nodes = ast.Root(fragment.id()); + + // Make a production rule for every node and its children. + while (!nodes.empty()) { + auto node_id = nodes.back(); + nodes.pop_back(); + + auto node = ast.GetNode(node_id); + auto node_kind = mx::syntex::NodeKind::Deserialize(node.kind); + + if (!node_kind.IsToken()) { + // This is an internal or root node. E.g. given the following: + // + // A + // / | \ + // B C D + // + // We want to make a rule of the form `B C D A`, i.e. if you match `B C D` + // then you have matched an `A`. This "backward" syntax enables us to prefix + // scan for left corners (`B` in this case) and find all rules starting with + // `B`. + + auto child_vector = ast.GetChildren(node_id); + assert(child_vector.size() >= 1); + + // FIXME: do something else with long grammar rules. PHP has + // some generated initializer lists with 100s of elements that + // blows up our stack when serializing a grammar. + if (child_vector.size() > 100) { + continue; + } + + // Add the child nodes to the work list. + nodes.insert(nodes.end(), child_vector.begin(), child_vector.end()); + + // Walk the trie + std::uint64_t leaves_id = 0; + for (auto child_id : child_vector) { + auto child = ast.GetNode(child_id); + leaves_id = grammar.GetChild(leaves_id, child.kind); + } + // Save pointer to rule head + auto head_id = grammar.GetChild(leaves_id, node.kind); + + // Avoid creating cyclic CFGs + bool allow_production = true; + + if (child_vector.size() == 1) { + std::vector queue = { node.kind }; + while (!queue.empty()) { + auto nt = queue.back(); + queue.pop_back(); + + // Check if we can reach our own left corner + auto child = ast.GetNode(child_vector[0]); + if (nt == child.kind) { + allow_production = false; + break; + } + + // Queue result of matching trivial productions + for(auto [left, rest] : grammar.GetChildLeaves(0, nt)) { + auto node = grammar.GetNode(rest); + if(node.is_production) { + queue.push_back(left); + } + } + } + } + + // Mark the head as a production if appropriate + grammar.UpdateNode(head_id, {allow_production}); + } + } +} + +void BuildAST(mx::Index index, ServerContext &context) { + for(auto file : mx::File::in(index)) { + for(auto fragment : mx::Fragment::in(file)) { + sqlite::Transaction tx(context.db); + std::scoped_lock lock(tx); + SerializeAST(fragment, context); + ImportGrammar(fragment, context); + } + } +} +} // namespace indexer \ No newline at end of file diff --git a/bin/Index/BuildAST.h b/bin/Index/BuildAST.h new file mode 100644 index 000000000..2cdabcf9e --- /dev/null +++ b/bin/Index/BuildAST.h @@ -0,0 +1,17 @@ +// Copyright (c) 2022-present, Trail of Bits, Inc. +// All rights reserved. +// +// This source code is licensed in accordance with the terms specified in +// the LICENSE file found in the root directory of this source tree. + +#pragma once + +#include +#include +#include "Context.h" + +namespace indexer { + +void BuildAST(mx::Index index, ServerContext& context); + +} // namespace indexer diff --git a/bin/Index/CMakeLists.txt b/bin/Index/CMakeLists.txt index 46be997d7..d156e3adc 100644 --- a/bin/Index/CMakeLists.txt +++ b/bin/Index/CMakeLists.txt @@ -9,6 +9,8 @@ set(exe_name "mx-index") add_executable("${exe_name}" + "BuildAST.cpp" + "BuildAST.h" "BuildPendingFragment.cpp" "Compress.cpp" "Compress.h" @@ -57,6 +59,7 @@ target_link_libraries("${exe_name}" PRIVATE ${MX_BEGIN_FORCE_LOAD_GROUP} "mx-util" + "mx-api" "concurrentqueue" ${MX_BEGIN_FORCE_LOAD_LIB} pasta::pasta ${MX_END_FORCE_LOAD_LIB} ${MX_END_FORCE_LOAD_GROUP} diff --git a/bin/Index/Main.cpp b/bin/Index/Main.cpp index 5f6b23180..89a0f1d53 100644 --- a/bin/Index/Main.cpp +++ b/bin/Index/Main.cpp @@ -24,6 +24,7 @@ #include "Context.h" #include "Parser.h" #include "Importer.h" +#include "BuildAST.h" // Should we show a help message? DECLARE_bool(help); @@ -158,5 +159,8 @@ extern "C" int main(int argc, char *argv[]) { executor.Start(); executor.Wait(); + auto index = mx::Index(mx::EntityProvider::from_database(FLAGS_db)); + indexer::BuildAST(index, ic->server_context[0]); + return EXIT_SUCCESS; } diff --git a/include/multiplier/IndexStorage.h b/include/multiplier/IndexStorage.h index 10c520e72..6f3ecbdfb 100644 --- a/include/multiplier/IndexStorage.h +++ b/include/multiplier/IndexStorage.h @@ -30,9 +30,6 @@ enum : char { kEntityIdUseToFragmentId, kEntityIdReference, kSpellingToTokenKind, - kGrammarRoot, - kGrammarNodes, - kGrammarChildren, }; enum MetadataName : char { @@ -70,16 +67,8 @@ struct ASTNode { }; class PersistentAST final { - sqlite::Connection &db; - std::shared_ptr get_root_stmt; - std::shared_ptr create_node_stmt; - std::shared_ptr add_root_stmt; - std::shared_ptr get_node_stmt; - std::shared_ptr get_index_stmt; - std::shared_ptr set_index_stmt; - std::shared_ptr get_fragments_stmt; - std::shared_ptr get_children_stmt; - std::shared_ptr add_child_stmt; + struct Impl; + std::unique_ptr impl; public: PersistentAST(sqlite::Connection &db); @@ -108,6 +97,28 @@ class PersistentAST final { void AddChild(std::uint64_t parent, std::uint64_t child); }; +struct GrammarNode { + bool is_production; +}; + +class PersistentGrammar final { + struct Impl; + std::unique_ptr impl; + + public: + PersistentGrammar(sqlite::Connection &db); + + std::vector> GetChildren(std::uint64_t parent); + + std::uint64_t GetChild(std::uint64_t parent, unsigned short kind); + + std::vector> GetChildLeaves(std::uint64_t parent, unsigned short kind); + + void UpdateNode(std::uint64_t id, const GrammarNode &node); + + GrammarNode GetNode(std::uint64_t id); +}; + class IndexStorage final { sqlite::Connection &db; @@ -213,18 +224,10 @@ class IndexStorage final { mx::PersistentMap spelling_to_token_kind; - mx::PersistentMap - grammar_root; - - mx::PersistentMap - grammar_nodes; - - mx::PersistentMap2 - grammar_children; - PersistentAST ast; + PersistentGrammar grammar; + // SQLite database. Used for things like symbol searches. SymbolDatabase database; diff --git a/lib/API/NodeKind.h b/include/multiplier/NodeKind.h similarity index 100% rename from lib/API/NodeKind.h rename to include/multiplier/NodeKind.h diff --git a/include/multiplier/PersistentMap.h b/include/multiplier/PersistentMap.h index 99ec88677..b6a576cb5 100644 --- a/include/multiplier/PersistentMap.h +++ b/include/multiplier/PersistentMap.h @@ -35,9 +35,6 @@ static constexpr const char* table_names[] = { "'mx::EntityIdUseToFragmentId'", "'mx::EntityIdReference'", "'mx::syntex::Tokens'", - "'mx::syntex::GrammarRoot'", - "'mx::syntex::GrammarNodes'", - "'mx::syntex::GrammarChildren'", }; template @@ -273,7 +270,7 @@ class PersistentMap { get_or_set_stmt = db.Prepare(ss.str()); ss = {}; - ss << "SELECT key, value FROM " << table_names; + ss << "SELECT key, value FROM " << table_names[kId]; enum_stmt = db.Prepare(ss.str()); } diff --git a/include/multiplier/SQLiteStore.h b/include/multiplier/SQLiteStore.h index 7b460b1d0..92af9525a 100644 --- a/include/multiplier/SQLiteStore.h +++ b/include/multiplier/SQLiteStore.h @@ -172,6 +172,11 @@ class Statement : public std::enable_shared_from_this { } } + template>> + void bind(const size_t i, const T &value) { + bind(i, static_cast(value)); + } + void reset(); template diff --git a/lib/API/CMakeLists.txt b/lib/API/CMakeLists.txt index be68e7e99..80f821325 100644 --- a/lib/API/CMakeLists.txt +++ b/lib/API/CMakeLists.txt @@ -51,7 +51,6 @@ add_library("mx-api" "Index.cpp" "InvalidEntityProvider.cpp" "InvalidEntityProvider.h" - "NodeKind.h" "PackedFileImpl.cpp" "PackedFragmentImpl.cpp" "PackedReaderState.cpp" diff --git a/lib/API/CachingEntityProvider.cpp b/lib/API/CachingEntityProvider.cpp index a4e4ae505..5d669a046 100644 --- a/lib/API/CachingEntityProvider.cpp +++ b/lib/API/CachingEntityProvider.cpp @@ -4,7 +4,7 @@ // This source code is licensed in accordance with the terms specified in // the LICENSE file found in the root directory of this source tree. -#include "NodeKind.h" +#include #include "Grammar.h" #include "CachingEntityProvider.h" diff --git a/lib/API/Grammar.h b/lib/API/Grammar.h index 3b4c4f9cb..f9f865302 100644 --- a/lib/API/Grammar.h +++ b/lib/API/Grammar.h @@ -6,7 +6,7 @@ #pragma once -#include "NodeKind.h" +#include #include diff --git a/lib/API/InvalidEntityProvider.cpp b/lib/API/InvalidEntityProvider.cpp index 1071be68c..45a56843b 100644 --- a/lib/API/InvalidEntityProvider.cpp +++ b/lib/API/InvalidEntityProvider.cpp @@ -6,7 +6,7 @@ #include "InvalidEntityProvider.h" -#include "NodeKind.h" +#include #include "Grammar.h" #include diff --git a/lib/API/Query.cpp b/lib/API/Query.cpp index bfc34ad14..695880ff2 100644 --- a/lib/API/Query.cpp +++ b/lib/API/Query.cpp @@ -531,7 +531,7 @@ done_filters:; } ParsedQueryImpl::ParsedQueryImpl(std::shared_ptr ep, std::string_view input) - : m_ep(std::move(ep)), m_input(input) { + : m_ep(ep), m_input(input) { ep->LoadGrammarRoot(grammar_root); } @@ -759,7 +759,7 @@ void ParsedQueryImpl::DebugParseTable(std::ostream &os) { } ParsedQuery::ParsedQuery(std::shared_ptr ep, std::string_view query) - : impl(std::make_shared(std::move(ep), query)) {} + : impl(std::make_shared(ep, query)) {} bool ParsedQuery::IsValid(void) const { for (auto &[key, markers] : impl->ParsesAtIndex(0)) { diff --git a/lib/API/Query.h b/lib/API/Query.h index 5e1b47037..2bf4d4672 100644 --- a/lib/API/Query.h +++ b/lib/API/Query.h @@ -6,7 +6,7 @@ #pragma once -#include "NodeKind.h" +#include #include "Grammar.h" #include #include diff --git a/lib/API/SQLiteEntityProvider.cpp b/lib/API/SQLiteEntityProvider.cpp index fac3b5ac1..661d8c9c6 100644 --- a/lib/API/SQLiteEntityProvider.cpp +++ b/lib/API/SQLiteEntityProvider.cpp @@ -5,7 +5,7 @@ // the LICENSE file found in the root directory of this source tree. #include "SQLiteEntityProvider.h" -#include "NodeKind.h" +#include #include "API.h" #include "Compress.h" #include "Re2.h" @@ -344,12 +344,13 @@ SQLiteEntityProvider::TokenKindOf(std::string_view spelling) { void SQLiteEntityProvider::LoadGrammarRoot(syntex::GrammarLeaves &root) { auto &storage = d->GetStorage(); + auto &grammar = storage.grammar; std::vector> to_load; - for(auto [id, kind] : storage.grammar_root) { - auto is_production = storage.grammar_nodes.TryGet(id).value_or(0); + for(auto [kind, id] : grammar.GetChildren(0)) { + auto data = grammar.GetNode(id); auto &node = root[syntex::NodeKind::Deserialize(kind)]; - node.is_production = is_production; + node.is_production = data.is_production; to_load.emplace_back(id, &node); } @@ -359,13 +360,10 @@ void SQLiteEntityProvider::LoadGrammarRoot(syntex::GrammarLeaves &root) { auto id = std::get<0>(pair); auto &node = *std::get<1>(pair); - for(auto it = storage.grammar_children.key1_equals(id); - it != storage.grammar_children.end(); - ++it) { - auto [parent_id, kind, child_id] = *it; - auto is_production = storage.grammar_nodes.TryGet(child_id).value_or(0); + for(auto [kind, child_id] : grammar.GetChildren(id)) { + auto data = grammar.GetNode(child_id); auto &child_node = node.leaves[syntex::NodeKind::Deserialize(kind)]; - child_node.is_production = is_production; + child_node.is_production = data.is_production; to_load.emplace_back(child_id, &child_node); } } diff --git a/lib/Common/IndexStorage.cpp b/lib/Common/IndexStorage.cpp index 34be08ba8..9db662369 100644 --- a/lib/Common/IndexStorage.cpp +++ b/lib/Common/IndexStorage.cpp @@ -12,7 +12,22 @@ #include namespace mx { -PersistentAST::PersistentAST(sqlite::Connection &db) : db(db) { +struct PersistentAST::Impl { + sqlite::Connection &db; + std::shared_ptr get_root_stmt; + std::shared_ptr create_node_stmt; + std::shared_ptr add_root_stmt; + std::shared_ptr get_node_stmt; + std::shared_ptr get_index_stmt; + std::shared_ptr set_index_stmt; + std::shared_ptr get_fragments_stmt; + std::shared_ptr get_children_stmt; + std::shared_ptr add_child_stmt; + + Impl(sqlite::Connection &db); +}; + +PersistentAST::Impl::Impl(sqlite::Connection &db) : db(db) { db.Execute( "CREATE TABLE IF NOT EXISTS " "'mx::syntex::ASTNode'(prev, kind, entity, spelling)"); @@ -53,46 +68,52 @@ PersistentAST::PersistentAST(sqlite::Connection &db) : db(db) { ); } +PersistentAST::PersistentAST(sqlite::Connection &db) + : impl(std::make_unique(db)) {} + std::vector PersistentAST::Root(RawEntityId fragment) { std::vector results; - get_root_stmt->BindValues(fragment); - while(get_root_stmt->ExecuteStep()) { - get_root_stmt->GetResult().Columns(results.emplace_back()); + impl->get_root_stmt->BindValues(fragment); + while(impl->get_root_stmt->ExecuteStep()) { + impl->get_root_stmt->GetResult().Columns(results.emplace_back()); } + impl->get_root_stmt->Reset(); return results; } std::uint64_t PersistentAST::AddNode(const ASTNode& node) { - create_node_stmt->BindValues(node.prev, node.kind, + impl->create_node_stmt->BindValues(node.prev, node.kind, node.entity, node.spelling); - create_node_stmt->ExecuteStep(); + impl->create_node_stmt->ExecuteStep(); std::uint64_t rowid; - create_node_stmt->GetResult().Columns(rowid); + impl->create_node_stmt->GetResult().Columns(rowid); + impl->create_node_stmt->Reset(); return rowid; } void PersistentAST::AddNodeToRoot(RawEntityId fragment, std::uint64_t node_id) { - add_root_stmt->BindValues(fragment, node_id); - add_root_stmt->Execute(); + impl->add_root_stmt->BindValues(fragment, node_id); + impl->add_root_stmt->Execute(); } ASTNode PersistentAST::GetNode(std::uint64_t node_id) { ASTNode node; - get_node_stmt->BindValues(node_id); - get_node_stmt->ExecuteStep(); - get_node_stmt->GetResult().Columns(node.prev, node.kind, + impl->get_node_stmt->BindValues(node_id); + impl->get_node_stmt->ExecuteStep(); + impl->get_node_stmt->GetResult().Columns(node.prev, node.kind, node.entity, node.spelling); - get_node_stmt->ExecuteStep(); + impl->get_node_stmt->Reset(); return node; } std::optional PersistentAST::GetNodeInIndex( RawEntityId fragment, unsigned short kind) { - get_index_stmt->BindValues(fragment, kind); - if(get_index_stmt->ExecuteStep()) { + impl->get_index_stmt->BindValues(fragment, kind); + if(impl->get_index_stmt->ExecuteStep()) { std::uint64_t rowid; - get_index_stmt->GetResult().Columns(rowid); + impl->get_index_stmt->GetResult().Columns(rowid); + impl->get_root_stmt->Reset(); return rowid; } return {}; @@ -102,30 +123,135 @@ void PersistentAST::SetNodeInIndex( RawEntityId fragment, unsigned short kind, std::uint64_t node_id) { - set_index_stmt->BindValues(fragment, kind, node_id); - set_index_stmt->Execute(); + impl->set_index_stmt->BindValues(fragment, kind, node_id); + impl->set_index_stmt->Execute(); } std::vector PersistentAST::GetFragments() { std::vector fragments; - while(get_fragments_stmt->ExecuteStep()) { - get_fragments_stmt->GetResult().Columns(fragments.emplace_back()); + while(impl->get_fragments_stmt->ExecuteStep()) { + impl->get_fragments_stmt->GetResult().Columns(fragments.emplace_back()); } return fragments; } std::vector PersistentAST::GetChildren(std::uint64_t parent) { std::vector children; - get_children_stmt->BindValues(parent); - while(get_children_stmt->ExecuteStep()) { - get_children_stmt->GetResult().Columns(children.emplace_back()); + impl->get_children_stmt->BindValues(parent); + while(impl->get_children_stmt->ExecuteStep()) { + impl->get_children_stmt->GetResult().Columns(children.emplace_back()); } return children; } void PersistentAST::AddChild(std::uint64_t parent, std::uint64_t child) { - add_child_stmt->BindValues(parent, child); - add_child_stmt->Execute(); + impl->add_child_stmt->BindValues(parent, child); + impl->add_child_stmt->Execute(); +} + +struct PersistentGrammar::Impl { + sqlite::Connection &db; + std::shared_ptr get_children_stmt; + std::shared_ptr get_child_stmt; + std::shared_ptr get_child_leaves_stmt; + std::shared_ptr update_node_stmt; + std::shared_ptr get_node_stmt; + std::shared_ptr add_node_stmt; + std::shared_ptr add_child_stmt; + + Impl(sqlite::Connection &db); +}; + +PersistentGrammar::Impl::Impl(sqlite::Connection &db) + : db(db) { + db.Execute( + "CREATE TABLE IF NOT EXISTS 'mx::syntex::GrammarNodes'(is_production)" + ); + db.Execute( + "CREATE TABLE IF NOT EXISTS 'mx::syntex::GrammarChildren'(parent, kind, child, PRIMARY KEY(parent, kind))" + ); + + get_children_stmt = db.Prepare( + "SELECT kind, child FROM 'mx::syntex::GrammarChildren' WHERE parent = ?1" + ); + get_child_stmt = db.Prepare( + "SELECT child FROM 'mx::syntex::GrammarChildren' WHERE parent = ?1 AND kind = ?2" + ); + get_child_leaves_stmt = db.Prepare( + "SELECT child_node.kind, child_node.child" + " FROM 'mx::syntex::GrammarChildren' AS parent_node," + " 'mx::syntex::GrammarChildren' AS child_node" + " WHERE parent_node.parent = ?1" + " AND parent_node.kind = ?2" + " AND child_node.parent = parent_node.child" + ); + update_node_stmt = db.Prepare( + "UPDATE 'mx::syntex::GrammarNodes' SET is_production = ?2 WHERE rowid = ?1" + ); + get_node_stmt = db.Prepare( + "SELECT is_production FROM 'mx::syntex::GrammarNodes' WHERE rowid = ?1" + ); + add_node_stmt = db.Prepare( + "INSERT INTO 'mx::syntex::GrammarNodes'(is_production) VALUES (?1) RETURNING rowid" + ); + add_child_stmt = db.Prepare( + "INSERT INTO 'mx::syntex::GrammarChildren'(parent, kind, child) VALUES (?1, ?2, ?3)" + ); +} + +PersistentGrammar::PersistentGrammar(sqlite::Connection &db) : impl(std::make_unique(db)) {} + +std::vector> +PersistentGrammar::GetChildren(std::uint64_t parent) { + std::vector> result; + impl->get_children_stmt->BindValues(parent); + while(impl->get_children_stmt->ExecuteStep()) { + auto &[kind, id] = result.emplace_back(); + impl->get_children_stmt->GetResult().Columns(kind, id); + } + return result; +} + +std::uint64_t PersistentGrammar::GetChild(std::uint64_t parent, unsigned short kind) { + impl->get_child_stmt->BindValues(parent, kind); + std::uint64_t child; + if(impl->get_child_stmt->ExecuteStep()) { + impl->get_child_stmt->GetResult().Columns(child); + impl->get_child_stmt->Reset(); + return child; + } + impl->add_node_stmt->BindValues(0); + impl->add_node_stmt->ExecuteStep(); + impl->add_node_stmt->GetResult().Columns(child); + impl->add_node_stmt->Reset(); + impl->add_child_stmt->BindValues(parent, kind, child); + impl->add_child_stmt->Execute(); + return child; +} + +std::vector> +PersistentGrammar::GetChildLeaves(std::uint64_t parent, unsigned short kind) { + std::vector> result; + impl->get_child_leaves_stmt->BindValues(parent, kind); + while(impl->get_child_leaves_stmt->ExecuteStep()) { + auto &[kind, id] = result.emplace_back(); + impl->get_child_leaves_stmt->GetResult().Columns(kind, id); + } + return result; +} + +void PersistentGrammar::UpdateNode(std::uint64_t id, const GrammarNode &node) { + impl->update_node_stmt->BindValues(id, node.is_production); + impl->update_node_stmt->Execute(); +} + +GrammarNode PersistentGrammar::GetNode(std::uint64_t id) { + GrammarNode node; + impl->get_node_stmt->BindValues(id); + impl->get_node_stmt->ExecuteStep(); + impl->get_node_stmt->GetResult().Columns(node.is_production); + impl->get_node_stmt->Reset(); + return node; } IndexStorage::IndexStorage(sqlite::Connection& db) @@ -147,10 +273,8 @@ IndexStorage::IndexStorage(sqlite::Connection& db) , entity_id_use_to_fragment_id(db) , entity_id_reference(db) , spelling_to_token_kind(db) - , grammar_root(db) - , grammar_nodes(db) - , grammar_children(db) , ast(db) + , grammar(db) , database(db) {} IndexStorage::~IndexStorage() {} From ccd5d3549e5fb89d9e9e871e693774d76e0b4fc3 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Wed, 9 Nov 2022 11:39:09 +0100 Subject: [PATCH 4/8] Remove `absl` dependency --- CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e40a18465..53308c3fb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -59,7 +59,6 @@ find_package(SQLite3 3.35 REQUIRED) find_package(reproc++ REQUIRED) find_package(pasta CONFIG REQUIRED) find_package(Python3 COMPONENTS Interpreter REQUIRED) -find_package(absl CONFIG REQUIRED) if(PLATFORM_MACOS) set(CMAKE_INSTALL_RPATH "@executable_path/../${CMAKE_INSTALL_LIBDIR}") From 41d704d384927ae903cf445cafdb86ba6edc64e2 Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Wed, 9 Nov 2022 12:18:11 +0100 Subject: [PATCH 5/8] Remove unused bits --- include/multiplier/PersistentMap.h | 153 ----------------------------- lib/API/Grammar.h | 3 - 2 files changed, 156 deletions(-) diff --git a/include/multiplier/PersistentMap.h b/include/multiplier/PersistentMap.h index b6a576cb5..79fc6ee7f 100644 --- a/include/multiplier/PersistentMap.h +++ b/include/multiplier/PersistentMap.h @@ -191,51 +191,6 @@ class PersistentSet { } }; -template -class Iterator { - private: - std::shared_ptr stmt; - std::tuple value; - - template - void Read(std::index_sequence) { - auto res = stmt->GetResult(); - res.Columns(std::get(value)...); - } - - public: - Iterator(std::shared_ptr stmt) - : stmt(std::move(stmt)) { - this->operator++(); - } - - bool operator==(const Iterator& b) const { - return stmt == b.stmt; - } - - bool operator!=(const Iterator& b) const { - return stmt != b.stmt; - } - - Iterator& operator++(void) { - if(!stmt->ExecuteStep()) { - stmt = nullptr; - return *this; - } - - Read(std::make_index_sequence()); - return *this; - } - - const std::tuple &operator*(void) const { - return value; - } - - const std::tuple *operator->(void) const { - return &value; - } -}; - // Persistent mapping from keys to values. template class PersistentMap { @@ -303,114 +258,6 @@ class PersistentMap { return std::nullopt; } - - Iterator begin() { - return Iterator(enum_stmt); - } - - Iterator end() { - return Iterator(nullptr); - } -}; - -template -class PersistentMap2 { - private: - sqlite::Connection &db; - std::shared_ptr set_stmt; - std::shared_ptr get_stmt; - std::shared_ptr get_or_set_stmt; - std::shared_ptr enum_stmt; - std::shared_ptr enum_k1_stmt; - std::shared_ptr enum_k2_stmt; - - public: - PersistentMap2(sqlite::Connection &db) : db(db) { - std::stringstream ss; - ss << "CREATE TABLE IF NOT EXISTS " - << table_names[kId] << "(key1, key2, value, PRIMARY KEY(key1, key2))"; - db.Execute(ss.str()); - - ss = {}; - ss << "INSERT OR REPLACE INTO " - << table_names[kId] << "(key1, key2, value) VALUES (?1, ?2, ?3)"; - set_stmt = db.Prepare(ss.str()); - - ss = {}; - ss << "SELECT key1, key2, value FROM " - << table_names[kId] << " WHERE key1 = ?1 AND key2 = ?2"; - get_stmt = db.Prepare(ss.str()); - - ss = {}; - ss << "INSERT INTO " << table_names[kId] - << "(key1, key2, value) VALUES(?1, ?2, ?3) " - << "ON CONFLICT DO UPDATE SET value=value RETURNING key1, key2, value"; - get_or_set_stmt = db.Prepare(ss.str()); - - ss = {}; - ss << "SELECT key1, key2, value FROM " << table_names; - enum_stmt = db.Prepare(ss.str()); - - ss = {}; - ss << "SELECT key1, key2, value FROM " << table_names[kId] - << " WHERE key1 = ?1"; - enum_k1_stmt = db.Prepare(ss.str()); - - ss = {}; - ss << "SELECT key1, key2, value FROM " << table_names[kId] - << " WHERE key2 = ?1"; - enum_k1_stmt = db.Prepare(ss.str()); - } - - V GetOrSet(K1 key1, K2 key2, V val) const { - get_or_set_stmt->BindValues(key1, key2, val); - get_or_set_stmt->ExecuteStep(); - auto res = get_or_set_stmt->GetResult(); - K1 stored_key1; - K2 stored_key2; - V stored_value; - res.Columns(stored_key1, stored_key2, stored_value); - get_or_set_stmt->ExecuteStep(); - return stored_value; - } - - void Set(K1 key1, K2 key2, V val) const { - set_stmt->BindValues(key1, key2, val); - set_stmt->Execute(); - } - - std::optional TryGet(K1 key1, K2 key2) const { - get_stmt->BindValues(key1, key2); - if(get_stmt->ExecuteStep()) { - K1 stored_key1; - K2 stored_key2; - V stored_value; - auto res = get_stmt->GetResult(); - res.Columns(stored_key1, stored_key2, stored_value); - get_stmt->ExecuteStep(); - return stored_value; - } - - return std::nullopt; - } - - Iterator begin() { - return Iterator(enum_stmt); - } - - Iterator key1_equals(K1 key) { - enum_k1_stmt->BindValues(key); - return Iterator(enum_k1_stmt); - } - - Iterator key2_equals(K2 key) { - enum_k2_stmt->BindValues(key); - return Iterator(enum_k1_stmt); - } - - Iterator end() { - return Iterator(nullptr); - } }; } // namespace mx diff --git a/lib/API/Grammar.h b/lib/API/Grammar.h index f9f865302..4135d4b24 100644 --- a/lib/API/Grammar.h +++ b/lib/API/Grammar.h @@ -7,7 +7,6 @@ #pragma once #include - #include namespace mx { @@ -17,8 +16,6 @@ struct GrammarNode; // // One set of grammar leaves -// FIXME(frabert): Deserialization crashes if this is turned into -// an `absl::flat_hash_map` // using GrammarLeaves = std::unordered_map; From 6d62e0164274b310bf03a3ea174190bf12fd385b Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Wed, 9 Nov 2022 18:27:52 +0100 Subject: [PATCH 6/8] Store metavariable name in match --- include/multiplier/Syntex.h | 6 +++--- lib/API/Query.cpp | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/include/multiplier/Syntex.h b/include/multiplier/Syntex.h index 7c8cabbc9..1c74f12c5 100644 --- a/include/multiplier/Syntex.h +++ b/include/multiplier/Syntex.h @@ -33,15 +33,15 @@ class ParsedQueryImpl; class MetavarMatch { private: - std::string_view name; + std::string name; mx::EntityId entity; public: - MetavarMatch(std::string_view name_, mx::EntityId entity_) + MetavarMatch(const std::string& name_, mx::EntityId entity_) : name(std::move(name_)), entity(std::move(entity_)) {} - const std::string_view &Name(void) const { + const std::string &Name(void) const { return name; } diff --git a/lib/API/Query.cpp b/lib/API/Query.cpp index 695880ff2..56ec9ce33 100644 --- a/lib/API/Query.cpp +++ b/lib/API/Query.cpp @@ -655,7 +655,9 @@ std::pair> ParsedQueryImpl::MatchMarker( switch (marker.m_kind) { case ParseMarker::METAVAR: if (marker.m_metavar) { - MetavarMatch mv_match(marker.m_metavar->m_name, node.entity); + MetavarMatch mv_match( + {marker.m_metavar->m_name.data(), marker.m_metavar->m_name.size()}, + node.entity); if (auto &predicate = marker.m_metavar->m_predicate) { if (!(*predicate)(mv_match)) { return {false, {}}; From 1f4b33cb7389de95158d29f4e067b4be3b002feb Mon Sep 17 00:00:00 2001 From: Peter Goodman Date: Thu, 10 Nov 2022 11:58:32 -0500 Subject: [PATCH 7/8] Update Syntex.h --- include/multiplier/Syntex.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/multiplier/Syntex.h b/include/multiplier/Syntex.h index 1c74f12c5..34a08b055 100644 --- a/include/multiplier/Syntex.h +++ b/include/multiplier/Syntex.h @@ -37,7 +37,7 @@ class MetavarMatch { mx::EntityId entity; public: - MetavarMatch(const std::string& name_, mx::EntityId entity_) + MetavarMatch(const std::string &name_, mx::EntityId entity_) : name(std::move(name_)), entity(std::move(entity_)) {} @@ -81,4 +81,4 @@ class Match { }; } // namespace syntex -} // namespace mx \ No newline at end of file +} // namespace mx From f3a21dd11f82d4b3742b0e8630190f4fe4e1373a Mon Sep 17 00:00:00 2001 From: Francesco Bertolaccini Date: Fri, 18 Nov 2022 12:34:30 +0100 Subject: [PATCH 8/8] Refactor API --- bin/Index/BuildAST.cpp | 8 +-- bin/Query/PredicateExample.cpp | 11 ++-- bin/Query/SyntexQuery.cpp | 13 ++-- include/multiplier/Index.h | 24 +++---- include/multiplier/NodeKind.h | 28 ++++---- include/multiplier/Query.h | 25 +++++++ include/multiplier/Syntex.h | 24 +++---- lib/API/CachingEntityProvider.cpp | 2 +- lib/API/CachingEntityProvider.h | 4 +- lib/API/Grammar.h | 10 ++- lib/API/Index.cpp | 16 +---- lib/API/InvalidEntityProvider.cpp | 2 +- lib/API/InvalidEntityProvider.h | 2 +- lib/API/Query.cpp | 105 +++++++++++++++--------------- lib/API/Query.h | 99 ++++++++++------------------ lib/API/SQLiteEntityProvider.cpp | 8 +-- lib/API/SQLiteEntityProvider.h | 2 +- 17 files changed, 180 insertions(+), 203 deletions(-) diff --git a/bin/Index/BuildAST.cpp b/bin/Index/BuildAST.cpp index 12ade8c50..e5df35cb2 100644 --- a/bin/Index/BuildAST.cpp +++ b/bin/Index/BuildAST.cpp @@ -39,7 +39,7 @@ static void SerializeAST(mx::Fragment fragment, ServerContext &ctx) { // Start with the token node mx::ASTNode node{}; - node.kind = mx::syntex::NodeKind{tok.kind()}.Serialize(); + node.kind = mx::SyntexNodeKind{tok.kind()}.Serialize(); node.entity = tok.id(); node.spelling = std::string(tok.data().data(), tok.data().size()); node.prev = ast.GetNodeInIndex(fragment.id(), node.kind); @@ -61,7 +61,7 @@ static void SerializeAST(mx::Fragment fragment, ServerContext &ctx) { if (auto decl = mx::Decl::from(*ctx)) { mx::ASTNode parent{}; - parent.kind = mx::syntex::NodeKind{decl->kind()}.Serialize(); + parent.kind = mx::SyntexNodeKind{decl->kind()}.Serialize(); parent.entity = decl->id(); parent.prev = ast.GetNodeInIndex(fragment.id(), parent.kind); auto parent_id = ast.AddNode(parent); @@ -75,7 +75,7 @@ static void SerializeAST(mx::Fragment fragment, ServerContext &ctx) { if (auto stmt = mx::Stmt::from(*ctx)) { mx::ASTNode parent{}; - parent.kind = mx::syntex::NodeKind{stmt->kind()}.Serialize(); + parent.kind = mx::SyntexNodeKind{stmt->kind()}.Serialize(); parent.entity = stmt->id(); parent.prev = ast.GetNodeInIndex(fragment.id(), parent.kind); auto parent_id = ast.AddNode(parent); @@ -107,7 +107,7 @@ static void ImportGrammar(mx::Fragment fragment, ServerContext& ctx) { nodes.pop_back(); auto node = ast.GetNode(node_id); - auto node_kind = mx::syntex::NodeKind::Deserialize(node.kind); + auto node_kind = mx::SyntexNodeKind::Deserialize(node.kind); if (!node_kind.IsToken()) { // This is an internal or root node. E.g. given the following: diff --git a/bin/Query/PredicateExample.cpp b/bin/Query/PredicateExample.cpp index d7e72a7a8..24662df7d 100644 --- a/bin/Query/PredicateExample.cpp +++ b/bin/Query/PredicateExample.cpp @@ -72,7 +72,7 @@ static std::optional IntegralTypeWidth(const mx::ValueDecl &decl) { return IntegralTypeWidth(decl.type()); } -static void HighlightMatch(std::ostream &os, mx::Index index, mx::syntex::Match m) { +static void HighlightMatch(std::ostream &os, mx::Index index, mx::SyntexMatch m) { auto stmt = std::get(index.entity(m.MetavarMatch(0).Entity())); auto ref = mx::DeclRefExpr::from(stmt); if (!ref) { @@ -170,15 +170,16 @@ extern "C" int main(int argc, char *argv[]) { // Setup query - auto res = index.query_syntex("$var:DECL_REF_EXPR << $num:INTEGER_LITERAL"); - if (!res.has_value()) { + auto res = index.parse_syntex_query("$var:DECL_REF_EXPR << $num:INTEGER_LITERAL"); + if (!res.IsValid()) { return EXIT_FAILURE; } // Match fragments - for(auto match : res.value()) { + res.ForEachMatch([&](auto match) { HighlightMatch(std::cout, index, std::move(match)); - } + return true; + }); return EXIT_SUCCESS; } diff --git a/bin/Query/SyntexQuery.cpp b/bin/Query/SyntexQuery.cpp index ebf6463e8..ed4eae29e 100644 --- a/bin/Query/SyntexQuery.cpp +++ b/bin/Query/SyntexQuery.cpp @@ -21,7 +21,7 @@ DEFINE_string(query, "", "Use argument value as query"); DEFINE_uint64(threads, 0, "Use this number of threads"); DEFINE_bool(suppress_output, false, "Don't print matches to stdout"); -static void PrintMatch(mx::Index index, const mx::syntex::Match &match) +static void PrintMatch(mx::Index index, const mx::SyntexMatch &match) { if (FLAGS_suppress_output) { return; @@ -87,16 +87,17 @@ extern "C" int main(int argc, char *argv[]) { // Parse query - auto res = index.query_syntex(FLAGS_query); + auto res = index.parse_syntex_query(FLAGS_query); - if (!res.has_value()) { + if (!res.IsValid()) { std::cerr << "Query `" << FLAGS_query << "` has no valid parses\n"; return EXIT_FAILURE; } - for(auto match : *res) { - PrintMatch(index, match); - } + res.ForEachMatch([&](auto match) { + PrintMatch(index, std::move(match)); + return true; + }); return EXIT_SUCCESS; } diff --git a/include/multiplier/Index.h b/include/multiplier/Index.h index 329aa7788..ecc79b82f 100644 --- a/include/multiplier/Index.h +++ b/include/multiplier/Index.h @@ -45,15 +45,12 @@ class WeggliQueryResult; class WeggliQueryResultImpl; struct ASTNode; -namespace syntex { -class Match; -class NodeKind; -class GrammarNode; -class ParsedQuery; -class ParsedQueryImpl; - -using GrammarLeaves = std::unordered_map; -} +class SyntexNodeKind; +class SyntexGrammarNode; +class SyntexQuery; +class SyntexQueryImpl; + +using SyntexGrammarLeaves = std::unordered_map; using DeclUse = Use; using StmtUse = Use; @@ -113,8 +110,8 @@ class EntityProvider { friend class UseIteratorImpl; friend class WeggliQueryResultImpl; friend class WeggliQueryResultIterator; - friend class syntex::ParsedQuery; - friend class syntex::ParsedQueryImpl; + friend class SyntexQuery; + friend class SyntexQueryImpl; protected: @@ -189,7 +186,7 @@ class EntityProvider { virtual std::optional TokenKindOf(std::string_view spelling) = 0; - virtual void LoadGrammarRoot(syntex::GrammarLeaves &root) = 0; + virtual void LoadGrammarRoot(SyntexGrammarLeaves &root) = 0; virtual std::vector GetFragmentsInAST(void) = 0; virtual ASTNode GetASTNode(std::uint64_t id) = 0; @@ -273,8 +270,7 @@ class Index { NamedDeclList query_entities(std::string name, mx::DeclCategory category) const; - std::optional> query_syntex(std::string_view query) const; - std::optional> query_syntex(FragmentId frag, std::string_view query) const; + SyntexQuery parse_syntex_query(std::string_view query); }; } // namespace mx diff --git a/include/multiplier/NodeKind.h b/include/multiplier/NodeKind.h index 7e4ef7d04..0df85d852 100644 --- a/include/multiplier/NodeKind.h +++ b/include/multiplier/NodeKind.h @@ -11,34 +11,33 @@ #include namespace mx { -namespace syntex { // -// NodeKind: Core class of Syntex, represents the following things: +// SyntexNodeKind: Core class of Syntex, represents the following things: // - An entry in a grammar rule // - Kind of node in a multiplier AST // - Kind of node in a query AST // -class NodeKind { +class SyntexNodeKind { private: unsigned short val; - NodeKind(unsigned short val_) : val(val_) {} + SyntexNodeKind(unsigned short val_) : val(val_) {} public: - static NodeKind Any() { - return NodeKind(UpperLimit()); + static SyntexNodeKind Any() { + return SyntexNodeKind(UpperLimit()); } - NodeKind(mx::DeclKind kind) + SyntexNodeKind(mx::DeclKind kind) : val(static_cast(kind)) {} - NodeKind(mx::StmtKind kind) + SyntexNodeKind(mx::StmtKind kind) : val(static_cast(kind) + mx::NumEnumerators(mx::DeclKind{})) {} - NodeKind(mx::TokenKind kind) + SyntexNodeKind(mx::TokenKind kind) : val(static_cast(kind) + mx::NumEnumerators(mx::DeclKind{}) + mx::NumEnumerators(mx::StmtKind{})) {} @@ -72,11 +71,11 @@ class NodeKind { - mx::NumEnumerators(mx::StmtKind{})); } - bool operator==(const NodeKind &other) const { + bool operator==(const SyntexNodeKind &other) const { return val == other.val; } - static NodeKind Deserialize(unsigned short val) { + static SyntexNodeKind Deserialize(unsigned short val) { return val; } @@ -108,7 +107,7 @@ template Visitor(F...) -> Visitor; // Pretty print a NodeKind to an output stream // -inline std::ostream& operator<<(std::ostream &os, const NodeKind &kind) { +inline std::ostream& operator<<(std::ostream &os, const SyntexNodeKind &kind) { kind.Visit(Visitor { [&] (mx::DeclKind kind) { os << "DeclKind::" << EnumeratorName(kind); }, [&] (mx::StmtKind kind) { os << "StmtKind::" << EnumeratorName(kind); }, @@ -118,14 +117,13 @@ inline std::ostream& operator<<(std::ostream &os, const NodeKind &kind) { return os; } -} // namespace syntex } // namespace mx namespace std { template<> -struct hash { - size_t operator()(const mx::syntex::NodeKind &kind) const { +struct hash { + size_t operator()(const mx::SyntexNodeKind &kind) const { return kind.Serialize(); } }; diff --git a/include/multiplier/Query.h b/include/multiplier/Query.h index 01fad218c..7e9514819 100644 --- a/include/multiplier/Query.h +++ b/include/multiplier/Query.h @@ -37,6 +37,10 @@ class WeggliQueryMatch; class WeggliQueryResultIterator; class WeggliQueryResult; class WeggliQueryResultImpl; +class SyntexQuery; +class SyntexQueryImpl; +class SyntexMatch; +class SyntexMetavarMatch; // The range of tokens of a match. class WeggliQueryMatch : public TokenRange { @@ -325,4 +329,25 @@ class RegexQueryResult { } }; +class SyntexQuery { + private: + std::shared_ptr impl; + SyntexQuery(void) = delete; + + public: + explicit SyntexQuery(std::shared_ptr ep, std::string_view query); + + bool IsValid() const; + + bool AddMetavarPredicate(const std::string_view &name, + std::function predicate); + + void ForEachMatch(mx::RawEntityId frag_id, + std::function pred) const; + void ForEachMatch(std::function pred) const; + + std::vector Find(mx::RawEntityId frag_id) const; + std::vector Find(void) const; +}; + } // namespace mx diff --git a/include/multiplier/Syntex.h b/include/multiplier/Syntex.h index 34a08b055..a846786cc 100644 --- a/include/multiplier/Syntex.h +++ b/include/multiplier/Syntex.h @@ -20,24 +20,21 @@ #include "Token.h" namespace mx { -namespace syntex { -class Grammar; -class GrammarImpl; -class ParsedQuery; -class ParsedQueryImpl; +class SyntexQuery; +class SyntexQueryImpl; // // Chunk of a fragment (potentially) matching a metavariable // -class MetavarMatch { +class SyntexMetavarMatch { private: std::string name; mx::EntityId entity; public: - MetavarMatch(const std::string &name_, mx::EntityId entity_) + SyntexMetavarMatch(const std::string &name_, mx::EntityId entity_) : name(std::move(name_)), entity(std::move(entity_)) {} @@ -54,16 +51,16 @@ class MetavarMatch { // Chunk of a ParsedQuery that matched against a part of a fragment // -class Match { +class SyntexMatch { private: - friend class ParsedQuery; + friend class SyntexQuery; mx::EntityId entity; - std::vector metavars; + std::vector metavars; public: - Match(mx::EntityId entity_, std::vector matevars_) + SyntexMatch(mx::EntityId entity_, std::vector matevars_) : entity(std::move(entity_)), metavars(std::move(matevars_)) {} @@ -71,14 +68,13 @@ class Match { return entity; } - const std::vector &MetavarMatches(void) const { + const std::vector &MetavarMatches(void) const { return metavars; } - const MetavarMatch &MetavarMatch(size_t i) const { + const SyntexMetavarMatch &MetavarMatch(size_t i) const { return metavars[i]; } }; -} // namespace syntex } // namespace mx diff --git a/lib/API/CachingEntityProvider.cpp b/lib/API/CachingEntityProvider.cpp index 5d669a046..ceb11af1c 100644 --- a/lib/API/CachingEntityProvider.cpp +++ b/lib/API/CachingEntityProvider.cpp @@ -285,7 +285,7 @@ CachingEntityProvider::TokenKindOf(std::string_view spelling) { return it->second; } -void CachingEntityProvider::LoadGrammarRoot(syntex::GrammarLeaves& root) { +void CachingEntityProvider::LoadGrammarRoot(SyntexGrammarLeaves& root) { if(grammar_root.empty()) { next->LoadGrammarRoot(grammar_root); } diff --git a/lib/API/CachingEntityProvider.h b/lib/API/CachingEntityProvider.h index 3b4c719f5..194263750 100644 --- a/lib/API/CachingEntityProvider.h +++ b/lib/API/CachingEntityProvider.h @@ -68,7 +68,7 @@ class CachingEntityProvider final : public EntityProvider { references; std::unordered_map spelling_to_token_kind; - syntex::GrammarLeaves grammar_root; + SyntexGrammarLeaves grammar_root; std::vector fragments_in_ast; std::unordered_map node_contents; @@ -127,7 +127,7 @@ class CachingEntityProvider final : public EntityProvider { std::optional TokenKindOf(std::string_view spelling) final; - void LoadGrammarRoot(syntex::GrammarLeaves &root) final; + void LoadGrammarRoot(SyntexGrammarLeaves &root) final; std::vector GetFragmentsInAST(void) final; ASTNode GetASTNode(std::uint64_t id) final; diff --git a/lib/API/Grammar.h b/lib/API/Grammar.h index 4135d4b24..a49b6b0d0 100644 --- a/lib/API/Grammar.h +++ b/lib/API/Grammar.h @@ -10,25 +10,23 @@ #include namespace mx { -namespace syntex { -struct GrammarNode; +struct SyntexGrammarNode; // // One set of grammar leaves // -using GrammarLeaves = std::unordered_map; +using SyntexGrammarLeaves = std::unordered_map; // // Node in the grammar tree // -struct GrammarNode { +struct SyntexGrammarNode { // Does this node correspond to the head of a production bool is_production; // Further leaves - GrammarLeaves leaves; + SyntexGrammarLeaves leaves; }; -} // namespace syntex } // namespace mx \ No newline at end of file diff --git a/lib/API/Index.cpp b/lib/API/Index.cpp index 7a912bf1e..463942942 100644 --- a/lib/API/Index.cpp +++ b/lib/API/Index.cpp @@ -330,20 +330,8 @@ NamedDeclList Index::query_entities( return decls; } -std::optional> Index::query_syntex(std::string_view query) const { - syntex::ParsedQuery parsed_query(impl, query); - if(!parsed_query.IsValid()) { - return std::nullopt; - } - return parsed_query.Find(); -} - -std::optional> Index::query_syntex(FragmentId frag, std::string_view query) const { - syntex::ParsedQuery parsed_query(impl, query); - if(!parsed_query.IsValid()) { - return std::nullopt; - } - return parsed_query.Find(frag.fragment_id); +SyntexQuery Index::parse_syntex_query(std::string_view query) { + return SyntexQuery(impl, query); } } // namespace mx diff --git a/lib/API/InvalidEntityProvider.cpp b/lib/API/InvalidEntityProvider.cpp index 45a56843b..185ab7f04 100644 --- a/lib/API/InvalidEntityProvider.cpp +++ b/lib/API/InvalidEntityProvider.cpp @@ -85,7 +85,7 @@ InvalidEntityProvider::TokenKindOf(std::string_view spelling) { return {}; } -void InvalidEntityProvider::LoadGrammarRoot(syntex::GrammarLeaves &) {} +void InvalidEntityProvider::LoadGrammarRoot(SyntexGrammarLeaves &) {} std::vector InvalidEntityProvider::GetFragmentsInAST(void) { return {}; diff --git a/lib/API/InvalidEntityProvider.h b/lib/API/InvalidEntityProvider.h index 5bfc4fbff..76cbfd329 100644 --- a/lib/API/InvalidEntityProvider.h +++ b/lib/API/InvalidEntityProvider.h @@ -61,7 +61,7 @@ class InvalidEntityProvider final : public EntityProvider { std::optional TokenKindOf(std::string_view spelling) final; - void LoadGrammarRoot(syntex::GrammarLeaves &root) final; + void LoadGrammarRoot(SyntexGrammarLeaves &root) final; std::vector GetFragmentsInAST(void) final; ASTNode GetASTNode(std::uint64_t id) final; diff --git a/lib/API/Query.cpp b/lib/API/Query.cpp index 56ec9ce33..76bcb2203 100644 --- a/lib/API/Query.cpp +++ b/lib/API/Query.cpp @@ -5,6 +5,9 @@ // the LICENSE file found in the root directory of this source tree. #include "Query.h" +#include +#include +#include #include #include @@ -12,7 +15,6 @@ #include namespace mx { -namespace syntex { template void Tokenize(TokenCallback token_callback, MetavarCallback metavar_callback, @@ -185,7 +187,7 @@ void Tokenize(TokenCallback token_callback, MetavarCallback metavar_callback, } }; - auto AddMetavar = [&] (std::string_view name, NodeKind filter) { + auto AddMetavar = [&] (std::string_view name, SyntexNodeKind filter) { size_t next = end; for (;;) switch (Look(next - end)) { @@ -245,7 +247,7 @@ void Tokenize(TokenCallback token_callback, MetavarCallback metavar_callback, ; auto name = input.substr(begin + 1, end - begin - 1); - NodeKind filter = NodeKind::Any(); + auto filter = SyntexNodeKind::Any(); // Skip over filter if present if (Match(':')) { @@ -261,21 +263,21 @@ void Tokenize(TokenCallback token_callback, MetavarCallback metavar_callback, for (int i = 0; i < NumEnumerators(mx::DeclKind::TYPE); ++i) { auto kind = static_cast(i); if (EnumeratorName(kind) == filter_str) { - filter = NodeKind(kind); + filter = SyntexNodeKind(kind); goto done_filters; } } for (int i = 0; i < NumEnumerators(mx::StmtKind::NULL_STMT); ++i) { auto kind = static_cast(i); if (EnumeratorName(kind) == filter_str) { - filter = NodeKind(kind); + filter = SyntexNodeKind(kind); goto done_filters; } } for (int i = 0; i < NumEnumerators(mx::TokenKind::UNKNOWN); ++i) { auto kind = static_cast(i); if (EnumeratorName(kind) == filter_str) { - filter = NodeKind(kind); + filter = SyntexNodeKind(kind); goto done_filters; } } @@ -530,13 +532,13 @@ done_filters:; } } -ParsedQueryImpl::ParsedQueryImpl(std::shared_ptr ep, std::string_view input) +SyntexQueryImpl::SyntexQueryImpl(std::shared_ptr ep, std::string_view input) : m_ep(ep), m_input(input) { ep->LoadGrammarRoot(grammar_root); } -void ParsedQueryImpl::MatchGlob(TableEntry &result, - const std::unordered_set &follow, +void SyntexQueryImpl::MatchGlob(TableEntry &result, + const std::unordered_set &follow, Item &item, size_t next) { @@ -559,9 +561,9 @@ void ParsedQueryImpl::MatchGlob(TableEntry &result, // Otherwise the rest of the grammar rule is a candiate for more globbing if (rest.leaves.size() > 0) { - const GrammarLeaves *old_leaves = item.m_leaves; + const SyntexGrammarLeaves *old_leaves = item.m_leaves; item.m_leaves = &rest.leaves; - item.m_children.emplace_back(NodeKind::Any(), next, Glob::YES); + item.m_children.emplace_back(SyntexNodeKind::Any(), next, Glob::YES); MatchGlob(result, follow, item, next); item.m_leaves = old_leaves; item.m_children.pop_back(); @@ -569,10 +571,10 @@ void ParsedQueryImpl::MatchGlob(TableEntry &result, } } -void ParsedQueryImpl::MatchRule(TableEntry &result, Item &item, size_t next) { +void SyntexQueryImpl::MatchRule(TableEntry &result, Item &item, size_t next) { // Iterate shifts for (auto &[key, _] : ParsesAtIndex(next)) { - NodeKind kind = key.first; + SyntexNodeKind kind = key.first; size_t next = key.second; item.IterateShifts(kind, next, Glob::NO, [&] (Item &item) { MatchRule(result, item, next); @@ -582,7 +584,7 @@ void ParsedQueryImpl::MatchRule(TableEntry &result, Item &item, size_t next) { // Iterate glob shifts if (auto it = m_globs.find(next); it != m_globs.end()) { // Compute set of node kinds that can follow $... - std::unordered_set follow; + std::unordered_set follow; for (auto &[key, _] : ParsesAtIndex(it->second)) { follow.insert(key.first); } @@ -590,19 +592,19 @@ void ParsedQueryImpl::MatchRule(TableEntry &result, Item &item, size_t next) { } // Iterate reductions - item.IterateReductions([&] (NodeKind kind, const auto &children) { + item.IterateReductions([&] (SyntexNodeKind kind, const auto &children) { result[{kind, next}].emplace(children); MatchPrefix(result, kind, next); }); } -void ParsedQueryImpl::MatchPrefix(TableEntry &result, NodeKind kind, size_t next) { +void SyntexQueryImpl::MatchPrefix(TableEntry &result, SyntexNodeKind kind, size_t next) { Item(&grammar_root).IterateShifts(kind, next, Glob::NO, [&] (Item &item) { MatchRule(result, item, next); }); } -const ParsedQueryImpl::TableEntry &ParsedQueryImpl::ParsesAtIndex(size_t index) { +const SyntexQueryImpl::TableEntry &SyntexQueryImpl::ParsesAtIndex(size_t index) { // Lookup memoized parses at this index auto it = m_parses.find(index); if (it != m_parses.end()) { @@ -621,7 +623,7 @@ const ParsedQueryImpl::TableEntry &ParsedQueryImpl::ParsesAtIndex(size_t index) } }; - auto MetavarCallback = [&] (std::string_view name, NodeKind filter, size_t next) { + auto MetavarCallback = [&] (std::string_view name, SyntexNodeKind filter, size_t next) { if (name == "") { result[{filter, next}].emplace(nullptr); } else { @@ -644,18 +646,18 @@ const ParsedQueryImpl::TableEntry &ParsedQueryImpl::ParsesAtIndex(size_t index) return result; } -std::pair> ParsedQueryImpl::MatchMarker( - const TableEntry &entry, const ParseMarker &marker, std::uint64_t node_id) { +std::pair> SyntexQueryImpl::MatchMarker( + const TableEntry &entry, const SyntexParseMarker &marker, std::uint64_t node_id) { - std::vector metavar_matches; + std::vector metavar_matches; auto node = m_ep->GetASTNode(node_id); - auto kind = NodeKind::Deserialize(node.kind); + auto kind = SyntexNodeKind::Deserialize(node.kind); auto children = m_ep->GetASTNodeChildren(node_id); switch (marker.m_kind) { - case ParseMarker::METAVAR: + case SyntexParseMarker::METAVAR: if (marker.m_metavar) { - MetavarMatch mv_match( + SyntexMetavarMatch mv_match( {marker.m_metavar->m_name.data(), marker.m_metavar->m_name.size()}, node.entity); if (auto &predicate = marker.m_metavar->m_predicate) { @@ -666,9 +668,9 @@ std::pair> ParsedQueryImpl::MatchMarker( metavar_matches.push_back(std::move(mv_match)); } return {true, metavar_matches}; - case ParseMarker::TERMINAL: + case SyntexParseMarker::TERMINAL: return {kind.IsToken() && node.spelling == marker.m_spelling, {}}; - case ParseMarker::NONTERMINAL: + case SyntexParseMarker::NONTERMINAL: if (kind.IsToken() || children.size() != marker.m_children.size()) { return {false, {}}; @@ -680,9 +682,9 @@ std::pair> ParsedQueryImpl::MatchMarker( for (std::uint64_t child_node_id : children) { auto &[kind, next, glob] = *child_it; auto child_node = m_ep->GetASTNode(child_node_id); - auto child_node_kind = NodeKind::Deserialize(child_node.kind); + auto child_node_kind = SyntexNodeKind::Deserialize(child_node.kind); - if (kind != NodeKind::Any() && kind != child_node_kind) { + if (kind != SyntexNodeKind::Any() && kind != child_node_kind) { return {false, {}}; } @@ -712,7 +714,7 @@ std::pair> ParsedQueryImpl::MatchMarker( } } -void ParsedQueryImpl::DebugParseTable(std::ostream &os) { +void SyntexQueryImpl::DebugParseTable(std::ostream &os) { // Make sure the DP table was actually filled in ParsesAtIndex(0); @@ -737,13 +739,13 @@ void ParsedQueryImpl::DebugParseTable(std::ostream &os) { // Print body switch (marker.m_kind) { - case ParseMarker::METAVAR: + case SyntexParseMarker::METAVAR: std::cout << "$" << (marker.m_metavar ? marker.m_metavar->m_name : ""); break; - case ParseMarker::TERMINAL: + case SyntexParseMarker::TERMINAL: std::cout << "`" << marker.m_spelling << "`"; break; - case ParseMarker::NONTERMINAL: + case SyntexParseMarker::NONTERMINAL: for (auto &[kind, next, glob] : marker.m_children) { if (glob == Glob::YES) { std::cout << "(" << kind << ", " << next << ", ..." << ") "; @@ -760,10 +762,10 @@ void ParsedQueryImpl::DebugParseTable(std::ostream &os) { } } -ParsedQuery::ParsedQuery(std::shared_ptr ep, std::string_view query) - : impl(std::make_shared(ep, query)) {} +SyntexQuery::SyntexQuery(std::shared_ptr ep, std::string_view query) + : impl(std::make_shared(ep, query)) {} -bool ParsedQuery::IsValid(void) const { +bool SyntexQuery::IsValid(void) const { for (auto &[key, markers] : impl->ParsesAtIndex(0)) { if (key.second == impl->m_input.size()) { return true; @@ -772,9 +774,9 @@ bool ParsedQuery::IsValid(void) const { return false; } -bool ParsedQuery::AddMetavarPredicate( +bool SyntexQuery::AddMetavarPredicate( const std::string_view &name, - std::function predicate) { + std::function predicate) { // Find metavariable name auto it = impl->m_metavars.find(name); @@ -786,7 +788,7 @@ bool ParsedQuery::AddMetavarPredicate( if (it->second.m_predicate) { it->second.m_predicate = [old_pred = std::move(it->second.m_predicate.value()), - new_pred = std::move(predicate)] (const MetavarMatch &mvm) -> bool { + new_pred = std::move(predicate)] (const SyntexMetavarMatch &mvm) -> bool { return old_pred(mvm) && new_pred(mvm); }; @@ -797,9 +799,9 @@ bool ParsedQuery::AddMetavarPredicate( return true; } -void ParsedQuery::ForEachMatch(std::function pred) const { +void SyntexQuery::ForEachMatch(std::function pred) const { bool done = false; - auto real_pred = [sub_pred = std::move(pred), &done] (Match m) -> bool { + auto real_pred = [sub_pred = std::move(pred), &done] (SyntexMatch m) -> bool { if (sub_pred(std::move(m))) { return true; } else { @@ -815,19 +817,19 @@ void ParsedQuery::ForEachMatch(std::function pred) const { } } -std::vector ParsedQuery::Find(mx::RawEntityId frag) const { - std::vector ret; - ForEachMatch(frag, [&ret] (Match m) -> bool { +std::vector SyntexQuery::Find(mx::RawEntityId frag) const { + std::vector ret; + ForEachMatch(frag, [&ret] (SyntexMatch m) -> bool { ret.emplace_back(std::move(m)); return true; }); return ret; } -std::vector ParsedQuery::Find(void) const { - std::vector ret; +std::vector SyntexQuery::Find(void) const { + std::vector ret; for (auto frag_id : impl->m_ep->GetFragmentsInAST()) { - ForEachMatch(frag_id, [&ret] (Match m) -> bool { + ForEachMatch(frag_id, [&ret] (SyntexMatch m) -> bool { ret.emplace_back(std::move(m)); return true; }); @@ -835,8 +837,8 @@ std::vector ParsedQuery::Find(void) const { return ret; } -void ParsedQuery::ForEachMatch(mx::RawEntityId frag_id, - std::function pred) const { +void SyntexQuery::ForEachMatch(mx::RawEntityId frag_id, + std::function pred) const { auto frag = impl->m_ep->FragmentFor(impl->m_ep, frag_id); // Find matching AST node @@ -845,13 +847,13 @@ void ParsedQuery::ForEachMatch(mx::RawEntityId frag_id, if (key.second != impl->m_input.size()) { continue; } - if (key.first == NodeKind::Any()) { + if (key.first == SyntexNodeKind::Any()) { for (auto ast_node_id : impl->m_ep->GetASTNodesInFragment(frag_id)) { auto ast_node = impl->m_ep->GetASTNode(ast_node_id); for (auto &marker : markers) { auto [ok, metavar_matches] = impl->MatchMarker( entry, marker, ast_node_id); - if (ok && !pred(Match(ast_node.entity, metavar_matches))) { + if (ok && !pred(SyntexMatch(ast_node.entity, metavar_matches))) { return; } } @@ -862,7 +864,7 @@ void ParsedQuery::ForEachMatch(mx::RawEntityId frag_id, auto ast_node = impl->m_ep->GetASTNode(*ast_node_id); for (auto &marker : markers) { auto [ok, metavar_matches] = impl->MatchMarker(entry, marker, *ast_node_id); - if (ok && !pred(Match(ast_node.entity, metavar_matches))) { + if (ok && !pred(SyntexMatch(ast_node.entity, metavar_matches))) { return; } } @@ -872,5 +874,4 @@ void ParsedQuery::ForEachMatch(mx::RawEntityId frag_id, } } -} // namespace syntex } // namespace mx diff --git a/lib/API/Query.h b/lib/API/Query.h index 2bf4d4672..50cb1d7d5 100644 --- a/lib/API/Query.h +++ b/lib/API/Query.h @@ -8,6 +8,7 @@ #include #include "Grammar.h" +#include #include #include #include @@ -21,8 +22,8 @@ inline void hash_combine(size_t &h, const T& v) } template<> -struct std::hash> { - size_t operator()(const std::pair &self) const { +struct std::hash> { + size_t operator()(const std::pair &self) const { size_t hash = 0; hash_combine(hash, self.first); hash_combine(hash, self.second); @@ -31,38 +32,13 @@ struct std::hash> { }; namespace mx { -namespace syntex { -// -// Result of parsing a query -// - -class ParsedQuery { - private: - std::shared_ptr impl; - ParsedQuery(void) = delete; - - public: - explicit ParsedQuery(std::shared_ptr ep, std::string_view query); - - bool IsValid() const; - - bool AddMetavarPredicate(const std::string_view &name, - std::function predicate); - - void ForEachMatch(mx::RawEntityId frag_id, - std::function pred) const; - void ForEachMatch(std::function pred) const; - - std::vector Find(mx::RawEntityId frag_id) const; - std::vector Find(void) const; -}; struct Metavar { std::string_view m_name; - std::optional> m_predicate; + std::optional> m_predicate; explicit Metavar(std::string_view name, - std::optional> predicate) + std::optional> predicate) : m_name(name), m_predicate(std::move(predicate)) {} }; @@ -71,7 +47,7 @@ enum class Glob { YES }; -struct ParseMarker { +struct SyntexParseMarker { // Node category enum { @@ -84,19 +60,19 @@ struct ParseMarker { union { Metavar *m_metavar; std::string_view m_spelling; - std::vector> m_children; + std::vector> m_children; }; - explicit ParseMarker(Metavar *metavar) + explicit SyntexParseMarker(Metavar *metavar) : m_kind(METAVAR), m_metavar(metavar) {} - explicit ParseMarker(std::string_view spelling) + explicit SyntexParseMarker(std::string_view spelling) : m_kind(TERMINAL), m_spelling(spelling) {} - explicit ParseMarker(const std::vector> &children) + explicit SyntexParseMarker(const std::vector> &children) : m_kind(NONTERMINAL), m_children(children) {} - ParseMarker(ParseMarker &&other) + SyntexParseMarker(SyntexParseMarker &&other) : m_kind(other.m_kind) { switch (m_kind) { @@ -107,12 +83,12 @@ struct ParseMarker { new (&m_spelling) std::string_view(other.m_spelling); break; case NONTERMINAL: - new (&m_children) std::vector>(std::move(other.m_children)); + new (&m_children) std::vector>(std::move(other.m_children)); break; } } - ~ParseMarker() { + ~SyntexParseMarker() { switch (m_kind) { case METAVAR: break; @@ -120,12 +96,12 @@ struct ParseMarker { m_spelling.std::string_view::~string_view(); break; case NONTERMINAL: - m_children.std::vector>::~vector(); + m_children.std::vector>::~vector(); break; } } - bool operator==(const ParseMarker &other) const { + bool operator==(const SyntexParseMarker &other) const { if (m_kind != other.m_kind) { return false; } @@ -144,21 +120,20 @@ struct ParseMarker { } }; -} // namespace syntex } // namespace mx template<> -struct std::hash { - size_t operator()(const mx::syntex::ParseMarker &self) const { +struct std::hash { + size_t operator()(const mx::SyntexParseMarker &self) const { size_t hash = 0; hash_combine(hash, self.m_kind); switch (self.m_kind) { - case mx::syntex::ParseMarker::METAVAR: + case mx::SyntexParseMarker::METAVAR: break; - case mx::syntex::ParseMarker::TERMINAL: + case mx::SyntexParseMarker::TERMINAL: hash_combine(hash, self.m_spelling); break; - case mx::syntex::ParseMarker::NONTERMINAL: + case mx::SyntexParseMarker::NONTERMINAL: for (auto &[kind, next, glob] : self.m_children) { hash_combine(hash, kind); hash_combine(hash, next); @@ -171,24 +146,23 @@ struct std::hash { }; namespace mx { -namespace syntex { // // Parser state (e.g. a pointer into the grammar trie) // struct Item { - const GrammarLeaves *m_leaves; - std::vector> m_children; + const SyntexGrammarLeaves *m_leaves; + std::vector> m_children; - explicit Item(const GrammarLeaves *leaves) + explicit Item(const SyntexGrammarLeaves *leaves) : m_leaves(leaves) {} template - void IterateShifts(NodeKind kind, size_t next, Glob glob, F cb) { - if (kind == NodeKind::Any()) { - const GrammarLeaves *old_leaves = m_leaves; + void IterateShifts(SyntexNodeKind kind, size_t next, Glob glob, F cb) { + if (kind == SyntexNodeKind::Any()) { + const SyntexGrammarLeaves *old_leaves = m_leaves; m_children.emplace_back(kind, next, glob); for (auto &[kind, rest] : *m_leaves) { @@ -209,7 +183,7 @@ struct Item { } // Morph ourselves into the shifted state - const GrammarLeaves *old_leaves = m_leaves; + const SyntexGrammarLeaves *old_leaves = m_leaves; m_leaves = &it->second.leaves; m_children.emplace_back(kind, next, glob); @@ -236,17 +210,17 @@ struct Item { // Wrapper around parsing functions // -struct ParsedQueryImpl { +struct SyntexQueryImpl { std::shared_ptr m_ep; // Input string std::string_view m_input; - GrammarLeaves grammar_root; + SyntexGrammarLeaves grammar_root; // Main DP parse table - using TableEntry = std::unordered_map, - std::unordered_set>; + using TableEntry = std::unordered_map, + std::unordered_set>; std::unordered_map m_parses; @@ -256,22 +230,21 @@ struct ParsedQueryImpl { // Globs std::unordered_map m_globs; - void MatchGlob(TableEntry &result, const std::unordered_set &follow, + void MatchGlob(TableEntry &result, const std::unordered_set &follow, Item &item, size_t next); void MatchRule(TableEntry &result, Item &item, size_t next); - void MatchPrefix(TableEntry &result, NodeKind kind, size_t next); + void MatchPrefix(TableEntry &result, SyntexNodeKind kind, size_t next); const TableEntry &ParsesAtIndex(size_t index); - explicit ParsedQueryImpl(std::shared_ptr ep, std::string_view input); + explicit SyntexQueryImpl(std::shared_ptr ep, std::string_view input); void DebugParseTable(std::ostream &os); - std::pair> MatchMarker( - const TableEntry &entry, const ParseMarker &marker, std::uint64_t node_id); + std::pair> MatchMarker( + const TableEntry &entry, const SyntexParseMarker &marker, std::uint64_t node_id); }; -} // namespace syntex } // namespace mx \ No newline at end of file diff --git a/lib/API/SQLiteEntityProvider.cpp b/lib/API/SQLiteEntityProvider.cpp index 661d8c9c6..4983c3652 100644 --- a/lib/API/SQLiteEntityProvider.cpp +++ b/lib/API/SQLiteEntityProvider.cpp @@ -342,14 +342,14 @@ SQLiteEntityProvider::TokenKindOf(std::string_view spelling) { return storage.spelling_to_token_kind.TryGet(spelling); } -void SQLiteEntityProvider::LoadGrammarRoot(syntex::GrammarLeaves &root) { +void SQLiteEntityProvider::LoadGrammarRoot(SyntexGrammarLeaves &root) { auto &storage = d->GetStorage(); auto &grammar = storage.grammar; - std::vector> to_load; + std::vector> to_load; for(auto [kind, id] : grammar.GetChildren(0)) { auto data = grammar.GetNode(id); - auto &node = root[syntex::NodeKind::Deserialize(kind)]; + auto &node = root[SyntexNodeKind::Deserialize(kind)]; node.is_production = data.is_production; to_load.emplace_back(id, &node); } @@ -362,7 +362,7 @@ void SQLiteEntityProvider::LoadGrammarRoot(syntex::GrammarLeaves &root) { for(auto [kind, child_id] : grammar.GetChildren(id)) { auto data = grammar.GetNode(child_id); - auto &child_node = node.leaves[syntex::NodeKind::Deserialize(kind)]; + auto &child_node = node.leaves[SyntexNodeKind::Deserialize(kind)]; child_node.is_production = data.is_production; to_load.emplace_back(child_id, &child_node); } diff --git a/lib/API/SQLiteEntityProvider.h b/lib/API/SQLiteEntityProvider.h index 38b9d56b3..3ae1e5346 100644 --- a/lib/API/SQLiteEntityProvider.h +++ b/lib/API/SQLiteEntityProvider.h @@ -68,7 +68,7 @@ class SQLiteEntityProvider final : public EntityProvider { std::optional TokenKindOf(std::string_view spelling) final; - void LoadGrammarRoot(syntex::GrammarLeaves &root) final; + void LoadGrammarRoot(SyntexGrammarLeaves &root) final; std::vector GetFragmentsInAST(void) final; ASTNode GetASTNode(std::uint64_t id) final;