Extend loop tiling utility to handle non-constant loop bounds and bounds that
are a max/min of several expressions. - Extend loop tiling to handle non-constant loop bounds and bounds that are a max/min of several expressions, i.e., bounds using multi-result affine maps - also fix b/120630124 as a result (the IR was in an invalid state when tiled loop generation failed; SSA uses were created that weren't plugged into the IR). PiperOrigin-RevId: 224604460
This commit is contained in:
committed by
jpienaar
parent
dfc752e42b
commit
2d6478fa92
@@ -67,14 +67,14 @@ static inline void moveLoopBody(ForStmt *src, ForStmt *dest) {
|
||||
moveLoopBody(src, dest, dest->begin());
|
||||
}
|
||||
|
||||
/// Constructs/sets new loop bounds after tiling for the case of
|
||||
/// Constructs and sets new loop bounds after tiling for the case of
|
||||
/// hyper-rectangular index sets, where the bounds of one dimension do not
|
||||
/// depend on other dimensions. Bounds of each dimension can thus be treated
|
||||
/// independently, and deriving the new bounds is much simpler and faster
|
||||
/// than for the case of tiling arbitrary polyhedral shapes.
|
||||
static bool setTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops,
|
||||
ArrayRef<ForStmt *> newLoops,
|
||||
ArrayRef<unsigned> tileSizes) {
|
||||
static void constructTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops,
|
||||
ArrayRef<ForStmt *> newLoops,
|
||||
ArrayRef<unsigned> tileSizes) {
|
||||
assert(!origLoops.empty());
|
||||
assert(origLoops.size() == tileSizes.size());
|
||||
|
||||
@@ -95,41 +95,50 @@ static bool setTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops,
|
||||
}
|
||||
// Bounds for intra-tile loops.
|
||||
for (unsigned i = 0; i < width; i++) {
|
||||
// TODO(bondhugula): Keep it simple for now - constant upper bound.
|
||||
if (!origLoops[i]->hasConstantUpperBound())
|
||||
return false;
|
||||
|
||||
int64_t largestDiv = getLargestDivisorOfTripCount(*origLoops[i]);
|
||||
auto mayBeConstantCount = getConstantTripCount(*origLoops[i]);
|
||||
AffineMap lbMap, ubMap;
|
||||
auto dim = b.getAffineDimExpr(0);
|
||||
lbMap = b.getAffineMap(1, 0, dim, {});
|
||||
newLoops[width + i]->setLowerBound(newLoops[i], lbMap);
|
||||
// The lower bound is just the tile-space loop.
|
||||
AffineMap lbMap = b.getDimIdentityMap();
|
||||
newLoops[width + i]->setLowerBound(/*operands=*/newLoops[i], lbMap);
|
||||
|
||||
// Set the upper bound.
|
||||
if (mayBeConstantCount.hasValue() &&
|
||||
mayBeConstantCount.getValue() < tileSizes[i]) {
|
||||
// Trip count is less than tile size; upper bound is the trip count.
|
||||
ubMap = b.getConstantAffineMap(mayBeConstantCount.getValue());
|
||||
auto ubMap = b.getConstantAffineMap(mayBeConstantCount.getValue());
|
||||
newLoops[width + i]->setUpperBoundMap(ubMap);
|
||||
} else if (largestDiv % tileSizes[i] != 0) {
|
||||
// Intra-tile loop ii goes from i to min(i + tileSize, ub_i).
|
||||
auto ubMax =
|
||||
b.getAffineConstantExpr(origLoops[i]->getConstantUpperBound());
|
||||
ubMap = b.getAffineMap(1, 0, {dim + tileSizes[i], ubMax}, {});
|
||||
newLoops[width + i]->setUpperBound(newLoops[i], ubMap);
|
||||
// Construct the upper bound map; the operands are the original operands
|
||||
// with 'i' (tile-space loop) appended to it. The new upper bound map is
|
||||
// the original one with an additional expression i + tileSize appended.
|
||||
SmallVector<MLValue *, 4> ubOperands(
|
||||
origLoops[i]->getUpperBoundOperands());
|
||||
ubOperands.push_back(newLoops[i]);
|
||||
|
||||
auto origUbMap = origLoops[i]->getUpperBoundMap();
|
||||
SmallVector<AffineExpr, 4> boundExprs;
|
||||
boundExprs.reserve(1 + origUbMap.getNumResults());
|
||||
auto dim = b.getAffineDimExpr(origUbMap.getNumInputs());
|
||||
// The new upper bound map is the original one with an additional
|
||||
// expression i + tileSize appended.
|
||||
boundExprs.push_back(dim + tileSizes[i]);
|
||||
boundExprs.insert(boundExprs.end(), origUbMap.getResults().begin(),
|
||||
origUbMap.getResults().end());
|
||||
auto ubMap =
|
||||
b.getAffineMap(origUbMap.getNumInputs() + 1, 0, boundExprs, {});
|
||||
newLoops[width + i]->setUpperBound(/*operands=*/ubOperands, ubMap);
|
||||
} else {
|
||||
// No need of the min expression.
|
||||
ubMap = b.getAffineMap(1, 0, dim + tileSizes[i], {});
|
||||
auto dim = b.getAffineDimExpr(0);
|
||||
auto ubMap = b.getAffineMap(1, 0, dim + tileSizes[i], {});
|
||||
newLoops[width + i]->setUpperBound(newLoops[i], ubMap);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Tiles the specified band of perfectly nested loops creating tile-space loops
|
||||
/// and intra-tile loops. A band is a contiguous set of loops.
|
||||
// TODO(bondhugula): handle non-constant bounds.
|
||||
// TODO(bondhugula): handle non hyper-rectangular spaces.
|
||||
UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
|
||||
ArrayRef<unsigned> tileSizes) {
|
||||
@@ -184,21 +193,19 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
|
||||
|
||||
FlatAffineConstraints cst(width, 0);
|
||||
addIndexSet(origLoopIVs, &cst);
|
||||
if (cst.isHyperRectangular(0, width)) {
|
||||
if (!setTiledIndexSetHyperRect(origLoops, newLoops, tileSizes)) {
|
||||
rootForStmt->emitError(
|
||||
"tiled code generation unimplemented for this case");
|
||||
return UtilResult::Failure;
|
||||
}
|
||||
// In this case, the point loop IVs just replace the original ones.
|
||||
for (unsigned i = 0; i < width; i++) {
|
||||
origLoopIVs[i]->replaceAllUsesWith(newLoops[i + width]);
|
||||
}
|
||||
} else {
|
||||
rootForStmt->emitError("tiled code generation unimplemented for this case");
|
||||
|
||||
if (!cst.isHyperRectangular(0, width)) {
|
||||
rootForStmt->emitError("tiled code generation unimplemented for the"
|
||||
"non-hyperrectangular case");
|
||||
return UtilResult::Failure;
|
||||
}
|
||||
|
||||
constructTiledIndexSetHyperRect(origLoops, newLoops, tileSizes);
|
||||
// In this case, the point loop IVs just replace the original ones.
|
||||
for (unsigned i = 0; i < width; i++) {
|
||||
origLoopIVs[i]->replaceAllUsesWith(newLoops[i + width]);
|
||||
}
|
||||
|
||||
// Erase the old loop nest.
|
||||
rootForStmt->erase();
|
||||
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
// RUN: mlir-opt %s -loop-tile -tile-size=32 | FileCheck %s
|
||||
|
||||
// CHECK: #map0 = (d0) -> (d0 + 32)
|
||||
// CHECK: #map1 = (d0) -> (d0 + 32, 50)
|
||||
// CHECK-DAG: #map0 = (d0) -> (d0 + 32)
|
||||
// CHECK-DAG: #map1 = (d0) -> (d0 + 32, 50)
|
||||
// CHECK-DAG: [[IDENTITY:#map[0-9]+]] = (d0) -> (d0)
|
||||
// CHECK-DAG: [[LB:#map[0-9]+]] = ()[s0] -> (0, s0)
|
||||
// CHECK-DAG: [[UB:#map[0-9]+]] = ()[s0, s1] -> (s0, 4096 floordiv s1)
|
||||
// CHECK-DAG: [[UB_INTRA_TILE:#map[0-9]+]] = (d0, d1, d2) -> (d2 + 32, s0, 4096 floordiv s1)
|
||||
|
||||
// CHECK-LABEL: mlfunc @loop_tiling()
|
||||
// CHECK-NEXT: for %i0 = 0 to 256 step 32 {
|
||||
// CHECK-NEXT: for %i1 = 0 to 512 step 32 {
|
||||
// CHECK-NEXT: for %i2 = 0 to 1024 step 32 {
|
||||
// CHECK-NEXT: for %i3 = (d0) -> (d0)(%i0) to #map0(%i0) {
|
||||
// CHECK-NEXT: for %i4 = (d0) -> (d0)(%i1) to #map0(%i1) {
|
||||
// CHECK-NEXT: for %i5 = (d0) -> (d0)(%i2) to #map0(%i2) {
|
||||
// CHECK-NEXT: for %i3 = [[IDENTITY]](%i0) to #map0(%i0) {
|
||||
// CHECK-NEXT: for %i4 = [[IDENTITY]](%i1) to #map0(%i1) {
|
||||
// CHECK-NEXT: for %i5 = [[IDENTITY]](%i2) to #map0(%i2) {
|
||||
// CHECK-NEXT: "foo"(%i3, %i4, %i5) : (index, index, index) -> ()
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
@@ -17,12 +22,12 @@
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: for %i6 = 0 to 50 step 32 {
|
||||
// CHECK-NEXT: for %i7 = (d0) -> (d0)(%i6) to min #map1(%i6) {
|
||||
// CHECK-NEXT: for %i7 = [[IDENTITY]](%i6) to min #map1(%i6) {
|
||||
// CHECK-NEXT: "bar"(%i7, %i7) : (index, index) -> ()
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: for %i8 = 0 to 21 step 32 {
|
||||
// CHECK-NEXT: for %i9 = (d0) -> (d0)(%i8) to 21 {
|
||||
// CHECK-NEXT: for %i9 = [[IDENTITY]](%i8) to 21 {
|
||||
// CHECK-NEXT: "foobar"(%i9) : (index) -> ()
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
@@ -47,3 +52,19 @@ mlfunc @loop_tiling() {
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
#lb = ()[s0] -> (0, s0)
|
||||
#ub = ()[s0, s1] -> (s0, 4096 floordiv s1)
|
||||
// CHECK-LABEL: mlfunc @loop_max_min_bound(%arg0 : memref<?xi32>, %arg1 : index, %arg2 : index) {
|
||||
mlfunc @loop_max_min_bound(%A : memref<? x i32>, %L : index, %U : index) {
|
||||
%M = dim %A, 0 : memref<? x i32>
|
||||
for %iTT = max #lb()[%L] to min #ub()[%M, %U] {
|
||||
%out = affine_apply (d0) -> (d0) (%iTT)
|
||||
}
|
||||
return
|
||||
// CHECK: for %i0 = max [[LB]]()[%arg1] to min [[UB]]()[%0, %arg2] step 32 {
|
||||
// CHECK-NEXT: for %i1 = [[IDENTITY]](%i0) to min [[UB_INTRA_TILE]](%0, %arg2, %i0) {
|
||||
// CHECK-NEXT: %1 = affine_apply [[IDENTITY]](%i1)
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user