[NFC][MLIR][Linalg] Refactor linalg.matmul tablegen ODS and related C++ code. (#116377)
This commit refactors part of the code in preparation for the migration of other *matmul* variants from OpDSL to ODS. Moves getDefaultIndexingmaps() helper into the MatmulOp class.
This commit is contained in:
committed by
GitHub
parent
b7ddb97ac2
commit
288f05f63e
@@ -622,7 +622,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
|
||||
[{
|
||||
buildMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
|
||||
attributes, MatmulOp::getRegionBuilder());
|
||||
attributes, MatmulOp::getRegionBuilder(),
|
||||
MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
|
||||
}]>,
|
||||
OpBuilder<
|
||||
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
||||
@@ -630,16 +631,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
|
||||
[{
|
||||
buildMatmulOp($_builder, $_state, resultTensorTypes,
|
||||
inputs, outputs, attributes, MatmulOp::getRegionBuilder());
|
||||
}]>,
|
||||
OpBuilder<
|
||||
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
|
||||
[{
|
||||
$_state.addOperands(operands);
|
||||
$_state.addAttributes(attributes);
|
||||
$_state.addTypes(resultTensorTypes);
|
||||
(void)$_state.addRegion();
|
||||
inputs, outputs, attributes, MatmulOp::getRegionBuilder(),
|
||||
MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
|
||||
}]>,
|
||||
OpBuilder<
|
||||
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
||||
@@ -648,7 +641,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
|
||||
[{
|
||||
$_state.addAttribute("cast", cast);
|
||||
buildMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
|
||||
attributes, MatmulOp::getRegionBuilder());
|
||||
attributes, MatmulOp::getRegionBuilder(),
|
||||
MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
|
||||
}]>
|
||||
|
||||
];
|
||||
@@ -664,7 +658,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
|
||||
Block &block, ArrayRef<NamedAttribute> attrs);
|
||||
|
||||
/// Returns a list of AffineMap with the typical matmul indexing charactristic.
|
||||
SmallVector<AffineMap> getDefaultIndexingMaps();
|
||||
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
|
||||
|
||||
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
|
||||
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
|
||||
|
||||
@@ -155,27 +155,6 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
|
||||
// iterator_types is an auto-generated method.
|
||||
}
|
||||
|
||||
/// Helper to create a typical indexing map for MatmulOp. Returns a list of
|
||||
/// AffineMap.
|
||||
static SmallVector<AffineMap, 3>
|
||||
getDefaultIndexingMapsForMatmul(MLIRContext *context) {
|
||||
AffineExpr d0, d1, d2;
|
||||
SmallVector<AffineMap, 3> indexingMaps;
|
||||
bindDims(context, d0, d1, d2);
|
||||
indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
|
||||
indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
|
||||
indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
|
||||
return indexingMaps;
|
||||
}
|
||||
|
||||
/// Wrapper to return the typical indexing map array attribute for MatmulOp.
|
||||
static SmallVector<Attribute>
|
||||
getDefaultMatmulIndexingMapAttr(MLIRContext *context) {
|
||||
return llvm::map_to_vector(
|
||||
getDefaultIndexingMapsForMatmul(context),
|
||||
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
|
||||
}
|
||||
|
||||
/// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
|
||||
/// The result types are derived automatically if `resultTensorTypes` is none.
|
||||
/// The body of the operation is filled using `regionBuilder`. All ods-gen
|
||||
@@ -208,24 +187,18 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
|
||||
state.attributes.getAttrs(), regionBuilder);
|
||||
}
|
||||
|
||||
static void
|
||||
buildMatmulOp(OpBuilder &b, OperationState &state,
|
||||
std::optional<TypeRange> resultTensorTypes, ValueRange inputs,
|
||||
ValueRange outputs, ArrayRef<NamedAttribute> attributes,
|
||||
RegionBuilderFn regionBuilder,
|
||||
std::optional<ArrayRef<AffineMap>> indexingMaps = std::nullopt) {
|
||||
// Initialize indexingMaps, for MatmulOp.
|
||||
static void buildMatmulOp(OpBuilder &b, OperationState &state,
|
||||
std::optional<TypeRange> resultTensorTypes,
|
||||
ValueRange inputs, ValueRange outputs,
|
||||
ArrayRef<NamedAttribute> attributes,
|
||||
RegionBuilderFn regionBuilder,
|
||||
ArrayRef<AffineMap> indexingMaps) {
|
||||
// Initialize indexingMaps attribute, for MatmulOp.
|
||||
SmallVector<Attribute, 3> indexingMapsAttrVal;
|
||||
if (indexingMaps.has_value()) {
|
||||
for (mlir::AffineMap map : *indexingMaps) {
|
||||
// Convert each AffineMap to an AffineMapAttr
|
||||
indexingMapsAttrVal.push_back(AffineMapAttr::get(map));
|
||||
}
|
||||
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
|
||||
} else {
|
||||
indexingMapsAttrVal = getDefaultMatmulIndexingMapAttr(b.getContext());
|
||||
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
|
||||
}
|
||||
indexingMapsAttrVal = llvm::map_to_vector(
|
||||
MatmulOp::getDefaultIndexingMaps(b.getContext()),
|
||||
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
|
||||
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
|
||||
return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
|
||||
attributes, regionBuilder);
|
||||
}
|
||||
@@ -3457,7 +3430,7 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
|
||||
unsigned opIndex) {
|
||||
SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
|
||||
SmallVector<AffineMap, 3> defaultIndexingMaps =
|
||||
matmulOp.getDefaultIndexingMaps();
|
||||
matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
|
||||
|
||||
auto opIndexingMap = opIndexingMaps[opIndex];
|
||||
auto defaultIndexingMap = defaultIndexingMaps[opIndex];
|
||||
@@ -3484,6 +3457,17 @@ namespace linalg {
|
||||
// MatMulOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns a list of AffineMap with the typical matmul indexing charactristic.
|
||||
SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
|
||||
AffineExpr d0, d1, d2;
|
||||
SmallVector<AffineMap, 3> indexingMaps;
|
||||
bindDims(context, d0, d1, d2);
|
||||
indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
|
||||
indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
|
||||
indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
|
||||
return indexingMaps;
|
||||
}
|
||||
|
||||
SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
|
||||
return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
|
||||
utils::IteratorType::parallel,
|
||||
@@ -3501,7 +3485,8 @@ bool MatmulOp::hasDynamicIndexingMaps() { return true; }
|
||||
/// Check if the op has broadcast and/or transpose semantic. Returns true if
|
||||
/// the user defined indexing maps are not equal to default map.
|
||||
bool MatmulOp::hasUserDefinedMaps() {
|
||||
SmallVector<AffineMap, 3> defaultMaps = getDefaultIndexingMaps();
|
||||
SmallVector<AffineMap, 3> defaultMaps =
|
||||
getDefaultIndexingMaps(this->getContext());
|
||||
SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
|
||||
return defaultMaps != explicitMaps;
|
||||
}
|
||||
@@ -3535,13 +3520,6 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
|
||||
helper.yieldOutputs(yields);
|
||||
}
|
||||
|
||||
/// Returns a list of AffineMap with the typical matmul indexing
|
||||
/// charactristic.
|
||||
SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps() {
|
||||
MLIRContext *context = this->getContext();
|
||||
return getDefaultIndexingMapsForMatmul(context);
|
||||
}
|
||||
|
||||
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
|
||||
bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
|
||||
assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
|
||||
@@ -3578,7 +3556,9 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
}
|
||||
// Initialize indexingMaps, if not supplied explicitly.
|
||||
if (indexingMapsAttr.empty()) {
|
||||
indexingMapsAttr = getDefaultMatmulIndexingMapAttr(result.getContext());
|
||||
indexingMapsAttr = llvm::map_to_vector(
|
||||
MatmulOp::getDefaultIndexingMaps(parser.getContext()),
|
||||
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
|
||||
}
|
||||
result.addAttribute("indexing_maps",
|
||||
parser.getBuilder().getArrayAttr(indexingMapsAttr));
|
||||
@@ -3592,8 +3572,9 @@ void MatmulOp::print(OpAsmPrinter &p) {
|
||||
printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
|
||||
elidedAttrs);
|
||||
|
||||
SmallVector<Attribute, 3> indexingMaps =
|
||||
getDefaultMatmulIndexingMapAttr(getContext());
|
||||
SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
|
||||
MatmulOp::getDefaultIndexingMaps(getContext()),
|
||||
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
|
||||
if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
|
||||
p << " indexing_maps = [";
|
||||
llvm::interleaveComma(getIndexingMaps(), p,
|
||||
|
||||
Reference in New Issue
Block a user