diff --git a/bolt/lib/Core/BinaryContext.cpp b/bolt/lib/Core/BinaryContext.cpp index 47eae964e816..7c2d8c52287b 100644 --- a/bolt/lib/Core/BinaryContext.cpp +++ b/bolt/lib/Core/BinaryContext.cpp @@ -555,6 +555,9 @@ bool BinaryContext::analyzeJumpTable(const uint64_t Address, const uint64_t NextJTAddress, JumpTable::AddressesType *EntriesAsAddress, bool *HasEntryInFragment) const { + // Target address of __builtin_unreachable. + const uint64_t UnreachableAddress = BF.getAddress() + BF.getSize(); + // Is one of the targets __builtin_unreachable? bool HasUnreachable = false; @@ -564,9 +567,15 @@ bool BinaryContext::analyzeJumpTable(const uint64_t Address, // Number of targets other than __builtin_unreachable. uint64_t NumRealEntries = 0; - auto addEntryAddress = [&](uint64_t EntryAddress) { - if (EntriesAsAddress) - EntriesAsAddress->emplace_back(EntryAddress); + // Size of the jump table without trailing __builtin_unreachable entries. + size_t TrimmedSize = 0; + + auto addEntryAddress = [&](uint64_t EntryAddress, bool Unreachable = false) { + if (!EntriesAsAddress) + return; + EntriesAsAddress->emplace_back(EntryAddress); + if (!Unreachable) + TrimmedSize = EntriesAsAddress->size(); }; ErrorOr Section = getSectionForAddress(Address); @@ -618,8 +627,8 @@ bool BinaryContext::analyzeJumpTable(const uint64_t Address, : *getPointerAtAddress(EntryAddress); // __builtin_unreachable() case. - if (Value == BF.getAddress() + BF.getSize()) { - addEntryAddress(Value); + if (Value == UnreachableAddress) { + addEntryAddress(Value, /*Unreachable*/ true); HasUnreachable = true; LLVM_DEBUG(dbgs() << formatv("OK: {0:x} __builtin_unreachable\n", Value)); continue; @@ -673,6 +682,13 @@ bool BinaryContext::analyzeJumpTable(const uint64_t Address, addEntryAddress(Value); } + // Trim direct/normal jump table to exclude trailing unreachable entries that + // can collide with a function address. + if (Type == JumpTable::JTT_NORMAL && EntriesAsAddress && + TrimmedSize != EntriesAsAddress->size() && + getBinaryFunctionAtAddress(UnreachableAddress)) + EntriesAsAddress->resize(TrimmedSize); + // It's a jump table if the number of real entries is more than 1, or there's // one real entry and one or more special targets. If there are only multiple // special targets, then it's not a jump table. diff --git a/bolt/test/runtime/X86/jt-confusion.s b/bolt/test/runtime/X86/jt-confusion.s new file mode 100644 index 000000000000..f15c83b35b6a --- /dev/null +++ b/bolt/test/runtime/X86/jt-confusion.s @@ -0,0 +1,164 @@ +# REQUIRES: system-linux + +# RUN: llvm-mc -filetype=obj -triple x86_64-unknown-unknown %s -o %t.o +# RUN: llvm-strip --strip-unneeded %t.o +# RUN: %clang %cflags -no-pie -nostartfiles -nostdlib -lc %t.o -o %t.exe -Wl,-q + +# RUN: llvm-bolt %t.exe -o %t.exe.bolt --relocs=1 --lite=0 + +# RUN: %t.exe.bolt + +## Check that BOLT's jump table detection diffrentiates between +## __builtin_unreachable() targets and function pointers. + +## The test case was built from the following two source files and +## modiffied for standalone build. main became _start, etc. +## $ $(CC) a.c -O1 -S -o a.s +## $ $(CC) b.c -O0 -S -o b.s + +## a.c: + +## typedef int (*fptr)(int); +## void check_fptr(fptr, int); +## +## int foo(int a) { +## check_fptr(foo, 0); +## switch (a) { +## default: +## __builtin_unreachable(); +## case 0: +## return 3; +## case 1: +## return 5; +## case 2: +## return 7; +## case 3: +## return 11; +## case 4: +## return 13; +## case 5: +## return 17; +## } +## return 0; +## } +## +## int main(int argc) { +## check_fptr(main, 1); +## return foo(argc); +## } +## +## const fptr funcs[2] = {foo, main}; + +## b.c.: + +## typedef int (*fptr)(int); +## extern const fptr funcs[2]; +## +## #define assert(C) { if (!(C)) (*(unsigned long long *)0) = 0; } +## void check_fptr(fptr f, int i) { +## assert(f == funcs[i]); +## } + + + .text + .globl foo + .type foo, @function +foo: +.LFB0: + .cfi_startproc + pushq %rbx + .cfi_def_cfa_offset 16 + .cfi_offset 3, -16 + movl %edi, %ebx + movl $0, %esi + movl $foo, %edi + call check_fptr + movl %ebx, %ebx + jmp *.L4(,%rbx,8) +.L8: + movl $5, %eax + jmp .L1 +.L7: + movl $7, %eax + jmp .L1 +.L6: + movl $11, %eax + jmp .L1 +.L5: + movl $13, %eax + jmp .L1 +.L3: + movl $17, %eax + jmp .L1 +.L10: + movl $3, %eax +.L1: + popq %rbx + .cfi_def_cfa_offset 8 + ret + .cfi_endproc +.LFE0: + .size foo, .-foo + .globl _start + .type _start, @function +_start: +.LFB1: + .cfi_startproc + pushq %rbx + .cfi_def_cfa_offset 16 + .cfi_offset 3, -16 + movl %edi, %ebx + movl $1, %esi + movl $_start, %edi + call check_fptr + movl $1, %edi + call foo + popq %rbx + .cfi_def_cfa_offset 8 + callq exit@PLT + .cfi_endproc +.LFE1: + .size _start, .-_start + .globl check_fptr + .type check_fptr, @function +check_fptr: +.LFB2: + .cfi_startproc + pushq %rbp + .cfi_def_cfa_offset 16 + .cfi_offset 6, -16 + movq %rsp, %rbp + .cfi_def_cfa_register 6 + movq %rdi, -8(%rbp) + movl %esi, -12(%rbp) + movl -12(%rbp), %eax + cltq + movq funcs(,%rax,8), %rax + cmpq %rax, -8(%rbp) + je .L33 + movl $0, %eax + movq $0, (%rax) +.L33: + nop + popq %rbp + .cfi_def_cfa 7, 8 + ret + .cfi_endproc + + .section .rodata + .align 8 + .align 4 +.L4: + .quad .L10 + .quad .L8 + .quad .L7 + .quad .L6 + .quad .L5 + .quad .L3 + + .globl funcs + .type funcs, @object + .size funcs, 16 +funcs: + .quad foo + .quad _start