diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 13635305f6a8..34e3f52bf7ff 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -9782,20 +9782,23 @@ static bool IsElementEquivalent(int MaskSize, SDValue Op, SDValue ExpectedOp, return false; EVT VT = Op.getValueType(); + EVT ExpectedVT = ExpectedOp.getValueType(); + + // Sources must be vectors and match the mask's element count. + if (!VT.isVector() || !ExpectedVT.isVector() || + (int)VT.getVectorNumElements() != MaskSize || + (int)ExpectedVT.getVectorNumElements() != MaskSize) + return false; + switch (Op.getOpcode()) { case ISD::BUILD_VECTOR: // If the values are build vectors, we can look through them to find // equivalent inputs that make the shuffles equivalent. - // TODO: Handle MaskSize != Op.getNumOperands()? - if (MaskSize == (int)Op.getNumOperands() && - MaskSize == (int)ExpectedOp.getNumOperands()) - return Op.getOperand(Idx) == ExpectedOp.getOperand(ExpectedIdx); - break; + return Op.getOperand(Idx) == ExpectedOp.getOperand(ExpectedIdx); case ISD::BITCAST: { SDValue Src = peekThroughBitcasts(Op); EVT SrcVT = Src.getValueType(); - if (Op == ExpectedOp && SrcVT.isVector() && - (int)VT.getVectorNumElements() == MaskSize) { + if (Op == ExpectedOp && SrcVT.isVector()) { if ((SrcVT.getScalarSizeInBits() % VT.getScalarSizeInBits()) == 0) { unsigned Scale = SrcVT.getScalarSizeInBits() / VT.getScalarSizeInBits(); return (Idx % Scale) == (ExpectedIdx % Scale) && @@ -9816,23 +9819,21 @@ static bool IsElementEquivalent(int MaskSize, SDValue Op, SDValue ExpectedOp, } case ISD::VECTOR_SHUFFLE: { auto *SVN = cast(Op); - return Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize && + return Op == ExpectedOp && SVN->getMaskElt(Idx) == SVN->getMaskElt(ExpectedIdx); } case X86ISD::VBROADCAST: case X86ISD::VBROADCAST_LOAD: - // TODO: Handle MaskSize != VT.getVectorNumElements()? - return (Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize); + return Op == ExpectedOp; case X86ISD::SUBV_BROADCAST_LOAD: - // TODO: Handle MaskSize != VT.getVectorNumElements()? - if (Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize) { + if (Op == ExpectedOp) { auto *MemOp = cast(Op); unsigned NumMemElts = MemOp->getMemoryVT().getVectorNumElements(); return (Idx % NumMemElts) == (ExpectedIdx % NumMemElts); } break; case X86ISD::VPERMI: { - if (Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize) { + if (Op == ExpectedOp) { SmallVector Mask; DecodeVPERMMask(MaskSize, Op.getConstantOperandVal(1), Mask); SDValue Src = Op.getOperand(0); @@ -9849,20 +9850,16 @@ static bool IsElementEquivalent(int MaskSize, SDValue Op, SDValue ExpectedOp, case X86ISD::PACKSS: case X86ISD::PACKUS: // HOP(X,X) can refer to the elt from the lower/upper half of a lane. - // TODO: Handle MaskSize != NumElts? // TODO: Handle HOP(X,Y) vs HOP(Y,X) equivalence cases. if (Op == ExpectedOp && Op.getOperand(0) == Op.getOperand(1)) { int NumElts = VT.getVectorNumElements(); - if (MaskSize == NumElts) { - int NumLanes = VT.getSizeInBits() / 128; - int NumEltsPerLane = NumElts / NumLanes; - int NumHalfEltsPerLane = NumEltsPerLane / 2; - bool SameLane = - (Idx / NumEltsPerLane) == (ExpectedIdx / NumEltsPerLane); - bool SameElt = - (Idx % NumHalfEltsPerLane) == (ExpectedIdx % NumHalfEltsPerLane); - return SameLane && SameElt; - } + int NumLanes = VT.getSizeInBits() / 128; + int NumEltsPerLane = NumElts / NumLanes; + int NumHalfEltsPerLane = NumEltsPerLane / 2; + bool SameLane = (Idx / NumEltsPerLane) == (ExpectedIdx / NumEltsPerLane); + bool SameElt = + (Idx % NumHalfEltsPerLane) == (ExpectedIdx % NumHalfEltsPerLane); + return SameLane && SameElt; } break; }