blob: 8902d31229b336ab1bea80a85d790dabfacc95b4 [file] [log] [blame]
/*
* Copyright (C) 2012 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LATINIME_DIC_NODE_STATE_SCORING_H
#define LATINIME_DIC_NODE_STATE_SCORING_H
#include <stdint.h>
#include "defines.h"
#include "digraph_utils.h"
namespace latinime {
class DicNodeStateScoring {
public:
AK_FORCE_INLINE DicNodeStateScoring()
: mDoubleLetterLevel(NOT_A_DOUBLE_LETTER),
mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX),
mEditCorrectionCount(0), mProximityCorrectionCount(0),
mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f),
mTotalPrevWordsLanguageCost(0.0f), mRawLength(0.0f) {
}
virtual ~DicNodeStateScoring() {}
void init() {
mEditCorrectionCount = 0;
mProximityCorrectionCount = 0;
mNormalizedCompoundDistance = 0.0f;
mSpatialDistance = 0.0f;
mLanguageDistance = 0.0f;
mTotalPrevWordsLanguageCost = 0.0f;
mRawLength = 0.0f;
mDoubleLetterLevel = NOT_A_DOUBLE_LETTER;
mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX;
}
AK_FORCE_INLINE void init(const DicNodeStateScoring *const scoring) {
mEditCorrectionCount = scoring->mEditCorrectionCount;
mProximityCorrectionCount = scoring->mProximityCorrectionCount;
mNormalizedCompoundDistance = scoring->mNormalizedCompoundDistance;
mSpatialDistance = scoring->mSpatialDistance;
mLanguageDistance = scoring->mLanguageDistance;
mTotalPrevWordsLanguageCost = scoring->mTotalPrevWordsLanguageCost;
mRawLength = scoring->mRawLength;
mDoubleLetterLevel = scoring->mDoubleLetterLevel;
mDigraphIndex = scoring->mDigraphIndex;
}
void addCost(const float spatialCost, const float languageCost, const bool doNormalization,
const int inputSize, const int totalInputIndex, const bool isEditCorrection,
const bool isProximityCorrection) {
addDistance(spatialCost, languageCost, doNormalization, inputSize, totalInputIndex);
if (isEditCorrection) {
++mEditCorrectionCount;
}
if (isProximityCorrection) {
++mProximityCorrectionCount;
}
if (languageCost > 0.0f) {
setTotalPrevWordsLanguageCost(mTotalPrevWordsLanguageCost + languageCost);
}
}
void addRawLength(const float rawLength) {
mRawLength += rawLength;
}
float getCompoundDistance() const {
return getCompoundDistance(1.0f);
}
float getCompoundDistance(const float languageWeight) const {
return mSpatialDistance + mLanguageDistance * languageWeight;
}
float getNormalizedCompoundDistance() const {
return mNormalizedCompoundDistance;
}
float getSpatialDistance() const {
return mSpatialDistance;
}
float getLanguageDistance() const {
return mLanguageDistance;
}
int16_t getEditCorrectionCount() const {
return mEditCorrectionCount;
}
int16_t getProximityCorrectionCount() const {
return mProximityCorrectionCount;
}
float getRawLength() const {
return mRawLength;
}
DoubleLetterLevel getDoubleLetterLevel() const {
return mDoubleLetterLevel;
}
void setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel) {
switch(doubleLetterLevel) {
case NOT_A_DOUBLE_LETTER:
break;
case A_DOUBLE_LETTER:
if (mDoubleLetterLevel != A_STRONG_DOUBLE_LETTER) {
mDoubleLetterLevel = doubleLetterLevel;
}
break;
case A_STRONG_DOUBLE_LETTER:
mDoubleLetterLevel = doubleLetterLevel;
break;
}
}
DigraphUtils::DigraphCodePointIndex getDigraphIndex() const {
return mDigraphIndex;
}
void advanceDigraphIndex() {
switch(mDigraphIndex) {
case DigraphUtils::NOT_A_DIGRAPH_INDEX:
mDigraphIndex = DigraphUtils::FIRST_DIGRAPH_CODEPOINT;
break;
case DigraphUtils::FIRST_DIGRAPH_CODEPOINT:
mDigraphIndex = DigraphUtils::SECOND_DIGRAPH_CODEPOINT;
break;
case DigraphUtils::SECOND_DIGRAPH_CODEPOINT:
mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX;
break;
}
}
float getTotalPrevWordsLanguageCost() const {
return mTotalPrevWordsLanguageCost;
}
private:
// Caution!!!
// Use a default copy constructor and an assign operator because shallow copies are ok
// for this class
DoubleLetterLevel mDoubleLetterLevel;
DigraphUtils::DigraphCodePointIndex mDigraphIndex;
int16_t mEditCorrectionCount;
int16_t mProximityCorrectionCount;
float mNormalizedCompoundDistance;
float mSpatialDistance;
float mLanguageDistance;
float mTotalPrevWordsLanguageCost;
float mRawLength;
AK_FORCE_INLINE void addDistance(float spatialDistance, float languageDistance,
bool doNormalization, int inputSize, int totalInputIndex) {
mSpatialDistance += spatialDistance;
mLanguageDistance += languageDistance;
if (!doNormalization) {
mNormalizedCompoundDistance = mSpatialDistance + mLanguageDistance;
} else {
mNormalizedCompoundDistance = (mSpatialDistance + mLanguageDistance)
/ static_cast<float>(max(1, totalInputIndex));
}
}
//TODO: remove
AK_FORCE_INLINE void setTotalPrevWordsLanguageCost(float totalPrevWordsLanguageCost) {
mTotalPrevWordsLanguageCost = totalPrevWordsLanguageCost;
}
};
} // namespace latinime
#endif // LATINIME_DIC_NODE_STATE_SCORING_H