This diff completes switching Tosa to DenseArrayAttr. Test plan: ninja check-mlir check-all Differential revision: https://reviews.llvm.org/D141111
143 lines
5.1 KiB
C++
143 lines
5.1 KiB
C++
//===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// These rewriters lower from the Tosa to the Tensor dialect.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
using namespace mlir;
|
|
using namespace tosa;
|
|
|
|
namespace {
|
|
|
|
class SliceConverter : public OpRewritePattern<tosa::SliceOp> {
|
|
public:
|
|
using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
|
|
PatternRewriter &rewriter) const final {
|
|
Location loc = sliceOp.getLoc();
|
|
Value input = sliceOp.getInput();
|
|
SmallVector<int64_t> strides, sizes;
|
|
ArrayRef<int64_t> starts = sliceOp.getStart();
|
|
strides.resize(sliceOp.getType().template cast<ShapedType>().getRank(), 1);
|
|
|
|
SmallVector<Value> dynSizes;
|
|
for (const auto &i : llvm::enumerate(sliceOp.getSize())) {
|
|
int64_t size = i.value();
|
|
size_t index = i.index();
|
|
sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
|
|
if (!ShapedType::isDynamic(sizes.back()))
|
|
continue;
|
|
|
|
auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
|
|
auto offset = rewriter.create<arith::ConstantOp>(
|
|
loc, rewriter.getIndexAttr(starts[index]));
|
|
dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
|
|
}
|
|
|
|
auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
|
|
sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
|
|
ValueRange({}), rewriter.getDenseI64ArrayAttr(starts),
|
|
rewriter.getDenseI64ArrayAttr(sizes),
|
|
rewriter.getDenseI64ArrayAttr(strides));
|
|
|
|
rewriter.replaceOp(sliceOp, newSliceOp.getResult());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class PadConverter : public OpRewritePattern<tosa::PadOp> {
|
|
public:
|
|
using OpRewritePattern<tosa::PadOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::PadOp padOp,
|
|
PatternRewriter &rewriter) const final {
|
|
auto loc = padOp.getLoc();
|
|
auto input = padOp.getInput1();
|
|
auto padding = padOp.getPadding();
|
|
|
|
ShapedType inputTy = input.getType().cast<ShapedType>();
|
|
Type elementTy = inputTy.getElementType();
|
|
int64_t rank = inputTy.getRank();
|
|
|
|
// Setup the default constantAttr.
|
|
|
|
Value padConstant;
|
|
|
|
if (padOp.getPadConst()) {
|
|
padConstant = rewriter.createOrFold<tensor::ExtractOp>(
|
|
loc, padOp.getPadConst(), ValueRange({}));
|
|
} else {
|
|
Attribute constantAttr;
|
|
if (elementTy.isa<FloatType>()) {
|
|
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
|
|
} else if (elementTy.isa<IntegerType>() && !padOp.getQuantizationInfo()) {
|
|
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
|
|
} else if (elementTy.isa<IntegerType>() && padOp.getQuantizationInfo()) {
|
|
int64_t value = padOp.getQuantizationInfo()->getInputZp();
|
|
constantAttr = rewriter.getIntegerAttr(elementTy, value);
|
|
}
|
|
if (constantAttr)
|
|
padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
|
|
}
|
|
|
|
if (!padConstant) {
|
|
return rewriter.notifyMatchFailure(
|
|
padOp, "tosa.pad was unable to determine the pad constant value.");
|
|
}
|
|
|
|
Value lowIndex =
|
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
|
|
Value highIndex =
|
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
|
|
|
SmallVector<OpFoldResult, 3> lowValues;
|
|
SmallVector<OpFoldResult, 3> highValues;
|
|
|
|
lowValues.reserve(rank);
|
|
highValues.reserve(rank);
|
|
|
|
for (int i = 0; i < rank; i++) {
|
|
Value inputIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i);
|
|
Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
|
|
loc, padding, ValueRange({inputIndex, lowIndex}));
|
|
Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
|
|
loc, padding, ValueRange({inputIndex, highIndex}));
|
|
|
|
lowVal = rewriter.createOrFold<arith::IndexCastOp>(
|
|
loc, rewriter.getIndexType(), lowVal);
|
|
highVal = rewriter.createOrFold<arith::IndexCastOp>(
|
|
loc, rewriter.getIndexType(), highVal);
|
|
|
|
lowValues.push_back(lowVal);
|
|
highValues.push_back(highVal);
|
|
}
|
|
|
|
auto newPadOp = rewriter.create<tensor::PadOp>(
|
|
loc, padOp.getType(), input, lowValues, highValues, padConstant);
|
|
|
|
rewriter.replaceOp(padOp, newPadOp.getResult());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::tosa::populateTosaToTensorConversionPatterns(
|
|
RewritePatternSet *patterns) {
|
|
patterns->add<SliceConverter, PadConverter>(patterns->getContext());
|
|
}
|