Files
clang-p2996/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
Rafael Ubal Tena 6b4b63a832 Lowering for 'tosa.scatter'
This patch adds support for `tosa.scatter` lowering in the `--tosa-to-scf` pass. Here's an example for this lowering:

```
func.func @tosa(
                %valuesIn : tensor<3x7x5xi32>,
                %indices : tensor<3x6xi32>,
                %input : tensor<3x6x5xi32>) ->
                tensor<3x7x5xi32> {
        %0 = "tosa.scatter"(%valuesIn, %indices, %input) :
                        (tensor<3x7x5xi32>,
                        tensor<3x6xi32>,
                        tensor<3x6x5xi32>) ->
                        (tensor<3x7x5xi32>)
        return %0 : tensor<3x7x5xi32>
}
```

translates to
  func.func @tosa(%arg0: tensor<3x7x5xi32>, %arg1: tensor<3x6xi32>, %arg2: tensor<3x6x5xi32>) -> tensor<3x7x5xi32> {
    %c0 = arith.constant 0 : index
    %c3 = arith.constant 3 : index
    %c1 = arith.constant 1 : index
    %c6 = arith.constant 6 : index
    %c2 = arith.constant 2 : index
    %c5 = arith.constant 5 : index
    %c0_0 = arith.constant 0 : index
    %c1_1 = arith.constant 1 : index
    %0 = scf.for %arg3 = %c0_0 to %c3 step %c1_1 iter_args(%arg4 = %arg0) -> (tensor<3x7x5xi32>) {
      %1 = scf.for %arg5 = %c0_0 to %c6 step %c1_1 iter_args(%arg6 = %arg4) -> (tensor<3x7x5xi32>) {
        %extracted = tensor.extract %arg1[%arg3, %arg5] : tensor<3x6xi32>
        %2 = arith.index_cast %extracted : i32 to index
        %extracted_slice = tensor.extract_slice %arg2[%arg3, %arg5, %c0_0] [%c1_1, %c1_1, %c5] [%c1_1, %c1_1, %c1_1] : tensor<3x6x5xi32> to tensor<?x?x?xi32>
        %inserted_slice = tensor.insert_slice %extracted_slice into %arg6[%arg3, %2, %c0_0] [%c1_1, %c1_1, %c5] [%c1_1, %c1_1, %c1_1] : tensor<?x?x?xi32> into tensor<3x7x5xi32>
        scf.yield %inserted_slice : tensor<3x7x5xi32>
      }
      scf.yield %1 : tensor<3x7x5xi32>
    }
    return %0 : tensor<3x7x5xi32>
  }
