[mlir][vector] add support for printing f16 and bf16

Love or hate it, but the vector.print operation was the very
first operation that actually made "end-to-end" CHECK integration
testing possible for MLIR. This revision adds support for
the -until recently- less common but important floating-point
types f16 and bf16.

This will become useful for accelerator specific testing (e.g. NVidia GPUs)

Reviewed By: wrengr

Differential Revision: https://reviews.llvm.org/D145207
This commit is contained in:
Aart Bik
2023-03-02 17:37:43 -08:00
parent 637ce0f713
commit 657f60a07b
6 changed files with 81 additions and 22 deletions

View File

@@ -1466,16 +1466,20 @@ public:
PrintConversion conversion = PrintConversion::None;
VectorType vectorType = printType.dyn_cast<VectorType>();
Type eltType = vectorType ? vectorType.getElementType() : printType;
auto parent = printOp->getParentOfType<ModuleOp>();
Operation *printer;
if (eltType.isF32()) {
printer =
LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
printer = LLVM::lookupOrCreatePrintF32Fn(parent);
} else if (eltType.isF64()) {
printer =
LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
printer = LLVM::lookupOrCreatePrintF64Fn(parent);
} else if (eltType.isF16()) {
conversion = PrintConversion::Bitcast16; // bits!
printer = LLVM::lookupOrCreatePrintF16Fn(parent);
} else if (eltType.isBF16()) {
conversion = PrintConversion::Bitcast16; // bits!
printer = LLVM::lookupOrCreatePrintBF16Fn(parent);
} else if (eltType.isIndex()) {
printer =
LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
printer = LLVM::lookupOrCreatePrintU64Fn(parent);
} else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
// Integers need a zero or sign extension on the operand
// (depending on the source type) as well as a signed or
@@ -1485,8 +1489,7 @@ public:
if (width <= 64) {
if (width < 64)
conversion = PrintConversion::ZeroExt64;
printer = LLVM::lookupOrCreatePrintU64Fn(
printOp->getParentOfType<ModuleOp>());
printer = LLVM::lookupOrCreatePrintU64Fn(parent);
} else {
return failure();
}
@@ -1499,8 +1502,7 @@ public:
conversion = PrintConversion::ZeroExt64;
else if (width < 64)
conversion = PrintConversion::SignExt64;
printer = LLVM::lookupOrCreatePrintI64Fn(
printOp->getParentOfType<ModuleOp>());
printer = LLVM::lookupOrCreatePrintI64Fn(parent);
} else {
return failure();
}
@@ -1515,8 +1517,7 @@ public:
emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank,
conversion);
emitCall(rewriter, printOp->getLoc(),
LLVM::lookupOrCreatePrintNewlineFn(
printOp->getParentOfType<ModuleOp>()));
LLVM::lookupOrCreatePrintNewlineFn(parent));
rewriter.eraseOp(printOp);
return success();
}
@@ -1526,7 +1527,8 @@ private:
// clang-format off
None,
ZeroExt64,
SignExt64
SignExt64,
Bitcast16
// clang-format on
};
@@ -1546,6 +1548,10 @@ private:
value = rewriter.create<arith::ExtSIOp>(
loc, IntegerType::get(rewriter.getContext(), 64), value);
break;
case PrintConversion::Bitcast16:
value = rewriter.create<LLVM::BitcastOp>(
loc, IntegerType::get(rewriter.getContext(), 16), value);
break;
case PrintConversion::None:
break;
}
@@ -1553,10 +1559,9 @@ private:
return;
}
emitCall(rewriter, loc,
LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
Operation *printComma =
LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
auto parent = op->getParentOfType<ModuleOp>();
emitCall(rewriter, loc, LLVM::lookupOrCreatePrintOpenFn(parent));
Operation *printComma = LLVM::lookupOrCreatePrintCommaFn(parent);
if (rank <= 1) {
auto reducedType = vectorType.getElementType();
@@ -1570,9 +1575,7 @@ private:
if (d != dim - 1)
emitCall(rewriter, loc, printComma);
}
emitCall(
rewriter, loc,
LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent));
return;
}
@@ -1587,8 +1590,7 @@ private:
if (d != dim - 1)
emitCall(rewriter, loc, printComma);
}
emitCall(rewriter, loc,
LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent));
}
// Helper to emit a call.