223 lines
8.3 KiB
C++
223 lines
8.3 KiB
C++
//===- LegalityTest.cpp ---------------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
|
|
#include "llvm/Analysis/AssumptionCache.h"
|
|
#include "llvm/Analysis/LoopInfo.h"
|
|
#include "llvm/Analysis/ScalarEvolution.h"
|
|
#include "llvm/Analysis/TargetLibraryInfo.h"
|
|
#include "llvm/AsmParser/Parser.h"
|
|
#include "llvm/IR/DataLayout.h"
|
|
#include "llvm/IR/Dominators.h"
|
|
#include "llvm/SandboxIR/Function.h"
|
|
#include "llvm/SandboxIR/Instruction.h"
|
|
#include "llvm/Support/SourceMgr.h"
|
|
#include "gtest/gtest.h"
|
|
|
|
using namespace llvm;
|
|
|
|
struct LegalityTest : public testing::Test {
|
|
LLVMContext C;
|
|
std::unique_ptr<Module> M;
|
|
std::unique_ptr<DominatorTree> DT;
|
|
std::unique_ptr<TargetLibraryInfoImpl> TLII;
|
|
std::unique_ptr<TargetLibraryInfo> TLI;
|
|
std::unique_ptr<AssumptionCache> AC;
|
|
std::unique_ptr<LoopInfo> LI;
|
|
std::unique_ptr<ScalarEvolution> SE;
|
|
|
|
ScalarEvolution &getSE(llvm::Function &LLVMF) {
|
|
DT = std::make_unique<DominatorTree>(LLVMF);
|
|
TLII = std::make_unique<TargetLibraryInfoImpl>();
|
|
TLI = std::make_unique<TargetLibraryInfo>(*TLII);
|
|
AC = std::make_unique<AssumptionCache>(LLVMF);
|
|
LI = std::make_unique<LoopInfo>(*DT);
|
|
SE = std::make_unique<ScalarEvolution>(LLVMF, *TLI, *AC, *DT, *LI);
|
|
return *SE;
|
|
}
|
|
|
|
void parseIR(LLVMContext &C, const char *IR) {
|
|
SMDiagnostic Err;
|
|
M = parseAssemblyString(IR, Err, C);
|
|
if (!M)
|
|
Err.print("LegalityTest", errs());
|
|
}
|
|
};
|
|
|
|
TEST_F(LegalityTest, Legality) {
|
|
parseIR(C, R"IR(
|
|
define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2) {
|
|
%gep0 = getelementptr float, ptr %ptr, i32 0
|
|
%gep1 = getelementptr float, ptr %ptr, i32 1
|
|
%gep3 = getelementptr float, ptr %ptr, i32 3
|
|
%ld0 = load float, ptr %gep0
|
|
%ld0b = load float, ptr %gep0
|
|
%ld1 = load float, ptr %gep1
|
|
%ld3 = load float, ptr %gep3
|
|
store float %ld0, ptr %gep0
|
|
store float %ld1, ptr %gep1
|
|
store <2 x float> %vec2, ptr %gep1
|
|
store <3 x float> %vec3, ptr %gep3
|
|
store i8 %arg, ptr %gep1
|
|
%fadd0 = fadd float %farg0, %farg0
|
|
%fadd1 = fadd fast float %farg1, %farg1
|
|
%trunc0 = trunc nuw nsw i64 %v0 to i8
|
|
%trunc1 = trunc nsw i64 %v1 to i8
|
|
%trunc64to8 = trunc i64 %v0 to i8
|
|
%trunc32to8 = trunc i32 %v2 to i8
|
|
%cmpSLT = icmp slt i64 %v0, %v1
|
|
%cmpSGT = icmp sgt i64 %v0, %v1
|
|
ret void
|
|
}
|
|
)IR");
|
|
llvm::Function *LLVMF = &*M->getFunction("foo");
|
|
auto &SE = getSE(*LLVMF);
|
|
const auto &DL = M->getDataLayout();
|
|
|
|
sandboxir::Context Ctx(C);
|
|
auto *F = Ctx.createFunction(LLVMF);
|
|
auto *BB = &*F->begin();
|
|
auto It = BB->begin();
|
|
[[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++);
|
|
[[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++);
|
|
[[maybe_unused]] auto *Gep3 = cast<sandboxir::GetElementPtrInst>(&*It++);
|
|
auto *Ld0 = cast<sandboxir::LoadInst>(&*It++);
|
|
auto *Ld0b = cast<sandboxir::LoadInst>(&*It++);
|
|
auto *Ld1 = cast<sandboxir::LoadInst>(&*It++);
|
|
auto *Ld3 = cast<sandboxir::LoadInst>(&*It++);
|
|
auto *St0 = cast<sandboxir::StoreInst>(&*It++);
|
|
auto *St1 = cast<sandboxir::StoreInst>(&*It++);
|
|
auto *StVec2 = cast<sandboxir::StoreInst>(&*It++);
|
|
auto *StVec3 = cast<sandboxir::StoreInst>(&*It++);
|
|
auto *StI8 = cast<sandboxir::StoreInst>(&*It++);
|
|
auto *FAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
|
|
auto *FAdd1 = cast<sandboxir::BinaryOperator>(&*It++);
|
|
auto *Trunc0 = cast<sandboxir::TruncInst>(&*It++);
|
|
auto *Trunc1 = cast<sandboxir::TruncInst>(&*It++);
|
|
auto *Trunc64to8 = cast<sandboxir::TruncInst>(&*It++);
|
|
auto *Trunc32to8 = cast<sandboxir::TruncInst>(&*It++);
|
|
auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
|
|
auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);
|
|
|
|
sandboxir::LegalityAnalysis Legality(SE, DL);
|
|
const auto &Result = Legality.canVectorize({St0, St1});
|
|
EXPECT_TRUE(isa<sandboxir::Widen>(Result));
|
|
|
|
{
|
|
// Check NotInstructions
|
|
auto &Result = Legality.canVectorize({F, St0});
|
|
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
|
|
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
|
|
sandboxir::ResultReason::NotInstructions);
|
|
}
|
|
{
|
|
// Check DiffOpcodes
|
|
const auto &Result = Legality.canVectorize({St0, Ld0});
|
|
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
|
|
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
|
|
sandboxir::ResultReason::DiffOpcodes);
|
|
}
|
|
{
|
|
// Check DiffTypes
|
|
EXPECT_TRUE(isa<sandboxir::Widen>(Legality.canVectorize({St0, StVec2})));
|
|
EXPECT_TRUE(isa<sandboxir::Widen>(Legality.canVectorize({StVec2, StVec3})));
|
|
|
|
const auto &Result = Legality.canVectorize({St0, StI8});
|
|
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
|
|
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
|
|
sandboxir::ResultReason::DiffTypes);
|
|
}
|
|
{
|
|
// Check DiffMathFlags
|
|
const auto &Result = Legality.canVectorize({FAdd0, FAdd1});
|
|
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
|
|
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
|
|
sandboxir::ResultReason::DiffMathFlags);
|
|
}
|
|
{
|
|
// Check DiffWrapFlags
|
|
const auto &Result = Legality.canVectorize({Trunc0, Trunc1});
|
|
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
|
|
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
|
|
sandboxir::ResultReason::DiffWrapFlags);
|
|
}
|
|
{
|
|
// Check DiffTypes for unary operands that have a different type.
|
|
const auto &Result = Legality.canVectorize({Trunc64to8, Trunc32to8});
|
|
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
|
|
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
|
|
sandboxir::ResultReason::DiffTypes);
|
|
}
|
|
{
|
|
// Check DiffOpcodes for CMPs with different predicates.
|
|
const auto &Result = Legality.canVectorize({CmpSLT, CmpSGT});
|
|
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
|
|
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
|
|
sandboxir::ResultReason::DiffOpcodes);
|
|
}
|
|
{
|
|
// Check NotConsecutive Ld0,Ld0b
|
|
const auto &Result = Legality.canVectorize({Ld0, Ld0b});
|
|
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
|
|
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
|
|
sandboxir::ResultReason::NotConsecutive);
|
|
}
|
|
{
|
|
// Check NotConsecutive Ld0,Ld3
|
|
const auto &Result = Legality.canVectorize({Ld0, Ld3});
|
|
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
|
|
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
|
|
sandboxir::ResultReason::NotConsecutive);
|
|
}
|
|
{
|
|
// Check Widen Ld0,Ld1
|
|
const auto &Result = Legality.canVectorize({Ld0, Ld1});
|
|
EXPECT_TRUE(isa<sandboxir::Widen>(Result));
|
|
}
|
|
}
|
|
|
|
#ifndef NDEBUG
|
|
TEST_F(LegalityTest, LegalityResultDump) {
|
|
parseIR(C, R"IR(
|
|
define void @foo() {
|
|
ret void
|
|
}
|
|
)IR");
|
|
llvm::Function *LLVMF = &*M->getFunction("foo");
|
|
auto &SE = getSE(*LLVMF);
|
|
const auto &DL = M->getDataLayout();
|
|
|
|
auto Matches = [](const sandboxir::LegalityResult &Result,
|
|
const std::string &ExpectedStr) -> bool {
|
|
std::string Buff;
|
|
raw_string_ostream OS(Buff);
|
|
Result.print(OS);
|
|
return Buff == ExpectedStr;
|
|
};
|
|
|
|
sandboxir::LegalityAnalysis Legality(SE, DL);
|
|
EXPECT_TRUE(
|
|
Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen"));
|
|
EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
|
|
sandboxir::ResultReason::NotInstructions),
|
|
"Pack Reason: NotInstructions"));
|
|
EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
|
|
sandboxir::ResultReason::DiffOpcodes),
|
|
"Pack Reason: DiffOpcodes"));
|
|
EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
|
|
sandboxir::ResultReason::DiffTypes),
|
|
"Pack Reason: DiffTypes"));
|
|
EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
|
|
sandboxir::ResultReason::DiffMathFlags),
|
|
"Pack Reason: DiffMathFlags"));
|
|
EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
|
|
sandboxir::ResultReason::DiffWrapFlags),
|
|
"Pack Reason: DiffWrapFlags"));
|
|
}
|
|
#endif // NDEBUG
|