From 6c813e8a3c0f08b00a52f37b5468762f17de2258 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Wed, 21 May 2025 13:50:02 -0700 Subject: [PATCH] [mlir][ROCDL] Add fp4 and fp6 conversion intrinsics, fix fp8 immargs (#140801) This PR adds support for the scaled conversion intrinsics for fp4 and fp6 types so that they can be targetted by a future amdgpu dialect op or used directly. Additionally, this patch refactors the copy-paste-heavy fp8 versions of these scaled conversion intrinsics with tablegen `foreach` loops, and fixes the fact that certain immargs weren't being stored as attributes. Note that some of the MLIR-level tests for those scaled fp8 intrinsics had incorrect return types, which have been fixed. (Note that while the operations have a known return type, the IR format still prints that type for clarity). --- mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 684 +++++++++--------- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 24 +- .../AMDGPUToROCDL/8-bit-floats-ocp.mlir | 30 +- .../AMDGPUToROCDL/8-bit-floats.mlir | 30 +- mlir/test/Dialect/LLVMIR/rocdl.mlir | 146 +++- mlir/test/Target/LLVMIR/rocdl.mlir | 140 +++- 6 files changed, 605 insertions(+), 449 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 6fb9e3aba1f0..1dadb7d9e885 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -709,20 +709,23 @@ def ROCDL_PermlaneX16Op : ROCDL_IntrOp<"permlanex16", [], [0], }]; } -def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>, - BuildableType<"::mlir::VectorType::get(" - "{2},$_builder.getI16Type())">; +class ROCDL_ConcreteVector : + FixedVectorOfLengthAndType<[length], [elem]>, + BuildableType< + "::mlir::VectorType::get({" # length # "} ," + # elem.builderCall # ")">; -def ROCDL_V2F16Type : FixedVectorOfLengthAndType<[2], [F16]>, - BuildableType<"::mlir::VectorType::get(" - "{2},$_builder.getF16Type())">; +def ROCDL_V2I16Type : ROCDL_ConcreteVector; +def ROCDL_V2F16Type : ROCDL_ConcreteVector; +def ROCDL_V2BF16Type : ROCDL_ConcreteVector; +def ROCDL_V2F32Type : ROCDL_ConcreteVector; +def ROCDL_V6I32Type : ROCDL_ConcreteVector; +def ROCDL_V8I32Type : ROCDL_ConcreteVector; +def ROCDL_V16F32Type : ROCDL_ConcreteVector; +def ROCDL_V32F16Type : ROCDL_ConcreteVector; +def ROCDL_V32BF16Type : ROCDL_ConcreteVector; +def ROCDL_V32F32Type : ROCDL_ConcreteVector; -def ROCDL_V2BF16Type : FixedVectorOfLengthAndType<[2], [BF16]>, - BuildableType<"::mlir::VectorType::get(" - "{2},$_builder.getBF16Type())">; - -// TODO: The word and byte selectors are immarg in LLVM -// update to be attributes in MLIR //===---------------------------------------------------------------------===// // 16-bit float intrinsics //===---------------------------------------------------------------------===// @@ -738,279 +741,12 @@ def ROCDL_CvtPkRtz: }]; } -def ROCDL_CvtScaleF32PkFp8F16Op : - ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f16", [], [], [Pure], 1>, - Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> { - let summary = "Scale and convert f16 to packed fp8"; - let description = [{ - Scale `src` by the exponent in `scale`, then convert to packed fp8. - Store the result in low/high word of `old` based on $wordSel, preserving the other word. - }]; - let assemblyFormat = [{ - attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res) - }]; -} - -def ROCDL_CvtScaleF32PkFp8Bf16Op : - ROCDL_IntrOp<"cvt.scalef32.pk.fp8.bf16", [], [], [Pure], 1>, - Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> { - let summary = "Scale and convert packed bf16 to packed fp8"; - let description = [{ - Scale `src` by the exponent in `scale`, then convert to packed fp8. - Store the result in low/high word of `old` based on $wordSel, preserving the other word. - }]; - let assemblyFormat = [{ - attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res) - }]; -} - - -def ROCDL_CvtScaleF32PkBf8F16Op : - ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f16", [], [], [Pure], 1>, - Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> { - let summary = "Scale and convert f16 to packed bf8"; - let description = [{ - Scale `src` by the exponent in `scale`, then convert to packed bf8. - Store the result in low/high word of `old` based on $wordSel, preserving the other word. - }]; - let assemblyFormat = [{ - attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res) - }]; -} - - -def ROCDL_CvtScaleF32PkBf8Bf16Op : - ROCDL_IntrOp<"cvt.scalef32.pk.bf8.bf16", [], [], [Pure], 1>, - Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> { - let summary = "Scale and convert bf16 to packed bf8"; - let description = [{ - Scale `src` by the exponent in `scale`, then convert to packed bf8. - Store the result in low/high word of `old` based on $wordSel, preserving the other word. - }]; - let assemblyFormat = [{ - attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res) - }]; -} - -def ROCDL_CvtScaleF32SrFp8F16Op : - ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f16", [], [], [Pure], 1>, - Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> { - let summary = "Scale and convert f16 to packed fp8 using stochastic rounding"; - let description = [{ - Scale `src` by the exponent in `scale`, then convert to packed p8 with stochastic rounding - using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others. - - }]; - let assemblyFormat = [{ - attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res) - }]; -} - -def ROCDL_CvtScaleF32SrBf8F16Op : - ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f16", [], [], [Pure], 1>, - Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> { - let summary = "Scale and convert f16 to packed bf8 using stochastic rounding"; - let description = [{ - Scale `src` by the exponent in `scale`, then convert to packed bf8 with stochastic rounding - using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others. - - }]; - let assemblyFormat = [{ - attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res) - }]; -} - -def ROCDL_CvtScaleF32SrFp8Bf16Op : - ROCDL_IntrOp<"cvt.scalef32.sr.fp8.bf16", [], [], [Pure], 1>, - Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> { - let summary = "Scale and convert packed bf16 to packed fp8 using stochastic rounding"; - let description = [{ - Scale `src` by the exponent in `scale`, then convert to packed fp8 with stochastic rounding - using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others. - - }]; - let assemblyFormat = [{ - attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res) - }]; -} - -def ROCDL_CvtScaleF32SrBf8Bf16Op : - ROCDL_IntrOp<"cvt.scalef32.sr.bf8.bf16", [], [], [Pure], 1>, - Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> { - let summary = "Scale and convert bf16 to packed fp8 using stochastic rounding"; - let description = [{ - Scale `src` by the exponent in `scale`, then convert to packed p8 with stochastic rounding - using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others. - - }]; - let assemblyFormat = [{ - attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res) - }]; -} - -def ROCDL_CvtScaleF32PkF16Fp8Op : - ROCDL_IntrOp<"cvt.scalef32.pk.f16.fp8", [], [], [Pure], 1>, - Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> { - let summary = "Convert fp8 to packed f16 and scale"; - let description = [{ Convert `src` based on $wordSel to packed f16, then scale - the packed values by the exponent in `scale`. - }]; - let assemblyFormat = [{ - attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res) - }]; -} - -def ROCDL_CvtScaleF32PkF16Bf8Op : - ROCDL_IntrOp<"cvt.scalef32.pk.f16.bf8", [], [], [Pure], 1>, - Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> { - let summary = "convert bf8 to packed f16 and scale"; - let description = [{ Convert `src` based on $wordSel to packed f16, then scale - the packed values by exponent in `scale`. - }]; - let assemblyFormat = [{ - attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res) - }]; -} - -def ROCDL_CvtScaleF32PkBf16Fp8Op : - ROCDL_IntrOp<"cvt.scalef32.pk.bf16.fp8", [], [], [Pure], 1>, - Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> { - let summary = "Convert fp8 to packed bf16 and scale"; - let description = [{ Convert `src` based on $wordSel to packed bf16, then scale - the packed values by the exponent in `scale`. - }]; - let assemblyFormat = [{ - attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res) - }]; -} - -def ROCDL_CvtScaleF32PkBf16Bf8Op : - ROCDL_IntrOp<"cvt.scalef32.pk.bf16.bf8", [], [], [Pure], 1>, - Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> { - let summary = "Convert bf8 to packed bf16 and scale"; - let description = [{ Convert `src` based on $wordSel to packed bf16, then scale - the packed values by the exponent in `scale`. - }]; - let assemblyFormat = [{ - attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res) - }]; -} - -def ROCDL_CvtScaleF16Fp8Op : - ROCDL_IntrOp<"cvt.scalef32.f16.fp8", [], [], [Pure], 1>, - Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> { - let summary = "Scale and convert fp8 to f16"; - let description = [{ Convert `src` based on $wordSel to f16, then scale the value - by the exponent in `scale`. Store the result into the `byteSel`th byte of `old`, - preserving the others. - }]; - let assemblyFormat = [{ - attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res) - }]; -} - -def ROCDL_CvtScaleF16Bf8Op : - ROCDL_IntrOp<"cvt.scalef32.f16.bf8", [], [], [Pure], 1>, - Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> { - let summary = "Scale and convert fp8 to f16"; - let description = [{ Convert `src` based on $wordSel to f16, then scale the value - by the exponent in `scale`. Store the result into the `byteSel`th byte of `old`, - preserving the others. - }]; - let assemblyFormat = [{ - attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res) - }]; -} - -//===---------------------------------------------------------------------===// -// 32-bit float intrinsics -//===---------------------------------------------------------------------===// -def ROCDL_CvtScaleF32PkF32Fp8Op : - ROCDL_IntrOp<"cvt.scalef32.pk.f32.fp8", [], [], [Pure], 1>, - Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> { - let summary = "Scale and convert packed fp8 to packed f32"; - let description = [{ - Convert `src` based on $wordSel to packed fp32, then scale the packed values by - the exponent in `scale`. Store the result in a vector. - }]; - let assemblyFormat = [{ - attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res) - }]; -} -def ROCDL_CvtScaleF32PkF32Bf8Op : - ROCDL_IntrOp<"cvt.scalef32.pk.f32.bf8", [], [], [Pure], 1>, - Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> { - let summary = "Scale and convert packed bf8 to packed f32"; - let description = [{ - Convert `src` based on $wordSel to packed fp32, then scale the packed values by - the exponent in `scale`. Store the result in a vector. - }]; - let assemblyFormat = [{ - attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res) - }]; -} -//===---------------------------------------------------------------------===// -// 8-bit float scale intrinsics -//===---------------------------------------------------------------------===// -def ROCDL_CvtScaleF32PkFp8F32Op : - ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f32", [], [], [Pure], 1>, - Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32:$scale, I1:$wordSel)> { - let summary = "Scale and convert two f32's to packed fp8"; - let description = [{ - Scale `srcA` and `srcB` by the exponent in `scale` then convert to packed fp8 - and store into the low/high word of `old`, preserving the other word. - }]; - let assemblyFormat = [{ - attr-dict $srcA `,` $srcB `,` $scale `->` $old `[` $wordSel `]` `:` type($res) - }]; -} - -def ROCDL_CvtScaleF32PkBf8F32Op : - ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f32", [], [], [Pure], 1>, - Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32: $scale, I1:$wordSel)> { - let summary = "Scale and convert two f32's to packed bf8"; - let description = [{ - Scale `srcA` and `srcB` by the exponent in `scale` then convert to packed bf8 - and store into the low/high word of `old`, preserving the other word. - }]; - let assemblyFormat = [{ - attr-dict $srcA `,` $srcB `,` $scale `->` $old `[` $wordSel `]` `:` type($res) - }]; -} - -def ROCDL_CvtScaleF32SrFp8F32Op : - ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f32", [], [], [Pure], 1>, - Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> { - let summary = "Scale and convert f32 to fp8 using stochastic rounding"; - let description = [{ - Scale `src` by the exponent in `scale` then convert to fp8 with stochastic rounding - using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others. - }]; - let assemblyFormat = [{ - attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res) - }]; -} - - -def ROCDL_CvtScaleF32SrBf8F32Op : - ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f32", [], [], [Pure], 1>, - Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> { - let summary = "Scale and convert f32 to bf8 using stochastic rounding"; - let description = [{ - Scale `src` by the exponent in `scale` then convert to bf8 with stochastic rounding - using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others. - }]; - let assemblyFormat = [{ - attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res) - }]; -} - //===---------------------------------------------------------------------===// // 8-bit float intrinsics //===---------------------------------------------------------------------===// def ROCDL_CvtF32Bf8Op : - ROCDL_IntrOp<"cvt.f32.bf8", [], [], [Pure], 1>, - Arguments<(ins I32:$srcA, I32:$byteSel)> { + ROCDL_ConcreteNonMemIntrOp<"cvt.f32.bf8", [Pure], 1, [1], ["byteSel"]>, + Arguments<(ins I32:$srcA, I32Attr:$byteSel)> { let summary = "Convert bf8 to f32"; let description = [{ Convert 8-bit bf8 value from the `byteSel`th bit of `srcA` to fp32. @@ -1020,23 +756,9 @@ def ROCDL_CvtF32Bf8Op : }]; } -def ROCDL_CvtScaleF32Bf8Op : - ROCDL_IntrOp<"cvt.scalef32.f32.bf8", [], [], [Pure], 1>, - Arguments<(ins I32:$src, F32: $scale, I32:$byteSel)> { - let summary = "Scale and convert bf8 to f32"; - let description = [{ - Scale `src` by the exponent in `scale` then convert 8-bit bf8 value - from the `byteSel`th bit of `src` to fp32. - }]; - let assemblyFormat = [{ - attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res) - }]; -} - - def ROCDL_CvtF32Fp8Op : - ROCDL_IntrOp<"cvt.f32.fp8", [], [], [Pure], 1>, - Arguments<(ins I32:$srcA, I32:$byteSel)> { + ROCDL_ConcreteNonMemIntrOp<"cvt.f32.fp8", [Pure], 1, [1], ["byteSel"]>, + Arguments<(ins I32:$srcA, I32Attr:$byteSel)> { let summary = "Convert fp8 to f32"; let description = [{ Convert 8-bit fp8 value from the `byteSel`th bit of `srcA` to fp32. @@ -1046,24 +768,9 @@ def ROCDL_CvtF32Fp8Op : }]; } - -def ROCDL_CvtScaleF32Fp8Op : - ROCDL_IntrOp<"cvt.scalef32.f32.fp8", [], [], [Pure], 1>, - Arguments<(ins I32:$src, F32: $scale, I32:$byteSel)> { - let summary = "Scale and convert fp8 to f32"; - let description = [{ - Scale `src` by the exponent in `scale` then convert 8-bit fp8 value - from the `byteSel`th bit of `src` to fp32. - - }]; - let assemblyFormat = [{ - attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res) - }]; -} - def ROCDL_CvtPkF32Fp8Op : - ROCDL_IntrOp<"cvt.pk.f32.fp8", [], [], [Pure], 1>, - Arguments<(ins I32:$src, I1:$wordSel)> { + ROCDL_ConcreteNonMemIntrOp<"cvt.pk.f32.fp8", [Pure], 1, [1], ["wordSel"]>, + Arguments<(ins I32:$src, I1Attr:$wordSel)> { let summary = "Convert packed fp8 to packed f32"; let description = [{ Convert `src` based on $wordSel to packed fp32. @@ -1074,8 +781,8 @@ def ROCDL_CvtPkF32Fp8Op : } def ROCDL_CvtPkF32Bf8Op : - ROCDL_IntrOp<"cvt.pk.f32.bf8", [], [], [Pure], 1>, - Arguments<(ins I32:$src, I1:$wordSel)> { + ROCDL_ConcreteNonMemIntrOp<"cvt.pk.f32.bf8", [Pure], 1, [1], ["wordSel"]>, + Arguments<(ins I32:$src, I1Attr:$wordSel)> { let summary = "Convert packed bf8 to packed f32"; let description = [{ Convert `src` based on $wordSel to packed fp32, @@ -1086,8 +793,8 @@ def ROCDL_CvtPkF32Bf8Op : } def ROCDL_CvtPkBf8F32Op : - ROCDL_IntrOp<"cvt.pk.bf8.f32", [], [], [Pure], 1>, - Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1:$wordSel)> { + ROCDL_ConcreteNonMemIntrOp<"cvt.pk.bf8.f32", [Pure], 1, [3], ["wordSel"]>, + Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1Attr:$wordSel)> { let summary = "Convert two f32's to bf8"; let description = [{ Convert `srcA` and `srcB` to bf8 and store into the low/high word of @@ -1099,8 +806,8 @@ def ROCDL_CvtPkBf8F32Op : } def ROCDL_CvtPkFp8F32Op : - ROCDL_IntrOp<"cvt.pk.fp8.f32", [], [], [Pure], 1>, - Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1:$wordSel)> { + ROCDL_ConcreteNonMemIntrOp<"cvt.pk.fp8.f32", [Pure], 1, [3], ["wordSel"]>, + Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1Attr:$wordSel)> { let summary = "Convert two f32's to fp8"; let description = [{ Convert `srcA` and `srcB` to fp8 and store into the low/high word of @@ -1112,8 +819,8 @@ def ROCDL_CvtPkFp8F32Op : } def ROCDL_CvtSrBf8F32Op : - ROCDL_IntrOp<"cvt.sr.bf8.f32", [], [], [Pure], 1>, - Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32:$byteSel)> { + ROCDL_ConcreteNonMemIntrOp<"cvt.sr.bf8.f32", [Pure], 1, [3], ["byteSel"]>, + Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32Attr:$byteSel)> { let summary = "Convert f32 to bf8, stochiastic rounding"; let description = [{ Convert `srcA` to bf8, adding the rounding factor from `srcB`, @@ -1125,8 +832,8 @@ def ROCDL_CvtSrBf8F32Op : } def ROCDL_CvtSrFp8F32Op : - ROCDL_IntrOp<"cvt.sr.fp8.f32", [], [], [Pure], 1>, - Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32:$byteSel)> { + ROCDL_ConcreteNonMemIntrOp<"cvt.sr.fp8.f32", [Pure], 1, [3], ["byteSel"]>, + Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32Attr:$byteSel)> { let summary = "Convert f32 to fp8, stochiastic rounding"; let description = [{ Convert `srcA` to fp8, adding the rounding factor from `srcB`, @@ -1137,6 +844,335 @@ def ROCDL_CvtSrFp8F32Op : }]; } +//===---------------------------------------------------------------------===// +// Scaled float conversion intrinsics +// +// These are using some tablegen trickery to avoid repetitive documentation +//===---------------------------------------------------------------------===// + +// Pair used so we can iterate over types.. +class ScaleArgInfo { + TypeConstraint type = argTyVal; + string name = !tolower(typeName); + string nameForOp = typeName; +} + +//===---------------------------------------------------------------------===// +// Scaled 32x6-bit float float conversion intrinsics +//===---------------------------------------------------------------------===// +foreach smallT = [ + // MLIR f6E2M3FN + ScaleArgInfo, + // MLIR f8E3M2FN + ScaleArgInfo +] in { + foreach largeT = [ + ScaleArgInfo, + ScaleArgInfo, + ScaleArgInfo, + ] in { + // Note: rouding down f32 values has a special case where + // we have to use 2 16xf32 arguments. + if !ne(largeT.name, "f32") then { + def ROCDL_CvtScaleF32Pk32 # smallT.nameForOp # largeT.nameForOp # Op : + ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk32." # smallT.name # "." # largeT.name, + [Pure], 1>, + Arguments<(ins largeT.type:$src, F32:$scale)> { + let results = (outs smallT.type:$res); + let summary = "Scale and convert packed " + # largeT.name # " to packed " # smallT.name; + let description = [{ + Convert 32 packed }] # largeT.name # [{ values to packed }] + # smallT.name # [{, dividing by the exponent part of `scale` + before doing so. + }]; + let assemblyFormat = [{ + attr-dict $src `,` $scale `:` type($res) + }]; + } + } // if + + def ROCDL_CvtScaleF32SrPk32 # smallT.nameForOp # largeT.nameForOp # Op : + ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.sr.pk32." # smallT.name # "." # largeT.name, + [Pure], 1>, + Arguments<(ins largeT.type:$src, I32:$seed, F32:$scale)> { + let results = (outs smallT.type:$res); + let summary = "Scale and convert packed " + # largeT.name # " to packed " # smallT.name + # " with stochiastic rounding"; + let description = [{ + Convert 32 packed }] # largeT.name # [{ values to packed }] + # smallT.name # [{, dividing by the exponent part of `scale` + before doing so and applying random rounding derived from + `seed`. + }]; + let assemblyFormat = [{ + attr-dict $src `,` $seed `,` $scale `:` type($res) + }]; + } + + def ROCDL_CvtScaleF32Pk32 # largeT.nameForOp # smallT.nameForOp # Op : + ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk32." # largeT.name # "." # smallT.name, + [Pure], 1>, + Arguments<(ins smallT.type:$src, F32:$scale)> { + let results = (outs largeT.type:$res); + let summary = "Scale and convert packed " + # smallT.name # " to packed " # largeT.name; + let description = [{ + Convert 32 packed }] # smallT.name # [{ values to packed }] + # largeT.name # [{, multiplying by the exponent part of `scale` + before doing so. + }]; + let assemblyFormat = [{ + attr-dict $src `,` $scale `:` type($res) + }]; + } + } // foreach largeT + + def ROCDL_CvtScaleF322xPk16 # smallT.nameForOp # F32Op : + ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.2xpk16." # smallT.name # ".f32", + [Pure], 1>, + Arguments<(ins ROCDL_V16F32Type:$src0, ROCDL_V16F32Type:$src1, F32:$scale)> { + let results = (outs smallT.type:$res); + let summary = "Scale and convert two vector<16xf32> to 32 packed " # smallT.name; + let description = [{ + Convert 32 single-precision float values, packed into two length-16 + vectors that will be logically concanenated, to packed }] + # smallT.name # [{, dividing by the exponent part of `scale` + before doing so. + }]; + let assemblyFormat = [{ + attr-dict $src0 `,` $src1 `,` $scale `:` type($res) + }]; + } +} // forach smallT + +//===---------------------------------------------------------------------===// +// Scaled conversions to/from fp8/bf8 (f8E4M3FN / f8E5M2) +//===---------------------------------------------------------------------===// +foreach smallTOp = ["Fp8", "Bf8"] in { + defvar smallT = !tolower(smallTOp); + + def ROCDL_CvtScaleF32F16 # smallTOp # Op : + ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.f16." # smallT, + [Pure], 1, [3, 4], ["srcSelIndex", "dstLoHiSel"]>, + Arguments<(ins ROCDL_V2F16Type:$oldVdst, I32:$src, F32:$scale, I32Attr:$srcSelIndex, I1Attr:$dstLoHiSel)> { + let results = (outs ROCDL_V2F16Type:$res); + let summary = "Scaled convert " # smallT # " from packed vector to f16, updating tied result"; + let description = [{ + Convert a }] # smallT # [{ byte from `src`, selected by + `srcSelIndex`, to f16 while multiplying it by the expontent of `scale`, + and place it into the `dstLoHiSel`th bit + of `oldVdst` preserving the other element of that vector in + the return value. + + The bytes are stored as an `i32` and not a `<4 x i8>`. + }]; + let assemblyFormat = [{ + attr-dict $src `[` $srcSelIndex `]` `,` $scale `->` $oldVdst `[` $dstLoHiSel `]` `:` type($res) + }]; + } + + def ROCDL_CvtScaleF32F32 # smallTOp # Op : + ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.f32." # smallT, + [Pure], 1, [2], ["srcSelIndex"]>, + Arguments<(ins I32:$src, F32:$scale, I32Attr:$srcSelIndex)> { + let results = (outs F32:$res); + let summary = "Scaled convert " # smallT # " from packed vector to f32"; + let description = [{ + Convert a }] # smallT # [{ byte from `src`, selected by + `srcSelIndex`, to f32, multiplying it by the exponent of `scale`. + + The bytes are stored in an `i32`, not a `<4 x i8>`. + }]; + let assemblyFormat = [{ + attr-dict $src `[` $srcSelIndex `]` `,` $scale `:` type($res) + }]; + } + + def ROCDL_CvtScaleF32Pk # smallTOp # F32Op : + ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk." # smallT # ".f32", + [Pure], 1, [4], ["dstLoHiSel"]>, + Arguments<(ins ROCDL_V2I16Type:$oldVdst, F32:$src0, F32:$src1, F32:$scale, I1Attr:$dstLoHiSel)> { + let results = (outs ROCDL_V2I16Type:$res); + let summary = "Scaled convert two f32 to two " # smallT # ", updating packed vector"; + let description = [{ + Convert two f32 values in `src0` and `src1` to two }] # smallT # [{ bytes, + dividing by the exponent in `scale`. The bytes are packed into + a 16-bit value which is inserted into `oldVdst` at the `dstLoHiSel` + position, with the entire updated vector being returned. + }]; + let assemblyFormat = [{ + attr-dict $src0 `,` $src1 `,` $scale `->` $oldVdst `[` $dstLoHiSel `]` `:` type($res) + }]; + } + + foreach largeT = [ + ScaleArgInfo, + ScaleArgInfo, + ] in { + def ROCDL_CvtScaleF32Pk # smallTOp # largeT.nameForOp # Op : + ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk." # smallT # "." # largeT.name, + [Pure], 1, [3], ["dstLoHiSel"]>, + Arguments<(ins ROCDL_V2I16Type:$oldVdst, largeT.type:$src0, F32:$scale, I1Attr:$dstLoHiSel)> { + let results = (outs ROCDL_V2I16Type:$res); + let summary = "Scaled convert two " # largeT.name # "to two " # smallT # ", updating packed vector"; + let description = [{ + Convert two }] # largeT.name # [{ values in `src0` to two }] + # smallT # [{ bytes, dividing by the exponent in `scale`. The bytes are + packed into a 16-bit value which is inserted into `oldVdst` at the + `dstLoHiSel` position, with the entire updated vector being returned. + }]; + let assemblyFormat = [{ + attr-dict $src0 `,` $scale `->` $oldVdst `[` $dstLoHiSel `]` `:` type($res) + }]; + } + } // foreach largeT + + foreach largeT = [ + ScaleArgInfo, + ScaleArgInfo, + ScaleArgInfo + ] in { + def ROCDL_CvtScaleF32Pk # largeT.nameForOp # smallTOp # Op : + ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk." # largeT.name # "." # smallT, + [Pure], 1, [2], ["srcLoHiSel"]>, + Arguments<(ins I32:$src, F32:$scale, I1Attr:$srcLoHiSel)> { + let results = (outs largeT.type:$res); + let summary = "Scaled convert two " # smallT # "to two " # largeT.name #; + let description = [{ + Convert two packed }] # smallT # [{ values in `src0` to two }] + # largeT.name # [{ values, multiplying by the exponent in `scale`. + The two values to be converted are selected from the low or high half + of `src` (a packed vector represented as an `i32`) + on the basis of `srcLoHiSel`. + }]; + let assemblyFormat = [{ + attr-dict $src `[` $srcLoHiSel `]` `,` $scale `:` type($res) + }]; + } + } // foreach largeT + + foreach largeT = [ + ScaleArgInfo, + ScaleArgInfo, + ScaleArgInfo + ] in { + def ROCDL_CvtScaleF32Sr # smallTOp # largeT.nameForOp # Op : + ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.sr." # smallT # "." # largeT.name, + [Pure], 1, [4], ["dstSelIndex"]>, + Arguments<(ins I32:$oldVdst, largeT.type:$src0, I32:$seed, F32:$scale, I32Attr:$dstSelIndex)> { + let results = (outs I32:$res); + let summary = "Scaled convert " # largeT.name # "to " # smallT # " with stochiastic rounding, updating packed vector"; + let description = [{ + Convert a }] # largeT.name # [{ value in `src0` to a }] + # smallT # [{ bytes, dividing by the exponent in `scale` and using `seed` + for stochiastic rounding. Place the resulting byte in the + `dstSelIndex`th bit of `oldVdst` and return the entire packed vector, + which is stored as an `i32`. + }]; + let assemblyFormat = [{ + attr-dict $src0 `,` $seed `,` $scale `->` $oldVdst `[` $dstSelIndex `]` `:` type($res) + }]; + } + } // foreach largeT +} // foreach smallTOp + +//===---------------------------------------------------------------------===// +// Scaled conversions to/from fp4 (f4E2M1FN) +//===---------------------------------------------------------------------===// + +foreach largeT = [ + ScaleArgInfo, + ScaleArgInfo, + ScaleArgInfo, +] in { + // Note: rouding down f32 values has a special case where + // we have to use 2 float arguments. + if !ne(largeT.name, "f32") then { + def ROCDL_CvtScaleF32PkFp4 # largeT.nameForOp # Op : + ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk.fp4." # largeT.name, + [Pure], 1, [3], ["dstSelIndex"]>, + Arguments<(ins I32:$oldVdst, largeT.type:$src, F32:$scale, I32Attr:$dstSelIndex)> { + let results = (outs I32:$res); + let summary = "Scale and convert two " + # largeT.name # " to packed fp4, updating tied vector"; + let description = [{ + Convert two packed }] # largeT.name # [{ values to packed + fp4, dividing by the exponent part of `scale` + before doing so. + + The two scaled values are packed into a byte. + That byte is used to update the `dstSelIndex`th + byte of `oldVdst`, which is returned in its entirity. + }]; + let assemblyFormat = [{ + attr-dict $src `,` $scale `->` $oldVdst `[` $dstSelIndex `]` `:` type($res) + }]; + } + } // if + + def ROCDL_CvtScaleF32SrPkFp4 # largeT.nameForOp # Op : + ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.sr.pk.fp4." # largeT.name, + [Pure], 1, [4], ["dstSelIndex"]>, + Arguments<(ins I32:$oldVdst, largeT.type:$src, I32:$seed, F32:$scale, I32Attr:$dstSelIndex)> { + let results = (outs I32:$res); + let summary = "Scale and convert two " + # largeT.name # " to packed fp4 with stochiastic rounding, updating tied vector"; + let description = [{ + Convert two packed }] # largeT.name # [{ values to packed + fp4, dividing by the exponent part of `scale` + before doing so and using `seed` as the random seed for + stochiastic rounding. + + The two scaled values are packed (little-endian) + into a byte. That byte is used to update the `dstSelIndex`th + byte of `oldVdst`, which is returned in its entirity. + }]; + let assemblyFormat = [{ + attr-dict $src `,` $seed `,` $scale `->` $oldVdst `[` $dstSelIndex `]` `:` type($res) + }]; + } + + def ROCDL_CvtScaleF32Pk # largeT.nameForOp # Fp4Op : + ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk." # largeT.name # ".fp4", + [Pure], 1, [2], ["srcSelIndex"]>, + Arguments<(ins I32:$src, F32:$scale, I32Attr:$srcSelIndex)> { + let results = (outs largeT.type:$res); + let summary = "Scale and convert two packed fp4 to packed " # largeT.name; + let description = [{ + Convert two packed fp4 (f4E2M1) values stored as one byte of a 32-bit integer + to packed }] # largeT.name # [{, multiplying by the exponent part of `scale` + before doing so. + + The byte to convert is chosen by `srcSelIndex`. + }]; + let assemblyFormat = [{ + attr-dict $src `[` $srcSelIndex `]` `,` $scale `:` type($res) + }]; + } +} // foreach largeT + +def ROCDL_CvtScaleF32PkFp4F32Op : + ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk.fp4.f32", + [Pure], 1, [4], ["dstSelIndex"]>, + Arguments<(ins I32:$oldVdst, F32:$src0, F32:$src1, F32:$scale, I32Attr:$dstSelIndex)> { + let results = (outs I32:$res); + let summary = "Scale and convert two f32 values to two packed fp4, updating tied vector"; + let description = [{ + Convert two single-precision float values, passed in `src0` and `src1` + into two fp4 values, dividing them by the expontent part of `scale` + before doing so. + + The two scaled values are packed into a byte. + That byte is used to update the `dstSelIndex`th + byte of `oldVdst`, which is returned in its entirity. + }]; + let assemblyFormat = [{ + attr-dict $src0 `,` $src1 `,` $scale `->` $oldVdst `[` $dstSelIndex `]` `:` type($res) + }]; +} + //===----------------------------------------------------------------------===// // ROCDL target attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 0694cf27faff..c463b64b5f77 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1210,22 +1210,20 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( } Value i32Source = rewriter.create(loc, i32, source); if (resultVecType) { - Value wordSel = createI1Constant(rewriter, loc, op.getIndex()); if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) { rewriter.replaceOpWithNewOp(op, f32, i32Source, - wordSel); + op.getIndex()); } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) { rewriter.replaceOpWithNewOp(op, f32, i32Source, - wordSel); + op.getIndex()); } } else { - Value byteSel = createI32Constant(rewriter, loc, op.getIndex()); if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) { rewriter.replaceOpWithNewOp(op, f32, i32Source, - byteSel); + op.getIndex()); } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) { rewriter.replaceOpWithNewOp(op, f32, i32Source, - byteSel); + op.getIndex()); } } return success(); @@ -1253,15 +1251,14 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite( existing = rewriter.create(loc, i32, existing); else existing = rewriter.create(loc, i32); - Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex()); Value result; if (typeIsExpectedBf8ForChipset(chipset, resultElemType)) result = rewriter.create(loc, i32, sourceA, sourceB, - existing, wordSel); + existing, op.getWordIndex()); else if (typeIsExpectedFp8ForChipset(chipset, resultElemType)) result = rewriter.create(loc, i32, sourceA, sourceB, - existing, wordSel); + existing, op.getWordIndex()); result = rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(resultType), result); @@ -1288,15 +1285,14 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite( existing = rewriter.create(loc, i32, existing); else existing = rewriter.create(loc, i32); - Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex()); Value result; if (typeIsExpectedBf8ForChipset(chipset, resultElemType)) - result = rewriter.create(loc, i32, source, stoch, - existing, byteSel); + result = rewriter.create( + loc, i32, source, stoch, existing, op.getStoreIndex()); else if (typeIsExpectedFp8ForChipset(chipset, resultElemType)) - result = rewriter.create(loc, i32, source, stoch, - existing, byteSel); + result = rewriter.create( + loc, i32, source, stoch, existing, op.getStoreIndex()); result = rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(resultType), result); diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir index ea0c3afbd902..464d47216c81 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-ocp.mlir @@ -7,8 +7,7 @@ // CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8> // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32 -// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32 -// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32 +// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]][0] : f32 // CHECK: return [[EXT]] : f32 func.func @ext_scalar(%v: f8E5M2) -> f32 { %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to f32 @@ -25,8 +24,7 @@ func.func @ext_scalar(%v: f8E5M2) -> f32 { // CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8> // CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8> // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32 -// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32 -// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32 +// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]][1] : f32 // CHECK: return [[EXT]] : f32 func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 { %ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FN> to f32 @@ -36,8 +34,7 @@ func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 { // CHECK-LABEL: func @ext_full_vec( // CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8> // CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32 -// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32 -// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32 +// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]][3] : f32 // CHECK: return [[EXT]] : f32 func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> f32 { %ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FN> to f32 @@ -54,8 +51,7 @@ func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> f32 { // CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8> // CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8> // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32 -// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(false) : i1 -// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : vector<2xf32> +// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]][false] : vector<2xf32> // CHECK: return [[EXT]] func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FN>) -> vector<2xf32> { %ret = amdgpu.ext_packed_fp8 %v[0] : vector<2xf8E4M3FN> to vector<2xf32> @@ -65,8 +61,7 @@ func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FN>) -> vector<2xf32> { // CHECK-LABEL: func @ext_packed_4xfp8 // CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8> // CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32 -// CHECK: [[C3:%.+]] = llvm.mlir.constant(true) : i1 -// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C3]]] : vector<2xf32> +// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]][true] : vector<2xf32> // CHECK: return [[EXT]] : vector<2xf32> func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FN>) -> vector<2xf32> { %ret = amdgpu.ext_packed_fp8 %v[1] : vector<4xf8E4M3FN> to vector<2xf32> @@ -77,8 +72,7 @@ func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FN>) -> vector<2xf32> { // CHECK-SAME: ([[V:%.+]]: f32) // CHECK: [[V2:%.+]] = llvm.mlir.undef : f32 // CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32 -// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1 -// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32 +// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]][false] : i32 // CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8> // CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN> func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FN> { @@ -89,8 +83,7 @@ func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FN> { // CHECK-LABEL: func @packed_truncx2 // CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32) // CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32 -// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1 -// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32 +// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]][false] : i32 // CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8> // CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN> func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FN> { @@ -102,8 +95,7 @@ func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FN> { // CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[EXISTING:%.+]]: vector<4xf8E5M2>) // CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2> to vector<4xi8> // CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32 -// CHECK: [[TRUE:%.+]] = llvm.mlir.constant(true) : i1 -// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]]{{\[}}[[TRUE]]] : i32 +// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]][true] : i32 // CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8> // CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2> func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2>) -> vector<4xf8E5M2> { @@ -114,8 +106,7 @@ func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2>) -> // CHECK-LABEL: func @packed_stoch_round // CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32) // CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32 -// CHECK: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32 -// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]]{{\[}}[[C0]]] : i32 +// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]][0] : i32 // CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8> // CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN> func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FN> { @@ -127,8 +118,7 @@ func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FN> { // CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32, [[EXISTING:%.+]]: vector<4xf8E5M2>) // CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2> to vector<4xi8> // CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32 -// CHECK: [[C1:%.+]] = llvm.mlir.constant(1 : i32) : i32 -// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]]{{\[}}[[C1]]] : i32 +// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]][1] : i32 // CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8> // CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2> func.func @packed_stoch_round_into(%v: f32, %s: i32, %existing: vector<4xf8E5M2>) -> vector<4xf8E5M2> { diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir index 219f822ca9a1..03fcb266a2e8 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats.mlir @@ -6,8 +6,7 @@ // CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8> // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32 -// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32 -// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32 +// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]][0] : f32 // CHECK: return [[EXT]] : f32 func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 { %ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2FNUZ to f32 @@ -24,8 +23,7 @@ func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 { // CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8> // CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8> // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32 -// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32 -// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32 +// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]][1] : f32 // CHECK: return [[EXT]] : f32 func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> f32 { %ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FNUZ> to f32 @@ -35,8 +33,7 @@ func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> f32 { // CHECK-LABEL: func @ext_full_vec // CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8> // CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32 -// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32 -// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32 +// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]][3] : f32 // CHECK: return [[EXT]] : f32 func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> f32 { %ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FNUZ> to f32 @@ -53,8 +50,7 @@ func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> f32 { // CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8> // CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8> // CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32 -// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(false) : i1 -// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : vector<2xf32> +// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]][false] : vector<2xf32> // CHECK: return [[EXT]] func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FNUZ>) -> vector<2xf32> { %ret = amdgpu.ext_packed_fp8 %v[0] : vector<2xf8E4M3FNUZ> to vector<2xf32> @@ -64,8 +60,7 @@ func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FNUZ>) -> vector<2xf32> { // CHECK-LABEL: func @ext_packed_4xfp8( // CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8> // CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32 -// CHECK: [[C3:%.+]] = llvm.mlir.constant(true) : i1 -// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C3]]] : vector<2xf32> +// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]][true] : vector<2xf32> // CHECK: return [[EXT]] : vector<2xf32> func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> { %ret = amdgpu.ext_packed_fp8 %v[1] : vector<4xf8E4M3FNUZ> to vector<2xf32> @@ -76,8 +71,7 @@ func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> { // CHECK-SAME: ([[V:%.+]]: f32) // CHECK: [[V2:%.+]] = llvm.mlir.undef : f32 // CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32 -// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1 -// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32 +// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]][false] : i32 // CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8> // CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ> func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FNUZ> { @@ -88,8 +82,7 @@ func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FNUZ> { // CHECK-LABEL: func @packed_truncx2 // CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32) // CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32 -// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1 -// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32 +// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]][false] : i32 // CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8> // CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ> func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FNUZ> { @@ -101,8 +94,7 @@ func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FNUZ> { // CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[EXISTING:%.+]]: vector<4xf8E5M2FNUZ>) // CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2FNUZ> to vector<4xi8> // CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32 -// CHECK: [[TRUE:%.+]] = llvm.mlir.constant(true) : i1 -// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]]{{\[}}[[TRUE]]] : i32 +// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]][true] : i32 // CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8> // CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2FNUZ> func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2FNUZ>) -> vector<4xf8E5M2FNUZ> { @@ -113,8 +105,7 @@ func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2FNUZ> // CHECK-LABEL: func @packed_stoch_round // CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32) // CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32 -// CHECK: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32 -// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]]{{\[}}[[C0]]] : i32 +// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]][0] : i32 // CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8> // CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ> func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FNUZ> { @@ -126,8 +117,7 @@ func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FNUZ> { // CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32, [[EXISTING:%.+]]: vector<4xf8E5M2FNUZ>) // CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2FNUZ> to vector<4xi8> // CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32 -// CHECK: [[C1:%.+]] = llvm.mlir.constant(1 : i32) : i32 -// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]]{{\[}}[[C1]]] : i32 +// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]][1] : i32 // CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8> // CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2FNUZ> func.func @packed_stoch_round_into(%v: f32, %s: i32, %existing: vector<4xf8E5M2FNUZ>) -> vector<4xf8E5M2FNUZ> { diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index fbde99389134..0503c2a15860 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -789,36 +789,32 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf // CHECK: rocdl.cvt.scalef32.sr.bf8.bf16 // CHECK: rocdl.cvt.scalef32.pk.f32.fp8 // CHECK: rocdl.cvt.scalef32.pk.f32.bf8 - %c0 = llvm.mlir.constant(0 : i32) : i32 - %c2 = llvm.mlir.constant(2 : i32) : i32 - %c3 = llvm.mlir.constant(3 : i32) : i32 %c4 = llvm.mlir.constant(1.0 : f32) : f32 - %false = llvm.mlir.constant(false) : i1 - %v1 = rocdl.cvt.f32.bf8 %source[%c0] : f32 - %v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32 - %v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[%c0], %c4 : f32 - %v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32 - %v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[%false], %c4 : vector<2xf16> - %v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : vector<2xf16> - %v3_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[%false], %c4 : vector<2xbf16> - %v4_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[%false], %c4 : vector<2xbf16> - %v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %v3_scaled[%c0] : f16 - %v6 = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %v3_scaled[%c0] : f16 - %source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32 - %source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32 - %source2_ext = rocdl.cvt.pk.f32.bf8 %source[%false] : vector<2xf32> - %source3_ext = rocdl.cvt.pk.f32.fp8 %source[%false] : vector<2xf32> - %source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32 - %source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32 - %source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32 - %source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[%c3] : i32 - %source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[%c3] : i32 - %source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c3] : i32 - %source6_scaled = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[%c3] : i32 - %source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[%c3] : i32 - %source6_scaled_bfloat = rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[%c3] : i32 - %source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[%false], %c4 : f32 - %source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[%false], %c4 : f32 + %v1 = rocdl.cvt.f32.bf8 %source[0] : f32 + %v2 = rocdl.cvt.f32.fp8 %source[0] : f32 + %v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[0], %c4 : f32 + %v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[0], %c4 : f32 + %v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[false], %c4 : vector<2xf16> + %v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[false], %c4 : vector<2xf16> + %v3_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[false], %c4 : vector<2xbf16> + %v4_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[false], %c4 : vector<2xbf16> + %v5 = rocdl.cvt.scalef32.f16.fp8 %source[0], %c4 -> %v3_scaled[false] : vector<2xf16> + %v6 = rocdl.cvt.scalef32.f16.bf8 %source[0], %c4 -> %v3_scaled[false] : vector<2xf16> + %source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[false] : i32 + %source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[false] : i32 + %source2_ext = rocdl.cvt.pk.f32.bf8 %source[false] : vector<2xf32> + %source3_ext = rocdl.cvt.pk.f32.fp8 %source[false] : vector<2xf32> + %source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[2] : i32 + %source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[3] : i32 + %source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[3] : i32 + %source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[3] : i32 + %source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[3] : i32 + %source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[3] : i32 + %source6_scaled = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[3] : i32 + %source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[3] : i32 + %source6_scaled_bfloat = rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[3] : i32 + %source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[false], %c4 : vector<2xf32> + %source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[false], %c4 : vector<2xf32> llvm.return %source5 : i32 } @@ -826,9 +822,8 @@ llvm.func @rocdl_8bit_packed_v2i16(%sourceA: f32, %sourceB: f32, %old: vector<2x // CHECK-LABEL: @rocdl_8bit_packed_v2i16 // CHECK: rocdl.cvt.scalef32.pk.fp8.f32 %c0 = llvm.mlir.constant(1.0 : f32) : f32 - %false = llvm.mlir.constant(false) : i1 - %source_scaled = rocdl.cvt.scalef32.pk.fp8.f32 %sourceA, %sourceB, %c0 -> %old[%false] : vector<2xi16> - %source2_scaled = rocdl.cvt.scalef32.pk.bf8.f32 %sourceA, %sourceB, %c0 -> %old[%false] : vector<2xi16> + %source_scaled = rocdl.cvt.scalef32.pk.fp8.f32 %sourceA, %sourceB, %c0 -> %old[false] : vector<2xi16> + %source2_scaled = rocdl.cvt.scalef32.pk.bf8.f32 %sourceA, %sourceB, %c0 -> %old[false] : vector<2xi16> llvm.return %source_scaled : vector<2xi16> } @@ -836,14 +831,91 @@ llvm.func @rocdl_v2f16_v2i16(%source: vector<2xf16>, %source2: vector<2xbf16>, % // CHECK-LABEL: @rocdl_v2f16_v2i16 // CHECK: rocdl.cvt.scalef32.pk.fp8.f16 %c0 = llvm.mlir.constant(1.0 : f32) : f32 - %false = llvm.mlir.constant(false) : i1 - %source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[%false] : vector<2xi16> - %source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16> - %source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[%false] : vector<2xi16> - %source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16> + %source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[false] : vector<2xi16> + %source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[false] : vector<2xi16> + %source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[false] : vector<2xi16> + %source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[false] : vector<2xi16> llvm.return %source_scaled : vector<2xi16> } +// CHECK-LABEL: @rocdl_6_bit_floats +// CHECK-SAME: (%[[V32F6:.+]]: vector<6xi32>, %[[V16F32:.+]]: vector<16xf32>, %[[V32F32:.+]]: vector<32xf32>, %[[V32F16:.+]]: vector<32xf16>, %[[V32BF16:.+]]: vector<32xbf16>, %[[SEED:.+]]: i32, %[[SCALE:.+]]: f32) +llvm.func @rocdl_6_bit_floats( + %v32f6: vector<6xi32>, %v16f32: vector<16xf32>, %v32f32: vector<32xf32>, + %v32f16: vector<32xf16>, %v32bf16: vector<32xbf16>, %seed: i32, + %scale: f32) { + // CHECK-NEXT: rocdl.cvt.scalef32.2xpk16.bf6.f32 %[[V16F32]], %[[V16F32]], %[[SCALE]] + %f32_to_bf6 = rocdl.cvt.scalef32.2xpk16.bf6.f32 %v16f32, %v16f32, %scale : vector<6xi32> + // CHECK-NEXT: rocdl.cvt.scalef32.2xpk16.fp6.f32 %[[V16F32]], %[[V16F32]], %[[SCALE]] + %f32_to_fp6 = rocdl.cvt.scalef32.2xpk16.fp6.f32 %v16f32, %v16f32, %scale : vector<6xi32> + // CHECK-NEXT: rocdl.cvt.scalef32.pk32.bf6.f16 %[[V32F16]], %[[SCALE]] + %f16_to_bf6 = rocdl.cvt.scalef32.pk32.bf6.f16 %v32f16, %scale : vector<6xi32> + // CHECK-NEXT: rocdl.cvt.scalef32.pk32.fp6.f16 %[[V32F16]], %[[SCALE]] + %f16_to_fp6 = rocdl.cvt.scalef32.pk32.fp6.f16 %v32f16, %scale : vector<6xi32> + // CHECK-NEXT: rocdl.cvt.scalef32.pk32.bf6.bf16 %[[V32BF16]], %[[SCALE]] + %bf16_to_bf6 = rocdl.cvt.scalef32.pk32.bf6.bf16 %v32bf16, %scale : vector<6xi32> + // CHECK-NEXT: rocdl.cvt.scalef32.pk32.fp6.bf16 %[[V32BF16]], %[[SCALE]] + %bf16_to_fp6 = rocdl.cvt.scalef32.pk32.fp6.bf16 %v32bf16, %scale : vector<6xi32> + + // CHECK-NEXT: rocdl.cvt.scalef32.pk32.f32.bf6 %[[V32F6]], %[[SCALE]] + %bf6_to_f32 = rocdl.cvt.scalef32.pk32.f32.bf6 %v32f6, %scale : vector<32xf32> + // CHECK-NEXT: rocdl.cvt.scalef32.pk32.f32.fp6 %[[V32F6]], %[[SCALE]] + %fp6_to_f32 = rocdl.cvt.scalef32.pk32.f32.fp6 %v32f6, %scale : vector<32xf32> + // CHECK-NEXT: rocdl.cvt.scalef32.pk32.f16.bf6 %[[V32F6]], %[[SCALE]] + %bf6_to_f16 = rocdl.cvt.scalef32.pk32.f16.bf6 %v32f6, %scale : vector<32xf16> + // CHECK-NEXT: rocdl.cvt.scalef32.pk32.f16.fp6 %[[V32F6]], %[[SCALE]] + %fp6_to_f16 = rocdl.cvt.scalef32.pk32.f16.fp6 %v32f6, %scale : vector<32xf16> + // CHECK-NEXT: rocdl.cvt.scalef32.pk32.bf16.bf6 %[[V32F6]], %[[SCALE]] + %bf6_to_bf16 = rocdl.cvt.scalef32.pk32.bf16.bf6 %v32f6, %scale : vector<32xbf16> + // CHECK-NEXT: rocdl.cvt.scalef32.pk32.bf16.fp6 %[[V32F6]], %[[SCALE]] + %fp6_to_bf16 = rocdl.cvt.scalef32.pk32.bf16.fp6 %v32f6, %scale : vector<32xbf16> + + // CHECK-NEXT: rocdl.cvt.scalef32.sr.pk32.bf6.f32 %[[V32F32]], %[[SEED]], %[[SCALE]] + %f32_to_bf6_sr = rocdl.cvt.scalef32.sr.pk32.bf6.f32 %v32f32, %seed, %scale : vector<6xi32> + // CHECK-NEXT: rocdl.cvt.scalef32.sr.pk32.fp6.f32 %[[V32F32]], %[[SEED]], %[[SCALE]] + %f32_to_fp6_sr = rocdl.cvt.scalef32.sr.pk32.fp6.f32 %v32f32, %seed, %scale : vector<6xi32> + // CHECK-NEXT: rocdl.cvt.scalef32.sr.pk32.bf6.f16 %[[V32F16]], %[[SEED]], %[[SCALE]] + %f16_to_bf6_sr = rocdl.cvt.scalef32.sr.pk32.bf6.f16 %v32f16, %seed, %scale : vector<6xi32> + // CHECK-NEXT: rocdl.cvt.scalef32.sr.pk32.fp6.f16 %[[V32F16]], %[[SEED]], %[[SCALE]] + %f16_to_fp6_sr = rocdl.cvt.scalef32.sr.pk32.fp6.f16 %v32f16, %seed, %scale : vector<6xi32> + // CHECK-NEXT: rocdl.cvt.scalef32.sr.pk32.bf6.bf16 %[[V32BF16]], %[[SEED]], %[[SCALE]] + %bf16_to_bf6_sr = rocdl.cvt.scalef32.sr.pk32.bf6.bf16 %v32bf16, %seed, %scale : vector<6xi32> + // CHECK-NEXT: rocdl.cvt.scalef32.sr.pk32.fp6.bf16 %[[V32BF16]], %[[SEED]], %[[SCALE]] + %bf16_to_fp6_sr = rocdl.cvt.scalef32.sr.pk32.fp6.bf16 %v32bf16, %seed, %scale : vector<6xi32> + + llvm.return +} + +// CHECK-LABEL: @rocdl_4_bit_floats +// CHECK-SAME: (%[[V8F4:.+]]: i32, %[[F32:.+]]: f32, %[[V2F32:.+]]: vector<2xf32>, %[[V2F16:.+]]: vector<2xf16>, %[[V2BF16:.+]]: vector<2xbf16>, %[[SEED:.+]]: i32, %[[SCALE:.+]]: f32) +llvm.func @rocdl_4_bit_floats( + %v8f4: i32, %f32: f32, %v2f32: vector<2xf32>, %v2f16: vector<2xf16>, + %v2bf16: vector<2xbf16>, %seed: i32, %scale: f32) { + + // CHECK-NEXT: rocdl.cvt.scalef32.pk.fp4.f32 %[[F32]], %[[F32]], %[[SCALE]] -> %[[V8F4]][0] + %f32_to_fp4 = rocdl.cvt.scalef32.pk.fp4.f32 %f32, %f32, %scale -> %v8f4[0] : i32 + // CHECK-NEXT: rocdl.cvt.scalef32.pk.fp4.f16 %[[V2F16]], %[[SCALE]] -> %[[V8F4]][1] + %f16_to_fp4 = rocdl.cvt.scalef32.pk.fp4.f16 %v2f16, %scale -> %v8f4[1] : i32 + // CHECK-NEXT: rocdl.cvt.scalef32.pk.fp4.bf16 %[[V2BF16]], %[[SCALE]] -> %[[V8F4]][0] + %bf16_to_fp4 = rocdl.cvt.scalef32.pk.fp4.bf16 %v2bf16, %scale -> %v8f4[0] : i32 + + // CHECK-NEXT: rocdl.cvt.scalef32.pk.f32.fp4 %[[V8F4]][0], %[[SCALE]] + %fp4_to_f32 = rocdl.cvt.scalef32.pk.f32.fp4 %v8f4[0], %scale : vector<2xf32> + // CHECK-NEXT: rocdl.cvt.scalef32.pk.f16.fp4 %[[V8F4]][1], %[[SCALE]] + %fp4_to_f16 = rocdl.cvt.scalef32.pk.f16.fp4 %v8f4[1], %scale : vector<2xf16> + // CHECK-NEXT: rocdl.cvt.scalef32.pk.bf16.fp4 %[[V8F4]][0], %[[SCALE]] + %fp4_to_bf16 = rocdl.cvt.scalef32.pk.bf16.fp4 %v8f4[0], %scale : vector<2xbf16> + + // CHECK-NEXT: rocdl.cvt.scalef32.sr.pk.fp4.f32 %[[V2F32]], %[[SEED]], %[[SCALE]] -> %[[V8F4]][0] + %f32_to_fp4_sr = rocdl.cvt.scalef32.sr.pk.fp4.f32 %v2f32, %seed, %scale -> %v8f4[0] : i32 + // CHECK-NEXT: rocdl.cvt.scalef32.sr.pk.fp4.f16 %[[V2F16]], %[[SEED]], %[[SCALE]] -> %[[V8F4]][1] + %f16_to_fp4_sr = rocdl.cvt.scalef32.sr.pk.fp4.f16 %v2f16, %seed, %scale -> %v8f4[1] : i32 + // CHECK-NEXT: rocdl.cvt.scalef32.sr.pk.fp4.bf16 %[[V2BF16]], %[[SEED]], %[[SCALE]] -> %[[V8F4]][0] + %bf16_to_fp4_sr = rocdl.cvt.scalef32.sr.pk.fp4.bf16 %v2bf16, %seed, %scale -> %v8f4[0] : i32 + + llvm.return +} + llvm.func @rocdl.s.waitcnt() { // CHECK-LABEL: rocdl.s.waitcnt // CHECK: rocdl.s.waitcnt 0 diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index b37f0da36195..a6a03c586dd2 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -1081,34 +1081,31 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf // CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.bf8.bf16(i32 %{{.+}}, bfloat %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3) // CHECK: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.fp8(i32 %{{.+}}, float 1.000000e+00, i1 false) // CHECK: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.bf8(i32 %{{.+}}, float 1.000000e+00, i1 false) - %c0 = llvm.mlir.constant(0 : i32) : i32 - %c2 = llvm.mlir.constant(2 : i32) : i32 - %c3 = llvm.mlir.constant(3 : i32) : i32 %c4 = llvm.mlir.constant(1.0 : f32) : f32 %false = llvm.mlir.constant(false) : i1 - %v1 = rocdl.cvt.f32.bf8 %source[%c0] : f32 - %v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[%c0], %c4 : f32 - %v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32 - %v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32 - %v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[%false], %c4 : i32 - %v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : i32 - %v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %source_packed[%c0] : f16 - %v6 = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %source_packed[%c0] : f16 - %v7 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[%false], %c4 : i32 - %v8 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[%false], %c4 : i32 - %source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32 - %source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32 - %source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32 - %source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32 - %source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32 - %source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[%c3] : i32 - %source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[%c3] : i32 - %source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c3] : i32 - %source6_scaled = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[%c3] : i32 - %source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[%c3] : i32 - %source6_scaled_bfloat = rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[%c3] : i32 - %source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[%false], %c4 : f32 - %source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[%false], %c4 : f32 + %v1 = rocdl.cvt.f32.bf8 %source[0] : f32 + %v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[0], %c4 : f32 + %v2 = rocdl.cvt.f32.fp8 %source[0] : f32 + %v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[0], %c4 : f32 + %v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[false], %c4 : vector<2xf16> + %v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[false], %c4 : vector<2xf16> + %v5 = rocdl.cvt.scalef32.f16.fp8 %source[0], %c4 -> %source_packed[false] : vector<2xf16> + %v6 = rocdl.cvt.scalef32.f16.bf8 %source[0], %c4 -> %source_packed[false] : vector<2xf16> + %v7 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[false], %c4 : vector<2xbf16> + %v8 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[false], %c4 : vector<2xbf16> + %source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[false] : i32 + %source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[false] : i32 + %source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[2] : i32 + %source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[3] : i32 + %source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[3] : i32 + %source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[3] : i32 + %source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[3] : i32 + %source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[3] : i32 + %source6_scaled = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[3] : i32 + %source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[3] : i32 + %source6_scaled_bfloat = rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[3] : i32 + %source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[false], %c4 : vector<2xf32> + %source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[false], %c4 : vector<2xf32> llvm.return %source5 : i32 } @@ -1117,9 +1114,8 @@ llvm.func @rocdl_8bit_packed_v2i16(%sourceA: f32, %sourceB: f32, %old: vector<2x // CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.fp8.f32(<2 x i16> %{{.+}}, float %{{.+}}, float %{{.+}}, float 1.000000e+00, i1 false) // CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.bf8.f32(<2 x i16> %{{.+}}, float %{{.+}}, float %{{.+}}, float 1.000000e+00, i1 false) %c0 = llvm.mlir.constant(1.0 : f32) : f32 - %false = llvm.mlir.constant(false) : i1 - %source_scaled = rocdl.cvt.scalef32.pk.fp8.f32 %sourceA, %sourceB, %c0 -> %old[%false] : vector<2xi16> - %source2_scaled = rocdl.cvt.scalef32.pk.bf8.f32 %sourceA, %sourceB, %c0 -> %old[%false] : vector<2xi16> + %source_scaled = rocdl.cvt.scalef32.pk.fp8.f32 %sourceA, %sourceB, %c0 -> %old[false] : vector<2xi16> + %source2_scaled = rocdl.cvt.scalef32.pk.bf8.f32 %sourceA, %sourceB, %c0 -> %old[false] : vector<2xi16> llvm.return %source_scaled : vector<2xi16> } @@ -1130,11 +1126,10 @@ llvm.func @rocdl_v2f16_v2i16(%source: vector<2xf16>, %source2: vector<2xbf16>, % // CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.bf8.f16(<2 x i16> %2, <2 x half> %0, float 1.000000e+00, i1 false) // CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.bf8.bf16(<2 x i16> %2, <2 x bfloat> %1, float 1.000000e+00, i1 false) %c0 = llvm.mlir.constant(1.0 : f32) : f32 - %false = llvm.mlir.constant(false) : i1 - %source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[%false] : vector<2xi16> - %source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16> - %source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[%false] : vector<2xi16> - %source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16> + %source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[false] : vector<2xi16> + %source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[false] : vector<2xi16> + %source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[false] : vector<2xi16> + %source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[false] : vector<2xi16> llvm.return %source_scaled : vector<2xi16> } @@ -1145,6 +1140,83 @@ llvm.func @rocdl_16bit_packed_floats(%sourceA: f32, %sourceB: f32) -> vector<2xf llvm.return %source : vector<2xf16> } +// CHECK-LABEL: @rocdl_6_bit_floats +// CHECK-SAME: (<6 x i32> %[[V32F6:.+]], <16 x float> %[[V16F32:.+]], <32 x float> %[[V32F32:.+]], <32 x half> %[[V32F16:.+]], <32 x bfloat> %[[V32BF16:.+]], i32 %[[SEED:.+]], float %[[SCALE:.+]]) +llvm.func @rocdl_6_bit_floats( + %v32f6: vector<6xi32>, %v16f32: vector<16xf32>, %v32f32: vector<32xf32>, + %v32f16: vector<32xf16>, %v32bf16: vector<32xbf16>, %seed: i32, + %scale: f32) { + // CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.2xpk16.bf6.f32(<16 x float> %[[V16F32]], <16 x float> %[[V16F32]], float %[[SCALE]]) + %f32_to_bf6 = rocdl.cvt.scalef32.2xpk16.bf6.f32 %v16f32, %v16f32, %scale : vector<6xi32> + // CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.2xpk16.fp6.f32(<16 x float> %[[V16F32]], <16 x float> %[[V16F32]], float %[[SCALE]]) + %f32_to_fp6 = rocdl.cvt.scalef32.2xpk16.fp6.f32 %v16f32, %v16f32, %scale : vector<6xi32> + // CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.pk32.bf6.f16(<32 x half> %[[V32F16]], float %[[SCALE]]) + %f16_to_bf6 = rocdl.cvt.scalef32.pk32.bf6.f16 %v32f16, %scale : vector<6xi32> + // CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.pk32.fp6.f16(<32 x half> %[[V32F16]], float %[[SCALE]]) + %f16_to_fp6 = rocdl.cvt.scalef32.pk32.fp6.f16 %v32f16, %scale : vector<6xi32> + // CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.pk32.bf6.bf16(<32 x bfloat> %[[V32BF16]], float %[[SCALE]]) + %bf16_to_bf6 = rocdl.cvt.scalef32.pk32.bf6.bf16 %v32bf16, %scale : vector<6xi32> + // CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.pk32.fp6.bf16(<32 x bfloat> %[[V32BF16]], float %[[SCALE]]) + %bf16_to_fp6 = rocdl.cvt.scalef32.pk32.fp6.bf16 %v32bf16, %scale : vector<6xi32> + + // CHECK-NEXT: call <32 x float> @llvm.amdgcn.cvt.scalef32.pk32.f32.bf6(<6 x i32> %[[V32F6]], float %[[SCALE]]) + %bf6_to_f32 = rocdl.cvt.scalef32.pk32.f32.bf6 %v32f6, %scale : vector<32xf32> + // CHECK-NEXT: call <32 x float> @llvm.amdgcn.cvt.scalef32.pk32.f32.fp6(<6 x i32> %[[V32F6]], float %[[SCALE]]) + %fp6_to_f32 = rocdl.cvt.scalef32.pk32.f32.fp6 %v32f6, %scale : vector<32xf32> + // CHECK-NEXT: call <32 x half> @llvm.amdgcn.cvt.scalef32.pk32.f16.bf6(<6 x i32> %[[V32F6]], float %[[SCALE]]) + %bf6_to_f16 = rocdl.cvt.scalef32.pk32.f16.bf6 %v32f6, %scale : vector<32xf16> + // CHECK-NEXT: call <32 x half> @llvm.amdgcn.cvt.scalef32.pk32.f16.fp6(<6 x i32> %[[V32F6]], float %[[SCALE]]) + %fp6_to_f16 = rocdl.cvt.scalef32.pk32.f16.fp6 %v32f6, %scale : vector<32xf16> + // CHECK-NEXT: call <32 x bfloat> @llvm.amdgcn.cvt.scalef32.pk32.bf16.bf6(<6 x i32> %[[V32F6]], float %[[SCALE]]) + %bf6_to_bf16 = rocdl.cvt.scalef32.pk32.bf16.bf6 %v32f6, %scale : vector<32xbf16> + // CHECK-NEXT: call <32 x bfloat> @llvm.amdgcn.cvt.scalef32.pk32.bf16.fp6(<6 x i32> %[[V32F6]], float %[[SCALE]]) + %fp6_to_bf16 = rocdl.cvt.scalef32.pk32.bf16.fp6 %v32f6, %scale : vector<32xbf16> + + // CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk32.bf6.f32(<32 x float> %[[V32F32]], i32 %[[SEED]], float %[[SCALE]]) + %f32_to_bf6_sr = rocdl.cvt.scalef32.sr.pk32.bf6.f32 %v32f32, %seed, %scale : vector<6xi32> + // CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk32.fp6.f32(<32 x float> %[[V32F32]], i32 %[[SEED]], float %[[SCALE]]) + %f32_to_fp6_sr = rocdl.cvt.scalef32.sr.pk32.fp6.f32 %v32f32, %seed, %scale : vector<6xi32> + // CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk32.bf6.f16(<32 x half> %[[V32F16]], i32 %[[SEED]], float %[[SCALE]]) + %f16_to_bf6_sr = rocdl.cvt.scalef32.sr.pk32.bf6.f16 %v32f16, %seed, %scale : vector<6xi32> + // CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk32.fp6.f16(<32 x half> %[[V32F16]], i32 %[[SEED]], float %[[SCALE]]) + %f16_to_fp6_sr = rocdl.cvt.scalef32.sr.pk32.fp6.f16 %v32f16, %seed, %scale : vector<6xi32> + // CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk32.bf6.bf16(<32 x bfloat> %[[V32BF16]], i32 %[[SEED]], float %[[SCALE]]) + %bf16_to_bf6_sr = rocdl.cvt.scalef32.sr.pk32.bf6.bf16 %v32bf16, %seed, %scale : vector<6xi32> + // CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk32.fp6.bf16(<32 x bfloat> %[[V32BF16]], i32 %[[SEED]], float %[[SCALE]]) + %bf16_to_fp6_sr = rocdl.cvt.scalef32.sr.pk32.fp6.bf16 %v32bf16, %seed, %scale : vector<6xi32> + + llvm.return +} + +// CHECK-LABEL: @rocdl_4_bit_floats +// CHECK-SAME: (i32 %[[V8F4:.+]], float %[[F32:.+]], <2 x float> %[[V2F32:.+]], <2 x half> %[[V2F16:.+]], <2 x bfloat> %[[V2BF16:.+]], i32 %[[SEED:.+]], float %[[SCALE:.+]]) +llvm.func @rocdl_4_bit_floats( + %v8f4: i32, %f32: f32, %v2f32: vector<2xf32>, %v2f16: vector<2xf16>, + %v2bf16: vector<2xbf16>, %seed: i32, %scale: f32) { + + // CHECK-NEXT: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.f32(i32 %[[V8F4]], float %[[F32]], float %[[F32]], float %[[SCALE]], i32 0) + %f32_to_fp4 = rocdl.cvt.scalef32.pk.fp4.f32 %f32, %f32, %scale -> %v8f4[0] : i32 + // CHECK-NEXT: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.f16(i32 %[[V8F4]], <2 x half> %[[V2F16]], float %[[SCALE]], i32 1) + %f16_to_fp4 = rocdl.cvt.scalef32.pk.fp4.f16 %v2f16, %scale -> %v8f4[1] : i32 + // CHECK-NEXT: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.bf16(i32 %[[V8F4]], <2 x bfloat> %[[V2BF16]], float %[[SCALE]], i32 0) + %bf16_to_fp4 = rocdl.cvt.scalef32.pk.fp4.bf16 %v2bf16, %scale -> %v8f4[0] : i32 + + // CHECK-NEXT: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.fp4(i32 %[[V8F4]], float %[[SCALE]], i32 0) + %fp4_to_f32 = rocdl.cvt.scalef32.pk.f32.fp4 %v8f4[0], %scale : vector<2xf32> + // CHECK-NEXT: call <2 x half> @llvm.amdgcn.cvt.scalef32.pk.f16.fp4(i32 %[[V8F4]], float %[[SCALE]], i32 1) + %fp4_to_f16 = rocdl.cvt.scalef32.pk.f16.fp4 %v8f4[1], %scale : vector<2xf16> + // CHECK-NEXT: call <2 x bfloat> @llvm.amdgcn.cvt.scalef32.pk.bf16.fp4(i32 %[[V8F4]], float %[[SCALE]], i32 0) + %fp4_to_bf16 = rocdl.cvt.scalef32.pk.bf16.fp4 %v8f4[0], %scale : vector<2xbf16> + + // CHECK-NEXT: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.f32(i32 %[[V8F4]], <2 x float> %[[V2F32]], i32 %[[SEED]], float %[[SCALE]], i32 0) + %f32_to_fp4_sr = rocdl.cvt.scalef32.sr.pk.fp4.f32 %v2f32, %seed, %scale -> %v8f4[0] : i32 + // CHECK-NEXT: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.f16(i32 %[[V8F4]], <2 x half> %[[V2F16]], i32 %[[SEED]], float %[[SCALE]], i32 1) + %f16_to_fp4_sr = rocdl.cvt.scalef32.sr.pk.fp4.f16 %v2f16, %seed, %scale -> %v8f4[1] : i32 + // CHECK-NEXT: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.bf16(i32 %[[V8F4]], <2 x bfloat> %[[V2BF16]], i32 %[[SEED]], float %[[SCALE]], i32 0) + %bf16_to_fp4_sr = rocdl.cvt.scalef32.sr.pk.fp4.bf16 %v2bf16, %seed, %scale -> %v8f4[0] : i32 + + llvm.return +} llvm.func @rocdl_atomic_attrs(%ptr: !llvm.ptr<1>, %data: f32) { // CHECK-LABEL: @rocdl_atomic_attrs // CHECK: atomicrmw