Updates the linalg::vectorize function to return a `FailureOr<VectorizationResult>` containing the values to replace the original operation, instead of directly replacing the original operation. This aligns better with the style of transforms used with the TilingInterface, and gives more control to users over the lowering, since it allows for additional transformation of the IR before replacement. There was already a `VectorizationResult` defined, which was used for the internal vectorize implementation using `CustomVectorizationHook`s, so the old struct is renamed to `VectorizationHookResult`. Note for integration: The replacement of the original operation is now the responsibility of the caller, so wherever `linalg::vectorize` is used, the caller must also do `rewriter.replaceOp(vectorizeResults->replacements)`. --------- Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
4123 lines
167 KiB
C++
4123 lines
167 KiB
C++
//===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements the linalg dialect Vectorization transformations.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
#include "mlir/Dialect/Affine/Utils.h"
|
|
|
|
#include "mlir/Analysis/SliceAnalysis.h"
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Tensor/Utils/Utils.h"
|
|
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
|
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
|
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
|
|
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
|
#include "mlir/IR/AffineExpr.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/OpDefinition.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/Value.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
#include "mlir/Transforms/RegionUtils.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/Sequence.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include "llvm/ADT/iterator_range.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/MathExtras.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include <optional>
|
|
#include <type_traits>
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
|
|
#define DEBUG_TYPE "linalg-vectorization"
|
|
|
|
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
|
|
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
|
|
|
|
/// Try to vectorize `convOp` as a convolution.
|
|
static FailureOr<Operation *>
|
|
vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
|
|
ArrayRef<int64_t> inputVecSizes = {},
|
|
ArrayRef<bool> inputVecScalableFlags = {},
|
|
bool flatten1DDepthwiseConv = false);
|
|
|
|
/// Vectorize tensor::InsertSliceOp with:
|
|
/// * vector::TransferReadOp + vector::TransferWriteOp
|
|
/// The vector sizes are either:
|
|
/// * user-provided in `inputVectorSizes`, or
|
|
/// * inferred from the static dims in the input and output tensors.
|
|
/// Bails out if:
|
|
/// * vector sizes are not user-provided, and
|
|
/// * at least one dim is dynamic (in both the input and output tensors).
|
|
///
|
|
/// Before:
|
|
/// !t_in_type = tensor<1x2x3xf32>
|
|
/// !t_out_type = tensor<9x8x7x1x2x3xf32>
|
|
/// !v_type = vector<1x2x3xf32>
|
|
/// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
|
|
/// into !t_out_type
|
|
/// After:
|
|
/// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
|
|
/// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
|
|
static LogicalResult
|
|
vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
|
|
ArrayRef<int64_t> inputVectorSizes,
|
|
SmallVectorImpl<Value> &newResults);
|
|
|
|
/// Returns the effective Pad value for the input op, provided it's a scalar.
|
|
///
|
|
/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
|
|
/// this Op performs padding, retrieve the padding value provided that it's
|
|
/// a scalar and static/fixed for all the padded values. Returns an empty value
|
|
/// otherwise.
|
|
static Value getStaticPadVal(Operation *op);
|
|
|
|
/// Return the unique instance of OpType in `block` if it is indeed unique.
|
|
/// Return null if none or more than 1 instances exist.
|
|
template <typename OpType>
|
|
static OpType getSingleOpOfType(Block &block) {
|
|
OpType res;
|
|
block.walk([&](OpType op) {
|
|
if (res) {
|
|
res = nullptr;
|
|
return WalkResult::interrupt();
|
|
}
|
|
res = op;
|
|
return WalkResult::advance();
|
|
});
|
|
return res;
|
|
}
|
|
|
|
/// Helper function to extract the input slices after filter is unrolled along
|
|
/// kw.
|
|
static SmallVector<Value>
|
|
extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input,
|
|
int64_t nSize, int64_t wSize, int64_t cSize,
|
|
int64_t kwSize, int strideW, int dilationW,
|
|
int64_t wSizeStep, bool isSingleChanneled) {
|
|
SmallVector<Value> result;
|
|
if (isSingleChanneled) {
|
|
// Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
|
|
// convolution.
|
|
SmallVector<int64_t> sizes = {wSizeStep};
|
|
SmallVector<int64_t> strides = {1};
|
|
for (int64_t kw = 0; kw < kwSize; ++kw) {
|
|
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
|
result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
|
|
loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides));
|
|
}
|
|
}
|
|
} else {
|
|
// Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
|
|
// for channeled convolution.
|
|
SmallVector<int64_t> sizes = {nSize, wSizeStep, cSize};
|
|
SmallVector<int64_t> strides = {1, 1, 1};
|
|
for (int64_t kw = 0; kw < kwSize; ++kw) {
|
|
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
|
result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
|
|
loc, input,
|
|
/*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
|
|
sizes, strides));
|
|
}
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/// Helper function to extract the filter slices after filter is unrolled along
|
|
/// kw.
|
|
static SmallVector<Value> extractConvFilterSlices(RewriterBase &rewriter,
|
|
Location loc, Value filter,
|
|
int64_t kwSize) {
|
|
SmallVector<Value> result;
|
|
// Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
|
|
// non-chanelled convolution] @ [kw].
|
|
for (int64_t kw = 0; kw < kwSize; ++kw) {
|
|
result.push_back(rewriter.create<vector::ExtractOp>(
|
|
loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/// Helper function to extract the result slices after filter is unrolled along
|
|
/// kw.
|
|
static SmallVector<Value>
|
|
extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res,
|
|
int64_t nSize, int64_t wSize, int64_t fSize,
|
|
int64_t wSizeStep, bool isSingleChanneled) {
|
|
SmallVector<Value> result;
|
|
if (isSingleChanneled) {
|
|
// Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
|
|
SmallVector<int64_t> sizes = {wSizeStep};
|
|
SmallVector<int64_t> strides = {1};
|
|
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
|
result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
|
|
loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides));
|
|
}
|
|
} else {
|
|
// Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
|
|
// convolution.
|
|
SmallVector<int64_t> sizes = {nSize, wSizeStep, fSize};
|
|
SmallVector<int64_t> strides = {1, 1, 1};
|
|
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
|
result.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
|
|
loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides));
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/// Helper function to insert the computed result slices.
|
|
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc,
|
|
Value res, int64_t wSize, int64_t wSizeStep,
|
|
SmallVectorImpl<Value> &resVals,
|
|
bool isSingleChanneled) {
|
|
|
|
if (isSingleChanneled) {
|
|
// Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
|
|
// This does not depend on kw.
|
|
SmallVector<int64_t> strides = {1};
|
|
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
|
res = rewriter.create<vector::InsertStridedSliceOp>(
|
|
loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides);
|
|
}
|
|
} else {
|
|
// Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
|
|
// convolution. This does not depend on kw.
|
|
SmallVector<int64_t> strides = {1, 1, 1};
|
|
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
|
res = rewriter.create<vector::InsertStridedSliceOp>(
|
|
loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0},
|
|
strides);
|
|
}
|
|
}
|
|
return res;
|
|
}
|
|
|
|
/// Contains the vectorization state and related methods used across the
|
|
/// vectorization process of a given operation.
|
|
struct VectorizationState {
|
|
VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
|
|
|
|
/// Initializes the vectorization state, including the computation of the
|
|
/// canonical vector shape for vectorization.
|
|
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
|
|
ArrayRef<int64_t> inputVectorSizes,
|
|
ArrayRef<bool> inputScalableVecDims);
|
|
|
|
/// Returns the canonical vector shape used to vectorize the iteration space.
|
|
ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
|
|
|
|
/// Returns the vector dimensions that are scalable in the canonical vector
|
|
/// shape.
|
|
ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
|
|
|
|
/// Returns a vector type of the provided `elementType` with the canonical
|
|
/// vector shape and the corresponding fixed/scalable dimensions bit. If
|
|
/// `dimPermutation` is provided, the canonical vector dimensions are permuted
|
|
/// accordingly.
|
|
VectorType getCanonicalVecType(
|
|
Type elementType,
|
|
std::optional<AffineMap> dimPermutation = std::nullopt) const {
|
|
SmallVector<int64_t> vectorShape;
|
|
SmallVector<bool> scalableDims;
|
|
if (dimPermutation.has_value()) {
|
|
vectorShape =
|
|
applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
|
|
scalableDims =
|
|
applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
|
|
} else {
|
|
vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
|
|
scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
|
|
}
|
|
|
|
return VectorType::get(vectorShape, elementType, scalableDims);
|
|
}
|
|
|
|
/// Masks an operation with the canonical vector mask if the operation needs
|
|
/// masking. Returns the masked operation or the original operation if masking
|
|
/// is not needed. If provided, the canonical mask for this operation is
|
|
/// permuted using `maybeIndexingMap`.
|
|
Operation *
|
|
maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
|
|
std::optional<AffineMap> maybeIndexingMap = std::nullopt);
|
|
|
|
private:
|
|
/// Initializes the iteration space static sizes using the Linalg op
|
|
/// information. This may become more complicated in the future.
|
|
void initIterSpaceStaticSizes(LinalgOp linalgOp) {
|
|
iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
|
|
}
|
|
|
|
/// Generates 'arith.constant' and 'tensor/memref.dim' operations for
|
|
/// all the static and dynamic dimensions of the iteration space to be
|
|
/// vectorized and store them in `iterSpaceValueSizes`.
|
|
LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
|
|
LinalgOp linalgOp);
|
|
|
|
/// Create or retrieve an existing mask value to mask `opToMask` in the
|
|
/// canonical vector iteration space. If `maybeMaskingMap` the mask is
|
|
/// permuted using that permutation map. If a new mask is created, it will be
|
|
/// cached for future users.
|
|
Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
|
|
LinalgOp linalgOp,
|
|
std::optional<AffineMap> maybeMaskingMap);
|
|
|
|
/// Check whether this permutation map can be used for masking. At the
|
|
/// moment we only make sure that there are no broadcast dimensions, but this
|
|
/// might change if indexing maps evolve.
|
|
bool isValidMaskingMap(AffineMap maskingMap) {
|
|
return maskingMap.getBroadcastDims().size() == 0;
|
|
}
|
|
|
|
/// Turn the input indexing map into a valid masking map.
|
|
///
|
|
/// The input indexing map may contain "zero" results, e.g.:
|
|
/// (d0, d1, d2, d3) -> (d2, d1, d0, 0)
|
|
/// Applying such maps to canonical vector shapes like this one:
|
|
/// (1, 16, 16, 4)
|
|
/// would yield an invalid vector shape like this:
|
|
/// (16, 16, 1, 0)
|
|
/// Instead, drop the broadcasting dims that make no sense for masking perm.
|
|
/// maps:
|
|
/// (d0, d1, d2, d3) -> (d2, d1, d0)
|
|
/// This way, the corresponding vector/mask type will be:
|
|
/// vector<16x16x1xty>
|
|
/// rather than this invalid Vector type:
|
|
/// vector<16x16x1x0xty>
|
|
AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
|
|
return indexingMap.dropZeroResults();
|
|
}
|
|
|
|
// Holds the compile-time static sizes of the iteration space to vectorize.
|
|
// Dynamic dimensions are represented using ShapedType::kDynamic.
|
|
SmallVector<int64_t> iterSpaceStaticSizes;
|
|
|
|
/// Holds the value sizes of the iteration space to vectorize. Static
|
|
/// dimensions are represented by 'arith.constant' and dynamic
|
|
/// dimensions by 'tensor/memref.dim'.
|
|
SmallVector<Value> iterSpaceValueSizes;
|
|
|
|
/// Holds the canonical vector shape used to vectorize the iteration space.
|
|
SmallVector<int64_t> canonicalVecShape;
|
|
|
|
/// Holds the vector dimensions that are scalable in the canonical vector
|
|
/// shape.
|
|
SmallVector<bool> scalableVecDims;
|
|
|
|
/// Holds the active masks for permutations of the canonical vector iteration
|
|
/// space.
|
|
DenseMap<AffineMap, Value> activeMaskCache;
|
|
|
|
/// Global vectorization guard for the incoming rewriter. It's initialized
|
|
/// when the vectorization state is initialized.
|
|
OpBuilder::InsertionGuard rewriterGuard;
|
|
};
|
|
|
|
LogicalResult
|
|
VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
|
|
LinalgOp linalgOp) {
|
|
// TODO: Support 0-d vectors.
|
|
for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
|
|
if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
|
|
// Create constant index op for static dimensions.
|
|
iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
|
|
linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
|
|
continue;
|
|
}
|
|
|
|
// Find an operand defined on this dimension of the iteration space to
|
|
// extract the runtime dimension size.
|
|
Value operand;
|
|
unsigned operandDimPos;
|
|
if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
|
|
operandDimPos)))
|
|
return failure();
|
|
|
|
Value dynamicDim = linalgOp.hasPureTensorSemantics()
|
|
? (Value)rewriter.create<tensor::DimOp>(
|
|
linalgOp.getLoc(), operand, operandDimPos)
|
|
: (Value)rewriter.create<memref::DimOp>(
|
|
linalgOp.getLoc(), operand, operandDimPos);
|
|
iterSpaceValueSizes.push_back(dynamicDim);
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Initializes the vectorization state, including the computation of the
|
|
/// canonical vector shape for vectorization.
|
|
// TODO: Move this to the constructor when we can remove the failure cases.
|
|
LogicalResult
|
|
VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
|
|
ArrayRef<int64_t> inputVectorSizes,
|
|
ArrayRef<bool> inputScalableVecDims) {
|
|
// Initialize the insertion point.
|
|
rewriter.setInsertionPoint(linalgOp);
|
|
|
|
if (!inputVectorSizes.empty()) {
|
|
// Get the canonical vector shape from the input vector sizes provided. This
|
|
// path should be taken to vectorize code with dynamic shapes and when using
|
|
// vector sizes greater than the iteration space sizes.
|
|
canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
|
|
scalableVecDims.append(inputScalableVecDims.begin(),
|
|
inputScalableVecDims.end());
|
|
} else {
|
|
// Compute the canonical vector shape from the operation shape. If there are
|
|
// dynamic shapes, the operation won't be vectorized. We assume all the
|
|
// vector dimensions are fixed.
|
|
canonicalVecShape = linalgOp.getStaticLoopRanges();
|
|
scalableVecDims.append(linalgOp.getNumLoops(), false);
|
|
}
|
|
|
|
LDBG("Canonical vector shape: ");
|
|
LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
|
|
LLVM_DEBUG(llvm::dbgs() << "\n");
|
|
LDBG("Scalable vector dims: ");
|
|
LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
|
|
LLVM_DEBUG(llvm::dbgs() << "\n");
|
|
|
|
if (ShapedType::isDynamicShape(canonicalVecShape))
|
|
return failure();
|
|
|
|
// Initialize iteration space static sizes.
|
|
initIterSpaceStaticSizes(linalgOp);
|
|
|
|
// Generate 'arith.constant' and 'tensor/memref.dim' operations for
|
|
// all the static and dynamic dimensions of the iteration space, needed to
|
|
// compute a mask during vectorization.
|
|
if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Create or retrieve an existing mask value to mask `opToMask` in the
|
|
/// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
|
|
/// using that permutation map. If a new mask is created, it will be cached for
|
|
/// future users.
|
|
Value VectorizationState::getOrCreateMaskFor(
|
|
RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
|
|
std::optional<AffineMap> maybeMaskingMap) {
|
|
|
|
assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
|
|
"Ill-formed masking map.");
|
|
|
|
// No mask is needed if the operation is not maskable.
|
|
auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
|
|
if (!maskableOp)
|
|
return Value();
|
|
|
|
assert(!maskableOp.isMasked() &&
|
|
"Masking an operation that is already masked");
|
|
|
|
// If no masking map was provided, use an identity map with the loop dims.
|
|
assert((!maybeMaskingMap || *maybeMaskingMap) &&
|
|
"Unexpected null mask permutation map");
|
|
AffineMap maskingMap =
|
|
maybeMaskingMap ? *maybeMaskingMap
|
|
: AffineMap::getMultiDimIdentityMap(
|
|
linalgOp.getNumLoops(), rewriter.getContext());
|
|
|
|
LDBG("Masking map: " << maskingMap << "\n");
|
|
|
|
// Return the active mask for the masking map of this operation if it was
|
|
// already created.
|
|
auto activeMaskIt = activeMaskCache.find(maskingMap);
|
|
if (activeMaskIt != activeMaskCache.end()) {
|
|
Value mask = activeMaskIt->second;
|
|
LDBG("Reusing mask: " << mask << "\n");
|
|
return mask;
|
|
}
|
|
|
|
// Compute permuted projection of the iteration space to be masked and the
|
|
// corresponding mask shape. If the resulting iteration space dimensions are
|
|
// static and identical to the mask shape, masking is not needed for this
|
|
// operation.
|
|
// TODO: Improve this check. Only projected permutation indexing maps are
|
|
// supported.
|
|
SmallVector<int64_t> permutedStaticSizes =
|
|
applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
|
|
auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
|
|
auto maskShape = maskType.getShape();
|
|
|
|
LDBG("Mask shape: ");
|
|
LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
|
|
LLVM_DEBUG(llvm::dbgs() << "\n");
|
|
|
|
if (permutedStaticSizes == maskShape) {
|
|
LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
|
|
activeMaskCache[maskingMap] = Value();
|
|
return Value();
|
|
}
|
|
|
|
// Permute the iteration space value sizes to compute the mask upper bounds.
|
|
SmallVector<Value> upperBounds =
|
|
applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
|
|
assert(!maskShape.empty() && !upperBounds.empty() &&
|
|
"Masked 0-d vectors are not supported yet");
|
|
|
|
// Create the mask based on the dimension values.
|
|
Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
|
|
maskType, upperBounds);
|
|
LDBG("Creating new mask: " << mask << "\n");
|
|
activeMaskCache[maskingMap] = mask;
|
|
return mask;
|
|
}
|
|
|
|
Operation *
|
|
VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
|
|
LinalgOp linalgOp,
|
|
std::optional<AffineMap> maybeIndexingMap) {
|
|
LDBG("Trying to mask: " << *opToMask << "\n");
|
|
|
|
std::optional<AffineMap> maybeMaskingMap = std::nullopt;
|
|
if (maybeIndexingMap)
|
|
maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
|
|
|
|
// Create or retrieve mask for this operation.
|
|
Value mask =
|
|
getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
|
|
|
|
if (!mask) {
|
|
LDBG("No mask required\n");
|
|
return opToMask;
|
|
}
|
|
|
|
// Wrap the operation with a new `vector.mask` and update D-U chain.
|
|
assert(opToMask && "Expected a valid operation to mask");
|
|
auto maskOp = cast<vector::MaskOp>(
|
|
mlir::vector::maskOperation(rewriter, opToMask, mask));
|
|
Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
|
|
|
|
for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
|
|
rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
|
|
maskOpTerminator);
|
|
|
|
LDBG("Masked operation: " << *maskOp << "\n");
|
|
return maskOp;
|
|
}
|
|
|
|
/// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
|
|
/// projectedPermutation, compress the unused dimensions to serve as a
|
|
/// permutation_map for a vector transfer operation.
|
|
/// For example, given a linalg op such as:
|
|
///
|
|
/// ```
|
|
/// %0 = linalg.generic {
|
|
/// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
|
|
/// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
|
|
/// }
|
|
/// ins(%0 : tensor<2x3x4xf32>)
|
|
/// outs(%1 : tensor<5x6xf32>)
|
|
/// ```
|
|
///
|
|
/// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
|
|
/// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
|
|
/// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
|
|
static AffineMap reindexIndexingMap(AffineMap map) {
|
|
assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
|
|
"expected projected permutation");
|
|
auto res = compressUnusedDims(map);
|
|
assert(res.getNumDims() ==
|
|
(res.getNumResults() - res.getNumOfZeroResults()) &&
|
|
"expected reindexed map with same number of dims and results");
|
|
return res;
|
|
}
|
|
|
|
/// Helper enum to represent conv1d input traversal order.
|
|
enum class Conv1DOpOrder {
|
|
W, // Corresponds to non-channeled 1D convolution operation.
|
|
Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
|
|
Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
|
|
};
|
|
|
|
/// Helper data structure to represent the result of vectorization for a single
|
|
/// operation. In certain specific cases, like terminators, we do not want to
|
|
/// propagate.
|
|
enum VectorizationHookStatus {
|
|
/// Op failed to vectorize.
|
|
Failure = 0,
|
|
/// Op vectorized and custom function took care of replacement logic
|
|
NoReplace,
|
|
/// Op vectorized into a new Op whose results will replace original Op's
|
|
/// results.
|
|
NewOp
|
|
// TODO: support values if Op vectorized to Many-Ops whose results we need to
|
|
// aggregate for replacement.
|
|
};
|
|
/// VectorizationHookResult contains the vectorized op returned from a
|
|
/// CustomVectorizationHook. This is an internal implementation detail of
|
|
/// linalg vectorization, not to be confused with VectorizationResult.
|
|
struct VectorizationHookResult {
|
|
/// Return status from vectorizing the current op.
|
|
enum VectorizationHookStatus status = VectorizationHookStatus::Failure;
|
|
/// New vectorized operation to replace the current op.
|
|
/// Replacement behavior is specified by `status`.
|
|
Operation *newOp;
|
|
};
|
|
|
|
std::optional<vector::CombiningKind>
|
|
mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
|
|
using ::mlir::vector::CombiningKind;
|
|
|
|
if (!combinerOp)
|
|
return std::nullopt;
|
|
return llvm::TypeSwitch<Operation *, std::optional<CombiningKind>>(combinerOp)
|
|
.Case<arith::AddIOp, arith::AddFOp>(
|
|
[&](auto op) { return CombiningKind::ADD; })
|
|
.Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
|
|
.Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
|
|
.Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
|
|
.Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
|
|
.Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; })
|
|
.Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
|
|
.Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
|
|
.Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
|
|
.Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; })
|
|
.Case<arith::MulIOp, arith::MulFOp>(
|
|
[&](auto op) { return CombiningKind::MUL; })
|
|
.Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
|
|
.Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
|
|
.Default([&](auto op) { return std::nullopt; });
|
|
}
|
|
|
|
/// Check whether `outputOperand` is a reduction with a single combiner
|
|
/// operation. Return the combiner operation of the reduction. Return
|
|
/// nullptr otherwise. Multiple reduction operations would impose an
|
|
/// ordering between reduction dimensions and is currently unsupported in
|
|
/// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
|
|
/// max(min(X))
|
|
// TODO: use in LinalgOp verification, there is a circular dependency atm.
|
|
static Operation *matchLinalgReduction(OpOperand *outputOperand) {
|
|
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
|
|
unsigned outputPos =
|
|
outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
|
|
// Only single combiner operations are supported for now.
|
|
SmallVector<Operation *, 4> combinerOps;
|
|
if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
|
|
combinerOps.size() != 1)
|
|
return nullptr;
|
|
|
|
// Return the combiner operation.
|
|
return combinerOps[0];
|
|
}
|
|
|
|
/// Broadcast `value` to a vector of `shape` if possible. Return value
|
|
/// otherwise.
|
|
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
|
|
auto dstVecType = dyn_cast<VectorType>(dstType);
|
|
// If no shape to broadcast to, just return `value`.
|
|
if (dstVecType.getRank() == 0)
|
|
return value;
|
|
if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
|
|
vector::BroadcastableToResult::Success)
|
|
return value;
|
|
Location loc = b.getInsertionPoint()->getLoc();
|
|
return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
|
|
}
|
|
|
|
/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
|
|
/// assumes that `reductionOp` has two operands and one of them is the reduction
|
|
/// initial value.buildMultiDimReduce
|
|
// Note: this is a true builder that notifies the OpBuilder listener.
|
|
// TODO: Consider moving as a static helper on the ReduceOp.
|
|
static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
|
|
Value valueToReduce, Value acc,
|
|
ArrayRef<bool> dimsToMask) {
|
|
auto maybeKind = getCombinerOpKind(reduceOp);
|
|
assert(maybeKind && "Failed precondition: could not get reduction kind");
|
|
return b.create<vector::MultiDimReductionOp>(
|
|
reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
|
|
}
|
|
|
|
static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
|
|
return llvm::to_vector(
|
|
llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
|
|
}
|
|
|
|
/// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
|
|
/// reduction iterator.
|
|
static bool hasReductionIterator(LinalgOp &op) {
|
|
return isa<linalg::ReduceOp>(op) ||
|
|
(isa<linalg::GenericOp>(op) &&
|
|
llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
|
|
}
|
|
|
|
/// Build a vector.transfer_write of `value` into `outputOperand` at indices set
|
|
/// to all `0`; where `outputOperand` is an output operand of the LinalgOp
|
|
/// currently being vectorized. If `dest` has null rank, build an memref.store.
|
|
/// Return the produced value or null if no value is produced.
|
|
// Note: this is a true builder that notifies the OpBuilder listener.
|
|
// TODO: Consider moving as a static helper on the ReduceOp.
|
|
static Value buildVectorWrite(RewriterBase &rewriter, Value value,
|
|
OpOperand *outputOperand,
|
|
VectorizationState &state) {
|
|
Location loc = value.getLoc();
|
|
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
|
|
AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
|
|
|
|
// Compute the vector type of the value to store. This type should be an
|
|
// identity or projection of the canonical vector type without any permutation
|
|
// applied, given that any permutation in a transfer write happens as part of
|
|
// the write itself.
|
|
AffineMap vectorTypeMap = AffineMap::getFilteredIdentityMap(
|
|
opOperandMap.getContext(), opOperandMap.getNumInputs(),
|
|
[&](AffineDimExpr dimExpr) -> bool {
|
|
return llvm::is_contained(opOperandMap.getResults(), dimExpr);
|
|
});
|
|
auto vectorType = state.getCanonicalVecType(
|
|
getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
|
|
|
|
Operation *write;
|
|
if (vectorType.getRank() > 0) {
|
|
AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
|
|
SmallVector<Value> indices(linalgOp.getRank(outputOperand),
|
|
rewriter.create<arith::ConstantIndexOp>(loc, 0));
|
|
value = broadcastIfNeeded(rewriter, value, vectorType);
|
|
assert(value.getType() == vectorType && "Incorrect type");
|
|
write = rewriter.create<vector::TransferWriteOp>(
|
|
loc, value, outputOperand->get(), indices, writeMap);
|
|
} else {
|
|
// 0-d case is still special: do not invert the reindexing writeMap.
|
|
if (!isa<VectorType>(value.getType()))
|
|
value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
|
|
assert(value.getType() == vectorType && "Incorrect type");
|
|
write = rewriter.create<vector::TransferWriteOp>(
|
|
loc, value, outputOperand->get(), ValueRange{});
|
|
}
|
|
|
|
write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
|
|
|
|
// If masked, set in-bounds to true. Masking guarantees that the access will
|
|
// be in-bounds.
|
|
if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
|
|
auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
|
|
SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
|
|
maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
|
|
}
|
|
|
|
LDBG("vectorized op: " << *write << "\n");
|
|
if (!write->getResults().empty())
|
|
return write->getResult(0);
|
|
return Value();
|
|
}
|
|
|
|
// Custom vectorization precondition function type. This is intented to be used
|
|
// with CustomVectorizationHook. Returns success if the corresponding custom
|
|
// hook can vectorize the op.
|
|
using CustomVectorizationPrecondition =
|
|
std::function<LogicalResult(Operation *, bool)>;
|
|
|
|
// Custom vectorization function type. Produce a vector form of Operation*
|
|
// assuming all its vectorized operands are already in the IRMapping.
|
|
// Return nullptr if the Operation cannot be vectorized.
|
|
using CustomVectorizationHook =
|
|
std::function<VectorizationHookResult(Operation *, const IRMapping &)>;
|
|
|
|
/// Helper function to vectorize the terminator of a `linalgOp`. New result
|
|
/// vector values are appended to `newResults`. Return
|
|
/// VectorizationHookStatus::NoReplace to signal the vectorization algorithm
|
|
/// that it should not try to map produced operations and instead return the
|
|
/// results using the `newResults` vector making them available to the
|
|
/// vectorization algorithm for RAUW. This function is meant to be used as a
|
|
/// CustomVectorizationHook.
|
|
static VectorizationHookResult
|
|
vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
|
|
const IRMapping &bvm, VectorizationState &state,
|
|
LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
|
|
auto yieldOp = dyn_cast<linalg::YieldOp>(op);
|
|
if (!yieldOp)
|
|
return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
|
|
for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
|
|
// TODO: Scan for an opportunity for reuse.
|
|
// TODO: use a map.
|
|
Value vectorValue = bvm.lookup(output.value());
|
|
Value newResult =
|
|
buildVectorWrite(rewriter, vectorValue,
|
|
linalgOp.getDpsInitOperand(output.index()), state);
|
|
if (newResult)
|
|
newResults.push_back(newResult);
|
|
}
|
|
|
|
return VectorizationHookResult{VectorizationHookStatus::NoReplace, nullptr};
|
|
}
|
|
|
|
/// Helper function to vectorize the index operations of a `linalgOp`. Return
|
|
/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
|
|
/// should map the produced operations. This function is meant to be used as a
|
|
/// CustomVectorizationHook.
|
|
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter,
|
|
VectorizationState &state,
|
|
Operation *op,
|
|
LinalgOp linalgOp) {
|
|
IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
|
|
if (!indexOp)
|
|
return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
|
|
auto loc = indexOp.getLoc();
|
|
// Compute the static loop sizes of the index op.
|
|
ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
|
|
auto dim = indexOp.getDim();
|
|
// Compute a one-dimensional index vector for the index op dimension.
|
|
auto indexVectorType =
|
|
VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
|
|
state.getScalableVecDims()[dim]);
|
|
auto indexSteps = rewriter.create<vector::StepOp>(loc, indexVectorType);
|
|
// Return the one-dimensional index vector if it lives in the trailing
|
|
// dimension of the iteration space since the vectorization algorithm in this
|
|
// case can handle the broadcast.
|
|
if (dim == targetShape.size() - 1)
|
|
return VectorizationHookResult{VectorizationHookStatus::NewOp, indexSteps};
|
|
// Otherwise permute the targetShape to move the index dimension last,
|
|
// broadcast the one-dimensional index vector to the permuted shape, and
|
|
// finally transpose the broadcasted index vector to undo the permutation.
|
|
auto permPattern =
|
|
llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
|
|
std::swap(permPattern[dim], permPattern.back());
|
|
auto permMap =
|
|
AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
|
|
|
|
auto broadCastOp = rewriter.create<vector::BroadcastOp>(
|
|
loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap),
|
|
indexSteps);
|
|
SmallVector<int64_t> transposition =
|
|
llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
|
|
std::swap(transposition.back(), transposition[dim]);
|
|
auto transposeOp =
|
|
rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
|
|
return VectorizationHookResult{VectorizationHookStatus::NewOp, transposeOp};
|
|
}
|
|
|
|
/// Helper function to check if the tensor.extract can be vectorized by the
|
|
/// custom hook vectorizeTensorExtract.
|
|
static LogicalResult
|
|
tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
|
|
tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
|
|
if (!extractOp)
|
|
return failure();
|
|
|
|
if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
|
|
return failure();
|
|
|
|
// Check the index type, but only for non 0-d tensors (for which we do need
|
|
// access indices).
|
|
if (not extractOp.getIndices().empty()) {
|
|
if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
|
|
return failure();
|
|
}
|
|
|
|
if (!llvm::all_of(extractOp->getResultTypes(),
|
|
VectorType::isValidElementType)) {
|
|
return failure();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Calculates the offsets (`$index_vec`) for `vector.gather` operations
|
|
/// generated from `tensor.extract`. The offset is calculated as follows
|
|
/// (example using scalar values):
|
|
///
|
|
/// offset = extractOp.indices[0]
|
|
/// for (i = 1; i < numIndices; i++)
|
|
/// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
|
|
///
|
|
/// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
|
|
/// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
|
|
static Value calculateGatherOffset(RewriterBase &rewriter,
|
|
VectorizationState &state,
|
|
tensor::ExtractOp extractOp,
|
|
const IRMapping &bvm) {
|
|
// The vector of indices for GatherOp should be shaped as the output vector.
|
|
auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
|
|
auto loc = extractOp.getLoc();
|
|
|
|
Value offset = broadcastIfNeeded(
|
|
rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
|
|
|
|
const size_t numIndices = extractOp.getIndices().size();
|
|
for (size_t i = 1; i < numIndices; i++) {
|
|
Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
|
|
|
|
auto dimSize = broadcastIfNeeded(
|
|
rewriter,
|
|
rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
|
|
indexVecType);
|
|
|
|
offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
|
|
|
|
auto extractOpIndex = broadcastIfNeeded(
|
|
rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
|
|
|
|
offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset);
|
|
}
|
|
|
|
return offset;
|
|
}
|
|
|
|
enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
|
|
|
|
/// Find the index of the trailing non-unit dim in linalgOp. This hook is used
|
|
/// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
|
|
/// represents a contiguous load operation.
|
|
///
|
|
/// Note that when calling this hook, it is assumed that the output vector is
|
|
/// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
|
|
/// labelled as a gather load before entering this method.
|
|
///
|
|
/// Following on from the above, it is assumed that:
|
|
/// * for statically shaped loops, when no masks are used, only one dim is !=
|
|
/// 1 (that's what the shape of the output vector is based on).
|
|
/// * for dynamically shaped loops, there might be more non-unit dims
|
|
/// as the output vector type is user-specified.
|
|
///
|
|
/// TODO: Statically shaped loops + vector masking
|
|
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
|
|
SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
|
|
assert(
|
|
(linalgOp.hasDynamicShape() ||
|
|
llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
|
|
"For statically shaped Linalg Ops, only one "
|
|
"non-unit loop dim is expected");
|
|
assert(loopRanges.size() != 0 && "Empty loops, nothing to analyse.");
|
|
|
|
size_t idx = loopRanges.size() - 1;
|
|
for (; idx != 0; idx--)
|
|
if (loopRanges[idx] != 1)
|
|
break;
|
|
|
|
return idx;
|
|
}
|
|
|
|
/// Checks whether `val` can be used for calculating a loop invariant index.
|
|
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
|
|
VectorType resType) {
|
|
|
|
assert(((llvm::count_if(resType.getShape(),
|
|
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
|
|
"n-D vectors are not yet supported");
|
|
|
|
// Blocks outside _this_ linalg.generic are effectively loop invariant.
|
|
// However, analysing block arguments for _this_ linalg.generic Op is a bit
|
|
// tricky. Just bail out in the latter case.
|
|
// TODO: We could try analysing the corresponding affine map here.
|
|
auto *block = linalgOp.getBlock();
|
|
if (isa<BlockArgument>(val))
|
|
return llvm::all_of(block->getArguments(),
|
|
[&val](Value v) { return (v != val); });
|
|
|
|
Operation *defOp = val.getDefiningOp();
|
|
assert(defOp && "This is neither a block argument nor an operation result");
|
|
|
|
// IndexOp is loop invariant as long as its result remains constant across
|
|
// iterations. Note that for dynamic shapes, the corresponding dim will also
|
|
// be conservatively treated as != 1.
|
|
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
|
|
return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
|
|
}
|
|
|
|
auto *ancestor = block->findAncestorOpInBlock(*defOp);
|
|
|
|
// Values define outside `linalgOp` are loop invariant.
|
|
if (!ancestor)
|
|
return true;
|
|
|
|
// Values defined inside `linalgOp`, which are constant, are loop invariant.
|
|
if (isa<arith::ConstantOp>(ancestor))
|
|
return true;
|
|
|
|
bool result = true;
|
|
for (auto op : ancestor->getOperands())
|
|
result &= isLoopInvariantIdx(linalgOp, op, resType);
|
|
|
|
return result;
|
|
}
|
|
|
|
/// Check whether `val` could be used for calculating the trailing index for a
|
|
/// contiguous load operation.
|
|
///
|
|
/// There are currently 3 types of values that are allowed here:
|
|
/// 1. loop-invariant values,
|
|
/// 2. values that increment by 1 with every loop iteration,
|
|
/// 3. results of basic arithmetic operations (linear and continuous)
|
|
/// involving 1., 2. and 3.
|
|
/// This method returns True if indeed only such values are used in calculating
|
|
/// `val.`
|
|
///
|
|
/// Additionally, the trailing index for a contiguous load operation should
|
|
/// increment by 1 with every loop iteration, i.e. be based on:
|
|
/// * `linalg.index <dim>` ,
|
|
/// where <dim> is the trailing non-unit dim of the iteration space (this way,
|
|
/// `linalg.index <dim>` increments by 1 with every loop iteration).
|
|
/// `foundIndexOp` is updated to `true` when such Op is found.
|
|
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
|
|
bool &foundIndexOp, VectorType resType) {
|
|
|
|
assert(((llvm::count_if(resType.getShape(),
|
|
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
|
|
"n-D vectors are not yet supported");
|
|
|
|
// Blocks outside _this_ linalg.generic are effectively loop invariant.
|
|
// However, analysing block arguments for _this_ linalg.generic Op is a bit
|
|
// tricky. Just bail out in the latter case.
|
|
// TODO: We could try analysing the corresponding affine map here.
|
|
auto *block = linalgOp.getBlock();
|
|
if (isa<BlockArgument>(val))
|
|
return llvm::all_of(block->getArguments(),
|
|
[&val](Value v) { return (v != val); });
|
|
|
|
Operation *defOp = val.getDefiningOp();
|
|
assert(defOp && "This is neither a block argument nor an operation result");
|
|
|
|
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
|
|
auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
|
|
|
|
foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
|
|
return true;
|
|
}
|
|
|
|
auto *ancestor = block->findAncestorOpInBlock(*defOp);
|
|
|
|
if (!ancestor)
|
|
return false;
|
|
|
|
// Conservatively reject Ops that could lead to indices with stride other
|
|
// than 1.
|
|
if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
|
|
return false;
|
|
|
|
bool result = false;
|
|
for (auto op : ancestor->getOperands())
|
|
result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
|
|
|
|
return result;
|
|
}
|
|
|
|
/// Infer the memory access pattern for the input ExtractOp
|
|
///
|
|
/// Based on the ExtratOp result shape and the access indices, decides whether
|
|
/// this Op corresponds to a contiguous load (including a broadcast of a scalar)
|
|
/// or a gather load. When analysing the ExtractOp indices (to identify
|
|
/// contiguous laods), this method looks for "loop" invariant indices (e.g.
|
|
/// block arguments) and indices that change linearly (e.g. via `linalg.index`
|
|
/// Op).
|
|
///
|
|
/// Note that it is always safe to use gather load operations for contiguous
|
|
/// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
|
|
/// that `extractOp` is a gather load.
|
|
static VectorMemoryAccessKind
|
|
getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
|
|
LinalgOp &linalgOp, VectorType resType) {
|
|
|
|
auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
|
|
|
|
// 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
|
|
if (inputShape.getShape().empty())
|
|
return VectorMemoryAccessKind::ScalarBroadcast;
|
|
|
|
// True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
|
|
// otherwise.
|
|
bool isOutput1DVector =
|
|
(llvm::count_if(resType.getShape(),
|
|
[](int64_t dimSize) { return dimSize > 1; }) == 1);
|
|
// 1. Assume that it's a gather load when reading non-1D vector.
|
|
if (!isOutput1DVector)
|
|
return VectorMemoryAccessKind::Gather;
|
|
|
|
bool leadingIdxsLoopInvariant = true;
|
|
|
|
// 2. Analyze the leading indices of `extractOp`.
|
|
// Look at the way each index is calculated and decide whether it is suitable
|
|
// for a contiguous load, i.e. whether it's loop invariant. If not, it's a
|
|
// gather load.
|
|
auto indices = extractOp.getIndices();
|
|
auto leadIndices = indices.drop_back(1);
|
|
|
|
for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
|
|
if (inputShape.getShape()[i] == 1)
|
|
continue;
|
|
|
|
leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
|
|
}
|
|
|
|
if (!leadingIdxsLoopInvariant) {
|
|
LDBG("Found gather load: " << extractOp);
|
|
return VectorMemoryAccessKind::Gather;
|
|
}
|
|
|
|
// 3. Analyze the trailing index for `extractOp`.
|
|
// At this point we know that the leading indices are loop invariant. This
|
|
// means that is potentially a scalar or a contiguous load. We can decide
|
|
// based on the trailing idx.
|
|
auto extractOpTrailingIdx = indices.back();
|
|
|
|
// 3a. Scalar broadcast load
|
|
// If the trailing index is loop invariant then this is a scalar load.
|
|
if (leadingIdxsLoopInvariant &&
|
|
isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
|
|
LDBG("Found scalar broadcast load: " << extractOp);
|
|
|
|
return VectorMemoryAccessKind::ScalarBroadcast;
|
|
}
|
|
|
|
// 3b. Contiguous loads
|
|
// The trailing `extractOp` index should increment with every loop iteration.
|
|
// This effectively means that it must be based on the trailing loop index.
|
|
// This is what the following bool captures.
|
|
bool foundIndexOp = false;
|
|
bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
|
|
foundIndexOp, resType);
|
|
// TODO: Support generating contiguous loads for column vectors - that will
|
|
// require adding a permutation map to tranfer_read Ops.
|
|
bool isRowVector = resType.getShape().back() != 1;
|
|
isContiguousLoad &= (foundIndexOp && isRowVector);
|
|
|
|
if (isContiguousLoad) {
|
|
LDBG("Found contigous load: " << extractOp);
|
|
return VectorMemoryAccessKind::Contiguous;
|
|
}
|
|
|
|
// 4. Fallback case - gather load.
|
|
LDBG("Found gather load: " << extractOp);
|
|
return VectorMemoryAccessKind::Gather;
|
|
}
|
|
|
|
/// Helper function to vectorize the tensor.extract operations. Returns
|
|
/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
|
|
/// should map the produced operations. This function is meant to be used as a
|
|
/// CustomVectorizationHook.
|
|
static VectorizationHookResult
|
|
vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
|
|
Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
|
|
tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
|
|
if (!extractOp)
|
|
return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
|
|
auto loc = extractOp.getLoc();
|
|
|
|
// Compute the static loop sizes of the extract op.
|
|
auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
|
|
auto maskConstantOp = rewriter.create<arith::ConstantOp>(
|
|
loc,
|
|
DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
|
|
/*value=*/true));
|
|
auto passThruConstantOp =
|
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
|
|
|
|
// Base indices are currently set to 0. We will need to re-visit if more
|
|
// generic scenarios are to be supported.
|
|
SmallVector<Value> baseIndices(
|
|
extractOp.getIndices().size(),
|
|
rewriter.create<arith::ConstantIndexOp>(loc, 0));
|
|
|
|
VectorMemoryAccessKind memAccessKind =
|
|
getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
|
|
|
|
// 1. Handle gather access
|
|
if (memAccessKind == VectorMemoryAccessKind::Gather) {
|
|
Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
|
|
|
|
// Generate the gather load
|
|
Operation *gatherOp = rewriter.create<vector::GatherOp>(
|
|
loc, resultType, extractOp.getTensor(), baseIndices, offset,
|
|
maskConstantOp, passThruConstantOp);
|
|
gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
|
|
|
|
LDBG("Vectorised as gather load: " << extractOp << "\n");
|
|
return VectorizationHookResult{VectorizationHookStatus::NewOp, gatherOp};
|
|
}
|
|
|
|
// 2. Handle:
|
|
// a. scalar loads + broadcast,
|
|
// b. contiguous loads.
|
|
// Both cases use vector.transfer_read.
|
|
|
|
// Collect indices for `vector.transfer_read`. At this point, the indices will
|
|
// either be scalars or would have been broadcast to vectors matching the
|
|
// result type. For indices that are vectors, there are two options:
|
|
// * for non-trailing indices, all elements are identical (contiguous
|
|
// loads are identified by looking for non-trailing indices that are
|
|
// invariant with respect to the corresponding linalg.generic), or
|
|
// * for trailing indices, the index vector will contain values with stride
|
|
// one, but for `vector.transfer_read` only the first (i.e. 0th) index is
|
|
// needed.
|
|
// This means that
|
|
// * for scalar indices - just re-use it,
|
|
// * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
|
|
// (0th) element and use that.
|
|
SmallVector<Value> transferReadIdxs;
|
|
for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
|
|
Value idx = bvm.lookup(extractOp.getIndices()[i]);
|
|
if (idx.getType().isIndex()) {
|
|
transferReadIdxs.push_back(idx);
|
|
continue;
|
|
}
|
|
|
|
auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
|
|
loc,
|
|
VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
|
|
resultType.getScalableDims().back()),
|
|
idx);
|
|
transferReadIdxs.push_back(
|
|
rewriter.create<vector::ExtractOp>(loc, indexAs1dVector, 0));
|
|
}
|
|
|
|
// `tensor.extract_element` is always in-bounds, hence the following holds.
|
|
auto dstRank = resultType.getRank();
|
|
auto srcRank = extractOp.getTensor().getType().getRank();
|
|
SmallVector<bool> inBounds(dstRank, true);
|
|
|
|
// 2a. Handle scalar broadcast access.
|
|
if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
|
|
MLIRContext *ctx = rewriter.getContext();
|
|
SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
|
|
auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
|
|
|
|
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
|
|
loc, resultType, extractOp.getTensor(), transferReadIdxs,
|
|
permutationMap, inBounds);
|
|
|
|
// Mask this broadcasting xfer_read here rather than relying on the generic
|
|
// path (the generic path assumes identity masking map, which wouldn't be
|
|
// valid here).
|
|
SmallVector<int64_t> readMaskShape = {1};
|
|
auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
|
|
auto allTrue = rewriter.create<vector::ConstantMaskOp>(
|
|
loc, readMaskType, vector::ConstantMaskKind::AllTrue);
|
|
auto *maskedReadOp =
|
|
mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
|
|
|
|
LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n");
|
|
return VectorizationHookResult{VectorizationHookStatus::NewOp,
|
|
maskedReadOp};
|
|
}
|
|
|
|
// 2b. Handle contiguous access.
|
|
auto permutationMap = AffineMap::getMinorIdentityMap(
|
|
srcRank, std::min(dstRank, srcRank), rewriter.getContext());
|
|
|
|
int32_t rankDiff = dstRank - srcRank;
|
|
// When dstRank > srcRank, broadcast the source tensor to the unitary leading
|
|
// dims so that the ranks match. This is done by extending the map with 0s.
|
|
// For example, for dstRank = 3, srcRank = 2, the following map created
|
|
// above:
|
|
// (d0, d1) --> (d0, d1)
|
|
// is extended as:
|
|
// (d0, d1) --> (0, d0, d1)
|
|
while (rankDiff > 0) {
|
|
permutationMap = permutationMap.insertResult(
|
|
mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
|
|
rankDiff--;
|
|
}
|
|
|
|
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
|
|
loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
|
|
inBounds);
|
|
|
|
LDBG("Vectorised as contiguous load: " << extractOp);
|
|
return VectorizationHookResult{VectorizationHookStatus::NewOp,
|
|
transferReadOp};
|
|
}
|
|
|
|
/// Emit reduction operations if the shapes of the value to reduce is different
|
|
/// that the result shape.
|
|
// Note: this is a true builder that notifies the OpBuilder listener.
|
|
// TODO: Consider moving as a static helper on the ReduceOp.
|
|
static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
|
|
Value reduceValue, Value initialValue,
|
|
const IRMapping &bvm) {
|
|
Value reduceVec = bvm.lookup(reduceValue);
|
|
Value outputVec = bvm.lookup(initialValue);
|
|
auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
|
|
auto outputType = dyn_cast<VectorType>(outputVec.getType());
|
|
// Reduce only if needed as the value may already have been reduce for
|
|
// contraction vectorization.
|
|
if (!reduceType ||
|
|
(outputType && reduceType.getShape() == outputType.getShape()))
|
|
return nullptr;
|
|
SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
|
|
return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
|
|
}
|
|
|
|
/// Generic vectorization for a single operation `op`, given already vectorized
|
|
/// operands carried by `bvm`. Vectorization occurs as follows:
|
|
/// 1. Try to apply any of the `customVectorizationHooks` and return its
|
|
/// result on success.
|
|
/// 2. Clone any constant in the current scope without vectorization: each
|
|
/// consumer of the constant will later determine the shape to which the
|
|
/// constant needs to be broadcast to.
|
|
/// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
|
|
/// of the `customVectorizationHooks` to cover such cases.
|
|
/// 4. Clone `op` in vector form to a vector of shape prescribed by the first
|
|
/// operand of maximal rank. Other operands have smaller rank and are
|
|
/// broadcast accordingly. It is assumed this broadcast is always legal,
|
|
/// otherwise, it means one of the `customVectorizationHooks` is incorrect.
|
|
///
|
|
/// This function assumes all operands of `op` have been vectorized and are in
|
|
/// the `bvm` mapping. As a consequence, this function is meant to be called on
|
|
/// a topologically-sorted list of ops.
|
|
/// This function does not update `bvm` but returns a VectorizationHookStatus
|
|
/// that instructs the caller what `bvm` update needs to occur.
|
|
static VectorizationHookResult
|
|
vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
|
|
LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
|
|
ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
|
|
LDBG("vectorize op " << *op << "\n");
|
|
|
|
// 1. Try to apply any CustomVectorizationHook.
|
|
if (!customVectorizationHooks.empty()) {
|
|
for (auto &customFunc : customVectorizationHooks) {
|
|
VectorizationHookResult result = customFunc(op, bvm);
|
|
if (result.status == VectorizationHookStatus::Failure)
|
|
continue;
|
|
return result;
|
|
}
|
|
}
|
|
|
|
// 2. Constant ops don't get vectorized but rather broadcasted at their users.
|
|
// Clone so that the constant is not confined to the linalgOp block .
|
|
if (isa<arith::ConstantOp, func::ConstantOp>(op))
|
|
return VectorizationHookResult{VectorizationHookStatus::NewOp,
|
|
rewriter.clone(*op)};
|
|
|
|
// 3. Only ElementwiseMappable are allowed in the generic vectorization.
|
|
if (!OpTrait::hasElementwiseMappableTraits(op))
|
|
return VectorizationHookResult{VectorizationHookStatus::Failure, nullptr};
|
|
|
|
// 4 . Check if the operation is a reduction.
|
|
SmallVector<std::pair<Value, Value>> reductionOperands;
|
|
for (Value operand : op->getOperands()) {
|
|
auto blockArg = dyn_cast<BlockArgument>(operand);
|
|
if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
|
|
blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
|
|
continue;
|
|
SmallVector<Operation *> reductionOps;
|
|
Value reduceValue = matchReduction(
|
|
linalgOp.getRegionOutputArgs(),
|
|
blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
|
|
if (!reduceValue)
|
|
continue;
|
|
reductionOperands.push_back(std::make_pair(reduceValue, operand));
|
|
}
|
|
if (!reductionOperands.empty()) {
|
|
assert(reductionOperands.size() == 1);
|
|
Operation *reduceOp =
|
|
reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
|
|
reductionOperands[0].second, bvm);
|
|
if (reduceOp)
|
|
return VectorizationHookResult{VectorizationHookStatus::NewOp, reduceOp};
|
|
}
|
|
|
|
// 5. Generic vectorization path for ElementwiseMappable ops.
|
|
// a. Get the first max ranked shape.
|
|
VectorType firstMaxRankedType;
|
|
for (Value operand : op->getOperands()) {
|
|
auto vecOperand = bvm.lookup(operand);
|
|
assert(vecOperand && "Vector operand couldn't be found");
|
|
|
|
auto vecType = dyn_cast<VectorType>(vecOperand.getType());
|
|
if (vecType && (!firstMaxRankedType ||
|
|
firstMaxRankedType.getRank() < vecType.getRank()))
|
|
firstMaxRankedType = vecType;
|
|
}
|
|
// b. Broadcast each op if needed.
|
|
SmallVector<Value> vecOperands;
|
|
for (Value scalarOperand : op->getOperands()) {
|
|
Value vecOperand = bvm.lookup(scalarOperand);
|
|
assert(vecOperand && "Vector operand couldn't be found");
|
|
|
|
if (firstMaxRankedType) {
|
|
auto vecType = VectorType::get(firstMaxRankedType.getShape(),
|
|
getElementTypeOrSelf(vecOperand.getType()),
|
|
firstMaxRankedType.getScalableDims());
|
|
vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
|
|
} else {
|
|
vecOperands.push_back(vecOperand);
|
|
}
|
|
}
|
|
// c. for elementwise, the result is the vector with the firstMaxRankedShape
|
|
SmallVector<Type> resultTypes;
|
|
for (Type resultType : op->getResultTypes()) {
|
|
resultTypes.push_back(
|
|
firstMaxRankedType
|
|
? VectorType::get(firstMaxRankedType.getShape(), resultType,
|
|
firstMaxRankedType.getScalableDims())
|
|
: resultType);
|
|
}
|
|
// d. Build and return the new op.
|
|
return VectorizationHookResult{
|
|
VectorizationHookStatus::NewOp,
|
|
rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
|
|
resultTypes, op->getAttrs())};
|
|
}
|
|
|
|
/// Generic vectorization function that rewrites the body of a `linalgOp` into
|
|
/// vector form. Generic vectorization proceeds as follows:
|
|
/// 1. Verify the `linalgOp` has one non-empty region.
|
|
/// 2. Values defined above the region are mapped to themselves and will be
|
|
/// broadcasted on a per-need basis by their consumers.
|
|
/// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
|
|
/// load).
|
|
/// TODO: Reuse opportunities for RAR dependencies.
|
|
/// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
|
|
/// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
|
|
/// iteration indices.
|
|
/// 5. Iteratively call vectorizeOneOp on the region operations.
|
|
///
|
|
/// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
|
|
/// performed to the maximal common vector size implied by the `linalgOp`
|
|
/// iteration space. This eager broadcasting is introduced in the
|
|
/// permutation_map of the vector.transfer_read operations. The eager
|
|
/// broadcasting makes it trivial to detrmine where broadcast, transposes and
|
|
/// reductions should occur, without any bookkeeping. The tradeoff is that, in
|
|
/// the absence of good canonicalizations, the amount of work increases.
|
|
/// This is not deemed a problem as we expect canonicalizations and foldings to
|
|
/// aggressively clean up the useless work.
|
|
static LogicalResult
|
|
vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
|
|
LinalgOp linalgOp,
|
|
SmallVectorImpl<Value> &newResults) {
|
|
LDBG("Vectorizing operation as linalg generic\n");
|
|
Block *block = linalgOp.getBlock();
|
|
|
|
// 2. Values defined above the region can only be broadcast for now. Make them
|
|
// map to themselves.
|
|
IRMapping bvm;
|
|
SetVector<Value> valuesSet;
|
|
mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
|
|
bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
|
|
|
|
if (linalgOp.getNumDpsInits() == 0)
|
|
return failure();
|
|
|
|
// 3. Turn all BBArgs into vector.transfer_read / load.
|
|
Location loc = linalgOp.getLoc();
|
|
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
|
for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
|
|
BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
|
|
if (linalgOp.isScalar(opOperand)) {
|
|
bvm.map(bbarg, opOperand->get());
|
|
continue;
|
|
}
|
|
|
|
// 3.a. Convert the indexing map for this input/output to a transfer read
|
|
// permutation map and masking map.
|
|
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
|
|
|
|
AffineMap readMap;
|
|
VectorType readType;
|
|
Type elemType = getElementTypeOrSelf(opOperand->get());
|
|
if (linalgOp.isDpsInput(opOperand)) {
|
|
// 3.a.i. For input reads we use the canonical vector shape.
|
|
readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
|
|
readType = state.getCanonicalVecType(elemType);
|
|
} else {
|
|
// 3.a.ii. For output reads (iteration-carried dependence, e.g.,
|
|
// reductions), the vector shape is computed by mapping the canonical
|
|
// vector shape to the output domain and back to the canonical domain.
|
|
readMap = inversePermutation(reindexIndexingMap(indexingMap));
|
|
readType =
|
|
state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
|
|
}
|
|
|
|
SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
|
|
|
|
Operation *read = rewriter.create<vector::TransferReadOp>(
|
|
loc, readType, opOperand->get(), indices, readMap);
|
|
read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
|
|
Value readValue = read->getResult(0);
|
|
|
|
// 3.b. If masked, set in-bounds to true. Masking guarantees that the access
|
|
// will be in-bounds.
|
|
if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
|
|
SmallVector<bool> inBounds(readType.getRank(), true);
|
|
cast<vector::TransferReadOp>(maskOp.getMaskableOp())
|
|
.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
|
|
}
|
|
|
|
// 3.c. Not all ops support 0-d vectors, extract the scalar for now.
|
|
// TODO: remove this.
|
|
if (readType.getRank() == 0)
|
|
readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
|
|
ArrayRef<int64_t>());
|
|
|
|
LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
|
|
<< "\n");
|
|
bvm.map(bbarg, readValue);
|
|
bvm.map(opOperand->get(), readValue);
|
|
}
|
|
|
|
SmallVector<CustomVectorizationHook> hooks;
|
|
// 4a. Register CustomVectorizationHook for yieldOp.
|
|
CustomVectorizationHook vectorizeYield =
|
|
[&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
|
|
return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
|
|
};
|
|
hooks.push_back(vectorizeYield);
|
|
|
|
// 4b. Register CustomVectorizationHook for indexOp.
|
|
CustomVectorizationHook vectorizeIndex =
|
|
[&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
|
|
return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
|
|
};
|
|
hooks.push_back(vectorizeIndex);
|
|
|
|
// 4c. Register CustomVectorizationHook for extractOp.
|
|
CustomVectorizationHook vectorizeExtract =
|
|
[&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
|
|
return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
|
|
};
|
|
hooks.push_back(vectorizeExtract);
|
|
|
|
// 5. Iteratively call `vectorizeOneOp` to each op in the slice.
|
|
for (Operation &op : block->getOperations()) {
|
|
VectorizationHookResult result =
|
|
vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
|
|
if (result.status == VectorizationHookStatus::Failure) {
|
|
LDBG("failed to vectorize: " << op << "\n");
|
|
return failure();
|
|
}
|
|
if (result.status == VectorizationHookStatus::NewOp) {
|
|
Operation *maybeMaskedOp =
|
|
state.maskOperation(rewriter, result.newOp, linalgOp);
|
|
LDBG("New vector op: " << *maybeMaskedOp << "\n");
|
|
bvm.map(op.getResults(), maybeMaskedOp->getResults());
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Given a linalg::PackOp, return the `dest` shape before any packing
|
|
/// permutations.
|
|
static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
|
|
ArrayRef<int64_t> destShape) {
|
|
return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
|
|
}
|
|
|
|
/// Determines whether a mask for xfer_write is trivially "all true"
|
|
///
|
|
/// Given all the inputs required to generate a mask (mask sizes and shapes),
|
|
/// and an xfer_write operation (write indices and the destination tensor
|
|
/// shape), determines whether the corresponding mask would be trivially
|
|
/// foldable (i.e., trivially "all true").
|
|
///
|
|
/// Use this method to avoid generating spurious masks and relaying on
|
|
/// vectorization post-processing to remove them.
|
|
///
|
|
/// Pre-conditions for a mask to be trivially foldable:
|
|
/// * All involved shapes (mask + destination tensor) are static.
|
|
/// * All write indices are constant.
|
|
/// * All mask sizes are constant (including `arith.constant`).
|
|
///
|
|
/// If the pre-conditions are met, the method checks for each destination
|
|
/// dimension `d`:
|
|
/// (1) destDimSize[rankDiff + d] <= maskShape[d]
|
|
/// (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
|
|
///
|
|
/// rankDiff = rank(dest) - rank(mask).
|
|
///
|
|
/// This method takes a conservative view: it may return false even if the mask
|
|
/// is technically foldable.
|
|
///
|
|
/// EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape
|
|
/// of the dest tensor):
|
|
/// %c0 = arith.constant 0 : index
|
|
/// %mask = vector.create_mask 5, 1
|
|
/// vector.mask %mask {
|
|
/// vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
|
|
/// {in_bounds = [true, true]}
|
|
/// : vector<5x1xi32>, tensor<5x1xi32>
|
|
/// }
|
|
///
|
|
/// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape,
|
|
/// mask is required to avoid out-of-bounds write):
|
|
/// %c0 = arith.constant 0 : index
|
|
/// %mask = vector.create_mask 5, 1
|
|
/// vector.mask %mask {
|
|
/// vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
|
|
/// {in_bounds = [true, true]}
|
|
/// : vector<8x1xi32>, tensor<5x1xi32>
|
|
/// }
|
|
///
|
|
/// TODO: Re-use in createReadOrMaskedRead
|
|
static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
|
|
SmallVector<Value> &writeIdxs,
|
|
ArrayRef<int64_t> destShape,
|
|
ArrayRef<int64_t> maskShape) {
|
|
// Masking is unavoidable in the case of dynamic tensors.
|
|
if (ShapedType::isDynamicShape(destShape))
|
|
return false;
|
|
|
|
// Collect all constant mask sizes.
|
|
SmallVector<int64_t, 4> cstMaskSizes;
|
|
for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
|
|
if (auto intSize = getConstantIntValue(dimSize)) {
|
|
cstMaskSizes.push_back(*intSize);
|
|
}
|
|
}
|
|
|
|
// If any of the mask sizes is non-constant, bail out.
|
|
if (cstMaskSizes.size() != maskShape.size())
|
|
return false;
|
|
|
|
// Collect all constant write indices.
|
|
SmallVector<int64_t, 4> cstWriteIdxs;
|
|
for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
|
|
APSInt intVal;
|
|
if (matchPattern(idx, m_ConstantInt(&intVal))) {
|
|
cstWriteIdxs.push_back(intVal.getSExtValue());
|
|
}
|
|
}
|
|
|
|
// If any of the write indices is non-constant, bail out.
|
|
if (cstWriteIdxs.size() != destShape.size())
|
|
return false;
|
|
|
|
// Go over all destination dims and check (1) and (2). Take into account that:
|
|
// * The number of mask sizes will match the rank of the vector to store.
|
|
// This could be lower than the rank of the destination tensor.
|
|
// * Mask sizes could be larger than the corresponding mask shape (hence
|
|
// `clamp`).
|
|
// TODO: The 2nd item should be rejected by the verifier.
|
|
int64_t rankDiff = destShape.size() - cstMaskSizes.size();
|
|
for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
|
|
if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
|
|
/*(2)*/ destShape[rankDiff + i] <
|
|
(std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
|
|
cstWriteIdxs[i]))
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
/// Creates an optionally masked TransferWriteOp
|
|
///
|
|
/// Generates the following operation:
|
|
/// %res = vector.transfer_write %vecToStore into %dest
|
|
///
|
|
/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
|
|
///
|
|
/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
|
|
/// %res = vector.mask %mask {
|
|
/// vector.transfer_write %vecToStore into %dest
|
|
/// }
|
|
///
|
|
/// The mask shape is identical to `vecToStore` (with the element type ==
|
|
/// i1), and the mask values are based on the shape of the `dest` tensor.
|
|
///
|
|
/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
|
|
/// is used instead of masking:
|
|
///
|
|
/// %write = vector.transfer_write %vecToStore into %dest
|
|
/// in_bounds_flags = (...)
|
|
/// %res = vector.transfer_write %input into %dest
|
|
/// {in_bounds = in_bounds_flags}
|
|
///
|
|
/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
|
|
/// are set to 0.
|
|
static Operation *
|
|
createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
|
|
Value dest, SmallVector<Value> writeIndices = {},
|
|
bool useInBoundsInsteadOfMasking = false) {
|
|
|
|
ShapedType destType = cast<ShapedType>(dest.getType());
|
|
int64_t destRank = destType.getRank();
|
|
auto destShape = destType.getShape();
|
|
|
|
VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
|
|
int64_t vecToStoreRank = vecToStoreType.getRank();
|
|
auto vecToStoreShape = vecToStoreType.getShape();
|
|
|
|
// Compute the in_bounds attribute
|
|
SmallVector<bool> inBoundsVal(vecToStoreRank, true);
|
|
if (useInBoundsInsteadOfMasking) {
|
|
// Update the inBounds attribute.
|
|
// FIXME: This computation is too weak - it ignores the write indices.
|
|
for (unsigned i = 0; i < vecToStoreRank; i++)
|
|
inBoundsVal[i] =
|
|
(destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
|
|
!ShapedType::isDynamic(destShape[destRank - vecToStoreRank + i]);
|
|
}
|
|
|
|
// If missing, initialize the write indices to 0.
|
|
assert(writeIndices.empty() ||
|
|
writeIndices.size() == static_cast<size_t>(destRank) &&
|
|
"Invalid number of write indices!");
|
|
if (writeIndices.empty()) {
|
|
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
|
|
writeIndices.assign(destRank, zero);
|
|
}
|
|
|
|
// Generate the xfer_write Op
|
|
Operation *write =
|
|
builder.create<vector::TransferWriteOp>(loc,
|
|
/*vector=*/vecToStore,
|
|
/*source=*/dest,
|
|
/*indices=*/writeIndices,
|
|
/*inBounds=*/inBoundsVal);
|
|
|
|
// If masking is disabled, exit.
|
|
if (useInBoundsInsteadOfMasking)
|
|
return write;
|
|
|
|
// Check if masking is needed. If not, exit.
|
|
if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
|
|
return write;
|
|
|
|
// Compute the mask and mask the write Op.
|
|
auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
|
|
|
|
SmallVector<OpFoldResult> destSizes =
|
|
tensor::getMixedSizes(builder, loc, dest);
|
|
SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
|
|
destSizes.end());
|
|
|
|
if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
|
|
vecToStoreShape))
|
|
return write;
|
|
|
|
Value maskForWrite =
|
|
builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
|
|
return mlir::vector::maskOperation(builder, write, maskForWrite);
|
|
}
|
|
|
|
/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
|
|
/// padding value and (3) input vector sizes into:
|
|
///
|
|
/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
|
|
///
|
|
/// As in the following example:
|
|
/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
|
|
/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
|
|
///
|
|
/// This pack would be vectorized to:
|
|
///
|
|
/// %load = vector.mask %mask {
|
|
/// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
|
|
/// {in_bounds = [true, true, true]} :
|
|
/// tensor<32x7x16xf32>, vector<32x8x16xf32>
|
|
/// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
|
|
/// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
|
|
/// to vector<32x4x2x1x16xf32>
|
|
/// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
|
|
/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
|
|
/// %write = vector.transfer_write %transpose,
|
|
/// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
|
|
/// {in_bounds = [true, true, true, true, true]}
|
|
/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
|
|
///
|
|
/// If the (3) input vector sizes are not provided, the vector sizes are
|
|
/// determined by the result tensor shape and the `in_bounds`
|
|
/// attribute is used instead of masking to mark out-of-bounds accesses.
|
|
///
|
|
/// NOTE: The input vector sizes specify the dimensions corresponding to the
|
|
/// outer dimensions of the output tensor. The remaining dimensions are
|
|
/// computed based on, e.g., the static inner tiles.
|
|
/// Supporting dynamic inner tiles will require the user to specify the
|
|
/// missing vector sizes. This is left as a TODO.
|
|
static LogicalResult
|
|
vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
|
|
ArrayRef<int64_t> inputVectorSizes,
|
|
SmallVectorImpl<Value> &newResults) {
|
|
// TODO: Introduce a parent class that will handle the insertion point update.
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
rewriter.setInsertionPoint(packOp);
|
|
|
|
Location loc = packOp.getLoc();
|
|
auto padValue = packOp.getPaddingValue();
|
|
if (!padValue) {
|
|
padValue = rewriter.create<arith::ConstantOp>(
|
|
loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
|
|
}
|
|
ReifiedRankedShapedTypeDims reifiedReturnShapes;
|
|
LogicalResult status =
|
|
cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
|
|
.reifyResultShapes(rewriter, reifiedReturnShapes);
|
|
(void)status; // prevent unused variable warning on non-assert builds.
|
|
assert(succeeded(status) && "failed to reify result shapes");
|
|
|
|
// If the input vector sizes are not provided, then the vector sizes are
|
|
// determined by the result tensor shape. In case the vector sizes aren't
|
|
// provided, we update the inBounds attribute instead of masking.
|
|
bool useInBoundsInsteadOfMasking = false;
|
|
if (inputVectorSizes.empty()) {
|
|
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
|
|
inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
|
|
useInBoundsInsteadOfMasking = true;
|
|
}
|
|
|
|
// Create masked TransferReadOp.
|
|
SmallVector<int64_t> inputShape(inputVectorSizes);
|
|
auto innerTiles = packOp.getStaticInnerTiles();
|
|
auto innerDimsPos = packOp.getInnerDimsPos();
|
|
auto outerDimsPerm = packOp.getOuterDimsPerm();
|
|
if (!outerDimsPerm.empty())
|
|
applyPermutationToVector(inputShape,
|
|
invertPermutationVector(outerDimsPerm));
|
|
for (auto [idx, size] : enumerate(innerTiles))
|
|
inputShape[innerDimsPos[idx]] *= size;
|
|
auto maskedRead = vector::createReadOrMaskedRead(
|
|
rewriter, loc, packOp.getSource(), inputShape, padValue,
|
|
useInBoundsInsteadOfMasking);
|
|
|
|
// Create ShapeCastOp.
|
|
SmallVector<int64_t> destShape(inputVectorSizes);
|
|
destShape.append(innerTiles.begin(), innerTiles.end());
|
|
auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
|
|
packOp.getDestType().getElementType());
|
|
auto shapeCastOp =
|
|
rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
|
|
|
|
// Create TransposeOp.
|
|
auto destPermutation =
|
|
invertPermutationVector(getPackInverseDestPerm(packOp));
|
|
auto transposeOp = rewriter.create<vector::TransposeOp>(
|
|
loc, shapeCastOp.getResult(), destPermutation);
|
|
|
|
// Create TransferWriteOp.
|
|
Value dest = rewriter.create<tensor::EmptyOp>(
|
|
loc, reifiedReturnShapes[0],
|
|
transposeOp.getResult().getType().getElementType());
|
|
Operation *write =
|
|
createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest);
|
|
newResults.push_back(write->getResult(0));
|
|
return success();
|
|
}
|
|
|
|
/// Vectorize a `linalg::UnPackOp` to these 4 Ops:
|
|
/// Vector::TransferReadOp - Reads a vector from the source tensor
|
|
/// vector::TransposeOp - Transpose the Source tensor
|
|
/// ShapeCastOp - Reshape the data based on the target.
|
|
/// vector::TransferWriteOp. - Write the result vector back to the destination
|
|
/// tensor.
|
|
/// If the vector sizes are not provided:
|
|
/// * the vector sizes are determined by the input operand and attributes,
|
|
/// * update the inBounds attribute instead of masking.
|
|
static LogicalResult
|
|
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
|
|
ArrayRef<int64_t> inputVectorSizes,
|
|
SmallVectorImpl<Value> &newResults) {
|
|
|
|
// TODO: Introduce a parent class that will handle the insertion point update.
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
rewriter.setInsertionPoint(unpackOp);
|
|
|
|
RankedTensorType unpackTensorType = unpackOp.getSourceType();
|
|
|
|
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
|
|
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
|
|
ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
|
|
bool useInBoundsInsteadOfMasking = false;
|
|
ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
|
|
|
|
auto destSize = unpackOp.getDestRank();
|
|
|
|
if (!inputVectorSizes.empty())
|
|
assert(inputVectorSizes.size() == destSize &&
|
|
"Incorrect number of input vector sizes");
|
|
|
|
// vectorSizes is the shape of the vector that will be used to do final
|
|
// write on the destination tensor. It is set like this: Let's say the
|
|
// source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
|
|
// Thus:
|
|
// 1. vectorSizes = sourceShape.take_front(N)
|
|
// 2. if outer_dims_perms is present: do that permutation on vectorSizes.
|
|
// 3. multiply all the locations in vectorSize pointed by innerDimPos by the
|
|
// innerTiles attribute value.
|
|
SmallVector<int64_t> vectorSizes(inputVectorSizes);
|
|
if (vectorSizes.empty()) {
|
|
llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
|
|
if (!outerDimsPerm.empty())
|
|
applyPermutationToVector(vectorSizes, outerDimsPerm);
|
|
for (auto [i, pos] : llvm::enumerate(innerDimPos))
|
|
vectorSizes[pos] *= innerTiles[i];
|
|
|
|
useInBoundsInsteadOfMasking = true;
|
|
}
|
|
|
|
// readVectorSizes is the size of tensor used to read and apply mask. It is
|
|
// set like this: Let's say the vectorSize (VS) array is size 'N' and
|
|
// the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
|
|
// size M-N
|
|
// Thus:
|
|
// - initially: readVectorSizes = vectorInputSizes
|
|
// - Divide all the readMaskShape locations pointed by innerDimPos
|
|
// by the innerTileSize attribute value.
|
|
// - if outer_dims_perms is present: do that permutation on readVectorSizes.
|
|
// - Append the remaining shape from SS
|
|
// E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
|
|
// inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
|
|
// 128] and outer_dims_perm is [1, 0] then read shape is:
|
|
// ReadVectorSizes(initial): [512, 128]
|
|
// Final Value(after innerDim Adjustment): [512/32, 128/16]
|
|
// = [16, 8]
|
|
// After applying outer_dims_perm: [8, 16]
|
|
// After appending the rest of the sourceShape: [8, 16, 32, 16]
|
|
|
|
SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
|
|
|
|
for (auto [index, size] : enumerate(innerTiles)) {
|
|
readVectorSizes[innerDimPos[index]] =
|
|
llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
|
|
}
|
|
if (!outerDimsPerm.empty()) {
|
|
applyPermutationToVector(readVectorSizes, outerDimsPerm);
|
|
}
|
|
readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
|
|
sourceShape.end());
|
|
|
|
ReifiedRankedShapedTypeDims reifiedRetShapes;
|
|
LogicalResult status =
|
|
cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
|
|
.reifyResultShapes(rewriter, reifiedRetShapes);
|
|
if (status.failed()) {
|
|
LDBG("Unable to reify result shapes of " << unpackOp);
|
|
return failure();
|
|
}
|
|
Location loc = unpackOp->getLoc();
|
|
|
|
auto padValue = rewriter.create<arith::ConstantOp>(
|
|
loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
|
|
|
|
// Read result, mask if necessary. If transferReadOp shape is not equal
|
|
// to shape of source, then a mask is necessary.
|
|
Value readResult = vector::createReadOrMaskedRead(
|
|
rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
|
|
/*useInBoundsInsteadOfMasking=*/false);
|
|
|
|
PackingMetadata packMetadata;
|
|
SmallVector<int64_t> lastDimToInsertPosPerm =
|
|
getUnPackInverseSrcPerm(unpackOp, packMetadata);
|
|
ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
|
|
SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
|
|
mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
|
|
applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
|
|
RankedTensorType stripMineTensorType =
|
|
RankedTensorType::get(stripMineShape, stripMineElemType);
|
|
// Transpose the appropriate rows to match output.
|
|
vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
|
|
loc, readResult, lastDimToInsertPosPerm);
|
|
|
|
// Collapse the vector to the size required by result.
|
|
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
|
|
stripMineTensorType, packMetadata.reassociations);
|
|
mlir::VectorType vecCollapsedType =
|
|
VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
|
|
vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
|
|
loc, vecCollapsedType, transposeOp->getResult(0));
|
|
|
|
// writeVectorSizes had to match the shapecast shape for dynamic sizes,
|
|
// otherwise the validator complains that the mask size is invalid.
|
|
SmallVector<int64_t> writeVectorSizes(
|
|
unpackOp.getDestType().hasStaticShape()
|
|
? vectorSizes
|
|
: shapeCastOp.getResultVectorType().getShape());
|
|
Value dest = rewriter.create<tensor::EmptyOp>(
|
|
loc, reifiedRetShapes[0],
|
|
shapeCastOp.getResult().getType().getElementType());
|
|
Operation *write = createWriteOrMaskedWrite(
|
|
rewriter, loc, shapeCastOp.getResult(), dest,
|
|
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
|
|
newResults.push_back(write->getResult(0));
|
|
return success();
|
|
}
|
|
|
|
/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
|
|
/// and (3) all-zero lowPad to
|
|
/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
|
|
static LogicalResult
|
|
vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
|
|
ArrayRef<int64_t> inputVectorSizes,
|
|
SmallVectorImpl<Value> &newResults) {
|
|
auto padValue = padOp.getConstantPaddingValue();
|
|
Location loc = padOp.getLoc();
|
|
|
|
// TODO: Introduce a parent class that will handle the insertion point update.
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
rewriter.setInsertionPoint(padOp);
|
|
|
|
ReifiedRankedShapedTypeDims reifiedReturnShapes;
|
|
LogicalResult status =
|
|
cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
|
|
.reifyResultShapes(rewriter, reifiedReturnShapes);
|
|
(void)status; // prevent unused variable warning on non-assert builds
|
|
assert(succeeded(status) && "failed to reify result shapes");
|
|
auto maskedRead = vector::createReadOrMaskedRead(
|
|
rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
|
|
/*useInBoundsInsteadOfMasking=*/false);
|
|
|
|
// Create Xfer write Op
|
|
Value dest = rewriter.create<tensor::EmptyOp>(
|
|
loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
|
|
Operation *write = createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
|
|
newResults.push_back(write->getResult(0));
|
|
return success();
|
|
}
|
|
|
|
// TODO: probably need some extra checks for reduction followed by consumer
|
|
// ops that may not commute (e.g. linear reduction + non-linear instructions).
|
|
static LogicalResult reductionPreconditions(LinalgOp op) {
|
|
if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
|
|
LDBG("reduction precondition failed: no reduction iterator\n");
|
|
return failure();
|
|
}
|
|
for (OpOperand &opOperand : op.getDpsInitsMutable()) {
|
|
AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
|
|
if (indexingMap.isPermutation())
|
|
continue;
|
|
|
|
Operation *reduceOp = matchLinalgReduction(&opOperand);
|
|
if (!reduceOp || !getCombinerOpKind(reduceOp)) {
|
|
LDBG("reduction precondition failed: reduction detection failed\n");
|
|
return failure();
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult
|
|
vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
|
|
bool flatten1DDepthwiseConv) {
|
|
if (flatten1DDepthwiseConv) {
|
|
LDBG("Vectorization of flattened convs with dynamic shapes is not "
|
|
"supported\n");
|
|
return failure();
|
|
}
|
|
|
|
if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
|
|
LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
|
|
return failure();
|
|
}
|
|
|
|
// Support dynamic shapes in 1D depthwise convolution, but only in the
|
|
// _channel_ dimension.
|
|
Value lhs = conv.getDpsInputOperand(0)->get();
|
|
ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
|
|
auto shapeWithoutCh = lhsShape.drop_back(1);
|
|
if (ShapedType::isDynamicShape(shapeWithoutCh)) {
|
|
LDBG("Dynamically-shaped op vectorization precondition failed: only "
|
|
"channel dim can be dynamic\n");
|
|
return failure();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult
|
|
vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
|
|
bool flatten1DDepthwiseConv) {
|
|
if (isa<ConvolutionOpInterface>(op.getOperation()))
|
|
return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
|
|
|
|
if (hasReductionIterator(op))
|
|
return reductionPreconditions(op);
|
|
|
|
// TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
|
|
// linalg.copy ops and ops that implement ContractionOpInterface for now.
|
|
if (!isElementwise(op) &&
|
|
!isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
|
|
op.getOperation()))
|
|
return failure();
|
|
|
|
LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
|
|
return success();
|
|
}
|
|
|
|
/// Need to check if the inner-tiles are static/constant.
|
|
static LogicalResult
|
|
vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
|
|
ArrayRef<int64_t> inputVectorSizes) {
|
|
|
|
if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
|
|
return !getConstantIntValue(res).has_value();
|
|
})) {
|
|
LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
|
|
return failure();
|
|
}
|
|
ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
|
|
bool satisfyEmptyCond = inputVectorSizes.empty() &&
|
|
unpackOp.getDestType().hasStaticShape() &&
|
|
unpackOp.getSourceType().hasStaticShape();
|
|
if (!satisfyEmptyCond &&
|
|
failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult
|
|
vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
|
|
ArrayRef<int64_t> inputVectorSizes) {
|
|
|
|
TypedValue<RankedTensorType> source = sliceOp.getSource();
|
|
auto sourceType = source.getType();
|
|
if (!VectorType::isValidElementType(sourceType.getElementType()))
|
|
return failure();
|
|
|
|
// Get the pad value.
|
|
// TransferReadOp (which is used to vectorize InsertSliceOp), requires a
|
|
// scalar padding value. Note that:
|
|
// * for in-bounds accesses,
|
|
// the value is actually irrelevant. There are 2 cases in which xfer.read
|
|
// accesses are known to be in-bounds:
|
|
// 1. The source shape is static (output vector sizes would be based on
|
|
// the source shape and hence all memory accesses would be in-bounds),
|
|
// 2. Masking is used, i.e. the output vector sizes are user-provided. In
|
|
// this case it is safe to assume that all memory accesses are in-bounds.
|
|
//
|
|
// When the value is not known and not needed, use 0. Otherwise, bail out.
|
|
Value padValue = getStaticPadVal(sliceOp);
|
|
bool isOutOfBoundsRead =
|
|
!sourceType.hasStaticShape() && inputVectorSizes.empty();
|
|
|
|
if (!padValue && isOutOfBoundsRead) {
|
|
LDBG("Failed to get a pad value for out-of-bounds read access\n");
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
namespace {
|
|
enum class ConvOperationKind { Conv, Pool };
|
|
} // namespace
|
|
|
|
static bool isCastOfBlockArgument(Operation *op) {
|
|
return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
|
|
isa<BlockArgument>(op->getOperand(0));
|
|
}
|
|
|
|
// Returns the ConvOperationKind of the op using reduceOp of the generic
|
|
// payload. If it is neither a convolution nor a pooling, it returns
|
|
// std::nullopt.
|
|
//
|
|
// If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
|
|
// + yield) and rhs is not used) then it is the body of a pooling
|
|
// If conv, check for single `mul` predecessor. The `mul` operands must be
|
|
// block arguments or extension of block arguments.
|
|
// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
|
|
// must be block arguments or extension of block arguments.
|
|
static std::optional<ConvOperationKind>
|
|
getConvOperationKind(Operation *reduceOp) {
|
|
int numBlockArguments =
|
|
llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
|
|
|
|
switch (numBlockArguments) {
|
|
case 1: {
|
|
// Will be convolution if feeder is a MulOp.
|
|
// A strength reduced version of MulOp for i1 type is AndOp which is also
|
|
// supported. Otherwise, it can be pooling. This strength reduction logic
|
|
// is in `buildBinaryFn` helper in the Linalg dialect.
|
|
auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
|
|
llvm::IsaPred<BlockArgument>);
|
|
assert(feedValIt != reduceOp->operand_end() &&
|
|
"Expected a non-block argument operand");
|
|
Operation *feedOp = (*feedValIt).getDefiningOp();
|
|
if (isCastOfBlockArgument(feedOp)) {
|
|
return ConvOperationKind::Pool;
|
|
}
|
|
|
|
if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
|
|
(isa<arith::AndIOp>(feedOp) &&
|
|
feedOp->getResultTypes()[0].isInteger(1))) &&
|
|
llvm::all_of(feedOp->getOperands(), [](Value v) {
|
|
if (isa<BlockArgument>(v))
|
|
return true;
|
|
if (Operation *op = v.getDefiningOp())
|
|
return isCastOfBlockArgument(op);
|
|
return false;
|
|
}))) {
|
|
return std::nullopt;
|
|
}
|
|
|
|
return ConvOperationKind::Conv;
|
|
}
|
|
case 2:
|
|
// Must be pooling
|
|
return ConvOperationKind::Pool;
|
|
default:
|
|
return std::nullopt;
|
|
}
|
|
}
|
|
|
|
static bool isSupportedPoolKind(vector::CombiningKind kind) {
|
|
switch (kind) {
|
|
case vector::CombiningKind::ADD:
|
|
case vector::CombiningKind::MAXNUMF:
|
|
case vector::CombiningKind::MAXIMUMF:
|
|
case vector::CombiningKind::MAXSI:
|
|
case vector::CombiningKind::MAXUI:
|
|
case vector::CombiningKind::MINNUMF:
|
|
case vector::CombiningKind::MINIMUMF:
|
|
case vector::CombiningKind::MINSI:
|
|
case vector::CombiningKind::MINUI:
|
|
return true;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
|
|
auto getOperandType = [&](auto operand) {
|
|
return dyn_cast<ShapedType>((operand->get()).getType());
|
|
};
|
|
ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
|
|
ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
|
|
ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
|
|
// (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
|
|
// (non-channeled convolution -> LHS and RHS both have single dimensions).
|
|
// Note that this also ensures 2D and 3D convolutions are rejected.
|
|
if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
|
|
(lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
|
|
return failure();
|
|
|
|
Operation *reduceOp = matchLinalgReduction(convOp.getDpsInitOperand(0));
|
|
if (!reduceOp)
|
|
return failure();
|
|
|
|
auto maybeOper = getConvOperationKind(reduceOp);
|
|
if (!maybeOper.has_value())
|
|
return failure();
|
|
|
|
auto maybeKind = getCombinerOpKind(reduceOp);
|
|
// Typically convolution will have a `Add` CombiningKind but for i1 type it
|
|
// can get strength reduced to `OR` which is also supported. This strength
|
|
// reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
|
|
if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
|
|
*maybeKind != vector::CombiningKind::OR) &&
|
|
(*maybeOper != ConvOperationKind::Pool ||
|
|
!isSupportedPoolKind(*maybeKind)))) {
|
|
return failure();
|
|
}
|
|
|
|
auto rhsRank = rhsShapedType.getRank();
|
|
if (*maybeOper == ConvOperationKind::Pool) {
|
|
if (rhsRank != 1)
|
|
return failure();
|
|
} else {
|
|
if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
|
|
return failure();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult vectorizeLinalgOpPrecondition(
|
|
LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
|
|
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
|
|
// tensor with dimension of 0 cannot be vectorized.
|
|
if (llvm::any_of(linalgOp->getOpOperands(), [&](OpOperand &operand) {
|
|
return llvm::is_contained(linalgOp.getShape(&operand), 0);
|
|
}))
|
|
return failure();
|
|
// Check API contract for input vector sizes.
|
|
if (!inputVectorSizes.empty() &&
|
|
failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
|
|
inputVectorSizes)))
|
|
return failure();
|
|
|
|
if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
|
|
linalgOp, flatten1DDepthwiseConv))) {
|
|
LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
|
|
return failure();
|
|
}
|
|
|
|
SmallVector<CustomVectorizationPrecondition> customPreconditions;
|
|
|
|
// Register CustomVectorizationPrecondition for extractOp.
|
|
customPreconditions.push_back(tensorExtractVectorizationPrecondition);
|
|
|
|
// All types in the body should be a supported element type for VectorType.
|
|
for (Operation &innerOp : linalgOp->getRegion(0).front()) {
|
|
// Check if any custom hook can vectorize the inner op.
|
|
if (llvm::any_of(
|
|
customPreconditions,
|
|
[&](const CustomVectorizationPrecondition &customPrecondition) {
|
|
return succeeded(
|
|
customPrecondition(&innerOp, vectorizeNDExtract));
|
|
})) {
|
|
continue;
|
|
}
|
|
if (!llvm::all_of(innerOp.getOperandTypes(),
|
|
VectorType::isValidElementType)) {
|
|
return failure();
|
|
}
|
|
if (!llvm::all_of(innerOp.getResultTypes(),
|
|
VectorType::isValidElementType)) {
|
|
return failure();
|
|
}
|
|
}
|
|
if (isElementwise(linalgOp))
|
|
return success();
|
|
|
|
// TODO: isaConvolutionOpInterface that can also infer from generic
|
|
// features. But we will still need stride/dilation attributes that will be
|
|
// annoying to reverse-engineer...
|
|
if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
|
|
return vectorizeConvOpPrecondition(linalgOp);
|
|
|
|
// TODO: the common vector shape is equal to the static loop sizes only when
|
|
// all indexing maps are projected permutations. For convs and stencils the
|
|
// logic will need to evolve.
|
|
if (!allIndexingsAreProjectedPermutation(linalgOp)) {
|
|
LDBG("precondition failed: not projected permutations\n");
|
|
return failure();
|
|
}
|
|
if (failed(reductionPreconditions(linalgOp))) {
|
|
LDBG("precondition failed: reduction preconditions\n");
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult
|
|
vectorizePackOpPrecondition(linalg::PackOp packOp,
|
|
ArrayRef<int64_t> inputVectorSizes) {
|
|
auto padValue = packOp.getPaddingValue();
|
|
Attribute cstAttr;
|
|
if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
|
|
LDBG("pad value is not constant: " << packOp << "\n");
|
|
return failure();
|
|
}
|
|
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
|
|
bool satisfyEmptyCond = true;
|
|
if (inputVectorSizes.empty()) {
|
|
if (!packOp.getDestType().hasStaticShape() ||
|
|
!packOp.getSourceType().hasStaticShape())
|
|
satisfyEmptyCond = false;
|
|
}
|
|
|
|
if (!satisfyEmptyCond &&
|
|
failed(vector::isValidMaskedInputVector(
|
|
resultTensorShape.take_front(packOp.getSourceRank()),
|
|
inputVectorSizes)))
|
|
return failure();
|
|
|
|
if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
|
|
return !getConstantIntValue(v).has_value();
|
|
})) {
|
|
LDBG("inner_tiles must be constant: " << packOp << "\n");
|
|
return failure();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult
|
|
vectorizePadOpPrecondition(tensor::PadOp padOp,
|
|
ArrayRef<int64_t> inputVectorSizes) {
|
|
auto padValue = padOp.getConstantPaddingValue();
|
|
if (!padValue) {
|
|
LDBG("pad value is not constant: " << padOp << "\n");
|
|
return failure();
|
|
}
|
|
|
|
ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
|
|
if (failed(vector::isValidMaskedInputVector(resultTensorShape,
|
|
inputVectorSizes)))
|
|
return failure();
|
|
|
|
// Padding with non-zero low pad values is not supported, unless the
|
|
// corresponding result dim is 1 as this would require shifting the results to
|
|
// the right for the low padded dims by the required amount of low padding.
|
|
// However, we do support low padding if the dims being low padded have result
|
|
// sizes of 1. The reason is when we have a low pad on a unit result dim, the
|
|
// input size of that dimension will be dynamically zero (as the sum of the
|
|
// low pad and input dim size has to be one) and hence we will create a zero
|
|
// mask as the lowering logic just makes the mask one for the input dim size -
|
|
// which is zero here. Hence we will load the pad value which is what we want
|
|
// in this case. If the low pad is dynamically zero then the lowering is
|
|
// correct as well as no shifts are necessary.
|
|
if (llvm::any_of(llvm::enumerate(padOp.getLow()), [&](const auto &en) {
|
|
Value padValue = en.value();
|
|
unsigned pos = en.index();
|
|
std::optional<int64_t> pad = getConstantIntValue(padValue);
|
|
return (!pad.has_value() || pad.value() != 0) &&
|
|
resultTensorShape[pos] != 1;
|
|
})) {
|
|
LDBG("low pad must all be zero for all non unit dims: " << padOp << "\n");
|
|
return failure();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Preconditions for scalable vectors. This is quite restrictive - it models
|
|
/// the fact that in practice we would only make selected dimensions scalable.
|
|
static LogicalResult
|
|
vectorizeScalableVectorPrecondition(Operation *op,
|
|
ArrayRef<int64_t> inputVectorSizes,
|
|
ArrayRef<bool> inputScalableVecDims) {
|
|
assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
|
|
"Number of input vector sizes and scalable dims doesn't match");
|
|
|
|
size_t numOfScalableDims =
|
|
llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
|
|
|
|
if (numOfScalableDims == 0)
|
|
return success();
|
|
|
|
auto linalgOp = dyn_cast<LinalgOp>(op);
|
|
|
|
// Cond 1: There's been no need for scalable vectorisation of
|
|
// non-linalg Ops so far
|
|
if (!linalgOp)
|
|
return failure();
|
|
|
|
// Cond 2: There's been no need for more than 2 scalable dims so far
|
|
if (numOfScalableDims > 2)
|
|
return failure();
|
|
|
|
// Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
|
|
// it matches one of the supported cases:
|
|
// 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
|
|
// (*).
|
|
// 2. Exactly 2 dims are scalable and those are the _last two adjacent_
|
|
// parallel dims.
|
|
// 3. Exactly 1 reduction dim is scalable and that's the last (innermost)
|
|
// dim.
|
|
// The 2nd restriction above means that only Matmul-like Ops are supported
|
|
// when 2 dims are scalable, e.g. :
|
|
// * iterators = [parallel, parallel, reduction]
|
|
// * scalable flags = [true, true, false]
|
|
//
|
|
// (*) Non-unit dims get folded away in practice.
|
|
// TODO: Relax these conditions as good motivating examples are identified.
|
|
|
|
// Find the first scalable flag.
|
|
bool seenNonUnitParallel = false;
|
|
auto iterators = linalgOp.getIteratorTypesArray();
|
|
SmallVector<bool> scalableFlags(inputScalableVecDims);
|
|
int64_t idx = scalableFlags.size() - 1;
|
|
while (!scalableFlags[idx]) {
|
|
bool isNonUnitDim = (inputVectorSizes[idx] != 1);
|
|
seenNonUnitParallel |=
|
|
(iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
|
|
|
|
iterators.pop_back();
|
|
scalableFlags.pop_back();
|
|
--idx;
|
|
}
|
|
|
|
// Analyze the iterator corresponding to the first scalable dim.
|
|
switch (iterators.back()) {
|
|
case utils::IteratorType::reduction: {
|
|
// Check 3. above is met.
|
|
if (iterators.size() != inputVectorSizes.size()) {
|
|
LDBG("Non-trailing reduction dim requested for scalable "
|
|
"vectorization\n");
|
|
return failure();
|
|
}
|
|
if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
|
|
LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
|
|
"is not supported\n");
|
|
return failure();
|
|
}
|
|
break;
|
|
}
|
|
case utils::IteratorType::parallel: {
|
|
// Check 1. and 2. above are met.
|
|
if (seenNonUnitParallel) {
|
|
LDBG("Inner parallel dim not requested for scalable "
|
|
"vectorization\n");
|
|
return failure();
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
|
|
// If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
|
|
// supported for which expect the folowing config:
|
|
// * iterators = [parallel, parallel, reduction]
|
|
// * scalable flags = [true, true, false]
|
|
if (numOfScalableDims == 2) {
|
|
// Disallow below case which breaks 3. above:
|
|
// * iterators = [..., parallel, reduction]
|
|
// * scalable flags = [..., true, true]
|
|
if (iterators.back() == utils::IteratorType::reduction) {
|
|
LDBG("Higher dim than the trailing reduction dim requested for scalable "
|
|
"vectorization\n");
|
|
return failure();
|
|
}
|
|
scalableFlags.pop_back();
|
|
iterators.pop_back();
|
|
|
|
if (!scalableFlags.back() ||
|
|
(iterators.back() != utils::IteratorType::parallel))
|
|
return failure();
|
|
}
|
|
|
|
// Check to not let go the matmul with extended semantic, through this
|
|
// transform.
|
|
if (linalgOp.hasUserDefinedMaps())
|
|
return failure();
|
|
|
|
// Cond 4: Only the following ops are supported in the
|
|
// presence of scalable vectors
|
|
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
|
|
isa<linalg::MatmulTransposeAOp>(op) ||
|
|
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
|
|
isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
|
|
}
|
|
|
|
LogicalResult mlir::linalg::vectorizeOpPrecondition(
|
|
Operation *op, ArrayRef<int64_t> inputVectorSizes,
|
|
ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
|
|
bool flatten1DDepthwiseConv) {
|
|
|
|
if (!hasVectorizationImpl(op))
|
|
return failure();
|
|
|
|
if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
|
|
inputScalableVecDims)))
|
|
return failure();
|
|
|
|
return TypeSwitch<Operation *, LogicalResult>(op)
|
|
.Case<linalg::LinalgOp>([&](auto linalgOp) {
|
|
return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
|
|
vectorizeNDExtract,
|
|
flatten1DDepthwiseConv);
|
|
})
|
|
.Case<tensor::PadOp>([&](auto padOp) {
|
|
return vectorizePadOpPrecondition(padOp, inputVectorSizes);
|
|
})
|
|
.Case<linalg::PackOp>([&](auto packOp) {
|
|
return vectorizePackOpPrecondition(packOp, inputVectorSizes);
|
|
})
|
|
.Case<linalg::UnPackOp>([&](auto unpackOp) {
|
|
return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
|
|
})
|
|
.Case<tensor::InsertSliceOp>([&](auto sliceOp) {
|
|
return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
|
|
})
|
|
.Default([](auto) { return failure(); });
|
|
}
|
|
|
|
/// Converts affine.apply Ops to arithmetic operations.
|
|
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
|
|
|
|
for (auto op : make_early_inc_range(toReplace)) {
|
|
rewriter.setInsertionPoint(op);
|
|
auto expanded = affine::expandAffineExpr(
|
|
rewriter, op->getLoc(), op.getAffineMap().getResult(0),
|
|
op.getOperands().take_front(op.getAffineMap().getNumDims()),
|
|
op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
|
|
rewriter.replaceOp(op, expanded);
|
|
}
|
|
}
|
|
|
|
bool mlir::linalg::hasVectorizationImpl(Operation *op) {
|
|
return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
|
|
tensor::InsertSliceOp>(op);
|
|
}
|
|
|
|
FailureOr<VectorizationResult>
|
|
mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
|
|
ArrayRef<int64_t> inputVectorSizes,
|
|
ArrayRef<bool> inputScalableVecDims,
|
|
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
|
|
LDBG("Attempting to vectorize:\n" << *op << "\n");
|
|
LDBG("Input vector sizes: ");
|
|
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
|
|
LLVM_DEBUG(llvm::dbgs() << "\n");
|
|
LDBG("Input scalable vector dims: ");
|
|
LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
|
|
LLVM_DEBUG(llvm::dbgs() << "\n");
|
|
|
|
if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
|
|
vectorizeNDExtract,
|
|
flatten1DDepthwiseConv))) {
|
|
LDBG("Vectorization pre-conditions failed\n");
|
|
return failure();
|
|
}
|
|
|
|
// Initialize vectorization state.
|
|
VectorizationState state(rewriter);
|
|
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
|
|
if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
|
|
inputScalableVecDims))) {
|
|
LDBG("Vectorization state couldn't be initialized\n");
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
SmallVector<Value> results;
|
|
auto vectorizeResult =
|
|
TypeSwitch<Operation *, LogicalResult>(op)
|
|
.Case<linalg::LinalgOp>([&](auto linalgOp) {
|
|
// TODO: isaConvolutionOpInterface that can also infer from
|
|
// generic features. Will require stride/dilation attributes
|
|
// inference.
|
|
if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
|
|
FailureOr<Operation *> convOr = vectorizeConvolution(
|
|
rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
|
|
flatten1DDepthwiseConv);
|
|
if (succeeded(convOr)) {
|
|
llvm::append_range(results, (*convOr)->getResults());
|
|
return success();
|
|
}
|
|
|
|
LDBG("Unsupported convolution can't be vectorized.\n");
|
|
return failure();
|
|
}
|
|
|
|
LDBG("Vectorize generic by broadcasting to the canonical vector "
|
|
"shape\n");
|
|
|
|
// Pre-process before proceeding.
|
|
convertAffineApply(rewriter, linalgOp);
|
|
|
|
// TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
|
|
// to 'OpBuilder' when it is passed over to some methods like
|
|
// 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
|
|
// erase an op within these methods, the actual rewriter won't be
|
|
// notified and we will end up with read-after-free issues!
|
|
return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
|
|
})
|
|
.Case<tensor::PadOp>([&](auto padOp) {
|
|
return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
|
|
results);
|
|
})
|
|
.Case<linalg::PackOp>([&](auto packOp) {
|
|
return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
|
|
results);
|
|
})
|
|
.Case<linalg::UnPackOp>([&](auto unpackOp) {
|
|
return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
|
|
inputVectorSizes, results);
|
|
})
|
|
.Case<tensor::InsertSliceOp>([&](auto sliceOp) {
|
|
return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
|
|
results);
|
|
})
|
|
.Default([](auto) { return failure(); });
|
|
|
|
if (failed(vectorizeResult)) {
|
|
LDBG("Vectorization failed\n");
|
|
return failure();
|
|
}
|
|
|
|
return VectorizationResult{results};
|
|
}
|
|
|
|
LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
|
|
memref::CopyOp copyOp) {
|
|
auto srcType = cast<MemRefType>(copyOp.getSource().getType());
|
|
auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
|
|
if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
|
|
return failure();
|
|
|
|
auto srcElementType = getElementTypeOrSelf(srcType);
|
|
auto dstElementType = getElementTypeOrSelf(dstType);
|
|
if (!VectorType::isValidElementType(srcElementType) ||
|
|
!VectorType::isValidElementType(dstElementType))
|
|
return failure();
|
|
|
|
auto readType = VectorType::get(srcType.getShape(), srcElementType);
|
|
auto writeType = VectorType::get(dstType.getShape(), dstElementType);
|
|
|
|
Location loc = copyOp->getLoc();
|
|
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
|
SmallVector<Value> indices(srcType.getRank(), zero);
|
|
|
|
Value readValue = rewriter.create<vector::TransferReadOp>(
|
|
loc, readType, copyOp.getSource(), indices,
|
|
rewriter.getMultiDimIdentityMap(srcType.getRank()));
|
|
if (cast<VectorType>(readValue.getType()).getRank() == 0) {
|
|
readValue =
|
|
rewriter.create<vector::ExtractOp>(loc, readValue, ArrayRef<int64_t>());
|
|
readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
|
|
}
|
|
Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
|
|
loc, readValue, copyOp.getTarget(), indices,
|
|
rewriter.getMultiDimIdentityMap(srcType.getRank()));
|
|
rewriter.replaceOp(copyOp, writeValue->getResults());
|
|
return success();
|
|
}
|
|
|
|
//----------------------------------------------------------------------------//
|
|
// Misc. vectorization patterns.
|
|
//----------------------------------------------------------------------------//
|
|
/// Base pattern for rewriting tensor::PadOps whose result is consumed by a
|
|
/// given operation type OpTy.
|
|
template <typename OpTy>
|
|
struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
|
|
using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::PadOp padOp,
|
|
PatternRewriter &rewriter) const final {
|
|
bool changed = false;
|
|
// Insert users in vector, because some users may be replaced/removed.
|
|
for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
|
|
if (auto op = dyn_cast<OpTy>(user))
|
|
changed |= rewriteUser(rewriter, padOp, op).succeeded();
|
|
return success(changed);
|
|
}
|
|
|
|
protected:
|
|
virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
|
|
tensor::PadOp padOp, OpTy op) const = 0;
|
|
};
|
|
|
|
/// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
|
|
/// ```
|
|
/// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
|
|
/// %r = vector.transfer_read %0[%c0, %c0], %cst
|
|
/// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
|
|
/// ```
|
|
/// is rewritten to:
|
|
/// ```
|
|
/// %r = vector.transfer_read %src[%c0, %c0], %padding
|
|
/// {in_bounds = [true, true]}
|
|
/// : tensor<?x?xf32>, vector<17x5xf32>
|
|
/// ```
|
|
/// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
|
|
/// sure that the original padding value %cst was never used.
|
|
///
|
|
/// This rewrite is possible if:
|
|
/// - `xferOp` has no out-of-bounds dims or mask.
|
|
/// - Low padding is static 0.
|
|
/// - Single, scalar padding value.
|
|
struct PadOpVectorizationWithTransferReadPattern
|
|
: public VectorizePadOpUserPattern<vector::TransferReadOp> {
|
|
using VectorizePadOpUserPattern<
|
|
vector::TransferReadOp>::VectorizePadOpUserPattern;
|
|
|
|
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
|
|
vector::TransferReadOp xferOp) const override {
|
|
// Low padding must be static 0.
|
|
if (!padOp.hasZeroLowPad())
|
|
return failure();
|
|
// Pad value must be a constant.
|
|
auto padValue = padOp.getConstantPaddingValue();
|
|
if (!padValue)
|
|
return failure();
|
|
// Padding value of existing `xferOp` is unused.
|
|
if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
|
|
return failure();
|
|
|
|
rewriter.modifyOpInPlace(xferOp, [&]() {
|
|
SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
|
|
xferOp->setAttr(xferOp.getInBoundsAttrName(),
|
|
rewriter.getBoolArrayAttr(inBounds));
|
|
xferOp.getBaseMutable().assign(padOp.getSource());
|
|
xferOp.getPaddingMutable().assign(padValue);
|
|
});
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Rewrite use of tensor::PadOp result in TransferWriteOp.
|
|
/// This pattern rewrites TransferWriteOps that write to a padded tensor
|
|
/// value, where the same amount of padding is immediately removed again after
|
|
/// the write. In such cases, the TransferWriteOp can write to the non-padded
|
|
/// tensor value and apply out-of-bounds masking. E.g.:
|
|
/// ```
|
|
/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
|
|
/// : tensor<...> to tensor<?x?xf32>
|
|
/// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
|
|
/// %2 = vector.transfer_write %vec, %1[...]
|
|
/// : vector<17x5xf32>, tensor<17x5xf32>
|
|
/// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
|
|
/// : tensor<17x5xf32> to tensor<?x?xf32>
|
|
/// ```
|
|
/// is rewritten to:
|
|
/// ```
|
|
/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
|
|
/// : tensor<...> to tensor<?x?xf32>
|
|
/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
|
|
/// tensor<?x?xf32>
|
|
/// ```
|
|
/// Note: It is important that the ExtractSliceOp %r resizes the result of the
|
|
/// TransferWriteOp to the same size as the input of the TensorPadOp (or an
|
|
/// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
|
|
/// from %r's old dimensions.
|
|
///
|
|
/// This rewrite is possible if:
|
|
/// - Low padding is static 0.
|
|
/// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
|
|
/// ExtractSliceOp trims the same amount of padding that was added
|
|
/// beforehand.
|
|
/// - Single, scalar padding value.
|
|
struct PadOpVectorizationWithTransferWritePattern
|
|
: public VectorizePadOpUserPattern<vector::TransferWriteOp> {
|
|
using VectorizePadOpUserPattern<
|
|
vector::TransferWriteOp>::VectorizePadOpUserPattern;
|
|
|
|
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
|
|
vector::TransferWriteOp xferOp) const override {
|
|
// TODO: support 0-d corner case.
|
|
if (xferOp.getTransferRank() == 0)
|
|
return failure();
|
|
|
|
// Low padding must be static 0.
|
|
if (!padOp.hasZeroLowPad())
|
|
return failure();
|
|
// Pad value must be a constant.
|
|
auto padValue = padOp.getConstantPaddingValue();
|
|
if (!padValue)
|
|
return failure();
|
|
// TransferWriteOp result must be directly consumed by an ExtractSliceOp.
|
|
if (!xferOp->hasOneUse())
|
|
return failure();
|
|
auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
|
|
if (!trimPadding)
|
|
return failure();
|
|
// Only static zero offsets supported when trimming padding.
|
|
if (!trimPadding.hasZeroOffset())
|
|
return failure();
|
|
// trimPadding must remove the amount of padding that was added earlier.
|
|
if (!hasSameTensorSize(padOp.getSource(), trimPadding))
|
|
return failure();
|
|
|
|
// Insert the new TransferWriteOp at position of the old TransferWriteOp.
|
|
rewriter.setInsertionPoint(xferOp);
|
|
|
|
SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
|
|
auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
|
|
xferOp, padOp.getSource().getType(), xferOp.getVector(),
|
|
padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
|
|
xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
|
|
rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Check if `beforePadding` and `afterTrimming` have the same tensor size,
|
|
/// i.e., same dimensions.
|
|
///
|
|
/// Dimensions may be static, dynamic or mix of both. In case of dynamic
|
|
/// dimensions, this function tries to infer the (static) tensor size by
|
|
/// looking at the defining op and utilizing op-specific knowledge.
|
|
///
|
|
/// This is a conservative analysis. In case equal tensor sizes cannot be
|
|
/// proven statically, this analysis returns `false` even though the tensor
|
|
/// sizes may turn out to be equal at runtime.
|
|
bool hasSameTensorSize(Value beforePadding,
|
|
tensor::ExtractSliceOp afterTrimming) const {
|
|
// If the input to tensor::PadOp is a CastOp, try with both CastOp
|
|
// result and CastOp operand.
|
|
if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
|
|
if (hasSameTensorSize(castOp.getSource(), afterTrimming))
|
|
return true;
|
|
|
|
auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
|
|
auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
|
|
// Only RankedTensorType supported.
|
|
if (!t1 || !t2)
|
|
return false;
|
|
// Rank of both values must be the same.
|
|
if (t1.getRank() != t2.getRank())
|
|
return false;
|
|
|
|
// All static dimensions must be the same. Mixed cases (e.g., dimension
|
|
// static in `t1` but dynamic in `t2`) are not supported.
|
|
for (unsigned i = 0; i < t1.getRank(); ++i) {
|
|
if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
|
|
return false;
|
|
if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
|
|
return false;
|
|
}
|
|
|
|
// Nothing more to check if all dimensions are static.
|
|
if (t1.getNumDynamicDims() == 0)
|
|
return true;
|
|
|
|
// All dynamic sizes must be the same. The only supported case at the
|
|
// moment is when `beforePadding` is an ExtractSliceOp (or a cast
|
|
// thereof).
|
|
|
|
// Apart from CastOp, only ExtractSliceOp is supported.
|
|
auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
|
|
if (!beforeSlice)
|
|
return false;
|
|
|
|
assert(static_cast<size_t>(t1.getRank()) ==
|
|
beforeSlice.getMixedSizes().size());
|
|
assert(static_cast<size_t>(t2.getRank()) ==
|
|
afterTrimming.getMixedSizes().size());
|
|
|
|
for (unsigned i = 0; i < t1.getRank(); ++i) {
|
|
// Skip static dimensions.
|
|
if (!t1.isDynamicDim(i))
|
|
continue;
|
|
auto size1 = beforeSlice.getMixedSizes()[i];
|
|
auto size2 = afterTrimming.getMixedSizes()[i];
|
|
|
|
// Case 1: Same value or same constant int.
|
|
if (isEqualConstantIntOrValue(size1, size2))
|
|
continue;
|
|
|
|
// Other cases: Take a deeper look at defining ops of values.
|
|
auto v1 = llvm::dyn_cast_if_present<Value>(size1);
|
|
auto v2 = llvm::dyn_cast_if_present<Value>(size2);
|
|
if (!v1 || !v2)
|
|
return false;
|
|
|
|
// Case 2: Both values are identical AffineMinOps. (Should not happen if
|
|
// CSE is run.)
|
|
auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
|
|
auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
|
|
if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
|
|
minOp1.getOperands() == minOp2.getOperands())
|
|
continue;
|
|
|
|
// Add additional cases as needed.
|
|
}
|
|
|
|
// All tests passed.
|
|
return true;
|
|
}
|
|
};
|
|
|
|
/// Returns the effective Pad value for the input op, provided it's a scalar.
|
|
///
|
|
/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
|
|
/// this Op performs padding, retrieve the padding value provided that it's
|
|
/// a scalar and static/fixed for all the padded values. Returns an empty value
|
|
/// otherwise.
|
|
///
|
|
/// TODO: This is used twice (when checking vectorization pre-conditions and
|
|
/// when vectorizing). Cache results instead of re-running.
|
|
static Value getStaticPadVal(Operation *op) {
|
|
if (!op)
|
|
return {};
|
|
|
|
// 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
|
|
// being broadcast, provided that it's a scalar.
|
|
if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
|
|
auto source = bcast.getSource();
|
|
if (llvm::dyn_cast<VectorType>(source.getType()))
|
|
return {};
|
|
|
|
return source;
|
|
}
|
|
|
|
// 2. linalg.fill - use the scalar input value that used to fill the output
|
|
// tensor.
|
|
if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
|
|
return fill.getInputs()[0];
|
|
}
|
|
|
|
// 3. tensor.generateOp - can't guarantee the value is fixed without
|
|
// analysing, bail out.
|
|
if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
|
|
return {};
|
|
}
|
|
|
|
// 4. vector.transfer_write - inspect the input vector that's written from. If
|
|
// if contains a single value that has been broadcast (e.g. via
|
|
// vector.broadcast), extract it, fail otherwise.
|
|
if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
|
|
return getStaticPadVal(xferWrite.getVector().getDefiningOp());
|
|
|
|
// 5. tensor.insert_slice - inspect the destination tensor. If it's larger
|
|
// than the input tensor, then, provided it's constant, we'll extract the
|
|
// value that was used to generate it (via e.g. linalg.fill), fail otherwise.
|
|
// TODO: Clarify the semantics when the input tensor is larger than the
|
|
// destination.
|
|
if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
|
|
return getStaticPadVal(slice.getDest().getDefiningOp());
|
|
|
|
return {};
|
|
}
|
|
|
|
static LogicalResult
|
|
vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
|
|
ArrayRef<int64_t> inputVectorSizes,
|
|
SmallVectorImpl<Value> &newResults) {
|
|
// TODO: Introduce a parent class that will handle the insertion point update.
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
rewriter.setInsertionPoint(sliceOp);
|
|
|
|
TypedValue<RankedTensorType> source = sliceOp.getSource();
|
|
auto sourceType = source.getType();
|
|
auto resultType = sliceOp.getResultType();
|
|
|
|
Value padValue = getStaticPadVal(sliceOp);
|
|
|
|
if (!padValue) {
|
|
auto elemType = sourceType.getElementType();
|
|
padValue = rewriter.create<arith::ConstantOp>(
|
|
sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
|
|
}
|
|
|
|
// 2. Get the vector shape
|
|
SmallVector<int64_t> vecShape;
|
|
size_t rankDiff = resultType.getRank() - sourceType.getRank();
|
|
for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
|
|
if (!inputVectorSizes.empty()) {
|
|
vecShape.push_back(inputVectorSizes[i]);
|
|
} else if (!sourceType.isDynamicDim(i)) {
|
|
vecShape.push_back(sourceType.getDimSize(i));
|
|
} else if (!resultType.isDynamicDim(i)) {
|
|
// Source shape is not statically known, but result shape is.
|
|
// Vectorize with size of result shape. This may be larger than the
|
|
// source size.
|
|
// FIXME: Using rankDiff implies that the source tensor is inserted at
|
|
// the end of the destination tensor. However, that's not required.
|
|
vecShape.push_back(resultType.getDimSize(rankDiff + i));
|
|
} else {
|
|
// Neither source nor result dim of padOp is static. Cannot vectorize
|
|
// the copy.
|
|
return failure();
|
|
}
|
|
}
|
|
auto vecType = VectorType::get(vecShape, sourceType.getElementType());
|
|
|
|
// 3. Generate TransferReadOp + TransferWriteOp
|
|
auto loc = sliceOp.getLoc();
|
|
|
|
// Create read
|
|
SmallVector<Value> readIndices(
|
|
vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
|
|
Value read = mlir::vector::createReadOrMaskedRead(
|
|
rewriter, loc, source, vecType.getShape(), padValue,
|
|
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
|
|
|
|
// Create write
|
|
auto writeIndices =
|
|
getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
|
|
Operation *write =
|
|
createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
|
|
writeIndices, inputVectorSizes.empty());
|
|
|
|
// 4. Finalize
|
|
newResults.push_back(write->getResult(0));
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
|
|
/// ```
|
|
/// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
|
|
/// %r = tensor.insert_slice %0
|
|
/// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
|
|
/// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
|
|
/// ```
|
|
/// is rewritten to:
|
|
/// ```
|
|
/// %0 = vector.transfer_read %src[%c0, %c0], %padding
|
|
/// : tensor<?x?xf32>, vector<17x5xf32>
|
|
/// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
|
|
/// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
|
|
/// ```
|
|
///
|
|
/// This rewrite is possible if:
|
|
/// - Low padding is static 0.
|
|
/// - `padOp` result shape is static.
|
|
/// - The entire padded tensor is inserted.
|
|
/// (Implies that sizes of `insertOp` are all static.)
|
|
/// - Only unit strides in `insertOp`.
|
|
/// - Single, scalar padding value.
|
|
/// - `padOp` result not used as destination.
|
|
struct PadOpVectorizationWithInsertSlicePattern
|
|
: public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
|
|
using VectorizePadOpUserPattern<
|
|
tensor::InsertSliceOp>::VectorizePadOpUserPattern;
|
|
|
|
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
|
|
tensor::InsertSliceOp insertOp) const override {
|
|
// Low padding must be static 0.
|
|
if (!padOp.hasZeroLowPad())
|
|
return failure();
|
|
// Only unit stride supported.
|
|
if (!insertOp.hasUnitStride())
|
|
return failure();
|
|
// Pad value must be a constant.
|
|
auto padValue = padOp.getConstantPaddingValue();
|
|
if (!padValue)
|
|
return failure();
|
|
// Dynamic shapes not supported.
|
|
if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
|
|
return failure();
|
|
// Pad result not used as destination.
|
|
if (insertOp.getDest() == padOp.getResult())
|
|
return failure();
|
|
|
|
auto vecType = VectorType::get(padOp.getType().getShape(),
|
|
padOp.getType().getElementType());
|
|
unsigned vecRank = vecType.getRank();
|
|
unsigned tensorRank = insertOp.getType().getRank();
|
|
|
|
// Check if sizes match: Insert the entire tensor into most minor dims.
|
|
// (No permutations allowed.)
|
|
SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
|
|
expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
|
|
if (!llvm::all_of(
|
|
llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
|
|
return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
|
|
}))
|
|
return failure();
|
|
|
|
// Insert the TransferReadOp and TransferWriteOp at the position of the
|
|
// InsertSliceOp.
|
|
rewriter.setInsertionPoint(insertOp);
|
|
|
|
// Generate TransferReadOp: Read entire source tensor and add high
|
|
// padding.
|
|
SmallVector<Value> readIndices(
|
|
vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
|
|
auto read = rewriter.create<vector::TransferReadOp>(
|
|
padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
|
|
|
|
// Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
|
|
// specified offsets. Write is fully in-bounds because a InsertSliceOp's
|
|
// source must fit into the destination at the specified offsets.
|
|
auto writeIndices = getValueOrCreateConstantIndexOp(
|
|
rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
|
|
SmallVector<bool> inBounds(vecRank, true);
|
|
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
|
|
insertOp, read, insertOp.getDest(), writeIndices,
|
|
ArrayRef<bool>{inBounds});
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void mlir::linalg::populatePadOpVectorizationPatterns(
|
|
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
|
|
patterns.add<PadOpVectorizationWithTransferReadPattern,
|
|
PadOpVectorizationWithTransferWritePattern,
|
|
PadOpVectorizationWithInsertSlicePattern>(
|
|
patterns.getContext(), baseBenefit.getBenefit() + 1);
|
|
}
|
|
|
|
//----------------------------------------------------------------------------//
|
|
// Forwarding patterns
|
|
//----------------------------------------------------------------------------//
|
|
|
|
/// Check whether there is any interleaved use of any `values` between
|
|
/// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
|
|
/// is in a different block.
|
|
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
|
|
ValueRange values) {
|
|
if (firstOp->getBlock() != secondOp->getBlock() ||
|
|
!firstOp->isBeforeInBlock(secondOp)) {
|
|
LDBG("interleavedUses precondition failed, firstOp: "
|
|
<< *firstOp << ", second op: " << *secondOp << "\n");
|
|
return true;
|
|
}
|
|
for (auto v : values) {
|
|
for (auto &u : v.getUses()) {
|
|
Operation *owner = u.getOwner();
|
|
if (owner == firstOp || owner == secondOp)
|
|
continue;
|
|
// TODO: this is too conservative, use dominance info in the future.
|
|
if (owner->getBlock() == firstOp->getBlock() &&
|
|
(owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
|
|
continue;
|
|
LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
|
|
<< ", second op: " << *secondOp << "\n");
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/// Return the unique subview use of `v` if it is indeed unique, null
|
|
/// otherwise.
|
|
static memref::SubViewOp getSubViewUseIfUnique(Value v) {
|
|
memref::SubViewOp subViewOp;
|
|
for (auto &u : v.getUses()) {
|
|
if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
|
|
if (subViewOp)
|
|
return memref::SubViewOp();
|
|
subViewOp = newSubViewOp;
|
|
}
|
|
}
|
|
return subViewOp;
|
|
}
|
|
|
|
/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
|
|
/// when available.
|
|
LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
|
|
vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
|
|
|
|
// TODO: support mask.
|
|
if (xferOp.getMask())
|
|
return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
|
|
|
|
// Transfer into `view`.
|
|
Value viewOrAlloc = xferOp.getBase();
|
|
if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
|
|
!viewOrAlloc.getDefiningOp<memref::AllocOp>())
|
|
return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
|
|
|
|
// Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
|
|
memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
|
|
if (!subViewOp)
|
|
return rewriter.notifyMatchFailure(xferOp, "no subview found");
|
|
Value subView = subViewOp.getResult();
|
|
|
|
// Find the copy into `subView` without interleaved uses.
|
|
memref::CopyOp copyOp;
|
|
for (auto &u : subView.getUses()) {
|
|
if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
|
|
assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
|
|
if (newCopyOp.getTarget() != subView)
|
|
continue;
|
|
if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
|
|
continue;
|
|
copyOp = newCopyOp;
|
|
break;
|
|
}
|
|
}
|
|
if (!copyOp)
|
|
return rewriter.notifyMatchFailure(xferOp, "no copy found");
|
|
|
|
// Find the fill into `viewOrAlloc` without interleaved uses before the
|
|
// copy.
|
|
FillOp maybeFillOp;
|
|
for (auto &u : viewOrAlloc.getUses()) {
|
|
if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
|
|
assert(isa<MemRefType>(newFillOp.output().getType()));
|
|
if (newFillOp.output() != viewOrAlloc)
|
|
continue;
|
|
if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
|
|
continue;
|
|
maybeFillOp = newFillOp;
|
|
break;
|
|
}
|
|
}
|
|
// Ensure padding matches.
|
|
if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
|
|
return rewriter.notifyMatchFailure(xferOp,
|
|
"padding value does not match fill");
|
|
|
|
// `in` is the subview that memref.copy reads. Replace it.
|
|
Value in = copyOp.getSource();
|
|
|
|
// memref.copy + linalg.fill can be used to create a padded local buffer.
|
|
// The `masked` attribute is only valid on this padded buffer.
|
|
// When forwarding to vector.transfer_read, the attribute must be reset
|
|
// conservatively.
|
|
auto vectorType = xferOp.getVectorType();
|
|
Value res = rewriter.create<vector::TransferReadOp>(
|
|
xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
|
|
xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
|
|
rewriter.getBoolArrayAttr(
|
|
SmallVector<bool>(vectorType.getRank(), false)));
|
|
|
|
if (maybeFillOp)
|
|
rewriter.eraseOp(maybeFillOp);
|
|
rewriter.eraseOp(copyOp);
|
|
rewriter.replaceOp(xferOp, res);
|
|
|
|
return success();
|
|
}
|
|
|
|
/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
|
|
/// when available.
|
|
LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
|
|
vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
|
|
// TODO: support mask.
|
|
if (xferOp.getMask())
|
|
return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
|
|
|
|
// Transfer into `viewOrAlloc`.
|
|
Value viewOrAlloc = xferOp.getBase();
|
|
if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
|
|
!viewOrAlloc.getDefiningOp<memref::AllocOp>())
|
|
return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
|
|
|
|
// Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
|
|
memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
|
|
if (!subViewOp)
|
|
return rewriter.notifyMatchFailure(xferOp, "no subview found");
|
|
Value subView = subViewOp.getResult();
|
|
|
|
// Find the copy from `subView` without interleaved uses.
|
|
memref::CopyOp copyOp;
|
|
for (auto &u : subViewOp.getResult().getUses()) {
|
|
if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
|
|
if (newCopyOp.getSource() != subView)
|
|
continue;
|
|
if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
|
|
continue;
|
|
copyOp = newCopyOp;
|
|
break;
|
|
}
|
|
}
|
|
if (!copyOp)
|
|
return rewriter.notifyMatchFailure(xferOp, "no copy found");
|
|
|
|
// `out` is the subview copied into that we replace.
|
|
assert(isa<MemRefType>(copyOp.getTarget().getType()));
|
|
Value out = copyOp.getTarget();
|
|
|
|
// Forward vector.transfer into copy.
|
|
// memref.copy + linalg.fill can be used to create a padded local buffer.
|
|
// The `masked` attribute is only valid on this padded buffer.
|
|
// When forwarding to vector.transfer_write, the attribute must be reset
|
|
// conservatively.
|
|
auto vector = xferOp.getVector();
|
|
rewriter.create<vector::TransferWriteOp>(
|
|
xferOp.getLoc(), vector, out, xferOp.getIndices(),
|
|
xferOp.getPermutationMapAttr(), xferOp.getMask(),
|
|
rewriter.getBoolArrayAttr(SmallVector<bool>(
|
|
dyn_cast<VectorType>(vector.getType()).getRank(), false)));
|
|
|
|
rewriter.eraseOp(copyOp);
|
|
rewriter.eraseOp(xferOp);
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Convolution vectorization patterns
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <int N>
|
|
static void bindShapeDims(ShapedType shapedType) {}
|
|
|
|
template <int N, typename IntTy, typename... IntTy2>
|
|
static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
|
|
val = shapedType.getShape()[N];
|
|
bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
|
|
}
|
|
|
|
/// Bind a pack of int& to the leading dimensions of shapedType.getShape().
|
|
template <typename... IntTy>
|
|
static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
|
|
bindShapeDims<0>(shapedType, vals...);
|
|
}
|
|
|
|
namespace {
|
|
/// Generate a vector implementation for either:
|
|
/// ```
|
|
/// Op def: ( w, kw )
|
|
/// Iters: ({Par(), Red()})
|
|
/// Layout: {{w + kw}, {kw}, {w}}
|
|
/// ```
|
|
/// kw is unrolled.
|
|
///
|
|
/// or
|
|
///
|
|
/// ```
|
|
/// Op def: ( n, w, c, kw, f )
|
|
/// Iters: ({Par(), Par(), Par(), Red(), Red()})
|
|
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
|
|
/// ```
|
|
/// kw is unrolled, w is unrolled iff dilationW > 1.
|
|
///
|
|
/// or
|
|
///
|
|
/// ```
|
|
/// Op def: ( n, c, w, f, kw )
|
|
/// Iters: ({Par(), Par(), Par(), Red(), Red()})
|
|
/// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
|
|
/// ```
|
|
/// kw is unrolled, w is unrolled iff dilationW > 1.
|
|
///
|
|
/// or
|
|
///
|
|
/// ```
|
|
/// Op def: ( n, w, c, kw )
|
|
/// Iters: ({Par(), Par(), Par(), Red()})
|
|
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
|
|
/// ```
|
|
/// kw is unrolled, w is unrolled iff dilationW > 1.
|
|
struct Conv1DGenerator
|
|
: public StructuredGenerator<LinalgOp, utils::IteratorType> {
|
|
Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
|
|
: StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
|
|
|
|
lhsShaped = linalgOp.getDpsInputOperand(0)->get();
|
|
rhsShaped = linalgOp.getDpsInputOperand(1)->get();
|
|
resShaped = linalgOp.getDpsInitOperand(0)->get();
|
|
lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
|
|
rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
|
|
resShapedType = dyn_cast<ShapedType>(resShaped.getType());
|
|
|
|
Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
|
|
redOp = reduceOp->getName().getIdentifier();
|
|
|
|
setConvOperationKind(reduceOp);
|
|
|
|
auto maybeKind = getCombinerOpKind(reduceOp);
|
|
reductionKind = maybeKind.value();
|
|
|
|
// The ConvolutionOpInterface gives us guarantees of existence for
|
|
// strides/dilations. However, we do not need to rely on those, we can
|
|
// simply use them if present, otherwise use the default and let the generic
|
|
// conv. matcher in the ConvGenerator succeed or fail.
|
|
auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
|
|
auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
|
|
strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
|
|
dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
|
|
}
|
|
|
|
/// Generate a vector implementation for:
|
|
/// ```
|
|
/// Op def: ( w, kw )
|
|
/// Iters: ({Par(), Red()})
|
|
/// Layout: {{w + kw}, {kw}, {w}}
|
|
/// ```
|
|
/// kw is always unrolled.
|
|
///
|
|
/// or
|
|
///
|
|
/// ```
|
|
/// Op def: ( n, w, c, kw, f )
|
|
/// Iters: ({Par(), Par(), Par(), Red(), Red()})
|
|
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
|
|
/// ```
|
|
/// kw is always unrolled.
|
|
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
|
|
/// > 1.
|
|
FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
|
|
int64_t nSize, wSize, cSize, kwSize, fSize;
|
|
SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
|
|
bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
|
|
switch (conv1DOpOrder) {
|
|
case Conv1DOpOrder::W:
|
|
// Initialize unused dimensions
|
|
nSize = fSize = cSize = 0;
|
|
// out{W}
|
|
bindShapeDims(resShapedType, wSize);
|
|
// kernel{kw}
|
|
bindShapeDims(rhsShapedType, kwSize);
|
|
lhsShape = {// iw = ow + kw - 1
|
|
// (i.e. 16 convolved with 3 -> 14)
|
|
(wSize + kwSize - 1)};
|
|
rhsShape = {kwSize};
|
|
resShape = {wSize};
|
|
break;
|
|
case Conv1DOpOrder::Nwc:
|
|
// out{n, w, f}
|
|
bindShapeDims(resShapedType, nSize, wSize, fSize);
|
|
switch (oper) {
|
|
case ConvOperationKind::Conv:
|
|
// kernel{kw, c, f}
|
|
bindShapeDims(rhsShapedType, kwSize, cSize);
|
|
break;
|
|
case ConvOperationKind::Pool:
|
|
// kernel{kw}
|
|
bindShapeDims(rhsShapedType, kwSize);
|
|
cSize = fSize;
|
|
break;
|
|
}
|
|
lhsShape = {nSize,
|
|
// iw = ow * sw + kw * dw - 1
|
|
// (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
|
|
// Perform the proper inclusive -> exclusive -> inclusive.
|
|
((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
|
|
1,
|
|
cSize};
|
|
switch (oper) {
|
|
case ConvOperationKind::Conv:
|
|
rhsShape = {kwSize, cSize, fSize};
|
|
break;
|
|
case ConvOperationKind::Pool:
|
|
rhsShape = {kwSize};
|
|
break;
|
|
}
|
|
resShape = {nSize, wSize, fSize};
|
|
break;
|
|
case Conv1DOpOrder::Ncw:
|
|
// out{n, f, w}
|
|
bindShapeDims(resShapedType, nSize, fSize, wSize);
|
|
switch (oper) {
|
|
case ConvOperationKind::Conv:
|
|
// kernel{f, c, kw}
|
|
bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
|
|
break;
|
|
case ConvOperationKind::Pool:
|
|
// kernel{kw}
|
|
bindShapeDims(rhsShapedType, kwSize);
|
|
cSize = fSize;
|
|
break;
|
|
}
|
|
lhsShape = {nSize, cSize,
|
|
// iw = ow * sw + kw * dw - 1
|
|
// (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
|
|
// Perform the proper inclusive -> exclusive -> inclusive.
|
|
((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
|
|
1};
|
|
switch (oper) {
|
|
case ConvOperationKind::Conv:
|
|
rhsShape = {fSize, cSize, kwSize};
|
|
break;
|
|
case ConvOperationKind::Pool:
|
|
rhsShape = {kwSize};
|
|
break;
|
|
}
|
|
resShape = {nSize, fSize, wSize};
|
|
break;
|
|
}
|
|
|
|
vector::TransferWriteOp write;
|
|
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
|
|
|
// w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
|
|
// When strideW == 1, we can batch the contiguous loads and avoid
|
|
// unrolling
|
|
int64_t wSizeStep = strideW == 1 ? wSize : 1;
|
|
|
|
Type lhsEltType = lhsShapedType.getElementType();
|
|
Type rhsEltType = rhsShapedType.getElementType();
|
|
Type resEltType = resShapedType.getElementType();
|
|
auto lhsType = VectorType::get(lhsShape, lhsEltType);
|
|
auto rhsType = VectorType::get(rhsShape, rhsEltType);
|
|
auto resType = VectorType::get(resShape, resEltType);
|
|
// Zero padding with the corresponding dimensions for lhs, rhs and res.
|
|
SmallVector<Value> lhsPadding(lhsShape.size(), zero);
|
|
SmallVector<Value> rhsPadding(rhsShape.size(), zero);
|
|
SmallVector<Value> resPadding(resShape.size(), zero);
|
|
|
|
// Read the whole lhs, rhs and res in one shot (with zero padding).
|
|
Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
|
|
lhsPadding);
|
|
// This is needed only for Conv.
|
|
Value rhs = nullptr;
|
|
if (oper == ConvOperationKind::Conv)
|
|
rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
|
|
rhsPadding);
|
|
Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
|
|
resPadding);
|
|
|
|
// The base vectorization case for channeled convolution is input:
|
|
// {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
|
|
// vectorization case, we do pre transpose on input, weight, and output.
|
|
switch (conv1DOpOrder) {
|
|
case Conv1DOpOrder::W:
|
|
case Conv1DOpOrder::Nwc:
|
|
// Base case, so no transposes necessary.
|
|
break;
|
|
case Conv1DOpOrder::Ncw: {
|
|
// To match base vectorization case, we pre-transpose current case.
|
|
// ncw -> nwc
|
|
static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
|
|
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs);
|
|
// fcw -> wcf
|
|
static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
|
|
|
|
// This is needed only for Conv.
|
|
if (oper == ConvOperationKind::Conv)
|
|
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
|
|
// nfw -> nwf
|
|
static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
|
|
res = rewriter.create<vector::TransposeOp>(loc, res, permRes);
|
|
break;
|
|
}
|
|
}
|
|
|
|
//===------------------------------------------------------------------===//
|
|
// Begin vector-only rewrite part
|
|
//===------------------------------------------------------------------===//
|
|
// Unroll along kw and read slices of lhs and rhs.
|
|
SmallVector<Value> lhsVals, rhsVals, resVals;
|
|
lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
|
|
kwSize, strideW, dilationW, wSizeStep,
|
|
isSingleChanneled);
|
|
// Do not do for pooling.
|
|
if (oper == ConvOperationKind::Conv)
|
|
rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
|
|
resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
|
|
wSizeStep, isSingleChanneled);
|
|
|
|
auto linearIndex = [&](int64_t kw, int64_t w) {
|
|
return kw * (wSize / wSizeStep) + w;
|
|
};
|
|
|
|
// Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
|
|
// or perform outerproduct for non-channeled convolution or perform simple
|
|
// arith operation for pooling
|
|
for (int64_t kw = 0; kw < kwSize; ++kw) {
|
|
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
|
switch (oper) {
|
|
case ConvOperationKind::Conv:
|
|
if (isSingleChanneled) {
|
|
resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
|
|
lhsVals[linearIndex(kw, w)],
|
|
rhsVals[kw], resVals[w]);
|
|
} else {
|
|
resVals[w] = conv1dSliceAsContraction(rewriter, loc,
|
|
lhsVals[linearIndex(kw, w)],
|
|
rhsVals[kw], resVals[w]);
|
|
}
|
|
break;
|
|
case ConvOperationKind::Pool:
|
|
resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
|
|
resVals[w]);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
|
|
isSingleChanneled);
|
|
//===------------------------------------------------------------------===//
|
|
// End vector-only rewrite part
|
|
//===------------------------------------------------------------------===//
|
|
|
|
// The base vectorization case for channeled convolution is output:
|
|
// {n,w,f} To reuse the result from base pattern vectorization case, we
|
|
// post transpose the base case result.
|
|
switch (conv1DOpOrder) {
|
|
case Conv1DOpOrder::W:
|
|
case Conv1DOpOrder::Nwc:
|
|
// Base case, so no transposes necessary.
|
|
break;
|
|
case Conv1DOpOrder::Ncw: {
|
|
// nwf -> nfw
|
|
static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
|
|
res = rewriter.create<vector::TransposeOp>(loc, res, perm);
|
|
break;
|
|
}
|
|
}
|
|
|
|
return rewriter
|
|
.create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
|
|
.getOperation();
|
|
}
|
|
|
|
// Take a value and widen to have the same element type as `ty`.
|
|
Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
|
|
const Type srcElementType = getElementTypeOrSelf(val.getType());
|
|
const Type dstElementType = getElementTypeOrSelf(ty);
|
|
assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
|
|
if (srcElementType == dstElementType)
|
|
return val;
|
|
|
|
const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
|
|
const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
|
|
const Type dstType =
|
|
cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
|
|
|
|
if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
|
|
return rewriter.create<arith::SIToFPOp>(loc, dstType, val);
|
|
}
|
|
|
|
if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
|
|
srcWidth < dstWidth)
|
|
return rewriter.create<arith::ExtFOp>(loc, dstType, val);
|
|
|
|
if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
|
|
srcWidth < dstWidth)
|
|
return rewriter.create<arith::ExtSIOp>(loc, dstType, val);
|
|
|
|
assert(false && "unhandled promotion case");
|
|
return nullptr;
|
|
}
|
|
|
|
// Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
|
|
Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
|
|
Value lhs, Value rhs, Value res) {
|
|
vector::IteratorType par = vector::IteratorType::parallel;
|
|
vector::IteratorType red = vector::IteratorType::reduction;
|
|
AffineExpr n, w, f, c;
|
|
bindDims(ctx, n, w, f, c);
|
|
lhs = promote(rewriter, loc, lhs, res.getType());
|
|
rhs = promote(rewriter, loc, rhs, res.getType());
|
|
auto contrationOp = rewriter.create<vector::ContractionOp>(
|
|
loc, lhs, rhs, res,
|
|
/*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
|
|
/*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
|
|
contrationOp.setKind(reductionKind);
|
|
return contrationOp;
|
|
}
|
|
|
|
// Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
|
|
// convolution.
|
|
Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
|
|
Value lhs, Value rhs, Value res) {
|
|
return rewriter.create<vector::OuterProductOp>(
|
|
loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD);
|
|
}
|
|
|
|
// Create a reduction: lhs{n, w, c} -> res{n, w, c}
|
|
Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
|
|
Value res) {
|
|
if (isPoolExt)
|
|
lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
|
|
return rewriter
|
|
.create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
|
|
->getResult(0);
|
|
}
|
|
|
|
/// Generate a vector implementation for:
|
|
/// ```
|
|
/// Op def: ( n, w, c, kw)
|
|
/// Iters: ({Par(), Par(), Par(), Red()})
|
|
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
|
|
/// ```
|
|
/// kw is always unrolled.
|
|
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
|
|
/// > 1.
|
|
FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
|
|
bool channelDimScalableFlag,
|
|
bool flatten) {
|
|
bool scalableChDim = false;
|
|
bool useMasking = false;
|
|
int64_t nSize, wSize, cSize, kwSize;
|
|
// kernel{kw, c}
|
|
bindShapeDims(rhsShapedType, kwSize, cSize);
|
|
if (ShapedType::isDynamic(cSize)) {
|
|
assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
|
|
cSize = channelDimVecSize;
|
|
// Scalable vectors are only used when both conditions are met:
|
|
// 1. channel dim is dynamic
|
|
// 2. channelDimScalableFlag is set
|
|
scalableChDim = channelDimScalableFlag;
|
|
useMasking = true;
|
|
}
|
|
|
|
assert(!(useMasking && flatten) &&
|
|
"Unsupported flattened conv with dynamic shapes");
|
|
|
|
// out{n, w, c}
|
|
bindShapeDims(resShapedType, nSize, wSize);
|
|
|
|
vector::TransferWriteOp write;
|
|
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
|
|
|
// w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
|
|
// When strideW == 1, we can batch the contiguous loads and avoid
|
|
// unrolling
|
|
int64_t wSizeStep = strideW == 1 ? wSize : 1;
|
|
|
|
Type lhsEltType = lhsShapedType.getElementType();
|
|
Type rhsEltType = rhsShapedType.getElementType();
|
|
Type resEltType = resShapedType.getElementType();
|
|
VectorType lhsType = VectorType::get(
|
|
{nSize,
|
|
// iw = ow * sw + kw * dw - 1
|
|
// (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
|
|
((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
|
|
cSize},
|
|
lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
|
|
VectorType rhsType =
|
|
VectorType::get({kwSize, cSize}, rhsEltType,
|
|
/*scalableDims=*/{false, scalableChDim});
|
|
VectorType resType =
|
|
VectorType::get({nSize, wSize, cSize}, resEltType,
|
|
/*scalableDims=*/{false, false, scalableChDim});
|
|
|
|
// Masks the input xfer Op along the channel dim, iff the corresponding
|
|
// scalable flag is set.
|
|
auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
|
|
ArrayRef<bool> scalableDims,
|
|
Operation *opToMask) {
|
|
if (!useMasking)
|
|
return opToMask;
|
|
auto maskType =
|
|
VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
|
|
|
|
SmallVector<bool> inBounds(maskShape.size(), true);
|
|
auto xferOp = cast<VectorTransferOpInterface>(opToMask);
|
|
xferOp->setAttr(xferOp.getInBoundsAttrName(),
|
|
rewriter.getBoolArrayAttr(inBounds));
|
|
|
|
SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer(
|
|
cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
|
|
|
|
Value maskOp =
|
|
rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims);
|
|
|
|
return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
|
|
};
|
|
|
|
// Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
|
|
// 0].
|
|
Value lhs = rewriter.create<vector::TransferReadOp>(
|
|
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
|
|
auto maybeMaskedLhs = maybeMaskXferOp(
|
|
lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
|
|
|
|
// Read rhs slice of size {kw, c} @ [0, 0].
|
|
Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
|
|
ValueRange{zero, zero});
|
|
auto maybeMaskedRhs = maybeMaskXferOp(
|
|
rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
|
|
|
|
// Read res slice of size {n, w, c} @ [0, 0, 0].
|
|
Value res = rewriter.create<vector::TransferReadOp>(
|
|
loc, resType, resShaped, ValueRange{zero, zero, zero});
|
|
auto maybeMaskedRes = maybeMaskXferOp(
|
|
resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
|
|
|
|
//===------------------------------------------------------------------===//
|
|
// Begin vector-only rewrite part
|
|
//===------------------------------------------------------------------===//
|
|
// Unroll along kw and read slices of lhs and rhs.
|
|
SmallVector<Value> lhsVals, rhsVals, resVals;
|
|
SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize};
|
|
SmallVector<int64_t> inOutStrides = {1, 1, 1};
|
|
|
|
// Extract lhs slice of size {n, wSizeStep, c}
|
|
// @ [0, sw * w + dw * kw, 0].
|
|
for (int64_t kw = 0; kw < kwSize; ++kw) {
|
|
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
|
lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
|
|
loc, maybeMaskedLhs->getResult(0),
|
|
/*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
|
|
inOutSliceSizes, inOutStrides));
|
|
}
|
|
}
|
|
// Extract rhs slice of size {c} @ [kw].
|
|
for (int64_t kw = 0; kw < kwSize; ++kw) {
|
|
rhsVals.push_back(rewriter.create<vector::ExtractOp>(
|
|
loc, maybeMaskedRhs->getResult(0),
|
|
/*offsets=*/ArrayRef<int64_t>{kw}));
|
|
}
|
|
// Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
|
|
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
|
resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
|
|
loc, maybeMaskedRes->getResult(0),
|
|
/*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
|
|
inOutStrides));
|
|
}
|
|
|
|
auto linearIndex = [&](int64_t kw, int64_t w) {
|
|
return kw * (wSize / wSizeStep) + w;
|
|
};
|
|
|
|
// Note - the scalable flags are ignored as flattening combined with
|
|
// scalable vectorization is not supported.
|
|
SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
|
|
auto lhsTypeAfterFlattening =
|
|
VectorType::get(inOutFlattenSliceSizes, lhsEltType);
|
|
auto resTypeAfterFlattening =
|
|
VectorType::get(inOutFlattenSliceSizes, resEltType);
|
|
|
|
// Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
|
|
for (int64_t kw = 0; kw < kwSize; ++kw) {
|
|
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
|
Value lhsVal = lhsVals[linearIndex(kw, w)];
|
|
Value resVal = resVals[w];
|
|
if (flatten) {
|
|
// Flatten the input and output vectors (collapse the channel
|
|
// dimension)
|
|
lhsVal = rewriter.create<vector::ShapeCastOp>(
|
|
loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
|
|
resVal = rewriter.create<vector::ShapeCastOp>(
|
|
loc, resTypeAfterFlattening, resVals[w]);
|
|
}
|
|
resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
|
|
rhsVals[kw], resVal, flatten);
|
|
if (flatten) {
|
|
// Un-flatten the output vector (restore the channel dimension)
|
|
resVals[w] = rewriter.create<vector::ShapeCastOp>(
|
|
loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Its possible we failed to create the Fma.
|
|
if (!llvm::all_of(resVals, [](Value v) { return v; })) {
|
|
// Manually revert (in reverse order) to avoid leaving a bad IR state.
|
|
for (auto &collection :
|
|
{resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
|
|
for (Value v : collection)
|
|
rewriter.eraseOp(v.getDefiningOp());
|
|
return rewriter.notifyMatchFailure(op, "failed to create FMA");
|
|
}
|
|
|
|
// Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
|
|
// This does not depend on kw.
|
|
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
|
maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>(
|
|
loc, resVals[w], maybeMaskedRes->getResult(0),
|
|
/*offsets=*/ArrayRef<int64_t>{0, w, 0},
|
|
/*strides=*/ArrayRef<int64_t>{1, 1, 1});
|
|
}
|
|
//===------------------------------------------------------------------===//
|
|
// End vector-only rewrite part
|
|
//===------------------------------------------------------------------===//
|
|
|
|
// Write back res slice of size {n, w, c} @ [0, 0, 0].
|
|
Operation *resOut = rewriter.create<vector::TransferWriteOp>(
|
|
loc, maybeMaskedRes->getResult(0), resShaped,
|
|
ValueRange{zero, zero, zero});
|
|
return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
|
|
resOut);
|
|
}
|
|
|
|
/// Lower:
|
|
/// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
|
|
/// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
|
|
/// to MulAcc.
|
|
Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
|
|
Value lhs, Value rhs, Value res,
|
|
bool flatten) {
|
|
auto rhsTy = cast<ShapedType>(rhs.getType());
|
|
auto resTy = cast<ShapedType>(res.getType());
|
|
|
|
// TODO(suderman): Change this to use a vector.ima intrinsic.
|
|
lhs = promote(rewriter, loc, lhs, resTy);
|
|
|
|
if (flatten) {
|
|
// NOTE: This following logic won't work for scalable vectors. For this
|
|
// reason, "flattening" is not supported when shapes are dynamic (this
|
|
// should be captured by one of the pre-conditions).
|
|
|
|
// There are two options for handling the filter:
|
|
// * shape_cast(broadcast(filter))
|
|
// * broadcast(shuffle(filter))
|
|
// Opt for the option without shape_cast to simplify the codegen.
|
|
auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
|
|
auto resSize = cast<VectorType>(res.getType()).getShape()[1];
|
|
|
|
SmallVector<int64_t, 16> indices;
|
|
for (int i = 0; i < resSize / rhsSize; ++i) {
|
|
for (int j = 0; j < rhsSize; ++j)
|
|
indices.push_back(j);
|
|
}
|
|
|
|
rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indices);
|
|
}
|
|
// Broadcast the filter to match the output vector
|
|
rhs = rewriter.create<vector::BroadcastOp>(
|
|
loc, resTy.clone(rhsTy.getElementType()), rhs);
|
|
|
|
rhs = promote(rewriter, loc, rhs, resTy);
|
|
|
|
if (!lhs || !rhs)
|
|
return nullptr;
|
|
|
|
if (isa<FloatType>(resTy.getElementType()))
|
|
return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
|
|
|
|
auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
|
|
return rewriter.create<arith::AddIOp>(loc, mul, res);
|
|
}
|
|
|
|
/// Entry point for non-channeled convolution:
|
|
/// {{w + kw}, {kw}, {w}}
|
|
FailureOr<Operation *> generateNonChanneledConv() {
|
|
AffineExpr w, kw;
|
|
bindDims(ctx, w, kw);
|
|
if (!iters({Par(), Red()}))
|
|
return rewriter.notifyMatchFailure(op,
|
|
"failed to match conv::W 1-par 1-red");
|
|
|
|
// No transposition needed.
|
|
if (layout({/*lhsIndex*/ {w + kw},
|
|
/*rhsIndex*/ {kw},
|
|
/*resIndex*/ {w}}))
|
|
return conv(Conv1DOpOrder::W);
|
|
|
|
return rewriter.notifyMatchFailure(op, "not a conv::W layout");
|
|
}
|
|
|
|
/// Entry point that transposes into the common form:
|
|
/// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
|
|
FailureOr<Operation *> generateNwcConv() {
|
|
AffineExpr n, w, f, kw, c;
|
|
bindDims(ctx, n, w, f, kw, c);
|
|
if (!iters({Par(), Par(), Par(), Red(), Red()}))
|
|
return rewriter.notifyMatchFailure(
|
|
op, "failed to match conv::Nwc 3-par 2-red");
|
|
|
|
// No transposition needed.
|
|
if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
|
|
/*rhsIndex*/ {kw, c, f},
|
|
/*resIndex*/ {n, w, f}}))
|
|
return conv(Conv1DOpOrder::Nwc);
|
|
|
|
return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
|
|
}
|
|
|
|
/// Entry point that transposes into the common form:
|
|
/// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
|
|
FailureOr<Operation *> generateNcwConv() {
|
|
AffineExpr n, w, f, kw, c;
|
|
bindDims(ctx, n, f, w, c, kw);
|
|
if (!iters({Par(), Par(), Par(), Red(), Red()}))
|
|
return rewriter.notifyMatchFailure(
|
|
op, "failed to match conv::Ncw 3-par 2-red");
|
|
|
|
if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
|
|
/*rhsIndex*/ {f, c, kw},
|
|
/*resIndex*/ {n, f, w}}))
|
|
return conv(Conv1DOpOrder::Ncw);
|
|
|
|
return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
|
|
}
|
|
|
|
/// Entry point that transposes into the common form:
|
|
/// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
|
|
FailureOr<Operation *> generateNwcPooling() {
|
|
AffineExpr n, w, c, kw;
|
|
bindDims(ctx, n, w, c, kw);
|
|
if (!iters({Par(), Par(), Par(), Red()}))
|
|
return rewriter.notifyMatchFailure(op,
|
|
"failed to match pooling 3-par 1-red");
|
|
|
|
// No transposition needed.
|
|
if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
|
|
/*rhsIndex*/ {kw},
|
|
/*resIndex*/ {n, w, c}}))
|
|
return conv(Conv1DOpOrder::Nwc);
|
|
|
|
return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
|
|
}
|
|
|
|
/// Entry point that transposes into the common form:
|
|
/// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
|
|
FailureOr<Operation *> generateNcwPooling() {
|
|
AffineExpr n, w, c, kw;
|
|
bindDims(ctx, n, c, w, kw);
|
|
if (!iters({Par(), Par(), Par(), Red()}))
|
|
return rewriter.notifyMatchFailure(op,
|
|
"failed to match pooling 3-par 1-red");
|
|
|
|
if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
|
|
/*rhsIndex*/ {kw},
|
|
/*resIndex*/ {n, c, w}}))
|
|
return conv(Conv1DOpOrder::Ncw);
|
|
|
|
return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
|
|
}
|
|
|
|
/// Entry point that transposes into the common form:
|
|
/// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
|
|
FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
|
|
bool vecChDimScalableFlag = false,
|
|
bool flatten = false) {
|
|
AffineExpr n, w, c, kw;
|
|
bindDims(ctx, n, w, c, kw);
|
|
if (!iters({Par(), Par(), Par(), Red()}))
|
|
return rewriter.notifyMatchFailure(
|
|
op, "failed to match depthwise::Nwc conv 3-par 1-red");
|
|
|
|
// No transposition needed.
|
|
if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
|
|
/*rhsIndex*/ {kw, c},
|
|
/*resIndex*/ {n, w, c}}))
|
|
return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
|
|
|
|
return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
|
|
}
|
|
|
|
private:
|
|
ConvOperationKind oper = ConvOperationKind::Conv;
|
|
StringAttr redOp;
|
|
StringAttr poolExtOp;
|
|
bool isPoolExt = false;
|
|
int strideW, dilationW;
|
|
Value lhsShaped, rhsShaped, resShaped;
|
|
ShapedType lhsShapedType, rhsShapedType, resShapedType;
|
|
vector::CombiningKind reductionKind;
|
|
|
|
// Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
|
|
void setConvOperationKind(Operation *reduceOp) {
|
|
int numBlockArguments =
|
|
llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
|
|
if (numBlockArguments == 1) {
|
|
// Will be convolution if feeder is a MulOp.
|
|
// A strength reduced version of MulOp for i1 type is AndOp which is also
|
|
// supported. Otherwise, it can be pooling. This strength reduction logic
|
|
// is in `buildBinaryFn` helper in the Linalg dialect.
|
|
auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
|
|
llvm::IsaPred<BlockArgument>);
|
|
Operation *feedOp = (*feedValIt).getDefiningOp();
|
|
if (isCastOfBlockArgument(feedOp)) {
|
|
oper = ConvOperationKind::Pool;
|
|
isPoolExt = true;
|
|
poolExtOp = feedOp->getName().getIdentifier();
|
|
return;
|
|
}
|
|
oper = ConvOperationKind::Conv;
|
|
return;
|
|
}
|
|
// numBlockArugments == 2 and this is a pooling op.
|
|
oper = ConvOperationKind::Pool;
|
|
isPoolExt = false;
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
/// Helper function to vectorize a LinalgOp with convolution semantics.
|
|
// TODO: extend the generic vectorization to support windows and drop this.
|
|
static FailureOr<Operation *> vectorizeConvolution(
|
|
RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
|
|
ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
|
|
Conv1DGenerator conv1dGen(rewriter, op);
|
|
auto res = conv1dGen.generateNonChanneledConv();
|
|
if (succeeded(res))
|
|
return res;
|
|
res = conv1dGen.generateNwcConv();
|
|
if (succeeded(res))
|
|
return res;
|
|
res = conv1dGen.generateNcwConv();
|
|
if (succeeded(res))
|
|
return res;
|
|
res = conv1dGen.generateNwcPooling();
|
|
if (succeeded(res))
|
|
return res;
|
|
res = conv1dGen.generateNcwPooling();
|
|
if (succeeded(res))
|
|
return res;
|
|
|
|
// Only depthwise 1D NWC convs are left - these can be vectorized using masks
|
|
// and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
|
|
// masked/scalable) is the channel dim (i.e. the trailing dim).
|
|
uint64_t vecChDimSize = ShapedType::kDynamic;
|
|
bool vecChDimScalableFlag = false;
|
|
if (!inputVecSizes.empty()) {
|
|
// Only use the input vector size corresponding to the channel dim. Other
|
|
// vector dims will be inferred from the Ops.
|
|
assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
|
|
isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
|
|
"Not a 1D depthwise conv!");
|
|
size_t chDimIdx =
|
|
TypeSwitch<Operation *, size_t>(op)
|
|
.Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
|
|
.Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
|
|
|
|
vecChDimSize = inputVecSizes[chDimIdx];
|
|
vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
|
|
}
|
|
return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
|
|
flatten1DDepthwiseConv);
|
|
}
|
|
|
|
struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
|
|
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(LinalgOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
|
|
if (failed(resultOrFail))
|
|
return failure();
|
|
Operation *newOp = *resultOrFail;
|
|
if (newOp->getNumResults() == 0) {
|
|
rewriter.eraseOp(op.getOperation());
|
|
return success();
|
|
}
|
|
assert(newOp->getNumResults() == 1 && "expected single result");
|
|
rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void mlir::linalg::populateConvolutionVectorizationPatterns(
|
|
RewritePatternSet &patterns, PatternBenefit benefit) {
|
|
patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
|
|
}
|