Files
clang-p2996/mlir/lib/Transforms/LowerVectorTransfers.cpp
Nicolas Vasilache d9b6420fc9 [MLIR] Add LowerVectorTransfersPass
This CL adds a pass that lowers VectorTransferReadOp and VectorTransferWriteOp
to a simple loop nest via local buffer allocations.

This is an MLIR->MLIR lowering based on builders.

A few TODOs are left to address in particular:
1. invert the permutation map so the accesses to the remote memref are coalesced;
2. pad the alloc for bank conflicts in local memory (e.g. GPUs shared_memory);
3. support broadcast / avoid copies when permutation_map is not of full column rank
4. add a proper "element_cast" op

One notable limitation is this does not plan on supporting boundary conditions.
It should be significantly easier to use pre-baked MLIR functions to handle such paddings.
This is left for future consideration.
Therefore the current CL only works properly for full-tile cases atm.

This CL also adds 2 simple tests:

```mlir
  for %i0 = 0 to %M step 3 {
    for %i1 = 0 to %N step 4 {
      for %i2 = 0 to %O {
        for %i3 = 0 to %P step 5 {
          vector_transfer_write %f1, %A, %i0, %i1, %i2, %i3 {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d0)} : vector<5x4x3xf32>, memref<?x?x?x?xf32, 0>, index, index, index, index
```

lowers into:
```mlir
for %i0 = 0 to %arg0 step 3 {
  for %i1 = 0 to %arg1 step 4 {
    for %i2 = 0 to %arg2 {
      for %i3 = 0 to %arg3 step 5 {
        %1 = alloc() : memref<5x4x3xf32>
        %2 = "element_type_cast"(%1) : (memref<5x4x3xf32>) -> memref<1xvector<5x4x3xf32>>
        store %cst, %2[%c0] : memref<1xvector<5x4x3xf32>>
        for %i4 = 0 to 5 {
          %3 = affine_apply (d0, d1) -> (d0 + d1) (%i3, %i4)
          for %i5 = 0 to 4 {
            %4 = affine_apply (d0, d1) -> (d0 + d1) (%i1, %i5)
            for %i6 = 0 to 3 {
              %5 = affine_apply (d0, d1) -> (d0 + d1) (%i0, %i6)
              %6 = load %1[%i4, %i5, %i6] : memref<5x4x3xf32>
              store %6, %0[%5, %4, %i2, %3] : memref<?x?x?x?xf32>
       dealloc %1 : memref<5x4x3xf32>
```

and
```mlir
  for %i0 = 0 to %M step 3 {
    for %i1 = 0 to %N {
      for %i2 = 0 to %O {
        for %i3 = 0 to %P step 5 {
          %f = vector_transfer_read %A, %i0, %i1, %i2, %i3 {permutation_map: (d0, d1, d2, d3) -> (d3, 0, d0)} : (memref<?x?x?x?xf32, 0>, index, index, index, index) -> vector<5x4x3xf32>

```

lowers into:
```mlir
for %i0 = 0 to %arg0 step 3 {
  for %i1 = 0 to %arg1 {
    for %i2 = 0 to %arg2 {
      for %i3 = 0 to %arg3 step 5 {
        %1 = alloc() : memref<5x4x3xf32>
        %2 = "element_type_cast"(%1) : (memref<5x4x3xf32>) -> memref<1xvector<5x4x3xf32>>
        for %i4 = 0 to 5 {
          %3 = affine_apply (d0, d1) -> (d0 + d1) (%i3, %i4)
          for %i5 = 0 to 4 {
            for %i6 = 0 to 3 {
              %4 = affine_apply (d0, d1) -> (d0 + d1) (%i0, %i6)
              %5 = load %0[%4, %i1, %i2, %3] : memref<?x?x?x?xf32>
              store %5, %1[%i4, %i5, %i6] : memref<5x4x3xf32>
        %6 = load %2[%c0] : memref<1xvector<5x4x3xf32>>
        dealloc %1 : memref<5x4x3xf32>
```

PiperOrigin-RevId: 224552717
2019-03-29 14:23:05 -07:00

262 lines
10 KiB
C++

