Files
clang-p2996/llvm/test/Transforms/LowerMatrixIntrinsics/propagate-multiple-iterations.ll
Florian Hahn ccf24225e3 [Matrix] Update shape propagation to iterate until done.
This patch updates the shape propagation to iterate until no new shape
information is discovered.

As initial seed for the forward propagation, we use the matrix intrinsic
instructions. Both propagateShapeForward and propagateShapeBackward
return new work lists, with the instructions to be used for the next
iteration. When propagating forward, we record all instructions we added
new shape information for. When propagating backward, we record all
users of instructions we added new shape information for.

Reviewers: anemet, Gerolf, reames, hfinkel, andrew.w.kaylor

Reviewed By: anemet

Differential Revision: https://reviews.llvm.org/D70901
2020-01-09 10:52:52 +00:00

85 lines
6.4 KiB
LLVM

; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt -lower-matrix-intrinsics -S < %s | FileCheck %s
; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
; Make sure we propagate in multiple iterations. First, we back-propagate the
; shape information from the transpose to %A, in the next iteration we
; forward-propagate it to %Mul, and then back to %B.
define <16 x double> @backpropagation_iterations(<16 x double>* %A.Ptr, <16 x double>* %B.Ptr) {
; CHECK-LABEL: @backpropagation_iterations(
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <16 x double>* [[A_PTR:%.*]] to double*
; CHECK-NEXT: [[TMP2:%.*]] = bitcast double* [[TMP1]] to <4 x double>*
; CHECK-NEXT: [[TMP3:%.*]] = load <4 x double>, <4 x double>* [[TMP2]], align 8
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr double, double* [[TMP1]], i32 4
; CHECK-NEXT: [[TMP6:%.*]] = bitcast double* [[TMP5]] to <4 x double>*
; CHECK-NEXT: [[TMP7:%.*]] = load <4 x double>, <4 x double>* [[TMP6]], align 8
; CHECK-NEXT: [[TMP9:%.*]] = getelementptr double, double* [[TMP1]], i32 8
; CHECK-NEXT: [[TMP10:%.*]] = bitcast double* [[TMP9]] to <4 x double>*
; CHECK-NEXT: [[TMP11:%.*]] = load <4 x double>, <4 x double>* [[TMP10]], align 8
; CHECK-NEXT: [[TMP13:%.*]] = getelementptr double, double* [[TMP1]], i32 12
; CHECK-NEXT: [[TMP14:%.*]] = bitcast double* [[TMP13]] to <4 x double>*
; CHECK-NEXT: [[TMP15:%.*]] = load <4 x double>, <4 x double>* [[TMP14]], align 8
; CHECK-NEXT: [[TMP16:%.*]] = extractelement <4 x double> [[TMP3]], i64 0
; CHECK-NEXT: [[TMP17:%.*]] = insertelement <4 x double> undef, double [[TMP16]], i64 0
; CHECK-NEXT: [[TMP18:%.*]] = extractelement <4 x double> [[TMP7]], i64 0
; CHECK-NEXT: [[TMP19:%.*]] = insertelement <4 x double> [[TMP17]], double [[TMP18]], i64 1
; CHECK-NEXT: [[TMP20:%.*]] = extractelement <4 x double> [[TMP11]], i64 0
; CHECK-NEXT: [[TMP21:%.*]] = insertelement <4 x double> [[TMP19]], double [[TMP20]], i64 2
; CHECK-NEXT: [[TMP22:%.*]] = extractelement <4 x double> [[TMP15]], i64 0
; CHECK-NEXT: [[TMP23:%.*]] = insertelement <4 x double> [[TMP21]], double [[TMP22]], i64 3
; CHECK-NEXT: [[TMP24:%.*]] = extractelement <4 x double> [[TMP3]], i64 1
; CHECK-NEXT: [[TMP25:%.*]] = insertelement <4 x double> undef, double [[TMP24]], i64 0
; CHECK-NEXT: [[TMP26:%.*]] = extractelement <4 x double> [[TMP7]], i64 1
; CHECK-NEXT: [[TMP27:%.*]] = insertelement <4 x double> [[TMP25]], double [[TMP26]], i64 1
; CHECK-NEXT: [[TMP28:%.*]] = extractelement <4 x double> [[TMP11]], i64 1
; CHECK-NEXT: [[TMP29:%.*]] = insertelement <4 x double> [[TMP27]], double [[TMP28]], i64 2
; CHECK-NEXT: [[TMP30:%.*]] = extractelement <4 x double> [[TMP15]], i64 1
; CHECK-NEXT: [[TMP31:%.*]] = insertelement <4 x double> [[TMP29]], double [[TMP30]], i64 3
; CHECK-NEXT: [[TMP32:%.*]] = extractelement <4 x double> [[TMP3]], i64 2
; CHECK-NEXT: [[TMP33:%.*]] = insertelement <4 x double> undef, double [[TMP32]], i64 0
; CHECK-NEXT: [[TMP34:%.*]] = extractelement <4 x double> [[TMP7]], i64 2
; CHECK-NEXT: [[TMP35:%.*]] = insertelement <4 x double> [[TMP33]], double [[TMP34]], i64 1
; CHECK-NEXT: [[TMP36:%.*]] = extractelement <4 x double> [[TMP11]], i64 2
; CHECK-NEXT: [[TMP37:%.*]] = insertelement <4 x double> [[TMP35]], double [[TMP36]], i64 2
; CHECK-NEXT: [[TMP38:%.*]] = extractelement <4 x double> [[TMP15]], i64 2
; CHECK-NEXT: [[TMP39:%.*]] = insertelement <4 x double> [[TMP37]], double [[TMP38]], i64 3
; CHECK-NEXT: [[TMP40:%.*]] = extractelement <4 x double> [[TMP3]], i64 3
; CHECK-NEXT: [[TMP41:%.*]] = insertelement <4 x double> undef, double [[TMP40]], i64 0
; CHECK-NEXT: [[TMP42:%.*]] = extractelement <4 x double> [[TMP7]], i64 3
; CHECK-NEXT: [[TMP43:%.*]] = insertelement <4 x double> [[TMP41]], double [[TMP42]], i64 1
; CHECK-NEXT: [[TMP44:%.*]] = extractelement <4 x double> [[TMP11]], i64 3
; CHECK-NEXT: [[TMP45:%.*]] = insertelement <4 x double> [[TMP43]], double [[TMP44]], i64 2
; CHECK-NEXT: [[TMP46:%.*]] = extractelement <4 x double> [[TMP15]], i64 3
; CHECK-NEXT: [[TMP47:%.*]] = insertelement <4 x double> [[TMP45]], double [[TMP46]], i64 3
; CHECK-NEXT: [[TMP48:%.*]] = bitcast <16 x double>* [[B_PTR:%.*]] to double*
; CHECK-NEXT: [[TMP49:%.*]] = bitcast double* [[TMP48]] to <4 x double>*
; CHECK-NEXT: [[TMP50:%.*]] = load <4 x double>, <4 x double>* [[TMP49]], align 8
; CHECK-NEXT: [[TMP52:%.*]] = getelementptr double, double* [[TMP48]], i32 4
; CHECK-NEXT: [[TMP53:%.*]] = bitcast double* [[TMP52]] to <4 x double>*
; CHECK-NEXT: [[TMP54:%.*]] = load <4 x double>, <4 x double>* [[TMP53]], align 8
; CHECK-NEXT: [[TMP56:%.*]] = getelementptr double, double* [[TMP48]], i32 8
; CHECK-NEXT: [[TMP57:%.*]] = bitcast double* [[TMP56]] to <4 x double>*
; CHECK-NEXT: [[TMP58:%.*]] = load <4 x double>, <4 x double>* [[TMP57]], align 8
; CHECK-NEXT: [[TMP60:%.*]] = getelementptr double, double* [[TMP48]], i32 12
; CHECK-NEXT: [[TMP61:%.*]] = bitcast double* [[TMP60]] to <4 x double>*
; CHECK-NEXT: [[TMP62:%.*]] = load <4 x double>, <4 x double>* [[TMP61]], align 8
; CHECK-NEXT: [[TMP63:%.*]] = fmul <4 x double> [[TMP3]], [[TMP50]]
; CHECK-NEXT: [[TMP64:%.*]] = fmul <4 x double> [[TMP7]], [[TMP54]]
; CHECK-NEXT: [[TMP65:%.*]] = fmul <4 x double> [[TMP11]], [[TMP58]]
; CHECK-NEXT: [[TMP66:%.*]] = fmul <4 x double> [[TMP15]], [[TMP62]]
; CHECK-NEXT: [[TMP67:%.*]] = shufflevector <4 x double> [[TMP63]], <4 x double> [[TMP64]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: [[TMP68:%.*]] = shufflevector <4 x double> [[TMP65]], <4 x double> [[TMP66]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: [[TMP69:%.*]] = shufflevector <8 x double> [[TMP67]], <8 x double> [[TMP68]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
; CHECK-NEXT: ret <16 x double> [[TMP69]]
;
%A = load <16 x double>, <16 x double>* %A.Ptr
%A.trans = tail call <16 x double> @llvm.matrix.transpose.v16f64(<16 x double> %A, i32 4, i32 4)
%B = load <16 x double>, <16 x double>* %B.Ptr
%Mul = fmul <16 x double> %A, %B
ret <16 x double> %Mul
}
declare <16 x double> @llvm.matrix.multiply.v16f64.v16f64.v16f64(<16 x double>, <16 x double>, i32 immarg, i32 immarg, i32 immarg)
declare <16 x double> @llvm.matrix.transpose.v16f64(<16 x double>, i32 immarg, i32 immarg)