[mlir] [VectorOps] generalize printing support for integers

This generalizes printing beyond just i1,i32,i64 and also accounts
for signed and unsigned interpretation in the output.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D88290
This commit is contained in:
Aart Bik
2020-09-25 03:32:05 -07:00
parent f330d9f163
commit b8880f5f97
4 changed files with 242 additions and 24 deletions

View File

@@ -1319,44 +1319,96 @@ public:
if (typeConverter.convertType(printType) == nullptr)
return failure();
// Make sure element type has runtime support (currently just Float/Double).
// Make sure element type has runtime support.
PrintConversion conversion = PrintConversion::None;
VectorType vectorType = printType.dyn_cast<VectorType>();
Type eltType = vectorType ? vectorType.getElementType() : printType;
int64_t rank = vectorType ? vectorType.getRank() : 0;
Operation *printer;
if (eltType.isSignlessInteger(1) || eltType.isSignlessInteger(32))
printer = getPrintI32(op);
else if (eltType.isSignlessInteger(64))
printer = getPrintI64(op);
else if (eltType.isF32())
if (eltType.isF32()) {
printer = getPrintFloat(op);
else if (eltType.isF64())
} else if (eltType.isF64()) {
printer = getPrintDouble(op);
else
} else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
// Integers need a zero or sign extension on the operand
// (depending on the source type) as well as a signed or
// unsigned print method. Up to 64-bit is supported.
unsigned width = intTy.getWidth();
if (intTy.isUnsigned()) {
if (width <= 32) {
if (width < 32)
conversion = PrintConversion::ZeroExt32;
printer = getPrintU32(op);
} else if (width <= 64) {
if (width < 64)
conversion = PrintConversion::ZeroExt64;
printer = getPrintU64(op);
} else {
return failure();
}
} else {
assert(intTy.isSignless() || intTy.isSigned());
if (width <= 32) {
// Note that we *always* zero extend booleans (1-bit integers),
// so that true/false is printed as 1/0 rather than -1/0.
if (width == 1)
conversion = PrintConversion::ZeroExt32;
else if (width < 32)
conversion = PrintConversion::SignExt32;
printer = getPrintI32(op);
} else if (width <= 64) {
if (width < 64)
conversion = PrintConversion::SignExt64;
printer = getPrintI64(op);
} else {
return failure();
}
}
} else {
return failure();
}
// Unroll vector into elementary print calls.
emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank);
int64_t rank = vectorType ? vectorType.getRank() : 0;
emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank,
conversion);
emitCall(rewriter, op->getLoc(), getPrintNewline(op));
rewriter.eraseOp(op);
return success();
}
private:
enum class PrintConversion {
None,
ZeroExt32,
SignExt32,
ZeroExt64,
SignExt64
};
void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
Value value, VectorType vectorType, Operation *printer,
int64_t rank) const {
int64_t rank, PrintConversion conversion) const {
Location loc = op->getLoc();
if (rank == 0) {
if (value.getType() == LLVM::LLVMType::getInt1Ty(rewriter.getContext())) {
// Convert i1 (bool) to i32 so we can use the print_i32 method.
// This avoids the need for a print_i1 method with an unclear ABI.
auto i32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext());
auto trueVal = rewriter.create<ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr(1));
auto falseVal = rewriter.create<ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr(0));
value = rewriter.create<SelectOp>(loc, value, trueVal, falseVal);
switch (conversion) {
case PrintConversion::ZeroExt32:
value = rewriter.create<ZeroExtendIOp>(
loc, value, LLVM::LLVMType::getInt32Ty(rewriter.getContext()));
break;
case PrintConversion::SignExt32:
value = rewriter.create<SignExtendIOp>(
loc, value, LLVM::LLVMType::getInt32Ty(rewriter.getContext()));
break;
case PrintConversion::ZeroExt64:
value = rewriter.create<ZeroExtendIOp>(
loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
break;
case PrintConversion::SignExt64:
value = rewriter.create<SignExtendIOp>(
loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext()));
break;
case PrintConversion::None:
break;
}
emitCall(rewriter, loc, printer, value);
return;
@@ -1372,7 +1424,8 @@ private:
rank > 1 ? reducedType : vectorType.getElementType());
Value nestedVal =
extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
conversion);
if (d != dim - 1)
emitCall(rewriter, loc, printComma);
}
@@ -1410,6 +1463,14 @@ private:
return getPrint(op, "print_i64",
LLVM::LLVMType::getInt64Ty(op->getContext()));
}
Operation *getPrintU32(Operation *op) const {
return getPrint(op, "printU32",
LLVM::LLVMType::getInt32Ty(op->getContext()));
}
Operation *getPrintU64(Operation *op) const {
return getPrint(op, "printU64",
LLVM::LLVMType::getInt64Ty(op->getContext()));
}
Operation *getPrintFloat(Operation *op) const {
return getPrint(op, "print_f32",
LLVM::LLVMType::getFloatTy(op->getContext()));