[mlir][spirv] Add definition for GL Length (#144041)
A canonicalization pattern from `spirv.GL.Length` to `spirv.GL.FAbs` for scalar operands is also added.
This commit is contained in:
@@ -1160,6 +1160,46 @@ def SPIRV_GLFMixOp :
|
||||
|
||||
// -----
|
||||
|
||||
def SPIRV_GLLengthOp : SPIRV_GLOp<"Length", 66, [
|
||||
Pure,
|
||||
TypesMatchWith<"result type must match operand element type",
|
||||
"operand", "result",
|
||||
"::mlir::getElementTypeOrSelf($_self)">
|
||||
]> {
|
||||
let summary = "Return the length of a vector x";
|
||||
|
||||
let description = [{
|
||||
Result is the length of vector x, i.e., sqrt(x[0]**2 + x[1]**2 + ...).
|
||||
|
||||
The operand x must be a scalar or vector whose component type is floating-point.
|
||||
|
||||
Result Type must be a scalar of the same type as the component type of x.
|
||||
|
||||
#### Example:
|
||||
|
||||
```mlir
|
||||
%2 = spirv.GL.Length %0 : vector<3xf32> -> f32
|
||||
%3 = spirv.GL.Length %1 : f32 -> f32
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SPIRV_ScalarOrVectorOf<SPIRV_Float>:$operand
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPIRV_Float:$result
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$operand attr-dict `:` type($operand) `->` type($result)
|
||||
}];
|
||||
|
||||
let hasVerifier = 0;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
def SPIRV_GLDistanceOp : SPIRV_GLOp<"Distance", 67, [
|
||||
Pure,
|
||||
AllTypesMatch<["p0", "p1"]>,
|
||||
|
||||
@@ -75,3 +75,11 @@ def ConvertComparisonIntoClamp2_#CmpClampPair[0] : Pat<
|
||||
)),
|
||||
(CmpClampPair[1] $input, $min, $max)>;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.GL.Length -> spirv.GL.FAbs
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConvertGLLengthToGLFAbs : Pat<
|
||||
(SPIRV_GLLengthOp SPIRV_Float:$operand),
|
||||
(SPIRV_GLFAbsOp $operand)>;
|
||||
|
||||
@@ -34,8 +34,8 @@ void populateSPIRVGLCanonicalizationPatterns(RewritePatternSet &results) {
|
||||
ConvertComparisonIntoClamp2_SPIRV_SLessThanOp,
|
||||
ConvertComparisonIntoClamp2_SPIRV_SLessThanEqualOp,
|
||||
ConvertComparisonIntoClamp2_SPIRV_ULessThanOp,
|
||||
ConvertComparisonIntoClamp2_SPIRV_ULessThanEqualOp>(
|
||||
results.getContext());
|
||||
ConvertComparisonIntoClamp2_SPIRV_ULessThanEqualOp,
|
||||
ConvertGLLengthToGLFAbs>(results.getContext());
|
||||
}
|
||||
} // namespace spirv
|
||||
} // namespace mlir
|
||||
|
||||
@@ -1000,3 +1000,69 @@ func.func @unpack_half_2x16_scalar_out(%arg0 : i32) -> () {
|
||||
%0 = spirv.GL.UnpackHalf2x16 %arg0 : i32 -> f32
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.GL.Length
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func.func @length(%arg0 : f32) -> () {
|
||||
// CHECK: spirv.GL.Length {{%.*}} : f32 -> f32
|
||||
%0 = spirv.GL.Length %arg0 : f32 -> f32
|
||||
return
|
||||
}
|
||||
|
||||
func.func @lengthvec(%arg0 : vector<3xf32>) -> () {
|
||||
// CHECK: spirv.GL.Length {{%.*}} : vector<3xf32> -> f32
|
||||
%0 = spirv.GL.Length %arg0 : vector<3xf32> -> f32
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @length_i32_in(%arg0 : i32) -> () {
|
||||
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'i32'}}
|
||||
%0 = spirv.GL.Length %arg0 : i32 -> f32
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @length_f16_in(%arg0 : f16) -> () {
|
||||
// expected-error @+1 {{op failed to verify that result type must match operand element type}}
|
||||
%0 = spirv.GL.Length %arg0 : f16 -> f32
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @length_i32vec_in(%arg0 : vector<3xi32>) -> () {
|
||||
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}}
|
||||
%0 = spirv.GL.Length %arg0 : vector<3xi32> -> f32
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @length_f16vec_in(%arg0 : vector<3xf16>) -> () {
|
||||
// expected-error @+1 {{op failed to verify that result type must match operand element type}}
|
||||
%0 = spirv.GL.Length %arg0 : vector<3xf16> -> f32
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @length_i32_out(%arg0 : vector<3xf32>) -> () {
|
||||
// expected-error @+1 {{op result #0 must be 16/32/64-bit float, but got 'i32'}}
|
||||
%0 = spirv.GL.Length %arg0 : vector<3xf32> -> i32
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @length_vec_out(%arg0 : vector<3xf32>) -> () {
|
||||
// expected-error @+1 {{op result #0 must be 16/32/64-bit float, but got 'vector<3xf32>'}}
|
||||
%0 = spirv.GL.Length %arg0 : vector<3xf32> -> vector<3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -177,3 +177,25 @@ func.func @clamp_ulessthanequal(%input: i32, %min: i32, %max: i32) -> i32 {
|
||||
// CHECK-NEXT: spirv.ReturnValue [[RES]]
|
||||
spirv.ReturnValue %2 : i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spirv.GL.Length
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: @convert_length_into_fabs_scalar
|
||||
func.func @convert_length_into_fabs_scalar(%arg0 : f32) -> f32 {
|
||||
//CHECK: spirv.GL.FAbs {{%.*}} : f32
|
||||
//CHECK-NOT: spirv.GL.Length
|
||||
%0 = spirv.GL.Length %arg0 : f32 -> f32
|
||||
spirv.ReturnValue %0 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @dont_convert_length_into_fabs_vec
|
||||
func.func @dont_convert_length_into_fabs_vec(%arg0 : vector<3xf32>) -> f32 {
|
||||
//CHECK: spirv.GL.Length {{%.*}} : vector<3xf32> -> f32
|
||||
//CHECK-NOT: spirv.GL.FAbs
|
||||
%0 = spirv.GL.Length %arg0 : vector<3xf32> -> f32
|
||||
spirv.ReturnValue %0 : f32
|
||||
}
|
||||
|
||||
@@ -128,6 +128,10 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
|
||||
%8 = spirv.GL.FindSMsb %arg3 : vector<3xi32>
|
||||
// CHECK: {{%.*}} = spirv.GL.FindUMsb {{%.*}} : vector<3xi32>
|
||||
%9 = spirv.GL.FindUMsb %arg3 : vector<3xi32>
|
||||
// CHECK: {{%.*}} = spirv.GL.Length {{%.*}} : f32 -> f32
|
||||
%10 = spirv.GL.Length %arg0 : f32 -> f32
|
||||
// CHECK: {{%.*}} = spirv.GL.Length {{%.*}} : vector<3xf32> -> f32
|
||||
%11 = spirv.GL.Length %arg1 : vector<3xf32> -> f32
|
||||
spirv.Return
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user