libclc: clspv: fix fma, add vstore and fix inlining issues

https://reviews.llvm.org/D147773

Patch by Romaric Jodin <rjodin@google.com>
This commit is contained in:
Kévin Petit
2023-05-09 16:52:13 +01:00
parent f859835766
commit 21508fa769
6 changed files with 298 additions and 124 deletions

View File

@@ -271,11 +271,11 @@ foreach( t ${LIBCLC_TARGETS_TO_BUILD} )
set( spvflags --spirv-max-version=1.1 )
elseif( ${ARCH} STREQUAL "clspv" )
set( t "spir--" )
set( build_flags )
set( build_flags "-Wno-unknown-assumption")
set( opt_flags -O3 )
elseif( ${ARCH} STREQUAL "clspv64" )
set( t "spir64--" )
set( build_flags )
set( build_flags "-Wno-unknown-assumption")
set( opt_flags -O3 )
else()
set( build_flags )

View File

@@ -1,5 +1,6 @@
math/fma.cl
math/nextafter.cl
shared/vstore_half.cl
subnormal_config.cl
../../generic/lib/geometric/distance.cl
../../generic/lib/geometric/length.cl
@@ -45,6 +46,12 @@ subnormal_config.cl
../../generic/lib/math/frexp.cl
../../generic/lib/math/half_cos.cl
../../generic/lib/math/half_divide.cl
../../generic/lib/math/half_exp.cl
../../generic/lib/math/half_exp10.cl
../../generic/lib/math/half_exp2.cl
../../generic/lib/math/half_log.cl
../../generic/lib/math/half_log10.cl
../../generic/lib/math/half_log2.cl
../../generic/lib/math/half_powr.cl
../../generic/lib/math/half_recip.cl
../../generic/lib/math/half_sin.cl

View File

