Non-32-bit scalar types requires special hardware support that may not exist on all Vulkan-capable GPUs. This is reflected as non-32-bit scalar types require special capabilities or extensions to be used. This commit makes SPIRVTypeConverter target environment aware so that it can properly convert standard types to what is accepted on the target environment. Right now if a scalar type bitwidth is not supported in the target environment, we use 32-bit unconditionally. This requires Vulkan runtime to also feed in data with a matched bitwidth and layout, especially for interface types. The Vulkan runtime can do that by inspecting the SPIR-V module. Longer term, we might want to introduce a way to control how such case are handled and explicitly fail if wanted. Differential Revision: https://reviews.llvm.org/D76244
174 lines
5.4 KiB
C++
174 lines
5.4 KiB
C++
//===- TargetAndABI.cpp - SPIR-V target and ABI utilities -----------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/SPIRV/TargetAndABI.h"
|
|
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/FunctionSupport.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/IR/SymbolTable.h"
|
|
|
|
using namespace mlir;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TargetEnv
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
spirv::TargetEnv::TargetEnv(spirv::TargetEnvAttr targetAttr)
|
|
: targetAttr(targetAttr) {
|
|
for (spirv::Extension ext : targetAttr.getExtensions())
|
|
givenExtensions.insert(ext);
|
|
|
|
// Add extensions implied by the current version.
|
|
for (spirv::Extension ext :
|
|
spirv::getImpliedExtensions(targetAttr.getVersion()))
|
|
givenExtensions.insert(ext);
|
|
|
|
for (spirv::Capability cap : targetAttr.getCapabilities()) {
|
|
givenCapabilities.insert(cap);
|
|
|
|
// Add capabilities implied by the current capability.
|
|
for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap))
|
|
givenCapabilities.insert(c);
|
|
}
|
|
}
|
|
|
|
spirv::Version spirv::TargetEnv::getVersion() {
|
|
return targetAttr.getVersion();
|
|
}
|
|
|
|
bool spirv::TargetEnv::allows(spirv::Capability capability) const {
|
|
return givenCapabilities.count(capability);
|
|
}
|
|
|
|
Optional<spirv::Capability>
|
|
spirv::TargetEnv::allows(ArrayRef<spirv::Capability> caps) const {
|
|
auto chosen = llvm::find_if(caps, [this](spirv::Capability cap) {
|
|
return givenCapabilities.count(cap);
|
|
});
|
|
if (chosen != caps.end())
|
|
return *chosen;
|
|
return llvm::None;
|
|
}
|
|
|
|
bool spirv::TargetEnv::allows(spirv::Extension extension) const {
|
|
return givenExtensions.count(extension);
|
|
}
|
|
|
|
Optional<spirv::Extension>
|
|
spirv::TargetEnv::allows(ArrayRef<spirv::Extension> exts) const {
|
|
auto chosen = llvm::find_if(exts, [this](spirv::Extension ext) {
|
|
return givenExtensions.count(ext);
|
|
});
|
|
if (chosen != exts.end())
|
|
return *chosen;
|
|
return llvm::None;
|
|
}
|
|
|
|
MLIRContext *spirv::TargetEnv::getContext() const {
|
|
return targetAttr.getContext();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Utility functions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
StringRef spirv::getInterfaceVarABIAttrName() {
|
|
return "spv.interface_var_abi";
|
|
}
|
|
|
|
spirv::InterfaceVarABIAttr
|
|
spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding,
|
|
Optional<spirv::StorageClass> storageClass,
|
|
MLIRContext *context) {
|
|
Type i32Type = IntegerType::get(32, context);
|
|
auto scAttr =
|
|
storageClass
|
|
? IntegerAttr::get(i32Type, static_cast<int64_t>(*storageClass))
|
|
: IntegerAttr();
|
|
return spirv::InterfaceVarABIAttr::get(
|
|
IntegerAttr::get(i32Type, descriptorSet),
|
|
IntegerAttr::get(i32Type, binding), scAttr, context);
|
|
}
|
|
|
|
StringRef spirv::getEntryPointABIAttrName() { return "spv.entry_point_abi"; }
|
|
|
|
spirv::EntryPointABIAttr
|
|
spirv::getEntryPointABIAttr(ArrayRef<int32_t> localSize, MLIRContext *context) {
|
|
assert(localSize.size() == 3);
|
|
return spirv::EntryPointABIAttr::get(
|
|
DenseElementsAttr::get<int32_t>(
|
|
VectorType::get(3, IntegerType::get(32, context)), localSize)
|
|
.cast<DenseIntElementsAttr>(),
|
|
context);
|
|
}
|
|
|
|
spirv::EntryPointABIAttr spirv::lookupEntryPointABI(Operation *op) {
|
|
while (op && !op->hasTrait<OpTrait::FunctionLike>())
|
|
op = op->getParentOp();
|
|
if (!op)
|
|
return {};
|
|
|
|
if (auto attr = op->getAttrOfType<spirv::EntryPointABIAttr>(
|
|
spirv::getEntryPointABIAttrName()))
|
|
return attr;
|
|
|
|
return {};
|
|
}
|
|
|
|
DenseIntElementsAttr spirv::lookupLocalWorkGroupSize(Operation *op) {
|
|
if (auto entryPoint = spirv::lookupEntryPointABI(op))
|
|
return entryPoint.local_size();
|
|
|
|
return {};
|
|
}
|
|
|
|
spirv::ResourceLimitsAttr
|
|
spirv::getDefaultResourceLimits(MLIRContext *context) {
|
|
auto i32Type = IntegerType::get(32, context);
|
|
auto v3i32Type = VectorType::get(3, i32Type);
|
|
|
|
// These numbers are from "Table 46. Required Limits" of the Vulkan spec.
|
|
return spirv::ResourceLimitsAttr ::get(
|
|
IntegerAttr::get(i32Type, 128),
|
|
DenseIntElementsAttr::get<int32_t>(v3i32Type, {128, 128, 64}), context);
|
|
}
|
|
|
|
StringRef spirv::getTargetEnvAttrName() { return "spv.target_env"; }
|
|
|
|
spirv::TargetEnvAttr spirv::getDefaultTargetEnv(MLIRContext *context) {
|
|
auto triple = spirv::VerCapExtAttr::get(spirv::Version::V_1_0,
|
|
{spirv::Capability::Shader},
|
|
ArrayRef<Extension>(), context);
|
|
return spirv::TargetEnvAttr::get(triple,
|
|
spirv::getDefaultResourceLimits(context));
|
|
}
|
|
|
|
spirv::TargetEnvAttr spirv::lookupTargetEnv(Operation *op) {
|
|
while (op) {
|
|
op = SymbolTable::getNearestSymbolTable(op);
|
|
if (!op)
|
|
break;
|
|
|
|
if (auto attr = op->getAttrOfType<spirv::TargetEnvAttr>(
|
|
spirv::getTargetEnvAttrName()))
|
|
return attr;
|
|
|
|
op = op->getParentOp();
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
spirv::TargetEnvAttr spirv::lookupTargetEnvOrDefault(Operation *op) {
|
|
if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op))
|
|
return attr;
|
|
|
|
return getDefaultTargetEnv(op->getContext());
|
|
}
|