diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h index 4d8516c1aa01..b5bf8f2f53f0 100644 --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -23,12 +23,229 @@ #ifndef MLIR_IR_AFFINE_EXPR_H #define MLIR_IR_AFFINE_EXPR_H +#include "mlir/Support/LLVM.h" + namespace mlir { -class AffineExpr { +class MLIRContext; + +/// A one-dimensional affine expression. +/// AffineExpression's are immutable (like Type's) +class AffineExpr { public: - AffineExpr(); - // TODO(andydavis,bondhugula) Implement affine expressions. + enum class Kind { + // Add. + Add, + // Mul. + Mul, + // Mod. + Mod, + // Floordiv + FloorDiv, + // Ceildiv + CeilDiv, + + /// This is a marker for the last affine binary op. The range of binary op's + /// is expected to be this element and earlier. + LAST_AFFINE_BINARY_OP = CeilDiv, + + // Unary op negation + Neg, + + // Constant integer. + Constant, + // Dimensional identifier. + DimId, + // Symbolic identifier. + SymbolId, + }; + + /// Return the classification for this type. + Kind getKind() const { return kind; } + + ~AffineExpr() = default; + + void print(raw_ostream &os) const; + void dump() const; + + protected: + explicit AffineExpr(Kind kind) : kind(kind) {} + + private: + /// Classification of the subclass + const Kind kind; +}; + +/// Binary affine expression. +class AffineBinaryOpExpr : public AffineExpr { + public: + static AffineBinaryOpExpr *get(Kind kind, AffineExpr *lhsOperand, + AffineExpr *rhsOperand, MLIRContext *context); + + AffineExpr *getLeftOperand() const { return lhsOperand; } + AffineExpr *getRightOperand() const { return rhsOperand; } + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(const AffineExpr *expr) { + return expr->getKind() <= Kind::LAST_AFFINE_BINARY_OP; + } + + protected: + explicit AffineBinaryOpExpr(Kind kind, AffineExpr *lhsOperand, + AffineExpr *rhsOperand) + : AffineExpr(kind), lhsOperand(lhsOperand), rhsOperand(rhsOperand) {} + + AffineExpr *const lhsOperand; + AffineExpr *const rhsOperand; +}; + +/// Binary affine add expression. +class AffineAddExpr : public AffineBinaryOpExpr { + public: + static AffineAddExpr *get(AffineExpr *lhsOperand, AffineExpr *rhsOperand, + MLIRContext *context); + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(const AffineExpr *expr) { + return expr->getKind() == Kind::Add; + } + + private: + explicit AffineAddExpr(AffineExpr *lhsOperand, AffineExpr *rhsOperand) + : AffineBinaryOpExpr(Kind::Add, lhsOperand, rhsOperand) {} +}; + +/// Binary affine mul expression. +class AffineMulExpr : public AffineBinaryOpExpr { + public: + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(const AffineExpr *expr) { + return expr->getKind() == Kind::Mul; + } + + private: + explicit AffineMulExpr(AffineExpr *lhsOperand, AffineExpr *rhsOperand) + : AffineBinaryOpExpr(Kind::Mul, lhsOperand, rhsOperand) {} +}; + +/// Binary affine mod expression. +class AffineModExpr : public AffineBinaryOpExpr { + public: + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(const AffineExpr *expr) { + return expr->getKind() == Kind::Mod; + } + + private: + explicit AffineModExpr(AffineExpr *lhsOperand, AffineExpr *rhsOperand) + : AffineBinaryOpExpr(Kind::Mod, lhsOperand, rhsOperand) {} +}; + +/// Binary affine floordiv expression. +class AffineFloorDivExpr : public AffineBinaryOpExpr { + public: + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(const AffineExpr *expr) { + return expr->getKind() == Kind::FloorDiv; + } + + private: + explicit AffineFloorDivExpr(AffineExpr *lhsOperand, AffineExpr *rhsOperand) + : AffineBinaryOpExpr(Kind::FloorDiv, lhsOperand, rhsOperand) {} +}; + +/// Binary affine ceildiv expression. +class AffineCeilDivExpr : public AffineBinaryOpExpr { + public: + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(const AffineExpr *expr) { + return expr->getKind() == Kind::CeilDiv; + } + + private: + explicit AffineCeilDivExpr(AffineExpr *lhsOperand, AffineExpr *rhsOperand) + : AffineBinaryOpExpr(Kind::CeilDiv, lhsOperand, rhsOperand) {} +}; + +/// Unary affine expression. +class AffineUnaryOpExpr : public AffineExpr { + public: + static AffineUnaryOpExpr *get(const AffineExpr &operand, + MLIRContext *context); + + static AffineUnaryOpExpr *get(const AffineExpr &operand); + AffineExpr *getOperand() const { return operand; } + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(const AffineExpr *expr) { + return expr->getKind() == Kind::Neg; + } + + private: + explicit AffineUnaryOpExpr(Kind kind, AffineExpr *operand) + : AffineExpr(kind), operand(operand) {} + + AffineExpr *operand; +}; + +/// A argument identifier appearing in an affine expression +class AffineDimExpr : public AffineExpr { + public: + static AffineDimExpr *get(unsigned position, MLIRContext *context); + + unsigned getPosition() const { return position; } + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(const AffineExpr *expr) { + return expr->getKind() == Kind::DimId; + } + + private: + explicit AffineDimExpr(unsigned position) + : AffineExpr(Kind::DimId), position(position) {} + + /// Position of this identifier in the argument list. + unsigned position; +}; + +/// A symbolic identifier appearing in an affine expression +class AffineSymbolExpr : public AffineExpr { + public: + static AffineSymbolExpr *get(unsigned position, MLIRContext *context); + + unsigned getPosition() const { return position; } + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(const AffineExpr *expr) { + return expr->getKind() == Kind::SymbolId; + } + + private: + explicit AffineSymbolExpr(unsigned position) + : AffineExpr(Kind::SymbolId), position(position) {} + + /// Position of this identifier in the symbol list. + unsigned position; +}; + +/// An integer constant appearing in affine expression. +class AffineConstantExpr : public AffineExpr { + public: + static AffineConstantExpr *get(int64_t constant, MLIRContext *context); + + int64_t getValue() const { return constant; } + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(const AffineExpr *expr) { + return expr->getKind() == Kind::Constant; + } + + private: + explicit AffineConstantExpr(int64_t constant) + : AffineExpr(Kind::Constant), constant(constant) {} + + // The constant. + int64_t constant; }; } // end namespace mlir diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h index af1df9687fe0..626605059a89 100644 --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -26,30 +26,43 @@ #include #include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" namespace mlir { +class MLIRContext; class AffineExpr; -class AffineMap { +/// A multi-dimensional affine map +/// Affine map's are immutable like Type's, and they are uniqued. +/// Eg: (d0, d1) -> (d0/128, d0 mod 128, d1) +/// The names used (d0, d1) don't matter - it's the mathematical function that +/// is unique to this affine map. +class AffineMap { public: - // Constructs an AffineMap with 'dimCount' dimension identifiers, and - // 'symbolCount' symbols. - // TODO(andydavis) Pass in ArrayRef to populate list of exprs. - AffineMap(unsigned dimCount, unsigned symbolCount); + static AffineMap *get(unsigned dimCount, unsigned symbolCount, + ArrayRef exprs, + MLIRContext *context); // Prints affine map to 'os'. void print(raw_ostream &os) const; + void dump() const; + + unsigned dimCount() const { return numDims; } + unsigned symbolCount() const { return numSymbols; } private: - // Number of dimensional indentifiers. - const unsigned dimCount; - // Number of symbols. - const unsigned symbolCount; - // TODO(andydavis) Do not use std::vector here (array size is not dynamic). - std::vector exprs; + AffineMap(unsigned dimCount, unsigned symbolCount, + ArrayRef exprs); + + const unsigned numDims; + const unsigned numSymbols; + + /// The affine expressions for this (multi-dimensional) map. + /// TODO: use trailing objects for these + ArrayRef exprs; }; -} // end namespace mlir +} // end namespace mlir #endif // MLIR_IR_AFFINE_MAP_H diff --git a/mlir/include/mlir/IR/Module.h b/mlir/include/mlir/IR/Module.h index bf7707812a94..7d22efd96e20 100644 --- a/mlir/include/mlir/IR/Module.h +++ b/mlir/include/mlir/IR/Module.h @@ -22,10 +22,13 @@ #ifndef MLIR_IR_MODULE_H #define MLIR_IR_MODULE_H +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Function.h" -#include namespace mlir { + +class AffineMap; + class Module { public: explicit Module(); @@ -33,6 +36,9 @@ public: // FIXME: wrong representation and API. std::vector functionList; + // FIXME: wrong representation and API. + // These affine maps are immutable + std::vector affineMapList; void print(raw_ostream &os) const; void dump() const; diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index 63d8d7df2c38..87901f03c77a 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -18,6 +18,3 @@ #include "mlir/IR/AffineExpr.h" using namespace mlir; - -AffineExpr::AffineExpr() { -} diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp index 83d1d2359cf7..863175124708 100644 --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -17,15 +17,13 @@ #include "mlir/IR/AffineMap.h" #include "llvm/ADT/StringRef.h" -#include "llvm/Support/raw_ostream.h" using namespace mlir; -AffineMap::AffineMap(unsigned dimCount, unsigned symbolCount) - : dimCount(dimCount), symbolCount(symbolCount) { -} - -void AffineMap::print(raw_ostream &os) const { - // TODO(andydavis) Print out affine map based on dimensionCount and - // symbolCount: (d0, d1) [S0, S1] -> (d0 + S0, d1 + S1) +// TODO(clattner): make this ctor take an LLVMContext. This will eventually +// copy the elements into the context. +AffineMap::AffineMap(unsigned dimCount, unsigned symbolCount, + ArrayRef exprs) + : numDims(dimCount), numSymbols(symbolCount), exprs(exprs) { + // TODO(bondhugula) } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 639000b53d30..067697cdf064 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -20,13 +20,15 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/CFGFunction.h" #include "mlir/IR/MLFunction.h" #include "mlir/IR/Module.h" #include "mlir/IR/Types.h" #include "mlir/Support/STLExtras.h" -#include "llvm/Support/raw_ostream.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/Support/raw_ostream.h" using namespace mlir; @@ -162,6 +164,15 @@ void Instruction::dump() const { print(llvm::errs()); } +void AffineExpr::print(raw_ostream &os) const { + // TODO(bondhugula): print out affine expression +} + +void AffineMap::print(raw_ostream &os) const { + // TODO(andydavis) Print out affine map based on dimensionCount and + // symbolCount: (d0, d1) [S0, S1] -> (d0 + S0, d1 + S1) +} + void BasicBlock::print(raw_ostream &os) const { CFGFunctionState state(getFunction(), os); state.print(); @@ -208,6 +219,8 @@ void MLFunction::print(raw_ostream &os) const { } void Module::print(raw_ostream &os) const { + for (auto *map : affineMapList) + map->print(os); for (auto *fn : functionList) fn->print(os); } diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 5f2bd8ec7640..7c1112bc2300 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -17,6 +17,8 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Identifier.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseSet.h" @@ -44,6 +46,23 @@ struct FunctionTypeKeyInfo : DenseMapInfo { return lhs == KeyTy(rhs->getInputs(), rhs->getResults()); } }; +struct AffineMapKeyInfo : DenseMapInfo { + // Affine maps are uniqued based on their arguments and affine expressions + using KeyTy = std::pair; + using DenseMapInfo::getHashValue; + using DenseMapInfo::isEqual; + + static unsigned getHashValue(KeyTy key) { + // FIXME(bondhugula): placeholder for now + return hash_combine(key.first, key.second); + } + + static bool isEqual(const KeyTy &lhs, const FunctionType *rhs) { + // TODO(bondhugula) + return false; + } +}; + struct VectorTypeKeyInfo : DenseMapInfo { // Vectors are uniqued based on their element type and shape. using KeyTy = std::pair>; @@ -97,6 +116,10 @@ public: // Primitive type uniquing. PrimitiveType *primitives[int(TypeKind::LAST_PRIMITIVE_TYPE)+1] = { nullptr }; + // Affine map uniquing. + using AffineMapSet = DenseSet; + AffineMapSet affineMaps; + /// Function type uniquing. using FunctionTypeSet = DenseSet; FunctionTypeSet functions; @@ -316,3 +339,52 @@ UnrankedTensorType *UnrankedTensorType::get(Type *elementType) { // Cache and return it. return existing.first->second = result; } + +// TODO(bondhugula,andydavis): unique affine maps based on dim list, +// symbol list and all affine expressions contained +AffineMap *AffineMap::get(unsigned dimCount, + unsigned symbolCount, + ArrayRef exprs, + MLIRContext *context) { + // TODO(bondhugula) + return new AffineMap(dimCount, symbolCount, exprs); +} + +AffineBinaryOpExpr *AffineBinaryOpExpr::get(AffineExpr::Kind kind, + AffineExpr *lhsOperand, + AffineExpr *rhsOperand, + MLIRContext *context) { + // TODO(bondhugula): allocate this through context + // FIXME + return new AffineBinaryOpExpr(kind, lhsOperand, rhsOperand); +} + +AffineAddExpr *AffineAddExpr::get(AffineExpr *lhsOperand, + AffineExpr *rhsOperand, + MLIRContext *context) { + // TODO(bondhugula): allocate this through context + // FIXME + return new AffineAddExpr(lhsOperand, rhsOperand); +} + +// TODO(bondhugula): add functions for AffineMulExpr, mod, floordiv, ceildiv + +AffineDimExpr *AffineDimExpr::get(unsigned position, MLIRContext *context) { + // TODO(bondhugula): complete this + // FIXME: this should be POD + return new AffineDimExpr(position); +} + +AffineSymbolExpr *AffineSymbolExpr::get(unsigned position, + MLIRContext *context) { + // TODO(bondhugula): complete this + // FIXME: this should be POD + return new AffineSymbolExpr(position); +} + +AffineConstantExpr *AffineConstantExpr::get(int64_t constant, + MLIRContext *context) { + // TODO(bondhugula): complete this + // FIXME: this should be POD + return new AffineConstantExpr(constant); +} diff --git a/mlir/lib/Parser/Lexer.cpp b/mlir/lib/Parser/Lexer.cpp index e3e5e9ec358d..17755e0291f9 100644 --- a/mlir/lib/Parser/Lexer.cpp +++ b/mlir/lib/Parser/Lexer.cpp @@ -78,9 +78,14 @@ Token Lexer::lexToken() { case ')': return formToken(Token::r_paren, tokStart); case '{': return formToken(Token::l_brace, tokStart); case '}': return formToken(Token::r_brace, tokStart); + case '[': return formToken(Token::l_bracket, tokStart); + case ']': return formToken(Token::r_bracket, tokStart); case '<': return formToken(Token::less, tokStart); case '>': return formToken(Token::greater, tokStart); + case '=': return formToken(Token::equal, tokStart); + case '+': return formToken(Token::plus, tokStart); + case '*': return formToken(Token::star, tokStart); case '-': if (*curPtr == '>') { ++curPtr; @@ -246,3 +251,4 @@ Token Lexer::lexString(const char *tokStart) { } } } + diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index c62ee5d14f63..692705018aa8 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -18,12 +18,14 @@ // This file implements the parser for the MLIR textual form. // //===----------------------------------------------------------------------===// +#include #include "mlir/Parser.h" #include "Lexer.h" +#include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" -#include "mlir/IR/Module.h" #include "mlir/IR/CFGFunction.h" +#include "mlir/IR/Module.h" #include "mlir/IR/MLFunction.h" #include "mlir/IR/Types.h" #include "llvm/Support/SourceMgr.h" @@ -33,6 +35,7 @@ using llvm::SMLoc; namespace { class CFGFunctionParserState; +class AffineMapParserState; /// Simple enum to make code read better in cases that would otherwise return a /// bool value. Failure is "true" in a boolean context. @@ -128,8 +131,19 @@ private: Type *parseType(); ParseResult parseTypeList(SmallVectorImpl &elements); + // Identifiers + ParseResult parseDimIdList(SmallVectorImpl &dims, + SmallVectorImpl &symbols); + ParseResult parseSymbolIdList(SmallVectorImpl &dims, + SmallVectorImpl &symbols); + StringRef parseDimOrSymbolId(SmallVectorImpl &dims, + SmallVectorImpl &symbols, + bool symbol); + // Polyhedral structures ParseResult parseAffineMapDef(); + AffineMap *parseAffineMapInline(StringRef mapId); + AffineExpr *parseAffineExpr(AffineMapParserState &state); // Functions. ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type); @@ -476,6 +490,40 @@ ParseResult Parser::parseTypeList(SmallVectorImpl &elements) { return ParseSuccess; } +namespace { +/// This class represents the transient parser state while parsing an affine +/// expression. +class AffineMapParserState { + public: + explicit AffineMapParserState(ArrayRef dims, + ArrayRef symbols) : + dims_(dims), symbols_(symbols) {} + + unsigned dimCount() const { return dims_.size(); } + unsigned symbolCount() const { return symbols_.size(); } + + // Stack operations for affine expression parsing + // TODO(bondhugula): all of this will be improved/made more principled + void pushAffineExpr(AffineExpr *expr) { exprStack.push(expr); } + AffineExpr *popAffineExpr() { + auto *t = exprStack.top(); + exprStack.pop(); + return t; + } + AffineExpr *topAffineExpr() { return exprStack.top(); } + + ArrayRef getDims() const { return dims_; } + ArrayRef getSymbols() const { return symbols_; } + + private: + const ArrayRef dims_; + const ArrayRef symbols_; + + // TEMP: stack to hold affine expressions + std::stack exprStack; +}; +} // end anonymous namespace + //===----------------------------------------------------------------------===// // Polyhedral structures. //===----------------------------------------------------------------------===// @@ -483,24 +531,205 @@ ParseResult Parser::parseTypeList(SmallVectorImpl &elements) { /// Affine map declaration. /// /// affine-map-def ::= affine-map-id `=` affine-map-inline -/// affine-map-inline ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr -/// ( `size` `(` dim-size (`,` dim-size)* `)` )? -/// dim-size ::= affine-expr | `min` `(` affine-expr ( `,` affine-expr)+ `)` /// ParseResult Parser::parseAffineMapDef() { assert(curToken.is(Token::affine_map_identifier)); StringRef affineMapId = curToken.getSpelling().drop_front(); + consumeToken(Token::affine_map_identifier); + // Check that 'affineMapId' is unique. // TODO(andydavis) Add a unit test for this case. if (affineMaps.count(affineMapId) > 0) return emitError("redefinition of affine map id '" + affineMapId + "'"); + // Parse the '=' + if (!consumeIf(Token::equal)) + return emitError("expected '=' in affine map outlined definition"); - consumeToken(Token::affine_map_identifier); + auto *affineMap = parseAffineMapInline(affineMapId); + affineMaps[affineMapId].reset(affineMap); + if (!affineMap) return ParseFailure; - // TODO(andydavis,bondhugula) Parse affine map definition. - affineMaps[affineMapId].reset(new AffineMap(1, 0)); - return ParseSuccess; + module->affineMapList.push_back(affineMap); + return affineMap ? ParseSuccess : ParseFailure; +} + +/// +/// Parse a multi-dimensional affine expression +/// affine-expr ::= `(` affine-expr `)` +/// | affine-expr `+` affine-expr +/// | affine-expr `-` affine-expr +/// | `-`? integer-literal `*` affine-expr +/// | `ceildiv` `(` affine-expr `,` integer-literal `)` +/// | `floordiv` `(` affine-expr `,` integer-literal `)` +/// | affine-expr `mod` integer-literal +/// | bare-id +/// | `-`? integer-literal +/// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `) +/// +/// Use 'state' to check if valid identifiers appear. +/// +AffineExpr *Parser::parseAffineExpr(AffineMapParserState &state) { + // TODO(bondhugula): complete support for this + // The code below is all placeholder / it is wrong / not complete + // Operator precedence not considered; pure left to right associativity + if (curToken.is(Token::comma)) { + emitError("expecting affine expression"); + return nullptr; + } + + while (curToken.isNot(Token::comma, Token::r_paren, + Token::eof, Token::error)) { + switch (curToken.getKind()) { + case Token::bare_identifier: { + // TODO(bondhugula): look up state to see if it's a symbol or dim_id and + // get its position + AffineExpr *expr = AffineDimExpr::get(0, context); + state.pushAffineExpr(expr); + consumeToken(Token::bare_identifier); + break; + } + case Token::plus: { + consumeToken(Token::plus); + if (state.topAffineExpr()) { + auto lChild = state.popAffineExpr(); + auto rChild = parseAffineExpr(state); + if (rChild) { + auto binaryOpExpr = AffineAddExpr::get(lChild, rChild, context); + state.popAffineExpr(); + state.pushAffineExpr(binaryOpExpr); + } else { + emitError("right operand of + missing"); + } + } else { + emitError("left operand of + missing"); + } + break; + } + case Token::integer: { + AffineExpr *expr = AffineConstantExpr::get( + curToken.getUnsignedIntegerValue().getValue(), context); + state.pushAffineExpr(expr); + consumeToken(Token::integer); + break; + } + case Token::l_paren: { + consumeToken(Token::l_paren); + break; + } + case Token::r_paren: { + consumeToken(Token::r_paren); + break; + } + default: { + emitError("affine map expr parse impl incomplete/unexpected token"); + return nullptr; + } + } + } + if (!state.topAffineExpr()) { + // An error will be emitted by parse comma separated list on an empty list + return nullptr; + } + return state.topAffineExpr(); +} + +// Return empty string if no bare id was found +StringRef Parser::parseDimOrSymbolId(SmallVectorImpl &dims, + SmallVectorImpl &symbols, + bool symbol = false) { + if (curToken.isNot(Token::bare_identifier)) { + emitError("expected bare identifier"); + return StringRef(); + } + // TODO(bondhugula): check whether the id already exists in either + // state.symbols or state.dims; report error if it does; otherwise create a + // new one. + StringRef ref = curToken.getSpelling(); + consumeToken(Token::bare_identifier); + return ref; +} + +ParseResult Parser::parseSymbolIdList(SmallVectorImpl &dims, + SmallVectorImpl &symbols) { + if (!consumeIf(Token::l_bracket)) return emitError("expected '['"); + + auto parseElt = [&]() -> ParseResult { + auto elt = parseDimOrSymbolId(dims, symbols, true); + // FIXME(bondhugula): assuming dim arg for now + if (!elt.empty()) { + symbols.push_back(elt); + return ParseSuccess; + } + return ParseFailure; + }; + return parseCommaSeparatedList(Token::r_bracket, parseElt); +} + +// TODO(andy,bondhugula) +ParseResult Parser::parseDimIdList(SmallVectorImpl &dims, + SmallVectorImpl &symbols) { + if (!consumeIf(Token::l_paren)) + return emitError("expected '(' at start of dimensional identifiers list"); + + auto parseElt = [&]() -> ParseResult { + auto elt = parseDimOrSymbolId(dims, symbols, false); + if (!elt.empty()) { + dims.push_back(elt); + return ParseSuccess; + } + return ParseFailure; + }; + + return parseCommaSeparatedList(Token::r_paren, parseElt); +} + +/// Affine map definition. +/// +/// affine-map-inline ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr +/// ( `size` `(` dim-size (`,` dim-size)* `)` )? +/// dim-size ::= affine-expr | `min` `(` affine-expr ( `,` affine-expr)+ `)` +/// +AffineMap *Parser::parseAffineMapInline(StringRef mapId) { + SmallVector dims; + SmallVector symbols; + + // List of dimensional identifiers. + if (parseDimIdList(dims, symbols)) return nullptr; + + // Symbols are optional. + if (curToken.is(Token::l_bracket)) { + if (parseSymbolIdList(dims, symbols)) return nullptr; + } + if (!consumeIf(Token::arrow)) { + emitError("expected '->' or '['"); + return nullptr; + } + if (!consumeIf(Token::l_paren)) { + emitError("expected '(' at start of affine map range"); + return nullptr; + } + + AffineMapParserState affState(dims, symbols); + + SmallVector exprs; + auto parseElt = [&]() -> ParseResult { + auto elt = parseAffineExpr(affState); + ParseResult res = elt ? ParseSuccess : ParseFailure; + exprs.push_back(elt); + return res; + }; + + // Parse a multi-dimensional affine expression (a comma-separated list of 1-d + // affine expressions) + if (parseCommaSeparatedList(Token::r_paren, parseElt, false)) return nullptr; + + // Parsed a valid affine map + auto *affineMap = + AffineMap::get(affState.dimCount(), affState.symbolCount(), exprs, + context); + + return affineMap; } //===----------------------------------------------------------------------===// @@ -525,8 +754,8 @@ ParseResult Parser::parseFunctionSignature(StringRef &name, if (curToken.isNot(Token::l_paren)) return emitError("expected '(' in function signature"); - SmallVector arguments; - if (parseTypeList(arguments)) + SmallVector arguments; + if (parseTypeList(arguments)) return ParseFailure; // Parse the return type if present. @@ -563,7 +792,7 @@ namespace { /// function as we are parsing it, e.g. the names for basic blocks. It handles /// forward references. class CFGFunctionParserState { -public: + public: CFGFunction *function; llvm::StringMap> blocksByName; @@ -851,3 +1080,4 @@ Module *mlir::parseSourceFile(llvm::SourceMgr &sourceMgr, MLIRContext *context, const SMDiagnosticHandlerTy &errorReporter) { return Parser(sourceMgr, context, errorReporter).parseModule(); } + diff --git a/mlir/lib/Parser/Token.cpp b/mlir/lib/Parser/Token.cpp index ca88b06ac41b..5563255b4444 100644 --- a/mlir/lib/Parser/Token.cpp +++ b/mlir/lib/Parser/Token.cpp @@ -65,6 +65,7 @@ StringRef Token::getTokenSpelling(Kind kind) { switch (kind) { default: assert(0 && "This token kind has no fixed spelling"); #define TOK_PUNCTUATION(NAME, SPELLING) case NAME: return SPELLING; +#define TOK_OPERATOR(NAME, SPELLING) case NAME: return SPELLING; #define TOK_KEYWORD(SPELLING) case kw_##SPELLING: return #SPELLING; #include "TokenKinds.def" } diff --git a/mlir/lib/Parser/Token.h b/mlir/lib/Parser/Token.h index 9c4d4f96def9..e5e4fc41886e 100644 --- a/mlir/lib/Parser/Token.h +++ b/mlir/lib/Parser/Token.h @@ -32,6 +32,7 @@ public: #define TOK_IDENTIFIER(NAME) NAME, #define TOK_LITERAL(NAME) NAME, #define TOK_PUNCTUATION(NAME, SPELLING) NAME, +#define TOK_OPERATOR(NAME, SPELLING) NAME, #define TOK_KEYWORD(SPELLING) kw_##SPELLING, #include "TokenKinds.def" }; @@ -99,3 +100,4 @@ private: } // end namespace mlir #endif // MLIR_LIB_PARSER_TOKEN_H + diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def index 7eae4708792e..72d769a90031 100644 --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -21,7 +21,7 @@ //===----------------------------------------------------------------------===// #if !defined(TOK_MARKER) && !defined(TOK_IDENTIFIER) && !defined(TOK_LITERAL)&&\ - !defined(TOK_PUNCTUATION) && !defined(TOK_KEYWORD) + !defined(TOK_PUNCTUATION) && !defined(TOK_OPERATOR) && !defined(TOK_KEYWORD) # error Must define one of the TOK_ macros. #endif @@ -37,6 +37,9 @@ #ifndef TOK_PUNCTUATION #define TOK_PUNCTUATION(NAME, SPELLING) #endif +#ifndef TOK_OPERATOR +#define TOK_OPERATOR(NAME, SPELLING) +#endif #ifndef TOK_KEYWORD #define TOK_KEYWORD(SPELLING) #endif @@ -66,10 +69,20 @@ TOK_PUNCTUATION(l_paren, "(") TOK_PUNCTUATION(r_paren, ")") TOK_PUNCTUATION(l_brace, "{") TOK_PUNCTUATION(r_brace, "}") +TOK_PUNCTUATION(l_bracket, "[") +TOK_PUNCTUATION(r_bracket, "]") TOK_PUNCTUATION(less, "<") TOK_PUNCTUATION(greater, ">") +TOK_PUNCTUATION(equal, "=") // TODO: More punctuation. +// Operators. +TOK_OPERATOR(plus, "+") +TOK_OPERATOR(star, "*") +TOK_OPERATOR(ceildiv, "ceildiv") +TOK_OPERATOR(floordiv, "floordiv") +// TODO: More operator tokens + // Keywords. These turn "foo" into Token::kw_foo enums. TOK_KEYWORD(bf16) TOK_KEYWORD(br) @@ -94,4 +107,5 @@ TOK_KEYWORD(vector) #undef TOK_IDENTIFIER #undef TOK_LITERAL #undef TOK_PUNCTUATION +#undef TOK_OPERATOR #undef TOK_KEYWORD diff --git a/mlir/test/IR/parser-affine-map.mlir b/mlir/test/IR/parser-affine-map.mlir new file mode 100644 index 000000000000..50dd9e2aa8d8 --- /dev/null +++ b/mlir/test/IR/parser-affine-map.mlir @@ -0,0 +1,7 @@ +#hello_world0 = (i, j) [s0] -> (i, j) +#hello_world1 = (i, j) -> (i, j) +#hello_world2 = () -> (0) +#hello_world3 = (i, j) [s0] -> (i + s0, j) +#hello_world4 = (i, j) [s0] -> (i + s0, j + 5) +#hello_world5 (i, j) [s0] -> i + s0, j) +#hello_world5 = (i, j) [s0] -> i + s0, j)