From 894b27a74611acffcc24fe03c59be2a2af36ea7f Mon Sep 17 00:00:00 2001 From: Longsheng Mou Date: Wed, 26 Mar 2025 16:51:26 +0800 Subject: [PATCH] [mlir][MemRefToLLVM] Fix crash with unconvertable memory space (#132323) This PR adds handling when the `memref.alloca` with unconvertable memory space to prevent a crash. Fixes #131439. --- mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 7 +++++-- mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir | 11 +++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index fe0ee11d84ad..cb4317ef1bce 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -105,8 +105,11 @@ struct AllocaOpLowering : public AllocLikeOpLLVMLowering { auto allocaOp = cast(op); auto elementType = typeConverter->convertType(allocaOp.getType().getElementType()); - unsigned addrSpace = - *getTypeConverter()->getMemRefAddressSpace(allocaOp.getType()); + FailureOr maybeAddressSpace = + getTypeConverter()->getMemRefAddressSpace(allocaOp.getType()); + if (failed(maybeAddressSpace)) + return std::make_tuple(Value(), Value()); + unsigned addrSpace = *maybeAddressSpace; auto elementPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace); diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir index 523e894aaef8..12e93c96f743 100644 --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -654,3 +654,14 @@ func.func @store_non_temporal(%input : memref<32xf32, affine_map<(d0) -> (d0)>>, memref.store %2, %output[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>> func.return } + +// ----- + +// Ensure unconvertable memory space not cause a crash + +// CHECK-LABEL: @alloca_unconvertable_memory_space +func.func @alloca_unconvertable_memory_space() { + // CHECK: memref.alloca + %alloca = memref.alloca() : memref<1x32x33xi32, #spirv.storage_class> + func.return +}