//===- XeGPUDialect.cpp - MLIR XeGPU dialect implementation -----*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "llvm/ADT/TypeSwitch.h" namespace mlir { namespace xegpu { void XeGPUDialect::initialize() { addTypes< #define GET_TYPEDEF_LIST #include >(); addOperations< #define GET_OP_LIST #include >(); addAttributes< #define GET_ATTRDEF_LIST #include >(); } //===----------------------------------------------------------------------===// // XeGPU_TensorDescAttr //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// // XeGPU_TensorDescType //===----------------------------------------------------------------------===// mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) { llvm::SmallVector shape; mlir::Type elementType; mlir::FailureOr encoding; // Parse literal '<' if (parser.parseLess()) return {}; auto shapeLoc = parser.getCurrentLocation(); if (mlir::failed(parser.parseDimensionList(shape))) { parser.emitError(shapeLoc, "failed to parse parameter 'shape'"); return {}; } auto elemTypeLoc = parser.getCurrentLocation(); if (mlir::failed(parser.parseType(elementType))) { parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'"); return {}; } // parse optional attributes if (mlir::succeeded(parser.parseOptionalComma())) { encoding = mlir::FieldParser::parse(parser); if (mlir::failed(encoding)) { parser.emitError( parser.getCurrentLocation(), "Failed to parse the attribute field for TensorDescType.\n"); return {}; } } // Parse literal '>' if (parser.parseGreater()) return {}; return TensorDescType::get(parser.getContext(), shape, elementType, encoding.value_or(mlir::Attribute())); } void TensorDescType::print(::mlir::AsmPrinter &printer) const { printer << "<"; auto shape = getShape(); for (int64_t dim : shape) { if (mlir::ShapedType::isDynamic(dim)) printer << '?'; else printer << dim; printer << 'x'; } printer << getElementType(); if (auto encoding = getEncoding()) printer << ", " << encoding; printer << ">"; } } // namespace xegpu } // namespace mlir #include #define GET_ATTRDEF_CLASSES #include #define GET_TYPEDEF_CLASSES #include