//===- StandardToSPIRV.cpp - Standard to SPIR-V Patterns ------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns to convert standard dialect to SPIR-V dialect. // //===----------------------------------------------------------------------===// #include "../SPIRVCommon/Pattern.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineMap.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "std-to-spirv-pattern" using namespace mlir; //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// // Note that DRR cannot be used for the patterns in this file: we may need to // convert type along the way, which requires ConversionPattern. DRR generates // normal RewritePattern. namespace { /// Converts std.return to spv.Return. class ReturnOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ReturnOp returnOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// Converts std.select to spv.Select. class SelectOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(SelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// Converts std.splat to spv.CompositeConstruct. class SplatPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(SplatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// Converts std.br to spv.Branch. struct BranchOpPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(BranchOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// Converts std.cond_br to spv.BranchConditional. struct CondBranchOpPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CondBranchOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// Converts tensor.extract into loading using access chains from SPIR-V local /// variables. class TensorExtractPattern final : public OpConversionPattern { public: TensorExtractPattern(TypeConverter &typeConverter, MLIRContext *context, int64_t threshold, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), byteCountThreshold(threshold) {} LogicalResult matchAndRewrite(tensor::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { TensorType tensorType = extractOp.tensor().getType().cast(); if (!tensorType.hasStaticShape()) return rewriter.notifyMatchFailure(extractOp, "non-static tensor"); if (tensorType.getNumElements() * tensorType.getElementTypeBitWidth() > byteCountThreshold * 8) return rewriter.notifyMatchFailure(extractOp, "exceeding byte count threshold"); Location loc = extractOp.getLoc(); int64_t rank = tensorType.getRank(); SmallVector strides(rank, 1); for (int i = rank - 2; i >= 0; --i) { strides[i] = strides[i + 1] * tensorType.getDimSize(i + 1); } Type varType = spirv::PointerType::get(adaptor.tensor().getType(), spirv::StorageClass::Function); spirv::VariableOp varOp; if (adaptor.tensor().getDefiningOp()) { varOp = rewriter.create( loc, varType, spirv::StorageClass::Function, /*initializer=*/adaptor.tensor()); } else { // Need to store the value to the local variable. It's questionable // whether we want to support such case though. return failure(); } auto &typeConverter = *getTypeConverter(); auto indexType = typeConverter.getIndexType(); Value index = spirv::linearizeIndex(adaptor.indices(), strides, /*offset=*/0, indexType, loc, rewriter); auto acOp = rewriter.create(loc, varOp, index); rewriter.replaceOpWithNewOp(extractOp, acOp); return success(); } private: int64_t byteCountThreshold; }; } // namespace //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// LogicalResult ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { if (returnOp.getNumOperands() > 1) return failure(); if (returnOp.getNumOperands() == 1) { rewriter.replaceOpWithNewOp(returnOp, adaptor.getOperands()[0]); } else { rewriter.replaceOpWithNewOp(returnOp); } return success(); } //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// LogicalResult SelectOpPattern::matchAndRewrite(SelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { rewriter.replaceOpWithNewOp(op, adaptor.getCondition(), adaptor.getTrueValue(), adaptor.getFalseValue()); return success(); } //===----------------------------------------------------------------------===// // SplatOp //===----------------------------------------------------------------------===// LogicalResult SplatPattern::matchAndRewrite(SplatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto dstVecType = op.getType().dyn_cast(); if (!dstVecType || !spirv::CompositeType::isValid(dstVecType)) return failure(); SmallVector source(dstVecType.getNumElements(), adaptor.getInput()); rewriter.replaceOpWithNewOp(op, dstVecType, source); return success(); } //===----------------------------------------------------------------------===// // BranchOpPattern //===----------------------------------------------------------------------===// LogicalResult BranchOpPattern::matchAndRewrite(BranchOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { rewriter.replaceOpWithNewOp(op, op.getDest(), adaptor.getDestOperands()); return success(); } //===----------------------------------------------------------------------===// // CondBranchOpPattern //===----------------------------------------------------------------------===// LogicalResult CondBranchOpPattern::matchAndRewrite( CondBranchOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { rewriter.replaceOpWithNewOp( op, op.getCondition(), op.getTrueDest(), adaptor.getTrueDestOperands(), op.getFalseDest(), adaptor.getFalseDestOperands()); return success(); } //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// namespace mlir { void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); patterns.add< // Unary and binary patterns spirv::UnaryAndBinaryOpPattern, spirv::UnaryAndBinaryOpPattern, spirv::UnaryAndBinaryOpPattern, spirv::UnaryAndBinaryOpPattern, spirv::UnaryAndBinaryOpPattern, spirv::UnaryAndBinaryOpPattern, ReturnOpPattern, SelectOpPattern, SplatPattern, BranchOpPattern, CondBranchOpPattern>(typeConverter, context); } void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, int64_t byteCountThreshold, RewritePatternSet &patterns) { patterns.add(typeConverter, patterns.getContext(), byteCountThreshold); } } // namespace mlir