[mlir][VectorOps] Use SCF for vector.print and allow scalable vectors
Reland of the original patch after updating the Python binding tests and a few CUDA/GPU MLIR tests. This patch splits the lowering of vector.print into first converting an n-D print into a loop of scalar prints of the elements, then a second pass that converts those scalar prints into the runtime calls. The former is done in VectorToSCF and the latter in VectorToLLVM. The main reason for this is to allow printing scalable vector types, which are not possible to fully unroll at compile time, though this also avoids fully unrolling very large vectors. To allow VectorToSCF to add the necessary punctuation between vectors and elements, a "punctuation" attribute has been added to vector.print. This abstracts calling the runtime functions such as printNewline(), without leaking the LLVM details into the higher abstraction levels. For example: vector.print <comma> lowers to llvm.call @printComma() : () -> () The output format and runtime functions remain the same, which avoids the need to alter a large number of tests (aside from the pipelines). Reviewed By: awarzynski, c-rhodes, aartbik Differential Revision: https://reviews.llvm.org/D156519
This commit is contained in:
@@ -28,13 +28,6 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::vector;
|
||||
|
||||
// Helper to reduce vector type by one rank at front.
|
||||
static VectorType reducedVectorTypeFront(VectorType tp) {
|
||||
assert((tp.getRank() > 1) && "unlowerable vector type");
|
||||
return VectorType::get(tp.getShape().drop_front(), tp.getElementType(),
|
||||
tp.getScalableDims().drop_front());
|
||||
}
|
||||
|
||||
// Helper to reduce vector type by *all* but one rank at back.
|
||||
static VectorType reducedVectorTypeBack(VectorType tp) {
|
||||
assert((tp.getRank() > 1) && "unlowerable vector type");
|
||||
@@ -1416,45 +1409,89 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
|
||||
public:
|
||||
using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
// Proof-of-concept lowering implementation that relies on a small
|
||||
// runtime support library, which only needs to provide a few
|
||||
// printing methods (single value for all data types, opening/closing
|
||||
// bracket, comma, newline). The lowering fully unrolls a vector
|
||||
// in terms of these elementary printing operations. The advantage
|
||||
// of this approach is that the library can remain unaware of all
|
||||
// low-level implementation details of vectors while still supporting
|
||||
// output of any shaped and dimensioned vector. Due to full unrolling,
|
||||
// this approach is less suited for very large vectors though.
|
||||
// Lowering implementation that relies on a small runtime support library,
|
||||
// which only needs to provide a few printing methods (single value for all
|
||||
// data types, opening/closing bracket, comma, newline). The lowering splits
|
||||
// the vector into elementary printing operations. The advantage of this
|
||||
// approach is that the library can remain unaware of all low-level
|
||||
// implementation details of vectors while still supporting output of any
|
||||
// shaped and dimensioned vector.
|
||||
//
|
||||
// Note: This lowering only handles scalars, n-D vectors are broken into
|
||||
// printing scalars in loops in VectorToSCF.
|
||||
//
|
||||
// TODO: rely solely on libc in future? something else?
|
||||
//
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type printType = printOp.getPrintType();
|
||||
auto parent = printOp->getParentOfType<ModuleOp>();
|
||||
auto loc = printOp->getLoc();
|
||||
|
||||
if (auto value = adaptor.getSource()) {
|
||||
Type printType = printOp.getPrintType();
|
||||
if (isa<VectorType>(printType)) {
|
||||
// Vectors should be broken into elementary print ops in VectorToSCF.
|
||||
return failure();
|
||||
}
|
||||
if (failed(emitScalarPrint(rewriter, parent, loc, printType, value)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto punct = printOp.getPunctuation();
|
||||
if (punct != PrintPunctuation::NoPunctuation) {
|
||||
emitCall(rewriter, printOp->getLoc(), [&] {
|
||||
switch (punct) {
|
||||
case PrintPunctuation::Close:
|
||||
return LLVM::lookupOrCreatePrintCloseFn(parent);
|
||||
case PrintPunctuation::Open:
|
||||
return LLVM::lookupOrCreatePrintOpenFn(parent);
|
||||
case PrintPunctuation::Comma:
|
||||
return LLVM::lookupOrCreatePrintCommaFn(parent);
|
||||
case PrintPunctuation::NewLine:
|
||||
return LLVM::lookupOrCreatePrintNewlineFn(parent);
|
||||
default:
|
||||
llvm_unreachable("unexpected punctuation");
|
||||
}
|
||||
}());
|
||||
}
|
||||
|
||||
rewriter.eraseOp(printOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
enum class PrintConversion {
|
||||
// clang-format off
|
||||
None,
|
||||
ZeroExt64,
|
||||
SignExt64,
|
||||
Bitcast16
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
|
||||
ModuleOp parent, Location loc, Type printType,
|
||||
Value value) const {
|
||||
if (typeConverter->convertType(printType) == nullptr)
|
||||
return failure();
|
||||
|
||||
// Make sure element type has runtime support.
|
||||
PrintConversion conversion = PrintConversion::None;
|
||||
VectorType vectorType = dyn_cast<VectorType>(printType);
|
||||
Type eltType = vectorType ? vectorType.getElementType() : printType;
|
||||
auto parent = printOp->getParentOfType<ModuleOp>();
|
||||
Operation *printer;
|
||||
if (eltType.isF32()) {
|
||||
if (printType.isF32()) {
|
||||
printer = LLVM::lookupOrCreatePrintF32Fn(parent);
|
||||
} else if (eltType.isF64()) {
|
||||
} else if (printType.isF64()) {
|
||||
printer = LLVM::lookupOrCreatePrintF64Fn(parent);
|
||||
} else if (eltType.isF16()) {
|
||||
} else if (printType.isF16()) {
|
||||
conversion = PrintConversion::Bitcast16; // bits!
|
||||
printer = LLVM::lookupOrCreatePrintF16Fn(parent);
|
||||
} else if (eltType.isBF16()) {
|
||||
} else if (printType.isBF16()) {
|
||||
conversion = PrintConversion::Bitcast16; // bits!
|
||||
printer = LLVM::lookupOrCreatePrintBF16Fn(parent);
|
||||
} else if (eltType.isIndex()) {
|
||||
} else if (printType.isIndex()) {
|
||||
printer = LLVM::lookupOrCreatePrintU64Fn(parent);
|
||||
} else if (auto intTy = dyn_cast<IntegerType>(eltType)) {
|
||||
} else if (auto intTy = dyn_cast<IntegerType>(printType)) {
|
||||
// Integers need a zero or sign extension on the operand
|
||||
// (depending on the source type) as well as a signed or
|
||||
// unsigned print method. Up to 64-bit is supported.
|
||||
@@ -1485,88 +1522,26 @@ public:
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Unroll vector into elementary print calls.
|
||||
int64_t rank = vectorType ? vectorType.getRank() : 0;
|
||||
Type type = vectorType ? vectorType : eltType;
|
||||
emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank,
|
||||
conversion);
|
||||
emitCall(rewriter, printOp->getLoc(),
|
||||
LLVM::lookupOrCreatePrintNewlineFn(parent));
|
||||
rewriter.eraseOp(printOp);
|
||||
switch (conversion) {
|
||||
case PrintConversion::ZeroExt64:
|
||||
value = rewriter.create<arith::ExtUIOp>(
|
||||
loc, IntegerType::get(rewriter.getContext(), 64), value);
|
||||
break;
|
||||
case PrintConversion::SignExt64:
|
||||
value = rewriter.create<arith::ExtSIOp>(
|
||||
loc, IntegerType::get(rewriter.getContext(), 64), value);
|
||||
break;
|
||||
case PrintConversion::Bitcast16:
|
||||
value = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, IntegerType::get(rewriter.getContext(), 16), value);
|
||||
break;
|
||||
case PrintConversion::None:
|
||||
break;
|
||||
}
|
||||
emitCall(rewriter, loc, printer, value);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
enum class PrintConversion {
|
||||
// clang-format off
|
||||
None,
|
||||
ZeroExt64,
|
||||
SignExt64,
|
||||
Bitcast16
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
Value value, Type type, Operation *printer, int64_t rank,
|
||||
PrintConversion conversion) const {
|
||||
VectorType vectorType = dyn_cast<VectorType>(type);
|
||||
Location loc = op->getLoc();
|
||||
if (!vectorType) {
|
||||
assert(rank == 0 && "The scalar case expects rank == 0");
|
||||
switch (conversion) {
|
||||
case PrintConversion::ZeroExt64:
|
||||
value = rewriter.create<arith::ExtUIOp>(
|
||||
loc, IntegerType::get(rewriter.getContext(), 64), value);
|
||||
break;
|
||||
case PrintConversion::SignExt64:
|
||||
value = rewriter.create<arith::ExtSIOp>(
|
||||
loc, IntegerType::get(rewriter.getContext(), 64), value);
|
||||
break;
|
||||
case PrintConversion::Bitcast16:
|
||||
value = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, IntegerType::get(rewriter.getContext(), 16), value);
|
||||
break;
|
||||
case PrintConversion::None:
|
||||
break;
|
||||
}
|
||||
emitCall(rewriter, loc, printer, value);
|
||||
return;
|
||||
}
|
||||
|
||||
auto parent = op->getParentOfType<ModuleOp>();
|
||||
emitCall(rewriter, loc, LLVM::lookupOrCreatePrintOpenFn(parent));
|
||||
Operation *printComma = LLVM::lookupOrCreatePrintCommaFn(parent);
|
||||
|
||||
if (rank <= 1) {
|
||||
auto reducedType = vectorType.getElementType();
|
||||
auto llvmType = typeConverter->convertType(reducedType);
|
||||
int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0);
|
||||
for (int64_t d = 0; d < dim; ++d) {
|
||||
Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
|
||||
llvmType, /*rank=*/0, /*pos=*/d);
|
||||
emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0,
|
||||
conversion);
|
||||
if (d != dim - 1)
|
||||
emitCall(rewriter, loc, printComma);
|
||||
}
|
||||
emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent));
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t dim = vectorType.getDimSize(0);
|
||||
for (int64_t d = 0; d < dim; ++d) {
|
||||
auto reducedType = reducedVectorTypeFront(vectorType);
|
||||
auto llvmType = typeConverter->convertType(reducedType);
|
||||
Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
|
||||
llvmType, rank, d);
|
||||
emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
|
||||
conversion);
|
||||
if (d != dim - 1)
|
||||
emitCall(rewriter, loc, printComma);
|
||||
}
|
||||
emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent));
|
||||
}
|
||||
|
||||
// Helper to emit a call.
|
||||
static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Operation *ref, ValueRange params = ValueRange()) {
|
||||
|
||||
Reference in New Issue
Block a user