[mlir][PartialReductionTilingInterface] Add support for ReductionTilingStrategy::PartialReductionOuterParallel in tileUsingSCF. (#143988)

Following up from https://github.com/llvm/llvm-project/pull/143467,
this PR adds support for
`ReductionTilingStrategy::PartialReductionOuterParallel` to
`tileUsingSCF`. The implementation of
`PartialReductionTilingInterface` for `Linalg` ops has been updated to
support this strategy as well. This makes the `tileUsingSCF` come on
par with `linalg::tileReductionUsingForall` which will be deprecated
subsequently.

Changes summary
- `PartialReductionTilingInterface` changes :
  - `tileToPartialReduction` method needed to get the induction
    variables of the generated tile loops. This was needed to keep the
    generated code similar to `linalg::tileReductionUsingForall`,
    specifically to create a simplified access for slicing the
intermediate partial results tensor when tiled in `num_threads` mode.
  - `getPartialResultTilePosition` methods needs the induction
    varialbes for the generated tile loops for the same reason above,
    and also needs the `tilingStrategy` to be passed in to generate
    correct code.

The tests in `transform-tile-reduction.mlir` testing the
`linalg::tileReductionUsingForall` have been moved over to test
`scf::tileUsingSCF` with
`ReductionTilingStrategy::PartialReductionOuterParallel`
strategy. Some of the test that were doing further cyclic distribution
of the transformed code from tiling are removed. Those seem like two
separate transformation that were merged into one. Ideally that would
need to happen when resolving the `scf.forall` rather than during
tiling.

Please review only the top commit. Depends on
https://github.com/llvm/llvm-project/pull/143467

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
This commit is contained in:
MaheshRavishankar
2025-06-23 12:27:26 -07:00
committed by GitHub
parent 6c232f440f
commit 7bc956d3d6
9 changed files with 535 additions and 265 deletions

View File

@@ -2019,6 +2019,7 @@ def TileReductionUsingForallOp :
// TODO: support mixed static-dynamic (see TileUsingForallOp). // TODO: support mixed static-dynamic (see TileUsingForallOp).
let arguments = (ins TransformHandleTypeInterface:$target, let arguments = (ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<I64ArrayAttr, "{}">:$reduction_dims,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$num_threads, DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$num_threads,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes, DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes,
OptionalAttr<DeviceMappingArrayAttr>:$mapping); OptionalAttr<DeviceMappingArrayAttr>:$mapping);
@@ -2036,10 +2037,11 @@ def TileReductionUsingForallOp :
let assemblyFormat = [{ let assemblyFormat = [{
$target $target
(`reduction_dims` `=` $reduction_dims^)?
`by` `by`
(`num_threads` `=` $num_threads^)? (`num_threads` `=` $num_threads^)?
(`,` `tile_sizes` `=` $tile_sizes^)? (`tile_sizes` `=` $tile_sizes^)?
(`,` `mapping` `=` $mapping^)? (`mapping` `=` $mapping^)?
attr-dict attr-dict
`:` functional-type(operands, results) `:` functional-type(operands, results)
}]; }];

View File

@@ -156,7 +156,7 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
/// corresponding pair of arrays. This is the inverse function of /// corresponding pair of arrays. This is the inverse function of
/// `getMixedValues`. /// `getMixedValues`.
std::pair<SmallVector<int64_t>, SmallVector<Value>> std::pair<SmallVector<int64_t>, SmallVector<Value>>
decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues); decomposeMixedValues(ArrayRef<OpFoldResult> mixedValues);
/// Helper to sort `values` according to matching `keys`. /// Helper to sort `values` according to matching `keys`.
SmallVector<Value> SmallVector<Value>

View File

@@ -367,15 +367,20 @@ def PartialReductionOpInterface :
OpInterface<"PartialReductionOpInterface", [TilingInterface]> { OpInterface<"PartialReductionOpInterface", [TilingInterface]> {
let description = [{ let description = [{
Interface for allowing operations to expose information needed to Interface for allowing operations to expose information needed to
tile reductions using partial reduction followed by merge. This is tile reductions using partial reduction followed by merge. This
complementary to TilingInterface to tile reductions. extends the `TilingInterface` to allow splitting a reduction
dimension into a parallel dimension and reduction dimension.
The materialized inter-tile loop could either be the reduction dimension
(i.e. `ReductionTilingStrategy::PartialReductionOuterReduction`) or
the parallel dimension (i.e
`ReductionTilingStrategy::PartialReductionOuterReduction`).
}]; }];
let cppNamespace = "::mlir"; let cppNamespace = "::mlir";
let methods = [ let methods = [
InterfaceMethod< InterfaceMethod<
/*desc=*/[{ /*desc=*/[{
Method to generate a tensor initalized with the identity value of the Method to generate a tensor initalized with the identity value of the
operation reduction. The tensor shape is equal to operation result reduction operator. The tensor shape is equal to operation result
shape with new dimension for each non zero tile size. shape with new dimension for each non zero tile size.
}], }],
/*retType=*/"::mlir::FailureOr<SmallVector<Value>>", /*retType=*/"::mlir::FailureOr<SmallVector<Value>>",
@@ -383,7 +388,7 @@ def PartialReductionOpInterface :
/*args=*/(ins /*args=*/(ins
"::mlir::OpBuilder &":$b, "::mlir::OpBuilder &":$b,
"Location":$loc, "Location":$loc,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes, "::mlir::ArrayRef<::mlir::OpFoldResult>":$tileSizes,
"const ::mlir::SetVector<unsigned> &":$reductionDims), "const ::mlir::SetVector<unsigned> &":$reductionDims),
/*methodBody=*/"", /*methodBody=*/"",
/*defaultImplementation=*/[{ /*defaultImplementation=*/[{
@@ -396,6 +401,11 @@ def PartialReductionOpInterface :
reduction dimension are converted to parallel dimensions with a size reduction dimension are converted to parallel dimensions with a size
less or equal to the tile size. This is meant to be used with less or equal to the tile size. This is meant to be used with
`mergeReductions` method which will combine the partial reductions. `mergeReductions` method which will combine the partial reductions.
The method recieves the `offset` and `sizes` for all iteration space
dimensions, as well as the iteration number of the tiled reduction
dimensions (which is the induction variable of the inter-tile loop
for the reduction dimension divided by the step of the loop) in
`splitReductionIvs`.
}], }],
/*retType=*/"::mlir::FailureOr<TilingResult>", /*retType=*/"::mlir::FailureOr<TilingResult>",
/*methodName=*/"tileToPartialReduction", /*methodName=*/"tileToPartialReduction",
@@ -406,7 +416,8 @@ def PartialReductionOpInterface :
"ValueRange":$init, "ValueRange":$init,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets, "::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes, "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
"const ::llvm::SetVector<unsigned> &":$reductionDims), "const ::llvm::SetVector<unsigned> &":$reductionDims,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$splitReductionIvs),
/*methodBody=*/"", /*methodBody=*/"",
/*defaultImplementation=*/[{ /*defaultImplementation=*/[{
return failure(); return failure();
@@ -436,15 +447,22 @@ def PartialReductionOpInterface :
the tiled operation. This is same as the tiled operation. This is same as
TilingInterface:::getResultTilePosition, but determines the result TilingInterface:::getResultTilePosition, but determines the result
tile position for partial reduction. tile position for partial reduction.
The method recieves the `offset` and `sizes` for all iteration space
dimensions, as well as the iteration number of the tiled reduction
dimensions (which is the induction variable of the inter-tile loop
for the reduction dimension divided by the tile size specified) in
`splitReductionIvs`.
}], }],
/*retType=*/"::llvm::LogicalResult", /*retType=*/"::llvm::LogicalResult",
/*methodName=*/"getPartialResultTilePosition", /*methodName=*/"getPartialResultTilePosition",
/*args=*/(ins /*args=*/(ins
"::mlir::OpBuilder &":$b, "::mlir::OpBuilder &":$b,
"unsigned":$resultNumber, "unsigned":$resultNumber,
"ReductionTilingStrategy":$tilingStrategy,
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets, "::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes, "::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
"const ::mlir::SetVector<unsigned> &":$reductionDims, "const ::mlir::SetVector<unsigned> &":$reductionDims,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$splitReductionIvs,
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets, "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets,
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes), "::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes),
/*methodBody=*/"", /*methodBody=*/"",

View File

@@ -3022,6 +3022,7 @@ void transform::TileReductionUsingForallOp::build(
build(builder, result, build(builder, result,
/*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy}, /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
/*target=*/target, /*target=*/target,
/*reduction_dims=*/{},
/*num_threads=*/staticNumThreadsAttr, /*num_threads=*/staticNumThreadsAttr,
/*tile_sizes=*/staticTileSizesAttr, /*tile_sizes=*/staticTileSizesAttr,
/*mapping=*/mapping); /*mapping=*/mapping);
@@ -3036,23 +3037,45 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads())); getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
SmallVector<OpFoldResult> tileSizes = SmallVector<OpFoldResult> tileSizes =
getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())); getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()));
FailureOr<linalg::ForallReductionTilingResult> result =
linalg::tileReductionUsingForall( scf::SCFTilingOptions options;
rewriter, cast<PartialReductionOpInterface>(target.getOperation()), options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
numThreads, tileSizes, getMapping()); options.setReductionTilingStrategy(
ReductionTilingStrategy::PartialReductionOuterParallel);
if (!getNumThreads().empty()) {
options.setNumThreads(numThreads);
} else {
options.setTileSizes(tileSizes);
}
if (auto mapping = getMapping()) {
options.setMapping(mapping.value().getValue());
}
SmallVector<unsigned> reductionDims =
extractFromIntegerArrayAttr<unsigned>(getReductionDims());
if (reductionDims.empty()) {
for (auto [idx, iteratorType] :
llvm::enumerate(target.getIteratorTypesArray())) {
if (iteratorType == utils::IteratorType::reduction)
reductionDims.push_back(idx);
}
}
options.setReductionDims(reductionDims);
FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF(
rewriter, cast<TilingInterface>(target.getOperation()), options);
if (failed(result)) { if (failed(result)) {
auto diag = emitSilenceableError() << "could not tile reduction"; auto diag = emitSilenceableError() << "could not tile reduction";
diag.attachNote(target.getLoc()) << "target operation";
return diag; return diag;
} }
rewriter.replaceOp(target, result->replacements);
for (Value initValue : result->initialValues) for (Value initValue : result->initialValues)
results.push_back(initValue.getDefiningOp()); results.push_back(initValue.getDefiningOp());
for (auto parallelTiledOp : result->parallelTiledOps) for (auto parallelTiledOp : result->tiledOps)
results.push_back(parallelTiledOp); results.push_back(parallelTiledOp);
for (auto mergeOp : result->mergeOps) for (auto mergeOp : result->mergeOps)
results.push_back(mergeOp); results.push_back(mergeOp);
results.push_back(result->loops); results.push_back(result->loops.front());
return DiagnosedSilenceableFailure::success(); return DiagnosedSilenceableFailure::success();
} }

