//===- Serializer.cpp - MLIR SPIR-V Serialization -------------------------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= // // This file defines the MLIR SPIR-V module to SPIR-V binary seralization. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/Serialization.h" #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/Builders.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/StringExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/bit.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; /// Returns the word-count-prefixed opcode for an SPIR-V instruction. static inline uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode) { assert(((wordCount >> 16) == 0) && "word count out of range!"); return (wordCount << 16) | static_cast(opcode); } /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into /// the given `binary` vector. static LogicalResult encodeInstructionInto(SmallVectorImpl &binary, spirv::Opcode op, ArrayRef operands) { uint32_t wordCount = 1 + operands.size(); binary.push_back(getPrefixedOpcode(wordCount, op)); if (!operands.empty()) { binary.append(operands.begin(), operands.end()); } return success(); } /// Encodes an SPIR-V `literal` string into the given `binary` vector. static LogicalResult encodeStringLiteralInto(SmallVectorImpl &binary, StringRef literal) { // We need to encode the literal and the null termination. auto encodingSize = literal.size() / 4 + 1; auto bufferStartSize = binary.size(); binary.resize(bufferStartSize + encodingSize, 0); std::memcpy(binary.data() + bufferStartSize, literal.data(), literal.size()); return success(); } namespace { /// A SPIR-V module serializer. /// /// A SPIR-V binary module is a single linear stream of instructions; each /// instruction is composed of 32-bit words with the layout: /// /// | | | | | ... | /// | <------ word -------> | <-- word --> | <-- word --> | ... | /// /// For the first word, the 16 high-order bits are the word count of the /// instruction, the 16 low-order bits are the opcode enumerant. The /// instructions then belong to different sections, which must be laid out in /// the particular order as specified in "2.4 Logical Layout of a Module" of /// the SPIR-V spec. class Serializer { public: /// Creates a serializer for the given SPIR-V `module`. explicit Serializer(spirv::ModuleOp module); /// Serializes the remembered SPIR-V module. LogicalResult serialize(); /// Collects the final SPIR-V `binary`. void collect(SmallVectorImpl &binary); private: // Note that there are two main categories of methods in this class: // * process*() methods are meant to fully serialize a SPIR-V module entity // (header, type, op, etc.). They update internal vectors containing // different binary sections. They are not meant to be called except the // top-level serialization loop. // * prepare*() methods are meant to be helpers that prepare for serializing // certain entity. They may or may not update internal vectors containing // different binary sections. They are meant to be called among themselves // or by other process*() methods for subtasks. //===--------------------------------------------------------------------===// // //===--------------------------------------------------------------------===// // Note that it is illegal to use id <0> in SPIR-V binary module. Various // methods in this class, if using SPIR-V word (uint32_t) as interface, // check or return id <0> to indicate error in processing. /// Consumes the next unused . This method will never return 0. uint32_t getNextID() { return nextID++; } //===--------------------------------------------------------------------===// // Module structure //===--------------------------------------------------------------------===// LogicalResult processMemoryModel(); LogicalResult processConstantOp(spirv::ConstantOp op); uint32_t findFunctionID(StringRef fnName) const { return funcIDMap.lookup(fnName); } /// Processes a SPIR-V function op. LogicalResult processFuncOp(FuncOp op); /// Process attributes that translate to decorations on the result LogicalResult processDecoration(Location loc, uint32_t resultID, NamedAttribute attr); //===--------------------------------------------------------------------===// // Types //===--------------------------------------------------------------------===// uint32_t findTypeID(Type type) const { return typeIDMap.lookup(type); } Type getVoidType() { return mlirBuilder.getNoneType(); } bool isVoidType(Type type) const { return type.isa(); } /// Main dispatch method for serializing a type. The result of the /// serialized type will be returned as `typeID`. LogicalResult processType(Location loc, Type type, uint32_t &typeID); /// Method for preparing basic SPIR-V type serialization. Returns the type's /// opcode and operands for the instruction via `typeEnum` and `operands`. LogicalResult prepareBasicType(Location loc, Type type, spirv::Opcode &typeEnum, SmallVectorImpl &operands); LogicalResult prepareFunctionType(Location loc, FunctionType type, spirv::Opcode &typeEnum, SmallVectorImpl &operands); //===--------------------------------------------------------------------===// // Constant //===--------------------------------------------------------------------===// uint32_t findConstantID(Attribute value) const { return constIDMap.lookup(value); } /// Main dispatch method for processing a constant with the given `constType` /// and `valueAttr`. `constType` is needed here because we can interpret the /// `valueAttr` as a different type than the type of `valueAttr` itself; for /// example, ArrayAttr, whose type is NoneType, is used for spirv::ArrayType /// constants. If `isSpec` is true, then the constant will be serialized as /// a specialization constant. uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr, bool isSpec); /// Prepares bool ElementsAttr serialization. This method updates `opcode` /// with a proper OpConstant* instruction and pushes literal values for the /// constant to `operands`. LogicalResult prepareBoolVectorConstant(Location loc, DenseIntElementsAttr elementsAttr, bool isSpec, spirv::Opcode &opcode, SmallVectorImpl &operands); /// Prepares int ElementsAttr serialization. This method updates `opcode` with /// a proper OpConstant* instruction and pushes literal values for the /// constant to `operands`. LogicalResult prepareIntVectorConstant(Location loc, DenseIntElementsAttr elementsAttr, bool isSpec, spirv::Opcode &opcode, SmallVectorImpl &operands); /// Prepares float ElementsAttr serialization. This method updates `opcode` /// with a proper OpConstant* instruction and pushes literal values for the /// constant to `operands`. LogicalResult prepareFloatVectorConstant(Location loc, DenseFPElementsAttr elementsAttr, bool isSpec, spirv::Opcode &opcode, SmallVectorImpl &operands); uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr, bool isSpec); uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr, bool isSpec); uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr, bool isSpec); //===--------------------------------------------------------------------===// // Operations //===--------------------------------------------------------------------===// uint32_t findValueID(Value *val) const { return valueIDMap.lookup(val); } /// Main dispatch method for serializing an operation. LogicalResult processOperation(Operation *op); /// Method to dispatch to the serialization function for an operation in /// SPIR-V dialect that is a mirror of an instruction in the SPIR-V spec. /// This is auto-generated from ODS. Dispatch is handled for all operations /// in SPIR-V dialect that have hasOpcode == 1. LogicalResult dispatchToAutogenSerialization(Operation *op); /// Method to serialize an operation in the SPIR-V dialect that is a mirror of /// an instruction in the SPIR-V spec. This is auto generated if hasOpcode == /// 1 and autogenSerialization == 1 in ODS. template LogicalResult processOp(OpTy op) { return op.emitError("unsupported op serialization"); } private: /// The SPIR-V module to be serialized. spirv::ModuleOp module; /// An MLIR builder for getting MLIR constructs. mlir::Builder mlirBuilder; /// The next available result . uint32_t nextID = 1; // The following are for different SPIR-V instruction sections. They follow // the logical layout of a SPIR-V module. SmallVector capabilities; SmallVector extensions; SmallVector extendedSets; SmallVector memoryModel; SmallVector entryPoints; SmallVector executionModes; // TODO(antiagainst): debug instructions SmallVector names; SmallVector decorations; SmallVector typesGlobalValues; SmallVector functions; /// Map from type used in SPIR-V module to their s DenseMap typeIDMap; /// Map from constant values to their s DenseMap constIDMap; /// Map from FuncOps name to s. llvm::StringMap funcIDMap; /// Map from results of normal operations to their s DenseMap valueIDMap; }; } // namespace Serializer::Serializer(spirv::ModuleOp module) : module(module), mlirBuilder(module.getContext()) {} LogicalResult Serializer::serialize() { if (failed(module.verify())) return failure(); // TODO(antiagainst): handle the other sections processMemoryModel(); // Iterate over the module body to serialze it. Assumptions are that there is // only one basic block in the moduleOp for (auto &op : module.getBlock()) { if (failed(processOperation(&op))) { return failure(); } } return success(); } void Serializer::collect(SmallVectorImpl &binary) { auto moduleSize = spirv::kHeaderWordCount + capabilities.size() + extensions.size() + extendedSets.size() + memoryModel.size() + entryPoints.size() + executionModes.size() + decorations.size() + typesGlobalValues.size() + functions.size(); binary.clear(); binary.reserve(moduleSize); spirv::appendModuleHeader(binary, nextID); binary.append(capabilities.begin(), capabilities.end()); binary.append(extensions.begin(), extensions.end()); binary.append(extendedSets.begin(), extendedSets.end()); binary.append(memoryModel.begin(), memoryModel.end()); binary.append(entryPoints.begin(), entryPoints.end()); binary.append(executionModes.begin(), executionModes.end()); binary.append(names.begin(), names.end()); binary.append(decorations.begin(), decorations.end()); binary.append(typesGlobalValues.begin(), typesGlobalValues.end()); binary.append(functions.begin(), functions.end()); } //===----------------------------------------------------------------------===// // Module structure //===----------------------------------------------------------------------===// LogicalResult Serializer::processMemoryModel() { uint32_t mm = module.getAttrOfType("memory_model").getInt(); uint32_t am = module.getAttrOfType("addressing_model").getInt(); return encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm}); } LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) { if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value(), op.is_spec_const())) { valueIDMap[op.getResult()] = resultID; return success(); } return failure(); } LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, NamedAttribute attr) { auto attrName = attr.first.strref(); auto decorationName = mlir::convertToCamelCase(attrName, true); auto decoration = spirv::symbolizeDecoration(decorationName); if (!decoration) { return emitError( loc, "non-argument attributes expected to have snake-case-ified " "decoration name, unhandled attribute with name : ") << attrName; } SmallVector args; args.push_back(resultID); args.push_back(static_cast(decoration.getValue())); switch (decoration.getValue()) { case spirv::Decoration::DescriptorSet: case spirv::Decoration::Binding: if (auto intAttr = attr.second.dyn_cast()) { args.push_back(intAttr.getValue().getZExtValue()); break; } return emitError(loc, "expected integer attribute for ") << attrName; default: return emitError(loc, "unhandled decoration ") << decorationName; } return encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, args); } LogicalResult Serializer::processFuncOp(FuncOp op) { uint32_t fnTypeID = 0; // Generate type of the function. processType(op.getLoc(), op.getType(), fnTypeID); // Add the function definition. SmallVector operands; uint32_t resTypeID = 0; auto resultTypes = op.getType().getResults(); if (resultTypes.size() > 1) { return emitError(op.getLoc(), "cannot serialize function with multiple return types"); } if (failed(processType(op.getLoc(), (resultTypes.empty() ? getVoidType() : resultTypes[0]), resTypeID))) { return failure(); } operands.push_back(resTypeID); auto funcID = getNextID(); funcIDMap[op.getName()] = funcID; operands.push_back(funcID); // TODO : Support other function control options. operands.push_back(static_cast(spirv::FunctionControl::None)); operands.push_back(fnTypeID); encodeInstructionInto(functions, spirv::Opcode::OpFunction, operands); // Add function name. SmallVector nameOperands; nameOperands.push_back(funcID); encodeStringLiteralInto(nameOperands, op.getName()); encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands); // Declare the parameters. for (auto arg : op.getArguments()) { uint32_t argTypeID = 0; if (failed(processType(op.getLoc(), arg->getType(), argTypeID))) { return failure(); } auto argValueID = getNextID(); valueIDMap[arg] = argValueID; encodeInstructionInto(functions, spirv::Opcode::OpFunctionParameter, {argTypeID, argValueID}); } // Process the body. if (op.isExternal()) { return emitError(op.getLoc(), "external function is unhandled"); } for (auto &b : op) { for (auto &op : b) { if (failed(processOperation(&op))) { return failure(); } } } // Insert Function End. return encodeInstructionInto(functions, spirv::Opcode::OpFunctionEnd, {}); } //===----------------------------------------------------------------------===// // Type //===----------------------------------------------------------------------===// LogicalResult Serializer::processType(Location loc, Type type, uint32_t &typeID) { typeID = findTypeID(type); if (typeID) { return success(); } typeID = getNextID(); SmallVector operands; operands.push_back(typeID); auto typeEnum = spirv::Opcode::OpTypeVoid; if ((type.isa() && succeeded(prepareFunctionType(loc, type.cast(), typeEnum, operands))) || succeeded(prepareBasicType(loc, type, typeEnum, operands))) { typeIDMap[type] = typeID; return encodeInstructionInto(typesGlobalValues, typeEnum, operands); } return failure(); } LogicalResult Serializer::prepareBasicType(Location loc, Type type, spirv::Opcode &typeEnum, SmallVectorImpl &operands) { if (isVoidType(type)) { typeEnum = spirv::Opcode::OpTypeVoid; return success(); } if (auto intType = type.dyn_cast()) { if (intType.getWidth() == 1) { typeEnum = spirv::Opcode::OpTypeBool; return success(); } typeEnum = spirv::Opcode::OpTypeInt; operands.push_back(intType.getWidth()); // TODO(antiagainst): support unsigned integers operands.push_back(1); return success(); } if (auto floatType = type.dyn_cast()) { typeEnum = spirv::Opcode::OpTypeFloat; operands.push_back(floatType.getWidth()); return success(); } if (auto vectorType = type.dyn_cast()) { uint32_t elementTypeID = 0; if (failed(processType(loc, vectorType.getElementType(), elementTypeID))) { return failure(); } typeEnum = spirv::Opcode::OpTypeVector; operands.push_back(elementTypeID); operands.push_back(vectorType.getNumElements()); return success(); } if (auto arrayType = type.dyn_cast()) { typeEnum = spirv::Opcode::OpTypeArray; uint32_t elementTypeID = 0; if (failed(processType(loc, arrayType.getElementType(), elementTypeID))) { return failure(); } operands.push_back(elementTypeID); if (auto elementCountID = prepareConstantInt( loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()), /*isSpec=*/false)) { operands.push_back(elementCountID); return success(); } return failure(); } if (auto ptrType = type.dyn_cast()) { uint32_t pointeeTypeID = 0; if (failed(processType(loc, ptrType.getPointeeType(), pointeeTypeID))) { return failure(); } typeEnum = spirv::Opcode::OpTypePointer; operands.push_back(static_cast(ptrType.getStorageClass())); operands.push_back(pointeeTypeID); return success(); } // TODO(ravishankarm) : Handle other types. return emitError(loc, "unhandled type in serialization: ") << type; } LogicalResult Serializer::prepareFunctionType(Location loc, FunctionType type, spirv::Opcode &typeEnum, SmallVectorImpl &operands) { typeEnum = spirv::Opcode::OpTypeFunction; assert(type.getNumResults() <= 1 && "Serialization supports only a single return value"); uint32_t resultID = 0; if (failed(processType( loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(), resultID))) { return failure(); } operands.push_back(resultID); for (auto &res : type.getInputs()) { uint32_t argTypeID = 0; if (failed(processType(loc, res, argTypeID))) { return failure(); } operands.push_back(argTypeID); } return success(); } //===----------------------------------------------------------------------===// // Constant //===----------------------------------------------------------------------===// uint32_t Serializer::prepareConstant(Location loc, Type constType, Attribute valueAttr, bool isSpec) { if (auto floatAttr = valueAttr.dyn_cast()) { return prepareConstantFp(loc, floatAttr, isSpec); } if (auto intAttr = valueAttr.dyn_cast()) { return prepareConstantInt(loc, intAttr, isSpec); } if (auto boolAttr = valueAttr.dyn_cast()) { return prepareConstantBool(loc, boolAttr, isSpec); } // This is a composite literal. We need to handle each component separately // and then emit an OpConstantComposite for the whole. if (auto id = findConstantID(valueAttr)) { return id; } uint32_t typeID = 0; if (failed(processType(loc, constType, typeID))) { return 0; } auto resultID = getNextID(); spirv::Opcode opcode = spirv::Opcode::OpNop; SmallVector operands; operands.push_back(typeID); operands.push_back(resultID); if (auto vectorAttr = valueAttr.dyn_cast()) { if (vectorAttr.getType().getElementType().isInteger(1)) { if (failed(prepareBoolVectorConstant(loc, vectorAttr, isSpec, opcode, operands))) return 0; } else if (failed(prepareIntVectorConstant(loc, vectorAttr, isSpec, opcode, operands))) return 0; } else if (auto vectorAttr = valueAttr.dyn_cast()) { if (failed(prepareFloatVectorConstant(loc, vectorAttr, isSpec, opcode, operands))) return 0; } else if (auto arrayAttr = valueAttr.dyn_cast()) { opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite : spirv::Opcode::OpConstantComposite; operands.reserve(arrayAttr.size() + 2); auto elementType = constType.cast().getElementType(); for (Attribute elementAttr : arrayAttr) if (auto elementID = prepareConstant(loc, elementType, elementAttr, isSpec)) { operands.push_back(elementID); } else { return 0; } } else { emitError(loc, "cannot serialize attribute: ") << valueAttr; return 0; } encodeInstructionInto(typesGlobalValues, opcode, operands); constIDMap[valueAttr] = resultID; return resultID; } LogicalResult Serializer::prepareBoolVectorConstant( Location loc, DenseIntElementsAttr elementsAttr, bool isSpec, spirv::Opcode &opcode, SmallVectorImpl &operands) { auto type = elementsAttr.getType(); assert(type.hasRank() && type.getRank() == 1 && "spv.constant should have verified only vector literal uses " "ElementsAttr"); assert(type.getElementType().isInteger(1) && "must be bool ElementsAttr"); auto count = type.getNumElements(); // Operands for constructing the SPIR-V OpConstant* instruction operands.reserve(count + 2); // For splat cases, we don't need to loop over all elements, especially when // the splat value is zero. if (Attribute splatAttr = elementsAttr.getSplatValue()) { // We can use OpConstantNull if this bool ElementsAttr is splatting false. if (!isSpec && !splatAttr.cast().getValue()) { opcode = spirv::Opcode::OpConstantNull; return success(); } if (auto id = prepareConstantBool(loc, splatAttr.cast(), isSpec)) { opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite : spirv::Opcode::OpConstantComposite; operands.append(count, id); return success(); } return failure(); } // Otherwise, we need to process each element and compose them with // OpConstantComposite. opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite : spirv::Opcode::OpConstantComposite; for (APInt intValue : elementsAttr) { // We are constructing an BoolAttr for each APInt here. But given that // we only use ElementsAttr for vectors with no more than 4 elements, it // should be fine here. auto boolAttr = mlirBuilder.getBoolAttr(intValue.isOneValue()); if (auto elementID = prepareConstantBool(loc, boolAttr, isSpec)) { operands.push_back(elementID); } else { return failure(); } } return success(); } LogicalResult Serializer::prepareIntVectorConstant( Location loc, DenseIntElementsAttr elementsAttr, bool isSpec, spirv::Opcode &opcode, SmallVectorImpl &operands) { auto type = elementsAttr.getType(); assert(type.hasRank() && type.getRank() == 1 && "spv.constant should have verified only vector literal uses " "ElementsAttr"); auto elementType = type.getElementType(); assert(!elementType.isInteger(1) && "must be non-bool ElementsAttr"); auto count = type.getNumElements(); // Operands for constructing the SPIR-V OpConstant* instruction operands.reserve(count + 2); // For splat cases, we don't need to loop over all elements, especially when // the splat value is zero. if (Attribute splatAttr = elementsAttr.getSplatValue()) { // We can use OpConstantNull if this int ElementsAttr is splatting 0. if (!isSpec && splatAttr.cast().getValue().isNullValue()) { opcode = spirv::Opcode::OpConstantNull; return success(); } if (auto id = prepareConstantInt(loc, splatAttr.cast(), isSpec)) { opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite : spirv::Opcode::OpConstantComposite; operands.append(count, id); return success(); } return failure(); } // Otherwise, we need to process each element and compose them with // OpConstantComposite. opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite : spirv::Opcode::OpConstantComposite; for (APInt intValue : elementsAttr) { // We are constructing an IntegerAttr for each APInt here. But given that // we only use ElementsAttr for vectors with no more than 4 elements, it // should be fine here. // TODO(antiagainst): revisit this if special extensions enabling large // vectors are supported. auto intAttr = mlirBuilder.getIntegerAttr(elementType, intValue); if (auto elementID = prepareConstantInt(loc, intAttr, isSpec)) { operands.push_back(elementID); } else { return failure(); } } return success(); } LogicalResult Serializer::prepareFloatVectorConstant( Location loc, DenseFPElementsAttr elementsAttr, bool isSpec, spirv::Opcode &opcode, SmallVectorImpl &operands) { auto type = elementsAttr.getType(); assert(type.hasRank() && type.getRank() == 1 && "spv.constant should have verified only vector literal uses " "ElementsAttr"); auto count = type.getNumElements(); auto elementType = type.getElementType(); operands.reserve(count + 2); if (Attribute splatAttr = elementsAttr.getSplatValue()) { if (!isSpec && splatAttr.cast().getValue().isZero()) { opcode = spirv::Opcode::OpConstantNull; return success(); } if (auto id = prepareConstantFp(loc, splatAttr.cast(), isSpec)) { opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite : spirv::Opcode::OpConstantComposite; operands.append(count, id); return success(); } return failure(); } opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite : spirv::Opcode::OpConstantComposite; for (APFloat floatValue : elementsAttr) { auto fpAttr = mlirBuilder.getFloatAttr(elementType, floatValue); if (auto elementID = prepareConstantFp(loc, fpAttr, isSpec)) { operands.push_back(elementID); } else { return failure(); } } return success(); } uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr, bool isSpec) { if (auto id = findConstantID(boolAttr)) { return id; } // Process the type for this bool literal uint32_t typeID = 0; if (failed(processType(loc, boolAttr.getType(), typeID))) { return 0; } auto resultID = getNextID(); auto opcode = boolAttr.getValue() ? (isSpec ? spirv::Opcode::OpSpecConstantTrue : spirv::Opcode::OpConstantTrue) : (isSpec ? spirv::Opcode::OpSpecConstantFalse : spirv::Opcode::OpConstantFalse); encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID}); return constIDMap[boolAttr] = resultID; } uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr, bool isSpec) { if (auto id = findConstantID(intAttr)) { return id; } // Process the type for this integer literal uint32_t typeID = 0; if (failed(processType(loc, intAttr.getType(), typeID))) { return 0; } auto resultID = getNextID(); APInt value = intAttr.getValue(); unsigned bitwidth = value.getBitWidth(); bool isSigned = value.isSignedIntN(bitwidth); auto opcode = isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; // According to SPIR-V spec, "When the type's bit width is less than 32-bits, // the literal's value appears in the low-order bits of the word, and the // high-order bits must be 0 for a floating-point type, or 0 for an integer // type with Signedness of 0, or sign extended when Signedness is 1." if (bitwidth == 32 || bitwidth == 16) { uint32_t word = 0; if (isSigned) { word = static_cast(value.getSExtValue()); } else { word = static_cast(value.getZExtValue()); } encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); } // According to SPIR-V spec: "When the type's bit width is larger than one // word, the literal’s low-order words appear first." else if (bitwidth == 64) { struct DoubleWord { uint32_t word1; uint32_t word2; } words; if (isSigned) { words = llvm::bit_cast(value.getSExtValue()); } else { words = llvm::bit_cast(value.getZExtValue()); } encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, words.word1, words.word2}); } else { std::string valueStr; llvm::raw_string_ostream rss(valueStr); value.print(rss, /*isSigned*/ false); emitError(loc, "cannot serialize ") << bitwidth << "-bit integer literal: " << rss.str(); return 0; } return constIDMap[intAttr] = resultID; } uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, bool isSpec) { if (auto id = findConstantID(floatAttr)) { return id; } // Process the type for this float literal uint32_t typeID = 0; if (failed(processType(loc, floatAttr.getType(), typeID))) { return 0; } auto resultID = getNextID(); APFloat value = floatAttr.getValue(); APInt intValue = value.bitcastToAPInt(); auto opcode = isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; if (&value.getSemantics() == &APFloat::IEEEsingle()) { uint32_t word = llvm::bit_cast(value.convertToFloat()); encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); } else if (&value.getSemantics() == &APFloat::IEEEdouble()) { struct DoubleWord { uint32_t word1; uint32_t word2; } words = llvm::bit_cast(value.convertToDouble()); encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, words.word1, words.word2}); } else if (&value.getSemantics() == &APFloat::IEEEhalf()) { uint32_t word = static_cast(value.bitcastToAPInt().getZExtValue()); encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); } else { std::string valueStr; llvm::raw_string_ostream rss(valueStr); value.print(rss); emitError(loc, "cannot serialize ") << floatAttr.getType() << "-typed float literal: " << rss.str(); return 0; } return constIDMap[floatAttr] = resultID; } //===----------------------------------------------------------------------===// // Operation //===----------------------------------------------------------------------===// LogicalResult Serializer::processOperation(Operation *op) { // First dispatch the methods that do not directly mirror an operation from // the SPIR-V spec if (auto constOp = dyn_cast(op)) { return processConstantOp(constOp); } if (auto fnOp = dyn_cast(op)) { return processFuncOp(fnOp); } if (isa(op)) { return success(); } return dispatchToAutogenSerialization(op); } namespace { template <> LogicalResult Serializer::processOp(spirv::EntryPointOp op) { SmallVector operands; // Add the ExectionModel. operands.push_back(static_cast(op.execution_model())); // Add the function . auto funcID = findFunctionID(op.fn()); if (!funcID) { return op.emitError("missing for function ") << op.fn() << "; function needs to be defined before spv.EntryPoint is " "serialized"; } operands.push_back(funcID); // Add the name of the function. encodeStringLiteralInto(operands, op.fn()); // Add the interface values. for (auto val : op.interface()) { auto id = findValueID(val); if (!id) { return op.emitError("referencing unintialized variable . " "spv.EntryPoint is at the end of spv.module. All " "referenced variables should already be defined"); } operands.push_back(id); } return encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, operands); } template <> LogicalResult Serializer::processOp(spirv::ExecutionModeOp op) { SmallVector operands; // Add the function . auto funcID = findFunctionID(op.fn()); if (!funcID) { return op.emitError("missing for function ") << op.fn() << "; function needs to be serialized before ExecutionModeOp is " "serialized"; } operands.push_back(funcID); // Add the ExecutionMode. operands.push_back(static_cast(op.execution_mode())); // Serialize values if any. auto values = op.values(); if (values) { for (auto &intVal : values.getValue()) { operands.push_back(static_cast( intVal.cast().getValue().getZExtValue())); } } return encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode, operands); } // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and // various Serializer::processOp<...>() specializations. #define GET_SERIALIZATION_FNS #include "mlir/Dialect/SPIRV/SPIRVSerialization.inc" } // namespace LogicalResult spirv::serialize(spirv::ModuleOp module, SmallVectorImpl &binary) { Serializer serializer(module); if (failed(serializer.serialize())) return failure(); serializer.collect(binary); return success(); }