[mlir][Transforms] Dialect Conversion: Add replaceOpWithMultiple (#115816)
This commit adds a new function `ConversionPatternRewriter::replaceOpWithMultiple`. This function is similar to `replaceOp`, but it accepts multiple `ValueRange` replacements, one per op result. Note: This function is not an overload of `replaceOp` because of ambiguous overload resolution that would make the API difficult to use. This commit aligns "block signature conversions" with "op replacements": both support 1:N replacements now. Due to incomplete 1:N support in the dialect conversion driver, an argument materialization is inserted when an SSA value is replaced with multiple values; same as block signature conversions already work around the problem. These argument materializations are going to be removed in a subsequent commit that adds full 1:N support. The purpose of this PR is to add missing features gradually in small increments. This commit also updates two MLIR transformations that have their custom workarounds around missing 1:N support. These can already start using `replaceOpWithMultiple`. Co-authored-by: Markus Böck <markus.boeck02@gmail.com>
This commit is contained in:
committed by
GitHub
parent
e5092c3019
commit
aed4356252
@@ -600,8 +600,8 @@ public:
|
||||
flattenOperands(adaptor.getOperands(), flattened);
|
||||
auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
|
||||
finalRetTy, flattened);
|
||||
// (2) Create cast operation for sparse tensor returns.
|
||||
SmallVector<Value> castedRet;
|
||||
// (2) Gather sparse tensor returns.
|
||||
SmallVector<SmallVector<Value>> packedResultVals;
|
||||
// Tracks the offset of current return value (of the original call)
|
||||
// relative to the new call (after sparse tensor flattening);
|
||||
unsigned retOffset = 0;
|
||||
@@ -618,21 +618,22 @@ public:
|
||||
assert(!sparseFlat.empty());
|
||||
if (sparseFlat.size() > 1) {
|
||||
auto flatSize = sparseFlat.size();
|
||||
ValueRange fields(iterator_range<ResultRange::iterator>(
|
||||
newCall.result_begin() + retOffset,
|
||||
newCall.result_begin() + retOffset + flatSize));
|
||||
castedRet.push_back(genTuple(rewriter, loc, retType, fields));
|
||||
packedResultVals.emplace_back();
|
||||
llvm::append_range(packedResultVals.back(),
|
||||
newCall.getResults().slice(retOffset, flatSize));
|
||||
retOffset += flatSize;
|
||||
} else {
|
||||
// If this is an 1:1 conversion, no need for casting.
|
||||
castedRet.push_back(newCall.getResult(retOffset));
|
||||
packedResultVals.emplace_back();
|
||||
packedResultVals.back().push_back(newCall.getResult(retOffset));
|
||||
retOffset++;
|
||||
}
|
||||
sparseFlat.clear();
|
||||
}
|
||||
|
||||
assert(castedRet.size() == op.getNumResults());
|
||||
rewriter.replaceOp(op, castedRet);
|
||||
assert(packedResultVals.size() == op.getNumResults());
|
||||
rewriter.replaceOpWithMultiple(
|
||||
op, llvm::to_vector_of<ValueRange>(packedResultVals));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -776,7 +777,7 @@ public:
|
||||
// Reuses specifier.
|
||||
fields.push_back(desc.getSpecifier());
|
||||
assert(fields.size() == desc.getNumFields());
|
||||
rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
|
||||
rewriter.replaceOpWithMultiple(op, {fields});
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -796,7 +797,7 @@ public:
|
||||
sizeHint, lvlSizesValues, fields);
|
||||
|
||||
// Replace operation with resulting memrefs.
|
||||
rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
|
||||
rewriter.replaceOpWithMultiple(op, {fields});
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -837,7 +838,7 @@ public:
|
||||
sizeHint, lvlSizesValues, fields);
|
||||
|
||||
// Replace operation with resulting memrefs.
|
||||
rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
|
||||
rewriter.replaceOpWithMultiple(op, {fields});
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -893,7 +894,7 @@ public:
|
||||
if (op.getHasInserts())
|
||||
genEndInsert(rewriter, op.getLoc(), desc);
|
||||
// Replace operation with resulting memrefs.
|
||||
rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc));
|
||||
rewriter.replaceOpWithMultiple(op, {desc.getFields()});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -1006,7 +1007,6 @@ public:
|
||||
rewriter.create<scf::YieldOp>(loc, insertRet);
|
||||
|
||||
rewriter.setInsertionPointAfter(loop);
|
||||
Value result = genTuple(rewriter, loc, dstType, loop->getResults());
|
||||
// Deallocate the buffers on exit of the full loop nest.
|
||||
Operation *parent = getTop(op);
|
||||
rewriter.setInsertionPointAfter(parent);
|
||||
@@ -1014,7 +1014,7 @@ public:
|
||||
rewriter.create<memref::DeallocOp>(loc, filled);
|
||||
rewriter.create<memref::DeallocOp>(loc, added);
|
||||
// Replace operation with resulting memrefs.
|
||||
rewriter.replaceOp(op, result);
|
||||
rewriter.replaceOpWithMultiple(op, {loop->getResults()});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -1041,8 +1041,7 @@ public:
|
||||
params, /*genCall=*/true);
|
||||
SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
|
||||
// Replace operation with resulting memrefs.
|
||||
rewriter.replaceOp(op,
|
||||
genTuple(rewriter, loc, op.getDest().getType(), ret));
|
||||
rewriter.replaceOpWithMultiple(op, {ret});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -1215,8 +1214,7 @@ public:
|
||||
return true;
|
||||
});
|
||||
|
||||
rewriter.replaceOp(
|
||||
op, genTuple(rewriter, loc, op.getResult().getType(), fields));
|
||||
rewriter.replaceOpWithMultiple(op, {fields});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -1271,8 +1269,7 @@ public:
|
||||
// NOTE: we can not generate tuples directly from descriptor here, as the
|
||||
// descriptor is holding the original type, yet we want the slice type
|
||||
// here (they shared every memref but with an updated specifier).
|
||||
rewriter.replaceOp(op, genTuple(rewriter, loc, op.getResult().getType(),
|
||||
desc.getFields()));
|
||||
rewriter.replaceOpWithMultiple(op, {desc.getFields()});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -1403,7 +1400,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
|
||||
}
|
||||
desc.setValMemSize(rewriter, loc, memSize);
|
||||
|
||||
rewriter.replaceOp(op, genTuple(rewriter, loc, desc));
|
||||
rewriter.replaceOpWithMultiple(op, {desc.getFields()});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -1577,7 +1574,7 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
|
||||
EmitCInterface::Off);
|
||||
|
||||
// Replace operation with resulting memrefs.
|
||||
rewriter.replaceOp(op, genTuple(rewriter, loc, dstTp, fields));
|
||||
rewriter.replaceOpWithMultiple(op, {fields});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user