//===- ShapeToSCF.cpp - conversion from Shape to SCF dialect --------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h" #include "../PassDetail.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; using namespace mlir::shape; using namespace mlir::scf; namespace { /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is /// only defined on `tensor` operands. The test for equality first /// compares their size and, if equal, checks every extent for equality. /// /// Example: /// /// %result = shape.shape_eq %a, %b : tensor, tensor /// /// becomes /// /// %c0 = constant 0 : index /// %0 = dim %arg0, %c0 : tensor /// %1 = dim %arg1, %c0 : tensor /// %2 = cmpi "eq", %0, %1 : index /// %result = scf.if %2 -> (i1) { /// %c1 = constant 1 : index /// %true = constant true /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) { /// %5 = extract_element %arg0[%arg2] : tensor /// %6 = extract_element %arg1[%arg2] : tensor /// %7 = cmpi "eq", %5, %6 : index /// %8 = and %arg3, %7 : i1 /// scf.yield %8 : i1 /// } /// scf.yield %4 : i1 /// } else { /// %false = constant false /// scf.yield %false : i1 /// } /// struct ShapeEqOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ShapeEqOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { // For now, this lowering is only defined on `tensor` operands, not // on shapes. if (op.lhs().getType().isa() || op.rhs().getType().isa()) { return failure(); } ShapeEqOp::Adaptor transformed(operands); auto loc = op.getLoc(); Type indexTy = rewriter.getIndexType(); Value zero = rewriter.create(loc, 0); Value lhsRank = rewriter.create(loc, indexTy, transformed.lhs(), zero); Value rhsRank = rewriter.create(loc, indexTy, transformed.rhs(), zero); Value eqRank = rewriter.create(loc, CmpIPredicate::eq, lhsRank, rhsRank); Type i1Ty = rewriter.getI1Type(); rewriter.replaceOpWithNewOp( op, i1Ty, eqRank, [&](OpBuilder &b, Location loc) { Value one = b.create(loc, 1); Value init = b.create(loc, i1Ty, b.getBoolAttr(true)); auto loop = b.create( loc, zero, lhsRank, one, ValueRange{init}, [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { Value conj = args[0]; Value lhsExtent = b.create(loc, transformed.lhs(), iv); Value rhsExtent = b.create(loc, transformed.rhs(), iv); Value eqExtent = b.create(loc, CmpIPredicate::eq, lhsExtent, rhsExtent); Value conjNext = b.create(loc, conj, eqExtent); b.create(loc, ValueRange({conjNext})); }); b.create(loc, loop.getResults()); }, [&](OpBuilder &b, Location loc) { Value result = b.create(loc, i1Ty, b.getBoolAttr(false)); b.create(loc, result); }); return success(); } namespace { /// Converts `shape.reduce` to `scf.for`. struct ReduceOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(shape::ReduceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final; }; } // namespace LogicalResult ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { // For now, this lowering is only defined on `tensor` operands. if (op.shape().getType().isa()) return failure(); auto loc = op.getLoc(); shape::ReduceOp::Adaptor transformed(operands); Value zero = rewriter.create(loc, 0); Value one = rewriter.create(loc, 1); Type indexTy = rewriter.getIndexType(); Value rank = rewriter.create(loc, indexTy, transformed.shape(), zero); auto loop = rewriter.create( loc, zero, rank, one, op.initVals(), [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { Value extent = b.create(loc, transformed.shape(), iv); SmallVector mappedValues{iv, extent}; mappedValues.append(args.begin(), args.end()); BlockAndValueMapping mapping; Block *reduceBody = op.getBody(); mapping.map(reduceBody->getArguments(), mappedValues); for (auto &nested : reduceBody->without_terminator()) b.clone(nested, mapping); SmallVector mappedResults; for (auto result : reduceBody->getTerminator()->getOperands()) mappedResults.push_back(mapping.lookup(result)); b.create(loc, mappedResults); }); rewriter.replaceOp(op, loop.getResults()); return success(); } namespace { /// Converts `shape_of` to for loop for unranked tensors. class ShapeOfOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ShapeOfOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { // For now, this lowering supports only error-free arguments. if (op.getType().isa()) return failure(); // For ranked tensors `shape_of` lowers to `std` and the pattern can be // found in the corresponding pass. ShapeOfOp::Adaptor transformed(operands); Value arg = transformed.arg(); Type argTy = arg.getType(); if (argTy.isa()) return failure(); // Allocate stack memory. auto loc = op.getLoc(); Value rank = rewriter.create(loc, arg); Type indexTy = rewriter.getIndexType(); Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy); Value mem = rewriter.create(loc, memTy, ValueRange{rank}); // Copy shape extents to stack-allocated memory. Value zero = rewriter.create(loc, 0); Value one = rewriter.create(loc, 1); rewriter.create( loc, zero, rank, one, llvm::None, [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { Value dim = rewriter.create(loc, arg, iv); rewriter.create(loc, dim, mem, ValueRange{iv}); rewriter.create(loc); }); // Load extents to tensor value. rewriter.replaceOpWithNewOp(op.getOperation(), mem); return success(); } namespace { struct ConvertShapeToSCFPass : public ConvertShapeToSCFBase { void runOnFunction() override; }; } // namespace void ConvertShapeToSCFPass::runOnFunction() { MLIRContext &ctx = getContext(); // Populate conversion patterns. OwningRewritePatternList patterns; populateShapeToSCFConversionPatterns(patterns, &ctx); // Setup target legality. ConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalOp(); // Apply conversion. if (failed(applyPartialConversion(getFunction(), target, patterns))) signalPassFailure(); } void mlir::populateShapeToSCFConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { // clang-format off patterns.insert< ShapeEqOpConverter, ReduceOpConverter, ShapeOfOpConverter>(ctx); // clang-format on } std::unique_ptr mlir::createConvertShapeToSCFPass() { return std::make_unique(); }