Files
clang-p2996/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
Adam Siemieniuk 0fa3ba7c39 [mlir][amx] Simplify intrinsic generation (#140559)
Replaces separate amx named intrinsic operations with direct calls to
LLVM intrinsic functions.
The existing amx tests are updated and expanded.

The separate conversion step translating amx intrinsics into LLVM IR is
eliminated. Instead, this step is now performed by the existing llvm
dialect infrastructure.

Related RFC:
https://discourse.llvm.org/t/rfc-simplify-x86-intrinsic-generation/85581/7
2025-05-23 14:16:09 +02:00

63 lines
2.2 KiB
C++

//===- LegalizeForLLVMExport.cpp - Prepare AMX for LLVM translation ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::amx;
namespace {
/// Generic one-to-one conversion of simply mappable operations into calls
/// to their respective LLVM intrinsics.
struct AMXIntrinsicOpConversion
: public OpInterfaceConversionPattern<amx::AMXIntrinsicOp> {
using OpInterfaceConversionPattern<
amx::AMXIntrinsicOp>::OpInterfaceConversionPattern;
AMXIntrinsicOpConversion(const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: OpInterfaceConversionPattern(typeConverter, &typeConverter.getContext(),
benefit),
typeConverter(typeConverter) {}
LogicalResult
matchAndRewrite(amx::AMXIntrinsicOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
return LLVM::detail::intrinsicRewrite(
op, rewriter.getStringAttr(op.getIntrinsicName()),
op.getIntrinsicOperands(operands, typeConverter, rewriter),
typeConverter, rewriter);
}
private:
const LLVMTypeConverter &typeConverter;
};
} // namespace
void mlir::populateAMXLegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<AMXIntrinsicOpConversion>(converter);
converter.addConversion([&](amx::TileType type) {
return LLVM::LLVMX86AMXType::get(&converter.getContext());
});
}
void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
target.addIllegalDialect<AMXDialect>();
}