LLVM itself is generally moving away from using `undef` and towards
using `poison`, to the point of having a lint that caches new uses of
`undef` in tests.
In order to not trip the lint on new patterns and to conform to the
evolution of LLVM
- Rename valious ::undef() methods on StructBuilder subclasses to
::poison()
- Audit the uses of UndefOp in the MLIR libraries and replace almost all
of them with PoisonOp
The remaining uses of `undef` are initializing `uninitialized` memrefs,
explicit conversions to undef from SPIR-V, and a few cases in
AMDGPUToROCDL where usage like
%v = insertelement <M x iN> undef, iN %v, i32 0
%arg = bitcast <M x iN> %v to i(M * N)
is used to handle "i32" arguments that are are really packed vectors of
smaller types that won't always be fully initialized.
359 lines
14 KiB
C++
359 lines
14 KiB
C++
//===- SparseStorageSpecifierToLLVM.cpp - convert specifier to llvm -------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "Utils/CodegenUtils.h"
|
|
|
|
#include "mlir/Conversion/LLVMCommon/StructBuilder.h"
|
|
#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
|
|
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
|
|
|
|
#include <optional>
|
|
|
|
using namespace mlir;
|
|
using namespace sparse_tensor;
|
|
|
|
namespace {
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Helper methods.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static SmallVector<Type, 4> getSpecifierFields(StorageSpecifierType tp) {
|
|
MLIRContext *ctx = tp.getContext();
|
|
auto enc = tp.getEncoding();
|
|
const Level lvlRank = enc.getLvlRank();
|
|
|
|
SmallVector<Type, 4> result;
|
|
// TODO: how can we get the lowering type for index type in the later pipeline
|
|
// to be consistent? LLVM::StructureType does not allow index fields.
|
|
auto sizeType = IntegerType::get(tp.getContext(), 64);
|
|
auto lvlSizes = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
|
|
auto memSizes = LLVM::LLVMArrayType::get(ctx, sizeType,
|
|
getNumDataFieldsFromEncoding(enc));
|
|
result.push_back(lvlSizes);
|
|
result.push_back(memSizes);
|
|
|
|
if (enc.isSlice()) {
|
|
// Extra fields are required for the slice information.
|
|
auto dimOffset = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
|
|
auto dimStride = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
|
|
|
|
result.push_back(dimOffset);
|
|
result.push_back(dimStride);
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
static Type convertSpecifier(StorageSpecifierType tp) {
|
|
return LLVM::LLVMStructType::getLiteral(tp.getContext(),
|
|
getSpecifierFields(tp));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Specifier struct builder.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
constexpr uint64_t kLvlSizePosInSpecifier = 0;
|
|
constexpr uint64_t kMemSizePosInSpecifier = 1;
|
|
constexpr uint64_t kDimOffsetPosInSpecifier = 2;
|
|
constexpr uint64_t kDimStridePosInSpecifier = 3;
|
|
|
|
class SpecifierStructBuilder : public StructBuilder {
|
|
private:
|
|
Value extractField(OpBuilder &builder, Location loc,
|
|
ArrayRef<int64_t> indices) const {
|
|
return genCast(builder, loc,
|
|
builder.create<LLVM::ExtractValueOp>(loc, value, indices),
|
|
builder.getIndexType());
|
|
}
|
|
|
|
void insertField(OpBuilder &builder, Location loc, ArrayRef<int64_t> indices,
|
|
Value v) {
|
|
value = builder.create<LLVM::InsertValueOp>(
|
|
loc, value, genCast(builder, loc, v, builder.getIntegerType(64)),
|
|
indices);
|
|
}
|
|
|
|
public:
|
|
explicit SpecifierStructBuilder(Value specifier) : StructBuilder(specifier) {
|
|
assert(value);
|
|
}
|
|
|
|
// Undef value for dimension sizes, all zero value for memory sizes.
|
|
static Value getInitValue(OpBuilder &builder, Location loc, Type structType,
|
|
Value source);
|
|
|
|
Value lvlSize(OpBuilder &builder, Location loc, Level lvl) const;
|
|
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value size);
|
|
|
|
Value dimOffset(OpBuilder &builder, Location loc, Dimension dim) const;
|
|
void setDimOffset(OpBuilder &builder, Location loc, Dimension dim,
|
|
Value size);
|
|
|
|
Value dimStride(OpBuilder &builder, Location loc, Dimension dim) const;
|
|
void setDimStride(OpBuilder &builder, Location loc, Dimension dim,
|
|
Value size);
|
|
|
|
Value memSize(OpBuilder &builder, Location loc, FieldIndex fidx) const;
|
|
void setMemSize(OpBuilder &builder, Location loc, FieldIndex fidx,
|
|
Value size);
|
|
|
|
Value memSizeArray(OpBuilder &builder, Location loc) const;
|
|
void setMemSizeArray(OpBuilder &builder, Location loc, Value array);
|
|
};
|
|
|
|
Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc,
|
|
Type structType, Value source) {
|
|
Value metaData = builder.create<LLVM::PoisonOp>(loc, structType);
|
|
SpecifierStructBuilder md(metaData);
|
|
if (!source) {
|
|
auto memSizeArrayType =
|
|
cast<LLVM::LLVMArrayType>(cast<LLVM::LLVMStructType>(structType)
|
|
.getBody()[kMemSizePosInSpecifier]);
|
|
|
|
Value zero = constantZero(builder, loc, memSizeArrayType.getElementType());
|
|
// Fill memSizes array with zero.
|
|
for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++)
|
|
md.setMemSize(builder, loc, i, zero);
|
|
} else {
|
|
// We copy non-slice information (memory sizes array) from source
|
|
SpecifierStructBuilder sourceMd(source);
|
|
md.setMemSizeArray(builder, loc, sourceMd.memSizeArray(builder, loc));
|
|
}
|
|
return md;
|
|
}
|
|
|
|
/// Builds IR extracting the pos-th offset from the descriptor.
|
|
Value SpecifierStructBuilder::dimOffset(OpBuilder &builder, Location loc,
|
|
Dimension dim) const {
|
|
return extractField(
|
|
builder, loc,
|
|
ArrayRef<int64_t>{kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)});
|
|
}
|
|
|
|
/// Builds IR inserting the pos-th offset into the descriptor.
|
|
void SpecifierStructBuilder::setDimOffset(OpBuilder &builder, Location loc,
|
|
Dimension dim, Value size) {
|
|
insertField(
|
|
builder, loc,
|
|
ArrayRef<int64_t>{kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)},
|
|
size);
|
|
}
|
|
|
|
/// Builds IR extracting the `lvl`-th level-size from the descriptor.
|
|
Value SpecifierStructBuilder::lvlSize(OpBuilder &builder, Location loc,
|
|
Level lvl) const {
|
|
// This static_cast makes the narrowing of `lvl` explicit, as required
|
|
// by the braces notation for the ctor.
|
|
return extractField(
|
|
builder, loc,
|
|
ArrayRef<int64_t>{kLvlSizePosInSpecifier, static_cast<int64_t>(lvl)});
|
|
}
|
|
|
|
/// Builds IR inserting the `lvl`-th level-size into the descriptor.
|
|
void SpecifierStructBuilder::setLvlSize(OpBuilder &builder, Location loc,
|
|
Level lvl, Value size) {
|
|
// This static_cast makes the narrowing of `lvl` explicit, as required
|
|
// by the braces notation for the ctor.
|
|
insertField(
|
|
builder, loc,
|
|
ArrayRef<int64_t>{kLvlSizePosInSpecifier, static_cast<int64_t>(lvl)},
|
|
size);
|
|
}
|
|
|
|
/// Builds IR extracting the pos-th stride from the descriptor.
|
|
Value SpecifierStructBuilder::dimStride(OpBuilder &builder, Location loc,
|
|
Dimension dim) const {
|
|
return extractField(
|
|
builder, loc,
|
|
ArrayRef<int64_t>{kDimStridePosInSpecifier, static_cast<int64_t>(dim)});
|
|
}
|
|
|
|
/// Builds IR inserting the pos-th stride into the descriptor.
|
|
void SpecifierStructBuilder::setDimStride(OpBuilder &builder, Location loc,
|
|
Dimension dim, Value size) {
|
|
insertField(
|
|
builder, loc,
|
|
ArrayRef<int64_t>{kDimStridePosInSpecifier, static_cast<int64_t>(dim)},
|
|
size);
|
|
}
|
|
|
|
/// Builds IR extracting the pos-th memory size into the descriptor.
|
|
Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc,
|
|
FieldIndex fidx) const {
|
|
return extractField(
|
|
builder, loc,
|
|
ArrayRef<int64_t>{kMemSizePosInSpecifier, static_cast<int64_t>(fidx)});
|
|
}
|
|
|
|
/// Builds IR inserting the `fidx`-th memory-size into the descriptor.
|
|
void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc,
|
|
FieldIndex fidx, Value size) {
|
|
insertField(
|
|
builder, loc,
|
|
ArrayRef<int64_t>{kMemSizePosInSpecifier, static_cast<int64_t>(fidx)},
|
|
size);
|
|
}
|
|
|
|
/// Builds IR extracting the memory size array from the descriptor.
|
|
Value SpecifierStructBuilder::memSizeArray(OpBuilder &builder,
|
|
Location loc) const {
|
|
return builder.create<LLVM::ExtractValueOp>(loc, value,
|
|
kMemSizePosInSpecifier);
|
|
}
|
|
|
|
/// Builds IR inserting the memory size array into the descriptor.
|
|
void SpecifierStructBuilder::setMemSizeArray(OpBuilder &builder, Location loc,
|
|
Value array) {
|
|
value = builder.create<LLVM::InsertValueOp>(loc, value, array,
|
|
kMemSizePosInSpecifier);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// The sparse storage specifier type converter (defined in Passes.h).
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
StorageSpecifierToLLVMTypeConverter::StorageSpecifierToLLVMTypeConverter() {
|
|
addConversion([](Type type) { return type; });
|
|
addConversion(convertSpecifier);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Storage specifier conversion rules.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <typename Base, typename SourceOp>
|
|
class SpecifierGetterSetterOpConverter : public OpConversionPattern<SourceOp> {
|
|
public:
|
|
using OpAdaptor = typename SourceOp::Adaptor;
|
|
using OpConversionPattern<SourceOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
SpecifierStructBuilder spec(adaptor.getSpecifier());
|
|
switch (op.getSpecifierKind()) {
|
|
case StorageSpecifierKind::LvlSize: {
|
|
Value v = Base::onLvlSize(rewriter, op, spec, (*op.getLevel()));
|
|
rewriter.replaceOp(op, v);
|
|
return success();
|
|
}
|
|
case StorageSpecifierKind::DimOffset: {
|
|
Value v = Base::onDimOffset(rewriter, op, spec, (*op.getLevel()));
|
|
rewriter.replaceOp(op, v);
|
|
return success();
|
|
}
|
|
case StorageSpecifierKind::DimStride: {
|
|
Value v = Base::onDimStride(rewriter, op, spec, (*op.getLevel()));
|
|
rewriter.replaceOp(op, v);
|
|
return success();
|
|
}
|
|
case StorageSpecifierKind::CrdMemSize:
|
|
case StorageSpecifierKind::PosMemSize:
|
|
case StorageSpecifierKind::ValMemSize: {
|
|
auto enc = op.getSpecifier().getType().getEncoding();
|
|
StorageLayout layout(enc);
|
|
std::optional<unsigned> lvl;
|
|
if (op.getLevel())
|
|
lvl = (*op.getLevel());
|
|
unsigned idx =
|
|
layout.getMemRefFieldIndex(toFieldKind(op.getSpecifierKind()), lvl);
|
|
Value v = Base::onMemSize(rewriter, op, spec, idx);
|
|
rewriter.replaceOp(op, v);
|
|
return success();
|
|
}
|
|
}
|
|
llvm_unreachable("unrecognized specifer kind");
|
|
}
|
|
};
|
|
|
|
struct StorageSpecifierSetOpConverter
|
|
: public SpecifierGetterSetterOpConverter<StorageSpecifierSetOpConverter,
|
|
SetStorageSpecifierOp> {
|
|
using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
|
|
|
|
static Value onLvlSize(OpBuilder &builder, SetStorageSpecifierOp op,
|
|
SpecifierStructBuilder &spec, Level lvl) {
|
|
spec.setLvlSize(builder, op.getLoc(), lvl, op.getValue());
|
|
return spec;
|
|
}
|
|
|
|
static Value onDimOffset(OpBuilder &builder, SetStorageSpecifierOp op,
|
|
SpecifierStructBuilder &spec, Dimension d) {
|
|
spec.setDimOffset(builder, op.getLoc(), d, op.getValue());
|
|
return spec;
|
|
}
|
|
|
|
static Value onDimStride(OpBuilder &builder, SetStorageSpecifierOp op,
|
|
SpecifierStructBuilder &spec, Dimension d) {
|
|
spec.setDimStride(builder, op.getLoc(), d, op.getValue());
|
|
return spec;
|
|
}
|
|
|
|
static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op,
|
|
SpecifierStructBuilder &spec, FieldIndex fidx) {
|
|
spec.setMemSize(builder, op.getLoc(), fidx, op.getValue());
|
|
return spec;
|
|
}
|
|
};
|
|
|
|
struct StorageSpecifierGetOpConverter
|
|
: public SpecifierGetterSetterOpConverter<StorageSpecifierGetOpConverter,
|
|
GetStorageSpecifierOp> {
|
|
using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
|
|
|
|
static Value onLvlSize(OpBuilder &builder, GetStorageSpecifierOp op,
|
|
SpecifierStructBuilder &spec, Level lvl) {
|
|
return spec.lvlSize(builder, op.getLoc(), lvl);
|
|
}
|
|
|
|
static Value onDimOffset(OpBuilder &builder, GetStorageSpecifierOp op,
|
|
const SpecifierStructBuilder &spec, Dimension d) {
|
|
return spec.dimOffset(builder, op.getLoc(), d);
|
|
}
|
|
|
|
static Value onDimStride(OpBuilder &builder, GetStorageSpecifierOp op,
|
|
const SpecifierStructBuilder &spec, Dimension d) {
|
|
return spec.dimStride(builder, op.getLoc(), d);
|
|
}
|
|
|
|
static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op,
|
|
SpecifierStructBuilder &spec, FieldIndex fidx) {
|
|
return spec.memSize(builder, op.getLoc(), fidx);
|
|
}
|
|
};
|
|
|
|
struct StorageSpecifierInitOpConverter
|
|
: public OpConversionPattern<StorageSpecifierInitOp> {
|
|
public:
|
|
using OpConversionPattern::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Type llvmType = getTypeConverter()->convertType(op.getResult().getType());
|
|
rewriter.replaceOp(
|
|
op, SpecifierStructBuilder::getInitValue(
|
|
rewriter, op.getLoc(), llvmType, adaptor.getSource()));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Public method for populating conversion rules.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void mlir::populateStorageSpecifierToLLVMPatterns(
|
|
const TypeConverter &converter, RewritePatternSet &patterns) {
|
|
patterns.add<StorageSpecifierGetOpConverter, StorageSpecifierSetOpConverter,
|
|
StorageSpecifierInitOpConverter>(converter,
|
|
patterns.getContext());
|
|
}
|