//===- UniformConstraints.cpp - Constraints for uniform quant -------------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= #include "mlir/Quantizer/Support/UniformConstraints.h" #include "mlir/Dialect/QuantOps/QuantTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Quantizer/Support/Configuration.h" #include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h" #include "mlir/Quantizer/Support/Metadata.h" #include "mlir/Quantizer/Support/Rules.h" #include "mlir/Quantizer/Support/TypeUtils.h" #include "mlir/Quantizer/Support/UniformSolvers.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; using namespace mlir::quantizer; using namespace mlir::quant; namespace { struct ClusteredFacts { ExpandingMinMaxFact requiredRange; DiscreteScaleZeroPointFact explicitScaleZeroPoint; }; } // end anonymous namespace static QuantizedType solveUniformType(SolverContext &solverContext, const ClusteredFacts &clusteredFacts, const CandidateQuantizedType &ct, Type originalElementType, Location loc) { switch (ct.scheme) { default: solverContext.getMlirContext().emitError( loc, "unsupported scheme for uniform type conversion"); return nullptr; case CandidateQuantizedType::Scheme::UniformPerLayer: { if (!clusteredFacts.requiredRange.hasValue()) { // TODO: Issue some kind of diagnostic. This is not an error. return nullptr; } uint64_t numLevels = ct.quantizedType.getStorageTypeMax() - ct.quantizedType.getStorageTypeMin(); UniformStorageParams params{numLevels, ct.quantizedType.getStorageTypeMin()}; UniformParamsFromMinMaxSolver solver( params, clusteredFacts.requiredRange.getValue().first, clusteredFacts.requiredRange.getValue().second); if (!solver.compute()) { solverContext.getMlirContext().emitWarning(loc) << "unable to solve uniform type with " << "UniformParamsFromMinMaxSolver"; return nullptr; } return UniformQuantizedType::getChecked( ct.quantizedType.getFlags(), ct.quantizedType.getStorageType(), originalElementType, solver.getScale(), solver.getZp(), ct.quantizedType.getStorageTypeMin(), ct.quantizedType.getStorageTypeMax(), loc); } case CandidateQuantizedType::Scheme::UniformExplicitFixedPointScale: { if (!clusteredFacts.explicitScaleZeroPoint.hasValue()) { solverContext.getMlirContext().emitRemark(loc) << "unable to solve uniform type with UniformExplicitFixedPointScale " << "(no explicitScaleZeroPoint)"; return nullptr; } const auto &scaleZp = clusteredFacts.explicitScaleZeroPoint.getValue(); assert(scaleZp.value && "optional value not set on fact"); if (scaleZp.conflict) { solverContext.getMlirContext().emitWarning(loc) << "conflicting explicit scale/zeroPoint on node cluster: " << "an arbitrary scale/zeroPoint will be used"; } return UniformQuantizedType::getChecked( ct.quantizedType.getFlags(), ct.quantizedType.getStorageType(), originalElementType, scaleZp.value->first, // scale 0, // zeroPoint (fixed point solutions only for this scheme) ct.quantizedType.getStorageTypeMin(), ct.quantizedType.getStorageTypeMax(), loc); return nullptr; } } } namespace { class PropagateExplicitScale : public CAGConstraintNode { public: PropagateExplicitScale() : CAGConstraintNode(Kind::UniformPropagateExplicitScale) {} static bool classof(const CAGNode *n) { return n->getKind() == Kind::Constraint || n->getKind() == Kind::UniformPropagateExplicitScale; } private: void printLabel(llvm::raw_ostream &os) const override { os << "PropagateExplicitScale"; } void propagate(SolverContext &solverContext, const TargetConfiguration &config) override { DiscreteScaleZeroPointFact scaleZp; // Get scale/zp from all parents. for (auto it = incoming_begin(), e = incoming_end(); it != e; ++it) { auto parentAnchor = llvm::cast(*it); auto selectedType = parentAnchor->getUniformMetadata().selectedType; if (auto uqType = selectedType.dyn_cast_or_null()) { scaleZp.assertValue( CAGUniformMetadata::SalienceRequired, std::make_pair(uqType.getScale(), static_cast(0))); } } // Propagate to children. if (scaleZp.hasValue()) { for (auto it = begin(), e = end(); it != e; ++it) { auto childAnchor = llvm::cast(*it); if (modified(childAnchor->getUniformMetadata() .explicitScaleZeroPoint.mergeFrom(scaleZp))) { childAnchor->markDirty(); } } } } }; /// A constraint node which will solve uniform quantization for all parents /// of the constraint, assuming that they are coupled. class SolveUniformConstraintNode : public CAGConstraintNode { public: SolveUniformConstraintNode() : CAGConstraintNode(Kind::SolveUniformConstraint) { markDirty(); } static bool classof(const CAGNode *n) { return n->getKind() == Kind::Constraint || n->getKind() == Kind::SolveUniformConstraint; } private: void printLabel(llvm::raw_ostream &os) const override { os << "SolveUniform"; } void propagate(SolverContext &solverContext, const TargetConfiguration &config) override { // First determine the required min/max range and type constraints. Location fusedLoc = UnknownLoc::get(&solverContext.getMlirContext()); llvm::SmallBitVector enabledCandidateTypesMask( config.getAllCandidateTypesMask()); ClusteredFacts clusteredFacts; Type originalElementType; for (auto it = incoming_begin(), e = incoming_end(); it != e; ++it) { auto parentAnchor = llvm::cast(*it); auto metadata = parentAnchor->getUniformMetadata(); // TODO: Possibly use a location that fuses all involved parents. fusedLoc = parentAnchor->getOp()->getLoc(); // Shared element type. auto parentOriginalElementType = getElementOrPrimitiveType(parentAnchor->getOriginalType()); if (!originalElementType) { originalElementType = parentOriginalElementType; } else { if (originalElementType != parentOriginalElementType) { parentAnchor->getOp()->emitError() << "cannot compute uniform type: parent element types mismatch"; return; } } // Range. clusteredFacts.requiredRange.mergeFrom(metadata.requiredRange); // Explicit scale and zero point. clusteredFacts.explicitScaleZeroPoint.mergeFrom( metadata.explicitScaleZeroPoint); // Shared candidate types. enabledCandidateTypesMask.reset(metadata.disabledCandidateTypes); } // Find the first enabled candidate type. const CandidateQuantizedType *bestCandidateType = nullptr; for (auto &ct : config.getCandidateTypes()) { if (enabledCandidateTypesMask.test(ct.ordinal)) { bestCandidateType = &ct; break; } } if (!bestCandidateType || !originalElementType) { solverContext.getMlirContext().emitRemark(fusedLoc) << "not solving uniform type (no viable candidate type)"; return; } // Solve for the type. QuantizedType selectedType = solveUniformType(solverContext, clusteredFacts, *bestCandidateType, originalElementType, fusedLoc); // Apply it to all parents. for (auto it = incoming_begin(), e = incoming_end(); it != e; ++it) { auto parentAnchor = llvm::cast(*it); auto &metadata = parentAnchor->getUniformMetadata(); if (metadata.selectedType != selectedType) { metadata.selectedType = selectedType; // And mark all children of the parent dirty (except us). for (auto child : *parentAnchor) { if (child != this) { child->markDirty(); } } } } } }; } // end anonymous namespace void UniformConstraintsBuilder::coupleAnchors(CAGAnchorNode *a, CAGAnchorNode *b) { slice.addClusteredConstraint({a, b}); } void UniformConstraintsBuilder::applyStats(CAGAnchorNode *a, TensorAxisStatistics stats) { a->getUniformMetadata().requiredRange.assertValue( CAGUniformMetadata::SalienceDefault, {stats.minValue, stats.maxValue}); } void UniformConstraintsBuilder::clamp(CAGAnchorNode *a, APFloat minValue, APFloat maxValue) { a->getUniformMetadata().requiredRange.assertValue( CAGUniformMetadata::SalienceDefault, {minValue.convertToDouble(), maxValue.convertToDouble()}); } void UniformConstraintsBuilder::propagateExplicitScale(CAGAnchorNode *from, CAGAnchorNode *to) { slice.addUnidirectionalConstraint(from, {to}); }