[mlir] [VectorOps] Add expand/compress operations to Vector dialect
Introduces the expand and compress operations to the Vector dialect (important memory operations for sparse computations), together with a first reference implementation that lowers to the LLVM IR dialect to enable running on CPU (and other targets that support the corresponding LLVM IR intrinsics). Reviewed By: reidtatge Differential Revision: https://reviews.llvm.org/D84888
This commit is contained in:
@@ -134,11 +134,9 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
|
||||
return success();
|
||||
}
|
||||
|
||||
// Helper that returns vector of pointers given a base and an index vector.
|
||||
LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
|
||||
LLVMTypeConverter &typeConverter, Location loc,
|
||||
Value memref, Value indices, MemRefType memRefType,
|
||||
VectorType vType, Type iType, Value &ptrs) {
|
||||
// Helper that returns the base address of a memref.
|
||||
LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value memref, MemRefType memRefType, Value &base) {
|
||||
// Inspect stride and offset structure.
|
||||
//
|
||||
// TODO: flat memory only for now, generalize
|
||||
@@ -149,13 +147,31 @@ LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
|
||||
if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
|
||||
offset != 0 || memRefType.getMemorySpace() != 0)
|
||||
return failure();
|
||||
base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Create a vector of pointers from base and indices.
|
||||
MemRefDescriptor memRefDescriptor(memref);
|
||||
Value base = memRefDescriptor.alignedPtr(rewriter, loc);
|
||||
int64_t size = vType.getDimSize(0);
|
||||
auto pType = memRefDescriptor.getElementType();
|
||||
auto ptrsType = LLVM::LLVMType::getVectorTy(pType, size);
|
||||
// Helper that returns a pointer given a memref base.
|
||||
LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value memref, MemRefType memRefType, Value &ptr) {
|
||||
Value base;
|
||||
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
|
||||
return failure();
|
||||
auto pType = MemRefDescriptor(memref).getElementType();
|
||||
ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Helper that returns vector of pointers given a memref base and an index
|
||||
// vector.
|
||||
LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value memref, Value indices, MemRefType memRefType,
|
||||
VectorType vType, Type iType, Value &ptrs) {
|
||||
Value base;
|
||||
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
|
||||
return failure();
|
||||
auto pType = MemRefDescriptor(memref).getElementType();
|
||||
auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0));
|
||||
ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
|
||||
return success();
|
||||
}
|
||||
@@ -305,9 +321,8 @@ public:
|
||||
VectorType vType = gather.getResultVectorType();
|
||||
Type iType = gather.getIndicesVectorType().getElementType();
|
||||
Value ptrs;
|
||||
if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(),
|
||||
adaptor.indices(), gather.getMemRefType(), vType,
|
||||
iType, ptrs)))
|
||||
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
|
||||
gather.getMemRefType(), vType, iType, ptrs)))
|
||||
return failure();
|
||||
|
||||
// Replace with the gather intrinsic.
|
||||
@@ -344,9 +359,8 @@ public:
|
||||
VectorType vType = scatter.getValueVectorType();
|
||||
Type iType = scatter.getIndicesVectorType().getElementType();
|
||||
Value ptrs;
|
||||
if (failed(getIndexedPtrs(rewriter, typeConverter, loc, adaptor.base(),
|
||||
adaptor.indices(), scatter.getMemRefType(), vType,
|
||||
iType, ptrs)))
|
||||
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
|
||||
scatter.getMemRefType(), vType, iType, ptrs)))
|
||||
return failure();
|
||||
|
||||
// Replace with the scatter intrinsic.
|
||||
@@ -357,6 +371,60 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.expandload.
|
||||
class VectorExpandLoadOpConversion : public ConvertToLLVMPattern {
|
||||
public:
|
||||
explicit VectorExpandLoadOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto expand = cast<vector::ExpandLoadOp>(op);
|
||||
auto adaptor = vector::ExpandLoadOpAdaptor(operands);
|
||||
|
||||
Value ptr;
|
||||
if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
|
||||
ptr)))
|
||||
return failure();
|
||||
|
||||
auto vType = expand.getResultVectorType();
|
||||
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
|
||||
op, typeConverter.convertType(vType), ptr, adaptor.mask(),
|
||||
adaptor.pass_thru());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.compressstore.
|
||||
class VectorCompressStoreOpConversion : public ConvertToLLVMPattern {
|
||||
public:
|
||||
explicit VectorCompressStoreOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(),
|
||||
context, typeConverter) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto compress = cast<vector::CompressStoreOp>(op);
|
||||
auto adaptor = vector::CompressStoreOpAdaptor(operands);
|
||||
|
||||
Value ptr;
|
||||
if (failed(getBasePtr(rewriter, loc, adaptor.base(),
|
||||
compress.getMemRefType(), ptr)))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
|
||||
op, adaptor.value(), ptr, adaptor.mask());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern for all vector reductions.
|
||||
class VectorReductionOpConversion : public ConvertToLLVMPattern {
|
||||
public:
|
||||
@@ -1274,7 +1342,9 @@ void mlir::populateVectorToLLVMConversionPatterns(
|
||||
VectorTransferConversion<TransferWriteOp>,
|
||||
VectorTypeCastOpConversion,
|
||||
VectorGatherOpConversion,
|
||||
VectorScatterOpConversion>(ctx, converter);
|
||||
VectorScatterOpConversion,
|
||||
VectorExpandLoadOpConversion,
|
||||
VectorCompressStoreOpConversion>(ctx, converter);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user