[mlir] NFC: Rename LLVMOpLowering::lowering to LLVMOpLowering::typeConverter
The existing name is an artifact dating back to the times when we did not have a dedicated TypeConverter infrastructure. It is also confusing with with the name of classes using it. Differential revision: https://reviews.llvm.org/D74707
This commit is contained in:
@@ -36,8 +36,8 @@ using namespace mlir::vector;
|
||||
|
||||
template <typename T>
|
||||
static LLVM::LLVMType getPtrToElementType(T containerType,
|
||||
LLVMTypeConverter &lowering) {
|
||||
return lowering.convertType(containerType.getElementType())
|
||||
LLVMTypeConverter &typeConverter) {
|
||||
return typeConverter.convertType(containerType.getElementType())
|
||||
.template cast<LLVM::LLVMType>()
|
||||
.getPointerTo();
|
||||
}
|
||||
@@ -56,12 +56,13 @@ static VectorType reducedVectorTypeBack(VectorType tp) {
|
||||
|
||||
// Helper that picks the proper sequence for inserting.
|
||||
static Value insertOne(ConversionPatternRewriter &rewriter,
|
||||
LLVMTypeConverter &lowering, Location loc, Value val1,
|
||||
Value val2, Type llvmType, int64_t rank, int64_t pos) {
|
||||
LLVMTypeConverter &typeConverter, Location loc,
|
||||
Value val1, Value val2, Type llvmType, int64_t rank,
|
||||
int64_t pos) {
|
||||
if (rank == 1) {
|
||||
auto idxType = rewriter.getIndexType();
|
||||
auto constant = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, lowering.convertType(idxType),
|
||||
loc, typeConverter.convertType(idxType),
|
||||
rewriter.getIntegerAttr(idxType, pos));
|
||||
return rewriter.create<LLVM::InsertElementOp>(loc, llvmType, val1, val2,
|
||||
constant);
|
||||
@@ -83,12 +84,12 @@ static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
|
||||
|
||||
// Helper that picks the proper sequence for extracting.
|
||||
static Value extractOne(ConversionPatternRewriter &rewriter,
|
||||
LLVMTypeConverter &lowering, Location loc, Value val,
|
||||
Type llvmType, int64_t rank, int64_t pos) {
|
||||
LLVMTypeConverter &typeConverter, Location loc,
|
||||
Value val, Type llvmType, int64_t rank, int64_t pos) {
|
||||
if (rank == 1) {
|
||||
auto idxType = rewriter.getIndexType();
|
||||
auto constant = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, lowering.convertType(idxType),
|
||||
loc, typeConverter.convertType(idxType),
|
||||
rewriter.getIntegerAttr(idxType, pos));
|
||||
return rewriter.create<LLVM::ExtractElementOp>(loc, llvmType, val,
|
||||
constant);
|
||||
@@ -137,7 +138,7 @@ public:
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto broadcastOp = cast<vector::BroadcastOp>(op);
|
||||
VectorType dstVectorType = broadcastOp.getVectorType();
|
||||
if (lowering.convertType(dstVectorType) == nullptr)
|
||||
if (typeConverter.convertType(dstVectorType) == nullptr)
|
||||
return matchFailure();
|
||||
// Rewrite when the full vector type can be lowered (which
|
||||
// implies all 'reduced' types can be lowered too).
|
||||
@@ -203,12 +204,12 @@ private:
|
||||
Value duplicateOneRank(Value value, Location loc, VectorType srcVectorType,
|
||||
VectorType dstVectorType, int64_t rank, int64_t dim,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Type llvmType = lowering.convertType(dstVectorType);
|
||||
Type llvmType = typeConverter.convertType(dstVectorType);
|
||||
assert((llvmType != nullptr) && "unlowerable vector type");
|
||||
if (rank == 1) {
|
||||
Value undef = rewriter.create<LLVM::UndefOp>(loc, llvmType);
|
||||
Value expand =
|
||||
insertOne(rewriter, lowering, loc, undef, value, llvmType, rank, 0);
|
||||
Value expand = insertOne(rewriter, typeConverter, loc, undef, value,
|
||||
llvmType, rank, 0);
|
||||
SmallVector<int32_t, 4> zeroValues(dim, 0);
|
||||
return rewriter.create<LLVM::ShuffleVectorOp>(
|
||||
loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues));
|
||||
@@ -217,8 +218,8 @@ private:
|
||||
reducedVectorTypeFront(dstVectorType), rewriter);
|
||||
Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
|
||||
for (int64_t d = 0; d < dim; ++d) {
|
||||
result =
|
||||
insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d);
|
||||
result = insertOne(rewriter, typeConverter, loc, result, expand, llvmType,
|
||||
rank, d);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -243,31 +244,32 @@ private:
|
||||
Value stretchOneRank(Value value, Location loc, VectorType srcVectorType,
|
||||
VectorType dstVectorType, int64_t rank, int64_t dim,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Type llvmType = lowering.convertType(dstVectorType);
|
||||
Type llvmType = typeConverter.convertType(dstVectorType);
|
||||
assert((llvmType != nullptr) && "unlowerable vector type");
|
||||
Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
|
||||
bool atStretch = dim != srcVectorType.getDimSize(0);
|
||||
if (rank == 1) {
|
||||
assert(atStretch);
|
||||
Type redLlvmType = lowering.convertType(dstVectorType.getElementType());
|
||||
Type redLlvmType =
|
||||
typeConverter.convertType(dstVectorType.getElementType());
|
||||
Value one =
|
||||
extractOne(rewriter, lowering, loc, value, redLlvmType, rank, 0);
|
||||
Value expand =
|
||||
insertOne(rewriter, lowering, loc, result, one, llvmType, rank, 0);
|
||||
extractOne(rewriter, typeConverter, loc, value, redLlvmType, rank, 0);
|
||||
Value expand = insertOne(rewriter, typeConverter, loc, result, one,
|
||||
llvmType, rank, 0);
|
||||
SmallVector<int32_t, 4> zeroValues(dim, 0);
|
||||
return rewriter.create<LLVM::ShuffleVectorOp>(
|
||||
loc, expand, result, rewriter.getI32ArrayAttr(zeroValues));
|
||||
}
|
||||
VectorType redSrcType = reducedVectorTypeFront(srcVectorType);
|
||||
VectorType redDstType = reducedVectorTypeFront(dstVectorType);
|
||||
Type redLlvmType = lowering.convertType(redSrcType);
|
||||
Type redLlvmType = typeConverter.convertType(redSrcType);
|
||||
for (int64_t d = 0; d < dim; ++d) {
|
||||
int64_t pos = atStretch ? 0 : d;
|
||||
Value one =
|
||||
extractOne(rewriter, lowering, loc, value, redLlvmType, rank, pos);
|
||||
Value one = extractOne(rewriter, typeConverter, loc, value, redLlvmType,
|
||||
rank, pos);
|
||||
Value expand = expandRanks(one, loc, redSrcType, redDstType, rewriter);
|
||||
result =
|
||||
insertOne(rewriter, lowering, loc, result, expand, llvmType, rank, d);
|
||||
result = insertOne(rewriter, typeConverter, loc, result, expand, llvmType,
|
||||
rank, d);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -286,7 +288,7 @@ public:
|
||||
auto reductionOp = cast<vector::ReductionOp>(op);
|
||||
auto kind = reductionOp.kind();
|
||||
Type eltType = reductionOp.dest().getType();
|
||||
Type llvmType = lowering.convertType(eltType);
|
||||
Type llvmType = typeConverter.convertType(eltType);
|
||||
if (eltType.isInteger(32) || eltType.isInteger(64)) {
|
||||
// Integer reductions: add/mul/min/max/and/or/xor.
|
||||
if (kind == "add")
|
||||
@@ -353,7 +355,7 @@ public:
|
||||
auto reductionOp = cast<vector::ReductionV2Op>(op);
|
||||
auto kind = reductionOp.kind();
|
||||
Type eltType = reductionOp.dest().getType();
|
||||
Type llvmType = lowering.convertType(eltType);
|
||||
Type llvmType = typeConverter.convertType(eltType);
|
||||
if (kind == "add") {
|
||||
rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>(
|
||||
op, llvmType, operands[1], operands[0]);
|
||||
@@ -383,7 +385,7 @@ public:
|
||||
auto v1Type = shuffleOp.getV1VectorType();
|
||||
auto v2Type = shuffleOp.getV2VectorType();
|
||||
auto vectorType = shuffleOp.getVectorType();
|
||||
Type llvmType = lowering.convertType(vectorType);
|
||||
Type llvmType = typeConverter.convertType(vectorType);
|
||||
auto maskArrayAttr = shuffleOp.mask();
|
||||
|
||||
// Bail if result type cannot be lowered.
|
||||
@@ -415,10 +417,10 @@ public:
|
||||
extPos -= v1Dim;
|
||||
value = adaptor.v2();
|
||||
}
|
||||
Value extract =
|
||||
extractOne(rewriter, lowering, loc, value, llvmType, rank, extPos);
|
||||
insert = insertOne(rewriter, lowering, loc, insert, extract, llvmType,
|
||||
rank, insPos++);
|
||||
Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType,
|
||||
rank, extPos);
|
||||
insert = insertOne(rewriter, typeConverter, loc, insert, extract,
|
||||
llvmType, rank, insPos++);
|
||||
}
|
||||
rewriter.replaceOp(op, insert);
|
||||
return matchSuccess();
|
||||
@@ -438,7 +440,7 @@ public:
|
||||
auto adaptor = vector::ExtractElementOpOperandAdaptor(operands);
|
||||
auto extractEltOp = cast<vector::ExtractElementOp>(op);
|
||||
auto vectorType = extractEltOp.getVectorType();
|
||||
auto llvmType = lowering.convertType(vectorType.getElementType());
|
||||
auto llvmType = typeConverter.convertType(vectorType.getElementType());
|
||||
|
||||
// Bail if result type cannot be lowered.
|
||||
if (!llvmType)
|
||||
@@ -465,7 +467,7 @@ public:
|
||||
auto extractOp = cast<vector::ExtractOp>(op);
|
||||
auto vectorType = extractOp.getVectorType();
|
||||
auto resultType = extractOp.getResult().getType();
|
||||
auto llvmResultType = lowering.convertType(resultType);
|
||||
auto llvmResultType = typeConverter.convertType(resultType);
|
||||
auto positionArrayAttr = extractOp.position();
|
||||
|
||||
// Bail if result type cannot be lowered.
|
||||
@@ -489,13 +491,13 @@ public:
|
||||
auto nMinusOnePositionAttrs =
|
||||
ArrayAttr::get(positionAttrs.drop_back(), context);
|
||||
extracted = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, lowering.convertType(oneDVectorType), extracted,
|
||||
loc, typeConverter.convertType(oneDVectorType), extracted,
|
||||
nMinusOnePositionAttrs);
|
||||
}
|
||||
|
||||
// Remaining extraction of element from 1-D LLVM vector
|
||||
auto position = positionAttrs.back().cast<IntegerAttr>();
|
||||
auto i64Type = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
|
||||
auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
|
||||
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
|
||||
extracted =
|
||||
rewriter.create<LLVM::ExtractElementOp>(loc, extracted, constant);
|
||||
@@ -553,7 +555,7 @@ public:
|
||||
auto adaptor = vector::InsertElementOpOperandAdaptor(operands);
|
||||
auto insertEltOp = cast<vector::InsertElementOp>(op);
|
||||
auto vectorType = insertEltOp.getDestVectorType();
|
||||
auto llvmType = lowering.convertType(vectorType);
|
||||
auto llvmType = typeConverter.convertType(vectorType);
|
||||
|
||||
// Bail if result type cannot be lowered.
|
||||
if (!llvmType)
|
||||
@@ -580,7 +582,7 @@ public:
|
||||
auto insertOp = cast<vector::InsertOp>(op);
|
||||
auto sourceType = insertOp.getSourceType();
|
||||
auto destVectorType = insertOp.getDestVectorType();
|
||||
auto llvmResultType = lowering.convertType(destVectorType);
|
||||
auto llvmResultType = typeConverter.convertType(destVectorType);
|
||||
auto positionArrayAttr = insertOp.position();
|
||||
|
||||
// Bail if result type cannot be lowered.
|
||||
@@ -607,16 +609,16 @@ public:
|
||||
auto nMinusOnePositionAttrs =
|
||||
ArrayAttr::get(positionAttrs.drop_back(), context);
|
||||
extracted = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, lowering.convertType(oneDVectorType), extracted,
|
||||
loc, typeConverter.convertType(oneDVectorType), extracted,
|
||||
nMinusOnePositionAttrs);
|
||||
}
|
||||
|
||||
// Insertion of an element into a 1-D LLVM vector.
|
||||
auto i64Type = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
|
||||
auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
|
||||
auto constant = rewriter.create<LLVM::ConstantOp>(loc, i64Type, position);
|
||||
Value inserted = rewriter.create<LLVM::InsertElementOp>(
|
||||
loc, lowering.convertType(oneDVectorType), extracted, adaptor.source(),
|
||||
constant);
|
||||
loc, typeConverter.convertType(oneDVectorType), extracted,
|
||||
adaptor.source(), constant);
|
||||
|
||||
// Potential insertion of resulting 1-D vector into array.
|
||||
if (positionAttrs.size() > 1) {
|
||||
@@ -830,7 +832,7 @@ public:
|
||||
auto vRHS = adaptor.rhs().getType().cast<LLVM::LLVMType>();
|
||||
auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements();
|
||||
auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements();
|
||||
auto llvmArrayOfVectType = lowering.convertType(
|
||||
auto llvmArrayOfVectType = typeConverter.convertType(
|
||||
cast<vector::OuterProductOp>(op).getResult().getType());
|
||||
Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
|
||||
Value a = adaptor.lhs(), b = adaptor.rhs();
|
||||
@@ -893,7 +895,7 @@ public:
|
||||
return matchFailure();
|
||||
MemRefDescriptor sourceMemRef(operands[0]);
|
||||
|
||||
auto llvmTargetDescriptorTy = lowering.convertType(targetMemRefType)
|
||||
auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
|
||||
.dyn_cast_or_null<LLVM::LLVMType>();
|
||||
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
|
||||
return matchFailure();
|
||||
@@ -916,7 +918,7 @@ public:
|
||||
if (failed(successStrides) || !isContiguous)
|
||||
return matchFailure();
|
||||
|
||||
auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
|
||||
auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect());
|
||||
|
||||
// Create descriptor.
|
||||
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
|
||||
@@ -979,7 +981,7 @@ public:
|
||||
auto adaptor = vector::PrintOpOperandAdaptor(operands);
|
||||
Type printType = printOp.getPrintType();
|
||||
|
||||
if (lowering.convertType(printType) == nullptr)
|
||||
if (typeConverter.convertType(printType) == nullptr)
|
||||
return matchFailure();
|
||||
|
||||
// Make sure element type has runtime support (currently just Float/Double).
|
||||
@@ -1021,10 +1023,10 @@ private:
|
||||
for (int64_t d = 0; d < dim; ++d) {
|
||||
auto reducedType =
|
||||
rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
|
||||
auto llvmType = lowering.convertType(
|
||||
auto llvmType = typeConverter.convertType(
|
||||
rank > 1 ? reducedType : vectorType.getElementType());
|
||||
Value nestedVal =
|
||||
extractOne(rewriter, lowering, loc, value, llvmType, rank, d);
|
||||
extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d);
|
||||
emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
|
||||
if (d != dim - 1)
|
||||
emitCall(rewriter, loc, printComma);
|
||||
@@ -1055,36 +1057,36 @@ private:
|
||||
|
||||
// Helpers for method names.
|
||||
Operation *getPrintI32(Operation *op) const {
|
||||
LLVM::LLVMDialect *dialect = lowering.getDialect();
|
||||
LLVM::LLVMDialect *dialect = typeConverter.getDialect();
|
||||
return getPrint(op, dialect, "print_i32",
|
||||
LLVM::LLVMType::getInt32Ty(dialect));
|
||||
}
|
||||
Operation *getPrintI64(Operation *op) const {
|
||||
LLVM::LLVMDialect *dialect = lowering.getDialect();
|
||||
LLVM::LLVMDialect *dialect = typeConverter.getDialect();
|
||||
return getPrint(op, dialect, "print_i64",
|
||||
LLVM::LLVMType::getInt64Ty(dialect));
|
||||
}
|
||||
Operation *getPrintFloat(Operation *op) const {
|
||||
LLVM::LLVMDialect *dialect = lowering.getDialect();
|
||||
LLVM::LLVMDialect *dialect = typeConverter.getDialect();
|
||||
return getPrint(op, dialect, "print_f32",
|
||||
LLVM::LLVMType::getFloatTy(dialect));
|
||||
}
|
||||
Operation *getPrintDouble(Operation *op) const {
|
||||
LLVM::LLVMDialect *dialect = lowering.getDialect();
|
||||
LLVM::LLVMDialect *dialect = typeConverter.getDialect();
|
||||
return getPrint(op, dialect, "print_f64",
|
||||
LLVM::LLVMType::getDoubleTy(dialect));
|
||||
}
|
||||
Operation *getPrintOpen(Operation *op) const {
|
||||
return getPrint(op, lowering.getDialect(), "print_open", {});
|
||||
return getPrint(op, typeConverter.getDialect(), "print_open", {});
|
||||
}
|
||||
Operation *getPrintClose(Operation *op) const {
|
||||
return getPrint(op, lowering.getDialect(), "print_close", {});
|
||||
return getPrint(op, typeConverter.getDialect(), "print_close", {});
|
||||
}
|
||||
Operation *getPrintComma(Operation *op) const {
|
||||
return getPrint(op, lowering.getDialect(), "print_comma", {});
|
||||
return getPrint(op, typeConverter.getDialect(), "print_comma", {});
|
||||
}
|
||||
Operation *getPrintNewline(Operation *op) const {
|
||||
return getPrint(op, lowering.getDialect(), "print_newline", {});
|
||||
return getPrint(op, typeConverter.getDialect(), "print_newline", {});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user