Files
clang-p2996/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
Matthias Springer 9340996706 [mlir][tensor] Add pattern to rewrite tensor.generate as a constant
Only ops with a static tensor type and a constant yield value are rewritten.

Differential Revision: https://reviews.llvm.org/D152511
2023-06-09 12:56:07 +02:00

54 lines
1.9 KiB
C++

//===- RewriteAsConstant.cpp - Patterns to rewrite tensor ops as constants ===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::tensor;
namespace {
/// Rewrite tensor.generate with arith.constant if the yielded value is a
/// constant and the tensor type is static.
struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
using OpRewritePattern<GenerateOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenerateOp generateOp,
PatternRewriter &rewriter) const override {
auto tensorType =
llvm::cast<RankedTensorType>(generateOp.getResult().getType());
if (!tensorType.hasStaticShape())
return failure();
auto terminatorOp =
cast<tensor::YieldOp>(generateOp.getBody().front().getTerminator());
Attribute attr;
if (!matchPattern(terminatorOp.getValue(), m_Constant(&attr)))
return failure();
Operation *constantOp =
rewriter.getContext()
->getLoadedDialect<TensorDialect>()
->materializeConstant(rewriter,
DenseElementsAttr::get(tensorType, attr),
tensorType, generateOp->getLoc());
if (!constantOp)
return failure();
rewriter.replaceOp(generateOp, constantOp->getResults());
return success();
}
};
} // namespace
void mlir::tensor::populateRewriteAsConstantPatterns(
RewritePatternSet &patterns) {
patterns.add<GenerateToConstant>(patterns.getContext());
}