[DirectX] replace byte splitting via vector bitcast with scalar (#140167)

instructions
- instead of bitcasting and extract element lets use trunc or trunc and
logical shift right to split.
- fixes #139020
This commit is contained in:
Farzon Lotfi
2025-06-04 21:28:43 -04:00
committed by GitHub
parent 9cd53787df
commit c1e0faecfc
2 changed files with 103 additions and 18 deletions

View File

@@ -8,6 +8,8 @@
#include "DXILLegalizePass.h"
#include "DirectX.h"
#include "llvm/ADT/APInt.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
@@ -510,6 +512,55 @@ static void updateFnegToFsub(Instruction &I,
ToRemove.push_back(&I);
}
static void
legalizeGetHighLowi64Bytes(Instruction &I,
SmallVectorImpl<Instruction *> &ToRemove,
DenseMap<Value *, Value *> &ReplacedValues) {
if (auto *BitCast = dyn_cast<BitCastInst>(&I)) {
if (BitCast->getDestTy() ==
FixedVectorType::get(Type::getInt32Ty(I.getContext()), 2) &&
BitCast->getSrcTy()->isIntegerTy(64)) {
ToRemove.push_back(BitCast);
ReplacedValues[BitCast] = BitCast->getOperand(0);
return;
}
}
if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) {
if (!dyn_cast<BitCastInst>(Extract->getVectorOperand()))
return;
auto *VecTy = dyn_cast<FixedVectorType>(Extract->getVectorOperandType());
if (VecTy && VecTy->getElementType()->isIntegerTy(32) &&
VecTy->getNumElements() == 2) {
if (auto *Index = dyn_cast<ConstantInt>(Extract->getIndexOperand())) {
unsigned Idx = Index->getZExtValue();
IRBuilder<> Builder(&I);
auto *Replacement = ReplacedValues[Extract->getVectorOperand()];
assert(Replacement && "The BitCast replacement should have been set "
"before working on ExtractElementInst.");
if (Idx == 0) {
Value *LowBytes = Builder.CreateTrunc(
Replacement, Type::getInt32Ty(I.getContext()));
ReplacedValues[Extract] = LowBytes;
} else {
assert(Idx == 1);
Value *LogicalShiftRight = Builder.CreateLShr(
Replacement,
ConstantInt::get(
Replacement->getType(),
APInt(Replacement->getType()->getIntegerBitWidth(), 32)));
Value *HighBytes = Builder.CreateTrunc(
LogicalShiftRight, Type::getInt32Ty(I.getContext()));
ReplacedValues[Extract] = HighBytes;
}
ToRemove.push_back(Extract);
Extract->replaceAllUsesWith(ReplacedValues[Extract]);
}
}
}
}
namespace {
class DXILLegalizationPipeline {
@@ -517,33 +568,49 @@ public:
DXILLegalizationPipeline() { initializeLegalizationPipeline(); }
bool runLegalizationPipeline(Function &F) {
bool MadeChange = false;
SmallVector<Instruction *> ToRemove;
DenseMap<Value *, Value *> ReplacedValues;
for (auto &I : instructions(F)) {
for (auto &LegalizationFn : LegalizationPipeline)
LegalizationFn(I, ToRemove, ReplacedValues);
for (int Stage = 0; Stage < NumStages; ++Stage) {
ToRemove.clear();
ReplacedValues.clear();
for (auto &I : instructions(F)) {
for (auto &LegalizationFn : LegalizationPipeline[Stage])
LegalizationFn(I, ToRemove, ReplacedValues);
}
for (auto *Inst : reverse(ToRemove))
Inst->eraseFromParent();
MadeChange |= !ToRemove.empty();
}
for (auto *Inst : reverse(ToRemove))
Inst->eraseFromParent();
return !ToRemove.empty();
return MadeChange;
}
private:
SmallVector<
enum LegalizationStage { Stage1 = 0, Stage2 = 1, NumStages };
using LegalizationFnTy =
std::function<void(Instruction &, SmallVectorImpl<Instruction *> &,
DenseMap<Value *, Value *> &)>>
LegalizationPipeline;
DenseMap<Value *, Value *> &)>;
SmallVector<LegalizationFnTy> LegalizationPipeline[NumStages];
void initializeLegalizationPipeline() {
LegalizationPipeline.push_back(upcastI8AllocasAndUses);
LegalizationPipeline.push_back(fixI8UseChain);
LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements);
LegalizationPipeline.push_back(legalizeFreeze);
LegalizationPipeline.push_back(legalizeMemCpy);
LegalizationPipeline.push_back(removeMemSet);
LegalizationPipeline.push_back(updateFnegToFsub);
LegalizationPipeline[Stage1].push_back(upcastI8AllocasAndUses);
LegalizationPipeline[Stage1].push_back(fixI8UseChain);
LegalizationPipeline[Stage1].push_back(legalizeGetHighLowi64Bytes);
LegalizationPipeline[Stage1].push_back(legalizeFreeze);
LegalizationPipeline[Stage1].push_back(legalizeMemCpy);
LegalizationPipeline[Stage1].push_back(removeMemSet);
LegalizationPipeline[Stage1].push_back(updateFnegToFsub);
// Note: legalizeGetHighLowi64Bytes and
// downcastI64toI32InsertExtractElements both modify extractelement, so they
// must run staggered stages. legalizeGetHighLowi64Bytes runs first b\c it
// removes extractelements, reducing the number that
// downcastI64toI32InsertExtractElements needs to handle.
LegalizationPipeline[Stage2].push_back(
downcastI64toI32InsertExtractElements);
}
};

View File

@@ -0,0 +1,18 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
define void @split_via_extract(i64 noundef %a) {
; CHECK-LABEL: define void @split_via_extract(
; CHECK-SAME: i64 noundef [[A:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[TMP0:%.*]] = trunc i64 [[A]] to i32
; CHECK-NEXT: [[TMP1:%.*]] = lshr i64 [[A]], 32
; CHECK-NEXT: [[TMP2:%.*]] = trunc i64 [[TMP1]] to i32
; CHECK-NEXT: ret void
;
entry:
%vecA = bitcast i64 %a to <2 x i32>
%low = extractelement <2 x i32> %vecA, i32 0 ; low 32 bits
%high = extractelement <2 x i32> %vecA, i32 1 ; high 32 bits
ret void
}