/*
 * Copyright (C) 2012 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <cstring>
#include <vector>

#include "binary_format.h"
#include "dic_node.h"
#include "dic_node_utils.h"
#include "dic_node_vector.h"
#include "proximity_info.h"
#include "proximity_info_state.h"

namespace latinime {

///////////////////////////////
// Node initialization utils //
///////////////////////////////

/* static */ void DicNodeUtils::initAsRoot(const int rootPos, const uint8_t *const dicRoot,
        const int prevWordNodePos, DicNode *newRootNode) {
    int curPos = rootPos;
    const int pos = curPos;
    const int childrenCount = BinaryFormat::getGroupCountAndForwardPointer(dicRoot, &curPos);
    const int childrenPos = curPos;
    newRootNode->initAsRoot(pos, childrenPos, childrenCount, prevWordNodePos);
}

/*static */ void DicNodeUtils::initAsRootWithPreviousWord(const int rootPos,
        const uint8_t *const dicRoot, DicNode *prevWordLastNode, DicNode *newRootNode) {
    int curPos = rootPos;
    const int pos = curPos;
    const int childrenCount = BinaryFormat::getGroupCountAndForwardPointer(dicRoot, &curPos);
    const int childrenPos = curPos;
    newRootNode->initAsRootWithPreviousWord(prevWordLastNode, pos, childrenPos, childrenCount);
}

/* static */ void DicNodeUtils::initByCopy(DicNode *srcNode, DicNode *destNode) {
    destNode->initByCopy(srcNode);
}

///////////////////////////////////
// Traverse node expansion utils //
///////////////////////////////////

/* static */ void DicNodeUtils::createAndGetPassingChildNode(DicNode *dicNode,
        const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly,
        DicNodeVector *childDicNodes) {
    // Passing multiple chars node. No need to traverse child
    const int codePoint = dicNode->getNodeTypedCodePoint();
    const int baseLowerCaseCodePoint = toBaseLowerCase(codePoint);
    const bool isMatch = isMatchedNodeCodePoint(pInfoState, pointIndex, exactOnly, codePoint);
    if (isMatch || isIntentionalOmissionCodePoint(baseLowerCaseCodePoint)) {
        childDicNodes->pushPassingChild(dicNode);
    }
}

/* static */ int DicNodeUtils::createAndGetLeavingChildNode(DicNode *dicNode, int pos,
        const uint8_t *const dicRoot, const int terminalDepth, const ProximityInfoState *pInfoState,
        const int pointIndex, const bool exactOnly, const std::vector<int> *const codePointsFilter,
        const ProximityInfo *const pInfo, DicNodeVector *childDicNodes) {
    int nextPos = pos;
    const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(dicRoot, &pos);
    const bool hasMultipleChars = (0 != (BinaryFormat::FLAG_HAS_MULTIPLE_CHARS & flags));
    const bool isTerminal = (0 != (BinaryFormat::FLAG_IS_TERMINAL & flags));
    const bool hasChildren = BinaryFormat::hasChildrenInFlags(flags);

    int codePoint = BinaryFormat::getCodePointAndForwardPointer(dicRoot, &pos);
    ASSERT(NOT_A_CODE_POINT != codePoint);
    const int nodeCodePoint = codePoint;
    // TODO: optimize this
    int additionalWordBuf[MAX_WORD_LENGTH];
    uint16_t additionalSubwordLength = 0;
    additionalWordBuf[additionalSubwordLength++] = codePoint;

    do {
        const int nextCodePoint = hasMultipleChars
                ? BinaryFormat::getCodePointAndForwardPointer(dicRoot, &pos) : NOT_A_CODE_POINT;
        const bool isLastChar = (NOT_A_CODE_POINT == nextCodePoint);
        if (!isLastChar) {
            additionalWordBuf[additionalSubwordLength++] = nextCodePoint;
        }
        codePoint = nextCodePoint;
    } while (NOT_A_CODE_POINT != codePoint);

    const int probability =
            isTerminal ? BinaryFormat::readProbabilityWithoutMovingPointer(dicRoot, pos) : -1;
    pos = BinaryFormat::skipProbability(flags, pos);
    int childrenPos = hasChildren ? BinaryFormat::readChildrenPosition(dicRoot, flags, pos) : 0;
    const int attributesPos = BinaryFormat::skipChildrenPosition(flags, pos);
    const int siblingPos = BinaryFormat::skipChildrenPosAndAttributes(dicRoot, flags, pos);

    if (isDicNodeFilteredOut(nodeCodePoint, pInfo, codePointsFilter)) {
        return siblingPos;
    }
    if (!isMatchedNodeCodePoint(pInfoState, pointIndex, exactOnly, nodeCodePoint)) {
        return siblingPos;
    }
    const int childrenCount = hasChildren
            ? BinaryFormat::getGroupCountAndForwardPointer(dicRoot, &childrenPos) : 0;
    childDicNodes->pushLeavingChild(dicNode, nextPos, flags, childrenPos, attributesPos, siblingPos,
            nodeCodePoint, childrenCount, probability, -1 /* bigramProbability */, isTerminal,
            hasMultipleChars, hasChildren, additionalSubwordLength, additionalWordBuf);
    return siblingPos;
}

