//===- SparseLowering.cpp - Lowers sparse primitives to library calls. ---===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" using namespace mlir; namespace { /// Returns function reference (first hit also inserts into module). static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result, ValueRange operands) { MLIRContext *context = op->getContext(); auto module = op->getParentOfType(); auto func = module.lookupSymbol(name); if (!func) { OpBuilder moduleBuilder(module.getBodyRegion()); moduleBuilder .create(op->getLoc(), name, FunctionType::get(context, operands.getTypes(), result)) .setPrivate(); } return SymbolRefAttr::get(context, name); } /// Sparse conversion rule to remove opaque pointer cast. class TensorFromPointerConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(linalg::SparseTensorFromPointerOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOp(op, operands[0]); return success(); } }; /// Sparse conversion rule for dimension accesses. class TensorToDimSizeConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(DimOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (!operands[0].getType().isa()) return failure(); Type resType = op.getType(); StringRef name = "sparseDimSize"; rewriter.replaceOpWithNewOp( op, resType, getFunc(op, name, resType, operands), operands); return success(); } }; /// Sparse conversion rule for pointer accesses. class TensorToPointersConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(linalg::SparseTensorToPointersMemRefOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Type resType = op.getType(); Type eltType = resType.cast().getElementType(); StringRef name; if (eltType.isIndex() || eltType.isInteger(64)) name = "sparsePtrsI64"; else return failure(); rewriter.replaceOpWithNewOp( op, resType, getFunc(op, name, resType, operands), operands); return success(); } }; /// Sparse conversion rule for index accesses. class TensorToIndicesConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(linalg::SparseTensorToIndicesMemRefOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Type resType = op.getType(); Type eltType = resType.cast().getElementType(); StringRef name; if (eltType.isIndex() || eltType.isInteger(64)) name = "sparseIndxsI64"; else return failure(); rewriter.replaceOpWithNewOp( op, resType, getFunc(op, name, resType, operands), operands); return success(); } }; /// Sparse conversion rule for value accesses. class TensorToValuesConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(linalg::SparseTensorToValuesMemRefOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Type resType = op.getType(); Type eltType = resType.cast().getElementType(); StringRef name; if (eltType.isF64()) name = "sparseValsF64"; else return failure(); rewriter.replaceOpWithNewOp( op, resType, getFunc(op, name, resType, operands), operands); return success(); } }; } // namespace /// Populates the given patterns list with conversion rules required for /// the sparsification of linear algebra operations. void linalg::populateSparsificationConversionPatterns( MLIRContext *context, OwningRewritePatternList &patterns) { patterns.insert(context); }