[mlir][tosa] Add custom operand getters for select op (#145921)

The select op has 3 inputs: input1, input2, input3 to according to the
tosa specification. However, use of getInput1(), getInput2() and
getInput3() in the codebase can be confusing and hinder readability.
This commit adds custom getters to help improve readability:
  - input1 -> getPred()
  - input2 -> getOnTrue()
  - input3 -> getOnFalse()

They should be preferred as they are more descriptive, however, the ODS
generated getters (getInputX()) may still be used.

Unfortunately the custom getters don't propagate to Adaptors such as
`FoldAdaptor`, so the ODS generated getters must be used.
This commit is contained in:
Luke Hutton
2025-06-30 10:11:09 +01:00
committed by GitHub
parent 473769ec9b
commit 2e7aa7ead6
5 changed files with 23 additions and 16 deletions

View File

@@ -1490,9 +1490,9 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
}];
let arguments = (ins
Tosa_I1Tensor:$input1,
Tosa_Tensor:$input2,
Tosa_Tensor:$input3
Tosa_I1Tensor:$input1, // pred
Tosa_Tensor:$input2, // on true
Tosa_Tensor:$input3 // on false
);
let results = (outs
@@ -1512,6 +1512,13 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
`)` `->` type($output)
}];
let extraClassDeclaration = [{
// Custom getters for readability
::mlir::TypedValue<::mlir::TensorType> getPred() { return getInput1(); }
::mlir::TypedValue<::mlir::TensorType> getOnTrue() { return getInput2(); }
::mlir::TypedValue<::mlir::TensorType> getOnFalse() { return getInput3(); }
}];
}
//===----------------------------------------------------------------------===//

View File

@@ -344,7 +344,7 @@ LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
return failure();
rewriter.modifyOpInPlace(op, [&]() {
op.getOperation()->setOperands(
{notOp.getInput1(), op.getInput3(), op.getInput2()});
{notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
});
return success();
}
@@ -1510,8 +1510,8 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
}
OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
if (getInput2() == getInput3())
return getInput2();
if (getOnTrue() == getOnFalse())
return getOnTrue();
auto predicate =
llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
@@ -1520,8 +1520,8 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
if (!predicate.isSplat())
return {};
return predicate.getSplatValue<APInt>().getBoolValue() ? getInput2()
: getInput3();
return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
: getOnFalse();
}
OpFoldResult TileOp::fold(FoldAdaptor adaptor) {

View File

@@ -3829,16 +3829,16 @@ LogicalResult ReverseOp::verify() {
LogicalResult tosa::SelectOp::verify() {
// verify input2 and input3 have same element type as output
if (verifySameElementTypes(*this, /* inType = */ getInput2().getType(),
if (verifySameElementTypes(*this, /* inType = */ getOnTrue().getType(),
/* outType = */ getOutput().getType())
.failed() ||
verifySameElementTypes(*this, /* inType = */ getInput3().getType(),
verifySameElementTypes(*this, /* inType = */ getOnFalse().getType(),
/* outType = */ getOutput().getType())
.failed()) {
return failure();
}
// verify input1 has element type of bool
auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().getType());
auto predicateType = llvm::dyn_cast<ShapedType>(getPred().getType());
if (!predicateType) {
return emitOpError("expect shaped tensor for input1, got ")
<< getInput1().getType();

View File

@@ -169,9 +169,9 @@ struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> {
LogicalResult matchAndRewrite(tosa::SelectOp tosaOp,
PatternRewriter &rewriter) const override {
Value input1 = tosaOp.getInput1();
Value input2 = tosaOp.getInput2();
Value input3 = tosaOp.getInput3();
Value input1 = tosaOp.getPred();
Value input2 = tosaOp.getOnTrue();
Value input3 = tosaOp.getOnFalse();
Value output = tosaOp.getResult();
auto outputType = dyn_cast<RankedTensorType>(output.getType());

View File

@@ -188,8 +188,8 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
addValue(op.getInput2());
addValue(op.getInput3());
addValue(op.getOnTrue());
addValue(op.getOnFalse());
addValue(op.getOutput());
return success();
}