[mlir][sparse] introduce sparse_tensor.iterate operation (#88955)

A `sparse_tensor.iterate` iterates over a sparse iteration space
extracted from `sparse_tensor.extract_iteration_space` operation
introduced in https://github.com/llvm/llvm-project/pull/88554.
This commit is contained in:
Peiming Liu
2024-06-10 10:20:24 -07:00
committed by GitHub
parent 1d96e4bc2d
commit e276cf0831
7 changed files with 538 additions and 8 deletions

View File

@@ -17,9 +17,13 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TensorEncoding.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/ADT/bit.h"
//===----------------------------------------------------------------------===//
//
// Type aliases to help code be more self-documenting. Unfortunately
@@ -54,6 +58,42 @@ struct COOSegment {
}
};
/// A simple wrapper to encode a bitset of (at most 64) levels, currently used
/// by `sparse_tensor.iterate` operation for the set of levels on which the
/// coordinates should be loaded.
class LevelSet {
uint64_t bits = 0;
public:
LevelSet() = default;
explicit LevelSet(uint64_t bits) : bits(bits) {}
operator uint64_t() const { return bits; }
LevelSet &set(unsigned i) {
assert(i < 64);
bits |= static_cast<uint64_t>(0x01u) << i;
return *this;
}
LevelSet &operator|=(LevelSet lhs) {
bits |= static_cast<uint64_t>(lhs);
return *this;
}
LevelSet &lshift(unsigned offset) {
bits = bits << offset;
return *this;
}
bool operator[](unsigned i) const {
assert(i < 64);
return (bits & (1 << i)) != 0;
}
unsigned count() const { return llvm::popcount(bits); }
bool empty() const { return bits == 0; }
};
} // namespace sparse_tensor
} // namespace mlir

View File

@@ -19,6 +19,21 @@ class SparseTensor_Attr<string name,
list<Trait> traits = []>
: AttrDef<SparseTensor_Dialect, name, traits>;
//===----------------------------------------------------------------------===//
// A simple bitset attribute wrapped around a single int64_t to encode a set of
// sparse tensor levels.
//===----------------------------------------------------------------------===//
def LevelSetAttr :
TypedAttrBase<
I64, "IntegerAttr",
And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType().isInteger(64)">]>,
"LevelSet attribute"> {
let returnType = [{::mlir::sparse_tensor::LevelSet}];
let convertFromStorage = [{::mlir::sparse_tensor::LevelSet($_self.getValue().getZExtValue())}];
}
//===----------------------------------------------------------------------===//
// These attributes are just like `IndexAttr` except that they clarify whether
// the index refers to a dimension (an axis of the semantic tensor) or a level

View File

