[NFC][libclc] Simplify clc_dot and dot implementation (#142922)

llvm-diff shows no change to amdgcn--amdhsa.bc
This commit is contained in:
Wenju He
2025-06-06 00:09:53 +00:00
committed by GitHub
parent 16b0d2f910
commit de3a9ea510
3 changed files with 31 additions and 112 deletions

View File

@@ -7,59 +7,7 @@
//===----------------------------------------------------------------------===//
#include <clc/internal/clc.h>
#include <clc/math/clc_fma.h>
_CLC_OVERLOAD _CLC_DEF float __clc_dot(float p0, float p1) { return p0 * p1; }
_CLC_OVERLOAD _CLC_DEF float __clc_dot(float2 p0, float2 p1) {
return p0.x * p1.x + p0.y * p1.y;
}
_CLC_OVERLOAD _CLC_DEF float __clc_dot(float3 p0, float3 p1) {
return p0.x * p1.x + p0.y * p1.y + p0.z * p1.z;
}
_CLC_OVERLOAD _CLC_DEF float __clc_dot(float4 p0, float4 p1) {
return p0.x * p1.x + p0.y * p1.y + p0.z * p1.z + p0.w * p1.w;
}
#ifdef cl_khr_fp64
#pragma OPENCL EXTENSION cl_khr_fp64 : enable
_CLC_OVERLOAD _CLC_DEF double __clc_dot(double p0, double p1) {
return p0 * p1;
}
_CLC_OVERLOAD _CLC_DEF double __clc_dot(double2 p0, double2 p1) {
return p0.x * p1.x + p0.y * p1.y;
}
_CLC_OVERLOAD _CLC_DEF double __clc_dot(double3 p0, double3 p1) {
return p0.x * p1.x + p0.y * p1.y + p0.z * p1.z;
}
_CLC_OVERLOAD _CLC_DEF double __clc_dot(double4 p0, double4 p1) {
return p0.x * p1.x + p0.y * p1.y + p0.z * p1.z + p0.w * p1.w;
}
#endif
#ifdef cl_khr_fp16
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
_CLC_OVERLOAD _CLC_DEF half __clc_dot(half p0, half p1) { return p0 * p1; }
_CLC_OVERLOAD _CLC_DEF half __clc_dot(half2 p0, half2 p1) {
return p0.x * p1.x + p0.y * p1.y;
}
_CLC_OVERLOAD _CLC_DEF half __clc_dot(half3 p0, half3 p1) {
return p0.x * p1.x + p0.y * p1.y + p0.z * p1.z;
}
_CLC_OVERLOAD _CLC_DEF half __clc_dot(half4 p0, half4 p1) {
return p0.x * p1.x + p0.y * p1.y + p0.z * p1.z + p0.w * p1.w;
}
#endif
#define __CLC_BODY <clc_dot.inc>
#include <clc/math/gentype.inc>

View File

@@ -0,0 +1,25 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#if (__CLC_VECSIZE_OR_1 == 1 || __CLC_VECSIZE_OR_1 == 2 || \
__CLC_VECSIZE_OR_1 == 3 || __CLC_VECSIZE_OR_1 == 4)
_CLC_OVERLOAD _CLC_DEF __CLC_SCALAR_GENTYPE __clc_dot(__CLC_GENTYPE x,
__CLC_GENTYPE y) {
#if __CLC_VECSIZE_OR_1 == 1
return x * y;
#elif __CLC_VECSIZE_OR_1 == 2
return x.s0 * y.s0 + x.s1 * y.s1;
#elif __CLC_VECSIZE_OR_1 == 3
return x.s0 * y.s0 + x.s1 * y.s1 + x.s2 * y.s2;
#else
return x.s0 * y.s0 + x.s1 * y.s1 + x.s2 * y.s2 + x.s3 * y.s3;
#endif
}
#endif

View File

@@ -9,60 +9,6 @@
#include <clc/geometric/clc_dot.h>
#include <clc/opencl/clc.h>
_CLC_OVERLOAD _CLC_DEF float dot(float p0, float p1) {
return __clc_dot(p0, p1);
}
_CLC_OVERLOAD _CLC_DEF float dot(float2 p0, float2 p1) {
return __clc_dot(p0, p1);
}
_CLC_OVERLOAD _CLC_DEF float dot(float3 p0, float3 p1) {
return __clc_dot(p0, p1);
}
_CLC_OVERLOAD _CLC_DEF float dot(float4 p0, float4 p1) {
return __clc_dot(p0, p1);
}
#ifdef cl_khr_fp64
#pragma OPENCL EXTENSION cl_khr_fp64 : enable
_CLC_OVERLOAD _CLC_DEF double dot(double p0, double p1) {
return __clc_dot(p0, p1);
}
_CLC_OVERLOAD _CLC_DEF double dot(double2 p0, double2 p1) {
return __clc_dot(p0, p1);
}
_CLC_OVERLOAD _CLC_DEF double dot(double3 p0, double3 p1) {
return __clc_dot(p0, p1);
}
_CLC_OVERLOAD _CLC_DEF double dot(double4 p0, double4 p1) {
return __clc_dot(p0, p1);
}
#endif
#ifdef cl_khr_fp16
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
_CLC_OVERLOAD _CLC_DEF half dot(half p0, half p1) { return __clc_dot(p0, p1); }
_CLC_OVERLOAD _CLC_DEF half dot(half2 p0, half2 p1) {
return __clc_dot(p0, p1);
}
_CLC_OVERLOAD _CLC_DEF half dot(half3 p0, half3 p1) {
return __clc_dot(p0, p1);
}
_CLC_OVERLOAD _CLC_DEF half dot(half4 p0, half4 p1) {
return __clc_dot(p0, p1);
}
#endif
#define FUNCTION dot
#define __CLC_BODY <clc/geometric/binary_def.inc>
#include <clc/math/gentype.inc>