[mlir][sparse] Factoring out Transforms/CodegenUtils.{cpp,h}

This moves a bunch of helper functions from `Transforms/SparseTensorConversion.cpp` into `Transforms/CodegenUtils.{cpp,h}` so that they can be reused by `Transforms/Sparsification.cpp`, etc.

See also the dependent D115010 which cleans up some corner cases in this change.

Reviewed By: aartbik, rriddle

Differential Revision: https://reviews.llvm.org/D115008
This commit is contained in:
wren romano
2022-01-04 15:06:28 -08:00
parent c99b2c6316
commit 85b8d03e12
6 changed files with 335 additions and 188 deletions

View File

@@ -14,6 +14,7 @@
//
//===----------------------------------------------------------------------===//
#include "CodegenUtils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
@@ -39,113 +40,6 @@ enum class EmitCInterface : bool { Off = false, On = true };
// Helper methods.
//===----------------------------------------------------------------------===//
/// Generates a constant zero of the given type.
inline static Value constantZero(ConversionPatternRewriter &rewriter,
Location loc, Type t) {
return rewriter.create<arith::ConstantOp>(loc, t, rewriter.getZeroAttr(t));
}
/// Generates a constant of `index` type.
inline static Value constantIndex(ConversionPatternRewriter &rewriter,
Location loc, int64_t i) {
return rewriter.create<arith::ConstantIndexOp>(loc, i);
}
/// Generates a constant of `i32` type.
inline static Value constantI32(ConversionPatternRewriter &rewriter,
Location loc, int32_t i) {
return rewriter.create<arith::ConstantIntOp>(loc, i, 32);
}
/// Generates a constant of `i8` type.
inline static Value constantI8(ConversionPatternRewriter &rewriter,
Location loc, int8_t i) {
return rewriter.create<arith::ConstantIntOp>(loc, i, 8);
}
/// Generates a constant of the given `Action`.
static Value constantAction(ConversionPatternRewriter &rewriter, Location loc,
Action action) {
return constantI32(rewriter, loc, static_cast<uint32_t>(action));
}
/// Generates a constant of the internal type encoding for overhead storage.
static Value constantOverheadTypeEncoding(ConversionPatternRewriter &rewriter,
Location loc, unsigned width) {
OverheadType sec;
switch (width) {
default:
sec = OverheadType::kU64;
break;
case 32:
sec = OverheadType::kU32;
break;
case 16:
sec = OverheadType::kU16;
break;
case 8:
sec = OverheadType::kU8;
break;
}
return constantI32(rewriter, loc, static_cast<uint32_t>(sec));
}
/// Generates a constant of the internal type encoding for pointer
/// overhead storage.
static Value constantPointerTypeEncoding(ConversionPatternRewriter &rewriter,
Location loc,
SparseTensorEncodingAttr &enc) {
return constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth());
}
/// Generates a constant of the internal type encoding for index overhead
/// storage.
static Value constantIndexTypeEncoding(ConversionPatternRewriter &rewriter,
Location loc,
SparseTensorEncodingAttr &enc) {
return constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth());
}
/// Generates a constant of the internal type encoding for primary storage.
static Value constantPrimaryTypeEncoding(ConversionPatternRewriter &rewriter,
Location loc, Type tp) {
PrimaryType primary;
if (tp.isF64())
primary = PrimaryType::kF64;
else if (tp.isF32())
primary = PrimaryType::kF32;
else if (tp.isInteger(64))
primary = PrimaryType::kI64;
else if (tp.isInteger(32))
primary = PrimaryType::kI32;
else if (tp.isInteger(16))
primary = PrimaryType::kI16;
else if (tp.isInteger(8))
primary = PrimaryType::kI8;
else
llvm_unreachable("Unknown element type");
return constantI32(rewriter, loc, static_cast<uint32_t>(primary));
}
/// Generates a constant of the internal dimension level type encoding.
static Value
constantDimLevelTypeEncoding(ConversionPatternRewriter &rewriter, Location loc,
SparseTensorEncodingAttr::DimLevelType dlt) {
DimLevelType dlt2;
switch (dlt) {
case SparseTensorEncodingAttr::DimLevelType::Dense:
dlt2 = DimLevelType::kDense;
break;
case SparseTensorEncodingAttr::DimLevelType::Compressed:
dlt2 = DimLevelType::kCompressed;
break;
case SparseTensorEncodingAttr::DimLevelType::Singleton:
dlt2 = DimLevelType::kSingleton;
break;
}
return constantI8(rewriter, loc, static_cast<uint8_t>(dlt2));
}
/// Returns the equivalent of `void*` for opaque arguments to the
/// execution engine.
static Type getOpaquePointerType(PatternRewriter &rewriter) {
@@ -336,22 +230,6 @@ static void newParams(ConversionPatternRewriter &rewriter,
params.push_back(ptr);
}
/// Generates the comparison `v != 0` where `v` is of numeric type `t`.
/// For floating types, we use the "unordered" comparator (i.e., returns
/// true if `v` is NaN).
static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc,
Value v) {
Type t = v.getType();
Value zero = constantZero(rewriter, loc, t);
if (t.isa<FloatType>())
return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
zero);
if (t.isIntOrIndex())
return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v,
zero);
llvm_unreachable("Unknown element type");
}
/// Generates the code to read the value from tensor[ivs], and conditionally
/// stores the indices ivs to the memory in ind. The generated code looks like
/// the following and the insertion point after this routine is inside the