[mlir][Linalg] Add a structured.pack_transpose transform op
This transform is complementary to the `structured.pack` op which allows packing a whole op but does not allow transposes on the individual operands. `structured.pack_transpose` allows transposing single operands connected to pack or unpack ops after the fact. This makes the system overall more composable than e.g. a giant transform op with all permutation specified at once. Differential Revision: https://reviews.llvm.org/D142053
This commit is contained in:
@@ -773,6 +773,9 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
||||
LogicalResult reifyResultShapes(OpBuilder &b,
|
||||
ReifiedRankedShapedTypeDims &reifiedReturnShapes);
|
||||
|
||||
/// Return the index in the indexingMaps vector that corresponds to this `opOperand`
|
||||
int64_t getIndexingMapIndex(OpOperand *opOperand);
|
||||
|
||||
//========================================================================//
|
||||
// Forwarding functions to access interface methods from the
|
||||
// DestinationStyleOpInterface.
|
||||
|
||||
@@ -363,8 +363,11 @@ def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PackOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
def PackOp : Op<Transform_Dialect, "structured.pack", [
|
||||
TransformOpInterface,
|
||||
DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,]> {
|
||||
let description = [{
|
||||
Pack a LinalgOp by applying a data tiling transformation on the op and
|
||||
@@ -439,14 +442,73 @@ def PackOp : Op<Transform_Dialect, "structured.pack", [
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::DiagnosedSilenceableFailure apply(
|
||||
transform::TransformResults &transformResults,
|
||||
transform::TransformState &state);
|
||||
|
||||
::llvm::SmallVector<::mlir::OpFoldResult> getMixedPackedSizes();
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PackTransposeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
def PackTransposeOp : Op<Transform_Dialect, "structured.pack_transpose", [
|
||||
FunctionalStyleTransformOpTrait,
|
||||
MemoryEffectsOpInterface,
|
||||
DeclareOpInterfaceMethods<TransformOpInterface>]> {
|
||||
let description = [{
|
||||
Apply a transposition to a single `tensor.pack` (resp. `tensor.unpack`) and
|
||||
update the `linalg.generic` op that consumes (resp. produces) the operation.
|
||||
|
||||
This transform allows composing a simple `structured.pack` with additional
|
||||
transpositions to e.g. match the data format required by a specific library
|
||||
call or ISA instruction.
|
||||
|
||||
The transpose spec must specify at least one of `outer_perm` or `inner_perm`
|
||||
attributes, which will act upon the `outer_dims_perm` or `inner_dims_pos` of
|
||||
the specified `tensor.pack` or `tensor.unpack` op.
|
||||
|
||||
If the `target` of this op is a `tensor.pack` then a new `tensor.empty` will
|
||||
be created along with transposed versions of the `tensor.pack` and the
|
||||
consuming `linalg.generic`, which is expected to be the sole consumer.
|
||||
|
||||
If the `target` of this op is a `tensor.unpack` then the whole pack / compute
|
||||
/ unpack chain will be transposed and transposed clones of `tensor.pack`,
|
||||
the consuming `linalg.generic` and the tail `tensor.pack` will be created.
|
||||
|
||||
#### Return modes
|
||||
|
||||
This operation targets a single `tensor.pack` / `tensor.unpack` op and a
|
||||
single matching `linalg.generic` that consumes / produces the op. Otherwise,
|
||||
it produces a silenceableFailure.
|
||||
|
||||
This operation may produce a silenceableFailure if the transpose spec is
|
||||
ill-formed (i.e. `outer_perm` or `inner_perm` are not permutations of the
|
||||
proper rank) or if the tranposition of all involved operations fails for any
|
||||
reason.
|
||||
|
||||
This operation returns 3 handles, one to the transformed LinalgOp, one to
|
||||
the transformed `tensor.pack` and one to the transformed `tensor.unpack`.
|
||||
The last handle for `tensor.unpack` is empty if `target_pack_or_unpack_op`
|
||||
was not itself a `tensor.unpack`.
|
||||
}];
|
||||
|
||||
let arguments = (ins TransformHandleTypeInterface:$target_pack_or_un_pack_op,
|
||||
TransformHandleTypeInterface:$target_linalg_op,
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_perm,
|
||||
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$inner_perm);
|
||||
let results = (outs TransformHandleTypeInterface:$packed_op,
|
||||
TransformHandleTypeInterface:$pack_op,
|
||||
TransformHandleTypeInterface:$un_pack_op);
|
||||
let assemblyFormat = [{
|
||||
$target_pack_or_un_pack_op
|
||||
`with_compute_op` `(` $target_linalg_op `)`
|
||||
(`outer_perm` `=` $outer_perm^ )?
|
||||
(`inner_perm` `=` $inner_perm^ )?
|
||||
attr-dict
|
||||
`:` functional-type(operands, results)
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -1776,6 +1776,21 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
|
||||
static Value createDestinationTensor(OpBuilder &b, Location loc,
|
||||
Value source, ArrayRef<OpFoldResult> innerTileSizes,
|
||||
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
|
||||
|
||||
/// Build and return a new PackOp that is a clone of the current PackOp with
|
||||
/// (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by
|
||||
/// innerPermutation (resp. outerPermutation).
|
||||
/// A new `tensor.empty` of the proper shape is built in the process.
|
||||
/// Asserts that:
|
||||
/// - At least one of innerPermutation or outerPermutation is non-empty.
|
||||
/// - If not empty, innerPermutation is a valid permutation of size
|
||||
/// matching innerDimPos.
|
||||
/// - If not empty, outerPermutation is a valid permutation of size
|
||||
/// matching outerDimsPerm.
|
||||
PackOp createTransposedClone(OpBuilder &b,
|
||||
Location loc,
|
||||
ArrayRef<int64_t> innerPermutation,
|
||||
ArrayRef<int64_t> outerPermutation);
|
||||
}];
|
||||
|
||||
let hasCanonicalizeMethod = 1;
|
||||
@@ -1832,7 +1847,23 @@ def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> {
|
||||
CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = commonExtraClassDeclaration;
|
||||
let extraClassDeclaration = commonExtraClassDeclaration # [{
|
||||
/// Build and return a new UnPackOp that is a clone of the current UnPackOp
|
||||
/// with (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by
|
||||
/// innerPermutation (resp. outerPermutation).
|
||||
/// Asserts that:
|
||||
/// - At least one of innerPermutation or outerPermutation is non-empty.
|
||||
/// - If not empty, innerPermutation is a valid permutation of size
|
||||
/// matching innerDimPos.
|
||||
/// - If not empty, outerPermutation is a valid permutation of size
|
||||
/// matching outerDimsPerm.
|
||||
UnPackOp createTransposedClone(OpBuilder &b,
|
||||
Location loc,
|
||||
Value transposedSource,
|
||||
ArrayRef<int64_t> innerPermutation,
|
||||
ArrayRef<int64_t> outerPermutation);
|
||||
}];
|
||||
|
||||
let hasCanonicalizeMethod = 1;
|
||||
}
|
||||
|
||||
|
||||
@@ -621,6 +621,22 @@ LinalgOp::reifyResultShapes(OpBuilder &b,
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Return the index in the indexingMaps vector that corresponds to this
|
||||
/// `opOperand`.
|
||||
int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
|
||||
auto operandNumber = opOperand->getOperandNumber();
|
||||
auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
|
||||
if (!dpsIface.isDpsInput(opOperand))
|
||||
return operandNumber;
|
||||
auto [start, end] = dpsIface.getDpsInitsPositionRange();
|
||||
assert(!dpsIface.isDpsInit(opOperand));
|
||||
// Account for potential inputs that are not DPS and may not appear in
|
||||
// `indexingMaps`.
|
||||
return cast<DestinationStyleOpInterface>(*this->getOperation())
|
||||
.getNumDpsInputs() +
|
||||
operandNumber - start;
|
||||
}
|
||||
|
||||
LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
|
||||
LinalgOp linalgOp = cast<LinalgOp>(op);
|
||||
|
||||
|
||||
@@ -17,17 +17,21 @@
|
||||
#include "mlir/Dialect/PDL/IR/PDL.h"
|
||||
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformUtils.h"
|
||||
#include "mlir/Dialect/Transform/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/TilingInterface.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/ScopeExit.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
@@ -1161,16 +1165,12 @@ transform::PackOp::apply(transform::TransformResults &transformResults,
|
||||
// Fail on multi-op handles.
|
||||
auto linalgOp = dyn_cast<linalg::LinalgOp>(targetOps.front());
|
||||
if (targetOps.size() != 1 || !linalgOp) {
|
||||
// TODO: remove this unnecessary set to empty once crashes are fixed.
|
||||
transformResults.set(getPackedOp().cast<OpResult>(), {});
|
||||
return emitSilenceableError()
|
||||
<< "requires target to map to exactly 1 LinalgOp (got "
|
||||
<< targetOps.size() << ")";
|
||||
}
|
||||
// Fail on mismatched number of pack sizes.
|
||||
if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
|
||||
// TODO: remove this unnecessary set to empty once crashes are fixed.
|
||||
transformResults.set(getPackedOp().cast<OpResult>(), {});
|
||||
return emitSilenceableError()
|
||||
<< "requires number of packed sizes match the number of loops ("
|
||||
<< getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
|
||||
@@ -1194,6 +1194,263 @@ void transform::PackOp::getEffects(
|
||||
transform::consumesHandle(getTarget(), effects);
|
||||
transform::onlyReadsHandle(getPackedSizes(), effects);
|
||||
transform::producesHandle(getPackedOp(), effects);
|
||||
transform::modifiesPayload(effects);
|
||||
}
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// PackTransposeOp
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
|
||||
} // namespace
|
||||
|
||||
/// Return true if `permutation` is a valid permutation of the `outer_dims_perm`
|
||||
/// (case OuterOrInnerPerm::Outer) or `inner_dims_pos` (OuterOrInnerPerm::Inner)
|
||||
/// of the `tensor.pack` or `tensor.unpack` `op.
|
||||
/// This is the case when the `permutation` rank matches the rank expected by
|
||||
/// `op` and `permutation` is itself a permutation vector.
|
||||
/// Return true if either `op` or `permutation` are empty to allow a simpler
|
||||
/// polymorphic implementation.
|
||||
template <typename RelayoutOpTy>
|
||||
bool isValidPackingPermutation(
|
||||
RelayoutOpTy op, ArrayRef<int64_t> permutation,
|
||||
OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
|
||||
static_assert(
|
||||
llvm::is_one_of<RelayoutOpTy, tensor::PackOp, tensor::UnPackOp>::value,
|
||||
"applies to only pack or unpack operations");
|
||||
if (!op || permutation.empty())
|
||||
return true;
|
||||
int64_t innerRank = op.getInnerDimsPos().size();
|
||||
if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
|
||||
return permutation.size() == innerRank && isPermutationVector(permutation);
|
||||
// op.getOuterDimsPerm() may be empty, in which case it is identity.
|
||||
// Don't rely on it.
|
||||
if (std::is_same<RelayoutOpTy, tensor::PackOp>::value) {
|
||||
return permutation.size() == op.getSourceRank() &&
|
||||
isPermutationVector(permutation);
|
||||
}
|
||||
return permutation.size() == op.getDestRank() &&
|
||||
isPermutationVector(permutation);
|
||||
}
|
||||
|
||||
/// Return a copy of `tensorType` after permutation by `permutationVector`.
|
||||
// Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder
|
||||
// but this would introduce a dependence on Dialect in IR.
|
||||
// TODO: Restructure.
|
||||
static RankedTensorType permuteShape(RankedTensorType tensorType,
|
||||
ArrayRef<int64_t> permutationVector) {
|
||||
SmallVector<int64_t> shape(tensorType.getShape());
|
||||
applyPermutationToVector(shape, permutationVector);
|
||||
return RankedTensorType::Builder(tensorType).setShape(shape);
|
||||
}
|
||||
|
||||
/// Return a new GenericOp obtained by transposing opOperand by the permutation
|
||||
/// vector:
|
||||
/// - the corresponding indexing map is transposed by `permutation`
|
||||
/// - the corresponding operand value is replaced by `transposedValue`
|
||||
/// `linalgOp` is replaced by the return op in the process.
|
||||
/// Asserts that `transposedValue` is of the proper transposed ShapedType.
|
||||
static LinalgOp transposeOneLinalgOperandAndReplace(
|
||||
RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand,
|
||||
ArrayRef<int64_t> permutation, Value transposedValue) {
|
||||
// Sanity check the operand.
|
||||
assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand");
|
||||
|
||||
// Sanity check of the expected transposed tensor type.
|
||||
auto tensorType = permuteShape(
|
||||
opOperand.get().getType().cast<RankedTensorType>(), permutation);
|
||||
assert(tensorType == transposedValue.getType() &&
|
||||
"expected tensor type mismatch");
|
||||
|
||||
// Compute the transposed indexing map.
|
||||
// Sigh unsigned pollution.
|
||||
SmallVector<unsigned> tmpTransposition = llvm::to_vector(
|
||||
llvm::map_range(permutation, [](int64_t i) -> unsigned { return i; }));
|
||||
AffineMap permutationMap =
|
||||
AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext());
|
||||
AffineMap transposedMap =
|
||||
permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand));
|
||||
|
||||
// Set the transposed indexing map in the proper position.
|
||||
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
|
||||
indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
|
||||
// Set the transposedValue in the proper operand position.
|
||||
SmallVector<Value> operands = linalgOp->getOperands();
|
||||
operands[opOperand.getOperandNumber()] = transposedValue;
|
||||
|
||||
ValueRange operandsRef(operands);
|
||||
auto transposedGenericOp = rewriter.create<linalg::GenericOp>(
|
||||
/*location=*/linalgOp->getLoc(),
|
||||
/*resultTensorTypes=*/
|
||||
operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(),
|
||||
/*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()),
|
||||
/*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()),
|
||||
/*indexingMaps=*/indexingMaps,
|
||||
/*iteratorTypes=*/linalgOp.getIteratorTypesArray());
|
||||
transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
|
||||
rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
|
||||
|
||||
return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
|
||||
}
|
||||
|
||||
LogicalResult transform::PackTransposeOp::verify() {
|
||||
if (!isPermutationVector(getInnerPerm())) {
|
||||
return emitOpError() << getInnerPermAttrName()
|
||||
<< " is not a valid permutation";
|
||||
}
|
||||
if (!isPermutationVector(getOuterPerm())) {
|
||||
return emitOpError() << getOuterPermAttrName()
|
||||
<< " is not a valid permutation";
|
||||
}
|
||||
if (getInnerPerm().empty() && getOuterPerm().empty()) {
|
||||
return emitOpError() << " at least one of " << getInnerPermAttrName()
|
||||
<< " or " << getOuterPermAttrName()
|
||||
<< " must be specified";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
|
||||
transform::TransformState &state) {
|
||||
ArrayRef<Operation *> packOrUnpackOps =
|
||||
state.getPayloadOps(getTargetPackOrUnPackOp());
|
||||
ArrayRef<Operation *> linalgOps = state.getPayloadOps(getTargetLinalgOp());
|
||||
// Step 1. If nothing to pack, propagate success.
|
||||
if (packOrUnpackOps.empty()) {
|
||||
transformResults.set(getPackedOp().cast<OpResult>(), {});
|
||||
transformResults.set(getPackOp().cast<OpResult>(), {});
|
||||
transformResults.set(getUnPackOp().cast<OpResult>(), {});
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
// Step 2. Bunch of runtime sanity check and error messages.
|
||||
// Step 2.1. Fail on multi-op handles.
|
||||
if (packOrUnpackOps.size() != 1 || linalgOps.size() != 1) {
|
||||
return emitSilenceableError()
|
||||
<< "requires target to map to exactly 1 packing op and 1 packed op ("
|
||||
<< "got " << packOrUnpackOps.size() << " and " << linalgOps.size()
|
||||
<< ")";
|
||||
}
|
||||
|
||||
// Step 2.2. Fail on wrong type.
|
||||
auto packOp = dyn_cast<tensor::PackOp>(packOrUnpackOps.front());
|
||||
auto unPackOp = dyn_cast<tensor::UnPackOp>(packOrUnpackOps.front());
|
||||
if ((!packOp && !unPackOp)) {
|
||||
return emitSilenceableError() << "requires target to map to a "
|
||||
"tensor.pack or tensor.unpack";
|
||||
}
|
||||
LinalgOp linalgOpTarget = dyn_cast<linalg::LinalgOp>(linalgOps.front());
|
||||
if (!linalgOpTarget)
|
||||
return emitSilenceableError() << "requires a LinalgOp target";
|
||||
|
||||
// Step 2.3. Fail if we can't get the producer / consumer Linalg op.
|
||||
LinalgOp linalgOp;
|
||||
if (packOp && packOp.getResult().hasOneUse())
|
||||
linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
|
||||
else if (unPackOp)
|
||||
linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
|
||||
if (linalgOp != linalgOpTarget) {
|
||||
auto errorMsg =
|
||||
packOp ? StringLiteral{"not a single use by the LinalgOp target"}
|
||||
: StringLiteral{"not produced by the LinalgOp target"};
|
||||
return emitSilenceableError() << errorMsg;
|
||||
}
|
||||
|
||||
// Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical PackOp.
|
||||
if (unPackOp) {
|
||||
assert(!packOp && "packOp must be null on entry when unPackOp is not null");
|
||||
OpOperand *packUse = linalgOp.getDpsInitOperand(
|
||||
unPackOp.getSource().cast<OpResult>().getResultNumber());
|
||||
packOp = dyn_cast_or_null<tensor::PackOp>(packUse->get().getDefiningOp());
|
||||
if (!packOp || !packOp.getResult().hasOneUse())
|
||||
return emitSilenceableError() << "could not find matching pack op";
|
||||
}
|
||||
|
||||
// Step 2.5. Fail if any permutation does not validate.
|
||||
for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
|
||||
ArrayRef<int64_t> perm =
|
||||
(permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
|
||||
auto errorMsg = (permType == OuterOrInnerPerm::Outer)
|
||||
? StringLiteral{"invalid outer_perm"}
|
||||
: StringLiteral{"invalid inner_perm"};
|
||||
if (!isValidPackingPermutation(packOp, perm, permType) ||
|
||||
!isValidPackingPermutation(unPackOp, perm, permType)) {
|
||||
Operation *packOrUnpackOp =
|
||||
unPackOp ? unPackOp.getOperation() : packOp.getOperation();
|
||||
return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
|
||||
}
|
||||
}
|
||||
|
||||
// From here on, packOp and linalgOp are always present, unPackOp may or may
|
||||
// not be present.
|
||||
assert(packOp && linalgOp && "unexpected null op");
|
||||
|
||||
// Step 3. Actually transpose the ops.
|
||||
Location loc = linalgOp.getLoc();
|
||||
IRRewriter rewriter(getContext());
|
||||
|
||||
// Step 3.a. Transpose packOp.
|
||||
rewriter.setInsertionPoint(packOp);
|
||||
tensor::PackOp transposedPackOp = packOp.createTransposedClone(
|
||||
rewriter, loc, getInnerPerm(), getOuterPerm());
|
||||
|
||||
// Step 3.b. Transpose linalgOp.
|
||||
assert(packOp.getResult().hasOneUse() && "expect single use");
|
||||
// transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the
|
||||
// identity. Don't rely on it.
|
||||
int64_t numLeadingDims = packOp.getSourceRank();
|
||||
int64_t numTrailingDims = packOp.getInnerDimsPos().size();
|
||||
// Step 3.b.i. Compute the permutation on the whole operand.
|
||||
// Leading part just reuse the outerPerm.
|
||||
SmallVector<int64_t> permutation(getOuterPerm());
|
||||
if (permutation.empty())
|
||||
llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
|
||||
// Trailing part needs to reindex positions by `numLeadingDims`.
|
||||
if (getInnerPerm().empty()) {
|
||||
llvm::append_range(
|
||||
permutation,
|
||||
llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
|
||||
} else {
|
||||
llvm::append_range(permutation,
|
||||
llvm::map_range(getInnerPerm(), [&](int64_t pos) {
|
||||
return numLeadingDims + pos;
|
||||
}));
|
||||
}
|
||||
assert(isPermutationVector(permutation) && "invalid permutation");
|
||||
// Step 3.b.ii. Save the transposedPackUse operand number in case we need to
|
||||
// get the tied OpResult after `linalgOp` has been replaced.
|
||||
OpOperand &packUse = *(packOp.getResult().getUses().begin());
|
||||
int64_t packUseOperandNumber = packUse.getOperandNumber();
|
||||
// Step 3.b.iii. Actually perform the transposition.
|
||||
rewriter.setInsertionPoint(linalgOp);
|
||||
linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace(
|
||||
rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
|
||||
|
||||
// Step 3.c. Maybe transpose unPackOp.
|
||||
tensor::UnPackOp transposedUnPackOp;
|
||||
if (unPackOp) {
|
||||
OpOperand &opOperand =
|
||||
transposedLinalgOp->getOpOperand(packUseOperandNumber);
|
||||
OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
|
||||
rewriter.setInsertionPoint(unPackOp);
|
||||
transposedUnPackOp = unPackOp.createTransposedClone(
|
||||
rewriter, loc, transposedResult, getInnerPerm(), getOuterPerm());
|
||||
}
|
||||
|
||||
// Step 4. Replace and return results.
|
||||
rewriter.replaceOp(packOp, transposedPackOp->getResults());
|
||||
transformResults.set(getPackOp().cast<OpResult>(), {transposedPackOp});
|
||||
// transposedLinalgOp was replaced in `transposeOneLinalgOperandAndReplace`.
|
||||
transformResults.set(getPackedOp().cast<OpResult>(), {transposedLinalgOp});
|
||||
if (unPackOp) {
|
||||
rewriter.replaceOp(unPackOp, transposedUnPackOp->getResults());
|
||||
transformResults.set(getUnPackOp().cast<OpResult>(), {transposedUnPackOp});
|
||||
} else {
|
||||
transformResults.set(getUnPackOp().cast<OpResult>(), {});
|
||||
}
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
@@ -1359,7 +1616,7 @@ transform::ReplaceOp::apply(TransformResults &transformResults,
|
||||
if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
|
||||
target->getNumRegions() > 0)
|
||||
return emitDefiniteFailure()
|
||||
<< "expected target that is isloated from above";
|
||||
<< "expected target that is isolated from above";
|
||||
}
|
||||
|
||||
// Clone and replace.
|
||||
@@ -1907,32 +2164,31 @@ transform::TileOp::apply(TransformResults &transformResults,
|
||||
scf::SCFTilingOptions tilingOptions;
|
||||
unsigned index = en.index();
|
||||
if (!tileSizes.empty()) {
|
||||
tilingOptions.setTileSizeComputationFunction(
|
||||
[&, index](OpBuilder &b, Operation *) {
|
||||
SmallVector<Value, 4> sizes;
|
||||
sizes.reserve(tileSizes.size());
|
||||
unsigned dynamicIdx = 0;
|
||||
for (OpFoldResult ofr : getMixedSizes()) {
|
||||
if (auto attr = ofr.dyn_cast<Attribute>()) {
|
||||
sizes.push_back(b.create<arith::ConstantIndexOp>(
|
||||
getLoc(), attr.cast<IntegerAttr>().getInt()));
|
||||
continue;
|
||||
}
|
||||
ArrayRef<Operation *> dynamicSizes =
|
||||
dynamicSizeProducers[dynamicIdx];
|
||||
ArrayRef<int64_t> params = paramSizes[dynamicIdx];
|
||||
++dynamicIdx;
|
||||
assert((dynamicSizes.empty() ^ params.empty()) &&
|
||||
"expected either dynamic sizes or parameters");
|
||||
if (!params.empty()) {
|
||||
sizes.push_back(
|
||||
b.create<arith::ConstantIndexOp>(getLoc(), params[index]));
|
||||
} else {
|
||||
sizes.push_back(dynamicSizes[index]->getResult(0));
|
||||
}
|
||||
}
|
||||
return sizes;
|
||||
});
|
||||
tilingOptions.setTileSizeComputationFunction([&, index](OpBuilder &b,
|
||||
Operation *) {
|
||||
SmallVector<Value, 4> sizes;
|
||||
sizes.reserve(tileSizes.size());
|
||||
unsigned dynamicIdx = 0;
|
||||
for (OpFoldResult ofr : getMixedSizes()) {
|
||||
if (auto attr = ofr.dyn_cast<Attribute>()) {
|
||||
sizes.push_back(b.create<arith::ConstantIndexOp>(
|
||||
getLoc(), attr.cast<IntegerAttr>().getInt()));
|
||||
continue;
|
||||
}
|
||||
ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
|
||||
ArrayRef<int64_t> params = paramSizes[dynamicIdx];
|
||||
++dynamicIdx;
|
||||
assert((dynamicSizes.empty() ^ params.empty()) &&
|
||||
"expected either dynamic sizes or parameters");
|
||||
if (!params.empty()) {
|
||||
sizes.push_back(
|
||||
b.create<arith::ConstantIndexOp>(getLoc(), params[index]));
|
||||
} else {
|
||||
sizes.push_back(dynamicSizes[index]->getResult(0));
|
||||
}
|
||||
}
|
||||
return sizes;
|
||||
});
|
||||
}
|
||||
|
||||
tilingOptions.setInterchange(getInterchange());
|
||||
@@ -2149,27 +2405,27 @@ DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
|
||||
|
||||
// Transform all targets one by one.
|
||||
for (Operation *target : targets) {
|
||||
auto tilableOp = dyn_cast<TilingInterface>(target);
|
||||
if (!tilableOp) {
|
||||
auto tileableOp = dyn_cast<TilingInterface>(target);
|
||||
if (!tileableOp) {
|
||||
DiagnosedSilenceableFailure diag =
|
||||
transformOp.emitSilenceableError()
|
||||
<< "only TilingInterface ops are supported";
|
||||
diag.attachNote(target->getLoc()) << "target op";
|
||||
return diag;
|
||||
}
|
||||
rewriter.setInsertionPoint(tilableOp);
|
||||
rewriter.setInsertionPoint(tileableOp);
|
||||
FailureOr<linalg::ForeachThreadTilingResult> tilingResult = failure();
|
||||
if (!mixedNumThreads.empty()) {
|
||||
tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp,
|
||||
tilingResult = linalg::tileToForeachThreadOp(rewriter, tileableOp,
|
||||
mixedNumThreads, mapping);
|
||||
} else {
|
||||
tilingResult = linalg::tileToForeachThreadOpUsingTileSizes(
|
||||
rewriter, tilableOp, mixedTileSizes, mapping);
|
||||
rewriter, tileableOp, mixedTileSizes, mapping);
|
||||
}
|
||||
|
||||
if (failed(tilingResult))
|
||||
return transformOp.emitDefaultSilenceableFailure(tilableOp);
|
||||
rewriter.replaceOp(tilableOp, tilingResult->tileOp->getResults());
|
||||
return transformOp.emitDefaultSilenceableFailure(tileableOp);
|
||||
rewriter.replaceOp(tileableOp, tilingResult->tileOp->getResults());
|
||||
|
||||
tileOps.push_back(tilingResult->tileOp);
|
||||
tiledOps.push_back(tilingResult->tiledOp);
|
||||
|
||||
@@ -3231,7 +3231,6 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
|
||||
return true;
|
||||
}
|
||||
return shape == constTileSize.value();
|
||||
|
||||
})) {
|
||||
return op->emitError("mismatch in inner tile sizes specified and shaped of "
|
||||
"tiled dimension in the packed type");
|
||||
@@ -3239,6 +3238,57 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Subset of PackOp/UnPackOp fields used to compute the result of applying
|
||||
/// various permutations to the op.
|
||||
// TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
|
||||
// these. These may or may not become true foldings / canonicalizations
|
||||
// depending on how aggressive we want to be in automatically folding
|
||||
// transposes.
|
||||
struct PackOrUnPackTransposeResult {
|
||||
SmallVector<int64_t> innerDimsPos;
|
||||
SmallVector<OpFoldResult> innerTiles;
|
||||
SmallVector<int64_t> outerDimsPerm;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
template <typename OpTy>
|
||||
static PackOrUnPackTransposeResult
|
||||
commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp,
|
||||
ArrayRef<int64_t> innerPermutation,
|
||||
ArrayRef<int64_t> outerPermutation) {
|
||||
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
|
||||
"applies to only pack or unpack operations");
|
||||
assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
|
||||
"some permutation must be non-empty");
|
||||
PackOrUnPackTransposeResult metadata;
|
||||
metadata.innerDimsPos =
|
||||
SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
|
||||
metadata.innerTiles =
|
||||
SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
|
||||
int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
|
||||
? packOrUnPackOp.getSourceRank()
|
||||
: packOrUnPackOp.getDestRank();
|
||||
metadata.outerDimsPerm =
|
||||
packOrUnPackOp.getOuterDimsPerm().empty()
|
||||
? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
|
||||
: SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
|
||||
if (!innerPermutation.empty()) {
|
||||
assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
|
||||
isPermutationVector(innerPermutation) &&
|
||||
"invalid inner permutation");
|
||||
applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
|
||||
applyPermutationToVector(metadata.innerTiles, innerPermutation);
|
||||
}
|
||||
if (!outerPermutation.empty()) {
|
||||
assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
|
||||
isPermutationVector(outerPermutation) &&
|
||||
"invalid outer permutation");
|
||||
applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
|
||||
}
|
||||
return metadata;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PackOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -3386,6 +3436,19 @@ Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
|
||||
return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
|
||||
}
|
||||
|
||||
PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
|
||||
ArrayRef<int64_t> innerPermutation,
|
||||
ArrayRef<int64_t> outerPermutation) {
|
||||
PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
|
||||
*this, innerPermutation, outerPermutation);
|
||||
Value transposedDest =
|
||||
createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
|
||||
metadata.innerDimsPos, metadata.outerDimsPerm);
|
||||
return b.create<PackOp>(loc, getSource(), transposedDest,
|
||||
metadata.innerDimsPos, metadata.innerTiles,
|
||||
getPaddingValue(), metadata.outerDimsPerm);
|
||||
}
|
||||
|
||||
/// Returns true if the tiles and the tiled dims are constant.
|
||||
template <typename OpTy>
|
||||
bool areTilesAndTiledDimsAllConstant(OpTy op) {
|
||||
@@ -3508,6 +3571,17 @@ void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
|
||||
builder.getDenseI64ArrayAttr(staticTileSizes));
|
||||
}
|
||||
|
||||
UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
|
||||
Value transposedSource,
|
||||
ArrayRef<int64_t> innerPermutation,
|
||||
ArrayRef<int64_t> outerPermutation) {
|
||||
PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
|
||||
*this, innerPermutation, outerPermutation);
|
||||
return b.create<UnPackOp>(loc, transposedSource, getDest(),
|
||||
metadata.innerDimsPos, metadata.innerTiles,
|
||||
metadata.outerDimsPerm);
|
||||
}
|
||||
|
||||
/// pack(unpack(x)) -> x
|
||||
LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
|
||||
PatternRewriter &rewriter) {
|
||||
|
||||
@@ -49,21 +49,21 @@ transform.sequence failures(propagate) {
|
||||
iterator_types = ["reduction", "parallel"]
|
||||
}
|
||||
|
||||
// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[$PACKED_MAP_0:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
|
||||
// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2) -> (d1)>
|
||||
|
||||
// CHECK-LABEL: @col_reduction_2d_static
|
||||
// CHECK-SAME: %[[T0:.+]]: tensor<7x3xf16>,
|
||||
// CHECK-SAME: %[[T1:.+]]: tensor<3xf16>
|
||||
func.func @col_reduction_2d_static(%t0: tensor<7x3xf16>, %t1: tensor<3xf16>) -> tensor<3xf16> {
|
||||
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x4xf16>
|
||||
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<3x2x4xf16>
|
||||
// CHECK: %[[PACKED:.*]] = tensor.pack %[[T0]] padding_value(%{{.*}} : f16)
|
||||
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<7x3xf16> -> tensor<2x3x4xf16>
|
||||
// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<7x3xf16> -> tensor<3x2x4xf16>
|
||||
// CHECK-NOT: tensor.pack
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]]]
|
||||
// CHECK-SAME: iterator_types = ["reduction", "parallel", "reduction"]
|
||||
// CHECK-SAME: ins(%{{.*}} : tensor<2x3x4xf16>)
|
||||
// CHECK-SAME: ins(%{{.*}} : tensor<3x2x4xf16>)
|
||||
// CHECK-SAME: outs(%{{.*}} : tensor<3xf16>)
|
||||
%2 = linalg.generic #col_reduction_2d_trait ins(%t0 : tensor<7x3xf16>) outs(%t1 : tensor<3xf16>) {
|
||||
^bb0(%in: f16, %out: f16):
|
||||
@@ -78,8 +78,15 @@ func.func @col_reduction_2d_static(%t0: tensor<7x3xf16>, %t1: tensor<3xf16>) ->
|
||||
transform.sequence failures(propagate) {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
|
||||
transform.structured.pack %0 packed_sizes = [4, 0]
|
||||
%1 = transform.structured.pack %0 packed_sizes = [4, 0]
|
||||
: (!pdl.operation) -> (!transform.op<"linalg.generic">)
|
||||
%pack = transform.get_producer_of_operand %1[0]
|
||||
: (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.pack">)
|
||||
%2, %pack_2, %empty_unpack_2 =
|
||||
transform.structured.pack_transpose %pack with_compute_op(%1)
|
||||
outer_perm = [1, 0]
|
||||
: (!transform.op<"tensor.pack">, !transform.op<"linalg.generic">)
|
||||
-> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !pdl.operation)
|
||||
}
|
||||
|
||||
// -----
|
||||
@@ -183,7 +190,7 @@ transform.sequence failures(propagate) {
|
||||
// K N n k
|
||||
// CHECK-DAG: #[[$PACKED_MAP_1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)>
|
||||
// M N m n
|
||||
// CHECK-DAG: #[[$PACKED_MAP_2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
|
||||
// CHECK-DAG: #[[$PACKED_MAP_2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d4, d3)>
|
||||
|
||||
// CHECK-LABEL: @matmul
|
||||
// CHECK-SAME: %[[A:[0-9a-zA-Z]+]]: tensor<?x?xf32>,
|
||||
@@ -196,19 +203,19 @@ func.func @matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>)
|
||||
// CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?x2x4xf32>
|
||||
// CHECK: %[[PACK_B:.*]] = tensor.pack %{{.*}} inner_dims_pos = [1, 0] inner_tiles = [3, 4]
|
||||
// CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?x3x4xf32>
|
||||
// CHECK: %[[PACK_C:.*]] = tensor.pack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [2, 3]
|
||||
// CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?x2x3xf32>
|
||||
// CHECK: %[[PACK_C:.*]] = tensor.pack %{{.*}} outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [3, 2]
|
||||
// CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?x3x2xf32>
|
||||
|
||||
// CHECK: linalg.generic {indexing_maps = [#[[$PACKED_MAP_0]], #[[$PACKED_MAP_1]], #[[$PACKED_MAP_2]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
|
||||
// CHECK-SAME: ins(%{{.*}} : tensor<?x?x2x4xf32>, tensor<?x?x3x4xf32>)
|
||||
// CHECK-SAME: outs(%{{.*}} : tensor<?x?x2x3xf32>)
|
||||
// CHECK-SAME: outs(%{{.*}} : tensor<?x?x3x2xf32>)
|
||||
%0 = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%C: tensor<?x?xf32>)
|
||||
-> tensor<?x?xf32>
|
||||
|
||||
// CHECK: tensor.unpack %{{.*}} inner_dims_pos = [0, 1] inner_tiles = [2, 3]
|
||||
// CHECK-SAME: : tensor<?x?x2x3xf32> -> tensor<?x?xf32>
|
||||
// CHECK: tensor.unpack %{{.*}} outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [3, 2]
|
||||
// CHECK-SAME: : tensor<?x?x3x2xf32> -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
@@ -218,6 +225,14 @@ transform.sequence failures(propagate) {
|
||||
// M N K
|
||||
%1 = transform.structured.pack %0 packed_sizes = [2, 3, 4]
|
||||
: (!pdl.operation) -> (!transform.op<"linalg.generic">)
|
||||
|
||||
%unpack = transform.get_consumers_of_result %1[0]
|
||||
: (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.unpack">)
|
||||
%2, %pack_2, %unpack_2 =
|
||||
transform.structured.pack_transpose %unpack with_compute_op(%1)
|
||||
outer_perm = [1, 0] inner_perm = [1, 0]
|
||||
: (!transform.op<"tensor.unpack">, !transform.op<"linalg.generic">)
|
||||
-> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.op<"tensor.unpack">)
|
||||
}
|
||||
|
||||
// -----
|
||||
@@ -404,3 +419,177 @@ transform.sequence failures(propagate) {
|
||||
%1 = transform.structured.pack %0 packed_sizes = [2, 3]
|
||||
: (!pdl.operation) -> (!transform.op<"linalg.generic">)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @no_single_packing_op(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) {
|
||||
%0 = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
|
||||
%1 = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %source : tensor<4x16x32x16xf32> -> tensor<128x256xf32>
|
||||
%2 = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg1: !pdl.operation):
|
||||
%0 = transform.structured.match ops{["tensor.pack"]} in %arg1
|
||||
%1 = transform.structured.match ops{["tensor.unpack"]} in %arg1
|
||||
// expected-error @below {{requires target to map to exactly 1 packing op and 1 packed op (got 2 and 1)}}
|
||||
transform.structured.pack_transpose %0 with_compute_op(%1)
|
||||
inner_perm = [0]
|
||||
: (!pdl.operation, !pdl.operation)
|
||||
-> (!pdl.operation, !pdl.operation, !pdl.operation)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @no_single_pack_unpack(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) {
|
||||
%0 = arith.constant 0 : index
|
||||
%1 = tensor.empty() : tensor<f32>
|
||||
return
|
||||
}
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg1: !pdl.operation):
|
||||
%0 = transform.structured.match ops{["arith.constant"]} in %arg1
|
||||
%1 = transform.structured.match ops{["tensor.empty"]} in %arg1
|
||||
// expected-error @below {{requires target to map to a tensor.pack or tensor.unpack}}
|
||||
transform.structured.pack_transpose %0 with_compute_op(%1)
|
||||
inner_perm = [0]
|
||||
: (!pdl.operation, !pdl.operation)
|
||||
-> (!pdl.operation, !pdl.operation, !pdl.operation)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @no_linalg_target(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) {
|
||||
%0 = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
|
||||
%1 = arith.constant 0 : index
|
||||
return
|
||||
}
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg1: !pdl.operation):
|
||||
%0 = transform.structured.match ops{["tensor.pack"]} in %arg1
|
||||
%1 = transform.structured.match ops{["arith.constant"]} in %arg1
|
||||
// expected-error @below {{requires a LinalgOp target}}
|
||||
transform.structured.pack_transpose %0 with_compute_op(%1)
|
||||
inner_perm = [0]
|
||||
: (!pdl.operation, !pdl.operation)
|
||||
-> (!pdl.operation, !pdl.operation, !pdl.operation)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @no_single_use_by_linalg(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) {
|
||||
%0 = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
|
||||
%f0 = arith.constant 0.0 : f32
|
||||
%1 = tensor.empty() : tensor<f32>
|
||||
%2 = linalg.fill ins(%f0: f32) outs(%1 : tensor<f32>) -> tensor<f32>
|
||||
return
|
||||
}
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg1: !pdl.operation):
|
||||
%0 = transform.structured.match ops{["tensor.pack"]} in %arg1
|
||||
%1 = transform.structured.match ops{["linalg.fill"]} in %arg1
|
||||
// expected-error @below {{not a single use by the LinalgOp target}}
|
||||
transform.structured.pack_transpose %0 with_compute_op(%1)
|
||||
inner_perm = [0]
|
||||
: (!pdl.operation, !pdl.operation)
|
||||
-> (!pdl.operation, !pdl.operation, !pdl.operation)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @not_produced_by_linalg(%source: tensor<128x256xf32>, %dest: tensor<4x16x32x16xf32>) {
|
||||
%a = tensor.pack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<128x256xf32> -> tensor<4x16x32x16xf32>
|
||||
%b = tensor.unpack %a inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %source : tensor<4x16x32x16xf32> -> tensor<128x256xf32>
|
||||
%f0 = arith.constant 0.0 : f32
|
||||
%1 = tensor.empty() : tensor<f32>
|
||||
%2 = linalg.fill ins(%f0: f32) outs(%1 : tensor<f32>) -> tensor<f32>
|
||||
return
|
||||
}
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg1: !pdl.operation):
|
||||
%0 = transform.structured.match ops{["tensor.unpack"]} in %arg1
|
||||
%1 = transform.structured.match ops{["linalg.fill"]} in %arg1
|
||||
// expected-error @below {{not produced by the LinalgOp target}}
|
||||
transform.structured.pack_transpose %0 with_compute_op(%1)
|
||||
inner_perm = [0]
|
||||
: (!pdl.operation, !pdl.operation)
|
||||
-> (!pdl.operation, !pdl.operation, !pdl.operation)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @no_matching_pack(%source: tensor<16xf32>) {
|
||||
%f0 = arith.constant 0.0 : f32
|
||||
%1 = tensor.empty() : tensor<4x4xf32>
|
||||
%2 = linalg.fill ins(%f0: f32) outs(%1 : tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
%b = tensor.unpack %2 inner_dims_pos = [0] inner_tiles = [4] into %source : tensor<4x4xf32> -> tensor<16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg1: !pdl.operation):
|
||||
%0 = transform.structured.match ops{["tensor.unpack"]} in %arg1
|
||||
%1 = transform.structured.match ops{["linalg.fill"]} in %arg1
|
||||
// expected-error @below {{could not find matching pack op}}
|
||||
transform.structured.pack_transpose %0 with_compute_op(%1)
|
||||
inner_perm = [0]
|
||||
: (!pdl.operation, !pdl.operation)
|
||||
-> (!pdl.operation, !pdl.operation, !pdl.operation)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_outer_perm(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>)
|
||||
-> tensor<?x?xf32> {
|
||||
%0 = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%C: tensor<?x?xf32>)
|
||||
-> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg1: !pdl.operation):
|
||||
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
|
||||
%1 = transform.structured.pack %0 packed_sizes = [2, 3, 4]
|
||||
: (!pdl.operation) -> (!transform.op<"linalg.generic">)
|
||||
|
||||
%unpack = transform.get_consumers_of_result %1[0]
|
||||
: (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.unpack">)
|
||||
%2, %pack_2, %unpack_2 =
|
||||
// expected-error @below {{invalid outer_perm}}
|
||||
transform.structured.pack_transpose %unpack with_compute_op(%1)
|
||||
outer_perm = [1]
|
||||
: (!transform.op<"tensor.unpack">, !transform.op<"linalg.generic">)
|
||||
-> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.op<"tensor.unpack">)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_inner_perm(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>)
|
||||
-> tensor<?x?xf32> {
|
||||
%0 = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%C: tensor<?x?xf32>)
|
||||
-> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg1: !pdl.operation):
|
||||
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
|
||||
%1 = transform.structured.pack %0 packed_sizes = [2, 3, 4]
|
||||
: (!pdl.operation) -> (!transform.op<"linalg.generic">)
|
||||
|
||||
%unpack = transform.get_consumers_of_result %1[0]
|
||||
: (!transform.op<"linalg.generic">) -> (!transform.op<"tensor.unpack">)
|
||||
%2, %pack_2, %unpack_2 =
|
||||
// expected-error @below {{invalid inner_perm}}
|
||||
transform.structured.pack_transpose %unpack with_compute_op(%1)
|
||||
inner_perm = [1]
|
||||
: (!transform.op<"tensor.unpack">, !transform.op<"linalg.generic">)
|
||||
-> (!transform.op<"linalg.generic">, !transform.op<"tensor.pack">, !transform.op<"tensor.unpack">)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user