Files
clang-p2996/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
Lei Zhang a1e78615fb [mlir][complex] Canonicalize re/im(neg(create))
When can just convert this to arith.negf.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D151633
2023-05-29 17:52:48 -07:00

257 lines
8.4 KiB
C++

//===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::complex;
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
return getValue();
}
void ConstantOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "cst");
}
bool ConstantOp::isBuildableWith(Attribute value, Type type) {
if (auto arrAttr = llvm::dyn_cast<ArrayAttr>(value)) {
auto complexTy = llvm::dyn_cast<ComplexType>(type);
if (!complexTy || arrAttr.size() != 2)
return false;
auto complexEltTy = complexTy.getElementType();
if (auto fre = llvm::dyn_cast<FloatAttr>(arrAttr[0])) {
auto im = llvm::dyn_cast<FloatAttr>(arrAttr[1]);
return im && fre.getType() == complexEltTy &&
im.getType() == complexEltTy;
}
if (auto ire = llvm::dyn_cast<IntegerAttr>(arrAttr[0])) {
auto im = llvm::dyn_cast<IntegerAttr>(arrAttr[1]);
return im && ire.getType() == complexEltTy &&
im.getType() == complexEltTy;
}
}
return false;
}
LogicalResult ConstantOp::verify() {
ArrayAttr arrayAttr = getValue();
if (arrayAttr.size() != 2) {
return emitOpError(
"requires 'value' to be a complex constant, represented as array of "
"two values");
}
auto complexEltTy = getType().getElementType();
auto re = llvm::dyn_cast<FloatAttr>(arrayAttr[0]);
auto im = llvm::dyn_cast<FloatAttr>(arrayAttr[1]);
if (!re || !im)
return emitOpError("requires attribute's elements to be float attributes");
if (complexEltTy != re.getType() || complexEltTy != im.getType()) {
return emitOpError()
<< "requires attribute's element types (" << re.getType() << ", "
<< im.getType()
<< ") to match the element type of the op's return type ("
<< complexEltTy << ")";
}
return success();
}
//===----------------------------------------------------------------------===//
// CreateOp
//===----------------------------------------------------------------------===//
OpFoldResult CreateOp::fold(FoldAdaptor adaptor) {
// Fold complex.create(complex.re(op), complex.im(op)).
if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) {
if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) {
if (reOp.getOperand() == imOp.getOperand()) {
return reOp.getOperand();
}
}
}
return {};
}
//===----------------------------------------------------------------------===//
// ImOp
//===----------------------------------------------------------------------===//
OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
ArrayAttr arrayAttr =
llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
if (arrayAttr && arrayAttr.size() == 2)
return arrayAttr[1];
if (auto createOp = getOperand().getDefiningOp<CreateOp>())
return createOp.getOperand(1);
return {};
}
namespace {
template <typename OpKind, int ComponentIndex>
struct FoldComponentNeg final : OpRewritePattern<OpKind> {
using OpRewritePattern<OpKind>::OpRewritePattern;
LogicalResult matchAndRewrite(OpKind op,
PatternRewriter &rewriter) const override {
auto negOp = op.getOperand().template getDefiningOp<NegOp>();
if (!negOp)
return failure();
auto createOp = negOp.getComplex().template getDefiningOp<CreateOp>();
if (!createOp)
return failure();
Type elementType = createOp.getType().getElementType();
assert(isa<FloatType>(elementType));
rewriter.replaceOpWithNewOp<arith::NegFOp>(
op, elementType, createOp.getOperand(ComponentIndex));
return success();
}
};
} // namespace
void ImOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldComponentNeg<ImOp, 1>>(context);
}
//===----------------------------------------------------------------------===//
// ReOp
//===----------------------------------------------------------------------===//
OpFoldResult ReOp::fold(FoldAdaptor adaptor) {
ArrayAttr arrayAttr =
llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
if (arrayAttr && arrayAttr.size() == 2)
return arrayAttr[0];
if (auto createOp = getOperand().getDefiningOp<CreateOp>())
return createOp.getOperand(0);
return {};
}
void ReOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldComponentNeg<ReOp, 0>>(context);
}
//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//
OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
// complex.add(complex.sub(a, b), b) -> a
if (auto sub = getLhs().getDefiningOp<SubOp>())
if (getRhs() == sub.getRhs())
return sub.getLhs();
// complex.add(b, complex.sub(a, b)) -> a
if (auto sub = getRhs().getDefiningOp<SubOp>())
if (getLhs() == sub.getRhs())
return sub.getLhs();
// complex.add(a, complex.constant<0.0, 0.0>) -> a
if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
auto arrayAttr = constantOp.getValue();
if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
return getLhs();
}
}
return {};
}
//===----------------------------------------------------------------------===//
// SubOp
//===----------------------------------------------------------------------===//
OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
// complex.sub(complex.add(a, b), b) -> a
if (auto add = getLhs().getDefiningOp<AddOp>())
if (getRhs() == add.getRhs())
return add.getLhs();
// complex.sub(a, complex.constant<0.0, 0.0>) -> a
if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
auto arrayAttr = constantOp.getValue();
if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
return getLhs();
}
}
return {};
}
//===----------------------------------------------------------------------===//
// NegOp
//===----------------------------------------------------------------------===//
OpFoldResult NegOp::fold(FoldAdaptor adaptor) {
// complex.neg(complex.neg(a)) -> a
if (auto negOp = getOperand().getDefiningOp<NegOp>())
return negOp.getOperand();
return {};
}
//===----------------------------------------------------------------------===//
// LogOp
//===----------------------------------------------------------------------===//
OpFoldResult LogOp::fold(FoldAdaptor adaptor) {
// complex.log(complex.exp(a)) -> a
if (auto expOp = getOperand().getDefiningOp<ExpOp>())
return expOp.getOperand();
return {};
}
//===----------------------------------------------------------------------===//
// ExpOp
//===----------------------------------------------------------------------===//
OpFoldResult ExpOp::fold(FoldAdaptor adaptor) {
// complex.exp(complex.log(a)) -> a
if (auto logOp = getOperand().getDefiningOp<LogOp>())
return logOp.getOperand();
return {};
}
//===----------------------------------------------------------------------===//
// ConjOp
//===----------------------------------------------------------------------===//
OpFoldResult ConjOp::fold(FoldAdaptor adaptor) {
// complex.conj(complex.conj(a)) -> a
if (auto conjOp = getOperand().getDefiningOp<ConjOp>())
return conjOp.getOperand();
return {};
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"