[mlir][Linalg] Fix crash in LinalgToStandard

Properly handle `appendMangledType` failure instead of asserting.

Fixes #59986.
This commit is contained in:
Nicolas Vasilache
2023-01-20 00:06:34 -08:00
parent 02fb5aae11
commit ff94419a28
2 changed files with 23 additions and 9 deletions

View File

@@ -1795,7 +1795,7 @@ SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
return llvm::to_vector<4>(concatRanges);
}
static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
if (auto memref = t.dyn_cast<MemRefType>()) {
ss << "view";
for (auto size : memref.getShape())
@@ -1804,16 +1804,19 @@ static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
else
ss << size << "x";
appendMangledType(ss, memref.getElementType());
} else if (auto vec = t.dyn_cast<VectorType>()) {
return success();
}
if (auto vec = t.dyn_cast<VectorType>()) {
ss << "vector";
llvm::interleave(
vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
appendMangledType(ss, vec.getElementType());
return success();
} else if (t.isSignlessIntOrIndexOrFloat()) {
ss << t;
} else {
llvm_unreachable("Invalid type for linalg library name mangling");
return success();
}
return failure();
}
std::string mlir::linalg::generateLibraryCallName(Operation *op) {
@@ -1823,11 +1826,14 @@ std::string mlir::linalg::generateLibraryCallName(Operation *op) {
std::replace(name.begin(), name.end(), '.', '_');
llvm::raw_string_ostream ss(name);
ss << "_";
auto types = op->getOperandTypes();
llvm::interleave(
types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); },
[&]() { ss << "_"; });
return ss.str();
for (Type t : op->getOperandTypes()) {
if (failed(appendMangledType(ss, t)))
return std::string();
ss << "_";
}
std::string res = ss.str();
res.pop_back();
return res;
}
//===----------------------------------------------------------------------===//

View File

@@ -71,3 +71,11 @@ func.func @func(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) {
} -> tensor<?xf32>
return
}
// -----
func.func @func(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<4x8xf32> {
// expected-error @below {{failed to legalize}}
%0 = linalg.copy ins(%arg0 : tensor<4x8xf32>) outs(%arg1 : tensor<4x8xf32>) -> tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}