This revision takes advantage of recent extensions to vectorization to refactor contraction detection into a bona fide Linalg interface. The mlit-linalg-ods-gen parser is extended to support adding such interfaces. The detection that was originally enabling vectorization is refactored to serve as both a test on a generic LinalgOp as well as to verify ops that declare to conform to that interface. This is plugged through Linalg transforms and strategies but it quickly becomes evident that the complexity and rigidity of the C++ class based templating does not pay for itself. Therefore, this revision changes the API for vectorization patterns to get rid of templates as much as possible. Variadic templates are relegated to the internals of LinalgTransformationFilter as much as possible and away from the user-facing APIs. It is expected other patterns / transformations will follow the same path and drop as much C++ templating as possible from the class definition. Differential revision: https://reviews.llvm.org/D95973
812 lines
32 KiB
C++
812 lines
32 KiB
C++
//===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements the linalg dialect Vectorization transformations.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
|
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
|
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
|
#include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
|
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
|
#include "mlir/IR/AffineExpr.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
#include "mlir/Transforms/RegionUtils.h"
|
|
#include "llvm/ADT/ScopeExit.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include <type_traits>
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::edsc;
|
|
using namespace mlir::edsc::intrinsics;
|
|
using namespace mlir::linalg;
|
|
|
|
using llvm::dbgs;
|
|
|
|
#define DEBUG_TYPE "linalg-vectorization"
|
|
|
|
/// Return true if the use-def chain from `v` to `from` consists of 0 or more
|
|
/// unary single-operand operations.
|
|
// TODO: relax to multi-operands with constants, which are technically unary ops
|
|
// as needed (e.g. add5).
|
|
static bool isChainOfUnaryOpsFrom(Value v, Value from) {
|
|
while (v != from) {
|
|
Operation *op = v.getDefiningOp();
|
|
if (!op || op->getNumOperands() != 1)
|
|
return false;
|
|
v = op->getOperand(0);
|
|
};
|
|
return true;
|
|
}
|
|
|
|
/// Return the unique instance of OpType in `block` if it is indeed unique.
|
|
/// Return null if none or more than 1 instances exist.
|
|
template <typename OpType>
|
|
static OpType getSingleOpOfType(Block &block) {
|
|
OpType res;
|
|
block.walk([&](OpType op) {
|
|
if (res) {
|
|
res = nullptr;
|
|
return WalkResult::interrupt();
|
|
}
|
|
res = op;
|
|
return WalkResult::advance();
|
|
});
|
|
return res;
|
|
}
|
|
|
|
/// Detect whether res is any permutation of `u5(u1(c) + u2(u3(a) * u4(b)))`
|
|
/// on the field (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent
|
|
/// unary operations that may change the type.
|
|
template <typename AddOpType, typename MulOpType>
|
|
static bool isAddMul(Block &block) {
|
|
if (block.getNumArguments() != 3)
|
|
return false;
|
|
Operation *yieldOp = block.getTerminator();
|
|
if (yieldOp->getNumOperands() != 1)
|
|
return false;
|
|
|
|
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: isAddMul: "; block.dump());
|
|
AddOpType addOp = getSingleOpOfType<AddOpType>(block);
|
|
MulOpType mulOp = getSingleOpOfType<MulOpType>(block);
|
|
if (!addOp || !mulOp)
|
|
return false;
|
|
|
|
Value argA = block.getArgument(0), argB = block.getArgument(1);
|
|
Value a = mulOp->getOperand(0), b = mulOp->getOperand(1);
|
|
Value mul = mulOp->getResult(0);
|
|
Value argC = block.getArgument(2);
|
|
Value c1 = addOp->getOperand(0), c2 = addOp->getOperand(1);
|
|
Value add = addOp->getResult(0);
|
|
Value res = yieldOp->getOperand(0);
|
|
// Result traces back to add.
|
|
auto un = isChainOfUnaryOpsFrom;
|
|
bool success = un(res, add);
|
|
// One of the operands of add traces back to argC, the other to the mul.
|
|
success |= (un(c1, argC) && un(c2, mul)) || ((un(c1, mul)) && un(c2, argC));
|
|
// One of the operands of mul traces back to argA, the other to argB.
|
|
success |= (un(a, argA) && un(b, argB)) || ((un(a, argB)) && un(b, argA));
|
|
return success;
|
|
}
|
|
|
|
/// Helper data structure to represent the result of vectorization.
|
|
/// In certain specific cases, like terminators, we do not want to propagate/
|
|
enum VectorizationStatus {
|
|
/// Op failed to vectorize.
|
|
Failure = 0,
|
|
/// Op vectorized and custom function took care of replacement logic
|
|
NoReplace,
|
|
/// Op vectorized into a new Op whose results will replace original Op's
|
|
/// results.
|
|
NewOp
|
|
// TODO: support values if Op vectorized to Many-Ops whose results we need to
|
|
// aggregate for replacement.
|
|
};
|
|
struct VectorizationResult {
|
|
/// Return status from vectorizing the current op.
|
|
enum VectorizationStatus status = VectorizationStatus::Failure;
|
|
/// New vectorized operation to replace the current op.
|
|
/// Replacement behavior is specified by `status`.
|
|
Operation *newOp;
|
|
};
|
|
|
|
/// Return a vector type of the same shape and element type as the (assumed)
|
|
/// ShapedType of `v`.
|
|
static VectorType extractVectorTypeFromShapedValue(Value v) {
|
|
auto st = v.getType().cast<ShapedType>();
|
|
if (st.isa<MemRefType>() && st.getShape().empty())
|
|
return VectorType();
|
|
return VectorType::get(st.getShape(), st.getElementType());
|
|
}
|
|
|
|
/// Build a vector.transfer_read from `source` at indices set to all `0`.
|
|
/// If source has rank zero, build an std.load.
|
|
/// Return the produced value.
|
|
static Value buildVectorRead(OpBuilder &builder, Value source) {
|
|
edsc::ScopedContext scope(builder);
|
|
auto shapedType = source.getType().cast<ShapedType>();
|
|
if (VectorType vectorType = extractVectorTypeFromShapedValue(source)) {
|
|
SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0));
|
|
return vector_transfer_read(vectorType, source, indices);
|
|
}
|
|
return std_load(source);
|
|
}
|
|
|
|
/// Build a vector.transfer_write of `value` into `dest` at indices set to all
|
|
/// `0`. If `dest` has null rank, build an std.store.
|
|
/// Return the produced value or null if no value is produced.
|
|
static Value buildVectorWrite(OpBuilder &builder, Value value, Value dest) {
|
|
edsc::ScopedContext scope(builder);
|
|
Operation *write;
|
|
auto shapedType = dest.getType().cast<ShapedType>();
|
|
if (VectorType vectorType = extractVectorTypeFromShapedValue(dest)) {
|
|
SmallVector<Value> indices(shapedType.getRank(), std_constant_index(0));
|
|
if (vectorType != value.getType())
|
|
value = vector_broadcast(vectorType, value);
|
|
write = vector_transfer_write(value, dest, indices);
|
|
} else {
|
|
write = std_store(value, dest);
|
|
}
|
|
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write);
|
|
if (!write->getResults().empty())
|
|
return write->getResult(0);
|
|
return Value();
|
|
}
|
|
|
|
/// If value of assumed VectorType has a shape different than `shape`, buil and
|
|
/// return a new vector.broadcast to `shape`.
|
|
/// Otherwise, just return value.
|
|
static Value broadcastIfNeeded(OpBuilder &builder, Value value,
|
|
ArrayRef<int64_t> shape) {
|
|
auto vecType = value.getType().dyn_cast<VectorType>();
|
|
if (shape.empty() || (vecType != nullptr && vecType.getShape() == shape))
|
|
return value;
|
|
auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType()
|
|
: value.getType());
|
|
return builder.create<vector::BroadcastOp>(
|
|
builder.getInsertionPoint()->getLoc(), newVecType, value);
|
|
}
|
|
|
|
// Custom vectorization function type. Produce a vector form of Operation*
|
|
// assuming all its vectorized operands are already in the BlockAndValueMapping.
|
|
// Return nullptr if the Operation cannot be vectorized.
|
|
using CustomVectorizationHook = std::function<VectorizationResult(
|
|
Operation *, const BlockAndValueMapping &)>;
|
|
|
|
/// Helper function to vectorize the terminator of a `linalgOp`. New result
|
|
/// vector values are appended to `results`.
|
|
/// Return VectorizationStatus::NoReplace to signal the vectorization algorithm
|
|
/// that it should not try to map produced operations: this is the purpose of
|
|
/// the `results` argument to capture such values and make them available for
|
|
/// RAUW to the vectorization algorithm.
|
|
/// This function is meant to be used as a CustomVectorizationHook.
|
|
static VectorizationResult
|
|
vectorizeLinalgYield(OpBuilder &builder, Operation *op,
|
|
const BlockAndValueMapping &bvm, LinalgOp linalgOp,
|
|
SmallVectorImpl<Value> &results) {
|
|
auto yieldOp = dyn_cast<linalg::YieldOp>(op);
|
|
if (!yieldOp)
|
|
return VectorizationResult{VectorizationStatus::Failure, nullptr};
|
|
for (auto outputs : llvm::enumerate(yieldOp.values())) {
|
|
// TODO: Scan for an opportunity for reuse.
|
|
// TODO: use a map.
|
|
Value vectorValue = bvm.lookup(outputs.value());
|
|
Value result = buildVectorWrite(builder, vectorValue,
|
|
linalgOp.getOutput(outputs.index()));
|
|
if (result)
|
|
results.push_back(result);
|
|
}
|
|
return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
|
|
}
|
|
|
|
/// Generic vectorization for a single operation `op`, given already vectorized
|
|
/// operands carried by `bvm`. Vectorization occurs as follows:
|
|
/// 1. Try to apply any of the `customVectorizationHooks` and return its
|
|
/// result on success.
|
|
/// 2. Clone any constant in the current scope without vectorization: each
|
|
/// consumer of the constant will later determine the shape to which the
|
|
/// constant needs to be broadcast to.
|
|
/// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
|
|
/// of the `customVectorizationHooks` to cover such cases.
|
|
/// 4. Clone `op` in vector form to a vector of shape prescribed by the first
|
|
/// operand of maximal rank. Other operands have smaller rank and are
|
|
/// broadcast accordingly. It is assumed this broadcast is always legal,
|
|
/// otherwise, it means one of the `customVectorizationHooks` is incorrect.
|
|
///
|
|
/// This function assumes all operands of `op` have been vectorized and are in
|
|
/// the `bvm` mapping. As a consequence, this function is meant to be called on
|
|
/// a topologically-sorted list of ops.
|
|
/// This function does not update `bvm` but returns a VectorizationStatus that
|
|
/// instructs the caller what `bvm` update needs to occur.
|
|
static VectorizationResult
|
|
vectorizeOneOp(OpBuilder &builder, Operation *op,
|
|
const BlockAndValueMapping &bvm,
|
|
ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
|
|
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorize op " << *op);
|
|
|
|
// 1. Try to apply any CustomVectorizationHook.
|
|
if (!customVectorizationHooks.empty()) {
|
|
for (auto &customFunc : customVectorizationHooks) {
|
|
VectorizationResult result = customFunc(op, bvm);
|
|
if (result.status == VectorizationStatus::Failure)
|
|
continue;
|
|
return result;
|
|
}
|
|
}
|
|
|
|
// 2. Constant ops don't get vectorized but rather broadcasted at their users.
|
|
// Clone so that the constant is not confined to the linalgOp block .
|
|
if (isa<ConstantOp>(op))
|
|
return VectorizationResult{VectorizationStatus::NewOp, builder.clone(*op)};
|
|
|
|
// 3. Only ElementwiseMappable are allowed in the generic vectorization.
|
|
if (!op->hasTrait<OpTrait::ElementwiseMappable>())
|
|
return VectorizationResult{VectorizationStatus::Failure, nullptr};
|
|
|
|
// 4. Generic vectorization path for ElementwiseMappable ops.
|
|
// a. first get the first max ranked shape.
|
|
SmallVector<int64_t, 4> firstMaxRankedShape;
|
|
for (Value operand : op->getOperands()) {
|
|
auto vt = bvm.lookup(operand).getType().dyn_cast<VectorType>();
|
|
if (vt && firstMaxRankedShape.size() < vt.getShape().size())
|
|
firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end());
|
|
}
|
|
// b. broadcast each op if needed.
|
|
auto vectorizedOperands = llvm::map_range(op->getOperands(), [&](Value v) {
|
|
return firstMaxRankedShape.empty()
|
|
? bvm.lookup(v)
|
|
: broadcastIfNeeded(builder, bvm.lookup(v), firstMaxRankedShape);
|
|
});
|
|
// c. for elementwise, the result is the vector with the firstMaxRankedShape
|
|
auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) {
|
|
return firstMaxRankedShape.empty()
|
|
? t
|
|
: VectorType::get(firstMaxRankedShape, t);
|
|
});
|
|
|
|
// Build and return the new op.
|
|
OperationState state(op->getLoc(), op->getName());
|
|
state.addAttributes(op->getAttrs());
|
|
state.addOperands(llvm::to_vector<4>(vectorizedOperands));
|
|
state.addTypes(llvm::to_vector<4>(returnTypes));
|
|
return VectorizationResult{VectorizationStatus::NewOp,
|
|
builder.createOperation(state)};
|
|
}
|
|
|
|
/// Generic vectorization function that rewrites the body of a `linalgOp` into
|
|
/// vector form. Generic vectorization proceeds as follows:
|
|
/// 1. The region for the linalg op is created if necessary.
|
|
/// 2. Values defined above the region are mapped to themselves and will be
|
|
/// broadcasted on a per-need basis by their consumers.
|
|
/// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
|
|
/// load).
|
|
/// TODO: Reuse opportunities for RAR dependencies.
|
|
/// 4. Register CustomVectorizationHook for YieldOp to capture the results.
|
|
/// 5. Iteratively call vectorizeOneOp on the region operations.
|
|
/// 6. RAUW the linalg op by the results captured vectorizing the YieldOp.
|
|
static LogicalResult vectorizeAsLinalgGeneric(
|
|
OpBuilder &builder, LinalgOp linalgOp,
|
|
ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) {
|
|
// 1. Certain Linalg ops do not have a region but only a region builder.
|
|
// If so, build the region so we can vectorize.
|
|
std::unique_ptr<Region> owningRegion;
|
|
Region *region;
|
|
if (linalgOp->getNumRegions() > 0) {
|
|
region = &linalgOp->getRegion(0);
|
|
} else {
|
|
// RAII avoid remaining in block.
|
|
OpBuilder::InsertionGuard g(builder);
|
|
owningRegion = std::make_unique<Region>();
|
|
region = owningRegion.get();
|
|
Block *block = builder.createBlock(region);
|
|
auto elementTypes = llvm::to_vector<4>(
|
|
llvm::map_range(linalgOp.getShapedOperandTypes(),
|
|
[](ShapedType t) { return t.getElementType(); }));
|
|
block->addArguments(elementTypes);
|
|
linalgOp.getRegionBuilder()(*block);
|
|
}
|
|
Block *block = ®ion->front();
|
|
|
|
BlockAndValueMapping bvm;
|
|
// 2. Values defined above the region can only be broadcast for now. Make them
|
|
// map to themselves.
|
|
llvm::SetVector<Value> valuesSet;
|
|
mlir::getUsedValuesDefinedAbove(*region, valuesSet);
|
|
bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
|
|
|
|
// 3. Turn all BBArgs into vector.transfer_read / load.
|
|
SmallVector<AffineMap> indexings;
|
|
for (auto bbarg : block->getArguments()) {
|
|
Value vectorArg = linalgOp.getShapedOperand(bbarg.getArgNumber());
|
|
Value vectorRead = buildVectorRead(builder, vectorArg);
|
|
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
|
|
<< bbarg.getArgNumber() << "): " << vectorRead);
|
|
bvm.map(bbarg, vectorRead);
|
|
bvm.map(vectorArg, vectorRead);
|
|
}
|
|
|
|
// 4. Register CustomVectorizationHook for yieldOp.
|
|
SmallVector<Value> results;
|
|
CustomVectorizationHook vectorizeYield =
|
|
[&](Operation *op,
|
|
const BlockAndValueMapping &bvm) -> VectorizationResult {
|
|
return vectorizeLinalgYield(builder, op, bvm, linalgOp, results);
|
|
};
|
|
// Append the vectorizeYield hook.
|
|
auto hooks = llvm::to_vector<4>(customVectorizationHooks);
|
|
hooks.push_back(vectorizeYield);
|
|
|
|
// 5. Iteratively call `vectorizeOneOp` to each op in the slice.
|
|
for (Operation &op : block->getOperations()) {
|
|
VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks);
|
|
if (result.status == VectorizationStatus::Failure) {
|
|
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op);
|
|
return failure();
|
|
}
|
|
if (result.status == VectorizationStatus::NewOp) {
|
|
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: "
|
|
<< *result.newOp;);
|
|
bvm.map(op.getResults(), result.newOp->getResults());
|
|
}
|
|
}
|
|
|
|
// 6. RAUW the linalg op by the results captured vectorizing the YieldOp.
|
|
if (!results.empty())
|
|
linalgOp->replaceAllUsesWith(results);
|
|
return success();
|
|
}
|
|
|
|
/// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
|
|
static bool hasOnlyScalarElementwiseOp(Region &r) {
|
|
if (!llvm::hasSingleElement(r))
|
|
return false;
|
|
for (Operation &op : r.front()) {
|
|
if (!(isa<ConstantOp, linalg::YieldOp>(op) ||
|
|
op.hasTrait<OpTrait::ElementwiseMappable>()) ||
|
|
llvm::any_of(op.getResultTypes(),
|
|
[](Type type) { return !type.isIntOrIndexOrFloat(); }))
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// Return true if the op is an element-wise linalg op.
|
|
static bool isElementwise(Operation *op) {
|
|
auto genericOp = dyn_cast<linalg::GenericOp>(op);
|
|
if (!genericOp)
|
|
return false;
|
|
if (genericOp.getNumLoops() != genericOp.getNumParallelLoops())
|
|
return false;
|
|
// TODO: relax the restrictions on indexing map.
|
|
for (unsigned i = 0, e = genericOp.getNumOutputs(); i < e; i++) {
|
|
if (!genericOp.getOutputIndexingMap(i).isIdentity())
|
|
return false;
|
|
}
|
|
// Currently bound the input indexing map to minor identity as other
|
|
// permutations might require adding transpose ops to convert the vector read
|
|
// to the right shape.
|
|
for (unsigned i = 0, e = genericOp.getNumInputs(); i < e; i++) {
|
|
if (!genericOp.getInputIndexingMap(i).isMinorIdentity())
|
|
return false;
|
|
}
|
|
return hasOnlyScalarElementwiseOp(genericOp.getRegion());
|
|
}
|
|
|
|
static void vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp) {
|
|
assert(isaContractionOpInterface(linalgOp) &&
|
|
"expected vectorizeContraction preconditions to be met");
|
|
Location loc = linalgOp.getLoc();
|
|
// Vectorize other ops as vector contraction.
|
|
// TODO: interface.
|
|
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
|
|
<< "Rewrite linalg op as vector.contract: ";
|
|
linalgOp.dump());
|
|
// Special function that describes how to vectorize the multiplication op in a
|
|
// linalg contraction.
|
|
CustomVectorizationHook vectorizeContraction =
|
|
[&](Operation *op,
|
|
const BlockAndValueMapping &bvm) -> VectorizationResult {
|
|
if (!isa<MulIOp, MulFOp>(op))
|
|
return VectorizationResult{VectorizationStatus::Failure, nullptr};
|
|
auto outShape = linalgOp.getOutputShapedType(0).getShape();
|
|
auto vType = outShape.empty()
|
|
? op->getResult(0).getType()
|
|
: VectorType::get(outShape, op->getResult(0).getType());
|
|
auto zero =
|
|
builder.create<ConstantOp>(loc, vType, builder.getZeroAttr(vType));
|
|
Operation *contract = builder.create<vector::ContractionOp>(
|
|
loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero,
|
|
linalgOp.indexing_maps(), linalgOp.iterator_types());
|
|
return VectorizationResult{VectorizationStatus::NewOp, contract};
|
|
};
|
|
auto status =
|
|
vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction});
|
|
(void)status;
|
|
assert(succeeded(status) &&
|
|
"Unexpected vectorization failed despite preconditions");
|
|
}
|
|
|
|
LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
|
|
auto linalgOp = cast<linalg::LinalgOp>(op);
|
|
// All types must be static shape to go to vector.
|
|
for (Value operand : linalgOp.getShapedOperands())
|
|
if (!operand.getType().cast<ShapedType>().hasStaticShape())
|
|
return failure();
|
|
for (Type outputTensorType : linalgOp.getOutputTensorTypes())
|
|
if (!outputTensorType.cast<ShapedType>().hasStaticShape())
|
|
return failure();
|
|
|
|
if (isa<linalg::FillOp, linalg::CopyOp>(op))
|
|
return success();
|
|
if (isElementwise(op))
|
|
return success();
|
|
return success(isaContractionOpInterface(linalgOp));
|
|
}
|
|
|
|
void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
|
|
assert(succeeded(vectorizeLinalgOpPrecondition(op)));
|
|
|
|
edsc::ScopedContext scope(builder, op->getLoc());
|
|
// In the case of 0-D memrefs, return null and special case to scalar load or
|
|
// store later.
|
|
if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
|
|
// Vectorize fill as a vector.broadcast.
|
|
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
|
|
<< "Rewrite linalg.fill as vector.broadcast: " << *op);
|
|
buildVectorWrite(builder, fillOp.value(), fillOp.output());
|
|
return;
|
|
}
|
|
if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
|
|
// Vectorize copy as a vector.transfer_read+vector.transfer_write.
|
|
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
|
|
<< "Rewrite linalg.copy as vector.transfer_read + "
|
|
"vector.transfer_write: "
|
|
<< *op);
|
|
Value vector = buildVectorRead(builder, copyOp.input());
|
|
buildVectorWrite(builder, vector, copyOp.output());
|
|
return;
|
|
}
|
|
|
|
if (isElementwise(op)) {
|
|
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
|
|
<< "Rewrite linalg op as vector.transfer_read + " << *op);
|
|
auto status = vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
|
|
(void)status;
|
|
assert(succeeded(status) &&
|
|
"Unexpected vectorization failed despite preconditions");
|
|
return;
|
|
}
|
|
|
|
vectorizeContraction(builder, cast<LinalgOp>(op));
|
|
}
|
|
|
|
//----------------------------------------------------------------------------//
|
|
// Misc. conv vectorization patterns.
|
|
//----------------------------------------------------------------------------//
|
|
// TODO: cleanup all this.
|
|
template <class ConvOp, int N>
|
|
LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(
|
|
ConvOp op, PatternRewriter &rewriter) const {
|
|
Location loc = op.getLoc();
|
|
MLIRContext *context = op.getContext();
|
|
edsc::ScopedContext scope(rewriter, loc);
|
|
|
|
ShapedType inShapeType = op.getInputShapedType(0);
|
|
ShapedType kShapeType = op.getInputShapedType(1);
|
|
|
|
ArrayRef<int64_t> inShape = inShapeType.getShape();
|
|
ArrayRef<int64_t> kShape = kShapeType.getShape();
|
|
|
|
if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape())
|
|
return failure();
|
|
|
|
SmallVector<AffineExpr, 4> mapping;
|
|
SmallVector<int64_t, 4> vectorDims;
|
|
// Fail to apply when the size of not vectorized dimension is not 1.
|
|
for (unsigned i = 0; i < N; i++) {
|
|
if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1))
|
|
return failure();
|
|
|
|
if (mask[i] && inShape[i] != kShape[i])
|
|
return failure();
|
|
|
|
if (mask[i]) {
|
|
mapping.push_back(getAffineDimExpr(i, context));
|
|
vectorDims.push_back(inShape[i]);
|
|
}
|
|
}
|
|
|
|
Value input = op.getInput(0);
|
|
Value kernel = op.getInput(1);
|
|
Value output = op.getOutputBuffer(0);
|
|
|
|
unsigned rank = inShapeType.getRank();
|
|
unsigned numDims = mapping.size();
|
|
Type elemType = inShapeType.getElementType();
|
|
|
|
auto map = AffineMap::get(rank, 0, mapping, context);
|
|
SmallVector<Value, 4> zeros(rank, std_constant_index(0));
|
|
auto vecType = VectorType::get(vectorDims, elemType);
|
|
|
|
auto inputVec = vector_transfer_read(vecType, input, zeros, map);
|
|
auto kernelVec = vector_transfer_read(vecType, kernel, zeros, map);
|
|
|
|
auto acc = std_constant(elemType, rewriter.getZeroAttr(elemType));
|
|
|
|
std::array<AffineMap, 3> indexingMaps{
|
|
AffineMap::getMultiDimIdentityMap(numDims, context),
|
|
AffineMap::getMultiDimIdentityMap(numDims, context),
|
|
AffineMap::get(numDims, 0, {}, context)};
|
|
|
|
std::vector<StringRef> iteratorTypes(numDims, "reduction");
|
|
|
|
auto result = rewriter.create<vector::ContractionOp>(
|
|
loc, inputVec, kernelVec, acc,
|
|
rewriter.getAffineMapArrayAttr(indexingMaps),
|
|
rewriter.getStrArrayAttr(iteratorTypes));
|
|
|
|
rewriter.create<StoreOp>(loc, result, output, ValueRange(zeros));
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
|
|
using ConvOpConst = ConvOpVectorization<ConvWOp, 1>;
|
|
|
|
/// Inserts tiling, promotion and vectorization pattern for ConvOp
|
|
/// conversion into corresponding pattern lists.
|
|
template <typename ConvOp, unsigned N>
|
|
static void
|
|
populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns,
|
|
OwningRewritePatternList &promotionPatterns,
|
|
OwningRewritePatternList &vectorizationPatterns,
|
|
ArrayRef<int64_t> tileSizes,
|
|
MLIRContext *context) {
|
|
if (tileSizes.size() < N)
|
|
return;
|
|
|
|
constexpr static StringRef kTiledMarker = "TILED";
|
|
constexpr static StringRef kPromotedMarker = "PROMOTED";
|
|
tilingPatterns.insert<LinalgTilingPattern<ConvOp>>(
|
|
context, LinalgTilingOptions().setTileSizes(tileSizes),
|
|
LinalgTransformationFilter(ArrayRef<Identifier>{},
|
|
Identifier::get(kTiledMarker, context)));
|
|
|
|
promotionPatterns.insert<LinalgPromotionPattern<ConvOp>>(
|
|
context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
|
|
LinalgTransformationFilter(Identifier::get(kTiledMarker, context),
|
|
Identifier::get(kPromotedMarker, context)));
|
|
|
|
SmallVector<bool, 4> mask(N);
|
|
int offset = tileSizes.size() - N;
|
|
std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(),
|
|
[](int64_t i) -> bool { return i > 1; });
|
|
|
|
vectorizationPatterns.insert<ConvOpVectorization<ConvOp, N>>(context, mask);
|
|
}
|
|
|
|
void mlir::linalg::populateConvVectorizationPatterns(
|
|
MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
|
|
ArrayRef<int64_t> tileSizes) {
|
|
OwningRewritePatternList tiling, promotion, vectorization;
|
|
populateVectorizationPatterns<ConvWOp, 1>(tiling, promotion, vectorization,
|
|
tileSizes, context);
|
|
|
|
populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization,
|
|
tileSizes, context);
|
|
|
|
populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization,
|
|
tileSizes, context);
|
|
|
|
populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization,
|
|
tileSizes, context);
|
|
|
|
populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization,
|
|
tileSizes, context);
|
|
|
|
populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization,
|
|
tileSizes, context);
|
|
|
|
populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization,
|
|
tileSizes, context);
|
|
|
|
populateVectorizationPatterns<ConvNDHWCOp, 5>(
|
|
tiling, promotion, vectorization, tileSizes, context);
|
|
|
|
populateVectorizationPatterns<ConvNCDHWOp, 5>(
|
|
tiling, promotion, vectorization, tileSizes, context);
|
|
|
|
patterns.push_back(std::move(tiling));
|
|
patterns.push_back(std::move(promotion));
|
|
patterns.push_back(std::move(vectorization));
|
|
}
|
|
|
|
//----------------------------------------------------------------------------//
|
|
// Forwarding patterns
|
|
//----------------------------------------------------------------------------//
|
|
|
|
/// Check whether there is any interleaved use of any `values` between `firstOp`
|
|
/// and `secondOp`. Conservatively return `true` if any op or value is in a
|
|
/// different block.
|
|
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
|
|
ValueRange values) {
|
|
if (firstOp->getBlock() != secondOp->getBlock() ||
|
|
!firstOp->isBeforeInBlock(secondOp)) {
|
|
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
|
|
<< "interleavedUses precondition failed, firstOp: "
|
|
<< *firstOp << ", second op: " << *secondOp);
|
|
return true;
|
|
}
|
|
for (auto v : values) {
|
|
for (auto &u : v.getUses()) {
|
|
Operation *owner = u.getOwner();
|
|
if (owner == firstOp || owner == secondOp)
|
|
continue;
|
|
// TODO: this is too conservative, use dominance info in the future.
|
|
if (owner->getBlock() == firstOp->getBlock() &&
|
|
(owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
|
|
continue;
|
|
LLVM_DEBUG(llvm::dbgs()
|
|
<< "\n[" DEBUG_TYPE "]: "
|
|
<< " found interleaved op " << *owner
|
|
<< ", firstOp: " << *firstOp << ", second op: " << *secondOp);
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/// Return the unique subview use of `v` if it is indeed unique, null otherwise.
|
|
static SubViewOp getSubViewUseIfUnique(Value v) {
|
|
SubViewOp subViewOp;
|
|
for (auto &u : v.getUses()) {
|
|
if (auto newSubViewOp = dyn_cast<SubViewOp>(u.getOwner())) {
|
|
if (subViewOp)
|
|
return SubViewOp();
|
|
subViewOp = newSubViewOp;
|
|
}
|
|
}
|
|
return subViewOp;
|
|
}
|
|
|
|
/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
|
|
/// when available.
|
|
LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
|
|
vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
|
|
|
|
// Transfer into `view`.
|
|
Value viewOrAlloc = xferOp.source();
|
|
if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
|
|
!viewOrAlloc.getDefiningOp<AllocOp>())
|
|
return failure();
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " << viewOrAlloc);
|
|
|
|
// Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
|
|
SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
|
|
if (!subViewOp)
|
|
return failure();
|
|
Value subView = subViewOp.getResult();
|
|
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
|
|
<< "with subView " << subView);
|
|
|
|
// Find the copy into `subView` without interleaved uses.
|
|
CopyOp copyOp;
|
|
for (auto &u : subView.getUses()) {
|
|
if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
|
|
if (newCopyOp.getOutputBuffer(0) != subView)
|
|
continue;
|
|
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
|
|
<< "copy candidate " << *newCopyOp);
|
|
if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
|
|
continue;
|
|
copyOp = newCopyOp;
|
|
break;
|
|
}
|
|
}
|
|
if (!copyOp)
|
|
return failure();
|
|
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
|
|
<< "with copy " << *copyOp);
|
|
|
|
// Find the fill into `viewOrAlloc` without interleaved uses before the copy.
|
|
FillOp maybeFillOp;
|
|
for (auto &u : viewOrAlloc.getUses()) {
|
|
if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
|
|
if (newFillOp.getOutputBuffer(0) != viewOrAlloc)
|
|
continue;
|
|
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
|
|
<< "fill candidate " << *newFillOp);
|
|
if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
|
|
continue;
|
|
maybeFillOp = newFillOp;
|
|
break;
|
|
}
|
|
}
|
|
// Ensure padding matches.
|
|
if (maybeFillOp && xferOp.padding() != maybeFillOp.value())
|
|
return failure();
|
|
if (maybeFillOp)
|
|
LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
|
|
<< "with maybeFillOp " << *maybeFillOp);
|
|
|
|
// `in` is the subview that linalg.copy reads. Replace it.
|
|
Value in = copyOp.getInput(0);
|
|
|
|
// linalg.copy + linalg.fill can be used to create a padded local buffer.
|
|
// The `masked` attribute is only valid on this padded buffer.
|
|
// When forwarding to vector.transfer_read, the attribute must be reset
|
|
// conservatively.
|
|
Value res = rewriter.create<vector::TransferReadOp>(
|
|
xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(),
|
|
xferOp.permutation_map(), xferOp.padding(), ArrayAttr());
|
|
|
|
if (maybeFillOp)
|
|
rewriter.eraseOp(maybeFillOp);
|
|
rewriter.eraseOp(copyOp);
|
|
rewriter.replaceOp(xferOp, res);
|
|
|
|
return success();
|
|
}
|
|
|
|
/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
|
|
/// when available.
|
|
LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
|
|
vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
|
|
// Transfer into `viewOrAlloc`.
|
|
Value viewOrAlloc = xferOp.source();
|
|
if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
|
|
!viewOrAlloc.getDefiningOp<AllocOp>())
|
|
return failure();
|
|
|
|
// Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
|
|
SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
|
|
if (!subViewOp)
|
|
return failure();
|
|
Value subView = subViewOp.getResult();
|
|
|
|
// Find the copy from `subView` without interleaved uses.
|
|
CopyOp copyOp;
|
|
for (auto &u : subViewOp.getResult().getUses()) {
|
|
if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
|
|
if (newCopyOp.getInput(0) != subView)
|
|
continue;
|
|
if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
|
|
continue;
|
|
copyOp = newCopyOp;
|
|
break;
|
|
}
|
|
}
|
|
if (!copyOp)
|
|
return failure();
|
|
|
|
// `out` is the subview copied into that we replace.
|
|
Value out = copyOp.getOutputBuffer(0);
|
|
|
|
// Forward vector.transfer into copy.
|
|
// linalg.copy + linalg.fill can be used to create a padded local buffer.
|
|
// The `masked` attribute is only valid on this padded buffer.
|
|
// When forwarding to vector.transfer_write, the attribute must be reset
|
|
// conservatively.
|
|
rewriter.create<vector::TransferWriteOp>(
|
|
xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(),
|
|
xferOp.permutation_map(), ArrayAttr());
|
|
|
|
rewriter.eraseOp(copyOp);
|
|
rewriter.eraseOp(xferOp);
|
|
|
|
return success();
|
|
}
|