[mlir][sparse] move tensor expression builder into Merger utility
Rationale: Follow-up on migrating lattice and tensor expression related methods into the new utility. This also prepares the next step of generalizing the op kinds that are handled. Reviewed By: gussmith23 Differential Revision: https://reviews.llvm.org/D105219
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
#ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
|
||||
#define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "llvm/ADT/BitVector.h"
|
||||
|
||||
@@ -148,11 +149,6 @@ public:
|
||||
/// Returns true if any set bit corresponds to queried dim.
|
||||
bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const;
|
||||
|
||||
/// Builds the iteration lattices in a bottom-up traversal given the remaining
|
||||
/// tensor (sub)expression and the next loop index in the iteration graph.
|
||||
/// Returns index of the root expression.
|
||||
unsigned buildLattices(unsigned exp, unsigned idx);
|
||||
|
||||
/// Setter
|
||||
void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
|
||||
|
||||
@@ -169,7 +165,19 @@ public:
|
||||
void dumpBits(const llvm::BitVector &bits) const;
|
||||
#endif
|
||||
|
||||
/// Builds the iteration lattices in a bottom-up traversal given the remaining
|
||||
/// tensor (sub)expression and the next loop index in the iteration graph.
|
||||
/// Returns index of the root expression.
|
||||
unsigned buildLattices(unsigned exp, unsigned idx);
|
||||
|
||||
/// Builds a tensor expression from the given Linalg operation.
|
||||
/// Returns index of the root expression on success.
|
||||
Optional<unsigned> buildTensorExpFromLinalg(linalg::GenericOp op);
|
||||
|
||||
private:
|
||||
/// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
|
||||
Optional<unsigned> buildTensorExp(linalg::GenericOp op, Value val);
|
||||
|
||||
const unsigned outTensor;
|
||||
const unsigned syntheticTensor;
|
||||
const unsigned numTensors;
|
||||
|
||||
@@ -208,51 +208,6 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
|
||||
/// This simplifies constructing (sub)expressions during iteration lattice
|
||||
/// building (compared to using the SSA representation everywhere).
|
||||
static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op,
|
||||
Value val) {
|
||||
if (auto arg = val.dyn_cast<BlockArgument>()) {
|
||||
unsigned argN = arg.getArgNumber();
|
||||
// Any argument of the generic op that is not marked as a scalar
|
||||
// argument is considered a tensor, indexed by the implicit loop
|
||||
// bounds. This includes rank-0 tensor arguments.
|
||||
if (arg.getOwner()->getParentOp() == op) {
|
||||
OpOperand *t = op.getInputAndOutputOperands()[argN];
|
||||
if (!op.isScalar(t))
|
||||
return merger.addExp(Kind::kTensor, argN);
|
||||
val = t->get(); // get scalar value
|
||||
}
|
||||
// Any other argument (marked as scalar argument for the generic op
|
||||
// or belonging to an enveloping op) is considered invariant.
|
||||
return merger.addExp(Kind::kInvariant, val);
|
||||
}
|
||||
Operation *def = val.getDefiningOp();
|
||||
if (def->getBlock() != &op.region().front()) {
|
||||
// Something defined outside is invariant.
|
||||
return merger.addExp(Kind::kInvariant, val);
|
||||
} else if (def->getNumOperands() == 2) {
|
||||
// Construct binary operations if subexpressions could be built.
|
||||
auto x = buildTensorExp(merger, op, def->getOperand(0));
|
||||
auto y = buildTensorExp(merger, op, def->getOperand(1));
|
||||
if (x.hasValue() && y.hasValue()) {
|
||||
unsigned e0 = x.getValue();
|
||||
unsigned e1 = y.getValue();
|
||||
if (isa<MulFOp>(def))
|
||||
return merger.addExp(Kind::kMulF, e0, e1);
|
||||
if (isa<MulIOp>(def))
|
||||
return merger.addExp(Kind::kMulI, e0, e1);
|
||||
if (isa<AddFOp>(def))
|
||||
return merger.addExp(Kind::kAddF, e0, e1);
|
||||
if (isa<AddIOp>(def))
|
||||
return merger.addExp(Kind::kAddI, e0, e1);
|
||||
}
|
||||
}
|
||||
// Cannot build (yet).
|
||||
return None;
|
||||
}
|
||||
|
||||
/// Returns true if given tensor co-iterates with conjunction only.
|
||||
/// For the output tensor, this defines a "simply dynamic" operation.
|
||||
/// For instance: A(I) = A(I) * B(I) * C(I)
|
||||
@@ -1224,14 +1179,12 @@ public:
|
||||
!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true))
|
||||
return failure();
|
||||
|
||||
// Finds the terminating yield statement and builds the tensor
|
||||
// expression for the Linalg operation in SSA form.
|
||||
Operation *yield = op.region().front().getTerminator();
|
||||
Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0));
|
||||
// Builds the tensor expression for the Linalg operation in SSA form.
|
||||
Optional<unsigned> exp = merger.buildTensorExpFromLinalg(op);
|
||||
if (!exp.hasValue())
|
||||
return failure(); // build failure
|
||||
return failure();
|
||||
|
||||
// Reject an inadmissable tensor expression.
|
||||
// Rejects an inadmissable tensor expression.
|
||||
if (!isAdmissableTensorExp(merger, op, exp.getValue()))
|
||||
return failure();
|
||||
|
||||
|
||||
@@ -6,4 +6,5 @@ add_mlir_dialect_library(MLIRSparseTensorUtils
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRLinalg
|
||||
)
|
||||
|
||||
@@ -14,6 +14,10 @@
|
||||
namespace mlir {
|
||||
namespace sparse_tensor {
|
||||
|
||||
//
|
||||
// Lattice methods.
|
||||
//
|
||||
|
||||
unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
|
||||
unsigned e = tensorExps.size();
|
||||
tensorExps.push_back(TensorExp(k, e0, e1, v));
|
||||
@@ -68,7 +72,7 @@ unsigned Merger::optimizeSet(unsigned s0) {
|
||||
if (p0 != p1) {
|
||||
// Is this a straightforward copy?
|
||||
unsigned e = latPoints[p1].exp;
|
||||
if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
|
||||
if (tensorExps[e].kind == Kind::kTensor && tensorExps[e].e0 == outTensor)
|
||||
continue;
|
||||
// Conjunction already covered?
|
||||
for (unsigned p2 : latSets[s]) {
|
||||
@@ -137,33 +141,6 @@ bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
unsigned Merger::buildLattices(unsigned e, unsigned idx) {
|
||||
Kind kind = exp(e).kind;
|
||||
if (kind == Kind::kTensor || kind == Kind::kInvariant) {
|
||||
// Either the index is really used in the tensor expression, or it is
|
||||
// set to the undefined index in that dimension. An invariant expression
|
||||
// is set to a synthetic tensor with undefined indices only.
|
||||
unsigned s = addSet();
|
||||
unsigned t = kind == Kind::kTensor ? exp(e).e0 : syntheticTensor;
|
||||
set(s).push_back(addLat(t, idx, e));
|
||||
return s;
|
||||
}
|
||||
unsigned s0 = buildLattices(exp(e).e0, idx);
|
||||
unsigned s1 = buildLattices(exp(e).e1, idx);
|
||||
switch (kind) {
|
||||
case Kind::kTensor:
|
||||
case Kind::kInvariant:
|
||||
llvm_unreachable("handled above");
|
||||
case Kind::kMulF:
|
||||
case Kind::kMulI:
|
||||
return takeConj(kind, s0, s1);
|
||||
case Kind::kAddF:
|
||||
case Kind::kAddI:
|
||||
return takeDisj(kind, s0, s1);
|
||||
}
|
||||
llvm_unreachable("unexpected expression kind");
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
|
||||
//
|
||||
@@ -173,6 +150,10 @@ unsigned Merger::buildLattices(unsigned e, unsigned idx) {
|
||||
void Merger::dumpExp(unsigned e) const {
|
||||
switch (tensorExps[e].kind) {
|
||||
case Kind::kTensor:
|
||||
if (tensorExps[e].e0 == syntheticTensor)
|
||||
llvm::dbgs() << "synthetic_";
|
||||
else if (tensorExps[e].e0 == outTensor)
|
||||
llvm::dbgs() << "output_";
|
||||
llvm::dbgs() << "tensor_" << tensorExps[e].e0;
|
||||
break;
|
||||
case Kind::kInvariant:
|
||||
@@ -242,5 +223,82 @@ void Merger::dumpBits(const llvm::BitVector &bits) const {
|
||||
|
||||
#endif // NDEBUG
|
||||
|
||||
//
|
||||
// Builder methods.
|
||||
//
|
||||
|
||||
unsigned Merger::buildLattices(unsigned e, unsigned idx) {
|
||||
Kind kind = tensorExps[e].kind;
|
||||
if (kind == Kind::kTensor || kind == Kind::kInvariant) {
|
||||
// Either the index is really used in the tensor expression, or it is
|
||||
// set to the undefined index in that dimension. An invariant expression
|
||||
// is set to a synthetic tensor with undefined indices only.
|
||||
unsigned s = addSet();
|
||||
unsigned t = kind == Kind::kTensor ? tensorExps[e].e0 : syntheticTensor;
|
||||
latSets[s].push_back(addLat(t, idx, e));
|
||||
return s;
|
||||
}
|
||||
unsigned s0 = buildLattices(tensorExps[e].e0, idx);
|
||||
unsigned s1 = buildLattices(tensorExps[e].e1, idx);
|
||||
switch (kind) {
|
||||
case Kind::kTensor:
|
||||
case Kind::kInvariant:
|
||||
llvm_unreachable("handled above");
|
||||
case Kind::kMulF:
|
||||
case Kind::kMulI:
|
||||
return takeConj(kind, s0, s1);
|
||||
case Kind::kAddF:
|
||||
case Kind::kAddI:
|
||||
return takeDisj(kind, s0, s1);
|
||||
}
|
||||
llvm_unreachable("unexpected expression kind");
|
||||
}
|
||||
|
||||
Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
|
||||
Operation *yield = op.region().front().getTerminator();
|
||||
return buildTensorExp(op, yield->getOperand(0));
|
||||
}
|
||||
|
||||
Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value val) {
|
||||
if (auto arg = val.dyn_cast<BlockArgument>()) {
|
||||
unsigned argN = arg.getArgNumber();
|
||||
// Any argument of the generic op that is not marked as a scalar
|
||||
// argument is considered a tensor, indexed by the implicit loop
|
||||
// bounds. This includes rank-0 tensor arguments.
|
||||
if (arg.getOwner()->getParentOp() == op) {
|
||||
OpOperand *t = op.getInputAndOutputOperands()[argN];
|
||||
if (!op.isScalar(t))
|
||||
return addExp(Kind::kTensor, argN);
|
||||
val = t->get(); // get scalar value
|
||||
}
|
||||
// Any other argument (marked as scalar argument for the generic op
|
||||
// or belonging to an enveloping op) is considered invariant.
|
||||
return addExp(Kind::kInvariant, val);
|
||||
}
|
||||
// Something defined outside is invariant.
|
||||
Operation *def = val.getDefiningOp();
|
||||
if (def->getBlock() != &op.region().front())
|
||||
return addExp(Kind::kInvariant, val);
|
||||
// Construct binary operations if subexpressions could be built.
|
||||
if (def->getNumOperands() == 2) {
|
||||
auto x = buildTensorExp(op, def->getOperand(0));
|
||||
auto y = buildTensorExp(op, def->getOperand(1));
|
||||
if (x.hasValue() && y.hasValue()) {
|
||||
unsigned e0 = x.getValue();
|
||||
unsigned e1 = y.getValue();
|
||||
if (isa<MulFOp>(def))
|
||||
return addExp(Kind::kMulF, e0, e1);
|
||||
if (isa<MulIOp>(def))
|
||||
return addExp(Kind::kMulI, e0, e1);
|
||||
if (isa<AddFOp>(def))
|
||||
return addExp(Kind::kAddF, e0, e1);
|
||||
if (isa<AddIOp>(def))
|
||||
return addExp(Kind::kAddI, e0, e1);
|
||||
}
|
||||
}
|
||||
// Cannot build.
|
||||
return None;
|
||||
}
|
||||
|
||||
} // namespace sparse_tensor
|
||||
} // namespace mlir
|
||||
|
||||
Reference in New Issue
Block a user