This patch adds support for `tosa.scatter` lowering in the `--tosa-to-scf` pass. Here's an example for this lowering:
```
func.func @tosa(
%valuesIn : tensor<3x7x5xi32>,
%indices : tensor<3x6xi32>,
%input : tensor<3x6x5xi32>) ->
tensor<3x7x5xi32> {
%0 = "tosa.scatter"(%valuesIn, %indices, %input) :
(tensor<3x7x5xi32>,
tensor<3x6xi32>,
tensor<3x6x5xi32>) ->
(tensor<3x7x5xi32>)
return %0 : tensor<3x7x5xi32>
}
```
translates to
func.func @tosa(%arg0: tensor<3x7x5xi32>, %arg1: tensor<3x6xi32>, %arg2: tensor<3x6x5xi32>) -> tensor<3x7x5xi32> {
%c0 = arith.constant 0 : index
%c3 = arith.constant 3 : index
%c1 = arith.constant 1 : index
%c6 = arith.constant 6 : index
%c2 = arith.constant 2 : index
%c5 = arith.constant 5 : index
%c0_0 = arith.constant 0 : index
%c1_1 = arith.constant 1 : index
%0 = scf.for %arg3 = %c0_0 to %c3 step %c1_1 iter_args(%arg4 = %arg0) -> (tensor<3x7x5xi32>) {
%1 = scf.for %arg5 = %c0_0 to %c6 step %c1_1 iter_args(%arg6 = %arg4) -> (tensor<3x7x5xi32>) {
%extracted = tensor.extract %arg1[%arg3, %arg5] : tensor<3x6xi32>
%2 = arith.index_cast %extracted : i32 to index
%extracted_slice = tensor.extract_slice %arg2[%arg3, %arg5, %c0_0] [%c1_1, %c1_1, %c5] [%c1_1, %c1_1, %c1_1] : tensor<3x6x5xi32> to tensor<?x?x?xi32>
%inserted_slice = tensor.insert_slice %extracted_slice into %arg6[%arg3, %2, %c0_0] [%c1_1, %c1_1, %c5] [%c1_1, %c1_1, %c1_1] : tensor<?x?x?xi32> into tensor<3x7x5xi32>
scf.yield %inserted_slice : tensor<3x7x5xi32>
}
scf.yield %1 : tensor<3x7x5xi32>
}
return %0 : tensor<3x7x5xi32>
}
```
We have attempted an alternative lowering pass that uses `tensor.scatter` as an intermediate step. However, we opted to aim straight at the `scf` dialect for the following reasons:
- The `tensor.scatter` op doesn't seem to be used anywhere. There is no available lowering pass for this op (although we have one that we'll upstream soon).
- The `tosa.scatter` and `tensor.scatter` op have different indexing semantics. The `indices` argument of `tosa.scatter` must be non-trivially modified and restructured (e.g. with a `linalg.generic` op) to adapt to the needs of `tensor.scatter`. While this overhead may be simplified and fused after a subsequent `tensor.scatter` lowering, it adds complex logic and an obscure intermediate state. Unless there is a good reason to go through the `tensor` dialect that we're missing, this additional complexity may not be justified.
Reviewed By: eric-k256
Differential Revision: https://reviews.llvm.org/D151117
58 lines
1.9 KiB
C++
58 lines
1.9 KiB
C++
//===- TosaToSCFPass.cpp - Lowering Tosa to SCF Dialect -------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This transformation pass legalizes Tosa operations to the SCF dialect.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
|
|
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_TOSATOSCF
|
|
#include "mlir/Conversion/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
using namespace tosa;
|
|
|
|
namespace {
|
|
struct TosaToSCF : public impl::TosaToSCFBase<TosaToSCF> {
|
|
public:
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
ConversionTarget target(getContext());
|
|
target.addLegalDialect<tensor::TensorDialect, scf::SCFDialect>();
|
|
target.addIllegalOp<tosa::IfOp, tosa::ScatterOp, tosa::WhileOp>();
|
|
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
|
|
|
|
auto *op = getOperation();
|
|
mlir::tosa::populateTosaToSCFConversionPatterns(&patterns);
|
|
if (failed(applyPartialConversion(op, target, std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
std::unique_ptr<Pass> mlir::tosa::createTosaToSCF() {
|
|
return std::make_unique<TosaToSCF>();
|
|
}
|
|
|
|
void mlir::tosa::addTosaToSCFPasses(OpPassManager &pm) {
|
|
pm.addNestedPass<func::FuncOp>(createTosaToSCF());
|
|
}
|