Add overflow flags support to the following ops: * `arith.addi` * `arith.subi` * `arith.muli` Example of new syntax: ``` %res = arith.addi %arg1, %arg2 overflow<nsw> : i64 ``` Similar to existing LLVM dialect syntax ``` %res = llvm.add %arg1, %arg2 overflow<nsw> : i64 ``` Tablegen canonicalization patterns updated to always drop flags, proper support with tests will be added later. Updated LLVMIR translation as part of this commit as it currenly written in a way that it will crash when new attributes added to arith ops otherwise. Also lower `arith` overflow flags to corresponding SPIR-V op decorations Discussion https://discourse.llvm.org/t/rfc-integer-overflow-flags-support-in-arith-dialect/76025 This effectively rolls forward #77211, #77700 and #77714 while adding a test to ensure the Python usage is not broken. More follow up needed but unrelated to the core change here. The changes here are minimal and just correspond to "textual namespacing" ODS side, no C++ or Python changes were needed. --------- --------- Co-authored-by: Ivan Butygin <ivan.butygin@gmail.com>, Yi Wu <yi.wu2@arm.com>
58 lines
2.4 KiB
C++
58 lines
2.4 KiB
C++
//===- AttrToLLVMConverter.cpp - Arith attributes conversion to LLVM ------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
|
|
|
|
using namespace mlir;
|
|
|
|
LLVM::FastmathFlags
|
|
mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
|
|
LLVM::FastmathFlags llvmFMF{};
|
|
const std::pair<arith::FastMathFlags, LLVM::FastmathFlags> flags[] = {
|
|
{arith::FastMathFlags::nnan, LLVM::FastmathFlags::nnan},
|
|
{arith::FastMathFlags::ninf, LLVM::FastmathFlags::ninf},
|
|
{arith::FastMathFlags::nsz, LLVM::FastmathFlags::nsz},
|
|
{arith::FastMathFlags::arcp, LLVM::FastmathFlags::arcp},
|
|
{arith::FastMathFlags::contract, LLVM::FastmathFlags::contract},
|
|
{arith::FastMathFlags::afn, LLVM::FastmathFlags::afn},
|
|
{arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}};
|
|
for (auto [arithFlag, llvmFlag] : flags) {
|
|
if (bitEnumContainsAny(arithFMF, arithFlag))
|
|
llvmFMF = llvmFMF | llvmFlag;
|
|
}
|
|
return llvmFMF;
|
|
}
|
|
|
|
LLVM::FastmathFlagsAttr
|
|
mlir::arith::convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr) {
|
|
arith::FastMathFlags arithFMF = fmfAttr.getValue();
|
|
return LLVM::FastmathFlagsAttr::get(
|
|
fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF));
|
|
}
|
|
|
|
LLVM::IntegerOverflowFlags mlir::arith::convertArithOverflowFlagsToLLVM(
|
|
arith::IntegerOverflowFlags arithFlags) {
|
|
LLVM::IntegerOverflowFlags llvmFlags{};
|
|
const std::pair<arith::IntegerOverflowFlags, LLVM::IntegerOverflowFlags>
|
|
flags[] = {
|
|
{arith::IntegerOverflowFlags::nsw, LLVM::IntegerOverflowFlags::nsw},
|
|
{arith::IntegerOverflowFlags::nuw, LLVM::IntegerOverflowFlags::nuw}};
|
|
for (auto [arithFlag, llvmFlag] : flags) {
|
|
if (bitEnumContainsAny(arithFlags, arithFlag))
|
|
llvmFlags = llvmFlags | llvmFlag;
|
|
}
|
|
return llvmFlags;
|
|
}
|
|
|
|
LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOverflowAttrToLLVM(
|
|
arith::IntegerOverflowFlagsAttr flagsAttr) {
|
|
arith::IntegerOverflowFlags arithFlags = flagsAttr.getValue();
|
|
return LLVM::IntegerOverflowFlagsAttr::get(
|
|
flagsAttr.getContext(), convertArithOverflowFlagsToLLVM(arithFlags));
|
|
}
|