The "Dim" prefix is a legacy left-over that no longer makes sense, since we have a very strict "Dimension" vs. "Level" definition for sparse tensor types and their storage.
306 lines
12 KiB
C++
306 lines
12 KiB
C++
//===- DimLvlMapParser.cpp - `DimLvlMap` parser implementation ------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "DimLvlMapParser.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::sparse_tensor;
|
|
using namespace mlir::sparse_tensor::ir_detail;
|
|
|
|
#define FAILURE_IF_FAILED(RES) \
|
|
if (failed(RES)) { \
|
|
return failure(); \
|
|
}
|
|
|
|
/// Helper function for `FAILURE_IF_NULLOPT_OR_FAILED` to avoid duplicating
|
|
/// its `RES` parameter.
|
|
static inline bool didntSucceed(OptionalParseResult res) {
|
|
return !res.has_value() || failed(*res);
|
|
}
|
|
|
|
#define FAILURE_IF_NULLOPT_OR_FAILED(RES) \
|
|
if (didntSucceed(RES)) { \
|
|
return failure(); \
|
|
}
|
|
|
|
// NOTE: this macro assumes `AsmParser parser` and `SMLoc loc` are in scope.
|
|
#define ERROR_IF(COND, MSG) \
|
|
if (COND) { \
|
|
return parser.emitError(loc, MSG); \
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// `DimLvlMapParser` implementation for variable parsing.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Our variation on `AffineParser::{parseBareIdExpr,parseIdentifierDefinition}`
|
|
OptionalParseResult DimLvlMapParser::parseVar(VarKind vk, bool isOptional,
|
|
Policy creationPolicy,
|
|
VarInfo::ID &varID,
|
|
bool &didCreate) {
|
|
// Save the current location so that we can have error messages point to
|
|
// the right place.
|
|
const auto loc = parser.getCurrentLocation();
|
|
StringRef name;
|
|
if (failed(parser.parseOptionalKeyword(&name))) {
|
|
ERROR_IF(!isOptional, "expected bare identifier")
|
|
return std::nullopt;
|
|
}
|
|
|
|
if (const auto res = env.lookupOrCreate(creationPolicy, name, loc, vk)) {
|
|
varID = res->first;
|
|
didCreate = res->second;
|
|
return success();
|
|
}
|
|
|
|
switch (creationPolicy) {
|
|
case Policy::MustNot:
|
|
return parser.emitError(loc, "use of undeclared identifier '" + name + "'");
|
|
case Policy::May:
|
|
llvm_unreachable("got nullopt for Policy::May");
|
|
case Policy::Must:
|
|
return parser.emitError(loc, "redefinition of identifier '" + name + "'");
|
|
}
|
|
llvm_unreachable("unknown Policy");
|
|
}
|
|
|
|
FailureOr<VarInfo::ID> DimLvlMapParser::parseVarUsage(VarKind vk,
|
|
bool requireKnown) {
|
|
VarInfo::ID id;
|
|
bool didCreate;
|
|
const bool isOptional = false;
|
|
const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::May;
|
|
const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate);
|
|
FAILURE_IF_NULLOPT_OR_FAILED(res)
|
|
assert(requireKnown ? !didCreate : true);
|
|
return id;
|
|
}
|
|
|
|
FailureOr<VarInfo::ID> DimLvlMapParser::parseVarBinding(VarKind vk,
|
|
bool requireKnown) {
|
|
const auto loc = parser.getCurrentLocation();
|
|
VarInfo::ID id;
|
|
bool didCreate;
|
|
const bool isOptional = false;
|
|
const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::Must;
|
|
const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate);
|
|
FAILURE_IF_NULLOPT_OR_FAILED(res)
|
|
assert(requireKnown ? !didCreate : didCreate);
|
|
bindVar(loc, id);
|
|
return id;
|
|
}
|
|
|
|
FailureOr<std::pair<Var, bool>>
|
|
DimLvlMapParser::parseOptionalVarBinding(VarKind vk, bool requireKnown) {
|
|
const auto loc = parser.getCurrentLocation();
|
|
VarInfo::ID id;
|
|
bool didCreate;
|
|
const bool isOptional = true;
|
|
const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::Must;
|
|
const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate);
|
|
if (res.has_value()) {
|
|
FAILURE_IF_FAILED(*res)
|
|
assert(didCreate);
|
|
return std::make_pair(bindVar(loc, id), true);
|
|
}
|
|
assert(!didCreate);
|
|
return std::make_pair(env.bindUnusedVar(vk), false);
|
|
}
|
|
|
|
Var DimLvlMapParser::bindVar(llvm::SMLoc loc, VarInfo::ID id) {
|
|
MLIRContext *context = parser.getContext();
|
|
const auto var = env.bindVar(id);
|
|
const auto &info = std::as_const(env).access(id);
|
|
const auto name = info.getName();
|
|
const auto num = *info.getNum();
|
|
switch (info.getKind()) {
|
|
case VarKind::Symbol: {
|
|
const auto affine = getAffineSymbolExpr(num, context);
|
|
dimsAndSymbols.emplace_back(name, affine);
|
|
lvlsAndSymbols.emplace_back(name, affine);
|
|
return var;
|
|
}
|
|
case VarKind::Dimension:
|
|
dimsAndSymbols.emplace_back(name, getAffineDimExpr(num, context));
|
|
return var;
|
|
case VarKind::Level:
|
|
lvlsAndSymbols.emplace_back(name, getAffineDimExpr(num, context));
|
|
return var;
|
|
}
|
|
llvm_unreachable("unknown VarKind");
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// `DimLvlMapParser` implementation for `DimLvlMap` per se.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
FailureOr<DimLvlMap> DimLvlMapParser::parseDimLvlMap() {
|
|
FAILURE_IF_FAILED(parseSymbolBindingList())
|
|
FAILURE_IF_FAILED(parseLvlVarBindingList())
|
|
FAILURE_IF_FAILED(parseDimSpecList())
|
|
FAILURE_IF_FAILED(parser.parseArrow())
|
|
FAILURE_IF_FAILED(parseLvlSpecList())
|
|
InFlightDiagnostic ifd = env.emitErrorIfAnyUnbound(parser);
|
|
if (failed(ifd))
|
|
return ifd;
|
|
return DimLvlMap(env.getRanks().getSymRank(), dimSpecs, lvlSpecs);
|
|
}
|
|
|
|
ParseResult DimLvlMapParser::parseSymbolBindingList() {
|
|
return parser.parseCommaSeparatedList(
|
|
OpAsmParser::Delimiter::OptionalSquare,
|
|
[this]() { return ParseResult(parseVarBinding(VarKind::Symbol)); },
|
|
" in symbol binding list");
|
|
}
|
|
|
|
ParseResult DimLvlMapParser::parseLvlVarBindingList() {
|
|
return parser.parseCommaSeparatedList(
|
|
OpAsmParser::Delimiter::OptionalBraces,
|
|
[this]() { return ParseResult(parseVarBinding(VarKind::Level)); },
|
|
" in level declaration list");
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// `DimLvlMapParser` implementation for `DimSpec`.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult DimLvlMapParser::parseDimSpecList() {
|
|
return parser.parseCommaSeparatedList(
|
|
OpAsmParser::Delimiter::Paren,
|
|
[this]() -> ParseResult { return parseDimSpec(); },
|
|
" in dimension-specifier list");
|
|
}
|
|
|
|
ParseResult DimLvlMapParser::parseDimSpec() {
|
|
// Parse the requisite dim-var binding.
|
|
const auto varID = parseVarBinding(VarKind::Dimension);
|
|
FAILURE_IF_FAILED(varID)
|
|
const DimVar var = env.getVar(*varID).cast<DimVar>();
|
|
|
|
// Parse an optional dimension expression.
|
|
AffineExpr affine;
|
|
if (succeeded(parser.parseOptionalEqual())) {
|
|
// Parse the dim affine expr, with only any lvl-vars in scope.
|
|
FAILURE_IF_FAILED(parser.parseAffineExpr(lvlsAndSymbols, affine))
|
|
}
|
|
DimExpr expr{affine};
|
|
|
|
// Parse an optional slice.
|
|
SparseTensorDimSliceAttr slice;
|
|
if (succeeded(parser.parseOptionalColon())) {
|
|
const auto loc = parser.getCurrentLocation();
|
|
Attribute attr;
|
|
FAILURE_IF_FAILED(parser.parseAttribute(attr))
|
|
slice = llvm::dyn_cast<SparseTensorDimSliceAttr>(attr);
|
|
ERROR_IF(!slice, "expected SparseTensorDimSliceAttr")
|
|
}
|
|
|
|
dimSpecs.emplace_back(var, expr, slice);
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// `DimLvlMapParser` implementation for `LvlSpec`.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
ParseResult DimLvlMapParser::parseLvlSpecList() {
|
|
// This method currently only supports two syntaxes:
|
|
//
|
|
// (1) There are no forward-declarations, and no lvl-var bindings:
|
|
// (d0, d1) -> (d0 : dense, d1 : compressed)
|
|
// Therefore `parseLvlVarBindingList` didn't bind any lvl-vars, and thus
|
|
// `parseLvlSpec` will need to use `VarEnv::bindUnusedVar` to ensure that
|
|
// the level-rank is correct at the end of parsing.
|
|
//
|
|
// (2) There are forward-declarations, and every lvl-spec must have
|
|
// a lvl-var binding:
|
|
// {l0, l1} (d0 = l0, d1 = l1) -> (l0 = d0 : dense, l1 = d1 : compressed)
|
|
// However, this introduces duplicate information since the order of
|
|
// the lvl-vars in `parseLvlVarBindingList` must agree with their order
|
|
// in the list of lvl-specs. Therefore, `parseLvlSpec` will not call
|
|
// `VarEnv::bindVar` (since `parseLvlVarBindingList` already did so),
|
|
// and must also validate the consistency between the two lvl-var orders.
|
|
const auto declaredLvlRank = env.getRanks().getLvlRank();
|
|
const bool requireLvlVarBinding = declaredLvlRank != 0;
|
|
// Have `ERROR_IF` point to the start of the list.
|
|
const auto loc = parser.getCurrentLocation();
|
|
const auto res = parser.parseCommaSeparatedList(
|
|
mlir::OpAsmParser::Delimiter::Paren,
|
|
[=]() -> ParseResult { return parseLvlSpec(requireLvlVarBinding); },
|
|
" in level-specifier list");
|
|
FAILURE_IF_FAILED(res)
|
|
const auto specLvlRank = lvlSpecs.size();
|
|
ERROR_IF(requireLvlVarBinding && specLvlRank != declaredLvlRank,
|
|
"Level-rank mismatch between forward-declarations and specifiers. "
|
|
"Declared " +
|
|
Twine(declaredLvlRank) + " level-variables; but got " +
|
|
Twine(specLvlRank) + " level-specifiers.")
|
|
return success();
|
|
}
|
|
|
|
static inline Twine nth(Var::Num n) {
|
|
switch (n) {
|
|
case 1:
|
|
return "1st";
|
|
case 2:
|
|
return "2nd";
|
|
default:
|
|
return Twine(n) + "th";
|
|
}
|
|
}
|
|
|
|
FailureOr<LvlVar>
|
|
DimLvlMapParser::parseLvlVarBinding(bool requireLvlVarBinding) {
|
|
// Nothing to parse, just bind an unnamed variable.
|
|
if (!requireLvlVarBinding)
|
|
return env.bindUnusedVar(VarKind::Level).cast<LvlVar>();
|
|
|
|
const auto loc = parser.getCurrentLocation();
|
|
// NOTE: Calling `parseVarUsage` here is semantically inappropriate,
|
|
// since the thing we're parsing is supposed to be a variable *binding*
|
|
// rather than a variable *use*. However, the call to `VarEnv::bindVar`
|
|
// (and its corresponding call to `DimLvlMapParser::recordVarBinding`)
|
|
// already occured in `parseLvlVarBindingList`, and therefore we must
|
|
// use `parseVarUsage` here in order to operationally do the right thing.
|
|
const auto varID = parseVarUsage(VarKind::Level, /*requireKnown=*/true);
|
|
FAILURE_IF_FAILED(varID)
|
|
const auto &info = std::as_const(env).access(*varID);
|
|
const auto var = info.getVar().cast<LvlVar>();
|
|
const auto forwardNum = var.getNum();
|
|
const auto specNum = lvlSpecs.size();
|
|
ERROR_IF(forwardNum != specNum,
|
|
"Level-variable ordering mismatch. The variable '" + info.getName() +
|
|
"' was forward-declared as the " + nth(forwardNum) +
|
|
" level; but is bound by the " + nth(specNum) +
|
|
" specification.")
|
|
FAILURE_IF_FAILED(parser.parseEqual())
|
|
return var;
|
|
}
|
|
|
|
ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) {
|
|
// Parse the optional lvl-var binding. `requireLvlVarBinding`
|
|
// specifies whether that "optional" is actually Must or MustNot.
|
|
const auto varRes = parseLvlVarBinding(requireLvlVarBinding);
|
|
FAILURE_IF_FAILED(varRes)
|
|
const LvlVar var = *varRes;
|
|
|
|
// Parse the lvl affine expr, with only the dim-vars in scope.
|
|
AffineExpr affine;
|
|
FAILURE_IF_FAILED(parser.parseAffineExpr(dimsAndSymbols, affine))
|
|
LvlExpr expr{affine};
|
|
|
|
FAILURE_IF_FAILED(parser.parseColon())
|
|
const auto type = lvlTypeParser.parseLvlType(parser);
|
|
FAILURE_IF_FAILED(type)
|
|
|
|
lvlSpecs.emplace_back(var, expr, static_cast<LevelType>(*type));
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|