@@ -34,6 +34,92 @@ struct fp {
uint sign;
};
static uint2 u2_set(uint hi, uint lo) {
uint2 res;
res.lo = lo;
res.hi = hi;
return res;
}
static uint2 u2_set_u(uint val) { return u2_set(0, val); }
static uint2 u2_mul(uint a, uint b) {
uint2 res;
res.hi = mul_hi(a, b);
res.lo = a * b;
return res;
}
static uint2 u2_sll(uint2 val, uint shift) {
if (shift == 0)
return val;
if (shift < 32) {
val.hi <<= shift;
val.hi |= val.lo >> (32 - shift);
val.lo <<= shift;
} else {
val.hi = val.lo << (shift - 32);
val.lo = 0;
}
return val;
}
static uint2 u2_srl(uint2 val, uint shift) {
if (shift == 0)
return val;
if (shift < 32) {
val.lo >>= shift;
val.lo |= val.hi << (32 - shift);
val.hi >>= shift;
} else {
val.lo = val.hi >> (shift - 32);
val.hi = 0;
}
return val;
}
static uint2 u2_or(uint2 a, uint b) {
a.lo |= b;
return a;
}
static uint2 u2_and(uint2 a, uint2 b) {
a.lo &= b.lo;
a.hi &= b.hi;
return a;
}
static uint2 u2_add(uint2 a, uint2 b) {
uint carry = (hadd(a.lo, b.lo) >> 31) & 0x1;
a.lo += b.lo;
a.hi += b.hi + carry;
return a;
}
static uint2 u2_add_u(uint2 a, uint b) { return u2_add(a, u2_set_u(b)); }
static uint2 u2_inv(uint2 a) {
a.lo = ~a.lo;
a.hi = ~a.hi;
return u2_add_u(a, 1);
}
static uint u2_clz(uint2 a) {
uint leading_zeroes = clz(a.hi);
if (leading_zeroes == 32) {
leading_zeroes += clz(a.lo);
}
return leading_zeroes;
}
static bool u2_eq(uint2 a, uint2 b) { return a.lo == b.lo && a.hi == b.hi; }
static bool u2_zero(uint2 a) { return u2_eq(a, u2_set_u(0)); }
static bool u2_gt(uint2 a, uint2 b) {
return a.hi > b.hi || (a.hi == b.hi && a.lo > b.lo);
}
_CLC_DEF _CLC_OVERLOAD float fma(float a, float b, float c) {
/* special cases */
if (isnan(a) || isnan(b) || isnan(c) || isinf(a) || isinf(b)) {
@@ -63,12 +149,9 @@ _CLC_DEF _CLC_OVERLOAD float fma(float a, float b, float c) {
st_b.exponent = b == .0f ? 0 : ((as_uint(b) & 0x7f800000) >> 23) - 127;
st_c.exponent = c == .0f ? 0 : ((as_uint(c) & 0x7f800000) >> 23) - 127;
st_a.mantissa.lo = a == .0f ? 0 : (as_uint(a) & 0x7fffff) | 0x800000;
st_b.mantissa.lo = b == .0f ? 0 : (as_uint(b) & 0x7fffff) | 0x800000;
st_c.mantissa.lo = c == .0f ? 0 : (as_uint(c) & 0x7fffff) | 0x800000;
st_a.mantissa.hi = 0;
st_b.mantissa.hi = 0;
st_c.mantissa.hi = 0;
st_a.mantissa = u2_set_u(a == .0f ? 0 : (as_uint(a) & 0x7fffff) | 0x800000);
st_b.mantissa = u2_set_u(b == .0f ? 0 : (as_uint(b) & 0x7fffff) | 0x800000);
st_c.mantissa = u2_set_u(c == .0f ? 0 : (as_uint(c) & 0x7fffff) | 0x800000);
st_a.sign = as_uint(a) & 0x80000000;
st_b.sign = as_uint(b) & 0x80000000;
@@ -81,15 +164,13 @@ _CLC_DEF _CLC_OVERLOAD float fma(float a, float b, float c) {
// add another bit to detect subtraction underflow
struct fp st_mul;
st_mul.sign = st_a.sign ^ st_b.sign;
st_mul.mantissa.hi = mul_hi(st_a.mantissa.lo, st_b.mantissa.lo);
st_mul.mantissa.lo = st_a.mantissa.lo * st_b.mantissa.lo;
uint upper_14bits = (st_mul.mantissa.lo >> 18) & 0x3fff;
st_mul.mantissa.lo <<= 14;
st_mul.mantissa.hi <<= 14;
st_mul.mantissa.hi |= upper_14bits;
st_mul.exponent = (st_mul.mantissa.lo != 0 || st_mul.mantissa.hi != 0)
? st_a.exponent + st_b.exponent
: 0;
st_mul.mantissa = u2_sll(u2_mul(st_a.mantissa.lo, st_b.mantissa.lo), 14);
st_mul.exponent =
!u2_zero(st_mul.mantissa) ? st_a.exponent + st_b.exponent : 0;
// FIXME: Detecting a == 0 || b == 0 above crashed GCN isel
if (st_mul.exponent == 0 && u2_zero(st_mul.mantissa))
return c;
// Mantissa is 23 fractional bits, shift it the same way as product mantissa
#define C_ADJUST 37ul
@@ -97,146 +178,80 @@ _CLC_DEF _CLC_OVERLOAD float fma(float a, float b, float c) {
// both exponents are bias adjusted
int exp_diff = st_mul.exponent - st_c.exponent;
uint abs_exp_diff = abs(exp_diff);
st_c.mantissa.hi = (st_c.mantissa.lo << 5);
st_c.mantissa.lo = 0;
uint2 cutoff_bits = (uint2)(0, 0);
uint2 cutoff_mask = (uint2)(0, 0);
if (abs_exp_diff < 32) {
cutoff_mask.lo = (1u << abs(exp_diff)) - 1u;
} else if (abs_exp_diff < 64) {
cutoff_mask.lo = 0xffffffff;
uint remaining = abs_exp_diff - 32;
cutoff_mask.hi = (1u << remaining) - 1u;
st_c.mantissa = u2_sll(st_c.mantissa, C_ADJUST);
uint2 cutoff_bits = u2_set_u(0);
uint2 cutoff_mask = u2_add(u2_sll(u2_set_u(1), abs(exp_diff)),
u2_set(0xffffffff, 0xffffffff));
if (exp_diff > 0) {
cutoff_bits =
exp_diff >= 64 ? st_c.mantissa : u2_and(st_c.mantissa, cutoff_mask);
st_c.mantissa =
exp_diff >= 64 ? u2_set_u(0) : u2_srl(st_c.mantissa, exp_diff);
} else {
cutoff_mask = (uint2)(0, 0);
cutoff_bits = -exp_diff >= 64 ? st_mul.mantissa
: u2_and(st_mul.mantissa, cutoff_mask);
st_mul.mantissa =
-exp_diff >= 64 ? u2_set_u(0) : u2_srl(st_mul.mantissa, -exp_diff);
}
uint2 tmp = (exp_diff > 0) ? st_c.mantissa : st_mul.mantissa;
if (abs_exp_diff > 0) {
cutoff_bits = abs_exp_diff >= 64 ? tmp : (tmp & cutoff_mask);
if (abs_exp_diff < 32) {
// shift some of the hi bits into the shifted lo bits.
uint shift_mask = (1u << abs_exp_diff) - 1;
uint upper_saved_bits = tmp.hi & shift_mask;
upper_saved_bits = upper_saved_bits << (32 - abs_exp_diff);
tmp.hi >>= abs_exp_diff;
tmp.lo >>= abs_exp_diff;
tmp.lo |= upper_saved_bits;
} else if (abs_exp_diff < 64) {
tmp.lo = (tmp.hi >> (abs_exp_diff - 32));
tmp.hi = 0;
} else {
tmp = (uint2)(0, 0);
}
}
if (exp_diff > 0)
st_c.mantissa = tmp;
else
st_mul.mantissa = tmp;
struct fp st_fma;
st_fma.sign = st_mul.sign;
st_fma.exponent = max(st_mul.exponent, st_c.exponent);
st_fma.mantissa = (uint2)(0, 0);
if (st_c.sign == st_mul.sign) {
uint carry = (hadd(st_mul.mantissa.lo, st_c.mantissa.lo) >> 31) & 0x1;
st_fma.mantissa = st_mul.mantissa + st_c.mantissa;
st_fma.mantissa.hi += carry;
st_fma.mantissa = u2_add(st_mul.mantissa, st_c.mantissa);
} else {
// cutoff bits borrow one
uint cutoff_borrow = ((cutoff_bits.lo != 0 || cutoff_bits.hi != 0) &&
(st_mul.exponent > st_c.exponent))
? 1
: 0;
uint borrow = 0;
if (st_c.mantissa.lo > st_mul.mantissa.lo) {
borrow = 1;
} else if (st_c.mantissa.lo == UINT_MAX && cutoff_borrow == 1) {
borrow = 1;
} else if ((st_c.mantissa.lo + cutoff_borrow) > st_mul.mantissa.lo) {
borrow = 1;
}
st_fma.mantissa.lo = st_mul.mantissa.lo - st_c.mantissa.lo - cutoff_borrow;
st_fma.mantissa.hi = st_mul.mantissa.hi - st_c.mantissa.hi - borrow;
st_fma.mantissa =
u2_add(u2_add(st_mul.mantissa, u2_inv(st_c.mantissa)),
(!u2_zero(cutoff_bits) && (st_mul.exponent > st_c.exponent)
? u2_set(0xffffffff, 0xffffffff)
: u2_set_u(0)));
}
// underflow: st_c.sign != st_mul.sign, and magnitude switches the sign
if (st_fma.mantissa.hi > INT_MAX) {
st_fma.mantissa = ~st_fma.mantissa;
uint carry = (hadd(st_fma.mantissa.lo, 1u) >> 31) & 0x1;
st_fma.mantissa.lo += 1;
st_fma.mantissa.hi += carry;
if (u2_gt(st_fma.mantissa, u2_set(0x7fffffff, 0xffffffff))) {
st_fma.mantissa = u2_inv(st_fma.mantissa);
st_fma.sign = st_mul.sign ^ 0x80000000;
}
// detect overflow/underflow
uint leading_zeroes = clz(st_fma.mantissa.hi);
if (leading_zeroes == 32) {
leading_zeroes += clz(st_fma.mantissa.lo);
}
int overflow_bits = 3 - leading_zeroes;
int overflow_bits = 3 - u2_clz(st_fma.mantissa);
// adjust exponent
st_fma.exponent += overflow_bits;
// handle underflow
if (overflow_bits < 0) {
uint shift = -overflow_bits;
if (shift < 32) {
uint shift_mask = (1u << shift) - 1;
uint saved_lo_bits = (st_fma.mantissa.lo >> (32 - shift)) & shift_mask;
st_fma.mantissa.lo <<= shift;
st_fma.mantissa.hi <<= shift;
st_fma.mantissa.hi |= saved_lo_bits;
} else if (shift < 64) {
st_fma.mantissa.hi = (st_fma.mantissa.lo << (64 - shift));
st_fma.mantissa.lo = 0;
} else {
st_fma.mantissa = (uint2)(0, 0);
}
st_fma.mantissa = u2_sll(st_fma.mantissa, -overflow_bits);
overflow_bits = 0;
}
// rounding
// overflow_bits is now in the range of [0, 3] making the shift greater than
// 32 bits.
uint2 trunc_mask;
uint trunc_shift = C_ADJUST + overflow_bits - 32;
trunc_mask.hi = (1u << trunc_shift) - 1;
trunc_mask.lo = UINT_MAX;
uint2 trunc_bits = st_fma.mantissa & trunc_mask;
trunc_bits.lo |= (cutoff_bits.hi != 0 || cutoff_bits.lo != 0) ? 1 : 0;
uint2 last_bit;
last_bit.lo = 0;
last_bit.hi = st_fma.mantissa.hi & (1u << trunc_shift);
uint grs_shift = C_ADJUST - 3 + overflow_bits - 32;
uint2 grs_bits;
grs_bits.lo = 0;
grs_bits.hi = 0x4u << grs_shift;
uint2 trunc_mask = u2_add(u2_sll(u2_set_u(1), C_ADJUST + overflow_bits),
u2_set(0xffffffff, 0xffffffff));
uint2 trunc_bits =
u2_or(u2_and(st_fma.mantissa, trunc_mask), !u2_zero(cutoff_bits));
uint2 last_bit =
u2_and(st_fma.mantissa, u2_sll(u2_set_u(1), C_ADJUST + overflow_bits));
uint2 grs_bits = u2_sll(u2_set_u(4), C_ADJUST - 3 + overflow_bits);
// round to nearest even
if ((trunc_bits.hi > grs_bits.hi ||
(trunc_bits.hi == grs_bits.hi && trunc_bits.lo > grs_bits.lo)) ||
(trunc_bits.hi == grs_bits.hi && trunc_bits.lo == grs_bits.lo &&
last_bit.hi != 0)) {
uint shift = C_ADJUST + overflow_bits - 32;
st_fma.mantissa.hi += 1u << shift;
if (u2_gt(trunc_bits, grs_bits) ||
(u2_eq(trunc_bits, grs_bits) && !u2_zero(last_bit))) {
st_fma.mantissa =
u2_add(st_fma.mantissa, u2_sll(u2_set_u(1), C_ADJUST + overflow_bits));
}
// Shift mantissa back to bit 23
st_fma.mantissa.lo = (st_fma.mantissa.hi >> (C_ADJUST + overflow_bits - 32));
st_fma.mantissa.hi = 0;
// Shift mantissa back to bit 23
st_fma.mantissa = u2_srl(st_fma.mantissa, C_ADJUST + overflow_bits);
// Detect rounding overflow
if (st_fma.mantissa.lo > 0xffffff) {
if (u2_gt(st_fma.mantissa, u2_set_u(0xffffff))) {
++st_fma.exponent;
st_fma.mantissa.lo >>= 1;
st_fma.mantissa = u2_srl(st_fma.mantissa, 1);
}
if (st_fma.mantissa.lo == 0) {
if (u2_zero(st_fma.mantissa)) {
return 0.0f;
}

View File

@@ -0,0 +1,135 @@
#include <clc/clc.h>
#pragma OPENCL EXTENSION cl_khr_byte_addressable_store : enable
#define ROUND_VEC1(out, in, ROUNDF) out = ROUNDF(in);
#define ROUND_VEC2(out, in, ROUNDF) \
ROUND_VEC1(out.lo, in.lo, ROUNDF); \
ROUND_VEC1(out.hi, in.hi, ROUNDF);
#define ROUND_VEC3(out, in, ROUNDF) \
ROUND_VEC1(out.s0, in.s0, ROUNDF); \
ROUND_VEC1(out.s1, in.s1, ROUNDF); \
ROUND_VEC1(out.s2, in.s2, ROUNDF);
#define ROUND_VEC4(out, in, ROUNDF) \
ROUND_VEC2(out.lo, in.lo, ROUNDF); \
ROUND_VEC2(out.hi, in.hi, ROUNDF);
#define ROUND_VEC8(out, in, ROUNDF) \
ROUND_VEC4(out.lo, in.lo, ROUNDF); \
ROUND_VEC4(out.hi, in.hi, ROUNDF);
#define ROUND_VEC16(out, in, ROUNDF) \
ROUND_VEC8(out.lo, in.lo, ROUNDF); \
ROUND_VEC8(out.hi, in.hi, ROUNDF);
#define __FUNC(SUFFIX, VEC_SIZE, TYPE, AS, ROUNDF) \
void _CLC_OVERLOAD vstore_half_##VEC_SIZE(TYPE, size_t, AS half *); \
_CLC_OVERLOAD _CLC_DEF void vstore_half##SUFFIX(TYPE vec, size_t offset, \
AS half *mem) { \
TYPE rounded_vec; \
ROUND_VEC##VEC_SIZE(rounded_vec, vec, ROUNDF); \
vstore_half_##VEC_SIZE(rounded_vec, offset, mem); \
} \
void _CLC_OVERLOAD vstorea_half_##VEC_SIZE(TYPE, size_t, AS half *); \
_CLC_OVERLOAD _CLC_DEF void vstorea_half##SUFFIX(TYPE vec, size_t offset, \
AS half *mem) { \
TYPE rounded_vec; \
ROUND_VEC##VEC_SIZE(rounded_vec, vec, ROUNDF); \
vstorea_half_##VEC_SIZE(rounded_vec, offset, mem); \
}
_CLC_DEF _CLC_OVERLOAD float __clc_rtz(float x) {
/* Handle nan corner case */
if (isnan(x))
return x;
/* RTZ does not produce Inf for large numbers */
if (fabs(x) > 65504.0f && !isinf(x))
return copysign(65504.0f, x);
const int exp = (as_uint(x) >> 23 & 0xff) - 127;
/* Manage range rounded to +- zero explicitely */
if (exp < -24)
return copysign(0.0f, x);
/* Remove lower 13 bits to make sure the number is rounded down */
int mask = 0xffffe000;
/* Denormals cannot be flushed, and they use different bit for rounding */
if (exp < -14)
mask <<= min(-(exp + 14), 10);
return as_float(as_uint(x) & mask);
}
_CLC_DEF _CLC_OVERLOAD float __clc_rti(float x) {
/* Handle nan corner case */
if (isnan(x))
return x;
const float inf = copysign(INFINITY, x);
uint ux = as_uint(x);
/* Manage +- infinity explicitely */
if (as_float(ux & 0x7fffffff) > 0x1.ffcp+15f) {
return inf;
}
/* Manage +- zero explicitely */
if ((ux & 0x7fffffff) == 0) {
return copysign(0.0f, x);
}
const int exp = (as_uint(x) >> 23 & 0xff) - 127;
/* Manage range rounded to smallest half denormal explicitely */
if (exp < -24) {
return copysign(0x1.0p-24f, x);
}
/* Set lower 13 bits */
int mask = (1 << 13) - 1;
/* Denormals cannot be flushed, and they use different bit for rounding */
if (exp < -14) {
mask = (1 << (13 + min(-(exp + 14), 10))) - 1;
}
const float next = nextafter(as_float(ux | mask), inf);
return ((ux & mask) == 0) ? as_float(ux) : next;
}
_CLC_DEF _CLC_OVERLOAD float __clc_rtn(float x) {
return ((as_uint(x) & 0x80000000) == 0) ? __clc_rtz(x) : __clc_rti(x);
}
_CLC_DEF _CLC_OVERLOAD float __clc_rtp(float x) {
return ((as_uint(x) & 0x80000000) == 0) ? __clc_rti(x) : __clc_rtz(x);
}
_CLC_DEF _CLC_OVERLOAD float __clc_rte(float x) {
/* Mantisa + implicit bit */
const uint mantissa = (as_uint(x) & 0x7fffff) | (1u << 23);
const int exp = (as_uint(x) >> 23 & 0xff) - 127;
int shift = 13;
if (exp < -14) {
/* The default assumes lower 13 bits are rounded,
* but it might be more for denormals.
* Shifting beyond last == 0b, and qr == 00b is not necessary */
shift += min(-(exp + 14), 15);
}
int mask = (1 << shift) - 1;
const uint grs = mantissa & mask;
const uint last = mantissa & (1 << shift);
/* IEEE round up rule is: grs > 101b or grs == 100b and last == 1.
* exp > 15 should round to inf. */
bool roundup = (grs > (1 << (shift - 1))) ||
(grs == (1 << (shift - 1)) && last != 0) || (exp > 15);
return roundup ? __clc_rti(x) : __clc_rtz(x);
}
#define __XFUNC(SUFFIX, VEC_SIZE, TYPE, AS) \
__FUNC(SUFFIX, VEC_SIZE, TYPE, AS, __clc_rte) \
__FUNC(SUFFIX##_rtz, VEC_SIZE, TYPE, AS, __clc_rtz) \
__FUNC(SUFFIX##_rtn, VEC_SIZE, TYPE, AS, __clc_rtn) \
__FUNC(SUFFIX##_rtp, VEC_SIZE, TYPE, AS, __clc_rtp) \
__FUNC(SUFFIX##_rte, VEC_SIZE, TYPE, AS, __clc_rte)
#define FUNC(SUFFIX, VEC_SIZE, TYPE, AS) __XFUNC(SUFFIX, VEC_SIZE, TYPE, AS)
#define __CLC_BODY "vstore_half.inc"
#include <clc/math/gentype.inc>
#undef __CLC_BODY
#undef FUNC
#undef __XFUNC
#undef __FUNC

View File

@@ -0,0 +1,15 @@
// This does exist only for fp32
#if __CLC_FPSIZE == 32
#ifdef __CLC_VECSIZE
FUNC(__CLC_VECSIZE, __CLC_VECSIZE, __CLC_GENTYPE, __private);
FUNC(__CLC_VECSIZE, __CLC_VECSIZE, __CLC_GENTYPE, __local);
FUNC(__CLC_VECSIZE, __CLC_VECSIZE, __CLC_GENTYPE, __global);
#undef __CLC_OFFSET
#else
FUNC(, 1, __CLC_GENTYPE, __private);
FUNC(, 1, __CLC_GENTYPE, __local);
FUNC(, 1, __CLC_GENTYPE, __global);
#endif
#endif

View File

@@ -4,9 +4,11 @@
// avoid inlines for SPIR-V related targets since we'll optimise later in the
// chain
#if defined(CLC_SPIRV) || defined(CLC_SPIRV64) || defined(CLC_CLSPV) || \
defined(CLC_CLSPV64)
#if defined(CLC_SPIRV) || defined(CLC_SPIRV64)
#define _CLC_DEF
#elif defined(CLC_CLSPV) || defined(CLC_CLSPV64)
#define _CLC_DEF \
__attribute__((noinline)) __attribute__((assume("clspv_libclc_builtin")))
#else
#define _CLC_DEF __attribute__((always_inline))
#endif