Files
clang-p2996/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
Alexander Shaposhnikov 9e1a344155 [MLIR][TOSA] Switch Tosa to DenseArrayAttr
This diff completes switching Tosa to DenseArrayAttr.

Test plan: ninja check-mlir check-all

Differential revision: https://reviews.llvm.org/D141111
2023-01-06 22:57:14 +00:00

143 lines
5.1 KiB
C++

//===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// These rewriters lower from the Tosa to the Tensor dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace tosa;
namespace {
class SliceConverter : public OpRewritePattern<tosa::SliceOp> {
public:
using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
PatternRewriter &rewriter) const final {
Location loc = sliceOp.getLoc();
Value input = sliceOp.getInput();
SmallVector<int64_t> strides, sizes;
ArrayRef<int64_t> starts = sliceOp.getStart();
strides.resize(sliceOp.getType().template cast<ShapedType>().getRank(), 1);
SmallVector<Value> dynSizes;
for (const auto &i : llvm::enumerate(sliceOp.getSize())) {
int64_t size = i.value();
size_t index = i.index();
sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
if (!ShapedType::isDynamic(sizes.back()))
continue;
auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
auto offset = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(starts[index]));
dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
}
auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
ValueRange({}), rewriter.getDenseI64ArrayAttr(starts),
rewriter.getDenseI64ArrayAttr(sizes),
rewriter.getDenseI64ArrayAttr(strides));
rewriter.replaceOp(sliceOp, newSliceOp.getResult());
return success();
}
};
class PadConverter : public OpRewritePattern<tosa::PadOp> {
public:
using OpRewritePattern<tosa::PadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::PadOp padOp,
PatternRewriter &rewriter) const final {
auto loc = padOp.getLoc();
auto input = padOp.getInput1();
auto padding = padOp.getPadding();
ShapedType inputTy = input.getType().cast<ShapedType>();
Type elementTy = inputTy.getElementType();
int64_t rank = inputTy.getRank();
// Setup the default constantAttr.
Value padConstant;
if (padOp.getPadConst()) {
padConstant = rewriter.createOrFold<tensor::ExtractOp>(
loc, padOp.getPadConst(), ValueRange({}));
} else {
Attribute constantAttr;
if (elementTy.isa<FloatType>()) {
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
} else if (elementTy.isa<IntegerType>() && !padOp.getQuantizationInfo()) {
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
} else if (elementTy.isa<IntegerType>() && padOp.getQuantizationInfo()) {
int64_t value = padOp.getQuantizationInfo()->getInputZp();
constantAttr = rewriter.getIntegerAttr(elementTy, value);
}
if (constantAttr)
padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
}
if (!padConstant) {
return rewriter.notifyMatchFailure(
padOp, "tosa.pad was unable to determine the pad constant value.");
}
Value lowIndex =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Value highIndex =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult, 3> lowValues;
SmallVector<OpFoldResult, 3> highValues;
lowValues.reserve(rank);
highValues.reserve(rank);
for (int i = 0; i < rank; i++) {
Value inputIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i);
Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
loc, padding, ValueRange({inputIndex, lowIndex}));
Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
loc, padding, ValueRange({inputIndex, highIndex}));
lowVal = rewriter.createOrFold<arith::IndexCastOp>(
loc, rewriter.getIndexType(), lowVal);
highVal = rewriter.createOrFold<arith::IndexCastOp>(
loc, rewriter.getIndexType(), highVal);
lowValues.push_back(lowVal);
highValues.push_back(highVal);
}
auto newPadOp = rewriter.create<tensor::PadOp>(
loc, padOp.getType(), input, lowValues, highValues, padConstant);
rewriter.replaceOp(padOp, newPadOp.getResult());
return success();
}
};
} // namespace
void mlir::tosa::populateTosaToTensorConversionPatterns(
RewritePatternSet *patterns) {
patterns->add<SliceConverter, PadConverter>(patterns->getContext());
}