//===- LowerVectorTransfers.cpp - LowerVectorTransfers Pass Impl *- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements target-dependent lowering of vector transfer operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/MLFunctionMatcher.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/Analysis/VectorAnalysis.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLValue.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/SSAValue.h"
#include "mlir/IR/Types.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <type_traits>
///
/// Implements lowering of VectorTransferReadOp and VectorTransferWriteOp to a
/// proper abstraction for the hardware.
///
/// For now only a simple loop nest is emitted.
///
using llvm::dbgs;
using llvm::SetVector;
using namespace mlir;
#define DEBUG_TYPE "lower-vector-transfers"
/// Creates and returns a memoized `constant 0 : index` at the top level of the
/// each function `f` on which it is called.
static SSAValue *getZero(MLFunction *f) {
static thread_local llvm::DenseMap<MLFunction *, SSAValue *> zeros;
auto it = zeros.find(f);
if (it == zeros.end()) {
MLFuncBuilder b(f);
b.setInsertionPointToStart(f);
zeros.insert(
std::make_pair(f, b.create<ConstantIndexOp>(b.getUnknownLoc(), 0)));
it = zeros.find(f);
}
return it->second;
}
namespace {
struct LowerVectorTransfersPass : public FunctionPass {
LowerVectorTransfersPass()
: FunctionPass(&LowerVectorTransfersPass::passID) {}
PassResult runOnMLFunction(MLFunction *f) override;
// Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit.
MLFunctionMatcherContext mlContext;
static char passID;
};
} // end anonymous namespace
char LowerVectorTransfersPass::passID = 0;
/// Creates the SSAValue for the sum of `a` and `b` without building a
/// full-fledged AffineMap for all indices.
///
/// Prerequisites:
/// `a` and `b` must be of IndexType.
static SSAValue *add(MLFuncBuilder *b, Location loc, SSAValue *v, SSAValue *w) {
assert(v->getType().isa<IndexType>() && "v must be of IndexType");
assert(w->getType().isa<IndexType>() && "w must be of IndexType");
auto *context = b->getContext();
auto d0 = getAffineDimExpr(0, context);
auto d1 = getAffineDimExpr(1, context);
auto map = AffineMap::get(2, 0, {d0 + d1}, {});
return b->create<AffineApplyOp>(loc, map, ArrayRef<SSAValue *>{v, w})
->getResult(0);
}
/// Performs simple lowering into a combination of:
/// 1. local memory allocation,
/// 2. vector_load/vector_store from/to local buffer
/// 3. perfect loop nest over scalar loads/stores from/to remote memory.
///
/// This is a simple sketch for now but does the job.
// TODO(ntv): This function has a lot of code conditioned on the template
// argument being one of the two types. Extract the common behavior into helper
// functions and detemplatizing it.
template <typename VectorTransferOpTy>
static void lowerAsLoops(VectorTransferOpTy *transfer) {
static_assert(
std::is_same<VectorTransferOpTy, VectorTransferReadOp>::value ||
std::is_same<VectorTransferOpTy, VectorTransferWriteOp>::value,
"Must be called on either VectorTransferReadOp or VectorTransferWriteOp");
auto vectorType = transfer->getVectorType();
auto vectorShape = vectorType.getShape();
// tmpMemRefType is used for staging the transfer in a local scalar buffer.
auto tmpMemRefType =
MemRefType::get(vectorShape, vectorType.getElementType(), {}, 0);
// vectorMemRefType is a view of tmpMemRefType as one vector.
auto vectorMemRefType = MemRefType::get({1}, vectorType, {}, 0);
MLFuncBuilder b(cast<OperationStmt>(transfer->getOperation()));
auto *zero = getZero(b.getFunction());
// 1. First allocate the local buffer in fast memory.
// TODO(ntv): CL memory space.
// TODO(ntv): Allocation padding for potential bank conflicts (e.g. GPUs).
auto tmpScalarAlloc = b.create<AllocOp>(transfer->getLoc(), tmpMemRefType);
// TODO(ntv): Proper OperationStmt.
OperationState state(b.getContext(), transfer->getLoc(), "vector_type_cast",
ArrayRef<SSAValue *>{tmpScalarAlloc->getResult()},
ArrayRef<Type>{vectorMemRefType});
auto vecView = b.createOperation(state);
// 2. Store the vector to local storage in case of a vector_transfer_write.
// TODO(ntv): This vector_store operation should be further lowered in the
// case of GPUs.
if (std::is_same<VectorTransferOpTy, VectorTransferWriteOp>::value) {
b.create<StoreOp>(vecView->getLoc(), transfer->getVector(),
vecView->getResult(0), ArrayRef<SSAValue *>{zero});
}
// 3. Emit the loop-nest.
// TODO(ntv): Invert the mapping and indexing contiguously in the remote
// memory.
// TODO(ntv): Handle broadcast / slice properly.
auto permutationMap = transfer->getPermutationMap();
SetVector<ForStmt *> loops;
SmallVector<SSAValue *, 8> accessIndices(transfer->getIndices());
for (auto it : llvm::enumerate(transfer->getVectorType().getShape())) {
auto composed = composeWithUnboundedMap(
getAffineDimExpr(it.index(), b.getContext()), permutationMap);
auto *forStmt = b.createFor(transfer->getLoc(), 0, it.value());
loops.insert(forStmt);
// Setting the insertion point to the innermost loop achieves nesting.
b.setInsertionPointToStart(loops.back());
if (composed == getAffineConstantExpr(0, b.getContext())) {
transfer->emitWarning(
"Redundant copy can be implemented as a vector broadcast");
} else {
auto dim = composed.template cast<AffineDimExpr>();
assert(accessIndices.size() > dim.getPosition());
accessIndices[dim.getPosition()] =
::add(&b, transfer->getLoc(), accessIndices[dim.getPosition()],
loops.back());
}
}
// 4. Emit memory operations within the loops.
// TODO(ntv): SelectOp + padding value for load out-of-bounds.
if (std::is_same<VectorTransferOpTy, VectorTransferReadOp>::value) {
// VectorTransferReadOp.
// a. read scalar from remote;
// b. write scalar to local.
auto scalarLoad = b.create<LoadOp>(transfer->getLoc(),
transfer->getMemRef(), accessIndices);
b.create<StoreOp>(
transfer->getLoc(), scalarLoad->getResult(),
tmpScalarAlloc->getResult(),
functional::map([](SSAValue *val) { return val; }, loops));
} else {
// VectorTransferWriteOp.
// a. read scalar from local;
// b. write scalar to remote.
auto scalarLoad = b.create<LoadOp>(
transfer->getLoc(), tmpScalarAlloc->getResult(),
functional::map([](SSAValue *val) { return val; }, loops));
b.create<StoreOp>(transfer->getLoc(), scalarLoad->getResult(),
transfer->getMemRef(), accessIndices);
}
// 5. Read the vector from local storage in case of a vector_transfer_read.
// TODO(ntv): This vector_load operation should be further lowered in the
// case of GPUs.
if (std::is_same<VectorTransferOpTy, VectorTransferReadOp>::value) {
b.setInsertionPoint(cast<OperationStmt>(transfer->getOperation()));
auto *vector = b.create<LoadOp>(transfer->getLoc(), vecView->getResult(0),
ArrayRef<SSAValue *>{zero})
->getResult();
transfer->getVector()->replaceAllUsesWith(vector);
}
// 6. Free the local buffer.
b.setInsertionPoint(cast<OperationStmt>(transfer->getOperation()));
b.create<DeallocOp>(transfer->getLoc(), tmpScalarAlloc);
// 7. It is now safe to erase the statement.
transfer->erase();
}
PassResult LowerVectorTransfersPass::runOnMLFunction(MLFunction *f) {
using matcher::Op;
LLVM_DEBUG(dbgs() << "\nLowerVectorTransfersPass on MLFunction\n");
LLVM_DEBUG(f->print(dbgs()));
// Avoid any read/write ordering considerations: do it in 2 steps.
// 1. vector_transfer_reads;
auto filterReads = [](const Statement &stmt) {
const auto &opStmt = cast<OperationStmt>(stmt);
return opStmt.isa<VectorTransferReadOp>();
};
for (auto m : Op(filterReads).match(f)) {
auto read = cast<OperationStmt>(m.first)->cast<VectorTransferReadOp>();
// TODO(ntv): Drop &* once lowerAsLoops is detemplatized.
lowerAsLoops(&*read);
}
// 2. vector_transfer_writes;
auto filterWrites = [](const Statement &stmt) {
const auto &opStmt = cast<OperationStmt>(stmt);
return opStmt.isa<VectorTransferWriteOp>();
};
for (auto m : Op(filterWrites).match(f)) {
auto write = cast<OperationStmt>(m.first)->cast<VectorTransferWriteOp>();
// TODO(ntv): Drop &* once lowerAsLoops is detemplatized.
lowerAsLoops(&*write);
}
return PassResult::Success;
}
FunctionPass *mlir::createLowerVectorTransfersPass() {
return new LowerVectorTransfersPass();
}
static PassRegistration<LowerVectorTransfersPass>
pass("lower-vector-transfers", "Materializes vector transfer ops to a "
"proper abstraction for the hardware");
#undef DEBUG_TYPE