[mlir][ROCDL] Add fp4 and fp6 conversion intrinsics, fix fp8 immargs (#140801)

This PR adds support for the scaled conversion intrinsics for fp4 and
fp6 types so that they can be targetted by a future amdgpu dialect op or
used directly.

Additionally, this patch refactors the copy-paste-heavy fp8 versions of
these scaled conversion intrinsics with tablegen `foreach` loops, and
fixes the fact that certain immargs weren't being stored as attributes.

Note that some of the MLIR-level tests for those scaled fp8 intrinsics
had incorrect return types, which have been fixed.

(Note that while the operations have a known return type, the IR format
still prints that type for clarity).
This commit is contained in:
Krzysztof Drewniak
2025-05-21 13:50:02 -07:00
committed by GitHub
parent 6212c199b1
commit 6c813e8a3c
6 changed files with 605 additions and 449 deletions

View File

@@ -709,20 +709,23 @@ def ROCDL_PermlaneX16Op : ROCDL_IntrOp<"permlanex16", [], [0],
}];
}
def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
BuildableType<"::mlir::VectorType::get("
"{2},$_builder.getI16Type())">;
class ROCDL_ConcreteVector<Type elem, int length> :
FixedVectorOfLengthAndType<[length], [elem]>,
BuildableType<
"::mlir::VectorType::get({" # length # "} ,"
# elem.builderCall # ")">;
def ROCDL_V2F16Type : FixedVectorOfLengthAndType<[2], [F16]>,
BuildableType<"::mlir::VectorType::get("
"{2},$_builder.getF16Type())">;
def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
def ROCDL_V32F32Type : ROCDL_ConcreteVector<F32, 32>;
def ROCDL_V2BF16Type : FixedVectorOfLengthAndType<[2], [BF16]>,
BuildableType<"::mlir::VectorType::get("
"{2},$_builder.getBF16Type())">;
// TODO: The word and byte selectors are immarg in LLVM
// update to be attributes in MLIR
//===---------------------------------------------------------------------===//
// 16-bit float intrinsics
//===---------------------------------------------------------------------===//
@@ -738,279 +741,12 @@ def ROCDL_CvtPkRtz:
}];
}
def ROCDL_CvtScaleF32PkFp8F16Op :
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f16", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert f16 to packed fp8";
let description = [{
Scale `src` by the exponent in `scale`, then convert to packed fp8.
Store the result in low/high word of `old` based on $wordSel, preserving the other word.
}];
let assemblyFormat = [{
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
}];
}
def ROCDL_CvtScaleF32PkFp8Bf16Op :
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.bf16", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert packed bf16 to packed fp8";
let description = [{
Scale `src` by the exponent in `scale`, then convert to packed fp8.
Store the result in low/high word of `old` based on $wordSel, preserving the other word.
}];
let assemblyFormat = [{
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
}];
}
def ROCDL_CvtScaleF32PkBf8F16Op :
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f16", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert f16 to packed bf8";
let description = [{
Scale `src` by the exponent in `scale`, then convert to packed bf8.
Store the result in low/high word of `old` based on $wordSel, preserving the other word.
}];
let assemblyFormat = [{
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
}];
}
def ROCDL_CvtScaleF32PkBf8Bf16Op :
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.bf16", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert bf16 to packed bf8";
let description = [{
Scale `src` by the exponent in `scale`, then convert to packed bf8.
Store the result in low/high word of `old` based on $wordSel, preserving the other word.
}];
let assemblyFormat = [{
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
}];
}
def ROCDL_CvtScaleF32SrFp8F16Op :
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f16", [], [], [Pure], 1>,
Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert f16 to packed fp8 using stochastic rounding";
let description = [{
Scale `src` by the exponent in `scale`, then convert to packed p8 with stochastic rounding
using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
}];
let assemblyFormat = [{
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
}];
}
def ROCDL_CvtScaleF32SrBf8F16Op :
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f16", [], [], [Pure], 1>,
Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert f16 to packed bf8 using stochastic rounding";
let description = [{
Scale `src` by the exponent in `scale`, then convert to packed bf8 with stochastic rounding
using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
}];
let assemblyFormat = [{
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
}];
}
def ROCDL_CvtScaleF32SrFp8Bf16Op :
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.bf16", [], [], [Pure], 1>,
Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert packed bf16 to packed fp8 using stochastic rounding";
let description = [{
Scale `src` by the exponent in `scale`, then convert to packed fp8 with stochastic rounding
using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
}];
let assemblyFormat = [{
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
}];
}
def ROCDL_CvtScaleF32SrBf8Bf16Op :
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.bf16", [], [], [Pure], 1>,
Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert bf16 to packed fp8 using stochastic rounding";
let description = [{
Scale `src` by the exponent in `scale`, then convert to packed p8 with stochastic rounding
using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
}];
let assemblyFormat = [{
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
}];
}
def ROCDL_CvtScaleF32PkF16Fp8Op :
ROCDL_IntrOp<"cvt.scalef32.pk.f16.fp8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
let summary = "Convert fp8 to packed f16 and scale";
let description = [{ Convert `src` based on $wordSel to packed f16, then scale
the packed values by the exponent in `scale`.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
}];
}
def ROCDL_CvtScaleF32PkF16Bf8Op :
ROCDL_IntrOp<"cvt.scalef32.pk.f16.bf8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
let summary = "convert bf8 to packed f16 and scale";
let description = [{ Convert `src` based on $wordSel to packed f16, then scale
the packed values by exponent in `scale`.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
}];
}
def ROCDL_CvtScaleF32PkBf16Fp8Op :
ROCDL_IntrOp<"cvt.scalef32.pk.bf16.fp8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
let summary = "Convert fp8 to packed bf16 and scale";
let description = [{ Convert `src` based on $wordSel to packed bf16, then scale
the packed values by the exponent in `scale`.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
}];
}
def ROCDL_CvtScaleF32PkBf16Bf8Op :
ROCDL_IntrOp<"cvt.scalef32.pk.bf16.bf8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
let summary = "Convert bf8 to packed bf16 and scale";
let description = [{ Convert `src` based on $wordSel to packed bf16, then scale
the packed values by the exponent in `scale`.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
}];
}
def ROCDL_CvtScaleF16Fp8Op :
ROCDL_IntrOp<"cvt.scalef32.f16.fp8", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
let summary = "Scale and convert fp8 to f16";
let description = [{ Convert `src` based on $wordSel to f16, then scale the value
by the exponent in `scale`. Store the result into the `byteSel`th byte of `old`,
preserving the others.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
}];
}
def ROCDL_CvtScaleF16Bf8Op :
ROCDL_IntrOp<"cvt.scalef32.f16.bf8", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
let summary = "Scale and convert fp8 to f16";
let description = [{ Convert `src` based on $wordSel to f16, then scale the value
by the exponent in `scale`. Store the result into the `byteSel`th byte of `old`,
preserving the others.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
}];
}
//===---------------------------------------------------------------------===//
// 32-bit float intrinsics
//===---------------------------------------------------------------------===//
def ROCDL_CvtScaleF32PkF32Fp8Op :
ROCDL_IntrOp<"cvt.scalef32.pk.f32.fp8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert packed fp8 to packed f32";
let description = [{
Convert `src` based on $wordSel to packed fp32, then scale the packed values by
the exponent in `scale`. Store the result in a vector.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
}];
}
def ROCDL_CvtScaleF32PkF32Bf8Op :
ROCDL_IntrOp<"cvt.scalef32.pk.f32.bf8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert packed bf8 to packed f32";
let description = [{
Convert `src` based on $wordSel to packed fp32, then scale the packed values by
the exponent in `scale`. Store the result in a vector.
}];
let assemblyFormat = [{
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
}];
}
//===---------------------------------------------------------------------===//
// 8-bit float scale intrinsics
//===---------------------------------------------------------------------===//
def ROCDL_CvtScaleF32PkFp8F32Op :
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f32", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32:$scale, I1:$wordSel)> {
let summary = "Scale and convert two f32's to packed fp8";
let description = [{
Scale `srcA` and `srcB` by the exponent in `scale` then convert to packed fp8
and store into the low/high word of `old`, preserving the other word.
}];
let assemblyFormat = [{
attr-dict $srcA `,` $srcB `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
}];
}
def ROCDL_CvtScaleF32PkBf8F32Op :
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f32", [], [], [Pure], 1>,
Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32: $scale, I1:$wordSel)> {
let summary = "Scale and convert two f32's to packed bf8";
let description = [{
Scale `srcA` and `srcB` by the exponent in `scale` then convert to packed bf8
and store into the low/high word of `old`, preserving the other word.
}];
let assemblyFormat = [{
attr-dict $srcA `,` $srcB `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
}];
}
def ROCDL_CvtScaleF32SrFp8F32Op :
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f32", [], [], [Pure], 1>,
Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert f32 to fp8 using stochastic rounding";
let description = [{
Scale `src` by the exponent in `scale` then convert to fp8 with stochastic rounding
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
}];
let assemblyFormat = [{
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
}];
}
def ROCDL_CvtScaleF32SrBf8F32Op :
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f32", [], [], [Pure], 1>,
Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert f32 to bf8 using stochastic rounding";
let description = [{
Scale `src` by the exponent in `scale` then convert to bf8 with stochastic rounding
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
}];
let assemblyFormat = [{
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
}];
}
//===---------------------------------------------------------------------===//
// 8-bit float intrinsics
//===---------------------------------------------------------------------===//
def ROCDL_CvtF32Bf8Op :
ROCDL_IntrOp<"cvt.f32.bf8", [], [], [Pure], 1>,
Arguments<(ins I32:$srcA, I32:$byteSel)> {
ROCDL_ConcreteNonMemIntrOp<"cvt.f32.bf8", [Pure], 1, [1], ["byteSel"]>,
Arguments<(ins I32:$srcA, I32Attr:$byteSel)> {
let summary = "Convert bf8 to f32";
let description = [{
Convert 8-bit bf8 value from the `byteSel`th bit of `srcA` to fp32.
@@ -1020,23 +756,9 @@ def ROCDL_CvtF32Bf8Op :
}];
}
def ROCDL_CvtScaleF32Bf8Op :
ROCDL_IntrOp<"cvt.scalef32.f32.bf8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert bf8 to f32";
let description = [{
Scale `src` by the exponent in `scale` then convert 8-bit bf8 value
from the `byteSel`th bit of `src` to fp32.
}];
let assemblyFormat = [{
attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
}];
}
def ROCDL_CvtF32Fp8Op :
ROCDL_IntrOp<"cvt.f32.fp8", [], [], [Pure], 1>,
Arguments<(ins I32:$srcA, I32:$byteSel)> {
ROCDL_ConcreteNonMemIntrOp<"cvt.f32.fp8", [Pure], 1, [1], ["byteSel"]>,
Arguments<(ins I32:$srcA, I32Attr:$byteSel)> {
let summary = "Convert fp8 to f32";
let description = [{
Convert 8-bit fp8 value from the `byteSel`th bit of `srcA` to fp32.
@@ -1046,24 +768,9 @@ def ROCDL_CvtF32Fp8Op :
}];
}
def ROCDL_CvtScaleF32Fp8Op :
ROCDL_IntrOp<"cvt.scalef32.f32.fp8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert fp8 to f32";
let description = [{
Scale `src` by the exponent in `scale` then convert 8-bit fp8 value
from the `byteSel`th bit of `src` to fp32.
}];
let assemblyFormat = [{
attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
}];
}
def ROCDL_CvtPkF32Fp8Op :
ROCDL_IntrOp<"cvt.pk.f32.fp8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, I1:$wordSel)> {
ROCDL_ConcreteNonMemIntrOp<"cvt.pk.f32.fp8", [Pure], 1, [1], ["wordSel"]>,
Arguments<(ins I32:$src, I1Attr:$wordSel)> {
let summary = "Convert packed fp8 to packed f32";
let description = [{
Convert `src` based on $wordSel to packed fp32.
@@ -1074,8 +781,8 @@ def ROCDL_CvtPkF32Fp8Op :
}
def ROCDL_CvtPkF32Bf8Op :
ROCDL_IntrOp<"cvt.pk.f32.bf8", [], [], [Pure], 1>,
Arguments<(ins I32:$src, I1:$wordSel)> {
ROCDL_ConcreteNonMemIntrOp<"cvt.pk.f32.bf8", [Pure], 1, [1], ["wordSel"]>,
Arguments<(ins I32:$src, I1Attr:$wordSel)> {
let summary = "Convert packed bf8 to packed f32";
let description = [{
Convert `src` based on $wordSel to packed fp32,
@@ -1086,8 +793,8 @@ def ROCDL_CvtPkF32Bf8Op :
}
def ROCDL_CvtPkBf8F32Op :
ROCDL_IntrOp<"cvt.pk.bf8.f32", [], [], [Pure], 1>,
Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1:$wordSel)> {
ROCDL_ConcreteNonMemIntrOp<"cvt.pk.bf8.f32", [Pure], 1, [3], ["wordSel"]>,
Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1Attr:$wordSel)> {
let summary = "Convert two f32's to bf8";
let description = [{
Convert `srcA` and `srcB` to bf8 and store into the low/high word of
@@ -1099,8 +806,8 @@ def ROCDL_CvtPkBf8F32Op :
}
def ROCDL_CvtPkFp8F32Op :
ROCDL_IntrOp<"cvt.pk.fp8.f32", [], [], [Pure], 1>,
Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1:$wordSel)> {
ROCDL_ConcreteNonMemIntrOp<"cvt.pk.fp8.f32", [Pure], 1, [3], ["wordSel"]>,
Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1Attr:$wordSel)> {
let summary = "Convert two f32's to fp8";
let description = [{
Convert `srcA` and `srcB` to fp8 and store into the low/high word of
@@ -1112,8 +819,8 @@ def ROCDL_CvtPkFp8F32Op :
}
def ROCDL_CvtSrBf8F32Op :
ROCDL_IntrOp<"cvt.sr.bf8.f32", [], [], [Pure], 1>,
Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32:$byteSel)> {
ROCDL_ConcreteNonMemIntrOp<"cvt.sr.bf8.f32", [Pure], 1, [3], ["byteSel"]>,
Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32Attr:$byteSel)> {
let summary = "Convert f32 to bf8, stochiastic rounding";
let description = [{
Convert `srcA` to bf8, adding the rounding factor from `srcB`,
@@ -1125,8 +832,8 @@ def ROCDL_CvtSrBf8F32Op :
}
def ROCDL_CvtSrFp8F32Op :
ROCDL_IntrOp<"cvt.sr.fp8.f32", [], [], [Pure], 1>,
Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32:$byteSel)> {
ROCDL_ConcreteNonMemIntrOp<"cvt.sr.fp8.f32", [Pure], 1, [3], ["byteSel"]>,
Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32Attr:$byteSel)> {
let summary = "Convert f32 to fp8, stochiastic rounding";
let description = [{
Convert `srcA` to fp8, adding the rounding factor from `srcB`,
@@ -1137,6 +844,335 @@ def ROCDL_CvtSrFp8F32Op :
}];
}
//===---------------------------------------------------------------------===//
// Scaled float conversion intrinsics
//
// These are using some tablegen trickery to avoid repetitive documentation
//===---------------------------------------------------------------------===//
// Pair used so we can iterate over types..
class ScaleArgInfo<TypeConstraint argTyVal, string typeName> {
TypeConstraint type = argTyVal;
string name = !tolower(typeName);
string nameForOp = typeName;
}
//===---------------------------------------------------------------------===//
// Scaled 32x6-bit float float conversion intrinsics
//===---------------------------------------------------------------------===//
foreach smallT = [
// MLIR f6E2M3FN
ScaleArgInfo<ROCDL_V6I32Type, "Fp6">,
// MLIR f8E3M2FN
ScaleArgInfo<ROCDL_V6I32Type, "Bf6">
] in {
foreach largeT = [
ScaleArgInfo<ROCDL_V32F16Type, "F16">,
ScaleArgInfo<ROCDL_V32BF16Type, "Bf16">,
ScaleArgInfo<ROCDL_V32F32Type, "F32">,
] in {
// Note: rouding down f32 values has a special case where
// we have to use 2 16xf32 arguments.
if !ne(largeT.name, "f32") then {
def ROCDL_CvtScaleF32Pk32 # smallT.nameForOp # largeT.nameForOp # Op :
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk32." # smallT.name # "." # largeT.name,
[Pure], 1>,
Arguments<(ins largeT.type:$src, F32:$scale)> {
let results = (outs smallT.type:$res);
let summary = "Scale and convert packed "
# largeT.name # " to packed " # smallT.name;
let description = [{
Convert 32 packed }] # largeT.name # [{ values to packed }]
# smallT.name # [{, dividing by the exponent part of `scale`
before doing so.
}];
let assemblyFormat = [{
attr-dict $src `,` $scale `:` type($res)
}];
}
} // if
def ROCDL_CvtScaleF32SrPk32 # smallT.nameForOp # largeT.nameForOp # Op :
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.sr.pk32." # smallT.name # "." # largeT.name,
[Pure], 1>,
Arguments<(ins largeT.type:$src, I32:$seed, F32:$scale)> {
let results = (outs smallT.type:$res);
let summary = "Scale and convert packed "
# largeT.name # " to packed " # smallT.name
# " with stochiastic rounding";
let description = [{
Convert 32 packed }] # largeT.name # [{ values to packed }]
# smallT.name # [{, dividing by the exponent part of `scale`
before doing so and applying random rounding derived from
`seed`.
}];
let assemblyFormat = [{
attr-dict $src `,` $seed `,` $scale `:` type($res)
}];
}
def ROCDL_CvtScaleF32Pk32 # largeT.nameForOp # smallT.nameForOp # Op :
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk32." # largeT.name # "." # smallT.name,
[Pure], 1>,
Arguments<(ins smallT.type:$src, F32:$scale)> {
let results = (outs largeT.type:$res);
let summary = "Scale and convert packed "
# smallT.name # " to packed " # largeT.name;
let description = [{
Convert 32 packed }] # smallT.name # [{ values to packed }]
# largeT.name # [{, multiplying by the exponent part of `scale`
before doing so.
}];
let assemblyFormat = [{
attr-dict $src `,` $scale `:` type($res)
}];
}
} // foreach largeT
def ROCDL_CvtScaleF322xPk16 # smallT.nameForOp # F32Op :
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.2xpk16." # smallT.name # ".f32",
[Pure], 1>,
Arguments<(ins ROCDL_V16F32Type:$src0, ROCDL_V16F32Type:$src1, F32:$scale)> {
let results = (outs smallT.type:$res);
let summary = "Scale and convert two vector<16xf32> to 32 packed " # smallT.name;
let description = [{
Convert 32 single-precision float values, packed into two length-16
vectors that will be logically concanenated, to packed }]
# smallT.name # [{, dividing by the exponent part of `scale`
before doing so.
}];
let assemblyFormat = [{
attr-dict $src0 `,` $src1 `,` $scale `:` type($res)
}];
}
} // forach smallT
//===---------------------------------------------------------------------===//
// Scaled conversions to/from fp8/bf8 (f8E4M3FN / f8E5M2)
//===---------------------------------------------------------------------===//
foreach smallTOp = ["Fp8", "Bf8"] in {
defvar smallT = !tolower(smallTOp);
def ROCDL_CvtScaleF32F16 # smallTOp # Op :
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.f16." # smallT,
[Pure], 1, [3, 4], ["srcSelIndex", "dstLoHiSel"]>,
Arguments<(ins ROCDL_V2F16Type:$oldVdst, I32:$src, F32:$scale, I32Attr:$srcSelIndex, I1Attr:$dstLoHiSel)> {
let results = (outs ROCDL_V2F16Type:$res);
let summary = "Scaled convert " # smallT # " from packed vector to f16, updating tied result";
let description = [{
Convert a }] # smallT # [{ byte from `src`, selected by
`srcSelIndex`, to f16 while multiplying it by the expontent of `scale`,
and place it into the `dstLoHiSel`th bit
of `oldVdst` preserving the other element of that vector in
the return value.
The bytes are stored as an `i32` and not a `<4 x i8>`.
}];
let assemblyFormat = [{
attr-dict $src `[` $srcSelIndex `]` `,` $scale `->` $oldVdst `[` $dstLoHiSel `]` `:` type($res)
}];
}
def ROCDL_CvtScaleF32F32 # smallTOp # Op :
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.f32." # smallT,
[Pure], 1, [2], ["srcSelIndex"]>,
Arguments<(ins I32:$src, F32:$scale, I32Attr:$srcSelIndex)> {
let results = (outs F32:$res);
let summary = "Scaled convert " # smallT # " from packed vector to f32";
let description = [{
Convert a }] # smallT # [{ byte from `src`, selected by
`srcSelIndex`, to f32, multiplying it by the exponent of `scale`.
The bytes are stored in an `i32`, not a `<4 x i8>`.
}];
let assemblyFormat = [{
attr-dict $src `[` $srcSelIndex `]` `,` $scale `:` type($res)
}];
}
def ROCDL_CvtScaleF32Pk # smallTOp # F32Op :
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk." # smallT # ".f32",
[Pure], 1, [4], ["dstLoHiSel"]>,
Arguments<(ins ROCDL_V2I16Type:$oldVdst, F32:$src0, F32:$src1, F32:$scale, I1Attr:$dstLoHiSel)> {
let results = (outs ROCDL_V2I16Type:$res);
let summary = "Scaled convert two f32 to two " # smallT # ", updating packed vector";
let description = [{
Convert two f32 values in `src0` and `src1` to two }] # smallT # [{ bytes,
dividing by the exponent in `scale`. The bytes are packed into
a 16-bit value which is inserted into `oldVdst` at the `dstLoHiSel`
position, with the entire updated vector being returned.
}];
let assemblyFormat = [{
attr-dict $src0 `,` $src1 `,` $scale `->` $oldVdst `[` $dstLoHiSel `]` `:` type($res)
}];
}
foreach largeT = [
ScaleArgInfo<ROCDL_V2F16Type, "F16">,
ScaleArgInfo<ROCDL_V2BF16Type, "Bf16">,
] in {
def ROCDL_CvtScaleF32Pk # smallTOp # largeT.nameForOp # Op :
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk." # smallT # "." # largeT.name,
[Pure], 1, [3], ["dstLoHiSel"]>,
Arguments<(ins ROCDL_V2I16Type:$oldVdst, largeT.type:$src0, F32:$scale, I1Attr:$dstLoHiSel)> {
let results = (outs ROCDL_V2I16Type:$res);
let summary = "Scaled convert two " # largeT.name # "to two " # smallT # ", updating packed vector";
let description = [{
Convert two }] # largeT.name # [{ values in `src0` to two }]
# smallT # [{ bytes, dividing by the exponent in `scale`. The bytes are
packed into a 16-bit value which is inserted into `oldVdst` at the
`dstLoHiSel` position, with the entire updated vector being returned.
}];
let assemblyFormat = [{
attr-dict $src0 `,` $scale `->` $oldVdst `[` $dstLoHiSel `]` `:` type($res)
}];
}
} // foreach largeT
foreach largeT = [
ScaleArgInfo<ROCDL_V2F16Type, "F16">,
ScaleArgInfo<ROCDL_V2BF16Type, "Bf16">,
ScaleArgInfo<ROCDL_V2F32Type, "F32">
] in {
def ROCDL_CvtScaleF32Pk # largeT.nameForOp # smallTOp # Op :
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk." # largeT.name # "." # smallT,
[Pure], 1, [2], ["srcLoHiSel"]>,
Arguments<(ins I32:$src, F32:$scale, I1Attr:$srcLoHiSel)> {
let results = (outs largeT.type:$res);
let summary = "Scaled convert two " # smallT # "to two " # largeT.name #;
let description = [{
Convert two packed }] # smallT # [{ values in `src0` to two }]
# largeT.name # [{ values, multiplying by the exponent in `scale`.
The two values to be converted are selected from the low or high half
of `src` (a packed vector represented as an `i32`)
on the basis of `srcLoHiSel`.
}];
let assemblyFormat = [{
attr-dict $src `[` $srcLoHiSel `]` `,` $scale `:` type($res)
}];
}
} // foreach largeT
foreach largeT = [
ScaleArgInfo<F32, "F32">,
ScaleArgInfo<F16, "F16">,
ScaleArgInfo<BF16, "BF16">
] in {
def ROCDL_CvtScaleF32Sr # smallTOp # largeT.nameForOp # Op :
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.sr." # smallT # "." # largeT.name,
[Pure], 1, [4], ["dstSelIndex"]>,
Arguments<(ins I32:$oldVdst, largeT.type:$src0, I32:$seed, F32:$scale, I32Attr:$dstSelIndex)> {
let results = (outs I32:$res);
let summary = "Scaled convert " # largeT.name # "to " # smallT # " with stochiastic rounding, updating packed vector";
let description = [{
Convert a }] # largeT.name # [{ value in `src0` to a }]
# smallT # [{ bytes, dividing by the exponent in `scale` and using `seed`
for stochiastic rounding. Place the resulting byte in the
`dstSelIndex`th bit of `oldVdst` and return the entire packed vector,
which is stored as an `i32`.
}];
let assemblyFormat = [{
attr-dict $src0 `,` $seed `,` $scale `->` $oldVdst `[` $dstSelIndex `]` `:` type($res)
}];
}
} // foreach largeT
} // foreach smallTOp
//===---------------------------------------------------------------------===//
// Scaled conversions to/from fp4 (f4E2M1FN)
//===---------------------------------------------------------------------===//
foreach largeT = [
ScaleArgInfo<ROCDL_V2F16Type, "F16">,
ScaleArgInfo<ROCDL_V2BF16Type, "Bf16">,
ScaleArgInfo<ROCDL_V2F32Type, "F32">,
] in {
// Note: rouding down f32 values has a special case where
// we have to use 2 float arguments.
if !ne(largeT.name, "f32") then {
def ROCDL_CvtScaleF32PkFp4 # largeT.nameForOp # Op :
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk.fp4." # largeT.name,
[Pure], 1, [3], ["dstSelIndex"]>,
Arguments<(ins I32:$oldVdst, largeT.type:$src, F32:$scale, I32Attr:$dstSelIndex)> {
let results = (outs I32:$res);
let summary = "Scale and convert two "
# largeT.name # " to packed fp4, updating tied vector";
let description = [{
Convert two packed }] # largeT.name # [{ values to packed
fp4, dividing by the exponent part of `scale`
before doing so.
The two scaled values are packed into a byte.
That byte is used to update the `dstSelIndex`th
byte of `oldVdst`, which is returned in its entirity.
}];
let assemblyFormat = [{
attr-dict $src `,` $scale `->` $oldVdst `[` $dstSelIndex `]` `:` type($res)
}];
}
} // if
def ROCDL_CvtScaleF32SrPkFp4 # largeT.nameForOp # Op :
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.sr.pk.fp4." # largeT.name,
[Pure], 1, [4], ["dstSelIndex"]>,
Arguments<(ins I32:$oldVdst, largeT.type:$src, I32:$seed, F32:$scale, I32Attr:$dstSelIndex)> {
let results = (outs I32:$res);
let summary = "Scale and convert two "
# largeT.name # " to packed fp4 with stochiastic rounding, updating tied vector";
let description = [{
Convert two packed }] # largeT.name # [{ values to packed
fp4, dividing by the exponent part of `scale`
before doing so and using `seed` as the random seed for
stochiastic rounding.
The two scaled values are packed (little-endian)
into a byte. That byte is used to update the `dstSelIndex`th
byte of `oldVdst`, which is returned in its entirity.
}];
let assemblyFormat = [{
attr-dict $src `,` $seed `,` $scale `->` $oldVdst `[` $dstSelIndex `]` `:` type($res)
}];
}
def ROCDL_CvtScaleF32Pk # largeT.nameForOp # Fp4Op :
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk." # largeT.name # ".fp4",
[Pure], 1, [2], ["srcSelIndex"]>,
Arguments<(ins I32:$src, F32:$scale, I32Attr:$srcSelIndex)> {
let results = (outs largeT.type:$res);
let summary = "Scale and convert two packed fp4 to packed " # largeT.name;
let description = [{
Convert two packed fp4 (f4E2M1) values stored as one byte of a 32-bit integer
to packed }] # largeT.name # [{, multiplying by the exponent part of `scale`
before doing so.
The byte to convert is chosen by `srcSelIndex`.
}];
let assemblyFormat = [{
attr-dict $src `[` $srcSelIndex `]` `,` $scale `:` type($res)
}];
}
} // foreach largeT
def ROCDL_CvtScaleF32PkFp4F32Op :
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk.fp4.f32",
[Pure], 1, [4], ["dstSelIndex"]>,
Arguments<(ins I32:$oldVdst, F32:$src0, F32:$src1, F32:$scale, I32Attr:$dstSelIndex)> {
let results = (outs I32:$res);
let summary = "Scale and convert two f32 values to two packed fp4, updating tied vector";
let description = [{
Convert two single-precision float values, passed in `src0` and `src1`
into two fp4 values, dividing them by the expontent part of `scale`
before doing so.
The two scaled values are packed into a byte.
That byte is used to update the `dstSelIndex`th
byte of `oldVdst`, which is returned in its entirity.
}];
let assemblyFormat = [{
attr-dict $src0 `,` $src1 `,` $scale `->` $oldVdst `[` $dstSelIndex `]` `:` type($res)
}];
}
//===----------------------------------------------------------------------===//
// ROCDL target attribute.
//===----------------------------------------------------------------------===//

