Revert "[mlir][VectorOps] Use SCF for vector.print and allow scalable vectors"
This reverts commit 3875804a07.
This caused some test failures for the MLIR python bindings. Reverting
until those are addressed.
This commit is contained in:
@@ -28,6 +28,13 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::vector;
|
||||
|
||||
// Helper to reduce vector type by one rank at front.
|
||||
static VectorType reducedVectorTypeFront(VectorType tp) {
|
||||
assert((tp.getRank() > 1) && "unlowerable vector type");
|
||||
return VectorType::get(tp.getShape().drop_front(), tp.getElementType(),
|
||||
tp.getScalableDims().drop_front());
|
||||
}
|
||||
|
||||
// Helper to reduce vector type by *all* but one rank at back.
|
||||
static VectorType reducedVectorTypeBack(VectorType tp) {
|
||||
assert((tp.getRank() > 1) && "unlowerable vector type");
|
||||
@@ -1409,89 +1416,45 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
|
||||
public:
|
||||
using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
// Lowering implementation that relies on a small runtime support library,
|
||||
// which only needs to provide a few printing methods (single value for all
|
||||
// data types, opening/closing bracket, comma, newline). The lowering splits
|
||||
// the vector into elementary printing operations. The advantage of this
|
||||
// approach is that the library can remain unaware of all low-level
|
||||
// implementation details of vectors while still supporting output of any
|
||||
// shaped and dimensioned vector.
|
||||
//
|
||||
// Note: This lowering only handles scalars, n-D vectors are broken into
|
||||
// printing scalars in loops in VectorToSCF.
|
||||
// Proof-of-concept lowering implementation that relies on a small
|
||||
// runtime support library, which only needs to provide a few
|
||||
// printing methods (single value for all data types, opening/closing
|
||||
// bracket, comma, newline). The lowering fully unrolls a vector
|
||||
// in terms of these elementary printing operations. The advantage
|
||||
// of this approach is that the library can remain unaware of all
|
||||
// low-level implementation details of vectors while still supporting
|
||||
// output of any shaped and dimensioned vector. Due to full unrolling,
|
||||
// this approach is less suited for very large vectors though.
|
||||
//
|
||||
// TODO: rely solely on libc in future? something else?
|
||||
//
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto parent = printOp->getParentOfType<ModuleOp>();
|
||||
auto loc = printOp->getLoc();
|
||||
Type printType = printOp.getPrintType();
|
||||
|
||||
if (auto value = adaptor.getSource()) {
|
||||
Type printType = printOp.getPrintType();
|
||||
if (isa<VectorType>(printType)) {
|
||||
// Vectors should be broken into elementary print ops in VectorToSCF.
|
||||
return failure();
|
||||
}
|
||||
if (failed(emitScalarPrint(rewriter, parent, loc, printType, value)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto punct = printOp.getPunctuation();
|
||||
if (punct != PrintPunctuation::NoPunctuation) {
|
||||
emitCall(rewriter, printOp->getLoc(), [&] {
|
||||
switch (punct) {
|
||||
case PrintPunctuation::Close:
|
||||
return LLVM::lookupOrCreatePrintCloseFn(parent);
|
||||
case PrintPunctuation::Open:
|
||||
return LLVM::lookupOrCreatePrintOpenFn(parent);
|
||||
case PrintPunctuation::Comma:
|
||||
return LLVM::lookupOrCreatePrintCommaFn(parent);
|
||||
case PrintPunctuation::NewLine:
|
||||
return LLVM::lookupOrCreatePrintNewlineFn(parent);
|
||||
default:
|
||||
llvm_unreachable("unexpected punctuation");
|
||||
}
|
||||
}());
|
||||
}
|
||||
|
||||
rewriter.eraseOp(printOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
enum class PrintConversion {
|
||||
// clang-format off
|
||||
None,
|
||||
ZeroExt64,
|
||||
SignExt64,
|
||||
Bitcast16
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
|
||||
ModuleOp parent, Location loc, Type printType,
|
||||
Value value) const {
|
||||
if (typeConverter->convertType(printType) == nullptr)
|
||||
return failure();
|
||||
|
||||
// Make sure element type has runtime support.
|
||||
PrintConversion conversion = PrintConversion::None;
|
||||
VectorType vectorType = dyn_cast<VectorType>(printType);
|
||||
Type eltType = vectorType ? vectorType.getElementType() : printType;
|
||||
auto parent = printOp->getParentOfType<ModuleOp>();
|
||||
Operation *printer;
|
||||
if (printType.isF32()) {
|
||||
if (eltType.isF32()) {
|
||||
printer = LLVM::lookupOrCreatePrintF32Fn(parent);
|
||||
} else if (printType.isF64()) {
|
||||
} else if (eltType.isF64()) {
|
||||
printer = LLVM::lookupOrCreatePrintF64Fn(parent);
|
||||
} else if (printType.isF16()) {
|
||||
} else if (eltType.isF16()) {
|
||||
conversion = PrintConversion::Bitcast16; // bits!
|
||||
printer = LLVM::lookupOrCreatePrintF16Fn(parent);
|
||||
} else if (printType.isBF16()) {
|
||||
} else if (eltType.isBF16()) {
|
||||
conversion = PrintConversion::Bitcast16; // bits!
|
||||
printer = LLVM::lookupOrCreatePrintBF16Fn(parent);
|
||||
} else if (printType.isIndex()) {
|
||||
} else if (eltType.isIndex()) {
|
||||
printer = LLVM::lookupOrCreatePrintU64Fn(parent);
|
||||
} else if (auto intTy = dyn_cast<IntegerType>(printType)) {
|
||||
} else if (auto intTy = dyn_cast<IntegerType>(eltType)) {
|
||||
// Integers need a zero or sign extension on the operand
|
||||
// (depending on the source type) as well as a signed or
|
||||
// unsigned print method. Up to 64-bit is supported.
|
||||
@@ -1522,26 +1485,88 @@ private:
|
||||
return failure();
|
||||
}
|
||||
|
||||
switch (conversion) {
|
||||
case PrintConversion::ZeroExt64:
|
||||
value = rewriter.create<arith::ExtUIOp>(
|
||||
loc, IntegerType::get(rewriter.getContext(), 64), value);
|
||||
break;
|
||||
case PrintConversion::SignExt64:
|
||||
value = rewriter.create<arith::ExtSIOp>(
|
||||
loc, IntegerType::get(rewriter.getContext(), 64), value);
|
||||
break;
|
||||
case PrintConversion::Bitcast16:
|
||||
value = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, IntegerType::get(rewriter.getContext(), 16), value);
|
||||
break;
|
||||
case PrintConversion::None:
|
||||
break;
|
||||
}
|
||||
emitCall(rewriter, loc, printer, value);
|
||||
// Unroll vector into elementary print calls.
|
||||
int64_t rank = vectorType ? vectorType.getRank() : 0;
|
||||
Type type = vectorType ? vectorType : eltType;
|
||||
emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank,
|
||||
conversion);
|
||||
emitCall(rewriter, printOp->getLoc(),
|
||||
LLVM::lookupOrCreatePrintNewlineFn(parent));
|
||||
rewriter.eraseOp(printOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
enum class PrintConversion {
|
||||
// clang-format off
|
||||
None,
|
||||
ZeroExt64,
|
||||
SignExt64,
|
||||
Bitcast16
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
Value value, Type type, Operation *printer, int64_t rank,
|
||||
PrintConversion conversion) const {
|
||||
VectorType vectorType = dyn_cast<VectorType>(type);
|
||||
Location loc = op->getLoc();
|
||||
if (!vectorType) {
|
||||
assert(rank == 0 && "The scalar case expects rank == 0");
|
||||
switch (conversion) {
|
||||
case PrintConversion::ZeroExt64:
|
||||
value = rewriter.create<arith::ExtUIOp>(
|
||||
loc, IntegerType::get(rewriter.getContext(), 64), value);
|
||||
break;
|
||||
case PrintConversion::SignExt64:
|
||||
value = rewriter.create<arith::ExtSIOp>(
|
||||
loc, IntegerType::get(rewriter.getContext(), 64), value);
|
||||
break;
|
||||
case PrintConversion::Bitcast16:
|
||||
value = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, IntegerType::get(rewriter.getContext(), 16), value);
|
||||
break;
|
||||
case PrintConversion::None:
|
||||
break;
|
||||
}
|
||||
emitCall(rewriter, loc, printer, value);
|
||||
return;
|
||||
}
|
||||
|
||||
auto parent = op->getParentOfType<ModuleOp>();
|
||||
emitCall(rewriter, loc, LLVM::lookupOrCreatePrintOpenFn(parent));
|
||||
Operation *printComma = LLVM::lookupOrCreatePrintCommaFn(parent);
|
||||
|
||||
if (rank <= 1) {
|
||||
auto reducedType = vectorType.getElementType();
|
||||
auto llvmType = typeConverter->convertType(reducedType);
|
||||
int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0);
|
||||
for (int64_t d = 0; d < dim; ++d) {
|
||||
Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
|
||||
llvmType, /*rank=*/0, /*pos=*/d);
|
||||
emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0,
|
||||
conversion);
|
||||
if (d != dim - 1)
|
||||
emitCall(rewriter, loc, printComma);
|
||||
}
|
||||
emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent));
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t dim = vectorType.getDimSize(0);
|
||||
for (int64_t d = 0; d < dim; ++d) {
|
||||
auto reducedType = reducedVectorTypeFront(vectorType);
|
||||
auto llvmType = typeConverter->convertType(reducedType);
|
||||
Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
|
||||
llvmType, rank, d);
|
||||
emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
|
||||
conversion);
|
||||
if (d != dim - 1)
|
||||
emitCall(rewriter, loc, printComma);
|
||||
}
|
||||
emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent));
|
||||
}
|
||||
|
||||
// Helper to emit a call.
|
||||
static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Operation *ref, ValueRange params = ValueRange()) {
|
||||
|
||||
Reference in New Issue
Block a user