[SandboxVec][Legality] Fix mask on diamond reuse with shuffle (#126963)
This patch fixes a bug in the creation of shuffle masks when vectorizing vectors in case of a diamond reuse with shuffle. The mask needs to enumerate all elements of a vector, not treat the original vector value as a single element. That is: if vectorizing two <2 x float> vectors into a <4 x float> the mask needs to have 4 indices, not just 2.
This commit is contained in:
@@ -18,6 +18,7 @@
|
||||
#include "llvm/SandboxIR/Value.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
|
||||
#include <algorithm>
|
||||
|
||||
namespace llvm::sandboxir {
|
||||
@@ -85,11 +86,13 @@ public:
|
||||
/// Update the map to reflect that \p Origs got vectorized into \p Vec.
|
||||
void registerVector(ArrayRef<Value *> Origs, Value *Vec) {
|
||||
auto &OrigToLaneMap = VectorToOrigLaneMap[Vec];
|
||||
for (auto [Lane, Orig] : enumerate(Origs)) {
|
||||
unsigned Lane = 0;
|
||||
for (Value *Orig : Origs) {
|
||||
auto Pair = OrigToVectorMap.try_emplace(Orig, Vec);
|
||||
assert(Pair.second && "Orig already exists in the map!");
|
||||
(void)Pair;
|
||||
OrigToLaneMap[Orig] = Lane;
|
||||
Lane += VecUtils::getNumLanes(Orig);
|
||||
}
|
||||
}
|
||||
void clear() {
|
||||
|
||||
@@ -202,14 +202,20 @@ CollectDescr
|
||||
LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const {
|
||||
SmallVector<CollectDescr::ExtractElementDescr, 4> Vec;
|
||||
Vec.reserve(Bndl.size());
|
||||
for (auto [Lane, V] : enumerate(Bndl)) {
|
||||
uint32_t LaneAccum;
|
||||
for (auto [Elm, V] : enumerate(Bndl)) {
|
||||
uint32_t VLanes = VecUtils::getNumLanes(V);
|
||||
if (auto *VecOp = IMaps.getVectorForOrig(V)) {
|
||||
// If there is a vector containing `V`, then get the lane it came from.
|
||||
std::optional<int> ExtractIdxOpt = IMaps.getOrigLane(VecOp, V);
|
||||
Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt : -1);
|
||||
// This could be a vector, like <2 x float> in which case the mask needs
|
||||
// to enumerate all lanes.
|
||||
for (int Ln = 0; Ln != VLanes; ++Ln)
|
||||
Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt + Ln : -1);
|
||||
} else {
|
||||
Vec.emplace_back(V);
|
||||
}
|
||||
LaneAccum += VLanes;
|
||||
}
|
||||
return CollectDescr(std::move(Vec));
|
||||
}
|
||||
|
||||
@@ -328,6 +328,8 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl,
|
||||
const ShuffleMask &Mask =
|
||||
cast<DiamondReuseWithShuffle>(LegalityRes).getMask();
|
||||
NewVec = createShuffle(VecOp, Mask, UserBB);
|
||||
assert(NewVec->getType() == VecOp->getType() &&
|
||||
"Expected same type! Bad mask ?");
|
||||
break;
|
||||
}
|
||||
case LegalityResultID::DiamondReuseMultiInput: {
|
||||
|
||||
@@ -243,6 +243,28 @@ define void @diamondWithShuffle(ptr %ptr) {
|
||||
ret void
|
||||
}
|
||||
|
||||
; Same but with <2 x float> elements instead of scalars.
|
||||
define void @diamondWithShuffleFromVec(ptr %ptr) {
|
||||
; CHECK-LABEL: define void @diamondWithShuffleFromVec(
|
||||
; CHECK-SAME: ptr [[PTR:%.*]]) {
|
||||
; CHECK-NEXT: [[PTR0:%.*]] = getelementptr <2 x float>, ptr [[PTR]], i32 0
|
||||
; CHECK-NEXT: [[VECL:%.*]] = load <4 x float>, ptr [[PTR0]], align 8
|
||||
; CHECK-NEXT: [[VSHUF:%.*]] = shufflevector <4 x float> [[VECL]], <4 x float> [[VECL]], <4 x i32> <i32 2, i32 3, i32 0, i32 1>
|
||||
; CHECK-NEXT: [[VEC:%.*]] = fsub <4 x float> [[VECL]], [[VSHUF]]
|
||||
; CHECK-NEXT: store <4 x float> [[VEC]], ptr [[PTR0]], align 8
|
||||
; CHECK-NEXT: ret void
|
||||
;
|
||||
%ptr0 = getelementptr <2 x float>, ptr %ptr, i32 0
|
||||
%ptr1 = getelementptr <2 x float>, ptr %ptr, i32 1
|
||||
%ld0 = load <2 x float>, ptr %ptr0
|
||||
%ld1 = load <2 x float>, ptr %ptr1
|
||||
%sub0 = fsub <2 x float> %ld0, %ld1
|
||||
%sub1 = fsub <2 x float> %ld1, %ld0
|
||||
store <2 x float> %sub0, ptr %ptr0
|
||||
store <2 x float> %sub1, ptr %ptr1
|
||||
ret void
|
||||
}
|
||||
|
||||
define void @diamondMultiInput(ptr %ptr, ptr %ptrX) {
|
||||
; CHECK-LABEL: define void @diamondMultiInput(
|
||||
; CHECK-SAME: ptr [[PTR:%.*]], ptr [[PTRX:%.*]]) {
|
||||
|
||||
@@ -85,3 +85,30 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
|
||||
EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add1));
|
||||
EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
|
||||
}
|
||||
|
||||
TEST_F(InstrMapsTest, VectorLanes) {
|
||||
parseIR(C, R"IR(
|
||||
define void @foo(<2 x i8> %v0, <2 x i8> %v1, <4 x i8> %v2, <4 x i8> %v3) {
|
||||
%vadd0 = add <2 x i8> %v0, %v1
|
||||
%vadd1 = add <2 x i8> %v0, %v1
|
||||
%vadd2 = add <4 x i8> %v2, %v3
|
||||
ret void
|
||||
}
|
||||
)IR");
|
||||
llvm::Function *LLVMF = &*M->getFunction("foo");
|
||||
sandboxir::Context Ctx(C);
|
||||
auto *F = Ctx.createFunction(LLVMF);
|
||||
auto *BB = &*F->begin();
|
||||
auto It = BB->begin();
|
||||
|
||||
auto *VAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
|
||||
auto *VAdd1 = cast<sandboxir::BinaryOperator>(&*It++);
|
||||
auto *VAdd2 = cast<sandboxir::BinaryOperator>(&*It++);
|
||||
|
||||
sandboxir::InstrMaps IMaps(Ctx);
|
||||
|
||||
// Check that the vector lanes are calculated correctly.
|
||||
IMaps.registerVector({VAdd0, VAdd1}, VAdd2);
|
||||
EXPECT_EQ(*IMaps.getOrigLane(VAdd2, VAdd0), 0U);
|
||||
EXPECT_EQ(*IMaps.getOrigLane(VAdd2, VAdd1), 2U);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user