[mlir][Vector] Introduce the MaskingOpInterface

This MaskingOpInterface provides masking cababilitites to those
operations that implement it. For only is only implemented by the `vector.mask`
operation and it's used to break the dependency between the Vector
dialect (where the `vector.mask` op lives) and operations implementing
the MaskableOpInterface.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D136734
This commit is contained in:
Diego Caballero
2022-10-26 00:58:56 +00:00
parent 5e4eec98d3
commit b1bc1a1ed6
14 changed files with 265 additions and 88 deletions

View File

@@ -13,7 +13,8 @@
#ifndef MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
#define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
#include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"

View File

@@ -13,7 +13,8 @@
#ifndef VECTOR_OPS
#define VECTOR_OPS
include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td"
include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td"
include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td"
include "mlir/IR/EnumAttr.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
@@ -2140,13 +2141,15 @@ def Vector_CreateMaskOp :
}
def Vector_MaskOp : Vector_Op<"mask", [
SingleBlockImplicitTerminator<"vector::YieldOp">, RecursiveMemoryEffects,
NoRegionArguments
SingleBlockImplicitTerminator<"vector::YieldOp">,
DeclareOpInterfaceMethods<MaskingOpInterface>,
RecursiveMemoryEffects, NoRegionArguments
]> {
let summary = "Predicates a maskable vector operation";
let description = [{
The `vector.mask` operation predicates the execution of another operation.
It takes an `i1` vector mask and an optional pass-thru vector as arguments.
The `vector.mask` is a `MaskingOpInterface` operation that predicates the
execution of another operation. It takes an `i1` vector mask and an
optional passthru vector as arguments.
A `vector.yield`-terminated region encloses the operation to be masked.
Values used within the region are captured from above. Only one *maskable*
operation can be masked with a `vector.mask` operation at a time. An

View File

@@ -1 +1,2 @@
add_mlir_interface(MaskingInterfaces)
add_mlir_interface(MaskableOpInterface)
add_mlir_interface(MaskingOpInterface)

View File

@@ -0,0 +1,23 @@
//===- MaskableOpInterface.h ----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the MaskableOpInterface.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_VECTOR_INTERFACES_MASKABLEOPINTERFACE_H_
#define MLIR_DIALECT_VECTOR_INTERFACES_MASKABLEOPINTERFACE_H_
#include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
// Include the generated interface declarations.
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h.inc"
#endif // MLIR_DIALECT_VECTOR_INTERFACES_MASKABLEOPINTERFACE_H_

View File

@@ -0,0 +1,72 @@
//===- MaskableOpInterfaces.td - Masking Interfaces Decls -*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the definition file for the MaskableOpInterface.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_VECTOR_INTERFACES_MASKABLEOPINTERFACE_TD
#define MLIR_DIALECT_VECTOR_INTERFACES_MASKABLEOPINTERFACE_TD
include "mlir/IR/OpBase.td"
def MaskableOpInterface : OpInterface<"MaskableOpInterface"> {
let description = [{
The 'MaskableOpInterface' defines an operation that can be masked using a
MaskingOpInterface (e.g., `vector.mask`) and provides information about its
masking constraints and semantics.
}];
let cppNamespace = "::mlir::vector";
let methods = [
InterfaceMethod<
/*desc=*/"Returns true if the operation is masked by a "
"MaskingOpInterface.",
/*retTy=*/"bool",
/*methodName=*/"isMasked",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return mlir::isa<mlir::vector::MaskingOpInterface>($_op->getParentOp());
}]>,
InterfaceMethod<
/*desc=*/"Returns the MaskingOpInterface masking this operation.",
/*retTy=*/"mlir::vector::MaskingOpInterface",
/*methodName=*/"getMaskingOp",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return mlir::cast<mlir::vector::MaskingOpInterface>(
$_op->getParentOp());
}]>,
InterfaceMethod<
/*desc=*/"Returns true if the operation can have a passthru argument when"
" masked.",
/*retTy=*/"bool",
/*methodName=*/"supportsPassthru",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return false;
}]>,
InterfaceMethod<
/*desc=*/"Returns the mask type expected by this operation. It requires "
"the operation to be vectorized.",
/*retTy=*/"mlir::VectorType",
/*methodName=*/"getExpectedMaskType",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
// Default implementation is only aimed for operations that implement the
// `getVectorType()` method.
return $_op.getVectorType().cloneWith(/*shape=*/llvm::None,
IntegerType::get($_op.getContext(), /*width=*/1));
}]>,
];
}
#endif // MLIR_DIALECT_VECTOR_INTERFACES_MASKABLEOPINTERFACE_TD

