[mlir][sparse] Add layout to the memref for the indices buffers to prepare for the AOS storage optimization for COO regions.
Fix relevant FileCheck tests. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D140742
This commit is contained in:
@@ -878,60 +878,65 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Base class for getter-like operations, e.g., to_indices, to_pointers.
|
||||
template <typename SourceOp, typename Base>
|
||||
class SparseGetterOpConverter : public OpConversionPattern<SourceOp> {
|
||||
/// Sparse codegen rule for pointer accesses.
|
||||
class SparseToPointersConverter : public OpConversionPattern<ToPointersOp> {
|
||||
public:
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
using OpConversionPattern<SourceOp>::OpConversionPattern;
|
||||
using OpAdaptor = typename ToPointersOp::Adaptor;
|
||||
using OpConversionPattern<ToPointersOp>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(ToPointersOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Replace the requested pointer access with corresponding field.
|
||||
// The cast_op is inserted by type converter to intermix 1:N type
|
||||
// conversion.
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
Value field = Base::getFieldForOp(desc, op);
|
||||
rewriter.replaceOp(op, field);
|
||||
uint64_t dim = op.getDimension().getZExtValue();
|
||||
rewriter.replaceOp(op, desc.getPtrMemRef(dim));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse codegen rule for pointer accesses.
|
||||
class SparseToPointersConverter
|
||||
: public SparseGetterOpConverter<ToPointersOp, SparseToPointersConverter> {
|
||||
public:
|
||||
using SparseGetterOpConverter::SparseGetterOpConverter;
|
||||
// Callback for SparseGetterOpConverter.
|
||||
static Value getFieldForOp(const SparseTensorDescriptor &desc,
|
||||
ToPointersOp op) {
|
||||
uint64_t dim = op.getDimension().getZExtValue();
|
||||
return desc.getPtrMemRef(dim);
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse codegen rule for index accesses.
|
||||
class SparseToIndicesConverter
|
||||
: public SparseGetterOpConverter<ToIndicesOp, SparseToIndicesConverter> {
|
||||
class SparseToIndicesConverter : public OpConversionPattern<ToIndicesOp> {
|
||||
public:
|
||||
using SparseGetterOpConverter::SparseGetterOpConverter;
|
||||
// Callback for SparseGetterOpConverter.
|
||||
static Value getFieldForOp(const SparseTensorDescriptor &desc,
|
||||
ToIndicesOp op) {
|
||||
using OpAdaptor = typename ToIndicesOp::Adaptor;
|
||||
using OpConversionPattern<ToIndicesOp>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Replace the requested pointer access with corresponding field.
|
||||
// The cast_op is inserted by type converter to intermix 1:N type
|
||||
// conversion.
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
uint64_t dim = op.getDimension().getZExtValue();
|
||||
return desc.getIdxMemRef(dim);
|
||||
Value field = desc.getIdxMemRef(dim);
|
||||
|
||||
// Insert a cast to bridge the actual type to the user expected type. If the
|
||||
// actual type and the user expected type aren't compatible, the compiler or
|
||||
// the runtime will issue an error.
|
||||
Type resType = op.getResult().getType();
|
||||
if (resType != field.getType())
|
||||
field = rewriter.create<memref::CastOp>(op.getLoc(), resType, field);
|
||||
rewriter.replaceOp(op, field);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Sparse codegen rule for value accesses.
|
||||
class SparseToValuesConverter
|
||||
: public SparseGetterOpConverter<ToValuesOp, SparseToValuesConverter> {
|
||||
class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
|
||||
public:
|
||||
using SparseGetterOpConverter::SparseGetterOpConverter;
|
||||
// Callback for SparseGetterOpConverter.
|
||||
static Value getFieldForOp(const SparseTensorDescriptor &desc,
|
||||
ToValuesOp /*op*/) {
|
||||
return desc.getValMemRef();
|
||||
using OpAdaptor = typename ToValuesOp::Adaptor;
|
||||
using OpConversionPattern<ToValuesOp>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Replace the requested pointer access with corresponding field.
|
||||
// The cast_op is inserted by type converter to intermix 1:N type
|
||||
// conversion.
|
||||
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
|
||||
rewriter.replaceOp(op, desc.getValMemRef());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user