[mlir] NFC: Rename LLVMOpLowering::lowering to LLVMOpLowering::typeConverter

The existing name is an artifact dating back to the times when we did not have
a dedicated TypeConverter infrastructure. It is also confusing with with the
name of classes using it.

Differential revision: https://reviews.llvm.org/D74707
This commit is contained in:
Alex Zinenko
2020-02-18 15:49:13 +01:00
parent 4518aab289
commit 0f04384daf
7 changed files with 130 additions and 123 deletions

View File

@@ -36,8 +36,8 @@ using namespace mlir::vector;
template <typename T>
static LLVM::LLVMType getPtrToElementType(T containerType,
LLVMTypeConverter &lowering) {
return lowering.convertType(containerType.getElementType())
LLVMTypeConverter &typeConverter) {
return typeConverter.convertType(containerType.getElementType())
.template cast<LLVM::LLVMType>()
.getPointerTo();
}
@@ -56,12 +56,13 @@ static VectorType reducedVectorTypeBack(VectorType tp) {
// Helper that picks the proper sequence for inserting.
static Value insertOne(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &lowering, Location loc, Value val1,
Value val2, Type llvmType, int64_t rank, int64_t pos) {
LLVMTypeConverter &typeConverter, Location loc,
Value val1, Value val2, Type llvmType, int64_t rank,
int64_t pos) {
if (rank == 1) {
auto idxType = rewriter.getIndexType();
auto constant = rewriter.create<LLVM::ConstantOp>(
loc, lowering.convertType(idxType),
loc, typeConverter.convertType(idxType),
rewriter.getIntegerAttr(idxType, pos));
return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
constant);
@@ -83,12 +84,12 @@ static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
// Helper that picks the proper sequence for extracting.
static Value extractOne(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &lowering, Location loc, Value val,
Type llvmType, int64_t rank, int64_t pos) {
LLVMTypeConverter &typeConverter, Location loc,
Value val, Type llvmType, int64_t rank, int64_t pos) {
if (rank == 1) {
auto idxType = rewriter.getIndexType();
auto constant = rewriter.create<LLVM::ConstantOp>(
loc, lowering.convertType(idxType),
loc, typeConverter.convertType(idxType),
rewriter.getIntegerAttr(idxType, pos));
return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
constant);
@@ -137,7 +138,7 @@ public:
ConversionPatternRewriter &rewriter) const override {
auto broadcastOp = cast<vector::BroadcastOp>(op);
VectorType dstVectorType = broadcastOp.getVectorType();
if (lowering.convertType(dstVectorType) == nullptr)
if (typeConverter.convertType(dstVectorType) == nullptr)
return matchFailure();
// Rewrite when the full vector type can be lowered (which
// implies all 'reduced' types can be lowered too).
@@ -203,12 +204,12 @@ private:
Value duplicateOneRank(Value value, Location loc, VectorType srcVectorType,
VectorType dstVectorType, int64_t rank, int64_t dim,
ConversionPatternRewriter &rewriter) const {
Type llvmType = lowering.convertType(dstVectorType);
Type llvmType = typeConverter.convertType(dstVectorType);
assert((llvmType != nullptr) && "unlowerable vector type");
if (rank == 1) {
Value undef = rewriter.create<LLVM::UndefOp>(loc, llvmType);
Value expand =
insertOne(rewriter, lowering, loc, undef, value, llvmType, rank, 0);
Value expand = insertOne(rewriter, typeConverter, loc, undef, value,
llvmType, rank, 0);
SmallVector<int32_t, 4> zeroValues(dim, 0);
return rewriter.create<LLVM::ShuffleVectorOp>(
loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues));
@@ -217,8 +218,8 @@ private:
reducedVectorTypeFront(dstVectorType), rewriter);
Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
for (int64_t d = 0; d < dim; ++d) {
result =
insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d);
result = insertOne(rewriter, typeConverter, loc, result, expand, llvmType,
rank, d);
}
return result;
}
@@ -243,31 +244,32 @@ private:
Value stretchOneRank(Value value, Location loc, VectorType srcVectorType,
VectorType dstVectorType, int64_t rank, int64_t dim,
ConversionPatternRewriter &rewriter) const {
Type llvmType = lowering.convertType(dstVectorType);
Type llvmType = typeConverter.convertType(dstVectorType);
assert((llvmType != nullptr) && "unlowerable vector type");
Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
bool atStretch = dim != srcVectorType.getDimSize(0);
if (rank == 1) {
assert(atStretch);
Type redLlvmType = lowering.convertType(dstVectorType.getElementType());
Type redLlvmType =
typeConverter.convertType(dstVectorType.getElementType());
Value one =
extractOne(rewriter, lowering, loc, value, redLlvmType, rank, 0);
Value expand =
insertOne(rewriter, lowering, loc, result, one, llvmType, rank, 0);
extractOne(rewriter, typeConverter, loc, value, redLlvmType, rank, 0);
Value expand = insertOne(rewriter, typeConverter, loc, result, one,
llvmType, rank, 0);
SmallVector<int32_t, 4> zeroValues(dim, 0);
return rewriter.create<LLVM::ShuffleVectorOp>(
loc, expand, result, rewriter.getI32ArrayAttr(zeroValues));
}
VectorType redSrcType = reducedVectorTypeFront(srcVectorType);
VectorType redDstType = reducedVectorTypeFront(dstVectorType);
Type redLlvmType = lowering.convertType(redSrcType);
Type redLlvmType = typeConverter.convertType(redSrcType);
for (int64_t d = 0; d < dim; ++d) {
int64_t pos = atStretch ? 0 : d;
Value one =
extractOne(rewriter, lowering, loc, value, redLlvmType, rank, pos);
Value one = extractOne(rewriter, typeConverter, loc, value, redLlvmType,
rank, pos);
Value expand = expandRanks(one, loc, redSrcType, redDstType, rewriter);
result =
insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d);
result = insertOne(rewriter, typeConverter, loc, result, expand, llvmType,
rank, d);
}
return result;
}
@@ -286,7 +288,7 @@ public:
auto reductionOp = cast<vector::ReductionOp>(op);
auto kind = reductionOp.kind();
Type eltType = reductionOp.dest().getType();
Type llvmType = lowering.convertType(eltType);
Type llvmType = typeConverter.convertType(eltType);
if (eltType.isInteger(32) || eltType.isInteger(64)) {
// Integer reductions: add/mul/min/max/and/or/xor.
if (kind == "add")
@@ -353,7 +355,7 @@ public:
auto reductionOp = cast<vector::ReductionV2Op>(op);
auto kind = reductionOp.kind();
Type eltType = reductionOp.dest().getType();
Type llvmType = lowering.convertType(eltType);
Type llvmType = typeConverter.convertType(eltType);
if (kind == "add") {
rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>(
op, llvmType, operands[1], operands[0]);
@@ -383,7 +385,7 @@ public:
auto v1Type = shuffleOp.getV1VectorType();
auto v2Type = shuffleOp.getV2VectorType();
auto vectorType = shuffleOp.getVectorType();
Type llvmType = lowering.convertType(vectorType);
Type llvmType = typeConverter.convertType(vectorType);
auto maskArrayAttr = shuffleOp.mask();
// Bail if result type cannot be lowered.
@@ -415,10 +417,10 @@ public:
extPos -= v1Dim;
value = adaptor.v2();
}
Value extract =
extractOne(rewriter, lowering, loc, value, llvmType, rank, extPos);
insert = insertOne(rewriter, lowering, loc, insert, extract, llvmType,
rank, insPos++);
Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType,
rank, extPos);
insert = insertOne(rewriter, typeConverter, loc, insert, extract,
llvmType, rank, insPos++);
}
rewriter.replaceOp(op, insert);
return matchSuccess();
@@ -438,7 +440,7 @@ public:
auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
auto extractEltOp = cast<vector::ExtractElementOp>(op);
auto vectorType = extractEltOp.getVectorType();
auto llvmType = lowering.convertType(vectorType.getElementType());
auto llvmType = typeConverter.convertType(vectorType.getElementType());
// Bail if result type cannot be lowered.
if (!llvmType)
@@ -465,7 +467,7 @@ public:
auto extractOp = cast<vector::ExtractOp>(op);
auto vectorType = extractOp.getVectorType();
auto resultType = extractOp.getResult().getType();
auto llvmResultType = lowering.convertType(resultType);
auto llvmResultType = typeConverter.convertType(resultType);
auto positionArrayAttr = extractOp.position();
// Bail if result type cannot be lowered.
@@ -489,13 +491,13 @@ public:
auto nMinusOnePositionAttrs =
ArrayAttr::get(positionAttrs.drop_back(), context);
extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, lowering.convertType(oneDVectorType), extracted,
loc, typeConverter.convertType(oneDVectorType), extracted,
nMinusOnePositionAttrs);
}
// Remaining extraction of element from 1-D LLVM vector
auto position = positionAttrs.back().cast<IntegerAttr>();
auto i64Type = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
extracted =
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
@@ -553,7 +555,7 @@ public:
auto adaptor = vector::InsertElementOpOperandAdaptor(operands);
auto insertEltOp = cast<vector::InsertElementOp>(op);
auto vectorType = insertEltOp.getDestVectorType();
auto llvmType = lowering.convertType(vectorType);
auto llvmType = typeConverter.convertType(vectorType);
// Bail if result type cannot be lowered.
if (!llvmType)
@@ -580,7 +582,7 @@ public:
auto insertOp = cast<vector::InsertOp>(op);
auto sourceType = insertOp.getSourceType();
auto destVectorType = insertOp.getDestVectorType();
auto llvmResultType = lowering.convertType(destVectorType);
auto llvmResultType = typeConverter.convertType(destVectorType);
auto positionArrayAttr = insertOp.position();
// Bail if result type cannot be lowered.
@@ -607,16 +609,16 @@ public:
auto nMinusOnePositionAttrs =
ArrayAttr::get(positionAttrs.drop_back(), context);
extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, lowering.convertType(oneDVectorType), extracted,
loc, typeConverter.convertType(oneDVectorType), extracted,
nMinusOnePositionAttrs);
}
// Insertion of an element into a 1-D LLVM vector.
auto i64Type = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
Value inserted = rewriter.create<LLVM::InsertElementOp>(
loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(),
constant);
loc, typeConverter.convertType(oneDVectorType), extracted,
adaptor.source(), constant);
// Potential insertion of resulting 1-D vector into array.
if (positionAttrs.size() > 1) {
@@ -830,7 +832,7 @@ public:
auto vRHS = adaptor.rhs().getType().cast<LLVM::LLVMType>();
auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements();
auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements();
auto llvmArrayOfVectType = lowering.convertType(
auto llvmArrayOfVectType = typeConverter.convertType(
cast<vector::OuterProductOp>(op).getResult().getType());
Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
Value a = adaptor.lhs(), b = adaptor.rhs();
@@ -893,7 +895,7 @@ public:
return matchFailure();
MemRefDescriptor sourceMemRef(operands[0]);
auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType)
auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
.dyn_cast_or_null<LLVM::LLVMType>();
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
return matchFailure();
@@ -916,7 +918,7 @@ public:
if (failed(successStrides) || !isContiguous)
return matchFailure();
auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
// Create descriptor.
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
@@ -979,7 +981,7 @@ public:
auto adaptor = vector::PrintOpOperandAdaptor(operands);
Type printType = printOp.getPrintType();
if (lowering.convertType(printType) == nullptr)
if (typeConverter.convertType(printType) == nullptr)
return matchFailure();
// Make sure element type has runtime support (currently just Float/Double).
@@ -1021,10 +1023,10 @@ private:
for (int64_t d = 0; d < dim; ++d) {
auto reducedType =
rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
auto llvmType = lowering.convertType(
auto llvmType = typeConverter.convertType(
rank > 1 ? reducedType : vectorType.getElementType());
Value nestedVal =
extractOne(rewriter, lowering, loc, value, llvmType, rank, d);
extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
if (d != dim - 1)
emitCall(rewriter, loc, printComma);
@@ -1055,36 +1057,36 @@ private:
// Helpers for method names.
Operation *getPrintI32(Operation *op) const {
LLVM::LLVMDialect *dialect = lowering.getDialect();
LLVM::LLVMDialect *dialect = typeConverter.getDialect();
return getPrint(op, dialect, "print_i32",
LLVM::LLVMType::getInt32Ty(dialect));
}
Operation *getPrintI64(Operation *op) const {
LLVM::LLVMDialect *dialect = lowering.getDialect();
LLVM::LLVMDialect *dialect = typeConverter.getDialect();
return getPrint(op, dialect, "print_i64",
LLVM::LLVMType::getInt64Ty(dialect));
}
Operation *getPrintFloat(Operation *op) const {
LLVM::LLVMDialect *dialect = lowering.getDialect();
LLVM::LLVMDialect *dialect = typeConverter.getDialect();
return getPrint(op, dialect, "print_f32",
LLVM::LLVMType::getFloatTy(dialect));
}
Operation *getPrintDouble(Operation *op) const {
LLVM::LLVMDialect *dialect = lowering.getDialect();
LLVM::LLVMDialect *dialect = typeConverter.getDialect();
return getPrint(op, dialect, "print_f64",
LLVM::LLVMType::getDoubleTy(dialect));
}
Operation *getPrintOpen(Operation *op) const {
return getPrint(op, lowering.getDialect(), "print_open", {});
return getPrint(op, typeConverter.getDialect(), "print_open", {});
}
Operation *getPrintClose(Operation *op) const {
return getPrint(op, lowering.getDialect(), "print_close", {});
return getPrint(op, typeConverter.getDialect(), "print_close", {});
}
Operation *getPrintComma(Operation *op) const {
return getPrint(op, lowering.getDialect(), "print_comma", {});
return getPrint(op, typeConverter.getDialect(), "print_comma", {});
}
Operation *getPrintNewline(Operation *op) const {
return getPrint(op, lowering.getDialect(), "print_newline", {});
return getPrint(op, typeConverter.getDialect(), "print_newline", {});
}
};