[clang] Refactor AST printing tests to share more infrastructure

Differential Revision: https://reviews.llvm.org/D105457
This commit is contained in:
Nathan Ridge 2021-07-06 01:40:24 -04:00
parent c41e67f3f1
commit 20176bc7dd
4 changed files with 100 additions and 137 deletions

View File

@ -19,72 +19,88 @@
namespace clang { namespace clang {
using PolicyAdjusterType = using PrintingPolicyAdjuster = llvm::function_ref<void(PrintingPolicy &Policy)>;
Optional<llvm::function_ref<void(PrintingPolicy &Policy)>>;
static void PrintStmt(raw_ostream &Out, const ASTContext *Context, template <typename NodeType>
const Stmt *S, PolicyAdjusterType PolicyAdjuster) { using NodePrinter =
assert(S != nullptr && "Expected non-null Stmt"); std::function<void(llvm::raw_ostream &Out, const ASTContext *Context,
PrintingPolicy Policy = Context->getPrintingPolicy(); const NodeType *Node,
if (PolicyAdjuster) PrintingPolicyAdjuster PolicyAdjuster)>;
(*PolicyAdjuster)(Policy);
S->printPretty(Out, /*Helper*/ nullptr, Policy);
}
template <typename NodeType>
using NodeFilter = std::function<bool(const NodeType *Node)>;
template <typename NodeType>
class PrintMatch : public ast_matchers::MatchFinder::MatchCallback { class PrintMatch : public ast_matchers::MatchFinder::MatchCallback {
using PrinterT = NodePrinter<NodeType>;
using FilterT = NodeFilter<NodeType>;
SmallString<1024> Printed; SmallString<1024> Printed;
unsigned NumFoundStmts; unsigned NumFoundNodes;
PolicyAdjusterType PolicyAdjuster; PrinterT Printer;
FilterT Filter;
PrintingPolicyAdjuster PolicyAdjuster;
public: public:
PrintMatch(PolicyAdjusterType PolicyAdjuster) PrintMatch(PrinterT Printer, PrintingPolicyAdjuster PolicyAdjuster,
: NumFoundStmts(0), PolicyAdjuster(PolicyAdjuster) {} FilterT Filter)
: NumFoundNodes(0), Printer(std::move(Printer)),
Filter(std::move(Filter)), PolicyAdjuster(PolicyAdjuster) {}
void run(const ast_matchers::MatchFinder::MatchResult &Result) override { void run(const ast_matchers::MatchFinder::MatchResult &Result) override {
const Stmt *S = Result.Nodes.getNodeAs<Stmt>("id"); const NodeType *N = Result.Nodes.getNodeAs<NodeType>("id");
if (!S) if (!N || !Filter(N))
return; return;
NumFoundStmts++; NumFoundNodes++;
if (NumFoundStmts > 1) if (NumFoundNodes > 1)
return; return;
llvm::raw_svector_ostream Out(Printed); llvm::raw_svector_ostream Out(Printed);
PrintStmt(Out, Result.Context, S, PolicyAdjuster); Printer(Out, Result.Context, N, PolicyAdjuster);
} }
StringRef getPrinted() const { return Printed; } StringRef getPrinted() const { return Printed; }
unsigned getNumFoundStmts() const { return NumFoundStmts; } unsigned getNumFoundNodes() const { return NumFoundNodes; }
}; };
template <typename T> template <typename NodeType, typename Matcher>
::testing::AssertionResult ::testing::AssertionResult PrintedNodeMatches(
PrintedStmtMatches(StringRef Code, const std::vector<std::string> &Args, StringRef Code, const std::vector<std::string> &Args,
const T &NodeMatch, StringRef ExpectedPrinted, const Matcher &NodeMatch, StringRef ExpectedPrinted, StringRef FileName,
PolicyAdjusterType PolicyAdjuster = None) { NodePrinter<NodeType> Printer,
PrintingPolicyAdjuster PolicyAdjuster = nullptr, bool AllowError = false,
NodeFilter<NodeType> Filter = [](const NodeType *) { return true; }) {
PrintMatch Printer(PolicyAdjuster); PrintMatch<NodeType> Callback(Printer, PolicyAdjuster, Filter);
ast_matchers::MatchFinder Finder; ast_matchers::MatchFinder Finder;
Finder.addMatcher(NodeMatch, &Printer); Finder.addMatcher(NodeMatch, &Callback);
std::unique_ptr<tooling::FrontendActionFactory> Factory( std::unique_ptr<tooling::FrontendActionFactory> Factory(
tooling::newFrontendActionFactory(&Finder)); tooling::newFrontendActionFactory(&Finder));
if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args)) bool ToolResult;
if (FileName.empty()) {
ToolResult = tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args);
} else {
ToolResult =
tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName);
}
if (!ToolResult && !AllowError)
return testing::AssertionFailure() return testing::AssertionFailure()
<< "Parsing error in \"" << Code.str() << "\""; << "Parsing error in \"" << Code.str() << "\"";
if (Printer.getNumFoundStmts() == 0) if (Callback.getNumFoundNodes() == 0)
return testing::AssertionFailure() << "Matcher didn't find any statements"; return testing::AssertionFailure() << "Matcher didn't find any nodes";
if (Printer.getNumFoundStmts() > 1) if (Callback.getNumFoundNodes() > 1)
return testing::AssertionFailure() return testing::AssertionFailure()
<< "Matcher should match only one statement (found " << "Matcher should match only one node (found "
<< Printer.getNumFoundStmts() << ")"; << Callback.getNumFoundNodes() << ")";
if (Printer.getPrinted() != ExpectedPrinted) if (Callback.getPrinted() != ExpectedPrinted)
return ::testing::AssertionFailure() return ::testing::AssertionFailure()
<< "Expected \"" << ExpectedPrinted.str() << "\", got \"" << "Expected \"" << ExpectedPrinted.str() << "\", got \""
<< Printer.getPrinted().str() << "\""; << Callback.getPrinted().str() << "\"";
return ::testing::AssertionSuccess(); return ::testing::AssertionSuccess();
} }

