[mlir][linalg] Add quantized conv2d operator with FCHW,NCHW order (#107740)
This patch adds a quantized version of the `linalg.conv2d_nchw_fchw` Op. This is the "channel-first" ordering typically used by PyTorch and others.
This commit is contained in:
@@ -3114,6 +3114,143 @@ structured_op: !LinalgStructuredOpConfig
|
||||
- !ScalarExpression
|
||||
scalar_arg: KZp
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: conv_2d_nchw_fchw_q
|
||||
cpp_class_name: Conv2DNchwFchwQOp
|
||||
doc: |-
|
||||
Performs 2-D convolution with zero point offsets.
|
||||
|
||||
Layout:
|
||||
* Input: NCHW.
|
||||
* Kernel: FCHW.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output. This includes the zero
|
||||
point offsets common to quantized operations.
|
||||
implements:
|
||||
- LinalgConvolutionOpInterface
|
||||
structured_op: !LinalgStructuredOpConfig
|
||||
args:
|
||||
- !LinalgOperandDefConfig
|
||||
name: I
|
||||
kind: input_tensor
|
||||
type_var: T1
|
||||
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
|
||||
s1, s2 * s3 + s4 * s5, s6 * s7 + s8 * s9)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: K
|
||||
kind: input_tensor
|
||||
type_var: T2
|
||||
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s10,
|
||||
s1, s4, s8)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: IZp
|
||||
kind: scalar
|
||||
type_var: I32
|
||||
- !LinalgOperandDefConfig
|
||||
name: KZp
|
||||
kind: scalar
|
||||
type_var: I32
|
||||
- !LinalgOperandDefConfig
|
||||
name: O
|
||||
kind: output_tensor
|
||||
type_var: U
|
||||
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
|
||||
s10, s2, s6)>
|
||||
- !LinalgOperandDefConfig
|
||||
name: strides
|
||||
kind: index_attr
|
||||
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
|
||||
(s3, s7)>
|
||||
default_indices:
|
||||
- 1
|
||||
- 1
|
||||
- !LinalgOperandDefConfig
|
||||
name: dilations
|
||||
kind: index_attr
|
||||
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
|
||||
(s5, s9)>
|
||||
default_indices:
|
||||
- 1
|
||||
- 1
|
||||
indexing_maps: !LinalgIndexingMapsConfig
|
||||
static_indexing_maps:
|
||||
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
|
||||
s9, s10] -> (d0, d4, d2 * s3 + d5 * s5, d3 * s7 + d6 * s9)>
|
||||
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
|
||||
s9, s10] -> (d1, d4, d5, d6)>
|
||||
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
|
||||
s9, s10] -> ()>
|
||||
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
|
||||
s9, s10] -> ()>
|
||||
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
|
||||
s9, s10] -> (d0, d1, d2, d3)>
|
||||
iterator_types:
|
||||
- parallel
|
||||
- parallel
|
||||
- parallel
|
||||
- parallel
|
||||
- reduction
|
||||
- reduction
|
||||
- reduction
|
||||
assignments:
|
||||
- !ScalarAssign
|
||||
arg: O
|
||||
value: !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: binary
|
||||
fn_name: add
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: O
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: binary
|
||||
fn_name: mul
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: binary
|
||||
fn_name: sub
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: I
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: IZp
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: binary
|
||||
fn_name: sub
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: K
|
||||
- !ScalarExpression
|
||||
scalar_fn:
|
||||
kind: type
|
||||
fn_name: cast_signed
|
||||
type_var: U
|
||||
operands:
|
||||
- !ScalarExpression
|
||||
scalar_arg: KZp
|
||||
--- !LinalgOpConfig
|
||||
metadata: !LinalgOpMetadata
|
||||
name: conv_2d_nchw_fchw
|
||||
cpp_class_name: Conv2DNchwFchwOp
|
||||
|
||||
@@ -876,6 +876,35 @@ def conv_2d_nhwc_fhwc_q(
|
||||
) * (TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c]) - TypeFn.cast_signed(U, KZp))
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def conv_2d_nchw_fchw_q(
|
||||
I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
|
||||
K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
|
||||
IZp=ScalarDef(I32),
|
||||
KZp=ScalarDef(I32),
|
||||
O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True),
|
||||
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
|
||||
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
|
||||
):
|
||||
"""Performs 2-D convolution with zero point offsets.
|
||||
|
||||
Layout:
|
||||
* Input: NCHW.
|
||||
* Kernel: FCHW.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output. This includes the zero
|
||||
point offsets common to quantized operations.
|
||||
"""
|
||||
implements(ConvolutionOpInterface)
|
||||
domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
|
||||
O[D.n, D.f, D.oh, D.ow] += (
|
||||
TypeFn.cast_signed(
|
||||
U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
|
||||
)
|
||||
- TypeFn.cast_signed(U, IZp)
|
||||
) * (TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) - TypeFn.cast_signed(U, KZp))
|
||||
|
||||
@linalg_structured_op
|
||||
def conv_2d_nchw_fchw(
|
||||
I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
|
||||
|
||||
@@ -664,3 +664,33 @@ func.func @winograd_output_dyn(%arg0: tensor<6x6x?x?x?x?xf32>, %arg1: tensor<?x?
|
||||
|
||||
// CHECK-LABEL: func @winograd_output_dyn
|
||||
// CHECK: linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x?x?x?x?xf32>) outs(%arg1 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func.func @conv2d_channel_first_q(%img: tensor<100x3x224x224xi32>, %filt: tensor<64x3x5x5xi32>, %a: i32, %b: i32) -> tensor<100x64x220x220xi32> {
|
||||
%init = arith.constant dense<0> : tensor<100x64x220x220xi32>
|
||||
%1 = linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>,
|
||||
strides = dense<1> : tensor<2xi64>}
|
||||
ins(%img, %filt, %a, %b : tensor<100x3x224x224xi32>, tensor<64x3x5x5xi32>, i32, i32)
|
||||
outs(%init : tensor<100x64x220x220xi32>) -> tensor<100x64x220x220xi32>
|
||||
return %1 : tensor<100x64x220x220xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @conv2d_channel_first_q(
|
||||
// CHECK: %[[arg0:[a-zA-z0-9]*]]: tensor<100x3x224x224xi32>, %[[arg1:[a-zA-z0-9]*]]: tensor<64x3x5x5xi32>, %[[arg2:[a-zA-z0-9]*]]: i32, %[[arg3:[a-zA-z0-9]*]]: i32)
|
||||
// CHECK: linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]] : tensor<100x3x224x224xi32>, tensor<64x3x5x5xi32>, i32, i32) outs(%{{.*}} : tensor<100x64x220x220xi32>) -> tensor<100x64x220x220xi32>
|
||||
|
||||
// -----
|
||||
|
||||
func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt: tensor<64x3x5x5xi8>, %a: i8, %b: i8) -> tensor<100x64x220x220xi32> {
|
||||
%init = arith.constant dense<0> : tensor<100x64x220x220xi32>
|
||||
%1 = linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>,
|
||||
strides = dense<1> : tensor<2xi64>}
|
||||
ins(%img, %filt, %a, %b : tensor<100x3x224x224xi8>, tensor<64x3x5x5xi8>, i8, i8)
|
||||
outs(%init : tensor<100x64x220x220xi32>) -> tensor<100x64x220x220xi32>
|
||||
return %1 : tensor<100x64x220x220xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @conv2d_channel_first_q_promote(
|
||||
// CHECK: %[[arg0:[a-zA-z0-9]*]]: tensor<100x3x224x224xi8>, %[[arg1:[a-zA-z0-9]*]]: tensor<64x3x5x5xi8>, %[[arg2:[a-zA-z0-9]*]]: i8, %[[arg3:[a-zA-z0-9]*]]: i8)
|
||||
// CHECK: linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]] : tensor<100x3x224x224xi8>, tensor<64x3x5x5xi8>, i8, i8) outs(%{{.*}} : tensor<100x64x220x220xi32>) -> tensor<100x64x220x220xi32>
|
||||
|
||||
Reference in New Issue
Block a user