[mlir][Vector] Support 0-D vectors in VectorPrintOpConversion
Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D114549
This commit is contained in:
committed by
Nicolas Vasilache
parent
0796869e4e
commit
cc311a155a
@@ -57,8 +57,7 @@ static Value insertOne(ConversionPatternRewriter &rewriter,
|
||||
static Value extractOne(ConversionPatternRewriter &rewriter,
|
||||
LLVMTypeConverter &typeConverter, Location loc,
|
||||
Value val, Type llvmType, int64_t rank, int64_t pos) {
|
||||
assert(rank > 0 && "0-D vector corner case should have been handled already");
|
||||
if (rank == 1) {
|
||||
if (rank <= 1) {
|
||||
auto idxType = rewriter.getIndexType();
|
||||
auto constant = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, typeConverter.convertType(idxType),
|
||||
@@ -987,7 +986,8 @@ public:
|
||||
|
||||
// Unroll vector into elementary print calls.
|
||||
int64_t rank = vectorType ? vectorType.getRank() : 0;
|
||||
emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
|
||||
Type type = vectorType ? vectorType : eltType;
|
||||
emitRanks(rewriter, printOp, adaptor.source(), type, printer, rank,
|
||||
conversion);
|
||||
emitCall(rewriter, printOp->getLoc(),
|
||||
LLVM::lookupOrCreatePrintNewlineFn(
|
||||
@@ -1006,10 +1006,12 @@ private:
|
||||
};
|
||||
|
||||
void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
Value value, VectorType vectorType, Operation *printer,
|
||||
int64_t rank, PrintConversion conversion) const {
|
||||
Value value, Type type, Operation *printer, int64_t rank,
|
||||
PrintConversion conversion) const {
|
||||
VectorType vectorType = type.dyn_cast<VectorType>();
|
||||
Location loc = op->getLoc();
|
||||
if (rank == 0) {
|
||||
if (!vectorType) {
|
||||
assert(rank == 0 && "The scalar case expects rank == 0");
|
||||
switch (conversion) {
|
||||
case PrintConversion::ZeroExt64:
|
||||
value = rewriter.create<arith::ExtUIOp>(
|
||||
@@ -1030,12 +1032,29 @@ private:
|
||||
LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
|
||||
Operation *printComma =
|
||||
LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
|
||||
|
||||
if (rank <= 1) {
|
||||
auto reducedType = vectorType.getElementType();
|
||||
auto llvmType = typeConverter->convertType(reducedType);
|
||||
int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0);
|
||||
for (int64_t d = 0; d < dim; ++d) {
|
||||
Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
|
||||
llvmType, /*rank=*/0, /*pos=*/d);
|
||||
emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0,
|
||||
conversion);
|
||||
if (d != dim - 1)
|
||||
emitCall(rewriter, loc, printComma);
|
||||
}
|
||||
emitCall(
|
||||
rewriter, loc,
|
||||
LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t dim = vectorType.getDimSize(0);
|
||||
for (int64_t d = 0; d < dim; ++d) {
|
||||
auto reducedType =
|
||||
rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
|
||||
auto llvmType = typeConverter->convertType(
|
||||
rank > 1 ? reducedType : vectorType.getElementType());
|
||||
auto reducedType = reducedVectorTypeFront(vectorType);
|
||||
auto llvmType = typeConverter->convertType(reducedType);
|
||||
Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
|
||||
llvmType, rank, d);
|
||||
emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
|
||||
|
||||
Reference in New Issue
Block a user