View File

@@ -328,6 +328,17 @@ struct LinalgOpTilingInterface
// External Model for implementing `PartialReductionInterface` for `LinalgOp`s. // External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// In a given set vector, get the position of a particular element.
std::optional<int> getPositionIn(const llvm::SetVector<unsigned> &reductionDims,
unsigned value) {
for (auto [index, reductionDim] : llvm::enumerate(reductionDims)) {
if (reductionDim == value) {
return index;
}
}
return std::nullopt;
}
/// Return an AffineMaps to use for the `outs` operands of the linalg op /// Return an AffineMaps to use for the `outs` operands of the linalg op
/// generated for partial results. The new AffineMap is the AffineMap of the /// generated for partial results. The new AffineMap is the AffineMap of the
/// untiled op with reduction dimensions appended at end in order in which they /// untiled op with reduction dimensions appended at end in order in which they
@@ -348,28 +359,86 @@ getPartialResultAffineMaps(LinalgOp linalgOp,
return partialReductionMaps; return partialReductionMaps;
} }
/// Return the slice of the `initValue` to use as input to the partial reduction struct InitSliceInfo {
/// op generated. SmallVector<int64_t> resultShape;
static Operation *getInitSliceForOuterReduction( SmallVector<OpFoldResult> offsets;
OpBuilder &b, Location loc, Value initValue, ArrayRef<OpFoldResult> offsets, SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides;
};
/// Return the result shape, offsets, sizes and strides of the slice of the
/// `initValue` to use as the destination of the partial reduction op generated
/// with outer reduction strategy.
static InitSliceInfo getInitSliceInfoForOuterReduction(
MLIRContext *context, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims, ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
AffineMap partialReductionMap) { ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) {
int64_t initRank = partialReductionMap.getNumResults(); int64_t initRank = partialReductionMap.getNumResults();
SmallVector<OpFoldResult> initOffsets, initSizes; SmallVector<OpFoldResult> initOffsets, initSizes;
SmallVector<OpFoldResult> initStrides(initRank, b.getIndexAttr(1)); Attribute zero = IntegerAttr::get(IndexType::get(context), 0);
Attribute one = IntegerAttr::get(IndexType::get(context), 1);
SmallVector<OpFoldResult> initStrides(initRank, one);
for (AffineExpr dimExpr : partialReductionMap.getResults()) { for (AffineExpr dimExpr : partialReductionMap.getResults()) {
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition(); unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
if (reductionDims.contains(dim)) { if (reductionDims.contains(dim)) {
initOffsets.push_back(b.getIndexAttr(0)); initOffsets.push_back(zero);
} else { } else {
initOffsets.push_back(offsets[dim]); initOffsets.push_back(offsets[dim]);
} }
initSizes.push_back(sizes[dim]); initSizes.push_back(sizes[dim]);
} }
// TODO: Use SubsetExtractOpInterface here once available. SmallVector<int64_t> resultShape;
auto extractSlice = b.create<tensor::ExtractSliceOp>( std::tie(resultShape, std::ignore) = decomposeMixedValues(initSizes);
loc, initValue, initOffsets, initSizes, initStrides); return {resultShape, initOffsets, initSizes, initStrides};
return extractSlice; }
/// Return the result shape, offsets, sizes and strides of the slice of the
/// `initValue` to use as destination of the partial reduction op generated with
/// outer parallel strategy.
static InitSliceInfo getInitSliceInfoForOuterParallel(
MLIRContext *context, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) {
int64_t initRank = partialReductionMap.getNumResults();
SmallVector<OpFoldResult> initOffsets, initSizes;
Attribute one = IntegerAttr::get(IndexType::get(context), 1);
SmallVector<OpFoldResult> initStrides(initRank, one);
SmallVector<OpFoldResult> resultShape;
for (AffineExpr dimExpr : partialReductionMap.getResults()) {
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
if (std::optional<unsigned> dimPos = getPositionIn(reductionDims, dim)) {
initOffsets.push_back(splitReductionIvs[dimPos.value()]);
initSizes.push_back(one);
} else {
initOffsets.push_back(offsets[dim]);
initSizes.push_back(sizes[dim]);
resultShape.push_back(sizes[dim]);
}
}
SmallVector<int64_t> staticShapes;
std::tie(staticShapes, std::ignore) = decomposeMixedValues(resultShape);
return {staticShapes, initOffsets, initSizes, initStrides};
}
/// Return the result shape, offsets, sizes and strides of the slice of the
/// `initValue` to use as destination of the partial reduction op.
static InitSliceInfo getInitSliceInfo(MLIRContext *context,
ReductionTilingStrategy strategy,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
const SetVector<unsigned> &reductionDims,
ArrayRef<OpFoldResult> splitReductionIvs,
AffineMap partialReductionMap) {
if (strategy == ReductionTilingStrategy::PartialReductionOuterReduction) {
return getInitSliceInfoForOuterReduction(context, offsets, sizes,
reductionDims, splitReductionIvs,
partialReductionMap);
}
assert(strategy == ReductionTilingStrategy::PartialReductionOuterParallel &&
"unexpected ReductionTilingStrategy");
return getInitSliceInfoForOuterParallel(context, offsets, sizes,
reductionDims, splitReductionIvs,
partialReductionMap);
} }
/// External model implementation of PartialReductionInterface for /// External model implementation of PartialReductionInterface for
@@ -390,21 +459,6 @@ struct LinalgOpPartialReductionInterface
SmallVector<AffineMap> partialResultMaps = SmallVector<AffineMap> partialResultMaps =
getPartialResultAffineMaps(linalgOp, reductionDims); getPartialResultAffineMaps(linalgOp, reductionDims);
// LinalgOp implements TilingInterface.
auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
SmallVector<OpFoldResult> shape =
llvm::map_to_vector(tilingInterfaceOp.getIterationDomain(b),
[](Range x) { return x.size; });
SmallVector<OpFoldResult> tiledShape;
for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) {
if (isZeroInteger(tileSize)) {
tiledShape.push_back(dimSize);
} else {
tiledShape.push_back(tileSize);
}
}
SmallVector<Value> inits; SmallVector<Value> inits;
for (auto [initIdx, result, partialMap] : for (auto [initIdx, result, partialMap] :
llvm::enumerate(linalgOp->getResults(), partialResultMaps)) { llvm::enumerate(linalgOp->getResults(), partialResultMaps)) {
@@ -424,7 +478,7 @@ struct LinalgOpPartialReductionInterface
SmallVector<OpFoldResult> partialResultShape; SmallVector<OpFoldResult> partialResultShape;
for (AffineExpr dimExpr : partialMap.getResults()) { for (AffineExpr dimExpr : partialMap.getResults()) {
auto dim = cast<AffineDimExpr>(dimExpr); auto dim = cast<AffineDimExpr>(dimExpr);
partialResultShape.push_back(tiledShape[dim.getPosition()]); partialResultShape.push_back(sizes[dim.getPosition()]);
} }
Type elType = getElementTypeOrSelf(result.getType()); Type elType = getElementTypeOrSelf(result.getType());
@@ -444,13 +498,8 @@ struct LinalgOpPartialReductionInterface
ReductionTilingStrategy tilingStrategy, ReductionTilingStrategy tilingStrategy,
ValueRange init, ArrayRef<OpFoldResult> offsets, ValueRange init, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> sizes,
const SetVector<unsigned> &reductionDims) const { const SetVector<unsigned> &reductionDims,
if (tilingStrategy != ArrayRef<OpFoldResult> splitReductionIvs) const {
ReductionTilingStrategy::PartialReductionOuterReduction) {
// TODO: Add support for `PartialReductionOuterParallel` strategy.
return op->emitOpError("unsupported partial reduction tiling with "
"`PartialReductionOuterParallel` strategy");
}
OpBuilder::InsertionGuard guard(b); OpBuilder::InsertionGuard guard(b);
auto linalgOp = cast<LinalgOp>(op); auto linalgOp = cast<LinalgOp>(op);
@@ -459,7 +508,16 @@ struct LinalgOpPartialReductionInterface
// Step 1. Extend init maps to have reduction dimension dims, since we // Step 1. Extend init maps to have reduction dimension dims, since we
// are converting them to parallel dimensions. // are converting them to parallel dimensions.
SmallVector<AffineMap> newInitMaps = partialReductionMaps; SmallVector<AffineMap> newInitMaps;
if (tilingStrategy ==
ReductionTilingStrategy::PartialReductionOuterReduction) {
newInitMaps = llvm::to_vector(partialReductionMaps);
} else {
newInitMaps = llvm::map_to_vector(
linalgOp.getDpsInitsMutable(), [&](OpOperand &opOperand) {
return linalgOp.getMatchingIndexingMap(&opOperand);
});
}
// Step 2a: Extract a slice of the input operands. // Step 2a: Extract a slice of the input operands.
SmallVector<Value> tiledInputs = makeTiledShapes( SmallVector<Value> tiledInputs = makeTiledShapes(
@@ -473,10 +531,17 @@ struct LinalgOpPartialReductionInterface
SmallVector<Value, 1> tiledInits; SmallVector<Value, 1> tiledInits;
for (auto [partialReductionMap, valueToTile] : for (auto [partialReductionMap, valueToTile] :
llvm::zip_equal(partialReductionMaps, init)) { llvm::zip_equal(partialReductionMaps, init)) {
Operation *sliceOp = InitSliceInfo sliceInfo = getInitSliceInfo(
getInitSliceForOuterReduction(b, loc, valueToTile, offsets, sizes, b.getContext(), tilingStrategy, offsets, sizes, reductionDims,
reductionDims, partialReductionMap); splitReductionIvs, partialReductionMap);
tiledInits.push_back(sliceOp->getResult(0)); auto valueToTileType = cast<RankedTensorType>(valueToTile.getType());
RankedTensorType sliceResultType = RankedTensorType::get(
sliceInfo.resultShape, valueToTileType.getElementType(),
valueToTileType.getEncoding());
auto sliceOp = b.create<tensor::ExtractSliceOp>(
loc, sliceResultType, valueToTile, sliceInfo.offsets, sliceInfo.sizes,
sliceInfo.strides);
tiledInits.push_back(sliceOp.getResult());
generatedSlices.push_back(sliceOp); generatedSlices.push_back(sliceOp);
} }
@@ -491,19 +556,31 @@ struct LinalgOpPartialReductionInterface
// Step 3. Change the reduction dim iterator types. // Step 3. Change the reduction dim iterator types.
SmallVector<utils::IteratorType> newIteratorTypes = SmallVector<utils::IteratorType> newIteratorTypes =
linalgOp.getIteratorTypesArray(); linalgOp.getIteratorTypesArray();
for (int dim : reductionDims) if (tilingStrategy ==
newIteratorTypes[dim] = utils::IteratorType::parallel; ReductionTilingStrategy::PartialReductionOuterReduction) {
for (int dim : reductionDims)
newIteratorTypes[dim] = utils::IteratorType::parallel;
}
// Step 4. Create the new generic op. // Step 4. Create the new generic op.
Operation *partialReductionOp;
auto resultTypes = ValueRange(tiledInits).getTypes(); auto resultTypes = ValueRange(tiledInits).getTypes();
auto genericOp = b.create<GenericOp>(loc, resultTypes, tiledInputs, if (tilingStrategy ==
tiledInits, newMaps, newIteratorTypes); ReductionTilingStrategy::PartialReductionOuterReduction) {
IRMapping mapping; auto genericOp = b.create<GenericOp>(
op->getRegion(0).cloneInto(&genericOp.getRegion(), loc, resultTypes, tiledInputs, tiledInits, newMaps, newIteratorTypes);
genericOp.getRegion().begin(), mapping); IRMapping mapping;
op->getRegion(0).cloneInto(&genericOp.getRegion(),
genericOp.getRegion().begin(), mapping);
partialReductionOp = genericOp.getOperation();
} else {
SmallVector<Value> operands = std::move(tiledInputs);
llvm::append_range(operands, tiledInits);
partialReductionOp = mlir::clone(b, op, resultTypes, operands);
}
return TilingResult{ return TilingResult{
{genericOp.getOperation()}, {partialReductionOp},
llvm::map_to_vector(genericOp->getResults(), llvm::map_to_vector(partialReductionOp->getResults(),
[](OpResult r) -> Value { return r; }), [](OpResult r) -> Value { return r; }),
generatedSlices}; generatedSlices};
} }
@@ -558,26 +635,19 @@ struct LinalgOpPartialReductionInterface
LogicalResult getPartialResultTilePosition( LogicalResult getPartialResultTilePosition(
Operation *op, OpBuilder &b, unsigned resultNumber, Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, ReductionTilingStrategy tilingStrategy, ArrayRef<OpFoldResult> offsets,
const SetVector<unsigned> &reductionDims, ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
ArrayRef<OpFoldResult> splitReductionIvs,
SmallVector<OpFoldResult> &resultOffsets, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) const { SmallVector<OpFoldResult> &resultSizes) const {
auto linalgOp = cast<LinalgOp>(op); auto linalgOp = cast<LinalgOp>(op);
SmallVector<AffineMap> partialReductionMaps = SmallVector<AffineMap> partialReductionMaps =
getPartialResultAffineMaps(linalgOp, reductionDims); getPartialResultAffineMaps(linalgOp, reductionDims);
InitSliceInfo sliceInfo = getInitSliceInfo(
for (AffineExpr dimExpr : partialReductionMaps[resultNumber].getResults()) { b.getContext(), tilingStrategy, offsets, sizes, reductionDims,
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition(); splitReductionIvs, partialReductionMaps[resultNumber]);
resultSizes.push_back(sizes[dim]); std::swap(resultOffsets, sliceInfo.offsets);
std::swap(resultSizes, sliceInfo.sizes);
if (llvm::is_contained(reductionDims, dim)) {
// Reduction dims are reduced, and are always outputed in the same
// place. So use offset 0 for them.
resultOffsets.push_back(b.getIndexAttr(0));
} else {
resultOffsets.push_back(offsets[dim]);
}
}
return success(); return success();
} }

View File

@@ -166,12 +166,11 @@ static LogicalResult checkTileSizes(TilingInterface op,
assert((numThreads.empty() || (numThreads.size() == iterators.size())) && assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
"when specified, expected number of threads to use for each loop"); "when specified, expected number of threads to use for each loop");
bool isParallelTiling = false, isReductionTiling = false; bool isParallelTiling = false;
for (auto [index, iterator, tileSize] : for (auto [index, iterator, tileSize] :
llvm::enumerate(iterators, tileSizes)) { llvm::enumerate(iterators, tileSizes)) {
if (!isConstantIntValue(tileSize, 0)) { if (!isConstantIntValue(tileSize, 0)) {
isParallelTiling |= iterator == utils::IteratorType::parallel; isParallelTiling |= iterator == utils::IteratorType::parallel;
isReductionTiling |= iterator == utils::IteratorType::reduction;
} }
if (loopType == scf::SCFTilingOptions::LoopType::ForallOp && if (loopType == scf::SCFTilingOptions::LoopType::ForallOp &&
@@ -199,15 +198,29 @@ static LogicalResult checkTileSizes(TilingInterface op,
} }
} }
if (isParallelTiling && isReductionTiling && if (reductionStrategy != ReductionTilingStrategy::FullReduction) {
reductionStrategy != ReductionTilingStrategy::FullReduction) { if (isParallelTiling) {
return op->emitOpError( return op->emitOpError("tiling parallel dimensions is not supported with "
"combined parallel and reduction tiling is not supported with partial " "partial reduction tiling strategies");
"reduction tiling strategies"); }
} }
return success(); return success();
} }
/// Get the reduction dims that are tiled. This accounts for reduction dims
/// that are specified as tiled, but the tile size is 0.
static SetVector<unsigned>
getSanitizedReductionDims(ArrayRef<OpFoldResult> tileSizes,
const scf::SCFTilingOptions &options) {
SetVector<unsigned> reductionDims;
for (auto dim : options.reductionDims) {
if (isConstantIntValue(tileSizes[dim], 0))
continue;
reductionDims.insert(dim);
}
return reductionDims;
}
/// Check if `stride` evenly divides the trip count `size - offset`. /// Check if `stride` evenly divides the trip count `size - offset`.
static bool tileDividesIterationDomain(Range loopRange) { static bool tileDividesIterationDomain(Range loopRange) {
std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset); std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
@@ -264,10 +277,12 @@ static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
/// `offset`s and `size`s of the tile of the iteration space that the /// `offset`s and `size`s of the tile of the iteration space that the
/// innermost loop body of the generated tiled loops corresponds to. /// innermost loop body of the generated tiled loops corresponds to.
static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>> static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, getTileOffsetAndSizes(RewriterBase &rewriter, Location loc,
ReductionTilingStrategy strategy, ValueRange ivs,
ArrayRef<Range> iterationDomain, ArrayRef<Range> iterationDomain,
ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> tileSizes,
ArrayRef<OpFoldResult> numThreads) { ArrayRef<OpFoldResult> numThreads,
const llvm::SetVector<unsigned> &reductionDims) {
SmallVector<OpFoldResult> offsets, sizes; SmallVector<OpFoldResult> offsets, sizes;
int materializedLoopNum = 0; int materializedLoopNum = 0;
@@ -279,8 +294,8 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
offsetExpr = d0 + d1 * s0; offsetExpr = d0 + d1 * s0;
residualTileSizeExpr = s1 - (d0 + d1 * s0); residualTileSizeExpr = s1 - (d0 + d1 * s0);
for (auto [nt, tileSize, loopRange] : for (auto [index, nt, tileSize, loopRange] :
llvm::zip_equal(numThreads, tileSizes, iterationDomain)) { llvm::enumerate(numThreads, tileSizes, iterationDomain)) {
// Non-tiled cases, set the offset and size to the // Non-tiled cases, set the offset and size to the
// `loopRange.offset/size`. // `loopRange.offset/size`.
@@ -564,9 +579,10 @@ static LogicalResult generateLoopNestUsingForallOp(
/// - `loops` is an in-out parameter into which the generated loops are /// - `loops` is an in-out parameter into which the generated loops are
/// populated. /// populated.
static LogicalResult generateLoopNest( static LogicalResult generateLoopNest(
RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, RewriterBase &rewriter, Location loc,
ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes, scf::SCFTilingOptions::LoopType loopType, ArrayRef<Range> loopRanges,
ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors, ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
ValueRange destinationTensors, ArrayRef<Attribute> mappingVector,
YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) { YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
// If the tile sizes are all zero, no loops are generated. Just call the // If the tile sizes are all zero, no loops are generated. Just call the
// callback function to handle untiled case. // callback function to handle untiled case.
@@ -576,25 +592,26 @@ static LogicalResult generateLoopNest(
return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors, return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
tiledResults, resultOffsets, resultSizes); tiledResults, resultOffsets, resultSizes);
} }
if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) { if (loopType == scf::SCFTilingOptions::LoopType::ForOp) {
return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes, return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
destinationTensors, tiledBodyFn, loops); destinationTensors, tiledBodyFn, loops);
} }
if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) { if (loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
return generateLoopNestUsingForallOp( return generateLoopNestUsingForallOp(
rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector, rewriter, loc, loopRanges, tileSizes, numThreads, mappingVector,
destinationTensors, tiledBodyFn, loops); destinationTensors, tiledBodyFn, loops);
} }
return rewriter.notifyMatchFailure(loc, "unhandled loop type"); return rewriter.notifyMatchFailure(loc, "unhandled loop type");
} }
static FailureOr<SmallVector<Value>> static FailureOr<SmallVector<Value>> createInitialTensorsForTiling(
createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, RewriterBase &rewriter, TilingInterface op,
ArrayRef<OpFoldResult> tileSizes, ReductionTilingStrategy reductionStrategy, ArrayRef<Range> iterationDomain,
const scf::SCFTilingOptions &options) { ArrayRef<OpFoldResult> numThreads, ArrayRef<OpFoldResult> tileSizes,
const SetVector<unsigned> &reductionDims) {
SmallVector<Value> initTensors; SmallVector<Value> initTensors;
Location loc = op->getLoc(); Location loc = op->getLoc();
if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) { if (reductionStrategy == ReductionTilingStrategy::FullReduction) {
if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, initTensors))) if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, initTensors)))
return failure(); return failure();
return initTensors; return initTensors;
@@ -602,20 +619,94 @@ createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op,
auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation()); auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
if (!redOp) { if (!redOp) {
return rewriter.notifyMatchFailure( return op->emitOpError(
op, "PartialReductionOuterReduction tiling strategy is only supported" "PartialReductionOuterReduction tiling strategy is only supported for "
"for operations implementing PartialReductionOpInterface"); "operations implementing PartialReductionOpInterface");
} }
return redOp.generateInitialTensorForPartialReduction( SmallVector<OpFoldResult> sizes(iterationDomain.size());
rewriter, loc, tileSizes, options.reductionDims); AffineExpr s0, s1, s2;
bindSymbols(rewriter.getContext(), s0, s1, s2);
AffineExpr sizeExpr = ((s0 - s1).ceilDiv(s2));
AffineExpr divExpr = s0.ceilDiv(s1);
for (auto [index, domain, tileSize] :
llvm::enumerate(iterationDomain, tileSizes)) {
if (!numThreads.empty()) {
// Untiled case.
if (isConstantIntValue(numThreads[index], 0)) {
sizes[index] = affine::makeComposedFoldedAffineApply(
rewriter, op.getLoc(), sizeExpr,
{domain.size, domain.offset, domain.stride});
continue;
}
sizes[index] = numThreads[index];
continue;
}
// Non reduction dimensions/non-tiled dimensions.
if (!reductionDims.contains(index) || isConstantIntValue(tileSize, 0)) {
sizes[index] = affine::makeComposedFoldedAffineApply(
rewriter, op.getLoc(), sizeExpr,
{domain.size, domain.offset, domain.stride});
continue;
}
if (reductionStrategy ==
ReductionTilingStrategy::PartialReductionOuterReduction) {
sizes[index] = tileSize;
continue;
}
assert(reductionStrategy ==
ReductionTilingStrategy::PartialReductionOuterParallel);
OpFoldResult normalizedRange = affine::makeComposedFoldedAffineApply(
rewriter, op.getLoc(), sizeExpr,
{domain.size, domain.offset, domain.stride});
sizes[index] = affine::makeComposedFoldedAffineApply(
rewriter, op.getLoc(), divExpr, {normalizedRange, tileSize});
}
return redOp.generateInitialTensorForPartialReduction(rewriter, loc, sizes,
reductionDims);
}
/// For the case of `ReductionTilingStrategy::PartialReductionOuterParallel`
/// the `PartialReductionOpInterface` methods need the index of the parallel
/// split reduction being executed.
static SmallVector<OpFoldResult>
getSplitReductionIvs(RewriterBase &rewriter, Location loc,
ReductionTilingStrategy reductionStrategy, ValueRange ivs,
ArrayRef<OpFoldResult> numThreads,
ArrayRef<OpFoldResult> tileSizes,
const SetVector<unsigned> &reductionDims) {
SmallVector<OpFoldResult> splitReductionIvs;
splitReductionIvs.resize(reductionDims.size(), rewriter.getIndexAttr(0));
AffineExpr s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
AffineExpr divExpr = s0.ceilDiv(s1);
int ivIndex = 0;
if (reductionStrategy ==
ReductionTilingStrategy::PartialReductionOuterParallel) {
for (auto [index, reductionDim] : llvm::enumerate(reductionDims)) {
if (!numThreads.empty()) {
splitReductionIvs[index] = ivs[ivIndex++];
continue;
}
splitReductionIvs[index] = affine::makeComposedFoldedAffineApply(
rewriter, loc, divExpr,
ArrayRef<OpFoldResult>{ivs[ivIndex++], tileSizes[reductionDim]});
}
}
return splitReductionIvs;
} }
static FailureOr<TilingResult> static FailureOr<TilingResult>
getTiledImplementation(RewriterBase &rewriter, TilingInterface op, getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
ReductionTilingStrategy reductionStrategy,
ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets, ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> sizes, ValueRange ivs,
const scf::SCFTilingOptions &options) { ArrayRef<OpFoldResult> numThreads,
if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) { ArrayRef<OpFoldResult> tileSizes,
const SetVector<unsigned> &reductionDims) {
if (reductionStrategy == ReductionTilingStrategy::FullReduction) {
return op.getTiledImplementation(rewriter, offsets, sizes); return op.getTiledImplementation(rewriter, offsets, sizes);
} }
@@ -626,20 +717,25 @@ getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
"supported for operations " "supported for operations "
"implementing PartialReductionOpInterface"); "implementing PartialReductionOpInterface");
} }
return redOp.tileToPartialReduction(rewriter, op.getLoc(),
options.reductionStrategy, regionIterArg, SmallVector<OpFoldResult> splitReductionIvs =
offsets, sizes, options.reductionDims); getSplitReductionIvs(rewriter, op.getLoc(), reductionStrategy, ivs,
numThreads, tileSizes, reductionDims);
return redOp.tileToPartialReduction(rewriter, op.getLoc(), reductionStrategy,
regionIterArg, offsets, sizes,
reductionDims, splitReductionIvs);
} }
static LogicalResult static LogicalResult getResultTilePosition(
getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy,
TilingInterface op, ArrayRef<OpFoldResult> offsets, int64_t index, Value tiledResult, TilingInterface op,
ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
SmallVector<OpFoldResult> &resultOffset, ValueRange ivs, ArrayRef<OpFoldResult> numThreads,
SmallVector<OpFoldResult> &resultSize, ArrayRef<OpFoldResult> tileSizes, const SetVector<unsigned> &reductionDims,
const scf::SCFTilingOptions &options) { SmallVector<OpFoldResult> &resultOffset,
SmallVector<OpFoldResult> &resultSize) {
if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) { if (reductionStrategy == ReductionTilingStrategy::FullReduction) {
return op.getResultTilePosition(rewriter, index, offsets, sizes, return op.getResultTilePosition(rewriter, index, offsets, sizes,
resultOffset, resultSize); resultOffset, resultSize);
} }
@@ -649,16 +745,20 @@ getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
op, "PartialReductionOuterReduction tiling strategy is only supported" op, "PartialReductionOuterReduction tiling strategy is only supported"
"for operations implementing PartialReductionOpInterface"); "for operations implementing PartialReductionOpInterface");
} }
return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes, SmallVector<OpFoldResult> splitReductionIvs =
options.reductionDims, resultOffset, getSplitReductionIvs(rewriter, op.getLoc(), reductionStrategy, ivs,
resultSize); numThreads, tileSizes, reductionDims);
return redOp.getPartialResultTilePosition(
rewriter, index, reductionStrategy, offsets, sizes, reductionDims,
splitReductionIvs, resultOffset, resultSize);
} }
static FailureOr<MergeResult> static FailureOr<MergeResult>
mergeTilingResults(RewriterBase &rewriter, TilingInterface op, mergeTilingResults(RewriterBase &rewriter, TilingInterface op,
ValueRange partialResults, ReductionTilingStrategy reductionStrategy,
const scf::SCFTilingOptions &options) { const SetVector<unsigned> &reductionDims,
assert(options.reductionStrategy != ReductionTilingStrategy::FullReduction && ValueRange partialResults) {
assert(reductionStrategy != ReductionTilingStrategy::FullReduction &&
"expected merge to be called for only partial reduction cases"); "expected merge to be called for only partial reduction cases");
auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation()); auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
@@ -669,7 +769,7 @@ mergeTilingResults(RewriterBase &rewriter, TilingInterface op,
"implementing PartialReductionOpInterface"); "implementing PartialReductionOpInterface");
} }
return redOp.mergeReductions(rewriter, op.getLoc(), partialResults, return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
options.reductionDims); reductionDims);
} }
/// Append the specified additional `newInitOperands` operands to the /// Append the specified additional `newInitOperands` operands to the
@@ -911,6 +1011,10 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
return failure(); return failure();
} }
// Get the reduction dims
SetVector<unsigned> reductionDims =
getSanitizedReductionDims(tileSizes, options);
// 3. If there is an interchange specified, permute the iteration domain and // 3. If there is an interchange specified, permute the iteration domain and
// the tile sizes. // the tile sizes.
SmallVector<int64_t> interchangeVector; SmallVector<int64_t> interchangeVector;
@@ -938,7 +1042,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
// 4a. Compute the `offsets` and `sizes` to use for tiling. // 4a. Compute the `offsets` and `sizes` to use for tiling.
SmallVector<OpFoldResult> offsets, sizes; SmallVector<OpFoldResult> offsets, sizes;
std::tie(offsets, sizes) = getTileOffsetAndSizes( std::tie(offsets, sizes) = getTileOffsetAndSizes(
rewriter, loc, ivs, iterationDomain, tileSizes, numThreads); rewriter, loc, options.reductionStrategy, ivs, iterationDomain,
tileSizes, numThreads, reductionDims);
// 4b. If interchange was provided, apply inverse of the interchange // 4b. If interchange was provided, apply inverse of the interchange
// to get back the offsets/sizes in the order to be specified. // to get back the offsets/sizes in the order to be specified.
@@ -966,8 +1071,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
} }
// 5c. Tile the cloned operation. // 5c. Tile the cloned operation.
tilingResult = getTiledImplementation(rewriter, clonedOp, regionIterArgs, tilingResult = getTiledImplementation(
offsets, sizes, options); rewriter, clonedOp, options.reductionStrategy, regionIterArgs, offsets,
sizes, ivs, numThreads, tileSizes, reductionDims);
if (failed(tilingResult)) { if (failed(tilingResult)) {
rewriter.eraseOp(clonedOp); rewriter.eraseOp(clonedOp);
return op.emitOpError("faild to tile operation"); return op.emitOpError("faild to tile operation");
@@ -982,9 +1088,10 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
llvm::enumerate(tilingResult->tiledValues)) { llvm::enumerate(tilingResult->tiledValues)) {
tiledResults.push_back(tiledValue); tiledResults.push_back(tiledValue);
SmallVector<OpFoldResult> resultOffset, resultSize; SmallVector<OpFoldResult> resultOffset, resultSize;
if (failed(getResultTilePosition(rewriter, index, tiledValue, op, offsets, if (failed(getResultTilePosition(
sizes, resultOffset, resultSize, rewriter, options.reductionStrategy, index, tiledValue, op,
options))) { offsets, sizes, ivs, numThreads, tileSizes, reductionDims,
resultOffset, resultSize))) {
for (auto op : tilingResult->tiledOps) { for (auto op : tilingResult->tiledOps) {
rewriter.eraseOp(op); rewriter.eraseOp(op);
} }
@@ -999,8 +1106,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
}; };
// 6. Find the destination tensors to use for the operation. // 6. Find the destination tensors to use for the operation.
FailureOr<SmallVector<Value>> maybeInits = FailureOr<SmallVector<Value>> maybeInits = createInitialTensorsForTiling(
createInitialTensorsForTiling(rewriter, op, tileSizes, options); rewriter, op, options.reductionStrategy, iterationDomain, numThreads,
tileSizes, reductionDims);
if (failed(maybeInits)) { if (failed(maybeInits)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "unable to create initial tensors for tiling"); op, "unable to create initial tensors for tiling");
@@ -1009,8 +1117,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
// 7. Generate the tiled loops nest using the callback defined above. // 7. Generate the tiled loops nest using the callback defined above.
SmallVector<LoopLikeOpInterface> loops; SmallVector<LoopLikeOpInterface> loops;
if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain, if (failed(generateLoopNest(rewriter, op.getLoc(), options.loopType,
tileSizes, numThreads, initTensors, iterationDomain, tileSizes, numThreads,
initTensors, options.mappingVector,
innerYieldTiledValuesFn, loops))) innerYieldTiledValuesFn, loops)))
return op.emitOpError("failed to generate tiling loops"); return op.emitOpError("failed to generate tiling loops");
assert(succeeded(tilingResult) && assert(succeeded(tilingResult) &&
@@ -1038,8 +1147,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
} }
// The results of the loop needs to be merged. // The results of the loop needs to be merged.
FailureOr<MergeResult> mergeResult = FailureOr<MergeResult> mergeResult = mergeTilingResults(
mergeTilingResults(rewriter, op, loopResults, options); rewriter, op, options.reductionStrategy, reductionDims, loopResults);
if (failed(mergeResult)) { if (failed(mergeResult)) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Failed to merge partial results from tiling"); op, "Failed to merge partial results from tiling");

View File

@@ -2315,13 +2315,13 @@ RankedTensorType ExtractSliceOp::inferResultType(
RankedTensorType ExtractSliceOp::inferResultType( RankedTensorType ExtractSliceOp::inferResultType(
RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets, RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) { ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides; SmallVector<int64_t> staticSizes;
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides; std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes);
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); assert(static_cast<int64_t>(staticSizes.size()) ==
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); sourceTensorType.getRank() &&
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); "unexpected staticSizes not equal to rank of source");
return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets, return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
staticSizes, staticStrides); sourceTensorType.getEncoding());
} }
/// If the rank is reduced (i.e. the desiredResultRank is smaller than the /// If the rank is reduced (i.e. the desiredResultRank is smaller than the

