For consumer fusion cases of this form
```
%0:2 = scf.forall .. shared_outs(%arg0 = ..., %arg0 = ...) {
tensor.parallel_insert_slice ... into %arg0
tensor.parallel_insert_slice ... into %arg1
}
%1 = linalg.generic ... ins(%0#0, %0#1)
```
the current consumer fusion that handles one slice at a time cannot fuse
the consumer into the loop, since fusing along one slice will create and
SSA violation on the other use from the `scf.forall`. The solution is to
allow consumer fusion to allow considering multiple slices at once. This
PR changes the `TilingInterface` methods related to consumer fusion,
i.e.
- `getTiledImplementationFromOperandTile`
- `getIterationDomainFromOperandTile`
to allow fusion while considering multiple operands. It is upto the
`TilingInterface` implementation to return an error if a list of tiles
of the operands cannot result in a consistent implementation of the
tiled operation.
The Linalg operation implementation of `TilingInterface` has been
modified to account for these changes and allow cases where operand
tiles that can result in a consistent tiling implementation are handled.
---------
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
130 lines
5.4 KiB
TableGen
130 lines
5.4 KiB
TableGen
//===- TestTilingInterfaceTransformOps.td -----------------*- tablegen -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#ifndef TEST_TILINGINTERFACE_TRANSFORM_OPS
|
|
#define TEST_TILINGINTERFACE_TRANSFORM_OPS
|
|
|
|
include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
|
|
include "mlir/Dialect/Transform/IR/TransformDialect.td"
|
|
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
|
|
include "mlir/Dialect/Transform/IR/TransformTypes.td"
|
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
|
include "mlir/IR/OpBase.td"
|
|
|
|
// Those operations in this file are meant for testing the tiling interface
|
|
// transformations using scf operations. Over time these testing options
|
|
// might be useful transformations in their own right. Move these over
|
|
// as transform ops in the main repo (also find a proper place for them)
|
|
|
|
def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
|
|
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
|
|
DeclareOpInterfaceMethods<TransformOpInterface>,
|
|
ReportTrackingListenerFailuresOpTrait]> {
|
|
let description = [{
|
|
Tiles the operations pointed to by the target handle, fuses their
|
|
producers greedily using the options provided as attributes.
|
|
It also yields some of the fused producers for testing.
|
|
|
|
On success returns the tiled operations as well as generated loops. Emits
|
|
a definite failure if tiling fails.
|
|
}];
|
|
|
|
let arguments =
|
|
(ins TransformHandleTypeInterface:$target,
|
|
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
|
|
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange,
|
|
DefaultValuedAttr<BoolAttr, "false">:$use_forall);
|
|
let results = (outs TransformHandleTypeInterface:$transfomed,
|
|
Variadic<TransformHandleTypeInterface>:$loops);
|
|
|
|
let assemblyFormat = [{
|
|
$target ($tile_sizes^)? (`interchange` $tile_interchange^)?
|
|
(`use_forall` $use_forall^)? attr-dict
|
|
`:` functional-type(operands, results)
|
|
}];
|
|
}
|
|
|
|
def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
|
|
[AttrSizedOperandSegments,
|
|
DeclareOpInterfaceMethods<TransformOpInterface>,
|
|
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
|
ReportTrackingListenerFailuresOpTrait]> {
|
|
let description = [{
|
|
Fuses the consumer of the operation pointed to by the target handle
|
|
using the options provided as attributes.
|
|
}];
|
|
|
|
let arguments = (ins
|
|
Variadic<TransformHandleTypeInterface>:$targets,
|
|
Variadic<TransformHandleTypeInterface>:$loops,
|
|
DefaultValuedAttr<I32Attr, "1">:$num_consumer_to_fuse);
|
|
let results = (outs TransformHandleTypeInterface:$consumer,
|
|
TransformHandleTypeInterface:$fused_consumer);
|
|
|
|
let assemblyFormat = [{
|
|
$targets `in` `(` $loops `)`
|
|
(`num_consumer_to_fuse` `=` $num_consumer_to_fuse^)?
|
|
attr-dict `:` functional-type(operands, results)
|
|
}];
|
|
}
|
|
|
|
def TestTileUsingForallOp : Op<Transform_Dialect, "test.tile_using_forall",
|
|
[DeclareOpInterfaceMethods<TransformOpInterface>,
|
|
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
|
ReportTrackingListenerFailuresOpTrait]> {
|
|
let description = [{
|
|
Test operation use to test tiling using TilingInterface and scf.forall for
|
|
the loop constructs. This is similar to
|
|
`transform.structured.tile_using_for`. Use of this operation is an
|
|
intermediate state and will be replaced in due course with either
|
|
`transform.structured.tile_using_for` or
|
|
`transform.structured.tile_using_forall`.
|
|
|
|
On success returns the tiled operations as well as generated loops. Emits
|
|
a definite failure if tiling fails.
|
|
}];
|
|
|
|
let arguments = (ins TransformHandleTypeInterface:$target,
|
|
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
|
|
DefaultValuedOptionalAttr<I64ArrayAttr, "{}">:$interchange,
|
|
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
|
|
let results = (outs TransformHandleTypeInterface:$tiled_op,
|
|
Variadic<TransformHandleTypeInterface>:$loops);
|
|
|
|
let assemblyFormat = [{
|
|
$target ($tile_sizes^)? (`interchange` `=` $interchange^)?
|
|
(`mapping` `=` $mapping^)?
|
|
attr-dict `:` functional-type(operands, results)
|
|
}];
|
|
}
|
|
|
|
def TestFuseUsingForallOp : Op<Transform_Dialect, "test.fuse_using_forall",
|
|
[DeclareOpInterfaceMethods<TransformOpInterface>,
|
|
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
|
ReportTrackingListenerFailuresOpTrait]> {
|
|
let description = [{
|
|
Test operation to tile the operation pointed to by the target handle and
|
|
fuses their producers greedily using the options provided as attributes.
|
|
This operation uses scf.forall for the loop construct.
|
|
}];
|
|
let arguments = (ins TransformHandleTypeInterface:$root_op,
|
|
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
|
|
DefaultValuedOptionalAttr<I64ArrayAttr, "{}">:$interchange,
|
|
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
|
|
let results = (outs TransformHandleTypeInterface:$tiled_ops,
|
|
Variadic<TransformHandleTypeInterface>:$loops);
|
|
|
|
let assemblyFormat = [{
|
|
$root_op ($tile_sizes^)? (`interchange` $interchange^)?
|
|
(`mapping` `=` $mapping^)?
|
|
attr-dict `:` functional-type(operands, results)
|
|
}];
|
|
}
|
|
|
|
#endif // TEST_TILINGINTERFACE_TRANSFORM_OPS
|