From 56548e1d9b2ed4f5d2fe3913c27af770cf0e06e5 Mon Sep 17 00:00:00 2001 From: Jon Roelofs Date: Thu, 12 Jun 2025 09:19:58 -0700 Subject: [PATCH] [Matrix] Fix a crash in VisitSelectInst due to iteration length mismatch --- .../Scalar/LowerMatrixIntrinsics.cpp | 9 ++- .../LowerMatrixIntrinsics/select.ll | 61 +++++++++++++++++++ 2 files changed, 65 insertions(+), 5 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index b32160ff275b..1e37f40fa9d5 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -2326,14 +2326,13 @@ public: MatrixTy A = getMatrix(OpA, Shape, Builder); MatrixTy B = getMatrix(OpB, Shape, Builder); - Value *CondV[2]; + SmallVector CondV; if (isa(Cond->getType())) { MatrixTy C = getMatrix(Cond, Shape, Builder); - CondV[0] = C.getVector(0); - CondV[1] = C.getVector(1); + llvm::copy(C.vectors(), std::back_inserter(CondV)); } else { - CondV[0] = Cond; - CondV[1] = Cond; + CondV.resize(A.getNumVectors()); + std::fill(CondV.begin(), CondV.end(), Cond); } for (auto [CV, AV, BV] : llvm::zip_equal(CondV, A.vectors(), B.vectors())) diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll index 70b0dfdb3e7e..bd97915759aa 100644 --- a/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll @@ -144,3 +144,64 @@ define void @select_2x2_vcond_shape3(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) { store <4 x float> %op, ptr %out ret void } + +define void @select_2x2_vcond_shape4(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) { +; CHECK-LABEL: @select_2x2_vcond_shape4( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <4 x float>, ptr [[LHS:%.*]], align 16 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <4 x i1>, ptr [[COND:%.*]], align 1 +; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <4 x float>, ptr [[RHS:%.*]], align 4 +; CHECK-NEXT: [[TMP1:%.*]] = select <4 x i1> [[COL_LOAD1]], <4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD2]] +; CHECK-NEXT: store <4 x float> [[TMP1]], ptr [[OUT:%.*]], align 16 +; CHECK-NEXT: ret void +; + %lhsv = load <4 x float>, ptr %lhs + %condv = call <4 x i1> @llvm.matrix.column.major.load(ptr %cond, i64 4, i1 false, i32 4, i32 1) + %rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 4, i1 false, i32 4, i32 1) + %op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv + store <4 x float> %op, ptr %out + ret void +} + +define void @select_2x2_vcond_shape5(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) { +; CHECK-LABEL: @select_2x2_vcond_shape5( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <1 x float>, ptr [[LHS:%.*]], align 16 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 1 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <1 x float>, ptr [[VEC_GEP]], align 4 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr float, ptr [[LHS]], i64 2 +; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <1 x float>, ptr [[VEC_GEP2]], align 8 +; CHECK-NEXT: [[VEC_GEP4:%.*]] = getelementptr float, ptr [[LHS]], i64 3 +; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <1 x float>, ptr [[VEC_GEP4]], align 4 +; CHECK-NEXT: [[COL_LOAD6:%.*]] = load <1 x i1>, ptr [[COND:%.*]], align 1 +; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr i1, ptr [[COND]], i64 1 +; CHECK-NEXT: [[COL_LOAD8:%.*]] = load <1 x i1>, ptr [[VEC_GEP7]], align 1 +; CHECK-NEXT: [[VEC_GEP9:%.*]] = getelementptr i1, ptr [[COND]], i64 2 +; CHECK-NEXT: [[COL_LOAD10:%.*]] = load <1 x i1>, ptr [[VEC_GEP9]], align 1 +; CHECK-NEXT: [[VEC_GEP11:%.*]] = getelementptr i1, ptr [[COND]], i64 3 +; CHECK-NEXT: [[COL_LOAD12:%.*]] = load <1 x i1>, ptr [[VEC_GEP11]], align 1 +; CHECK-NEXT: [[COL_LOAD13:%.*]] = load <1 x float>, ptr [[RHS:%.*]], align 4 +; CHECK-NEXT: [[VEC_GEP14:%.*]] = getelementptr float, ptr [[RHS]], i64 1 +; CHECK-NEXT: [[COL_LOAD15:%.*]] = load <1 x float>, ptr [[VEC_GEP14]], align 4 +; CHECK-NEXT: [[VEC_GEP16:%.*]] = getelementptr float, ptr [[RHS]], i64 2 +; CHECK-NEXT: [[COL_LOAD17:%.*]] = load <1 x float>, ptr [[VEC_GEP16]], align 4 +; CHECK-NEXT: [[VEC_GEP18:%.*]] = getelementptr float, ptr [[RHS]], i64 3 +; CHECK-NEXT: [[COL_LOAD19:%.*]] = load <1 x float>, ptr [[VEC_GEP18]], align 4 +; CHECK-NEXT: [[TMP1:%.*]] = select <1 x i1> [[COL_LOAD6]], <1 x float> [[COL_LOAD]], <1 x float> [[COL_LOAD13]] +; CHECK-NEXT: [[TMP2:%.*]] = select <1 x i1> [[COL_LOAD8]], <1 x float> [[COL_LOAD1]], <1 x float> [[COL_LOAD15]] +; CHECK-NEXT: [[TMP3:%.*]] = select <1 x i1> [[COL_LOAD10]], <1 x float> [[COL_LOAD3]], <1 x float> [[COL_LOAD17]] +; CHECK-NEXT: [[TMP4:%.*]] = select <1 x i1> [[COL_LOAD12]], <1 x float> [[COL_LOAD5]], <1 x float> [[COL_LOAD19]] +; CHECK-NEXT: store <1 x float> [[TMP1]], ptr [[OUT:%.*]], align 16 +; CHECK-NEXT: [[VEC_GEP20:%.*]] = getelementptr float, ptr [[OUT]], i64 1 +; CHECK-NEXT: store <1 x float> [[TMP2]], ptr [[VEC_GEP20]], align 4 +; CHECK-NEXT: [[VEC_GEP21:%.*]] = getelementptr float, ptr [[OUT]], i64 2 +; CHECK-NEXT: store <1 x float> [[TMP3]], ptr [[VEC_GEP21]], align 8 +; CHECK-NEXT: [[VEC_GEP22:%.*]] = getelementptr float, ptr [[OUT]], i64 3 +; CHECK-NEXT: store <1 x float> [[TMP4]], ptr [[VEC_GEP22]], align 4 +; CHECK-NEXT: ret void +; + %lhsv = load <4 x float>, ptr %lhs + %condv = call <4 x i1> @llvm.matrix.column.major.load(ptr %cond, i64 1, i1 false, i32 1, i32 4) + %rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 1, i1 false, i32 1, i32 4) + %op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv + store <4 x float> %op, ptr %out + ret void +}