Files
clang-p2996/mlir/lib/Dialect/Shape/IR/Shape.cpp
Sean Silva 569e4f9bc9 shape dialect: add some ops
- add `to_extent_tensor`
 - rename `create_shape` to `from_extent_tensor` for symmetry
- add `split_at` and `concat` ops for basic shape manipulations

This set of ops is inspired by the requirements of lowering a dynamic-shape-aware batch matmul op. For such an op, the "matrix" dimensions aren't subject to broadcasting but the others are, and so we need to slice, broadcast, and reconstruct the final output shape. Furthermore, the actual broadcasting op used downstream uses a tensor of extents as its preferred shape interface for the actual op that does the broadcasting.

However, this functionality is quite general. It's obvious that `to_extent_tensor` is needed long-term to support many common patterns that involve computations on shapes. We can evolve the shape manipulation ops introduced here. The specific choices made here took into consideration the potentially unranked nature of the !shape.shape type, which means that a simple listing of dimensions to extract isn't possible in general.

Differential Revision: https://reviews.llvm.org/D76817
2020-03-27 16:38:42 -07:00

144 lines
4.5 KiB
C++

//===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::shape;
ShapeDialect::ShapeDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
>();
addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType>();
// Allow unknown operations during prototyping and testing. As the dialect is
// still evolving it makes it simple to start with an unregistered ops and
// try different variants before actually defining the op.
allowUnknownOperations();
}
/// Parse a type registered to this dialect.
Type ShapeDialect::parseType(DialectAsmParser &parser) const {
StringRef keyword;
if (parser.parseKeyword(&keyword))
return Type();
if (keyword == "component")
return ComponentType::get(getContext());
if (keyword == "element")
return ElementType::get(getContext());
if (keyword == "shape")
return ShapeType::get(getContext());
if (keyword == "size")
return SizeType::get(getContext());
if (keyword == "value_shape")
return ValueShapeType::get(getContext());
parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
return Type();
}
/// Print a type registered to this dialect.
void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
switch (type.getKind()) {
case ShapeTypes::Component:
os << "component";
return;
case ShapeTypes::Element:
os << "element";
return;
case ShapeTypes::Size:
os << "size";
return;
case ShapeTypes::Shape:
os << "shape";
return;
case ShapeTypes::ValueShape:
os << "value_shape";
return;
default:
llvm_unreachable("unexpected 'shape' type kind");
}
}
//===----------------------------------------------------------------------===//
// Constant*Op
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, ConstantOp &op) {
p << "shape.constant ";
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});
if (op.getAttrs().size() > 1)
p << ' ';
p.printAttributeWithoutType(op.value());
p << " : " << op.getType();
}
static ParseResult parseConstantOp(OpAsmParser &parser,
OperationState &result) {
Attribute valueAttr;
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
Type i64Type = parser.getBuilder().getIntegerType(64);
if (parser.parseAttribute(valueAttr, i64Type, "value", result.attributes))
return failure();
Type type;
if (parser.parseColonType(type))
return failure();
// Add the attribute type to the list.
return parser.addTypeToList(type, result.types);
}
static LogicalResult verify(ConstantOp &op) { return success(); }
//===----------------------------------------------------------------------===//
// SplitAtOp
//===----------------------------------------------------------------------===//
LogicalResult SplitAtOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto shapeType = ShapeType::get(context);
inferredReturnTypes.push_back(shapeType);
inferredReturnTypes.push_back(shapeType);
return success();
}
//===----------------------------------------------------------------------===//
// ConcatOp
//===----------------------------------------------------------------------===//
LogicalResult ConcatOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto shapeType = ShapeType::get(context);
inferredReturnTypes.push_back(shapeType);
return success();
}
namespace mlir {
namespace shape {
#define GET_OP_CLASSES
#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
} // namespace shape
} // namespace mlir