[mlir][sparse] Refactoring: remove dependence on tuple type when lowering sparse tensors.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D133390
This commit is contained in:
@@ -482,65 +482,6 @@ LogicalResult YieldOp::verify() {
|
||||
"expected parent op to be sparse_tensor unary, binary, or reduce");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Sparse Tensor Storage Operation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult StorageOp::verify() {
|
||||
auto retTypes = getResult().getType().getTypes();
|
||||
if (retTypes.size() != getInputs().size())
|
||||
return emitError("The number of inputs is inconsistent with output tuple");
|
||||
|
||||
for (auto pair : llvm::zip(getInputs(), retTypes)) {
|
||||
auto input = std::get<0>(pair);
|
||||
auto retTy = std::get<1>(pair);
|
||||
|
||||
if (input.getType() != retTy)
|
||||
return emitError(llvm::formatv("Type mismatch between input (type={0}) "
|
||||
"and output tuple element (type={1})",
|
||||
input.getType(), retTy));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult StorageGetOp::verify() {
|
||||
uint64_t extractIdx = getIdx().getZExtValue();
|
||||
auto innerTypeArray = getStorage().getType().getTypes();
|
||||
if (extractIdx >= innerTypeArray.size())
|
||||
return emitError(llvm::formatv(
|
||||
"Out-of-bound access with index={0} on tuple with length={1}",
|
||||
extractIdx, innerTypeArray.size()));
|
||||
|
||||
auto expectedTy = getStorage().getType().getType(extractIdx);
|
||||
auto returnTy = getResult().getType();
|
||||
if (expectedTy != returnTy)
|
||||
return emitError(llvm::formatv(
|
||||
"Type mismatch between the returning type (type={0}) and the "
|
||||
"corresponding element type at index {1} (type={2})",
|
||||
expectedTy, extractIdx, returnTy));
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult StorageSetOp::verify() {
|
||||
uint64_t setIdx = getIdx().getZExtValue();
|
||||
SmallVector<Type, 8> expectedElemTy(getStorage().getType().getTypes());
|
||||
if (setIdx >= expectedElemTy.size())
|
||||
return emitError(llvm::formatv(
|
||||
"Out-of-bound access with index = {0} on tuple with length={1}", setIdx,
|
||||
expectedElemTy.size()));
|
||||
|
||||
// Updates the element type after storage_set.
|
||||
expectedElemTy[setIdx] = getValue().getType();
|
||||
auto expectedTy = TupleType::get(getContext(), expectedElemTy);
|
||||
auto returnTy = getResult().getType();
|
||||
if (expectedTy != returnTy)
|
||||
return emitError(
|
||||
llvm::formatv("Type mismatch between the returning type "
|
||||
"(type={0}) and the expected type (type={1})",
|
||||
returnTy, expectedTy));
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorDialect Methods.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Reference in New Issue
Block a user