diff --git a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt index a22f8332514b..0fe01824b824 100644 --- a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt @@ -1,6 +1,5 @@ add_mlir_dialect(X86Vector x86vector) add_mlir_doc(X86Vector X86Vector Dialects/ -gen-dialect-doc -dialect=x86vector) -set(LLVM_TARGET_DEFINITIONS X86Vector.td) -mlir_tablegen(X86VectorConversions.inc -gen-llvmir-conversions) -add_public_tablegen_target(MLIRX86VectorConversionsIncGen) +add_mlir_interface(X86VectorInterfaces) +add_dependencies(MLIRX86VectorIncGen MLIRX86VectorInterfacesIncGen) diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td index 566013e73f4b..5be0d92db463 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -16,6 +16,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Dialect/X86Vector/X86VectorInterfaces.td" //===----------------------------------------------------------------------===// // X86Vector dialect definition @@ -34,30 +35,12 @@ def X86Vector_Dialect : Dialect { class AVX512_Op traits = []> : Op {} -// Intrinsic operation used during lowering to LLVM IR. -class AVX512_IntrOp traits = [], - string extension = ""> : - LLVM_IntrOpBase; - -// Defined by first result overload. May have to be extended for other -// instructions in the future. -class AVX512_IntrOverloadedOp traits = [], - string extension = ""> : - LLVM_IntrOpBase overloadedResults=*/[0], - /*list overloadedOperands=*/[], - traits, /*numResults=*/1>; - //----------------------------------------------------------------------------// // MaskCompressOp //----------------------------------------------------------------------------// def MaskCompressOp : AVX512_Op<"mask.compress", [Pure, + DeclareOpInterfaceMethods, // TODO: Support optional arguments in `AllTypesMatch`. "type($src)" could // then be removed from assemblyFormat. AllTypesMatch<["a", "dst"]>, @@ -91,21 +74,17 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [Pure, let assemblyFormat = "$k `,` $a (`,` $src^)? attr-dict" " `:` type($dst) (`,` type($src)^)?"; let hasVerifier = 1; -} -def MaskCompressIntrOp : AVX512_IntrOverloadedOp<"mask.compress", [ - Pure, - AllTypesMatch<["a", "src", "res"]>, - TypesMatchWith<"`k` has the same number of bits as elements in `res`", - "res", "k", - "VectorType::get({::llvm::cast($_self).getShape()[0]}, " - "IntegerType::get($_self.getContext(), 1))">]> { - let arguments = (ins VectorOfLengthAndType<[16, 8], - [F32, I32, F64, I64]>:$a, - VectorOfLengthAndType<[16, 8], - [F32, I32, F64, I64]>:$src, - VectorOfLengthAndType<[16, 8], - [I1]>:$k); + let extraClassDefinition = [{ + std::string $cppClass::getIntrinsicName() { + // Call the baseline overloaded intrisic. + // Final overload name mangling is resolved by the created function call. + return "llvm.x86.avx512.mask.compress"; + } + }]; + let extraClassDeclaration = [{ + SmallVector getIntrinsicOperands(::mlir::RewriterBase&); + }]; } //----------------------------------------------------------------------------// @@ -113,6 +92,7 @@ def MaskCompressIntrOp : AVX512_IntrOverloadedOp<"mask.compress", [ //----------------------------------------------------------------------------// def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [Pure, + DeclareOpInterfaceMethods, AllTypesMatch<["src", "a", "dst"]>, TypesMatchWith<"imm has the same number of bits as elements in dst", "dst", "imm", @@ -142,26 +122,20 @@ def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [Pure, let results = (outs VectorOfLengthAndType<[16, 8], [F32, F64]>:$dst); let assemblyFormat = "$src `,` $k `,` $a `,` $imm `,` $rounding attr-dict `:` type($dst)"; -} -def MaskRndScalePSIntrOp : AVX512_IntrOp<"mask.rndscale.ps.512", 1, [ - Pure, - AllTypesMatch<["src", "a", "res"]>]> { - let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src, - I32:$k, - VectorOfLengthAndType<[16], [F32]>:$a, - I16:$imm, - LLVM_Type:$rounding); -} - -def MaskRndScalePDIntrOp : AVX512_IntrOp<"mask.rndscale.pd.512", 1, [ - Pure, - AllTypesMatch<["src", "a", "res"]>]> { - let arguments = (ins VectorOfLengthAndType<[8], [F64]>:$src, - I32:$k, - VectorOfLengthAndType<[8], [F64]>:$a, - I8:$imm, - LLVM_Type:$rounding); + let extraClassDefinition = [{ + std::string $cppClass::getIntrinsicName() { + std::string intr = "llvm.x86.avx512.mask.rndscale"; + VectorType vecType = getSrc().getType(); + Type elemType = vecType.getElementType(); + intr += "."; + intr += elemType.isF32() ? "ps" : "pd"; + unsigned elemBitWidth = vecType.getElementTypeBitWidth(); + unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth; + intr += "." + std::to_string(opBitWidth); + return intr; + } + }]; } //----------------------------------------------------------------------------// @@ -169,6 +143,7 @@ def MaskRndScalePDIntrOp : AVX512_IntrOp<"mask.rndscale.pd.512", 1, [ //----------------------------------------------------------------------------// def MaskScaleFOp : AVX512_Op<"mask.scalef", [Pure, + DeclareOpInterfaceMethods, AllTypesMatch<["src", "a", "b", "dst"]>, TypesMatchWith<"k has the same number of bits as elements in dst", "dst", "k", @@ -199,26 +174,20 @@ def MaskScaleFOp : AVX512_Op<"mask.scalef", [Pure, // Fully specified by traits. let assemblyFormat = "$src `,` $a `,` $b `,` $k `,` $rounding attr-dict `:` type($dst)"; -} -def MaskScaleFPSIntrOp : AVX512_IntrOp<"mask.scalef.ps.512", 1, [ - Pure, - AllTypesMatch<["src", "a", "b", "res"]>]> { - let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src, - VectorOfLengthAndType<[16], [F32]>:$a, - VectorOfLengthAndType<[16], [F32]>:$b, - I16:$k, - LLVM_Type:$rounding); -} - -def MaskScaleFPDIntrOp : AVX512_IntrOp<"mask.scalef.pd.512", 1, [ - Pure, - AllTypesMatch<["src", "a", "b", "res"]>]> { - let arguments = (ins VectorOfLengthAndType<[8], [F64]>:$src, - VectorOfLengthAndType<[8], [F64]>:$a, - VectorOfLengthAndType<[8], [F64]>:$b, - I8:$k, - LLVM_Type:$rounding); + let extraClassDefinition = [{ + std::string $cppClass::getIntrinsicName() { + std::string intr = "llvm.x86.avx512.mask.scalef"; + VectorType vecType = getSrc().getType(); + Type elemType = vecType.getElementType(); + intr += "."; + intr += elemType.isF32() ? "ps" : "pd"; + unsigned elemBitWidth = vecType.getElementTypeBitWidth(); + unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth; + intr += "." + std::to_string(opBitWidth); + return intr; + } + }]; } //----------------------------------------------------------------------------// @@ -226,6 +195,7 @@ def MaskScaleFPDIntrOp : AVX512_IntrOp<"mask.scalef.pd.512", 1, [ //----------------------------------------------------------------------------// def Vp2IntersectOp : AVX512_Op<"vp2intersect", [Pure, + DeclareOpInterfaceMethods, AllTypesMatch<["a", "b"]>, TypesMatchWith<"k1 has the same number of bits as elements in a", "a", "k1", @@ -260,18 +230,20 @@ def Vp2IntersectOp : AVX512_Op<"vp2intersect", [Pure, ); let assemblyFormat = "$a `,` $b attr-dict `:` type($a)"; -} -def Vp2IntersectDIntrOp : AVX512_IntrOp<"vp2intersect.d.512", 2, [ - Pure]> { - let arguments = (ins VectorOfLengthAndType<[16], [I32]>:$a, - VectorOfLengthAndType<[16], [I32]>:$b); -} - -def Vp2IntersectQIntrOp : AVX512_IntrOp<"vp2intersect.q.512", 2, [ - Pure]> { - let arguments = (ins VectorOfLengthAndType<[8], [I64]>:$a, - VectorOfLengthAndType<[8], [I64]>:$b); + let extraClassDefinition = [{ + std::string $cppClass::getIntrinsicName() { + std::string intr = "llvm.x86.avx512.vp2intersect"; + VectorType vecType = getA().getType(); + Type elemType = vecType.getElementType(); + intr += "."; + intr += elemType.isInteger(32) ? "d" : "q"; + unsigned elemBitWidth = vecType.getElementTypeBitWidth(); + unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth; + intr += "." + std::to_string(opBitWidth); + return intr; + } + }]; } //----------------------------------------------------------------------------// @@ -279,6 +251,7 @@ def Vp2IntersectQIntrOp : AVX512_IntrOp<"vp2intersect.q.512", 2, [ //----------------------------------------------------------------------------// def DotBF16Op : AVX512_Op<"dot", [Pure, + DeclareOpInterfaceMethods, AllTypesMatch<["a", "b"]>, AllTypesMatch<["src", "dst"]>, TypesMatchWith<"`a` has twice an many elements as `src`", @@ -299,7 +272,7 @@ def DotBF16Op : AVX512_Op<"dot", [Pure, Example: ```mlir - %0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32> + %dst = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32> ``` }]; let arguments = (ins VectorOfLengthAndType<[4, 8, 16], [F32]>:$src, @@ -309,36 +282,17 @@ def DotBF16Op : AVX512_Op<"dot", [Pure, let results = (outs VectorOfLengthAndType<[4, 8, 16], [F32]>:$dst); let assemblyFormat = "$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)"; -} -def DotBF16Ps128IntrOp : AVX512_IntrOp<"dpbf16ps.128", 1, [Pure, - AllTypesMatch<["a", "b"]>, - AllTypesMatch<["src", "res"]>], - /*extension=*/"bf16"> { - let arguments = (ins VectorOfLengthAndType<[4], [F32]>:$src, - VectorOfLengthAndType<[8], [BF16]>:$a, - VectorOfLengthAndType<[8], [BF16]>:$b); - let results = (outs VectorOfLengthAndType<[4], [F32]>:$res); -} - -def DotBF16Ps256IntrOp : AVX512_IntrOp<"dpbf16ps.256", 1, [Pure, - AllTypesMatch<["a", "b"]>, - AllTypesMatch<["src", "res"]>], - /*extension=*/"bf16"> { - let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$src, - VectorOfLengthAndType<[16], [BF16]>:$a, - VectorOfLengthAndType<[16], [BF16]>:$b); - let results = (outs VectorOfLengthAndType<[8], [F32]>:$res); -} - -def DotBF16Ps512IntrOp : AVX512_IntrOp<"dpbf16ps.512", 1, [Pure, - AllTypesMatch<["a", "b"]>, - AllTypesMatch<["src", "res"]>], - /*extension=*/"bf16"> { - let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src, - VectorOfLengthAndType<[32], [BF16]>:$a, - VectorOfLengthAndType<[32], [BF16]>:$b); - let results = (outs VectorOfLengthAndType<[16], [F32]>:$res); + let extraClassDefinition = [{ + std::string $cppClass::getIntrinsicName() { + std::string intr = "llvm.x86.avx512bf16.dpbf16ps"; + VectorType vecType = getSrc().getType(); + unsigned elemBitWidth = vecType.getElementTypeBitWidth(); + unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth; + intr += "." + std::to_string(opBitWidth); + return intr; + } + }]; } //----------------------------------------------------------------------------// @@ -346,6 +300,7 @@ def DotBF16Ps512IntrOp : AVX512_IntrOp<"dpbf16ps.512", 1, [Pure, //----------------------------------------------------------------------------// def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure, + DeclareOpInterfaceMethods, AllElementCountsMatch<["a", "dst"]>]> { let summary = "Convert packed F32 to packed BF16 Data."; let description = [{ @@ -367,18 +322,17 @@ def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure, let results = (outs VectorOfLengthAndType<[8, 16], [BF16]>:$dst); let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst)"; -} -def CvtNeF32ToBF16Ps256IntrOp : AVX512_IntrOp<"cvtneps2bf16.256", 1, [Pure], - /*extension=*/"bf16"> { - let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a); - let results = (outs VectorOfLengthAndType<[8], [BF16]>:$res); -} - -def CvtNeF32ToBF16Ps512IntrOp : AVX512_IntrOp<"cvtneps2bf16.512", 1, [Pure], - /*extension=*/"bf16"> { - let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$a); - let results = (outs VectorOfLengthAndType<[16], [BF16]>:$res); + let extraClassDefinition = [{ + std::string $cppClass::getIntrinsicName() { + std::string intr = "llvm.x86.avx512bf16.cvtneps2bf16"; + VectorType vecType = getA().getType(); + unsigned elemBitWidth = vecType.getElementTypeBitWidth(); + unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth; + intr += "." + std::to_string(opBitWidth); + return intr; + } + }]; } //===----------------------------------------------------------------------===// @@ -395,33 +349,32 @@ class AVX_Op traits = []> : class AVX_LowOp traits = []> : Op {} -// Intrinsic operation used during lowering to LLVM IR. -class AVX_IntrOp traits = []> : - LLVM_IntrOpBase; - //----------------------------------------------------------------------------// // AVX Rsqrt //----------------------------------------------------------------------------// -def RsqrtOp : AVX_Op<"rsqrt", [Pure, SameOperandsAndResultType]> { +def RsqrtOp : AVX_Op<"rsqrt", [Pure, + DeclareOpInterfaceMethods, + SameOperandsAndResultType]> { let summary = "Rsqrt"; let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a); let results = (outs VectorOfLengthAndType<[8], [F32]>:$b); let assemblyFormat = "$a attr-dict `:` type($a)"; -} -def RsqrtIntrOp : AVX_IntrOp<"rsqrt.ps.256", 1, [Pure, - SameOperandsAndResultType]> { - let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a); + let extraClassDefinition = [{ + std::string $cppClass::getIntrinsicName() { + return "llvm.x86.avx.rsqrt.ps.256"; + } + }]; } //----------------------------------------------------------------------------// // AVX Dot //----------------------------------------------------------------------------// -def DotOp : AVX_LowOp<"dot", [Pure, SameOperandsAndResultType]> { +def DotOp : AVX_LowOp<"dot", [Pure, + DeclareOpInterfaceMethods, + SameOperandsAndResultType]> { let summary = "Dot"; let description = [{ Computes the 4-way dot products of the lower and higher parts of the source @@ -443,13 +396,16 @@ def DotOp : AVX_LowOp<"dot", [Pure, SameOperandsAndResultType]> { VectorOfLengthAndType<[8], [F32]>:$b); let results = (outs VectorOfLengthAndType<[8], [F32]>:$res); let assemblyFormat = "$a `,` $b attr-dict `:` type($res)"; -} -def DotIntrOp : AVX_IntrOp<"dp.ps.256", 1, [Pure, - AllTypesMatch<["a", "b", "res"]>]> { - let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a, - VectorOfLengthAndType<[8], [F32]>:$b, I8:$c); - let results = (outs VectorOfLengthAndType<[8], [F32]>:$res); + let extraClassDefinition = [{ + std::string $cppClass::getIntrinsicName() { + // Only one variant is supported right now - no extra mangling. + return "llvm.x86.avx.dp.ps.256"; + } + }]; + let extraClassDeclaration = [{ + SmallVector getIntrinsicOperands(::mlir::RewriterBase&); + }]; } #endif // X86VECTOR_OPS diff --git a/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h b/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h index 4017bc1a917e..7bcf4c69b0a6 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h +++ b/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h @@ -18,9 +18,13 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +/// Include the generated interface declarations. +#include "mlir/Dialect/X86Vector/X86VectorInterfaces.h.inc" + #include "mlir/Dialect/X86Vector/X86VectorDialect.h.inc" #define GET_OP_CLASSES diff --git a/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td new file mode 100644 index 000000000000..98d5ca70b4a7 --- /dev/null +++ b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td @@ -0,0 +1,68 @@ +//===- X86VectorInterfaces.td - X86Vector interfaces -------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines interfaces for the X86Vector dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef X86VECTOR_INTERFACES +#define X86VECTOR_INTERFACES + +include "mlir/IR/Interfaces.td" + +//===----------------------------------------------------------------------===// +// One-to-One Intrinsic Interface +//===----------------------------------------------------------------------===// + +def OneToOneIntrinsicOpInterface : OpInterface<"OneToOneIntrinsicOp"> { + let description = [{ + Interface for 1-to-1 conversion of an operation into LLVM intrinsics. + + An op implementing this interface can be simply replaced by a call + to a matching intrinsic function. + The op must ensure that the combinations of their arguments and results + have valid intrinsic counterparts. + + For example, an operation supporting different vector widths: + ```mlir + %res_v8 = x86vector.op %value_v8 : vector<8xf32> + %res_v16 = x86vector.op %value_v16 : vector<16xf32> + ``` + can be converted to the following intrinsic calls: + ```mlir + %res_v8 = llvm.call_intrinsic "llvm.x86.op.intr.256"(%value_v8) + %res_v16 = llvm.call_intrinsic "llvm.x86.op.intr.512"(%value_v16) + ``` + }]; + let cppNamespace = "::mlir::x86vector"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Returns mangled LLVM intrinsic function name matching the operation + variant. + }], + /*retType=*/"std::string", + /*methodName=*/"getIntrinsicName" + >, + InterfaceMethod< + /*desc=*/[{ + Returns operands for a corresponding LLVM intrinsic. + + Additional operations may be created to facilitate mapping + between the source operands and the target intrinsic. + }], + /*retType=*/"SmallVector", + /*methodName=*/"getIntrinsicOperands", + /*args=*/(ins "::mlir::RewriterBase &":$rewriter), + /*methodBody=*/"", + /*defaultImplementation=*/"return SmallVector($_op->getOperands());" + >, + ]; +} + +#endif // X86VECTOR_INTERFACES diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h index de9d5872cc45..e043ff2f6825 100644 --- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h @@ -29,7 +29,6 @@ #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/SPIRV/SPIRVToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/VCIX/VCIXToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.h" namespace mlir { class DialectRegistry; @@ -50,7 +49,6 @@ static inline void registerAllToLLVMIRTranslations(DialectRegistry ®istry) { registerROCDLDialectTranslation(registry); registerSPIRVDialectTranslation(registry); registerVCIXDialectTranslation(registry); - registerX86VectorDialectTranslation(registry); // Extension required for translating GPU offloading Ops. gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry); diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.h deleted file mode 100644 index a215bcf625ae..000000000000 --- a/mlir/include/mlir/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.h +++ /dev/null @@ -1,32 +0,0 @@ -//===- X86VectorToLLVMIRTranslation.h - X86Vector to LLVM IR ----*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This provides registration calls for X86Vector dialect to LLVM IR -// translation. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_TARGET_LLVMIR_DIALECT_X86VECTOR_X86VECTORTOLLVMIRTRANSLATION_H -#define MLIR_TARGET_LLVMIR_DIALECT_X86VECTOR_X86VECTORTOLLVMIRTRANSLATION_H - -namespace mlir { - -class DialectRegistry; -class MLIRContext; - -/// Register the X86Vector dialect and the translation from it to the LLVM IR -/// in the given registry; -void registerX86VectorDialectTranslation(DialectRegistry ®istry); - -/// Register the X86Vector dialect and the translation from it in the registry -/// associated with the given context. -void registerX86VectorDialectTranslation(MLIRContext &context); - -} // namespace mlir - -#endif // MLIR_TARGET_LLVMIR_DIALECT_X86VECTOR_X86VECTORTOLLVMIRTRANSLATION_H diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp index ac21f1714689..5bb4dcfd60d8 100644 --- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp +++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/X86Vector/X86VectorDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" @@ -19,6 +20,8 @@ using namespace mlir; +#include "mlir/Dialect/X86Vector/X86VectorInterfaces.cpp.inc" + #include "mlir/Dialect/X86Vector/X86VectorDialect.cpp.inc" void x86vector::X86VectorDialect::initialize() { @@ -42,5 +45,34 @@ LogicalResult x86vector::MaskCompressOp::verify() { return success(); } +SmallVector +x86vector::MaskCompressOp::getIntrinsicOperands(RewriterBase &rewriter) { + auto loc = getLoc(); + + auto opType = getA().getType(); + Value src; + if (getSrc()) { + src = getSrc(); + } else if (getConstantSrc()) { + src = rewriter.create(loc, opType, getConstantSrcAttr()); + } else { + auto zeroAttr = rewriter.getZeroAttr(opType); + src = rewriter.create(loc, opType, zeroAttr); + } + + return SmallVector{getA(), src, getK()}; +} + +SmallVector +x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter) { + SmallVector operands(getOperands()); + // Dot product of all elements, broadcasted to all elements. + Value scale = + rewriter.create(getLoc(), rewriter.getI8Type(), 0xff); + operands.push_back(scale); + + return operands; +} + #define GET_OP_CLASSES #include "mlir/Dialect/X86Vector/X86Vector.cpp.inc" diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt index f1fb64ffa3f4..c51266afe9e8 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt @@ -2,9 +2,6 @@ add_mlir_dialect_library(MLIRX86VectorTransforms AVXTranspose.cpp LegalizeForLLVMExport.cpp - DEPENDS - MLIRX86VectorConversionsIncGen - LINK_LIBS PUBLIC MLIRArithDialect MLIRX86VectorDialect diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp index f1fbb39b97fc..c0c7f61f55f8 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp @@ -10,7 +10,6 @@ #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" -#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/IR/BuiltinOps.h" @@ -19,242 +18,103 @@ using namespace mlir; using namespace mlir::x86vector; -/// Extracts the "main" vector element type from the given X86Vector operation. -template -static Type getSrcVectorElementType(OpTy op) { - return cast(op.getSrc().getType()).getElementType(); -} -template <> -Type getSrcVectorElementType(Vp2IntersectOp op) { - return cast(op.getA().getType()).getElementType(); -} - namespace { -/// Base conversion for AVX512 ops that can be lowered to one of the two -/// intrinsics based on the bitwidth of their "main" vector element type. This -/// relies on the to-LLVM-dialect conversion helpers to correctly pack the -/// results of multi-result intrinsic ops. -template -struct LowerToIntrinsic : public OpConversionPattern { - explicit LowerToIntrinsic(const LLVMTypeConverter &converter) - : OpConversionPattern(converter, &converter.getContext()) {} +/// Replaces an operation with a call to an LLVM intrinsic with the specified +/// name and operands. +/// +/// The rewrite performs a simple one-to-one matching between the op and LLVM +/// intrinsic. For example: +/// +/// ```mlir +/// %res = x86vector.op %val : vector<16xf32> +/// ``` +/// +/// can be converted to +/// +/// ```mlir +/// %res = llvm.call_intrinsic "intrinsic"(%val) +/// ``` +/// +/// The provided operands must be LLVM-compatible. +/// +/// Upholds a convention that multi-result operations get converted into an +/// operation returning the LLVM IR structure type, in which case individual +/// values are first extracted before replacing the original results. +LogicalResult intrinsicRewrite(Operation *op, StringAttr intrinsic, + ValueRange operands, + const LLVMTypeConverter &typeConverter, + PatternRewriter &rewriter) { + auto loc = op->getLoc(); - const LLVMTypeConverter &getTypeConverter() const { - return *static_cast( - OpConversionPattern::getTypeConverter()); - } + if (!llvm::all_of(operands, [](Value value) { + return LLVM::isCompatibleType(value.getType()); + })) + return rewriter.notifyMatchFailure(op, "Expects LLVM-compatible types."); - LogicalResult - matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type elementType = getSrcVectorElementType(op); - unsigned bitwidth = elementType.getIntOrFloatBitWidth(); - if (bitwidth == 32) - return LLVM::detail::oneToOneRewrite( - op, Intr32OpTy::getOperationName(), adaptor.getOperands(), - op->getAttrs(), getTypeConverter(), rewriter); - if (bitwidth == 64) - return LLVM::detail::oneToOneRewrite( - op, Intr64OpTy::getOperationName(), adaptor.getOperands(), - op->getAttrs(), getTypeConverter(), rewriter); - return rewriter.notifyMatchFailure( - op, "expected 'src' to be either f32 or f64"); - } -}; + unsigned numResults = op->getNumResults(); + Type resType; + if (numResults != 0) + resType = typeConverter.packOperationResults(op->getResultTypes()); -struct MaskCompressOpConversion - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(MaskCompressOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto opType = adaptor.getA().getType(); - - Value src; - if (op.getSrc()) { - src = adaptor.getSrc(); - } else if (op.getConstantSrc()) { - src = rewriter.create(op.getLoc(), opType, - op.getConstantSrcAttr()); - } else { - auto zeroAttr = rewriter.getZeroAttr(opType); - src = rewriter.create(op->getLoc(), opType, zeroAttr); - } - - rewriter.replaceOpWithNewOp(op, opType, adaptor.getA(), - src, adaptor.getK()); + auto callIntrOp = + rewriter.create(loc, resType, intrinsic, operands); + // Propagate attributes. + callIntrOp->setAttrs(op->getAttrDictionary()); + if (numResults <= 1) { + // Directly replace the original op. + rewriter.replaceOp(op, callIntrOp); return success(); } -}; -struct DotBF16OpConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(DotBF16Op op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto typeA = dyn_cast(op.getA().getType()); - unsigned elemBitWidth = typeA.getElementTypeBitWidth(); - unsigned opBitWidth = typeA.getShape()[0] * elemBitWidth; - - auto opType = adaptor.getSrc().getType(); - auto opSrc = adaptor.getSrc(); - auto opA = adaptor.getA(); - auto opB = adaptor.getB(); - - switch (opBitWidth) { - case 128: { - rewriter.replaceOpWithNewOp(op, opType, opSrc, opA, - opB); - break; - } - case 256: { - rewriter.replaceOpWithNewOp(op, opType, opSrc, opA, - opB); - break; - } - case 512: { - rewriter.replaceOpWithNewOp(op, opType, opSrc, opA, - opB); - break; - } - default: { - return rewriter.notifyMatchFailure(op, - "unsupported AVX512-BF16 dot variant"); - } - } - - return success(); + // Extract individual results from packed structure and use them as + // replacements. + SmallVector results; + results.reserve(numResults); + Value intrRes = callIntrOp.getResults(); + for (unsigned i = 0; i < numResults; ++i) { + results.push_back(rewriter.create(loc, intrRes, i)); } -}; + rewriter.replaceOp(op, results); -struct CvtPackedF32ToBF16Conversion - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + return success(); +} - LogicalResult - matchAndRewrite(CvtPackedF32ToBF16Op op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto typeA = dyn_cast(op.getA().getType()); - unsigned elemBitWidth = typeA.getElementTypeBitWidth(); - unsigned opBitWidth = typeA.getShape()[0] * elemBitWidth; +/// Generic one-to-one conversion of simply mappable operations into calls +/// to their respective LLVM intrinsics. +struct OneToOneIntrinsicOpConversion + : public OpInterfaceRewritePattern { + using OpInterfaceRewritePattern< + x86vector::OneToOneIntrinsicOp>::OpInterfaceRewritePattern; - auto opType = op.getDst().getType(); - auto opA = op.getA(); + OneToOneIntrinsicOpConversion(const LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : OpInterfaceRewritePattern(&typeConverter.getContext(), benefit), + typeConverter(typeConverter) {} - switch (opBitWidth) { - case 256: { - rewriter.replaceOpWithNewOp(op, opType, opA); - break; - } - case 512: { - rewriter.replaceOpWithNewOp(op, opType, opA); - break; - } - default: { - return rewriter.notifyMatchFailure( - op, "unsupported AVX512-BF16 packed f32 to bf16 variant"); - } - } - - return success(); - } -}; - -struct RsqrtOpConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(RsqrtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto opType = adaptor.getA().getType(); - rewriter.replaceOpWithNewOp(op, opType, adaptor.getA()); - return success(); - } -}; - -struct DotOpConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(DotOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto opType = adaptor.getA().getType(); - Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8); - // Dot product of all elements, broadcasted to all elements. - auto attr = rewriter.getI8IntegerAttr(static_cast(0xff)); - Value scale = - rewriter.create(op.getLoc(), llvmIntType, attr); - rewriter.replaceOpWithNewOp(op, opType, adaptor.getA(), - adaptor.getB(), scale); - return success(); - } -}; - -/// An entry associating the "main" AVX512 op with its instantiations for -/// vectors of 32-bit and 64-bit elements. -template -struct RegEntry { - using MainOp = OpTy; - using Intr32Op = Intr32OpTy; - using Intr64Op = Intr64OpTy; -}; - -/// A container for op association entries facilitating the configuration of -/// dialect conversion. -template -struct RegistryImpl { - /// Registers the patterns specializing the "main" op to one of the - /// "intrinsic" ops depending on elemental type. - static void registerPatterns(const LLVMTypeConverter &converter, - RewritePatternSet &patterns) { - patterns - .add...>(converter); + LogicalResult matchAndRewrite(x86vector::OneToOneIntrinsicOp op, + PatternRewriter &rewriter) const override { + return intrinsicRewrite(op, rewriter.getStringAttr(op.getIntrinsicName()), + op.getIntrinsicOperands(rewriter), typeConverter, + rewriter); } - /// Configures the conversion target to lower out "main" ops. - static void configureTarget(LLVMConversionTarget &target) { - target.addIllegalOp(); - target.addLegalOp(); - target.addLegalOp(); - } +private: + const LLVMTypeConverter &typeConverter; }; -using Registry = RegistryImpl< - RegEntry, - RegEntry, - RegEntry>; - } // namespace /// Populate the given list with patterns that convert from X86Vector to LLVM. void mlir::populateX86VectorLegalizeForLLVMExportPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns) { - Registry::registerPatterns(converter, patterns); - patterns - .add( - converter); + patterns.add(converter); } void mlir::configureX86VectorLegalizeForExportTarget( LLVMConversionTarget &target) { - Registry::configureTarget(target); - target.addLegalOp(); - target.addIllegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addIllegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addIllegalOp(); - target.addLegalOp(); - target.addIllegalOp(); - target.addLegalOp(); - target.addIllegalOp(); + target.addIllegalOp(); } diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt index f59f1d51093e..4ace3964e8ae 100644 --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -54,7 +54,6 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration MLIRAMXToLLVMIRTranslation MLIRBuiltinToLLVMIRTranslation MLIRGPUToLLVMIRTranslation - MLIRX86VectorToLLVMIRTranslation MLIRLLVMToLLVMIRTranslation MLIRNVVMToLLVMIRTranslation MLIROpenACCToLLVMIRTranslation diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt index a88e8b1fd833..40df6e3f4b64 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -11,4 +11,3 @@ add_subdirectory(OpenMP) add_subdirectory(ROCDL) add_subdirectory(SPIRV) add_subdirectory(VCIX) -add_subdirectory(X86Vector) diff --git a/mlir/lib/Target/LLVMIR/Dialect/X86Vector/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/X86Vector/CMakeLists.txt deleted file mode 100644 index 3910982c0cdc..000000000000 --- a/mlir/lib/Target/LLVMIR/Dialect/X86Vector/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -add_mlir_translation_library(MLIRX86VectorToLLVMIRTranslation - X86VectorToLLVMIRTranslation.cpp - - DEPENDS - MLIRX86VectorConversionsIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRIR - MLIRX86VectorDialect - MLIRLLVMDialect - MLIRSupport - MLIRTargetLLVMIRExport - ) diff --git a/mlir/lib/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.cpp deleted file mode 100644 index fa5f61420ee8..000000000000 --- a/mlir/lib/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.cpp +++ /dev/null @@ -1,58 +0,0 @@ -//===- X86VectorToLLVMIRTranslation.cpp - Translate X86Vector to LLVM IR---===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements a translation between the MLIR X86Vector dialect and -// LLVM IR. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.h" -#include "mlir/Dialect/X86Vector/X86VectorDialect.h" -#include "mlir/IR/Operation.h" -#include "mlir/Target/LLVMIR/ModuleTranslation.h" - -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/IntrinsicsX86.h" - -using namespace mlir; -using namespace mlir::LLVM; - -namespace { -/// Implementation of the dialect interface that converts operations belonging -/// to the X86Vector dialect to LLVM IR. -class X86VectorDialectLLVMIRTranslationInterface - : public LLVMTranslationDialectInterface { -public: - using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; - - /// Translates the given operation to LLVM IR using the provided IR builder - /// and saving the state in `moduleTranslation`. - LogicalResult - convertOperation(Operation *op, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) const final { - Operation &opInst = *op; -#include "mlir/Dialect/X86Vector/X86VectorConversions.inc" - - return failure(); - } -}; -} // namespace - -void mlir::registerX86VectorDialectTranslation(DialectRegistry ®istry) { - registry.insert(); - registry.addExtension( - +[](MLIRContext *ctx, x86vector::X86VectorDialect *dialect) { - dialect->addInterfaces(); - }); -} - -void mlir::registerX86VectorDialectTranslation(MLIRContext &context) { - DialectRegistry registry; - registerX86VectorDialectTranslation(registry); - context.appendDialectRegistry(registry); -} diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir index 59be7dd75b3b..df0be7bce83b 100644 --- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir @@ -1,33 +1,40 @@ // RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86vector" | mlir-opt | FileCheck %s // CHECK-LABEL: func @avx512_mask_rndscale -func.func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8) +func.func @avx512_mask_rndscale( + %src: vector<16xf32>, %a: vector<16xf32>, %b: vector<8xf64>, + %imm_i16: i16, %imm_i8: i8, %scale_k_i16: i16, %scale_k_i8: i8) -> (vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64>) { - // CHECK: x86vector.avx512.intr.mask.rndscale.ps.512 - %0 = x86vector.avx512.mask.rndscale %a, %i32, %a, %i16, %i32: vector<16xf32> - // CHECK: x86vector.avx512.intr.mask.rndscale.pd.512 - %1 = x86vector.avx512.mask.rndscale %b, %i32, %b, %i8, %i32: vector<8xf64> + %rnd_k = arith.constant 15 : i32 + %rnd = arith.constant 42 : i32 + // CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.rndscale.ps.512" + %0 = x86vector.avx512.mask.rndscale %src, %rnd_k, %a, %imm_i16, %rnd : vector<16xf32> + // CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.rndscale.pd.512" + %1 = x86vector.avx512.mask.rndscale %b, %rnd_k, %b, %imm_i8, %rnd : vector<8xf64> - // CHECK: x86vector.avx512.intr.mask.scalef.ps.512 - %2 = x86vector.avx512.mask.scalef %a, %a, %a, %i16, %i32: vector<16xf32> - // CHECK: x86vector.avx512.intr.mask.scalef.pd.512 - %3 = x86vector.avx512.mask.scalef %b, %b, %b, %i8, %i32: vector<8xf64> + // CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.scalef.ps.512" + %2 = x86vector.avx512.mask.scalef %a, %a, %a, %scale_k_i16, %rnd : vector<16xf32> + // CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.scalef.pd.512" + %3 = x86vector.avx512.mask.scalef %b, %b, %b, %scale_k_i8, %rnd : vector<8xf64> // Keep results alive. return %0, %1, %2, %3 : vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64> } // CHECK-LABEL: func @avx512_mask_compress -func.func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>, - %k2: vector<8xi1>, %a2: vector<8xi64>) +func.func @avx512_mask_compress( + %k1: vector<16xi1>, %a1: vector<16xf32>, %k2: vector<8xi1>, %a2: vector<8xi64>) -> (vector<16xf32>, vector<16xf32>, vector<8xi64>) { - // CHECK: x86vector.avx512.intr.mask.compress + // CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<16xf32>) + // CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.compress" %0 = x86vector.avx512.mask.compress %k1, %a1 : vector<16xf32> - // CHECK: x86vector.avx512.intr.mask.compress - %1 = x86vector.avx512.mask.compress %k1, %a1 {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32> - // CHECK: x86vector.avx512.intr.mask.compress + // CHECK: llvm.mlir.constant(dense<5.000000e+00> : vector<16xf32>) + // CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.compress" + %1 = x86vector.avx512.mask.compress %k1, %a1 + {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32> + // CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.compress" %2 = x86vector.avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64> return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64> } @@ -36,9 +43,9 @@ func.func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>, func.func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>) -> (vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>) { - // CHECK: x86vector.avx512.intr.vp2intersect.d.512 + // CHECK: llvm.call_intrinsic "llvm.x86.avx512.vp2intersect.d.512" %0, %1 = x86vector.avx512.vp2intersect %a, %a : vector<16xi32> - // CHECK: x86vector.avx512.intr.vp2intersect.q.512 + // CHECK: llvm.call_intrinsic "llvm.x86.avx512.vp2intersect.q.512" %2, %3 = x86vector.avx512.vp2intersect %b, %b : vector<8xi64> return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1> } @@ -47,7 +54,7 @@ func.func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>) func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>, %b: vector<8xbf16>) -> (vector<4xf32>) { - // CHECK: x86vector.avx512.intr.dpbf16ps.128 + // CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.dpbf16ps.128" %0 = x86vector.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32> return %0 : vector<4xf32> } @@ -56,7 +63,7 @@ func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>, func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>, %b: vector<16xbf16>) -> (vector<8xf32>) { - // CHECK: x86vector.avx512.intr.dpbf16ps.256 + // CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.dpbf16ps.256" %0 = x86vector.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32> return %0 : vector<8xf32> } @@ -65,7 +72,7 @@ func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>, func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>, %b: vector<32xbf16>) -> (vector<16xf32>) { - // CHECK: x86vector.avx512.intr.dpbf16ps.512 + // CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.dpbf16ps.512" %0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32> return %0 : vector<16xf32> } @@ -74,7 +81,7 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>, func.func @avx512bf16_cvt_packed_f32_to_bf16_256( %a: vector<8xf32>) -> (vector<8xbf16>) { - // CHECK: x86vector.avx512.intr.cvtneps2bf16.256 + // CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.cvtneps2bf16.256" %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16> return %0 : vector<8xbf16> } @@ -83,7 +90,7 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_256( func.func @avx512bf16_cvt_packed_f32_to_bf16_512( %a: vector<16xf32>) -> (vector<16xbf16>) { - // CHECK: x86vector.avx512.intr.cvtneps2bf16.512 + // CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.cvtneps2bf16.512" %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16> return %0 : vector<16xbf16> } @@ -91,7 +98,7 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512( // CHECK-LABEL: func @avx_rsqrt func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>) { - // CHECK: x86vector.avx.intr.rsqrt.ps.256 + // CHECK: llvm.call_intrinsic "llvm.x86.avx.rsqrt.ps.256" %0 = x86vector.avx.rsqrt %a : vector<8xf32> return %0 : vector<8xf32> } @@ -99,7 +106,8 @@ func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>) // CHECK-LABEL: func @avx_dot func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>) { - // CHECK: x86vector.avx.intr.dp.ps.256 + // CHECK: llvm.mlir.constant(-1 : i8) + // CHECK: llvm.call_intrinsic "llvm.x86.avx.dp.ps.256" %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32> return %0 : vector<8xf32> } diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir index db1c10cd5cd3..85dad36334b1 100644 --- a/mlir/test/Target/LLVMIR/x86vector.mlir +++ b/mlir/test/Target/LLVMIR/x86vector.mlir @@ -1,133 +1,128 @@ -// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s +// RUN: mlir-opt %s --convert-vector-to-llvm="enable-x86vector" --convert-to-llvm \ +// RUN: | mlir-translate --mlir-to-llvmir \ +// RUN: | FileCheck %s // CHECK-LABEL: define <16 x float> @LLVM_x86_avx512_mask_ps_512 -llvm.func @LLVM_x86_avx512_mask_ps_512(%a: vector<16 x f32>, - %c: i16) - -> (vector<16 x f32>) +func.func @LLVM_x86_avx512_mask_ps_512( + %src: vector<16xf32>, %a: vector<16xf32>, %b: vector<16xf32>, + %imm: i16, %scale_k: i16) + -> (vector<16xf32>) { - %b = llvm.mlir.constant(42 : i32) : i32 + %rnd_k = arith.constant 15 : i32 + %rnd = arith.constant 42 : i32 // CHECK: call <16 x float> @llvm.x86.avx512.mask.rndscale.ps.512(<16 x float> - %0 = "x86vector.avx512.intr.mask.rndscale.ps.512"(%a, %b, %a, %c, %b) : - (vector<16 x f32>, i32, vector<16 x f32>, i16, i32) -> vector<16 x f32> + %0 = x86vector.avx512.mask.rndscale %src, %rnd_k, %a, %imm, %rnd : vector<16xf32> // CHECK: call <16 x float> @llvm.x86.avx512.mask.scalef.ps.512(<16 x float> - %1 = "x86vector.avx512.intr.mask.scalef.ps.512"(%a, %a, %a, %c, %b) : - (vector<16 x f32>, vector<16 x f32>, vector<16 x f32>, i16, i32) -> vector<16 x f32> - llvm.return %1: vector<16 x f32> + %1 = x86vector.avx512.mask.scalef %0, %a, %b, %scale_k, %rnd : vector<16xf32> + return %1 : vector<16xf32> } // CHECK-LABEL: define <8 x double> @LLVM_x86_avx512_mask_pd_512 -llvm.func @LLVM_x86_avx512_mask_pd_512(%a: vector<8xf64>, - %c: i8) +func.func @LLVM_x86_avx512_mask_pd_512( + %src: vector<8xf64>, %a: vector<8xf64>, %b: vector<8xf64>, + %imm: i8, %scale_k: i8) -> (vector<8xf64>) { - %b = llvm.mlir.constant(42 : i32) : i32 + %rnd_k = arith.constant 15 : i32 + %rnd = arith.constant 42 : i32 // CHECK: call <8 x double> @llvm.x86.avx512.mask.rndscale.pd.512(<8 x double> - %0 = "x86vector.avx512.intr.mask.rndscale.pd.512"(%a, %b, %a, %c, %b) : - (vector<8xf64>, i32, vector<8xf64>, i8, i32) -> vector<8xf64> + %0 = x86vector.avx512.mask.rndscale %src, %rnd_k, %a, %imm, %rnd : vector<8xf64> // CHECK: call <8 x double> @llvm.x86.avx512.mask.scalef.pd.512(<8 x double> - %1 = "x86vector.avx512.intr.mask.scalef.pd.512"(%a, %a, %a, %c, %b) : - (vector<8xf64>, vector<8xf64>, vector<8xf64>, i8, i32) -> vector<8xf64> - llvm.return %1: vector<8xf64> + %1 = x86vector.avx512.mask.scalef %0, %a, %b, %scale_k, %rnd : vector<8xf64> + return %1 : vector<8xf64> } // CHECK-LABEL: define <16 x float> @LLVM_x86_mask_compress -llvm.func @LLVM_x86_mask_compress(%k: vector<16xi1>, %a: vector<16xf32>) +func.func @LLVM_x86_mask_compress(%k: vector<16xi1>, %a: vector<16xf32>) -> vector<16xf32> { // CHECK: call <16 x float> @llvm.x86.avx512.mask.compress.v16f32( - %0 = "x86vector.avx512.intr.mask.compress"(%a, %a, %k) : - (vector<16xf32>, vector<16xf32>, vector<16xi1>) -> vector<16xf32> - llvm.return %0 : vector<16xf32> + %0 = x86vector.avx512.mask.compress %k, %a, %a : vector<16xf32>, vector<16xf32> + return %0 : vector<16xf32> } // CHECK-LABEL: define { <16 x i1>, <16 x i1> } @LLVM_x86_vp2intersect_d_512 -llvm.func @LLVM_x86_vp2intersect_d_512(%a: vector<16xi32>, %b: vector<16xi32>) - -> !llvm.struct<(vector<16 x i1>, vector<16 x i1>)> +func.func @LLVM_x86_vp2intersect_d_512(%a: vector<16xi32>, %b: vector<16xi32>) + -> (vector<16xi1>, vector<16xi1>) { // CHECK: call { <16 x i1>, <16 x i1> } @llvm.x86.avx512.vp2intersect.d.512(<16 x i32> - %0 = "x86vector.avx512.intr.vp2intersect.d.512"(%a, %b) : - (vector<16xi32>, vector<16xi32>) -> !llvm.struct<(vector<16 x i1>, vector<16 x i1>)> - llvm.return %0 : !llvm.struct<(vector<16 x i1>, vector<16 x i1>)> + %0, %1 = x86vector.avx512.vp2intersect %a, %b : vector<16xi32> + return %0, %1 : vector<16xi1>, vector<16xi1> } // CHECK-LABEL: define { <8 x i1>, <8 x i1> } @LLVM_x86_vp2intersect_q_512 -llvm.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>) - -> !llvm.struct<(vector<8 x i1>, vector<8 x i1>)> +func.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>) + -> (vector<8 x i1>, vector<8 x i1>) { // CHECK: call { <8 x i1>, <8 x i1> } @llvm.x86.avx512.vp2intersect.q.512(<8 x i64> - %0 = "x86vector.avx512.intr.vp2intersect.q.512"(%a, %b) : - (vector<8xi64>, vector<8xi64>) -> !llvm.struct<(vector<8 x i1>, vector<8 x i1>)> - llvm.return %0 : !llvm.struct<(vector<8 x i1>, vector<8 x i1>)> + %0, %1 = x86vector.avx512.vp2intersect %a, %b : vector<8xi64> + return %0, %1 : vector<8 x i1>, vector<8 x i1> } // CHECK-LABEL: define <4 x float> @LLVM_x86_avx512bf16_dpbf16ps_128 -llvm.func @LLVM_x86_avx512bf16_dpbf16ps_128( +func.func @LLVM_x86_avx512bf16_dpbf16ps_128( %src: vector<4xf32>, %a: vector<8xbf16>, %b: vector<8xbf16> ) -> vector<4xf32> { // CHECK: call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128( - %0 = "x86vector.avx512.intr.dpbf16ps.128"(%src, %a, %b) - : (vector<4xf32>, vector<8xbf16>, vector<8xbf16>) -> vector<4xf32> - llvm.return %0 : vector<4xf32> + %0 = x86vector.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32> + return %0 : vector<4xf32> } // CHECK-LABEL: define <8 x float> @LLVM_x86_avx512bf16_dpbf16ps_256 -llvm.func @LLVM_x86_avx512bf16_dpbf16ps_256( +func.func @LLVM_x86_avx512bf16_dpbf16ps_256( %src: vector<8xf32>, %a: vector<16xbf16>, %b: vector<16xbf16> ) -> vector<8xf32> { // CHECK: call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256( - %0 = "x86vector.avx512.intr.dpbf16ps.256"(%src, %a, %b) - : (vector<8xf32>, vector<16xbf16>, vector<16xbf16>) -> vector<8xf32> - llvm.return %0 : vector<8xf32> + %0 = x86vector.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32> + return %0 : vector<8xf32> } // CHECK-LABEL: define <16 x float> @LLVM_x86_avx512bf16_dpbf16ps_512 -llvm.func @LLVM_x86_avx512bf16_dpbf16ps_512( +func.func @LLVM_x86_avx512bf16_dpbf16ps_512( %src: vector<16xf32>, %a: vector<32xbf16>, %b: vector<32xbf16> ) -> vector<16xf32> { // CHECK: call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512( - %0 = "x86vector.avx512.intr.dpbf16ps.512"(%src, %a, %b) - : (vector<16xf32>, vector<32xbf16>, vector<32xbf16>) -> vector<16xf32> - llvm.return %0 : vector<16xf32> + %0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32> + return %0 : vector<16xf32> } // CHECK-LABEL: define <8 x bfloat> @LLVM_x86_avx512bf16_cvtneps2bf16_256 -llvm.func @LLVM_x86_avx512bf16_cvtneps2bf16_256( +func.func @LLVM_x86_avx512bf16_cvtneps2bf16_256( %a: vector<8xf32>) -> vector<8xbf16> { // CHECK: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256( - %0 = "x86vector.avx512.intr.cvtneps2bf16.256"(%a) - : (vector<8xf32>) -> vector<8xbf16> - llvm.return %0 : vector<8xbf16> + %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a + : vector<8xf32> -> vector<8xbf16> + return %0 : vector<8xbf16> } // CHECK-LABEL: define <16 x bfloat> @LLVM_x86_avx512bf16_cvtneps2bf16_512 -llvm.func @LLVM_x86_avx512bf16_cvtneps2bf16_512( +func.func @LLVM_x86_avx512bf16_cvtneps2bf16_512( %a: vector<16xf32>) -> vector<16xbf16> { // CHECK: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512( - %0 = "x86vector.avx512.intr.cvtneps2bf16.512"(%a) - : (vector<16xf32>) -> vector<16xbf16> - llvm.return %0 : vector<16xbf16> + %0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a + : vector<16xf32> -> vector<16xbf16> + return %0 : vector<16xbf16> } // CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256 -llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32> +func.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32> { // CHECK: call <8 x float> @llvm.x86.avx.rsqrt.ps.256(<8 x float> - %0 = "x86vector.avx.intr.rsqrt.ps.256"(%a) : (vector<8xf32>) -> (vector<8xf32>) - llvm.return %0 : vector<8xf32> + %0 = x86vector.avx.rsqrt %a : vector<8xf32> + return %0 : vector<8xf32> } // CHECK-LABEL: define <8 x float> @LLVM_x86_avx_dp_ps_256 -llvm.func @LLVM_x86_avx_dp_ps_256( +func.func @LLVM_x86_avx_dp_ps_256( %a: vector<8xf32>, %b: vector<8xf32> ) -> vector<8xf32> { // CHECK: call <8 x float> @llvm.x86.avx.dp.ps.256( - %c = llvm.mlir.constant(-1 : i8) : i8 - %1 = "x86vector.avx.intr.dp.ps.256"(%a, %b, %c) : (vector<8xf32>, vector<8xf32>, i8) -> vector<8xf32> - llvm.return %1 : vector<8xf32> + %0 = x86vector.avx.intr.dot %a, %b : vector<8xf32> + return %0 : vector<8xf32> } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index c08508e166fa..e69fc0e50a5b 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2298,14 +2298,6 @@ cc_library( ], ) -gentbl_cc_library( - name = "X86VectorConversionIncGen", - tbl_outs = {"include/mlir/Dialect/X86Vector/X86VectorConversions.inc": ["-gen-llvmir-conversions"]}, - tblgen = ":mlir-tblgen", - td_file = "include/mlir/Dialect/X86Vector/X86Vector.td", - deps = [":X86VectorTdFiles"], -) - ##---------------------------------------------------------------------------## # IRDL dialect. ##---------------------------------------------------------------------------##