/* static */ bool DicNodeUtils::isDicNodeFilteredOut(const int nodeCodePoint,
        const ProximityInfo *const pInfo, const std::vector<int> *const codePointsFilter) {
    const int filterSize = codePointsFilter ? codePointsFilter->size() : 0;
    if (filterSize <= 0) {
        return false;
    }
    if (pInfo && (pInfo->getKeyIndexOf(nodeCodePoint) == NOT_AN_INDEX
            || isIntentionalOmissionCodePoint(nodeCodePoint))) {
        // If normalized nodeCodePoint is not on the keyboard or skippable, this child is never
        // filtered.
        return false;
    }
    const int lowerCodePoint = toLowerCase(nodeCodePoint);
    const int baseLowerCodePoint = toBaseCodePoint(lowerCodePoint);
    // TODO: Avoid linear search
    for (int i = 0; i < filterSize; ++i) {
        // Checking if a normalized code point is in filter characters when pInfo is not
        // null. When pInfo is null, nodeCodePoint is used to check filtering without
        // normalizing.
        if ((pInfo && ((*codePointsFilter)[i] == lowerCodePoint
                || (*codePointsFilter)[i] == baseLowerCodePoint))
                        || (!pInfo && (*codePointsFilter)[i] == nodeCodePoint)) {
            return false;
        }
    }
    return true;
}

/* static */ void DicNodeUtils::createAndGetAllLeavingChildNodes(DicNode *dicNode,
        const uint8_t *const dicRoot, const ProximityInfoState *pInfoState, const int pointIndex,
        const bool exactOnly, const std::vector<int> *const codePointsFilter,
        const ProximityInfo *const pInfo, DicNodeVector *childDicNodes) {
    const int terminalDepth = dicNode->getLeavingDepth();
    const int childCount = dicNode->getChildrenCount();
    int nextPos = dicNode->getChildrenPos();
    for (int i = 0; i < childCount; i++) {
        const int filterSize = codePointsFilter ? codePointsFilter->size() : 0;
        nextPos = createAndGetLeavingChildNode(dicNode, nextPos, dicRoot, terminalDepth, pInfoState,
                pointIndex, exactOnly, codePointsFilter, pInfo, childDicNodes);
        if (!pInfo && filterSize > 0 && childDicNodes->exceeds(filterSize)) {
            // All code points have been found.
            break;
        }
    }
}

/* static */ void DicNodeUtils::getAllChildDicNodes(DicNode *dicNode, const uint8_t *const dicRoot,
        DicNodeVector *childDicNodes) {
    getProximityChildDicNodes(dicNode, dicRoot, 0, 0, false, childDicNodes);
}

/* static */ void DicNodeUtils::getProximityChildDicNodes(DicNode *dicNode,
        const uint8_t *const dicRoot, const ProximityInfoState *pInfoState, const int pointIndex,
        bool exactOnly, DicNodeVector *childDicNodes) {
    if (dicNode->isTotalInputSizeExceedingLimit()) {
        return;
    }
    if (!dicNode->isLeavingNode()) {
        DicNodeUtils::createAndGetPassingChildNode(dicNode, pInfoState, pointIndex, exactOnly,
                childDicNodes);
    } else {
        DicNodeUtils::createAndGetAllLeavingChildNodes(dicNode, dicRoot, pInfoState, pointIndex,
                exactOnly, 0 /* codePointsFilter */, 0 /* pInfo */,
                childDicNodes);
    }
}

///////////////////
// Scoring utils //
///////////////////
/**
 * Computes the combined bigram / unigram cost for the given dicNode.
 */
/* static */ float DicNodeUtils::getBigramNodeImprobability(const uint8_t *const dicRoot,
        const DicNode *const node, hash_map_compat<int, int16_t> *bigramCacheMap) {
    if (node->isImpossibleBigramWord()) {
        return static_cast<float>(MAX_VALUE_FOR_WEIGHTING);
    }
    const int probability = getBigramNodeProbability(dicRoot, node, bigramCacheMap);
    // TODO: This equation to calculate the improbability looks unreasonable.  Investigate this.
    const float cost = static_cast<float>(MAX_PROBABILITY - probability)
            / static_cast<float>(MAX_PROBABILITY);
    return cost;
}

/* static */ int DicNodeUtils::getBigramNodeProbability(const uint8_t *const dicRoot,
        const DicNode *const node, hash_map_compat<int, int16_t> *bigramCacheMap) {
    const int unigramProbability = node->getProbability();
    const int encodedDiffOfBigramProbability =
            getBigramNodeEncodedDiffProbability(dicRoot, node, bigramCacheMap);
    if (NOT_A_PROBABILITY == encodedDiffOfBigramProbability) {
        return backoff(unigramProbability);
    }
    return BinaryFormat::computeProbabilityForBigram(
            unigramProbability, encodedDiffOfBigramProbability);
}

///////////////////////////////////////
// Bigram / Unigram dictionary utils //
///////////////////////////////////////

