This CL adds a pass that lowers VectorTransferReadOp and VectorTransferWriteOp
to a simple loop nest via local buffer allocations.
This is an MLIR->MLIR lowering based on builders.
A few TODOs are left to address in particular:
1. invert the permutation map so the accesses to the remote memref are coalesced;
2. pad the alloc for bank conflicts in local memory (e.g. GPUs shared_memory);
3. support broadcast / avoid copies when permutation_map is not of full column rank
4. add a proper "element_cast" op
One notable limitation is this does not plan on supporting boundary conditions.
It should be significantly easier to use pre-baked MLIR functions to handle such paddings.
This is left for future consideration.
Therefore the current CL only works properly for full-tile cases atm.
This CL also adds 2 simple tests:
```mlir
for %i0 = 0 to %M step 3 {
for %i1 = 0 to %N step 4 {
for %i2 = 0 to %O {
for %i3 = 0 to %P step 5 {
vector_transfer_write %f1, %A, %i0, %i1, %i2, %i3 {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d0)} : vector<5x4x3xf32>, memref<?x?x?x?xf32, 0>, index, index, index, index
```
lowers into:
```mlir
for %i0 = 0 to %arg0 step 3 {
for %i1 = 0 to %arg1 step 4 {
for %i2 = 0 to %arg2 {
for %i3 = 0 to %arg3 step 5 {
%1 = alloc() : memref<5x4x3xf32>
%2 = "element_type_cast"(%1) : (memref<5x4x3xf32>) -> memref<1xvector<5x4x3xf32>>
store %cst, %2[%c0] : memref<1xvector<5x4x3xf32>>
for %i4 = 0 to 5 {
%3 = affine_apply (d0, d1) -> (d0 + d1) (%i3, %i4)
for %i5 = 0 to 4 {
%4 = affine_apply (d0, d1) -> (d0 + d1) (%i1, %i5)
for %i6 = 0 to 3 {
%5 = affine_apply (d0, d1) -> (d0 + d1) (%i0, %i6)
%6 = load %1[%i4, %i5, %i6] : memref<5x4x3xf32>
store %6, %0[%5, %4, %i2, %3] : memref<?x?x?x?xf32>
dealloc %1 : memref<5x4x3xf32>
```
and
```mlir
for %i0 = 0 to %M step 3 {
for %i1 = 0 to %N {
for %i2 = 0 to %O {
for %i3 = 0 to %P step 5 {
%f = vector_transfer_read %A, %i0, %i1, %i2, %i3 {permutation_map: (d0, d1, d2, d3) -> (d3, 0, d0)} : (memref<?x?x?x?xf32, 0>, index, index, index, index) -> vector<5x4x3xf32>
```
lowers into:
```mlir
for %i0 = 0 to %arg0 step 3 {
for %i1 = 0 to %arg1 {
for %i2 = 0 to %arg2 {
for %i3 = 0 to %arg3 step 5 {
%1 = alloc() : memref<5x4x3xf32>
%2 = "element_type_cast"(%1) : (memref<5x4x3xf32>) -> memref<1xvector<5x4x3xf32>>
for %i4 = 0 to 5 {
%3 = affine_apply (d0, d1) -> (d0 + d1) (%i3, %i4)
for %i5 = 0 to 4 {
for %i6 = 0 to 3 {
%4 = affine_apply (d0, d1) -> (d0 + d1) (%i0, %i6)
%5 = load %0[%4, %i1, %i2, %3] : memref<?x?x?x?xf32>
store %5, %1[%i4, %i5, %i6] : memref<5x4x3xf32>
%6 = load %2[%c0] : memref<1xvector<5x4x3xf32>>
dealloc %1 : memref<5x4x3xf32>
```
PiperOrigin-RevId: 224552717
262 lines
10 KiB
C++
262 lines
10 KiB
C++
//===- LowerVectorTransfers.cpp - LowerVectorTransfers Pass Impl *- C++ -*-===//
|
|
//
|
|
// Copyright 2019 The MLIR Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
// =============================================================================
|
|
//
|
|
// This file implements target-dependent lowering of vector transfer operations.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Analysis/AffineAnalysis.h"
|
|
#include "mlir/Analysis/MLFunctionMatcher.h"
|
|
#include "mlir/Analysis/Utils.h"
|
|
#include "mlir/Analysis/VectorAnalysis.h"
|
|
#include "mlir/IR/AffineExpr.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/Location.h"
|
|
#include "mlir/IR/MLValue.h"
|
|
#include "mlir/IR/OperationSupport.h"
|
|
#include "mlir/IR/SSAValue.h"
|
|
#include "mlir/IR/Types.h"
|
|
#include "mlir/Pass.h"
|
|
#include "mlir/StandardOps/StandardOps.h"
|
|
#include "mlir/Support/Functional.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
#include "mlir/Transforms/Passes.h"
|
|
|
|
#include "llvm/ADT/SetVector.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include <type_traits>
|
|
|
|
///
|
|
/// Implements lowering of VectorTransferReadOp and VectorTransferWriteOp to a
|
|
/// proper abstraction for the hardware.
|
|
///
|
|
/// For now only a simple loop nest is emitted.
|
|
///
|
|
|
|
using llvm::dbgs;
|
|
using llvm::SetVector;
|
|
|
|
using namespace mlir;
|
|
|
|
#define DEBUG_TYPE "lower-vector-transfers"
|
|
|
|
/// Creates and returns a memoized `constant 0 : index` at the top level of the
|
|
/// each function `f` on which it is called.
|
|
static SSAValue *getZero(MLFunction *f) {
|
|
static thread_local llvm::DenseMap<MLFunction *, SSAValue *> zeros;
|
|
auto it = zeros.find(f);
|
|
if (it == zeros.end()) {
|
|
MLFuncBuilder b(f);
|
|
b.setInsertionPointToStart(f);
|
|
zeros.insert(
|
|
std::make_pair(f, b.create<ConstantIndexOp>(b.getUnknownLoc(), 0)));
|
|
it = zeros.find(f);
|
|
}
|
|
return it->second;
|
|
}
|
|
|
|
namespace {
|
|
|
|
struct LowerVectorTransfersPass : public FunctionPass {
|
|
LowerVectorTransfersPass()
|
|
: FunctionPass(&LowerVectorTransfersPass::passID) {}
|
|
|
|
PassResult runOnMLFunction(MLFunction *f) override;
|
|
|
|
// Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit.
|
|
MLFunctionMatcherContext mlContext;
|
|
|
|
static char passID;
|
|
};
|
|
|
|
} // end anonymous namespace
|
|
|
|
char LowerVectorTransfersPass::passID = 0;
|
|
|
|
/// Creates the SSAValue for the sum of `a` and `b` without building a
|
|
/// full-fledged AffineMap for all indices.
|
|
///
|
|
/// Prerequisites:
|
|
/// `a` and `b` must be of IndexType.
|
|
static SSAValue *add(MLFuncBuilder *b, Location loc, SSAValue *v, SSAValue *w) {
|
|
assert(v->getType().isa<IndexType>() && "v must be of IndexType");
|
|
assert(w->getType().isa<IndexType>() && "w must be of IndexType");
|
|
auto *context = b->getContext();
|
|
auto d0 = getAffineDimExpr(0, context);
|
|
auto d1 = getAffineDimExpr(1, context);
|
|
auto map = AffineMap::get(2, 0, {d0 + d1}, {});
|
|
return b->create<AffineApplyOp>(loc, map, ArrayRef<SSAValue *>{v, w})
|
|
->getResult(0);
|
|
}
|
|
|
|
/// Performs simple lowering into a combination of:
|
|
/// 1. local memory allocation,
|
|
/// 2. vector_load/vector_store from/to local buffer
|
|
/// 3. perfect loop nest over scalar loads/stores from/to remote memory.
|
|
///
|
|
/// This is a simple sketch for now but does the job.
|
|
// TODO(ntv): This function has a lot of code conditioned on the template
|
|
// argument being one of the two types. Extract the common behavior into helper
|
|
// functions and detemplatizing it.
|
|
template <typename VectorTransferOpTy>
|
|
static void lowerAsLoops(VectorTransferOpTy *transfer) {
|
|
static_assert(
|
|
std::is_same<VectorTransferOpTy, VectorTransferReadOp>::value ||
|
|
std::is_same<VectorTransferOpTy, VectorTransferWriteOp>::value,
|
|
"Must be called on either VectorTransferReadOp or VectorTransferWriteOp");
|
|
auto vectorType = transfer->getVectorType();
|
|
auto vectorShape = vectorType.getShape();
|
|
// tmpMemRefType is used for staging the transfer in a local scalar buffer.
|
|
auto tmpMemRefType =
|
|
MemRefType::get(vectorShape, vectorType.getElementType(), {}, 0);
|
|
// vectorMemRefType is a view of tmpMemRefType as one vector.
|
|
auto vectorMemRefType = MemRefType::get({1}, vectorType, {}, 0);
|
|
|
|
MLFuncBuilder b(cast<OperationStmt>(transfer->getOperation()));
|
|
auto *zero = getZero(b.getFunction());
|
|
|
|
// 1. First allocate the local buffer in fast memory.
|
|
// TODO(ntv): CL memory space.
|
|
// TODO(ntv): Allocation padding for potential bank conflicts (e.g. GPUs).
|
|
auto tmpScalarAlloc = b.create<AllocOp>(transfer->getLoc(), tmpMemRefType);
|
|
// TODO(ntv): Proper OperationStmt.
|
|
OperationState state(b.getContext(), transfer->getLoc(), "vector_type_cast",
|
|
ArrayRef<SSAValue *>{tmpScalarAlloc->getResult()},
|
|
ArrayRef<Type>{vectorMemRefType});
|
|
auto vecView = b.createOperation(state);
|
|
|
|
// 2. Store the vector to local storage in case of a vector_transfer_write.
|
|
// TODO(ntv): This vector_store operation should be further lowered in the
|
|
// case of GPUs.
|
|
if (std::is_same<VectorTransferOpTy, VectorTransferWriteOp>::value) {
|
|
b.create<StoreOp>(vecView->getLoc(), transfer->getVector(),
|
|
vecView->getResult(0), ArrayRef<SSAValue *>{zero});
|
|
}
|
|
|
|
// 3. Emit the loop-nest.
|
|
// TODO(ntv): Invert the mapping and indexing contiguously in the remote
|
|
// memory.
|
|
// TODO(ntv): Handle broadcast / slice properly.
|
|
auto permutationMap = transfer->getPermutationMap();
|
|
SetVector<ForStmt *> loops;
|
|
SmallVector<SSAValue *, 8> accessIndices(transfer->getIndices());
|
|
for (auto it : llvm::enumerate(transfer->getVectorType().getShape())) {
|
|
auto composed = composeWithUnboundedMap(
|
|
getAffineDimExpr(it.index(), b.getContext()), permutationMap);
|
|
auto *forStmt = b.createFor(transfer->getLoc(), 0, it.value());
|
|
loops.insert(forStmt);
|
|
// Setting the insertion point to the innermost loop achieves nesting.
|
|
b.setInsertionPointToStart(loops.back());
|
|
if (composed == getAffineConstantExpr(0, b.getContext())) {
|
|
transfer->emitWarning(
|
|
"Redundant copy can be implemented as a vector broadcast");
|
|
} else {
|
|
auto dim = composed.template cast<AffineDimExpr>();
|
|
assert(accessIndices.size() > dim.getPosition());
|
|
accessIndices[dim.getPosition()] =
|
|
::add(&b, transfer->getLoc(), accessIndices[dim.getPosition()],
|
|
loops.back());
|
|
}
|
|
}
|
|
|
|
// 4. Emit memory operations within the loops.
|
|
// TODO(ntv): SelectOp + padding value for load out-of-bounds.
|
|
if (std::is_same<VectorTransferOpTy, VectorTransferReadOp>::value) {
|
|
// VectorTransferReadOp.
|
|
// a. read scalar from remote;
|
|
// b. write scalar to local.
|
|
auto scalarLoad = b.create<LoadOp>(transfer->getLoc(),
|
|
transfer->getMemRef(), accessIndices);
|
|
b.create<StoreOp>(
|
|
transfer->getLoc(), scalarLoad->getResult(),
|
|
tmpScalarAlloc->getResult(),
|
|
functional::map([](SSAValue *val) { return val; }, loops));
|
|
} else {
|
|
// VectorTransferWriteOp.
|
|
// a. read scalar from local;
|
|
// b. write scalar to remote.
|
|
auto scalarLoad = b.create<LoadOp>(
|
|
transfer->getLoc(), tmpScalarAlloc->getResult(),
|
|
functional::map([](SSAValue *val) { return val; }, loops));
|
|
b.create<StoreOp>(transfer->getLoc(), scalarLoad->getResult(),
|
|
transfer->getMemRef(), accessIndices);
|
|
}
|
|
|
|
// 5. Read the vector from local storage in case of a vector_transfer_read.
|
|
// TODO(ntv): This vector_load operation should be further lowered in the
|
|
// case of GPUs.
|
|
if (std::is_same<VectorTransferOpTy, VectorTransferReadOp>::value) {
|
|
b.setInsertionPoint(cast<OperationStmt>(transfer->getOperation()));
|
|
auto *vector = b.create<LoadOp>(transfer->getLoc(), vecView->getResult(0),
|
|
ArrayRef<SSAValue *>{zero})
|
|
->getResult();
|
|
transfer->getVector()->replaceAllUsesWith(vector);
|
|
}
|
|
|
|
// 6. Free the local buffer.
|
|
b.setInsertionPoint(cast<OperationStmt>(transfer->getOperation()));
|
|
b.create<DeallocOp>(transfer->getLoc(), tmpScalarAlloc);
|
|
|
|
// 7. It is now safe to erase the statement.
|
|
transfer->erase();
|
|
}
|
|
|
|
PassResult LowerVectorTransfersPass::runOnMLFunction(MLFunction *f) {
|
|
using matcher::Op;
|
|
LLVM_DEBUG(dbgs() << "\nLowerVectorTransfersPass on MLFunction\n");
|
|
LLVM_DEBUG(f->print(dbgs()));
|
|
|
|
// Avoid any read/write ordering considerations: do it in 2 steps.
|
|
// 1. vector_transfer_reads;
|
|
auto filterReads = [](const Statement &stmt) {
|
|
const auto &opStmt = cast<OperationStmt>(stmt);
|
|
return opStmt.isa<VectorTransferReadOp>();
|
|
};
|
|
for (auto m : Op(filterReads).match(f)) {
|
|
auto read = cast<OperationStmt>(m.first)->cast<VectorTransferReadOp>();
|
|
// TODO(ntv): Drop &* once lowerAsLoops is detemplatized.
|
|
lowerAsLoops(&*read);
|
|
}
|
|
|
|
// 2. vector_transfer_writes;
|
|
auto filterWrites = [](const Statement &stmt) {
|
|
const auto &opStmt = cast<OperationStmt>(stmt);
|
|
return opStmt.isa<VectorTransferWriteOp>();
|
|
};
|
|
for (auto m : Op(filterWrites).match(f)) {
|
|
auto write = cast<OperationStmt>(m.first)->cast<VectorTransferWriteOp>();
|
|
// TODO(ntv): Drop &* once lowerAsLoops is detemplatized.
|
|
lowerAsLoops(&*write);
|
|
}
|
|
|
|
return PassResult::Success;
|
|
}
|
|
|
|
FunctionPass *mlir::createLowerVectorTransfersPass() {
|
|
return new LowerVectorTransfersPass();
|
|
}
|
|
|
|
static PassRegistration<LowerVectorTransfersPass>
|
|
pass("lower-vector-transfers", "Materializes vector transfer ops to a "
|
|
"proper abstraction for the hardware");
|
|
|
|
#undef DEBUG_TYPE
|