Files
clang-p2996/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
Lei Zhang 223be5f630 [mlir][spirv] Perform partial conversion in VectorToSPIRVPass
This allows the pass to participate in progressive lowering
and it also allows us to write tests better.

Along the way, cleaned up the tests.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D115756
2021-12-16 09:35:56 -05:00

63 lines
2.2 KiB
C++

//===- VectorToSPIRVPass.cpp - Vector to SPIR-V Passes --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert Vector dialect to SPIRV dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h"
#include "../PassDetail.h"
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
namespace {
struct ConvertVectorToSPIRVPass
: public ConvertVectorToSPIRVBase<ConvertVectorToSPIRVPass> {
void runOnOperation() override;
};
} // namespace
void ConvertVectorToSPIRVPass::runOnOperation() {
MLIRContext *context = &getContext();
ModuleOp module = getOperation();
auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
// Use UnrealizedConversionCast as the bridge so that we don't need to pull in
// patterns for other dialects.
auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs,
Location loc) {
auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
return Optional<Value>(cast.getResult(0));
};
typeConverter.addSourceMaterialization(addUnrealizedCast);
typeConverter.addTargetMaterialization(addUnrealizedCast);
target->addLegalOp<UnrealizedConversionCastOp>();
RewritePatternSet patterns(context);
populateVectorToSPIRVPatterns(typeConverter, patterns);
if (failed(applyPartialConversion(module, *target, std::move(patterns))))
return signalPassFailure();
}
std::unique_ptr<OperationPass<ModuleOp>>
mlir::createConvertVectorToSPIRVPass() {
return std::make_unique<ConvertVectorToSPIRVPass>();
}