From a522c227a1d7d5dd4cd855a5fe4460193faf0856 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrzej=20Warzy=C5=84ski?= Date: Wed, 26 Feb 2025 07:48:01 +0000 Subject: [PATCH] [mlir][vector] Move tests for `rewriteAlignedSubByteInt{Ext|Trunc}` (nfc) (#126416) Moves tests for `rewriteAlignedSubByteIntExt` and `rewriteAlignedSubByteIntTrunc` into a dedicated files. Also adds + fixes some comments. This is merely for better organisation and so that it's easier to identify the patterns and edge cases being tested. --- .../Transforms/VectorEmulateNarrowType.cpp | 10 +- .../Vector/vector-rewrite-narrow-types.mlir | 377 ---------------- ...vector-rewrite-subbyte-ext-and-trunci.mlir | 415 ++++++++++++++++++ 3 files changed, 420 insertions(+), 382 deletions(-) create mode 100644 mlir/test/Dialect/Vector/vector-rewrite-subbyte-ext-and-trunci.mlir diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 5d8a525ac87f..51e72753ff16 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -7,8 +7,8 @@ //===----------------------------------------------------------------------===// // // This file implements target-independent rewrites and utilities to emulate -// narrow types that are not supported by the target hardware, e.g. i4, using -// wider types, e.g. i8. +// narrow types that are not supported by the target hardware, e.g. i4 +// ("emulated type"), using wider types, e.g. i8 ("container type"). // /// Currently, only power-of-two integer types are supported. These are /// converted to wider integers that are either 8 bits wide or wider. @@ -2063,19 +2063,19 @@ void vector::populateVectorNarrowTypeRewritePatterns( // Patterns for aligned cases. We set higher priority as they are expected to // generate better performance for aligned cases. - // The emulated type is always i8. + // The container type is always i8. patterns.add, RewriteAlignedSubByteIntExt, RewriteAlignedSubByteIntTrunc>(patterns.getContext(), benefit.getBenefit() + 1); - // The emulated type is always i8. + // The container type is always i8. patterns .add, RewriteAlignedSubByteIntExt>( patterns.getContext(), benefit.getBenefit() + 1); } -// The emulated type is always i8. +// The container type is always i8. void vector::populateVectorTransposeNarrowTypeRewritePatterns( RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(patterns.getContext(), benefit); diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir index 8d28f248e392..a4af307b15da 100644 --- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir +++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir @@ -193,382 +193,6 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> { return %1 : vector<8xi17> } - -// Negative test - the trailing dim 1 is not a multiple of 2 (i.e. 8 / 4). -// CHECK-LABEL: func.func @unaligned_extsi_i4_to_i8( -func.func @unaligned_extsi_i4_to_i8(%a: vector<1xi4>) -> vector<1xi8> { - // CHECK-NOT: arith.bitcast - // CHECK: arith.extsi %[[IN:.*]] : vector<1xi4> to vector<1xi8> - %0 = arith.extsi %a : vector<1xi4> to vector<1xi8> - return %0 : vector<1xi8> -} - -// Negative test - the trailing dim 2 is not a multiple of 4 (i.e. 8 / 2). -// CHECK-LABEL: func.func @unaligned_extsi_i2_to_i8( -func.func @unaligned_extsi_i2_to_i8(%a: vector<2xi2>) -> vector<2xi8> { - // CHECK-NOT: arith.bitcast - // CHECK: arith.extsi %[[IN:.*]] : vector<2xi2> to vector<2xi8> - %0 = arith.extsi %a : vector<2xi2> to vector<2xi8> - return %0 : vector<2xi8> -} - -// CHECK-LABEL: func.func @aligned_extsi_i4_to_i8( -func.func @aligned_extsi_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> { -// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> { -// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> -// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> -// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> -// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8> -// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> -// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8> - %0 = arith.extsi %a : vector<8xi4> to vector<8xi8> - return %0 : vector<8xi8> -} - -// CHECK-LABEL: func.func @aligned_extsi_i2_to_i8( -func.func @aligned_extsi_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> { -// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> { -// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8> -// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8> -// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8> -// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8> -// Extract bits 0-1 -// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8> -// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8> -// Extract bits 2-3 -// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8> -// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8> -// Extract bits 4-5 -// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8> -// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8> -// Extract bits 6-7 -// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8> -// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8> -// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8> -// CHECK: %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8> - %0 = arith.extsi %a : vector<8xi2> to vector<8xi8> - return %0 : vector<8xi8> -} - -// CHECK-LABEL: func.func @aligned_extsi_i4_to_i32( -func.func @aligned_extsi_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> { -// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> { -// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> -// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> -// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> -// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8> -// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> -// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8> -// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32> - %0 = arith.extsi %a : vector<8xi4> to vector<8xi32> - return %0 : vector<8xi32> -} - -// CHECK-LABEL: func.func @aligned_extsi_i2_to_i32( -func.func @aligned_extsi_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> { -// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> { -// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8> -// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8> -// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8> -// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8> -// Extract bits 0-1 -// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8> -// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8> -// Extract bits 2-3 -// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8> -// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8> -// Extract bits 4-5 -// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8> -// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8> -// Extract bits 6-7 -// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8> -// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8> -// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8> -// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8> -// CHECK: %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32> - %0 = arith.extsi %a : vector<8xi2> to vector<8xi32> - return %0 : vector<8xi32> -} - -// CHECK-LABEL: func.func @aligned_extsi_i4_to_i32_2d( -func.func @aligned_extsi_i4_to_i32_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> { -// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> { -// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8> -// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8> -// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> -// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8> -// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> -// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8> -// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32> - %0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32> - return %0 : vector<8x32xi32> -} - -// CHECK-LABEL: func.func @aligned_extsi_i2_to_i32_2d( -func.func @aligned_extsi_i2_to_i32_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> { -// CHECK-SAME: %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> { -// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8> -// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8> -// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8> -// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8> -// Extract bits 0-1 -// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<8x8xi8> -// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<8x8xi8> -// Extract bits 2-3 -// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<8x8xi8> -// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<8x8xi8> -// Extract bits 4-5 -// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<8x8xi8> -// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<8x8xi8> -// Extract bits 6-7 -// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<8x8xi8> -// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8> -// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8> -// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8> -// CHECK: %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32> - %0 = arith.extsi %a : vector<8x32xi2> to vector<8x32xi32> - return %0 : vector<8x32xi32> -} - - -// CHECK-LABEL: func.func @aligned_trunci_i8_to_i4( -func.func @aligned_trunci_i8_to_i4(%a: vector<8xi8>) -> vector<8xi4> { -// CHECK-SAME: %[[IN:.*]]: vector<8xi8>) -> vector<8xi4> { -// CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8> -// CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> -// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[IN]] : vector<8xi8> -> vector<4xi8> -// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8> -// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8> -// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8> -// CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4> - %0 = arith.trunci %a : vector<8xi8> to vector<8xi4> - return %0 : vector<8xi4> -} - -// CHECK-LABEL: func.func @aligned_trunci_i32_to_i4( -func.func @aligned_trunci_i32_to_i4(%a: vector<8xi32>) -> vector<8xi4> { -// CHECK-SAME: %[[IN:.*]]: vector<8xi32>) -> vector<8xi4> { -// CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8> -// CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> -// CHECK: %[[I8:.*]] = arith.trunci %[[IN]] : vector<8xi32> to vector<8xi8> -// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[I8]] : vector<8xi8> -> vector<4xi8> -// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8> -// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8> -// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8> -// CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4> - %0 = arith.trunci %a : vector<8xi32> to vector<8xi4> - return %0 : vector<8xi4> -} - -// CHECK-LABEL: func.func @aligned_trunci_2d( -func.func @aligned_trunci_2d(%a: vector<8x32xi32>) -> vector<8x32xi4> { -// CHECK-NOT: vector.shuffle -// CHECK-NOT: vector.andi -// CHECK-NOT: vector.shli -// CHECK-NOT: vector.ori -// CHECK: arith.trunci {{.*}} : vector<8x32xi32> to vector<8x32xi8> -// CHECK-NOT: arith.trunci {{.*}} : vector<8x32xi8> to vector<8x32xi4> -// CHECK: vector.deinterleave - %0 = arith.trunci %a : vector<8x32xi32> to vector<8x32xi4> - return %0 : vector<8x32xi4> -} - -// CHECK-LABEL: func.func @aligned_trunci_nd( -// CHECK-SAME: %[[IN:.*]]: vector<3x8x32xi32>) -> vector<3x8x32xi4> { -func.func @aligned_trunci_nd(%a: vector<3x8x32xi32>) -> vector<3x8x32xi4> { - // CHECK: %[[LEFT_SHIFT_BITS:.*]] = arith.constant dense<4> : vector<3x8x16xi8> - // CHECK: %[[I4_MASK:.*]] = arith.constant dense<15> : vector<3x8x16xi8> - // CHECK: %[[I8:.*]] = arith.trunci %[[IN]] : vector<3x8x32xi32> to vector<3x8x32xi8> - // CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[I8]] : vector<3x8x32xi8> -> vector<3x8x16xi8> - // CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[I4_MASK]] : vector<3x8x16xi8> - // CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[LEFT_SHIFT_BITS]] : vector<3x8x16xi8> - // CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<3x8x16xi8> - // CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<3x8x16xi8> to vector<3x8x32xi4> - %0 = arith.trunci %a : vector<3x8x32xi32> to vector<3x8x32xi4> - return %0 : vector<3x8x32xi4> -} - -func.func @aligned_trunci_i8_to_i2_no_match(%a: vector<8xi8>) -> vector<8xi2> { - // CHECK-NOT: arith.bitcast - // CHECK: arith.trunci %[[IN:.*]] : vector<8xi8> to vector<8xi2> - %0 = arith.trunci %a : vector<8xi8> to vector<8xi2> - return %0 : vector<8xi2> -} - -// CHECK-LABEL: func.func @aligned_extui_i4_to_i8( -func.func @aligned_extui_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> { -// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> { -// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> -// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8> -// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> -// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8> -// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> -// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8> - %0 = arith.extui %a : vector<8xi4> to vector<8xi8> - return %0 : vector<8xi8> -} - -// CHECK-LABEL: func.func @aligned_extui_i2_to_i8( -func.func @aligned_extui_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> { -// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> { -// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8> -// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8> -// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8> -// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8> -// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8> -// Extract bits 0-1 -// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8> -// Extract bits 2-3 -// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8> -// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8> -// Extract bits 4-5 -// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8> -// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8> -// Extract bits 6-7 -// CHECK: %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8> -// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8> -// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8> -// CHECK: %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8> - %0 = arith.extui %a : vector<8xi2> to vector<8xi8> - return %0 : vector<8xi8> -} - -// CHECK-LABEL: func.func @aligned_extui_i4_to_i32( -func.func @aligned_extui_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> { -// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> { -// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> -// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8> -// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> -// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8> -// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> -// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8> -// CHECK: %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32> - %0 = arith.extui %a : vector<8xi4> to vector<8xi32> - return %0 : vector<8xi32> -} - -// CHECK-LABEL: func.func @aligned_extui_i2_to_i32( -func.func @aligned_extui_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> { -// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> { -// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8> -// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8> -// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8> -// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8> -// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8> -// Extract bits 0-1 -// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8> -// Extract bits 2-3 -// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8> -// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8> -// Extract bits 4-5 -// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8> -// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8> -// Extract bits 6-7 -// CHECK: %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8> -// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8> -// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8> -// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8> -// CHECK: %[[RESULT:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32> - %0 = arith.extui %a : vector<8xi2> to vector<8xi32> - return %0 : vector<8xi32> -} - -// CHECK-LABEL: func.func @aligned_extui_i4_to_i32_2d( -func.func @aligned_extui_i4_to_i32_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> { -// CHECK-SAME: %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> { -// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8> -// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8> -// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[VAL_0]] : vector<8x32xi4> to vector<8x16xi8> -// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8> -// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> -// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8> -// CHECK: %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32> - %0 = arith.extui %a : vector<8x32xi4> to vector<8x32xi32> - return %0 : vector<8x32xi32> -} - -// CHECK-LABEL: func.func @aligned_extui_i2_to_i32_2d( -func.func @aligned_extui_i2_to_i32_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> { -// CHECK-SAME: %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> { -// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8> -// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8> -// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8> -// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<8x8xi8> -// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8> -// Extract bits 0-1 -// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x8xi8> -// Extract bits 2-3 -// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<8x8xi8> -// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<8x8xi8> -// Extract bits 4-5 -// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<8x8xi8> -// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<8x8xi8> -// Extract bits 6-7 -// CHECK: %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<8x8xi8> -// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8> -// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8> -// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8> -// CHECK: %[[RESULT:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32> - %0 = arith.extui %a : vector<8x32xi2> to vector<8x32xi32> - return %0 : vector<8x32xi32> -} - -// CHECK-LABEL: func.func @aligned_sitofp( -func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> { -// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> { -// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> -// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> -// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> -// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8> -// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> -// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8> -// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32> - %0 = arith.sitofp %a : vector<8xi4> to vector<8xf32> - return %0 : vector<8xf32> -} - -// CHECK-LABEL: func.func @aligned_sitofp_2d( -func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> { -// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> { -// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8> -// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8> -// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> -// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8> -// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> -// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8> -// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32> - %0 = arith.sitofp %a : vector<8x32xi4> to vector<8x32xf32> - return %0 : vector<8x32xf32> -} - -// CHECK-LABEL: func.func @aligned_uitofp( -func.func @aligned_uitofp(%a: vector<8xi4>) -> vector<8xf32> { -// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> { -// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> -// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8> -// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> -// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8> -// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> -// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8> -// CHECK: %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32> - %0 = arith.uitofp %a : vector<8xi4> to vector<8xf32> - return %0 : vector<8xf32> -} - -// CHECK-LABEL: func.func @aligned_uitofp_2d( -func.func @aligned_uitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> { -// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> { -// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8> -// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8> -// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8> -// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8> -// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> -// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8> -// CHECK: %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32> - %0 = arith.uitofp %a : vector<8x32xi4> to vector<8x32xf32> - return %0 : vector<8x32xf32> -} - // CHECK-LABEL: func.func @i4_transpose( func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> { // CHECK-SAME: %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> { @@ -589,7 +213,6 @@ func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> { return %0 : vector<16x8xi7> } - module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { %f = transform.structured.match ops{["func.func"]} in %module_op diff --git a/mlir/test/Dialect/Vector/vector-rewrite-subbyte-ext-and-trunci.mlir b/mlir/test/Dialect/Vector/vector-rewrite-subbyte-ext-and-trunci.mlir new file mode 100644 index 000000000000..aa75e0200525 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-rewrite-subbyte-ext-and-trunci.mlir @@ -0,0 +1,415 @@ +// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s + +///---------------------------------------------------------------------------------------- +/// arith.extsi +/// +/// [Pattern: RewriteAlignedSubByteIntExt] +///---------------------------------------------------------------------------------------- +// Negative test - the trailing dim 1 is not a multiple of 2 (i.e. 8 / 4). +// CHECK-LABEL: func.func @unaligned_extsi_i4_to_i8( +func.func @unaligned_extsi_i4_to_i8(%a: vector<1xi4>) -> vector<1xi8> { + // CHECK-NOT: arith.bitcast + // CHECK: arith.extsi %[[IN:.*]] : vector<1xi4> to vector<1xi8> + %0 = arith.extsi %a : vector<1xi4> to vector<1xi8> + return %0 : vector<1xi8> +} + +// Negative test - the trailing dim 2 is not a multiple of 4 (i.e. 8 / 2). +// CHECK-LABEL: func.func @unaligned_extsi_i2_to_i8( +func.func @unaligned_extsi_i2_to_i8(%a: vector<2xi2>) -> vector<2xi8> { + // CHECK-NOT: arith.bitcast + // CHECK: arith.extsi %[[IN:.*]] : vector<2xi2> to vector<2xi8> + %0 = arith.extsi %a : vector<2xi2> to vector<2xi8> + return %0 : vector<2xi8> +} + +// CHECK-LABEL: func.func @aligned_extsi_i4_to_i8( +func.func @aligned_extsi_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> { +// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> { +// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> +// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8> + %0 = arith.extsi %a : vector<8xi4> to vector<8xi8> + return %0 : vector<8xi8> +} + +// CHECK-LABEL: func.func @aligned_extsi_i2_to_i8( +func.func @aligned_extsi_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> { +// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> { +// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8> +// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8> +// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8> +// Extract bits 0-1 +// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8> +// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8> +// Extract bits 2-3 +// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8> +// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8> +// Extract bits 4-5 +// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8> +// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8> +// Extract bits 6-7 +// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8> +// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8> +// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8> +// CHECK: %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8> + %0 = arith.extsi %a : vector<8xi2> to vector<8xi8> + return %0 : vector<8xi8> +} + +// CHECK-LABEL: func.func @aligned_extsi_i4_to_i32( +func.func @aligned_extsi_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> { +// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> { +// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> +// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8> +// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32> + %0 = arith.extsi %a : vector<8xi4> to vector<8xi32> + return %0 : vector<8xi32> +} + +// CHECK-LABEL: func.func @aligned_extsi_i2_to_i32( +func.func @aligned_extsi_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> { +// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> { +// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8> +// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8> +// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8> +// Extract bits 0-1 +// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8> +// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8> +// Extract bits 2-3 +// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8> +// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8> +// Extract bits 4-5 +// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8> +// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8> +// Extract bits 6-7 +// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8> +// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8> +// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8> +// CHECK: %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32> + %0 = arith.extsi %a : vector<8xi2> to vector<8xi32> + return %0 : vector<8xi32> +} + +// CHECK-LABEL: func.func @aligned_extsi_i4_to_i32_2d( +func.func @aligned_extsi_i4_to_i32_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> { +// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> { +// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8> +// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> +// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8> +// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8> +// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32> + %0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32> + return %0 : vector<8x32xi32> +} + +// CHECK-LABEL: func.func @aligned_extsi_i2_to_i32_2d( +func.func @aligned_extsi_i2_to_i32_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> { +// CHECK-SAME: %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> { +// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8> +// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8> +// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8> +// Extract bits 0-1 +// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<8x8xi8> +// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<8x8xi8> +// Extract bits 2-3 +// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<8x8xi8> +// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<8x8xi8> +// Extract bits 4-5 +// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<8x8xi8> +// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<8x8xi8> +// Extract bits 6-7 +// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<8x8xi8> +// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8> +// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8> +// CHECK: %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32> + %0 = arith.extsi %a : vector<8x32xi2> to vector<8x32xi32> + return %0 : vector<8x32xi32> +} + +///---------------------------------------------------------------------------------------- +/// arith.trunci +/// +/// [Pattern: RewriteAlignedSubByteIntTrunc] +///---------------------------------------------------------------------------------------- +// CHECK-LABEL: func.func @aligned_trunci_i8_to_i4( +func.func @aligned_trunci_i8_to_i4(%a: vector<8xi8>) -> vector<8xi4> { +// CHECK-SAME: %[[IN:.*]]: vector<8xi8>) -> vector<8xi4> { +// CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8> +// CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> +// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[IN]] : vector<8xi8> -> vector<4xi8> +// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8> +// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8> +// CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4> + %0 = arith.trunci %a : vector<8xi8> to vector<8xi4> + return %0 : vector<8xi4> +} + +// CHECK-LABEL: func.func @aligned_trunci_i32_to_i4( +func.func @aligned_trunci_i32_to_i4(%a: vector<8xi32>) -> vector<8xi4> { +// CHECK-SAME: %[[IN:.*]]: vector<8xi32>) -> vector<8xi4> { +// CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8> +// CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> +// CHECK: %[[I8:.*]] = arith.trunci %[[IN]] : vector<8xi32> to vector<8xi8> +// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[I8]] : vector<8xi8> -> vector<4xi8> +// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8> +// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8> +// CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4> + %0 = arith.trunci %a : vector<8xi32> to vector<8xi4> + return %0 : vector<8xi4> +} + +// CHECK-LABEL: func.func @aligned_trunci_2d( +func.func @aligned_trunci_2d(%a: vector<8x32xi32>) -> vector<8x32xi4> { +// CHECK-NOT: vector.shuffle +// CHECK-NOT: vector.andi +// CHECK-NOT: vector.shli +// CHECK-NOT: vector.ori +// CHECK: arith.trunci {{.*}} : vector<8x32xi32> to vector<8x32xi8> +// CHECK-NOT: arith.trunci {{.*}} : vector<8x32xi8> to vector<8x32xi4> +// CHECK: vector.deinterleave + %0 = arith.trunci %a : vector<8x32xi32> to vector<8x32xi4> + return %0 : vector<8x32xi4> +} + +// CHECK-LABEL: func.func @aligned_trunci_nd( +// CHECK-SAME: %[[IN:.*]]: vector<3x8x32xi32>) -> vector<3x8x32xi4> { +func.func @aligned_trunci_nd(%a: vector<3x8x32xi32>) -> vector<3x8x32xi4> { + // CHECK: %[[LEFT_SHIFT_BITS:.*]] = arith.constant dense<4> : vector<3x8x16xi8> + // CHECK: %[[I4_MASK:.*]] = arith.constant dense<15> : vector<3x8x16xi8> + // CHECK: %[[I8:.*]] = arith.trunci %[[IN]] : vector<3x8x32xi32> to vector<3x8x32xi8> + // CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[I8]] : vector<3x8x32xi8> -> vector<3x8x16xi8> + // CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[I4_MASK]] : vector<3x8x16xi8> + // CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[LEFT_SHIFT_BITS]] : vector<3x8x16xi8> + // CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<3x8x16xi8> + // CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<3x8x16xi8> to vector<3x8x32xi4> + %0 = arith.trunci %a : vector<3x8x32xi32> to vector<3x8x32xi4> + return %0 : vector<3x8x32xi4> +} + +func.func @aligned_trunci_i8_to_i2_no_match(%a: vector<8xi8>) -> vector<8xi2> { + // CHECK-NOT: arith.bitcast + // CHECK: arith.trunci %[[IN:.*]] : vector<8xi8> to vector<8xi2> + %0 = arith.trunci %a : vector<8xi8> to vector<8xi2> + return %0 : vector<8xi2> +} + +///---------------------------------------------------------------------------------------- +/// arith.extui +/// +/// [Pattern: RewriteAlignedSubByteIntExt] +///---------------------------------------------------------------------------------------- + +// CHECK-LABEL: func.func @aligned_extui_i4_to_i8( +func.func @aligned_extui_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> { +// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> { +// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> +// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> +// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8> +// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8> + %0 = arith.extui %a : vector<8xi4> to vector<8xi8> + return %0 : vector<8xi8> +} + +// CHECK-LABEL: func.func @aligned_extui_i2_to_i8( +func.func @aligned_extui_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> { +// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> { +// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8> +// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8> +// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8> +// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8> +// Extract bits 0-1 +// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8> +// Extract bits 2-3 +// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8> +// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8> +// Extract bits 4-5 +// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8> +// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8> +// Extract bits 6-7 +// CHECK: %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8> +// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8> +// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8> +// CHECK: %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8> + %0 = arith.extui %a : vector<8xi2> to vector<8xi8> + return %0 : vector<8xi8> +} + +// CHECK-LABEL: func.func @aligned_extui_i4_to_i32( +func.func @aligned_extui_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> { +// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> { +// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> +// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> +// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8> +// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8> +// CHECK: %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32> + %0 = arith.extui %a : vector<8xi4> to vector<8xi32> + return %0 : vector<8xi32> +} + +// CHECK-LABEL: func.func @aligned_extui_i2_to_i32( +func.func @aligned_extui_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> { +// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> { +// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8> +// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8> +// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8> +// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8> +// Extract bits 0-1 +// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8> +// Extract bits 2-3 +// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8> +// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8> +// Extract bits 4-5 +// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8> +// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8> +// Extract bits 6-7 +// CHECK: %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8> +// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8> +// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8> +// CHECK: %[[RESULT:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32> + %0 = arith.extui %a : vector<8xi2> to vector<8xi32> + return %0 : vector<8xi32> +} + +// CHECK-LABEL: func.func @aligned_extui_i4_to_i32_2d( +func.func @aligned_extui_i4_to_i32_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> { +// CHECK-SAME: %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> { +// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8> +// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[VAL_0]] : vector<8x32xi4> to vector<8x16xi8> +// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8> +// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8> +// CHECK: %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32> + %0 = arith.extui %a : vector<8x32xi4> to vector<8x32xi32> + return %0 : vector<8x32xi32> +} + +// CHECK-LABEL: func.func @aligned_extui_i2_to_i32_2d( +func.func @aligned_extui_i2_to_i32_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> { +// CHECK-SAME: %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> { +// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8> +// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8> +// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8> +// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<8x8xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8> +// Extract bits 0-1 +// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x8xi8> +// Extract bits 2-3 +// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<8x8xi8> +// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<8x8xi8> +// Extract bits 4-5 +// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<8x8xi8> +// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<8x8xi8> +// Extract bits 6-7 +// CHECK: %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<8x8xi8> +// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8> +// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8> +// CHECK: %[[RESULT:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32> + %0 = arith.extui %a : vector<8x32xi2> to vector<8x32xi32> + return %0 : vector<8x32xi32> +} + +///---------------------------------------------------------------------------------------- +/// arith.sitofp +/// +/// [Pattern: RewriteAlignedSubByteIntExt] +///---------------------------------------------------------------------------------------- + +// CHECK-LABEL: func.func @aligned_sitofp( +func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> { +// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> { +// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> +// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8> +// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32> + %0 = arith.sitofp %a : vector<8xi4> to vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: func.func @aligned_sitofp_2d( +func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> { +// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> { +// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8> +// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> +// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8> +// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8> +// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32> + %0 = arith.sitofp %a : vector<8x32xi4> to vector<8x32xf32> + return %0 : vector<8x32xf32> +} + +///---------------------------------------------------------------------------------------- +/// arith.uitofp +/// +/// [Pattern: RewriteAlignedSubByteIntExt] +///---------------------------------------------------------------------------------------- + +// CHECK-LABEL: func.func @aligned_uitofp( +func.func @aligned_uitofp(%a: vector<8xi4>) -> vector<8xf32> { +// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> { +// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> +// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> +// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8> +// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8> +// CHECK: %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32> + %0 = arith.uitofp %a : vector<8xi4> to vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: func.func @aligned_uitofp_2d( +func.func @aligned_uitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> { +// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> { +// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8> +// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8> +// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8> +// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8> +// CHECK: %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32> + %0 = arith.uitofp %a : vector<8x32xi4> to vector<8x32xf32> + return %0 : vector<8x32xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + + transform.apply_patterns to %f { + transform.apply_patterns.vector.rewrite_narrow_types + } : !transform.any_op + transform.yield + } +}