View File

@@ -1,52 +0,0 @@
//===- MaskingInterfaces.td - Masking Interfaces Decls === -*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the definition file for vector masking related interfaces.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES
#define MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES
include "mlir/IR/OpBase.td"
def MaskableOpInterface : OpInterface<"MaskableOpInterface"> {
let description = [{
The 'MaskableOpInterface' define an operation that can be masked using the
`vector.mask` operation and provides information about its masking
constraints and semantics.
}];
let cppNamespace = "::mlir::vector";
let methods = [
InterfaceMethod<
/*desc=*/"Returns true if the operation may have a passthru argument when"
" masked.",
/*retTy=*/"bool",
/*methodName=*/"supportsPassthru",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return false;
}]>,
InterfaceMethod<
/*desc=*/"Returns the mask type expected by this operation. It requires the"
" operation to be vectorized.",
/*retTy=*/"mlir::VectorType",
/*methodName=*/"getExpectedMaskType",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
// Default implementation is only aimed for operations that implement the
// `getVectorType()` method.
return $_op.getVectorType().cloneWith(
/*shape=*/llvm::None, IntegerType::get($_op.getContext(), /*width=*/1));
}]>,
];
}
#endif // MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES

View File

@@ -1,4 +1,4 @@
//===- MaskingInterfaces.h - Masking interfaces ---------------------------===//
//===- MaskingOpInterface.h -----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,17 +6,17 @@
//
//===----------------------------------------------------------------------===//
//
// This file implements the interfaces for masking operations.
// This file implements the MaskingOpInterface.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES_H_
#define MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES_H_
#ifndef MLIR_DIALECT_VECTOR_INTERFACES_MASKINGOPINTERFACE_H_
#define MLIR_DIALECT_VECTOR_INTERFACES_MASKINGOPINTERFACE_H_
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
/// Include the generated interface declarations.
#include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h.inc"
// Include the generated interface declarations.
#include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h.inc"
#endif // MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES_H_
#endif // MLIR_DIALECT_VECTOR_INTERFACES_MASKINGOPINTERFACE_H_

View File

@@ -0,0 +1,58 @@
//===- MaskingOpInterfaces.td - MaskingOpInterface Decls = -*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the definition file for the MaskingOpInterface.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_VECTOR_INTERFACES_MASKINGOPINTERFACE_TD
#define MLIR_DIALECT_VECTOR_INTERFACES_MASKINGOPINTERFACE_TD
include "mlir/IR/OpBase.td"
def MaskingOpInterface : OpInterface<"MaskingOpInterface"> {
let description = [{
The 'MaskingOpInterface' defines an vector operation that can apply masking
to its own or other vector operations.
}];
let cppNamespace = "::mlir::vector";
let methods = [
InterfaceMethod<
/*desc=*/"Returns the mask value of this masking operation.",
/*retTy=*/"mlir::Value",
/*methodName=*/"getMask",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/"">,
InterfaceMethod<
/*desc=*/"Returns the operation masked by this masking operation.",
// TODO: Return a MaskableOpInterface when interface infra can handle
// dependences between interfaces.
/*retTy=*/"Operation *",
/*methodName=*/"getMaskableOp",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/"">,
InterfaceMethod<
/*desc=*/"Returns true if the masking operation has a passthru value.",
/*retTy=*/"bool",
/*methodName=*/"hasPassthru",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/"">,
InterfaceMethod<
/*desc=*/"Returns the passthru value of this masking operation.",
/*retTy=*/"mlir::Value",
/*methodName=*/"getPassthru",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/"">,
];
}
#endif // MLIR_DIALECT_VECTOR_INTERFACES_MASKINGOPINTERFACE_TD

View File

@@ -15,7 +15,8 @@ add_mlir_dialect_library(MLIRVectorDialect
MLIRDestinationStyleOpInterface
MLIRDialectUtils
MLIRIR
MLIRMaskingInterfaces
MLIRMaskableOpInterface
MLIRMaskingOpInterface
MLIRMemRefDialect
MLIRSideEffectInterfaces
MLIRTensorDialect

View File

@@ -5003,6 +5003,14 @@ LogicalResult MaskOp::verify() {
return success();
}
// MaskingOpInterface definitions.
/// Returns the operation masked by this 'vector.mask'.
Operation *MaskOp::getMaskableOp() { return &getMaskRegion().front().front(); }
/// Returns true if 'vector.mask' has a passthru value.
bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
//===----------------------------------------------------------------------===//
// ScanOp
//===----------------------------------------------------------------------===//

View File

@@ -1,5 +1,6 @@
set(LLVM_OPTIONAL_SOURCES
MaskingInterfaces.cpp
MaskableOpInterface.cpp
MaskingOpInterface.cpp
)
function(add_mlir_interface_library name)
@@ -17,5 +18,5 @@ function(add_mlir_interface_library name)
)
endfunction(add_mlir_interface_library)
add_mlir_interface_library(MaskingInterfaces)
add_mlir_interface_library(MaskableOpInterface)
add_mlir_interface_library(MaskingOpInterface)