@@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
//===----------------------------------------------------------------------===//
// Base class.
@@ -1304,7 +1306,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
"ForeachOp"]>]> {
"ForeachOp", "IterateOp"]>]> {
let summary = "Yield from sparse_tensor set-like operations";
let description = [{
Yields a value from within a `binary`, `unary`, `reduce`,
@@ -1476,7 +1478,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the
iteration space.
The type of returned the value is automatically inferred to
The type of returned the value is must be
`!sparse_tensor.iter_space<#INPUT_ENCODING, lvls = $loLvl to $hiLvl>`.
The returned iteration space can then be iterated over by
`sparse_tensor.iterate` operations to visit every stored element
@@ -1487,6 +1489,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
// Extracts a 1-D iteration space from a COO tensor at level 1.
%space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
->!sparse_tensor.iter_space<#COO, lvls = 1>
```
}];
@@ -1499,20 +1502,120 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
return getHiLvl() - getLoLvl();
}
ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
return getResultSpace().getType().getLvlTypes();
return getExtractedSpace().getType().getLvlTypes();
}
}];
let arguments = (ins AnySparseTensor:$tensor,
Optional<AnySparseIterator>:$parentIter,
LevelAttr:$loLvl, LevelAttr:$hiLvl);
let results = (outs AnySparseIterSpace:$resultSpace);
let results = (outs AnySparseIterSpace:$extractedSpace);
let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
" attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
" attr-dict `:` type($tensor) (`,` type($parentIter)^)? "
"`->` qualified(type($extractedSpace))";
let hasVerifier = 1;
}
def IterateOp : SparseTensor_Op<"iterate",
[RecursiveMemoryEffects, RecursivelySpeculatable,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getInitsMutable", "getLoopResults", "getRegionIterArgs",
"getYieldedValuesMutable"]>,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands"]>,
SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
let summary = "Iterates over a sparse iteration space";
let description = [{
The `sparse_tensor.iterate` operation represents a loop (nest) over
the provided iteration space extracted from a specific sparse tensor.
The operation defines an SSA value for a sparse iterator that points
to the current stored element in the sparse tensor and SSA values
for coordinates of the stored element. The coordinates are always
converted to `index` type despite of the underlying sparse tensor
storage. When coordinates are not used, the SSA values can be skipped
by `_` symbols, which usually leads to simpler generated code after
sparsification. For example:
```mlir
// The coordinate for level 0 is not used when iterating over a 2-D
// iteration space.
%sparse_tensor.iterate %iterator in %space at(_, %crd_1)
: !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
```
`sparse_tensor.iterate` can also operate on loop-carried variables.
It returns the final values after loop termination.
The initial values of the variables are passed as additional SSA operands
to the iterator SSA value and used coordinate SSA values mentioned
above. The operation region has an argument for the iterator, variadic
arguments for specified (used) coordiates and followed by one argument
for each loop-carried variable, representing the value of the variable
at the current iteration.
The body region must contain exactly one block that terminates with
`sparse_tensor.yield`.
The results of an `sparse_tensor.iterate` hold the final values after
the last iteration. If the `sparse_tensor.iterate` defines any values,
a yield must be explicitly present.
The number and types of the `sparse_tensor.iterate` results must match
the initial values in the iter_args binding and the yield operands.
A nested `sparse_tensor.iterate` example that prints all the coordinates
stored in the sparse input:
```mlir
func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
// Iterates over the first level of %sp
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
: tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0 to 1>
%r1 = sparse_tensor.iterate %it1 in %l1 at (%coord0)
: !sparse_tensor.iter_space<#COO, lvls = 0 to 1> {
// Iterates over the second level of %sp
%l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
-> !sparse_tensor.iter_space<#COO, lvls = 1 to 2>
%r2 = sparse_tensor.iterate %it2 in %l2 at (coord1)
: !sparse_tensor.iter_space<#COO, lvls = 1 to 2> {
vector.print %coord0 : index
vector.print %coord1 : index
}
}
}
```
}];
let arguments = (ins AnySparseIterSpace:$iterSpace,
Variadic<AnyType>:$initArgs,
LevelSetAttr:$crdUsedLvls);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$region);
let extraClassDeclaration = [{
unsigned getSpaceDim() {
return getIterSpace().getType().getSpaceDim();
}
BlockArgument getIterator() {
return getRegion().getArguments().front();
}
Block::BlockArgListType getCrds() {
// The first block argument is iterator, the remaining arguments are
// referenced coordinates.
return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
}
unsigned getNumRegionIterArgs() {
return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
}
}];
let hasVerifier = 1;
let hasRegionVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
// Sparse Tensor Debugging and Test-Only Operations.
//===----------------------------------------------------------------------===//

View File

@@ -2130,6 +2130,106 @@ static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
printLevelRange(p, lo, hi);
}
static ParseResult
parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
SmallVectorImpl<OpAsmParser::Argument> &iterators,
SmallVectorImpl<OpAsmParser::Argument> &iterArgs) {
SmallVector<OpAsmParser::UnresolvedOperand> spaces;
SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
// Parse "%iters, ... in %spaces, ..."
if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
parser.parseOperandList(spaces))
return failure();
if (iterators.size() != spaces.size())
return parser.emitError(
parser.getNameLoc(),
"mismatch in number of sparse iterators and sparse spaces");
// Parse "at(%crd0, _, ...)"
LevelSet crdUsedLvlSet;
bool hasUsedCrds = succeeded(parser.parseOptionalKeyword("at"));
unsigned lvlCrdCnt = 0;
if (hasUsedCrds) {
ParseResult crdList = parser.parseCommaSeparatedList(
OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
if (parser.parseOptionalKeyword("_")) {
if (parser.parseArgument(iterArgs.emplace_back()))
return failure();
// Always use IndexType for the coordinate.
crdUsedLvlSet.set(lvlCrdCnt);
iterArgs.back().type = parser.getBuilder().getIndexType();
}
lvlCrdCnt += 1;
return success();
});
if (failed(crdList)) {
return parser.emitError(
parser.getNameLoc(),
"expecting SSA value or \"_\" for level coordinates");
}
}
// Set the CrdUsedLvl bitset.
state.addAttribute("crdUsedLvls",
parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
// Parse "iter_args(%arg = %init, ...)"
bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
if (hasIterArgs)
if (parser.parseAssignmentList(iterArgs, initArgs))
return failure();
SmallVector<Type> iterSpaceTps;
// parse ": sparse_tensor.iter_space -> ret"
if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
return failure();
if (iterSpaceTps.size() != spaces.size())
return parser.emitError(parser.getNameLoc(),
"mismatch in number of iteration space operands "
"and iteration space types");
for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp);
if (!spaceTp)
return parser.emitError(parser.getNameLoc(),
"expected sparse_tensor.iter_space type for "
"iteration space operands");
if (hasUsedCrds && spaceTp.getSpaceDim() != lvlCrdCnt)
return parser.emitError(parser.getNameLoc(),
"mismatch in number of iteration space dimension "
"and specified coordinates");
it.type = spaceTp.getIteratorType();
}
if (hasIterArgs)
if (parser.parseArrowTypeList(state.types))
return failure();
// Resolves input operands.
if (parser.resolveOperands(spaces, iterSpaceTps, parser.getNameLoc(),
state.operands))
return failure();
if (hasIterArgs) {
unsigned numCrds = crdUsedLvlSet.count();
// Strip off leading args that used for coordinates.
MutableArrayRef args = MutableArrayRef(iterArgs).drop_front(numCrds);
if (args.size() != initArgs.size() || args.size() != state.types.size()) {
return parser.emitError(
parser.getNameLoc(),
"mismatch in number of iteration arguments and return values");
}
for (auto [it, init, tp] : llvm::zip_equal(args, initArgs, state.types)) {
it.type = tp;
if (parser.resolveOperand(init, tp, state.operands))
return failure();
}
}
return success();
}
LogicalResult ExtractIterSpaceOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
@@ -2153,7 +2253,7 @@ LogicalResult ExtractIterSpaceOp::verify() {
}
if (pIter) {
IterSpaceType spaceTp = getResultSpace().getType();
IterSpaceType spaceTp = getExtractedSpace().getType();
if (pIter.getType().getEncoding() != spaceTp.getEncoding())
return emitOpError(
"mismatch in parent iterator encoding and iteration space encoding.");
@@ -2166,6 +2266,161 @@ LogicalResult ExtractIterSpaceOp::verify() {
return success();
}
ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::Argument iterator;
OpAsmParser::UnresolvedOperand iterSpace;
SmallVector<OpAsmParser::Argument> iters, iterArgs;
if (parseSparseSpaceLoop(parser, result, iters, iterArgs))
return failure();
if (iters.size() != 1)
return parser.emitError(parser.getNameLoc(),
"expected only one iterator/iteration space");
iters.append(iterArgs);
Region *body = result.addRegion();
if (parser.parseRegion(*body, iters))
return failure();
IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
// Parse the optional attribute list.
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
return success();
}
/// Prints the initialization list in the form of
/// <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
/// where 'inner' values are assumed to be region arguments and 'outer' values
/// are regular SSA values.
static void printInitializationList(OpAsmPrinter &p,
Block::BlockArgListType blocksArgs,
ValueRange initializers,
StringRef prefix = "") {
assert(blocksArgs.size() == initializers.size() &&
"expected same length of arguments and initializers");
if (initializers.empty())
return;
p << prefix << '(';
llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
p << std::get<0>(it) << " = " << std::get<1>(it);
});
p << ")";
}
static void printUsedCrdsList(OpAsmPrinter &p, unsigned spaceDim,
Block::BlockArgListType blocksArgs,
LevelSet crdUsedLvls) {
if (crdUsedLvls.empty())
return;
p << " at(";
for (unsigned i = 0; i < spaceDim; i++) {
if (crdUsedLvls[i]) {
p << blocksArgs.front();
blocksArgs = blocksArgs.drop_front();
} else {
p << "_";
}
if (i != spaceDim - 1)
p << ", ";
}
assert(blocksArgs.empty());
p << ")";
}
void IterateOp::print(OpAsmPrinter &p) {
p << " " << getIterator() << " in " << getIterSpace();
printUsedCrdsList(p, getSpaceDim(), getCrds(), getCrdUsedLvls());
printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args");
p << " : " << getIterSpace().getType() << " ";
if (!getInitArgs().empty())
p << "-> (" << getInitArgs().getTypes() << ") ";
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/!getInitArgs().empty());
}
LogicalResult IterateOp::verify() {
if (getInitArgs().size() != getNumResults()) {
return emitOpError(
"mismatch in number of loop-carried values and defined values");
}
return success();
}
LogicalResult IterateOp::verifyRegions() {
if (getIterator().getType() != getIterSpace().getType().getIteratorType())
return emitOpError("mismatch in iterator and iteration space type");
if (getNumRegionIterArgs() != getNumResults())
return emitOpError(
"mismatch in number of basic block args and defined values");
auto initArgs = getInitArgs();
auto iterArgs = getRegionIterArgs();
auto yieldVals = getYieldedValues();
auto opResults = getResults();
if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
opResults.size()})) {
return emitOpError() << "number mismatch between iter args and results.";
}
for (auto [i, init, iter, yield, ret] :
llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
if (init.getType() != ret.getType())
return emitOpError() << "types mismatch between " << i
<< "th iter operand and defined value";
if (iter.getType() != ret.getType())
return emitOpError() << "types mismatch between " << i
<< "th iter region arg and defined value";
if (yield.getType() != ret.getType())
return emitOpError() << "types mismatch between " << i
<< "th yield value and defined value";
}
return success();
}
/// OpInterfaces' methods implemented by IterateOp.
SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; }
MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
return getInitArgsMutable();
}
Block::BlockArgListType IterateOp::getRegionIterArgs() {
return getRegion().getArguments().take_back(getNumRegionIterArgs());
}
std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
return cast<sparse_tensor::YieldOp>(
getRegion().getBlocks().front().getTerminator())
.getResultsMutable();
}
std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
return getInitArgs();
}
void IterateOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
// Both the operation itself and the region may be branching into the body or
// back into the operation itself.
regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
// It is possible for loop not to enter the body.
regions.push_back(RegionSuccessor(getResults()));
}
//===----------------------------------------------------------------------===//
// Sparse Tensor Dialect Setups.
//===----------------------------------------------------------------------===//
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,

View File

@@ -1025,6 +1025,7 @@ func.func @sparse_print(%arg0: tensor<10x10xf64>) {
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 2>) {
// expected-error@+1 {{'sparse_tensor.extract_iteration_space' expect larger level upper bound than lower bound}}
%l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 to 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 2>
-> !sparse_tensor.iter_space<#COO, lvls = 0 to 2>
return
}
@@ -1040,6 +1041,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
// expected-error@+1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be specified iff level lower bound equals 0}}
%l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
-> !sparse_tensor.iter_space<#COO, lvls = 1>
return
}
@@ -1054,7 +1056,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) {
// expected-error@+1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be specified iff level lower bound equals 0}}
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 1 : tensor<4x8xf32, #COO>
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 1 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 1>
return
}
@@ -1077,6 +1079,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) {
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#CSR, lvls = 0>) {
// expected-error@+1 {{'sparse_tensor.extract_iteration_space' op mismatch in parent iterator encoding and iteration space encoding.}}
%l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#CSR, lvls = 0>
-> !sparse_tensor.iter_space<#COO, lvls = 1>
return
}
@@ -1092,5 +1095,63 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
// expected-error@+1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be used to extract an iteration space from a consecutive level.}}
%l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
-> !sparse_tensor.iter_space<#COO, lvls = 2>
return
}
// -----
#COO = #sparse_tensor.encoding<{
map = (i, j) -> (
i : compressed(nonunique),
j : singleton(soa)
)
}>
func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
// expected-error @+1 {{'sparse_tensor.iterate' op different number of region iter_args and yielded values: 2 != 1}}
%r1, %r2 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i, %sj = %j): !sparse_tensor.iter_space<#COO, lvls = 0> -> (index, index) {
sparse_tensor.yield %si : index
}
return %r1 : index
}
// -----
#COO = #sparse_tensor.encoding<{
map = (i, j) -> (
i : compressed(nonunique),
j : singleton(soa)
)
}>
// expected-note@+1 {{prior use here}}
func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index) -> f32 {
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
// expected-error @+1 {{use of value '%i' expects different type than prior uses: 'f32' vs 'index'}}
%r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> f32 {
sparse_tensor.yield %outer : f32
}
return %r1 : f32
}
// -----
#COO = #sparse_tensor.encoding<{
map = (i, j) -> (
i : compressed(nonunique),
j : singleton(soa)
)
}>
func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
// expected-error @+1 {{'sparse_tensor.iterate' op 0-th region iter_arg and 0-th yielded value have different type: 'index' != 'f32'}}
%r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> index {
%y = arith.constant 1.0 : f32
sparse_tensor.yield %y : f32
}
return %r1 : index
}

View File

@@ -758,8 +758,37 @@ func.func @sparse_has_runtime() -> i1 {
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>)
-> (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>) {
// Extracting the iteration space for the first level needs no parent iterator.
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
// Extracting the iteration space for the second level needs a parent iterator.
%l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
-> !sparse_tensor.iter_space<#COO, lvls = 1>
return %l1, %l2 : !sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>
}
// -----
#COO = #sparse_tensor.encoding<{
map = (i, j) -> (
i : compressed(nonunique),
j : singleton(soa)
)
}>
// CHECK-LABEL: func.func @sparse_iterate(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf32, #sparse{{[0-9]*}}>,
// CHECK-SAME: %[[VAL_1:.*]]: index,
// CHECK-SAME: %[[VAL_2:.*]]: index) -> index {
// CHECK: %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 : tensor<4x8xf32, #sparse{{[0-9]*}}>
// CHECK: %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]]) iter_args(%[[VAL_7:.*]] = %[[VAL_1]]) : !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> -> (index) {
// CHECK: sparse_tensor.yield %[[VAL_7]] : index
// CHECK: }
// CHECK: return %[[VAL_4]] : index
// CHECK: }
func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
%r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
sparse_tensor.yield %outer : index
}
return %r1 : index
}

View File

@@ -0,0 +1,27 @@
// RUN: mlir-opt %s --loop-invariant-code-motion | FileCheck %s
#CSR = #sparse_tensor.encoding<{
map = (i, j) -> (
i : dense,
j : compressed
)
}>
// Make sure that pure instructions are hoisted outside the loop.
//
// CHECK: sparse_tensor.values
// CHECK: sparse_tensor.positions
// CHECK: sparse_tensor.coordinate
// CHECK: sparse_tensor.iterate
func.func @sparse_iterate(%sp : tensor<?x?xf64, #CSR>) {
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf64, #CSR>
-> !sparse_tensor.iter_space<#CSR, lvls = 0>
sparse_tensor.iterate %it1 in %l1 at (%crd) : !sparse_tensor.iter_space<#CSR, lvls = 0> {
%0 = sparse_tensor.values %sp : tensor<?x?xf64, #CSR> to memref<?xf64>
%1 = sparse_tensor.positions %sp { level = 1 : index } : tensor<?x?xf64, #CSR> to memref<?xindex>
%2 = sparse_tensor.coordinates %sp { level = 1 : index } : tensor<?x?xf64, #CSR> to memref<?xindex>
"test.op"(%0, %1, %2) : (memref<?xf64>, memref<?xindex>, memref<?xindex>) -> ()
}
return
}