```

We have attempted an alternative lowering pass that uses `tensor.scatter` as an intermediate step. However, we opted to aim straight at the `scf` dialect for the following reasons:

- The `tensor.scatter` op doesn't seem to be used anywhere. There is no available lowering pass for this op (although we have one that we'll upstream soon).
- The `tosa.scatter` and `tensor.scatter` op have different indexing semantics. The `indices` argument of `tosa.scatter` must be non-trivially modified and restructured (e.g. with a `linalg.generic` op) to adapt to the needs of `tensor.scatter`. While this overhead may be simplified and fused after a subsequent `tensor.scatter` lowering, it adds complex logic and an obscure intermediate state. Unless there is a good reason to go through the `tensor` dialect that we're missing, this additional complexity may not be justified.

Reviewed By: eric-k256

Differential Revision: https://reviews.llvm.org/D151117
2023-05-30 14:28:52 -07:00

181 lines
6.4 KiB
C++

//===- TosaToSCF.cpp - Lowering Tosa to SCF Dialect -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// These rewriters lower from the Tosa to the SCF dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace tosa;
static void inlineIfCase(Region &srcRegion, Region &dstRegion,
OperandRange operands, PatternRewriter &rewriter) {
rewriter.cloneRegionBefore(srcRegion, &dstRegion.front());
rewriter.eraseBlock(&dstRegion.back());
Block *headBlock = &dstRegion.front();
for (auto it : llvm::zip(headBlock->getArguments(), operands))
std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
auto yield = cast<YieldOp>(headBlock->getTerminator());
rewriter.setInsertionPoint(yield);
rewriter.create<scf::YieldOp>(yield.getLoc(), yield.getInputs());
rewriter.eraseOp(yield);
headBlock->eraseArguments(0, headBlock->getNumArguments());
}
static void inlineWhileCase(Region &srcRegion, Region &dstRegion,
PatternRewriter &rewriter, bool isCond) {
rewriter.cloneRegionBefore(srcRegion, &dstRegion.back());
rewriter.eraseBlock(&dstRegion.back());
Block *headBlock = &dstRegion.front();
auto yield = cast<YieldOp>(headBlock->getTerminator());
rewriter.setInsertionPoint(yield);
if (isCond) {
auto condition =
rewriter.create<tensor::ExtractOp>(yield.getLoc(), yield.getOperand(0));
rewriter.create<scf::ConditionOp>(yield.getLoc(), condition,
headBlock->getArguments());
} else {
rewriter.setInsertionPoint(yield);
rewriter.create<scf::YieldOp>(yield.getLoc(), yield.getInputs());
}
rewriter.eraseOp(yield);
}
namespace {
class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
public:
using OpRewritePattern<tosa::IfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::IfOp op,
PatternRewriter &rewriter) const final {
auto condition =
rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCond());
auto newIf = rewriter.create<scf::IfOp>(op.getLoc(), op.getResultTypes(),
condition, true);
inlineIfCase(op.getThenBranch(), newIf.getThenRegion(), op.getInputs(),
rewriter);
inlineIfCase(op.getElseBranch(), newIf.getElseRegion(), op.getInputs(),
rewriter);
rewriter.replaceOp(op, newIf.getResults());
return success();
}
};
class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
static Value createTensorDim(OpBuilder &builder, Location loc, Value tensor,
int64_t dim) {
return builder.createOrFold<tensor::DimOp>(loc, tensor, dim);
}
static Value createIndexConst(OpBuilder &builder, Location loc,
int64_t value) {
return builder.create<arith::ConstantIndexOp>(loc, value);
}
public:
using OpRewritePattern<tosa::ScatterOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::ScatterOp scatter,
PatternRewriter &rewriter) const final {
auto valuesIn = scatter.getValuesIn();
auto indices = scatter.getIndices();
auto input = scatter.getInput();
auto loc = scatter.getLoc();
// N, W, C are chosen to match the TOSA spec
auto dimN = createTensorDim(rewriter, loc, input, 0);
auto dimW = createTensorDim(rewriter, loc, input, 1);
auto dimC = createTensorDim(rewriter, loc, input, 2);
auto zero = createIndexConst(rewriter, loc, 0);
auto one = createIndexConst(rewriter, loc, 1);
// Loop bounds
auto lbs = llvm::SmallVector<Value>(2, zero);
auto steps = llvm::SmallVector<Value>(2, one);
auto ubs = llvm::SmallVector<Value>{{dimN, dimW}};
auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
ValueRange args) -> scf::ValueVector {
auto n = ivs[0];
// Read the index and cast it to index type
auto index = builder.create<tensor::ExtractOp>(loc, indices, ivs);
auto castIndex = builder.create<arith::IndexCastOp>(
loc, builder.getIndexType(), index);
// Offset, sizes, and strides for the input tensor
auto inputOffset = llvm::to_vector(ivs);
inputOffset.push_back(zero);
llvm::SmallVector<Value> sizes = {one, one, dimC};
llvm::SmallVector<Value> strides = {one, one, one};
auto slice = builder.create<tensor::ExtractSliceOp>(
loc, input, inputOffset, sizes, strides);
// Insert the slice into the output accumulator tensor.
llvm::SmallVector<Value> outputOffset = {n, castIndex, zero};
auto updated = builder.create<tensor::InsertSliceOp>(
loc, slice, args[0], outputOffset, sizes, strides);
return {updated};
};
auto loops = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps,
ValueRange{valuesIn}, buildBody);
rewriter.replaceOp(scatter, loops.results);
return success();
}
};
class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
public:
using OpRewritePattern<tosa::WhileOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::WhileOp op,
PatternRewriter &rewriter) const final {
auto newWhile = rewriter.create<scf::WhileOp>(
op.getLoc(), op.getResultTypes(), op.getInputs());
rewriter.createBlock(&newWhile.getBefore());
rewriter.createBlock(&newWhile.getAfter());
inlineWhileCase(op.getCond(), newWhile.getBefore(), rewriter, true);
inlineWhileCase(op.getBody(), newWhile.getAfter(), rewriter, false);
rewriter.replaceOp(op, newWhile.getResults());
return success();
}
};
} // namespace
void mlir::tosa::populateTosaToSCFConversionPatterns(
RewritePatternSet *patterns) {
patterns->add<IfOpConverter, ScatterOpConverter, WhileOpConverter>(
patterns->getContext());
}