[mlir] use built-in vector types instead of LLVM dialect types when possible

Continue the convergence between LLVM dialect and built-in types by using the
built-in vector type whenever possible, that is for fixed vectors of built-in
integers and built-in floats. LLVM dialect vector type is still in use for
pointers, less frequent floating point types that do not have a built-in
equivalent, and scalable vectors. However, the top-level `LLVMVectorType` class
has been removed in favor of free functions capable of inspecting both built-in
and LLVM dialect vector types: `LLVM::getVectorElementType`,
`LLVM::getNumVectorElements` and `LLVM::getFixedVectorType`. Additional work is
necessary to design an implemented the extensions to built-in types so as to
remove the `LLVMFixedVectorType` entirely.

Note that the default output format for the built-in vectors does not have
whitespace around the `x` separator, e.g., `vector<4xf32>` as opposed to the
LLVM dialect vector type format that does, e.g., `!llvm.vec<4 x fp128>`. This
required changing the FileCheck patterns in several tests.

Reviewed By: mehdi_amini, silvas

Differential Revision: https://reviews.llvm.org/D94405
This commit is contained in:
Alex Zinenko
2021-01-11 13:58:05 +01:00
parent 7ab803095a
commit bd30a796fc
54 changed files with 1283 additions and 1219 deletions

View File

@@ -182,7 +182,7 @@ static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
return failure();
auto pType = MemRefDescriptor(memref).getElementPtrType();
auto ptrsType = LLVM::LLVMFixedVectorType::get(pType, vType.getDimSize(0));
auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
return success();
}
@@ -192,8 +192,7 @@ static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
// used when source/dst memrefs are not on address space 0.
static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
Value ptr, MemRefType memRefType, Type vt) {
auto pType =
LLVM::LLVMPointerType::get(vt.template cast<LLVM::LLVMFixedVectorType>());
auto pType = LLVM::LLVMPointerType::get(vt);
if (memRefType.getMemorySpace() == 0)
return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
return rewriter.create<LLVM::AddrSpaceCastOp>(loc, pType, ptr);
@@ -1226,7 +1225,7 @@ public:
//
// TODO: when the leaf transfer rank is k > 1, we need the last `k`
// dimensions here.
unsigned vecWidth = vtp.getNumElements();
unsigned vecWidth = LLVM::getVectorNumElements(vtp).getFixedValue();
unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
Value off = xferOp.indices()[lastIndex];
Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);