[mlir] Remove Linalg fusion-on-memrefs.
PSA: https://discourse.llvm.org/t/psa-retire-tileandfuselinalgops-method/63850 Differential Revision: https://reviews.llvm.org/D141807
This commit is contained in:
@@ -1,275 +0,0 @@
|
||||
//===- DependenceAnalysis.h - Dependence analysis on SSA views --*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
|
||||
#define MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include <optional>
|
||||
|
||||
namespace mlir {
|
||||
namespace func {
|
||||
class FuncOp;
|
||||
} // namespace func
|
||||
|
||||
namespace linalg {
|
||||
|
||||
class LinalgOp;
|
||||
|
||||
/// A very primitive alias analysis which just records for each view, either:
|
||||
/// 1. The base buffer, or
|
||||
/// 2. The block argument view
|
||||
/// that it indexes into.
|
||||
/// This does not perform inter-block or inter-procedural analysis and assumes
|
||||
/// that different block argument views do not alias.
|
||||
class Aliases {
|
||||
public:
|
||||
/// Returns true if v1 and v2 alias.
|
||||
bool alias(Value v1, Value v2) { return find(v1) == find(v2); }
|
||||
|
||||
private:
|
||||
/// Returns the base buffer or block argument into which the view `v` aliases.
|
||||
/// This lazily records the new aliases discovered while walking back the
|
||||
/// use-def chain.
|
||||
Value find(Value v);
|
||||
|
||||
DenseMap<Value, Value> aliases;
|
||||
};
|
||||
|
||||
/// Data structure for holding a dependence graph that operates on LinalgOp and
|
||||
/// views as SSA values.
|
||||
class LinalgDependenceGraph {
|
||||
public:
|
||||
enum DependenceType { RAR = 0, RAW, WAR, WAW, NumTypes };
|
||||
// TODO: OpOperand tracks dependencies on buffer operands. Tensor result will
|
||||
// need an extension to use OpResult.
|
||||
struct LinalgDependenceGraphElem {
|
||||
using OpView = PointerUnion<OpOperand *, Value>;
|
||||
// dependentOpView may be either:
|
||||
// 1. src in the case of dependencesIntoGraphs.
|
||||
// 2. dst in the case of dependencesFromDstGraphs.
|
||||
OpView dependentOpView;
|
||||
// View in the op that is used to index in the graph:
|
||||
// 1. src in the case of dependencesFromDstGraphs.
|
||||
// 2. dst in the case of dependencesIntoGraphs.
|
||||
OpView indexingOpView;
|
||||
// Type of the dependence.
|
||||
DependenceType dependenceType;
|
||||
|
||||
// Return the Operation that owns the operand or result represented in
|
||||
// `opView`.
|
||||
static Operation *getOwner(OpView opView) {
|
||||
if (OpOperand *operand = opView.dyn_cast<OpOperand *>())
|
||||
return operand->getOwner();
|
||||
return opView.get<Value>().cast<OpResult>().getOwner();
|
||||
}
|
||||
// Return the operand or the result Value represented by the `opView`.
|
||||
static Value getValue(OpView opView) {
|
||||
if (OpOperand *operand = opView.dyn_cast<OpOperand *>())
|
||||
return operand->get();
|
||||
return opView.get<Value>();
|
||||
}
|
||||
// Return the indexing map of the operand/result in `opView` specified in
|
||||
// the owning LinalgOp. If the owner is not a LinalgOp returns std::nullopt.
|
||||
static std::optional<AffineMap> getIndexingMap(OpView opView) {
|
||||
auto owner = dyn_cast<LinalgOp>(getOwner(opView));
|
||||
if (!owner)
|
||||
return std::nullopt;
|
||||
if (OpOperand *operand = opView.dyn_cast<OpOperand *>())
|
||||
return owner.getMatchingIndexingMap(operand);
|
||||
return owner.getMatchingIndexingMap(owner.getDpsInitOperand(
|
||||
opView.get<Value>().cast<OpResult>().getResultNumber()));
|
||||
}
|
||||
// Return the operand number if the `opView` is an OpOperand *. Otherwise
|
||||
// return std::nullopt.
|
||||
static std::optional<unsigned> getOperandNumber(OpView opView) {
|
||||
if (OpOperand *operand = opView.dyn_cast<OpOperand *>())
|
||||
return operand->getOperandNumber();
|
||||
return std::nullopt;
|
||||
}
|
||||
// Return the result number if the `opView` is an OpResult. Otherwise return
|
||||
// std::nullopt.
|
||||
static std::optional<unsigned> getResultNumber(OpView opView) {
|
||||
if (OpResult result = opView.dyn_cast<Value>().cast<OpResult>())
|
||||
return result.getResultNumber();
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// Return the owner of the dependent OpView.
|
||||
Operation *getDependentOp() const { return getOwner(dependentOpView); }
|
||||
|
||||
// Return the owner of the indexing OpView.
|
||||
Operation *getIndexingOp() const { return getOwner(indexingOpView); }
|
||||
|
||||
// Return the operand or result stored in the dependentOpView.
|
||||
Value getDependentValue() const { return getValue(dependentOpView); }
|
||||
|
||||
// Return the operand or result stored in the indexingOpView.
|
||||
Value getIndexingValue() const { return getValue(indexingOpView); }
|
||||
|
||||
// If the dependent OpView is an operand, return operand number. Return
|
||||
// std::nullopt otherwise.
|
||||
std::optional<unsigned> getDependentOpViewOperandNum() const {
|
||||
return getOperandNumber(dependentOpView);
|
||||
}
|
||||
|
||||
// If the indexing OpView is an operand, return operand number. Return
|
||||
// std::nullopt otherwise.
|
||||
std::optional<unsigned> getIndexingOpViewOperandNum() const {
|
||||
return getOperandNumber(indexingOpView);
|
||||
}
|
||||
|
||||
// If the dependent OpView is a result value, return the result
|
||||
// number. Return std::nullopt otherwise.
|
||||
std::optional<unsigned> getDependentOpViewResultNum() const {
|
||||
return getResultNumber(dependentOpView);
|
||||
}
|
||||
|
||||
// If the dependent OpView is a result value, return the result
|
||||
// number. Return std::nullopt otherwise.
|
||||
std::optional<unsigned> getIndexingOpViewResultNum() const {
|
||||
return getResultNumber(indexingOpView);
|
||||
}
|
||||
|
||||
// Return the indexing map of the operand/result in the dependent OpView as
|
||||
// specified in the owner of the OpView.
|
||||
std::optional<AffineMap> getDependentOpViewIndexingMap() const {
|
||||
return getIndexingMap(dependentOpView);
|
||||
}
|
||||
|
||||
// Return the indexing map of the operand/result in the indexing OpView as
|
||||
// specified in the owner of the OpView.
|
||||
std::optional<AffineMap> getIndexingOpViewIndexingMap() const {
|
||||
return getIndexingMap(indexingOpView);
|
||||
}
|
||||
};
|
||||
using LinalgDependences = SmallVector<LinalgDependenceGraphElem, 8>;
|
||||
using DependenceGraph = DenseMap<Operation *, LinalgDependences>;
|
||||
using dependence_iterator = LinalgDependences::const_iterator;
|
||||
using dependence_range = iterator_range<dependence_iterator>;
|
||||
|
||||
static StringRef getDependenceTypeStr(DependenceType depType);
|
||||
|
||||
// Builds a linalg dependence graph for the ops of type LinalgOp under `f`.
|
||||
static LinalgDependenceGraph buildDependenceGraph(Aliases &aliases,
|
||||
func::FuncOp f);
|
||||
LinalgDependenceGraph(Aliases &aliases, ArrayRef<LinalgOp> ops);
|
||||
|
||||
/// Returns the X such that op -> X is a dependence of type dt.
|
||||
dependence_range getDependencesFrom(Operation *src, DependenceType dt) const;
|
||||
dependence_range getDependencesFrom(LinalgOp src, DependenceType dt) const;
|
||||
|
||||
/// Returns the X such that X -> op is a dependence of type dt.
|
||||
dependence_range getDependencesInto(Operation *dst, DependenceType dt) const;
|
||||
dependence_range getDependencesInto(LinalgOp dst, DependenceType dt) const;
|
||||
|
||||
/// Returns the operations that are interleaved between `srcLinalgOp` and
|
||||
/// `dstLinalgOp` and that are involved in any RAW, WAR or WAW dependence
|
||||
/// relation with `srcLinalgOp`, on any view.
|
||||
/// Any such operation prevents reordering.
|
||||
SmallVector<Operation *, 8>
|
||||
findCoveringDependences(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp) const;
|
||||
|
||||
/// Returns the operations that are interleaved between `srcLinalgOp` and
|
||||
/// `dstLinalgOp` and that are involved in a RAR or RAW with `srcLinalgOp`.
|
||||
/// Dependences are restricted to views aliasing `view`.
|
||||
SmallVector<Operation *, 8> findCoveringReads(LinalgOp srcLinalgOp,
|
||||
LinalgOp dstLinalgOp,
|
||||
Value view) const;
|
||||
|
||||
/// Returns the operations that are interleaved between `srcLinalgOp` and
|
||||
/// `dstLinalgOp` and that are involved in a WAR or WAW with `srcLinalgOp`.
|
||||
/// Dependences are restricted to views aliasing `view`.
|
||||
SmallVector<Operation *, 8> findCoveringWrites(LinalgOp srcLinalgOp,
|
||||
LinalgOp dstLinalgOp,
|
||||
Value view) const;
|
||||
|
||||
/// Returns true if the two operations have the specified dependence from
|
||||
/// `srcLinalgOp` to `dstLinalgOp`.
|
||||
bool hasDependenceFrom(LinalgOp srcLinalgOp, LinalgOp dstLinalgOp,
|
||||
ArrayRef<DependenceType> depTypes = {
|
||||
DependenceType::RAW, DependenceType::WAW}) const;
|
||||
|
||||
/// Returns true if the `linalgOp` has dependences into it.
|
||||
bool hasDependentOperationsInto(LinalgOp linalgOp,
|
||||
ArrayRef<DependenceType> depTypes = {
|
||||
DependenceType::RAW,
|
||||
DependenceType::WAW}) const;
|
||||
|
||||
/// Returns true if the `linalgOp` has dependences from it.
|
||||
bool hasDependentOperationsFrom(LinalgOp linalgOp,
|
||||
ArrayRef<DependenceType> depTypes = {
|
||||
DependenceType::RAW,
|
||||
DependenceType::WAW}) const;
|
||||
|
||||
/// Returns true if the `linalgOp` has dependences into or from it.
|
||||
bool hasDependentOperations(LinalgOp linalgOp,
|
||||
ArrayRef<DependenceType> depTypes = {
|
||||
DependenceType::RAW,
|
||||
DependenceType::WAW}) const;
|
||||
|
||||
/// Returns all operations that have a dependence into `linalgOp` of types
|
||||
/// listed in `depTypes`.
|
||||
SmallVector<LinalgDependenceGraphElem, 2> getDependentOperationsInto(
|
||||
LinalgOp linalgOp, ArrayRef<DependenceType> depTypes = {
|
||||
DependenceType::RAW, DependenceType::WAW}) const;
|
||||
|
||||
/// Returns all operations that have a dependence from `linalgOp` of types
|
||||
/// listed in `depTypes`.
|
||||
SmallVector<LinalgDependenceGraphElem, 2> getDependentOperationsFrom(
|
||||
LinalgOp linalgOp, ArrayRef<DependenceType> depTypes = {
|
||||
DependenceType::RAW, DependenceType::WAW}) const;
|
||||
|
||||
/// Returns all dependent operations (into and from) given `operation`.
|
||||
SmallVector<LinalgDependenceGraphElem, 2>
|
||||
getDependentOperations(LinalgOp linalgOp,
|
||||
ArrayRef<DependenceType> depTypes = {
|
||||
DependenceType::RAW, DependenceType::WAW}) const;
|
||||
|
||||
void print(raw_ostream &os) const;
|
||||
|
||||
void dump() const;
|
||||
|
||||
private:
|
||||
// Keep dependences in both directions, this is not just a performance gain
|
||||
// but it also reduces usage errors.
|
||||
// Dependence information is stored as a map of:
|
||||
// (source operation -> LinalgDependenceGraphElem)
|
||||
DependenceGraph dependencesFromGraphs[DependenceType::NumTypes];
|
||||
// Reverse dependence information is stored as a map of:
|
||||
// (destination operation -> LinalgDependenceGraphElem)
|
||||
DependenceGraph dependencesIntoGraphs[DependenceType::NumTypes];
|
||||
|
||||
/// Analyses the aliasing views between `src` and `dst` and inserts the proper
|
||||
/// dependences in the graph.
|
||||
void addDependencesBetween(LinalgOp src, LinalgOp dst);
|
||||
|
||||
// Adds an new dependence unit in the proper graph.
|
||||
// Uses std::pair to keep operations and view together and avoid usage errors
|
||||
// related to src/dst and producer/consumer terminology in the context of
|
||||
// dependences.
|
||||
void addDependenceElem(DependenceType dt,
|
||||
LinalgDependenceGraphElem::OpView indexingOpView,
|
||||
LinalgDependenceGraphElem::OpView dependentOpView);
|
||||
|
||||
/// Implementation detail for findCoveringxxx.
|
||||
SmallVector<Operation *, 8>
|
||||
findOperationsWithCoveringDependences(LinalgOp srcLinalgOp,
|
||||
LinalgOp dstLinalgOp, Value view,
|
||||
ArrayRef<DependenceType> types) const;
|
||||
|
||||
Aliases &aliases;
|
||||
SmallVector<LinalgOp, 8> linalgOps;
|
||||
DenseMap<Operation *, unsigned> linalgOpPositions;
|
||||
};
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
|
||||
@@ -9,10 +9,8 @@
|
||||
#ifndef MLIR_DIALECT_LINALG_UTILS_UTILS_H
|
||||
#define MLIR_DIALECT_LINALG_UTILS_UTILS_H
|
||||
|
||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include <optional>
|
||||
|
||||
@@ -27,7 +25,6 @@ class ExtractSliceOp;
|
||||
} // namespace tensor
|
||||
|
||||
namespace linalg {
|
||||
class LinalgDependenceGraph;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// General utilities
|
||||
@@ -153,19 +150,6 @@ enum class LinalgTilingLoopType {
|
||||
ParallelLoops = 2
|
||||
};
|
||||
|
||||
/// Checks whether the specific `producer` is the last write to exactly the
|
||||
/// whole `consumedView`. This checks structural dominance, that the dependence
|
||||
/// is a RAW without any interleaved write to any piece of `consumedView`.
|
||||
bool isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
|
||||
LinalgOp consumer, Value consumedView,
|
||||
LinalgOp producer);
|
||||
|
||||
/// Checks whether fusing the specific `producer` of the `consumedView` is
|
||||
/// feasible. This checks `producer` is the last write of `consumedView` and
|
||||
/// that no interleaved dependence would be violated (RAW, WAR or WAW).
|
||||
bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer,
|
||||
Value consumedView, LinalgOp producer);
|
||||
|
||||
/// Computes tile offsets, given a list of loop `ivs` and `tileSizes`. In case a
|
||||
/// tile size is zero (i.e., no tiling), the corresponding offset is also zero.
|
||||
SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc,
|
||||
@@ -268,13 +252,6 @@ void offsetIndices(OpBuilder &b, LinalgOp linalgOp,
|
||||
void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
|
||||
ArrayRef<OpFoldResult> offests);
|
||||
|
||||
using FusableOpDependencesTy = llvm::MapVector<
|
||||
Operation *,
|
||||
SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>;
|
||||
FusableOpDependencesTy
|
||||
findAllFusableDependences(ArrayRef<LinalgOp> ops,
|
||||
const LinalgDependenceGraph &dependenceGraph);
|
||||
|
||||
/// A struct containing the Linalg producer before and after fusion.
|
||||
/// When operating on tensors, `fusedProducer` may feed into a `tensor.cast` op
|
||||
/// before the consumer Linalg op, until enough canonicalizations have applied.
|
||||
@@ -283,14 +260,6 @@ struct FusionInfo {
|
||||
LinalgOp fusedProducer;
|
||||
};
|
||||
|
||||
/// Fuses producer into consumer if the producer is structurally feasible and
|
||||
/// the fusion would not violate dependencies.
|
||||
/// Implements the fusion part of the "tileAndFuse on buffers" transformation
|
||||
/// and thus requires the `consumerOpOperand` to be a `subview` op (generally
|
||||
/// obtained by applying the tiling transformation).
|
||||
FailureOr<FusionInfo> fuseProducerOfBuffer(OpBuilder &b,
|
||||
OpOperand &consumerOpOperand,
|
||||
const LinalgDependenceGraph &graph);
|
||||
/// Tensor counterpart of `fuseProducerOfBuffer`.
|
||||
/// This implements the fusion part of the "tileAndFuse on tensors"
|
||||
/// transformation and thus requires the `consumerOpOperand` to be a
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
add_mlir_dialect_library(MLIRLinalgAnalysis
|
||||
DependenceAnalysis.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRAffineAnalysis
|
||||
MLIRAnalysis
|
||||
MLIRIR
|
||||
MLIRLinalgDialect
|
||||
MLIRMemRefDialect
|
||||
)
|
||||
@@ -1,366 +0,0 @@
|
||||
//===- DependenceAnalysis.cpp - Dependence analysis on SSA views ----------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements view-based alias and dependence analyses.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE "linalg-dependence-analysis"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
using llvm::dbgs;
|
||||
|
||||
Value Aliases::find(Value v) {
|
||||
if (v.isa<BlockArgument>())
|
||||
return v;
|
||||
|
||||
auto it = aliases.find(v);
|
||||
if (it != aliases.end()) {
|
||||
assert(it->getSecond().getType().isa<BaseMemRefType>() &&
|
||||
"Memref expected");
|
||||
return it->getSecond();
|
||||
}
|
||||
|
||||
while (true) {
|
||||
if (v.isa<BlockArgument>())
|
||||
return v;
|
||||
|
||||
Operation *defOp = v.getDefiningOp();
|
||||
if (!defOp)
|
||||
return v;
|
||||
|
||||
// Treat RegionBranchOpInterfaces like an allocate and don't try to follow
|
||||
// the aliasing further.
|
||||
if (isa<RegionBranchOpInterface>(defOp))
|
||||
return v;
|
||||
if (isa<bufferization::ToMemrefOp>(defOp))
|
||||
return v;
|
||||
|
||||
if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(defOp)) {
|
||||
// Collect all memory effects on `v`.
|
||||
SmallVector<MemoryEffects::EffectInstance, 1> effects;
|
||||
memEffect.getEffectsOnValue(v, effects);
|
||||
|
||||
// If we have the 'Allocate' memory effect on `v`, then `v` should be the
|
||||
// original buffer.
|
||||
if (llvm::any_of(
|
||||
effects, [](const MemoryEffects::EffectInstance &instance) {
|
||||
return isa<MemoryEffects::Allocate>(instance.getEffect());
|
||||
}))
|
||||
return v;
|
||||
}
|
||||
|
||||
if (auto viewLikeOp = dyn_cast<ViewLikeOpInterface>(defOp)) {
|
||||
auto it =
|
||||
aliases.insert(std::make_pair(v, find(viewLikeOp.getViewSource())));
|
||||
return it.first->second;
|
||||
}
|
||||
|
||||
llvm::errs() << "View alias analysis reduces to: " << v << "\n";
|
||||
llvm_unreachable("unsupported view alias case");
|
||||
}
|
||||
}
|
||||
|
||||
StringRef LinalgDependenceGraph::getDependenceTypeStr(DependenceType depType) {
|
||||
switch (depType) {
|
||||
case LinalgDependenceGraph::DependenceType::RAW:
|
||||
return "RAW";
|
||||
case LinalgDependenceGraph::DependenceType::RAR:
|
||||
return "RAR";
|
||||
case LinalgDependenceGraph::DependenceType::WAR:
|
||||
return "WAR";
|
||||
case LinalgDependenceGraph::DependenceType::WAW:
|
||||
return "WAW";
|
||||
default:
|
||||
break;
|
||||
}
|
||||
llvm_unreachable("Unexpected DependenceType");
|
||||
}
|
||||
|
||||
LinalgDependenceGraph
|
||||
LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, func::FuncOp f) {
|
||||
SmallVector<LinalgOp, 8> linalgOps;
|
||||
f.walk([&](LinalgOp op) { linalgOps.push_back(op); });
|
||||
return LinalgDependenceGraph(aliases, linalgOps);
|
||||
}
|
||||
|
||||
LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
|
||||
ArrayRef<LinalgOp> ops)
|
||||
: aliases(aliases), linalgOps(ops.begin(), ops.end()) {
|
||||
for (const auto &en : llvm::enumerate(linalgOps)) {
|
||||
linalgOpPositions.insert(
|
||||
std::make_pair(en.value().getOperation(), en.index()));
|
||||
}
|
||||
for (unsigned i = 0, e = ops.size(); i < e; ++i) {
|
||||
for (unsigned j = i + 1; j < e; ++j) {
|
||||
addDependencesBetween(ops[i], ops[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void LinalgDependenceGraph::addDependenceElem(
|
||||
DependenceType dt, LinalgDependenceGraphElem::OpView indexingOpView,
|
||||
LinalgDependenceGraphElem::OpView dependentOpView) {
|
||||
LLVM_DEBUG(dbgs() << "\nAdd dep type " << getDependenceTypeStr(dt) << ":\t ("
|
||||
<< LinalgDependenceGraphElem::getValue(indexingOpView)
|
||||
<< " @) -> \n\t\t("
|
||||
<< LinalgDependenceGraphElem::getValue(dependentOpView)
|
||||
<< " @)");
|
||||
dependencesFromGraphs[dt][LinalgDependenceGraphElem::getOwner(indexingOpView)]
|
||||
.push_back(
|
||||
LinalgDependenceGraphElem{dependentOpView, indexingOpView, dt});
|
||||
dependencesIntoGraphs[dt]
|
||||
[LinalgDependenceGraphElem::getOwner(dependentOpView)]
|
||||
.push_back(LinalgDependenceGraphElem{
|
||||
indexingOpView, dependentOpView, dt});
|
||||
}
|
||||
|
||||
LinalgDependenceGraph::dependence_range
|
||||
LinalgDependenceGraph::getDependencesFrom(
|
||||
LinalgOp src, LinalgDependenceGraph::DependenceType dt) const {
|
||||
return getDependencesFrom(src.getOperation(), dt);
|
||||
}
|
||||
|
||||
LinalgDependenceGraph::dependence_range
|
||||
LinalgDependenceGraph::getDependencesFrom(
|
||||
Operation *src, LinalgDependenceGraph::DependenceType dt) const {
|
||||
auto iter = dependencesFromGraphs[dt].find(src);
|
||||
if (iter == dependencesFromGraphs[dt].end())
|
||||
return llvm::make_range(nullptr, nullptr);
|
||||
return llvm::make_range(iter->second.begin(), iter->second.end());
|
||||
}
|
||||
|
||||
LinalgDependenceGraph::dependence_range
|
||||
LinalgDependenceGraph::getDependencesInto(
|
||||
LinalgOp dst, LinalgDependenceGraph::DependenceType dt) const {
|
||||
return getDependencesInto(dst.getOperation(), dt);
|
||||
}
|
||||
|
||||
LinalgDependenceGraph::dependence_range
|
||||
LinalgDependenceGraph::getDependencesInto(
|
||||
Operation *dst, LinalgDependenceGraph::DependenceType dt) const {
|
||||
auto iter = dependencesIntoGraphs[dt].find(dst);
|
||||
if (iter == dependencesIntoGraphs[dt].end())
|
||||
return llvm::make_range(nullptr, nullptr);
|
||||
return llvm::make_range(iter->second.begin(), iter->second.end());
|
||||
}
|
||||
|
||||
void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
|
||||
LLVM_DEBUG(dbgs() << "addDependencesBetween " << *src.getOperation()
|
||||
<< " and " << *dst.getOperation() << "\n");
|
||||
if (src.hasTensorSemantics() && dst.hasTensorSemantics()) {
|
||||
for (OpOperand *dstOpOperand : dst.getDpsInputOperands()) {
|
||||
if (!dstOpOperand->get().getType().isa<RankedTensorType>())
|
||||
continue;
|
||||
// Check if the operand is defined by the src.
|
||||
auto definingOp = dstOpOperand->get().getDefiningOp<LinalgOp>();
|
||||
if (definingOp && definingOp == src)
|
||||
addDependenceElem(DependenceType::RAW, dstOpOperand->get(),
|
||||
dstOpOperand);
|
||||
}
|
||||
for (OpOperand *dstOpOperand : dst.getDpsInitOperands()) {
|
||||
// Check if the operand is defined by the src.
|
||||
auto definingOp = dstOpOperand->get().getDefiningOp<LinalgOp>();
|
||||
if (definingOp && definingOp == src) {
|
||||
if (dst.isInitTensor(dstOpOperand)) {
|
||||
addDependenceElem(DependenceType::RAW, dstOpOperand->get(),
|
||||
dstOpOperand);
|
||||
}
|
||||
addDependenceElem(DependenceType::WAW, dstOpOperand->get(),
|
||||
dstOpOperand);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
assert(src.hasBufferSemantics() && dst.hasBufferSemantics() &&
|
||||
"unhandled dependence tracking for mixed buffer/tensor operations");
|
||||
for (OpOperand *srcOpOperand : src.getDpsInitOperands()) { // W
|
||||
// RAW graph
|
||||
for (OpOperand *dstOpOperand : dst.getDpsInputOperands()) { // R
|
||||
if (!dstOpOperand->get().getType().isa<MemRefType>())
|
||||
continue;
|
||||
if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAW alias
|
||||
addDependenceElem(DependenceType::RAW, srcOpOperand, dstOpOperand);
|
||||
}
|
||||
// WAW graph
|
||||
for (OpOperand *dstOpOperand : dst.getDpsInitOperands()) // W
|
||||
if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAW alias
|
||||
addDependenceElem(DependenceType::WAW, srcOpOperand, dstOpOperand);
|
||||
}
|
||||
for (OpOperand *srcOpOperand : src.getDpsInputOperands()) { // R
|
||||
if (!srcOpOperand->get().getType().isa<MemRefType>())
|
||||
continue;
|
||||
// RAR graph
|
||||
for (OpOperand *dstOpOperand : dst.getDpsInputOperands()) { // R
|
||||
if (!dstOpOperand->get().getType().isa<MemRefType>())
|
||||
continue;
|
||||
if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAR alias
|
||||
addDependenceElem(DependenceType::RAR, srcOpOperand, dstOpOperand);
|
||||
}
|
||||
// WAR graph
|
||||
for (OpOperand *dstOpOperand : dst.getDpsInitOperands()) // W
|
||||
if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAR alias
|
||||
addDependenceElem(DependenceType::WAR, srcOpOperand, dstOpOperand);
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Operation *, 8>
|
||||
LinalgDependenceGraph::findCoveringDependences(LinalgOp srcLinalgOp,
|
||||
LinalgOp dstLinalgOp) const {
|
||||
return findOperationsWithCoveringDependences(
|
||||
srcLinalgOp, dstLinalgOp, nullptr,
|
||||
{DependenceType::WAW, DependenceType::WAR, DependenceType::RAW});
|
||||
}
|
||||
|
||||
SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringWrites(
|
||||
LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const {
|
||||
return findOperationsWithCoveringDependences(
|
||||
srcLinalgOp, dstLinalgOp, view,
|
||||
{DependenceType::WAW, DependenceType::WAR});
|
||||
}
|
||||
|
||||
SmallVector<Operation *, 8> LinalgDependenceGraph::findCoveringReads(
|
||||
LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view) const {
|
||||
return findOperationsWithCoveringDependences(
|
||||
srcLinalgOp, dstLinalgOp, view,
|
||||
{DependenceType::RAR, DependenceType::RAW});
|
||||
}
|
||||
|
||||
SmallVector<Operation *, 8>
|
||||
LinalgDependenceGraph::findOperationsWithCoveringDependences(
|
||||
LinalgOp srcLinalgOp, LinalgOp dstLinalgOp, Value view,
|
||||
ArrayRef<DependenceType> types) const {
|
||||
auto *src = srcLinalgOp.getOperation();
|
||||
auto *dst = dstLinalgOp.getOperation();
|
||||
auto srcPos = linalgOpPositions.lookup(src);
|
||||
auto dstPos = linalgOpPositions.lookup(dst);
|
||||
assert(srcPos < dstPos && "expected dst after src in IR traversal order");
|
||||
|
||||
SmallVector<Operation *, 8> res;
|
||||
// Consider an intermediate interleaved `interim` op, look for any dependence
|
||||
// to an aliasing view on a src -> op -> dst path.
|
||||
// TODO: we are not considering paths yet, just interleaved positions.
|
||||
for (auto dt : types) {
|
||||
for (auto dependence : getDependencesFrom(src, dt)) {
|
||||
auto interimPos = linalgOpPositions.lookup(dependence.getDependentOp());
|
||||
// Skip if not interleaved.
|
||||
if (interimPos >= dstPos || interimPos <= srcPos)
|
||||
continue;
|
||||
Value consumerView = dependence.getIndexingValue();
|
||||
if (view && !aliases.alias(view, consumerView))
|
||||
continue;
|
||||
auto *op = dependence.getDependentOp();
|
||||
LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type "
|
||||
<< getDependenceTypeStr(dt) << ": " << *src << " -> "
|
||||
<< *op << " on " << consumerView);
|
||||
res.push_back(op);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
bool LinalgDependenceGraph::hasDependenceFrom(
|
||||
LinalgOp srcLinalgOp, LinalgOp dstLinalgOp,
|
||||
ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
|
||||
for (auto dep : depTypes)
|
||||
for (auto dependence : getDependencesInto(dstLinalgOp, dep))
|
||||
if (dependence.getDependentOp() == srcLinalgOp)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool LinalgDependenceGraph::hasDependentOperationsFrom(
|
||||
LinalgOp linalgOp,
|
||||
ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
|
||||
for (auto dep : depTypes) {
|
||||
if (!getDependencesFrom(linalgOp, dep).empty())
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool LinalgDependenceGraph::hasDependentOperationsInto(
|
||||
LinalgOp linalgOp,
|
||||
ArrayRef<LinalgDependenceGraph::DependenceType> depTypes) const {
|
||||
for (auto dep : depTypes) {
|
||||
if (!getDependencesInto(linalgOp, dep).empty())
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool LinalgDependenceGraph::hasDependentOperations(
|
||||
LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
|
||||
return hasDependentOperationsInto(linalgOp, depTypes) ||
|
||||
hasDependentOperationsFrom(linalgOp, depTypes);
|
||||
}
|
||||
|
||||
SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
|
||||
LinalgDependenceGraph::getDependentOperationsInto(
|
||||
LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
|
||||
SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
|
||||
dependentOperations;
|
||||
for (auto dependenceType : depTypes) {
|
||||
auto dependencies = getDependencesInto(linalgOp, dependenceType);
|
||||
dependentOperations.append(dependencies.begin(), dependencies.end());
|
||||
}
|
||||
return dependentOperations;
|
||||
}
|
||||
|
||||
SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
|
||||
LinalgDependenceGraph::getDependentOperationsFrom(
|
||||
LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
|
||||
SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
|
||||
dependentOperations;
|
||||
for (auto dependenceType : depTypes) {
|
||||
auto dependencies = getDependencesFrom(linalgOp, dependenceType);
|
||||
dependentOperations.append(dependencies.begin(), dependencies.end());
|
||||
}
|
||||
return dependentOperations;
|
||||
}
|
||||
|
||||
/// Returns all dependent operations (into and from) given `operation`.
|
||||
SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 2>
|
||||
LinalgDependenceGraph::getDependentOperations(
|
||||
LinalgOp linalgOp, ArrayRef<DependenceType> depTypes) const {
|
||||
SmallVector<LinalgDependenceGraphElem, 2> dependentOperations =
|
||||
getDependentOperationsInto(linalgOp, depTypes);
|
||||
SmallVector<LinalgDependenceGraphElem, 2> t =
|
||||
getDependentOperationsFrom(linalgOp, depTypes);
|
||||
dependentOperations.append(t.begin(), t.end());
|
||||
return dependentOperations;
|
||||
}
|
||||
|
||||
void LinalgDependenceGraph::print(raw_ostream &os) const {
|
||||
for (auto dt : {
|
||||
LinalgDependenceGraph::DependenceType::RAW,
|
||||
LinalgDependenceGraph::DependenceType::WAW,
|
||||
}) {
|
||||
const auto &fromGraph = dependencesFromGraphs[dt];
|
||||
for (const auto &it : fromGraph) {
|
||||
os << "[LinalgDependenceGraph] DT " << dt << " from: " << *it.first
|
||||
<< ":\n";
|
||||
for (const auto &dep : it.second) {
|
||||
os << "\tDT " << dt << " " << *dep.getDependentOp() << ":\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void LinalgDependenceGraph::dump() const { print(llvm::errs()); }
|
||||
@@ -1,4 +1,3 @@
|
||||
add_subdirectory(Analysis)
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(TransformOps)
|
||||
add_subdirectory(Transforms)
|
||||
|
||||
@@ -56,7 +56,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
|
||||
MLIRMemRefDialect
|
||||
MLIRMemRefTransforms
|
||||
MLIRLinalgDialect
|
||||
MLIRLinalgAnalysis
|
||||
MLIRLinalgUtils
|
||||
MLIRSCFDialect
|
||||
MLIRSCFTransforms
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
@@ -204,173 +203,6 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap,
|
||||
return fuse(b, producerOp, fusedLoopsAndRanges);
|
||||
}
|
||||
|
||||
// Encode structural fusion safety preconditions.
|
||||
// Some of these will be lifted in the future with better analysis.
|
||||
static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
|
||||
LinalgOp consumer) {
|
||||
assert(producer.hasBufferSemantics() &&
|
||||
"expected linalg op with buffer semantics");
|
||||
assert(consumer.hasBufferSemantics() &&
|
||||
"expected linalg op with buffer semantics");
|
||||
if (producer.getNumDpsInits() != 1) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)");
|
||||
return false;
|
||||
}
|
||||
// Only fuse when the producer block dominates.
|
||||
DominanceInfo dom(producer.getOperation());
|
||||
if (!dom.dominates(producer->getBlock(), consumer->getBlock())) {
|
||||
LLVM_DEBUG(
|
||||
llvm::dbgs()
|
||||
<< "\nNot structurally fusable (producer block does not dominate)");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
|
||||
LinalgOp consumer,
|
||||
Value consumedView,
|
||||
LinalgOp producer) {
|
||||
assert(producer.hasBufferSemantics() &&
|
||||
"expected linalg op with buffer semantics");
|
||||
assert(consumer.hasBufferSemantics() &&
|
||||
"expected linalg op with buffer semantics");
|
||||
// Make some simple structural checks that alleviate the need for more
|
||||
// complex analyses.
|
||||
if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "\n***Not static last write due to structure:\t"
|
||||
<< *producer.getOperation());
|
||||
return false;
|
||||
}
|
||||
// Check for any interleaved write to consumedView.
|
||||
if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "\n***Not fusable due to interleaved write:\t"
|
||||
<< *producer.getOperation());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
|
||||
LinalgOp consumer, Value consumedView,
|
||||
LinalgOp producer) {
|
||||
assert(producer.hasBufferSemantics() &&
|
||||
"expected linalg op with buffer semantics");
|
||||
assert(consumer.hasBufferSemantics() &&
|
||||
"expected linalg op with buffer semantics");
|
||||
if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
|
||||
return false;
|
||||
// Check for any fusion-preventing dependence to any shape read/written that
|
||||
// would violate dependences.
|
||||
if (!graph.findCoveringDependences(producer, consumer).empty()) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "\n***Not fusable due to an interleaved dependence:\t"
|
||||
<< *producer.getOperation());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// For `consumer` with buffer semantics, find the Linalg operation on buffers
|
||||
/// that is the last writer of `consumerOpOperand`. For now the fusable
|
||||
/// dependence is returned as an instance of the `dependenceGraph`.
|
||||
static FailureOr<LinalgDependenceGraph::LinalgDependenceGraphElem>
|
||||
findFusableProducer(OpOperand &consumerOpOperand,
|
||||
const LinalgDependenceGraph &dependenceGraph) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "findFusableProducer for: "
|
||||
<< consumerOpOperand.get() << " @"
|
||||
<< consumerOpOperand.getOperandNumber() << " in "
|
||||
<< *consumerOpOperand.getOwner() << "\n");
|
||||
LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner());
|
||||
if (!consumerOp)
|
||||
return failure();
|
||||
|
||||
// Only consider RAW and WAW atm.
|
||||
for (auto depType : {
|
||||
LinalgDependenceGraph::DependenceType::RAW,
|
||||
LinalgDependenceGraph::DependenceType::WAW,
|
||||
}) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "Dependencies into: " << *consumerOp.getOperation() << "\n");
|
||||
for (auto dependence : llvm::make_filter_range(
|
||||
dependenceGraph.getDependencesInto(consumerOp, depType),
|
||||
[&](LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Inspect dependence btw: "
|
||||
<< elem.getIndexingValue() << " and "
|
||||
<< elem.getDependentValue() << "\n");
|
||||
Value v = elem.getIndexingValue();
|
||||
std::optional<unsigned> operandNum =
|
||||
elem.getIndexingOpViewOperandNum();
|
||||
return isa<LinalgOp>(elem.getDependentOp()) &&
|
||||
v == consumerOpOperand.get() && operandNum &&
|
||||
*operandNum == consumerOpOperand.getOperandNumber();
|
||||
})) {
|
||||
// Consumer consumes this view, `isStructurallyFusableProducer` also
|
||||
// checks whether it is a strict subview of the producer view.
|
||||
auto producer = cast<LinalgOp>(dependence.getDependentOp());
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "\n"
|
||||
<< LinalgDependenceGraph::getDependenceTypeStr(depType)
|
||||
<< "producer: " << *dependence.getDependentOp()
|
||||
<< " view: " << dependence.getDependentValue() << "\n");
|
||||
|
||||
// If the producer and consumer have tensor semantics, the only dependence
|
||||
// between them is through a RAW dependence and they are fusable by
|
||||
// construction. For buffer semantics need additional checks.
|
||||
if (producer.hasBufferSemantics() && consumerOp.hasBufferSemantics() &&
|
||||
isFusableInto(dependenceGraph, consumerOp, consumerOpOperand.get(),
|
||||
producer))
|
||||
return dependence;
|
||||
if (producer.hasTensorSemantics() && consumerOp.hasTensorSemantics()) {
|
||||
assert(dependence.dependenceType ==
|
||||
LinalgDependenceGraph::DependenceType::RAW);
|
||||
return dependence;
|
||||
}
|
||||
}
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
FailureOr<FusionInfo>
|
||||
mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand,
|
||||
const LinalgDependenceGraph &graph) {
|
||||
std::optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
|
||||
fusableDependence = findFusableProducer(consumerOpOperand, graph);
|
||||
if (!fusableDependence)
|
||||
return failure();
|
||||
|
||||
LinalgOp producerOp = dyn_cast<LinalgOp>(fusableDependence->getDependentOp());
|
||||
if (!producerOp)
|
||||
return failure();
|
||||
|
||||
// If producer is already in the same block as consumer, we are done.
|
||||
if (consumerOpOperand.get().getParentBlock() ==
|
||||
fusableDependence->getDependentValue().getParentBlock())
|
||||
return failure();
|
||||
|
||||
std::optional<AffineMap> producerMap =
|
||||
fusableDependence->getDependentOpViewIndexingMap();
|
||||
if (!producerMap)
|
||||
return failure();
|
||||
|
||||
// Must be a subview or an extract_slice to guarantee there are loops we can
|
||||
// fuse into.
|
||||
auto subView = consumerOpOperand.get().getDefiningOp<memref::SubViewOp>();
|
||||
if (!subView) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview)");
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Fuse `producer` just before `consumer`.
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(consumerOpOperand.getOwner());
|
||||
LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: "
|
||||
<< *consumerOpOperand.getOwner() << "\n");
|
||||
|
||||
auto fusedProducer = fuse(b, producerOp, *producerMap, consumerOpOperand);
|
||||
return FusionInfo{producerOp, fusedProducer};
|
||||
}
|
||||
|
||||
/// Walk back use-def chain through scf::For yields.
|
||||
/// Sets `producer` and `outputIndex` if it finds a producer LinalgOp
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
// RUN: mlir-opt %s -test-linalg-greedy-fusion | FileCheck %s
|
||||
|
||||
func.func @f1(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>, %B: memref<?x?xf32, strided<[?, 1], offset: ?>>, %C: memref<?x?xf32, strided<[?, 1], offset: ?>>, %D: memref<?x?xf32, strided<[?, 1], offset: ?>>, %E: memref<?x?xf32, strided<[?, 1], offset: ?>>) -> memref<?x?xf32, strided<[?, 1], offset: ?>> {
|
||||
%c1 = arith.constant 1 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%c3 = arith.constant 3 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
%c40 = arith.constant 40 : index
|
||||
%c30 = arith.constant 30 : index
|
||||
%c20 = arith.constant 20 : index
|
||||
%0 = memref.dim %C, %c0 : memref<?x?xf32, strided<[?, 1], offset: ?>>
|
||||
%1 = memref.dim %C, %c1 : memref<?x?xf32, strided<[?, 1], offset: ?>>
|
||||
%2 = memref.dim %D, %c1 : memref<?x?xf32, strided<[?, 1], offset: ?>>
|
||||
linalg.matmul ins(%A, %B: memref<?x?xf32, strided<[?, 1], offset: ?>>, memref<?x?xf32, strided<[?, 1], offset: ?>>)
|
||||
outs(%C: memref<?x?xf32, strided<[?, 1], offset: ?>>)
|
||||
scf.for %arg5 = %c0 to %0 step %c20 {
|
||||
scf.for %arg6 = %c0 to %2 step %c30 {
|
||||
scf.for %arg7 = %c0 to %1 step %c40 {
|
||||
%5 = memref.subview %C[%arg5, %arg7][%c20, %c40][%c1, %c1] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%7 = memref.subview %D[%arg7, %arg6][%c40, %c30][%c1, %c1]: memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%8 = memref.subview %E[%arg5, %arg6][%c20, %c40][%c1, %c1] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%9 = memref.dim %5, %c0 : memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%10 = memref.dim %5, %c1 : memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%11 = memref.dim %7, %c1 : memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
scf.for %arg8 = %c0 to %9 step %c2 {
|
||||
scf.for %arg9 = %c0 to %11 step %c3 {
|
||||
scf.for %arg10 = %c0 to %10 step %c4 {
|
||||
%14 = memref.subview %5[%arg8, %arg10][%c2, %c4][%c1, %c1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%16 = memref.subview %7[%arg10, %arg9][%c4, %c3][%c1, %c1]: memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%17 = memref.subview %8[%arg8, %arg9][%c2, %c3][%c1, %c1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
linalg.matmul ins(%14, %16: memref<?x?xf32, strided<[?, ?], offset: ?>>, memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
outs(%17: memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return %E : memref<?x?xf32, strided<[?, 1], offset: ?>>
|
||||
}
|
||||
// CHECK-LABEL: func @f1
|
||||
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
|
||||
// CHECK: scf.for
|
||||
// CHECK: scf.for
|
||||
// CHECK: scf.for
|
||||
// CHECK: scf.for
|
||||
// CHECK: scf.for
|
||||
// CHECK: scf.for
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK: linalg.matmul
|
||||
@@ -1,160 +0,0 @@
|
||||
// RUN: mlir-opt %s -test-linalg-greedy-fusion -split-input-file | FileCheck %s
|
||||
|
||||
#id_2d = affine_map<(d0, d1) -> (d0, d1)>
|
||||
#pointwise_2d_trait = {
|
||||
indexing_maps = [#id_2d, #id_2d, #id_2d],
|
||||
iterator_types = ["parallel", "parallel"]
|
||||
}
|
||||
func.func @fuse_indexed_consumer(%A: memref<?x?xf32>,
|
||||
%B: memref<?x?xf32>,
|
||||
%C: memref<?x?xf32>,
|
||||
%D: memref<?x?xf32>) {
|
||||
linalg.generic #pointwise_2d_trait
|
||||
ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
|
||||
outs(%C : memref<?x?xf32>) {
|
||||
^bb0(%e: f32, %arg5: f32, %arg6: f32):
|
||||
%2 = arith.addf %e, %arg5 : f32
|
||||
linalg.yield %2 : f32
|
||||
}
|
||||
%c1 = arith.constant 1 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c25 = arith.constant 25 : index
|
||||
%c10 = arith.constant 10 : index
|
||||
%0 = memref.dim %C, %c0 : memref<?x?xf32>
|
||||
%1 = memref.dim %C, %c1 : memref<?x?xf32>
|
||||
%2 = memref.dim %D, %c0 : memref<?x?xf32>
|
||||
%3 = memref.dim %D, %c1 : memref<?x?xf32>
|
||||
scf.for %arg2 = %c0 to %0 step %c10 {
|
||||
scf.for %arg3 = %c0 to %1 step %c25 {
|
||||
%4 = memref.subview %C[%arg2, %arg3][%c10, %c25][%c1, %c1] :
|
||||
memref<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%5 = memref.subview %D[%arg2, %arg3][%c10, %c25][%c1, %c1] :
|
||||
memref<?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
linalg.generic {
|
||||
indexing_maps = [#id_2d, #id_2d],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%4 : memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
outs(%5 : memref<?x?xf32, strided<[?, ?], offset: ?>>) {
|
||||
^bb0(%arg4: f32, %arg5: f32):
|
||||
%idx0 = linalg.index 0 : index
|
||||
%idx1 = linalg.index 1 : index
|
||||
%6 = arith.addi %idx0, %arg2 : index
|
||||
%7 = arith.addi %idx1, %arg3 : index
|
||||
%8 = arith.index_cast %6 : index to i32
|
||||
%9 = arith.sitofp %8 : i32 to f32
|
||||
%10 = arith.index_cast %7 : index to i32
|
||||
%11 = arith.sitofp %10 : i32 to f32
|
||||
%12 = arith.addf %9, %11 : f32
|
||||
%13 = arith.addf %12, %arg4 : f32
|
||||
linalg.yield %13 : f32
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @fuse_indexed_consumer
|
||||
// CHECK: scf.for
|
||||
// CHECK: scf.for
|
||||
// CHECK-NOT: scf.for
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NOT: affine.apply
|
||||
// CHECK: arith.addf
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: arith.index_cast
|
||||
|
||||
// -----
|
||||
|
||||
func.func @fuse_indexed_producer(%A: memref<?x?xindex>,
|
||||
%B: memref<?x?xindex>) {
|
||||
%c1 = arith.constant 1 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c25 = arith.constant 25 : index
|
||||
%c10 = arith.constant 10 : index
|
||||
linalg.generic {
|
||||
indexing_maps = [affine_map<(i, j) -> (j, i)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
outs(%A : memref<?x?xindex>) {
|
||||
^bb0(%a: index):
|
||||
%idx0 = linalg.index 0 : index
|
||||
%idx1 = linalg.index 1 : index
|
||||
%0 = arith.addi %idx0, %idx1 : index
|
||||
linalg.yield %0 : index
|
||||
}
|
||||
%A_X = memref.dim %A, %c0 : memref<?x?xindex>
|
||||
%A_Y = memref.dim %A, %c1 : memref<?x?xindex>
|
||||
scf.parallel (%arg2, %arg3) = (%c0, %c0) to (%A_X, %A_Y) step (%c10, %c25) {
|
||||
%A_view = memref.subview %A[%arg2, %arg3][%c10, %c25][%c1, %c1] :
|
||||
memref<?x?xindex> to memref<?x?xindex, strided<[?, ?], offset: ?>>
|
||||
%B_view = memref.subview %B[%arg2, %arg3][%c10, %c25][%c1, %c1] :
|
||||
memref<?x?xindex> to memref<?x?xindex, strided<[?, ?], offset: ?>>
|
||||
linalg.generic {
|
||||
indexing_maps = [affine_map<(i, j) -> (i, j)>,
|
||||
affine_map<(i, j) -> (i, j)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%A_view : memref<?x?xindex, strided<[?, ?], offset: ?>>)
|
||||
outs(%B_view : memref<?x?xindex, strided<[?, ?], offset: ?>>) {
|
||||
^bb0(%a: index, %b: index):
|
||||
linalg.yield %a : index
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
|
||||
// CHECK-LABEL: func @fuse_indexed_producer
|
||||
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) =
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: [[idx0:%.*]] = linalg.index 0 : index
|
||||
// CHECK: [[i_new:%.*]] = affine.apply [[$MAP]]([[idx0]], [[J]])
|
||||
// CHECK: [[idx1:%.*]] = linalg.index 1 : index
|
||||
// CHECK: [[j_new:%.*]] = affine.apply [[$MAP]]([[idx1]], [[I]])
|
||||
// CHECK: [[sum:%.*]] = arith.addi [[i_new]], [[j_new]] : index
|
||||
// CHECK: linalg.yield [[sum]] : index
|
||||
// CHECK: linalg.generic
|
||||
|
||||
// -----
|
||||
|
||||
func.func @fuse_indexed_producer_tiled_second_dim_only(%A: memref<?x?xindex>,
|
||||
%B: memref<?x?xindex>) {
|
||||
%c1 = arith.constant 1 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c25 = arith.constant 25 : index
|
||||
linalg.generic {
|
||||
indexing_maps = [affine_map<(i, j) -> (i, j)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
outs(%A : memref<?x?xindex>) {
|
||||
^bb0(%a: index):
|
||||
%idx0 = linalg.index 0 : index
|
||||
%idx1 = linalg.index 1 : index
|
||||
%0 = arith.addi %idx0, %idx1 : index
|
||||
linalg.yield %0 : index
|
||||
}
|
||||
%A_X = memref.dim %A, %c0 : memref<?x?xindex>
|
||||
%A_Y = memref.dim %A, %c1 : memref<?x?xindex>
|
||||
scf.parallel (%arg3) = (%c0) to (%A_Y) step (%c25) {
|
||||
%A_view = memref.subview %A[%c0, %arg3][%A_X, %c25][%c1, %c1] :
|
||||
memref<?x?xindex> to memref<?x?xindex, strided<[?, ?], offset: ?>>
|
||||
%B_view = memref.subview %B[%c0, %arg3][%A_X, %c25][%c1, %c1] :
|
||||
memref<?x?xindex> to memref<?x?xindex, strided<[?, ?], offset: ?>>
|
||||
linalg.generic {
|
||||
indexing_maps = [affine_map<(i, j) -> (i, j)>,
|
||||
affine_map<(i, j) -> (i, j)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%A_view : memref<?x?xindex, strided<[?, ?], offset: ?>>)
|
||||
outs(%B_view : memref<?x?xindex, strided<[?, ?], offset: ?>>) {
|
||||
^bb0(%a: index, %b: index):
|
||||
linalg.yield %a : index
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
|
||||
// CHECK-LABEL: func @fuse_indexed_producer_tiled_second_dim_only
|
||||
// CHECK: scf.parallel ([[J:%.*]]) =
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: [[idx0:%.*]] = linalg.index 0 : index
|
||||
// CHECK: [[idx1:%.*]] = linalg.index 1 : index
|
||||
// CHECK: [[j_new:%.*]] = affine.apply [[$MAP]]([[idx1]], [[J]])
|
||||
// CHECK: [[sum:%.*]] = arith.addi [[idx0]], [[j_new]] : index
|
||||
// CHECK: linalg.yield [[sum]] : index
|
||||
// CHECK: linalg.generic
|
||||
|
||||
@@ -1,745 +0,0 @@
|
||||
// RUN: mlir-opt %s -test-linalg-greedy-fusion -split-input-file | FileCheck %s
|
||||
|
||||
func.func @f1(%A: memref<?x?xf32, strided<[?, 1], offset: 0>>,
|
||||
%B: memref<?x?xf32, strided<[?, 1], offset: 0>>,
|
||||
%C: memref<?x?xf32, strided<[?, 1], offset: 0>>,
|
||||
%D: memref<?x?xf32, strided<[?, 1], offset: 0>>,
|
||||
%E: memref<?x?xf32, strided<[?, 1], offset: 0>>
|
||||
) -> memref<?x?xf32, strided<[?, 1], offset: 0>> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%c3 = arith.constant 3 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%0 = memref.dim %A, %c0 : memref<?x?xf32, strided<[?, 1], offset: 0>>
|
||||
%1 = memref.dim %A, %c1 : memref<?x?xf32, strided<[?, 1], offset: 0>>
|
||||
%2 = memref.dim %B, %c1 : memref<?x?xf32, strided<[?, 1], offset: 0>>
|
||||
linalg.matmul ins(%A, %B : memref<?x?xf32, strided<[?, 1], offset: 0>>,
|
||||
memref<?x?xf32, strided<[?, 1], offset: 0>>)
|
||||
outs(%C : memref<?x?xf32, strided<[?, 1], offset: 0>>)
|
||||
scf.for %arg5 = %c0 to %0 step %c2 {
|
||||
scf.for %arg6 = %c0 to %2 step %c3 {
|
||||
scf.for %arg7 = %c0 to %1 step %c4 {
|
||||
%5 = memref.subview %A[%arg5, %arg7][%c2, %c4][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, 1], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%7 = memref.subview %B[%arg7, %arg6][%c4, %c3][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, 1], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%8 = memref.subview %C[%arg5, %arg6][%c2, %c3][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, 1], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
linalg.matmul ins(%5, %7 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
outs(%8: memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
}
|
||||
}
|
||||
}
|
||||
return %E : memref<?x?xf32, strided<[?, 1], offset: 0>>
|
||||
}
|
||||
// CHECK-LABEL: func @f1
|
||||
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
|
||||
// CHECK: scf.for
|
||||
// CHECK: scf.for
|
||||
// CHECK: scf.for
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK: linalg.matmul
|
||||
|
||||
// -----
|
||||
|
||||
func.func @f2(%A: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%B: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%C: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%D: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%E: memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
) -> memref<?x?xf32, strided<[?, ?], offset: 0>> {
|
||||
%c1 = arith.constant 1 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%c3 = arith.constant 3 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
linalg.matmul ins(%A, %B : memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
outs(%C: memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
%0 = memref.dim %C, %c0 : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
%1 = memref.dim %C, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
%2 = memref.dim %D, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
scf.for %arg5 = %c0 to %0 step %c2 {
|
||||
scf.for %arg6 = %c0 to %2 step %c3 {
|
||||
scf.for %arg7 = %c0 to %1 step %c4 {
|
||||
%5 = memref.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%7 = memref.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%8 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
linalg.matmul ins(%5, %7 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
outs(%8 : memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
}
|
||||
}
|
||||
}
|
||||
return %E : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
}
|
||||
// CHECK-LABEL: func @f2
|
||||
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
|
||||
// CHECK-DAG: %[[C_0:.*]] = memref.dim %[[C]], %c0{{[_0-9]*}} : memref<?x?xf32, strided<[?, ?]>>
|
||||
// CHECK-DAG: %[[C_1:.*]] = memref.dim %[[C]], %c1{{[_0-9]*}} : memref<?x?xf32, strided<[?, ?]>>
|
||||
// CHECK-DAG: %[[D_1:.*]] = memref.dim %[[D]], %c1{{[_0-9]*}} : memref<?x?xf32, strided<[?, ?]>>
|
||||
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} {
|
||||
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
|
||||
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK: linalg.matmul
|
||||
|
||||
// -----
|
||||
|
||||
func.func @f3(%A: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%B: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%C: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%D: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%E: memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
) -> memref<?x?xf32, strided<[?, ?], offset: 0>> {
|
||||
%c1 = arith.constant 1 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%c3 = arith.constant 3 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
linalg.matmul ins(%A, %B : memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
outs(%C : memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
%0 = memref.dim %D, %c0 : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
%1 = memref.dim %D, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
%2 = memref.dim %C, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
scf.for %arg5 = %c0 to %0 step %c2 {
|
||||
scf.for %arg6 = %c0 to %2 step %c3 {
|
||||
scf.for %arg7 = %c0 to %1 step %c4 {
|
||||
%5 = memref.subview %D[%arg5, %arg7][%c2, %c4][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%7 = memref.subview %C[%arg7, %arg6][%c4, %c3][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%8 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
linalg.matmul ins(%5, %7 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
outs(%8 : memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
}
|
||||
}
|
||||
}
|
||||
return %E : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
}
|
||||
// CHECK-LABEL: func @f3
|
||||
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
|
||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[D_0:.*]] = memref.dim %[[D]], %[[C0]] : memref<?x?xf32, strided<[?, ?]>>
|
||||
// CHECK: %[[D_1:.*]] = memref.dim %[[D]], %[[C1]] : memref<?x?xf32, strided<[?, ?]>>
|
||||
// CHECK: %[[C_1:.*]] = memref.dim %[[C]], %[[C1]] : memref<?x?xf32, strided<[?, ?]>>
|
||||
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} {
|
||||
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
|
||||
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK: linalg.matmul
|
||||
|
||||
// -----
|
||||
|
||||
func.func @f4(%A: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%B: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%C: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%D: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%E: memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
) -> memref<?x?xf32, strided<[?, ?], offset: 0>> {
|
||||
%c1 = arith.constant 1 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%c3 = arith.constant 3 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
linalg.matmul ins(%A, %B : memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
outs(%C : memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
linalg.matmul ins(%A, %B : memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
outs(%D : memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
%0 = memref.dim %C, %c0 : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
%1 = memref.dim %C, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
%2 = memref.dim %D, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
scf.for %arg5 = %c0 to %0 step %c2 {
|
||||
scf.for %arg6 = %c0 to %2 step %c3 {
|
||||
scf.for %arg7 = %c0 to %1 step %c4 {
|
||||
%5 = memref.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%7 = memref.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%8 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
linalg.matmul ins(%5, %7 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
outs(%8 : memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
}
|
||||
}
|
||||
}
|
||||
return %E : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
}
|
||||
// CHECK-LABEL: func @f4
|
||||
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
|
||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[C_0:.*]] = memref.dim %[[C]], %[[C0:.*]] : memref<?x?xf32, strided<[?, ?]>>
|
||||
// CHECK: %[[C_1:.*]] = memref.dim %[[C]], %[[C1:.*]] : memref<?x?xf32, strided<[?, ?]>>
|
||||
// CHECK: %[[D_1:.*]] = memref.dim %[[D]], %[[C1:.*]] : memref<?x?xf32, strided<[?, ?]>>
|
||||
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} {
|
||||
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
|
||||
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
|
||||
// Fuse D then fuse C, no false dependence prevent it.
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK: linalg.matmul
|
||||
|
||||
// -----
|
||||
|
||||
func.func @f5(%A: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%B: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%C: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%D: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%E: memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
) -> memref<?x?xf32, strided<[?, ?], offset: 0>> {
|
||||
%c1 = arith.constant 1 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%c3 = arith.constant 3 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
%0 = memref.dim %B, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
%1 = memref.dim %D, %c0 : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
%2 = memref.dim %D, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
linalg.matmul ins(%A, %B : memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
outs(%C : memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
linalg.matmul ins(%C, %B : memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
outs(%D : memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
scf.for %arg5 = %c0 to %1 step %c2 {
|
||||
scf.for %arg6 = %c0 to %0 step %c3 {
|
||||
scf.for %arg7 = %c0 to %2 step %c4 {
|
||||
%5 = memref.subview %D[%arg5, %arg7][%c2, %c4][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%7 = memref.subview %B[%arg7, %arg6][%c4, %c3][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%8 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
linalg.matmul ins(%5, %7 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
outs(%8 : memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
}
|
||||
}
|
||||
}
|
||||
return %E : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
}
|
||||
|
||||
// CHECK-DAG: #[[BOUND_2_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)>
|
||||
// CHECK-DAG: #[[BOUND_2_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 2)>
|
||||
// CHECK-DAG: #[[BOUND_4_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)>
|
||||
// CHECK: func @f5
|
||||
// CHECK-SAME: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
|
||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[A_0:.*]] = memref.dim %[[A]], %[[C0]] : memref<?x?xf32, strided<[?, ?]>>
|
||||
// CHECK-DAG: %[[B_1:.*]] = memref.dim %[[B]], %[[C1]] : memref<?x?xf32, strided<[?, ?]>>
|
||||
// CHECK-DAG: %[[C_0:.*]] = memref.dim %[[C]], %[[C0]] : memref<?x?xf32, strided<[?, ?]>>
|
||||
// CHECK-DAG: %[[D_0:.*]] = memref.dim %[[D]], %[[C0]] : memref<?x?xf32, strided<[?, ?]>>
|
||||
// CHECK-DAG: %[[D_1:.*]] = memref.dim %[[D]], %[[C1]] : memref<?x?xf32, strided<[?, ?]>>
|
||||
// CHECK-DAG: %[[B_00:.*]] = memref.subview %[[B]][0, 0]{{.*}}
|
||||
// CHECK: scf.for %[[I:.*]] = %{{.*}} to %[[D_0]] step %{{.*}} {
|
||||
// CHECK: %[[BOUND_2_C0:.+]] = affine.min #[[BOUND_2_MAP]](%[[I]])[%[[C_0]]]
|
||||
// CHECK: %[[C_I0:.*]] = memref.subview %[[C]][%[[I]], 0] [%[[BOUND_2_C0]]
|
||||
// CHECK: %[[BOUND_ID_C0:.+]] = affine.min #[[BOUND_2_MAP_2]](%[[I]])[%[[A_0]], %[[C_0]]]
|
||||
// CHECK: %[[A_I0:.*]] = memref.subview %[[A]][%[[I]], 0]
|
||||
// CHECK: %[[C_I0_OUT:.*]] = memref.subview %[[C]][%[[I]], 0] [%[[BOUND_ID_C0]]
|
||||
// CHECK: scf.for %[[J:.*]] = %{{.*}} to %[[B_1]] step %{{.*}} {
|
||||
// CHECK: %[[E_IJ:.*]] = memref.subview %[[E]][%[[I]], %[[J]]]
|
||||
// CHECK: scf.for %[[K:.*]] = %{{.*}} to %[[D_1]] step %{{.*}} {
|
||||
// CHECK: %[[D_IK:.*]] = memref.subview %[[D]][%[[I]], %[[K]]] [2, 4]
|
||||
// CHECK: %[[B_KJ:.*]] = memref.subview %[[B]][%[[K]], %[[J]]]
|
||||
// CHECK: %[[BOUND_4_B1:.*]] = affine.min #[[BOUND_4_MAP]](%[[K]])[%[[B_1]]]
|
||||
// CHECK: %[[B_0K:.*]] = memref.subview %[[B]][0, %[[K]]]
|
||||
// CHECK: %[[D_IK_OUT:.+]] = memref.subview %[[D]][%[[I]], %[[K]]] [%[[BOUND_2_C0]], %[[BOUND_4_B1]]]
|
||||
// CHECK: linalg.matmul ins(%[[A_I0]], %[[B_00]]{{.*}} outs(%[[C_I0_OUT]]
|
||||
// CHECK: linalg.matmul ins(%[[C_I0]], %[[B_0K]]{{.*}} outs(%[[D_IK_OUT]]
|
||||
// CHECK: linalg.matmul ins(%[[D_IK]], %[[B_KJ]]{{.*}} outs(%[[E_IJ]]
|
||||
|
||||
// -----
|
||||
|
||||
#map0 = affine_map<(d0) -> (d0 + 2)>
|
||||
#map1 = affine_map<(d0) -> (d0 + 4)>
|
||||
#map2 = affine_map<(d0) -> (d0 + 3)>
|
||||
|
||||
func.func @f6(%A: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%B: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%C: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%D: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%E: memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
) -> memref<?x?xf32, strided<[?, ?], offset: 0>> {
|
||||
%c1 = arith.constant 1 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%c3 = arith.constant 3 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
%0 = memref.dim %C, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
linalg.matmul ins(%A, %B : memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
outs(%C : memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
linalg.matmul ins(%A, %C : memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
outs(%E : memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
%1 = memref.dim %C, %c0 : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
%2 = memref.dim %D, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
scf.for %arg5 = %c0 to %1 step %c2 {
|
||||
scf.for %arg6 = %c0 to %2 step %c3 {
|
||||
scf.for %arg7 = %c0 to %0 step %c4 {
|
||||
%3 = affine.apply #map0(%arg5)
|
||||
%4 = affine.apply #map1(%arg7)
|
||||
%5 = memref.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%6 = affine.apply #map2(%arg6)
|
||||
%7 = memref.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%8 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
linalg.matmul ins(%5, %7 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
outs(%8 : memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
}
|
||||
}
|
||||
}
|
||||
return %E : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
}
|
||||
// CHECK-LABEL: func @f6
|
||||
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
|
||||
// Fuse the producer of E (WAW) then the producer of C (WAR).
|
||||
// CHECK: scf.for
|
||||
// CHECK: scf.for
|
||||
// CHECK: scf.for
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK: linalg.matmul
|
||||
|
||||
// -----
|
||||
|
||||
func.func @f7(%A: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%B: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%C: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%D: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%E: memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
) -> memref<?x?xf32, strided<[?, ?], offset: 0>> {
|
||||
%c1 = arith.constant 1 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%c3 = arith.constant 3 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
%0 = memref.dim %A, %c0 : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
%1 = memref.dim %A, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
%2 = memref.dim %C, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
%3 = memref.dim %C, %c0 : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
%4 = memref.dim %D, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
linalg.matmul ins(%A, %C : memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
outs(%E : memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
linalg.matmul ins(%A, %B : memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
outs(%C : memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
scf.for %arg5 = %c0 to %0 step %c2 {
|
||||
scf.for %arg6 = %c0 to %2 step %c3 {
|
||||
scf.for %arg7 = %c0 to %1 step %c4 {
|
||||
%7 = memref.subview %A[%arg5, %arg7][%c2, %c4][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%9 = memref.subview %C[%arg7, %arg6][%c4, %c3][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%10 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
linalg.matmul ins(%7, %9 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
outs(%10 : memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
}
|
||||
}
|
||||
}
|
||||
scf.for %arg5 = %c0 to %3 step %c2 {
|
||||
scf.for %arg6 = %c0 to %4 step %c3 {
|
||||
scf.for %arg7 = %c0 to %2 step %c4 {
|
||||
%7 = memref.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%9 = memref.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%10 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
linalg.matmul ins(%7, %9 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
outs(%10 : memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
}
|
||||
}
|
||||
}
|
||||
return %E : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
}
|
||||
// CHECK-LABEL: func @f7
|
||||
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
|
||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[A_0:.*]] = memref.dim %[[A]], %[[C0:.*]] : memref<?x?xf32, strided<[?, ?]>>
|
||||
// CHECK: %[[A_1:.*]] = memref.dim %[[A]], %[[C1:.*]] : memref<?x?xf32, strided<[?, ?]>>
|
||||
// CHECK: %[[C_1:.*]] = memref.dim %[[C]], %[[C1:.*]] : memref<?x?xf32, strided<[?, ?]>>
|
||||
// CHECK: %[[C_0:.*]] = memref.dim %[[C]], %[[C0:.*]] : memref<?x?xf32, strided<[?, ?]>>
|
||||
// CHECK: %[[D_1:.*]] = memref.dim %[[D]], %[[C1:.*]] : memref<?x?xf32, strided<[?, ?]>>
|
||||
// CHECK: linalg.matmul ins(%[[A]], %[[C]]{{.*}} outs(%[[E]]
|
||||
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[A_0]] step %{{.*}} {
|
||||
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
|
||||
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[A_1]] step %{{.*}} {
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} {
|
||||
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
|
||||
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK-NOT: linalg.matmul
|
||||
|
||||
// -----
|
||||
|
||||
#map0 = affine_map<(d0) -> (d0 + 2)>
|
||||
#map1 = affine_map<(d0) -> (d0 + 4)>
|
||||
#map2 = affine_map<(d0) -> (d0 + 3)>
|
||||
|
||||
func.func @f8(%A: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%B: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%C: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%D: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%E: memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
) -> memref<?x?xf32, strided<[?, ?], offset: 0>> {
|
||||
%c1 = arith.constant 1 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
%c3 = arith.constant 3 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
%0 = memref.dim %A, %c0 : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
%1 = memref.dim %A, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
linalg.matmul ins(%A, %C : memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
outs(%D : memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
linalg.matmul ins(%A, %B : memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
outs(%C : memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
%2 = memref.dim %D, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
scf.for %arg5 = %c0 to %0 step %c2 {
|
||||
scf.for %arg6 = %c0 to %2 step %c3 {
|
||||
scf.for %arg7 = %c0 to %1 step %c4 {
|
||||
%3 = affine.apply #map0(%arg5)
|
||||
%4 = affine.apply #map1(%arg7)
|
||||
%5 = memref.subview %A[%arg5, %arg7][%c2, %c4][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%6 = affine.apply #map2(%arg6)
|
||||
%7 = memref.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%8 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
linalg.matmul ins(%5, %7 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
outs(%8 : memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
}
|
||||
}
|
||||
}
|
||||
return %E : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
}
|
||||
// CHECK-LABEL: func @f8
|
||||
// CHECK: (%[[A:.*]]: memref{{.*}}, %[[B:.*]]: memref{{.*}}, %[[C:.*]]: memref{{.*}}, %[[D:.*]]: memref{{.*}}, %[[E:.*]]: memref{{.*}})
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK: scf.for
|
||||
// CHECK: scf.for
|
||||
// CHECK: scf.for
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK-NOT: linalg.matmul
|
||||
|
||||
// -----
|
||||
|
||||
#id_2d = affine_map<(i, j) -> (i, j)>
|
||||
#pointwise_2d_trait = {
|
||||
indexing_maps = [#id_2d, #id_2d, #id_2d],
|
||||
iterator_types = ["parallel", "parallel"]
|
||||
}
|
||||
func.func @pointwise(%A: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%B: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%C: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
%D: memref<?x?xf32, strided<[?, ?], offset: 0>>) {
|
||||
%c1 = arith.constant 1 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c3 = arith.constant 3 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
linalg.generic #pointwise_2d_trait
|
||||
ins(%A, %A: memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
outs(%B : memref<?x?xf32, strided<[?, ?], offset: 0>>) {
|
||||
^bb0(%E: f32, %arg5: f32, %arg6: f32):
|
||||
%2 = arith.addf %E, %arg5 : f32
|
||||
linalg.yield %2 : f32
|
||||
}
|
||||
%0 = memref.dim %B, %c0 : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
%1 = memref.dim %B, %c1 : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
scf.for %arg4 = %c0 to %0 step %c2 {
|
||||
scf.for %arg5 = %c0 to %1 step %c3 {
|
||||
%4 = memref.subview %B[%arg4, %arg5][%c2, %c3][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%5 = memref.subview %C[%arg4, %arg5][%c2, %c3][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%6 = memref.subview %D[%arg4, %arg5][%c2, %c3][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
linalg.generic #pointwise_2d_trait
|
||||
ins(%4, %5: memref<?x?xf32, strided<[?, ?], offset: ?>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
outs(%6 : memref<?x?xf32, strided<[?, ?], offset: ?>>) {
|
||||
^bb0(%arg6: f32, %arg7: f32, %arg8: f32):
|
||||
%7 = arith.mulf %arg6, %arg7 : f32
|
||||
linalg.yield %7 : f32
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @pointwise
|
||||
// CHECK: scf.for
|
||||
// CHECK: scf.for
|
||||
// CHECK-NOT: scf.for
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: arith.addf
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: arith.mulf
|
||||
|
||||
// -----
|
||||
|
||||
#id_2d = affine_map<(i, j) -> (i, j)>
|
||||
#pointwise_2d_trait = {
|
||||
indexing_maps = [#id_2d, #id_2d, #id_2d],
|
||||
iterator_types = ["parallel", "parallel"]
|
||||
}
|
||||
func.func @pointwise_no_view(%M: index, %N: index) {
|
||||
%c1 = arith.constant 1 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c3 = arith.constant 3 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
%A = memref.alloc (%M, %N): memref<?x?xf32>
|
||||
%B = memref.alloc (%M, %N): memref<?x?xf32>
|
||||
%C = memref.alloc (%M, %N): memref<?x?xf32>
|
||||
%D = memref.alloc (%M, %N): memref<?x?xf32>
|
||||
%E = memref.alloc (%M, %N): memref<?x?xf32>
|
||||
linalg.generic #pointwise_2d_trait
|
||||
ins(%A, %A : memref<?x?xf32>, memref<?x?xf32>)
|
||||
outs(%B : memref<?x?xf32>) {
|
||||
^bb0(%e: f32, %arg5: f32, %arg6: f32):
|
||||
%2 = arith.addf %e, %arg5 : f32
|
||||
linalg.yield %2 : f32
|
||||
}
|
||||
%0 = memref.dim %B, %c0 : memref<?x?xf32>
|
||||
%1 = memref.dim %B, %c1 : memref<?x?xf32>
|
||||
scf.for %arg4 = %c0 to %0 step %c2 {
|
||||
scf.for %arg5 = %c0 to %1 step %c3 {
|
||||
%4 = memref.subview %B[%arg4, %arg5][%c2, %c3][%c1, %c1] :
|
||||
memref<?x?xf32> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%5 = memref.subview %C[%arg4, %arg5][%c2, %c3][%c1, %c1] :
|
||||
memref<?x?xf32> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%6 = memref.subview %D[%arg4, %arg5][%c2, %c3][%c1, %c1] :
|
||||
memref<?x?xf32> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
linalg.generic #pointwise_2d_trait
|
||||
ins(%4, %5: memref<?x?xf32, strided<[?, ?], offset: ?>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
outs(%6 : memref<?x?xf32, strided<[?, ?], offset: ?>>) {
|
||||
^bb0(%arg6: f32, %arg7: f32, %arg8: f32):
|
||||
%7 = arith.mulf %arg6, %arg7 : f32
|
||||
linalg.yield %7 : f32
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @pointwise_no_view
|
||||
// CHECK: scf.for
|
||||
// CHECK: scf.for
|
||||
// CHECK-NOT: scf.for
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: arith.addf
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: arith.mulf
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1) -> (d0)>
|
||||
#map1 = affine_map<(d0, d1) -> (d0, d1)>
|
||||
|
||||
func.func @fusion_of_three(%arg0: memref<100x10xf32>,
|
||||
%arg1: memref<100xf32>,
|
||||
%arg2: memref<100x10xf32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%0 = memref.alloc() {temp = true} : memref<100x10xf32>
|
||||
linalg.generic {
|
||||
indexing_maps = [#map0, #map1],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%arg1 : memref<100xf32>)
|
||||
outs(%0 : memref<100x10xf32>) {
|
||||
^bb0(%arg3: f32, %arg4: f32):
|
||||
linalg.yield %arg3 : f32
|
||||
}
|
||||
%1 = memref.alloc() {temp = true} : memref<100x10xf32>
|
||||
linalg.generic {
|
||||
indexing_maps = [#map1, #map1, #map1],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%arg0, %0: memref<100x10xf32>, memref<100x10xf32>)
|
||||
outs(%1 : memref<100x10xf32>) {
|
||||
^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
|
||||
%2 = arith.subf %arg3, %arg4 : f32
|
||||
linalg.yield %2 : f32
|
||||
}
|
||||
memref.dealloc %0 : memref<100x10xf32>
|
||||
%2 = memref.dim %1, %c0 : memref<100x10xf32>
|
||||
%3 = memref.dim %1, %c1 : memref<100x10xf32>
|
||||
%4 = memref.dim %arg2, %c0 : memref<100x10xf32>
|
||||
%5 = memref.dim %arg2, %c1 : memref<100x10xf32>
|
||||
scf.for %i = %c0 to %2 step %c1 {
|
||||
scf.for %j = %c0 to %3 step %c1 {
|
||||
%6 = memref.subview %1[%i, %j][%c1, %c1][%c1, %c1] :
|
||||
memref<100x10xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%7 = memref.subview %arg2[%i, %j][%c1, %c1][%c1, %c1] :
|
||||
memref<100x10xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
linalg.generic {
|
||||
indexing_maps = [#map1, #map1],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%6 : memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
outs(%7 : memref<?x?xf32, strided<[?, ?], offset: ?>>) {
|
||||
^bb0(%arg3: f32, %arg4: f32):
|
||||
%8 = math.exp %arg3 : f32
|
||||
linalg.yield %8 : f32
|
||||
}
|
||||
}
|
||||
}
|
||||
memref.dealloc %1 : memref<100x10xf32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @fusion
|
||||
// CHECK-NOT: linalg.generic
|
||||
// CHECK: scf.for
|
||||
// CHECK: scf.for
|
||||
// CHECK-NOT: scf.for
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: linalg.yield
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: arith.subf
|
||||
// CHECK: linalg.yield
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: exp
|
||||
// CHECK: linalg.yield
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
#map0 = affine_map<(d0)[s0] -> (2, -d0 + s0)>
|
||||
#map1 = affine_map<(d0)[s0] -> (3, -d0 + s0)>
|
||||
#map3 = affine_map<(d0)[s0, s1] -> (s0 + 1, -d0 + s0 + s1)>
|
||||
#map4 = affine_map<(d0)[s0, s1] -> (s0 + 2, -d0 + s0 + s1)>
|
||||
|
||||
func.func @fill_and_conv(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
%c2 = arith.constant 2 : index
|
||||
%c3 = arith.constant 3 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
linalg.fill ins(%cst : f32) outs(%arg0 : memref<?x?xf32>)
|
||||
%2 = memref.dim %arg1, %c0 : memref<?x?xf32>
|
||||
%3 = memref.dim %arg1, %c1 : memref<?x?xf32>
|
||||
%4 = memref.dim %arg2, %c0 : memref<?x?xf32>
|
||||
%5 = memref.dim %arg2, %c1 : memref<?x?xf32>
|
||||
scf.for %arg3 = %c0 to %4 step %c2 {
|
||||
scf.for %arg4 = %c0 to %5 step %c3 {
|
||||
%6 = affine.min #map3(%arg3)[%2, %4]
|
||||
%7 = affine.min #map4(%arg4)[%3, %5]
|
||||
%8 = memref.subview %arg0[%arg3, %arg4] [%6, %7] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
|
||||
%9 = affine.min #map0(%arg3)[%4]
|
||||
%10 = affine.min #map1(%arg4)[%5]
|
||||
%11 = memref.subview %arg2[%arg3, %arg4] [%9, %10] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
|
||||
linalg.conv_2d ins(%8, %arg1 : memref<?x?xf32, strided<[?, 1], offset: ?>>, memref<?x?xf32>) outs(%11 : memref<?x?xf32, strided<[?, 1], offset: ?>>)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @fill_and_conv
|
||||
// CHECK: scf.for
|
||||
// CHECK: scf.for
|
||||
// CHECK: linalg.fill
|
||||
// CHECK: linalg.conv_2d
|
||||
|
||||
// -----
|
||||
|
||||
// Test that different allocation-like ops are recognized and properly handled.
|
||||
func.func @accept_different_alloc_ops(%dim: index, %s0 : index, %s1: index) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
%c3 = arith.constant 3 : index
|
||||
%c4 = arith.constant 4 : index
|
||||
|
||||
%A = memref.alloca(%dim, %dim)[%s0, %s1] : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
%B = memref.alloca(%dim, %dim)[%s0, %s1] : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
%C = memref.alloc(%dim, %dim)[%s0, %s1] : memref<?x?xf32, strided<[?, ?], offset: 0>>
|
||||
|
||||
linalg.matmul ins(%A, %B : memref<?x?xf32, strided<[?, ?], offset: 0>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
outs(%C : memref<?x?xf32, strided<[?, ?], offset: 0>>)
|
||||
|
||||
scf.for %i = %c0 to %dim step %c2 {
|
||||
scf.for %j = %c0 to %dim step %c3 {
|
||||
scf.for %k = %c0 to %dim step %c4 {
|
||||
%0 = memref.subview %A[%i, %k][%c2, %c4][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%1 = memref.subview %B[%k, %j][%c4, %c3][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
%2 = memref.subview %C[%i, %j][%c2, %c3][%c1, %c1] :
|
||||
memref<?x?xf32, strided<[?, ?], offset: 0>> to
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
linalg.matmul ins(%0, %1 : memref<?x?xf32, strided<[?, ?], offset: ?>>,
|
||||
memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
outs(%2 : memref<?x?xf32, strided<[?, ?], offset: ?>>)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @accept_different_alloc_ops
|
||||
// CHECK-COUNT-3: scf.for
|
||||
// CHECK-COUNT-2: linalg.matmul
|
||||
@@ -12,7 +12,6 @@
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
@@ -39,22 +38,9 @@ static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) {
|
||||
bool changed = false;
|
||||
for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
|
||||
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
|
||||
if (opOperand.get().getType().isa<MemRefType>()) {
|
||||
// TODO: LinalgDependenceGraph should be able to update itself.
|
||||
// The current naive and expensive reconstruction of the graph should be
|
||||
// removed.
|
||||
linalg::Aliases aliases;
|
||||
linalg::LinalgDependenceGraph graph(aliases, linalgOps);
|
||||
auto info = fuseProducerOfBuffer(b, opOperand, graph);
|
||||
if (failed(info))
|
||||
continue;
|
||||
auto *originalOp = info->originalProducer.getOperation();
|
||||
eraseSet.insert(originalOp);
|
||||
auto *originalOpInLinalgOpsVector =
|
||||
std::find(linalgOps.begin(), linalgOps.end(), originalOp);
|
||||
*originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
|
||||
changed = true;
|
||||
} else if (opOperand.get().getType().isa<RankedTensorType>()) {
|
||||
if (opOperand.get().getType().isa<MemRefType>())
|
||||
continue;
|
||||
if (opOperand.get().getType().isa<RankedTensorType>()) {
|
||||
// Tile and Fuse tensor input.
|
||||
if (opOperand.getOperandNumber() >= linalgOp.getNumDpsInputs())
|
||||
continue;
|
||||
|
||||
@@ -8384,25 +8384,6 @@ gentbl_cc_library(
|
||||
deps = [":PassBaseTdFiles"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "LinalgAnalysis",
|
||||
srcs = glob([
|
||||
"lib/Dialect/Linalg/Analysis/*.cpp",
|
||||
"lib/Dialect/Linalg/Analysis/*.h",
|
||||
]),
|
||||
hdrs = glob([
|
||||
"include/mlir/Dialect/Linalg/Analysis/*.h",
|
||||
]),
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":BufferizationDialect",
|
||||
":FuncDialect",
|
||||
":IR",
|
||||
":LinalgDialect",
|
||||
"//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "LinalgUtils",
|
||||
srcs = glob([
|
||||
@@ -8422,7 +8403,6 @@ cc_library(
|
||||
":DialectUtils",
|
||||
":FuncDialect",
|
||||
":IR",
|
||||
":LinalgAnalysis",
|
||||
":LinalgDialect",
|
||||
":MemRefDialect",
|
||||
":Pass",
|
||||
@@ -8462,7 +8442,6 @@ cc_library(
|
||||
":FuncDialect",
|
||||
":FuncTransforms",
|
||||
":IR",
|
||||
":LinalgAnalysis",
|
||||
":LinalgDialect",
|
||||
":LinalgPassIncGen",
|
||||
":LinalgStructuredOpsIncGen",
|
||||
|
||||
@@ -564,7 +564,6 @@ cc_library(
|
||||
"//mlir:FuncTransforms",
|
||||
"//mlir:GPUDialect",
|
||||
"//mlir:IR",
|
||||
"//mlir:LinalgAnalysis",
|
||||
"//mlir:LinalgDialect",
|
||||
"//mlir:LinalgTransforms",
|
||||
"//mlir:LinalgUtils",
|
||||
|
||||
Reference in New Issue
Block a user