[DirectX] Implement typedBufferLoad_checkbit (#108087)

This represents a typedBufferLoad that's followed by
"CheckAccessFullyMapped". It returns an extra `i1` representing that
value.

Fixes #108085
This commit is contained in:
Justin Bogner
2024-09-11 16:24:38 -07:00
committed by GitHub
parent 93e45a69dd
commit 34e20f18f0
5 changed files with 96 additions and 11 deletions

View File

@@ -361,6 +361,12 @@ Examples:
- ``i32``
- Index into the buffer
.. code-block:: llvm
%ret = call {<4 x float>, i1}
@llvm.dx.typedBufferLoad.checkbit.v4f32.tdx.TypedBuffer_v4f32_0_0_0t(
target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 %index)
Texture and Typed Buffer Stores
-------------------------------

View File

@@ -32,6 +32,9 @@ def int_dx_handle_fromBinding
def int_dx_typedBufferLoad
: DefaultAttrsIntrinsic<[llvm_any_ty], [llvm_any_ty, llvm_i32_ty]>;
def int_dx_typedBufferLoad_checkbit
: DefaultAttrsIntrinsic<[llvm_any_ty, llvm_i1_ty],
[llvm_any_ty, llvm_i32_ty]>;
def int_dx_typedBufferStore
: DefaultAttrsIntrinsic<[], [llvm_any_ty, llvm_i32_ty, llvm_anyvector_ty]>;

View File

@@ -719,6 +719,15 @@ def BufferStore : DXILOp<69, bufferStore> {
let stages = [Stages<DXIL1_0, [all_stages]>];
}
def CheckAccessFullyMapped : DXILOp<71, checkAccessFullyMapped> {
let Doc = "checks whether a Sample, Gather, or Load operation "
"accessed mapped tiles in a tiled resource";
let arguments = [OverloadTy];
let result = Int1Ty;
let overloads = [Overloads<DXIL1_0, [Int32Ty]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
}
def ThreadId : DXILOp<93, threadId> {
let Doc = "Reads the thread ID";
let LLVMIntrinsic = int_dx_thread_id;

View File

@@ -265,16 +265,50 @@ public:
/// Replace uses of \c Intrin with the values in the `dx.ResRet` of \c Op.
/// Since we expect to be post-scalarization, make an effort to avoid vectors.
Error replaceResRetUses(CallInst *Intrin, CallInst *Op) {
Error replaceResRetUses(CallInst *Intrin, CallInst *Op, bool HasCheckBit) {
IRBuilder<> &IRB = OpBuilder.getIRB();
Instruction *OldResult = Intrin;
Type *OldTy = Intrin->getType();
if (HasCheckBit) {
auto *ST = cast<StructType>(OldTy);
Value *CheckOp = nullptr;
Type *Int32Ty = IRB.getInt32Ty();
for (Use &U : make_early_inc_range(OldResult->uses())) {
if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
ArrayRef<unsigned> Indices = EVI->getIndices();
assert(Indices.size() == 1);
// We're only interested in uses of the check bit for now.
if (Indices[0] != 1)
continue;
if (!CheckOp) {
Value *NewEVI = IRB.CreateExtractValue(Op, 4);
Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
OpCode::CheckAccessFullyMapped, {NewEVI}, Int32Ty);
if (Error E = OpCall.takeError())
return E;
CheckOp = *OpCall;
}
EVI->replaceAllUsesWith(CheckOp);
EVI->eraseFromParent();
}
}
OldResult = cast<Instruction>(IRB.CreateExtractValue(Op, 0));
OldTy = ST->getElementType(0);
}
// For scalars, we just extract the first element.
if (!isa<FixedVectorType>(OldTy)) {
Value *EVI = IRB.CreateExtractValue(Op, 0);
Intrin->replaceAllUsesWith(EVI);
Intrin->eraseFromParent();
OldResult->replaceAllUsesWith(EVI);
OldResult->eraseFromParent();
if (OldResult != Intrin) {
assert(Intrin->use_empty() && "Intrinsic still has uses?");
Intrin->eraseFromParent();
}
return Error::success();
}
@@ -283,7 +317,7 @@ public:
// The users of the operation should all be scalarized, so we attempt to
// replace the extractelements with extractvalues directly.
for (Use &U : make_early_inc_range(Intrin->uses())) {
for (Use &U : make_early_inc_range(OldResult->uses())) {
if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser())) {
if (auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
size_t IndexVal = IndexOp->getZExtValue();
@@ -331,7 +365,7 @@ public:
// If we still have uses, then we're not fully scalarized and need to
// recreate the vector. This should only happen for things like exported
// functions from libraries.
if (!Intrin->use_empty()) {
if (!OldResult->use_empty()) {
for (int I = 0, E = N; I != E; ++I)
if (!Extracts[I])
Extracts[I] = IRB.CreateExtractValue(Op, I);
@@ -339,14 +373,19 @@ public:
Value *Vec = UndefValue::get(OldTy);
for (int I = 0, E = N; I != E; ++I)
Vec = IRB.CreateInsertElement(Vec, Extracts[I], I);
Intrin->replaceAllUsesWith(Vec);
OldResult->replaceAllUsesWith(Vec);
}
OldResult->eraseFromParent();
if (OldResult != Intrin) {
assert(Intrin->use_empty() && "Intrinsic still has uses?");
Intrin->eraseFromParent();
}
Intrin->eraseFromParent();
return Error::success();
}
[[nodiscard]] bool lowerTypedBufferLoad(Function &F) {
[[nodiscard]] bool lowerTypedBufferLoad(Function &F, bool HasCheckBit) {
IRBuilder<> &IRB = OpBuilder.getIRB();
Type *Int32Ty = IRB.getInt32Ty();
@@ -358,14 +397,17 @@ public:
Value *Index0 = CI->getArgOperand(1);
Value *Index1 = UndefValue::get(Int32Ty);
Type *NewRetTy = OpBuilder.getResRetType(CI->getType()->getScalarType());
Type *OldTy = CI->getType();
if (HasCheckBit)
OldTy = cast<StructType>(OldTy)->getElementType(0);
Type *NewRetTy = OpBuilder.getResRetType(OldTy->getScalarType());
std::array<Value *, 3> Args{Handle, Index0, Index1};
Expected<CallInst *> OpCall =
OpBuilder.tryCreateOp(OpCode::BufferLoad, Args, NewRetTy);
if (Error E = OpCall.takeError())
return E;
if (Error E = replaceResRetUses(CI, *OpCall))
if (Error E = replaceResRetUses(CI, *OpCall, HasCheckBit))
return E;
return Error::success();
@@ -434,7 +476,10 @@ public:
HasErrors |= lowerHandleFromBinding(F);
break;
case Intrinsic::dx_typedBufferLoad:
HasErrors |= lowerTypedBufferLoad(F);
HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/false);
break;
case Intrinsic::dx_typedBufferLoad_checkbit:
HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/true);
break;
case Intrinsic::dx_typedBufferStore:
HasErrors |= lowerTypedBufferStore(F);

View File

@@ -4,6 +4,7 @@ target triple = "dxil-pc-shadermodel6.6-compute"
declare void @scalar_user(float)
declare void @vector_user(<4 x float>)
declare void @check_user(i1)
define void @loadv4f32() {
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
@@ -128,6 +129,27 @@ define void @loadv2f32() {
ret void
}
define void @loadv4f32_checkbit() {
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
%buffer = call target("dx.TypedBuffer", <4 x float>, 0, 0, 0)
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_0_0_0(
i32 0, i32 0, i32 1, i32 0, i1 false)
; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
%data0 = call {<4 x float>, i1} @llvm.dx.typedBufferLoad.checkbit.f32(
target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 0)
; CHECK: [[STATUS:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA0]], 4
; CHECK: [[MAPPED:%.*]] = call i1 @dx.op.checkAccessFullyMapped.i32(i32 71, i32 [[STATUS]]
%check = extractvalue {<4 x float>, i1} %data0, 1
; CHECK: call void @check_user(i1 [[MAPPED]])
call void @check_user(i1 %check)
ret void
}
define void @loadv4i32() {
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]