/* static */ int16_t DicNodeUtils::getBigramNodeEncodedDiffProbability(const uint8_t *const dicRoot,
        const DicNode *const node, hash_map_compat<int, int16_t> *bigramCacheMap) {
    const int wordPos = node->getPos();
    const int prevWordPos = node->getPrevWordPos();
    return getBigramProbability(dicRoot, prevWordPos, wordPos, bigramCacheMap);
}

// TODO: Move this to BigramDictionary
/* static */ int16_t DicNodeUtils::getBigramProbability(const uint8_t *const dicRoot, int pos,
        const int nextPos, hash_map_compat<int, int16_t> *bigramCacheMap) {
    // TODO: this is painfully slow compared to the method used in the previous version of the
    // algorithm. Switch to that method.
    if (NOT_VALID_WORD == pos) return NOT_A_PROBABILITY;
    if (NOT_VALID_WORD == nextPos) return NOT_A_PROBABILITY;

    // Create a hash code for the given node pair (based on Josh Bloch's effective Java).
    // TODO: Use a real hash map data structure that deals with collisions.
    int hash = 17;
    hash = hash * 31 + pos;
    hash = hash * 31 + nextPos;

    hash_map_compat<int, int16_t>::const_iterator mapPos = bigramCacheMap->find(hash);
    if (mapPos != bigramCacheMap->end()) {
        return mapPos->second;
    }
    if (NOT_VALID_WORD == pos) {
        return NOT_A_PROBABILITY;
    }
    const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(dicRoot, &pos);
    if (0 == (flags & BinaryFormat::FLAG_HAS_BIGRAMS)) {
        return NOT_A_PROBABILITY;
    }
    if (0 == (flags & BinaryFormat::FLAG_HAS_MULTIPLE_CHARS)) {
        BinaryFormat::getCodePointAndForwardPointer(dicRoot, &pos);
    } else {
        pos = BinaryFormat::skipOtherCharacters(dicRoot, pos);
    }
    pos = BinaryFormat::skipChildrenPosition(flags, pos);
    pos = BinaryFormat::skipProbability(flags, pos);
    uint8_t bigramFlags;
    int count = 0;
    do {
        bigramFlags = BinaryFormat::getFlagsAndForwardPointer(dicRoot, &pos);
        const int bigramPos = BinaryFormat::getAttributeAddressAndForwardPointer(dicRoot,
                bigramFlags, &pos);
        if (bigramPos == nextPos) {
            const int16_t probability = BinaryFormat::MASK_ATTRIBUTE_PROBABILITY & bigramFlags;
            if (static_cast<int>(bigramCacheMap->size()) < MAX_BIGRAM_MAP_SIZE) {
                (*bigramCacheMap)[hash] = probability;
            }
            return probability;
        }
        count++;
    } while ((0 != (BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags))
            && count < MAX_BIGRAMS_CONSIDERED_PER_CONTEXT);
    if (static_cast<int>(bigramCacheMap->size()) < MAX_BIGRAM_MAP_SIZE) {
        // TODO: does this -1 mean NOT_VALID_WORD?
        (*bigramCacheMap)[hash] = -1;
    }
    return NOT_A_PROBABILITY;
}

/* static */ int DicNodeUtils::getWordPos(const uint8_t *const dicRoot, const int *word,
        const int wordLength) {
    if (!word) {
        return NOT_VALID_WORD;
    }
    return BinaryFormat::getTerminalPosition(
            dicRoot, word, wordLength, false /* forceLowerCaseSearch */);
}

/* static */ bool DicNodeUtils::isMatchedNodeCodePoint(const ProximityInfoState *pInfoState,
        const int pointIndex, const bool exactOnly, const int nodeCodePoint) {
    if (!pInfoState) {
        return true;
    }
    if (exactOnly) {
        return pInfoState->getPrimaryCodePointAt(pointIndex) == nodeCodePoint;
    }
    const ProximityType matchedId = pInfoState->getProximityType(pointIndex, nodeCodePoint,
            true /* checkProximityChars */);
    return isProximityChar(matchedId);
}

////////////////
// Char utils //
////////////////

// TODO: Move to char_utils?
/* static */ int DicNodeUtils::appendTwoWords(const int *const src0, const int16_t length0,
        const int *const src1, const int16_t length1, int *dest) {
    int actualLength0 = 0;
    for (int i = 0; i < length0; ++i) {
        if (src0[i] == 0) {
            break;
        }
        actualLength0 = i + 1;
    }
    actualLength0 = min(actualLength0, MAX_WORD_LENGTH);
    memcpy(dest, src0, actualLength0 * sizeof(dest[0]));
    if (!src1 || length1 == 0) {
        return actualLength0;
    }
    int actualLength1 = 0;
    for (int i = 0; i < length1; ++i) {
        if (src1[i] == 0) {
            break;
        }
        actualLength1 = i + 1;
    }
    actualLength1 = min(actualLength1, MAX_WORD_LENGTH - actualLength0 - 1);
    memcpy(&dest[actualLength0], src1, actualLength1 * sizeof(dest[0]));
    return actualLength0 + actualLength1;
}
} // namespace latinime
