[mlir] Remove Linalg fusion-on-memrefs.

PSA: https://discourse.llvm.org/t/psa-retire-tileandfuselinalgops-method/63850

Differential Revision: https://reviews.llvm.org/D141807
This commit is contained in:
Alexander Belyaev
2023-01-30 14:03:38 +01:00
parent db59654e23
commit dc37dc824a
15 changed files with 3 additions and 1852 deletions

View File

@@ -1,275 +0,0 @@
//===- DependenceAnalysis.h - Dependence analysis on SSA views --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
#define MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpDefinition.h"
#include <optional>
namespace mlir {
namespace func {
class FuncOp;
} // namespace func
namespace linalg {
class LinalgOp;
/// A very primitive alias analysis which just records for each view, either:
/// 1. The base buffer, or
/// 2. The block argument view
/// that it indexes into.
/// This does not perform inter-block or inter-procedural analysis and assumes
/// that different block argument views do not alias.
class Aliases {
public:
/// Returns true if v1 and v2 alias.
bool alias(Value v1, Value v2) { return find(v1) == find(v2); }
private:
/// Returns the base buffer or block argument into which the view `v` aliases.
/// This lazily records the new aliases discovered while walking back the
/// use-def chain.
Value find(Value v);
DenseMap<Value, Value> aliases;
};
/// Data structure for holding a dependence graph that operates on LinalgOp and
/// views as SSA values.
class LinalgDependenceGraph {
public:
enum DependenceType { RAR = 0, RAW, WAR, WAW, NumTypes };
// TODO: OpOperand tracks dependencies on buffer operands. Tensor result will
// need an extension to use OpResult.
struct LinalgDependenceGraphElem {
using OpView = PointerUnion<OpOperand *, Value>;
// dependentOpView may be either:
// 1. src in the case of dependencesIntoGraphs.
// 2. dst in the case of dependencesFromDstGraphs.
OpView dependentOpView;
// View in the op that is used to index in the graph:
// 1. src in the case of dependencesFromDstGraphs.
// 2. dst in the case of dependencesIntoGraphs.
OpView indexingOpView;
// Type of the dependence.
DependenceType dependenceType;
// Return the Operation that owns the operand or result represented in
// `opView`.
static Operation *getOwner(OpView opView) {
if (OpOperand *operand = opView.dyn_cast<OpOperand *>())
return operand->getOwner();
return opView.get<Value>().cast<OpResult>().getOwner();
}
// Return the operand or the result Value represented by the `opView`.
static Value getValue(OpView opView) {
if (OpOperand *operand = opView.dyn_cast<OpOperand *>())
return operand->get();
return opView.get<Value>();
}
// Return the indexing map of the operand/result in `opView` specified in
// the owning LinalgOp. If the owner is not a LinalgOp returns std::nullopt.
static std::optional<AffineMap> getIndexingMap(OpView opView) {
auto owner = dyn_cast<LinalgOp>(getOwner(opView));
if (!owner)
return std::nullopt;
if (OpOperand *operand = opView.dyn_cast<OpOperand *>())
return owner.getMatchingIndexingMap(operand);
return owner.getMatchingIndexingMap(owner.getDpsInitOperand(
opView.get<Value>().cast<OpResult>().getResultNumber()));
}
// Return the operand number if the `opView` is an OpOperand *. Otherwise
// return std::nullopt.
static std::optional<unsigned> getOperandNumber(OpView opView) {
if (OpOperand *operand = opView.dyn_cast<OpOperand *>())
return operand->getOperandNumber();
return std::nullopt;
}
// Return the result number if the `opView` is an OpResult. Otherwise return
// std::nullopt.
static std::optional<unsigned> getResultNumber(OpView opView) {
if (OpResult result = opView.dyn_cast<Value>().cast<OpResult>())
return result.getResultNumber();
return std::nullopt;
}
// Return the owner of the dependent OpView.
Operation *getDependentOp() const { return getOwner(dependentOpView); }
// Return the owner of the indexing OpView.
Operation *getIndexingOp() const { return getOwner(indexingOpView); }
// Return the operand or result stored in the dependentOpView.
Value getDependentValue() const { return getValue(dependentOpView); }
// Return the operand or result stored in the indexingOpView.
Value getIndexingValue() const { return getValue(indexingOpView); }
// If the dependent OpView is an operand, return operand number. Return
// std::nullopt otherwise.
std::optional<unsigned> getDependentOpViewOperandNum() const {
return getOperandNumber(dependentOpView);
}
// If the indexing OpView is an operand, return operand number. Return
// std::nullopt otherwise.
std::optional<unsigned> getIndexingOpViewOperandNum() const {
return getOperandNumber(indexingOpView);
}
// If the dependent OpView is a result value, return the result
// number. Return std::nullopt otherwise.
std::optional<unsigned> getDependentOpViewResultNum() const {
return getResultNumber(dependentOpView);
}
// If the dependent OpView is a result value, return the result
// number. Return std::nullopt otherwise.
std::optional<unsigned> getIndexingOpViewResultNum() const {
return getResultNumber(indexingOpView);
}
// Return the indexing map of the operand/result in the dependent OpView as
// specified in the owner of the OpView.
std::optional<AffineMap> getDependentOpViewIndexingMap() const {
return getIndexingMap(dependentOpView);
}
// Return the indexing map of the operand/result in the indexing OpView as
// specified in the owner of the OpView.
std::optional<AffineMap> getIndexingOpViewIndexingMap() const {
return getIndexingMap(indexingOpView);
}
};
using LinalgDependences = SmallVector<LinalgDependenceGraphElem, 8>;
using DependenceGraph = DenseMap<Operation *, LinalgDependences>;
using dependence_iterator = LinalgDependences::const_iterator;
using dependence_range = iterator_range<dependence_iterator>;
static StringRef getDependenceTypeStr(DependenceType depType);
// Builds a linalg dependence graph for the ops of type LinalgOp under `f`.
static LinalgDependenceGraph buildDependenceGraph(Aliases &aliases,
func::FuncOp f);
LinalgDependenceGraph(Aliases &aliases, ArrayRef<LinalgOp> ops);
/// Returns the X such that op -> X is a dependence of type dt.
dependence_range getDependencesFrom(Operation *src, DependenceType dt) const;
dependence_range getDependencesFrom(LinalgOp src, DependenceType dt) const;
/// Returns the X such that X -> op is a dependence of type dt.
dependence_range getDependencesInto(Operation *dst, DependenceType dt) const;
dependence_range getDependencesInto(LinalgOp dst, DependenceType dt) const;
/// Returns the operations that are interleaved between `srcLinalgOp` and
/// `dstLinalgOp` and that are involved in any RAW, WAR or WAW dependence
/// relation with `srcLinalgOp`, on any view.
/// Any such operation prevents reordering.
SmallVector<Operation *, 8>
findCoveringDependences(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp) const;
/// Returns the operations that are interleaved between `srcLinalgOp` and
/// `dstLinalgOp` and that are involved in a RAR or RAW with `srcLinalgOp`.
/// Dependences are restricted to views aliasing `view`.
SmallVector<Operation *, 8> findCoveringReads(LinalgOp srcLinalgOp,
LinalgOp dstLinalgOp,
Value view) const;
/// Returns the operations that are interleaved between `srcLinalgOp` and
/// `dstLinalgOp` and that are involved in a WAR or WAW with `srcLinalgOp`.
/// Dependences are restricted to views aliasing `view`.
SmallVector<Operation *, 8> findCoveringWrites(LinalgOp srcLinalgOp,
LinalgOp dstLinalgOp,
Value view) const;
/// Returns true if the two operations have the specified dependence from
/// `srcLinalgOp` to `dstLinalgOp`.
bool hasDependenceFrom(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp,
ArrayRef<DependenceType> depTypes = {
DependenceType::RAW, DependenceType::WAW}) const;
/// Returns true if the `linalgOp` has dependences into it.
bool hasDependentOperationsInto(LinalgOp linalgOp,
ArrayRef<DependenceType> depTypes = {
DependenceType::RAW,
DependenceType::WAW}) const;
/// Returns true if the `linalgOp` has dependences from it.
bool hasDependentOperationsFrom(LinalgOp linalgOp,
ArrayRef<DependenceType> depTypes = {
DependenceType::RAW,
DependenceType::WAW}) const;
/// Returns true if the `linalgOp` has dependences into or from it.
bool hasDependentOperations(LinalgOp linalgOp,
ArrayRef<DependenceType> depTypes = {
DependenceType::RAW,
DependenceType::WAW}) const;
/// Returns all operations that have a dependence into `linalgOp` of types
/// listed in `depTypes`.
SmallVector<LinalgDependenceGraphElem, 2> getDependentOperationsInto(
LinalgOp linalgOp, ArrayRef<DependenceType> depTypes = {
DependenceType::RAW, DependenceType::WAW}) const;
/// Returns all operations that have a dependence from `linalgOp` of types
/// listed in `depTypes`.
SmallVector<LinalgDependenceGraphElem, 2> getDependentOperationsFrom(
LinalgOp linalgOp, ArrayRef<DependenceType> depTypes = {
DependenceType::RAW, DependenceType::WAW}) const;
/// Returns all dependent operations (into and from) given `operation`.
SmallVector<LinalgDependenceGraphElem, 2>
getDependentOperations(LinalgOp linalgOp,
ArrayRef<DependenceType> depTypes = {
DependenceType::RAW, DependenceType::WAW}) const;
void print(raw_ostream &os) const;
void dump() const;
private:
// Keep dependences in both directions, this is not just a performance gain
// but it also reduces usage errors.
// Dependence information is stored as a map of:
// (source operation -> LinalgDependenceGraphElem)
DependenceGraph dependencesFromGraphs[DependenceType::NumTypes];
// Reverse dependence information is stored as a map of:
// (destination operation -> LinalgDependenceGraphElem)
DependenceGraph dependencesIntoGraphs[DependenceType::NumTypes];
/// Analyses the aliasing views between `src` and `dst` and inserts the proper
/// dependences in the graph.
void addDependencesBetween(LinalgOp src, LinalgOp dst);
// Adds an new dependence unit in the proper graph.
// Uses std::pair to keep operations and view together and avoid usage errors
// related to src/dst and producer/consumer terminology in the context of
// dependences.
void addDependenceElem(DependenceType dt,
LinalgDependenceGraphElem::OpView indexingOpView,
LinalgDependenceGraphElem::OpView dependentOpView);
/// Implementation detail for findCoveringxxx.
SmallVector<Operation *, 8>
findOperationsWithCoveringDependences(LinalgOp srcLinalgOp,
LinalgOp dstLinalgOp, Value view,
ArrayRef<DependenceType> types) const;
Aliases &aliases;
SmallVector<LinalgOp, 8> linalgOps;
DenseMap<Operation *, unsigned> linalgOpPositions;
};
} // namespace linalg
} // namespace mlir
#endif // MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_

