blob: 3938c0ec53df1390d8af19f304b2c8132f8e57fe [file] [log] [blame]
/*
* Copyright (C) 2012 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LATINIME_TYPING_WEIGHTING_H
#define LATINIME_TYPING_WEIGHTING_H
#include "defines.h"
#include "suggest_utils.h"
#include "suggest/core/dicnode/dic_node_utils.h"
#include "suggest/core/policy/weighting.h"
#include "suggest/core/session/dic_traverse_session.h"
#include "suggest/policyimpl/typing/scoring_params.h"
namespace latinime {
class DicNode;
struct DicNode_InputStateG;
class MultiBigramMap;
class TypingWeighting : public Weighting {
public:
static const TypingWeighting *getInstance() { return &sInstance; }
protected:
float getTerminalSpatialCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const {
float cost = 0.0f;
if (dicNode->hasMultipleWords()) {
cost += ScoringParams::HAS_MULTI_WORD_TERMINAL_COST;
}
if (dicNode->getProximityCorrectionCount() > 0) {
cost += ScoringParams::HAS_PROXIMITY_TERMINAL_COST;
}
if (dicNode->getEditCorrectionCount() > 0) {
cost += ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST;
}
return cost;
}
float getOmissionCost(const DicNode *const parentDicNode, const DicNode *const dicNode) const {
const bool isZeroCostOmission = parentDicNode->isZeroCostOmission();
const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode);
// If the traversal omitted the first letter then the dicNode should now be on the second.
const bool isFirstLetterOmission = dicNode->getDepth() == 2;
float cost = 0.0f;
if (isZeroCostOmission) {
cost = 0.0f;
} else if (isFirstLetterOmission) {
cost = ScoringParams::OMISSION_COST_FIRST_CHAR;
} else {
cost = sameCodePoint ? ScoringParams::OMISSION_COST_SAME_CHAR
: ScoringParams::OMISSION_COST;
}
return cost;
}
float getMatchedCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const {
const int pointIndex = dicNode->getInputIndex(0);
// Note: min() required since length can be MAX_POINT_TO_KEY_LENGTH for characters not on
// the keyboard (like accented letters)
const float normalizedSquaredLength = traverseSession->getProximityInfoState(0)
->getPointToKeyLength(pointIndex, dicNode->getNodeCodePoint());
const float normalizedDistance = SuggestUtils::getSweetSpotFactor(
traverseSession->isTouchPositionCorrectionEnabled(), normalizedSquaredLength);
const float weightedDistance = ScoringParams::DISTANCE_WEIGHT_LENGTH * normalizedDistance;
const bool isFirstChar = pointIndex == 0;
const bool isProximity = isProximityDicNode(traverseSession, dicNode);
float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_PROXIMITY_COST
: ScoringParams::PROXIMITY_COST) : 0.0f;
if (dicNode->getDepth() == 2) {
// At the second character of the current word, we check if the first char is uppercase
// and the word is a second or later word of a multiple word suggestion. We demote it
// if so.
const bool isSecondOrLaterWordFirstCharUppercase =
dicNode->hasMultipleWords() && dicNode->isFirstCharUppercase();
if (isSecondOrLaterWordFirstCharUppercase) {
cost += ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE;
}
}
return weightedDistance + cost;
}
bool isProximityDicNode(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const {
const int pointIndex = dicNode->getInputIndex(0);
const int primaryCodePoint = toBaseLowerCase(
traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(pointIndex));
const int dicNodeChar = toBaseLowerCase(dicNode->getNodeCodePoint());
return primaryCodePoint != dicNodeChar;
}
float getTranspositionCost(const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode) const {
const int16_t parentPointIndex = parentDicNode->getInputIndex(0);
const int prevCodePoint = parentDicNode->getNodeCodePoint();
const float distance1 = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
parentPointIndex + 1, prevCodePoint);
const int codePoint = dicNode->getNodeCodePoint();
const float distance2 = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
parentPointIndex, codePoint);
const float distance = distance1 + distance2;
const float weightedLengthDistance =
distance * ScoringParams::DISTANCE_WEIGHT_LENGTH;
return ScoringParams::TRANSPOSITION_COST + weightedLengthDistance;
}
float getInsertionCost(const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode) const {
const int16_t parentPointIndex = parentDicNode->getInputIndex(0);
const int prevCodePoint =
traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(parentPointIndex);
const int currentCodePoint = dicNode->getNodeCodePoint();
const bool sameCodePoint = prevCodePoint == currentCodePoint;
const float dist = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
parentPointIndex + 1, currentCodePoint);
const float weightedDistance = dist * ScoringParams::DISTANCE_WEIGHT_LENGTH;
const bool singleChar = dicNode->getDepth() == 1;
const float cost = (singleChar ? ScoringParams::INSERTION_COST_FIRST_CHAR : 0.0f)
+ (sameCodePoint ? ScoringParams::INSERTION_COST_SAME_CHAR
: ScoringParams::INSERTION_COST);
return cost + weightedDistance;
}
float getNewWordCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const {
return ScoringParams::COST_NEW_WORD * traverseSession->getMultiWordCostMultiplier();
}
float getNewWordBigramCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode,
MultiBigramMap *const multiBigramMap) const {
return DicNodeUtils::getBigramNodeImprobability(traverseSession->getOffsetDict(),
dicNode, multiBigramMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
}
float getCompletionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const {
// The auto completion starts when the input index is same as the input size
const bool firstCompletion = dicNode->getInputIndex(0)
== traverseSession->getInputSize();
// TODO: Change the cost for the first completion for the gesture?
const float cost = firstCompletion ? ScoringParams::COST_FIRST_LOOKAHEAD
: ScoringParams::COST_LOOKAHEAD;
return cost;
}
float getTerminalLanguageCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode, const float dicNodeLanguageImprobability) const {
const float languageImprobability = (dicNode->isExactMatch()) ?
0.0f : dicNodeLanguageImprobability;
return languageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
}
AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const {
return false;
}
AK_FORCE_INLINE float getAdditionalProximityCost() const {
return ScoringParams::ADDITIONAL_PROXIMITY_COST;
}
AK_FORCE_INLINE float getSubstitutionCost() const {
return ScoringParams::SUBSTITUTION_COST;
}
AK_FORCE_INLINE float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const {
const float cost = ScoringParams::SPACE_SUBSTITUTION_COST + ScoringParams::COST_NEW_WORD;
return cost * traverseSession->getMultiWordCostMultiplier();
}
ErrorType getErrorType(const CorrectionType correctionType,
const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode) const;
private:
DISALLOW_COPY_AND_ASSIGN(TypingWeighting);
static const TypingWeighting sInstance;
TypingWeighting() {}
~TypingWeighting() {}
};
} // namespace latinime
#endif // LATINIME_TYPING_WEIGHTING_H