Revert "[mlir][VectorOps] Use SCF for vector.print and allow scalable vectors"

This reverts commit 490dae26cb.

Bot is broken, seems like there is a problem of ambiguity in the parser.
This commit is contained in:
Mehdi Amini
2023-08-09 17:30:05 -07:00
parent 5033ec0a9e
commit 1b272d21c8
64 changed files with 243 additions and 468 deletions

View File

@@ -28,6 +28,13 @@
using namespace mlir;
using namespace mlir::vector;
// Helper to reduce vector type by one rank at front.
static VectorType reducedVectorTypeFront(VectorType tp) {
assert((tp.getRank() > 1) && "unlowerable vector type");
return VectorType::get(tp.getShape().drop_front(), tp.getElementType(),
tp.getScalableDims().drop_front());
}
// Helper to reduce vector type by *all* but one rank at back.
static VectorType reducedVectorTypeBack(VectorType tp) {
assert((tp.getRank() > 1) && "unlowerable vector type");
@@ -1409,89 +1416,45 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
public:
using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
// Lowering implementation that relies on a small runtime support library,
// which only needs to provide a few printing methods (single value for all
// data types, opening/closing bracket, comma, newline). The lowering splits
// the vector into elementary printing operations. The advantage of this
// approach is that the library can remain unaware of all low-level
// implementation details of vectors while still supporting output of any
// shaped and dimensioned vector.
//
// Note: This lowering only handles scalars, n-D vectors are broken into
// printing scalars in loops in VectorToSCF.
// Proof-of-concept lowering implementation that relies on a small
// runtime support library, which only needs to provide a few
// printing methods (single value for all data types, opening/closing
// bracket, comma, newline). The lowering fully unrolls a vector
// in terms of these elementary printing operations. The advantage
// of this approach is that the library can remain unaware of all
// low-level implementation details of vectors while still supporting
// output of any shaped and dimensioned vector. Due to full unrolling,
// this approach is less suited for very large vectors though.
//
// TODO: rely solely on libc in future? something else?
//
LogicalResult
matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto parent = printOp->getParentOfType<ModuleOp>();
auto loc = printOp->getLoc();
Type printType = printOp.getPrintType();
if (auto value = adaptor.getSource()) {
Type printType = printOp.getPrintType();
if (isa<VectorType>(printType)) {
// Vectors should be broken into elementary print ops in VectorToSCF.
return failure();
}
if (failed(emitScalarPrint(rewriter, parent, loc, printType, value)))
return failure();
}
auto punct = printOp.getPunctuation();
if (punct != PrintPunctuation::NoPunctuation) {
emitCall(rewriter, printOp->getLoc(), [&] {
switch (punct) {
case PrintPunctuation::Close:
return LLVM::lookupOrCreatePrintCloseFn(parent);
case PrintPunctuation::Open:
return LLVM::lookupOrCreatePrintOpenFn(parent);
case PrintPunctuation::Comma:
return LLVM::lookupOrCreatePrintCommaFn(parent);
case PrintPunctuation::NewLine:
return LLVM::lookupOrCreatePrintNewlineFn(parent);
default:
llvm_unreachable("unexpected punctuation");
}
}());
}
rewriter.eraseOp(printOp);
return success();
}
private:
enum class PrintConversion {
// clang-format off
None,
ZeroExt64,
SignExt64,
Bitcast16
// clang-format on
};
LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
ModuleOp parent, Location loc, Type printType,
Value value) const {
if (typeConverter->convertType(printType) == nullptr)
return failure();
// Make sure element type has runtime support.
PrintConversion conversion = PrintConversion::None;
VectorType vectorType = dyn_cast<VectorType>(printType);
Type eltType = vectorType ? vectorType.getElementType() : printType;
auto parent = printOp->getParentOfType<ModuleOp>();
Operation *printer;
if (printType.isF32()) {
if (eltType.isF32()) {
printer = LLVM::lookupOrCreatePrintF32Fn(parent);
} else if (printType.isF64()) {
} else if (eltType.isF64()) {
printer = LLVM::lookupOrCreatePrintF64Fn(parent);
} else if (printType.isF16()) {
} else if (eltType.isF16()) {
conversion = PrintConversion::Bitcast16; // bits!
printer = LLVM::lookupOrCreatePrintF16Fn(parent);
} else if (printType.isBF16()) {
} else if (eltType.isBF16()) {
conversion = PrintConversion::Bitcast16; // bits!
printer = LLVM::lookupOrCreatePrintBF16Fn(parent);
} else if (printType.isIndex()) {
} else if (eltType.isIndex()) {
printer = LLVM::lookupOrCreatePrintU64Fn(parent);
} else if (auto intTy = dyn_cast<IntegerType>(printType)) {
} else if (auto intTy = dyn_cast<IntegerType>(eltType)) {
// Integers need a zero or sign extension on the operand
// (depending on the source type) as well as a signed or
// unsigned print method. Up to 64-bit is supported.
@@ -1522,26 +1485,88 @@ private:
return failure();
}
switch (conversion) {
case PrintConversion::ZeroExt64:
value = rewriter.create<arith::ExtUIOp>(
loc, IntegerType::get(rewriter.getContext(), 64), value);
break;
case PrintConversion::SignExt64:
value = rewriter.create<arith::ExtSIOp>(
loc, IntegerType::get(rewriter.getContext(), 64), value);
break;
case PrintConversion::Bitcast16:
value = rewriter.create<LLVM::BitcastOp>(
loc, IntegerType::get(rewriter.getContext(), 16), value);
break;
case PrintConversion::None:
break;
}
emitCall(rewriter, loc, printer, value);
// Unroll vector into elementary print calls.
int64_t rank = vectorType ? vectorType.getRank() : 0;
Type type = vectorType ? vectorType : eltType;
emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank,
conversion);
emitCall(rewriter, printOp->getLoc(),
LLVM::lookupOrCreatePrintNewlineFn(parent));
rewriter.eraseOp(printOp);
return success();
}
private:
enum class PrintConversion {
// clang-format off
None,
ZeroExt64,
SignExt64,
Bitcast16
// clang-format on
};
void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
Value value, Type type, Operation *printer, int64_t rank,
PrintConversion conversion) const {
VectorType vectorType = dyn_cast<VectorType>(type);
Location loc = op->getLoc();
if (!vectorType) {
assert(rank == 0 && "The scalar case expects rank == 0");
switch (conversion) {
case PrintConversion::ZeroExt64:
value = rewriter.create<arith::ExtUIOp>(
loc, IntegerType::get(rewriter.getContext(), 64), value);
break;
case PrintConversion::SignExt64:
value = rewriter.create<arith::ExtSIOp>(
loc, IntegerType::get(rewriter.getContext(), 64), value);
break;
case PrintConversion::Bitcast16:
value = rewriter.create<LLVM::BitcastOp>(
loc, IntegerType::get(rewriter.getContext(), 16), value);
break;
case PrintConversion::None:
break;
}
emitCall(rewriter, loc, printer, value);
return;
}
auto parent = op->getParentOfType<ModuleOp>();
emitCall(rewriter, loc, LLVM::lookupOrCreatePrintOpenFn(parent));
Operation *printComma = LLVM::lookupOrCreatePrintCommaFn(parent);
if (rank <= 1) {
auto reducedType = vectorType.getElementType();
auto llvmType = typeConverter->convertType(reducedType);
int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0);
for (int64_t d = 0; d < dim; ++d) {
Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
llvmType, /*rank=*/0, /*pos=*/d);
emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0,
conversion);
if (d != dim - 1)
emitCall(rewriter, loc, printComma);
}
emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent));
return;
}
int64_t dim = vectorType.getDimSize(0);
for (int64_t d = 0; d < dim; ++d) {
auto reducedType = reducedVectorTypeFront(vectorType);
auto llvmType = typeConverter->convertType(reducedType);
Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
llvmType, rank, d);
emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
conversion);
if (d != dim - 1)
emitCall(rewriter, loc, printComma);
}
emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent));
}
// Helper to emit a call.
static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
Operation *ref, ValueRange params = ValueRange()) {