From e6d62c910fdc26cda58d21db84c5ef01b910c81d Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Fri, 6 Jun 2025 18:06:46 +0100 Subject: [PATCH] [X86] IsElementEquivalent - pull out vector element count mismatch code. NFC. All cases rely on the ops having the same vector count as the masksize, and this is unlikely to change now that we handle bitcasts, so just early out. --- llvm/lib/Target/X86/X86ISelLowering.cpp | 45 ++++++++++++------------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 13635305f6a8..34e3f52bf7ff 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -9782,20 +9782,23 @@ static bool IsElementEquivalent(int MaskSize, SDValue Op, SDValue ExpectedOp, return false; EVT VT = Op.getValueType(); + EVT ExpectedVT = ExpectedOp.getValueType(); + + // Sources must be vectors and match the mask's element count. + if (!VT.isVector() || !ExpectedVT.isVector() || + (int)VT.getVectorNumElements() != MaskSize || + (int)ExpectedVT.getVectorNumElements() != MaskSize) + return false; + switch (Op.getOpcode()) { case ISD::BUILD_VECTOR: // If the values are build vectors, we can look through them to find // equivalent inputs that make the shuffles equivalent. - // TODO: Handle MaskSize != Op.getNumOperands()? - if (MaskSize == (int)Op.getNumOperands() && - MaskSize == (int)ExpectedOp.getNumOperands()) - return Op.getOperand(Idx) == ExpectedOp.getOperand(ExpectedIdx); - break; + return Op.getOperand(Idx) == ExpectedOp.getOperand(ExpectedIdx); case ISD::BITCAST: { SDValue Src = peekThroughBitcasts(Op); EVT SrcVT = Src.getValueType(); - if (Op == ExpectedOp && SrcVT.isVector() && - (int)VT.getVectorNumElements() == MaskSize) { + if (Op == ExpectedOp && SrcVT.isVector()) { if ((SrcVT.getScalarSizeInBits() % VT.getScalarSizeInBits()) == 0) { unsigned Scale = SrcVT.getScalarSizeInBits() / VT.getScalarSizeInBits(); return (Idx % Scale) == (ExpectedIdx % Scale) && @@ -9816,23 +9819,21 @@ static bool IsElementEquivalent(int MaskSize, SDValue Op, SDValue ExpectedOp, } case ISD::VECTOR_SHUFFLE: { auto *SVN = cast(Op); - return Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize && + return Op == ExpectedOp && SVN->getMaskElt(Idx) == SVN->getMaskElt(ExpectedIdx); } case X86ISD::VBROADCAST: case X86ISD::VBROADCAST_LOAD: - // TODO: Handle MaskSize != VT.getVectorNumElements()? - return (Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize); + return Op == ExpectedOp; case X86ISD::SUBV_BROADCAST_LOAD: - // TODO: Handle MaskSize != VT.getVectorNumElements()? - if (Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize) { + if (Op == ExpectedOp) { auto *MemOp = cast(Op); unsigned NumMemElts = MemOp->getMemoryVT().getVectorNumElements(); return (Idx % NumMemElts) == (ExpectedIdx % NumMemElts); } break; case X86ISD::VPERMI: { - if (Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize) { + if (Op == ExpectedOp) { SmallVector Mask; DecodeVPERMMask(MaskSize, Op.getConstantOperandVal(1), Mask); SDValue Src = Op.getOperand(0); @@ -9849,20 +9850,16 @@ static bool IsElementEquivalent(int MaskSize, SDValue Op, SDValue ExpectedOp, case X86ISD::PACKSS: case X86ISD::PACKUS: // HOP(X,X) can refer to the elt from the lower/upper half of a lane. - // TODO: Handle MaskSize != NumElts? // TODO: Handle HOP(X,Y) vs HOP(Y,X) equivalence cases. if (Op == ExpectedOp && Op.getOperand(0) == Op.getOperand(1)) { int NumElts = VT.getVectorNumElements(); - if (MaskSize == NumElts) { - int NumLanes = VT.getSizeInBits() / 128; - int NumEltsPerLane = NumElts / NumLanes; - int NumHalfEltsPerLane = NumEltsPerLane / 2; - bool SameLane = - (Idx / NumEltsPerLane) == (ExpectedIdx / NumEltsPerLane); - bool SameElt = - (Idx % NumHalfEltsPerLane) == (ExpectedIdx % NumHalfEltsPerLane); - return SameLane && SameElt; - } + int NumLanes = VT.getSizeInBits() / 128; + int NumEltsPerLane = NumElts / NumLanes; + int NumHalfEltsPerLane = NumEltsPerLane / 2; + bool SameLane = (Idx / NumEltsPerLane) == (ExpectedIdx / NumEltsPerLane); + bool SameElt = + (Idx % NumHalfEltsPerLane) == (ExpectedIdx % NumHalfEltsPerLane); + return SameLane && SameElt; } break; }