[mlir][sparse] Refactoring: remove dependence on tuple type when lowering sparse tensors.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D133390
This commit is contained in:
Peiming Liu
2022-09-07 00:49:44 +00:00
parent 300155911a
commit edca72f5bc
12 changed files with 317 additions and 876 deletions

View File

@@ -482,65 +482,6 @@ LogicalResult YieldOp::verify() {
"expected parent op to be sparse_tensor unary, binary, or reduce");
}
//===----------------------------------------------------------------------===//
// Sparse Tensor Storage Operation.
//===----------------------------------------------------------------------===//
LogicalResult StorageOp::verify() {
auto retTypes = getResult().getType().getTypes();
if (retTypes.size() != getInputs().size())
return emitError("The number of inputs is inconsistent with output tuple");
for (auto pair : llvm::zip(getInputs(), retTypes)) {
auto input = std::get<0>(pair);
auto retTy = std::get<1>(pair);
if (input.getType() != retTy)
return emitError(llvm::formatv("Type mismatch between input (type={0}) "
"and output tuple element (type={1})",
input.getType(), retTy));
}
return success();
}
LogicalResult StorageGetOp::verify() {
uint64_t extractIdx = getIdx().getZExtValue();
auto innerTypeArray = getStorage().getType().getTypes();
if (extractIdx >= innerTypeArray.size())
return emitError(llvm::formatv(
"Out-of-bound access with index={0} on tuple with length={1}",
extractIdx, innerTypeArray.size()));
auto expectedTy = getStorage().getType().getType(extractIdx);
auto returnTy = getResult().getType();
if (expectedTy != returnTy)
return emitError(llvm::formatv(
"Type mismatch between the returning type (type={0}) and the "
"corresponding element type at index {1} (type={2})",
expectedTy, extractIdx, returnTy));
return success();
}
LogicalResult StorageSetOp::verify() {
uint64_t setIdx = getIdx().getZExtValue();
SmallVector<Type, 8> expectedElemTy(getStorage().getType().getTypes());
if (setIdx >= expectedElemTy.size())
return emitError(llvm::formatv(
"Out-of-bound access with index = {0} on tuple with length={1}", setIdx,
expectedElemTy.size()));
// Updates the element type after storage_set.
expectedElemTy[setIdx] = getValue().getType();
auto expectedTy = TupleType::get(getContext(), expectedElemTy);
auto returnTy = getResult().getType();
if (expectedTy != returnTy)
return emitError(
llvm::formatv("Type mismatch between the returning type "
"(type={0}) and the expected type (type={1})",
returnTy, expectedTy));
return success();
}
//===----------------------------------------------------------------------===//
// TensorDialect Methods.
//===----------------------------------------------------------------------===//