//===- StandardToSPIRV.cpp - Standard to SPIR-V Patterns ------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns to convert standard dialect to SPIR-V dialect. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineMap.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "std-to-spirv-pattern" using namespace mlir; //===----------------------------------------------------------------------===// // Utility functions //===----------------------------------------------------------------------===// /// Returns true if the given `type` is a boolean scalar or vector type. static bool isBoolScalarOrVector(Type type) { if (type.isInteger(1)) return true; if (auto vecType = type.dyn_cast()) return vecType.getElementType().isInteger(1); return false; } /// Converts the given `srcAttr` into a boolean attribute if it holds an /// integral value. Returns null attribute if conversion fails. static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { if (auto boolAttr = srcAttr.dyn_cast()) return boolAttr; if (auto intAttr = srcAttr.dyn_cast()) return builder.getBoolAttr(intAttr.getValue().getBoolValue()); return BoolAttr(); } /// Converts the given `srcAttr` to a new attribute of the given `dstType`. /// Returns null attribute if conversion fails. static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder) { // If the source number uses less active bits than the target bitwidth, then // it should be safe to convert. if (srcAttr.getValue().isIntN(dstType.getWidth())) return builder.getIntegerAttr(dstType, srcAttr.getInt()); // XXX: Try again by interpreting the source number as a signed value. // Although integers in the standard dialect are signless, they can represent // a signed number. It's the operation decides how to interpret. This is // dangerous, but it seems there is no good way of handling this if we still // want to change the bitwidth. Emit a message at least. if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) { auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt()); LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '" << dstAttr << "' for type '" << dstType << "'\n"); return dstAttr; } LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' illegal: cannot fit into target type '" << dstType << "'\n"); return IntegerAttr(); } /// Converts the given `srcAttr` to a new attribute of the given `dstType`. /// Returns null attribute if `dstType` is not 32-bit or conversion fails. static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder) { // Only support converting to float for now. if (!dstType.isF32()) return FloatAttr(); // Try to convert the source floating-point number to single precision. APFloat dstVal = srcAttr.getValue(); bool losesInfo = false; APFloat::opStatus status = dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo); if (status != APFloat::opOK || losesInfo) { LLVM_DEBUG(llvm::dbgs() << srcAttr << " illegal: cannot fit into converted type '" << dstType << "'\n"); return FloatAttr(); } return builder.getF32FloatAttr(dstVal.convertToFloat()); } /// Returns signed remainder for `lhs` and `rhs` and lets the result follow /// the sign of `signOperand`. /// /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative /// the result is undefined." So we cannot directly use spv.SRem/spv.SMod /// if either operand can be negative. Emulate it via spv.UMod. static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, Value signOperand, OpBuilder &builder) { assert(lhs.getType() == rhs.getType()); assert(lhs == signOperand || rhs == signOperand); Type type = lhs.getType(); // Calculate the remainder with spv.UMod. Value lhsAbs = builder.create(loc, type, lhs); Value rhsAbs = builder.create(loc, type, rhs); Value abs = builder.create(loc, lhsAbs, rhsAbs); // Fix the sign. Value isPositive; if (lhs == signOperand) isPositive = builder.create(loc, lhs, lhsAbs); else isPositive = builder.create(loc, rhs, rhsAbs); Value absNegate = builder.create(loc, type, abs); return builder.create(loc, type, isPositive, abs, absNegate); } //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// // Note that DRR cannot be used for the patterns in this file: we may need to // convert type along the way, which requires ConversionPattern. DRR generates // normal RewritePattern. namespace { /// Converts unary and binary standard operations to SPIR-V operations. template class UnaryAndBinaryOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { assert(adaptor.getOperands().size() <= 2); auto dstType = this->getTypeConverter()->convertType(operation.getType()); if (!dstType) return failure(); if (SPIRVOp::template hasTrait() && dstType != operation.getType()) { return operation.emitError( "bitwidth emulation is not implemented yet on unsigned op"); } rewriter.template replaceOpWithNewOp(operation, dstType, adaptor.getOperands()); return success(); } }; /// Converts std.remi_signed to SPIR-V ops. /// /// This cannot be merged into the template unary/binary pattern due to /// Vulkan restrictions over spv.SRem and spv.SMod. class SignedRemIOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(SignedRemIOp remOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// Converts bitwise standard operations to SPIR-V operations. This is a special /// pattern other than the BinaryOpPatternPattern because if the operands are /// boolean values, SPIR-V uses different operations (`SPIRVLogicalOp`). For /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. template class BitwiseOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { assert(adaptor.getOperands().size() == 2); auto dstType = this->getTypeConverter()->convertType(operation.getResult().getType()); if (!dstType) return failure(); if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { rewriter.template replaceOpWithNewOp( operation, dstType, adaptor.getOperands()); } else { rewriter.template replaceOpWithNewOp( operation, dstType, adaptor.getOperands()); } return success(); } }; /// Converts composite std.constant operation to spv.Constant. class ConstantCompositeOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// Converts scalar std.constant operation to spv.Constant. class ConstantScalarOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// Converts floating-point comparison operations to SPIR-V ops. class CmpFOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// Converts floating point NaN check to SPIR-V ops. This pattern requires /// Kernel capability. class CmpFOpNanKernelPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// Converts floating point NaN check to SPIR-V ops. This pattern does not /// require additional capability. class CmpFOpNanNonePattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// Converts integer compare operation on i1 type operands to SPIR-V ops. class BoolCmpIOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// Converts integer compare operation to SPIR-V ops. class CmpIOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// Converts std.return to spv.Return. class ReturnOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ReturnOp returnOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// Converts std.select to spv.Select. class SelectOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(SelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// Converts std.splat to spv.CompositeConstruct. class SplatPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(SplatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// Converts std.zexti to spv.Select if the type of source is i1 or vector of /// i1. class ZeroExtendI1Pattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ZeroExtendIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcType = adaptor.getOperands().front().getType(); if (!isBoolScalarOrVector(srcType)) return failure(); auto dstType = this->getTypeConverter()->convertType(op.getResult().getType()); Location loc = op.getLoc(); Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); rewriter.template replaceOpWithNewOp( op, dstType, adaptor.getOperands().front(), one, zero); return success(); } }; /// Converts tensor.extract into loading using access chains from SPIR-V local /// variables. class TensorExtractPattern final : public OpConversionPattern { public: TensorExtractPattern(TypeConverter &typeConverter, MLIRContext *context, int64_t threshold, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), byteCountThreshold(threshold) {} LogicalResult matchAndRewrite(tensor::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { TensorType tensorType = extractOp.tensor().getType().cast(); if (!tensorType.hasStaticShape()) return rewriter.notifyMatchFailure(extractOp, "non-static tensor"); if (tensorType.getNumElements() * tensorType.getElementTypeBitWidth() > byteCountThreshold * 8) return rewriter.notifyMatchFailure(extractOp, "exceeding byte count threshold"); Location loc = extractOp.getLoc(); int64_t rank = tensorType.getRank(); SmallVector strides(rank, 1); for (int i = rank - 2; i >= 0; --i) { strides[i] = strides[i + 1] * tensorType.getDimSize(i + 1); } Type varType = spirv::PointerType::get(adaptor.tensor().getType(), spirv::StorageClass::Function); spirv::VariableOp varOp; if (adaptor.tensor().getDefiningOp()) { varOp = rewriter.create( loc, varType, spirv::StorageClass::Function, /*initializer=*/adaptor.tensor()); } else { // Need to store the value to the local variable. It's questionable // whether we want to support such case though. return failure(); } auto &typeConverter = *getTypeConverter(); auto indexType = typeConverter.getIndexType(); Value index = spirv::linearizeIndex(adaptor.indices(), strides, /*offset=*/0, indexType, loc, rewriter); auto acOp = rewriter.create(loc, varOp, index); rewriter.replaceOpWithNewOp(extractOp, acOp); return success(); } private: int64_t byteCountThreshold; }; /// Converts std.trunci to spv.Select if the type of result is i1 or vector of /// i1. class TruncI1Pattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(TruncateIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = this->getTypeConverter()->convertType(op.getResult().getType()); if (!isBoolScalarOrVector(dstType)) return failure(); Location loc = op.getLoc(); auto srcType = adaptor.getOperands().front().getType(); // Check if (x & 1) == 1. Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); Value maskedSrc = rewriter.create( loc, srcType, adaptor.getOperands()[0], mask); Value isOne = rewriter.create(loc, maskedSrc, mask); Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); rewriter.replaceOpWithNewOp(op, dstType, isOne, one, zero); return success(); } }; /// Converts std.uitofp to spv.Select if the type of source is i1 or vector of /// i1. class UIToFPI1Pattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(UIToFPOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcType = adaptor.getOperands().front().getType(); if (!isBoolScalarOrVector(srcType)) return failure(); auto dstType = this->getTypeConverter()->convertType(op.getResult().getType()); Location loc = op.getLoc(); Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); rewriter.template replaceOpWithNewOp( op, dstType, adaptor.getOperands().front(), one, zero); return success(); } }; /// Converts type-casting standard operations to SPIR-V operations. template class TypeCastingOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { assert(adaptor.getOperands().size() == 1); auto srcType = adaptor.getOperands().front().getType(); auto dstType = this->getTypeConverter()->convertType(operation.getResult().getType()); if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) return failure(); if (dstType == srcType) { // Due to type conversion, we are seeing the same source and target type. // Then we can just erase this operation by forwarding its operand. rewriter.replaceOp(operation, adaptor.getOperands().front()); } else { rewriter.template replaceOpWithNewOp(operation, dstType, adaptor.getOperands()); } return success(); } }; /// Converts std.xor to SPIR-V operations. class XOrOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// Converts std.xor to SPIR-V operations if the type of source is i1 or vector /// of i1. class BoolXOrOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace //===----------------------------------------------------------------------===// // SignedRemIOpPattern //===----------------------------------------------------------------------===// LogicalResult SignedRemIOpPattern::matchAndRewrite( SignedRemIOp remOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value result = emulateSignedRemainder( remOp.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], adaptor.getOperands()[0], rewriter); rewriter.replaceOp(remOp, result); return success(); } //===----------------------------------------------------------------------===// // ConstantOp with composite type. //===----------------------------------------------------------------------===// // TODO: This probably should be split into the vector case and tensor case, // so that the tensor case can be moved to TensorToSPIRV conversion. But, // std.constant is for the standard dialect though. LogicalResult ConstantCompositeOpPattern::matchAndRewrite( ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto srcType = constOp.getType().dyn_cast(); if (!srcType) return failure(); // std.constant should only have vector or tenor types. assert((srcType.isa())); auto dstType = getTypeConverter()->convertType(srcType); if (!dstType) return failure(); auto dstElementsAttr = constOp.value().dyn_cast(); ShapedType dstAttrType = dstElementsAttr.getType(); if (!dstElementsAttr) return failure(); // If the composite type has more than one dimensions, perform linearization. if (srcType.getRank() > 1) { if (srcType.isa()) { dstAttrType = RankedTensorType::get(srcType.getNumElements(), srcType.getElementType()); dstElementsAttr = dstElementsAttr.reshape(dstAttrType); } else { // TODO: add support for large vectors. return failure(); } } Type srcElemType = srcType.getElementType(); Type dstElemType; // Tensor types are converted to SPIR-V array types; vector types are // converted to SPIR-V vector/array types. if (auto arrayType = dstType.dyn_cast()) dstElemType = arrayType.getElementType(); else dstElemType = dstType.cast().getElementType(); // If the source and destination element types are different, perform // attribute conversion. if (srcElemType != dstElemType) { SmallVector elements; if (srcElemType.isa()) { for (FloatAttr srcAttr : dstElementsAttr.getValues()) { FloatAttr dstAttr = convertFloatAttr(srcAttr, dstElemType.cast(), rewriter); if (!dstAttr) return failure(); elements.push_back(dstAttr); } } else if (srcElemType.isInteger(1)) { return failure(); } else { for (IntegerAttr srcAttr : dstElementsAttr.getValues()) { IntegerAttr dstAttr = convertIntegerAttr( srcAttr, dstElemType.cast(), rewriter); if (!dstAttr) return failure(); elements.push_back(dstAttr); } } // Unfortunately, we cannot use dialect-specific types for element // attributes; element attributes only works with builtin types. So we need // to prepare another converted builtin types for the destination elements // attribute. if (dstAttrType.isa()) dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType); else dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); } rewriter.replaceOpWithNewOp(constOp, dstType, dstElementsAttr); return success(); } //===----------------------------------------------------------------------===// // ConstantOp with scalar type. //===----------------------------------------------------------------------===// LogicalResult ConstantScalarOpPattern::matchAndRewrite( ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Type srcType = constOp.getType(); if (!srcType.isIntOrIndexOrFloat()) return failure(); Type dstType = getTypeConverter()->convertType(srcType); if (!dstType) return failure(); // Floating-point types. if (srcType.isa()) { auto srcAttr = constOp.value().cast(); auto dstAttr = srcAttr; // Floating-point types not supported in the target environment are all // converted to float type. if (srcType != dstType) { dstAttr = convertFloatAttr(srcAttr, dstType.cast(), rewriter); if (!dstAttr) return failure(); } rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); return success(); } // Bool type. if (srcType.isInteger(1)) { // std.constant can use 0/1 instead of true/false for i1 values. We need to // handle that here. auto dstAttr = convertBoolAttr(constOp.value(), rewriter); if (!dstAttr) return failure(); rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); return success(); } // IndexType or IntegerType. Index values are converted to 32-bit integer // values when converting to SPIR-V. auto srcAttr = constOp.value().cast(); auto dstAttr = convertIntegerAttr(srcAttr, dstType.cast(), rewriter); if (!dstAttr) return failure(); rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); return success(); } //===----------------------------------------------------------------------===// // CmpFOp //===----------------------------------------------------------------------===// LogicalResult CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { switch (cmpFOp.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ rewriter.replaceOpWithNewOp(cmpFOp, cmpFOp.getResult().getType(), \ adaptor.lhs(), adaptor.rhs()); \ return success(); // Ordered. DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp); DISPATCH(CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); DISPATCH(CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp); DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp); // Unordered. DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp); DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); DISPATCH(CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp); DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp); #undef DISPATCH default: break; } return failure(); } LogicalResult CmpFOpNanKernelPattern::matchAndRewrite( CmpFOp cmpFOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { if (cmpFOp.getPredicate() == CmpFPredicate::ORD) { rewriter.replaceOpWithNewOp(cmpFOp, adaptor.lhs(), adaptor.rhs()); return success(); } if (cmpFOp.getPredicate() == CmpFPredicate::UNO) { rewriter.replaceOpWithNewOp(cmpFOp, adaptor.lhs(), adaptor.rhs()); return success(); } return failure(); } LogicalResult CmpFOpNanNonePattern::matchAndRewrite( CmpFOp cmpFOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { if (cmpFOp.getPredicate() != CmpFPredicate::ORD && cmpFOp.getPredicate() != CmpFPredicate::UNO) return failure(); Location loc = cmpFOp.getLoc(); Value lhsIsNan = rewriter.create(loc, adaptor.lhs()); Value rhsIsNan = rewriter.create(loc, adaptor.rhs()); Value replace = rewriter.create(loc, lhsIsNan, rhsIsNan); if (cmpFOp.getPredicate() == CmpFPredicate::ORD) replace = rewriter.create(loc, replace); rewriter.replaceOp(cmpFOp, replace); return success(); } //===----------------------------------------------------------------------===// // CmpIOp //===----------------------------------------------------------------------===// LogicalResult BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Type operandType = cmpIOp.lhs().getType(); if (!isBoolScalarOrVector(operandType)) return failure(); switch (cmpIOp.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ rewriter.replaceOpWithNewOp(cmpIOp, cmpIOp.getResult().getType(), \ adaptor.lhs(), adaptor.rhs()); \ return success(); DISPATCH(CmpIPredicate::eq, spirv::LogicalEqualOp); DISPATCH(CmpIPredicate::ne, spirv::LogicalNotEqualOp); #undef DISPATCH default:; } return failure(); } LogicalResult CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Type operandType = cmpIOp.lhs().getType(); if (isBoolScalarOrVector(operandType)) return failure(); switch (cmpIOp.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ if (spirvOp::template hasTrait() && \ operandType != this->getTypeConverter()->convertType(operandType)) { \ return cmpIOp.emitError( \ "bitwidth emulation is not implemented yet on unsigned op"); \ } \ rewriter.replaceOpWithNewOp(cmpIOp, cmpIOp.getResult().getType(), \ adaptor.lhs(), adaptor.rhs()); \ return success(); DISPATCH(CmpIPredicate::eq, spirv::IEqualOp); DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp); DISPATCH(CmpIPredicate::slt, spirv::SLessThanOp); DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp); DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp); DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp); DISPATCH(CmpIPredicate::ult, spirv::ULessThanOp); DISPATCH(CmpIPredicate::ule, spirv::ULessThanEqualOp); DISPATCH(CmpIPredicate::ugt, spirv::UGreaterThanOp); DISPATCH(CmpIPredicate::uge, spirv::UGreaterThanEqualOp); #undef DISPATCH } return failure(); } //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// LogicalResult ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { if (returnOp.getNumOperands() > 1) return failure(); if (returnOp.getNumOperands() == 1) { rewriter.replaceOpWithNewOp(returnOp, adaptor.getOperands()[0]); } else { rewriter.replaceOpWithNewOp(returnOp); } return success(); } //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// LogicalResult SelectOpPattern::matchAndRewrite(SelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { rewriter.replaceOpWithNewOp( op, adaptor.condition(), adaptor.true_value(), adaptor.false_value()); return success(); } //===----------------------------------------------------------------------===// // SplatOp //===----------------------------------------------------------------------===// LogicalResult SplatPattern::matchAndRewrite(SplatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto dstVecType = op.getType().dyn_cast(); if (!dstVecType || !spirv::CompositeType::isValid(dstVecType)) return failure(); SmallVector source(dstVecType.getNumElements(), adaptor.input()); rewriter.replaceOpWithNewOp(op, dstVecType, source); return success(); } //===----------------------------------------------------------------------===// // XorOp //===----------------------------------------------------------------------===// LogicalResult XOrOpPattern::matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { assert(adaptor.getOperands().size() == 2); if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) return failure(); auto dstType = getTypeConverter()->convertType(xorOp.getType()); if (!dstType) return failure(); rewriter.replaceOpWithNewOp(xorOp, dstType, adaptor.getOperands()); return success(); } LogicalResult BoolXOrOpPattern::matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { assert(adaptor.getOperands().size() == 2); if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) return failure(); auto dstType = getTypeConverter()->convertType(xorOp.getType()); if (!dstType) return failure(); rewriter.replaceOpWithNewOp(xorOp, dstType, adaptor.getOperands()); return success(); } //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// namespace mlir { void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); patterns.add< // Unary and binary patterns BitwiseOpPattern, BitwiseOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, SignedRemIOpPattern, XOrOpPattern, BoolXOrOpPattern, // Comparison patterns BoolCmpIOpPattern, CmpFOpPattern, CmpFOpNanNonePattern, CmpIOpPattern, // Constant patterns ConstantCompositeOpPattern, ConstantScalarOpPattern, ReturnOpPattern, SelectOpPattern, SplatPattern, // Type cast patterns UIToFPI1Pattern, ZeroExtendI1Pattern, TruncI1Pattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern>(typeConverter, context); // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel // capability is available. patterns.add(typeConverter, context, /*benefit=*/2); } void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, int64_t byteCountThreshold, RewritePatternSet &patterns) { patterns.add(typeConverter, patterns.getContext(), byteCountThreshold); } } // namespace mlir