[mlir][Linalg] Add a structured.pack_transpose transform op

This transform is complementary to the `structured.pack` op which
allows packing a whole op but does not allow transposes on the individual
operands.

`structured.pack_transpose` allows transposing single operands connected to
pack or unpack ops after the fact.

This makes the system overall more composable than e.g. a giant transform
op with all permutation specified at once.

Differential Revision: https://reviews.llvm.org/D142053
This commit is contained in:
Nicolas Vasilache
2023-01-18 12:26:13 -08:00
parent ff94419a28
commit 790f237012
7 changed files with 687 additions and 56 deletions

View File

@@ -773,6 +773,9 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
LogicalResult reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes);
/// Return the index in the indexingMaps vector that corresponds to this `opOperand`
int64_t getIndexingMapIndex(OpOperand *opOperand);
//========================================================================//
// Forwarding functions to access interface methods from the
// DestinationStyleOpInterface.

View File

@@ -363,8 +363,11 @@ def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
}];
}
//===----------------------------------------------------------------------===//
// PackOp
//===----------------------------------------------------------------------===//
def PackOp : Op<Transform_Dialect, "structured.pack", [
TransformOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,]> {
let description = [{
Pack a LinalgOp by applying a data tiling transformation on the op and
@@ -439,14 +442,73 @@ def PackOp : Op<Transform_Dialect, "structured.pack", [
}];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure apply(
transform::TransformResults &transformResults,
transform::TransformState &state);
::llvm::SmallVector<::mlir::OpFoldResult> getMixedPackedSizes();
}];
}
//===----------------------------------------------------------------------===//
// PackTransposeOp
//===----------------------------------------------------------------------===//
def PackTransposeOp : Op<Transform_Dialect, "structured.pack_transpose", [
FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let description = [{
Apply a transposition to a single `tensor.pack` (resp. `tensor.unpack`) and
update the `linalg.generic` op that consumes (resp. produces) the operation.
This transform allows composing a simple `structured.pack` with additional
transpositions to e.g. match the data format required by a specific library
call or ISA instruction.
The transpose spec must specify at least one of `outer_perm` or `inner_perm`
attributes, which will act upon the `outer_dims_perm` or `inner_dims_pos` of
the specified `tensor.pack` or `tensor.unpack` op.
If the `target` of this op is a `tensor.pack` then a new `tensor.empty` will
be created along with transposed versions of the `tensor.pack` and the
consuming `linalg.generic`, which is expected to be the sole consumer.
If the `target` of this op is a `tensor.unpack` then the whole pack / compute
/ unpack chain will be transposed and transposed clones of `tensor.pack`,
the consuming `linalg.generic` and the tail `tensor.pack` will be created.
#### Return modes
This operation targets a single `tensor.pack` / `tensor.unpack` op and a
single matching `linalg.generic` that consumes / produces the op. Otherwise,
it produces a silenceableFailure.
This operation may produce a silenceableFailure if the transpose spec is
ill-formed (i.e. `outer_perm` or `inner_perm` are not permutations of the
proper rank) or if the tranposition of all involved operations fails for any
reason.
This operation returns 3 handles, one to the transformed LinalgOp, one to
the transformed `tensor.pack` and one to the transformed `tensor.unpack`.
The last handle for `tensor.unpack` is empty if `target_pack_or_unpack_op`
was not itself a `tensor.unpack`.
}];
let arguments = (ins TransformHandleTypeInterface:$target_pack_or_un_pack_op,
TransformHandleTypeInterface:$target_linalg_op,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_perm,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$inner_perm);
let results = (outs TransformHandleTypeInterface:$packed_op,
TransformHandleTypeInterface:$pack_op,
TransformHandleTypeInterface:$un_pack_op);
let assemblyFormat = [{
$target_pack_or_un_pack_op
`with_compute_op` `(` $target_linalg_op `)`
(`outer_perm` `=` $outer_perm^ )?
(`inner_perm` `=` $inner_perm^ )?
attr-dict
`:` functional-type(operands, results)
}];
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// PadOp
//===----------------------------------------------------------------------===//

View File

@@ -1776,6 +1776,21 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
static Value createDestinationTensor(OpBuilder &b, Location loc,
Value source, ArrayRef<OpFoldResult> innerTileSizes,
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
/// Build and return a new PackOp that is a clone of the current PackOp with
/// (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by
/// innerPermutation (resp. outerPermutation).
/// A new `tensor.empty` of the proper shape is built in the process.
/// Asserts that:
/// - At least one of innerPermutation or outerPermutation is non-empty.
/// - If not empty, innerPermutation is a valid permutation of size
/// matching innerDimPos.
/// - If not empty, outerPermutation is a valid permutation of size
/// matching outerDimsPerm.
PackOp createTransposedClone(OpBuilder &b,
Location loc,
ArrayRef<int64_t> innerPermutation,
ArrayRef<int64_t> outerPermutation);
}];
let hasCanonicalizeMethod = 1;
@@ -1832,7 +1847,23 @@ def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> {
CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)>
];
let extraClassDeclaration = commonExtraClassDeclaration;
let extraClassDeclaration = commonExtraClassDeclaration # [{
/// Build and return a new UnPackOp that is a clone of the current UnPackOp
/// with (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by
/// innerPermutation (resp. outerPermutation).
/// Asserts that:
/// - At least one of innerPermutation or outerPermutation is non-empty.
/// - If not empty, innerPermutation is a valid permutation of size
/// matching innerDimPos.
/// - If not empty, outerPermutation is a valid permutation of size
/// matching outerDimsPerm.
UnPackOp createTransposedClone(OpBuilder &b,
Location loc,
Value transposedSource,
ArrayRef<int64_t> innerPermutation,
ArrayRef<int64_t> outerPermutation);
}];
let hasCanonicalizeMethod = 1;
}

