[mlir] [VectorOps] Add expand/compress operations to Vector dialect

Introduces the expand and compress operations to the Vector dialect
(important memory operations for sparse computations), together
with a first reference implementation that lowers to the LLVM IR
dialect to enable running on CPU (and other targets that support
the corresponding LLVM IR intrinsics).

Reviewed By: reidtatge

Differential Revision: https://reviews.llvm.org/D84888
This commit is contained in:
aartbik
2020-07-31 12:47:25 -07:00
parent 3c0f347002
commit e8dcf5f87d
11 changed files with 505 additions and 60 deletions

View File

@@ -134,11 +134,9 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
return success();
}
// Helper that returns vector of pointers given a base and an index vector.
LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter, Location loc,
Value memref, Value indices, MemRefType memRefType,
VectorType vType, Type iType, Value &ptrs) {
// Helper that returns the base address of a memref.
LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
Value memref, MemRefType memRefType, Value &base) {
// Inspect stride and offset structure.
//
// TODO: flat memory only for now, generalize
@@ -149,13 +147,31 @@ LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
offset != 0 || memRefType.getMemorySpace() != 0)
return failure();
base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
return success();
}
// Create a vector of pointers from base and indices.
MemRefDescriptor memRefDescriptor(memref);
Value base = memRefDescriptor.alignedPtr(rewriter, loc);
int64_t size = vType.getDimSize(0);
auto pType = memRefDescriptor.getElementType();
auto ptrsType = LLVM::LLVMType::getVectorTy(pType, size);
// Helper that returns a pointer given a memref base.
LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, Location loc,
Value memref, MemRefType memRefType, Value &ptr) {
Value base;
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
return failure();
auto pType = MemRefDescriptor(memref).getElementType();
ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
return success();
}
// Helper that returns vector of pointers given a memref base and an index
// vector.
LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
Value memref, Value indices, MemRefType memRefType,
VectorType vType, Type iType, Value &ptrs) {
Value base;
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
return failure();
auto pType = MemRefDescriptor(memref).getElementType();
auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0));
ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
return success();
}
@@ -305,9 +321,8 @@ public:
VectorType vType = gather.getResultVectorType();
Type iType = gather.getIndicesVectorType().getElementType();
Value ptrs;
if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(),
adaptor.indices(), gather.getMemRefType(), vType,
iType, ptrs)))
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
gather.getMemRefType(), vType, iType, ptrs)))
return failure();
// Replace with the gather intrinsic.
@@ -344,9 +359,8 @@ public:
VectorType vType = scatter.getValueVectorType();
Type iType = scatter.getIndicesVectorType().getElementType();
Value ptrs;
if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(),
adaptor.indices(), scatter.getMemRefType(), vType,
iType, ptrs)))
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
scatter.getMemRefType(), vType, iType, ptrs)))
return failure();
// Replace with the scatter intrinsic.
@@ -357,6 +371,60 @@ public:
}
};
/// Conversion pattern for a vector.expandload.
class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorExpandLoadOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context,
typeConverter) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto expand = cast<vector::ExpandLoadOp>(op);
auto adaptor = vector::ExpandLoadOpAdaptor(operands);
Value ptr;
if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
ptr)))
return failure();
auto vType = expand.getResultVectorType();
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
op, typeConverter.convertType(vType), ptr, adaptor.mask(),
adaptor.pass_thru());
return success();
}
};
/// Conversion pattern for a vector.compressstore.
class VectorCompressStoreOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorCompressStoreOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(),
context, typeConverter) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto compress = cast<vector::CompressStoreOp>(op);
auto adaptor = vector::CompressStoreOpAdaptor(operands);
Value ptr;
if (failed(getBasePtr(rewriter, loc, adaptor.base(),
compress.getMemRefType(), ptr)))
return failure();
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
op, adaptor.value(), ptr, adaptor.mask());
return success();
}
};
/// Conversion pattern for all vector reductions.
class VectorReductionOpConversion : public ConvertToLLVMPattern {
public:
@@ -1274,7 +1342,9 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorTransferConversion<TransferWriteOp>,
VectorTypeCastOpConversion,
VectorGatherOpConversion,
VectorScatterOpConversion>(ctx, converter);
VectorScatterOpConversion,
VectorExpandLoadOpConversion,
VectorCompressStoreOpConversion>(ctx, converter);
// clang-format on
}