//===- ConvertSPIRVToLLVM.cpp - SPIR-V dialect to LLVM dialect conversion -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns to convert SPIR-V dialect to LLVM dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Module.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "spirv-to-llvm-pattern" using namespace mlir; //===----------------------------------------------------------------------===// // Utility functions //===----------------------------------------------------------------------===// /// Returns true if the given type is a signed integer or vector type. static bool isSignedIntegerOrVector(Type type) { if (type.isSignedInteger()) return true; if (auto vecType = type.dyn_cast()) return vecType.getElementType().isSignedInteger(); return false; } /// Returns true if the given type is an unsigned integer or vector type static bool isUnsignedIntegerOrVector(Type type) { if (type.isUnsignedInteger()) return true; if (auto vecType = type.dyn_cast()) return vecType.getElementType().isUnsignedInteger(); return false; } /// Returns the bit width of integer, float or vector of float or integer values static unsigned getBitWidth(Type type) { assert((type.isIntOrFloat() || type.isa()) && "bitwidth is not supported for this type"); if (type.isIntOrFloat()) return type.getIntOrFloatBitWidth(); auto vecType = type.dyn_cast(); auto elementType = vecType.getElementType(); assert(elementType.isIntOrFloat() && "only integers and floats have a bitwidth"); return elementType.getIntOrFloatBitWidth(); } /// Returns the bit width of LLVMType integer or vector. static unsigned getLLVMTypeBitWidth(LLVM::LLVMType type) { return type.isVectorTy() ? type.getVectorElementType() .getUnderlyingType() ->getIntegerBitWidth() : type.getUnderlyingType()->getIntegerBitWidth(); } /// Creates `IntegerAttribute` with all bits set for given type IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) { if (auto vecType = type.dyn_cast()) { auto integerType = vecType.getElementType().cast(); return builder.getIntegerAttr(integerType, -1); } auto integerType = type.cast(); return builder.getIntegerAttr(integerType, -1); } /// Creates `llvm.mlir.constant` with all bits set for the given type. static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter) { if (srcType.isa()) return rewriter.create( loc, dstType, SplatElementsAttr::get(srcType.cast(), minusOneIntegerAttribute(srcType, rewriter))); return rewriter.create( loc, dstType, minusOneIntegerAttribute(srcType, rewriter)); } /// Utility function for bitfiled ops: /// - `BitFieldInsert` /// - `BitFieldSExtract` /// - `BitFieldUExtract` /// Truncates or extends the value. If the bitwidth of the value is the same as /// `dstType` bitwidth, the value remains unchanged. static Value optionallyTruncateOrExtend(Location loc, Value value, Type dstType, PatternRewriter &rewriter) { auto srcType = value.getType(); auto llvmType = dstType.cast(); unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType); unsigned valueBitWidth = srcType.isa() ? getLLVMTypeBitWidth(srcType.cast()) : getBitWidth(srcType); if (valueBitWidth < targetBitWidth) return rewriter.create(loc, llvmType, value); // If the bit widths of `Count` and `Offset` are greater than the bit width // of the target type, they are truncated. Truncation is safe since `Count` // and `Offset` must be no more than 64 for op behaviour to be defined. Hence, // both values can be expressed in 8 bits. if (valueBitWidth > targetBitWidth) return rewriter.create(loc, llvmType, value); return value; } /// Broadcasts the value to vector with `numElements` number of elements static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { auto vectorType = VectorType::get(numElements, toBroadcast.getType()); auto llvmVectorType = typeConverter.convertType(vectorType); auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); Value broadcasted = rewriter.create(loc, llvmVectorType); for (unsigned i = 0; i < numElements; ++i) { auto index = rewriter.create( loc, llvmI32Type, rewriter.getI32IntegerAttr(i)); broadcasted = rewriter.create( loc, llvmVectorType, broadcasted, toBroadcast, index); } return broadcasted; } //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// namespace { class BitFieldInsertPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::BitFieldInsertOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto srcType = op.getType(); auto dstType = this->typeConverter.convertType(srcType); if (!dstType) return failure(); Location loc = op.getLoc(); // Broadcast `Offset` and `Count` to match the type of `Base` and `Insert`. // If `Base` is of a vector type, construct a vector that has: // - same number of elements as `Base` // - each element has the type that is the same as the type of `Offset` or // `Count` // - each element has the same value as `Offset` or `Count` Value offset; Value count; if (auto vectorType = srcType.dyn_cast()) { unsigned numElements = vectorType.getNumElements(); offset = broadcast(loc, op.offset(), numElements, typeConverter, rewriter); count = broadcast(loc, op.count(), numElements, typeConverter, rewriter); } else { offset = op.offset(); count = op.count(); } // Create a mask with all bits set of the same type as `srcType` Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); // Need to cast `Offset` and `Count` if their bit width is different // from `Base` bit width. Value optionallyCastedCount = optionallyTruncateOrExtend(loc, count, dstType, rewriter); Value optionallyCastedOffset = optionallyTruncateOrExtend(loc, offset, dstType, rewriter); // Create a mask with bits set outside [Offset, Offset + Count - 1]. Value maskShiftedByCount = rewriter.create( loc, dstType, minusOne, optionallyCastedCount); Value negated = rewriter.create(loc, dstType, maskShiftedByCount, minusOne); Value maskShiftedByCountAndOffset = rewriter.create( loc, dstType, negated, optionallyCastedOffset); Value mask = rewriter.create( loc, dstType, maskShiftedByCountAndOffset, minusOne); // Extract unchanged bits from the `Base` that are outside of // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`. Value baseAndMask = rewriter.create(loc, dstType, op.base(), mask); Value insertShiftedByOffset = rewriter.create( loc, dstType, op.insert(), optionallyCastedOffset); rewriter.replaceOpWithNewOp(op, dstType, baseAndMask, insertShiftedByOffset); return success(); } }; /// Converts SPIR-V ConstantOp with scalar or vector type. class ConstantScalarAndVectorPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::ConstantOp constOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto srcType = constOp.getType(); if (!srcType.isa() && !srcType.isIntOrFloat()) return failure(); auto dstType = typeConverter.convertType(srcType); if (!dstType) return failure(); // SPIR-V constant can be a signed/unsigned integer, which has to be // casted to signless integer when converting to LLVM dialect. Removing the // sign bit may have unexpected behaviour. However, it is better to handle // it case-by-case, given that the purpose of the conversion is not to // cover all possible corner cases. if (isSignedIntegerOrVector(srcType) || isUnsignedIntegerOrVector(srcType)) { auto *context = rewriter.getContext(); auto signlessType = IntegerType::get(getBitWidth(srcType), context); if (srcType.isa()) { auto dstElementsAttr = constOp.value().cast(); rewriter.replaceOpWithNewOp( constOp, dstType, dstElementsAttr.mapValues( signlessType, [&](const APInt &value) { return value; })); return success(); } auto srcAttr = constOp.value().cast(); auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue()); rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); return success(); } rewriter.replaceOpWithNewOp(constOp, dstType, operands, constOp.getAttrs()); return success(); } }; /// Converts SPIR-V operations that have straightforward LLVM equivalent /// into LLVM dialect operations. template class DirectConversionPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(SPIRVOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(operation.getType()); if (!dstType) return failure(); rewriter.template replaceOpWithNewOp(operation, dstType, operands, operation.getAttrs()); return success(); } }; /// Converts SPIR-V cast ops that do not have straightforward LLVM /// equivalent in LLVM dialect. template class IndirectCastPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(SPIRVOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Type fromType = operation.operand().getType(); Type toType = operation.getType(); auto dstType = this->typeConverter.convertType(toType); if (!dstType) return failure(); if (getBitWidth(fromType) < getBitWidth(toType)) { rewriter.template replaceOpWithNewOp(operation, dstType, operands); return success(); } if (getBitWidth(fromType) > getBitWidth(toType)) { rewriter.template replaceOpWithNewOp(operation, dstType, operands); return success(); } return failure(); } }; class FunctionCallPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::FunctionCallOp callOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (callOp.getNumResults() == 0) { rewriter.replaceOpWithNewOp(callOp, llvm::None, operands, callOp.getAttrs()); return success(); } // Function returns a single result. auto dstType = this->typeConverter.convertType(callOp.getType(0)); rewriter.replaceOpWithNewOp(callOp, dstType, operands, callOp.getAttrs()); return success(); } }; /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate" template class FComparePattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(SPIRVOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(operation.getType()); if (!dstType) return failure(); rewriter.template replaceOpWithNewOp( operation, dstType, rewriter.getI64IntegerAttr(static_cast(predicate)), operation.operand1(), operation.operand2()); return success(); } }; /// Converts SPIR-V integer comparisons to llvm.icmp "predicate" template class IComparePattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(SPIRVOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(operation.getType()); if (!dstType) return failure(); rewriter.template replaceOpWithNewOp( operation, dstType, rewriter.getI64IntegerAttr(static_cast(predicate)), operation.operand1(), operation.operand2()); return success(); } }; /// Converts `spv.Not` and `spv.LogicalNot` into LLVM dialect. template class NotPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(SPIRVOp notOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto srcType = notOp.getType(); auto dstType = this->typeConverter.convertType(srcType); if (!dstType) return failure(); Location loc = notOp.getLoc(); IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter); auto mask = srcType.template isa() ? rewriter.create( loc, dstType, SplatElementsAttr::get( srcType.template cast(), minusOne)) : rewriter.create(loc, dstType, minusOne); rewriter.template replaceOpWithNewOp(notOp, dstType, notOp.operand(), mask); return success(); } }; class ReturnPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::ReturnOp returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(returnOp, ArrayRef(), ArrayRef()); return success(); } }; class ReturnValuePattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::ReturnValueOp returnValueOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(returnValueOp, ArrayRef(), operands); return success(); } }; /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect /// puts a restriction on `Shift` and `Base` to have the same bit width, /// `Shift` is zero or sign extended to match this specification. Cases when /// `Shift` bit width > `Base` bit width are considered to be illegal. template class ShiftPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(SPIRVOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(operation.getType()); if (!dstType) return failure(); Type op1Type = operation.operand1().getType(); Type op2Type = operation.operand2().getType(); if (op1Type == op2Type) { rewriter.template replaceOpWithNewOp(operation, dstType, operands); return success(); } Location loc = operation.getLoc(); Value extended; if (isUnsignedIntegerOrVector(op2Type)) { extended = rewriter.template create(loc, dstType, operation.operand2()); } else { extended = rewriter.template create(loc, dstType, operation.operand2()); } Value result = rewriter.template create( loc, dstType, operation.operand1(), extended); rewriter.replaceOp(operation, result); return success(); } }; //===----------------------------------------------------------------------===// // FuncOp conversion //===----------------------------------------------------------------------===// class FuncConversionPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // Convert function signature. At the moment LLVMType converter is enough // for currently supported types. auto funcType = funcOp.getType(); TypeConverter::SignatureConversion signatureConverter( funcType.getNumInputs()); auto llvmType = this->typeConverter.convertFunctionSignature( funcOp.getType(), /*isVariadic=*/false, signatureConverter); // Create a new `LLVMFuncOp` Location loc = funcOp.getLoc(); StringRef name = funcOp.getName(); auto newFuncOp = rewriter.create(loc, name, llvmType); // Convert SPIR-V Function Control to equivalent LLVM function attribute MLIRContext *context = funcOp.getContext(); switch (funcOp.function_control()) { #define DISPATCH(functionControl, llvmAttr) \ case functionControl: \ newFuncOp.setAttr("passthrough", ArrayAttr::get({llvmAttr}, context)); \ break; DISPATCH(spirv::FunctionControl::Inline, StringAttr::get("alwaysinline", context)); DISPATCH(spirv::FunctionControl::DontInline, StringAttr::get("noinline", context)); DISPATCH(spirv::FunctionControl::Pure, StringAttr::get("readonly", context)); DISPATCH(spirv::FunctionControl::Const, StringAttr::get("readnone", context)); #undef DISPATCH // Default: if `spirv::FunctionControl::None`, then no attributes are // needed. default: break; } rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter); rewriter.eraseOp(funcOp); return success(); } }; //===----------------------------------------------------------------------===// // ModuleOp conversion //===----------------------------------------------------------------------===// class ModuleConversionPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::ModuleOp spvModuleOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto newModuleOp = rewriter.create(spvModuleOp.getLoc()); rewriter.inlineRegionBefore(spvModuleOp.body(), newModuleOp.getBody()); // Remove the terminator block that was automatically added by builder rewriter.eraseBlock(&newModuleOp.getBodyRegion().back()); rewriter.eraseOp(spvModuleOp); return success(); } }; class ModuleEndConversionPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult matchAndRewrite(spirv::ModuleEndOp moduleEndOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(moduleEndOp); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// void mlir::populateSPIRVToLLVMConversionPatterns( MLIRContext *context, LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns) { patterns.insert< // Arithmetic ops DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, // Bitwise ops BitFieldInsertPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, NotPattern, // Cast ops DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, IndirectCastPattern, IndirectCastPattern, IndirectCastPattern, // Comparison ops IComparePattern, IComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, FComparePattern, IComparePattern, IComparePattern, IComparePattern, IComparePattern, IComparePattern, IComparePattern, IComparePattern, IComparePattern, // Constant op ConstantScalarAndVectorPattern, // Function Call op FunctionCallPattern, // Logical ops DirectConversionPattern, DirectConversionPattern, IComparePattern, IComparePattern, NotPattern, // Shift ops ShiftPattern, ShiftPattern, ShiftPattern, // Return ops ReturnPattern, ReturnValuePattern>(context, typeConverter); } void mlir::populateSPIRVToLLVMFunctionConversionPatterns( MLIRContext *context, LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns) { patterns.insert(context, typeConverter); } void mlir::populateSPIRVToLLVMModuleConversionPatterns( MLIRContext *context, LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns) { patterns.insert( context, typeConverter); }