Emit inbounds and nuw attributes in memref. (#138984)

Now that MLIR accepts nuw and nusw in getelementptr, this patch emits
the inbounds and nuw attributes when lower memref to LLVM in load and
store operators.

This patch also strengthens the memref.load and memref.store spec about
undefined behaviour during lowering.

This patch also lifts the |rewriter| parameter in getStridedElementPtr
ahead so that LLVM::GEPNoWrapFlags can be added at the end with a
default value and grouped together with other operators' parameters.

Signed-off-by: Lin, Peiyong <linpyong@gmail.com>
This commit is contained in:
Peiyong Lin
2025-05-20 14:16:22 -07:00
committed by GitHub
parent 11db1285e4
commit 04ad8d4900
14 changed files with 103 additions and 77 deletions

View File

@@ -289,8 +289,8 @@ public:
// Resolve address.
auto vtype = cast<VectorType>(
this->typeConverter->convertType(loadOrStoreOp.getVectorType()));
Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value dataPtr = this->getStridedElementPtr(
rewriter, loc, memRefTy, adaptor.getBase(), adaptor.getIndices());
replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, dataPtr, align,
rewriter);
return success();
@@ -337,8 +337,8 @@ public:
return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
// Resolve address.
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
adaptor.getBase(), adaptor.getIndices());
Value base = adaptor.getBase();
Value ptrs =
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
@@ -393,8 +393,8 @@ public:
"could not resolve alignment");
// Resolve address.
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
adaptor.getBase(), adaptor.getIndices());
Value ptrs =
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
adaptor.getBase(), ptr, adaptor.getIndexVec(), vType);
@@ -428,8 +428,8 @@ public:
// Resolve address.
auto vtype = typeConverter->convertType(expand.getVectorType());
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
adaptor.getBase(), adaptor.getIndices());
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
@@ -450,8 +450,8 @@ public:
MemRefType memRefType = compress.getMemRefType();
// Resolve address.
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
adaptor.getBase(), adaptor.getIndices());
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
compress, adaptor.getValueToStore(), ptr, adaptor.getMask());