[Matrix] Fix a crash in VisitSelectInst due to iteration length mismatch
This commit is contained in:
@@ -2326,14 +2326,13 @@ public:
|
||||
MatrixTy A = getMatrix(OpA, Shape, Builder);
|
||||
MatrixTy B = getMatrix(OpB, Shape, Builder);
|
||||
|
||||
Value *CondV[2];
|
||||
SmallVector<Value*> CondV;
|
||||
if (isa<FixedVectorType>(Cond->getType())) {
|
||||
MatrixTy C = getMatrix(Cond, Shape, Builder);
|
||||
CondV[0] = C.getVector(0);
|
||||
CondV[1] = C.getVector(1);
|
||||
llvm::copy(C.vectors(), std::back_inserter(CondV));
|
||||
} else {
|
||||
CondV[0] = Cond;
|
||||
CondV[1] = Cond;
|
||||
CondV.resize(A.getNumVectors());
|
||||
std::fill(CondV.begin(), CondV.end(), Cond);
|
||||
}
|
||||
|
||||
for (auto [CV, AV, BV] : llvm::zip_equal(CondV, A.vectors(), B.vectors()))
|
||||
|
||||
@@ -144,3 +144,64 @@ define void @select_2x2_vcond_shape3(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) {
|
||||
store <4 x float> %op, ptr %out
|
||||
ret void
|
||||
}
|
||||
|
||||
define void @select_2x2_vcond_shape4(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) {
|
||||
; CHECK-LABEL: @select_2x2_vcond_shape4(
|
||||
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <4 x float>, ptr [[LHS:%.*]], align 16
|
||||
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <4 x i1>, ptr [[COND:%.*]], align 1
|
||||
; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <4 x float>, ptr [[RHS:%.*]], align 4
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = select <4 x i1> [[COL_LOAD1]], <4 x float> [[COL_LOAD]], <4 x float> [[COL_LOAD2]]
|
||||
; CHECK-NEXT: store <4 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
|
||||
; CHECK-NEXT: ret void
|
||||
;
|
||||
%lhsv = load <4 x float>, ptr %lhs
|
||||
%condv = call <4 x i1> @llvm.matrix.column.major.load(ptr %cond, i64 4, i1 false, i32 4, i32 1)
|
||||
%rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 4, i1 false, i32 4, i32 1)
|
||||
%op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv
|
||||
store <4 x float> %op, ptr %out
|
||||
ret void
|
||||
}
|
||||
|
||||
define void @select_2x2_vcond_shape5(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) {
|
||||
; CHECK-LABEL: @select_2x2_vcond_shape5(
|
||||
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <1 x float>, ptr [[LHS:%.*]], align 16
|
||||
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 1
|
||||
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <1 x float>, ptr [[VEC_GEP]], align 4
|
||||
; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr float, ptr [[LHS]], i64 2
|
||||
; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <1 x float>, ptr [[VEC_GEP2]], align 8
|
||||
; CHECK-NEXT: [[VEC_GEP4:%.*]] = getelementptr float, ptr [[LHS]], i64 3
|
||||
; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <1 x float>, ptr [[VEC_GEP4]], align 4
|
||||
; CHECK-NEXT: [[COL_LOAD6:%.*]] = load <1 x i1>, ptr [[COND:%.*]], align 1
|
||||
; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr i1, ptr [[COND]], i64 1
|
||||
; CHECK-NEXT: [[COL_LOAD8:%.*]] = load <1 x i1>, ptr [[VEC_GEP7]], align 1
|
||||
; CHECK-NEXT: [[VEC_GEP9:%.*]] = getelementptr i1, ptr [[COND]], i64 2
|
||||
; CHECK-NEXT: [[COL_LOAD10:%.*]] = load <1 x i1>, ptr [[VEC_GEP9]], align 1
|
||||
; CHECK-NEXT: [[VEC_GEP11:%.*]] = getelementptr i1, ptr [[COND]], i64 3
|
||||
; CHECK-NEXT: [[COL_LOAD12:%.*]] = load <1 x i1>, ptr [[VEC_GEP11]], align 1
|
||||
; CHECK-NEXT: [[COL_LOAD13:%.*]] = load <1 x float>, ptr [[RHS:%.*]], align 4
|
||||
; CHECK-NEXT: [[VEC_GEP14:%.*]] = getelementptr float, ptr [[RHS]], i64 1
|
||||
; CHECK-NEXT: [[COL_LOAD15:%.*]] = load <1 x float>, ptr [[VEC_GEP14]], align 4
|
||||
; CHECK-NEXT: [[VEC_GEP16:%.*]] = getelementptr float, ptr [[RHS]], i64 2
|
||||
; CHECK-NEXT: [[COL_LOAD17:%.*]] = load <1 x float>, ptr [[VEC_GEP16]], align 4
|
||||
; CHECK-NEXT: [[VEC_GEP18:%.*]] = getelementptr float, ptr [[RHS]], i64 3
|
||||
; CHECK-NEXT: [[COL_LOAD19:%.*]] = load <1 x float>, ptr [[VEC_GEP18]], align 4
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = select <1 x i1> [[COL_LOAD6]], <1 x float> [[COL_LOAD]], <1 x float> [[COL_LOAD13]]
|
||||
; CHECK-NEXT: [[TMP2:%.*]] = select <1 x i1> [[COL_LOAD8]], <1 x float> [[COL_LOAD1]], <1 x float> [[COL_LOAD15]]
|
||||
; CHECK-NEXT: [[TMP3:%.*]] = select <1 x i1> [[COL_LOAD10]], <1 x float> [[COL_LOAD3]], <1 x float> [[COL_LOAD17]]
|
||||
; CHECK-NEXT: [[TMP4:%.*]] = select <1 x i1> [[COL_LOAD12]], <1 x float> [[COL_LOAD5]], <1 x float> [[COL_LOAD19]]
|
||||
; CHECK-NEXT: store <1 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
|
||||
; CHECK-NEXT: [[VEC_GEP20:%.*]] = getelementptr float, ptr [[OUT]], i64 1
|
||||
; CHECK-NEXT: store <1 x float> [[TMP2]], ptr [[VEC_GEP20]], align 4
|
||||
; CHECK-NEXT: [[VEC_GEP21:%.*]] = getelementptr float, ptr [[OUT]], i64 2
|
||||
; CHECK-NEXT: store <1 x float> [[TMP3]], ptr [[VEC_GEP21]], align 8
|
||||
; CHECK-NEXT: [[VEC_GEP22:%.*]] = getelementptr float, ptr [[OUT]], i64 3
|
||||
; CHECK-NEXT: store <1 x float> [[TMP4]], ptr [[VEC_GEP22]], align 4
|
||||
; CHECK-NEXT: ret void
|
||||
;
|
||||
%lhsv = load <4 x float>, ptr %lhs
|
||||
%condv = call <4 x i1> @llvm.matrix.column.major.load(ptr %cond, i64 1, i1 false, i32 1, i32 4)
|
||||
%rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 1, i1 false, i32 1, i32 4)
|
||||
%op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv
|
||||
store <4 x float> %op, ptr %out
|
||||
ret void
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user