//===- Async.cpp - MLIR Async Operations ----------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/IR/DialectImplementation.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::async; #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc" constexpr StringRef AsyncDialect::kAllowedToBlockAttrName; void AsyncDialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" >(); addTypes< #define GET_TYPEDEF_LIST #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc" >(); } //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// static LogicalResult verify(YieldOp op) { // Get the underlying value types from async values returned from the // parent `async.execute` operation. auto executeOp = op->getParentOfType(); auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) { return result.getType().cast().getValueType(); }); if (op.getOperandTypes() != types) return op.emitOpError("operand types do not match the types returned from " "the parent ExecuteOp"); return success(); } MutableOperandRange YieldOp::getMutableSuccessorOperands(Optional index) { assert(!index.hasValue()); return operandsMutable(); } //===----------------------------------------------------------------------===// /// ExecuteOp //===----------------------------------------------------------------------===// constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes"; void ExecuteOp::getNumRegionInvocations( ArrayRef, SmallVectorImpl &countPerRegion) { assert(countPerRegion.empty()); countPerRegion.push_back(1); } OperandRange ExecuteOp::getSuccessorEntryOperands(unsigned index) { assert(index == 0 && "invalid region index"); return operands(); } void ExecuteOp::getSuccessorRegions(Optional index, ArrayRef, SmallVectorImpl ®ions) { // The `body` region branch back to the parent operation. if (index.hasValue()) { assert(*index == 0 && "invalid region index"); regions.push_back(RegionSuccessor(results())); return; } // Otherwise the successor is the body region. regions.push_back(RegionSuccessor(&body(), body().getArguments())); } void ExecuteOp::build(OpBuilder &builder, OperationState &result, TypeRange resultTypes, ValueRange dependencies, ValueRange operands, BodyBuilderFn bodyBuilder) { result.addOperands(dependencies); result.addOperands(operands); // Add derived `operand_segment_sizes` attribute based on parsed operands. int32_t numDependencies = dependencies.size(); int32_t numOperands = operands.size(); auto operandSegmentSizes = DenseIntElementsAttr::get( VectorType::get({2}, builder.getIntegerType(32)), {numDependencies, numOperands}); result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); // First result is always a token, and then `resultTypes` wrapped into // `async.value`. result.addTypes({TokenType::get(result.getContext())}); for (Type type : resultTypes) result.addTypes(ValueType::get(type)); // Add a body region with block arguments as unwrapped async value operands. Region *bodyRegion = result.addRegion(); bodyRegion->push_back(new Block); Block &bodyBlock = bodyRegion->front(); for (Value operand : operands) { auto valueType = operand.getType().dyn_cast(); bodyBlock.addArgument(valueType ? valueType.getValueType() : operand.getType()); } // Create the default terminator if the builder is not provided and if the // expected result is empty. Otherwise, leave this to the caller // because we don't know which values to return from the execute op. if (resultTypes.empty() && !bodyBuilder) { OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(&bodyBlock); builder.create(result.location, ValueRange()); } else if (bodyBuilder) { OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(&bodyBlock); bodyBuilder(builder, result.location, bodyBlock.getArguments()); } } static void print(OpAsmPrinter &p, ExecuteOp op) { // [%tokens,...] if (!op.dependencies().empty()) p << " [" << op.dependencies() << "]"; // (%value as %unwrapped: !async.value, ...) if (!op.operands().empty()) { p << " ("; Block *entry = op.body().empty() ? nullptr : &op.body().front(); llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable { Value argument = entry ? entry->getArgument(n++) : Value(); p << operand << " as " << argument << ": " << operand.getType(); }); p << ")"; } // -> (!async.value, ...) p.printOptionalArrowTypeList(llvm::drop_begin(op.getResultTypes())); p.printOptionalAttrDictWithKeyword(op->getAttrs(), {kOperandSegmentSizesAttr}); p.printRegion(op.body(), /*printEntryBlockArgs=*/false); } static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) { MLIRContext *ctx = result.getContext(); // Sizes of parsed variadic operands, will be updated below after parsing. int32_t numDependencies = 0; int32_t numOperands = 0; auto tokenTy = TokenType::get(ctx); // Parse dependency tokens. if (succeeded(parser.parseOptionalLSquare())) { SmallVector tokenArgs; if (parser.parseOperandList(tokenArgs) || parser.resolveOperands(tokenArgs, tokenTy, result.operands) || parser.parseRSquare()) return failure(); numDependencies = tokenArgs.size(); } // Parse async value operands (%value as %unwrapped : !async.value). SmallVector valueArgs; SmallVector unwrappedArgs; SmallVector valueTypes; SmallVector unwrappedTypes; if (succeeded(parser.parseOptionalLParen())) { auto argsLoc = parser.getCurrentLocation(); // Parse a single instance of `%value as %unwrapped : !async.value`. auto parseAsyncValueArg = [&]() -> ParseResult { if (parser.parseOperand(valueArgs.emplace_back()) || parser.parseKeyword("as") || parser.parseOperand(unwrappedArgs.emplace_back()) || parser.parseColonType(valueTypes.emplace_back())) return failure(); auto valueTy = valueTypes.back().dyn_cast(); unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type()); return success(); }; // If the next token is `)` skip async value arguments parsing. if (failed(parser.parseOptionalRParen())) { do { if (parseAsyncValueArg()) return failure(); } while (succeeded(parser.parseOptionalComma())); if (parser.parseRParen() || parser.resolveOperands(valueArgs, valueTypes, argsLoc, result.operands)) return failure(); } numOperands = valueArgs.size(); } // Add derived `operand_segment_sizes` attribute based on parsed operands. auto operandSegmentSizes = DenseIntElementsAttr::get( VectorType::get({2}, parser.getBuilder().getI32Type()), {numDependencies, numOperands}); result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); // Parse the types of results returned from the async execute op. SmallVector resultTypes; if (parser.parseOptionalArrowTypeList(resultTypes)) return failure(); // Async execute first result is always a completion token. parser.addTypeToList(tokenTy, result.types); parser.addTypesToList(resultTypes, result.types); // Parse operation attributes. NamedAttrList attrs; if (parser.parseOptionalAttrDictWithKeyword(attrs)) return failure(); result.addAttributes(attrs); // Parse asynchronous region. Region *body = result.addRegion(); if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs}, /*argTypes=*/{unwrappedTypes}, /*enableNameShadowing=*/false)) return failure(); return success(); } static LogicalResult verify(ExecuteOp op) { // Unwrap async.execute value operands types. auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) { return operand.getType().cast().getValueType(); }); // Verify that unwrapped argument types matches the body region arguments. if (op.body().getArgumentTypes() != unwrappedTypes) return op.emitOpError("async body region argument types do not match the " "execute operation arguments types"); return success(); } //===----------------------------------------------------------------------===// /// CreateGroupOp //===----------------------------------------------------------------------===// LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op, PatternRewriter &rewriter) { // Find all `await_all` users of the group. llvm::SmallVector awaitAllUsers; auto isAwaitAll = [&](Operation *op) -> bool { if (AwaitAllOp awaitAll = dyn_cast(op)) { awaitAllUsers.push_back(awaitAll); return true; } return false; }; // Check if all users of the group are `await_all` operations. if (!llvm::all_of(op->getUsers(), isAwaitAll)) return failure(); // If group is only awaited without adding anything to it, we can safely erase // the create operation and all users. for (AwaitAllOp awaitAll : awaitAllUsers) rewriter.eraseOp(awaitAll); rewriter.eraseOp(op); return success(); } //===----------------------------------------------------------------------===// /// AwaitOp //===----------------------------------------------------------------------===// void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand, ArrayRef attrs) { result.addOperands({operand}); result.attributes.append(attrs.begin(), attrs.end()); // Add unwrapped async.value type to the returned values types. if (auto valueType = operand.getType().dyn_cast()) result.addTypes(valueType.getValueType()); } static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType, Type &resultType) { if (parser.parseType(operandType)) return failure(); // Add unwrapped async.value type to the returned values types. if (auto valueType = operandType.dyn_cast()) resultType = valueType.getValueType(); return success(); } static void printAwaitResultType(OpAsmPrinter &p, Operation *op, Type operandType, Type resultType) { p << operandType; } static LogicalResult verify(AwaitOp op) { Type argType = op.operand().getType(); // Awaiting on a token does not have any results. if (argType.isa() && !op.getResultTypes().empty()) return op.emitOpError("awaiting on a token must have empty result"); // Awaiting on a value unwraps the async value type. if (auto value = argType.dyn_cast()) { if (*op.getResultType() != value.getValueType()) return op.emitOpError() << "result type " << *op.getResultType() << " does not match async value type " << value.getValueType(); } return success(); } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" //===----------------------------------------------------------------------===// // TableGen'd type method definitions //===----------------------------------------------------------------------===// #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc" void ValueType::print(DialectAsmPrinter &printer) const { printer << getMnemonic(); printer << "<"; printer.printType(getValueType()); printer << '>'; } Type ValueType::parse(mlir::MLIRContext *, mlir::DialectAsmParser &parser) { Type ty; if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) { parser.emitError(parser.getNameLoc(), "failed to parse async value type"); return Type(); } return ValueType::get(ty); } /// Print a type registered to this dialect. void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const { if (failed(generatedTypePrinter(type, os))) llvm_unreachable("unexpected 'async' type kind"); } /// Parse a type registered to this dialect. Type AsyncDialect::parseType(DialectAsmParser &parser) const { StringRef typeTag; if (parser.parseKeyword(&typeTag)) return Type(); Type genType; auto parseResult = generatedTypeParser(parser.getBuilder().getContext(), parser, typeTag, genType); if (parseResult.hasValue()) return genType; parser.emitError(parser.getNameLoc(), "unknown async type: ") << typeTag; return {}; }