Revert "[mlir][VectorOps] Use SCF for vector.print and allow scalable vectors"
This reverts commit 490dae26cb.
Bot is broken, seems like there is a problem of ambiguity in the parser.
This commit is contained in:
@@ -28,6 +28,13 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::vector;
|
||||
|
||||
// Helper to reduce vector type by one rank at front.
|
||||
static VectorType reducedVectorTypeFront(VectorType tp) {
|
||||
assert((tp.getRank() > 1) && "unlowerable vector type");
|
||||
return VectorType::get(tp.getShape().drop_front(), tp.getElementType(),
|
||||
tp.getScalableDims().drop_front());
|
||||
}
|
||||
|
||||
// Helper to reduce vector type by *all* but one rank at back.
|
||||
static VectorType reducedVectorTypeBack(VectorType tp) {
|
||||
assert((tp.getRank() > 1) && "unlowerable vector type");
|
||||
@@ -1409,89 +1416,45 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
|
||||
public:
|
||||
using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
// Lowering implementation that relies on a small runtime support library,
|
||||
// which only needs to provide a few printing methods (single value for all
|
||||
// data types, opening/closing bracket, comma, newline). The lowering splits
|
||||
// the vector into elementary printing operations. The advantage of this
|
||||
// approach is that the library can remain unaware of all low-level
|
||||
// implementation details of vectors while still supporting output of any
|
||||
// shaped and dimensioned vector.
|
||||
//
|
||||
// Note: This lowering only handles scalars, n-D vectors are broken into
|
||||
// printing scalars in loops in VectorToSCF.
|
||||
// Proof-of-concept lowering implementation that relies on a small
|
||||
// runtime support library, which only needs to provide a few
|
||||
// printing methods (single value for all data types, opening/closing
|
||||
// bracket, comma, newline). The lowering fully unrolls a vector
|
||||
// in terms of these elementary printing operations. The advantage
|
||||
// of this approach is that the library can remain unaware of all
|
||||
// low-level implementation details of vectors while still supporting
|
||||
// output of any shaped and dimensioned vector. Due to full unrolling,
|
||||
// this approach is less suited for very large vectors though.
|
||||
//
|
||||
// TODO: rely solely on libc in future? something else?
|
||||
//
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto parent = printOp->getParentOfType<ModuleOp>();
|
||||
auto loc = printOp->getLoc();
|
||||
Type printType = printOp.getPrintType();
|
||||
|
||||
if (auto value = adaptor.getSource()) {
|
||||
Type printType = printOp.getPrintType();
|
||||
if (isa<VectorType>(printType)) {
|
||||
// Vectors should be broken into elementary print ops in VectorToSCF.
|
||||
return failure();
|
||||
}
|
||||
if (failed(emitScalarPrint(rewriter, parent, loc, printType, value)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto punct = printOp.getPunctuation();
|
||||
if (punct != PrintPunctuation::NoPunctuation) {
|
||||
emitCall(rewriter, printOp->getLoc(), [&] {
|
||||
switch (punct) {
|
||||
case PrintPunctuation::Close:
|
||||
return LLVM::lookupOrCreatePrintCloseFn(parent);
|
||||
case PrintPunctuation::Open:
|
||||
return LLVM::lookupOrCreatePrintOpenFn(parent);
|
||||
case PrintPunctuation::Comma:
|
||||
return LLVM::lookupOrCreatePrintCommaFn(parent);
|
||||
case PrintPunctuation::NewLine:
|
||||
return LLVM::lookupOrCreatePrintNewlineFn(parent);
|
||||
default:
|
||||
llvm_unreachable("unexpected punctuation");
|
||||
}
|
||||
}());
|
||||
}
|
||||
|
||||
rewriter.eraseOp(printOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
enum class PrintConversion {
|
||||
// clang-format off
|
||||
None,
|
||||
ZeroExt64,
|
||||
SignExt64,
|
||||
Bitcast16
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
|
||||
ModuleOp parent, Location loc, Type printType,
|
||||
Value value) const {
|
||||
if (typeConverter->convertType(printType) == nullptr)
|
||||
return failure();
|
||||
|
||||
// Make sure element type has runtime support.
|
||||
PrintConversion conversion = PrintConversion::None;
|
||||
VectorType vectorType = dyn_cast<VectorType>(printType);
|
||||
Type eltType = vectorType ? vectorType.getElementType() : printType;
|
||||
auto parent = printOp->getParentOfType<ModuleOp>();
|
||||
Operation *printer;
|
||||
if (printType.isF32()) {
|
||||
if (eltType.isF32()) {
|
||||
printer = LLVM::lookupOrCreatePrintF32Fn(parent);
|
||||
} else if (printType.isF64()) {
|
||||
} else if (eltType.isF64()) {
|
||||
printer = LLVM::lookupOrCreatePrintF64Fn(parent);
|
||||
} else if (printType.isF16()) {
|
||||
} else if (eltType.isF16()) {
|
||||
conversion = PrintConversion::Bitcast16; // bits!
|
||||
printer = LLVM::lookupOrCreatePrintF16Fn(parent);
|
||||
} else if (printType.isBF16()) {
|
||||
} else if (eltType.isBF16()) {
|
||||
conversion = PrintConversion::Bitcast16; // bits!
|
||||
printer = LLVM::lookupOrCreatePrintBF16Fn(parent);
|
||||
} else if (printType.isIndex()) {
|
||||
} else if (eltType.isIndex()) {
|
||||
printer = LLVM::lookupOrCreatePrintU64Fn(parent);
|
||||
} else if (auto intTy = dyn_cast<IntegerType>(printType)) {
|
||||
} else if (auto intTy = dyn_cast<IntegerType>(eltType)) {
|
||||
// Integers need a zero or sign extension on the operand
|
||||
// (depending on the source type) as well as a signed or
|
||||
// unsigned print method. Up to 64-bit is supported.
|
||||
@@ -1522,26 +1485,88 @@ private:
|
||||
return failure();
|
||||
}
|
||||
|
||||
switch (conversion) {
|
||||
case PrintConversion::ZeroExt64:
|
||||
value = rewriter.create<arith::ExtUIOp>(
|
||||
loc, IntegerType::get(rewriter.getContext(), 64), value);
|
||||
break;
|
||||
case PrintConversion::SignExt64:
|
||||
value = rewriter.create<arith::ExtSIOp>(
|
||||
loc, IntegerType::get(rewriter.getContext(), 64), value);
|
||||
break;
|
||||
case PrintConversion::Bitcast16:
|
||||
value = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, IntegerType::get(rewriter.getContext(), 16), value);
|
||||
break;
|
||||
case PrintConversion::None:
|
||||
break;
|
||||
}
|
||||
emitCall(rewriter, loc, printer, value);
|
||||
// Unroll vector into elementary print calls.
|
||||
int64_t rank = vectorType ? vectorType.getRank() : 0;
|
||||
Type type = vectorType ? vectorType : eltType;
|
||||
emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank,
|
||||
conversion);
|
||||
emitCall(rewriter, printOp->getLoc(),
|
||||
LLVM::lookupOrCreatePrintNewlineFn(parent));
|
||||
rewriter.eraseOp(printOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
enum class PrintConversion {
|
||||
// clang-format off
|
||||
None,
|
||||
ZeroExt64,
|
||||
SignExt64,
|
||||
Bitcast16
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
Value value, Type type, Operation *printer, int64_t rank,
|
||||
PrintConversion conversion) const {
|
||||
VectorType vectorType = dyn_cast<VectorType>(type);
|
||||
Location loc = op->getLoc();
|
||||
if (!vectorType) {
|
||||
assert(rank == 0 && "The scalar case expects rank == 0");
|
||||
switch (conversion) {
|
||||
case PrintConversion::ZeroExt64:
|
||||
value = rewriter.create<arith::ExtUIOp>(
|
||||
loc, IntegerType::get(rewriter.getContext(), 64), value);
|
||||
break;
|
||||
case PrintConversion::SignExt64:
|
||||
value = rewriter.create<arith::ExtSIOp>(
|
||||
loc, IntegerType::get(rewriter.getContext(), 64), value);
|
||||
break;
|
||||
case PrintConversion::Bitcast16:
|
||||
value = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, IntegerType::get(rewriter.getContext(), 16), value);
|
||||
break;
|
||||
case PrintConversion::None:
|
||||
break;
|
||||
}
|
||||
emitCall(rewriter, loc, printer, value);
|
||||
return;
|
||||
}
|
||||
|
||||
auto parent = op->getParentOfType<ModuleOp>();
|
||||
emitCall(rewriter, loc, LLVM::lookupOrCreatePrintOpenFn(parent));
|
||||
Operation *printComma = LLVM::lookupOrCreatePrintCommaFn(parent);
|
||||
|
||||
if (rank <= 1) {
|
||||
auto reducedType = vectorType.getElementType();
|
||||
auto llvmType = typeConverter->convertType(reducedType);
|
||||
int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0);
|
||||
for (int64_t d = 0; d < dim; ++d) {
|
||||
Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
|
||||
llvmType, /*rank=*/0, /*pos=*/d);
|
||||
emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0,
|
||||
conversion);
|
||||
if (d != dim - 1)
|
||||
emitCall(rewriter, loc, printComma);
|
||||
}
|
||||
emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent));
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t dim = vectorType.getDimSize(0);
|
||||
for (int64_t d = 0; d < dim; ++d) {
|
||||
auto reducedType = reducedVectorTypeFront(vectorType);
|
||||
auto llvmType = typeConverter->convertType(reducedType);
|
||||
Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
|
||||
llvmType, rank, d);
|
||||
emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
|
||||
conversion);
|
||||
if (d != dim - 1)
|
||||
emitCall(rewriter, loc, printComma);
|
||||
}
|
||||
emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent));
|
||||
}
|
||||
|
||||
// Helper to emit a call.
|
||||
static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Operation *ref, ValueRange params = ValueRange()) {
|
||||
|
||||
Reference in New Issue
Block a user