[mlir][sparse] Factoring out Transforms/CodegenUtils.{cpp,h}
This moves a bunch of helper functions from `Transforms/SparseTensorConversion.cpp` into `Transforms/CodegenUtils.{cpp,h}` so that they can be reused by `Transforms/Sparsification.cpp`, etc.
See also the dependent D115010 which cleans up some corner cases in this change.
Reviewed By: aartbik, rriddle
Differential Revision: https://reviews.llvm.org/D115008
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "CodegenUtils.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
@@ -39,113 +40,6 @@ enum class EmitCInterface : bool { Off = false, On = true };
|
||||
// Helper methods.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Generates a constant zero of the given type.
|
||||
inline static Value constantZero(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type t) {
|
||||
return rewriter.create<arith::ConstantOp>(loc, t, rewriter.getZeroAttr(t));
|
||||
}
|
||||
|
||||
/// Generates a constant of `index` type.
|
||||
inline static Value constantIndex(ConversionPatternRewriter &rewriter,
|
||||
Location loc, int64_t i) {
|
||||
return rewriter.create<arith::ConstantIndexOp>(loc, i);
|
||||
}
|
||||
|
||||
/// Generates a constant of `i32` type.
|
||||
inline static Value constantI32(ConversionPatternRewriter &rewriter,
|
||||
Location loc, int32_t i) {
|
||||
return rewriter.create<arith::ConstantIntOp>(loc, i, 32);
|
||||
}
|
||||
|
||||
/// Generates a constant of `i8` type.
|
||||
inline static Value constantI8(ConversionPatternRewriter &rewriter,
|
||||
Location loc, int8_t i) {
|
||||
return rewriter.create<arith::ConstantIntOp>(loc, i, 8);
|
||||
}
|
||||
|
||||
/// Generates a constant of the given `Action`.
|
||||
static Value constantAction(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Action action) {
|
||||
return constantI32(rewriter, loc, static_cast<uint32_t>(action));
|
||||
}
|
||||
|
||||
/// Generates a constant of the internal type encoding for overhead storage.
|
||||
static Value constantOverheadTypeEncoding(ConversionPatternRewriter &rewriter,
|
||||
Location loc, unsigned width) {
|
||||
OverheadType sec;
|
||||
switch (width) {
|
||||
default:
|
||||
sec = OverheadType::kU64;
|
||||
break;
|
||||
case 32:
|
||||
sec = OverheadType::kU32;
|
||||
break;
|
||||
case 16:
|
||||
sec = OverheadType::kU16;
|
||||
break;
|
||||
case 8:
|
||||
sec = OverheadType::kU8;
|
||||
break;
|
||||
}
|
||||
return constantI32(rewriter, loc, static_cast<uint32_t>(sec));
|
||||
}
|
||||
|
||||
/// Generates a constant of the internal type encoding for pointer
|
||||
/// overhead storage.
|
||||
static Value constantPointerTypeEncoding(ConversionPatternRewriter &rewriter,
|
||||
Location loc,
|
||||
SparseTensorEncodingAttr &enc) {
|
||||
return constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth());
|
||||
}
|
||||
|
||||
/// Generates a constant of the internal type encoding for index overhead
|
||||
/// storage.
|
||||
static Value constantIndexTypeEncoding(ConversionPatternRewriter &rewriter,
|
||||
Location loc,
|
||||
SparseTensorEncodingAttr &enc) {
|
||||
return constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth());
|
||||
}
|
||||
|
||||
/// Generates a constant of the internal type encoding for primary storage.
|
||||
static Value constantPrimaryTypeEncoding(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type tp) {
|
||||
PrimaryType primary;
|
||||
if (tp.isF64())
|
||||
primary = PrimaryType::kF64;
|
||||
else if (tp.isF32())
|
||||
primary = PrimaryType::kF32;
|
||||
else if (tp.isInteger(64))
|
||||
primary = PrimaryType::kI64;
|
||||
else if (tp.isInteger(32))
|
||||
primary = PrimaryType::kI32;
|
||||
else if (tp.isInteger(16))
|
||||
primary = PrimaryType::kI16;
|
||||
else if (tp.isInteger(8))
|
||||
primary = PrimaryType::kI8;
|
||||
else
|
||||
llvm_unreachable("Unknown element type");
|
||||
return constantI32(rewriter, loc, static_cast<uint32_t>(primary));
|
||||
}
|
||||
|
||||
/// Generates a constant of the internal dimension level type encoding.
|
||||
static Value
|
||||
constantDimLevelTypeEncoding(ConversionPatternRewriter &rewriter, Location loc,
|
||||
SparseTensorEncodingAttr::DimLevelType dlt) {
|
||||
DimLevelType dlt2;
|
||||
switch (dlt) {
|
||||
case SparseTensorEncodingAttr::DimLevelType::Dense:
|
||||
dlt2 = DimLevelType::kDense;
|
||||
break;
|
||||
case SparseTensorEncodingAttr::DimLevelType::Compressed:
|
||||
dlt2 = DimLevelType::kCompressed;
|
||||
break;
|
||||
case SparseTensorEncodingAttr::DimLevelType::Singleton:
|
||||
dlt2 = DimLevelType::kSingleton;
|
||||
break;
|
||||
}
|
||||
return constantI8(rewriter, loc, static_cast<uint8_t>(dlt2));
|
||||
}
|
||||
|
||||
/// Returns the equivalent of `void*` for opaque arguments to the
|
||||
/// execution engine.
|
||||
static Type getOpaquePointerType(PatternRewriter &rewriter) {
|
||||
@@ -336,22 +230,6 @@ static void newParams(ConversionPatternRewriter &rewriter,
|
||||
params.push_back(ptr);
|
||||
}
|
||||
|
||||
/// Generates the comparison `v != 0` where `v` is of numeric type `t`.
|
||||
/// For floating types, we use the "unordered" comparator (i.e., returns
|
||||
/// true if `v` is NaN).
|
||||
static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value v) {
|
||||
Type t = v.getType();
|
||||
Value zero = constantZero(rewriter, loc, t);
|
||||
if (t.isa<FloatType>())
|
||||
return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
|
||||
zero);
|
||||
if (t.isIntOrIndex())
|
||||
return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v,
|
||||
zero);
|
||||
llvm_unreachable("Unknown element type");
|
||||
}
|
||||
|
||||
/// Generates the code to read the value from tensor[ivs], and conditionally
|
||||
/// stores the indices ivs to the memory in ind. The generated code looks like
|
||||
/// the following and the insertion point after this routine is inside the
|
||||
|
||||
Reference in New Issue
Block a user