diff --git a/cpp2rust/cpp_rule_preprocessor.cpp b/cpp2rust/cpp_rule_preprocessor.cpp index 22329aee..a612f020 100644 --- a/cpp2rust/cpp_rule_preprocessor.cpp +++ b/cpp2rust/cpp_rule_preprocessor.cpp @@ -86,6 +86,26 @@ class Callback : public clang::ast_matchers::MatchFinder::MatchCallback { void run(const clang::ast_matchers::MatchFinder::MatchResult &R) override { assert(sema_); Mapper::PushASTContext scoped(*R.Context); + if (auto func = R.Nodes.getNodeAs("validate_func")) { + const char *err = nullptr; + if (auto body = + clang::dyn_cast_or_null(func->getBody())) { + if (body->size() != 1) { + err = "body must contain exactly one statement (a return)"; + } else if (!clang::isa(*body->body_begin())) { + err = "body must be a return statement"; + } + } else { + err = "body cannot be empty"; + } + + if (err) { + llvm::errs() << "ERROR: " << func->getQualifiedNameAsString() << ": " + << err << '\n'; + std::exit(EXIT_FAILURE); + } + return; + } if (auto var = R.Nodes.getNodeAs("tvar")) { clang::QualType type; if (auto *tdecl = var->getDescribedAliasTemplate()) { @@ -687,6 +707,12 @@ class ActionFactory : public clang::tooling::FrontendActionFactory { typeAliasDecl(matchesName("(^|::)t[0-9]+$"), isExpansionInMainFile()) .bind("tvar"), &cb_); + + finder_.addMatcher(functionDecl(isDefinition(), + matchesName("(^|::)f[0-9]+$"), + isExpansionInMainFile()) + .bind("validate_func"), + &cb_); } std::unique_ptr create() override {