Files
clang-p2996/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
Jacques Pienaar 8934b10642 [mlir][arith] Add overflow flags support to arith ops (#78376)
Add overflow flags support to the following ops:
* `arith.addi`
* `arith.subi`
* `arith.muli`

Example of new syntax:
```
%res = arith.addi %arg1, %arg2 overflow<nsw> : i64
```
Similar to existing LLVM dialect syntax
```
%res = llvm.add %arg1, %arg2 overflow<nsw> : i64
```

Tablegen canonicalization patterns updated to always drop flags, proper
support with tests will be added later.

Updated LLVMIR translation as part of this commit as it currenly written
in a way that it will crash when new attributes added to arith ops
otherwise.

Also lower `arith` overflow flags to corresponding SPIR-V op decorations

Discussion

https://discourse.llvm.org/t/rfc-integer-overflow-flags-support-in-arith-dialect/76025

This effectively rolls forward #77211, #77700 and #77714 while adding a
test to ensure the Python usage is not broken. More follow up needed but
unrelated to the core change here. The changes here are minimal and just
correspond to "textual namespacing" ODS side, no C++ or Python changes
were needed.

---------

---------

Co-authored-by: Ivan Butygin <ivan.butygin@gmail.com>, Yi Wu <yi.wu2@arm.com>
2024-01-17 06:12:23 +03:00

58 lines
2.4 KiB
C++

//===- AttrToLLVMConverter.cpp - Arith attributes conversion to LLVM ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
using namespace mlir;
LLVM::FastmathFlags
mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
LLVM::FastmathFlags llvmFMF{};
const std::pair<arith::FastMathFlags, LLVM::FastmathFlags> flags[] = {
{arith::FastMathFlags::nnan, LLVM::FastmathFlags::nnan},
{arith::FastMathFlags::ninf, LLVM::FastmathFlags::ninf},
{arith::FastMathFlags::nsz, LLVM::FastmathFlags::nsz},
{arith::FastMathFlags::arcp, LLVM::FastmathFlags::arcp},
{arith::FastMathFlags::contract, LLVM::FastmathFlags::contract},
{arith::FastMathFlags::afn, LLVM::FastmathFlags::afn},
{arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}};
for (auto [arithFlag, llvmFlag] : flags) {
if (bitEnumContainsAny(arithFMF, arithFlag))
llvmFMF = llvmFMF | llvmFlag;
}
return llvmFMF;
}
LLVM::FastmathFlagsAttr
mlir::arith::convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr) {
arith::FastMathFlags arithFMF = fmfAttr.getValue();
return LLVM::FastmathFlagsAttr::get(
fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF));
}
LLVM::IntegerOverflowFlags mlir::arith::convertArithOverflowFlagsToLLVM(
arith::IntegerOverflowFlags arithFlags) {
LLVM::IntegerOverflowFlags llvmFlags{};
const std::pair<arith::IntegerOverflowFlags, LLVM::IntegerOverflowFlags>
flags[] = {
{arith::IntegerOverflowFlags::nsw, LLVM::IntegerOverflowFlags::nsw},
{arith::IntegerOverflowFlags::nuw, LLVM::IntegerOverflowFlags::nuw}};
for (auto [arithFlag, llvmFlag] : flags) {
if (bitEnumContainsAny(arithFlags, arithFlag))
llvmFlags = llvmFlags | llvmFlag;
}
return llvmFlags;
}
LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOverflowAttrToLLVM(
arith::IntegerOverflowFlagsAttr flagsAttr) {
arith::IntegerOverflowFlags arithFlags = flagsAttr.getValue();
return LLVM::IntegerOverflowFlagsAttr::get(
flagsAttr.getContext(), convertArithOverflowFlagsToLLVM(arithFlags));
}