//===- ConvertStandardToSPIRV.cpp - Standard to SPIR-V dialect conversion--===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= // // This file implements patterns to convert Standard Ops to the SPIR-V dialect. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/LayoutUtils.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVLowering.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "llvm/ADT/SetVector.h" using namespace mlir; //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// namespace { /// Convert constant operation with IndexType return to SPIR-V constant /// operation. Since IndexType is not used within SPIR-V dialect, this needs /// special handling to make sure the result type and the type of the value /// attribute are consistent. class ConstantIndexOpConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult matchAndRewrite(ConstantOp constIndexOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (!constIndexOp.getResult()->getType().isa()) { return matchFailure(); } // The attribute has index type which is not directly supported in // SPIR-V. Get the integer value and create a new IntegerAttr. auto constAttr = constIndexOp.value().dyn_cast(); if (!constAttr) { return matchFailure(); } // Use the bitwidth set in the value attribute to decide the result type // of the SPIR-V constant operation since SPIR-V does not support index // types. auto constVal = constAttr.getValue(); auto constValType = constAttr.getType().dyn_cast(); if (!constValType) { return matchFailure(); } auto spirvConstType = typeConverter.convertBasicType(constIndexOp.getResult()->getType()); auto spirvConstVal = rewriter.getIntegerAttr(spirvConstType, constAttr.getInt()); rewriter.replaceOpWithNewOp(constIndexOp, spirvConstType, spirvConstVal); return matchSuccess(); } }; /// Convert compare operation to SPIR-V dialect. class CmpIOpConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { CmpIOpOperandAdaptor cmpIOpOperands(operands); switch (cmpIOp.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ rewriter.replaceOpWithNewOp( \ cmpIOp, cmpIOp.getResult()->getType(), cmpIOpOperands.lhs(), \ cmpIOpOperands.rhs()); \ return matchSuccess(); DISPATCH(CmpIPredicate::EQ, spirv::IEqualOp); DISPATCH(CmpIPredicate::NE, spirv::INotEqualOp); DISPATCH(CmpIPredicate::SLT, spirv::SLessThanOp); DISPATCH(CmpIPredicate::SLE, spirv::SLessThanEqualOp); DISPATCH(CmpIPredicate::SGT, spirv::SGreaterThanOp); DISPATCH(CmpIPredicate::SGE, spirv::SGreaterThanEqualOp); #undef DISPATCH default: break; } return matchFailure(); } }; /// Convert integer binary operations to SPIR-V operations. Cannot use /// tablegen for this. If the integer operation is on variables of IndexType, /// the type of the return value of the replacement operation differs from /// that of the replaced operation. This is not handled in tablegen-based /// pattern specification. template class IntegerOpConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult matchAndRewrite(StdOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto resultType = this->typeConverter.convertBasicType(operation.getResult()->getType()); rewriter.template replaceOpWithNewOp( operation, resultType, operands, ArrayRef()); return this->matchSuccess(); } }; /// Convert load -> spv.LoadOp. The operands of the replaced operation are of /// IndexType while that of the replacement operation are of type i32. This is /// not supported in tablegen based pattern specification. // TODO(ravishankarm) : These could potentially be templated on the operation // being converted, since the same logic should work for linalg.load. class LoadOpConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult matchAndRewrite(LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { LoadOpOperandAdaptor loadOperands(operands); auto basePtr = loadOperands.memref(); auto ptrType = basePtr->getType().dyn_cast(); if (!ptrType) { return matchFailure(); } auto loadPtr = rewriter.create( loadOp.getLoc(), basePtr, loadOperands.indices()); auto loadPtrType = loadPtr.getType().cast(); rewriter.replaceOpWithNewOp( loadOp, loadPtrType.getPointeeType(), loadPtr, /*memory_access =*/nullptr, /*alignment =*/nullptr); return matchSuccess(); } }; /// Convert return -> spv.Return. class ReturnToSPIRVConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult matchAndRewrite(ReturnOp returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (returnOp.getNumOperands()) { return matchFailure(); } rewriter.replaceOpWithNewOp(returnOp); return matchSuccess(); } }; /// Convert select -> spv.Select class SelectOpConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult matchAndRewrite(SelectOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { SelectOpOperandAdaptor selectOperands(operands); rewriter.replaceOpWithNewOp(op, selectOperands.condition(), selectOperands.true_value(), selectOperands.false_value()); return matchSuccess(); } }; /// Convert store -> spv.StoreOp. The operands of the replaced operation are /// of IndexType while that of the replacement operation are of type i32. This /// is not supported in tablegen based pattern specification. // TODO(ravishankarm) : These could potentially be templated on the operation // being converted, since the same logic should work for linalg.store. class StoreOpConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; PatternMatchResult matchAndRewrite(StoreOp storeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { StoreOpOperandAdaptor storeOperands(operands); auto value = storeOperands.value(); auto basePtr = storeOperands.memref(); auto ptrType = basePtr->getType().dyn_cast(); if (!ptrType) { return matchFailure(); } auto storePtr = rewriter.create( storeOp.getLoc(), basePtr, storeOperands.indices()); rewriter.replaceOpWithNewOp(storeOp, storePtr, value, /*memory_access =*/nullptr, /*alignment =*/nullptr); return matchSuccess(); } }; } // namespace namespace { /// Import the Standard Ops to SPIR-V Patterns. #include "StandardToSPIRV.cpp.inc" } // namespace namespace mlir { void populateStandardToSPIRVPatterns(MLIRContext *context, SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns) { populateWithGenerated(context, &patterns); // Add the return op conversion. patterns .insert, IntegerOpConversion, IntegerOpConversion, IntegerOpConversion, IntegerOpConversion, LoadOpConversion, ReturnToSPIRVConversion, SelectOpConversion, StoreOpConversion>( context, typeConverter); } } // namespace mlir