View File

@@ -1210,22 +1210,20 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
}
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
if (resultVecType) {
Value wordSel = createI1Constant(rewriter, loc, op.getIndex());
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
wordSel);
op.getIndex());
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
wordSel);
op.getIndex());
}
} else {
Value byteSel = createI32Constant(rewriter, loc, op.getIndex());
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
byteSel);
op.getIndex());
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
byteSel);
op.getIndex());
}
}
return success();
@@ -1253,15 +1251,14 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
else
existing = rewriter.create<LLVM::UndefOp>(loc, i32);
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
Value result;
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
existing, op.getWordIndex());
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
existing, wordSel);
existing, op.getWordIndex());
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
op, getTypeConverter()->convertType(resultType), result);
@@ -1288,15 +1285,14 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
else
existing = rewriter.create<LLVM::UndefOp>(loc, i32);
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
Value result;
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
existing, byteSel);
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(
loc, i32, source, stoch, existing, op.getStoreIndex());
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
existing, byteSel);
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(
loc, i32, source, stoch, existing, op.getStoreIndex());
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
op, getTypeConverter()->convertType(resultType), result);

View File

@@ -7,8 +7,7 @@
// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]][0] : f32
// CHECK: return [[EXT]] : f32
func.func @ext_scalar(%v: f8E5M2) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to f32
@@ -25,8 +24,7 @@ func.func @ext_scalar(%v: f8E5M2) -> f32 {
// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]][1] : f32
// CHECK: return [[EXT]] : f32
func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FN> to f32
@@ -36,8 +34,7 @@ func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 {
// CHECK-LABEL: func @ext_full_vec(
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]][3] : f32
// CHECK: return [[EXT]] : f32
func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FN> to f32
@@ -54,8 +51,7 @@ func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> f32 {
// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(false) : i1
// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : vector<2xf32>
// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]][false] : vector<2xf32>
// CHECK: return [[EXT]]
func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FN>) -> vector<2xf32> {
%ret = amdgpu.ext_packed_fp8 %v[0] : vector<2xf8E4M3FN> to vector<2xf32>
@@ -65,8 +61,7 @@ func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FN>) -> vector<2xf32> {
// CHECK-LABEL: func @ext_packed_4xfp8
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
// CHECK: [[C3:%.+]] = llvm.mlir.constant(true) : i1
// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C3]]] : vector<2xf32>
// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]][true] : vector<2xf32>
// CHECK: return [[EXT]] : vector<2xf32>
func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FN>) -> vector<2xf32> {
%ret = amdgpu.ext_packed_fp8 %v[1] : vector<4xf8E4M3FN> to vector<2xf32>
@@ -77,8 +72,7 @@ func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FN>) -> vector<2xf32> {
// CHECK-SAME: ([[V:%.+]]: f32)
// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]][false] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FN> {
@@ -89,8 +83,7 @@ func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FN> {
// CHECK-LABEL: func @packed_truncx2
// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32)
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]][false] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FN> {
@@ -102,8 +95,7 @@ func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FN> {
// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[EXISTING:%.+]]: vector<4xf8E5M2>)
// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2> to vector<4xi8>
// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
// CHECK: [[TRUE:%.+]] = llvm.mlir.constant(true) : i1
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]]{{\[}}[[TRUE]]] : i32
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]][true] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2>
func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2>) -> vector<4xf8E5M2> {
@@ -114,8 +106,7 @@ func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2>) ->
// CHECK-LABEL: func @packed_stoch_round
// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32)
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
// CHECK: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]]{{\[}}[[C0]]] : i32
// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]][0] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FN> {
@@ -127,8 +118,7 @@ func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FN> {
// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32, [[EXISTING:%.+]]: vector<4xf8E5M2>)
// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2> to vector<4xi8>
// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
// CHECK: [[C1:%.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]]{{\[}}[[C1]]] : i32
// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]][1] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2>
func.func @packed_stoch_round_into(%v: f32, %s: i32, %existing: vector<4xf8E5M2>) -> vector<4xf8E5M2> {

View File

@@ -6,8 +6,7 @@
// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]][0] : f32
// CHECK: return [[EXT]] : f32
func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2FNUZ to f32
@@ -24,8 +23,7 @@ func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 {
// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]][1] : f32
// CHECK: return [[EXT]] : f32
func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FNUZ> to f32
@@ -35,8 +33,7 @@ func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> f32 {
// CHECK-LABEL: func @ext_full_vec
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]][3] : f32
// CHECK: return [[EXT]] : f32
func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> f32 {
%ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FNUZ> to f32
@@ -53,8 +50,7 @@ func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> f32 {
// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(false) : i1
// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : vector<2xf32>
// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]][false] : vector<2xf32>
// CHECK: return [[EXT]]
func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FNUZ>) -> vector<2xf32> {
%ret = amdgpu.ext_packed_fp8 %v[0] : vector<2xf8E4M3FNUZ> to vector<2xf32>
@@ -64,8 +60,7 @@ func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FNUZ>) -> vector<2xf32> {
// CHECK-LABEL: func @ext_packed_4xfp8(
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8>
// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
// CHECK: [[C3:%.+]] = llvm.mlir.constant(true) : i1
// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C3]]] : vector<2xf32>
// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]][true] : vector<2xf32>
// CHECK: return [[EXT]] : vector<2xf32>
func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> {
%ret = amdgpu.ext_packed_fp8 %v[1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
@@ -76,8 +71,7 @@ func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> {
// CHECK-SAME: ([[V:%.+]]: f32)
// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]][false] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ>
func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FNUZ> {
@@ -88,8 +82,7 @@ func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FNUZ> {
// CHECK-LABEL: func @packed_truncx2
// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32)
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]][false] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ>
func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FNUZ> {
@@ -101,8 +94,7 @@ func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FNUZ> {
// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[EXISTING:%.+]]: vector<4xf8E5M2FNUZ>)
// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2FNUZ> to vector<4xi8>
// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
// CHECK: [[TRUE:%.+]] = llvm.mlir.constant(true) : i1
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]]{{\[}}[[TRUE]]] : i32
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]][true] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2FNUZ>
func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2FNUZ>) -> vector<4xf8E5M2FNUZ> {
@@ -113,8 +105,7 @@ func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2FNUZ>
// CHECK-LABEL: func @packed_stoch_round
// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32)
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
// CHECK: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]]{{\[}}[[C0]]] : i32
// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]][0] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ>
func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FNUZ> {
@@ -126,8 +117,7 @@ func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FNUZ> {
// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32, [[EXISTING:%.+]]: vector<4xf8E5M2FNUZ>)
// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2FNUZ> to vector<4xi8>
// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
// CHECK: [[C1:%.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]]{{\[}}[[C1]]] : i32
// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]][1] : i32
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2FNUZ>
func.func @packed_stoch_round_into(%v: f32, %s: i32, %existing: vector<4xf8E5M2FNUZ>) -> vector<4xf8E5M2FNUZ> {

View File

@@ -789,36 +789,32 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf
// CHECK: rocdl.cvt.scalef32.sr.bf8.bf16
// CHECK: rocdl.cvt.scalef32.pk.f32.fp8
// CHECK: rocdl.cvt.scalef32.pk.f32.bf8
%c0 = llvm.mlir.constant(0 : i32) : i32
%c2 = llvm.mlir.constant(2 : i32) : i32
%c3 = llvm.mlir.constant(3 : i32) : i32
%c4 = llvm.mlir.constant(1.0 : f32) : f32
%false = llvm.mlir.constant(false) : i1
%v1 = rocdl.cvt.f32.bf8 %source[%c0] : f32
%v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32
%v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[%c0], %c4 : f32
%v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32
%v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[%false], %c4 : vector<2xf16>
%v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : vector<2xf16>
%v3_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[%false], %c4 : vector<2xbf16>
%v4_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[%false], %c4 : vector<2xbf16>
%v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %v3_scaled[%c0] : f16
%v6 = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %v3_scaled[%c0] : f16
%source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
%source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
%source2_ext = rocdl.cvt.pk.f32.bf8 %source[%false] : vector<2xf32>
%source3_ext = rocdl.cvt.pk.f32.fp8 %source[%false] : vector<2xf32>
%source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
%source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
%source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32
%source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[%c3] : i32
%source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[%c3] : i32
%source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c3] : i32
%source6_scaled = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[%c3] : i32
%source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[%c3] : i32
%source6_scaled_bfloat = rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[%c3] : i32
%source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[%false], %c4 : f32
%source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[%false], %c4 : f32
%v1 = rocdl.cvt.f32.bf8 %source[0] : f32
%v2 = rocdl.cvt.f32.fp8 %source[0] : f32
%v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[0], %c4 : f32
%v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[0], %c4 : f32
%v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[false], %c4 : vector<2xf16>
%v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[false], %c4 : vector<2xf16>
%v3_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[false], %c4 : vector<2xbf16>
%v4_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[false], %c4 : vector<2xbf16>
%v5 = rocdl.cvt.scalef32.f16.fp8 %source[0], %c4 -> %v3_scaled[false] : vector<2xf16>
%v6 = rocdl.cvt.scalef32.f16.bf8 %source[0], %c4 -> %v3_scaled[false] : vector<2xf16>
%source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[false] : i32
%source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[false] : i32
%source2_ext = rocdl.cvt.pk.f32.bf8 %source[false] : vector<2xf32>
%source3_ext = rocdl.cvt.pk.f32.fp8 %source[false] : vector<2xf32>
%source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[2] : i32
%source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[3] : i32
%source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[3] : i32
%source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[3] : i32
%source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[3] : i32
%source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[3] : i32
%source6_scaled = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[3] : i32
%source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[3] : i32
%source6_scaled_bfloat = rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[3] : i32
%source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[false], %c4 : vector<2xf32>
%source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[false], %c4 : vector<2xf32>
llvm.return %source5 : i32
}
@@ -826,9 +822,8 @@ llvm.func @rocdl_8bit_packed_v2i16(%sourceA: f32, %sourceB: f32, %old: vector<2x
// CHECK-LABEL: @rocdl_8bit_packed_v2i16
// CHECK: rocdl.cvt.scalef32.pk.fp8.f32
%c0 = llvm.mlir.constant(1.0 : f32) : f32
%false = llvm.mlir.constant(false) : i1
%source_scaled = rocdl.cvt.scalef32.pk.fp8.f32 %sourceA, %sourceB, %c0 -> %old[%false] : vector<2xi16>
%source2_scaled = rocdl.cvt.scalef32.pk.bf8.f32 %sourceA, %sourceB, %c0 -> %old[%false] : vector<2xi16>
%source_scaled = rocdl.cvt.scalef32.pk.fp8.f32 %sourceA, %sourceB, %c0 -> %old[false] : vector<2xi16>
%source2_scaled = rocdl.cvt.scalef32.pk.bf8.f32 %sourceA, %sourceB, %c0 -> %old[false] : vector<2xi16>
llvm.return %source_scaled : vector<2xi16>
}
@@ -836,14 +831,91 @@ llvm.func @rocdl_v2f16_v2i16(%source: vector<2xf16>, %source2: vector<2xbf16>, %
// CHECK-LABEL: @rocdl_v2f16_v2i16
// CHECK: rocdl.cvt.scalef32.pk.fp8.f16
%c0 = llvm.mlir.constant(1.0 : f32) : f32
%false = llvm.mlir.constant(false) : i1
%source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
%source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
%source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
%source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
%source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[false] : vector<2xi16>
%source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[false] : vector<2xi16>
%source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[false] : vector<2xi16>
%source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[false] : vector<2xi16>
llvm.return %source_scaled : vector<2xi16>
}
// CHECK-LABEL: @rocdl_6_bit_floats
// CHECK-SAME: (%[[V32F6:.+]]: vector<6xi32>, %[[V16F32:.+]]: vector<16xf32>, %[[V32F32:.+]]: vector<32xf32>, %[[V32F16:.+]]: vector<32xf16>, %[[V32BF16:.+]]: vector<32xbf16>, %[[SEED:.+]]: i32, %[[SCALE:.+]]: f32)
llvm.func @rocdl_6_bit_floats(
%v32f6: vector<6xi32>, %v16f32: vector<16xf32>, %v32f32: vector<32xf32>,
%v32f16: vector<32xf16>, %v32bf16: vector<32xbf16>, %seed: i32,
%scale: f32) {
// CHECK-NEXT: rocdl.cvt.scalef32.2xpk16.bf6.f32 %[[V16F32]], %[[V16F32]], %[[SCALE]]
%f32_to_bf6 = rocdl.cvt.scalef32.2xpk16.bf6.f32 %v16f32, %v16f32, %scale : vector<6xi32>
// CHECK-NEXT: rocdl.cvt.scalef32.2xpk16.fp6.f32 %[[V16F32]], %[[V16F32]], %[[SCALE]]
%f32_to_fp6 = rocdl.cvt.scalef32.2xpk16.fp6.f32 %v16f32, %v16f32, %scale : vector<6xi32>
// CHECK-NEXT: rocdl.cvt.scalef32.pk32.bf6.f16 %[[V32F16]], %[[SCALE]]
%f16_to_bf6 = rocdl.cvt.scalef32.pk32.bf6.f16 %v32f16, %scale : vector<6xi32>
// CHECK-NEXT: rocdl.cvt.scalef32.pk32.fp6.f16 %[[V32F16]], %[[SCALE]]
%f16_to_fp6 = rocdl.cvt.scalef32.pk32.fp6.f16 %v32f16, %scale : vector<6xi32>
// CHECK-NEXT: rocdl.cvt.scalef32.pk32.bf6.bf16 %[[V32BF16]], %[[SCALE]]
%bf16_to_bf6 = rocdl.cvt.scalef32.pk32.bf6.bf16 %v32bf16, %scale : vector<6xi32>
// CHECK-NEXT: rocdl.cvt.scalef32.pk32.fp6.bf16 %[[V32BF16]], %[[SCALE]]
%bf16_to_fp6 = rocdl.cvt.scalef32.pk32.fp6.bf16 %v32bf16, %scale : vector<6xi32>
// CHECK-NEXT: rocdl.cvt.scalef32.pk32.f32.bf6 %[[V32F6]], %[[SCALE]]
%bf6_to_f32 = rocdl.cvt.scalef32.pk32.f32.bf6 %v32f6, %scale : vector<32xf32>
// CHECK-NEXT: rocdl.cvt.scalef32.pk32.f32.fp6 %[[V32F6]], %[[SCALE]]
%fp6_to_f32 = rocdl.cvt.scalef32.pk32.f32.fp6 %v32f6, %scale : vector<32xf32>
// CHECK-NEXT: rocdl.cvt.scalef32.pk32.f16.bf6 %[[V32F6]], %[[SCALE]]
%bf6_to_f16 = rocdl.cvt.scalef32.pk32.f16.bf6 %v32f6, %scale : vector<32xf16>
// CHECK-NEXT: rocdl.cvt.scalef32.pk32.f16.fp6 %[[V32F6]], %[[SCALE]]
%fp6_to_f16 = rocdl.cvt.scalef32.pk32.f16.fp6 %v32f6, %scale : vector<32xf16>
// CHECK-NEXT: rocdl.cvt.scalef32.pk32.bf16.bf6 %[[V32F6]], %[[SCALE]]
%bf6_to_bf16 = rocdl.cvt.scalef32.pk32.bf16.bf6 %v32f6, %scale : vector<32xbf16>
// CHECK-NEXT: rocdl.cvt.scalef32.pk32.bf16.fp6 %[[V32F6]], %[[SCALE]]
%fp6_to_bf16 = rocdl.cvt.scalef32.pk32.bf16.fp6 %v32f6, %scale : vector<32xbf16>
// CHECK-NEXT: rocdl.cvt.scalef32.sr.pk32.bf6.f32 %[[V32F32]], %[[SEED]], %[[SCALE]]
%f32_to_bf6_sr = rocdl.cvt.scalef32.sr.pk32.bf6.f32 %v32f32, %seed, %scale : vector<6xi32>
// CHECK-NEXT: rocdl.cvt.scalef32.sr.pk32.fp6.f32 %[[V32F32]], %[[SEED]], %[[SCALE]]
%f32_to_fp6_sr = rocdl.cvt.scalef32.sr.pk32.fp6.f32 %v32f32, %seed, %scale : vector<6xi32>
// CHECK-NEXT: rocdl.cvt.scalef32.sr.pk32.bf6.f16 %[[V32F16]], %[[SEED]], %[[SCALE]]
%f16_to_bf6_sr = rocdl.cvt.scalef32.sr.pk32.bf6.f16 %v32f16, %seed, %scale : vector<6xi32>
// CHECK-NEXT: rocdl.cvt.scalef32.sr.pk32.fp6.f16 %[[V32F16]], %[[SEED]], %[[SCALE]]
%f16_to_fp6_sr = rocdl.cvt.scalef32.sr.pk32.fp6.f16 %v32f16, %seed, %scale : vector<6xi32>
// CHECK-NEXT: rocdl.cvt.scalef32.sr.pk32.bf6.bf16 %[[V32BF16]], %[[SEED]], %[[SCALE]]
%bf16_to_bf6_sr = rocdl.cvt.scalef32.sr.pk32.bf6.bf16 %v32bf16, %seed, %scale : vector<6xi32>
// CHECK-NEXT: rocdl.cvt.scalef32.sr.pk32.fp6.bf16 %[[V32BF16]], %[[SEED]], %[[SCALE]]
%bf16_to_fp6_sr = rocdl.cvt.scalef32.sr.pk32.fp6.bf16 %v32bf16, %seed, %scale : vector<6xi32>
llvm.return
}
// CHECK-LABEL: @rocdl_4_bit_floats
// CHECK-SAME: (%[[V8F4:.+]]: i32, %[[F32:.+]]: f32, %[[V2F32:.+]]: vector<2xf32>, %[[V2F16:.+]]: vector<2xf16>, %[[V2BF16:.+]]: vector<2xbf16>, %[[SEED:.+]]: i32, %[[SCALE:.+]]: f32)
llvm.func @rocdl_4_bit_floats(
%v8f4: i32, %f32: f32, %v2f32: vector<2xf32>, %v2f16: vector<2xf16>,
%v2bf16: vector<2xbf16>, %seed: i32, %scale: f32) {
// CHECK-NEXT: rocdl.cvt.scalef32.pk.fp4.f32 %[[F32]], %[[F32]], %[[SCALE]] -> %[[V8F4]][0]
%f32_to_fp4 = rocdl.cvt.scalef32.pk.fp4.f32 %f32, %f32, %scale -> %v8f4[0] : i32
// CHECK-NEXT: rocdl.cvt.scalef32.pk.fp4.f16 %[[V2F16]], %[[SCALE]] -> %[[V8F4]][1]
%f16_to_fp4 = rocdl.cvt.scalef32.pk.fp4.f16 %v2f16, %scale -> %v8f4[1] : i32
// CHECK-NEXT: rocdl.cvt.scalef32.pk.fp4.bf16 %[[V2BF16]], %[[SCALE]] -> %[[V8F4]][0]
%bf16_to_fp4 = rocdl.cvt.scalef32.pk.fp4.bf16 %v2bf16, %scale -> %v8f4[0] : i32
// CHECK-NEXT: rocdl.cvt.scalef32.pk.f32.fp4 %[[V8F4]][0], %[[SCALE]]
%fp4_to_f32 = rocdl.cvt.scalef32.pk.f32.fp4 %v8f4[0], %scale : vector<2xf32>
// CHECK-NEXT: rocdl.cvt.scalef32.pk.f16.fp4 %[[V8F4]][1], %[[SCALE]]
%fp4_to_f16 = rocdl.cvt.scalef32.pk.f16.fp4 %v8f4[1], %scale : vector<2xf16>
// CHECK-NEXT: rocdl.cvt.scalef32.pk.bf16.fp4 %[[V8F4]][0], %[[SCALE]]
%fp4_to_bf16 = rocdl.cvt.scalef32.pk.bf16.fp4 %v8f4[0], %scale : vector<2xbf16>
// CHECK-NEXT: rocdl.cvt.scalef32.sr.pk.fp4.f32 %[[V2F32]], %[[SEED]], %[[SCALE]] -> %[[V8F4]][0]
%f32_to_fp4_sr = rocdl.cvt.scalef32.sr.pk.fp4.f32 %v2f32, %seed, %scale -> %v8f4[0] : i32
// CHECK-NEXT: rocdl.cvt.scalef32.sr.pk.fp4.f16 %[[V2F16]], %[[SEED]], %[[SCALE]] -> %[[V8F4]][1]
%f16_to_fp4_sr = rocdl.cvt.scalef32.sr.pk.fp4.f16 %v2f16, %seed, %scale -> %v8f4[1] : i32
// CHECK-NEXT: rocdl.cvt.scalef32.sr.pk.fp4.bf16 %[[V2BF16]], %[[SEED]], %[[SCALE]] -> %[[V8F4]][0]
%bf16_to_fp4_sr = rocdl.cvt.scalef32.sr.pk.fp4.bf16 %v2bf16, %seed, %scale -> %v8f4[0] : i32
llvm.return
}
llvm.func @rocdl.s.waitcnt() {
// CHECK-LABEL: rocdl.s.waitcnt
// CHECK: rocdl.s.waitcnt 0

View File

@@ -1081,34 +1081,31 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.bf8.bf16(i32 %{{.+}}, bfloat %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
// CHECK: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.fp8(i32 %{{.+}}, float 1.000000e+00, i1 false)
// CHECK: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.bf8(i32 %{{.+}}, float 1.000000e+00, i1 false)
%c0 = llvm.mlir.constant(0 : i32) : i32
%c2 = llvm.mlir.constant(2 : i32) : i32
%c3 = llvm.mlir.constant(3 : i32) : i32
%c4 = llvm.mlir.constant(1.0 : f32) : f32
%false = llvm.mlir.constant(false) : i1
%v1 = rocdl.cvt.f32.bf8 %source[%c0] : f32
%v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[%c0], %c4 : f32
%v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32
%v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32
%v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[%false], %c4 : i32
%v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : i32
%v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %source_packed[%c0] : f16
%v6 = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %source_packed[%c0] : f16
%v7 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[%false], %c4 : i32
%v8 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[%false], %c4 : i32
%source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
%source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
%source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
%source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
%source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32
%source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[%c3] : i32
%source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[%c3] : i32
%source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c3] : i32
%source6_scaled = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[%c3] : i32
%source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[%c3] : i32
%source6_scaled_bfloat = rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[%c3] : i32
%source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[%false], %c4 : f32
%source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[%false], %c4 : f32
%v1 = rocdl.cvt.f32.bf8 %source[0] : f32
%v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[0], %c4 : f32
%v2 = rocdl.cvt.f32.fp8 %source[0] : f32
%v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[0], %c4 : f32
%v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[false], %c4 : vector<2xf16>
%v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[false], %c4 : vector<2xf16>
%v5 = rocdl.cvt.scalef32.f16.fp8 %source[0], %c4 -> %source_packed[false] : vector<2xf16>
%v6 = rocdl.cvt.scalef32.f16.bf8 %source[0], %c4 -> %source_packed[false] : vector<2xf16>
%v7 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[false], %c4 : vector<2xbf16>
%v8 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[false], %c4 : vector<2xbf16>
%source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[false] : i32
%source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[false] : i32
%source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[2] : i32
%source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[3] : i32
%source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[3] : i32
%source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[3] : i32
%source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[3] : i32
%source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[3] : i32
%source6_scaled = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[3] : i32
%source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[3] : i32
%source6_scaled_bfloat = rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[3] : i32
%source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[false], %c4 : vector<2xf32>
%source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[false], %c4 : vector<2xf32>
llvm.return %source5 : i32
}
@@ -1117,9 +1114,8 @@ llvm.func @rocdl_8bit_packed_v2i16(%sourceA: f32, %sourceB: f32, %old: vector<2x
// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.fp8.f32(<2 x i16> %{{.+}}, float %{{.+}}, float %{{.+}}, float 1.000000e+00, i1 false)
// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.bf8.f32(<2 x i16> %{{.+}}, float %{{.+}}, float %{{.+}}, float 1.000000e+00, i1 false)
%c0 = llvm.mlir.constant(1.0 : f32) : f32
%false = llvm.mlir.constant(false) : i1
%source_scaled = rocdl.cvt.scalef32.pk.fp8.f32 %sourceA, %sourceB, %c0 -> %old[%false] : vector<2xi16>
%source2_scaled = rocdl.cvt.scalef32.pk.bf8.f32 %sourceA, %sourceB, %c0 -> %old[%false] : vector<2xi16>
%source_scaled = rocdl.cvt.scalef32.pk.fp8.f32 %sourceA, %sourceB, %c0 -> %old[false] : vector<2xi16>
%source2_scaled = rocdl.cvt.scalef32.pk.bf8.f32 %sourceA, %sourceB, %c0 -> %old[false] : vector<2xi16>
llvm.return %source_scaled : vector<2xi16>
}
@@ -1130,11 +1126,10 @@ llvm.func @rocdl_v2f16_v2i16(%source: vector<2xf16>, %source2: vector<2xbf16>, %
// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.bf8.f16(<2 x i16> %2, <2 x half> %0, float 1.000000e+00, i1 false)
// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.bf8.bf16(<2 x i16> %2, <2 x bfloat> %1, float 1.000000e+00, i1 false)
%c0 = llvm.mlir.constant(1.0 : f32) : f32
%false = llvm.mlir.constant(false) : i1
%source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
%source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
%source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
%source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
%source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[false] : vector<2xi16>
%source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[false] : vector<2xi16>
%source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[false] : vector<2xi16>
%source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[false] : vector<2xi16>
llvm.return %source_scaled : vector<2xi16>
}
@@ -1145,6 +1140,83 @@ llvm.func @rocdl_16bit_packed_floats(%sourceA: f32, %sourceB: f32) -> vector<2xf
llvm.return %source : vector<2xf16>
}
// CHECK-LABEL: @rocdl_6_bit_floats
// CHECK-SAME: (<6 x i32> %[[V32F6:.+]], <16 x float> %[[V16F32:.+]], <32 x float> %[[V32F32:.+]], <32 x half> %[[V32F16:.+]], <32 x bfloat> %[[V32BF16:.+]], i32 %[[SEED:.+]], float %[[SCALE:.+]])
llvm.func @rocdl_6_bit_floats(
%v32f6: vector<6xi32>, %v16f32: vector<16xf32>, %v32f32: vector<32xf32>,
%v32f16: vector<32xf16>, %v32bf16: vector<32xbf16>, %seed: i32,
%scale: f32) {
// CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.2xpk16.bf6.f32(<16 x float> %[[V16F32]], <16 x float> %[[V16F32]], float %[[SCALE]])
%f32_to_bf6 = rocdl.cvt.scalef32.2xpk16.bf6.f32 %v16f32, %v16f32, %scale : vector<6xi32>
// CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.2xpk16.fp6.f32(<16 x float> %[[V16F32]], <16 x float> %[[V16F32]], float %[[SCALE]])
%f32_to_fp6 = rocdl.cvt.scalef32.2xpk16.fp6.f32 %v16f32, %v16f32, %scale : vector<6xi32>
// CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.pk32.bf6.f16(<32 x half> %[[V32F16]], float %[[SCALE]])
%f16_to_bf6 = rocdl.cvt.scalef32.pk32.bf6.f16 %v32f16, %scale : vector<6xi32>
// CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.pk32.fp6.f16(<32 x half> %[[V32F16]], float %[[SCALE]])
%f16_to_fp6 = rocdl.cvt.scalef32.pk32.fp6.f16 %v32f16, %scale : vector<6xi32>
// CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.pk32.bf6.bf16(<32 x bfloat> %[[V32BF16]], float %[[SCALE]])
%bf16_to_bf6 = rocdl.cvt.scalef32.pk32.bf6.bf16 %v32bf16, %scale : vector<6xi32>
// CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.pk32.fp6.bf16(<32 x bfloat> %[[V32BF16]], float %[[SCALE]])
%bf16_to_fp6 = rocdl.cvt.scalef32.pk32.fp6.bf16 %v32bf16, %scale : vector<6xi32>
// CHECK-NEXT: call <32 x float> @llvm.amdgcn.cvt.scalef32.pk32.f32.bf6(<6 x i32> %[[V32F6]], float %[[SCALE]])
%bf6_to_f32 = rocdl.cvt.scalef32.pk32.f32.bf6 %v32f6, %scale : vector<32xf32>
// CHECK-NEXT: call <32 x float> @llvm.amdgcn.cvt.scalef32.pk32.f32.fp6(<6 x i32> %[[V32F6]], float %[[SCALE]])
%fp6_to_f32 = rocdl.cvt.scalef32.pk32.f32.fp6 %v32f6, %scale : vector<32xf32>
// CHECK-NEXT: call <32 x half> @llvm.amdgcn.cvt.scalef32.pk32.f16.bf6(<6 x i32> %[[V32F6]], float %[[SCALE]])
%bf6_to_f16 = rocdl.cvt.scalef32.pk32.f16.bf6 %v32f6, %scale : vector<32xf16>
// CHECK-NEXT: call <32 x half> @llvm.amdgcn.cvt.scalef32.pk32.f16.fp6(<6 x i32> %[[V32F6]], float %[[SCALE]])
%fp6_to_f16 = rocdl.cvt.scalef32.pk32.f16.fp6 %v32f6, %scale : vector<32xf16>
// CHECK-NEXT: call <32 x bfloat> @llvm.amdgcn.cvt.scalef32.pk32.bf16.bf6(<6 x i32> %[[V32F6]], float %[[SCALE]])
%bf6_to_bf16 = rocdl.cvt.scalef32.pk32.bf16.bf6 %v32f6, %scale : vector<32xbf16>
// CHECK-NEXT: call <32 x bfloat> @llvm.amdgcn.cvt.scalef32.pk32.bf16.fp6(<6 x i32> %[[V32F6]], float %[[SCALE]])
%fp6_to_bf16 = rocdl.cvt.scalef32.pk32.bf16.fp6 %v32f6, %scale : vector<32xbf16>
// CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk32.bf6.f32(<32 x float> %[[V32F32]], i32 %[[SEED]], float %[[SCALE]])
%f32_to_bf6_sr = rocdl.cvt.scalef32.sr.pk32.bf6.f32 %v32f32, %seed, %scale : vector<6xi32>
// CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk32.fp6.f32(<32 x float> %[[V32F32]], i32 %[[SEED]], float %[[SCALE]])
%f32_to_fp6_sr = rocdl.cvt.scalef32.sr.pk32.fp6.f32 %v32f32, %seed, %scale : vector<6xi32>
// CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk32.bf6.f16(<32 x half> %[[V32F16]], i32 %[[SEED]], float %[[SCALE]])
%f16_to_bf6_sr = rocdl.cvt.scalef32.sr.pk32.bf6.f16 %v32f16, %seed, %scale : vector<6xi32>
// CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk32.fp6.f16(<32 x half> %[[V32F16]], i32 %[[SEED]], float %[[SCALE]])
%f16_to_fp6_sr = rocdl.cvt.scalef32.sr.pk32.fp6.f16 %v32f16, %seed, %scale : vector<6xi32>
// CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk32.bf6.bf16(<32 x bfloat> %[[V32BF16]], i32 %[[SEED]], float %[[SCALE]])
%bf16_to_bf6_sr = rocdl.cvt.scalef32.sr.pk32.bf6.bf16 %v32bf16, %seed, %scale : vector<6xi32>
// CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk32.fp6.bf16(<32 x bfloat> %[[V32BF16]], i32 %[[SEED]], float %[[SCALE]])
%bf16_to_fp6_sr = rocdl.cvt.scalef32.sr.pk32.fp6.bf16 %v32bf16, %seed, %scale : vector<6xi32>
llvm.return
}
// CHECK-LABEL: @rocdl_4_bit_floats
// CHECK-SAME: (i32 %[[V8F4:.+]], float %[[F32:.+]], <2 x float> %[[V2F32:.+]], <2 x half> %[[V2F16:.+]], <2 x bfloat> %[[V2BF16:.+]], i32 %[[SEED:.+]], float %[[SCALE:.+]])
llvm.func @rocdl_4_bit_floats(
%v8f4: i32, %f32: f32, %v2f32: vector<2xf32>, %v2f16: vector<2xf16>,
%v2bf16: vector<2xbf16>, %seed: i32, %scale: f32) {
// CHECK-NEXT: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.f32(i32 %[[V8F4]], float %[[F32]], float %[[F32]], float %[[SCALE]], i32 0)
%f32_to_fp4 = rocdl.cvt.scalef32.pk.fp4.f32 %f32, %f32, %scale -> %v8f4[0] : i32
// CHECK-NEXT: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.f16(i32 %[[V8F4]], <2 x half> %[[V2F16]], float %[[SCALE]], i32 1)
%f16_to_fp4 = rocdl.cvt.scalef32.pk.fp4.f16 %v2f16, %scale -> %v8f4[1] : i32
// CHECK-NEXT: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.bf16(i32 %[[V8F4]], <2 x bfloat> %[[V2BF16]], float %[[SCALE]], i32 0)
%bf16_to_fp4 = rocdl.cvt.scalef32.pk.fp4.bf16 %v2bf16, %scale -> %v8f4[0] : i32
// CHECK-NEXT: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.fp4(i32 %[[V8F4]], float %[[SCALE]], i32 0)
%fp4_to_f32 = rocdl.cvt.scalef32.pk.f32.fp4 %v8f4[0], %scale : vector<2xf32>
// CHECK-NEXT: call <2 x half> @llvm.amdgcn.cvt.scalef32.pk.f16.fp4(i32 %[[V8F4]], float %[[SCALE]], i32 1)
%fp4_to_f16 = rocdl.cvt.scalef32.pk.f16.fp4 %v8f4[1], %scale : vector<2xf16>
// CHECK-NEXT: call <2 x bfloat> @llvm.amdgcn.cvt.scalef32.pk.bf16.fp4(i32 %[[V8F4]], float %[[SCALE]], i32 0)
%fp4_to_bf16 = rocdl.cvt.scalef32.pk.bf16.fp4 %v8f4[0], %scale : vector<2xbf16>
// CHECK-NEXT: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.f32(i32 %[[V8F4]], <2 x float> %[[V2F32]], i32 %[[SEED]], float %[[SCALE]], i32 0)
%f32_to_fp4_sr = rocdl.cvt.scalef32.sr.pk.fp4.f32 %v2f32, %seed, %scale -> %v8f4[0] : i32
// CHECK-NEXT: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.f16(i32 %[[V8F4]], <2 x half> %[[V2F16]], i32 %[[SEED]], float %[[SCALE]], i32 1)
%f16_to_fp4_sr = rocdl.cvt.scalef32.sr.pk.fp4.f16 %v2f16, %seed, %scale -> %v8f4[1] : i32
// CHECK-NEXT: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.bf16(i32 %[[V8F4]], <2 x bfloat> %[[V2BF16]], i32 %[[SEED]], float %[[SCALE]], i32 0)
%bf16_to_fp4_sr = rocdl.cvt.scalef32.sr.pk.fp4.bf16 %v2bf16, %seed, %scale -> %v8f4[0] : i32
llvm.return
}
llvm.func @rocdl_atomic_attrs(%ptr: !llvm.ptr<1>, %data: f32) {
// CHECK-LABEL: @rocdl_atomic_attrs
// CHECK: atomicrmw