[mlir][ROCDL] Add fp4 and fp6 conversion intrinsics, fix fp8 immargs (#140801)
This PR adds support for the scaled conversion intrinsics for fp4 and fp6 types so that they can be targetted by a future amdgpu dialect op or used directly. Additionally, this patch refactors the copy-paste-heavy fp8 versions of these scaled conversion intrinsics with tablegen `foreach` loops, and fixes the fact that certain immargs weren't being stored as attributes. Note that some of the MLIR-level tests for those scaled fp8 intrinsics had incorrect return types, which have been fixed. (Note that while the operations have a known return type, the IR format still prints that type for clarity).
This commit is contained in:
committed by
GitHub
parent
6212c199b1
commit
6c813e8a3c
@@ -709,20 +709,23 @@ def ROCDL_PermlaneX16Op : ROCDL_IntrOp<"permlanex16", [], [0],
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
|
||||
BuildableType<"::mlir::VectorType::get("
|
||||
"{2},$_builder.getI16Type())">;
|
||||
class ROCDL_ConcreteVector<Type elem, int length> :
|
||||
FixedVectorOfLengthAndType<[length], [elem]>,
|
||||
BuildableType<
|
||||
"::mlir::VectorType::get({" # length # "} ,"
|
||||
# elem.builderCall # ")">;
|
||||
|
||||
def ROCDL_V2F16Type : FixedVectorOfLengthAndType<[2], [F16]>,
|
||||
BuildableType<"::mlir::VectorType::get("
|
||||
"{2},$_builder.getF16Type())">;
|
||||
def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
|
||||
def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
|
||||
def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
|
||||
def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
|
||||
def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
|
||||
def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
|
||||
def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
|
||||
def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
|
||||
def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
|
||||
def ROCDL_V32F32Type : ROCDL_ConcreteVector<F32, 32>;
|
||||
|
||||
def ROCDL_V2BF16Type : FixedVectorOfLengthAndType<[2], [BF16]>,
|
||||
BuildableType<"::mlir::VectorType::get("
|
||||
"{2},$_builder.getBF16Type())">;
|
||||
|
||||
// TODO: The word and byte selectors are immarg in LLVM
|
||||
// update to be attributes in MLIR
|
||||
//===---------------------------------------------------------------------===//
|
||||
// 16-bit float intrinsics
|
||||
//===---------------------------------------------------------------------===//
|
||||
@@ -738,279 +741,12 @@ def ROCDL_CvtPkRtz:
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_CvtScaleF32PkFp8F16Op :
|
||||
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f16", [], [], [Pure], 1>,
|
||||
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
|
||||
let summary = "Scale and convert f16 to packed fp8";
|
||||
let description = [{
|
||||
Scale `src` by the exponent in `scale`, then convert to packed fp8.
|
||||
Store the result in low/high word of `old` based on $wordSel, preserving the other word.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_CvtScaleF32PkFp8Bf16Op :
|
||||
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.bf16", [], [], [Pure], 1>,
|
||||
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
|
||||
let summary = "Scale and convert packed bf16 to packed fp8";
|
||||
let description = [{
|
||||
Scale `src` by the exponent in `scale`, then convert to packed fp8.
|
||||
Store the result in low/high word of `old` based on $wordSel, preserving the other word.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
def ROCDL_CvtScaleF32PkBf8F16Op :
|
||||
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f16", [], [], [Pure], 1>,
|
||||
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2F16Type: $src, F32: $scale, I1:$wordSel)> {
|
||||
let summary = "Scale and convert f16 to packed bf8";
|
||||
let description = [{
|
||||
Scale `src` by the exponent in `scale`, then convert to packed bf8.
|
||||
Store the result in low/high word of `old` based on $wordSel, preserving the other word.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
def ROCDL_CvtScaleF32PkBf8Bf16Op :
|
||||
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.bf16", [], [], [Pure], 1>,
|
||||
Arguments<(ins ROCDL_V2I16Type: $old, ROCDL_V2BF16Type: $src, F32: $scale, I1:$wordSel)> {
|
||||
let summary = "Scale and convert bf16 to packed bf8";
|
||||
let description = [{
|
||||
Scale `src` by the exponent in `scale`, then convert to packed bf8.
|
||||
Store the result in low/high word of `old` based on $wordSel, preserving the other word.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_CvtScaleF32SrFp8F16Op :
|
||||
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f16", [], [], [Pure], 1>,
|
||||
Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
|
||||
let summary = "Scale and convert f16 to packed fp8 using stochastic rounding";
|
||||
let description = [{
|
||||
Scale `src` by the exponent in `scale`, then convert to packed p8 with stochastic rounding
|
||||
using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
|
||||
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_CvtScaleF32SrBf8F16Op :
|
||||
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f16", [], [], [Pure], 1>,
|
||||
Arguments<(ins I32:$old, F16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
|
||||
let summary = "Scale and convert f16 to packed bf8 using stochastic rounding";
|
||||
let description = [{
|
||||
Scale `src` by the exponent in `scale`, then convert to packed bf8 with stochastic rounding
|
||||
using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
|
||||
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_CvtScaleF32SrFp8Bf16Op :
|
||||
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.bf16", [], [], [Pure], 1>,
|
||||
Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
|
||||
let summary = "Scale and convert packed bf16 to packed fp8 using stochastic rounding";
|
||||
let description = [{
|
||||
Scale `src` by the exponent in `scale`, then convert to packed fp8 with stochastic rounding
|
||||
using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
|
||||
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_CvtScaleF32SrBf8Bf16Op :
|
||||
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.bf16", [], [], [Pure], 1>,
|
||||
Arguments<(ins I32:$old, BF16:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
|
||||
let summary = "Scale and convert bf16 to packed fp8 using stochastic rounding";
|
||||
let description = [{
|
||||
Scale `src` by the exponent in `scale`, then convert to packed p8 with stochastic rounding
|
||||
using seed data in `seed`. Store into the `byteSel`th byte of `old`, preserving the others.
|
||||
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_CvtScaleF32PkF16Fp8Op :
|
||||
ROCDL_IntrOp<"cvt.scalef32.pk.f16.fp8", [], [], [Pure], 1>,
|
||||
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
|
||||
let summary = "Convert fp8 to packed f16 and scale";
|
||||
let description = [{ Convert `src` based on $wordSel to packed f16, then scale
|
||||
the packed values by the exponent in `scale`.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_CvtScaleF32PkF16Bf8Op :
|
||||
ROCDL_IntrOp<"cvt.scalef32.pk.f16.bf8", [], [], [Pure], 1>,
|
||||
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
|
||||
let summary = "convert bf8 to packed f16 and scale";
|
||||
let description = [{ Convert `src` based on $wordSel to packed f16, then scale
|
||||
the packed values by exponent in `scale`.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_CvtScaleF32PkBf16Fp8Op :
|
||||
ROCDL_IntrOp<"cvt.scalef32.pk.bf16.fp8", [], [], [Pure], 1>,
|
||||
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
|
||||
let summary = "Convert fp8 to packed bf16 and scale";
|
||||
let description = [{ Convert `src` based on $wordSel to packed bf16, then scale
|
||||
the packed values by the exponent in `scale`.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_CvtScaleF32PkBf16Bf8Op :
|
||||
ROCDL_IntrOp<"cvt.scalef32.pk.bf16.bf8", [], [], [Pure], 1>,
|
||||
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
|
||||
let summary = "Convert bf8 to packed bf16 and scale";
|
||||
let description = [{ Convert `src` based on $wordSel to packed bf16, then scale
|
||||
the packed values by the exponent in `scale`.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_CvtScaleF16Fp8Op :
|
||||
ROCDL_IntrOp<"cvt.scalef32.f16.fp8", [], [], [Pure], 1>,
|
||||
Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
|
||||
let summary = "Scale and convert fp8 to f16";
|
||||
let description = [{ Convert `src` based on $wordSel to f16, then scale the value
|
||||
by the exponent in `scale`. Store the result into the `byteSel`th byte of `old`,
|
||||
preserving the others.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_CvtScaleF16Bf8Op :
|
||||
ROCDL_IntrOp<"cvt.scalef32.f16.bf8", [], [], [Pure], 1>,
|
||||
Arguments<(ins ROCDL_V2F16Type:$old, I32:$src, F32: $scale, I32:$byteSel, I1:$wordSel)> {
|
||||
let summary = "Scale and convert fp8 to f16";
|
||||
let description = [{ Convert `src` based on $wordSel to f16, then scale the value
|
||||
by the exponent in `scale`. Store the result into the `byteSel`th byte of `old`,
|
||||
preserving the others.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `[` $wordSel `]` `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// 32-bit float intrinsics
|
||||
//===---------------------------------------------------------------------===//
|
||||
def ROCDL_CvtScaleF32PkF32Fp8Op :
|
||||
ROCDL_IntrOp<"cvt.scalef32.pk.f32.fp8", [], [], [Pure], 1>,
|
||||
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
|
||||
let summary = "Scale and convert packed fp8 to packed f32";
|
||||
let description = [{
|
||||
Convert `src` based on $wordSel to packed fp32, then scale the packed values by
|
||||
the exponent in `scale`. Store the result in a vector.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
|
||||
}];
|
||||
}
|
||||
def ROCDL_CvtScaleF32PkF32Bf8Op :
|
||||
ROCDL_IntrOp<"cvt.scalef32.pk.f32.bf8", [], [], [Pure], 1>,
|
||||
Arguments<(ins I32:$src, F32: $scale, I1:$wordSel)> {
|
||||
let summary = "Scale and convert packed bf8 to packed f32";
|
||||
let description = [{
|
||||
Convert `src` based on $wordSel to packed fp32, then scale the packed values by
|
||||
the exponent in `scale`. Store the result in a vector.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `[` $wordSel `]` `,` $scale `:` type($res)
|
||||
}];
|
||||
}
|
||||
//===---------------------------------------------------------------------===//
|
||||
// 8-bit float scale intrinsics
|
||||
//===---------------------------------------------------------------------===//
|
||||
def ROCDL_CvtScaleF32PkFp8F32Op :
|
||||
ROCDL_IntrOp<"cvt.scalef32.pk.fp8.f32", [], [], [Pure], 1>,
|
||||
Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32:$scale, I1:$wordSel)> {
|
||||
let summary = "Scale and convert two f32's to packed fp8";
|
||||
let description = [{
|
||||
Scale `srcA` and `srcB` by the exponent in `scale` then convert to packed fp8
|
||||
and store into the low/high word of `old`, preserving the other word.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $srcA `,` $srcB `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_CvtScaleF32PkBf8F32Op :
|
||||
ROCDL_IntrOp<"cvt.scalef32.pk.bf8.f32", [], [], [Pure], 1>,
|
||||
Arguments<(ins ROCDL_V2I16Type:$old, F32:$srcA, F32:$srcB, F32: $scale, I1:$wordSel)> {
|
||||
let summary = "Scale and convert two f32's to packed bf8";
|
||||
let description = [{
|
||||
Scale `srcA` and `srcB` by the exponent in `scale` then convert to packed bf8
|
||||
and store into the low/high word of `old`, preserving the other word.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $srcA `,` $srcB `,` $scale `->` $old `[` $wordSel `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_CvtScaleF32SrFp8F32Op :
|
||||
ROCDL_IntrOp<"cvt.scalef32.sr.fp8.f32", [], [], [Pure], 1>,
|
||||
Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
|
||||
let summary = "Scale and convert f32 to fp8 using stochastic rounding";
|
||||
let description = [{
|
||||
Scale `src` by the exponent in `scale` then convert to fp8 with stochastic rounding
|
||||
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
def ROCDL_CvtScaleF32SrBf8F32Op :
|
||||
ROCDL_IntrOp<"cvt.scalef32.sr.bf8.f32", [], [], [Pure], 1>,
|
||||
Arguments<(ins I32:$old, F32:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
|
||||
let summary = "Scale and convert f32 to bf8 using stochastic rounding";
|
||||
let description = [{
|
||||
Scale `src` by the exponent in `scale` then convert to bf8 with stochastic rounding
|
||||
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// 8-bit float intrinsics
|
||||
//===---------------------------------------------------------------------===//
|
||||
def ROCDL_CvtF32Bf8Op :
|
||||
ROCDL_IntrOp<"cvt.f32.bf8", [], [], [Pure], 1>,
|
||||
Arguments<(ins I32:$srcA, I32:$byteSel)> {
|
||||
ROCDL_ConcreteNonMemIntrOp<"cvt.f32.bf8", [Pure], 1, [1], ["byteSel"]>,
|
||||
Arguments<(ins I32:$srcA, I32Attr:$byteSel)> {
|
||||
let summary = "Convert bf8 to f32";
|
||||
let description = [{
|
||||
Convert 8-bit bf8 value from the `byteSel`th bit of `srcA` to fp32.
|
||||
@@ -1020,23 +756,9 @@ def ROCDL_CvtF32Bf8Op :
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_CvtScaleF32Bf8Op :
|
||||
ROCDL_IntrOp<"cvt.scalef32.f32.bf8", [], [], [Pure], 1>,
|
||||
Arguments<(ins I32:$src, F32: $scale, I32:$byteSel)> {
|
||||
let summary = "Scale and convert bf8 to f32";
|
||||
let description = [{
|
||||
Scale `src` by the exponent in `scale` then convert 8-bit bf8 value
|
||||
from the `byteSel`th bit of `src` to fp32.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
def ROCDL_CvtF32Fp8Op :
|
||||
ROCDL_IntrOp<"cvt.f32.fp8", [], [], [Pure], 1>,
|
||||
Arguments<(ins I32:$srcA, I32:$byteSel)> {
|
||||
ROCDL_ConcreteNonMemIntrOp<"cvt.f32.fp8", [Pure], 1, [1], ["byteSel"]>,
|
||||
Arguments<(ins I32:$srcA, I32Attr:$byteSel)> {
|
||||
let summary = "Convert fp8 to f32";
|
||||
let description = [{
|
||||
Convert 8-bit fp8 value from the `byteSel`th bit of `srcA` to fp32.
|
||||
@@ -1046,24 +768,9 @@ def ROCDL_CvtF32Fp8Op :
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
def ROCDL_CvtScaleF32Fp8Op :
|
||||
ROCDL_IntrOp<"cvt.scalef32.f32.fp8", [], [], [Pure], 1>,
|
||||
Arguments<(ins I32:$src, F32: $scale, I32:$byteSel)> {
|
||||
let summary = "Scale and convert fp8 to f32";
|
||||
let description = [{
|
||||
Scale `src` by the exponent in `scale` then convert 8-bit fp8 value
|
||||
from the `byteSel`th bit of `src` to fp32.
|
||||
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_CvtPkF32Fp8Op :
|
||||
ROCDL_IntrOp<"cvt.pk.f32.fp8", [], [], [Pure], 1>,
|
||||
Arguments<(ins I32:$src, I1:$wordSel)> {
|
||||
ROCDL_ConcreteNonMemIntrOp<"cvt.pk.f32.fp8", [Pure], 1, [1], ["wordSel"]>,
|
||||
Arguments<(ins I32:$src, I1Attr:$wordSel)> {
|
||||
let summary = "Convert packed fp8 to packed f32";
|
||||
let description = [{
|
||||
Convert `src` based on $wordSel to packed fp32.
|
||||
@@ -1074,8 +781,8 @@ def ROCDL_CvtPkF32Fp8Op :
|
||||
}
|
||||
|
||||
def ROCDL_CvtPkF32Bf8Op :
|
||||
ROCDL_IntrOp<"cvt.pk.f32.bf8", [], [], [Pure], 1>,
|
||||
Arguments<(ins I32:$src, I1:$wordSel)> {
|
||||
ROCDL_ConcreteNonMemIntrOp<"cvt.pk.f32.bf8", [Pure], 1, [1], ["wordSel"]>,
|
||||
Arguments<(ins I32:$src, I1Attr:$wordSel)> {
|
||||
let summary = "Convert packed bf8 to packed f32";
|
||||
let description = [{
|
||||
Convert `src` based on $wordSel to packed fp32,
|
||||
@@ -1086,8 +793,8 @@ def ROCDL_CvtPkF32Bf8Op :
|
||||
}
|
||||
|
||||
def ROCDL_CvtPkBf8F32Op :
|
||||
ROCDL_IntrOp<"cvt.pk.bf8.f32", [], [], [Pure], 1>,
|
||||
Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1:$wordSel)> {
|
||||
ROCDL_ConcreteNonMemIntrOp<"cvt.pk.bf8.f32", [Pure], 1, [3], ["wordSel"]>,
|
||||
Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1Attr:$wordSel)> {
|
||||
let summary = "Convert two f32's to bf8";
|
||||
let description = [{
|
||||
Convert `srcA` and `srcB` to bf8 and store into the low/high word of
|
||||
@@ -1099,8 +806,8 @@ def ROCDL_CvtPkBf8F32Op :
|
||||
}
|
||||
|
||||
def ROCDL_CvtPkFp8F32Op :
|
||||
ROCDL_IntrOp<"cvt.pk.fp8.f32", [], [], [Pure], 1>,
|
||||
Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1:$wordSel)> {
|
||||
ROCDL_ConcreteNonMemIntrOp<"cvt.pk.fp8.f32", [Pure], 1, [3], ["wordSel"]>,
|
||||
Arguments<(ins F32:$srcA, F32:$srcB, I32:$old, I1Attr:$wordSel)> {
|
||||
let summary = "Convert two f32's to fp8";
|
||||
let description = [{
|
||||
Convert `srcA` and `srcB` to fp8 and store into the low/high word of
|
||||
@@ -1112,8 +819,8 @@ def ROCDL_CvtPkFp8F32Op :
|
||||
}
|
||||
|
||||
def ROCDL_CvtSrBf8F32Op :
|
||||
ROCDL_IntrOp<"cvt.sr.bf8.f32", [], [], [Pure], 1>,
|
||||
Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32:$byteSel)> {
|
||||
ROCDL_ConcreteNonMemIntrOp<"cvt.sr.bf8.f32", [Pure], 1, [3], ["byteSel"]>,
|
||||
Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32Attr:$byteSel)> {
|
||||
let summary = "Convert f32 to bf8, stochiastic rounding";
|
||||
let description = [{
|
||||
Convert `srcA` to bf8, adding the rounding factor from `srcB`,
|
||||
@@ -1125,8 +832,8 @@ def ROCDL_CvtSrBf8F32Op :
|
||||
}
|
||||
|
||||
def ROCDL_CvtSrFp8F32Op :
|
||||
ROCDL_IntrOp<"cvt.sr.fp8.f32", [], [], [Pure], 1>,
|
||||
Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32:$byteSel)> {
|
||||
ROCDL_ConcreteNonMemIntrOp<"cvt.sr.fp8.f32", [Pure], 1, [3], ["byteSel"]>,
|
||||
Arguments<(ins F32:$srcA, I32:$srcB, I32:$old, I32Attr:$byteSel)> {
|
||||
let summary = "Convert f32 to fp8, stochiastic rounding";
|
||||
let description = [{
|
||||
Convert `srcA` to fp8, adding the rounding factor from `srcB`,
|
||||
@@ -1137,6 +844,335 @@ def ROCDL_CvtSrFp8F32Op :
|
||||
}];
|
||||
}
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// Scaled float conversion intrinsics
|
||||
//
|
||||
// These are using some tablegen trickery to avoid repetitive documentation
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
// Pair used so we can iterate over types..
|
||||
class ScaleArgInfo<TypeConstraint argTyVal, string typeName> {
|
||||
TypeConstraint type = argTyVal;
|
||||
string name = !tolower(typeName);
|
||||
string nameForOp = typeName;
|
||||
}
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// Scaled 32x6-bit float float conversion intrinsics
|
||||
//===---------------------------------------------------------------------===//
|
||||
foreach smallT = [
|
||||
// MLIR f6E2M3FN
|
||||
ScaleArgInfo<ROCDL_V6I32Type, "Fp6">,
|
||||
// MLIR f8E3M2FN
|
||||
ScaleArgInfo<ROCDL_V6I32Type, "Bf6">
|
||||
] in {
|
||||
foreach largeT = [
|
||||
ScaleArgInfo<ROCDL_V32F16Type, "F16">,
|
||||
ScaleArgInfo<ROCDL_V32BF16Type, "Bf16">,
|
||||
ScaleArgInfo<ROCDL_V32F32Type, "F32">,
|
||||
] in {
|
||||
// Note: rouding down f32 values has a special case where
|
||||
// we have to use 2 16xf32 arguments.
|
||||
if !ne(largeT.name, "f32") then {
|
||||
def ROCDL_CvtScaleF32Pk32 # smallT.nameForOp # largeT.nameForOp # Op :
|
||||
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk32." # smallT.name # "." # largeT.name,
|
||||
[Pure], 1>,
|
||||
Arguments<(ins largeT.type:$src, F32:$scale)> {
|
||||
let results = (outs smallT.type:$res);
|
||||
let summary = "Scale and convert packed "
|
||||
# largeT.name # " to packed " # smallT.name;
|
||||
let description = [{
|
||||
Convert 32 packed }] # largeT.name # [{ values to packed }]
|
||||
# smallT.name # [{, dividing by the exponent part of `scale`
|
||||
before doing so.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `,` $scale `:` type($res)
|
||||
}];
|
||||
}
|
||||
} // if
|
||||
|
||||
def ROCDL_CvtScaleF32SrPk32 # smallT.nameForOp # largeT.nameForOp # Op :
|
||||
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.sr.pk32." # smallT.name # "." # largeT.name,
|
||||
[Pure], 1>,
|
||||
Arguments<(ins largeT.type:$src, I32:$seed, F32:$scale)> {
|
||||
let results = (outs smallT.type:$res);
|
||||
let summary = "Scale and convert packed "
|
||||
# largeT.name # " to packed " # smallT.name
|
||||
# " with stochiastic rounding";
|
||||
let description = [{
|
||||
Convert 32 packed }] # largeT.name # [{ values to packed }]
|
||||
# smallT.name # [{, dividing by the exponent part of `scale`
|
||||
before doing so and applying random rounding derived from
|
||||
`seed`.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `,` $seed `,` $scale `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_CvtScaleF32Pk32 # largeT.nameForOp # smallT.nameForOp # Op :
|
||||
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk32." # largeT.name # "." # smallT.name,
|
||||
[Pure], 1>,
|
||||
Arguments<(ins smallT.type:$src, F32:$scale)> {
|
||||
let results = (outs largeT.type:$res);
|
||||
let summary = "Scale and convert packed "
|
||||
# smallT.name # " to packed " # largeT.name;
|
||||
let description = [{
|
||||
Convert 32 packed }] # smallT.name # [{ values to packed }]
|
||||
# largeT.name # [{, multiplying by the exponent part of `scale`
|
||||
before doing so.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `,` $scale `:` type($res)
|
||||
}];
|
||||
}
|
||||
} // foreach largeT
|
||||
|
||||
def ROCDL_CvtScaleF322xPk16 # smallT.nameForOp # F32Op :
|
||||
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.2xpk16." # smallT.name # ".f32",
|
||||
[Pure], 1>,
|
||||
Arguments<(ins ROCDL_V16F32Type:$src0, ROCDL_V16F32Type:$src1, F32:$scale)> {
|
||||
let results = (outs smallT.type:$res);
|
||||
let summary = "Scale and convert two vector<16xf32> to 32 packed " # smallT.name;
|
||||
let description = [{
|
||||
Convert 32 single-precision float values, packed into two length-16
|
||||
vectors that will be logically concanenated, to packed }]
|
||||
# smallT.name # [{, dividing by the exponent part of `scale`
|
||||
before doing so.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src0 `,` $src1 `,` $scale `:` type($res)
|
||||
}];
|
||||
}
|
||||
} // forach smallT
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// Scaled conversions to/from fp8/bf8 (f8E4M3FN / f8E5M2)
|
||||
//===---------------------------------------------------------------------===//
|
||||
foreach smallTOp = ["Fp8", "Bf8"] in {
|
||||
defvar smallT = !tolower(smallTOp);
|
||||
|
||||
def ROCDL_CvtScaleF32F16 # smallTOp # Op :
|
||||
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.f16." # smallT,
|
||||
[Pure], 1, [3, 4], ["srcSelIndex", "dstLoHiSel"]>,
|
||||
Arguments<(ins ROCDL_V2F16Type:$oldVdst, I32:$src, F32:$scale, I32Attr:$srcSelIndex, I1Attr:$dstLoHiSel)> {
|
||||
let results = (outs ROCDL_V2F16Type:$res);
|
||||
let summary = "Scaled convert " # smallT # " from packed vector to f16, updating tied result";
|
||||
let description = [{
|
||||
Convert a }] # smallT # [{ byte from `src`, selected by
|
||||
`srcSelIndex`, to f16 while multiplying it by the expontent of `scale`,
|
||||
and place it into the `dstLoHiSel`th bit
|
||||
of `oldVdst` preserving the other element of that vector in
|
||||
the return value.
|
||||
|
||||
The bytes are stored as an `i32` and not a `<4 x i8>`.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `[` $srcSelIndex `]` `,` $scale `->` $oldVdst `[` $dstLoHiSel `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_CvtScaleF32F32 # smallTOp # Op :
|
||||
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.f32." # smallT,
|
||||
[Pure], 1, [2], ["srcSelIndex"]>,
|
||||
Arguments<(ins I32:$src, F32:$scale, I32Attr:$srcSelIndex)> {
|
||||
let results = (outs F32:$res);
|
||||
let summary = "Scaled convert " # smallT # " from packed vector to f32";
|
||||
let description = [{
|
||||
Convert a }] # smallT # [{ byte from `src`, selected by
|
||||
`srcSelIndex`, to f32, multiplying it by the exponent of `scale`.
|
||||
|
||||
The bytes are stored in an `i32`, not a `<4 x i8>`.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `[` $srcSelIndex `]` `,` $scale `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_CvtScaleF32Pk # smallTOp # F32Op :
|
||||
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk." # smallT # ".f32",
|
||||
[Pure], 1, [4], ["dstLoHiSel"]>,
|
||||
Arguments<(ins ROCDL_V2I16Type:$oldVdst, F32:$src0, F32:$src1, F32:$scale, I1Attr:$dstLoHiSel)> {
|
||||
let results = (outs ROCDL_V2I16Type:$res);
|
||||
let summary = "Scaled convert two f32 to two " # smallT # ", updating packed vector";
|
||||
let description = [{
|
||||
Convert two f32 values in `src0` and `src1` to two }] # smallT # [{ bytes,
|
||||
dividing by the exponent in `scale`. The bytes are packed into
|
||||
a 16-bit value which is inserted into `oldVdst` at the `dstLoHiSel`
|
||||
position, with the entire updated vector being returned.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src0 `,` $src1 `,` $scale `->` $oldVdst `[` $dstLoHiSel `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
foreach largeT = [
|
||||
ScaleArgInfo<ROCDL_V2F16Type, "F16">,
|
||||
ScaleArgInfo<ROCDL_V2BF16Type, "Bf16">,
|
||||
] in {
|
||||
def ROCDL_CvtScaleF32Pk # smallTOp # largeT.nameForOp # Op :
|
||||
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk." # smallT # "." # largeT.name,
|
||||
[Pure], 1, [3], ["dstLoHiSel"]>,
|
||||
Arguments<(ins ROCDL_V2I16Type:$oldVdst, largeT.type:$src0, F32:$scale, I1Attr:$dstLoHiSel)> {
|
||||
let results = (outs ROCDL_V2I16Type:$res);
|
||||
let summary = "Scaled convert two " # largeT.name # "to two " # smallT # ", updating packed vector";
|
||||
let description = [{
|
||||
Convert two }] # largeT.name # [{ values in `src0` to two }]
|
||||
# smallT # [{ bytes, dividing by the exponent in `scale`. The bytes are
|
||||
packed into a 16-bit value which is inserted into `oldVdst` at the
|
||||
`dstLoHiSel` position, with the entire updated vector being returned.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src0 `,` $scale `->` $oldVdst `[` $dstLoHiSel `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
} // foreach largeT
|
||||
|
||||
foreach largeT = [
|
||||
ScaleArgInfo<ROCDL_V2F16Type, "F16">,
|
||||
ScaleArgInfo<ROCDL_V2BF16Type, "Bf16">,
|
||||
ScaleArgInfo<ROCDL_V2F32Type, "F32">
|
||||
] in {
|
||||
def ROCDL_CvtScaleF32Pk # largeT.nameForOp # smallTOp # Op :
|
||||
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk." # largeT.name # "." # smallT,
|
||||
[Pure], 1, [2], ["srcLoHiSel"]>,
|
||||
Arguments<(ins I32:$src, F32:$scale, I1Attr:$srcLoHiSel)> {
|
||||
let results = (outs largeT.type:$res);
|
||||
let summary = "Scaled convert two " # smallT # "to two " # largeT.name #;
|
||||
let description = [{
|
||||
Convert two packed }] # smallT # [{ values in `src0` to two }]
|
||||
# largeT.name # [{ values, multiplying by the exponent in `scale`.
|
||||
The two values to be converted are selected from the low or high half
|
||||
of `src` (a packed vector represented as an `i32`)
|
||||
on the basis of `srcLoHiSel`.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `[` $srcLoHiSel `]` `,` $scale `:` type($res)
|
||||
}];
|
||||
}
|
||||
} // foreach largeT
|
||||
|
||||
foreach largeT = [
|
||||
ScaleArgInfo<F32, "F32">,
|
||||
ScaleArgInfo<F16, "F16">,
|
||||
ScaleArgInfo<BF16, "BF16">
|
||||
] in {
|
||||
def ROCDL_CvtScaleF32Sr # smallTOp # largeT.nameForOp # Op :
|
||||
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.sr." # smallT # "." # largeT.name,
|
||||
[Pure], 1, [4], ["dstSelIndex"]>,
|
||||
Arguments<(ins I32:$oldVdst, largeT.type:$src0, I32:$seed, F32:$scale, I32Attr:$dstSelIndex)> {
|
||||
let results = (outs I32:$res);
|
||||
let summary = "Scaled convert " # largeT.name # "to " # smallT # " with stochiastic rounding, updating packed vector";
|
||||
let description = [{
|
||||
Convert a }] # largeT.name # [{ value in `src0` to a }]
|
||||
# smallT # [{ bytes, dividing by the exponent in `scale` and using `seed`
|
||||
for stochiastic rounding. Place the resulting byte in the
|
||||
`dstSelIndex`th bit of `oldVdst` and return the entire packed vector,
|
||||
which is stored as an `i32`.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src0 `,` $seed `,` $scale `->` $oldVdst `[` $dstSelIndex `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
} // foreach largeT
|
||||
} // foreach smallTOp
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// Scaled conversions to/from fp4 (f4E2M1FN)
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
foreach largeT = [
|
||||
ScaleArgInfo<ROCDL_V2F16Type, "F16">,
|
||||
ScaleArgInfo<ROCDL_V2BF16Type, "Bf16">,
|
||||
ScaleArgInfo<ROCDL_V2F32Type, "F32">,
|
||||
] in {
|
||||
// Note: rouding down f32 values has a special case where
|
||||
// we have to use 2 float arguments.
|
||||
if !ne(largeT.name, "f32") then {
|
||||
def ROCDL_CvtScaleF32PkFp4 # largeT.nameForOp # Op :
|
||||
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk.fp4." # largeT.name,
|
||||
[Pure], 1, [3], ["dstSelIndex"]>,
|
||||
Arguments<(ins I32:$oldVdst, largeT.type:$src, F32:$scale, I32Attr:$dstSelIndex)> {
|
||||
let results = (outs I32:$res);
|
||||
let summary = "Scale and convert two "
|
||||
# largeT.name # " to packed fp4, updating tied vector";
|
||||
let description = [{
|
||||
Convert two packed }] # largeT.name # [{ values to packed
|
||||
fp4, dividing by the exponent part of `scale`
|
||||
before doing so.
|
||||
|
||||
The two scaled values are packed into a byte.
|
||||
That byte is used to update the `dstSelIndex`th
|
||||
byte of `oldVdst`, which is returned in its entirity.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `,` $scale `->` $oldVdst `[` $dstSelIndex `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
} // if
|
||||
|
||||
def ROCDL_CvtScaleF32SrPkFp4 # largeT.nameForOp # Op :
|
||||
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.sr.pk.fp4." # largeT.name,
|
||||
[Pure], 1, [4], ["dstSelIndex"]>,
|
||||
Arguments<(ins I32:$oldVdst, largeT.type:$src, I32:$seed, F32:$scale, I32Attr:$dstSelIndex)> {
|
||||
let results = (outs I32:$res);
|
||||
let summary = "Scale and convert two "
|
||||
# largeT.name # " to packed fp4 with stochiastic rounding, updating tied vector";
|
||||
let description = [{
|
||||
Convert two packed }] # largeT.name # [{ values to packed
|
||||
fp4, dividing by the exponent part of `scale`
|
||||
before doing so and using `seed` as the random seed for
|
||||
stochiastic rounding.
|
||||
|
||||
The two scaled values are packed (little-endian)
|
||||
into a byte. That byte is used to update the `dstSelIndex`th
|
||||
byte of `oldVdst`, which is returned in its entirity.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `,` $seed `,` $scale `->` $oldVdst `[` $dstSelIndex `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
def ROCDL_CvtScaleF32Pk # largeT.nameForOp # Fp4Op :
|
||||
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk." # largeT.name # ".fp4",
|
||||
[Pure], 1, [2], ["srcSelIndex"]>,
|
||||
Arguments<(ins I32:$src, F32:$scale, I32Attr:$srcSelIndex)> {
|
||||
let results = (outs largeT.type:$res);
|
||||
let summary = "Scale and convert two packed fp4 to packed " # largeT.name;
|
||||
let description = [{
|
||||
Convert two packed fp4 (f4E2M1) values stored as one byte of a 32-bit integer
|
||||
to packed }] # largeT.name # [{, multiplying by the exponent part of `scale`
|
||||
before doing so.
|
||||
|
||||
The byte to convert is chosen by `srcSelIndex`.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src `[` $srcSelIndex `]` `,` $scale `:` type($res)
|
||||
}];
|
||||
}
|
||||
} // foreach largeT
|
||||
|
||||
def ROCDL_CvtScaleF32PkFp4F32Op :
|
||||
ROCDL_ConcreteNonMemIntrOp<"cvt.scalef32.pk.fp4.f32",
|
||||
[Pure], 1, [4], ["dstSelIndex"]>,
|
||||
Arguments<(ins I32:$oldVdst, F32:$src0, F32:$src1, F32:$scale, I32Attr:$dstSelIndex)> {
|
||||
let results = (outs I32:$res);
|
||||
let summary = "Scale and convert two f32 values to two packed fp4, updating tied vector";
|
||||
let description = [{
|
||||
Convert two single-precision float values, passed in `src0` and `src1`
|
||||
into two fp4 values, dividing them by the expontent part of `scale`
|
||||
before doing so.
|
||||
|
||||
The two scaled values are packed into a byte.
|
||||
That byte is used to update the `dstSelIndex`th
|
||||
byte of `oldVdst`, which is returned in its entirity.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
attr-dict $src0 `,` $src1 `,` $scale `->` $oldVdst `[` $dstSelIndex `]` `:` type($res)
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ROCDL target attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -1210,22 +1210,20 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
|
||||
}
|
||||
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
|
||||
if (resultVecType) {
|
||||
Value wordSel = createI1Constant(rewriter, loc, op.getIndex());
|
||||
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
|
||||
rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
|
||||
wordSel);
|
||||
op.getIndex());
|
||||
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
|
||||
rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
|
||||
wordSel);
|
||||
op.getIndex());
|
||||
}
|
||||
} else {
|
||||
Value byteSel = createI32Constant(rewriter, loc, op.getIndex());
|
||||
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
|
||||
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
|
||||
byteSel);
|
||||
op.getIndex());
|
||||
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
|
||||
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
|
||||
byteSel);
|
||||
op.getIndex());
|
||||
}
|
||||
}
|
||||
return success();
|
||||
@@ -1253,15 +1251,14 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
|
||||
existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
|
||||
else
|
||||
existing = rewriter.create<LLVM::UndefOp>(loc, i32);
|
||||
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
|
||||
|
||||
Value result;
|
||||
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
|
||||
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
|
||||
existing, wordSel);
|
||||
existing, op.getWordIndex());
|
||||
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
|
||||
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
|
||||
existing, wordSel);
|
||||
existing, op.getWordIndex());
|
||||
|
||||
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
|
||||
op, getTypeConverter()->convertType(resultType), result);
|
||||
@@ -1288,15 +1285,14 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
|
||||
existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
|
||||
else
|
||||
existing = rewriter.create<LLVM::UndefOp>(loc, i32);
|
||||
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
|
||||
|
||||
Value result;
|
||||
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
|
||||
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
|
||||
existing, byteSel);
|
||||
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(
|
||||
loc, i32, source, stoch, existing, op.getStoreIndex());
|
||||
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
|
||||
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
|
||||
existing, byteSel);
|
||||
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(
|
||||
loc, i32, source, stoch, existing, op.getStoreIndex());
|
||||
|
||||
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
|
||||
op, getTypeConverter()->convertType(resultType), result);
|
||||
|
||||
@@ -7,8 +7,7 @@
|
||||
// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
|
||||
// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
|
||||
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]][0] : f32
|
||||
// CHECK: return [[EXT]] : f32
|
||||
func.func @ext_scalar(%v: f8E5M2) -> f32 {
|
||||
%ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2 to f32
|
||||
@@ -25,8 +24,7 @@ func.func @ext_scalar(%v: f8E5M2) -> f32 {
|
||||
// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
|
||||
// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
|
||||
// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
|
||||
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]][1] : f32
|
||||
// CHECK: return [[EXT]] : f32
|
||||
func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 {
|
||||
%ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FN> to f32
|
||||
@@ -36,8 +34,7 @@ func.func @ext_short_vec(%v: vector<2xf8E4M3FN>) -> f32 {
|
||||
// CHECK-LABEL: func @ext_full_vec(
|
||||
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8>
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
|
||||
// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
|
||||
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
|
||||
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]][3] : f32
|
||||
// CHECK: return [[EXT]] : f32
|
||||
func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> f32 {
|
||||
%ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FN> to f32
|
||||
@@ -54,8 +51,7 @@ func.func @ext_full_vec(%v: vector<4xf8E4M3FN>) -> f32 {
|
||||
// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
|
||||
// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
|
||||
// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(false) : i1
|
||||
// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : vector<2xf32>
|
||||
// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]][false] : vector<2xf32>
|
||||
// CHECK: return [[EXT]]
|
||||
func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FN>) -> vector<2xf32> {
|
||||
%ret = amdgpu.ext_packed_fp8 %v[0] : vector<2xf8E4M3FN> to vector<2xf32>
|
||||
@@ -65,8 +61,7 @@ func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FN>) -> vector<2xf32> {
|
||||
// CHECK-LABEL: func @ext_packed_4xfp8
|
||||
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FN> to vector<4xi8>
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
|
||||
// CHECK: [[C3:%.+]] = llvm.mlir.constant(true) : i1
|
||||
// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C3]]] : vector<2xf32>
|
||||
// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]][true] : vector<2xf32>
|
||||
// CHECK: return [[EXT]] : vector<2xf32>
|
||||
func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FN>) -> vector<2xf32> {
|
||||
%ret = amdgpu.ext_packed_fp8 %v[1] : vector<4xf8E4M3FN> to vector<2xf32>
|
||||
@@ -77,8 +72,7 @@ func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FN>) -> vector<2xf32> {
|
||||
// CHECK-SAME: ([[V:%.+]]: f32)
|
||||
// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32
|
||||
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
|
||||
// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
|
||||
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
|
||||
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]][false] : i32
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
|
||||
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
|
||||
func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FN> {
|
||||
@@ -89,8 +83,7 @@ func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FN> {
|
||||
// CHECK-LABEL: func @packed_truncx2
|
||||
// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32)
|
||||
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
|
||||
// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
|
||||
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
|
||||
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]][false] : i32
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
|
||||
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
|
||||
func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FN> {
|
||||
@@ -102,8 +95,7 @@ func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FN> {
|
||||
// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[EXISTING:%.+]]: vector<4xf8E5M2>)
|
||||
// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2> to vector<4xi8>
|
||||
// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
|
||||
// CHECK: [[TRUE:%.+]] = llvm.mlir.constant(true) : i1
|
||||
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]]{{\[}}[[TRUE]]] : i32
|
||||
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]][true] : i32
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
|
||||
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2>
|
||||
func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2>) -> vector<4xf8E5M2> {
|
||||
@@ -114,8 +106,7 @@ func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2>) ->
|
||||
// CHECK-LABEL: func @packed_stoch_round
|
||||
// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32)
|
||||
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
|
||||
// CHECK: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]]{{\[}}[[C0]]] : i32
|
||||
// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]][0] : i32
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
|
||||
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
|
||||
func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FN> {
|
||||
@@ -127,8 +118,7 @@ func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FN> {
|
||||
// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32, [[EXISTING:%.+]]: vector<4xf8E5M2>)
|
||||
// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2> to vector<4xi8>
|
||||
// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
|
||||
// CHECK: [[C1:%.+]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]]{{\[}}[[C1]]] : i32
|
||||
// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]][1] : i32
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
|
||||
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2>
|
||||
func.func @packed_stoch_round_into(%v: f32, %s: i32, %existing: vector<4xf8E5M2>) -> vector<4xf8E5M2> {
|
||||
|
||||
@@ -6,8 +6,7 @@
|
||||
// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
|
||||
// CHECK: [[C0_2:%.+]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]]{{\[}}[[C0_2]]] : f32
|
||||
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.bf8 [[CAST]][0] : f32
|
||||
// CHECK: return [[EXT]] : f32
|
||||
func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 {
|
||||
%ret = amdgpu.ext_packed_fp8 %v[0] : f8E5M2FNUZ to f32
|
||||
@@ -24,8 +23,7 @@ func.func @ext_scalar(%v: f8E5M2FNUZ) -> f32 {
|
||||
// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
|
||||
// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
|
||||
// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : f32
|
||||
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]][1] : f32
|
||||
// CHECK: return [[EXT]] : f32
|
||||
func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> f32 {
|
||||
%ret = amdgpu.ext_packed_fp8 %v[1] : vector<2xf8E4M3FNUZ> to f32
|
||||
@@ -35,8 +33,7 @@ func.func @ext_short_vec(%v: vector<2xf8E4M3FNUZ>) -> f32 {
|
||||
// CHECK-LABEL: func @ext_full_vec
|
||||
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8>
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
|
||||
// CHECK: [[C3:%.+]] = llvm.mlir.constant(3 : i32) : i32
|
||||
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]]{{\[}}[[C3]]] : f32
|
||||
// CHECK: [[EXT:%.+]] = rocdl.cvt.f32.fp8 [[CAST]][3] : f32
|
||||
// CHECK: return [[EXT]] : f32
|
||||
func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> f32 {
|
||||
%ret = amdgpu.ext_packed_fp8 %v[3] : vector<4xf8E4M3FNUZ> to f32
|
||||
@@ -53,8 +50,7 @@ func.func @ext_full_vec(%v: vector<4xf8E4M3FNUZ>) -> f32 {
|
||||
// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
|
||||
// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
|
||||
// CHECK: [[C1_2:%.+]] = llvm.mlir.constant(false) : i1
|
||||
// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C1_2]]] : vector<2xf32>
|
||||
// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]][false] : vector<2xf32>
|
||||
// CHECK: return [[EXT]]
|
||||
func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FNUZ>) -> vector<2xf32> {
|
||||
%ret = amdgpu.ext_packed_fp8 %v[0] : vector<2xf8E4M3FNUZ> to vector<2xf32>
|
||||
@@ -64,8 +60,7 @@ func.func @ext_packed_2xfp8(%v: vector<2xf8E4M3FNUZ>) -> vector<2xf32> {
|
||||
// CHECK-LABEL: func @ext_packed_4xfp8(
|
||||
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<4xf8E4M3FNUZ> to vector<4xi8>
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
|
||||
// CHECK: [[C3:%.+]] = llvm.mlir.constant(true) : i1
|
||||
// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]]{{\[}}[[C3]]] : vector<2xf32>
|
||||
// CHECK: [[EXT:%.+]] = rocdl.cvt.pk.f32.fp8 [[CAST]][true] : vector<2xf32>
|
||||
// CHECK: return [[EXT]] : vector<2xf32>
|
||||
func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> {
|
||||
%ret = amdgpu.ext_packed_fp8 %v[1] : vector<4xf8E4M3FNUZ> to vector<2xf32>
|
||||
@@ -76,8 +71,7 @@ func.func @ext_packed_4xfp8(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> {
|
||||
// CHECK-SAME: ([[V:%.+]]: f32)
|
||||
// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32
|
||||
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
|
||||
// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
|
||||
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
|
||||
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[V2]] -> [[EXISTING]][false] : i32
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
|
||||
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ>
|
||||
func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FNUZ> {
|
||||
@@ -88,8 +82,7 @@ func.func @packed_trunc(%v: f32) -> vector<4xf8E4M3FNUZ> {
|
||||
// CHECK-LABEL: func @packed_truncx2
|
||||
// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32)
|
||||
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
|
||||
// CHECK: [[FALSE:%.+]] = llvm.mlir.constant(false) : i1
|
||||
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]]{{\[}}[[FALSE]]] : i32
|
||||
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.fp8.f32 [[V]], [[W]] -> [[EXISTING]][false] : i32
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
|
||||
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ>
|
||||
func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FNUZ> {
|
||||
@@ -101,8 +94,7 @@ func.func @packed_truncx2(%v: f32, %w: f32) -> vector<4xf8E4M3FNUZ> {
|
||||
// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[EXISTING:%.+]]: vector<4xf8E5M2FNUZ>)
|
||||
// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2FNUZ> to vector<4xi8>
|
||||
// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
|
||||
// CHECK: [[TRUE:%.+]] = llvm.mlir.constant(true) : i1
|
||||
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]]{{\[}}[[TRUE]]] : i32
|
||||
// CHECK: [[PACKED:%.+]] = rocdl.cvt.pk.bf8.f32 [[V]], [[W]] -> [[EXISTING_INT]][true] : i32
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
|
||||
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2FNUZ>
|
||||
func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2FNUZ>) -> vector<4xf8E5M2FNUZ> {
|
||||
@@ -113,8 +105,7 @@ func.func @packed_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2FNUZ>
|
||||
// CHECK-LABEL: func @packed_stoch_round
|
||||
// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32)
|
||||
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : i32
|
||||
// CHECK: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]]{{\[}}[[C0]]] : i32
|
||||
// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.fp8.f32 [[V]], [[S]] -> [[EXISTING]][0] : i32
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
|
||||
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FNUZ>
|
||||
func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FNUZ> {
|
||||
@@ -126,8 +117,7 @@ func.func @packed_stoch_round(%v: f32, %s: i32) -> vector<4xf8E4M3FNUZ> {
|
||||
// CHECK-SAME: ([[V:%.+]]: f32, [[S:%.+]]: i32, [[EXISTING:%.+]]: vector<4xf8E5M2FNUZ>)
|
||||
// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2FNUZ> to vector<4xi8>
|
||||
// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to i32
|
||||
// CHECK: [[C1:%.+]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]]{{\[}}[[C1]]] : i32
|
||||
// CHECK: [[PACKED:%.+]] = rocdl.cvt.sr.bf8.f32 [[V]], [[S]] -> [[EXISTING_INT]][1] : i32
|
||||
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : i32 to vector<4xi8>
|
||||
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2FNUZ>
|
||||
func.func @packed_stoch_round_into(%v: f32, %s: i32, %existing: vector<4xf8E5M2FNUZ>) -> vector<4xf8E5M2FNUZ> {
|
||||
|
||||
@@ -789,36 +789,32 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf
|
||||
// CHECK: rocdl.cvt.scalef32.sr.bf8.bf16
|
||||
// CHECK: rocdl.cvt.scalef32.pk.f32.fp8
|
||||
// CHECK: rocdl.cvt.scalef32.pk.f32.bf8
|
||||
%c0 = llvm.mlir.constant(0 : i32) : i32
|
||||
%c2 = llvm.mlir.constant(2 : i32) : i32
|
||||
%c3 = llvm.mlir.constant(3 : i32) : i32
|
||||
%c4 = llvm.mlir.constant(1.0 : f32) : f32
|
||||
%false = llvm.mlir.constant(false) : i1
|
||||
%v1 = rocdl.cvt.f32.bf8 %source[%c0] : f32
|
||||
%v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32
|
||||
%v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[%c0], %c4 : f32
|
||||
%v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32
|
||||
%v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[%false], %c4 : vector<2xf16>
|
||||
%v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : vector<2xf16>
|
||||
%v3_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[%false], %c4 : vector<2xbf16>
|
||||
%v4_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[%false], %c4 : vector<2xbf16>
|
||||
%v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %v3_scaled[%c0] : f16
|
||||
%v6 = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %v3_scaled[%c0] : f16
|
||||
%source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
|
||||
%source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
|
||||
%source2_ext = rocdl.cvt.pk.f32.bf8 %source[%false] : vector<2xf32>
|
||||
%source3_ext = rocdl.cvt.pk.f32.fp8 %source[%false] : vector<2xf32>
|
||||
%source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
|
||||
%source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
|
||||
%source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32
|
||||
%source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[%c3] : i32
|
||||
%source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[%c3] : i32
|
||||
%source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c3] : i32
|
||||
%source6_scaled = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[%c3] : i32
|
||||
%source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[%c3] : i32
|
||||
%source6_scaled_bfloat = rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[%c3] : i32
|
||||
%source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[%false], %c4 : f32
|
||||
%source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[%false], %c4 : f32
|
||||
%v1 = rocdl.cvt.f32.bf8 %source[0] : f32
|
||||
%v2 = rocdl.cvt.f32.fp8 %source[0] : f32
|
||||
%v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[0], %c4 : f32
|
||||
%v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[0], %c4 : f32
|
||||
%v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[false], %c4 : vector<2xf16>
|
||||
%v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[false], %c4 : vector<2xf16>
|
||||
%v3_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[false], %c4 : vector<2xbf16>
|
||||
%v4_scaled_bf16 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[false], %c4 : vector<2xbf16>
|
||||
%v5 = rocdl.cvt.scalef32.f16.fp8 %source[0], %c4 -> %v3_scaled[false] : vector<2xf16>
|
||||
%v6 = rocdl.cvt.scalef32.f16.bf8 %source[0], %c4 -> %v3_scaled[false] : vector<2xf16>
|
||||
%source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[false] : i32
|
||||
%source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[false] : i32
|
||||
%source2_ext = rocdl.cvt.pk.f32.bf8 %source[false] : vector<2xf32>
|
||||
%source3_ext = rocdl.cvt.pk.f32.fp8 %source[false] : vector<2xf32>
|
||||
%source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[2] : i32
|
||||
%source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[3] : i32
|
||||
%source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[3] : i32
|
||||
%source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[3] : i32
|
||||
%source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[3] : i32
|
||||
%source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[3] : i32
|
||||
%source6_scaled = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[3] : i32
|
||||
%source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[3] : i32
|
||||
%source6_scaled_bfloat = rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[3] : i32
|
||||
%source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[false], %c4 : vector<2xf32>
|
||||
%source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[false], %c4 : vector<2xf32>
|
||||
llvm.return %source5 : i32
|
||||
}
|
||||
|
||||
@@ -826,9 +822,8 @@ llvm.func @rocdl_8bit_packed_v2i16(%sourceA: f32, %sourceB: f32, %old: vector<2x
|
||||
// CHECK-LABEL: @rocdl_8bit_packed_v2i16
|
||||
// CHECK: rocdl.cvt.scalef32.pk.fp8.f32
|
||||
%c0 = llvm.mlir.constant(1.0 : f32) : f32
|
||||
%false = llvm.mlir.constant(false) : i1
|
||||
%source_scaled = rocdl.cvt.scalef32.pk.fp8.f32 %sourceA, %sourceB, %c0 -> %old[%false] : vector<2xi16>
|
||||
%source2_scaled = rocdl.cvt.scalef32.pk.bf8.f32 %sourceA, %sourceB, %c0 -> %old[%false] : vector<2xi16>
|
||||
%source_scaled = rocdl.cvt.scalef32.pk.fp8.f32 %sourceA, %sourceB, %c0 -> %old[false] : vector<2xi16>
|
||||
%source2_scaled = rocdl.cvt.scalef32.pk.bf8.f32 %sourceA, %sourceB, %c0 -> %old[false] : vector<2xi16>
|
||||
llvm.return %source_scaled : vector<2xi16>
|
||||
}
|
||||
|
||||
@@ -836,14 +831,91 @@ llvm.func @rocdl_v2f16_v2i16(%source: vector<2xf16>, %source2: vector<2xbf16>, %
|
||||
// CHECK-LABEL: @rocdl_v2f16_v2i16
|
||||
// CHECK: rocdl.cvt.scalef32.pk.fp8.f16
|
||||
%c0 = llvm.mlir.constant(1.0 : f32) : f32
|
||||
%false = llvm.mlir.constant(false) : i1
|
||||
%source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
|
||||
%source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
|
||||
%source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
|
||||
%source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
|
||||
%source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[false] : vector<2xi16>
|
||||
%source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[false] : vector<2xi16>
|
||||
%source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[false] : vector<2xi16>
|
||||
%source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[false] : vector<2xi16>
|
||||
llvm.return %source_scaled : vector<2xi16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @rocdl_6_bit_floats
|
||||
// CHECK-SAME: (%[[V32F6:.+]]: vector<6xi32>, %[[V16F32:.+]]: vector<16xf32>, %[[V32F32:.+]]: vector<32xf32>, %[[V32F16:.+]]: vector<32xf16>, %[[V32BF16:.+]]: vector<32xbf16>, %[[SEED:.+]]: i32, %[[SCALE:.+]]: f32)
|
||||
llvm.func @rocdl_6_bit_floats(
|
||||
%v32f6: vector<6xi32>, %v16f32: vector<16xf32>, %v32f32: vector<32xf32>,
|
||||
%v32f16: vector<32xf16>, %v32bf16: vector<32xbf16>, %seed: i32,
|
||||
%scale: f32) {
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.2xpk16.bf6.f32 %[[V16F32]], %[[V16F32]], %[[SCALE]]
|
||||
%f32_to_bf6 = rocdl.cvt.scalef32.2xpk16.bf6.f32 %v16f32, %v16f32, %scale : vector<6xi32>
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.2xpk16.fp6.f32 %[[V16F32]], %[[V16F32]], %[[SCALE]]
|
||||
%f32_to_fp6 = rocdl.cvt.scalef32.2xpk16.fp6.f32 %v16f32, %v16f32, %scale : vector<6xi32>
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.pk32.bf6.f16 %[[V32F16]], %[[SCALE]]
|
||||
%f16_to_bf6 = rocdl.cvt.scalef32.pk32.bf6.f16 %v32f16, %scale : vector<6xi32>
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.pk32.fp6.f16 %[[V32F16]], %[[SCALE]]
|
||||
%f16_to_fp6 = rocdl.cvt.scalef32.pk32.fp6.f16 %v32f16, %scale : vector<6xi32>
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.pk32.bf6.bf16 %[[V32BF16]], %[[SCALE]]
|
||||
%bf16_to_bf6 = rocdl.cvt.scalef32.pk32.bf6.bf16 %v32bf16, %scale : vector<6xi32>
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.pk32.fp6.bf16 %[[V32BF16]], %[[SCALE]]
|
||||
%bf16_to_fp6 = rocdl.cvt.scalef32.pk32.fp6.bf16 %v32bf16, %scale : vector<6xi32>
|
||||
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.pk32.f32.bf6 %[[V32F6]], %[[SCALE]]
|
||||
%bf6_to_f32 = rocdl.cvt.scalef32.pk32.f32.bf6 %v32f6, %scale : vector<32xf32>
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.pk32.f32.fp6 %[[V32F6]], %[[SCALE]]
|
||||
%fp6_to_f32 = rocdl.cvt.scalef32.pk32.f32.fp6 %v32f6, %scale : vector<32xf32>
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.pk32.f16.bf6 %[[V32F6]], %[[SCALE]]
|
||||
%bf6_to_f16 = rocdl.cvt.scalef32.pk32.f16.bf6 %v32f6, %scale : vector<32xf16>
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.pk32.f16.fp6 %[[V32F6]], %[[SCALE]]
|
||||
%fp6_to_f16 = rocdl.cvt.scalef32.pk32.f16.fp6 %v32f6, %scale : vector<32xf16>
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.pk32.bf16.bf6 %[[V32F6]], %[[SCALE]]
|
||||
%bf6_to_bf16 = rocdl.cvt.scalef32.pk32.bf16.bf6 %v32f6, %scale : vector<32xbf16>
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.pk32.bf16.fp6 %[[V32F6]], %[[SCALE]]
|
||||
%fp6_to_bf16 = rocdl.cvt.scalef32.pk32.bf16.fp6 %v32f6, %scale : vector<32xbf16>
|
||||
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.sr.pk32.bf6.f32 %[[V32F32]], %[[SEED]], %[[SCALE]]
|
||||
%f32_to_bf6_sr = rocdl.cvt.scalef32.sr.pk32.bf6.f32 %v32f32, %seed, %scale : vector<6xi32>
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.sr.pk32.fp6.f32 %[[V32F32]], %[[SEED]], %[[SCALE]]
|
||||
%f32_to_fp6_sr = rocdl.cvt.scalef32.sr.pk32.fp6.f32 %v32f32, %seed, %scale : vector<6xi32>
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.sr.pk32.bf6.f16 %[[V32F16]], %[[SEED]], %[[SCALE]]
|
||||
%f16_to_bf6_sr = rocdl.cvt.scalef32.sr.pk32.bf6.f16 %v32f16, %seed, %scale : vector<6xi32>
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.sr.pk32.fp6.f16 %[[V32F16]], %[[SEED]], %[[SCALE]]
|
||||
%f16_to_fp6_sr = rocdl.cvt.scalef32.sr.pk32.fp6.f16 %v32f16, %seed, %scale : vector<6xi32>
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.sr.pk32.bf6.bf16 %[[V32BF16]], %[[SEED]], %[[SCALE]]
|
||||
%bf16_to_bf6_sr = rocdl.cvt.scalef32.sr.pk32.bf6.bf16 %v32bf16, %seed, %scale : vector<6xi32>
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.sr.pk32.fp6.bf16 %[[V32BF16]], %[[SEED]], %[[SCALE]]
|
||||
%bf16_to_fp6_sr = rocdl.cvt.scalef32.sr.pk32.fp6.bf16 %v32bf16, %seed, %scale : vector<6xi32>
|
||||
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @rocdl_4_bit_floats
|
||||
// CHECK-SAME: (%[[V8F4:.+]]: i32, %[[F32:.+]]: f32, %[[V2F32:.+]]: vector<2xf32>, %[[V2F16:.+]]: vector<2xf16>, %[[V2BF16:.+]]: vector<2xbf16>, %[[SEED:.+]]: i32, %[[SCALE:.+]]: f32)
|
||||
llvm.func @rocdl_4_bit_floats(
|
||||
%v8f4: i32, %f32: f32, %v2f32: vector<2xf32>, %v2f16: vector<2xf16>,
|
||||
%v2bf16: vector<2xbf16>, %seed: i32, %scale: f32) {
|
||||
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.pk.fp4.f32 %[[F32]], %[[F32]], %[[SCALE]] -> %[[V8F4]][0]
|
||||
%f32_to_fp4 = rocdl.cvt.scalef32.pk.fp4.f32 %f32, %f32, %scale -> %v8f4[0] : i32
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.pk.fp4.f16 %[[V2F16]], %[[SCALE]] -> %[[V8F4]][1]
|
||||
%f16_to_fp4 = rocdl.cvt.scalef32.pk.fp4.f16 %v2f16, %scale -> %v8f4[1] : i32
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.pk.fp4.bf16 %[[V2BF16]], %[[SCALE]] -> %[[V8F4]][0]
|
||||
%bf16_to_fp4 = rocdl.cvt.scalef32.pk.fp4.bf16 %v2bf16, %scale -> %v8f4[0] : i32
|
||||
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.pk.f32.fp4 %[[V8F4]][0], %[[SCALE]]
|
||||
%fp4_to_f32 = rocdl.cvt.scalef32.pk.f32.fp4 %v8f4[0], %scale : vector<2xf32>
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.pk.f16.fp4 %[[V8F4]][1], %[[SCALE]]
|
||||
%fp4_to_f16 = rocdl.cvt.scalef32.pk.f16.fp4 %v8f4[1], %scale : vector<2xf16>
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.pk.bf16.fp4 %[[V8F4]][0], %[[SCALE]]
|
||||
%fp4_to_bf16 = rocdl.cvt.scalef32.pk.bf16.fp4 %v8f4[0], %scale : vector<2xbf16>
|
||||
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.sr.pk.fp4.f32 %[[V2F32]], %[[SEED]], %[[SCALE]] -> %[[V8F4]][0]
|
||||
%f32_to_fp4_sr = rocdl.cvt.scalef32.sr.pk.fp4.f32 %v2f32, %seed, %scale -> %v8f4[0] : i32
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.sr.pk.fp4.f16 %[[V2F16]], %[[SEED]], %[[SCALE]] -> %[[V8F4]][1]
|
||||
%f16_to_fp4_sr = rocdl.cvt.scalef32.sr.pk.fp4.f16 %v2f16, %seed, %scale -> %v8f4[1] : i32
|
||||
// CHECK-NEXT: rocdl.cvt.scalef32.sr.pk.fp4.bf16 %[[V2BF16]], %[[SEED]], %[[SCALE]] -> %[[V8F4]][0]
|
||||
%bf16_to_fp4_sr = rocdl.cvt.scalef32.sr.pk.fp4.bf16 %v2bf16, %seed, %scale -> %v8f4[0] : i32
|
||||
|
||||
llvm.return
|
||||
}
|
||||
|
||||
llvm.func @rocdl.s.waitcnt() {
|
||||
// CHECK-LABEL: rocdl.s.waitcnt
|
||||
// CHECK: rocdl.s.waitcnt 0
|
||||
|
||||
@@ -1081,34 +1081,31 @@ llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf
|
||||
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.bf8.bf16(i32 %{{.+}}, bfloat %{{.+}}, i32 %{{.+}}, float 1.000000e+00, i32 3)
|
||||
// CHECK: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.fp8(i32 %{{.+}}, float 1.000000e+00, i1 false)
|
||||
// CHECK: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.bf8(i32 %{{.+}}, float 1.000000e+00, i1 false)
|
||||
%c0 = llvm.mlir.constant(0 : i32) : i32
|
||||
%c2 = llvm.mlir.constant(2 : i32) : i32
|
||||
%c3 = llvm.mlir.constant(3 : i32) : i32
|
||||
%c4 = llvm.mlir.constant(1.0 : f32) : f32
|
||||
%false = llvm.mlir.constant(false) : i1
|
||||
%v1 = rocdl.cvt.f32.bf8 %source[%c0] : f32
|
||||
%v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[%c0], %c4 : f32
|
||||
%v2 = rocdl.cvt.f32.fp8 %source[%c0] : f32
|
||||
%v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[%c0], %c4 : f32
|
||||
%v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[%false], %c4 : i32
|
||||
%v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[%false], %c4 : i32
|
||||
%v5 = rocdl.cvt.scalef32.f16.fp8 %source[%false], %c4 -> %source_packed[%c0] : f16
|
||||
%v6 = rocdl.cvt.scalef32.f16.bf8 %source[%false], %c4 -> %source_packed[%c0] : f16
|
||||
%v7 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[%false], %c4 : i32
|
||||
%v8 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[%false], %c4 : i32
|
||||
%source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[%false] : i32
|
||||
%source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[%false] : i32
|
||||
%source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c2] : i32
|
||||
%source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[%c3] : i32
|
||||
%source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[%c3] : i32
|
||||
%source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[%c3] : i32
|
||||
%source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[%c3] : i32
|
||||
%source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[%c3] : i32
|
||||
%source6_scaled = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[%c3] : i32
|
||||
%source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[%c3] : i32
|
||||
%source6_scaled_bfloat = rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[%c3] : i32
|
||||
%source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[%false], %c4 : f32
|
||||
%source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[%false], %c4 : f32
|
||||
%v1 = rocdl.cvt.f32.bf8 %source[0] : f32
|
||||
%v1_scaled = rocdl.cvt.scalef32.f32.bf8 %source[0], %c4 : f32
|
||||
%v2 = rocdl.cvt.f32.fp8 %source[0] : f32
|
||||
%v2_scaled = rocdl.cvt.scalef32.f32.fp8 %source[0], %c4 : f32
|
||||
%v3_scaled = rocdl.cvt.scalef32.pk.f16.bf8 %source[false], %c4 : vector<2xf16>
|
||||
%v4_scaled = rocdl.cvt.scalef32.pk.f16.fp8 %source[false], %c4 : vector<2xf16>
|
||||
%v5 = rocdl.cvt.scalef32.f16.fp8 %source[0], %c4 -> %source_packed[false] : vector<2xf16>
|
||||
%v6 = rocdl.cvt.scalef32.f16.bf8 %source[0], %c4 -> %source_packed[false] : vector<2xf16>
|
||||
%v7 = rocdl.cvt.scalef32.pk.bf16.bf8 %source[false], %c4 : vector<2xbf16>
|
||||
%v8 = rocdl.cvt.scalef32.pk.bf16.fp8 %source[false], %c4 : vector<2xbf16>
|
||||
%source2 = rocdl.cvt.pk.bf8.f32 %v1, %v2 -> %source[false] : i32
|
||||
%source3 = rocdl.cvt.pk.fp8.f32 %v1, %v2 -> %source2[false] : i32
|
||||
%source4 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[2] : i32
|
||||
%source5 = rocdl.cvt.sr.fp8.f32 %v2, %stoch -> %source4[3] : i32
|
||||
%source5_scaled = rocdl.cvt.scalef32.sr.fp8.f32 %v2, %stoch, %c4 -> %source4[3] : i32
|
||||
%source5_scaled_half = rocdl.cvt.scalef32.sr.fp8.f16 %source_half, %stoch, %c4 -> %source4[3] : i32
|
||||
%source5_scaled_bfloat = rocdl.cvt.scalef32.sr.fp8.bf16 %source_bfloat, %stoch, %c4 -> %source4[3] : i32
|
||||
%source6 = rocdl.cvt.sr.bf8.f32 %v1, %stoch -> %source3[3] : i32
|
||||
%source6_scaled = rocdl.cvt.scalef32.sr.bf8.f32 %v2, %stoch, %c4 -> %source3[3] : i32
|
||||
%source6_scaled_half = rocdl.cvt.scalef32.sr.bf8.f16 %source_half, %stoch, %c4 -> %source3[3] : i32
|
||||
%source6_scaled_bfloat = rocdl.cvt.scalef32.sr.bf8.bf16 %source_bfloat, %stoch, %c4 -> %source3[3] : i32
|
||||
%source7_scaled = rocdl.cvt.scalef32.pk.f32.fp8 %source[false], %c4 : vector<2xf32>
|
||||
%source8_scaled = rocdl.cvt.scalef32.pk.f32.bf8 %source[false], %c4 : vector<2xf32>
|
||||
llvm.return %source5 : i32
|
||||
}
|
||||
|
||||
@@ -1117,9 +1114,8 @@ llvm.func @rocdl_8bit_packed_v2i16(%sourceA: f32, %sourceB: f32, %old: vector<2x
|
||||
// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.fp8.f32(<2 x i16> %{{.+}}, float %{{.+}}, float %{{.+}}, float 1.000000e+00, i1 false)
|
||||
// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.bf8.f32(<2 x i16> %{{.+}}, float %{{.+}}, float %{{.+}}, float 1.000000e+00, i1 false)
|
||||
%c0 = llvm.mlir.constant(1.0 : f32) : f32
|
||||
%false = llvm.mlir.constant(false) : i1
|
||||
%source_scaled = rocdl.cvt.scalef32.pk.fp8.f32 %sourceA, %sourceB, %c0 -> %old[%false] : vector<2xi16>
|
||||
%source2_scaled = rocdl.cvt.scalef32.pk.bf8.f32 %sourceA, %sourceB, %c0 -> %old[%false] : vector<2xi16>
|
||||
%source_scaled = rocdl.cvt.scalef32.pk.fp8.f32 %sourceA, %sourceB, %c0 -> %old[false] : vector<2xi16>
|
||||
%source2_scaled = rocdl.cvt.scalef32.pk.bf8.f32 %sourceA, %sourceB, %c0 -> %old[false] : vector<2xi16>
|
||||
llvm.return %source_scaled : vector<2xi16>
|
||||
}
|
||||
|
||||
@@ -1130,11 +1126,10 @@ llvm.func @rocdl_v2f16_v2i16(%source: vector<2xf16>, %source2: vector<2xbf16>, %
|
||||
// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.bf8.f16(<2 x i16> %2, <2 x half> %0, float 1.000000e+00, i1 false)
|
||||
// CHECK: call <2 x i16> @llvm.amdgcn.cvt.scalef32.pk.bf8.bf16(<2 x i16> %2, <2 x bfloat> %1, float 1.000000e+00, i1 false)
|
||||
%c0 = llvm.mlir.constant(1.0 : f32) : f32
|
||||
%false = llvm.mlir.constant(false) : i1
|
||||
%source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
|
||||
%source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
|
||||
%source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[%false] : vector<2xi16>
|
||||
%source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[%false] : vector<2xi16>
|
||||
%source_scaled = rocdl.cvt.scalef32.pk.fp8.f16 %source, %c0 -> %old[false] : vector<2xi16>
|
||||
%source2_scaled = rocdl.cvt.scalef32.pk.fp8.bf16 %source2, %c0 -> %old[false] : vector<2xi16>
|
||||
%source3_scaled = rocdl.cvt.scalef32.pk.bf8.f16 %source, %c0 -> %old[false] : vector<2xi16>
|
||||
%source4_scaled = rocdl.cvt.scalef32.pk.bf8.bf16 %source2, %c0 -> %old[false] : vector<2xi16>
|
||||
llvm.return %source_scaled : vector<2xi16>
|
||||
}
|
||||
|
||||
@@ -1145,6 +1140,83 @@ llvm.func @rocdl_16bit_packed_floats(%sourceA: f32, %sourceB: f32) -> vector<2xf
|
||||
llvm.return %source : vector<2xf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @rocdl_6_bit_floats
|
||||
// CHECK-SAME: (<6 x i32> %[[V32F6:.+]], <16 x float> %[[V16F32:.+]], <32 x float> %[[V32F32:.+]], <32 x half> %[[V32F16:.+]], <32 x bfloat> %[[V32BF16:.+]], i32 %[[SEED:.+]], float %[[SCALE:.+]])
|
||||
llvm.func @rocdl_6_bit_floats(
|
||||
%v32f6: vector<6xi32>, %v16f32: vector<16xf32>, %v32f32: vector<32xf32>,
|
||||
%v32f16: vector<32xf16>, %v32bf16: vector<32xbf16>, %seed: i32,
|
||||
%scale: f32) {
|
||||
// CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.2xpk16.bf6.f32(<16 x float> %[[V16F32]], <16 x float> %[[V16F32]], float %[[SCALE]])
|
||||
%f32_to_bf6 = rocdl.cvt.scalef32.2xpk16.bf6.f32 %v16f32, %v16f32, %scale : vector<6xi32>
|
||||
// CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.2xpk16.fp6.f32(<16 x float> %[[V16F32]], <16 x float> %[[V16F32]], float %[[SCALE]])
|
||||
%f32_to_fp6 = rocdl.cvt.scalef32.2xpk16.fp6.f32 %v16f32, %v16f32, %scale : vector<6xi32>
|
||||
// CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.pk32.bf6.f16(<32 x half> %[[V32F16]], float %[[SCALE]])
|
||||
%f16_to_bf6 = rocdl.cvt.scalef32.pk32.bf6.f16 %v32f16, %scale : vector<6xi32>
|
||||
// CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.pk32.fp6.f16(<32 x half> %[[V32F16]], float %[[SCALE]])
|
||||
%f16_to_fp6 = rocdl.cvt.scalef32.pk32.fp6.f16 %v32f16, %scale : vector<6xi32>
|
||||
// CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.pk32.bf6.bf16(<32 x bfloat> %[[V32BF16]], float %[[SCALE]])
|
||||
%bf16_to_bf6 = rocdl.cvt.scalef32.pk32.bf6.bf16 %v32bf16, %scale : vector<6xi32>
|
||||
// CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.pk32.fp6.bf16(<32 x bfloat> %[[V32BF16]], float %[[SCALE]])
|
||||
%bf16_to_fp6 = rocdl.cvt.scalef32.pk32.fp6.bf16 %v32bf16, %scale : vector<6xi32>
|
||||
|
||||
// CHECK-NEXT: call <32 x float> @llvm.amdgcn.cvt.scalef32.pk32.f32.bf6(<6 x i32> %[[V32F6]], float %[[SCALE]])
|
||||
%bf6_to_f32 = rocdl.cvt.scalef32.pk32.f32.bf6 %v32f6, %scale : vector<32xf32>
|
||||
// CHECK-NEXT: call <32 x float> @llvm.amdgcn.cvt.scalef32.pk32.f32.fp6(<6 x i32> %[[V32F6]], float %[[SCALE]])
|
||||
%fp6_to_f32 = rocdl.cvt.scalef32.pk32.f32.fp6 %v32f6, %scale : vector<32xf32>
|
||||
// CHECK-NEXT: call <32 x half> @llvm.amdgcn.cvt.scalef32.pk32.f16.bf6(<6 x i32> %[[V32F6]], float %[[SCALE]])
|
||||
%bf6_to_f16 = rocdl.cvt.scalef32.pk32.f16.bf6 %v32f6, %scale : vector<32xf16>
|
||||
// CHECK-NEXT: call <32 x half> @llvm.amdgcn.cvt.scalef32.pk32.f16.fp6(<6 x i32> %[[V32F6]], float %[[SCALE]])
|
||||
%fp6_to_f16 = rocdl.cvt.scalef32.pk32.f16.fp6 %v32f6, %scale : vector<32xf16>
|
||||
// CHECK-NEXT: call <32 x bfloat> @llvm.amdgcn.cvt.scalef32.pk32.bf16.bf6(<6 x i32> %[[V32F6]], float %[[SCALE]])
|
||||
%bf6_to_bf16 = rocdl.cvt.scalef32.pk32.bf16.bf6 %v32f6, %scale : vector<32xbf16>
|
||||
// CHECK-NEXT: call <32 x bfloat> @llvm.amdgcn.cvt.scalef32.pk32.bf16.fp6(<6 x i32> %[[V32F6]], float %[[SCALE]])
|
||||
%fp6_to_bf16 = rocdl.cvt.scalef32.pk32.bf16.fp6 %v32f6, %scale : vector<32xbf16>
|
||||
|
||||
// CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk32.bf6.f32(<32 x float> %[[V32F32]], i32 %[[SEED]], float %[[SCALE]])
|
||||
%f32_to_bf6_sr = rocdl.cvt.scalef32.sr.pk32.bf6.f32 %v32f32, %seed, %scale : vector<6xi32>
|
||||
// CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk32.fp6.f32(<32 x float> %[[V32F32]], i32 %[[SEED]], float %[[SCALE]])
|
||||
%f32_to_fp6_sr = rocdl.cvt.scalef32.sr.pk32.fp6.f32 %v32f32, %seed, %scale : vector<6xi32>
|
||||
// CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk32.bf6.f16(<32 x half> %[[V32F16]], i32 %[[SEED]], float %[[SCALE]])
|
||||
%f16_to_bf6_sr = rocdl.cvt.scalef32.sr.pk32.bf6.f16 %v32f16, %seed, %scale : vector<6xi32>
|
||||
// CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk32.fp6.f16(<32 x half> %[[V32F16]], i32 %[[SEED]], float %[[SCALE]])
|
||||
%f16_to_fp6_sr = rocdl.cvt.scalef32.sr.pk32.fp6.f16 %v32f16, %seed, %scale : vector<6xi32>
|
||||
// CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk32.bf6.bf16(<32 x bfloat> %[[V32BF16]], i32 %[[SEED]], float %[[SCALE]])
|
||||
%bf16_to_bf6_sr = rocdl.cvt.scalef32.sr.pk32.bf6.bf16 %v32bf16, %seed, %scale : vector<6xi32>
|
||||
// CHECK-NEXT: call <6 x i32> @llvm.amdgcn.cvt.scalef32.sr.pk32.fp6.bf16(<32 x bfloat> %[[V32BF16]], i32 %[[SEED]], float %[[SCALE]])
|
||||
%bf16_to_fp6_sr = rocdl.cvt.scalef32.sr.pk32.fp6.bf16 %v32bf16, %seed, %scale : vector<6xi32>
|
||||
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @rocdl_4_bit_floats
|
||||
// CHECK-SAME: (i32 %[[V8F4:.+]], float %[[F32:.+]], <2 x float> %[[V2F32:.+]], <2 x half> %[[V2F16:.+]], <2 x bfloat> %[[V2BF16:.+]], i32 %[[SEED:.+]], float %[[SCALE:.+]])
|
||||
llvm.func @rocdl_4_bit_floats(
|
||||
%v8f4: i32, %f32: f32, %v2f32: vector<2xf32>, %v2f16: vector<2xf16>,
|
||||
%v2bf16: vector<2xbf16>, %seed: i32, %scale: f32) {
|
||||
|
||||
// CHECK-NEXT: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.f32(i32 %[[V8F4]], float %[[F32]], float %[[F32]], float %[[SCALE]], i32 0)
|
||||
%f32_to_fp4 = rocdl.cvt.scalef32.pk.fp4.f32 %f32, %f32, %scale -> %v8f4[0] : i32
|
||||
// CHECK-NEXT: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.f16(i32 %[[V8F4]], <2 x half> %[[V2F16]], float %[[SCALE]], i32 1)
|
||||
%f16_to_fp4 = rocdl.cvt.scalef32.pk.fp4.f16 %v2f16, %scale -> %v8f4[1] : i32
|
||||
// CHECK-NEXT: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.bf16(i32 %[[V8F4]], <2 x bfloat> %[[V2BF16]], float %[[SCALE]], i32 0)
|
||||
%bf16_to_fp4 = rocdl.cvt.scalef32.pk.fp4.bf16 %v2bf16, %scale -> %v8f4[0] : i32
|
||||
|
||||
// CHECK-NEXT: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.fp4(i32 %[[V8F4]], float %[[SCALE]], i32 0)
|
||||
%fp4_to_f32 = rocdl.cvt.scalef32.pk.f32.fp4 %v8f4[0], %scale : vector<2xf32>
|
||||
// CHECK-NEXT: call <2 x half> @llvm.amdgcn.cvt.scalef32.pk.f16.fp4(i32 %[[V8F4]], float %[[SCALE]], i32 1)
|
||||
%fp4_to_f16 = rocdl.cvt.scalef32.pk.f16.fp4 %v8f4[1], %scale : vector<2xf16>
|
||||
// CHECK-NEXT: call <2 x bfloat> @llvm.amdgcn.cvt.scalef32.pk.bf16.fp4(i32 %[[V8F4]], float %[[SCALE]], i32 0)
|
||||
%fp4_to_bf16 = rocdl.cvt.scalef32.pk.bf16.fp4 %v8f4[0], %scale : vector<2xbf16>
|
||||
|
||||
// CHECK-NEXT: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.f32(i32 %[[V8F4]], <2 x float> %[[V2F32]], i32 %[[SEED]], float %[[SCALE]], i32 0)
|
||||
%f32_to_fp4_sr = rocdl.cvt.scalef32.sr.pk.fp4.f32 %v2f32, %seed, %scale -> %v8f4[0] : i32
|
||||
// CHECK-NEXT: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.f16(i32 %[[V8F4]], <2 x half> %[[V2F16]], i32 %[[SEED]], float %[[SCALE]], i32 1)
|
||||
%f16_to_fp4_sr = rocdl.cvt.scalef32.sr.pk.fp4.f16 %v2f16, %seed, %scale -> %v8f4[1] : i32
|
||||
// CHECK-NEXT: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.bf16(i32 %[[V8F4]], <2 x bfloat> %[[V2BF16]], i32 %[[SEED]], float %[[SCALE]], i32 0)
|
||||
%bf16_to_fp4_sr = rocdl.cvt.scalef32.sr.pk.fp4.bf16 %v2bf16, %seed, %scale -> %v8f4[0] : i32
|
||||
|
||||
llvm.return
|
||||
}
|
||||
llvm.func @rocdl_atomic_attrs(%ptr: !llvm.ptr<1>, %data: f32) {
|
||||
// CHECK-LABEL: @rocdl_atomic_attrs
|
||||
// CHECK: atomicrmw
|
||||
|
||||
Reference in New Issue
Block a user