Extend loop tiling utility to handle non-constant loop bounds and bounds that

are a max/min of several expressions.

- Extend loop tiling to handle non-constant loop bounds and bounds that
  are a max/min of several expressions, i.e., bounds using multi-result affine
  maps

- also fix b/120630124 as a result (the IR was in an invalid state when tiled
  loop generation failed; SSA uses were created that weren't plugged into the IR).

PiperOrigin-RevId: 224604460
This commit is contained in:
Uday Bondhugula
2018-12-07 17:35:49 -08:00
committed by jpienaar
parent dfc752e42b
commit 2d6478fa92
2 changed files with 67 additions and 39 deletions

View File

@@ -67,14 +67,14 @@ static inline void moveLoopBody(ForStmt *src, ForStmt *dest) {
moveLoopBody(src, dest, dest->begin());
}
/// Constructs/sets new loop bounds after tiling for the case of
/// Constructs and sets new loop bounds after tiling for the case of
/// hyper-rectangular index sets, where the bounds of one dimension do not
/// depend on other dimensions. Bounds of each dimension can thus be treated
/// independently, and deriving the new bounds is much simpler and faster
/// than for the case of tiling arbitrary polyhedral shapes.
static bool setTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops,
ArrayRef<ForStmt *> newLoops,
ArrayRef<unsigned> tileSizes) {
static void constructTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops,
ArrayRef<ForStmt *> newLoops,
ArrayRef<unsigned> tileSizes) {
assert(!origLoops.empty());
assert(origLoops.size() == tileSizes.size());
@@ -95,41 +95,50 @@ static bool setTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops,
}
// Bounds for intra-tile loops.
for (unsigned i = 0; i < width; i++) {
// TODO(bondhugula): Keep it simple for now - constant upper bound.
if (!origLoops[i]->hasConstantUpperBound())
return false;
int64_t largestDiv = getLargestDivisorOfTripCount(*origLoops[i]);
auto mayBeConstantCount = getConstantTripCount(*origLoops[i]);
AffineMap lbMap, ubMap;
auto dim = b.getAffineDimExpr(0);
lbMap = b.getAffineMap(1, 0, dim, {});
newLoops[width + i]->setLowerBound(newLoops[i], lbMap);
// The lower bound is just the tile-space loop.
AffineMap lbMap = b.getDimIdentityMap();
newLoops[width + i]->setLowerBound(/*operands=*/newLoops[i], lbMap);
// Set the upper bound.
if (mayBeConstantCount.hasValue() &&
mayBeConstantCount.getValue() < tileSizes[i]) {
// Trip count is less than tile size; upper bound is the trip count.
ubMap = b.getConstantAffineMap(mayBeConstantCount.getValue());
auto ubMap = b.getConstantAffineMap(mayBeConstantCount.getValue());
newLoops[width + i]->setUpperBoundMap(ubMap);
} else if (largestDiv % tileSizes[i] != 0) {
// Intra-tile loop ii goes from i to min(i + tileSize, ub_i).
auto ubMax =
b.getAffineConstantExpr(origLoops[i]->getConstantUpperBound());
ubMap = b.getAffineMap(1, 0, {dim + tileSizes[i], ubMax}, {});
newLoops[width + i]->setUpperBound(newLoops[i], ubMap);
// Construct the upper bound map; the operands are the original operands
// with 'i' (tile-space loop) appended to it. The new upper bound map is
// the original one with an additional expression i + tileSize appended.
SmallVector<MLValue *, 4> ubOperands(
origLoops[i]->getUpperBoundOperands());
ubOperands.push_back(newLoops[i]);
auto origUbMap = origLoops[i]->getUpperBoundMap();
SmallVector<AffineExpr, 4> boundExprs;
boundExprs.reserve(1 + origUbMap.getNumResults());
auto dim = b.getAffineDimExpr(origUbMap.getNumInputs());
// The new upper bound map is the original one with an additional
// expression i + tileSize appended.
boundExprs.push_back(dim + tileSizes[i]);
boundExprs.insert(boundExprs.end(), origUbMap.getResults().begin(),
origUbMap.getResults().end());
auto ubMap =
b.getAffineMap(origUbMap.getNumInputs() + 1, 0, boundExprs, {});
newLoops[width + i]->setUpperBound(/*operands=*/ubOperands, ubMap);
} else {
// No need of the min expression.
ubMap = b.getAffineMap(1, 0, dim + tileSizes[i], {});
auto dim = b.getAffineDimExpr(0);
auto ubMap = b.getAffineMap(1, 0, dim + tileSizes[i], {});
newLoops[width + i]->setUpperBound(newLoops[i], ubMap);
}
}
return true;
}
/// Tiles the specified band of perfectly nested loops creating tile-space loops
/// and intra-tile loops. A band is a contiguous set of loops.
// TODO(bondhugula): handle non-constant bounds.
// TODO(bondhugula): handle non hyper-rectangular spaces.
UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
ArrayRef<unsigned> tileSizes) {
@@ -184,21 +193,19 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
FlatAffineConstraints cst(width, 0);
addIndexSet(origLoopIVs, &cst);
if (cst.isHyperRectangular(0, width)) {
if (!setTiledIndexSetHyperRect(origLoops, newLoops, tileSizes)) {
rootForStmt->emitError(
"tiled code generation unimplemented for this case");
return UtilResult::Failure;
}
// In this case, the point loop IVs just replace the original ones.
for (unsigned i = 0; i < width; i++) {
origLoopIVs[i]->replaceAllUsesWith(newLoops[i + width]);
}
} else {
rootForStmt->emitError("tiled code generation unimplemented for this case");
if (!cst.isHyperRectangular(0, width)) {
rootForStmt->emitError("tiled code generation unimplemented for the"
"non-hyperrectangular case");
return UtilResult::Failure;
}
constructTiledIndexSetHyperRect(origLoops, newLoops, tileSizes);
// In this case, the point loop IVs just replace the original ones.
for (unsigned i = 0; i < width; i++) {
origLoopIVs[i]->replaceAllUsesWith(newLoops[i + width]);
}
// Erase the old loop nest.
rootForStmt->erase();

View File

@@ -1,14 +1,19 @@
// RUN: mlir-opt %s -loop-tile -tile-size=32 | FileCheck %s
// CHECK: #map0 = (d0) -> (d0 + 32)
// CHECK: #map1 = (d0) -> (d0 + 32, 50)
// CHECK-DAG: #map0 = (d0) -> (d0 + 32)
// CHECK-DAG: #map1 = (d0) -> (d0 + 32, 50)
// CHECK-DAG: [[IDENTITY:#map[0-9]+]] = (d0) -> (d0)
// CHECK-DAG: [[LB:#map[0-9]+]] = ()[s0] -> (0, s0)
// CHECK-DAG: [[UB:#map[0-9]+]] = ()[s0, s1] -> (s0, 4096 floordiv s1)
// CHECK-DAG: [[UB_INTRA_TILE:#map[0-9]+]] = (d0, d1, d2) -> (d2 + 32, s0, 4096 floordiv s1)
// CHECK-LABEL: mlfunc @loop_tiling()
// CHECK-NEXT: for %i0 = 0 to 256 step 32 {
// CHECK-NEXT: for %i1 = 0 to 512 step 32 {
// CHECK-NEXT: for %i2 = 0 to 1024 step 32 {
// CHECK-NEXT: for %i3 = (d0) -> (d0)(%i0) to #map0(%i0) {
// CHECK-NEXT: for %i4 = (d0) -> (d0)(%i1) to #map0(%i1) {
// CHECK-NEXT: for %i5 = (d0) -> (d0)(%i2) to #map0(%i2) {
// CHECK-NEXT: for %i3 = [[IDENTITY]](%i0) to #map0(%i0) {
// CHECK-NEXT: for %i4 = [[IDENTITY]](%i1) to #map0(%i1) {
// CHECK-NEXT: for %i5 = [[IDENTITY]](%i2) to #map0(%i2) {
// CHECK-NEXT: "foo"(%i3, %i4, %i5) : (index, index, index) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: }
@@ -17,12 +22,12 @@
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: for %i6 = 0 to 50 step 32 {
// CHECK-NEXT: for %i7 = (d0) -> (d0)(%i6) to min #map1(%i6) {
// CHECK-NEXT: for %i7 = [[IDENTITY]](%i6) to min #map1(%i6) {
// CHECK-NEXT: "bar"(%i7, %i7) : (index, index) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: for %i8 = 0 to 21 step 32 {
// CHECK-NEXT: for %i9 = (d0) -> (d0)(%i8) to 21 {
// CHECK-NEXT: for %i9 = [[IDENTITY]](%i8) to 21 {
// CHECK-NEXT: "foobar"(%i9) : (index) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: }
@@ -47,3 +52,19 @@ mlfunc @loop_tiling() {
return
}
#lb = ()[s0] -> (0, s0)
#ub = ()[s0, s1] -> (s0, 4096 floordiv s1)
// CHECK-LABEL: mlfunc @loop_max_min_bound(%arg0 : memref<?xi32>, %arg1 : index, %arg2 : index) {
mlfunc @loop_max_min_bound(%A : memref<? x i32>, %L : index, %U : index) {
%M = dim %A, 0 : memref<? x i32>
for %iTT = max #lb()[%L] to min #ub()[%M, %U] {
%out = affine_apply (d0) -> (d0) (%iTT)
}
return
// CHECK: for %i0 = max [[LB]]()[%arg1] to min [[UB]]()[%0, %arg2] step 32 {
// CHECK-NEXT: for %i1 = [[IDENTITY]](%i0) to min [[UB_INTRA_TILE]](%0, %arg2, %i0) {
// CHECK-NEXT: %1 = affine_apply [[IDENTITY]](%i1)
// CHECK-NEXT: }
// CHECK-NEXT: }
}