View File

@@ -9,10 +9,8 @@
#ifndef MLIR_DIALECT_LINALG_UTILS_UTILS_H
#define MLIR_DIALECT_LINALG_UTILS_UTILS_H
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/StringSet.h"
#include <optional>
@@ -27,7 +25,6 @@ class ExtractSliceOp;
} // namespace tensor
namespace linalg {
class LinalgDependenceGraph;
//===----------------------------------------------------------------------===//
// General utilities
@@ -153,19 +150,6 @@ enum class LinalgTilingLoopType {
ParallelLoops = 2
};
/// Checks whether the specific `producer` is the last write to exactly the
/// whole `consumedView`. This checks structural dominance, that the dependence
/// is a RAW without any interleaved write to any piece of `consumedView`.
bool isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
LinalgOp consumer, Value consumedView,
LinalgOp producer);
/// Checks whether fusing the specific `producer` of the `consumedView` is
/// feasible. This checks `producer` is the last write of `consumedView` and
/// that no interleaved dependence would be violated (RAW, WAR or WAW).
bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer,
Value consumedView, LinalgOp producer);
/// Computes tile offsets, given a list of loop `ivs` and `tileSizes`. In case a
/// tile size is zero (i.e., no tiling), the corresponding offset is also zero.
SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc,
@@ -268,13 +252,6 @@ void offsetIndices(OpBuilder &b, LinalgOp linalgOp,
void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
ArrayRef<OpFoldResult> offests);
using FusableOpDependencesTy = llvm::MapVector<
Operation *,
SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>;
FusableOpDependencesTy
findAllFusableDependences(ArrayRef<LinalgOp> ops,
const LinalgDependenceGraph &dependenceGraph);
/// A struct containing the Linalg producer before and after fusion.
/// When operating on tensors, `fusedProducer` may feed into a `tensor.cast` op
/// before the consumer Linalg op, until enough canonicalizations have applied.
@@ -283,14 +260,6 @@ struct FusionInfo {
LinalgOp fusedProducer;
};
/// Fuses producer into consumer if the producer is structurally feasible and
/// the fusion would not violate dependencies.
/// Implements the fusion part of the "tileAndFuse on buffers" transformation
/// and thus requires the `consumerOpOperand` to be a `subview` op (generally
/// obtained by applying the tiling transformation).
FailureOr<FusionInfo> fuseProducerOfBuffer(OpBuilder &b,
OpOperand &consumerOpOperand,
const LinalgDependenceGraph &graph);
/// Tensor counterpart of `fuseProducerOfBuffer`.
/// This implements the fusion part of the "tileAndFuse on tensors"
/// transformation and thus requires the `consumerOpOperand` to be a

View File

@@ -1,13 +0,0 @@
add_mlir_dialect_library(MLIRLinalgAnalysis
DependenceAnalysis.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
LINK_LIBS PUBLIC
MLIRAffineAnalysis
MLIRAnalysis
MLIRIR
MLIRLinalgDialect
MLIRMemRefDialect
)

View File

