Files
clang-p2996/mlir/lib/Dialect/Tensor/IR/TensorDialect.cpp
Aart Bik 23c9e8bc25 [mlir][tensors] Introduce attribute interface/attribute for tensor encoding
The new "encoding" field in tensor types so far had no meaning. This revision introduces:

1. an encoding attribute interface in IR: for verification between tensors and encodings in general
2. an attribute in Tensor dialect; #tensor.sparse<dict> + concrete sparse tensors API

Active discussion:
https://llvm.discourse.group/t/rfc-introduce-a-sparse-tensor-type-to-core-mlir/2944/

Reviewed By: silvas, penpornk, bixia

Differential Revision: https://reviews.llvm.org/D101008
2021-04-26 18:31:54 -07:00

199 lines
7.1 KiB
C++

//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::tensor;
//===----------------------------------------------------------------------===//
// TableGen'd Attributes Methods
//===----------------------------------------------------------------------===//
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Tensor/IR/TensorAttrDefs.cpp.inc"
// Dictionary keys.
static constexpr StringRef getSparseDimLevelTypeAttrName() {
return "sparseDimLevelType";
}
static constexpr StringRef getSparseDimOrderingAttrName() {
return "sparseDimOrdering";
}
static constexpr StringRef getSparsePointerBitWidthAttrName() {
return "sparsePointerBitWidth";
}
static constexpr StringRef getSparseIndexBitWidthAttrName() {
return "sparseIndexBitWidth";
}
// Dictionary values.
static constexpr StringRef getDenseDimLevelTypeVal() { return "dense"; }
static constexpr StringRef getCompressedDimLevelTypeVal() {
return "compressed";
}
static constexpr StringRef getSingletonDimLevelTypeVal() { return "singleton"; }
Attribute SparseTensorEncodingAttr::parse(MLIRContext *context,
DialectAsmParser &parser, Type type) {
if (failed(parser.parseLess()))
return {};
DictionaryAttr dict;
if (failed(parser.parseAttribute(dict)))
return {};
if (failed(parser.parseGreater()))
return {};
return SparseTensorEncodingAttr::get(context, dict);
}
void SparseTensorEncodingAttr::print(DialectAsmPrinter &printer) const {
printer << "sparse<" << getDict() << ">";
}
LogicalResult SparseTensorEncodingAttr::verifyEncoding(
llvm::ArrayRef<int64_t> shape, Type elementType,
llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
unsigned size = shape.size();
for (const NamedAttribute &attr : getDict()) {
if (attr.first == getSparseDimLevelTypeAttrName()) {
// Dimension level type verification.
auto arrayAttr = attr.second.dyn_cast<ArrayAttr>();
if (!arrayAttr || size != static_cast<int64_t>(arrayAttr.size()))
return emitError() << "expected an array of size " << size
<< " for dimension level types";
for (unsigned i = 0; i < size; i++) {
auto strAttr = arrayAttr[i].dyn_cast<StringAttr>();
if (!strAttr)
return emitError()
<< "expected string value in dimension level types";
auto strVal = strAttr.getValue();
if (strVal != getDenseDimLevelTypeVal() &&
strVal != getCompressedDimLevelTypeVal() &&
strVal != getSingletonDimLevelTypeVal())
return emitError() << "unexpected dimension level type: " << strAttr;
}
} else if (attr.first == getSparseDimOrderingAttrName()) {
// Dimension order verification.
auto affineAttr = attr.second.dyn_cast<AffineMapAttr>();
if (!affineAttr)
return emitError() << "expected an affine map for dimension ordering";
AffineMap map = affineAttr.getValue();
if (size != map.getNumResults() || !map.isPermutation())
return emitError() << "expected a permutation affine map of size "
<< size << " for dimension ordering";
} else if (attr.first == getSparsePointerBitWidthAttrName() ||
attr.first == getSparseIndexBitWidthAttrName()) {
// Pointer or index bitwidth verification.
auto intAttr = attr.second.dyn_cast<IntegerAttr>();
if (!intAttr)
return emitError() << "expected an integral bitwidth";
switch (intAttr.getInt()) {
case 0:
case 8:
case 16:
case 32:
case 64:
continue;
default:
return emitError() << "unexpected bitwidth: " << intAttr.getInt();
}
} else {
return emitError() << "unexpected key: " << attr.first.str();
}
}
return success();
}
SparseTensorEncodingAttr::DimLevelType
SparseTensorEncodingAttr::getDimLevelType(unsigned dim) const {
if (auto value = getDict().get(getSparseDimLevelTypeAttrName())) {
auto strVal =
value.dyn_cast<ArrayAttr>()[dim].cast<StringAttr>().getValue();
if (strVal == getCompressedDimLevelTypeVal())
return DimLevelType::Compressed;
if (strVal == getSingletonDimLevelTypeVal())
return DimLevelType::Singleton;
}
return DimLevelType::Dense;
}
AffineMap SparseTensorEncodingAttr::getDimOrdering() const {
if (auto value = getDict().get(getSparseDimOrderingAttrName()))
return value.cast<AffineMapAttr>().getValue();
return {};
}
unsigned SparseTensorEncodingAttr::getPointerBitWidth() const {
if (auto value = getDict().get(getSparsePointerBitWidthAttrName()))
return value.cast<IntegerAttr>().getInt();
return 0;
}
unsigned SparseTensorEncodingAttr::getIndexBitWidth() const {
if (auto value = getDict().get(getSparseIndexBitWidthAttrName()))
return value.cast<IntegerAttr>().getInt();
return 0;
}
//===----------------------------------------------------------------------===//
// TensorDialect Dialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
struct TensorInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
BlockAndValueMapping &valueMapping) const final {
return true;
}
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
BlockAndValueMapping &) const final {
return true;
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// TensorDialect Methods
//===----------------------------------------------------------------------===//
void TensorDialect::initialize() {
addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/Tensor/IR/TensorAttrDefs.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
>();
addInterfaces<TensorInlinerInterface>();
}
Attribute TensorDialect::parseAttribute(DialectAsmParser &parser,
Type type) const {
StringRef attrTag;
if (failed(parser.parseKeyword(&attrTag)))
return Attribute();
Attribute attr;
auto parseResult =
generatedAttributeParser(getContext(), parser, attrTag, type, attr);
if (parseResult.hasValue())
return attr;
parser.emitError(parser.getNameLoc(), "unknown tensor attribute");
return Attribute();
}
void TensorDialect::printAttribute(::mlir::Attribute attr,
::mlir::DialectAsmPrinter &printer) const {
if (succeeded(generatedAttributePrinter(attr, printer)))
return;
}