Files
clang-p2996/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp
Finn Plummer 9e0186d925 [HLSL][RootSignature] Implement ResourceRange as an IntervalMap (#140957)
A resource range consists of a closed interval, `[a;b]`, denoting which
shader registers it is bound to.

For instance:
 - `CBV(b1)`  corresponds to the resource range of `[1;1]`
 - `CBV(b0, numDescriptors = 3)` likewise to `[0;2]`

We want to provide an error diagnostic when there is an overlap in the
required registers (an overlap in the resource ranges).

The goal of this pr is to implement a structure to model a set of
resource ranges and provide an api to detect any overlap over a set of
resource ranges.

`ResourceRange` models this by implementing an `IntervalMap` to denote a
mapping from an interval of registers back to a resource range. It
allows for a new `ResourceRange` to be added to the mapping and it will
report if and what the first overlap is.

For the context of how this will be used in validation of a
`RootSignatureDecl` please see the proceeding pull request here:
https://github.com/llvm/llvm-project/pull/140962.

- Implements `ResourceRange` as an `IntervalMap`
- Adds unit testing of the various `insert` scenarios

Note: it was also considered to implement this as an `IntervalTree`,
this would allow reporting of a diagnostic for each overlap that is
encountered, as opposed to just the first. However, error generation of
just reporting the first error is already rather verbose, and adding the
additional diagnostics only made this worse.

Part 1 of https://github.com/llvm/llvm-project/issues/129942
2025-06-17 10:24:57 -07:00

422 lines
15 KiB
C++

//===- HLSLRootSignatureUtils.cpp - HLSL Root Signature helpers -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file This file contains helpers for working with HLSL Root Signatures.
///
//===----------------------------------------------------------------------===//
#include "llvm/Frontend/HLSL/HLSLRootSignatureUtils.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/bit.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Metadata.h"
#include "llvm/Support/ScopedPrinter.h"
namespace llvm {
namespace hlsl {
namespace rootsig {
template <typename T>
static std::optional<StringRef> getEnumName(const T Value,
ArrayRef<EnumEntry<T>> Enums) {
for (const auto &EnumItem : Enums)
if (EnumItem.Value == Value)
return EnumItem.Name;
return std::nullopt;
}
template <typename T>
static raw_ostream &printEnum(raw_ostream &OS, const T Value,
ArrayRef<EnumEntry<T>> Enums) {
auto MaybeName = getEnumName(Value, Enums);
if (MaybeName)
OS << *MaybeName;
return OS;
}
template <typename T>
static raw_ostream &printFlags(raw_ostream &OS, const T Value,
ArrayRef<EnumEntry<T>> Flags) {
bool FlagSet = false;
unsigned Remaining = llvm::to_underlying(Value);
while (Remaining) {
unsigned Bit = 1u << llvm::countr_zero(Remaining);
if (Remaining & Bit) {
if (FlagSet)
OS << " | ";
auto MaybeFlag = getEnumName(T(Bit), Flags);
if (MaybeFlag)
OS << *MaybeFlag;
else
OS << "invalid: " << Bit;
FlagSet = true;
}
Remaining &= ~Bit;
}
if (!FlagSet)
OS << "None";
return OS;
}
static const EnumEntry<RegisterType> RegisterNames[] = {
{"b", RegisterType::BReg},
{"t", RegisterType::TReg},
{"u", RegisterType::UReg},
{"s", RegisterType::SReg},
};
static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) {
printEnum(OS, Reg.ViewType, ArrayRef(RegisterNames));
OS << Reg.Number;
return OS;
}
static const EnumEntry<ShaderVisibility> VisibilityNames[] = {
{"All", ShaderVisibility::All},
{"Vertex", ShaderVisibility::Vertex},
{"Hull", ShaderVisibility::Hull},
{"Domain", ShaderVisibility::Domain},
{"Geometry", ShaderVisibility::Geometry},
{"Pixel", ShaderVisibility::Pixel},
{"Amplification", ShaderVisibility::Amplification},
{"Mesh", ShaderVisibility::Mesh},
};
static raw_ostream &operator<<(raw_ostream &OS,
const ShaderVisibility &Visibility) {
printEnum(OS, Visibility, ArrayRef(VisibilityNames));
return OS;
}
static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
{"CBV", dxil::ResourceClass::CBuffer},
{"SRV", dxil::ResourceClass::SRV},
{"UAV", dxil::ResourceClass::UAV},
{"Sampler", dxil::ResourceClass::Sampler},
};
static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) {
printEnum(OS, dxil::ResourceClass(llvm::to_underlying(Type)),
ArrayRef(ResourceClassNames));
return OS;
}
static const EnumEntry<DescriptorRangeFlags> DescriptorRangeFlagNames[] = {
{"DescriptorsVolatile", DescriptorRangeFlags::DescriptorsVolatile},
{"DataVolatile", DescriptorRangeFlags::DataVolatile},
{"DataStaticWhileSetAtExecute",
DescriptorRangeFlags::DataStaticWhileSetAtExecute},
{"DataStatic", DescriptorRangeFlags::DataStatic},
{"DescriptorsStaticKeepingBufferBoundsChecks",
DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks},
};
static raw_ostream &operator<<(raw_ostream &OS,
const DescriptorRangeFlags &Flags) {
printFlags(OS, Flags, ArrayRef(DescriptorRangeFlagNames));
return OS;
}
static const EnumEntry<RootFlags> RootFlagNames[] = {
{"AllowInputAssemblerInputLayout",
RootFlags::AllowInputAssemblerInputLayout},
{"DenyVertexShaderRootAccess", RootFlags::DenyVertexShaderRootAccess},
{"DenyHullShaderRootAccess", RootFlags::DenyHullShaderRootAccess},
{"DenyDomainShaderRootAccess", RootFlags::DenyDomainShaderRootAccess},
{"DenyGeometryShaderRootAccess", RootFlags::DenyGeometryShaderRootAccess},
{"DenyPixelShaderRootAccess", RootFlags::DenyPixelShaderRootAccess},
{"AllowStreamOutput", RootFlags::AllowStreamOutput},
{"LocalRootSignature", RootFlags::LocalRootSignature},
{"DenyAmplificationShaderRootAccess",
RootFlags::DenyAmplificationShaderRootAccess},
{"DenyMeshShaderRootAccess", RootFlags::DenyMeshShaderRootAccess},
{"CBVSRVUAVHeapDirectlyIndexed", RootFlags::CBVSRVUAVHeapDirectlyIndexed},
{"SamplerHeapDirectlyIndexed", RootFlags::SamplerHeapDirectlyIndexed},
};
raw_ostream &operator<<(raw_ostream &OS, const RootFlags &Flags) {
OS << "RootFlags(";
printFlags(OS, Flags, ArrayRef(RootFlagNames));
OS << ")";
return OS;
}
raw_ostream &operator<<(raw_ostream &OS, const RootConstants &Constants) {
OS << "RootConstants(num32BitConstants = " << Constants.Num32BitConstants
<< ", " << Constants.Reg << ", space = " << Constants.Space
<< ", visibility = " << Constants.Visibility << ")";
return OS;
}
raw_ostream &operator<<(raw_ostream &OS, const DescriptorTable &Table) {
OS << "DescriptorTable(numClauses = " << Table.NumClauses
<< ", visibility = " << Table.Visibility << ")";
return OS;
}
raw_ostream &operator<<(raw_ostream &OS, const DescriptorTableClause &Clause) {
OS << Clause.Type << "(" << Clause.Reg
<< ", numDescriptors = " << Clause.NumDescriptors
<< ", space = " << Clause.Space << ", offset = ";
if (Clause.Offset == DescriptorTableOffsetAppend)
OS << "DescriptorTableOffsetAppend";
else
OS << Clause.Offset;
OS << ", flags = " << Clause.Flags << ")";
return OS;
}
void dumpRootElements(raw_ostream &OS, ArrayRef<RootElement> Elements) {
OS << "RootElements{";
bool First = true;
for (const RootElement &Element : Elements) {
if (!First)
OS << ",";
OS << " ";
if (const auto &Clause = std::get_if<DescriptorTableClause>(&Element))
OS << *Clause;
if (const auto &Table = std::get_if<DescriptorTable>(&Element))
OS << *Table;
First = false;
}
OS << "}";
}
namespace {
// We use the OverloadBuild with std::visit to ensure the compiler catches if a
// new RootElement variant type is added but it's metadata generation isn't
// handled.
template <class... Ts> struct OverloadedBuild : Ts... {
using Ts::operator()...;
};
template <class... Ts> OverloadedBuild(Ts...) -> OverloadedBuild<Ts...>;
} // namespace
MDNode *MetadataBuilder::BuildRootSignature() {
const auto Visitor = OverloadedBuild{
[this](const RootFlags &Flags) -> MDNode * {
return BuildRootFlags(Flags);
},
[this](const RootConstants &Constants) -> MDNode * {
return BuildRootConstants(Constants);
},
[this](const RootDescriptor &Descriptor) -> MDNode * {
return BuildRootDescriptor(Descriptor);
},
[this](const DescriptorTableClause &Clause) -> MDNode * {
return BuildDescriptorTableClause(Clause);
},
[this](const DescriptorTable &Table) -> MDNode * {
return BuildDescriptorTable(Table);
},
[this](const StaticSampler &Sampler) -> MDNode * {
return BuildStaticSampler(Sampler);
},
};
for (const RootElement &Element : Elements) {
MDNode *ElementMD = std::visit(Visitor, Element);
assert(ElementMD != nullptr &&
"Root Element must be initialized and validated");
GeneratedMetadata.push_back(ElementMD);
}
return MDNode::get(Ctx, GeneratedMetadata);
}
MDNode *MetadataBuilder::BuildRootFlags(const RootFlags &Flags) {
IRBuilder<> Builder(Ctx);
Metadata *Operands[] = {
MDString::get(Ctx, "RootFlags"),
ConstantAsMetadata::get(Builder.getInt32(llvm::to_underlying(Flags))),
};
return MDNode::get(Ctx, Operands);
}
MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) {
IRBuilder<> Builder(Ctx);
Metadata *Operands[] = {
MDString::get(Ctx, "RootConstants"),
ConstantAsMetadata::get(
Builder.getInt32(llvm::to_underlying(Constants.Visibility))),
ConstantAsMetadata::get(Builder.getInt32(Constants.Reg.Number)),
ConstantAsMetadata::get(Builder.getInt32(Constants.Space)),
ConstantAsMetadata::get(Builder.getInt32(Constants.Num32BitConstants)),
};
return MDNode::get(Ctx, Operands);
}
MDNode *MetadataBuilder::BuildRootDescriptor(const RootDescriptor &Descriptor) {
IRBuilder<> Builder(Ctx);
std::optional<StringRef> TypeName =
getEnumName(dxil::ResourceClass(llvm::to_underlying(Descriptor.Type)),
ArrayRef(ResourceClassNames));
assert(TypeName && "Provided an invalid Resource Class");
llvm::SmallString<7> Name({"Root", *TypeName});
Metadata *Operands[] = {
MDString::get(Ctx, Name),
ConstantAsMetadata::get(
Builder.getInt32(llvm::to_underlying(Descriptor.Visibility))),
ConstantAsMetadata::get(Builder.getInt32(Descriptor.Reg.Number)),
ConstantAsMetadata::get(Builder.getInt32(Descriptor.Space)),
ConstantAsMetadata::get(
Builder.getInt32(llvm::to_underlying(Descriptor.Flags))),
};
return MDNode::get(Ctx, Operands);
}
MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) {
IRBuilder<> Builder(Ctx);
SmallVector<Metadata *> TableOperands;
// Set the mandatory arguments
TableOperands.push_back(MDString::get(Ctx, "DescriptorTable"));
TableOperands.push_back(ConstantAsMetadata::get(
Builder.getInt32(llvm::to_underlying(Table.Visibility))));
// Remaining operands are references to the table's clauses. The in-memory
// representation of the Root Elements created from parsing will ensure that
// the previous N elements are the clauses for this table.
assert(Table.NumClauses <= GeneratedMetadata.size() &&
"Table expected all owned clauses to be generated already");
// So, add a refence to each clause to our operands
TableOperands.append(GeneratedMetadata.end() - Table.NumClauses,
GeneratedMetadata.end());
// Then, remove those clauses from the general list of Root Elements
GeneratedMetadata.pop_back_n(Table.NumClauses);
return MDNode::get(Ctx, TableOperands);
}
MDNode *MetadataBuilder::BuildDescriptorTableClause(
const DescriptorTableClause &Clause) {
IRBuilder<> Builder(Ctx);
std::optional<StringRef> Name =
getEnumName(dxil::ResourceClass(llvm::to_underlying(Clause.Type)),
ArrayRef(ResourceClassNames));
assert(Name && "Provided an invalid Resource Class");
Metadata *Operands[] = {
MDString::get(Ctx, *Name),
ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)),
ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)),
ConstantAsMetadata::get(Builder.getInt32(Clause.Space)),
ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)),
ConstantAsMetadata::get(
Builder.getInt32(llvm::to_underlying(Clause.Flags))),
};
return MDNode::get(Ctx, Operands);
}
MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {
IRBuilder<> Builder(Ctx);
Metadata *Operands[] = {
MDString::get(Ctx, "StaticSampler"),
ConstantAsMetadata::get(
Builder.getInt32(llvm::to_underlying(Sampler.Filter))),
ConstantAsMetadata::get(
Builder.getInt32(llvm::to_underlying(Sampler.AddressU))),
ConstantAsMetadata::get(
Builder.getInt32(llvm::to_underlying(Sampler.AddressV))),
ConstantAsMetadata::get(
Builder.getInt32(llvm::to_underlying(Sampler.AddressW))),
ConstantAsMetadata::get(llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx),
Sampler.MipLODBias)),
ConstantAsMetadata::get(Builder.getInt32(Sampler.MaxAnisotropy)),
ConstantAsMetadata::get(
Builder.getInt32(llvm::to_underlying(Sampler.CompFunc))),
ConstantAsMetadata::get(
Builder.getInt32(llvm::to_underlying(Sampler.BorderColor))),
ConstantAsMetadata::get(
llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), Sampler.MinLOD)),
ConstantAsMetadata::get(
llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), Sampler.MaxLOD)),
ConstantAsMetadata::get(Builder.getInt32(Sampler.Reg.Number)),
ConstantAsMetadata::get(Builder.getInt32(Sampler.Space)),
ConstantAsMetadata::get(
Builder.getInt32(llvm::to_underlying(Sampler.Visibility))),
};
return MDNode::get(Ctx, Operands);
}
std::optional<const RangeInfo *>
ResourceRange::getOverlapping(const RangeInfo &Info) const {
MapT::const_iterator Interval = Intervals.find(Info.LowerBound);
if (!Interval.valid() || Info.UpperBound < Interval.start())
return std::nullopt;
return Interval.value();
}
const RangeInfo *ResourceRange::lookup(uint32_t X) const {
return Intervals.lookup(X, nullptr);
}
std::optional<const RangeInfo *> ResourceRange::insert(const RangeInfo &Info) {
uint32_t LowerBound = Info.LowerBound;
uint32_t UpperBound = Info.UpperBound;
std::optional<const RangeInfo *> Res = std::nullopt;
MapT::iterator Interval = Intervals.begin();
while (true) {
if (UpperBound < LowerBound)
break;
Interval.advanceTo(LowerBound);
if (!Interval.valid()) // No interval found
break;
// Let Interval = [x;y] and [LowerBound;UpperBound] = [a;b] and note that
// a <= y implicitly from Intervals.find(LowerBound)
if (UpperBound < Interval.start())
break; // found interval does not overlap with inserted one
if (!Res.has_value()) // Update to be the first found intersection
Res = Interval.value();
if (Interval.start() <= LowerBound && UpperBound <= Interval.stop()) {
// x <= a <= b <= y implies that [a;b] is covered by [x;y]
// -> so we don't need to insert this, report an overlap
return Res;
} else if (LowerBound <= Interval.start() &&
Interval.stop() <= UpperBound) {
// a <= x <= y <= b implies that [x;y] is covered by [a;b]
// -> so remove the existing interval that we will cover with the
// overwrite
Interval.erase();
} else if (LowerBound < Interval.start() && UpperBound <= Interval.stop()) {
// a < x <= b <= y implies that [a; x] is not covered but [x;b] is
// -> so set b = x - 1 such that [a;x-1] is now the interval to insert
UpperBound = Interval.start() - 1;
} else if (Interval.start() <= LowerBound && Interval.stop() < UpperBound) {
// a < x <= b <= y implies that [y; b] is not covered but [a;y] is
// -> so set a = y + 1 such that [y+1;b] is now the interval to insert
LowerBound = Interval.stop() + 1;
}
}
assert(LowerBound <= UpperBound && "Attempting to insert an empty interval");
Intervals.insert(LowerBound, UpperBound, &Info);
return Res;
}
} // namespace rootsig
} // namespace hlsl
} // namespace llvm