//===- StandardToSPIRV.cpp - Standard to SPIR-V Patterns ------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns to convert standard dialect to SPIR-V dialect. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineMap.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "std-to-spirv-pattern" using namespace mlir; //===----------------------------------------------------------------------===// // Utility functions //===----------------------------------------------------------------------===// /// Returns true if the given `type` is a boolean scalar or vector type. static bool isBoolScalarOrVector(Type type) { if (type.isInteger(1)) return true; if (auto vecType = type.dyn_cast()) return vecType.getElementType().isInteger(1); return false; } /// Converts the given `srcAttr` into a boolean attribute if it holds an /// integral value. Returns null attribute if conversion fails. static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) { if (auto boolAttr = srcAttr.dyn_cast()) return boolAttr; if (auto intAttr = srcAttr.dyn_cast()) return builder.getBoolAttr(intAttr.getValue().getBoolValue()); return BoolAttr(); } /// Converts the given `srcAttr` to a new attribute of the given `dstType`. /// Returns null attribute if conversion fails. static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder) { // If the source number uses less active bits than the target bitwidth, then // it should be safe to convert. if (srcAttr.getValue().isIntN(dstType.getWidth())) return builder.getIntegerAttr(dstType, srcAttr.getInt()); // XXX: Try again by interpreting the source number as a signed value. // Although integers in the standard dialect are signless, they can represent // a signed number. It's the operation decides how to interpret. This is // dangerous, but it seems there is no good way of handling this if we still // want to change the bitwidth. Emit a message at least. if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) { auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt()); LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '" << dstAttr << "' for type '" << dstType << "'\n"); return dstAttr; } LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' illegal: cannot fit into target type '" << dstType << "'\n"); return IntegerAttr(); } /// Converts the given `srcAttr` to a new attribute of the given `dstType`. /// Returns null attribute if `dstType` is not 32-bit or conversion fails. static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder) { // Only support converting to float for now. if (!dstType.isF32()) return FloatAttr(); // Try to convert the source floating-point number to single precision. APFloat dstVal = srcAttr.getValue(); bool losesInfo = false; APFloat::opStatus status = dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo); if (status != APFloat::opOK || losesInfo) { LLVM_DEBUG(llvm::dbgs() << srcAttr << " illegal: cannot fit into converted type '" << dstType << "'\n"); return FloatAttr(); } return builder.getF32FloatAttr(dstVal.convertToFloat()); } /// Returns signed remainder for `lhs` and `rhs` and lets the result follow /// the sign of `signOperand`. /// /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative /// the result is undefined." So we cannot directly use spv.SRem/spv.SMod /// if either operand can be negative. Emulate it via spv.UMod. static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, Value signOperand, OpBuilder &builder) { assert(lhs.getType() == rhs.getType()); assert(lhs == signOperand || rhs == signOperand); Type type = lhs.getType(); // Calculate the remainder with spv.UMod. Value lhsAbs = builder.create(loc, type, lhs); Value rhsAbs = builder.create(loc, type, rhs); Value abs = builder.create(loc, lhsAbs, rhsAbs); // Fix the sign. Value isPositive; if (lhs == signOperand) isPositive = builder.create(loc, lhs, lhsAbs); else isPositive = builder.create(loc, rhs, rhsAbs); Value absNegate = builder.create(loc, type, abs); return builder.create(loc, type, isPositive, abs, absNegate); } /// Returns the offset of the value in `targetBits` representation. /// /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`. /// It's assumed to be non-negative. /// /// When accessing an element in the array treating as having elements of /// `targetBits`, multiple values are loaded in the same time. The method /// returns the offset where the `srcIdx` locates in the value. For example, if /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is /// located at (x % 4) * 8. Because there are four elements in one i32, and one /// element has 8 bits. static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder) { assert(targetBits % sourceBits == 0); IntegerType targetType = builder.getIntegerType(targetBits); IntegerAttr idxAttr = builder.getIntegerAttr(targetType, targetBits / sourceBits); auto idx = builder.create(loc, targetType, idxAttr); IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits); auto srcBitsValue = builder.create(loc, targetType, srcBitsAttr); auto m = builder.create(loc, srcIdx, idx); return builder.create(loc, targetType, m, srcBitsValue); } /// Returns an adjusted spirv::AccessChainOp. Based on the /// extension/capabilities, certain integer bitwidths `sourceBits` might not be /// supported. During conversion if a memref of an unsupported type is used, /// load/stores to this memref need to be modified to use a supported higher /// bitwidth `targetBits` and extracting the required bits. For an accessing a /// 1D array (spv.array or spv.rt_array), the last index is modified to load the /// bits needed. The extraction of the actual bits needed are handled /// separately. Note that this only works for a 1-D tensor. static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder) { assert(targetBits % sourceBits == 0); const auto loc = op.getLoc(); IntegerType targetType = builder.getIntegerType(targetBits); IntegerAttr attr = builder.getIntegerAttr(targetType, targetBits / sourceBits); auto idx = builder.create(loc, targetType, attr); auto lastDim = op->getOperand(op.getNumOperands() - 1); auto indices = llvm::to_vector<4>(op.indices()); // There are two elements if this is a 1-D tensor. assert(indices.size() == 2); indices.back() = builder.create(loc, lastDim, idx); Type t = typeConverter.convertType(op.component_ptr().getType()); return builder.create(loc, t, op.base_ptr(), indices); } /// Returns the shifted `targetBits`-bit value with the given offset. static Value shiftValue(Location loc, Value value, Value offset, Value mask, int targetBits, OpBuilder &builder) { Type targetType = builder.getIntegerType(targetBits); Value result = builder.create(loc, value, mask); return builder.create(loc, targetType, result, offset); } /// Returns true if the allocations of type `t` can be lowered to SPIR-V. static bool isAllocationSupported(MemRefType t) { // Currently only support workgroup local memory allocations with static // shape and int or float or vector of int or float element type. if (!(t.hasStaticShape() && SPIRVTypeConverter::getMemorySpaceForStorageClass( spirv::StorageClass::Workgroup) == t.getMemorySpaceAsInt())) return false; Type elementType = t.getElementType(); if (auto vecType = elementType.dyn_cast()) elementType = vecType.getElementType(); return elementType.isIntOrFloat(); } /// Returns the scope to use for atomic operations use for emulating store /// operations of unsupported integer bitwidths, based on the memref /// type. Returns None on failure. static Optional getAtomicOpScope(MemRefType t) { Optional storageClass = SPIRVTypeConverter::getStorageClassForMemorySpace( t.getMemorySpaceAsInt()); if (!storageClass) return {}; switch (*storageClass) { case spirv::StorageClass::StorageBuffer: return spirv::Scope::Device; case spirv::StorageClass::Workgroup: return spirv::Scope::Workgroup; default: { } } return {}; } //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// // Note that DRR cannot be used for the patterns in this file: we may need to // convert type along the way, which requires ConversionPattern. DRR generates // normal RewritePattern. namespace { /// Converts an allocation operation to SPIR-V. Currently only supports lowering /// to Workgroup memory when the size is constant. Note that this pattern needs /// to be applied in a pass that runs at least at spv.module scope since it wil /// ladd global variables into the spv.module. class AllocOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::AllocOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { MemRefType allocType = operation.getType(); if (!isAllocationSupported(allocType)) return operation.emitError("unhandled allocation type"); // Get the SPIR-V type for the allocation. Type spirvType = getTypeConverter()->convertType(allocType); // Insert spv.GlobalVariable for this allocation. Operation *parent = SymbolTable::getNearestSymbolTable(operation->getParentOp()); if (!parent) return failure(); Location loc = operation.getLoc(); spirv::GlobalVariableOp varOp; { OpBuilder::InsertionGuard guard(rewriter); Block &entryBlock = *parent->getRegion(0).begin(); rewriter.setInsertionPointToStart(&entryBlock); auto varOps = entryBlock.getOps(); std::string varName = std::string("__workgroup_mem__") + std::to_string(std::distance(varOps.begin(), varOps.end())); varOp = rewriter.create(loc, spirvType, varName, /*initializer=*/nullptr); } // Get pointer to global variable at the current scope. rewriter.replaceOpWithNewOp(operation, varOp); return success(); } }; /// Removed a deallocation if it is a supported allocation. Currently only /// removes deallocation if the memory space is workgroup memory. class DeallocOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::DeallocOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { MemRefType deallocType = operation.memref().getType().cast(); if (!isAllocationSupported(deallocType)) return operation.emitError("unhandled deallocation type"); rewriter.eraseOp(operation); return success(); } }; /// Converts unary and binary standard operations to SPIR-V operations. template class UnaryAndBinaryOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(StdOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { assert(operands.size() <= 2); auto dstType = this->getTypeConverter()->convertType(operation.getType()); if (!dstType) return failure(); if (SPIRVOp::template hasTrait() && dstType != operation.getType()) { return operation.emitError( "bitwidth emulation is not implemented yet on unsigned op"); } rewriter.template replaceOpWithNewOp(operation, dstType, operands); return success(); } }; /// Converts math.log1p to SPIR-V ops. /// /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to /// these operations. class Log1pOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(math::Log1pOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { assert(operands.size() == 1); Location loc = operation.getLoc(); auto type = this->getTypeConverter()->convertType(operation.operand().getType()); auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); auto onePlus = rewriter.create(loc, one, operands[0]); rewriter.replaceOpWithNewOp(operation, type, onePlus); return success(); } }; /// Converts std.remi_signed to SPIR-V ops. /// /// This cannot be merged into the template unary/binary pattern due to /// Vulkan restrictions over spv.SRem and spv.SMod. class SignedRemIOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(SignedRemIOp remOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts bitwise standard operations to SPIR-V operations. This is a special /// pattern other than the BinaryOpPatternPattern because if the operands are /// boolean values, SPIR-V uses different operations (`SPIRVLogicalOp`). For /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. template class BitwiseOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(StdOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { assert(operands.size() == 2); auto dstType = this->getTypeConverter()->convertType(operation.getResult().getType()); if (!dstType) return failure(); if (isBoolScalarOrVector(operands.front().getType())) { rewriter.template replaceOpWithNewOp(operation, dstType, operands); } else { rewriter.template replaceOpWithNewOp(operation, dstType, operands); } return success(); } }; /// Converts composite std.constant operation to spv.Constant. class ConstantCompositeOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ConstantOp constOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts scalar std.constant operation to spv.Constant. class ConstantScalarOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ConstantOp constOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts floating-point comparison operations to SPIR-V ops. class CmpFOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts floating point NaN check to SPIR-V ops. This pattern requires /// Kernel capability. class CmpFOpNanKernelPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts floating point NaN check to SPIR-V ops. This pattern does not /// require additional capability. class CmpFOpNanNonePattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts integer compare operation on i1 type operands to SPIR-V ops. class BoolCmpIOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts integer compare operation to SPIR-V ops. class CmpIOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts memref.load to spv.Load. class IntLoadOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts memref.load to spv.Load. class LoadOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts std.return to spv.Return. class ReturnOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ReturnOp returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts std.select to spv.Select. class SelectOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(SelectOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts std.splat to spv.CompositeConstruct. class SplatPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(SplatOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts memref.store to spv.Store on integers. class IntStoreOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::StoreOp storeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts memref.store to spv.Store. class StoreOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::StoreOp storeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts std.zexti to spv.Select if the type of source is i1 or vector of /// i1. class ZeroExtendI1Pattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ZeroExtendIOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto srcType = operands.front().getType(); if (!isBoolScalarOrVector(srcType)) return failure(); auto dstType = this->getTypeConverter()->convertType(op.getResult().getType()); Location loc = op.getLoc(); Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); rewriter.template replaceOpWithNewOp( op, dstType, operands.front(), one, zero); return success(); } }; /// Converts tensor.extract into loading using access chains from SPIR-V local /// variables. class TensorExtractPattern final : public OpConversionPattern { public: TensorExtractPattern(TypeConverter &typeConverter, MLIRContext *context, int64_t threshold, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), byteCountThreshold(threshold) {} LogicalResult matchAndRewrite(tensor::ExtractOp extractOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { TensorType tensorType = extractOp.tensor().getType().cast(); if (!tensorType.hasStaticShape()) return rewriter.notifyMatchFailure(extractOp, "non-static tensor"); if (tensorType.getNumElements() * tensorType.getElementTypeBitWidth() > byteCountThreshold * 8) return rewriter.notifyMatchFailure(extractOp, "exceeding byte count threshold"); Location loc = extractOp.getLoc(); tensor::ExtractOp::Adaptor adaptor(operands); int64_t rank = tensorType.getRank(); SmallVector strides(rank, 1); for (int i = rank - 2; i >= 0; --i) { strides[i] = strides[i + 1] * tensorType.getDimSize(i + 1); } Type varType = spirv::PointerType::get(adaptor.tensor().getType(), spirv::StorageClass::Function); spirv::VariableOp varOp; if (adaptor.tensor().getDefiningOp()) { varOp = rewriter.create( loc, varType, spirv::StorageClass::Function, /*initializer=*/adaptor.tensor()); } else { // Need to store the value to the local variable. It's questionable // whether we want to support such case though. return failure(); } Value index = spirv::linearizeIndex(adaptor.indices(), strides, /*offset=*/0, loc, rewriter); auto acOp = rewriter.create(loc, varOp, index); rewriter.replaceOpWithNewOp(extractOp, acOp); return success(); } private: int64_t byteCountThreshold; }; /// Converts std.trunci to spv.Select if the type of result is i1 or vector of /// i1. class TruncI1Pattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(TruncateIOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dstType = this->getTypeConverter()->convertType(op.getResult().getType()); if (!isBoolScalarOrVector(dstType)) return failure(); Location loc = op.getLoc(); auto srcType = operands.front().getType(); // Check if (x & 1) == 1. Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); Value maskedSrc = rewriter.create(loc, srcType, operands[0], mask); Value isOne = rewriter.create(loc, maskedSrc, mask); Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); rewriter.replaceOpWithNewOp(op, dstType, isOne, one, zero); return success(); } }; /// Converts std.uitofp to spv.Select if the type of source is i1 or vector of /// i1. class UIToFPI1Pattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(UIToFPOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto srcType = operands.front().getType(); if (!isBoolScalarOrVector(srcType)) return failure(); auto dstType = this->getTypeConverter()->convertType(op.getResult().getType()); Location loc = op.getLoc(); Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); rewriter.template replaceOpWithNewOp( op, dstType, operands.front(), one, zero); return success(); } }; /// Converts type-casting standard operations to SPIR-V operations. template class TypeCastingOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(StdOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { assert(operands.size() == 1); auto srcType = operands.front().getType(); auto dstType = this->getTypeConverter()->convertType(operation.getResult().getType()); if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) return failure(); if (dstType == srcType) { // Due to type conversion, we are seeing the same source and target type. // Then we can just erase this operation by forwarding its operand. rewriter.replaceOp(operation, operands.front()); } else { rewriter.template replaceOpWithNewOp(operation, dstType, operands); } return success(); } }; /// Converts std.xor to SPIR-V operations. class XOrOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(XOrOp xorOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; /// Converts std.xor to SPIR-V operations if the type of source is i1 or vector /// of i1. class BoolXOrOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(XOrOp xorOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; } // namespace //===----------------------------------------------------------------------===// // SignedRemIOpPattern //===----------------------------------------------------------------------===// LogicalResult SignedRemIOpPattern::matchAndRewrite( SignedRemIOp remOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { Value result = emulateSignedRemainder(remOp.getLoc(), operands[0], operands[1], operands[0], rewriter); rewriter.replaceOp(remOp, result); return success(); } //===----------------------------------------------------------------------===// // ConstantOp with composite type. //===----------------------------------------------------------------------===// // TODO: This probably should be split into the vector case and tensor case, // so that the tensor case can be moved to TensorToSPIRV conversion. But, // std.constant is for the standard dialect though. LogicalResult ConstantCompositeOpPattern::matchAndRewrite( ConstantOp constOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { auto srcType = constOp.getType().dyn_cast(); if (!srcType) return failure(); // std.constant should only have vector or tenor types. assert((srcType.isa())); auto dstType = getTypeConverter()->convertType(srcType); if (!dstType) return failure(); auto dstElementsAttr = constOp.value().dyn_cast(); ShapedType dstAttrType = dstElementsAttr.getType(); if (!dstElementsAttr) return failure(); // If the composite type has more than one dimensions, perform linearization. if (srcType.getRank() > 1) { if (srcType.isa()) { dstAttrType = RankedTensorType::get(srcType.getNumElements(), srcType.getElementType()); dstElementsAttr = dstElementsAttr.reshape(dstAttrType); } else { // TODO: add support for large vectors. return failure(); } } Type srcElemType = srcType.getElementType(); Type dstElemType; // Tensor types are converted to SPIR-V array types; vector types are // converted to SPIR-V vector/array types. if (auto arrayType = dstType.dyn_cast()) dstElemType = arrayType.getElementType(); else dstElemType = dstType.cast().getElementType(); // If the source and destination element types are different, perform // attribute conversion. if (srcElemType != dstElemType) { SmallVector elements; if (srcElemType.isa()) { for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) { FloatAttr dstAttr = convertFloatAttr( srcAttr.cast(), dstElemType.cast(), rewriter); if (!dstAttr) return failure(); elements.push_back(dstAttr); } } else if (srcElemType.isInteger(1)) { return failure(); } else { for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) { IntegerAttr dstAttr = convertIntegerAttr(srcAttr.cast(), dstElemType.cast(), rewriter); if (!dstAttr) return failure(); elements.push_back(dstAttr); } } // Unfortunately, we cannot use dialect-specific types for element // attributes; element attributes only works with builtin types. So we need // to prepare another converted builtin types for the destination elements // attribute. if (dstAttrType.isa()) dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType); else dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType); dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements); } rewriter.replaceOpWithNewOp(constOp, dstType, dstElementsAttr); return success(); } //===----------------------------------------------------------------------===// // ConstantOp with scalar type. //===----------------------------------------------------------------------===// LogicalResult ConstantScalarOpPattern::matchAndRewrite( ConstantOp constOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { Type srcType = constOp.getType(); if (!srcType.isIntOrIndexOrFloat()) return failure(); Type dstType = getTypeConverter()->convertType(srcType); if (!dstType) return failure(); // Floating-point types. if (srcType.isa()) { auto srcAttr = constOp.value().cast(); auto dstAttr = srcAttr; // Floating-point types not supported in the target environment are all // converted to float type. if (srcType != dstType) { dstAttr = convertFloatAttr(srcAttr, dstType.cast(), rewriter); if (!dstAttr) return failure(); } rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); return success(); } // Bool type. if (srcType.isInteger(1)) { // std.constant can use 0/1 instead of true/false for i1 values. We need to // handle that here. auto dstAttr = convertBoolAttr(constOp.value(), rewriter); if (!dstAttr) return failure(); rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); return success(); } // IndexType or IntegerType. Index values are converted to 32-bit integer // values when converting to SPIR-V. auto srcAttr = constOp.value().cast(); auto dstAttr = convertIntegerAttr(srcAttr, dstType.cast(), rewriter); if (!dstAttr) return failure(); rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); return success(); } //===----------------------------------------------------------------------===// // CmpFOp //===----------------------------------------------------------------------===// LogicalResult CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { CmpFOpAdaptor cmpFOpOperands(operands); switch (cmpFOp.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ rewriter.replaceOpWithNewOp(cmpFOp, cmpFOp.getResult().getType(), \ cmpFOpOperands.lhs(), \ cmpFOpOperands.rhs()); \ return success(); // Ordered. DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp); DISPATCH(CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); DISPATCH(CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp); DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp); // Unordered. DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp); DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); DISPATCH(CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp); DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp); #undef DISPATCH default: break; } return failure(); } LogicalResult CmpFOpNanKernelPattern::matchAndRewrite( CmpFOp cmpFOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { CmpFOpAdaptor cmpFOpOperands(operands); if (cmpFOp.getPredicate() == CmpFPredicate::ORD) { rewriter.replaceOpWithNewOp(cmpFOp, cmpFOpOperands.lhs(), cmpFOpOperands.rhs()); return success(); } if (cmpFOp.getPredicate() == CmpFPredicate::UNO) { rewriter.replaceOpWithNewOp( cmpFOp, cmpFOpOperands.lhs(), cmpFOpOperands.rhs()); return success(); } return failure(); } LogicalResult CmpFOpNanNonePattern::matchAndRewrite( CmpFOp cmpFOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (cmpFOp.getPredicate() != CmpFPredicate::ORD && cmpFOp.getPredicate() != CmpFPredicate::UNO) return failure(); CmpFOpAdaptor cmpFOpOperands(operands); Location loc = cmpFOp.getLoc(); Value lhsIsNan = rewriter.create(loc, cmpFOpOperands.lhs()); Value rhsIsNan = rewriter.create(loc, cmpFOpOperands.rhs()); Value replace = rewriter.create(loc, lhsIsNan, rhsIsNan); if (cmpFOp.getPredicate() == CmpFPredicate::ORD) replace = rewriter.create(loc, replace); rewriter.replaceOp(cmpFOp, replace); return success(); } //===----------------------------------------------------------------------===// // CmpIOp //===----------------------------------------------------------------------===// LogicalResult BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { CmpIOpAdaptor cmpIOpOperands(operands); Type operandType = cmpIOp.lhs().getType(); if (!isBoolScalarOrVector(operandType)) return failure(); switch (cmpIOp.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ rewriter.replaceOpWithNewOp(cmpIOp, cmpIOp.getResult().getType(), \ cmpIOpOperands.lhs(), \ cmpIOpOperands.rhs()); \ return success(); DISPATCH(CmpIPredicate::eq, spirv::LogicalEqualOp); DISPATCH(CmpIPredicate::ne, spirv::LogicalNotEqualOp); #undef DISPATCH default:; } return failure(); } LogicalResult CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { CmpIOpAdaptor cmpIOpOperands(operands); Type operandType = cmpIOp.lhs().getType(); if (isBoolScalarOrVector(operandType)) return failure(); switch (cmpIOp.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ if (spirvOp::template hasTrait() && \ operandType != this->getTypeConverter()->convertType(operandType)) { \ return cmpIOp.emitError( \ "bitwidth emulation is not implemented yet on unsigned op"); \ } \ rewriter.replaceOpWithNewOp(cmpIOp, cmpIOp.getResult().getType(), \ cmpIOpOperands.lhs(), \ cmpIOpOperands.rhs()); \ return success(); DISPATCH(CmpIPredicate::eq, spirv::IEqualOp); DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp); DISPATCH(CmpIPredicate::slt, spirv::SLessThanOp); DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp); DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp); DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp); DISPATCH(CmpIPredicate::ult, spirv::ULessThanOp); DISPATCH(CmpIPredicate::ule, spirv::ULessThanEqualOp); DISPATCH(CmpIPredicate::ugt, spirv::UGreaterThanOp); DISPATCH(CmpIPredicate::uge, spirv::UGreaterThanEqualOp); #undef DISPATCH } return failure(); } //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// LogicalResult IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { memref::LoadOpAdaptor loadOperands(operands); auto loc = loadOp.getLoc(); auto memrefType = loadOp.memref().getType().cast(); if (!memrefType.getElementType().isSignlessInteger()) return failure(); auto &typeConverter = *getTypeConverter(); spirv::AccessChainOp accessChainOp = spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(), loadOperands.indices(), loc, rewriter); int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); bool isBool = srcBits == 1; if (isBool) srcBits = typeConverter.getOptions().boolNumBits; Type pointeeType = typeConverter.convertType(memrefType) .cast() .getPointeeType(); Type structElemType = pointeeType.cast().getElementType(0); Type dstType; if (auto arrayType = structElemType.dyn_cast()) dstType = arrayType.getElementType(); else dstType = structElemType.cast().getElementType(); int dstBits = dstType.getIntOrFloatBitWidth(); assert(dstBits % srcBits == 0); // If the rewrited load op has the same bit width, use the loading value // directly. if (srcBits == dstBits) { rewriter.replaceOpWithNewOp(loadOp, accessChainOp.getResult()); return success(); } // Assume that getElementPtr() works linearizely. If it's a scalar, the method // still returns a linearized accessing. If the accessing is not linearized, // there will be offset issues. assert(accessChainOp.indices().size() == 2); Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, srcBits, dstBits, rewriter); Value spvLoadOp = rewriter.create( loc, dstType, adjustedPtr, loadOp->getAttrOfType( spirv::attributeName()), loadOp->getAttrOfType("alignment")); // Shift the bits to the rightmost. // ____XXXX________ -> ____________XXXX Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); Value result = rewriter.create( loc, spvLoadOp.getType(), spvLoadOp, offset); // Apply the mask to extract corresponding bits. Value mask = rewriter.create( loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); result = rewriter.create(loc, dstType, result, mask); // Apply sign extension on the loading value unconditionally. The signedness // semantic is carried in the operator itself, we relies other pattern to // handle the casting. IntegerAttr shiftValueAttr = rewriter.getIntegerAttr(dstType, dstBits - srcBits); Value shiftValue = rewriter.create(loc, dstType, shiftValueAttr); result = rewriter.create(loc, dstType, result, shiftValue); result = rewriter.create(loc, dstType, result, shiftValue); if (isBool) { dstType = typeConverter.convertType(loadOp.getType()); mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter); Value isOne = rewriter.create(loc, result, mask); Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); result = rewriter.create(loc, dstType, isOne, one, zero); } else if (result.getType().getIntOrFloatBitWidth() != static_cast(dstBits)) { result = rewriter.create(loc, dstType, result); } rewriter.replaceOp(loadOp, result); assert(accessChainOp.use_empty()); rewriter.eraseOp(accessChainOp); return success(); } LogicalResult LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { memref::LoadOpAdaptor loadOperands(operands); auto memrefType = loadOp.memref().getType().cast(); if (memrefType.getElementType().isSignlessInteger()) return failure(); auto loadPtr = spirv::getElementPtr( *getTypeConverter(), memrefType, loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter); rewriter.replaceOpWithNewOp(loadOp, loadPtr); return success(); } //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// LogicalResult ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (returnOp.getNumOperands() > 1) return failure(); if (returnOp.getNumOperands() == 1) { rewriter.replaceOpWithNewOp(returnOp, operands[0]); } else { rewriter.replaceOpWithNewOp(returnOp); } return success(); } //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// LogicalResult SelectOpPattern::matchAndRewrite(SelectOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { SelectOpAdaptor selectOperands(operands); rewriter.replaceOpWithNewOp(op, selectOperands.condition(), selectOperands.true_value(), selectOperands.false_value()); return success(); } //===----------------------------------------------------------------------===// // SplatOp //===----------------------------------------------------------------------===// LogicalResult SplatPattern::matchAndRewrite(SplatOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { auto dstVecType = op.getType().dyn_cast(); if (!dstVecType || !spirv::CompositeType::isValid(dstVecType)) return failure(); SplatOp::Adaptor adaptor(operands); SmallVector source(dstVecType.getNumElements(), adaptor.input()); rewriter.replaceOpWithNewOp(op, dstVecType, source); return success(); } //===----------------------------------------------------------------------===// // StoreOp //===----------------------------------------------------------------------===// LogicalResult IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { memref::StoreOpAdaptor storeOperands(operands); auto memrefType = storeOp.memref().getType().cast(); if (!memrefType.getElementType().isSignlessInteger()) return failure(); auto loc = storeOp.getLoc(); auto &typeConverter = *getTypeConverter(); spirv::AccessChainOp accessChainOp = spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(), storeOperands.indices(), loc, rewriter); int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); bool isBool = srcBits == 1; if (isBool) srcBits = typeConverter.getOptions().boolNumBits; Type pointeeType = typeConverter.convertType(memrefType) .cast() .getPointeeType(); Type structElemType = pointeeType.cast().getElementType(0); Type dstType; if (auto arrayType = structElemType.dyn_cast()) dstType = arrayType.getElementType(); else dstType = structElemType.cast().getElementType(); int dstBits = dstType.getIntOrFloatBitWidth(); assert(dstBits % srcBits == 0); if (srcBits == dstBits) { rewriter.replaceOpWithNewOp( storeOp, accessChainOp.getResult(), storeOperands.value()); return success(); } // Since there are multi threads in the processing, the emulation will be done // with atomic operations. E.g., if the storing value is i8, rewrite the // StoreOp to // 1) load a 32-bit integer // 2) clear 8 bits in the loading value // 3) store 32-bit value back // 4) load a 32-bit integer // 5) modify 8 bits in the loading value // 6) store 32-bit value back // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step // 4 to step 6 are done by AtomicOr as another atomic step. assert(accessChainOp.indices().size() == 2); Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); // Create a mask to clear the destination. E.g., if it is the second i8 in // i32, 0xFFFF00FF is created. Value mask = rewriter.create( loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); Value clearBitsMask = rewriter.create(loc, dstType, mask, offset); clearBitsMask = rewriter.create(loc, dstType, clearBitsMask); Value storeVal = storeOperands.value(); if (isBool) { Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); storeVal = rewriter.create(loc, dstType, storeVal, one, zero); } storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter); Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, srcBits, dstBits, rewriter); Optional scope = getAtomicOpScope(memrefType); if (!scope) return failure(); Value result = rewriter.create( loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, clearBitsMask); result = rewriter.create( loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, storeVal); // The AtomicOrOp has no side effect. Since it is already inserted, we can // just remove the original StoreOp. Note that rewriter.replaceOp() // doesn't work because it only accepts that the numbers of result are the // same. rewriter.eraseOp(storeOp); assert(accessChainOp.use_empty()); rewriter.eraseOp(accessChainOp); return success(); } LogicalResult StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { memref::StoreOpAdaptor storeOperands(operands); auto memrefType = storeOp.memref().getType().cast(); if (memrefType.getElementType().isSignlessInteger()) return failure(); auto storePtr = spirv::getElementPtr(*getTypeConverter(), memrefType, storeOperands.memref(), storeOperands.indices(), storeOp.getLoc(), rewriter); rewriter.replaceOpWithNewOp(storeOp, storePtr, storeOperands.value()); return success(); } //===----------------------------------------------------------------------===// // XorOp //===----------------------------------------------------------------------===// LogicalResult XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { assert(operands.size() == 2); if (isBoolScalarOrVector(operands.front().getType())) return failure(); auto dstType = getTypeConverter()->convertType(xorOp.getType()); if (!dstType) return failure(); rewriter.replaceOpWithNewOp(xorOp, dstType, operands); return success(); } LogicalResult BoolXOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { assert(operands.size() == 2); if (!isBoolScalarOrVector(operands.front().getType())) return failure(); auto dstType = getTypeConverter()->convertType(xorOp.getType()); if (!dstType) return failure(); rewriter.replaceOpWithNewOp(xorOp, dstType, operands); return success(); } //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// namespace mlir { void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); patterns.add< // Math dialect operations. // TODO: Move to separate pass. UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, // Unary and binary patterns BitwiseOpPattern, BitwiseOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, Log1pOpPattern, SignedRemIOpPattern, XOrOpPattern, BoolXOrOpPattern, // Comparison patterns BoolCmpIOpPattern, CmpFOpPattern, CmpFOpNanNonePattern, CmpIOpPattern, // Constant patterns ConstantCompositeOpPattern, ConstantScalarOpPattern, // Memory patterns AllocOpPattern, DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern, StoreOpPattern, ReturnOpPattern, SelectOpPattern, SplatPattern, // Type cast patterns UIToFPI1Pattern, ZeroExtendI1Pattern, TruncI1Pattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern>(typeConverter, context); // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel // capability is available. patterns.add(typeConverter, context, /*benefit=*/2); } void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, int64_t byteCountThreshold, RewritePatternSet &patterns) { patterns.add(typeConverter, patterns.getContext(), byteCountThreshold); } } // namespace mlir