//===- ShapeToSCF.cpp - conversion from Shape to SCF dialect --------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h" #include "../PassDetail.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; using namespace mlir::shape; namespace { /// Converts `shape.reduce` to `scf.for`. struct ReduceOpConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ReduceOp op, PatternRewriter &rewriter) const final; }; } // namespace LogicalResult ReduceOpConverter::matchAndRewrite(ReduceOp reduceOp, PatternRewriter &rewriter) const { auto loc = reduceOp.getLoc(); Value zero = rewriter.create(loc, 0); Value one = rewriter.create(loc, 1); Value extentTensor = rewriter.create( loc, RankedTensorType::get({ShapedType::kDynamicSize}, rewriter.getIndexType()), reduceOp.shape()); Value size = rewriter.create(loc, rewriter.getIndexType(), extentTensor, zero); auto loop = rewriter.create( loc, zero, size, one, reduceOp.initVals(), [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { Value indexExtent = b.create(loc, extentTensor, iv); Value sizeExtent = b.create(loc, indexExtent); SmallVector mapped_values{iv, sizeExtent}; mapped_values.append(args.begin(), args.end()); BlockAndValueMapping mapping; Block *reduceBody = reduceOp.getBody(); mapping.map(reduceBody->getArguments(), mapped_values); for (auto &nested : reduceBody->without_terminator()) b.clone(nested, mapping); SmallVector mappedResults; for (auto result : reduceBody->getTerminator()->getOperands()) mappedResults.push_back(mapping.lookup(result)); b.create(loc, mappedResults); }); rewriter.replaceOp(reduceOp, loop.getResults()); return success(); } namespace { /// Converts `shape_of` to for loop for unranked tensors. class ShapeOfOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ShapeOfOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { ShapeOfOp::Adaptor transformed(operands); auto tensorVal = transformed.arg(); auto tensorTy = tensorVal.getType(); // For ranked tensors `shape_of` lowers to `std` and the pattern can be // found in the corresponding pass. if (tensorTy.isa()) return failure(); // Allocate stack memory. auto loc = op.getLoc(); auto rankVal = rewriter.create(loc, tensorVal); auto i64Ty = rewriter.getI64Type(); auto memTy = MemRefType::get({ShapedType::kDynamicSize}, i64Ty); auto memVal = rewriter.create(loc, memTy, ValueRange({rankVal})); // Copy shape extents to stack-allocated memory. auto zeroVal = rewriter.create(loc, 0); auto oneVal = rewriter.create(loc, 1); rewriter.create( loc, zeroVal, rankVal, oneVal, ValueRange(), [&](OpBuilder &b, Location loc, Value iVal, ValueRange args) { auto dimVal = b.create(loc, tensorVal, iVal); auto dimIntVal = b.create(loc, dimVal, i64Ty); b.create(loc, dimIntVal, memVal, ValueRange({iVal})); b.create(loc); }); // Load extents to tensor value. auto shapeIntVal = rewriter.create(loc, memVal); auto indexTy = rewriter.getIndexType(); auto shapeTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); rewriter.replaceOpWithNewOp(op.getOperation(), shapeIntVal, shapeTy); return success(); } namespace { struct ConvertShapeToSCFPass : public ConvertShapeToSCFBase { void runOnFunction() override; }; } // namespace void ConvertShapeToSCFPass::runOnFunction() { MLIRContext &ctx = getContext(); // Populate conversion patterns. OwningRewritePatternList patterns; populateShapeToSCFConversionPatterns(patterns, &ctx); // Setup target legality. ConversionTarget target(getContext()); target.addLegalDialect(); target.addIllegalOp(); // Apply conversion. if (failed(applyPartialConversion(getFunction(), target, patterns))) signalPassFailure(); } void mlir::populateShapeToSCFConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { patterns.insert(ctx); } std::unique_ptr mlir::createConvertShapeToSCFPass() { return std::make_unique(); }