Files
clang-p2996/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
Quinn Dawkins f310a5d2c1 [mlir][tensor] Add a tensor.concat operation (#72779)
This adds an operation for concatenating ranked tensors along a static
dimension, as well as a decomposition mirroring the existing lowering
from TOSA to Tensor. This offers a convergence point for "input" like
dialects that include various lowerings for concatenation operations,
easing later analysis. In the future, this op can implement the
necessary interfaces for tiling, as well as potentially add conversions
to some kind of linalg and/or memref counterpart.

This patch adds the op, the decomposition, and some basic
folding/canonicalization. Replacing lowerings with the op (such as the
TOSA lowering) will come as a follow up.

See
https://discourse.llvm.org/t/rfc-tensor-add-a-tensor-concatenate-operation/74858
2023-12-01 15:05:29 -05:00

94 lines
3.5 KiB
C++

//===- ConcatOpPatterns.cpp - Patterns related to tensor.concat lowering --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::tensor;
namespace {
/// Decompose `tensor.concat` into `tensor.empty` and a chain of slice inserts.
///
/// %concat = tensor.concat dim(1) %0, %1 :
/// (tensor<2x3xf32>, tensor<2x4xf32>) -> tensor<2x7xf32>
///
/// Becomes
///
/// %empty = tensor.empty() : tensor<2x7xf32>
/// %insert0 = tensor.insert_slice %0 into %empty[0, 0][2, 3][1, 1]
/// %concat = tensor.insert_slice %1 into %insert0[0, 3][2, 4][1, 1]
struct DecomposeTensorConcatOp : public OpRewritePattern<ConcatOp> {
using OpRewritePattern<ConcatOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ConcatOp concatOp,
PatternRewriter &rewriter) const override {
Location loc = concatOp.getLoc();
FailureOr<Value> dest =
tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0));
if (failed(dest))
return failure();
auto empty = dest->getDefiningOp<tensor::EmptyOp>();
if (!empty)
return failure();
int64_t dim = concatOp.getDim();
Value dimValue = rewriter.createOrFold<arith::ConstantOp>(
loc, rewriter.getIndexAttr(dim));
int64_t rank = concatOp.getResultType().getRank();
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
// Compute the partial sums for the slice offsets.
AffineExpr sum = rewriter.getAffineDimExpr(0);
SmallVector<AffineExpr> partialSums = {sum};
SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
for (auto [idx, input] :
llvm::enumerate(concatOp.getInputs().drop_back())) {
sum = sum + rewriter.getAffineDimExpr(idx + 1);
partialSums.push_back(sum);
offsetStrides.push_back(
rewriter.createOrFold<tensor::DimOp>(loc, input, dimValue));
}
auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
partialSums, rewriter.getContext());
SmallVector<OpFoldResult> dimOffsets =
affine::makeComposedFoldedMultiResultAffineApply(
rewriter, loc, partialSumMap, offsetStrides);
// Construct the chain of insert_slice ops into the destination.
Value result = *dest;
for (auto [input, offset] :
llvm::zip_equal(concatOp.getInputs(), dimOffsets)) {
SmallVector<OpFoldResult> sizes =
tensor::getMixedSizes(rewriter, loc, input);
offsets[dim] = offset;
result = rewriter.createOrFold<tensor::InsertSliceOp>(
loc, input, result, offsets, sizes, strides);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(
concatOp, concatOp.getResultType(), result);
return success();
}
};
} // namespace
void mlir::tensor::populateDecomposeTensorConcatPatterns(
RewritePatternSet &patterns) {
patterns.add<DecomposeTensorConcatOp>(patterns.getContext());
}