//===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // These rewriters lower from the Tosa to the Tensor dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/TosaToTensor/TosaToTensor.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; using namespace tosa; namespace { class SliceConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, PatternRewriter &rewriter) const final { Location loc = sliceOp.getLoc(); Value input = sliceOp.getInput(); SmallVector strides, sizes; ArrayRef starts = sliceOp.getStart(); strides.resize(sliceOp.getType().template cast().getRank(), 1); SmallVector dynSizes; for (const auto &i : llvm::enumerate(sliceOp.getSize())) { int64_t size = i.value(); size_t index = i.index(); sizes.push_back(size == -1 ? ShapedType::kDynamic : size); if (!ShapedType::isDynamic(sizes.back())) continue; auto dim = rewriter.create(loc, input, index); auto offset = rewriter.create( loc, rewriter.getIndexAttr(starts[index])); dynSizes.push_back(rewriter.create(loc, dim, offset)); } auto newSliceOp = rewriter.create( sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes, ValueRange({}), rewriter.getDenseI64ArrayAttr(starts), rewriter.getDenseI64ArrayAttr(sizes), rewriter.getDenseI64ArrayAttr(strides)); rewriter.replaceOp(sliceOp, newSliceOp.getResult()); return success(); } }; class PadConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::PadOp padOp, PatternRewriter &rewriter) const final { auto loc = padOp.getLoc(); auto input = padOp.getInput1(); auto padding = padOp.getPadding(); ShapedType inputTy = input.getType().cast(); Type elementTy = inputTy.getElementType(); int64_t rank = inputTy.getRank(); // Setup the default constantAttr. Value padConstant; if (padOp.getPadConst()) { padConstant = rewriter.createOrFold( loc, padOp.getPadConst(), ValueRange({})); } else { Attribute constantAttr; if (elementTy.isa()) { constantAttr = rewriter.getFloatAttr(elementTy, 0.0); } else if (elementTy.isa() && !padOp.getQuantizationInfo()) { constantAttr = rewriter.getIntegerAttr(elementTy, 0); } else if (elementTy.isa() && padOp.getQuantizationInfo()) { int64_t value = padOp.getQuantizationInfo()->getInputZp(); constantAttr = rewriter.getIntegerAttr(elementTy, value); } if (constantAttr) padConstant = rewriter.create(loc, constantAttr); } if (!padConstant) { return rewriter.notifyMatchFailure( padOp, "tosa.pad was unable to determine the pad constant value."); } Value lowIndex = rewriter.create(loc, rewriter.getIndexAttr(0)); Value highIndex = rewriter.create(loc, rewriter.getIndexAttr(1)); SmallVector lowValues; SmallVector highValues; lowValues.reserve(rank); highValues.reserve(rank); for (int i = 0; i < rank; i++) { Value inputIndex = rewriter.createOrFold(loc, i); Value lowVal = rewriter.createOrFold( loc, padding, ValueRange({inputIndex, lowIndex})); Value highVal = rewriter.createOrFold( loc, padding, ValueRange({inputIndex, highIndex})); lowVal = rewriter.createOrFold( loc, rewriter.getIndexType(), lowVal); highVal = rewriter.createOrFold( loc, rewriter.getIndexType(), highVal); lowValues.push_back(lowVal); highValues.push_back(highVal); } auto newPadOp = rewriter.create( loc, padOp.getType(), input, lowValues, highValues, padConstant); rewriter.replaceOp(padOp, newPadOp.getResult()); return success(); } }; } // namespace void mlir::tosa::populateTosaToTensorConversionPatterns( RewritePatternSet *patterns) { patterns->add(patterns->getContext()); }