Precursor: https://reviews.llvm.org/D110200 Removed redundant ops from the standard dialect that were moved to the `arith` or `math` dialects. Renamed all instances of operations in the codebase and in tests. Reviewed By: rriddle, jpienaar Differential Revision: https://reviews.llvm.org/D110797
410 lines
16 KiB
C++
410 lines
16 KiB
C++
//===- VectorMultiDimReductionTransforms.cpp - Multi-Reduction Transforms -===//
|
|
//
|
|
/// Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
|
/// Exceptions. See https://llvm.org/LICENSE.txt for license information.
|
|
/// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
/// This file implements target-independent rewrites of MultiDimReductionOp.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
|
#include "mlir/Dialect/Vector/VectorTransforms.h"
|
|
#include "mlir/Dialect/Vector/VectorUtils.h"
|
|
#include "mlir/IR/AffineExpr.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
|
|
#define DEBUG_TYPE "vector-multi-reduction"
|
|
|
|
using namespace mlir;
|
|
|
|
/// This file implements the following transformations as composable atomic
|
|
/// patterns.
|
|
|
|
/// Converts vector.multi_reduction into inner-most/outer-most reduction form
|
|
/// by using vector.transpose
|
|
class InnerOuterDimReductionConversion
|
|
: public OpRewritePattern<vector::MultiDimReductionOp> {
|
|
public:
|
|
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
|
|
|
|
explicit InnerOuterDimReductionConversion(MLIRContext *context,
|
|
bool useInnerDimsForReduction)
|
|
: mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
|
|
useInnerDimsForReduction(useInnerDimsForReduction) {}
|
|
|
|
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto src = multiReductionOp.source();
|
|
auto loc = multiReductionOp.getLoc();
|
|
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
|
|
|
|
// Separate reduction and parallel dims
|
|
auto reductionDimsRange =
|
|
multiReductionOp.reduction_dims().getAsValueRange<IntegerAttr>();
|
|
auto reductionDims = llvm::to_vector<4>(llvm::map_range(
|
|
reductionDimsRange, [](APInt a) { return a.getZExtValue(); }));
|
|
llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
|
|
reductionDims.end());
|
|
int64_t reductionSize = reductionDims.size();
|
|
SmallVector<int64_t, 4> parallelDims;
|
|
for (int64_t i = 0; i < srcRank; ++i)
|
|
if (!reductionDimsSet.contains(i))
|
|
parallelDims.push_back(i);
|
|
|
|
// Add transpose only if inner-most/outer-most dimensions are not parallel
|
|
if (useInnerDimsForReduction &&
|
|
(parallelDims ==
|
|
llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
|
|
return failure();
|
|
|
|
if (!useInnerDimsForReduction &&
|
|
(parallelDims !=
|
|
llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
|
|
return failure();
|
|
|
|
SmallVector<int64_t, 4> indices;
|
|
if (useInnerDimsForReduction) {
|
|
indices.append(parallelDims.begin(), parallelDims.end());
|
|
indices.append(reductionDims.begin(), reductionDims.end());
|
|
} else {
|
|
indices.append(reductionDims.begin(), reductionDims.end());
|
|
indices.append(parallelDims.begin(), parallelDims.end());
|
|
}
|
|
auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
|
|
SmallVector<bool> reductionMask(srcRank, false);
|
|
for (int i = 0; i < reductionSize; ++i) {
|
|
if (useInnerDimsForReduction)
|
|
reductionMask[srcRank - i - 1] = true;
|
|
else
|
|
reductionMask[i] = true;
|
|
}
|
|
rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
|
|
multiReductionOp, transposeOp.result(), reductionMask,
|
|
multiReductionOp.kind());
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
const bool useInnerDimsForReduction;
|
|
};
|
|
|
|
/// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
|
|
/// dimensions are either inner most or outer most.
|
|
class ReduceMultiDimReductionRank
|
|
: public OpRewritePattern<vector::MultiDimReductionOp> {
|
|
public:
|
|
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
|
|
|
|
explicit ReduceMultiDimReductionRank(MLIRContext *context,
|
|
bool useInnerDimsForReduction)
|
|
: mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
|
|
useInnerDimsForReduction(useInnerDimsForReduction) {}
|
|
|
|
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
|
|
auto srcShape = multiReductionOp.getSourceVectorType().getShape();
|
|
auto loc = multiReductionOp.getLoc();
|
|
|
|
// If rank less than 2, nothing to do.
|
|
if (srcRank < 2)
|
|
return failure();
|
|
|
|
// If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
|
|
SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
|
|
if (srcRank == 2 && reductionMask.front() != reductionMask.back())
|
|
return failure();
|
|
|
|
// 1. Separate reduction and parallel dims.
|
|
SmallVector<int64_t, 4> parallelDims, parallelShapes;
|
|
SmallVector<int64_t, 4> reductionDims, reductionShapes;
|
|
for (auto it : llvm::enumerate(reductionMask)) {
|
|
int64_t i = it.index();
|
|
bool isReduction = it.value();
|
|
if (isReduction) {
|
|
reductionDims.push_back(i);
|
|
reductionShapes.push_back(srcShape[i]);
|
|
} else {
|
|
parallelDims.push_back(i);
|
|
parallelShapes.push_back(srcShape[i]);
|
|
}
|
|
}
|
|
|
|
// 2. Compute flattened parallel and reduction sizes.
|
|
int flattenedParallelDim = 0;
|
|
int flattenedReductionDim = 0;
|
|
if (parallelShapes.size() > 0) {
|
|
flattenedParallelDim = 1;
|
|
for (auto d : parallelShapes)
|
|
flattenedParallelDim *= d;
|
|
}
|
|
if (reductionShapes.size() > 0) {
|
|
flattenedReductionDim = 1;
|
|
for (auto d : reductionShapes)
|
|
flattenedReductionDim *= d;
|
|
}
|
|
// We must at least have some parallel or some reduction.
|
|
assert((flattenedParallelDim || flattenedReductionDim) &&
|
|
"expected at least one parallel or reduction dim");
|
|
|
|
// 3. Fail if reduction/parallel dims are not contiguous.
|
|
// Check parallelDims are exactly [0 .. size).
|
|
int64_t counter = 0;
|
|
if (useInnerDimsForReduction &&
|
|
llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
|
|
return failure();
|
|
// Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
|
|
counter = reductionDims.size();
|
|
if (!useInnerDimsForReduction &&
|
|
llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
|
|
return failure();
|
|
|
|
// 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
|
|
// a single parallel (resp. reduction) dim.
|
|
SmallVector<bool, 2> mask;
|
|
SmallVector<int64_t, 2> vectorShape;
|
|
if (flattenedParallelDim) {
|
|
mask.push_back(false);
|
|
vectorShape.push_back(flattenedParallelDim);
|
|
}
|
|
if (flattenedReductionDim) {
|
|
mask.push_back(true);
|
|
vectorShape.push_back(flattenedReductionDim);
|
|
}
|
|
if (!useInnerDimsForReduction && vectorShape.size() == 2) {
|
|
std::swap(mask.front(), mask.back());
|
|
std::swap(vectorShape.front(), vectorShape.back());
|
|
}
|
|
auto castedType = VectorType::get(
|
|
vectorShape, multiReductionOp.getSourceVectorType().getElementType());
|
|
Value cast = rewriter.create<vector::ShapeCastOp>(
|
|
loc, castedType, multiReductionOp.source());
|
|
|
|
// 5. Creates the flattened form of vector.multi_reduction with inner/outer
|
|
// most dim as reduction.
|
|
auto newOp = rewriter.create<vector::MultiDimReductionOp>(
|
|
loc, cast, mask, multiReductionOp.kind());
|
|
|
|
// 6. If there are no parallel shapes, the result is a scalar.
|
|
// TODO: support 0-d vectors when available.
|
|
if (parallelShapes.empty()) {
|
|
rewriter.replaceOp(multiReductionOp, newOp.dest());
|
|
return success();
|
|
}
|
|
|
|
// 7. Creates shape cast for the output n-D -> 2-D
|
|
VectorType outputCastedType = VectorType::get(
|
|
parallelShapes,
|
|
multiReductionOp.getSourceVectorType().getElementType());
|
|
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
|
|
multiReductionOp, outputCastedType, newOp.dest());
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
const bool useInnerDimsForReduction;
|
|
};
|
|
|
|
/// Unrolls vector.multi_reduction with outermost reductions
|
|
/// and combines results
|
|
struct TwoDimMultiReductionToElementWise
|
|
: public OpRewritePattern<vector::MultiDimReductionOp> {
|
|
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
|
|
// Rank-2 ["parallel", "reduce"] or bail.
|
|
if (srcRank != 2)
|
|
return failure();
|
|
|
|
if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
|
|
return failure();
|
|
|
|
auto loc = multiReductionOp.getLoc();
|
|
ArrayRef<int64_t> srcShape =
|
|
multiReductionOp.getSourceVectorType().getShape();
|
|
|
|
Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
|
|
if (!elementType.isIntOrIndexOrFloat())
|
|
return failure();
|
|
|
|
Value condition;
|
|
Value result =
|
|
rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), 0)
|
|
.getResult();
|
|
for (int64_t i = 1; i < srcShape[0]; i++) {
|
|
auto operand =
|
|
rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), i);
|
|
switch (multiReductionOp.kind()) {
|
|
case vector::CombiningKind::ADD:
|
|
if (elementType.isIntOrIndex())
|
|
result = rewriter.create<arith::AddIOp>(loc, operand, result);
|
|
else
|
|
result = rewriter.create<arith::AddFOp>(loc, operand, result);
|
|
break;
|
|
case vector::CombiningKind::MUL:
|
|
if (elementType.isIntOrIndex())
|
|
result = rewriter.create<arith::MulIOp>(loc, operand, result);
|
|
else
|
|
result = rewriter.create<arith::MulFOp>(loc, operand, result);
|
|
break;
|
|
case vector::CombiningKind::MINUI:
|
|
result = rewriter.create<MinUIOp>(loc, operand, result);
|
|
break;
|
|
case vector::CombiningKind::MINSI:
|
|
result = rewriter.create<MinSIOp>(loc, operand, result);
|
|
break;
|
|
case vector::CombiningKind::MINF:
|
|
result = rewriter.create<MinFOp>(loc, operand, result);
|
|
break;
|
|
case vector::CombiningKind::MAXUI:
|
|
result = rewriter.create<MaxUIOp>(loc, operand, result);
|
|
break;
|
|
case vector::CombiningKind::MAXSI:
|
|
result = rewriter.create<MaxSIOp>(loc, operand, result);
|
|
break;
|
|
case vector::CombiningKind::MAXF:
|
|
result = rewriter.create<MaxFOp>(loc, operand, result);
|
|
break;
|
|
case vector::CombiningKind::AND:
|
|
result = rewriter.create<arith::AndIOp>(loc, operand, result);
|
|
break;
|
|
case vector::CombiningKind::OR:
|
|
result = rewriter.create<arith::OrIOp>(loc, operand, result);
|
|
break;
|
|
case vector::CombiningKind::XOR:
|
|
result = rewriter.create<arith::XOrIOp>(loc, operand, result);
|
|
break;
|
|
}
|
|
}
|
|
|
|
rewriter.replaceOp(multiReductionOp, result);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Converts 2d vector.multi_reduction with inner most reduction dimension into
|
|
/// a sequence of vector.reduction ops.
|
|
struct TwoDimMultiReductionToReduction
|
|
: public OpRewritePattern<vector::MultiDimReductionOp> {
|
|
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
|
|
if (srcRank != 2)
|
|
return failure();
|
|
|
|
if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
|
|
return failure();
|
|
|
|
auto loc = multiReductionOp.getLoc();
|
|
Value result = rewriter.create<ConstantOp>(
|
|
loc, multiReductionOp.getDestType(),
|
|
rewriter.getZeroAttr(multiReductionOp.getDestType()));
|
|
int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
|
|
|
|
// TODO: Add vector::CombiningKind attribute instead of string to
|
|
// vector.reduction.
|
|
auto getKindStr = [](vector::CombiningKind kind) {
|
|
switch (kind) {
|
|
case vector::CombiningKind::ADD:
|
|
return "add";
|
|
case vector::CombiningKind::MUL:
|
|
return "mul";
|
|
case vector::CombiningKind::MINUI:
|
|
return "minui";
|
|
case vector::CombiningKind::MINSI:
|
|
return "minsi";
|
|
case vector::CombiningKind::MINF:
|
|
return "minf";
|
|
case vector::CombiningKind::MAXUI:
|
|
return "maxui";
|
|
case vector::CombiningKind::MAXSI:
|
|
return "maxsi";
|
|
case vector::CombiningKind::MAXF:
|
|
return "maxf";
|
|
case vector::CombiningKind::AND:
|
|
return "and";
|
|
case vector::CombiningKind::OR:
|
|
return "or";
|
|
case vector::CombiningKind::XOR:
|
|
return "xor";
|
|
}
|
|
llvm_unreachable("unknown combining kind");
|
|
};
|
|
|
|
for (int i = 0; i < outerDim; ++i) {
|
|
auto v = rewriter.create<vector::ExtractOp>(
|
|
loc, multiReductionOp.source(), ArrayRef<int64_t>{i});
|
|
auto reducedValue = rewriter.create<vector::ReductionOp>(
|
|
loc, getElementTypeOrSelf(multiReductionOp.getDestType()),
|
|
rewriter.getStringAttr(getKindStr(multiReductionOp.kind())), v,
|
|
ValueRange{});
|
|
result = rewriter.create<vector::InsertElementOp>(loc, reducedValue,
|
|
result, i);
|
|
}
|
|
rewriter.replaceOp(multiReductionOp, result);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
|
|
/// form with both a single parallel and reduction dimension.
|
|
/// This is achieved with a simple vector.shape_cast that inserts a leading 1.
|
|
/// The case with a single parallel dimension is a noop and folds away
|
|
/// separately.
|
|
struct OneDimMultiReductionToTwoDim
|
|
: public OpRewritePattern<vector::MultiDimReductionOp> {
|
|
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
|
|
// Rank-1 or bail.
|
|
if (srcRank != 1)
|
|
return failure();
|
|
|
|
auto loc = multiReductionOp.getLoc();
|
|
auto srcVectorType = multiReductionOp.getSourceVectorType();
|
|
auto srcShape = srcVectorType.getShape();
|
|
auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()},
|
|
srcVectorType.getElementType());
|
|
assert(!multiReductionOp.getDestType().isa<VectorType>() &&
|
|
"multi_reduction with a single dimension expects a scalar result");
|
|
|
|
// If the unique dim is reduced and we insert a parallel in front, we need a
|
|
// {false, true} mask.
|
|
SmallVector<bool, 2> mask{false, true};
|
|
|
|
/// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
|
|
Value cast = rewriter.create<vector::ShapeCastOp>(
|
|
loc, castedType, multiReductionOp.source());
|
|
Value reduced = rewriter.create<vector::MultiDimReductionOp>(
|
|
loc, cast, mask, multiReductionOp.kind());
|
|
rewriter.replaceOpWithNewOp<vector::ExtractOp>(multiReductionOp, reduced,
|
|
ArrayRef<int64_t>{0});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
|
|
RewritePatternSet &patterns, bool useInnerDimsForReduction) {
|
|
patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank,
|
|
OneDimMultiReductionToTwoDim>(patterns.getContext(),
|
|
useInnerDimsForReduction);
|
|
if (useInnerDimsForReduction)
|
|
patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext());
|
|
else
|
|
patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext());
|
|
}
|