//===- ConvertSPIRVToLLVM.cpp - SPIR-V dialect to LLVM dialect conversion -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns to convert SPIR-V dialect to LLVM dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Module.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "spirv-to-llvm-pattern" using namespace mlir; //===----------------------------------------------------------------------===// // Utility functions //===----------------------------------------------------------------------===// /// Returns true if the given type is a signed integer or vector type. static bool isSignedIntegerOrVector(Type type) { if (type.isSignedInteger()) return true; if (auto vecType = type.dyn_cast()) return vecType.getElementType().isSignedInteger(); return false; } /// Returns true if the given type is an unsigned integer or vector type static bool isUnsignedIntegerOrVector(Type type) { if (type.isUnsignedInteger()) return true; if (auto vecType = type.dyn_cast()) return vecType.getElementType().isUnsignedInteger(); return false; } /// Returns the bit width of integer, float or vector of float or integer values static unsigned getBitWidth(Type type) { assert((type.isIntOrFloat() || type.isa()) && "bitwidth is not supported for this type"); if (type.isIntOrFloat()) return type.getIntOrFloatBitWidth(); auto vecType = type.dyn_cast(); auto elementType = vecType.getElementType(); assert(elementType.isIntOrFloat() && "only integers and floats have a bitwidth"); return elementType.getIntOrFloatBitWidth(); } /// Returns the bit width of LLVMType integer or vector. static unsigned getLLVMTypeBitWidth(LLVM::LLVMType type) { return type.isVectorTy() ? type.getVectorElementType().getIntegerBitWidth() : type.getIntegerBitWidth(); } /// Creates `IntegerAttribute` with all bits set for given type static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) { if (auto vecType = type.dyn_cast()) { auto integerType = vecType.getElementType().cast(); return builder.getIntegerAttr(integerType, -1); } auto integerType = type.cast(); return builder.getIntegerAttr(integerType, -1); } /// Creates `llvm.mlir.constant` with all bits set for the given type. static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter) { if (srcType.isa()) { return rewriter.create( loc, dstType, SplatElementsAttr::get(srcType.cast(), minusOneIntegerAttribute(srcType, rewriter))); } return rewriter.create( loc, dstType, minusOneIntegerAttribute(srcType, rewriter)); } /// Utility function for bitfiled ops: /// - `BitFieldInsert` /// - `BitFieldSExtract` /// - `BitFieldUExtract` /// Truncates or extends the value. If the bitwidth of the value is the same as /// `dstType` bitwidth, the value remains unchanged. static Value optionallyTruncateOrExtend(Location loc, Value value, Type dstType, PatternRewriter &rewriter) { auto srcType = value.getType(); auto llvmType = dstType.cast(); unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType); unsigned valueBitWidth = srcType.isa() ? getLLVMTypeBitWidth(srcType.cast()) : getBitWidth(srcType); if (valueBitWidth < targetBitWidth) return rewriter.create(loc, llvmType, value); // If the bit widths of `Count` and `Offset` are greater than the bit width // of the target type, they are truncated. Truncation is safe since `Count` // and `Offset` must be no more than 64 for op behaviour to be defined. Hence, // both values can be expressed in 8 bits. if (valueBitWidth > targetBitWidth) return rewriter.create(loc, llvmType, value); return value; } /// Broadcasts the value to vector with `numElements` number of elements. static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { auto vectorType = VectorType::get(numElements, toBroadcast.getType()); auto llvmVectorType = typeConverter.convertType(vectorType); auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); Value broadcasted = rewriter.create(loc, llvmVectorType); for (unsigned i = 0; i < numElements; ++i) { auto index = rewriter.create( loc, llvmI32Type, rewriter.getI32IntegerAttr(i)); broadcasted = rewriter.create( loc, llvmVectorType, broadcasted, toBroadcast, index); } return broadcasted; } /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged. static Value optionallyBroadcast(Location loc, Value value, Type srcType, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { if (auto vectorType = srcType.dyn_cast()) { unsigned numElements = vectorType.getNumElements(); return broadcast(loc, value, numElements, typeConverter, rewriter); } return value; } /// Utility function for bitfiled ops: `BitFieldInsert`, `BitFieldSExtract` and /// `BitFieldUExtract`. /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of /// a vector type, construct a vector that has: /// - same number of elements as `Base` /// - each element has the type that is the same as the type of `Offset` or /// `Count` /// - each element has the same value as `Offset` or `Count` /// Then cast `Offset` and `Count` if their bit width is different /// from `Base` bit width. static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, LLVMTypeConverter &converter, ConversionPatternRewriter &rewriter) { Value broadcasted = optionallyBroadcast(loc, value, srcType, converter, rewriter); return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter); } /// Converts SPIR-V struct with no offset to packed LLVM struct. static Type convertStructTypePacked(spirv::StructType type, LLVMTypeConverter &converter) { auto elementsVector = llvm::to_vector<8>( llvm::map_range(type.getElementTypes(), [&](Type elementType) { return converter.convertType(elementType).cast(); })); return LLVM::LLVMType::getStructTy(converter.getDialect(), elementsVector, /*isPacked=*/true); } /// Creates LLVM dialect constant with the given value. static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, LLVMTypeConverter &converter, unsigned value) { return rewriter.create( loc, LLVM::LLVMType::getInt32Ty(converter.getDialect()), rewriter.getIntegerAttr(rewriter.getI32Type(), value)); } /// Utility for `spv.Load` and `spv.Store` conversion. static LogicalResult replaceWithLoadOrStore(Operation *op, ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal) { if (auto loadOp = dyn_cast(op)) { auto dstType = typeConverter.convertType(loadOp.getType()); if (!dstType) return failure(); rewriter.replaceOpWithNewOp( loadOp, dstType, loadOp.ptr(), alignment, isVolatile, isNonTemporal); return success(); } auto storeOp = cast(op); rewriter.replaceOpWithNewOp(storeOp, storeOp.value(), storeOp.ptr(), alignment, isVolatile, isNonTemporal); return success(); } //===----------------------------------------------------------------------===// // Type conversion //===----------------------------------------------------------------------===// /// Converts SPIR-V array type to LLVM array. There is no modelling of array /// stride at the moment. static Optional convertArrayType(spirv::ArrayType type, TypeConverter &converter) { if (type.getArrayStride() != 0) return llvm::None; auto elementType = converter.convertType(type.getElementType()).cast(); unsigned numElements = type.getNumElements(); return LLVM::LLVMType::getArrayTy(elementType, numElements); } /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not /// modelled at the moment. static Type convertPointerType(spirv::PointerType type, TypeConverter &converter) { auto pointeeType = converter.convertType(type.getPointeeType()).cast(); return pointeeType.getPointerTo(); } /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is /// no modelling of array stride at the moment. static Optional convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter) { if (type.getArrayStride() != 0) return llvm::None; auto elementType = converter.convertType(type.getElementType()).cast(); return LLVM::LLVMType::getArrayTy(elementType, 0); } /// Converts SPIR-V struct to LLVM struct. There is no support of structs with /// member decorations or with offset. static Optional convertStructType(spirv::StructType type, LLVMTypeConverter &converter) { SmallVector memberDecorations; type.getMemberDecorations(memberDecorations); if (type.hasOffset() || !memberDecorations.empty()) return llvm::None; return convertStructTypePacked(type, converter); } //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// namespace { class BitFieldInsertPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::BitFieldInsertOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto srcType = op.getType(); auto dstType = typeConverter.convertType(srcType); if (!dstType) return failure(); Location loc = op.getLoc(); // Process `Offset` and `Count`: broadcast and extend/truncate if needed. Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType, typeConverter, rewriter); Value count = processCountOrOffset(loc, op.count(), srcType, dstType, typeConverter, rewriter); // Create a mask with bits set outside [Offset, Offset + Count - 1]. Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); Value maskShiftedByCount = rewriter.create(loc, dstType, minusOne, count); Value negated = rewriter.create(loc, dstType, maskShiftedByCount, minusOne); Value maskShiftedByCountAndOffset = rewriter.create(loc, dstType, negated, offset); Value mask = rewriter.create( loc, dstType, maskShiftedByCountAndOffset, minusOne); // Extract unchanged bits from the `Base` that are outside of // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`. Value baseAndMask = rewriter.create(loc, dstType, op.base(), mask); Value insertShiftedByOffset = rewriter.create(loc, dstType, op.insert(), offset); rewriter.replaceOpWithNewOp(op, dstType, baseAndMask, insertShiftedByOffset); return success(); } }; /// Converts SPIR-V ConstantOp with scalar or vector type. class ConstantScalarAndVectorPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::ConstantOp constOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto srcType = constOp.getType(); if (!srcType.isa() && !srcType.isIntOrFloat()) return failure(); auto dstType = typeConverter.convertType(srcType); if (!dstType) return failure(); // SPIR-V constant can be a signed/unsigned integer, which has to be // casted to signless integer when converting to LLVM dialect. Removing the // sign bit may have unexpected behaviour. However, it is better to handle // it case-by-case, given that the purpose of the conversion is not to // cover all possible corner cases. if (isSignedIntegerOrVector(srcType) || isUnsignedIntegerOrVector(srcType)) { auto *context = rewriter.getContext(); auto signlessType = IntegerType::get(getBitWidth(srcType), context); if (srcType.isa()) { auto dstElementsAttr = constOp.value().cast(); rewriter.replaceOpWithNewOp( constOp, dstType, dstElementsAttr.mapValues( signlessType, [&](const APInt &value) { return value; })); return success(); } auto srcAttr = constOp.value().cast(); auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue()); rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); return success(); } rewriter.replaceOpWithNewOp(constOp, dstType, operands, constOp.getAttrs()); return success(); } }; class BitFieldSExtractPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::BitFieldSExtractOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto srcType = op.getType(); auto dstType = typeConverter.convertType(srcType); if (!dstType) return failure(); Location loc = op.getLoc(); // Process `Offset` and `Count`: broadcast and extend/truncate if needed. Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType, typeConverter, rewriter); Value count = processCountOrOffset(loc, op.count(), srcType, dstType, typeConverter, rewriter); // Create a constant that holds the size of the `Base`. IntegerType integerType; if (auto vecType = srcType.dyn_cast()) integerType = vecType.getElementType().cast(); else integerType = srcType.cast(); auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType)); Value size = srcType.isa() ? rewriter.create( loc, dstType, SplatElementsAttr::get(srcType.cast(), baseSize)) : rewriter.create(loc, dstType, baseSize); // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit // at Offset + Count - 1 is the most significant bit now. Value countPlusOffset = rewriter.create(loc, dstType, count, offset); Value amountToShiftLeft = rewriter.create(loc, dstType, size, countPlusOffset); Value baseShiftedLeft = rewriter.create( loc, dstType, op.base(), amountToShiftLeft); // Shift the result right, filling the bits with the sign bit. Value amountToShiftRight = rewriter.create(loc, dstType, offset, amountToShiftLeft); rewriter.replaceOpWithNewOp(op, dstType, baseShiftedLeft, amountToShiftRight); return success(); } }; class BitFieldUExtractPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::BitFieldUExtractOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto srcType = op.getType(); auto dstType = typeConverter.convertType(srcType); if (!dstType) return failure(); Location loc = op.getLoc(); // Process `Offset` and `Count`: broadcast and extend/truncate if needed. Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType, typeConverter, rewriter); Value count = processCountOrOffset(loc, op.count(), srcType, dstType, typeConverter, rewriter); // Create a mask with bits set at [0, Count - 1]. Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); Value maskShiftedByCount = rewriter.create(loc, dstType, minusOne, count); Value mask = rewriter.create(loc, dstType, maskShiftedByCount, minusOne); // Shift `Base` by `Offset` and apply the mask on it. Value shiftedBase = rewriter.create(loc, dstType, op.base(), offset); rewriter.replaceOpWithNewOp(op, dstType, shiftedBase, mask); return success(); } }; class BranchConversionPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::BranchOp branchOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(branchOp, operands, branchOp.getTarget()); return success(); } }; class BranchConditionalConversionPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion< spirv::BranchConditionalOp>::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::BranchConditionalOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // If branch weights exist, map them to 32-bit integer vector. ElementsAttr branchWeights = nullptr; if (auto weights = op.branch_weights()) { VectorType weightType = VectorType::get(2, rewriter.getI32Type()); branchWeights = DenseElementsAttr::get(weightType, weights.getValue().getValue()); } rewriter.replaceOpWithNewOp( op, op.condition(), op.getTrueBlockArguments(), op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(), op.getFalseBlock()); return success(); } }; /// Converts SPIR-V operations that have straightforward LLVM equivalent /// into LLVM dialect operations. template class DirectConversionPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(SPIRVOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(operation.getType()); if (!dstType) return failure(); rewriter.template replaceOpWithNewOp(operation, dstType, operands, operation.getAttrs()); return success(); } }; /// Converts SPIR-V cast ops that do not have straightforward LLVM /// equivalent in LLVM dialect. template class IndirectCastPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(SPIRVOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Type fromType = operation.operand().getType(); Type toType = operation.getType(); auto dstType = this->typeConverter.convertType(toType); if (!dstType) return failure(); if (getBitWidth(fromType) < getBitWidth(toType)) { rewriter.template replaceOpWithNewOp(operation, dstType, operands); return success(); } if (getBitWidth(fromType) > getBitWidth(toType)) { rewriter.template replaceOpWithNewOp(operation, dstType, operands); return success(); } return failure(); } }; class FunctionCallPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::FunctionCallOp callOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (callOp.getNumResults() == 0) { rewriter.replaceOpWithNewOp(callOp, llvm::None, operands, callOp.getAttrs()); return success(); } // Function returns a single result. auto dstType = typeConverter.convertType(callOp.getType(0)); rewriter.replaceOpWithNewOp(callOp, dstType, operands, callOp.getAttrs()); return success(); } }; /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate" template class FComparePattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(SPIRVOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(operation.getType()); if (!dstType) return failure(); rewriter.template replaceOpWithNewOp( operation, dstType, rewriter.getI64IntegerAttr(static_cast(predicate)), operation.operand1(), operation.operand2()); return success(); } }; /// Converts SPIR-V integer comparisons to llvm.icmp "predicate" template class IComparePattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(SPIRVOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(operation.getType()); if (!dstType) return failure(); rewriter.template replaceOpWithNewOp( operation, dstType, rewriter.getI64IntegerAttr(static_cast(predicate)), operation.operand1(), operation.operand2()); return success(); } }; /// Converts `spv.Load` and `spv.Store` to LLVM dialect. template class LoadStorePattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(SPIRVop op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (!op.memory_access().hasValue()) { replaceWithLoadOrStore(op, rewriter, this->typeConverter, /*alignment=*/0, /*isVolatile=*/false, /*isNonTemporal=*/ false); return success(); } auto memoryAccess = op.memory_access().getValue(); switch (memoryAccess) { case spirv::MemoryAccess::Aligned: case spirv::MemoryAccess::None: case spirv::MemoryAccess::Nontemporal: case spirv::MemoryAccess::Volatile: { unsigned alignment = memoryAccess == spirv::MemoryAccess::Aligned ? op.alignment().getValue().getZExtValue() : 0; bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal; bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile; replaceWithLoadOrStore(op, rewriter, this->typeConverter, alignment, isVolatile, isNonTemporal); return success(); } default: // There is no support of other memory access attributes. return failure(); } } }; /// Converts `spv.Not` and `spv.LogicalNot` into LLVM dialect. template class NotPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(SPIRVOp notOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto srcType = notOp.getType(); auto dstType = this->typeConverter.convertType(srcType); if (!dstType) return failure(); Location loc = notOp.getLoc(); IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter); auto mask = srcType.template isa() ? rewriter.create( loc, dstType, SplatElementsAttr::get( srcType.template cast(), minusOne)) : rewriter.create(loc, dstType, minusOne); rewriter.template replaceOpWithNewOp(notOp, dstType, notOp.operand(), mask); return success(); } }; class ReturnPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::ReturnOp returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(returnOp, ArrayRef(), ArrayRef()); return success(); } }; class ReturnValuePattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::ReturnValueOp returnValueOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(returnValueOp, ArrayRef(), operands); return success(); } }; class MergePattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::MergeOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.eraseOp(op); return success(); } }; /// Converts `spv.selection` with `spv.BranchConditional` in its header block. /// All blocks within selection should be reachable for conversion to succeed. class SelectionPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::SelectionOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // There is no support for `Flatten` or `DontFlatten` selection control at // the moment. This are just compiler hints and can be performed during the // optimization passes. if (op.selection_control() != spirv::SelectionControl::None) return failure(); // `spv.selection` should have at least two blocks: one selection header // block and one merge block. If no blocks are present, or control flow // branches straight to merge block (two blocks are present), the op is // redundant and it is erased. if (op.body().getBlocks().size() <= 2) { rewriter.eraseOp(op); return success(); } Location loc = op.getLoc(); // Split the current block after `spv.selection`. The remaing ops will be // used in `continueBlock`. auto *currentBlock = rewriter.getInsertionBlock(); rewriter.setInsertionPointAfter(op); auto position = rewriter.getInsertionPoint(); auto *continueBlock = rewriter.splitBlock(currentBlock, position); // Extract conditional branch information from the header block. By SPIR-V // dialect spec, it should contain `spv.BranchConditional` or `spv.Switch` // op. Note that `spv.Switch op` is not supported at the moment in the // SPIR-V dialect. Remove this block when finished. auto *headerBlock = op.getHeaderBlock(); assert(headerBlock->getOperations().size() == 1); auto condBrOp = dyn_cast( headerBlock->getOperations().front()); if (!condBrOp) return failure(); rewriter.eraseBlock(headerBlock); // Branch from merge block to continue block. auto *mergeBlock = op.getMergeBlock(); Operation *terminator = mergeBlock->getTerminator(); ValueRange terminatorOperands = terminator->getOperands(); rewriter.setInsertionPointToEnd(mergeBlock); rewriter.create(loc, terminatorOperands, continueBlock); // Link current block to `true` and `false` blocks within the selection. Block *trueBlock = condBrOp.getTrueBlock(); Block *falseBlock = condBrOp.getFalseBlock(); rewriter.setInsertionPointToEnd(currentBlock); rewriter.create(loc, condBrOp.condition(), trueBlock, condBrOp.trueTargetOperands(), falseBlock, condBrOp.falseTargetOperands()); rewriter.inlineRegionBefore(op.body(), continueBlock); rewriter.replaceOp(op, continueBlock->getArguments()); return success(); } }; /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect /// puts a restriction on `Shift` and `Base` to have the same bit width, /// `Shift` is zero or sign extended to match this specification. Cases when /// `Shift` bit width > `Base` bit width are considered to be illegal. template class ShiftPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(SPIRVOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(operation.getType()); if (!dstType) return failure(); Type op1Type = operation.operand1().getType(); Type op2Type = operation.operand2().getType(); if (op1Type == op2Type) { rewriter.template replaceOpWithNewOp(operation, dstType, operands); return success(); } Location loc = operation.getLoc(); Value extended; if (isUnsignedIntegerOrVector(op2Type)) { extended = rewriter.template create(loc, dstType, operation.operand2()); } else { extended = rewriter.template create(loc, dstType, operation.operand2()); } Value result = rewriter.template create( loc, dstType, operation.operand1(), extended); rewriter.replaceOp(operation, result); return success(); } }; class TanPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::GLSLTanOp tanOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dstType = typeConverter.convertType(tanOp.getType()); if (!dstType) return failure(); Location loc = tanOp.getLoc(); Value sin = rewriter.create(loc, dstType, tanOp.operand()); Value cos = rewriter.create(loc, dstType, tanOp.operand()); rewriter.replaceOpWithNewOp(tanOp, dstType, sin, cos); return success(); } }; class VariablePattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::VariableOp varOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto srcType = varOp.getType(); // Initialization is supported for scalars and vectors only. auto pointerTo = srcType.cast().getPointeeType(); auto init = varOp.initializer(); if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa()) return failure(); auto dstType = typeConverter.convertType(srcType); if (!dstType) return failure(); Location loc = varOp.getLoc(); Value size = createI32ConstantOf(loc, rewriter, typeConverter, 1); if (!init) { rewriter.replaceOpWithNewOp(varOp, dstType, size); return success(); } Value allocated = rewriter.create(loc, dstType, size); rewriter.create(loc, init, allocated); rewriter.replaceOp(varOp, allocated); return success(); } }; //===----------------------------------------------------------------------===// // FuncOp conversion //===----------------------------------------------------------------------===// class FuncConversionPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // Convert function signature. At the moment LLVMType converter is enough // for currently supported types. auto funcType = funcOp.getType(); TypeConverter::SignatureConversion signatureConverter( funcType.getNumInputs()); auto llvmType = typeConverter.convertFunctionSignature( funcOp.getType(), /*isVariadic=*/false, signatureConverter); if (!llvmType) return failure(); // Create a new `LLVMFuncOp` Location loc = funcOp.getLoc(); StringRef name = funcOp.getName(); auto newFuncOp = rewriter.create(loc, name, llvmType); // Convert SPIR-V Function Control to equivalent LLVM function attribute MLIRContext *context = funcOp.getContext(); switch (funcOp.function_control()) { #define DISPATCH(functionControl, llvmAttr) \ case functionControl: \ newFuncOp.setAttr("passthrough", ArrayAttr::get({llvmAttr}, context)); \ break; DISPATCH(spirv::FunctionControl::Inline, StringAttr::get("alwaysinline", context)); DISPATCH(spirv::FunctionControl::DontInline, StringAttr::get("noinline", context)); DISPATCH(spirv::FunctionControl::Pure, StringAttr::get("readonly", context)); DISPATCH(spirv::FunctionControl::Const, StringAttr::get("readnone", context)); #undef DISPATCH // Default: if `spirv::FunctionControl::None`, then no attributes are // needed. default: break; } rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, &signatureConverter))) { return failure(); } rewriter.eraseOp(funcOp); return success(); } }; //===----------------------------------------------------------------------===// // ModuleOp conversion //===----------------------------------------------------------------------===// class ModuleConversionPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::ModuleOp spvModuleOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto newModuleOp = rewriter.create(spvModuleOp.getLoc()); rewriter.inlineRegionBefore(spvModuleOp.body(), newModuleOp.getBody()); // Remove the terminator block that was automatically added by builder rewriter.eraseBlock(&newModuleOp.getBodyRegion().back()); rewriter.eraseOp(spvModuleOp); return success(); } }; class ModuleEndConversionPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::ModuleEndOp moduleEndOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(moduleEndOp); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) { typeConverter.addConversion([&](spirv::ArrayType type) { return convertArrayType(type, typeConverter); }); typeConverter.addConversion([&](spirv::PointerType type) { return convertPointerType(type, typeConverter); }); typeConverter.addConversion([&](spirv::RuntimeArrayType type) { return convertRuntimeArrayType(type, typeConverter); }); typeConverter.addConversion([&](spirv::StructType type) { return convertStructType(type, typeConverter); }); } void mlir::populateSPIRVToLLVMConversionPatterns( MLIRContext *context, LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns) { patterns.insert< // Arithmetic ops DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, // Bitwise ops BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, NotPattern, // Cast ops DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, IndirectCastPattern, IndirectCastPattern, IndirectCastPattern, // Comparison ops IComparePattern, IComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, IComparePattern, IComparePattern, IComparePattern, IComparePattern, IComparePattern, IComparePattern, IComparePattern, IComparePattern, // Constant op ConstantScalarAndVectorPattern, // Control Flow ops BranchConversionPattern, BranchConditionalConversionPattern, SelectionPattern, MergePattern, // Function Call op FunctionCallPattern, // GLSL extended instruction set ops DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, TanPattern, // Logical ops DirectConversionPattern, DirectConversionPattern, IComparePattern, IComparePattern, NotPattern, // Memory ops LoadStorePattern, LoadStorePattern, VariablePattern, // Miscellaneous ops DirectConversionPattern, DirectConversionPattern, // Shift ops ShiftPattern, ShiftPattern, ShiftPattern, // Return ops ReturnPattern, ReturnValuePattern>(context, typeConverter); } void mlir::populateSPIRVToLLVMFunctionConversionPatterns( MLIRContext *context, LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns) { patterns.insert(context, typeConverter); } void mlir::populateSPIRVToLLVMModuleConversionPatterns( MLIRContext *context, LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns) { patterns.insert( context, typeConverter); }