[mlir][emitc] Add comparison operation
This adds a comparison operation to EmitC which supports ==, !=, <=, <, >=, >, <=>. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D158180
This commit is contained in:
committed by
Marius Brehler
parent
54784b1831
commit
adea7e7032
@@ -2,6 +2,8 @@ add_mlir_dialect(EmitC emitc)
|
||||
add_mlir_doc(EmitC EmitC Dialects/ -gen-dialect-doc)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS EmitCAttributes.td)
|
||||
mlir_tablegen(EmitCEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(EmitCEnums.cpp.inc -gen-enum-defs)
|
||||
mlir_tablegen(EmitCAttributes.h.inc -gen-attrdef-decls)
|
||||
mlir_tablegen(EmitCAttributes.cpp.inc -gen-attrdef-defs)
|
||||
add_public_tablegen_target(MLIREmitCAttributesIncGen)
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
#include "mlir/Dialect/EmitC/IR/EmitCDialect.h.inc"
|
||||
#include "mlir/Dialect/EmitC/IR/EmitCEnums.h.inc"
|
||||
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "mlir/Dialect/EmitC/IR/EmitCAttributes.h.inc"
|
||||
|
||||
@@ -27,8 +27,8 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
class EmitC_Op<string mnemonic, list<Trait> traits = []>
|
||||
: Op<EmitC_Dialect, mnemonic, traits>;
|
||||
|
||||
// Base class for binary arithmetic operations.
|
||||
class EmitC_BinaryArithOp<string mnemonic, list<Trait> traits = []> :
|
||||
// Base class for binary operations.
|
||||
class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
|
||||
EmitC_Op<mnemonic, traits> {
|
||||
let arguments = (ins AnyType:$lhs, AnyType:$rhs);
|
||||
let results = (outs AnyType);
|
||||
@@ -39,7 +39,7 @@ class EmitC_BinaryArithOp<string mnemonic, list<Trait> traits = []> :
|
||||
def IntegerIndexOrOpaqueType : AnyTypeOf<[AnyInteger, Index, EmitC_OpaqueType]>;
|
||||
def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[AnyFloat, IntegerIndexOrOpaqueType]>;
|
||||
|
||||
def EmitC_AddOp : EmitC_BinaryArithOp<"add", []> {
|
||||
def EmitC_AddOp : EmitC_BinaryOp<"add", []> {
|
||||
let summary = "Addition operation";
|
||||
let description = [{
|
||||
With the `add` operation the arithmetic operator + (addition) can
|
||||
@@ -150,6 +150,37 @@ def EmitC_CastOp : EmitC_Op<"cast", [
|
||||
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
|
||||
}
|
||||
|
||||
def EmitC_CmpOp : EmitC_BinaryOp<"cmp", []> {
|
||||
let summary = "Comparison operation";
|
||||
let description = [{
|
||||
With the `cmp` operation the comparison operators ==, !=, <, <=, >, >=, <=>
|
||||
can be applied.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
// Custom form of the cmp operation.
|
||||
%0 = emitc.cmp eq, %arg0, %arg1 : (i32, i32) -> i1
|
||||
%1 = emitc.cmp lt, %arg2, %arg3 :
|
||||
(
|
||||
!emitc.opaque<"std::valarray<float>">,
|
||||
!emitc.opaque<"std::valarray<float>">
|
||||
) -> !emitc.opaque<"std::valarray<bool>">
|
||||
```
|
||||
```c++
|
||||
// Code emitted for the operations above.
|
||||
bool v5 = v1 == v2;
|
||||
std::valarray<bool> v6 = v3 < v4;
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins EmitC_CmpPredicateAttr:$predicate,
|
||||
AnyType:$lhs,
|
||||
AnyType:$rhs);
|
||||
let results = (outs AnyType);
|
||||
|
||||
let assemblyFormat = "$predicate `,` operands attr-dict `:` functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
|
||||
let summary = "Constant operation";
|
||||
let description = [{
|
||||
@@ -180,7 +211,7 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def EmitC_DivOp : EmitC_BinaryArithOp<"div", []> {
|
||||
def EmitC_DivOp : EmitC_BinaryOp<"div", []> {
|
||||
let summary = "Division operation";
|
||||
let description = [{
|
||||
With the `div` operation the arithmetic operator / (division) can
|
||||
@@ -248,7 +279,7 @@ def EmitC_LiteralOp : EmitC_Op<"literal", [Pure]> {
|
||||
let assemblyFormat = "$value attr-dict `:` type($result)";
|
||||
}
|
||||
|
||||
def EmitC_MulOp : EmitC_BinaryArithOp<"mul", []> {
|
||||
def EmitC_MulOp : EmitC_BinaryOp<"mul", []> {
|
||||
let summary = "Multiplication operation";
|
||||
let description = [{
|
||||
With the `mul` operation the arithmetic operator * (multiplication) can
|
||||
@@ -272,7 +303,7 @@ def EmitC_MulOp : EmitC_BinaryArithOp<"mul", []> {
|
||||
let results = (outs FloatIntegerIndexOrOpaqueType);
|
||||
}
|
||||
|
||||
def EmitC_RemOp : EmitC_BinaryArithOp<"rem", []> {
|
||||
def EmitC_RemOp : EmitC_BinaryOp<"rem", []> {
|
||||
let summary = "Remainder operation";
|
||||
let description = [{
|
||||
With the `rem` operation the arithmetic operator % (remainder) can
|
||||
@@ -294,7 +325,7 @@ def EmitC_RemOp : EmitC_BinaryArithOp<"rem", []> {
|
||||
let results = (outs IntegerIndexOrOpaqueType);
|
||||
}
|
||||
|
||||
def EmitC_SubOp : EmitC_BinaryArithOp<"sub", []> {
|
||||
def EmitC_SubOp : EmitC_BinaryOp<"sub", []> {
|
||||
let summary = "Subtraction operation";
|
||||
let description = [{
|
||||
With the `sub` operation the arithmetic operator - (subtraction) can
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
include "mlir/IR/BuiltinAttributeInterfaces.td"
|
||||
include "mlir/IR/EnumAttr.td"
|
||||
include "mlir/Dialect/EmitC/IR/EmitCBase.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -26,6 +27,20 @@ class EmitC_Attr<string name, string attrMnemonic, list<Trait> traits = []>
|
||||
let mnemonic = attrMnemonic;
|
||||
}
|
||||
|
||||
def EmitC_CmpPredicateAttr : I64EnumAttr<
|
||||
"CmpPredicate", "",
|
||||
[
|
||||
I64EnumAttrCase<"eq", 0>,
|
||||
I64EnumAttrCase<"ne", 1>,
|
||||
I64EnumAttrCase<"lt", 2>,
|
||||
I64EnumAttrCase<"le", 3>,
|
||||
I64EnumAttrCase<"gt", 4>,
|
||||
I64EnumAttrCase<"ge", 5>,
|
||||
I64EnumAttrCase<"three_way", 6>,
|
||||
]> {
|
||||
let cppNamespace = "::mlir::emitc";
|
||||
}
|
||||
|
||||
def EmitC_OpaqueAttr : EmitC_Attr<"Opaque", "opaque"> {
|
||||
let summary = "An opaque attribute";
|
||||
|
||||
|
||||
@@ -257,6 +257,12 @@ LogicalResult emitc::VariableOp::verify() {
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// EmitC Enums
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/EmitC/IR/EmitCEnums.cpp.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// EmitC Attributes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -246,15 +246,15 @@ static LogicalResult printOperation(CppEmitter &emitter,
|
||||
return printConstantOp(emitter, operation, value);
|
||||
}
|
||||
|
||||
static LogicalResult printBinaryArithOperation(CppEmitter &emitter,
|
||||
Operation *operation,
|
||||
StringRef binaryArithOperator) {
|
||||
static LogicalResult printBinaryOperation(CppEmitter &emitter,
|
||||
Operation *operation,
|
||||
StringRef binaryOperator) {
|
||||
raw_ostream &os = emitter.ostream();
|
||||
|
||||
if (failed(emitter.emitAssignPrefix(*operation)))
|
||||
return failure();
|
||||
os << emitter.getOrCreateName(operation->getOperand(0));
|
||||
os << " " << binaryArithOperator;
|
||||
os << " " << binaryOperator;
|
||||
os << " " << emitter.getOrCreateName(operation->getOperand(1));
|
||||
|
||||
return success();
|
||||
@@ -263,31 +263,65 @@ static LogicalResult printBinaryArithOperation(CppEmitter &emitter,
|
||||
static LogicalResult printOperation(CppEmitter &emitter, emitc::AddOp addOp) {
|
||||
Operation *operation = addOp.getOperation();
|
||||
|
||||
return printBinaryArithOperation(emitter, operation, "+");
|
||||
return printBinaryOperation(emitter, operation, "+");
|
||||
}
|
||||
|
||||
static LogicalResult printOperation(CppEmitter &emitter, emitc::DivOp divOp) {
|
||||
Operation *operation = divOp.getOperation();
|
||||
|
||||
return printBinaryArithOperation(emitter, operation, "/");
|
||||
return printBinaryOperation(emitter, operation, "/");
|
||||
}
|
||||
|
||||
static LogicalResult printOperation(CppEmitter &emitter, emitc::MulOp mulOp) {
|
||||
Operation *operation = mulOp.getOperation();
|
||||
|
||||
return printBinaryArithOperation(emitter, operation, "*");
|
||||
return printBinaryOperation(emitter, operation, "*");
|
||||
}
|
||||
|
||||
static LogicalResult printOperation(CppEmitter &emitter, emitc::RemOp remOp) {
|
||||
Operation *operation = remOp.getOperation();
|
||||
|
||||
return printBinaryArithOperation(emitter, operation, "%");
|
||||
return printBinaryOperation(emitter, operation, "%");
|
||||
}
|
||||
|
||||
static LogicalResult printOperation(CppEmitter &emitter, emitc::SubOp subOp) {
|
||||
Operation *operation = subOp.getOperation();
|
||||
|
||||
return printBinaryArithOperation(emitter, operation, "-");
|
||||
return printBinaryOperation(emitter, operation, "-");
|
||||
}
|
||||
|
||||
static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
|
||||
Operation *operation = cmpOp.getOperation();
|
||||
|
||||
StringRef binaryOperator;
|
||||
|
||||
switch (cmpOp.getPredicate()) {
|
||||
case emitc::CmpPredicate::eq:
|
||||
binaryOperator = "==";
|
||||
break;
|
||||
case emitc::CmpPredicate::ne:
|
||||
binaryOperator = "!=";
|
||||
break;
|
||||
case emitc::CmpPredicate::lt:
|
||||
binaryOperator = "<";
|
||||
break;
|
||||
case emitc::CmpPredicate::le:
|
||||
binaryOperator = "<=";
|
||||
break;
|
||||
case emitc::CmpPredicate::gt:
|
||||
binaryOperator = ">";
|
||||
break;
|
||||
case emitc::CmpPredicate::ge:
|
||||
binaryOperator = ">=";
|
||||
break;
|
||||
case emitc::CmpPredicate::three_way:
|
||||
binaryOperator = "<=>";
|
||||
break;
|
||||
default:
|
||||
return cmpOp.emitError("unhandled comparison predicate");
|
||||
}
|
||||
|
||||
return printBinaryOperation(emitter, operation, binaryOperator);
|
||||
}
|
||||
|
||||
static LogicalResult printOperation(CppEmitter &emitter,
|
||||
@@ -977,8 +1011,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
|
||||
[&](auto op) { return printOperation(*this, op); })
|
||||
// EmitC ops.
|
||||
.Case<emitc::AddOp, emitc::ApplyOp, emitc::CallOp, emitc::CastOp,
|
||||
emitc::ConstantOp, emitc::DivOp, emitc::IncludeOp, emitc::MulOp,
|
||||
emitc::RemOp, emitc::SubOp, emitc::VariableOp>(
|
||||
emitc::CmpOp, emitc::ConstantOp, emitc::DivOp, emitc::IncludeOp,
|
||||
emitc::MulOp, emitc::RemOp, emitc::SubOp, emitc::VariableOp>(
|
||||
[&](auto op) { return printOperation(*this, op); })
|
||||
// Func ops.
|
||||
.Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
|
||||
|
||||
@@ -79,3 +79,21 @@ func.func @sub_pointer(%arg0: !emitc.ptr<f32>, %arg1: i32, %arg2: !emitc.opaque<
|
||||
%4 = "emitc.sub" (%arg0, %arg3) : (!emitc.ptr<f32>, !emitc.ptr<f32>) -> i32
|
||||
return
|
||||
}
|
||||
|
||||
func.func @cmp(%arg0 : i32, %arg1 : f32, %arg2 : i64, %arg3 : f64, %arg4 : !emitc.opaque<"unsigned">, %arg5 : !emitc.opaque<"std::valarray<int>">, %arg6 : !emitc.opaque<"custom">) {
|
||||
%1 = "emitc.cmp" (%arg0, %arg0) {predicate = 0} : (i32, i32) -> i1
|
||||
%2 = emitc.cmp eq, %arg0, %arg0 : (i32, i32) -> i1
|
||||
%3 = "emitc.cmp" (%arg1, %arg1) {predicate = 1} : (f32, f32) -> i1
|
||||
%4 = emitc.cmp ne, %arg1, %arg1 : (f32, f32) -> i1
|
||||
%5 = "emitc.cmp" (%arg2, %arg2) {predicate = 2} : (i64, i64) -> i1
|
||||
%6 = emitc.cmp lt, %arg2, %arg2 : (i64, i64) -> i1
|
||||
%7 = "emitc.cmp" (%arg3, %arg3) {predicate = 3} : (f64, f64) -> i1
|
||||
%8 = emitc.cmp le, %arg3, %arg3 : (f64, f64) -> i1
|
||||
%9 = "emitc.cmp" (%arg4, %arg4) {predicate = 4} : (!emitc.opaque<"unsigned">, !emitc.opaque<"unsigned">) -> i1
|
||||
%10 = emitc.cmp gt, %arg4, %arg4 : (!emitc.opaque<"unsigned">, !emitc.opaque<"unsigned">) -> i1
|
||||
%11 = "emitc.cmp" (%arg5, %arg5) {predicate = 5} : (!emitc.opaque<"std::valarray<int>">, !emitc.opaque<"std::valarray<int>">) -> !emitc.opaque<"std::valarray<bool>">
|
||||
%12 = emitc.cmp ge, %arg5, %arg5 : (!emitc.opaque<"std::valarray<int>">, !emitc.opaque<"std::valarray<int>">) -> !emitc.opaque<"std::valarray<bool>">
|
||||
%13 = "emitc.cmp" (%arg6, %arg6) {predicate = 6} : (!emitc.opaque<"custom">, !emitc.opaque<"custom">) -> !emitc.opaque<"custom">
|
||||
%14 = emitc.cmp three_way, %arg6, %arg6 : (!emitc.opaque<"custom">, !emitc.opaque<"custom">) -> !emitc.opaque<"custom">
|
||||
return
|
||||
}
|
||||
|
||||
21
mlir/test/Target/Cpp/comparison_operators.mlir
Normal file
21
mlir/test/Target/Cpp/comparison_operators.mlir
Normal file
@@ -0,0 +1,21 @@
|
||||
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
|
||||
|
||||
func.func @cmp(%arg0 : i32, %arg1 : f32, %arg2 : i64, %arg3 : f64, %arg4 : !emitc.opaque<"unsigned">, %arg5 : !emitc.opaque<"std::valarray<int>">, %arg6 : !emitc.opaque<"custom">) {
|
||||
%1 = emitc.cmp eq, %arg0, %arg2 : (i32, i64) -> i1
|
||||
%2 = emitc.cmp ne, %arg1, %arg3 : (f32, f64) -> i1
|
||||
%3 = emitc.cmp lt, %arg2, %arg4 : (i64, !emitc.opaque<"unsigned">) -> !emitc.opaque<"int">
|
||||
%4 = emitc.cmp le, %arg3, %arg3 : (f64, f64) -> i1
|
||||
%5 = emitc.cmp gt, %arg6, %arg4 : (!emitc.opaque<"custom">, !emitc.opaque<"unsigned">) -> !emitc.opaque<"custom">
|
||||
%6 = emitc.cmp ge, %arg5, %arg5 : (!emitc.opaque<"std::valarray<int>">, !emitc.opaque<"std::valarray<int>">) -> !emitc.opaque<"std::valarray<bool>">
|
||||
%7 = emitc.cmp three_way, %arg6, %arg6 : (!emitc.opaque<"custom">, !emitc.opaque<"custom">) -> !emitc.opaque<"custom">
|
||||
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: void cmp
|
||||
// CHECK-NEXT: bool [[V7:[^ ]*]] = [[V0:[^ ]*]] == [[V2:[^ ]*]];
|
||||
// CHECK-NEXT: bool [[V8:[^ ]*]] = [[V1:[^ ]*]] != [[V3:[^ ]*]];
|
||||
// CHECK-NEXT: int [[V9:[^ ]*]] = [[V2]] < [[V4:[^ ]*]];
|
||||
// CHECK-NEXT: bool [[V10:[^ ]*]] = [[V3]] <= [[V3]];
|
||||
// CHECK-NEXT: custom [[V11:[^ ]*]] = [[V6:[^ ]*]] > [[V4]];
|
||||
// CHECK-NEXT: std::valarray<bool> [[V12:[^ ]*]] = [[V5:[^ ]*]] >= [[V5]];
|
||||
// CHECK-NEXT: custom [[V13:[^ ]*]] = [[V6]] <=> [[V6]];
|
||||
Reference in New Issue
Block a user