View File

@ -18,6 +18,7 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "ASTPrint.h"
#include "clang/AST/ASTContext.h" #include "clang/AST/ASTContext.h"
#include "clang/ASTMatchers/ASTMatchFinder.h" #include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/ASTMatchers/ASTMatchers.h" #include "clang/ASTMatchers/ASTMatchers.h"
@ -32,10 +33,8 @@ using namespace tooling;
namespace { namespace {
using PrintingPolicyModifier = void (*)(PrintingPolicy &policy);
void PrintDecl(raw_ostream &Out, const ASTContext *Context, const Decl *D, void PrintDecl(raw_ostream &Out, const ASTContext *Context, const Decl *D,
PrintingPolicyModifier PolicyModifier) { PrintingPolicyAdjuster PolicyModifier) {
PrintingPolicy Policy = Context->getPrintingPolicy(); PrintingPolicy Policy = Context->getPrintingPolicy();
Policy.TerseOutput = true; Policy.TerseOutput = true;
Policy.Indentation = 0; Policy.Indentation = 0;
@ -44,74 +43,23 @@ void PrintDecl(raw_ostream &Out, const ASTContext *Context, const Decl *D,
D->print(Out, Policy, /*Indentation*/ 0, /*PrintInstantiation*/ false); D->print(Out, Policy, /*Indentation*/ 0, /*PrintInstantiation*/ false);
} }
class PrintMatch : public MatchFinder::MatchCallback {
SmallString<1024> Printed;
unsigned NumFoundDecls;
PrintingPolicyModifier PolicyModifier;
public:
PrintMatch(PrintingPolicyModifier PolicyModifier)
: NumFoundDecls(0), PolicyModifier(PolicyModifier) {}
void run(const MatchFinder::MatchResult &Result) override {
const Decl *D = Result.Nodes.getNodeAs<Decl>("id");
if (!D || D->isImplicit())
return;
NumFoundDecls++;
if (NumFoundDecls > 1)
return;
llvm::raw_svector_ostream Out(Printed);
PrintDecl(Out, Result.Context, D, PolicyModifier);
}
StringRef getPrinted() const {
return Printed;
}
unsigned getNumFoundDecls() const {
return NumFoundDecls;
}
};
::testing::AssertionResult ::testing::AssertionResult
PrintedDeclMatches(StringRef Code, const std::vector<std::string> &Args, PrintedDeclMatches(StringRef Code, const std::vector<std::string> &Args,
const DeclarationMatcher &NodeMatch, const DeclarationMatcher &NodeMatch,
StringRef ExpectedPrinted, StringRef FileName, StringRef ExpectedPrinted, StringRef FileName,
PrintingPolicyModifier PolicyModifier = nullptr, PrintingPolicyAdjuster PolicyModifier = nullptr,
bool AllowError = false) { bool AllowError = false) {
PrintMatch Printer(PolicyModifier); return PrintedNodeMatches<Decl>(
MatchFinder Finder; Code, Args, NodeMatch, ExpectedPrinted, FileName, PrintDecl,
Finder.addMatcher(NodeMatch, &Printer); PolicyModifier, AllowError,
std::unique_ptr<FrontendActionFactory> Factory( // Filter out implicit decls
newFrontendActionFactory(&Finder)); [](const Decl *D) { return !D->isImplicit(); });
if (!runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName) &&
!AllowError)
return testing::AssertionFailure()
<< "Parsing error in \"" << Code.str() << "\"";
if (Printer.getNumFoundDecls() == 0)
return testing::AssertionFailure()
<< "Matcher didn't find any declarations";
if (Printer.getNumFoundDecls() > 1)
return testing::AssertionFailure()
<< "Matcher should match only one declaration "
"(found " << Printer.getNumFoundDecls() << ")";
if (Printer.getPrinted() != ExpectedPrinted)
return ::testing::AssertionFailure()
<< "Expected \"" << ExpectedPrinted.str() << "\", "
"got \"" << Printer.getPrinted().str() << "\"";
return ::testing::AssertionSuccess();
} }
::testing::AssertionResult ::testing::AssertionResult
PrintedDeclCXX98Matches(StringRef Code, StringRef DeclName, PrintedDeclCXX98Matches(StringRef Code, StringRef DeclName,
StringRef ExpectedPrinted, StringRef ExpectedPrinted,
PrintingPolicyModifier PolicyModifier = nullptr) { PrintingPolicyAdjuster PolicyModifier = nullptr) {
std::vector<std::string> Args(1, "-std=c++98"); std::vector<std::string> Args(1, "-std=c++98");
return PrintedDeclMatches(Code, Args, namedDecl(hasName(DeclName)).bind("id"), return PrintedDeclMatches(Code, Args, namedDecl(hasName(DeclName)).bind("id"),
ExpectedPrinted, "input.cc", PolicyModifier); ExpectedPrinted, "input.cc", PolicyModifier);
@ -120,7 +68,7 @@ PrintedDeclCXX98Matches(StringRef Code, StringRef DeclName,
::testing::AssertionResult ::testing::AssertionResult
PrintedDeclCXX98Matches(StringRef Code, const DeclarationMatcher &NodeMatch, PrintedDeclCXX98Matches(StringRef Code, const DeclarationMatcher &NodeMatch,
StringRef ExpectedPrinted, StringRef ExpectedPrinted,
PrintingPolicyModifier PolicyModifier = nullptr) { PrintingPolicyAdjuster PolicyModifier = nullptr) {
std::vector<std::string> Args(1, "-std=c++98"); std::vector<std::string> Args(1, "-std=c++98");
return PrintedDeclMatches(Code, return PrintedDeclMatches(Code,
Args, Args,
@ -165,7 +113,7 @@ PrintedDeclCXX98Matches(StringRef Code, const DeclarationMatcher &NodeMatch,
::testing::AssertionResult ::testing::AssertionResult
PrintedDeclCXX17Matches(StringRef Code, const DeclarationMatcher &NodeMatch, PrintedDeclCXX17Matches(StringRef Code, const DeclarationMatcher &NodeMatch,
StringRef ExpectedPrinted, StringRef ExpectedPrinted,
PrintingPolicyModifier PolicyModifier = nullptr) { PrintingPolicyAdjuster PolicyModifier = nullptr) {
std::vector<std::string> Args{"-std=c++17", "-fno-delayed-template-parsing"}; std::vector<std::string> Args{"-std=c++17", "-fno-delayed-template-parsing"};
return PrintedDeclMatches(Code, Args, NodeMatch, ExpectedPrinted, "input.cc", return PrintedDeclMatches(Code, Args, NodeMatch, ExpectedPrinted, "input.cc",
PolicyModifier); PolicyModifier);
@ -174,7 +122,7 @@ PrintedDeclCXX17Matches(StringRef Code, const DeclarationMatcher &NodeMatch,
::testing::AssertionResult ::testing::AssertionResult
PrintedDeclC11Matches(StringRef Code, const DeclarationMatcher &NodeMatch, PrintedDeclC11Matches(StringRef Code, const DeclarationMatcher &NodeMatch,
StringRef ExpectedPrinted, StringRef ExpectedPrinted,
PrintingPolicyModifier PolicyModifier = nullptr) { PrintingPolicyAdjuster PolicyModifier = nullptr) {
std::vector<std::string> Args(1, "-std=c11"); std::vector<std::string> Args(1, "-std=c11");
return PrintedDeclMatches(Code, Args, NodeMatch, ExpectedPrinted, "input.c", return PrintedDeclMatches(Code, Args, NodeMatch, ExpectedPrinted, "input.c",
PolicyModifier); PolicyModifier);

View File

@ -15,6 +15,7 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "ASTPrint.h"
#include "clang/AST/ASTContext.h" #include "clang/AST/ASTContext.h"
#include "clang/AST/Decl.h" #include "clang/AST/Decl.h"
#include "clang/AST/PrettyPrinter.h" #include "clang/AST/PrettyPrinter.h"
@ -66,31 +67,11 @@ public:
const DeclarationMatcher &NodeMatch, StringRef ExpectedPrinted, const DeclarationMatcher &NodeMatch, StringRef ExpectedPrinted,
StringRef FileName, StringRef FileName,
std::function<void(llvm::raw_ostream &, const NamedDecl *)> Print) { std::function<void(llvm::raw_ostream &, const NamedDecl *)> Print) {
PrintMatch Printer(std::move(Print)); return PrintedNodeMatches<NamedDecl>(
MatchFinder Finder; Code, Args, NodeMatch, ExpectedPrinted, FileName,
Finder.addMatcher(NodeMatch, &Printer); [Print](llvm::raw_ostream &Out, const ASTContext *Context,
std::unique_ptr<FrontendActionFactory> Factory = const NamedDecl *ND,
newFrontendActionFactory(&Finder); PrintingPolicyAdjuster PolicyAdjuster) { Print(Out, ND); });
if (!runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName))
return testing::AssertionFailure()
<< "Parsing error in \"" << Code.str() << "\"";
if (Printer.getNumFoundDecls() == 0)
return testing::AssertionFailure()
<< "Matcher didn't find any named declarations";
if (Printer.getNumFoundDecls() > 1)
return testing::AssertionFailure()
<< "Matcher should match only one named declaration "
"(found " << Printer.getNumFoundDecls() << ")";
if (Printer.getPrinted() != ExpectedPrinted)
return ::testing::AssertionFailure()
<< "Expected \"" << ExpectedPrinted.str() << "\", "
"got \"" << Printer.getPrinted().str() << "\"";
return ::testing::AssertionSuccess();
} }
::testing::AssertionResult ::testing::AssertionResult

View File

@ -38,11 +38,29 @@ DeclarationMatcher FunctionBodyMatcher(StringRef ContainingFunction) {
has(compoundStmt(has(stmt().bind("id"))))); has(compoundStmt(has(stmt().bind("id")))));
} }
static void PrintStmt(raw_ostream &Out, const ASTContext *Context,
const Stmt *S, PrintingPolicyAdjuster PolicyAdjuster) {
assert(S != nullptr && "Expected non-null Stmt");
PrintingPolicy Policy = Context->getPrintingPolicy();
if (PolicyAdjuster)
PolicyAdjuster(Policy);
S->printPretty(Out, /*Helper*/ nullptr, Policy);
}
template <typename Matcher>
::testing::AssertionResult
PrintedStmtMatches(StringRef Code, const std::vector<std::string> &Args,
const Matcher &NodeMatch, StringRef ExpectedPrinted,
PrintingPolicyAdjuster PolicyAdjuster = nullptr) {
return PrintedNodeMatches<Stmt>(Code, Args, NodeMatch, ExpectedPrinted, "",
PrintStmt, PolicyAdjuster);
}
template <typename T> template <typename T>
::testing::AssertionResult ::testing::AssertionResult
PrintedStmtCXXMatches(StdVer Standard, StringRef Code, const T &NodeMatch, PrintedStmtCXXMatches(StdVer Standard, StringRef Code, const T &NodeMatch,
StringRef ExpectedPrinted, StringRef ExpectedPrinted,
PolicyAdjusterType PolicyAdjuster = None) { PrintingPolicyAdjuster PolicyAdjuster = nullptr) {
const char *StdOpt; const char *StdOpt;
switch (Standard) { switch (Standard) {
case StdVer::CXX98: StdOpt = "-std=c++98"; break; case StdVer::CXX98: StdOpt = "-std=c++98"; break;
@ -64,7 +82,7 @@ template <typename T>
::testing::AssertionResult ::testing::AssertionResult
PrintedStmtMSMatches(StringRef Code, const T &NodeMatch, PrintedStmtMSMatches(StringRef Code, const T &NodeMatch,
StringRef ExpectedPrinted, StringRef ExpectedPrinted,
PolicyAdjusterType PolicyAdjuster = None) { PrintingPolicyAdjuster PolicyAdjuster = nullptr) {
std::vector<std::string> Args = { std::vector<std::string> Args = {
"-std=c++98", "-std=c++98",
"-target", "i686-pc-win32", "-target", "i686-pc-win32",
@ -79,7 +97,7 @@ template <typename T>
::testing::AssertionResult ::testing::AssertionResult
PrintedStmtObjCMatches(StringRef Code, const T &NodeMatch, PrintedStmtObjCMatches(StringRef Code, const T &NodeMatch,
StringRef ExpectedPrinted, StringRef ExpectedPrinted,
PolicyAdjusterType PolicyAdjuster = None) { PrintingPolicyAdjuster PolicyAdjuster = nullptr) {
std::vector<std::string> Args = { std::vector<std::string> Args = {
"-ObjC", "-ObjC",
"-fobjc-runtime=macosx-10.12.0", "-fobjc-runtime=macosx-10.12.0",
@ -202,10 +220,10 @@ class A {
}; };
)"; )";
// No implicit 'this'. // No implicit 'this'.
ASSERT_TRUE(PrintedStmtCXXMatches(StdVer::CXX11, ASSERT_TRUE(PrintedStmtCXXMatches(
CPPSource, memberExpr(anything()).bind("id"), "field", StdVer::CXX11, CPPSource, memberExpr(anything()).bind("id"), "field",
PolicyAdjusterType(
[](PrintingPolicy &PP) { PP.SuppressImplicitBase = true; }))); [](PrintingPolicy &PP) { PP.SuppressImplicitBase = true; }));
// Print implicit 'this'. // Print implicit 'this'.
ASSERT_TRUE(PrintedStmtCXXMatches(StdVer::CXX11, ASSERT_TRUE(PrintedStmtCXXMatches(StdVer::CXX11,
CPPSource, memberExpr(anything()).bind("id"), "this->field")); CPPSource, memberExpr(anything()).bind("id"), "this->field"));
@ -222,11 +240,10 @@ class A {
@end @end
)"; )";
// No implicit 'self'. // No implicit 'self'.
ASSERT_TRUE(PrintedStmtObjCMatches(ObjCSource, returnStmt().bind("id"), ASSERT_TRUE(PrintedStmtObjCMatches(
"return ivar;\n", ObjCSource, returnStmt().bind("id"), "return ivar;\n",
PolicyAdjusterType([](PrintingPolicy &PP) {
PP.SuppressImplicitBase = true; [](PrintingPolicy &PP) { PP.SuppressImplicitBase = true; }));
})));
// Print implicit 'self'. // Print implicit 'self'.
ASSERT_TRUE(PrintedStmtObjCMatches(ObjCSource, returnStmt().bind("id"), ASSERT_TRUE(PrintedStmtObjCMatches(ObjCSource, returnStmt().bind("id"),
"return self->ivar;\n")); "return self->ivar;\n"));
@ -243,5 +260,6 @@ TEST(StmtPrinter, TerseOutputWithLambdas) {
// body not printed when TerseOutput is on. // body not printed when TerseOutput is on.
ASSERT_TRUE(PrintedStmtCXXMatches( ASSERT_TRUE(PrintedStmtCXXMatches(
StdVer::CXX11, CPPSource, lambdaExpr(anything()).bind("id"), "[] {}", StdVer::CXX11, CPPSource, lambdaExpr(anything()).bind("id"), "[] {}",
PolicyAdjusterType([](PrintingPolicy &PP) { PP.TerseOutput = true; })));
[](PrintingPolicy &PP) { PP.TerseOutput = true; }));
} }