@@ -1,366 +0,0 @@
//===- DependenceAnalysis.cpp - Dependence analysis on SSA views ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements view-based alias and dependence analyses.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "linalg-dependence-analysis"
using namespace mlir;
using namespace mlir::linalg;
using llvm::dbgs;
Value Aliases::find(Value v) {
if (v.isa<BlockArgument>())
return v;
auto it = aliases.find(v);
if (it != aliases.end()) {
assert(it->getSecond().getType().isa<BaseMemRefType>() &&
"Memref expected");
return it->getSecond();
}
while (true) {
if (v.isa<BlockArgument>())
return v;
Operation *defOp = v.getDefiningOp();
if (!defOp)
return v;
// Treat RegionBranchOpInterfaces like an allocate and don't try to follow
// the aliasing further.
if (isa<RegionBranchOpInterface>(defOp))
return v;
if (isa<bufferization::ToMemrefOp>(defOp))
return v;
if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(defOp)) {
// Collect all memory effects on `v`.
SmallVector<MemoryEffects::EffectInstance, 1> effects;
memEffect.getEffectsOnValue(v, effects);
// If we have the 'Allocate' memory effect on `v`, then `v` should be the
// original buffer.
if (llvm::any_of(
effects, [](const MemoryEffects::EffectInstance &instance) {
return isa<MemoryEffects::Allocate>(instance.getEffect());
}))
return v;
}
if (auto viewLikeOp = dyn_cast<ViewLikeOpInterface>(defOp)) {
auto it =
aliases.insert(std::make_pair(v, find(viewLikeOp.getViewSource())));
return it.first->second;
}
llvm::errs() << "View alias analysis reduces to: " << v << "\n";
llvm_unreachable("unsupported view alias case");
}
}
StringRef LinalgDependenceGraph::getDependenceTypeStr(DependenceType depType) {
switch (depType) {
case LinalgDependenceGraph::DependenceType::RAW:
return "RAW";
case LinalgDependenceGraph::DependenceType::RAR:
return "RAR";
case LinalgDependenceGraph::DependenceType::WAR:
return "WAR";
case LinalgDependenceGraph::DependenceType::WAW:
return "WAW";
default:
break;
}
llvm_unreachable("Unexpected DependenceType");
}
LinalgDependenceGraph
LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, func::FuncOp f) {
SmallVector<LinalgOp, 8> linalgOps;
f.walk([&](LinalgOp op) { linalgOps.push_back(op); });
return LinalgDependenceGraph(aliases, linalgOps);
}
LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
ArrayRef<LinalgOp> ops)
: aliases(aliases), linalgOps(ops.begin(), ops.end()) {
for (const auto &en : llvm::enumerate(linalgOps)) {
linalgOpPositions.insert(
std::make_pair(en.value().getOperation(), en.index()));
}
for (unsigned i = 0, e = ops.size(); i < e; ++i) {
for (unsigned j = i + 1; j < e; ++j) {
addDependencesBetween(ops[i], ops[j]);
}
}
}
void LinalgDependenceGraph::addDependenceElem(
DependenceType dt, LinalgDependenceGraphElem::OpView indexingOpView,
LinalgDependenceGraphElem::OpView dependentOpView) {
LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t ("
<< LinalgDependenceGraphElem::getValue(indexingOpView)
<< " @) -> \n\t\t("
<< LinalgDependenceGraphElem::getValue(dependentOpView)
<< " @)");
dependencesFromGraphs[dt][LinalgDependenceGraphElem::getOwner(indexingOpView)]
.push_back(
LinalgDependenceGraphElem{dependentOpView, indexingOpView, dt});
dependencesIntoGraphs[dt]
[LinalgDependenceGraphElem::getOwner(dependentOpView)]
.push_back(LinalgDependenceGraphElem{
indexingOpView, dependentOpView, dt});
}
LinalgDependenceGraph::dependence_range
LinalgDependenceGraph::getDependencesFrom(
LinalgOp src, LinalgDependenceGraph::DependenceType dt) const {
return getDependencesFrom(src.getOperation(), dt);
}
LinalgDependenceGraph::dependence_range
LinalgDependenceGraph::getDependencesFrom(
Operation *src, LinalgDependenceGraph::DependenceType dt) const {
auto iter = dependencesFromGraphs[dt].find(src);
if (iter == dependencesFromGraphs[dt].end())
return llvm::make_range(nullptr, nullptr);
return llvm::make_range(iter->second.begin(), iter->second.end());
}
LinalgDependenceGraph::dependence_range
LinalgDependenceGraph::getDependencesInto(
LinalgOp dst, LinalgDependenceGraph::DependenceType dt) const {
return getDependencesInto(dst.getOperation(), dt);
}
LinalgDependenceGraph::dependence_range
LinalgDependenceGraph::getDependencesInto(
Operation *dst, LinalgDependenceGraph::DependenceType dt) const {
auto iter = dependencesIntoGraphs[dt].find(dst);
if (iter == dependencesIntoGraphs[dt].end())
return llvm::make_range(nullptr, nullptr);
return llvm::make_range(iter->second.begin(), iter->second.end());
}
void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
LLVM_DEBUG(dbgs() << "addDependencesBetween " << *src.getOperation()
<< " and " << *dst.getOperation() << "\n");
if (src.hasTensorSemantics() && dst.hasTensorSemantics()) {
for (OpOperand *dstOpOperand : dst.getDpsInputOperands()) {
if (!dstOpOperand->get().getType().isa<RankedTensorType>())
continue;
// Check if the operand is defined by the src.
auto definingOp = dstOpOperand->get().getDefiningOp<LinalgOp>();
if (definingOp && definingOp == src)
addDependenceElem(DependenceType::RAW, dstOpOperand->get(),
dstOpOperand);
}
for (OpOperand *dstOpOperand : dst.getDpsInitOperands()) {
// Check if the operand is defined by the src.
auto definingOp = dstOpOperand->get().getDefiningOp<LinalgOp>();
if (definingOp && definingOp == src) {
if (dst.isInitTensor(dstOpOperand)) {
addDependenceElem(DependenceType::RAW, dstOpOperand->get(),
dstOpOperand);
}
addDependenceElem(DependenceType::WAW, dstOpOperand->get(),
dstOpOperand);
}
}
return;
}
assert(src.hasBufferSemantics() && dst.hasBufferSemantics() &&
"unhandled dependence tracking for mixed buffer/tensor operations");
for (OpOperand *srcOpOperand : src.getDpsInitOperands()) { // W
// RAW graph
for (OpOperand *dstOpOperand : dst.getDpsInputOperands()) { // R
if (!dstOpOperand->get().getType().isa<MemRefType>())
continue;
if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAW alias
addDependenceElem(DependenceType::RAW, srcOpOperand, dstOpOperand);
}
// WAW graph
for (OpOperand *dstOpOperand : dst.getDpsInitOperands()) // W
if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAW alias
addDependenceElem(DependenceType::WAW, srcOpOperand, dstOpOperand);
}
for (OpOperand *srcOpOperand : src.getDpsInputOperands()) { // R
if (!srcOpOperand->get().getType().isa<MemRefType>())
continue;
// RAR graph
for (OpOperand *dstOpOperand : dst.getDpsInputOperands()) { // R
if (!dstOpOperand->get().getType().isa<MemRefType>())
continue;
if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAR alias
addDependenceElem(DependenceType::RAR, srcOpOperand, dstOpOperand);
}
// WAR graph
for (OpOperand *dstOpOperand : dst.getDpsInitOperands()) // W
if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAR alias
addDependenceElem(DependenceType::WAR, srcOpOperand, dstOpOperand);
}
}
SmallVector<Operation *, 8>
LinalgDependenceGraph::findCoveringDependences(LinalgOp srcLinalgOp,
LinalgOp dstLinalgOp) const {
return findOperationsWithCoveringDependences(
srcLinalgOp, dstLinalgOp, nullptr,
{DependenceType::WAW, DependenceType::WAR, DependenceType::RAW});
}
SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringWrites(
LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const {
return findOperationsWithCoveringDependences(
srcLinalgOp, dstLinalgOp, view,
{DependenceType::WAW, DependenceType::WAR});
}
SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringReads(
LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const {
return findOperationsWithCoveringDependences(
srcLinalgOp, dstLinalgOp, view,
{DependenceType::RAR, DependenceType::RAW});
}
SmallVector<Operation *, 8>
LinalgDependenceGraph::findOperationsWithCoveringDependences(
LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view,
ArrayRef<DependenceType> types) const {
auto *src = srcLinalgOp.getOperation();
auto *dst = dstLinalgOp.getOperation();
auto srcPos = linalgOpPositions.lookup(src);
auto dstPos = linalgOpPositions.lookup(dst);
assert(srcPos < dstPos && "expected dst after src in IR traversal order");
SmallVector<Operation *, 8> res;
// Consider an intermediate interleaved `interim` op, look for any dependence
// to an aliasing view on a src -> op -> dst path.
// TODO: we are not considering paths yet, just interleaved positions.
for (auto dt : types) {
for (auto dependence : getDependencesFrom(src, dt)) {
auto interimPos = linalgOpPositions.lookup(dependence.getDependentOp());
// Skip if not interleaved.
if (interimPos >= dstPos || interimPos <= srcPos)
continue;
Value consumerView = dependence.getIndexingValue();
if (view && !aliases.alias(view, consumerView))
continue;
auto *op = dependence.getDependentOp();
LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type "
<< getDependenceTypeStr(dt) << ": " << *src << " -> "
<< *op << " on " << consumerView);
res.push_back(op);
}
}
return res;
}
bool LinalgDependenceGraph::hasDependenceFrom(
LinalgOp srcLinalgOp, LinalgOp dstLinalgOp,
ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
for (auto dep : depTypes)
for (auto dependence : getDependencesInto(dstLinalgOp, dep))
if (dependence.getDependentOp() == srcLinalgOp)
return true;
return false;
}
bool LinalgDependenceGraph::hasDependentOperationsFrom(
LinalgOp linalgOp,
ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
for (auto dep : depTypes) {
if (!getDependencesFrom(linalgOp, dep).empty())
return true;
}
return false;
}
bool LinalgDependenceGraph::hasDependentOperationsInto(
LinalgOp linalgOp,
ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
for (auto dep : depTypes) {
if (!getDependencesInto(linalgOp, dep).empty())
return true;
}
return false;
}
bool LinalgDependenceGraph::hasDependentOperations(
LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
return hasDependentOperationsInto(linalgOp, depTypes) ||
hasDependentOperationsFrom(linalgOp, depTypes);
}
SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
LinalgDependenceGraph::getDependentOperationsInto(
LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
dependentOperations;
for (auto dependenceType : depTypes) {
auto dependencies = getDependencesInto(linalgOp, dependenceType);
dependentOperations.append(dependencies.begin(), dependencies.end());
}
return dependentOperations;
}
SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
LinalgDependenceGraph::getDependentOperationsFrom(
LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
dependentOperations;
for (auto dependenceType : depTypes) {
auto dependencies = getDependencesFrom(linalgOp, dependenceType);
dependentOperations.append(dependencies.begin(), dependencies.end());
}
return dependentOperations;
}
/// Returns all dependent operations (into and from) given `operation`.
SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
LinalgDependenceGraph::getDependentOperations(
LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
SmallVector<LinalgDependenceGraphElem, 2> dependentOperations =
getDependentOperationsInto(linalgOp, depTypes);
SmallVector<LinalgDependenceGraphElem, 2> t =
getDependentOperationsFrom(linalgOp, depTypes);
dependentOperations.append(t.begin(), t.end());
return dependentOperations;
}
void LinalgDependenceGraph::print(raw_ostream &os) const {
for (auto dt : {
LinalgDependenceGraph::DependenceType::RAW,
LinalgDependenceGraph::DependenceType::WAW,
}) {
const auto &fromGraph = dependencesFromGraphs[dt];
for (const auto &it : fromGraph) {
os << "[LinalgDependenceGraph] DT " << dt << " from: " << *it.first
<< ":\n";
for (const auto &dep : it.second) {
os << "\tDT " << dt << " " << *dep.getDependentOp() << ":\n";
}
}
}
}
void LinalgDependenceGraph::dump() const { print(llvm::errs()); }

View File

@@ -1,4 +1,3 @@
add_subdirectory(Analysis)
add_subdirectory(IR)
add_subdirectory(TransformOps)
add_subdirectory(Transforms)

View File

@@ -56,7 +56,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRMemRefDialect
MLIRMemRefTransforms
MLIRLinalgDialect
MLIRLinalgAnalysis
MLIRLinalgUtils
MLIRSCFDialect
MLIRSCFTransforms

View File

@@ -12,7 +12,6 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -204,173 +203,6 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap,
return fuse(b, producerOp, fusedLoopsAndRanges);
}
// Encode structural fusion safety preconditions.
// Some of these will be lifted in the future with better analysis.
static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
LinalgOp consumer) {
assert(producer.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
assert(consumer.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
if (producer.getNumDpsInits() != 1) {
LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)");
return false;
}
// Only fuse when the producer block dominates.
DominanceInfo dom(producer.getOperation());
if (!dom.dominates(producer->getBlock(), consumer->getBlock())) {
LLVM_DEBUG(
llvm::dbgs()
<< "\nNot structurally fusable (producer block does not dominate)");
return false;
}
return true;
}
bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
LinalgOp consumer,
Value consumedView,
LinalgOp producer) {
assert(producer.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
assert(consumer.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
// Make some simple structural checks that alleviate the need for more
// complex analyses.
if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
LLVM_DEBUG(llvm::dbgs() << "\n***Not static last write due to structure:\t"
<< *producer.getOperation());
return false;
}
// Check for any interleaved write to consumedView.
if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
LLVM_DEBUG(llvm::dbgs() << "\n***Not fusable due to interleaved write:\t"
<< *producer.getOperation());
return false;
}
return true;
}
bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
LinalgOp consumer, Value consumedView,
LinalgOp producer) {
assert(producer.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
assert(consumer.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
return false;
// Check for any fusion-preventing dependence to any shape read/written that
// would violate dependences.
if (!graph.findCoveringDependences(producer, consumer).empty()) {
LLVM_DEBUG(llvm::dbgs()
<< "\n***Not fusable due to an interleaved dependence:\t"
<< *producer.getOperation());
return false;
}
return true;
}
/// For `consumer` with buffer semantics, find the Linalg operation on buffers
/// that is the last writer of `consumerOpOperand`. For now the fusable
/// dependence is returned as an instance of the `dependenceGraph`.
static FailureOr<LinalgDependenceGraph::LinalgDependenceGraphElem>
findFusableProducer(OpOperand &consumerOpOperand,
const LinalgDependenceGraph &dependenceGraph) {
LLVM_DEBUG(llvm::dbgs() << "findFusableProducer for: "
<< consumerOpOperand.get() << " @"
<< consumerOpOperand.getOperandNumber() << " in "
<< *consumerOpOperand.getOwner() << "\n");
LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner());
if (!consumerOp)
return failure();
// Only consider RAW and WAW atm.
for (auto depType : {
LinalgDependenceGraph::DependenceType::RAW,
LinalgDependenceGraph::DependenceType::WAW,
}) {
LLVM_DEBUG(llvm::dbgs()
<< "Dependencies into: " << *consumerOp.getOperation() << "\n");
for (auto dependence : llvm::make_filter_range(
dependenceGraph.getDependencesInto(consumerOp, depType),
[&](LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
LLVM_DEBUG(llvm::dbgs() << "Inspect dependence btw: "
<< elem.getIndexingValue() << " and "
<< elem.getDependentValue() << "\n");
Value v = elem.getIndexingValue();
std::optional<unsigned> operandNum =
elem.getIndexingOpViewOperandNum();
return isa<LinalgOp>(elem.getDependentOp()) &&
v == consumerOpOperand.get() && operandNum &&
*operandNum == consumerOpOperand.getOperandNumber();
})) {
// Consumer consumes this view, `isStructurallyFusableProducer` also
// checks whether it is a strict subview of the producer view.
auto producer = cast<LinalgOp>(dependence.getDependentOp());
LLVM_DEBUG(llvm::dbgs()
<< "\n"
<< LinalgDependenceGraph::getDependenceTypeStr(depType)
<< "producer: " << *dependence.getDependentOp()
<< " view: " << dependence.getDependentValue() << "\n");
// If the producer and consumer have tensor semantics, the only dependence
// between them is through a RAW dependence and they are fusable by
// construction. For buffer semantics need additional checks.
if (producer.hasBufferSemantics() && consumerOp.hasBufferSemantics() &&
isFusableInto(dependenceGraph, consumerOp, consumerOpOperand.get(),
producer))
return dependence;
if (producer.hasTensorSemantics() && consumerOp.hasTensorSemantics()) {
assert(dependence.dependenceType ==
LinalgDependenceGraph::DependenceType::RAW);
return dependence;
}
}
}
return failure();
}
FailureOr<FusionInfo>
mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand,
const LinalgDependenceGraph &graph) {
std::optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
fusableDependence = findFusableProducer(consumerOpOperand, graph);
if (!fusableDependence)
return failure();
LinalgOp producerOp = dyn_cast<LinalgOp>(fusableDependence->getDependentOp());
if (!producerOp)
return failure();
// If producer is already in the same block as consumer, we are done.
if (consumerOpOperand.get().getParentBlock() ==
fusableDependence->getDependentValue().getParentBlock())
return failure();
std::optional<AffineMap> producerMap =
fusableDependence->getDependentOpViewIndexingMap();
if (!producerMap)
return failure();
// Must be a subview or an extract_slice to guarantee there are loops we can
// fuse into.
auto subView = consumerOpOperand.get().getDefiningOp<memref::SubViewOp>();
if (!subView) {
LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview)");
return failure();
}
// Fuse `producer` just before `consumer`.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(consumerOpOperand.getOwner());
LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: "
<< *consumerOpOperand.getOwner() << "\n");
auto fusedProducer = fuse(b, producerOp, *producerMap, consumerOpOperand);
return FusionInfo{producerOp, fusedProducer};
}
/// Walk back use-def chain through scf::For yields.
/// Sets `producer` and `outputIndex` if it finds a producer LinalgOp

