Files
clang-p2996/mlir/lib/Conversion/ComplexToSPIRV/ComplexToSPIRV.cpp
Lei Zhang c312ce0a1d [mlir][spirv][complex] Convert complex ops to SPIR-V ops
This commit adds conversion from complex construction and
extraction ops to SPIR-V. Other arithemtic ops can be done
via ComplexToStandard patterns.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D147193
2023-03-30 09:13:56 -07:00

93 lines
3.4 KiB
C++

//===- ComplexToSPIRV.cpp - Complex to SPIR-V Patterns --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert Complex dialect to SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRV.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "complex-to-spirv-pattern"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
namespace {
struct CreateOpPattern final : OpConversionPattern<complex::CreateOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(complex::CreateOp createOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type spirvType = getTypeConverter()->convertType(createOp.getType());
if (!spirvType)
return rewriter.notifyMatchFailure(createOp,
"unable to convert result type");
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
createOp, spirvType, adaptor.getOperands());
return success();
}
};
struct ReOpPattern final : OpConversionPattern<complex::ReOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(complex::ReOp reOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type spirvType = getTypeConverter()->convertType(reOp.getType());
if (!spirvType)
return rewriter.notifyMatchFailure(reOp, "unable to convert result type");
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
reOp, adaptor.getComplex(), llvm::ArrayRef(0));
return success();
}
};
struct ImOpPattern final : OpConversionPattern<complex::ImOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(complex::ImOp imOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type spirvType = getTypeConverter()->convertType(imOp.getType());
if (!spirvType)
return rewriter.notifyMatchFailure(imOp, "unable to convert result type");
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
imOp, adaptor.getComplex(), llvm::ArrayRef(1));
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//
void mlir::populateComplexToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<CreateOpPattern, ReOpPattern, ImOpPattern>(typeConverter,
context);
}