[mlir][sparse] Add layout to the memref for the indices buffers to prepare for the AOS storage optimization for COO regions.

Fix relevant FileCheck tests.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D140742
This commit is contained in:
bixia1
2023-01-03 15:16:12 -08:00
parent 8383da1583
commit 90aa436291
10 changed files with 168 additions and 81 deletions

View File

@@ -878,60 +878,65 @@ public:
}
};
/// Base class for getter-like operations, e.g., to_indices, to_pointers.
template <typename SourceOp, typename Base>
class SparseGetterOpConverter : public OpConversionPattern<SourceOp> {
/// Sparse codegen rule for pointer accesses.
class SparseToPointersConverter : public OpConversionPattern<ToPointersOp> {
public:
using OpAdaptor = typename SourceOp::Adaptor;
using OpConversionPattern<SourceOp>::OpConversionPattern;
using OpAdaptor = typename ToPointersOp::Adaptor;
using OpConversionPattern<ToPointersOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
matchAndRewrite(ToPointersOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested pointer access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
Value field = Base::getFieldForOp(desc, op);
rewriter.replaceOp(op, field);
uint64_t dim = op.getDimension().getZExtValue();
rewriter.replaceOp(op, desc.getPtrMemRef(dim));
return success();
}
};
/// Sparse codegen rule for pointer accesses.
class SparseToPointersConverter
: public SparseGetterOpConverter<ToPointersOp, SparseToPointersConverter> {
public:
using SparseGetterOpConverter::SparseGetterOpConverter;
// Callback for SparseGetterOpConverter.
static Value getFieldForOp(const SparseTensorDescriptor &desc,
ToPointersOp op) {
uint64_t dim = op.getDimension().getZExtValue();
return desc.getPtrMemRef(dim);
}
};
/// Sparse codegen rule for index accesses.
class SparseToIndicesConverter
: public SparseGetterOpConverter<ToIndicesOp, SparseToIndicesConverter> {
class SparseToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
public:
using SparseGetterOpConverter::SparseGetterOpConverter;
// Callback for SparseGetterOpConverter.
static Value getFieldForOp(const SparseTensorDescriptor &desc,
ToIndicesOp op) {
using OpAdaptor = typename ToIndicesOp::Adaptor;
using OpConversionPattern<ToIndicesOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested pointer access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
uint64_t dim = op.getDimension().getZExtValue();
return desc.getIdxMemRef(dim);
Value field = desc.getIdxMemRef(dim);
// Insert a cast to bridge the actual type to the user expected type. If the
// actual type and the user expected type aren't compatible, the compiler or
// the runtime will issue an error.
Type resType = op.getResult().getType();
if (resType != field.getType())
field = rewriter.create<memref::CastOp>(op.getLoc(), resType, field);
rewriter.replaceOp(op, field);
return success();
}
};
/// Sparse codegen rule for value accesses.
class SparseToValuesConverter
: public SparseGetterOpConverter<ToValuesOp, SparseToValuesConverter> {
class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
public:
using SparseGetterOpConverter::SparseGetterOpConverter;
// Callback for SparseGetterOpConverter.
static Value getFieldForOp(const SparseTensorDescriptor &desc,
ToValuesOp /*op*/) {
return desc.getValMemRef();
using OpAdaptor = typename ToValuesOp::Adaptor;
using OpConversionPattern<ToValuesOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Replace the requested pointer access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
rewriter.replaceOp(op, desc.getValMemRef());
return success();
}
};