[mlir][emitc] Add comparison operation

This adds a comparison operation to EmitC which supports ==, !=, <=, <, >=, >, <=>.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D158180
This commit is contained in:
Simon Camphausen
2023-08-29 14:45:35 +00:00
committed by Marius Brehler
parent 54784b1831
commit adea7e7032
8 changed files with 146 additions and 18 deletions

View File

@@ -2,6 +2,8 @@ add_mlir_dialect(EmitC emitc)
add_mlir_doc(EmitC EmitC Dialects/ -gen-dialect-doc)
set(LLVM_TARGET_DEFINITIONS EmitCAttributes.td)
mlir_tablegen(EmitCEnums.h.inc -gen-enum-decls)
mlir_tablegen(EmitCEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(EmitCAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(EmitCAttributes.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(MLIREmitCAttributesIncGen)

View File

@@ -21,6 +21,7 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Dialect/EmitC/IR/EmitCDialect.h.inc"
#include "mlir/Dialect/EmitC/IR/EmitCEnums.h.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/EmitC/IR/EmitCAttributes.h.inc"

View File

@@ -27,8 +27,8 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
class EmitC_Op<string mnemonic, list<Trait> traits = []>
: Op<EmitC_Dialect, mnemonic, traits>;
// Base class for binary arithmetic operations.
class EmitC_BinaryArithOp<string mnemonic, list<Trait> traits = []> :
// Base class for binary operations.
class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
EmitC_Op<mnemonic, traits> {
let arguments = (ins AnyType:$lhs, AnyType:$rhs);
let results = (outs AnyType);
@@ -39,7 +39,7 @@ class EmitC_BinaryArithOp<string mnemonic, list<Trait> traits = []> :
def IntegerIndexOrOpaqueType : AnyTypeOf<[AnyInteger, Index, EmitC_OpaqueType]>;
def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[AnyFloat, IntegerIndexOrOpaqueType]>;
def EmitC_AddOp : EmitC_BinaryArithOp<"add", []> {
def EmitC_AddOp : EmitC_BinaryOp<"add", []> {
let summary = "Addition operation";
let description = [{
With the `add` operation the arithmetic operator + (addition) can
@@ -150,6 +150,37 @@ def EmitC_CastOp : EmitC_Op<"cast", [
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
}
def EmitC_CmpOp : EmitC_BinaryOp<"cmp", []> {
let summary = "Comparison operation";
let description = [{
With the `cmp` operation the comparison operators ==, !=, <, <=, >, >=, <=>
can be applied.
Example:
```mlir
// Custom form of the cmp operation.
%0 = emitc.cmp eq, %arg0, %arg1 : (i32, i32) -> i1
%1 = emitc.cmp lt, %arg2, %arg3 :
(
!emitc.opaque<"std::valarray<float>">,
!emitc.opaque<"std::valarray<float>">
) -> !emitc.opaque<"std::valarray<bool>">
```
```c++
// Code emitted for the operations above.
bool v5 = v1 == v2;
std::valarray<bool> v6 = v3 < v4;
```
}];
let arguments = (ins EmitC_CmpPredicateAttr:$predicate,
AnyType:$lhs,
AnyType:$rhs);
let results = (outs AnyType);
let assemblyFormat = "$predicate `,` operands attr-dict `:` functional-type(operands, results)";
}
def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
let summary = "Constant operation";
let description = [{
@@ -180,7 +211,7 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
let hasVerifier = 1;
}
def EmitC_DivOp : EmitC_BinaryArithOp<"div", []> {
def EmitC_DivOp : EmitC_BinaryOp<"div", []> {
let summary = "Division operation";
let description = [{
With the `div` operation the arithmetic operator / (division) can
@@ -248,7 +279,7 @@ def EmitC_LiteralOp : EmitC_Op<"literal", [Pure]> {
let assemblyFormat = "$value attr-dict `:` type($result)";
}
def EmitC_MulOp : EmitC_BinaryArithOp<"mul", []> {
def EmitC_MulOp : EmitC_BinaryOp<"mul", []> {
let summary = "Multiplication operation";
let description = [{
With the `mul` operation the arithmetic operator * (multiplication) can
@@ -272,7 +303,7 @@ def EmitC_MulOp : EmitC_BinaryArithOp<"mul", []> {
let results = (outs FloatIntegerIndexOrOpaqueType);
}
def EmitC_RemOp : EmitC_BinaryArithOp<"rem", []> {
def EmitC_RemOp : EmitC_BinaryOp<"rem", []> {
let summary = "Remainder operation";
let description = [{
With the `rem` operation the arithmetic operator % (remainder) can
@@ -294,7 +325,7 @@ def EmitC_RemOp : EmitC_BinaryArithOp<"rem", []> {
let results = (outs IntegerIndexOrOpaqueType);
}
def EmitC_SubOp : EmitC_BinaryArithOp<"sub", []> {
def EmitC_SubOp : EmitC_BinaryOp<"sub", []> {
let summary = "Subtraction operation";
let description = [{
With the `sub` operation the arithmetic operator - (subtraction) can

View File

@@ -15,6 +15,7 @@
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/EnumAttr.td"
include "mlir/Dialect/EmitC/IR/EmitCBase.td"
//===----------------------------------------------------------------------===//
@@ -26,6 +27,20 @@ class EmitC_Attr<string name, string attrMnemonic, list<Trait> traits = []>
let mnemonic = attrMnemonic;
}
def EmitC_CmpPredicateAttr : I64EnumAttr<
"CmpPredicate", "",
[
I64EnumAttrCase<"eq", 0>,
I64EnumAttrCase<"ne", 1>,
I64EnumAttrCase<"lt", 2>,
I64EnumAttrCase<"le", 3>,
I64EnumAttrCase<"gt", 4>,
I64EnumAttrCase<"ge", 5>,
I64EnumAttrCase<"three_way", 6>,
]> {
let cppNamespace = "::mlir::emitc";
}
def EmitC_OpaqueAttr : EmitC_Attr<"Opaque", "opaque"> {
let summary = "An opaque attribute";

View File

@@ -257,6 +257,12 @@ LogicalResult emitc::VariableOp::verify() {
#define GET_OP_CLASSES
#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
//===----------------------------------------------------------------------===//
// EmitC Enums
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/EmitC/IR/EmitCEnums.cpp.inc"
//===----------------------------------------------------------------------===//
// EmitC Attributes
//===----------------------------------------------------------------------===//

View File

@@ -246,15 +246,15 @@ static LogicalResult printOperation(CppEmitter &emitter,
return printConstantOp(emitter, operation, value);
}
static LogicalResult printBinaryArithOperation(CppEmitter &emitter,
Operation *operation,
StringRef binaryArithOperator) {
static LogicalResult printBinaryOperation(CppEmitter &emitter,
Operation *operation,
StringRef binaryOperator) {
raw_ostream &os = emitter.ostream();
if (failed(emitter.emitAssignPrefix(*operation)))
return failure();
os << emitter.getOrCreateName(operation->getOperand(0));
os << " " << binaryArithOperator;
os << " " << binaryOperator;
os << " " << emitter.getOrCreateName(operation->getOperand(1));
return success();
@@ -263,31 +263,65 @@ static LogicalResult printBinaryArithOperation(CppEmitter &emitter,
static LogicalResult printOperation(CppEmitter &emitter, emitc::AddOp addOp) {
Operation *operation = addOp.getOperation();
return printBinaryArithOperation(emitter, operation, "+");
return printBinaryOperation(emitter, operation, "+");
}
static LogicalResult printOperation(CppEmitter &emitter, emitc::DivOp divOp) {
Operation *operation = divOp.getOperation();
return printBinaryArithOperation(emitter, operation, "/");
return printBinaryOperation(emitter, operation, "/");
}
static LogicalResult printOperation(CppEmitter &emitter, emitc::MulOp mulOp) {
Operation *operation = mulOp.getOperation();
return printBinaryArithOperation(emitter, operation, "*");
return printBinaryOperation(emitter, operation, "*");
}
static LogicalResult printOperation(CppEmitter &emitter, emitc::RemOp remOp) {
Operation *operation = remOp.getOperation();
return printBinaryArithOperation(emitter, operation, "%");
return printBinaryOperation(emitter, operation, "%");
}
static LogicalResult printOperation(CppEmitter &emitter, emitc::SubOp subOp) {
Operation *operation = subOp.getOperation();
return printBinaryArithOperation(emitter, operation, "-");
return printBinaryOperation(emitter, operation, "-");
}
static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
Operation *operation = cmpOp.getOperation();
StringRef binaryOperator;
switch (cmpOp.getPredicate()) {
case emitc::CmpPredicate::eq:
binaryOperator = "==";
break;
case emitc::CmpPredicate::ne:
binaryOperator = "!=";
break;
case emitc::CmpPredicate::lt:
binaryOperator = "<";
break;
case emitc::CmpPredicate::le:
binaryOperator = "<=";
break;
case emitc::CmpPredicate::gt:
binaryOperator = ">";
break;
case emitc::CmpPredicate::ge:
binaryOperator = ">=";
break;
case emitc::CmpPredicate::three_way:
binaryOperator = "<=>";
break;
default:
return cmpOp.emitError("unhandled comparison predicate");
}
return printBinaryOperation(emitter, operation, binaryOperator);
}
static LogicalResult printOperation(CppEmitter &emitter,
@@ -977,8 +1011,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
[&](auto op) { return printOperation(*this, op); })
// EmitC ops.
.Case<emitc::AddOp, emitc::ApplyOp, emitc::CallOp, emitc::CastOp,
emitc::ConstantOp, emitc::DivOp, emitc::IncludeOp, emitc::MulOp,
emitc::RemOp, emitc::SubOp, emitc::VariableOp>(
emitc::CmpOp, emitc::ConstantOp, emitc::DivOp, emitc::IncludeOp,
emitc::MulOp, emitc::RemOp, emitc::SubOp, emitc::VariableOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
.Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(

View File

@@ -79,3 +79,21 @@ func.func @sub_pointer(%arg0: !emitc.ptr<f32>, %arg1: i32, %arg2: !emitc.opaque<
%4 = "emitc.sub" (%arg0, %arg3) : (!emitc.ptr<f32>, !emitc.ptr<f32>) -> i32
return
}
func.func @cmp(%arg0 : i32, %arg1 : f32, %arg2 : i64, %arg3 : f64, %arg4 : !emitc.opaque<"unsigned">, %arg5 : !emitc.opaque<"std::valarray<int>">, %arg6 : !emitc.opaque<"custom">) {
%1 = "emitc.cmp" (%arg0, %arg0) {predicate = 0} : (i32, i32) -> i1
%2 = emitc.cmp eq, %arg0, %arg0 : (i32, i32) -> i1
%3 = "emitc.cmp" (%arg1, %arg1) {predicate = 1} : (f32, f32) -> i1
%4 = emitc.cmp ne, %arg1, %arg1 : (f32, f32) -> i1
%5 = "emitc.cmp" (%arg2, %arg2) {predicate = 2} : (i64, i64) -> i1
%6 = emitc.cmp lt, %arg2, %arg2 : (i64, i64) -> i1
%7 = "emitc.cmp" (%arg3, %arg3) {predicate = 3} : (f64, f64) -> i1
%8 = emitc.cmp le, %arg3, %arg3 : (f64, f64) -> i1
%9 = "emitc.cmp" (%arg4, %arg4) {predicate = 4} : (!emitc.opaque<"unsigned">, !emitc.opaque<"unsigned">) -> i1
%10 = emitc.cmp gt, %arg4, %arg4 : (!emitc.opaque<"unsigned">, !emitc.opaque<"unsigned">) -> i1
%11 = "emitc.cmp" (%arg5, %arg5) {predicate = 5} : (!emitc.opaque<"std::valarray<int>">, !emitc.opaque<"std::valarray<int>">) -> !emitc.opaque<"std::valarray<bool>">
%12 = emitc.cmp ge, %arg5, %arg5 : (!emitc.opaque<"std::valarray<int>">, !emitc.opaque<"std::valarray<int>">) -> !emitc.opaque<"std::valarray<bool>">
%13 = "emitc.cmp" (%arg6, %arg6) {predicate = 6} : (!emitc.opaque<"custom">, !emitc.opaque<"custom">) -> !emitc.opaque<"custom">
%14 = emitc.cmp three_way, %arg6, %arg6 : (!emitc.opaque<"custom">, !emitc.opaque<"custom">) -> !emitc.opaque<"custom">
return
}

View File

@@ -0,0 +1,21 @@
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
func.func @cmp(%arg0 : i32, %arg1 : f32, %arg2 : i64, %arg3 : f64, %arg4 : !emitc.opaque<"unsigned">, %arg5 : !emitc.opaque<"std::valarray<int>">, %arg6 : !emitc.opaque<"custom">) {
%1 = emitc.cmp eq, %arg0, %arg2 : (i32, i64) -> i1
%2 = emitc.cmp ne, %arg1, %arg3 : (f32, f64) -> i1
%3 = emitc.cmp lt, %arg2, %arg4 : (i64, !emitc.opaque<"unsigned">) -> !emitc.opaque<"int">
%4 = emitc.cmp le, %arg3, %arg3 : (f64, f64) -> i1
%5 = emitc.cmp gt, %arg6, %arg4 : (!emitc.opaque<"custom">, !emitc.opaque<"unsigned">) -> !emitc.opaque<"custom">
%6 = emitc.cmp ge, %arg5, %arg5 : (!emitc.opaque<"std::valarray<int>">, !emitc.opaque<"std::valarray<int>">) -> !emitc.opaque<"std::valarray<bool>">
%7 = emitc.cmp three_way, %arg6, %arg6 : (!emitc.opaque<"custom">, !emitc.opaque<"custom">) -> !emitc.opaque<"custom">
return
}
// CHECK-LABEL: void cmp
// CHECK-NEXT: bool [[V7:[^ ]*]] = [[V0:[^ ]*]] == [[V2:[^ ]*]];
// CHECK-NEXT: bool [[V8:[^ ]*]] = [[V1:[^ ]*]] != [[V3:[^ ]*]];
// CHECK-NEXT: int [[V9:[^ ]*]] = [[V2]] < [[V4:[^ ]*]];
// CHECK-NEXT: bool [[V10:[^ ]*]] = [[V3]] <= [[V3]];
// CHECK-NEXT: custom [[V11:[^ ]*]] = [[V6:[^ ]*]] > [[V4]];
// CHECK-NEXT: std::valarray<bool> [[V12:[^ ]*]] = [[V5:[^ ]*]] >= [[V5]];
// CHECK-NEXT: custom [[V13:[^ ]*]] = [[V6]] <=> [[V6]];