View File

@@ -1,4 +1,4 @@
//===- MaskingInterfaces.cpp - Masking interfaces ----------====-*- C++ -*-===//
//===- MaskableOpInterfaces.cpp - MaskableOpInterface Defs -====-*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,11 +6,13 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
using namespace mlir;
using namespace mlir::vector;
//===----------------------------------------------------------------------===//
// Masking Interfaces
// MaskableOpInterface Defs
//===----------------------------------------------------------------------===//
/// Include the definitions of the masking interfaces.
#include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.cpp.inc"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.cpp.inc"

View File

@@ -0,0 +1,18 @@
//===- MaskingOpInterface.cpp - MaskingOpInterface Defs -----====-*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h"
using namespace mlir;
using namespace mlir::vector;
//===----------------------------------------------------------------------===//
// MaskingOpInterface Defs
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.cpp.inc"

View File

@@ -3250,7 +3250,8 @@ cc_library(
":DialectUtils",
":IR",
":InferTypeOpInterface",
":MaskingInterfaces",
":MaskableOpInterface",
":MaskingOpInterface",
":MemRefDialect",
":SideEffectInterfaces",
":Support",
@@ -8201,8 +8202,15 @@ cc_library(
##---------------------------------------------------------------------------##
td_library(
name = "MaskingInterfacesTdFiles",
srcs = ["include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td"],
name = "MaskingOpInterfaceTdFiles",
srcs = ["include/mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td"],
includes = ["include"],
deps = [":OpBaseTdFiles"],
)
td_library(
name = "MaskableOpInterfaceTdFiles",
srcs = ["include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td"],
includes = ["include"],
deps = [":OpBaseTdFiles"],
)
@@ -8215,7 +8223,8 @@ td_library(
":ControlFlowInterfacesTdFiles",
":DestinationStyleOpInterfaceTdFiles",
":InferTypeOpInterfaceTdFiles",
":MaskingInterfacesTdFiles",
":MaskableOpInterfaceTdFiles",
":MaskingOpInterfaceTdFiles",
":OpBaseTdFiles",
":SideEffectInterfacesTdFiles",
":VectorInterfacesTdFiles",
@@ -8224,21 +8233,39 @@ td_library(
)
gentbl_cc_library(
name = "MaskingInterfacesIncGen",
name = "MaskableOpInterfaceIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
["-gen-op-interface-decls"],
"include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h.inc",
"include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h.inc",
),
(
["-gen-op-interface-defs"],
"include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.cpp.inc",
"include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.cpp.inc",
),
],
tblgen = ":mlir-tblgen",
td_file = "include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td",
deps = [":MaskingInterfacesTdFiles"],
td_file = "include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td",
deps = [":MaskableOpInterfaceTdFiles"],
)
gentbl_cc_library(
name = "MaskingOpInterfaceIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
["-gen-op-interface-decls"],
"include/mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h.inc",
),
(
["-gen-op-interface-defs"],
"include/mlir/Dialect/Vector/Interfaces/MaskingOpInterface.cpp.inc",
),
],
tblgen = ":mlir-tblgen",
td_file = "include/mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td",
deps = [":MaskingOpInterfaceTdFiles"],
)
gentbl_cc_library(
@@ -8294,13 +8321,27 @@ gentbl_cc_library(
)
cc_library(
name = "MaskingInterfaces",
srcs = ["lib/Dialect/Vector/Interfaces/MaskingInterfaces.cpp"],
hdrs = ["include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h"],
name = "MaskableOpInterface",
srcs = ["lib/Dialect/Vector/Interfaces/MaskableOpInterface.cpp"],
hdrs = ["include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"],
includes = ["include"],
deps = [
":IR",
":MaskingInterfacesIncGen",
":MaskableOpInterfaceIncGen",
":MaskingOpInterface",
":Support",
"//llvm:Support",
],
)
cc_library(
name = "MaskingOpInterface",
srcs = ["lib/Dialect/Vector/Interfaces/MaskingOpInterface.cpp"],
hdrs = ["include/mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h"],
includes = ["include"],
deps = [
":IR",
":MaskingOpInterfaceIncGen",
":Support",
"//llvm:Support",
],