//===- LegalityTest.cpp ---------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/AsmParser/Parser.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/SandboxIR/Function.h" #include "llvm/SandboxIR/Instruction.h" #include "llvm/Support/SourceMgr.h" #include "gtest/gtest.h" using namespace llvm; struct LegalityTest : public testing::Test { LLVMContext C; std::unique_ptr M; std::unique_ptr DT; std::unique_ptr TLII; std::unique_ptr TLI; std::unique_ptr AC; std::unique_ptr LI; std::unique_ptr SE; ScalarEvolution &getSE(llvm::Function &LLVMF) { DT = std::make_unique(LLVMF); TLII = std::make_unique(); TLI = std::make_unique(*TLII); AC = std::make_unique(LLVMF); LI = std::make_unique(*DT); SE = std::make_unique(LLVMF, *TLI, *AC, *DT, *LI); return *SE; } void parseIR(LLVMContext &C, const char *IR) { SMDiagnostic Err; M = parseAssemblyString(IR, Err, C); if (!M) Err.print("LegalityTest", errs()); } }; TEST_F(LegalityTest, Legality) { parseIR(C, R"IR( define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2) { %gep0 = getelementptr float, ptr %ptr, i32 0 %gep1 = getelementptr float, ptr %ptr, i32 1 %gep3 = getelementptr float, ptr %ptr, i32 3 %ld0 = load float, ptr %gep0 %ld0b = load float, ptr %gep0 %ld1 = load float, ptr %gep1 %ld3 = load float, ptr %gep3 store float %ld0, ptr %gep0 store float %ld1, ptr %gep1 store <2 x float> %vec2, ptr %gep1 store <3 x float> %vec3, ptr %gep3 store i8 %arg, ptr %gep1 %fadd0 = fadd float %farg0, %farg0 %fadd1 = fadd fast float %farg1, %farg1 %trunc0 = trunc nuw nsw i64 %v0 to i8 %trunc1 = trunc nsw i64 %v1 to i8 %trunc64to8 = trunc i64 %v0 to i8 %trunc32to8 = trunc i32 %v2 to i8 %cmpSLT = icmp slt i64 %v0, %v1 %cmpSGT = icmp sgt i64 %v0, %v1 ret void } )IR"); llvm::Function *LLVMF = &*M->getFunction("foo"); auto &SE = getSE(*LLVMF); const auto &DL = M->getDataLayout(); sandboxir::Context Ctx(C); auto *F = Ctx.createFunction(LLVMF); auto *BB = &*F->begin(); auto It = BB->begin(); [[maybe_unused]] auto *Gep0 = cast(&*It++); [[maybe_unused]] auto *Gep1 = cast(&*It++); [[maybe_unused]] auto *Gep3 = cast(&*It++); auto *Ld0 = cast(&*It++); auto *Ld0b = cast(&*It++); auto *Ld1 = cast(&*It++); auto *Ld3 = cast(&*It++); auto *St0 = cast(&*It++); auto *St1 = cast(&*It++); auto *StVec2 = cast(&*It++); auto *StVec3 = cast(&*It++); auto *StI8 = cast(&*It++); auto *FAdd0 = cast(&*It++); auto *FAdd1 = cast(&*It++); auto *Trunc0 = cast(&*It++); auto *Trunc1 = cast(&*It++); auto *Trunc64to8 = cast(&*It++); auto *Trunc32to8 = cast(&*It++); auto *CmpSLT = cast(&*It++); auto *CmpSGT = cast(&*It++); sandboxir::LegalityAnalysis Legality(SE, DL); const auto &Result = Legality.canVectorize({St0, St1}); EXPECT_TRUE(isa(Result)); { // Check NotInstructions auto &Result = Legality.canVectorize({F, St0}); EXPECT_TRUE(isa(Result)); EXPECT_EQ(cast(Result).getReason(), sandboxir::ResultReason::NotInstructions); } { // Check DiffOpcodes const auto &Result = Legality.canVectorize({St0, Ld0}); EXPECT_TRUE(isa(Result)); EXPECT_EQ(cast(Result).getReason(), sandboxir::ResultReason::DiffOpcodes); } { // Check DiffTypes EXPECT_TRUE(isa(Legality.canVectorize({St0, StVec2}))); EXPECT_TRUE(isa(Legality.canVectorize({StVec2, StVec3}))); const auto &Result = Legality.canVectorize({St0, StI8}); EXPECT_TRUE(isa(Result)); EXPECT_EQ(cast(Result).getReason(), sandboxir::ResultReason::DiffTypes); } { // Check DiffMathFlags const auto &Result = Legality.canVectorize({FAdd0, FAdd1}); EXPECT_TRUE(isa(Result)); EXPECT_EQ(cast(Result).getReason(), sandboxir::ResultReason::DiffMathFlags); } { // Check DiffWrapFlags const auto &Result = Legality.canVectorize({Trunc0, Trunc1}); EXPECT_TRUE(isa(Result)); EXPECT_EQ(cast(Result).getReason(), sandboxir::ResultReason::DiffWrapFlags); } { // Check DiffTypes for unary operands that have a different type. const auto &Result = Legality.canVectorize({Trunc64to8, Trunc32to8}); EXPECT_TRUE(isa(Result)); EXPECT_EQ(cast(Result).getReason(), sandboxir::ResultReason::DiffTypes); } { // Check DiffOpcodes for CMPs with different predicates. const auto &Result = Legality.canVectorize({CmpSLT, CmpSGT}); EXPECT_TRUE(isa(Result)); EXPECT_EQ(cast(Result).getReason(), sandboxir::ResultReason::DiffOpcodes); } { // Check NotConsecutive Ld0,Ld0b const auto &Result = Legality.canVectorize({Ld0, Ld0b}); EXPECT_TRUE(isa(Result)); EXPECT_EQ(cast(Result).getReason(), sandboxir::ResultReason::NotConsecutive); } { // Check NotConsecutive Ld0,Ld3 const auto &Result = Legality.canVectorize({Ld0, Ld3}); EXPECT_TRUE(isa(Result)); EXPECT_EQ(cast(Result).getReason(), sandboxir::ResultReason::NotConsecutive); } { // Check Widen Ld0,Ld1 const auto &Result = Legality.canVectorize({Ld0, Ld1}); EXPECT_TRUE(isa(Result)); } } #ifndef NDEBUG TEST_F(LegalityTest, LegalityResultDump) { parseIR(C, R"IR( define void @foo() { ret void } )IR"); llvm::Function *LLVMF = &*M->getFunction("foo"); auto &SE = getSE(*LLVMF); const auto &DL = M->getDataLayout(); auto Matches = [](const sandboxir::LegalityResult &Result, const std::string &ExpectedStr) -> bool { std::string Buff; raw_string_ostream OS(Buff); Result.print(OS); return Buff == ExpectedStr; }; sandboxir::LegalityAnalysis Legality(SE, DL); EXPECT_TRUE( Matches(Legality.createLegalityResult(), "Widen")); EXPECT_TRUE(Matches(Legality.createLegalityResult( sandboxir::ResultReason::NotInstructions), "Pack Reason: NotInstructions")); EXPECT_TRUE(Matches(Legality.createLegalityResult( sandboxir::ResultReason::DiffOpcodes), "Pack Reason: DiffOpcodes")); EXPECT_TRUE(Matches(Legality.createLegalityResult( sandboxir::ResultReason::DiffTypes), "Pack Reason: DiffTypes")); EXPECT_TRUE(Matches(Legality.createLegalityResult( sandboxir::ResultReason::DiffMathFlags), "Pack Reason: DiffMathFlags")); EXPECT_TRUE(Matches(Legality.createLegalityResult( sandboxir::ResultReason::DiffWrapFlags), "Pack Reason: DiffWrapFlags")); } #endif // NDEBUG