View File

@@ -621,6 +621,22 @@ LinalgOp::reifyResultShapes(OpBuilder &b,
return success();
}
/// Return the index in the indexingMaps vector that corresponds to this
/// `opOperand`.
int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
auto operandNumber = opOperand->getOperandNumber();
auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
if (!dpsIface.isDpsInput(opOperand))
return operandNumber;
auto [start, end] = dpsIface.getDpsInitsPositionRange();
assert(!dpsIface.isDpsInit(opOperand));
// Account for potential inputs that are not DPS and may not appear in
// `indexingMaps`.
return cast<DestinationStyleOpInterface>(*this->getOperation())
.getNumDpsInputs() +
operandNumber - start;
}
LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
LinalgOp linalgOp = cast<LinalgOp>(op);

View File

@@ -17,17 +17,21 @@
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/IR/TransformUtils.h"
#include "mlir/Dialect/Transform/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"
@@ -1161,16 +1165,12 @@ transform::PackOp::apply(transform::TransformResults &transformResults,
// Fail on multi-op handles.
auto linalgOp = dyn_cast<linalg::LinalgOp>(targetOps.front());
if (targetOps.size() != 1 || !linalgOp) {
// TODO: remove this unnecessary set to empty once crashes are fixed.
transformResults.set(getPackedOp().cast<OpResult>(), {});
return emitSilenceableError()
<< "requires target to map to exactly 1 LinalgOp (got "
<< targetOps.size() << ")";
}
// Fail on mismatched number of pack sizes.
if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
// TODO: remove this unnecessary set to empty once crashes are fixed.
transformResults.set(getPackedOp().cast<OpResult>(), {});
return emitSilenceableError()
<< "requires number of packed sizes match the number of loops ("
<< getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
@@ -1194,6 +1194,263 @@ void transform::PackOp::getEffects(
transform::consumesHandle(getTarget(), effects);
transform::onlyReadsHandle(getPackedSizes(), effects);
transform::producesHandle(getPackedOp(), effects);
transform::modifiesPayload(effects);
}
//===---------------------------------------------------------------------===//
// PackTransposeOp
//===---------------------------------------------------------------------===//
namespace {
enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
} // namespace
/// Return true if `permutation` is a valid permutation of the `outer_dims_perm`
/// (case OuterOrInnerPerm::Outer) or `inner_dims_pos` (OuterOrInnerPerm::Inner)
/// of the `tensor.pack` or `tensor.unpack` `op.
/// This is the case when the `permutation` rank matches the rank expected by
/// `op` and `permutation` is itself a permutation vector.
/// Return true if either `op` or `permutation` are empty to allow a simpler
/// polymorphic implementation.
template <typename RelayoutOpTy>
bool isValidPackingPermutation(
RelayoutOpTy op, ArrayRef<int64_t> permutation,
OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
static_assert(
llvm::is_one_of<RelayoutOpTy, tensor::PackOp, tensor::UnPackOp>::value,
"applies to only pack or unpack operations");
if (!op || permutation.empty())
return true;
int64_t innerRank = op.getInnerDimsPos().size();
if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
return permutation.size() == innerRank && isPermutationVector(permutation);
// op.getOuterDimsPerm() may be empty, in which case it is identity.
// Don't rely on it.
if (std::is_same<RelayoutOpTy, tensor::PackOp>::value) {
return permutation.size() == op.getSourceRank() &&
isPermutationVector(permutation);
}
return permutation.size() == op.getDestRank() &&
isPermutationVector(permutation);
}
/// Return a copy of `tensorType` after permutation by `permutationVector`.
// Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder
// but this would introduce a dependence on Dialect in IR.
// TODO: Restructure.
static RankedTensorType permuteShape(RankedTensorType tensorType,
ArrayRef<int64_t> permutationVector) {
SmallVector<int64_t> shape(tensorType.getShape());
applyPermutationToVector(shape, permutationVector);
return RankedTensorType::Builder(tensorType).setShape(shape);
}
/// Return a new GenericOp obtained by transposing opOperand by the permutation
/// vector:
/// - the corresponding indexing map is transposed by `permutation`
/// - the corresponding operand value is replaced by `transposedValue`
/// `linalgOp` is replaced by the return op in the process.
/// Asserts that `transposedValue` is of the proper transposed ShapedType.
static LinalgOp transposeOneLinalgOperandAndReplace(
RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand,
ArrayRef<int64_t> permutation, Value transposedValue) {
// Sanity check the operand.
assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand");
// Sanity check of the expected transposed tensor type.
auto tensorType = permuteShape(
opOperand.get().getType().cast<RankedTensorType>(), permutation);
assert(tensorType == transposedValue.getType() &&
"expected tensor type mismatch");
// Compute the transposed indexing map.
// Sigh unsigned pollution.
SmallVector<unsigned> tmpTransposition = llvm::to_vector(
llvm::map_range(permutation, [](int64_t i) -> unsigned { return i; }));
AffineMap permutationMap =
AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext());
AffineMap transposedMap =
permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand));
// Set the transposed indexing map in the proper position.
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
// Set the transposedValue in the proper operand position.
SmallVector<Value> operands = linalgOp->getOperands();
operands[opOperand.getOperandNumber()] = transposedValue;
ValueRange operandsRef(operands);
auto transposedGenericOp = rewriter.create<linalg::GenericOp>(
/*location=*/linalgOp->getLoc(),
/*resultTensorTypes=*/
operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(),
/*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()),
/*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()),
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/linalgOp.getIteratorTypesArray());
transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
}
LogicalResult transform::PackTransposeOp::verify() {
if (!isPermutationVector(getInnerPerm())) {
return emitOpError() << getInnerPermAttrName()
<< " is not a valid permutation";
}
if (!isPermutationVector(getOuterPerm())) {
return emitOpError() << getOuterPermAttrName()
<< " is not a valid permutation";
}
if (getInnerPerm().empty() && getOuterPerm().empty()) {
return emitOpError() << " at least one of " << getInnerPermAttrName()
<< " or " << getOuterPermAttrName()
<< " must be specified";
}
return success();
}
DiagnosedSilenceableFailure
transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
transform::TransformState &state) {
ArrayRef<Operation *> packOrUnpackOps =
state.getPayloadOps(getTargetPackOrUnPackOp());
ArrayRef<Operation *> linalgOps = state.getPayloadOps(getTargetLinalgOp());
// Step 1. If nothing to pack, propagate success.
if (packOrUnpackOps.empty()) {
transformResults.set(getPackedOp().cast<OpResult>(), {});
transformResults.set(getPackOp().cast<OpResult>(), {});
transformResults.set(getUnPackOp().cast<OpResult>(), {});
return DiagnosedSilenceableFailure::success();
}
// Step 2. Bunch of runtime sanity check and error messages.
// Step 2.1. Fail on multi-op handles.
if (packOrUnpackOps.size() != 1 || linalgOps.size() != 1) {
return emitSilenceableError()
<< "requires target to map to exactly 1 packing op and 1 packed op ("
<< "got " << packOrUnpackOps.size() << " and " << linalgOps.size()
<< ")";
}
// Step 2.2. Fail on wrong type.
auto packOp = dyn_cast<tensor::PackOp>(packOrUnpackOps.front());
auto unPackOp = dyn_cast<tensor::UnPackOp>(packOrUnpackOps.front());
if ((!packOp && !unPackOp)) {
return emitSilenceableError() << "requires target to map to a "
"tensor.pack or tensor.unpack";
}
LinalgOp linalgOpTarget = dyn_cast<linalg::LinalgOp>(linalgOps.front());
if (!linalgOpTarget)
return emitSilenceableError() << "requires a LinalgOp target";
// Step 2.3. Fail if we can't get the producer / consumer Linalg op.
LinalgOp linalgOp;
if (packOp && packOp.getResult().hasOneUse())
linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
else if (unPackOp)
linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
if (linalgOp != linalgOpTarget) {
auto errorMsg =
packOp ? StringLiteral{"not a single use by the LinalgOp target"}
: StringLiteral{"not produced by the LinalgOp target"};
return emitSilenceableError() << errorMsg;
}
// Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical PackOp.
if (unPackOp) {
assert(!packOp && "packOp must be null on entry when unPackOp is not null");
OpOperand *packUse = linalgOp.getDpsInitOperand(
unPackOp.getSource().cast<OpResult>().getResultNumber());
packOp = dyn_cast_or_null<tensor::PackOp>(packUse->get().getDefiningOp());
if (!packOp || !packOp.getResult().hasOneUse())
return emitSilenceableError() << "could not find matching pack op";
}
// Step 2.5. Fail if any permutation does not validate.
for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
ArrayRef<int64_t> perm =
(permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
auto errorMsg = (permType == OuterOrInnerPerm::Outer)
? StringLiteral{"invalid outer_perm"}
: StringLiteral{"invalid inner_perm"};
if (!isValidPackingPermutation(packOp, perm, permType) ||
!isValidPackingPermutation(unPackOp, perm, permType)) {
Operation *packOrUnpackOp =
unPackOp ? unPackOp.getOperation() : packOp.getOperation();
return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
}
}
// From here on, packOp and linalgOp are always present, unPackOp may or may
// not be present.
assert(packOp && linalgOp && "unexpected null op");
// Step 3. Actually transpose the ops.
Location loc = linalgOp.getLoc();
IRRewriter rewriter(getContext());
// Step 3.a. Transpose packOp.
rewriter.setInsertionPoint(packOp);
tensor::PackOp transposedPackOp = packOp.createTransposedClone(
rewriter, loc, getInnerPerm(), getOuterPerm());
// Step 3.b. Transpose linalgOp.
assert(packOp.getResult().hasOneUse() && "expect single use");
// transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the
// identity. Don't rely on it.
int64_t numLeadingDims = packOp.getSourceRank();
int64_t numTrailingDims = packOp.getInnerDimsPos().size();
// Step 3.b.i. Compute the permutation on the whole operand.
// Leading part just reuse the outerPerm.
SmallVector<int64_t> permutation(getOuterPerm());
if (permutation.empty())
llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
// Trailing part needs to reindex positions by `numLeadingDims`.
if (getInnerPerm().empty()) {
llvm::append_range(
permutation,
llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
} else {
llvm::append_range(permutation,
llvm::map_range(getInnerPerm(), [&](int64_t pos) {
return numLeadingDims + pos;
}));
}
assert(isPermutationVector(permutation) && "invalid permutation");
// Step 3.b.ii. Save the transposedPackUse operand number in case we need to
// get the tied OpResult after `linalgOp` has been replaced.
OpOperand &packUse = *(packOp.getResult().getUses().begin());
int64_t packUseOperandNumber = packUse.getOperandNumber();
// Step 3.b.iii. Actually perform the transposition.
rewriter.setInsertionPoint(linalgOp);
linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace(
rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
// Step 3.c. Maybe transpose unPackOp.
tensor::UnPackOp transposedUnPackOp;
if (unPackOp) {
OpOperand &opOperand =
transposedLinalgOp->getOpOperand(packUseOperandNumber);
OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
rewriter.setInsertionPoint(unPackOp);
transposedUnPackOp = unPackOp.createTransposedClone(
rewriter, loc, transposedResult, getInnerPerm(), getOuterPerm());
}
// Step 4. Replace and return results.
rewriter.replaceOp(packOp, transposedPackOp->getResults());
transformResults.set(getPackOp().cast<OpResult>(), {transposedPackOp});
// transposedLinalgOp was replaced in `transposeOneLinalgOperandAndReplace`.
transformResults.set(getPackedOp().cast<OpResult>(), {transposedLinalgOp});
if (unPackOp) {
rewriter.replaceOp(unPackOp, transposedUnPackOp->getResults());
transformResults.set(getUnPackOp().cast<OpResult>(), {transposedUnPackOp});
} else {
transformResults.set(getUnPackOp().cast<OpResult>(), {});
}
return DiagnosedSilenceableFailure::success();
}
//===---------------------------------------------------------------------===//
@@ -1359,7 +1616,7 @@ transform::ReplaceOp::apply(TransformResults &transformResults,
if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
target->getNumRegions() > 0)
return emitDefiniteFailure()
<< "expected target that is isloated from above";
<< "expected target that is isolated from above";
}
// Clone and replace.
@@ -1907,32 +2164,31 @@ transform::TileOp::apply(TransformResults &transformResults,
scf::SCFTilingOptions tilingOptions;
unsigned index = en.index();
if (!tileSizes.empty()) {
tilingOptions.setTileSizeComputationFunction(
[&, index](OpBuilder &b, Operation *) {
SmallVector<Value, 4> sizes;
sizes.reserve(tileSizes.size());
unsigned dynamicIdx = 0;
for (OpFoldResult ofr : getMixedSizes()) {
if (auto attr = ofr.dyn_cast<Attribute>()) {
sizes.push_back(b.create<arith::ConstantIndexOp>(
getLoc(), attr.cast<IntegerAttr>().getInt()));
continue;
}
ArrayRef<Operation *> dynamicSizes =
dynamicSizeProducers[dynamicIdx];
ArrayRef<int64_t> params = paramSizes[dynamicIdx];
++dynamicIdx;
assert((dynamicSizes.empty() ^ params.empty()) &&
"expected either dynamic sizes or parameters");
if (!params.empty()) {
sizes.push_back(
b.create<arith::ConstantIndexOp>(getLoc(), params[index]));
} else {
sizes.push_back(dynamicSizes[index]->getResult(0));
}
}
return sizes;
});
tilingOptions.setTileSizeComputationFunction([&, index](OpBuilder &b,
Operation *) {
SmallVector<Value, 4> sizes;
sizes.reserve(tileSizes.size());
unsigned dynamicIdx = 0;
for (OpFoldResult ofr : getMixedSizes()) {
if (auto attr = ofr.dyn_cast<Attribute>()) {
sizes.push_back(b.create<arith::ConstantIndexOp>(
getLoc(), attr.cast<IntegerAttr>().getInt()));
continue;
}
ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
ArrayRef<int64_t> params = paramSizes[dynamicIdx];
++dynamicIdx;
assert((dynamicSizes.empty() ^ params.empty()) &&
"expected either dynamic sizes or parameters");
if (!params.empty()) {
sizes.push_back(
b.create<arith::ConstantIndexOp>(getLoc(), params[index]));
} else {
sizes.push_back(dynamicSizes[index]->getResult(0));
}
}
return sizes;
});
}
tilingOptions.setInterchange(getInterchange());
@@ -2149,27 +2405,27 @@ DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
// Transform all targets one by one.
for (Operation *target : targets) {
auto tilableOp = dyn_cast<TilingInterface>(target);
if (!tilableOp) {
auto tileableOp = dyn_cast<TilingInterface>(target);
if (!tileableOp) {
DiagnosedSilenceableFailure diag =
transformOp.emitSilenceableError()
<< "only TilingInterface ops are supported";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
rewriter.setInsertionPoint(tilableOp);
rewriter.setInsertionPoint(tileableOp);
FailureOr<linalg::ForeachThreadTilingResult> tilingResult = failure();
if (!mixedNumThreads.empty()) {
tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp,
tilingResult = linalg::tileToForeachThreadOp(rewriter, tileableOp,
mixedNumThreads, mapping);
} else {
tilingResult = linalg::tileToForeachThreadOpUsingTileSizes(
rewriter, tilableOp, mixedTileSizes, mapping);
rewriter, tileableOp, mixedTileSizes, mapping);
}
if (failed(tilingResult))
return transformOp.emitDefaultSilenceableFailure(tilableOp);
rewriter.replaceOp(tilableOp, tilingResult->tileOp->getResults());
return transformOp.emitDefaultSilenceableFailure(tileableOp);
rewriter.replaceOp(tileableOp, tilingResult->tileOp->getResults());
tileOps.push_back(tilingResult->tileOp);
tiledOps.push_back(tilingResult->tiledOp);

View File

@@ -3231,7 +3231,6 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
return true;
}
return shape == constTileSize.value();
})) {
return op->emitError("mismatch in inner tile sizes specified and shaped of "
"tiled dimension in the packed type");
@@ -3239,6 +3238,57 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
return success();
}
namespace {
/// Subset of PackOp/UnPackOp fields used to compute the result of applying
/// various permutations to the op.
// TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
// these. These may or may not become true foldings / canonicalizations
// depending on how aggressive we want to be in automatically folding
// transposes.
struct PackOrUnPackTransposeResult {
SmallVector<int64_t> innerDimsPos;
SmallVector<OpFoldResult> innerTiles;
SmallVector<int64_t> outerDimsPerm;
};
} // namespace
template <typename OpTy>
static PackOrUnPackTransposeResult
commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp,
ArrayRef<int64_t> innerPermutation,
ArrayRef<int64_t> outerPermutation) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
"applies to only pack or unpack operations");
assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
"some permutation must be non-empty");
PackOrUnPackTransposeResult metadata;
metadata.innerDimsPos =
SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
metadata.innerTiles =
SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
? packOrUnPackOp.getSourceRank()
: packOrUnPackOp.getDestRank();
metadata.outerDimsPerm =
packOrUnPackOp.getOuterDimsPerm().empty()
? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
: SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
if (!innerPermutation.empty()) {
assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
isPermutationVector(innerPermutation) &&
"invalid inner permutation");
applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
applyPermutationToVector(metadata.innerTiles, innerPermutation);
}
if (!outerPermutation.empty()) {
assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
isPermutationVector(outerPermutation) &&
"invalid outer permutation");
applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
}
return metadata;
}
//===----------------------------------------------------------------------===//
// PackOp
//===----------------------------------------------------------------------===//
@@ -3386,6 +3436,19 @@ Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
}
PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
ArrayRef<int64_t> innerPermutation,
ArrayRef<int64_t> outerPermutation) {
PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
*this, innerPermutation, outerPermutation);
Value transposedDest =
createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
metadata.innerDimsPos, metadata.outerDimsPerm);
return b.create<PackOp>(loc, getSource(), transposedDest,
metadata.innerDimsPos, metadata.innerTiles,
getPaddingValue(), metadata.outerDimsPerm);
}
/// Returns true if the tiles and the tiled dims are constant.
template <typename OpTy>
bool areTilesAndTiledDimsAllConstant(OpTy op) {
@@ -3508,6 +3571,17 @@ void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
builder.getDenseI64ArrayAttr(staticTileSizes));
}
UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
Value transposedSource,
ArrayRef<int64_t> innerPermutation,
ArrayRef<int64_t> outerPermutation) {
PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
*this, innerPermutation, outerPermutation);
return b.create<UnPackOp>(loc, transposedSource, getDest(),
metadata.innerDimsPos, metadata.innerTiles,
metadata.outerDimsPerm);
}
/// pack(unpack(x)) -> x
LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
PatternRewriter &rewriter) {

View File

@@ -49,21 +49,21 @@ transform.sequence failures(propagate) {
iterator_types = ["reduction", "parallel"]
}
// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2) -> (d1)>
// CHECK-LABEL: @col_reduction_2d_static
// CHECK-SAME: %[[T0:.+]]: tensor<7x3xf16>,
// CHECK-SAME: %[[T1:.+]]: tensor<3xf16>
func.func @col_reduction_2d_static(%t0: tensor<7x3xf16>, %t1: tensor<3xf16>) -> tensor<3xf16> {
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x4xf16>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<3x2x4xf16>
// CHECK: %[[PACKED:.*]] = tensor.pack %[[T0]] padding_value(%{{.*}} : f16)
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<7x3xf16> -> tensor<2x3x4xf16>
// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<7x3xf16> -> tensor<3x2x4xf16>
// CHECK-NOT: tensor.pack
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]]]
// CHECK-SAME: iterator_types = ["reduction", "parallel", "reduction"]
// CHECK-SAME: ins(%{{.*}} : tensor<2x3x4xf16>)
// CHECK-SAME: ins(%{{.*}} : tensor<3x2x4xf16>)
// CHECK-SAME: outs(%{{.*}} : tensor<3xf16>)
%2 = linalg.generic #col_reduction_2d_trait ins(%t0 : tensor<7x3xf16>) outs(%t1 : tensor<3xf16>) {
^bb0(%in: f16, %out: f16):
@@ -78,8 +78,15 @@ func.func @col_reduction_2d_static(%t0: tensor<7x3xf16>, %t1: tensor<3xf16>) ->
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
transform.structured.pack %0 packed_sizes = [4, 0]
%1 = transform.structured.pack %0 packed_sizes = [4, 0]
: (!pdl.operation) -> (!transform.op<"linalg.generic">)
%pack = transform.get_producer_of_operand %1[0]
: (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.pack">)
%2, %pack_2, %empty_unpack_2 =
transform.structured.pack_transpose %pack with_compute_op(%1)
outer_perm = [1, 0]
: (!transform.op<"tensor.pack">, !transform.op<"linalg.generic">)
-> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !pdl.operation)
}
// -----
@@ -183,7 +190,7 @@ transform.sequence failures(propagate) {
// K N n k
// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)>
// M N m n
// CHECK-DAG: #[[$PACKED_MAP_2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
// CHECK-DAG: #[[$PACKED_MAP_2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d4, d3)>
// CHECK-LABEL: @matmul
// CHECK-SAME: %[[A:[0-9a-zA-Z]+]]: tensor<?x?xf32>,
@@ -196,19 +203,19 @@ func.func @matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>)
// CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?x2x4xf32>
// CHECK: %[[PACK_B:.*]] = tensor.pack %{{.*}} inner_dims_pos = [1, 0] inner_tiles = [3, 4]
// CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?x3x4xf32>
// CHECK: %[[PACK_C:.*]] = tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [2, 3]
// CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?x2x3xf32>
// CHECK: %[[PACK_C:.*]] = tensor.pack %{{.*}} outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [3, 2]
// CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?x3x2xf32>
// CHECK: linalg.generic {indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]], #[[$PACKED_MAP_2]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
// CHECK-SAME: ins(%{{.*}} : tensor<?x?x2x4xf32>, tensor<?x?x3x4xf32>)
// CHECK-SAME: outs(%{{.*}} : tensor<?x?x2x3xf32>)
// CHECK-SAME: outs(%{{.*}} : tensor<?x?x3x2xf32>)
%0 = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
outs(%C: tensor<?x?xf32>)
-> tensor<?x?xf32>
// CHECK: tensor.unpack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [2, 3]
// CHECK-SAME: : tensor<?x?x2x3xf32> -> tensor<?x?xf32>
// CHECK: tensor.unpack %{{.*}} outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [3, 2]
// CHECK-SAME: : tensor<?x?x3x2xf32> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
@@ -218,6 +225,14 @@ transform.sequence failures(propagate) {
// M N K
%1 = transform.structured.pack %0 packed_sizes = [2, 3, 4]
: (!pdl.operation) -> (!transform.op<"linalg.generic">)
%unpack = transform.get_consumers_of_result %1[0]
: (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.unpack">)
%2, %pack_2, %unpack_2 =
transform.structured.pack_transpose %unpack with_compute_op(%1)
outer_perm = [1, 0] inner_perm = [1, 0]
: (!transform.op<"tensor.unpack">, !transform.op<"linalg.generic">)
-> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.op<"tensor.unpack">)
}
// -----
@@ -404,3 +419,177 @@ transform.sequence failures(propagate) {
%1 = transform.structured.pack %0 packed_sizes = [2, 3]
: (!pdl.operation) -> (!transform.op<"linalg.generic">)
}
// -----
func.func @no_single_packing_op(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) {
%0 = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
%1 = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %source : tensor<4x16x32x16xf32> -> tensor<128x256xf32>
%2 = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
return
}
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["tensor.pack"]} in %arg1
%1 = transform.structured.match ops{["tensor.unpack"]} in %arg1
// expected-error @below {{requires target to map to exactly 1 packing op and 1 packed op (got 2 and 1)}}
transform.structured.pack_transpose %0 with_compute_op(%1)
inner_perm = [0]
: (!pdl.operation, !pdl.operation)
-> (!pdl.operation, !pdl.operation, !pdl.operation)
}
// -----
func.func @no_single_pack_unpack(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) {
%0 = arith.constant 0 : index
%1 = tensor.empty() : tensor<f32>
return
}
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["arith.constant"]} in %arg1
%1 = transform.structured.match ops{["tensor.empty"]} in %arg1
// expected-error @below {{requires target to map to a tensor.pack or tensor.unpack}}
transform.structured.pack_transpose %0 with_compute_op(%1)
inner_perm = [0]
: (!pdl.operation, !pdl.operation)
-> (!pdl.operation, !pdl.operation, !pdl.operation)
}
// -----
func.func @no_linalg_target(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) {
%0 = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
%1 = arith.constant 0 : index
return
}
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["tensor.pack"]} in %arg1
%1 = transform.structured.match ops{["arith.constant"]} in %arg1
// expected-error @below {{requires a LinalgOp target}}
transform.structured.pack_transpose %0 with_compute_op(%1)
inner_perm = [0]
: (!pdl.operation, !pdl.operation)
-> (!pdl.operation, !pdl.operation, !pdl.operation)
}
// -----
func.func @no_single_use_by_linalg(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) {
%0 = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
%f0 = arith.constant 0.0 : f32
%1 = tensor.empty() : tensor<f32>
%2 = linalg.fill ins(%f0: f32) outs(%1 : tensor<f32>) -> tensor<f32>
return
}
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["tensor.pack"]} in %arg1
%1 = transform.structured.match ops{["linalg.fill"]} in %arg1
// expected-error @below {{not a single use by the LinalgOp target}}
transform.structured.pack_transpose %0 with_compute_op(%1)
inner_perm = [0]
: (!pdl.operation, !pdl.operation)
-> (!pdl.operation, !pdl.operation, !pdl.operation)
}
// -----
func.func @not_produced_by_linalg(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) {
%a = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
%b = tensor.unpack %a inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %source : tensor<4x16x32x16xf32> -> tensor<128x256xf32>
%f0 = arith.constant 0.0 : f32
%1 = tensor.empty() : tensor<f32>
%2 = linalg.fill ins(%f0: f32) outs(%1 : tensor<f32>) -> tensor<f32>
return
}
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["tensor.unpack"]} in %arg1
%1 = transform.structured.match ops{["linalg.fill"]} in %arg1
// expected-error @below {{not produced by the LinalgOp target}}
transform.structured.pack_transpose %0 with_compute_op(%1)
inner_perm = [0]
: (!pdl.operation, !pdl.operation)
-> (!pdl.operation, !pdl.operation, !pdl.operation)
}
// -----
func.func @no_matching_pack(%source: tensor<16xf32>) {
%f0 = arith.constant 0.0 : f32
%1 = tensor.empty() : tensor<4x4xf32>
%2 = linalg.fill ins(%f0: f32) outs(%1 : tensor<4x4xf32>) -> tensor<4x4xf32>
%b = tensor.unpack %2 inner_dims_pos = [0] inner_tiles = [4] into %source : tensor<4x4xf32> -> tensor<16xf32>
return
}
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["tensor.unpack"]} in %arg1
%1 = transform.structured.match ops{["linalg.fill"]} in %arg1
// expected-error @below {{could not find matching pack op}}
transform.structured.pack_transpose %0 with_compute_op(%1)
inner_perm = [0]
: (!pdl.operation, !pdl.operation)
-> (!pdl.operation, !pdl.operation, !pdl.operation)
}
// -----
func.func @invalid_outer_perm(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>)
-> tensor<?x?xf32> {
%0 = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
outs(%C: tensor<?x?xf32>)
-> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1 = transform.structured.pack %0 packed_sizes = [2, 3, 4]
: (!pdl.operation) -> (!transform.op<"linalg.generic">)
%unpack = transform.get_consumers_of_result %1[0]
: (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.unpack">)
%2, %pack_2, %unpack_2 =
// expected-error @below {{invalid outer_perm}}
transform.structured.pack_transpose %unpack with_compute_op(%1)
outer_perm = [1]
: (!transform.op<"tensor.unpack">, !transform.op<"linalg.generic">)
-> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.op<"tensor.unpack">)
}
// -----
func.func @invalid_inner_perm(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>)
-> tensor<?x?xf32> {
%0 = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
outs(%C: tensor<?x?xf32>)
-> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
%1 = transform.structured.pack %0 packed_sizes = [2, 3, 4]
: (!pdl.operation) -> (!transform.op<"linalg.generic">)
%unpack = transform.get_consumers_of_result %1[0]
: (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.unpack">)
%2, %pack_2, %unpack_2 =
// expected-error @below {{invalid inner_perm}}
transform.structured.pack_transpose %unpack with_compute_op(%1)
inner_perm = [1]
: (!transform.op<"tensor.unpack">, !transform.op<"linalg.generic">)
-> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.op<"tensor.unpack">)
}