View File

@@ -208,7 +208,7 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
/// Decompose a vector of mixed static or dynamic values into the corresponding /// Decompose a vector of mixed static or dynamic values into the corresponding
/// pair of arrays. This is the inverse function of `getMixedValues`. /// pair of arrays. This is the inverse function of `getMixedValues`.
std::pair<SmallVector<int64_t>, SmallVector<Value>> std::pair<SmallVector<int64_t>, SmallVector<Value>>
decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues) { decomposeMixedValues(ArrayRef<OpFoldResult> mixedValues) {
SmallVector<int64_t> staticValues; SmallVector<int64_t> staticValues;
SmallVector<Value> dynamicValues; SmallVector<Value> dynamicValues;
for (const auto &it : mixedValues) { for (const auto &it : mixedValues) {

View File

@@ -112,7 +112,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0 %1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0
by num_threads = [0, 5], tile_sizes = [] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) by num_threads = [0, 5] tile_sizes = [] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
transform.yield transform.yield
} }
} }
@@ -134,10 +134,9 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]] // CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
// CHECK-DAG: %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]]) // CHECK-DAG: %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]])
// CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32> // CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
// CHECK: %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]] // CHECK-DAG: %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]]
// CHECK: %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32> // CHECK-DAG: %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[TEMPEXT:.+]] = tensor.extract_slice %[[ET]][0] [%[[D0]]] [1] : tensor<?xf32> to tensor<?xf32> // CHECK: %[[PARTIAL:.+]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "reduction"]} ins(%[[INCHUNK]] : tensor<?x?xf32>) outs(%[[ET]] : tensor<?xf32>) {
// CHECK: %[[PARTIAL:.+]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "reduction"]} ins(%[[INCHUNK]] : tensor<?x?xf32>) outs(%[[TEMPEXT]] : tensor<?xf32>) {
// CHECK: arith.mulf // CHECK: arith.mulf
// CHECK: arith.addf // CHECK: arith.addf
// CHECK: linalg.yield // CHECK: linalg.yield
@@ -166,7 +165,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0 %1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0
by num_threads = [0, 0, 5], tile_sizes = [] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) by num_threads = [0, 0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
transform.yield transform.yield
} }
} }
@@ -187,11 +186,10 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]] // CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
// CHECK-DAG: %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]]) // CHECK-DAG: %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]])
// CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor<?x?x5xf32> to tensor<?x?xf32> // CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor<?x?x5xf32> to tensor<?x?xf32>
// CHECK: %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]] // CHECK-DAG: %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]]
// CHECK: %[[INCHUNKA:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32> // CHECK-DAG: %[[INCHUNKA:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[INCHUNKB:.+]] = tensor.extract_slice %[[ARG1]][%[[TINDEX]], 0] [%[[TS1]], %[[D2]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32> // CHECK-DAG: %[[INCHUNKB:.+]] = tensor.extract_slice %[[ARG1]][%[[TINDEX]], 0] [%[[TS1]], %[[D2]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[TEMPEXT:.+]] = tensor.extract_slice %[[ET]][0, 0] [%[[D0]], %[[D2]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32> // CHECK: %[[PARTIAL:.+]] = linalg.matmul ins(%[[INCHUNKA]], %[[INCHUNKB]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ET]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[PARTIAL:.+]] = linalg.matmul ins(%[[INCHUNKA]], %[[INCHUNKB]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[TEMPEXT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: scf.forall.in_parallel { // CHECK: scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x5xf32> // CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x5xf32>
// CHECK: } // CHECK: }
@@ -204,113 +202,9 @@ module attributes {transform.with_named_sequence} {
// ----- // -----
func.func @reduction_tile_parallel_cyclic_dist(
%arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> {
%red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%arg0 : tensor<?x?xf32>)
outs(%out : tensor<?xf32>) {
^bb0(%arg7: f32, %arg9: f32):
%1 = arith.mulf %arg7, %arg7 : f32
%2 = arith.addf %1, %arg9 : f32
linalg.yield %2 : f32
} -> tensor<?xf32>
return %red : tensor<?xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0
by num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread<x>] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 3)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 3)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0)>
// CHECK: func @reduction_tile_parallel_cyclic_dist(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?xf32>
// CHECK-DAG: %[[I:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C15:.*]] = arith.constant 15 : index
// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
// CHECK: %[[E:.*]] = tensor.empty(%[[D0]]) : tensor<?x5xf32>
// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) in (5) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
// CHECK: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
// CHECK: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[LB:.+]] = affine.apply #[[MAP0]]()[%[[IV]]]
// CHECK: %[[CARRY:.+]] = scf.for %[[IV1:.+]] = %[[LB]] to %[[D1]] step %[[C15]] iter_args(%[[ACC:.+]] = %[[ET]]) -> (tensor<?xf32>) {
// CHECK: %[[TS0:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[D1]]]
// CHECK: %[[D3:.+]] = tensor.dim %[[ACC]], %[[C0]] : tensor<?xf32>
// CHECK: %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV1]]] [%[[D0]], %[[TS0]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[TEMPEXT:.+]] = tensor.extract_slice %[[ACC]][0] [%[[D3]]] [1] : tensor<?xf32> to tensor<?xf32>
// CHECK: %[[PARTIAL:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]]], iterator_types = ["parallel", "reduction"]} ins(%[[INCHUNK]] : tensor<?x?xf32>) outs(%[[TEMPEXT]] : tensor<?xf32>) {
// CHECK: arith.mulf
// CHECK: arith.addf
// CHECK: linalg.yield
// CHECK: } -> tensor<?xf32>
// CHECK: %[[INS:.+]] = tensor.insert_slice %[[PARTIAL]] into %[[ACC]][0] [%[[D3]]] [1] : tensor<?xf32> into tensor<?xf32>
// CHECK: scf.yield %[[INS]] : tensor<?xf32>
// CHECK: }
// CHECK: scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[CARRY]] into %[[ARG3]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
// CHECK: }
// CHECK: }
// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) dimensions = [1]
// CHECK: arith.addf
// CHECK: linalg.yield
// CHECK: }
// CHECK: return %[[R]] : tensor<?xf32>
// -----
func.func @reduction_tile_parallel_cyclic_dist(
%arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> {
%red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%arg0 : tensor<?x?xf32>)
outs(%out : tensor<?xf32>) {
^bb0(%arg7: f32, %arg9: f32):
%1 = arith.mulf %arg7, %arg7 : f32
%2 = arith.addf %1, %arg9 : f32
linalg.yield %2 : f32
} -> tensor<?xf32>
return %red : tensor<?xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0
by num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread<x>] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
// CHECK: expecting fill
// CHECK-NEXT: linalg.fill
transform.print %1 {name = "expecting fill"} : !transform.any_op
// CHECK: expecting parallel reduction
// CHECK-NEXT: linalg.generic
// CHECK: iterator_types = ["parallel", "reduction"]
transform.print %2 {name = "expecting parallel reduction"} : !transform.any_op
// CHECK: expecting parallel reduction
// CHECK-NEXT: linalg.reduce
// CHECK: iterator_types = ["parallel", "reduction"]
transform.print %3 {name = "expecting parallel reduction"} : !transform.any_op
transform.yield
}
}
// -----
func.func @reduction_untiled_forall( func.func @reduction_untiled_forall(
%arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> { %arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> {
// expected-note @below {{target operation}} // expected-error @below {{tiling parallel dimensions is not supported with partial reduction tiling strategies}}
%red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0)>], affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]} iterator_types = ["parallel", "reduction"]}
@@ -329,9 +223,8 @@ module attributes {transform.with_named_sequence} {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error @below {{could not tile reduction}} // expected-error @below {{could not tile reduction}}
%1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0 %1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0
by num_threads = [5], tile_sizes = [3], mapping = [#gpu.thread<x>] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) by num_threads = [5] tile_sizes = [3] mapping = [#gpu.thread<x>] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
transform.yield
} }
} }
@@ -643,3 +536,158 @@ module {
// CHECK-SAME: outs(%[[INIT]] : // CHECK-SAME: outs(%[[INIT]] :
// CHECK-SAME: dimensions = [1, 2] // CHECK-SAME: dimensions = [1, 2]
// CHECK: return %[[REDUCE]] // CHECK: return %[[REDUCE]]
// -----
func.func @reduction_tile_parallel_using_tile_sizes(
%arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> {
%red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%arg0 : tensor<?x?xf32>)
outs(%out : tensor<?xf32>) {
^bb0(%arg7: f32, %arg9: f32):
%1 = arith.mulf %arg7, %arg7 : f32
%2 = arith.addf %1, %arg9 : f32
linalg.yield %2 : f32
} -> tensor<?xf32>
return %red : tensor<?xf32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 ceildiv 5)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
// CHECK: func @reduction_tile_parallel_using_tile_sizes(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?xf32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
// CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
// CHECK-DAG: %[[PARALLEL_DIM:.+]] = affine.apply #[[MAP0]]()[%[[D1]]]
// CHECK: %[[E:.*]] = tensor.empty(%[[D0]], %[[PARALLEL_DIM]]) : tensor<?x?xf32>
// CHECK: %[[F:.*]] = linalg.fill
// CHECK-SAME: outs(%[[E]] :
// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) = (0) to (%[[D1]]) step (5) shared_outs(%[[ARG3:.+]] = %[[F]])
// CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP1]](%[[IV]])[%[[D1]]]
// CHECK-DAG: %[[INIT_OFFSET:.+]] = affine.apply #[[MAP0]]()[%[[IV]]]
// CHECK-DAG: %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [%[[D0]], %[[TS0]]] [1, 1]
// CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3]][0, %[[INIT_OFFSET]]] [%[[D0]], 1] [1, 1]
// CHECK: %[[PARTIAL:.+]] = linalg.generic
// CHECK-SAME: ins(%[[INCHUNK]] :
// CHECK-SAME: outs(%[[ET]] :
// CHECK: scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, %[[INIT_OFFSET]]] [%[[D0]], 1] [1, 1]
// CHECK: }
// CHECK: }
// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]]
// CHECK-SAME: outs(%[[ARG1]] :
// CHECK: return %[[R]] : tensor<?xf32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %2, %3, %loop = transform.structured.tile_reduction_using_forall %0
by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}
// -----
// Check that only one of the reduction dimension can be tiled (in this case inner).
#map = affine_map<(d0, d1, d2) -> (d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0)>
module {
func.func @reduction_using_forall_tile_single_of_multiple_reduction_inner(
%arg0: tensor<86x128xf32>, %arg1: tensor<4096x86x128xf32>, %arg2: tensor<4096xf32>) -> tensor<4096xf32> {
%0 = linalg.generic {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "reduction", "reduction"]}
ins(%arg0, %arg1 : tensor<86x128xf32>, tensor<4096x86x128xf32>) outs(%arg2 : tensor<4096xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%1 = arith.mulf %in, %in_0 : f32
%2 = arith.addf %1, %out : f32
linalg.yield %2 : f32
} -> tensor<4096xf32>
return %0 : tensor<4096xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%fill_op, %split_linalg_op, %combining_linalg_op, %for_op =
transform.structured.tile_reduction_using_forall %0 reduction_dims = [2] by tile_sizes = [0, 0, 64]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 ceildiv 64)>
// CHECK: func @reduction_using_forall_tile_single_of_multiple_reduction_inner(%[[ARG0:.+]]: tensor<86x128xf32>, %[[ARG1:.+]]: tensor<4096x86x128xf32>, %[[ARG2:.+]]: tensor<4096xf32>)
// CHECK: %[[E:.*]] = tensor.empty() : tensor<4096x2xf32>
// CHECK: %[[F:.*]] = linalg.fill
// CHECK-SAME: outs(%[[E]] :
// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) = (0) to (128) step (64) shared_outs(%[[ARG3:.+]] = %[[F]])
// CHECK-DAG: %[[INIT_OFFSET:.+]] = affine.apply #[[MAP0]]()[%[[IV]]]
// CHECK-DAG: %[[ARG0_SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [86, 64] [1, 1]
// CHECK-DAG: %[[ARG1_SLICE:.+]] = tensor.extract_slice %[[ARG1]][0, 0, %[[IV]]] [4096, 86, 64] [1, 1, 1]
// CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3]][0, %[[INIT_OFFSET]]] [4096, 1] [1, 1]
// CHECK: %[[PARTIAL:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0_SLICE]], %[[ARG1_SLICE]] :
// CHECK-SAME: outs(%[[ET]] :
// CHECK: scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, %[[INIT_OFFSET]]] [4096, 1] [1, 1]
// CHECK: }
// CHECK: }
// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]]
// CHECK-SAME: outs(%[[ARG2]] :
// CHECK: return %[[R]]
// -----
// Check that specifying both reduction dimensions, but setting tile size to 0 for one of them behaves consistent with specifying single reduction dimension.
#map = affine_map<(d0, d1, d2) -> (d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0)>
module {
func.func @reduction_using_forall_tilesize_0_of_multiple_reduction_inner(
%arg0: tensor<86x128xf32>, %arg1: tensor<4096x86x128xf32>, %arg2: tensor<4096xf32>) -> tensor<4096xf32> {
%0 = linalg.generic {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "reduction", "reduction"]}
ins(%arg0, %arg1 : tensor<86x128xf32>, tensor<4096x86x128xf32>) outs(%arg2 : tensor<4096xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%1 = arith.mulf %in, %in_0 : f32
%2 = arith.addf %1, %out : f32
linalg.yield %2 : f32
} -> tensor<4096xf32>
return %0 : tensor<4096xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%fill_op, %split_linalg_op, %combining_linalg_op, %for_op =
transform.structured.tile_reduction_using_forall %0 reduction_dims = [1, 2] by tile_sizes = [0, 0, 64]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 ceildiv 64)>
// CHECK: func @reduction_using_forall_tilesize_0_of_multiple_reduction_inner(%[[ARG0:.+]]: tensor<86x128xf32>, %[[ARG1:.+]]: tensor<4096x86x128xf32>, %[[ARG2:.+]]: tensor<4096xf32>)
// CHECK: %[[E:.*]] = tensor.empty() : tensor<4096x2xf32>
// CHECK: %[[F:.*]] = linalg.fill
// CHECK-SAME: outs(%[[E]] :
// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) = (0) to (128) step (64) shared_outs(%[[ARG3:.+]] = %[[F]])
// CHECK-DAG: %[[INIT_OFFSET:.+]] = affine.apply #[[MAP0]]()[%[[IV]]]
// CHECK-DAG: %[[ARG0_SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [86, 64] [1, 1]
// CHECK-DAG: %[[ARG1_SLICE:.+]] = tensor.extract_slice %[[ARG1]][0, 0, %[[IV]]] [4096, 86, 64] [1, 1, 1]
// CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3]][0, %[[INIT_OFFSET]]] [4096, 1] [1, 1]
// CHECK: %[[PARTIAL:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0_SLICE]], %[[ARG1_SLICE]] :
// CHECK-SAME: outs(%[[ET]] :
// CHECK: scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, %[[INIT_OFFSET]]] [4096, 1] [1, 1]
// CHECK: }
// CHECK: }
// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]]
// CHECK-SAME: outs(%[[ARG2]] :
// CHECK: return %[[R]]