View File

@@ -11,7 +11,6 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"

View File

@@ -15,7 +15,6 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"

View File

@@ -1,51 +0,0 @@
// RUN: mlir-opt %s -test-linalg-greedy-fusion | FileCheck %s
func.func @f1(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>, %B: memref<?x?xf32, strided<[?, 1], offset: ?>>, %C: memref<?x?xf32, strided<[?, 1], offset: ?>>, %D: memref<?x?xf32, strided<[?, 1], offset: ?>>, %E: memref<?x?xf32, strided<[?, 1], offset: ?>>) -> memref<?x?xf32, strided<[?, 1], offset: ?>> {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%c40 = arith.constant 40 : index
%c30 = arith.constant 30 : index
%c20 = arith.constant 20 : index
%0 = memref.dim %C, %c0 : memref<?x?xf32, strided<[?, 1], offset: ?>>
%1 = memref.dim %C, %c1 : memref<?x?xf32, strided<[?, 1], offset: ?>>
%2 = memref.dim %D, %c1 : memref<?x?xf32, strided<[?, 1], offset: ?>>
linalg.matmul ins(%A, %B: memref<?x?xf32, strided<[?, 1], offset: ?>>, memref<?x?xf32, strided<[?, 1], offset: ?>>)
outs(%C: memref<?x?xf32, strided<[?, 1], offset: ?>>)
scf.for %arg5 = %c0 to %0 step %c20 {
scf.for %arg6 = %c0 to %2 step %c30 {
scf.for %arg7 = %c0 to %1 step %c40 {
%5 = memref.subview %C[%arg5, %arg7][%c20, %c40][%c1, %c1] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%7 = memref.subview %D[%arg7, %arg6][%c40, %c30][%c1, %c1]: memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%8 = memref.subview %E[%arg5, %arg6][%c20, %c40][%c1, %c1] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%9 = memref.dim %5, %c0 : memref<?x?xf32, strided<[?, ?], offset: ?>>
%10 = memref.dim %5, %c1 : memref<?x?xf32, strided<[?, ?], offset: ?>>
%11 = memref.dim %7, %c1 : memref<?x?xf32, strided<[?, ?], offset: ?>>
scf.for %arg8 = %c0 to %9 step %c2 {
scf.for %arg9 = %c0 to %11 step %c3 {
scf.for %arg10 = %c0 to %10 step %c4 {
%14 = memref.subview %5[%arg8, %arg10][%c2, %c4][%c1, %c1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%16 = memref.subview %7[%arg10, %arg9][%c4, %c3][%c1, %c1]: memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%17 = memref.subview %8[%arg8, %arg9][%c2, %c3][%c1, %c1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
linalg.matmul ins(%14, %16: memref<?x?xf32, strided<[?, ?], offset: ?>>, memref<?x?xf32, strided<[?, ?], offset: ?>>)
outs(%17: memref<?x?xf32, strided<[?, ?], offset: ?>>)
}
}
}
}
}
}
return %E : memref<?x?xf32, strided<[?, 1], offset: ?>>
}
// CHECK-LABEL: func @f1
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
// CHECK: scf.for
// CHECK: scf.for
// CHECK: scf.for
// CHECK: scf.for
// CHECK: scf.for
// CHECK: scf.for
// CHECK: linalg.matmul
// CHECK: linalg.matmul

View File

@@ -1,160 +0,0 @@
// RUN: mlir-opt %s -test-linalg-greedy-fusion -split-input-file | FileCheck %s
#id_2d = affine_map<(d0, d1) -> (d0, d1)>
#pointwise_2d_trait = {
indexing_maps = [#id_2d, #id_2d, #id_2d],
iterator_types = ["parallel", "parallel"]
}
func.func @fuse_indexed_consumer(%A: memref<?x?xf32>,
%B: memref<?x?xf32>,
%C: memref<?x?xf32>,
%D: memref<?x?xf32>) {
linalg.generic #pointwise_2d_trait
ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
outs(%C : memref<?x?xf32>) {
^bb0(%e: f32, %arg5: f32, %arg6: f32):
%2 = arith.addf %e, %arg5 : f32
linalg.yield %2 : f32
}
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c25 = arith.constant 25 : index
%c10 = arith.constant 10 : index
%0 = memref.dim %C, %c0 : memref<?x?xf32>
%1 = memref.dim %C, %c1 : memref<?x?xf32>
%2 = memref.dim %D, %c0 : memref<?x?xf32>
%3 = memref.dim %D, %c1 : memref<?x?xf32>
scf.for %arg2 = %c0 to %0 step %c10 {
scf.for %arg3 = %c0 to %1 step %c25 {
%4 = memref.subview %C[%arg2, %arg3][%c10, %c25][%c1, %c1] :
memref<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%5 = memref.subview %D[%arg2, %arg3][%c10, %c25][%c1, %c1] :
memref<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
linalg.generic {
indexing_maps = [#id_2d, #id_2d],
iterator_types = ["parallel", "parallel"]}
ins(%4 : memref<?x?xf32, strided<[?, ?], offset: ?>>)
outs(%5 : memref<?x?xf32, strided<[?, ?], offset: ?>>) {
^bb0(%arg4: f32, %arg5: f32):
%idx0 = linalg.index 0 : index
%idx1 = linalg.index 1 : index
%6 = arith.addi %idx0, %arg2 : index
%7 = arith.addi %idx1, %arg3 : index
%8 = arith.index_cast %6 : index to i32
%9 = arith.sitofp %8 : i32 to f32
%10 = arith.index_cast %7 : index to i32
%11 = arith.sitofp %10 : i32 to f32
%12 = arith.addf %9, %11 : f32
%13 = arith.addf %12, %arg4 : f32
linalg.yield %13 : f32
}
}
}
return
}
// CHECK-LABEL: func @fuse_indexed_consumer
// CHECK: scf.for
// CHECK: scf.for
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK-NOT: affine.apply
// CHECK: arith.addf
// CHECK: linalg.generic
// CHECK: arith.index_cast
// -----
func.func @fuse_indexed_producer(%A: memref<?x?xindex>,
%B: memref<?x?xindex>) {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c25 = arith.constant 25 : index
%c10 = arith.constant 10 : index
linalg.generic {
indexing_maps = [affine_map<(i, j) -> (j, i)>],
iterator_types = ["parallel", "parallel"]}
outs(%A : memref<?x?xindex>) {
^bb0(%a: index):
%idx0 = linalg.index 0 : index
%idx1 = linalg.index 1 : index
%0 = arith.addi %idx0, %idx1 : index
linalg.yield %0 : index
}
%A_X = memref.dim %A, %c0 : memref<?x?xindex>
%A_Y = memref.dim %A, %c1 : memref<?x?xindex>
scf.parallel (%arg2, %arg3) = (%c0, %c0) to (%A_X, %A_Y) step (%c10, %c25) {
%A_view = memref.subview %A[%arg2, %arg3][%c10, %c25][%c1, %c1] :
memref<?x?xindex> to memref<?x?xindex, strided<[?, ?], offset: ?>>
%B_view = memref.subview %B[%arg2, %arg3][%c10, %c25][%c1, %c1] :
memref<?x?xindex> to memref<?x?xindex, strided<[?, ?], offset: ?>>
linalg.generic {
indexing_maps = [affine_map<(i, j) -> (i, j)>,
affine_map<(i, j) -> (i, j)>],
iterator_types = ["parallel", "parallel"]}
ins(%A_view : memref<?x?xindex, strided<[?, ?], offset: ?>>)
outs(%B_view : memref<?x?xindex, strided<[?, ?], offset: ?>>) {
^bb0(%a: index, %b: index):
linalg.yield %a : index
}
}
return
}
// CHECK: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
// CHECK-LABEL: func @fuse_indexed_producer
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) =
// CHECK: linalg.generic
// CHECK: [[idx0:%.*]] = linalg.index 0 : index
// CHECK: [[i_new:%.*]] = affine.apply [[$MAP]]([[idx0]], [[J]])
// CHECK: [[idx1:%.*]] = linalg.index 1 : index
// CHECK: [[j_new:%.*]] = affine.apply [[$MAP]]([[idx1]], [[I]])
// CHECK: [[sum:%.*]] = arith.addi [[i_new]], [[j_new]] : index
// CHECK: linalg.yield [[sum]] : index
// CHECK: linalg.generic
// -----
func.func @fuse_indexed_producer_tiled_second_dim_only(%A: memref<?x?xindex>,
%B: memref<?x?xindex>) {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c25 = arith.constant 25 : index
linalg.generic {
indexing_maps = [affine_map<(i, j) -> (i, j)>],
iterator_types = ["parallel", "parallel"]}
outs(%A : memref<?x?xindex>) {
^bb0(%a: index):
%idx0 = linalg.index 0 : index
%idx1 = linalg.index 1 : index
%0 = arith.addi %idx0, %idx1 : index
linalg.yield %0 : index
}
%A_X = memref.dim %A, %c0 : memref<?x?xindex>
%A_Y = memref.dim %A, %c1 : memref<?x?xindex>
scf.parallel (%arg3) = (%c0) to (%A_Y) step (%c25) {
%A_view = memref.subview %A[%c0, %arg3][%A_X, %c25][%c1, %c1] :
memref<?x?xindex> to memref<?x?xindex, strided<[?, ?], offset: ?>>
%B_view = memref.subview %B[%c0, %arg3][%A_X, %c25][%c1, %c1] :
memref<?x?xindex> to memref<?x?xindex, strided<[?, ?], offset: ?>>
linalg.generic {
indexing_maps = [affine_map<(i, j) -> (i, j)>,
affine_map<(i, j) -> (i, j)>],
iterator_types = ["parallel", "parallel"]}
ins(%A_view : memref<?x?xindex, strided<[?, ?], offset: ?>>)
outs(%B_view : memref<?x?xindex, strided<[?, ?], offset: ?>>) {
^bb0(%a: index, %b: index):
linalg.yield %a : index
}
}
return
}
// CHECK: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
// CHECK-LABEL: func @fuse_indexed_producer_tiled_second_dim_only
// CHECK: scf.parallel ([[J:%.*]]) =
// CHECK: linalg.generic
// CHECK: [[idx0:%.*]] = linalg.index 0 : index
// CHECK: [[idx1:%.*]] = linalg.index 1 : index
// CHECK: [[j_new:%.*]] = affine.apply [[$MAP]]([[idx1]], [[J]])
// CHECK: [[sum:%.*]] = arith.addi [[idx0]], [[j_new]] : index
// CHECK: linalg.yield [[sum]] : index
// CHECK: linalg.generic

View File

@@ -1,745 +0,0 @@
// RUN: mlir-opt %s -test-linalg-greedy-fusion -split-input-file | FileCheck %s
func.func @f1(%A: memref<?x?xf32, strided<[?, 1], offset: 0>>,
%B: memref<?x?xf32, strided<[?, 1], offset: 0>>,
%C: memref<?x?xf32, strided<[?, 1], offset: 0>>,
%D: memref<?x?xf32, strided<[?, 1], offset: 0>>,
%E: memref<?x?xf32, strided<[?, 1], offset: 0>>
) -> memref<?x?xf32, strided<[?, 1], offset: 0>> {
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%0 = memref.dim %A, %c0 : memref<?x?xf32, strided<[?, 1], offset: 0>>
%1 = memref.dim %A, %c1 : memref<?x?xf32, strided<[?, 1], offset: 0>>
%2 = memref.dim %B, %c1 : memref<?x?xf32, strided<[?, 1], offset: 0>>
linalg.matmul ins(%A, %B : memref<?x?xf32, strided<[?, 1], offset: 0>>,
memref<?x?xf32, strided<[?, 1], offset: 0>>)
outs(%C : memref<?x?xf32, strided<[?, 1], offset: 0>>)
scf.for %arg5 = %c0 to %0 step %c2 {
scf.for %arg6 = %c0 to %2 step %c3 {
scf.for %arg7 = %c0 to %1 step %c4 {
%5 = memref.subview %A[%arg5, %arg7][%c2, %c4][%c1, %c1] :
memref<?x?xf32, strided<[?, 1], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
%7 = memref.subview %B[%arg7, %arg6][%c4, %c3][%c1, %c1] :
memref<?x?xf32, strided<[?, 1], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
%8 = memref.subview %C[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref<?x?xf32, strided<[?, 1], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
linalg.matmul ins(%5, %7 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
memref<?x?xf32, strided<[?, ?], offset: ?>>)
outs(%8: memref<?x?xf32, strided<[?, ?], offset: ?>>)
}
}
}
return %E : memref<?x?xf32, strided<[?, 1], offset: 0>>
}
// CHECK-LABEL: func @f1
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
// CHECK: scf.for
// CHECK: scf.for
// CHECK: scf.for
// CHECK: linalg.matmul
// CHECK: linalg.matmul
// -----
func.func @f2(%A: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%B: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%C: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%D: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%E: memref<?x?xf32, strided<[?, ?], offset: 0>>
) -> memref<?x?xf32, strided<[?, ?], offset: 0>> {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
linalg.matmul ins(%A, %B : memref<?x?xf32, strided<[?, ?], offset: 0>>,
memref<?x?xf32, strided<[?, ?], offset: 0>>)
outs(%C: memref<?x?xf32, strided<[?, ?], offset: 0>>)
%0 = memref.dim %C, %c0 : memref<?x?xf32, strided<[?, ?], offset: 0>>
%1 = memref.dim %C, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
%2 = memref.dim %D, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
scf.for %arg5 = %c0 to %0 step %c2 {
scf.for %arg6 = %c0 to %2 step %c3 {
scf.for %arg7 = %c0 to %1 step %c4 {
%5 = memref.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
%7 = memref.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
%8 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
linalg.matmul ins(%5, %7 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
memref<?x?xf32, strided<[?, ?], offset: ?>>)
outs(%8 : memref<?x?xf32, strided<[?, ?], offset: ?>>)
}
}
}
return %E : memref<?x?xf32, strided<[?, ?], offset: 0>>
}
// CHECK-LABEL: func @f2
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
// CHECK-DAG: %[[C_0:.*]] = memref.dim %[[C]], %c0{{[_0-9]*}} : memref<?x?xf32, strided<[?, ?]>>
// CHECK-DAG: %[[C_1:.*]] = memref.dim %[[C]], %c1{{[_0-9]*}} : memref<?x?xf32, strided<[?, ?]>>
// CHECK-DAG: %[[D_1:.*]] = memref.dim %[[D]], %c1{{[_0-9]*}} : memref<?x?xf32, strided<[?, ?]>>
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} {
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
// CHECK: linalg.matmul
// CHECK: linalg.matmul
// -----
func.func @f3(%A: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%B: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%C: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%D: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%E: memref<?x?xf32, strided<[?, ?], offset: 0>>
) -> memref<?x?xf32, strided<[?, ?], offset: 0>> {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
linalg.matmul ins(%A, %B : memref<?x?xf32, strided<[?, ?], offset: 0>>,
memref<?x?xf32, strided<[?, ?], offset: 0>>)
outs(%C : memref<?x?xf32, strided<[?, ?], offset: 0>>)
%0 = memref.dim %D, %c0 : memref<?x?xf32, strided<[?, ?], offset: 0>>
%1 = memref.dim %D, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
%2 = memref.dim %C, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
scf.for %arg5 = %c0 to %0 step %c2 {
scf.for %arg6 = %c0 to %2 step %c3 {
scf.for %arg7 = %c0 to %1 step %c4 {
%5 = memref.subview %D[%arg5, %arg7][%c2, %c4][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
%7 = memref.subview %C[%arg7, %arg6][%c4, %c3][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
%8 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
linalg.matmul ins(%5, %7 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
memref<?x?xf32, strided<[?, ?], offset: ?>>)
outs(%8 : memref<?x?xf32, strided<[?, ?], offset: ?>>)
}
}
}
return %E : memref<?x?xf32, strided<[?, ?], offset: 0>>
}
// CHECK-LABEL: func @f3
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[D_0:.*]] = memref.dim %[[D]], %[[C0]] : memref<?x?xf32, strided<[?, ?]>>
// CHECK: %[[D_1:.*]] = memref.dim %[[D]], %[[C1]] : memref<?x?xf32, strided<[?, ?]>>
// CHECK: %[[C_1:.*]] = memref.dim %[[C]], %[[C1]] : memref<?x?xf32, strided<[?, ?]>>
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} {
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
// CHECK: linalg.matmul
// CHECK: linalg.matmul
// -----
func.func @f4(%A: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%B: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%C: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%D: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%E: memref<?x?xf32, strided<[?, ?], offset: 0>>
) -> memref<?x?xf32, strided<[?, ?], offset: 0>> {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
linalg.matmul ins(%A, %B : memref<?x?xf32, strided<[?, ?], offset: 0>>,
memref<?x?xf32, strided<[?, ?], offset: 0>>)
outs(%C : memref<?x?xf32, strided<[?, ?], offset: 0>>)
linalg.matmul ins(%A, %B : memref<?x?xf32, strided<[?, ?], offset: 0>>,
memref<?x?xf32, strided<[?, ?], offset: 0>>)
outs(%D : memref<?x?xf32, strided<[?, ?], offset: 0>>)
%0 = memref.dim %C, %c0 : memref<?x?xf32, strided<[?, ?], offset: 0>>
%1 = memref.dim %C, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
%2 = memref.dim %D, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
scf.for %arg5 = %c0 to %0 step %c2 {
scf.for %arg6 = %c0 to %2 step %c3 {
scf.for %arg7 = %c0 to %1 step %c4 {
%5 = memref.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
%7 = memref.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
%8 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
linalg.matmul ins(%5, %7 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
memref<?x?xf32, strided<[?, ?], offset: ?>>)
outs(%8 : memref<?x?xf32, strided<[?, ?], offset: ?>>)
}
}
}
return %E : memref<?x?xf32, strided<[?, ?], offset: 0>>
}
// CHECK-LABEL: func @f4
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C_0:.*]] = memref.dim %[[C]], %[[C0:.*]] : memref<?x?xf32, strided<[?, ?]>>
// CHECK: %[[C_1:.*]] = memref.dim %[[C]], %[[C1:.*]] : memref<?x?xf32, strided<[?, ?]>>
// CHECK: %[[D_1:.*]] = memref.dim %[[D]], %[[C1:.*]] : memref<?x?xf32, strided<[?, ?]>>
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} {
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
// Fuse D then fuse C, no false dependence prevent it.
// CHECK: linalg.matmul
// CHECK: linalg.matmul
// CHECK: linalg.matmul
// -----
func.func @f5(%A: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%B: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%C: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%D: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%E: memref<?x?xf32, strided<[?, ?], offset: 0>>
) -> memref<?x?xf32, strided<[?, ?], offset: 0>> {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%0 = memref.dim %B, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
%1 = memref.dim %D, %c0 : memref<?x?xf32, strided<[?, ?], offset: 0>>
%2 = memref.dim %D, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
linalg.matmul ins(%A, %B : memref<?x?xf32, strided<[?, ?], offset: 0>>,
memref<?x?xf32, strided<[?, ?], offset: 0>>)
outs(%C : memref<?x?xf32, strided<[?, ?], offset: 0>>)
linalg.matmul ins(%C, %B : memref<?x?xf32, strided<[?, ?], offset: 0>>,
memref<?x?xf32, strided<[?, ?], offset: 0>>)
outs(%D : memref<?x?xf32, strided<[?, ?], offset: 0>>)
scf.for %arg5 = %c0 to %1 step %c2 {
scf.for %arg6 = %c0 to %0 step %c3 {
scf.for %arg7 = %c0 to %2 step %c4 {
%5 = memref.subview %D[%arg5, %arg7][%c2, %c4][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
%7 = memref.subview %B[%arg7, %arg6][%c4, %c3][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
%8 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
linalg.matmul ins(%5, %7 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
memref<?x?xf32, strided<[?, ?], offset: ?>>)
outs(%8 : memref<?x?xf32, strided<[?, ?], offset: ?>>)
}
}
}
return %E : memref<?x?xf32, strided<[?, ?], offset: 0>>
}
// CHECK-DAG: #[[BOUND_2_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)>
// CHECK-DAG: #[[BOUND_2_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 2)>
// CHECK-DAG: #[[BOUND_4_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)>
// CHECK: func @f5
// CHECK-SAME: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[A_0:.*]] = memref.dim %[[A]], %[[C0]] : memref<?x?xf32, strided<[?, ?]>>
// CHECK-DAG: %[[B_1:.*]] = memref.dim %[[B]], %[[C1]] : memref<?x?xf32, strided<[?, ?]>>
// CHECK-DAG: %[[C_0:.*]] = memref.dim %[[C]], %[[C0]] : memref<?x?xf32, strided<[?, ?]>>
// CHECK-DAG: %[[D_0:.*]] = memref.dim %[[D]], %[[C0]] : memref<?x?xf32, strided<[?, ?]>>
// CHECK-DAG: %[[D_1:.*]] = memref.dim %[[D]], %[[C1]] : memref<?x?xf32, strided<[?, ?]>>
// CHECK-DAG: %[[B_00:.*]] = memref.subview %[[B]][0, 0]{{.*}}
// CHECK: scf.for %[[I:.*]] = %{{.*}} to %[[D_0]] step %{{.*}} {
// CHECK: %[[BOUND_2_C0:.+]] = affine.min #[[BOUND_2_MAP]](%[[I]])[%[[C_0]]]
// CHECK: %[[C_I0:.*]] = memref.subview %[[C]][%[[I]], 0] [%[[BOUND_2_C0]]
// CHECK: %[[BOUND_ID_C0:.+]] = affine.min #[[BOUND_2_MAP_2]](%[[I]])[%[[A_0]], %[[C_0]]]
// CHECK: %[[A_I0:.*]] = memref.subview %[[A]][%[[I]], 0]
// CHECK: %[[C_I0_OUT:.*]] = memref.subview %[[C]][%[[I]], 0] [%[[BOUND_ID_C0]]
// CHECK: scf.for %[[J:.*]] = %{{.*}} to %[[B_1]] step %{{.*}} {
// CHECK: %[[E_IJ:.*]] = memref.subview %[[E]][%[[I]], %[[J]]]
// CHECK: scf.for %[[K:.*]] = %{{.*}} to %[[D_1]] step %{{.*}} {
// CHECK: %[[D_IK:.*]] = memref.subview %[[D]][%[[I]], %[[K]]] [2, 4]
// CHECK: %[[B_KJ:.*]] = memref.subview %[[B]][%[[K]], %[[J]]]
// CHECK: %[[BOUND_4_B1:.*]] = affine.min #[[BOUND_4_MAP]](%[[K]])[%[[B_1]]]
// CHECK: %[[B_0K:.*]] = memref.subview %[[B]][0, %[[K]]]
// CHECK: %[[D_IK_OUT:.+]] = memref.subview %[[D]][%[[I]], %[[K]]] [%[[BOUND_2_C0]], %[[BOUND_4_B1]]]
// CHECK: linalg.matmul ins(%[[A_I0]], %[[B_00]]{{.*}} outs(%[[C_I0_OUT]]
// CHECK: linalg.matmul ins(%[[C_I0]], %[[B_0K]]{{.*}} outs(%[[D_IK_OUT]]
// CHECK: linalg.matmul ins(%[[D_IK]], %[[B_KJ]]{{.*}} outs(%[[E_IJ]]
// -----
#map0 = affine_map<(d0) -> (d0 + 2)>
#map1 = affine_map<(d0) -> (d0 + 4)>
#map2 = affine_map<(d0) -> (d0 + 3)>
func.func @f6(%A: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%B: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%C: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%D: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%E: memref<?x?xf32, strided<[?, ?], offset: 0>>
) -> memref<?x?xf32, strided<[?, ?], offset: 0>> {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%0 = memref.dim %C, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
linalg.matmul ins(%A, %B : memref<?x?xf32, strided<[?, ?], offset: 0>>,
memref<?x?xf32, strided<[?, ?], offset: 0>>)
outs(%C : memref<?x?xf32, strided<[?, ?], offset: 0>>)
linalg.matmul ins(%A, %C : memref<?x?xf32, strided<[?, ?], offset: 0>>,
memref<?x?xf32, strided<[?, ?], offset: 0>>)
outs(%E : memref<?x?xf32, strided<[?, ?], offset: 0>>)
%1 = memref.dim %C, %c0 : memref<?x?xf32, strided<[?, ?], offset: 0>>
%2 = memref.dim %D, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
scf.for %arg5 = %c0 to %1 step %c2 {
scf.for %arg6 = %c0 to %2 step %c3 {
scf.for %arg7 = %c0 to %0 step %c4 {
%3 = affine.apply #map0(%arg5)
%4 = affine.apply #map1(%arg7)
%5 = memref.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
%6 = affine.apply #map2(%arg6)
%7 = memref.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
%8 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
linalg.matmul ins(%5, %7 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
memref<?x?xf32, strided<[?, ?], offset: ?>>)
outs(%8 : memref<?x?xf32, strided<[?, ?], offset: ?>>)
}
}
}
return %E : memref<?x?xf32, strided<[?, ?], offset: 0>>
}
// CHECK-LABEL: func @f6
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
// Fuse the producer of E (WAW) then the producer of C (WAR).
// CHECK: scf.for
// CHECK: scf.for
// CHECK: scf.for
// CHECK: linalg.matmul
// CHECK: linalg.matmul
// CHECK: linalg.matmul
// -----
func.func @f7(%A: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%B: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%C: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%D: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%E: memref<?x?xf32, strided<[?, ?], offset: 0>>
) -> memref<?x?xf32, strided<[?, ?], offset: 0>> {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%0 = memref.dim %A, %c0 : memref<?x?xf32, strided<[?, ?], offset: 0>>
%1 = memref.dim %A, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
%2 = memref.dim %C, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
%3 = memref.dim %C, %c0 : memref<?x?xf32, strided<[?, ?], offset: 0>>
%4 = memref.dim %D, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
linalg.matmul ins(%A, %C : memref<?x?xf32, strided<[?, ?], offset: 0>>,
memref<?x?xf32, strided<[?, ?], offset: 0>>)
outs(%E : memref<?x?xf32, strided<[?, ?], offset: 0>>)
linalg.matmul ins(%A, %B : memref<?x?xf32, strided<[?, ?], offset: 0>>,
memref<?x?xf32, strided<[?, ?], offset: 0>>)
outs(%C : memref<?x?xf32, strided<[?, ?], offset: 0>>)
scf.for %arg5 = %c0 to %0 step %c2 {
scf.for %arg6 = %c0 to %2 step %c3 {
scf.for %arg7 = %c0 to %1 step %c4 {
%7 = memref.subview %A[%arg5, %arg7][%c2, %c4][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
%9 = memref.subview %C[%arg7, %arg6][%c4, %c3][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
%10 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
linalg.matmul ins(%7, %9 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
memref<?x?xf32, strided<[?, ?], offset: ?>>)
outs(%10 : memref<?x?xf32, strided<[?, ?], offset: ?>>)
}
}
}
scf.for %arg5 = %c0 to %3 step %c2 {
scf.for %arg6 = %c0 to %4 step %c3 {
scf.for %arg7 = %c0 to %2 step %c4 {
%7 = memref.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
%9 = memref.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
%10 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
linalg.matmul ins(%7, %9 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
memref<?x?xf32, strided<[?, ?], offset: ?>>)
outs(%10 : memref<?x?xf32, strided<[?, ?], offset: ?>>)
}
}
}
return %E : memref<?x?xf32, strided<[?, ?], offset: 0>>
}
// CHECK-LABEL: func @f7
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[A_0:.*]] = memref.dim %[[A]], %[[C0:.*]] : memref<?x?xf32, strided<[?, ?]>>
// CHECK: %[[A_1:.*]] = memref.dim %[[A]], %[[C1:.*]] : memref<?x?xf32, strided<[?, ?]>>
// CHECK: %[[C_1:.*]] = memref.dim %[[C]], %[[C1:.*]] : memref<?x?xf32, strided<[?, ?]>>
// CHECK: %[[C_0:.*]] = memref.dim %[[C]], %[[C0:.*]] : memref<?x?xf32, strided<[?, ?]>>
// CHECK: %[[D_1:.*]] = memref.dim %[[D]], %[[C1:.*]] : memref<?x?xf32, strided<[?, ?]>>
// CHECK: linalg.matmul ins(%[[A]], %[[C]]{{.*}} outs(%[[E]]
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[A_0]] step %{{.*}} {
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[A_1]] step %{{.*}} {
// CHECK: linalg.matmul
// CHECK: linalg.matmul
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} {
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
// CHECK: linalg.matmul
// CHECK-NOT: linalg.matmul
// -----
#map0 = affine_map<(d0) -> (d0 + 2)>
#map1 = affine_map<(d0) -> (d0 + 4)>
#map2 = affine_map<(d0) -> (d0 + 3)>
func.func @f8(%A: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%B: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%C: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%D: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%E: memref<?x?xf32, strided<[?, ?], offset: 0>>
) -> memref<?x?xf32, strided<[?, ?], offset: 0>> {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%0 = memref.dim %A, %c0 : memref<?x?xf32, strided<[?, ?], offset: 0>>
%1 = memref.dim %A, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
linalg.matmul ins(%A, %C : memref<?x?xf32, strided<[?, ?], offset: 0>>,
memref<?x?xf32, strided<[?, ?], offset: 0>>)
outs(%D : memref<?x?xf32, strided<[?, ?], offset: 0>>)
linalg.matmul ins(%A, %B : memref<?x?xf32, strided<[?, ?], offset: 0>>,
memref<?x?xf32, strided<[?, ?], offset: 0>>)
outs(%C : memref<?x?xf32, strided<[?, ?], offset: 0>>)
%2 = memref.dim %D, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
scf.for %arg5 = %c0 to %0 step %c2 {
scf.for %arg6 = %c0 to %2 step %c3 {
scf.for %arg7 = %c0 to %1 step %c4 {
%3 = affine.apply #map0(%arg5)
%4 = affine.apply #map1(%arg7)
%5 = memref.subview %A[%arg5, %arg7][%c2, %c4][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
%6 = affine.apply #map2(%arg6)
%7 = memref.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
%8 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
linalg.matmul ins(%5, %7 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
memref<?x?xf32, strided<[?, ?], offset: ?>>)
outs(%8 : memref<?x?xf32, strided<[?, ?], offset: ?>>)
}
}
}
return %E : memref<?x?xf32, strided<[?, ?], offset: 0>>
}
// CHECK-LABEL: func @f8
// CHECK: (%[[A:.*]]: memref{{.*}}, %[[B:.*]]: memref{{.*}}, %[[C:.*]]: memref{{.*}}, %[[D:.*]]: memref{{.*}}, %[[E:.*]]: memref{{.*}})
// CHECK: linalg.matmul
// CHECK: linalg.matmul
// CHECK: scf.for
// CHECK: scf.for
// CHECK: scf.for
// CHECK: linalg.matmul
// CHECK-NOT: linalg.matmul
// -----
#id_2d = affine_map<(i, j) -> (i, j)>
#pointwise_2d_trait = {
indexing_maps = [#id_2d, #id_2d, #id_2d],
iterator_types = ["parallel", "parallel"]
}
func.func @pointwise(%A: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%B: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%C: memref<?x?xf32, strided<[?, ?], offset: 0>>,
%D: memref<?x?xf32, strided<[?, ?], offset: 0>>) {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
linalg.generic #pointwise_2d_trait
ins(%A, %A: memref<?x?xf32, strided<[?, ?], offset: 0>>,
memref<?x?xf32, strided<[?, ?], offset: 0>>)
outs(%B : memref<?x?xf32, strided<[?, ?], offset: 0>>) {
^bb0(%E: f32, %arg5: f32, %arg6: f32):
%2 = arith.addf %E, %arg5 : f32
linalg.yield %2 : f32
}
%0 = memref.dim %B, %c0 : memref<?x?xf32, strided<[?, ?], offset: 0>>
%1 = memref.dim %B, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
scf.for %arg4 = %c0 to %0 step %c2 {
scf.for %arg5 = %c0 to %1 step %c3 {
%4 = memref.subview %B[%arg4, %arg5][%c2, %c3][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
%5 = memref.subview %C[%arg4, %arg5][%c2, %c3][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
%6 = memref.subview %D[%arg4, %arg5][%c2, %c3][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
linalg.generic #pointwise_2d_trait
ins(%4, %5: memref<?x?xf32, strided<[?, ?], offset: ?>>,
memref<?x?xf32, strided<[?, ?], offset: ?>>)
outs(%6 : memref<?x?xf32, strided<[?, ?], offset: ?>>) {
^bb0(%arg6: f32, %arg7: f32, %arg8: f32):
%7 = arith.mulf %arg6, %arg7 : f32
linalg.yield %7 : f32
}
}
}
return
}
// CHECK-LABEL: func @pointwise
// CHECK: scf.for
// CHECK: scf.for
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: arith.addf
// CHECK: linalg.generic
// CHECK: arith.mulf
// -----
#id_2d = affine_map<(i, j) -> (i, j)>
#pointwise_2d_trait = {
indexing_maps = [#id_2d, #id_2d, #id_2d],
iterator_types = ["parallel", "parallel"]
}
func.func @pointwise_no_view(%M: index, %N: index) {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c3 = arith.constant 3 : index
%c2 = arith.constant 2 : index
%A = memref.alloc (%M, %N): memref<?x?xf32>
%B = memref.alloc (%M, %N): memref<?x?xf32>
%C = memref.alloc (%M, %N): memref<?x?xf32>
%D = memref.alloc (%M, %N): memref<?x?xf32>
%E = memref.alloc (%M, %N): memref<?x?xf32>
linalg.generic #pointwise_2d_trait
ins(%A, %A : memref<?x?xf32>, memref<?x?xf32>)
outs(%B : memref<?x?xf32>) {
^bb0(%e: f32, %arg5: f32, %arg6: f32):
%2 = arith.addf %e, %arg5 : f32
linalg.yield %2 : f32
}
%0 = memref.dim %B, %c0 : memref<?x?xf32>
%1 = memref.dim %B, %c1 : memref<?x?xf32>
scf.for %arg4 = %c0 to %0 step %c2 {
scf.for %arg5 = %c0 to %1 step %c3 {
%4 = memref.subview %B[%arg4, %arg5][%c2, %c3][%c1, %c1] :
memref<?x?xf32> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
%5 = memref.subview %C[%arg4, %arg5][%c2, %c3][%c1, %c1] :
memref<?x?xf32> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
%6 = memref.subview %D[%arg4, %arg5][%c2, %c3][%c1, %c1] :
memref<?x?xf32> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
linalg.generic #pointwise_2d_trait
ins(%4, %5: memref<?x?xf32, strided<[?, ?], offset: ?>>,
memref<?x?xf32, strided<[?, ?], offset: ?>>)
outs(%6 : memref<?x?xf32, strided<[?, ?], offset: ?>>) {
^bb0(%arg6: f32, %arg7: f32, %arg8: f32):
%7 = arith.mulf %arg6, %arg7 : f32
linalg.yield %7 : f32
}
}
}
return
}
// CHECK-LABEL: func @pointwise_no_view
// CHECK: scf.for
// CHECK: scf.for
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: arith.addf
// CHECK: linalg.generic
// CHECK: arith.mulf
// -----
#map0 = affine_map<(d0, d1) -> (d0)>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
func.func @fusion_of_three(%arg0: memref<100x10xf32>,
%arg1: memref<100xf32>,
%arg2: memref<100x10xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = memref.alloc() {temp = true} : memref<100x10xf32>
linalg.generic {
indexing_maps = [#map0, #map1],
iterator_types = ["parallel", "parallel"]}
ins(%arg1 : memref<100xf32>)
outs(%0 : memref<100x10xf32>) {
^bb0(%arg3: f32, %arg4: f32):
linalg.yield %arg3 : f32
}
%1 = memref.alloc() {temp = true} : memref<100x10xf32>
linalg.generic {
indexing_maps = [#map1, #map1, #map1],
iterator_types = ["parallel", "parallel"]}
ins(%arg0, %0: memref<100x10xf32>, memref<100x10xf32>)
outs(%1 : memref<100x10xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
%2 = arith.subf %arg3, %arg4 : f32
linalg.yield %2 : f32
}
memref.dealloc %0 : memref<100x10xf32>
%2 = memref.dim %1, %c0 : memref<100x10xf32>
%3 = memref.dim %1, %c1 : memref<100x10xf32>
%4 = memref.dim %arg2, %c0 : memref<100x10xf32>
%5 = memref.dim %arg2, %c1 : memref<100x10xf32>
scf.for %i = %c0 to %2 step %c1 {
scf.for %j = %c0 to %3 step %c1 {
%6 = memref.subview %1[%i, %j][%c1, %c1][%c1, %c1] :
memref<100x10xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
%7 = memref.subview %arg2[%i, %j][%c1, %c1][%c1, %c1] :
memref<100x10xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
linalg.generic {
indexing_maps = [#map1, #map1],
iterator_types = ["parallel", "parallel"]}
ins(%6 : memref<?x?xf32, strided<[?, ?], offset: ?>>)
outs(%7 : memref<?x?xf32, strided<[?, ?], offset: ?>>) {
^bb0(%arg3: f32, %arg4: f32):
%8 = math.exp %arg3 : f32
linalg.yield %8 : f32
}
}
}
memref.dealloc %1 : memref<100x10xf32>
return
}
// CHECK-LABEL: func @fusion
// CHECK-NOT: linalg.generic
// CHECK: scf.for
// CHECK: scf.for
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: linalg.yield
// CHECK: linalg.generic
// CHECK: arith.subf
// CHECK: linalg.yield
// CHECK: linalg.generic
// CHECK: exp
// CHECK: linalg.yield
// -----
#map0 = affine_map<(d0)[s0] -> (2, -d0 + s0)>
#map1 = affine_map<(d0)[s0] -> (3, -d0 + s0)>
#map3 = affine_map<(d0)[s0, s1] -> (s0 + 1, -d0 + s0 + s1)>
#map4 = affine_map<(d0)[s0, s1] -> (s0 + 2, -d0 + s0 + s1)>
func.func @fill_and_conv(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
%cst = arith.constant 0.000000e+00 : f32
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
linalg.fill ins(%cst : f32) outs(%arg0 : memref<?x?xf32>)
%2 = memref.dim %arg1, %c0 : memref<?x?xf32>
%3 = memref.dim %arg1, %c1 : memref<?x?xf32>
%4 = memref.dim %arg2, %c0 : memref<?x?xf32>
%5 = memref.dim %arg2, %c1 : memref<?x?xf32>
scf.for %arg3 = %c0 to %4 step %c2 {
scf.for %arg4 = %c0 to %5 step %c3 {
%6 = affine.min #map3(%arg3)[%2, %4]
%7 = affine.min #map4(%arg4)[%3, %5]
%8 = memref.subview %arg0[%arg3, %arg4] [%6, %7] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
%9 = affine.min #map0(%arg3)[%4]
%10 = affine.min #map1(%arg4)[%5]
%11 = memref.subview %arg2[%arg3, %arg4] [%9, %10] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
linalg.conv_2d ins(%8, %arg1 : memref<?x?xf32, strided<[?, 1], offset: ?>>, memref<?x?xf32>) outs(%11 : memref<?x?xf32, strided<[?, 1], offset: ?>>)
}
}
return
}
// CHECK-LABEL: func @fill_and_conv
// CHECK: scf.for
// CHECK: scf.for
// CHECK: linalg.fill
// CHECK: linalg.conv_2d
// -----
// Test that different allocation-like ops are recognized and properly handled.
func.func @accept_different_alloc_ops(%dim: index, %s0 : index, %s1: index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%A = memref.alloca(%dim, %dim)[%s0, %s1] : memref<?x?xf32, strided<[?, ?], offset: 0>>
%B = memref.alloca(%dim, %dim)[%s0, %s1] : memref<?x?xf32, strided<[?, ?], offset: 0>>
%C = memref.alloc(%dim, %dim)[%s0, %s1] : memref<?x?xf32, strided<[?, ?], offset: 0>>
linalg.matmul ins(%A, %B : memref<?x?xf32, strided<[?, ?], offset: 0>>,
memref<?x?xf32, strided<[?, ?], offset: 0>>)
outs(%C : memref<?x?xf32, strided<[?, ?], offset: 0>>)
scf.for %i = %c0 to %dim step %c2 {
scf.for %j = %c0 to %dim step %c3 {
scf.for %k = %c0 to %dim step %c4 {
%0 = memref.subview %A[%i, %k][%c2, %c4][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
%1 = memref.subview %B[%k, %j][%c4, %c3][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
%2 = memref.subview %C[%i, %j][%c2, %c3][%c1, %c1] :
memref<?x?xf32, strided<[?, ?], offset: 0>> to
memref<?x?xf32, strided<[?, ?], offset: ?>>
linalg.matmul ins(%0, %1 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
memref<?x?xf32, strided<[?, ?], offset: ?>>)
outs(%2 : memref<?x?xf32, strided<[?, ?], offset: ?>>)
}
}
}
return
}
// CHECK-LABEL: func @accept_different_alloc_ops
// CHECK-COUNT-3: scf.for
// CHECK-COUNT-2: linalg.matmul

View File

@@ -12,7 +12,6 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
@@ -39,22 +38,9 @@ static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) {
bool changed = false;
for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
if (opOperand.get().getType().isa<MemRefType>()) {
// TODO: LinalgDependenceGraph should be able to update itself.
// The current naive and expensive reconstruction of the graph should be
// removed.
linalg::Aliases aliases;
linalg::LinalgDependenceGraph graph(aliases, linalgOps);
auto info = fuseProducerOfBuffer(b, opOperand, graph);
if (failed(info))
continue;
auto *originalOp = info->originalProducer.getOperation();
eraseSet.insert(originalOp);
auto *originalOpInLinalgOpsVector =
std::find(linalgOps.begin(), linalgOps.end(), originalOp);
*originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
changed = true;
} else if (opOperand.get().getType().isa<RankedTensorType>()) {
if (opOperand.get().getType().isa<MemRefType>())
continue;
if (opOperand.get().getType().isa<RankedTensorType>()) {
// Tile and Fuse tensor input.
if (opOperand.getOperandNumber() >= linalgOp.getNumDpsInputs())
continue;

View File

@@ -8384,25 +8384,6 @@ gentbl_cc_library(
deps = [":PassBaseTdFiles"],
)
cc_library(
name = "LinalgAnalysis",
srcs = glob([
"lib/Dialect/Linalg/Analysis/*.cpp",
"lib/Dialect/Linalg/Analysis/*.h",
]),
hdrs = glob([
"include/mlir/Dialect/Linalg/Analysis/*.h",
]),
includes = ["include"],
deps = [
":BufferizationDialect",
":FuncDialect",
":IR",
":LinalgDialect",
"//llvm:Support",
],
)
cc_library(
name = "LinalgUtils",
srcs = glob([
@@ -8422,7 +8403,6 @@ cc_library(
":DialectUtils",
":FuncDialect",
":IR",
":LinalgAnalysis",
":LinalgDialect",
":MemRefDialect",
":Pass",
@@ -8462,7 +8442,6 @@ cc_library(
":FuncDialect",
":FuncTransforms",
":IR",
":LinalgAnalysis",
":LinalgDialect",
":LinalgPassIncGen",
":LinalgStructuredOpsIncGen",

View File

@@ -564,7 +564,6 @@ cc_library(
"//mlir:FuncTransforms",
"//mlir:GPUDialect",
"//mlir:IR",
"//mlir:LinalgAnalysis",
"//mlir:LinalgDialect",
"//mlir:LinalgTransforms",
"//mlir:LinalgUtils",