[mlir][x86vector] Simplify intrinsic generation (#133692)
Replaces separate x86vector named intrinsic operations with direct calls
to LLVM intrinsic functions.
This rework reduces the number of named ops leaving only high-level MLIR
equivalents of whole intrinsic classes e.g., variants of AVX512 dot on
BF16 inputs. Dialect conversion applies LLVM intrinsic name mangling
further simplifying lowering logic.
The separate conversion step translating x86vector intrinsics into LLVM
IR is also eliminated. Instead, this step is now performed by the
existing llvm dialect infrastructure.
RFC:
https://discourse.llvm.org/t/rfc-simplify-x86-intrinsic-generation/85581
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
add_mlir_dialect(X86Vector x86vector)
|
||||
add_mlir_doc(X86Vector X86Vector Dialects/ -gen-dialect-doc -dialect=x86vector)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS X86Vector.td)
|
||||
mlir_tablegen(X86VectorConversions.inc -gen-llvmir-conversions)
|
||||
add_public_tablegen_target(MLIRX86VectorConversionsIncGen)
|
||||
add_mlir_interface(X86VectorInterfaces)
|
||||
add_dependencies(MLIRX86VectorIncGen MLIRX86VectorInterfacesIncGen)
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
|
||||
include "mlir/Dialect/X86Vector/X86VectorInterfaces.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// X86Vector dialect definition
|
||||
@@ -34,30 +35,12 @@ def X86Vector_Dialect : Dialect {
|
||||
class AVX512_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<X86Vector_Dialect, "avx512." # mnemonic, traits> {}
|
||||
|
||||
// Intrinsic operation used during lowering to LLVM IR.
|
||||
class AVX512_IntrOp<string mnemonic, int numResults,
|
||||
list<Trait> traits = [],
|
||||
string extension = ""> :
|
||||
LLVM_IntrOpBase<X86Vector_Dialect, "avx512.intr." # mnemonic,
|
||||
!subst("EXT", extension, "x86_avx512EXT_") # !subst(".", "_", mnemonic),
|
||||
[], [], traits, numResults>;
|
||||
|
||||
// Defined by first result overload. May have to be extended for other
|
||||
// instructions in the future.
|
||||
class AVX512_IntrOverloadedOp<string mnemonic,
|
||||
list<Trait> traits = [],
|
||||
string extension = ""> :
|
||||
LLVM_IntrOpBase<X86Vector_Dialect, "avx512.intr." # mnemonic,
|
||||
!subst("EXT", extension, "x86_avx512EXT_") # !subst(".", "_", mnemonic),
|
||||
/*list<int> overloadedResults=*/[0],
|
||||
/*list<int> overloadedOperands=*/[],
|
||||
traits, /*numResults=*/1>;
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// MaskCompressOp
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
|
||||
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
|
||||
// TODO: Support optional arguments in `AllTypesMatch`. "type($src)" could
|
||||
// then be removed from assemblyFormat.
|
||||
AllTypesMatch<["a", "dst"]>,
|
||||
@@ -91,21 +74,17 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
|
||||
let assemblyFormat = "$k `,` $a (`,` $src^)? attr-dict"
|
||||
" `:` type($dst) (`,` type($src)^)?";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def MaskCompressIntrOp : AVX512_IntrOverloadedOp<"mask.compress", [
|
||||
Pure,
|
||||
AllTypesMatch<["a", "src", "res"]>,
|
||||
TypesMatchWith<"`k` has the same number of bits as elements in `res`",
|
||||
"res", "k",
|
||||
"VectorType::get({::llvm::cast<VectorType>($_self).getShape()[0]}, "
|
||||
"IntegerType::get($_self.getContext(), 1))">]> {
|
||||
let arguments = (ins VectorOfLengthAndType<[16, 8],
|
||||
[F32, I32, F64, I64]>:$a,
|
||||
VectorOfLengthAndType<[16, 8],
|
||||
[F32, I32, F64, I64]>:$src,
|
||||
VectorOfLengthAndType<[16, 8],
|
||||
[I1]>:$k);
|
||||
let extraClassDefinition = [{
|
||||
std::string $cppClass::getIntrinsicName() {
|
||||
// Call the baseline overloaded intrisic.
|
||||
// Final overload name mangling is resolved by the created function call.
|
||||
return "llvm.x86.avx512.mask.compress";
|
||||
}
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&);
|
||||
}];
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
@@ -113,6 +92,7 @@ def MaskCompressIntrOp : AVX512_IntrOverloadedOp<"mask.compress", [
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [Pure,
|
||||
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
|
||||
AllTypesMatch<["src", "a", "dst"]>,
|
||||
TypesMatchWith<"imm has the same number of bits as elements in dst",
|
||||
"dst", "imm",
|
||||
@@ -142,26 +122,20 @@ def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [Pure,
|
||||
let results = (outs VectorOfLengthAndType<[16, 8], [F32, F64]>:$dst);
|
||||
let assemblyFormat =
|
||||
"$src `,` $k `,` $a `,` $imm `,` $rounding attr-dict `:` type($dst)";
|
||||
}
|
||||
|
||||
def MaskRndScalePSIntrOp : AVX512_IntrOp<"mask.rndscale.ps.512", 1, [
|
||||
Pure,
|
||||
AllTypesMatch<["src", "a", "res"]>]> {
|
||||
let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
|
||||
I32:$k,
|
||||
VectorOfLengthAndType<[16], [F32]>:$a,
|
||||
I16:$imm,
|
||||
LLVM_Type:$rounding);
|
||||
}
|
||||
|
||||
def MaskRndScalePDIntrOp : AVX512_IntrOp<"mask.rndscale.pd.512", 1, [
|
||||
Pure,
|
||||
AllTypesMatch<["src", "a", "res"]>]> {
|
||||
let arguments = (ins VectorOfLengthAndType<[8], [F64]>:$src,
|
||||
I32:$k,
|
||||
VectorOfLengthAndType<[8], [F64]>:$a,
|
||||
I8:$imm,
|
||||
LLVM_Type:$rounding);
|
||||
let extraClassDefinition = [{
|
||||
std::string $cppClass::getIntrinsicName() {
|
||||
std::string intr = "llvm.x86.avx512.mask.rndscale";
|
||||
VectorType vecType = getSrc().getType();
|
||||
Type elemType = vecType.getElementType();
|
||||
intr += ".";
|
||||
intr += elemType.isF32() ? "ps" : "pd";
|
||||
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
|
||||
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
|
||||
intr += "." + std::to_string(opBitWidth);
|
||||
return intr;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
@@ -169,6 +143,7 @@ def MaskRndScalePDIntrOp : AVX512_IntrOp<"mask.rndscale.pd.512", 1, [
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
def MaskScaleFOp : AVX512_Op<"mask.scalef", [Pure,
|
||||
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
|
||||
AllTypesMatch<["src", "a", "b", "dst"]>,
|
||||
TypesMatchWith<"k has the same number of bits as elements in dst",
|
||||
"dst", "k",
|
||||
@@ -199,26 +174,20 @@ def MaskScaleFOp : AVX512_Op<"mask.scalef", [Pure,
|
||||
// Fully specified by traits.
|
||||
let assemblyFormat =
|
||||
"$src `,` $a `,` $b `,` $k `,` $rounding attr-dict `:` type($dst)";
|
||||
}
|
||||
|
||||
def MaskScaleFPSIntrOp : AVX512_IntrOp<"mask.scalef.ps.512", 1, [
|
||||
Pure,
|
||||
AllTypesMatch<["src", "a", "b", "res"]>]> {
|
||||
let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
|
||||
VectorOfLengthAndType<[16], [F32]>:$a,
|
||||
VectorOfLengthAndType<[16], [F32]>:$b,
|
||||
I16:$k,
|
||||
LLVM_Type:$rounding);
|
||||
}
|
||||
|
||||
def MaskScaleFPDIntrOp : AVX512_IntrOp<"mask.scalef.pd.512", 1, [
|
||||
Pure,
|
||||
AllTypesMatch<["src", "a", "b", "res"]>]> {
|
||||
let arguments = (ins VectorOfLengthAndType<[8], [F64]>:$src,
|
||||
VectorOfLengthAndType<[8], [F64]>:$a,
|
||||
VectorOfLengthAndType<[8], [F64]>:$b,
|
||||
I8:$k,
|
||||
LLVM_Type:$rounding);
|
||||
let extraClassDefinition = [{
|
||||
std::string $cppClass::getIntrinsicName() {
|
||||
std::string intr = "llvm.x86.avx512.mask.scalef";
|
||||
VectorType vecType = getSrc().getType();
|
||||
Type elemType = vecType.getElementType();
|
||||
intr += ".";
|
||||
intr += elemType.isF32() ? "ps" : "pd";
|
||||
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
|
||||
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
|
||||
intr += "." + std::to_string(opBitWidth);
|
||||
return intr;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
@@ -226,6 +195,7 @@ def MaskScaleFPDIntrOp : AVX512_IntrOp<"mask.scalef.pd.512", 1, [
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
def Vp2IntersectOp : AVX512_Op<"vp2intersect", [Pure,
|
||||
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
|
||||
AllTypesMatch<["a", "b"]>,
|
||||
TypesMatchWith<"k1 has the same number of bits as elements in a",
|
||||
"a", "k1",
|
||||
@@ -260,18 +230,20 @@ def Vp2IntersectOp : AVX512_Op<"vp2intersect", [Pure,
|
||||
);
|
||||
let assemblyFormat =
|
||||
"$a `,` $b attr-dict `:` type($a)";
|
||||
}
|
||||
|
||||
def Vp2IntersectDIntrOp : AVX512_IntrOp<"vp2intersect.d.512", 2, [
|
||||
Pure]> {
|
||||
let arguments = (ins VectorOfLengthAndType<[16], [I32]>:$a,
|
||||
VectorOfLengthAndType<[16], [I32]>:$b);
|
||||
}
|
||||
|
||||
def Vp2IntersectQIntrOp : AVX512_IntrOp<"vp2intersect.q.512", 2, [
|
||||
Pure]> {
|
||||
let arguments = (ins VectorOfLengthAndType<[8], [I64]>:$a,
|
||||
VectorOfLengthAndType<[8], [I64]>:$b);
|
||||
let extraClassDefinition = [{
|
||||
std::string $cppClass::getIntrinsicName() {
|
||||
std::string intr = "llvm.x86.avx512.vp2intersect";
|
||||
VectorType vecType = getA().getType();
|
||||
Type elemType = vecType.getElementType();
|
||||
intr += ".";
|
||||
intr += elemType.isInteger(32) ? "d" : "q";
|
||||
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
|
||||
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
|
||||
intr += "." + std::to_string(opBitWidth);
|
||||
return intr;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
@@ -279,6 +251,7 @@ def Vp2IntersectQIntrOp : AVX512_IntrOp<"vp2intersect.q.512", 2, [
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
def DotBF16Op : AVX512_Op<"dot", [Pure,
|
||||
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
|
||||
AllTypesMatch<["a", "b"]>,
|
||||
AllTypesMatch<["src", "dst"]>,
|
||||
TypesMatchWith<"`a` has twice an many elements as `src`",
|
||||
@@ -299,7 +272,7 @@ def DotBF16Op : AVX512_Op<"dot", [Pure,
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
%0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
|
||||
%dst = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
|
||||
```
|
||||
}];
|
||||
let arguments = (ins VectorOfLengthAndType<[4, 8, 16], [F32]>:$src,
|
||||
@@ -309,36 +282,17 @@ def DotBF16Op : AVX512_Op<"dot", [Pure,
|
||||
let results = (outs VectorOfLengthAndType<[4, 8, 16], [F32]>:$dst);
|
||||
let assemblyFormat =
|
||||
"$src `,` $a `,` $b attr-dict `:` type($a) `->` type($src)";
|
||||
}
|
||||
|
||||
def DotBF16Ps128IntrOp : AVX512_IntrOp<"dpbf16ps.128", 1, [Pure,
|
||||
AllTypesMatch<["a", "b"]>,
|
||||
AllTypesMatch<["src", "res"]>],
|
||||
/*extension=*/"bf16"> {
|
||||
let arguments = (ins VectorOfLengthAndType<[4], [F32]>:$src,
|
||||
VectorOfLengthAndType<[8], [BF16]>:$a,
|
||||
VectorOfLengthAndType<[8], [BF16]>:$b);
|
||||
let results = (outs VectorOfLengthAndType<[4], [F32]>:$res);
|
||||
}
|
||||
|
||||
def DotBF16Ps256IntrOp : AVX512_IntrOp<"dpbf16ps.256", 1, [Pure,
|
||||
AllTypesMatch<["a", "b"]>,
|
||||
AllTypesMatch<["src", "res"]>],
|
||||
/*extension=*/"bf16"> {
|
||||
let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$src,
|
||||
VectorOfLengthAndType<[16], [BF16]>:$a,
|
||||
VectorOfLengthAndType<[16], [BF16]>:$b);
|
||||
let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
|
||||
}
|
||||
|
||||
def DotBF16Ps512IntrOp : AVX512_IntrOp<"dpbf16ps.512", 1, [Pure,
|
||||
AllTypesMatch<["a", "b"]>,
|
||||
AllTypesMatch<["src", "res"]>],
|
||||
/*extension=*/"bf16"> {
|
||||
let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src,
|
||||
VectorOfLengthAndType<[32], [BF16]>:$a,
|
||||
VectorOfLengthAndType<[32], [BF16]>:$b);
|
||||
let results = (outs VectorOfLengthAndType<[16], [F32]>:$res);
|
||||
let extraClassDefinition = [{
|
||||
std::string $cppClass::getIntrinsicName() {
|
||||
std::string intr = "llvm.x86.avx512bf16.dpbf16ps";
|
||||
VectorType vecType = getSrc().getType();
|
||||
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
|
||||
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
|
||||
intr += "." + std::to_string(opBitWidth);
|
||||
return intr;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
@@ -346,6 +300,7 @@ def DotBF16Ps512IntrOp : AVX512_IntrOp<"dpbf16ps.512", 1, [Pure,
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
|
||||
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
|
||||
AllElementCountsMatch<["a", "dst"]>]> {
|
||||
let summary = "Convert packed F32 to packed BF16 Data.";
|
||||
let description = [{
|
||||
@@ -367,18 +322,17 @@ def CvtPackedF32ToBF16Op : AVX512_Op<"cvt.packed.f32_to_bf16", [Pure,
|
||||
let results = (outs VectorOfLengthAndType<[8, 16], [BF16]>:$dst);
|
||||
let assemblyFormat =
|
||||
"$a attr-dict `:` type($a) `->` type($dst)";
|
||||
}
|
||||
|
||||
def CvtNeF32ToBF16Ps256IntrOp : AVX512_IntrOp<"cvtneps2bf16.256", 1, [Pure],
|
||||
/*extension=*/"bf16"> {
|
||||
let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
|
||||
let results = (outs VectorOfLengthAndType<[8], [BF16]>:$res);
|
||||
}
|
||||
|
||||
def CvtNeF32ToBF16Ps512IntrOp : AVX512_IntrOp<"cvtneps2bf16.512", 1, [Pure],
|
||||
/*extension=*/"bf16"> {
|
||||
let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$a);
|
||||
let results = (outs VectorOfLengthAndType<[16], [BF16]>:$res);
|
||||
let extraClassDefinition = [{
|
||||
std::string $cppClass::getIntrinsicName() {
|
||||
std::string intr = "llvm.x86.avx512bf16.cvtneps2bf16";
|
||||
VectorType vecType = getA().getType();
|
||||
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
|
||||
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
|
||||
intr += "." + std::to_string(opBitWidth);
|
||||
return intr;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -395,33 +349,32 @@ class AVX_Op<string mnemonic, list<Trait> traits = []> :
|
||||
class AVX_LowOp<string mnemonic, list<Trait> traits = []> :
|
||||
Op<X86Vector_Dialect, "avx.intr." # mnemonic, traits> {}
|
||||
|
||||
// Intrinsic operation used during lowering to LLVM IR.
|
||||
class AVX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
|
||||
LLVM_IntrOpBase<X86Vector_Dialect, "avx.intr." # mnemonic,
|
||||
"x86_avx_" # !subst(".", "_", mnemonic),
|
||||
[], [], traits, numResults>;
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// AVX Rsqrt
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
def RsqrtOp : AVX_Op<"rsqrt", [Pure, SameOperandsAndResultType]> {
|
||||
def RsqrtOp : AVX_Op<"rsqrt", [Pure,
|
||||
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
|
||||
SameOperandsAndResultType]> {
|
||||
let summary = "Rsqrt";
|
||||
let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
|
||||
let results = (outs VectorOfLengthAndType<[8], [F32]>:$b);
|
||||
let assemblyFormat = "$a attr-dict `:` type($a)";
|
||||
}
|
||||
|
||||
def RsqrtIntrOp : AVX_IntrOp<"rsqrt.ps.256", 1, [Pure,
|
||||
SameOperandsAndResultType]> {
|
||||
let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a);
|
||||
let extraClassDefinition = [{
|
||||
std::string $cppClass::getIntrinsicName() {
|
||||
return "llvm.x86.avx.rsqrt.ps.256";
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// AVX Dot
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
def DotOp : AVX_LowOp<"dot", [Pure, SameOperandsAndResultType]> {
|
||||
def DotOp : AVX_LowOp<"dot", [Pure,
|
||||
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>,
|
||||
SameOperandsAndResultType]> {
|
||||
let summary = "Dot";
|
||||
let description = [{
|
||||
Computes the 4-way dot products of the lower and higher parts of the source
|
||||
@@ -443,13 +396,16 @@ def DotOp : AVX_LowOp<"dot", [Pure, SameOperandsAndResultType]> {
|
||||
VectorOfLengthAndType<[8], [F32]>:$b);
|
||||
let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
|
||||
let assemblyFormat = "$a `,` $b attr-dict `:` type($res)";
|
||||
}
|
||||
|
||||
def DotIntrOp : AVX_IntrOp<"dp.ps.256", 1, [Pure,
|
||||
AllTypesMatch<["a", "b", "res"]>]> {
|
||||
let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a,
|
||||
VectorOfLengthAndType<[8], [F32]>:$b, I8:$c);
|
||||
let results = (outs VectorOfLengthAndType<[8], [F32]>:$res);
|
||||
let extraClassDefinition = [{
|
||||
std::string $cppClass::getIntrinsicName() {
|
||||
// Only one variant is supported right now - no extra mangling.
|
||||
return "llvm.x86.avx.dp.ps.256";
|
||||
}
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&);
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // X86VECTOR_OPS
|
||||
|
||||
@@ -18,9 +18,13 @@
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
/// Include the generated interface declarations.
|
||||
#include "mlir/Dialect/X86Vector/X86VectorInterfaces.h.inc"
|
||||
|
||||
#include "mlir/Dialect/X86Vector/X86VectorDialect.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
|
||||
68
mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
Normal file
68
mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
Normal file
@@ -0,0 +1,68 @@
|
||||
//===- X86VectorInterfaces.td - X86Vector interfaces -------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines interfaces for the X86Vector dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef X86VECTOR_INTERFACES
|
||||
#define X86VECTOR_INTERFACES
|
||||
|
||||
include "mlir/IR/Interfaces.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// One-to-One Intrinsic Interface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def OneToOneIntrinsicOpInterface : OpInterface<"OneToOneIntrinsicOp"> {
|
||||
let description = [{
|
||||
Interface for 1-to-1 conversion of an operation into LLVM intrinsics.
|
||||
|
||||
An op implementing this interface can be simply replaced by a call
|
||||
to a matching intrinsic function.
|
||||
The op must ensure that the combinations of their arguments and results
|
||||
have valid intrinsic counterparts.
|
||||
|
||||
For example, an operation supporting different vector widths:
|
||||
```mlir
|
||||
%res_v8 = x86vector.op %value_v8 : vector<8xf32>
|
||||
%res_v16 = x86vector.op %value_v16 : vector<16xf32>
|
||||
```
|
||||
can be converted to the following intrinsic calls:
|
||||
```mlir
|
||||
%res_v8 = llvm.call_intrinsic "llvm.x86.op.intr.256"(%value_v8)
|
||||
%res_v16 = llvm.call_intrinsic "llvm.x86.op.intr.512"(%value_v16)
|
||||
```
|
||||
}];
|
||||
let cppNamespace = "::mlir::x86vector";
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Returns mangled LLVM intrinsic function name matching the operation
|
||||
variant.
|
||||
}],
|
||||
/*retType=*/"std::string",
|
||||
/*methodName=*/"getIntrinsicName"
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Returns operands for a corresponding LLVM intrinsic.
|
||||
|
||||
Additional operations may be created to facilitate mapping
|
||||
between the source operands and the target intrinsic.
|
||||
}],
|
||||
/*retType=*/"SmallVector<Value>",
|
||||
/*methodName=*/"getIntrinsicOperands",
|
||||
/*args=*/(ins "::mlir::RewriterBase &":$rewriter),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/"return SmallVector<Value>($_op->getOperands());"
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
#endif // X86VECTOR_INTERFACES
|
||||
@@ -29,7 +29,6 @@
|
||||
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/SPIRV/SPIRVToLLVMIRTranslation.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/VCIX/VCIXToLLVMIRTranslation.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.h"
|
||||
|
||||
namespace mlir {
|
||||
class DialectRegistry;
|
||||
@@ -50,7 +49,6 @@ static inline void registerAllToLLVMIRTranslations(DialectRegistry ®istry) {
|
||||
registerROCDLDialectTranslation(registry);
|
||||
registerSPIRVDialectTranslation(registry);
|
||||
registerVCIXDialectTranslation(registry);
|
||||
registerX86VectorDialectTranslation(registry);
|
||||
|
||||
// Extension required for translating GPU offloading Ops.
|
||||
gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry);
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
//===- X86VectorToLLVMIRTranslation.h - X86Vector to LLVM IR ----*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This provides registration calls for X86Vector dialect to LLVM IR
|
||||
// translation.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TARGET_LLVMIR_DIALECT_X86VECTOR_X86VECTORTOLLVMIRTRANSLATION_H
|
||||
#define MLIR_TARGET_LLVMIR_DIALECT_X86VECTOR_X86VECTORTOLLVMIRTRANSLATION_H
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class DialectRegistry;
|
||||
class MLIRContext;
|
||||
|
||||
/// Register the X86Vector dialect and the translation from it to the LLVM IR
|
||||
/// in the given registry;
|
||||
void registerX86VectorDialectTranslation(DialectRegistry ®istry);
|
||||
|
||||
/// Register the X86Vector dialect and the translation from it in the registry
|
||||
/// associated with the given context.
|
||||
void registerX86VectorDialectTranslation(MLIRContext &context);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_TARGET_LLVMIR_DIALECT_X86VECTOR_X86VECTORTOLLVMIRTRANSLATION_H
|
||||
@@ -11,6 +11,7 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
@@ -19,6 +20,8 @@
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
#include "mlir/Dialect/X86Vector/X86VectorInterfaces.cpp.inc"
|
||||
|
||||
#include "mlir/Dialect/X86Vector/X86VectorDialect.cpp.inc"
|
||||
|
||||
void x86vector::X86VectorDialect::initialize() {
|
||||
@@ -42,5 +45,34 @@ LogicalResult x86vector::MaskCompressOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
x86vector::MaskCompressOp::getIntrinsicOperands(RewriterBase &rewriter) {
|
||||
auto loc = getLoc();
|
||||
|
||||
auto opType = getA().getType();
|
||||
Value src;
|
||||
if (getSrc()) {
|
||||
src = getSrc();
|
||||
} else if (getConstantSrc()) {
|
||||
src = rewriter.create<LLVM::ConstantOp>(loc, opType, getConstantSrcAttr());
|
||||
} else {
|
||||
auto zeroAttr = rewriter.getZeroAttr(opType);
|
||||
src = rewriter.create<LLVM::ConstantOp>(loc, opType, zeroAttr);
|
||||
}
|
||||
|
||||
return SmallVector<Value>{getA(), src, getK()};
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter) {
|
||||
SmallVector<Value> operands(getOperands());
|
||||
// Dot product of all elements, broadcasted to all elements.
|
||||
Value scale =
|
||||
rewriter.create<LLVM::ConstantOp>(getLoc(), rewriter.getI8Type(), 0xff);
|
||||
operands.push_back(scale);
|
||||
|
||||
return operands;
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"
|
||||
|
||||
@@ -2,9 +2,6 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
|
||||
AVXTranspose.cpp
|
||||
LegalizeForLLVMExport.cpp
|
||||
|
||||
DEPENDS
|
||||
MLIRX86VectorConversionsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRArithDialect
|
||||
MLIRX86VectorDialect
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
@@ -19,242 +18,103 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::x86vector;
|
||||
|
||||
/// Extracts the "main" vector element type from the given X86Vector operation.
|
||||
template <typename OpTy>
|
||||
static Type getSrcVectorElementType(OpTy op) {
|
||||
return cast<VectorType>(op.getSrc().getType()).getElementType();
|
||||
}
|
||||
template <>
|
||||
Type getSrcVectorElementType(Vp2IntersectOp op) {
|
||||
return cast<VectorType>(op.getA().getType()).getElementType();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// Base conversion for AVX512 ops that can be lowered to one of the two
|
||||
/// intrinsics based on the bitwidth of their "main" vector element type. This
|
||||
/// relies on the to-LLVM-dialect conversion helpers to correctly pack the
|
||||
/// results of multi-result intrinsic ops.
|
||||
template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
|
||||
struct LowerToIntrinsic : public OpConversionPattern<OpTy> {
|
||||
explicit LowerToIntrinsic(const LLVMTypeConverter &converter)
|
||||
: OpConversionPattern<OpTy>(converter, &converter.getContext()) {}
|
||||
/// Replaces an operation with a call to an LLVM intrinsic with the specified
|
||||
/// name and operands.
|
||||
///
|
||||
/// The rewrite performs a simple one-to-one matching between the op and LLVM
|
||||
/// intrinsic. For example:
|
||||
///
|
||||
/// ```mlir
|
||||
/// %res = x86vector.op %val : vector<16xf32>
|
||||
/// ```
|
||||
///
|
||||
/// can be converted to
|
||||
///
|
||||
/// ```mlir
|
||||
/// %res = llvm.call_intrinsic "intrinsic"(%val)
|
||||
/// ```
|
||||
///
|
||||
/// The provided operands must be LLVM-compatible.
|
||||
///
|
||||
/// Upholds a convention that multi-result operations get converted into an
|
||||
/// operation returning the LLVM IR structure type, in which case individual
|
||||
/// values are first extracted before replacing the original results.
|
||||
LogicalResult intrinsicRewrite(Operation *op, StringAttr intrinsic,
|
||||
ValueRange operands,
|
||||
const LLVMTypeConverter &typeConverter,
|
||||
PatternRewriter &rewriter) {
|
||||
auto loc = op->getLoc();
|
||||
|
||||
const LLVMTypeConverter &getTypeConverter() const {
|
||||
return *static_cast<const LLVMTypeConverter *>(
|
||||
OpConversionPattern<OpTy>::getTypeConverter());
|
||||
}
|
||||
if (!llvm::all_of(operands, [](Value value) {
|
||||
return LLVM::isCompatibleType(value.getType());
|
||||
}))
|
||||
return rewriter.notifyMatchFailure(op, "Expects LLVM-compatible types.");
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type elementType = getSrcVectorElementType<OpTy>(op);
|
||||
unsigned bitwidth = elementType.getIntOrFloatBitWidth();
|
||||
if (bitwidth == 32)
|
||||
return LLVM::detail::oneToOneRewrite(
|
||||
op, Intr32OpTy::getOperationName(), adaptor.getOperands(),
|
||||
op->getAttrs(), getTypeConverter(), rewriter);
|
||||
if (bitwidth == 64)
|
||||
return LLVM::detail::oneToOneRewrite(
|
||||
op, Intr64OpTy::getOperationName(), adaptor.getOperands(),
|
||||
op->getAttrs(), getTypeConverter(), rewriter);
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expected 'src' to be either f32 or f64");
|
||||
}
|
||||
};
|
||||
unsigned numResults = op->getNumResults();
|
||||
Type resType;
|
||||
if (numResults != 0)
|
||||
resType = typeConverter.packOperationResults(op->getResultTypes());
|
||||
|
||||
struct MaskCompressOpConversion
|
||||
: public ConvertOpToLLVMPattern<MaskCompressOp> {
|
||||
using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(MaskCompressOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto opType = adaptor.getA().getType();
|
||||
|
||||
Value src;
|
||||
if (op.getSrc()) {
|
||||
src = adaptor.getSrc();
|
||||
} else if (op.getConstantSrc()) {
|
||||
src = rewriter.create<arith::ConstantOp>(op.getLoc(), opType,
|
||||
op.getConstantSrcAttr());
|
||||
} else {
|
||||
auto zeroAttr = rewriter.getZeroAttr(opType);
|
||||
src = rewriter.create<arith::ConstantOp>(op->getLoc(), opType, zeroAttr);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.getA(),
|
||||
src, adaptor.getK());
|
||||
auto callIntrOp =
|
||||
rewriter.create<LLVM::CallIntrinsicOp>(loc, resType, intrinsic, operands);
|
||||
// Propagate attributes.
|
||||
callIntrOp->setAttrs(op->getAttrDictionary());
|
||||
|
||||
if (numResults <= 1) {
|
||||
// Directly replace the original op.
|
||||
rewriter.replaceOp(op, callIntrOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct DotBF16OpConversion : public ConvertOpToLLVMPattern<DotBF16Op> {
|
||||
using ConvertOpToLLVMPattern<DotBF16Op>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(DotBF16Op op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto typeA = dyn_cast<VectorType>(op.getA().getType());
|
||||
unsigned elemBitWidth = typeA.getElementTypeBitWidth();
|
||||
unsigned opBitWidth = typeA.getShape()[0] * elemBitWidth;
|
||||
|
||||
auto opType = adaptor.getSrc().getType();
|
||||
auto opSrc = adaptor.getSrc();
|
||||
auto opA = adaptor.getA();
|
||||
auto opB = adaptor.getB();
|
||||
|
||||
switch (opBitWidth) {
|
||||
case 128: {
|
||||
rewriter.replaceOpWithNewOp<DotBF16Ps128IntrOp>(op, opType, opSrc, opA,
|
||||
opB);
|
||||
break;
|
||||
}
|
||||
case 256: {
|
||||
rewriter.replaceOpWithNewOp<DotBF16Ps256IntrOp>(op, opType, opSrc, opA,
|
||||
opB);
|
||||
break;
|
||||
}
|
||||
case 512: {
|
||||
rewriter.replaceOpWithNewOp<DotBF16Ps512IntrOp>(op, opType, opSrc, opA,
|
||||
opB);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"unsupported AVX512-BF16 dot variant");
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
// Extract individual results from packed structure and use them as
|
||||
// replacements.
|
||||
SmallVector<Value, 4> results;
|
||||
results.reserve(numResults);
|
||||
Value intrRes = callIntrOp.getResults();
|
||||
for (unsigned i = 0; i < numResults; ++i) {
|
||||
results.push_back(rewriter.create<LLVM::ExtractValueOp>(loc, intrRes, i));
|
||||
}
|
||||
};
|
||||
rewriter.replaceOp(op, results);
|
||||
|
||||
struct CvtPackedF32ToBF16Conversion
|
||||
: public ConvertOpToLLVMPattern<CvtPackedF32ToBF16Op> {
|
||||
using ConvertOpToLLVMPattern<CvtPackedF32ToBF16Op>::ConvertOpToLLVMPattern;
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(CvtPackedF32ToBF16Op op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto typeA = dyn_cast<VectorType>(op.getA().getType());
|
||||
unsigned elemBitWidth = typeA.getElementTypeBitWidth();
|
||||
unsigned opBitWidth = typeA.getShape()[0] * elemBitWidth;
|
||||
/// Generic one-to-one conversion of simply mappable operations into calls
|
||||
/// to their respective LLVM intrinsics.
|
||||
struct OneToOneIntrinsicOpConversion
|
||||
: public OpInterfaceRewritePattern<x86vector::OneToOneIntrinsicOp> {
|
||||
using OpInterfaceRewritePattern<
|
||||
x86vector::OneToOneIntrinsicOp>::OpInterfaceRewritePattern;
|
||||
|
||||
auto opType = op.getDst().getType();
|
||||
auto opA = op.getA();
|
||||
OneToOneIntrinsicOpConversion(const LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: OpInterfaceRewritePattern(&typeConverter.getContext(), benefit),
|
||||
typeConverter(typeConverter) {}
|
||||
|
||||
switch (opBitWidth) {
|
||||
case 256: {
|
||||
rewriter.replaceOpWithNewOp<CvtNeF32ToBF16Ps256IntrOp>(op, opType, opA);
|
||||
break;
|
||||
}
|
||||
case 512: {
|
||||
rewriter.replaceOpWithNewOp<CvtNeF32ToBF16Ps512IntrOp>(op, opType, opA);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unsupported AVX512-BF16 packed f32 to bf16 variant");
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
|
||||
using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(RsqrtOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto opType = adaptor.getA().getType();
|
||||
rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.getA());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct DotOpConversion : public ConvertOpToLLVMPattern<DotOp> {
|
||||
using ConvertOpToLLVMPattern<DotOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(DotOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto opType = adaptor.getA().getType();
|
||||
Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8);
|
||||
// Dot product of all elements, broadcasted to all elements.
|
||||
auto attr = rewriter.getI8IntegerAttr(static_cast<int8_t>(0xff));
|
||||
Value scale =
|
||||
rewriter.create<LLVM::ConstantOp>(op.getLoc(), llvmIntType, attr);
|
||||
rewriter.replaceOpWithNewOp<DotIntrOp>(op, opType, adaptor.getA(),
|
||||
adaptor.getB(), scale);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// An entry associating the "main" AVX512 op with its instantiations for
|
||||
/// vectors of 32-bit and 64-bit elements.
|
||||
template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
|
||||
struct RegEntry {
|
||||
using MainOp = OpTy;
|
||||
using Intr32Op = Intr32OpTy;
|
||||
using Intr64Op = Intr64OpTy;
|
||||
};
|
||||
|
||||
/// A container for op association entries facilitating the configuration of
|
||||
/// dialect conversion.
|
||||
template <typename... Args>
|
||||
struct RegistryImpl {
|
||||
/// Registers the patterns specializing the "main" op to one of the
|
||||
/// "intrinsic" ops depending on elemental type.
|
||||
static void registerPatterns(const LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns
|
||||
.add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op,
|
||||
typename Args::Intr64Op>...>(converter);
|
||||
LogicalResult matchAndRewrite(x86vector::OneToOneIntrinsicOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
return intrinsicRewrite(op, rewriter.getStringAttr(op.getIntrinsicName()),
|
||||
op.getIntrinsicOperands(rewriter), typeConverter,
|
||||
rewriter);
|
||||
}
|
||||
|
||||
/// Configures the conversion target to lower out "main" ops.
|
||||
static void configureTarget(LLVMConversionTarget &target) {
|
||||
target.addIllegalOp<typename Args::MainOp...>();
|
||||
target.addLegalOp<typename Args::Intr32Op...>();
|
||||
target.addLegalOp<typename Args::Intr64Op...>();
|
||||
}
|
||||
private:
|
||||
const LLVMTypeConverter &typeConverter;
|
||||
};
|
||||
|
||||
using Registry = RegistryImpl<
|
||||
RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>,
|
||||
RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>,
|
||||
RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>;
|
||||
|
||||
} // namespace
|
||||
|
||||
/// Populate the given list with patterns that convert from X86Vector to LLVM.
|
||||
void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
|
||||
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
Registry::registerPatterns(converter, patterns);
|
||||
patterns
|
||||
.add<MaskCompressOpConversion, DotBF16OpConversion,
|
||||
CvtPackedF32ToBF16Conversion, RsqrtOpConversion, DotOpConversion>(
|
||||
converter);
|
||||
patterns.add<OneToOneIntrinsicOpConversion>(converter);
|
||||
}
|
||||
|
||||
void mlir::configureX86VectorLegalizeForExportTarget(
|
||||
LLVMConversionTarget &target) {
|
||||
Registry::configureTarget(target);
|
||||
target.addLegalOp<MaskCompressIntrOp>();
|
||||
target.addIllegalOp<MaskCompressOp>();
|
||||
target.addLegalOp<DotBF16Ps128IntrOp>();
|
||||
target.addLegalOp<DotBF16Ps256IntrOp>();
|
||||
target.addLegalOp<DotBF16Ps512IntrOp>();
|
||||
target.addIllegalOp<DotBF16Op>();
|
||||
target.addLegalOp<CvtNeF32ToBF16Ps256IntrOp>();
|
||||
target.addLegalOp<CvtNeF32ToBF16Ps512IntrOp>();
|
||||
target.addIllegalOp<CvtPackedF32ToBF16Op>();
|
||||
target.addLegalOp<RsqrtIntrOp>();
|
||||
target.addIllegalOp<RsqrtOp>();
|
||||
target.addLegalOp<DotIntrOp>();
|
||||
target.addIllegalOp<DotOp>();
|
||||
target.addIllegalOp<MaskCompressOp, MaskRndScaleOp, MaskScaleFOp,
|
||||
Vp2IntersectOp, DotBF16Op, CvtPackedF32ToBF16Op, RsqrtOp,
|
||||
DotOp>();
|
||||
}
|
||||
|
||||
@@ -54,7 +54,6 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration
|
||||
MLIRAMXToLLVMIRTranslation
|
||||
MLIRBuiltinToLLVMIRTranslation
|
||||
MLIRGPUToLLVMIRTranslation
|
||||
MLIRX86VectorToLLVMIRTranslation
|
||||
MLIRLLVMToLLVMIRTranslation
|
||||
MLIRNVVMToLLVMIRTranslation
|
||||
MLIROpenACCToLLVMIRTranslation
|
||||
|
||||
@@ -11,4 +11,3 @@ add_subdirectory(OpenMP)
|
||||
add_subdirectory(ROCDL)
|
||||
add_subdirectory(SPIRV)
|
||||
add_subdirectory(VCIX)
|
||||
add_subdirectory(X86Vector)
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
add_mlir_translation_library(MLIRX86VectorToLLVMIRTranslation
|
||||
X86VectorToLLVMIRTranslation.cpp
|
||||
|
||||
DEPENDS
|
||||
MLIRX86VectorConversionsIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRX86VectorDialect
|
||||
MLIRLLVMDialect
|
||||
MLIRSupport
|
||||
MLIRTargetLLVMIRExport
|
||||
)
|
||||
@@ -1,58 +0,0 @@
|
||||
//===- X86VectorToLLVMIRTranslation.cpp - Translate X86Vector to LLVM IR---===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements a translation between the MLIR X86Vector dialect and
|
||||
// LLVM IR.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.h"
|
||||
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
|
||||
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/IntrinsicsX86.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::LLVM;
|
||||
|
||||
namespace {
|
||||
/// Implementation of the dialect interface that converts operations belonging
|
||||
/// to the X86Vector dialect to LLVM IR.
|
||||
class X86VectorDialectLLVMIRTranslationInterface
|
||||
: public LLVMTranslationDialectInterface {
|
||||
public:
|
||||
using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
|
||||
|
||||
/// Translates the given operation to LLVM IR using the provided IR builder
|
||||
/// and saving the state in `moduleTranslation`.
|
||||
LogicalResult
|
||||
convertOperation(Operation *op, llvm::IRBuilderBase &builder,
|
||||
LLVM::ModuleTranslation &moduleTranslation) const final {
|
||||
Operation &opInst = *op;
|
||||
#include "mlir/Dialect/X86Vector/X86VectorConversions.inc"
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::registerX86VectorDialectTranslation(DialectRegistry ®istry) {
|
||||
registry.insert<x86vector::X86VectorDialect>();
|
||||
registry.addExtension(
|
||||
+[](MLIRContext *ctx, x86vector::X86VectorDialect *dialect) {
|
||||
dialect->addInterfaces<X86VectorDialectLLVMIRTranslationInterface>();
|
||||
});
|
||||
}
|
||||
|
||||
void mlir::registerX86VectorDialectTranslation(MLIRContext &context) {
|
||||
DialectRegistry registry;
|
||||
registerX86VectorDialectTranslation(registry);
|
||||
context.appendDialectRegistry(registry);
|
||||
}
|
||||
@@ -1,33 +1,40 @@
|
||||
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86vector" | mlir-opt | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @avx512_mask_rndscale
|
||||
func.func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8)
|
||||
func.func @avx512_mask_rndscale(
|
||||
%src: vector<16xf32>, %a: vector<16xf32>, %b: vector<8xf64>,
|
||||
%imm_i16: i16, %imm_i8: i8, %scale_k_i16: i16, %scale_k_i8: i8)
|
||||
-> (vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64>)
|
||||
{
|
||||
// CHECK: x86vector.avx512.intr.mask.rndscale.ps.512
|
||||
%0 = x86vector.avx512.mask.rndscale %a, %i32, %a, %i16, %i32: vector<16xf32>
|
||||
// CHECK: x86vector.avx512.intr.mask.rndscale.pd.512
|
||||
%1 = x86vector.avx512.mask.rndscale %b, %i32, %b, %i8, %i32: vector<8xf64>
|
||||
%rnd_k = arith.constant 15 : i32
|
||||
%rnd = arith.constant 42 : i32
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.rndscale.ps.512"
|
||||
%0 = x86vector.avx512.mask.rndscale %src, %rnd_k, %a, %imm_i16, %rnd : vector<16xf32>
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.rndscale.pd.512"
|
||||
%1 = x86vector.avx512.mask.rndscale %b, %rnd_k, %b, %imm_i8, %rnd : vector<8xf64>
|
||||
|
||||
// CHECK: x86vector.avx512.intr.mask.scalef.ps.512
|
||||
%2 = x86vector.avx512.mask.scalef %a, %a, %a, %i16, %i32: vector<16xf32>
|
||||
// CHECK: x86vector.avx512.intr.mask.scalef.pd.512
|
||||
%3 = x86vector.avx512.mask.scalef %b, %b, %b, %i8, %i32: vector<8xf64>
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.scalef.ps.512"
|
||||
%2 = x86vector.avx512.mask.scalef %a, %a, %a, %scale_k_i16, %rnd : vector<16xf32>
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.scalef.pd.512"
|
||||
%3 = x86vector.avx512.mask.scalef %b, %b, %b, %scale_k_i8, %rnd : vector<8xf64>
|
||||
|
||||
// Keep results alive.
|
||||
return %0, %1, %2, %3 : vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @avx512_mask_compress
|
||||
func.func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>,
|
||||
%k2: vector<8xi1>, %a2: vector<8xi64>)
|
||||
func.func @avx512_mask_compress(
|
||||
%k1: vector<16xi1>, %a1: vector<16xf32>, %k2: vector<8xi1>, %a2: vector<8xi64>)
|
||||
-> (vector<16xf32>, vector<16xf32>, vector<8xi64>)
|
||||
{
|
||||
// CHECK: x86vector.avx512.intr.mask.compress
|
||||
// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<16xf32>)
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.compress"
|
||||
%0 = x86vector.avx512.mask.compress %k1, %a1 : vector<16xf32>
|
||||
// CHECK: x86vector.avx512.intr.mask.compress
|
||||
%1 = x86vector.avx512.mask.compress %k1, %a1 {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32>
|
||||
// CHECK: x86vector.avx512.intr.mask.compress
|
||||
// CHECK: llvm.mlir.constant(dense<5.000000e+00> : vector<16xf32>)
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.compress"
|
||||
%1 = x86vector.avx512.mask.compress %k1, %a1
|
||||
{constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32>
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.mask.compress"
|
||||
%2 = x86vector.avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64>
|
||||
return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64>
|
||||
}
|
||||
@@ -36,9 +43,9 @@ func.func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>,
|
||||
func.func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
|
||||
-> (vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>)
|
||||
{
|
||||
// CHECK: x86vector.avx512.intr.vp2intersect.d.512
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.vp2intersect.d.512"
|
||||
%0, %1 = x86vector.avx512.vp2intersect %a, %a : vector<16xi32>
|
||||
// CHECK: x86vector.avx512.intr.vp2intersect.q.512
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512.vp2intersect.q.512"
|
||||
%2, %3 = x86vector.avx512.vp2intersect %b, %b : vector<8xi64>
|
||||
return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>
|
||||
}
|
||||
@@ -47,7 +54,7 @@ func.func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
|
||||
func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
|
||||
%b: vector<8xbf16>) -> (vector<4xf32>)
|
||||
{
|
||||
// CHECK: x86vector.avx512.intr.dpbf16ps.128
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.dpbf16ps.128"
|
||||
%0 = x86vector.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
|
||||
return %0 : vector<4xf32>
|
||||
}
|
||||
@@ -56,7 +63,7 @@ func.func @avx512bf16_dot_128(%src: vector<4xf32>, %a: vector<8xbf16>,
|
||||
func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
|
||||
%b: vector<16xbf16>) -> (vector<8xf32>)
|
||||
{
|
||||
// CHECK: x86vector.avx512.intr.dpbf16ps.256
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.dpbf16ps.256"
|
||||
%0 = x86vector.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
|
||||
return %0 : vector<8xf32>
|
||||
}
|
||||
@@ -65,7 +72,7 @@ func.func @avx512bf16_dot_256(%src: vector<8xf32>, %a: vector<16xbf16>,
|
||||
func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
|
||||
%b: vector<32xbf16>) -> (vector<16xf32>)
|
||||
{
|
||||
// CHECK: x86vector.avx512.intr.dpbf16ps.512
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.dpbf16ps.512"
|
||||
%0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
|
||||
return %0 : vector<16xf32>
|
||||
}
|
||||
@@ -74,7 +81,7 @@ func.func @avx512bf16_dot_512(%src: vector<16xf32>, %a: vector<32xbf16>,
|
||||
func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
|
||||
%a: vector<8xf32>) -> (vector<8xbf16>)
|
||||
{
|
||||
// CHECK: x86vector.avx512.intr.cvtneps2bf16.256
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.cvtneps2bf16.256"
|
||||
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<8xf32> -> vector<8xbf16>
|
||||
return %0 : vector<8xbf16>
|
||||
}
|
||||
@@ -83,7 +90,7 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_256(
|
||||
func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
|
||||
%a: vector<16xf32>) -> (vector<16xbf16>)
|
||||
{
|
||||
// CHECK: x86vector.avx512.intr.cvtneps2bf16.512
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx512bf16.cvtneps2bf16.512"
|
||||
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a : vector<16xf32> -> vector<16xbf16>
|
||||
return %0 : vector<16xbf16>
|
||||
}
|
||||
@@ -91,7 +98,7 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
|
||||
// CHECK-LABEL: func @avx_rsqrt
|
||||
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
|
||||
{
|
||||
// CHECK: x86vector.avx.intr.rsqrt.ps.256
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx.rsqrt.ps.256"
|
||||
%0 = x86vector.avx.rsqrt %a : vector<8xf32>
|
||||
return %0 : vector<8xf32>
|
||||
}
|
||||
@@ -99,7 +106,8 @@ func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
|
||||
// CHECK-LABEL: func @avx_dot
|
||||
func.func @avx_dot(%a: vector<8xf32>, %b: vector<8xf32>) -> (vector<8xf32>)
|
||||
{
|
||||
// CHECK: x86vector.avx.intr.dp.ps.256
|
||||
// CHECK: llvm.mlir.constant(-1 : i8)
|
||||
// CHECK: llvm.call_intrinsic "llvm.x86.avx.dp.ps.256"
|
||||
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
|
||||
return %0 : vector<8xf32>
|
||||
}
|
||||
|
||||
@@ -1,133 +1,128 @@
|
||||
// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
|
||||
// RUN: mlir-opt %s --convert-vector-to-llvm="enable-x86vector" --convert-to-llvm \
|
||||
// RUN: | mlir-translate --mlir-to-llvmir \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: define <16 x float> @LLVM_x86_avx512_mask_ps_512
|
||||
llvm.func @LLVM_x86_avx512_mask_ps_512(%a: vector<16 x f32>,
|
||||
%c: i16)
|
||||
-> (vector<16 x f32>)
|
||||
func.func @LLVM_x86_avx512_mask_ps_512(
|
||||
%src: vector<16xf32>, %a: vector<16xf32>, %b: vector<16xf32>,
|
||||
%imm: i16, %scale_k: i16)
|
||||
-> (vector<16xf32>)
|
||||
{
|
||||
%b = llvm.mlir.constant(42 : i32) : i32
|
||||
%rnd_k = arith.constant 15 : i32
|
||||
%rnd = arith.constant 42 : i32
|
||||
// CHECK: call <16 x float> @llvm.x86.avx512.mask.rndscale.ps.512(<16 x float>
|
||||
%0 = "x86vector.avx512.intr.mask.rndscale.ps.512"(%a, %b, %a, %c, %b) :
|
||||
(vector<16 x f32>, i32, vector<16 x f32>, i16, i32) -> vector<16 x f32>
|
||||
%0 = x86vector.avx512.mask.rndscale %src, %rnd_k, %a, %imm, %rnd : vector<16xf32>
|
||||
// CHECK: call <16 x float> @llvm.x86.avx512.mask.scalef.ps.512(<16 x float>
|
||||
%1 = "x86vector.avx512.intr.mask.scalef.ps.512"(%a, %a, %a, %c, %b) :
|
||||
(vector<16 x f32>, vector<16 x f32>, vector<16 x f32>, i16, i32) -> vector<16 x f32>
|
||||
llvm.return %1: vector<16 x f32>
|
||||
%1 = x86vector.avx512.mask.scalef %0, %a, %b, %scale_k, %rnd : vector<16xf32>
|
||||
return %1 : vector<16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define <8 x double> @LLVM_x86_avx512_mask_pd_512
|
||||
llvm.func @LLVM_x86_avx512_mask_pd_512(%a: vector<8xf64>,
|
||||
%c: i8)
|
||||
func.func @LLVM_x86_avx512_mask_pd_512(
|
||||
%src: vector<8xf64>, %a: vector<8xf64>, %b: vector<8xf64>,
|
||||
%imm: i8, %scale_k: i8)
|
||||
-> (vector<8xf64>)
|
||||
{
|
||||
%b = llvm.mlir.constant(42 : i32) : i32
|
||||
%rnd_k = arith.constant 15 : i32
|
||||
%rnd = arith.constant 42 : i32
|
||||
// CHECK: call <8 x double> @llvm.x86.avx512.mask.rndscale.pd.512(<8 x double>
|
||||
%0 = "x86vector.avx512.intr.mask.rndscale.pd.512"(%a, %b, %a, %c, %b) :
|
||||
(vector<8xf64>, i32, vector<8xf64>, i8, i32) -> vector<8xf64>
|
||||
%0 = x86vector.avx512.mask.rndscale %src, %rnd_k, %a, %imm, %rnd : vector<8xf64>
|
||||
// CHECK: call <8 x double> @llvm.x86.avx512.mask.scalef.pd.512(<8 x double>
|
||||
%1 = "x86vector.avx512.intr.mask.scalef.pd.512"(%a, %a, %a, %c, %b) :
|
||||
(vector<8xf64>, vector<8xf64>, vector<8xf64>, i8, i32) -> vector<8xf64>
|
||||
llvm.return %1: vector<8xf64>
|
||||
%1 = x86vector.avx512.mask.scalef %0, %a, %b, %scale_k, %rnd : vector<8xf64>
|
||||
return %1 : vector<8xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define <16 x float> @LLVM_x86_mask_compress
|
||||
llvm.func @LLVM_x86_mask_compress(%k: vector<16xi1>, %a: vector<16xf32>)
|
||||
func.func @LLVM_x86_mask_compress(%k: vector<16xi1>, %a: vector<16xf32>)
|
||||
-> vector<16xf32>
|
||||
{
|
||||
// CHECK: call <16 x float> @llvm.x86.avx512.mask.compress.v16f32(
|
||||
%0 = "x86vector.avx512.intr.mask.compress"(%a, %a, %k) :
|
||||
(vector<16xf32>, vector<16xf32>, vector<16xi1>) -> vector<16xf32>
|
||||
llvm.return %0 : vector<16xf32>
|
||||
%0 = x86vector.avx512.mask.compress %k, %a, %a : vector<16xf32>, vector<16xf32>
|
||||
return %0 : vector<16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define { <16 x i1>, <16 x i1> } @LLVM_x86_vp2intersect_d_512
|
||||
llvm.func @LLVM_x86_vp2intersect_d_512(%a: vector<16xi32>, %b: vector<16xi32>)
|
||||
-> !llvm.struct<(vector<16 x i1>, vector<16 x i1>)>
|
||||
func.func @LLVM_x86_vp2intersect_d_512(%a: vector<16xi32>, %b: vector<16xi32>)
|
||||
-> (vector<16xi1>, vector<16xi1>)
|
||||
{
|
||||
// CHECK: call { <16 x i1>, <16 x i1> } @llvm.x86.avx512.vp2intersect.d.512(<16 x i32>
|
||||
%0 = "x86vector.avx512.intr.vp2intersect.d.512"(%a, %b) :
|
||||
(vector<16xi32>, vector<16xi32>) -> !llvm.struct<(vector<16 x i1>, vector<16 x i1>)>
|
||||
llvm.return %0 : !llvm.struct<(vector<16 x i1>, vector<16 x i1>)>
|
||||
%0, %1 = x86vector.avx512.vp2intersect %a, %b : vector<16xi32>
|
||||
return %0, %1 : vector<16xi1>, vector<16xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define { <8 x i1>, <8 x i1> } @LLVM_x86_vp2intersect_q_512
|
||||
llvm.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>)
|
||||
-> !llvm.struct<(vector<8 x i1>, vector<8 x i1>)>
|
||||
func.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>)
|
||||
-> (vector<8 x i1>, vector<8 x i1>)
|
||||
{
|
||||
// CHECK: call { <8 x i1>, <8 x i1> } @llvm.x86.avx512.vp2intersect.q.512(<8 x i64>
|
||||
%0 = "x86vector.avx512.intr.vp2intersect.q.512"(%a, %b) :
|
||||
(vector<8xi64>, vector<8xi64>) -> !llvm.struct<(vector<8 x i1>, vector<8 x i1>)>
|
||||
llvm.return %0 : !llvm.struct<(vector<8 x i1>, vector<8 x i1>)>
|
||||
%0, %1 = x86vector.avx512.vp2intersect %a, %b : vector<8xi64>
|
||||
return %0, %1 : vector<8 x i1>, vector<8 x i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define <4 x float> @LLVM_x86_avx512bf16_dpbf16ps_128
|
||||
llvm.func @LLVM_x86_avx512bf16_dpbf16ps_128(
|
||||
func.func @LLVM_x86_avx512bf16_dpbf16ps_128(
|
||||
%src: vector<4xf32>, %a: vector<8xbf16>, %b: vector<8xbf16>
|
||||
) -> vector<4xf32>
|
||||
{
|
||||
// CHECK: call <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128(
|
||||
%0 = "x86vector.avx512.intr.dpbf16ps.128"(%src, %a, %b)
|
||||
: (vector<4xf32>, vector<8xbf16>, vector<8xbf16>) -> vector<4xf32>
|
||||
llvm.return %0 : vector<4xf32>
|
||||
%0 = x86vector.avx512.dot %src, %a, %b : vector<8xbf16> -> vector<4xf32>
|
||||
return %0 : vector<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define <8 x float> @LLVM_x86_avx512bf16_dpbf16ps_256
|
||||
llvm.func @LLVM_x86_avx512bf16_dpbf16ps_256(
|
||||
func.func @LLVM_x86_avx512bf16_dpbf16ps_256(
|
||||
%src: vector<8xf32>, %a: vector<16xbf16>, %b: vector<16xbf16>
|
||||
) -> vector<8xf32>
|
||||
{
|
||||
// CHECK: call <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256(
|
||||
%0 = "x86vector.avx512.intr.dpbf16ps.256"(%src, %a, %b)
|
||||
: (vector<8xf32>, vector<16xbf16>, vector<16xbf16>) -> vector<8xf32>
|
||||
llvm.return %0 : vector<8xf32>
|
||||
%0 = x86vector.avx512.dot %src, %a, %b : vector<16xbf16> -> vector<8xf32>
|
||||
return %0 : vector<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define <16 x float> @LLVM_x86_avx512bf16_dpbf16ps_512
|
||||
llvm.func @LLVM_x86_avx512bf16_dpbf16ps_512(
|
||||
func.func @LLVM_x86_avx512bf16_dpbf16ps_512(
|
||||
%src: vector<16xf32>, %a: vector<32xbf16>, %b: vector<32xbf16>
|
||||
) -> vector<16xf32>
|
||||
{
|
||||
// CHECK: call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(
|
||||
%0 = "x86vector.avx512.intr.dpbf16ps.512"(%src, %a, %b)
|
||||
: (vector<16xf32>, vector<32xbf16>, vector<32xbf16>) -> vector<16xf32>
|
||||
llvm.return %0 : vector<16xf32>
|
||||
%0 = x86vector.avx512.dot %src, %a, %b : vector<32xbf16> -> vector<16xf32>
|
||||
return %0 : vector<16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define <8 x bfloat> @LLVM_x86_avx512bf16_cvtneps2bf16_256
|
||||
llvm.func @LLVM_x86_avx512bf16_cvtneps2bf16_256(
|
||||
func.func @LLVM_x86_avx512bf16_cvtneps2bf16_256(
|
||||
%a: vector<8xf32>) -> vector<8xbf16>
|
||||
{
|
||||
// CHECK: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(
|
||||
%0 = "x86vector.avx512.intr.cvtneps2bf16.256"(%a)
|
||||
: (vector<8xf32>) -> vector<8xbf16>
|
||||
llvm.return %0 : vector<8xbf16>
|
||||
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a
|
||||
: vector<8xf32> -> vector<8xbf16>
|
||||
return %0 : vector<8xbf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define <16 x bfloat> @LLVM_x86_avx512bf16_cvtneps2bf16_512
|
||||
llvm.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
|
||||
func.func @LLVM_x86_avx512bf16_cvtneps2bf16_512(
|
||||
%a: vector<16xf32>) -> vector<16xbf16>
|
||||
{
|
||||
// CHECK: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(
|
||||
%0 = "x86vector.avx512.intr.cvtneps2bf16.512"(%a)
|
||||
: (vector<16xf32>) -> vector<16xbf16>
|
||||
llvm.return %0 : vector<16xbf16>
|
||||
%0 = x86vector.avx512.cvt.packed.f32_to_bf16 %a
|
||||
: vector<16xf32> -> vector<16xbf16>
|
||||
return %0 : vector<16xbf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256
|
||||
llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
|
||||
func.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32>
|
||||
{
|
||||
// CHECK: call <8 x float> @llvm.x86.avx.rsqrt.ps.256(<8 x float>
|
||||
%0 = "x86vector.avx.intr.rsqrt.ps.256"(%a) : (vector<8xf32>) -> (vector<8xf32>)
|
||||
llvm.return %0 : vector<8xf32>
|
||||
%0 = x86vector.avx.rsqrt %a : vector<8xf32>
|
||||
return %0 : vector<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_dp_ps_256
|
||||
llvm.func @LLVM_x86_avx_dp_ps_256(
|
||||
func.func @LLVM_x86_avx_dp_ps_256(
|
||||
%a: vector<8xf32>, %b: vector<8xf32>
|
||||
) -> vector<8xf32>
|
||||
{
|
||||
// CHECK: call <8 x float> @llvm.x86.avx.dp.ps.256(
|
||||
%c = llvm.mlir.constant(-1 : i8) : i8
|
||||
%1 = "x86vector.avx.intr.dp.ps.256"(%a, %b, %c) : (vector<8xf32>, vector<8xf32>, i8) -> vector<8xf32>
|
||||
llvm.return %1 : vector<8xf32>
|
||||
%0 = x86vector.avx.intr.dot %a, %b : vector<8xf32>
|
||||
return %0 : vector<8xf32>
|
||||
}
|
||||
|
||||
@@ -2298,14 +2298,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
gentbl_cc_library(
|
||||
name = "X86VectorConversionIncGen",
|
||||
tbl_outs = {"include/mlir/Dialect/X86Vector/X86VectorConversions.inc": ["-gen-llvmir-conversions"]},
|
||||
tblgen = ":mlir-tblgen",
|
||||
td_file = "include/mlir/Dialect/X86Vector/X86Vector.td",
|
||||
deps = [":X86VectorTdFiles"],
|
||||
)
|
||||
|
||||
##---------------------------------------------------------------------------##
|
||||
# IRDL dialect.
|
||||
##---------------------------------------------------------------------------##
|
||||
|
||||
Reference in New Issue
Block a user