[mlir][Transforms] Dialect Conversion: Add replaceOpWithMultiple (#115816)

This commit adds a new function
`ConversionPatternRewriter::replaceOpWithMultiple`. This function is
similar to `replaceOp`, but it accepts multiple `ValueRange`
replacements, one per op result.

Note: This function is not an overload of `replaceOp` because of
ambiguous overload resolution that would make the API difficult to use.

This commit aligns "block signature conversions" with "op replacements":
both support 1:N replacements now. Due to incomplete 1:N support in the
dialect conversion driver, an argument materialization is inserted when
an SSA value is replaced with multiple values; same as block signature
conversions already work around the problem. These argument
materializations are going to be removed in a subsequent commit that
adds full 1:N support. The purpose of this PR is to add missing features
gradually in small increments.

This commit also updates two MLIR transformations that have their custom
workarounds around missing 1:N support. These can already start using
`replaceOpWithMultiple`.

Co-authored-by: Markus Böck <markus.boeck02@gmail.com>
This commit is contained in:
Matthias Springer
2024-11-14 10:27:58 +09:00
committed by GitHub
parent e5092c3019
commit aed4356252
5 changed files with 164 additions and 99 deletions

View File

@@ -600,8 +600,8 @@ public:
flattenOperands(adaptor.getOperands(), flattened);
auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
finalRetTy, flattened);
// (2) Create cast operation for sparse tensor returns.
SmallVector<Value> castedRet;
// (2) Gather sparse tensor returns.
SmallVector<SmallVector<Value>> packedResultVals;
// Tracks the offset of current return value (of the original call)
// relative to the new call (after sparse tensor flattening);
unsigned retOffset = 0;
@@ -618,21 +618,22 @@ public:
assert(!sparseFlat.empty());
if (sparseFlat.size() > 1) {
auto flatSize = sparseFlat.size();
ValueRange fields(iterator_range<ResultRange::iterator>(
newCall.result_begin() + retOffset,
newCall.result_begin() + retOffset + flatSize));
castedRet.push_back(genTuple(rewriter, loc, retType, fields));
packedResultVals.emplace_back();
llvm::append_range(packedResultVals.back(),
newCall.getResults().slice(retOffset, flatSize));
retOffset += flatSize;
} else {
// If this is an 1:1 conversion, no need for casting.
castedRet.push_back(newCall.getResult(retOffset));
packedResultVals.emplace_back();
packedResultVals.back().push_back(newCall.getResult(retOffset));
retOffset++;
}
sparseFlat.clear();
}
assert(castedRet.size() == op.getNumResults());
rewriter.replaceOp(op, castedRet);
assert(packedResultVals.size() == op.getNumResults());
rewriter.replaceOpWithMultiple(
op, llvm::to_vector_of<ValueRange>(packedResultVals));
return success();
}
};
@@ -776,7 +777,7 @@ public:
// Reuses specifier.
fields.push_back(desc.getSpecifier());
assert(fields.size() == desc.getNumFields());
rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
@@ -796,7 +797,7 @@ public:
sizeHint, lvlSizesValues, fields);
// Replace operation with resulting memrefs.
rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
@@ -837,7 +838,7 @@ public:
sizeHint, lvlSizesValues, fields);
// Replace operation with resulting memrefs.
rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
@@ -893,7 +894,7 @@ public:
if (op.getHasInserts())
genEndInsert(rewriter, op.getLoc(), desc);
// Replace operation with resulting memrefs.
rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc));
rewriter.replaceOpWithMultiple(op, {desc.getFields()});
return success();
}
};
@@ -1006,7 +1007,6 @@ public:
rewriter.create<scf::YieldOp>(loc, insertRet);
rewriter.setInsertionPointAfter(loop);
Value result = genTuple(rewriter, loc, dstType, loop->getResults());
// Deallocate the buffers on exit of the full loop nest.
Operation *parent = getTop(op);
rewriter.setInsertionPointAfter(parent);
@@ -1014,7 +1014,7 @@ public:
rewriter.create<memref::DeallocOp>(loc, filled);
rewriter.create<memref::DeallocOp>(loc, added);
// Replace operation with resulting memrefs.
rewriter.replaceOp(op, result);
rewriter.replaceOpWithMultiple(op, {loop->getResults()});
return success();
}
};
@@ -1041,8 +1041,7 @@ public:
params, /*genCall=*/true);
SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
// Replace operation with resulting memrefs.
rewriter.replaceOp(op,
genTuple(rewriter, loc, op.getDest().getType(), ret));
rewriter.replaceOpWithMultiple(op, {ret});
return success();
}
};
@@ -1215,8 +1214,7 @@ public:
return true;
});
rewriter.replaceOp(
op, genTuple(rewriter, loc, op.getResult().getType(), fields));
rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
};
@@ -1271,8 +1269,7 @@ public:
// NOTE: we can not generate tuples directly from descriptor here, as the
// descriptor is holding the original type, yet we want the slice type
// here (they shared every memref but with an updated specifier).
rewriter.replaceOp(op, genTuple(rewriter, loc, op.getResult().getType(),
desc.getFields()));
rewriter.replaceOpWithMultiple(op, {desc.getFields()});
return success();
}
};
@@ -1403,7 +1400,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
}
desc.setValMemSize(rewriter, loc, memSize);
rewriter.replaceOp(op, genTuple(rewriter, loc, desc));
rewriter.replaceOpWithMultiple(op, {desc.getFields()});
return success();
}
};
@@ -1577,7 +1574,7 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
EmitCInterface::Off);
// Replace operation with resulting memrefs.
rewriter.replaceOp(op, genTuple(rewriter, loc, dstTp, fields));
rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
};