[mlir] [VectorOps] generalize printing support for integers
This generalizes printing beyond just i1,i32,i64 and also accounts for signed and unsigned interpretation in the output. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D88290
This commit is contained in:
@@ -1319,44 +1319,96 @@ public:
|
||||
if (typeConverter.convertType(printType) == nullptr)
|
||||
return failure();
|
||||
|
||||
// Make sure element type has runtime support (currently just Float/Double).
|
||||
// Make sure element type has runtime support.
|
||||
PrintConversion conversion = PrintConversion::None;
|
||||
VectorType vectorType = printType.dyn_cast<VectorType>();
|
||||
Type eltType = vectorType ? vectorType.getElementType() : printType;
|
||||
int64_t rank = vectorType ? vectorType.getRank() : 0;
|
||||
Operation *printer;
|
||||
if (eltType.isSignlessInteger(1) || eltType.isSignlessInteger(32))
|
||||
printer = getPrintI32(op);
|
||||
else if (eltType.isSignlessInteger(64))
|
||||
printer = getPrintI64(op);
|
||||
else if (eltType.isF32())
|
||||
if (eltType.isF32()) {
|
||||
printer = getPrintFloat(op);
|
||||
else if (eltType.isF64())
|
||||
} else if (eltType.isF64()) {
|
||||
printer = getPrintDouble(op);
|
||||
else
|
||||
} else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
|
||||
// Integers need a zero or sign extension on the operand
|
||||
// (depending on the source type) as well as a signed or
|
||||
// unsigned print method. Up to 64-bit is supported.
|
||||
unsigned width = intTy.getWidth();
|
||||
if (intTy.isUnsigned()) {
|
||||
if (width <= 32) {
|
||||
if (width < 32)
|
||||
conversion = PrintConversion::ZeroExt32;
|
||||
printer = getPrintU32(op);
|
||||
} else if (width <= 64) {
|
||||
if (width < 64)
|
||||
conversion = PrintConversion::ZeroExt64;
|
||||
printer = getPrintU64(op);
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
} else {
|
||||
assert(intTy.isSignless() || intTy.isSigned());
|
||||
if (width <= 32) {
|
||||
// Note that we *always* zero extend booleans (1-bit integers),
|
||||
// so that true/false is printed as 1/0 rather than -1/0.
|
||||
if (width == 1)
|
||||
conversion = PrintConversion::ZeroExt32;
|
||||
else if (width < 32)
|
||||
conversion = PrintConversion::SignExt32;
|
||||
printer = getPrintI32(op);
|
||||
} else if (width <= 64) {
|
||||
if (width < 64)
|
||||
conversion = PrintConversion::SignExt64;
|
||||
printer = getPrintI64(op);
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Unroll vector into elementary print calls.
|
||||
emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank);
|
||||
int64_t rank = vectorType ? vectorType.getRank() : 0;
|
||||
emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank,
|
||||
conversion);
|
||||
emitCall(rewriter, op->getLoc(), getPrintNewline(op));
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
enum class PrintConversion {
|
||||
None,
|
||||
ZeroExt32,
|
||||
SignExt32,
|
||||
ZeroExt64,
|
||||
SignExt64
|
||||
};
|
||||
|
||||
void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
Value value, VectorType vectorType, Operation *printer,
|
||||
int64_t rank) const {
|
||||
int64_t rank, PrintConversion conversion) const {
|
||||
Location loc = op->getLoc();
|
||||
if (rank == 0) {
|
||||
if (value.getType() == LLVM::LLVMType::getInt1Ty(rewriter.getContext())) {
|
||||
// Convert i1 (bool) to i32 so we can use the print_i32 method.
|
||||
// This avoids the need for a print_i1 method with an unclear ABI.
|
||||
auto i32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext());
|
||||
auto trueVal = rewriter.create<ConstantOp>(
|
||||
loc, i32Type, rewriter.getI32IntegerAttr(1));
|
||||
auto falseVal = rewriter.create<ConstantOp>(
|
||||
loc, i32Type, rewriter.getI32IntegerAttr(0));
|
||||
value = rewriter.create<SelectOp>(loc, value, trueVal, falseVal);
|
||||
switch (conversion) {
|
||||
case PrintConversion::ZeroExt32:
|
||||
value = rewriter.create<ZeroExtendIOp>(
|
||||
loc, value, LLVM::LLVMType::getInt32Ty(rewriter.getContext()));
|
||||
break;
|
||||
case PrintConversion::SignExt32:
|
||||
value = rewriter.create<SignExtendIOp>(
|
||||
loc, value, LLVM::LLVMType::getInt32Ty(rewriter.getContext()));
|
||||
break;
|
||||
case PrintConversion::ZeroExt64:
|
||||
value = rewriter.create<ZeroExtendIOp>(
|
||||
loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
|
||||
break;
|
||||
case PrintConversion::SignExt64:
|
||||
value = rewriter.create<SignExtendIOp>(
|
||||
loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
|
||||
break;
|
||||
case PrintConversion::None:
|
||||
break;
|
||||
}
|
||||
emitCall(rewriter, loc, printer, value);
|
||||
return;
|
||||
@@ -1372,7 +1424,8 @@ private:
|
||||
rank > 1 ? reducedType : vectorType.getElementType());
|
||||
Value nestedVal =
|
||||
extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
|
||||
emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
|
||||
emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
|
||||
conversion);
|
||||
if (d != dim - 1)
|
||||
emitCall(rewriter, loc, printComma);
|
||||
}
|
||||
@@ -1410,6 +1463,14 @@ private:
|
||||
return getPrint(op, "print_i64",
|
||||
LLVM::LLVMType::getInt64Ty(op->getContext()));
|
||||
}
|
||||
Operation *getPrintU32(Operation *op) const {
|
||||
return getPrint(op, "printU32",
|
||||
LLVM::LLVMType::getInt32Ty(op->getContext()));
|
||||
}
|
||||
Operation *getPrintU64(Operation *op) const {
|
||||
return getPrint(op, "printU64",
|
||||
LLVM::LLVMType::getInt64Ty(op->getContext()));
|
||||
}
|
||||
Operation *getPrintFloat(Operation *op) const {
|
||||
return getPrint(op, "print_f32",
|
||||
LLVM::LLVMType::getFloatTy(op->getContext()));
|
||||
|
||||
Reference in New Issue
Block a user