[mlir][vector] add support for printing f16 and bf16
Love or hate it, but the vector.print operation was the very first operation that actually made "end-to-end" CHECK integration testing possible for MLIR. This revision adds support for the -until recently- less common but important floating-point types f16 and bf16. This will become useful for accelerator specific testing (e.g. NVidia GPUs) Reviewed By: wrengr Differential Revision: https://reviews.llvm.org/D145207
This commit is contained in:
@@ -1466,16 +1466,20 @@ public:
|
||||
PrintConversion conversion = PrintConversion::None;
|
||||
VectorType vectorType = printType.dyn_cast<VectorType>();
|
||||
Type eltType = vectorType ? vectorType.getElementType() : printType;
|
||||
auto parent = printOp->getParentOfType<ModuleOp>();
|
||||
Operation *printer;
|
||||
if (eltType.isF32()) {
|
||||
printer =
|
||||
LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
|
||||
printer = LLVM::lookupOrCreatePrintF32Fn(parent);
|
||||
} else if (eltType.isF64()) {
|
||||
printer =
|
||||
LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
|
||||
printer = LLVM::lookupOrCreatePrintF64Fn(parent);
|
||||
} else if (eltType.isF16()) {
|
||||
conversion = PrintConversion::Bitcast16; // bits!
|
||||
printer = LLVM::lookupOrCreatePrintF16Fn(parent);
|
||||
} else if (eltType.isBF16()) {
|
||||
conversion = PrintConversion::Bitcast16; // bits!
|
||||
printer = LLVM::lookupOrCreatePrintBF16Fn(parent);
|
||||
} else if (eltType.isIndex()) {
|
||||
printer =
|
||||
LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
|
||||
printer = LLVM::lookupOrCreatePrintU64Fn(parent);
|
||||
} else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
|
||||
// Integers need a zero or sign extension on the operand
|
||||
// (depending on the source type) as well as a signed or
|
||||
@@ -1485,8 +1489,7 @@ public:
|
||||
if (width <= 64) {
|
||||
if (width < 64)
|
||||
conversion = PrintConversion::ZeroExt64;
|
||||
printer = LLVM::lookupOrCreatePrintU64Fn(
|
||||
printOp->getParentOfType<ModuleOp>());
|
||||
printer = LLVM::lookupOrCreatePrintU64Fn(parent);
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
@@ -1499,8 +1502,7 @@ public:
|
||||
conversion = PrintConversion::ZeroExt64;
|
||||
else if (width < 64)
|
||||
conversion = PrintConversion::SignExt64;
|
||||
printer = LLVM::lookupOrCreatePrintI64Fn(
|
||||
printOp->getParentOfType<ModuleOp>());
|
||||
printer = LLVM::lookupOrCreatePrintI64Fn(parent);
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
@@ -1515,8 +1517,7 @@ public:
|
||||
emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank,
|
||||
conversion);
|
||||
emitCall(rewriter, printOp->getLoc(),
|
||||
LLVM::lookupOrCreatePrintNewlineFn(
|
||||
printOp->getParentOfType<ModuleOp>()));
|
||||
LLVM::lookupOrCreatePrintNewlineFn(parent));
|
||||
rewriter.eraseOp(printOp);
|
||||
return success();
|
||||
}
|
||||
@@ -1526,7 +1527,8 @@ private:
|
||||
// clang-format off
|
||||
None,
|
||||
ZeroExt64,
|
||||
SignExt64
|
||||
SignExt64,
|
||||
Bitcast16
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
@@ -1546,6 +1548,10 @@ private:
|
||||
value = rewriter.create<arith::ExtSIOp>(
|
||||
loc, IntegerType::get(rewriter.getContext(), 64), value);
|
||||
break;
|
||||
case PrintConversion::Bitcast16:
|
||||
value = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, IntegerType::get(rewriter.getContext(), 16), value);
|
||||
break;
|
||||
case PrintConversion::None:
|
||||
break;
|
||||
}
|
||||
@@ -1553,10 +1559,9 @@ private:
|
||||
return;
|
||||
}
|
||||
|
||||
emitCall(rewriter, loc,
|
||||
LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
|
||||
Operation *printComma =
|
||||
LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
|
||||
auto parent = op->getParentOfType<ModuleOp>();
|
||||
emitCall(rewriter, loc, LLVM::lookupOrCreatePrintOpenFn(parent));
|
||||
Operation *printComma = LLVM::lookupOrCreatePrintCommaFn(parent);
|
||||
|
||||
if (rank <= 1) {
|
||||
auto reducedType = vectorType.getElementType();
|
||||
@@ -1570,9 +1575,7 @@ private:
|
||||
if (d != dim - 1)
|
||||
emitCall(rewriter, loc, printComma);
|
||||
}
|
||||
emitCall(
|
||||
rewriter, loc,
|
||||
LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
|
||||
emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1587,8 +1590,7 @@ private:
|
||||
if (d != dim - 1)
|
||||
emitCall(rewriter, loc, printComma);
|
||||
}
|
||||
emitCall(rewriter, loc,
|
||||
LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
|
||||
emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent));
|
||||
}
|
||||
|
||||
// Helper to emit a call.
|
||||
|
||||
Reference in New Issue
Block a user