[SandboxVec][Legality] Fix mask on diamond reuse with shuffle (#126963)

This patch fixes a bug in the creation of shuffle masks when vectorizing
vectors in case of a diamond reuse with shuffle. The mask needs to
enumerate all elements of a vector, not treat the original vector value
as a single element. That is: if vectorizing two <2 x float> vectors
into a <4 x float> the mask needs to have 4 indices, not just 2.
This commit is contained in:
vporpo
2025-02-12 12:29:09 -08:00
committed by GitHub
parent 9478822f4f
commit 7a7f9190d0
5 changed files with 63 additions and 3 deletions

View File

@@ -18,6 +18,7 @@
#include "llvm/SandboxIR/Value.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
#include <algorithm>
namespace llvm::sandboxir {
@@ -85,11 +86,13 @@ public:
/// Update the map to reflect that \p Origs got vectorized into \p Vec.
void registerVector(ArrayRef<Value *> Origs, Value *Vec) {
auto &OrigToLaneMap = VectorToOrigLaneMap[Vec];
for (auto [Lane, Orig] : enumerate(Origs)) {
unsigned Lane = 0;
for (Value *Orig : Origs) {
auto Pair = OrigToVectorMap.try_emplace(Orig, Vec);
assert(Pair.second && "Orig already exists in the map!");
(void)Pair;
OrigToLaneMap[Orig] = Lane;
Lane += VecUtils::getNumLanes(Orig);
}
}
void clear() {

View File

@@ -202,14 +202,20 @@ CollectDescr
LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const {
SmallVector<CollectDescr::ExtractElementDescr, 4> Vec;
Vec.reserve(Bndl.size());
for (auto [Lane, V] : enumerate(Bndl)) {
uint32_t LaneAccum;
for (auto [Elm, V] : enumerate(Bndl)) {
uint32_t VLanes = VecUtils::getNumLanes(V);
if (auto *VecOp = IMaps.getVectorForOrig(V)) {
// If there is a vector containing `V`, then get the lane it came from.
std::optional<int> ExtractIdxOpt = IMaps.getOrigLane(VecOp, V);
Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt : -1);
// This could be a vector, like <2 x float> in which case the mask needs
// to enumerate all lanes.
for (int Ln = 0; Ln != VLanes; ++Ln)
Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt + Ln : -1);
} else {
Vec.emplace_back(V);
}
LaneAccum += VLanes;
}
return CollectDescr(std::move(Vec));
}

View File

@@ -328,6 +328,8 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl,
const ShuffleMask &Mask =
cast<DiamondReuseWithShuffle>(LegalityRes).getMask();
NewVec = createShuffle(VecOp, Mask, UserBB);
assert(NewVec->getType() == VecOp->getType() &&
"Expected same type! Bad mask ?");
break;
}
case LegalityResultID::DiamondReuseMultiInput: {

View File

@@ -243,6 +243,28 @@ define void @diamondWithShuffle(ptr %ptr) {
ret void
}
; Same but with <2 x float> elements instead of scalars.
define void @diamondWithShuffleFromVec(ptr %ptr) {
; CHECK-LABEL: define void @diamondWithShuffleFromVec(
; CHECK-SAME: ptr [[PTR:%.*]]) {
; CHECK-NEXT: [[PTR0:%.*]] = getelementptr <2 x float>, ptr [[PTR]], i32 0
; CHECK-NEXT: [[VECL:%.*]] = load <4 x float>, ptr [[PTR0]], align 8
; CHECK-NEXT: [[VSHUF:%.*]] = shufflevector <4 x float> [[VECL]], <4 x float> [[VECL]], <4 x i32> <i32 2, i32 3, i32 0, i32 1>
; CHECK-NEXT: [[VEC:%.*]] = fsub <4 x float> [[VECL]], [[VSHUF]]
; CHECK-NEXT: store <4 x float> [[VEC]], ptr [[PTR0]], align 8
; CHECK-NEXT: ret void
;
%ptr0 = getelementptr <2 x float>, ptr %ptr, i32 0
%ptr1 = getelementptr <2 x float>, ptr %ptr, i32 1
%ld0 = load <2 x float>, ptr %ptr0
%ld1 = load <2 x float>, ptr %ptr1
%sub0 = fsub <2 x float> %ld0, %ld1
%sub1 = fsub <2 x float> %ld1, %ld0
store <2 x float> %sub0, ptr %ptr0
store <2 x float> %sub1, ptr %ptr1
ret void
}
define void @diamondMultiInput(ptr %ptr, ptr %ptrX) {
; CHECK-LABEL: define void @diamondMultiInput(
; CHECK-SAME: ptr [[PTR:%.*]], ptr [[PTRX:%.*]]) {

View File

@@ -85,3 +85,30 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add1));
EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
}
TEST_F(InstrMapsTest, VectorLanes) {
parseIR(C, R"IR(
define void @foo(<2 x i8> %v0, <2 x i8> %v1, <4 x i8> %v2, <4 x i8> %v3) {
%vadd0 = add <2 x i8> %v0, %v1
%vadd1 = add <2 x i8> %v0, %v1
%vadd2 = add <4 x i8> %v2, %v3
ret void
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
auto It = BB->begin();
auto *VAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
auto *VAdd1 = cast<sandboxir::BinaryOperator>(&*It++);
auto *VAdd2 = cast<sandboxir::BinaryOperator>(&*It++);
sandboxir::InstrMaps IMaps(Ctx);
// Check that the vector lanes are calculated correctly.
IMaps.registerVector({VAdd0, VAdd1}, VAdd2);
EXPECT_EQ(*IMaps.getOrigLane(VAdd2, VAdd0), 0U);
EXPECT_EQ(*IMaps.getOrigLane(VAdd2, VAdd1), 2U);
}