Files
clang-p2996/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
Adam Siemieniuk 0c2a6f2d62 [mlir][x86vector] Simplify intrinsic generation (#133692)
Replaces separate x86vector named intrinsic operations with direct calls
to LLVM intrinsic functions.
    
This rework reduces the number of named ops leaving only high-level MLIR
equivalents of whole intrinsic classes e.g., variants of AVX512 dot on
BF16 inputs. Dialect conversion applies LLVM intrinsic name mangling
further simplifying lowering logic.
    
The separate conversion step translating x86vector intrinsics into LLVM
IR is also eliminated. Instead, this step is now performed by the
existing llvm dialect infrastructure.

RFC:
https://discourse.llvm.org/t/rfc-simplify-x86-intrinsic-generation/85581
2025-04-09 19:59:37 +02:00

79 lines
2.5 KiB
C++

//===- X86VectorDialect.cpp - MLIR X86Vector ops implementation -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the X86Vector dialect and its operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
using namespace mlir;
#include "mlir/Dialect/X86Vector/X86VectorInterfaces.cpp.inc"
#include "mlir/Dialect/X86Vector/X86VectorDialect.cpp.inc"
void x86vector::X86VectorDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"
>();
}
LogicalResult x86vector::MaskCompressOp::verify() {
if (getSrc() && getConstantSrc())
return emitError("cannot use both src and constant_src");
if (getSrc() && (getSrc().getType() != getDst().getType()))
return emitError("failed to verify that src and dst have same type");
if (getConstantSrc() && (getConstantSrc()->getType() != getDst().getType()))
return emitError(
"failed to verify that constant_src and dst have same type");
return success();
}
SmallVector<Value>
x86vector::MaskCompressOp::getIntrinsicOperands(RewriterBase &rewriter) {
auto loc = getLoc();
auto opType = getA().getType();
Value src;
if (getSrc()) {
src = getSrc();
} else if (getConstantSrc()) {
src = rewriter.create<LLVM::ConstantOp>(loc, opType, getConstantSrcAttr());
} else {
auto zeroAttr = rewriter.getZeroAttr(opType);
src = rewriter.create<LLVM::ConstantOp>(loc, opType, zeroAttr);
}
return SmallVector<Value>{getA(), src, getK()};
}
SmallVector<Value>
x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter) {
SmallVector<Value> operands(getOperands());
// Dot product of all elements, broadcasted to all elements.
Value scale =
rewriter.create<LLVM::ConstantOp>(getLoc(), rewriter.getI8Type(), 0xff);
operands.push_back(scale);
return operands;
}
#define GET_OP_CLASSES
#include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"