Improve bigram frequency lookup

Bug: 8592527

Change-Id: I1908bcb552279b9acb140fe4d8d26b10ed9eda72
diff --git a/native/jni/src/binary_format.h b/native/jni/src/binary_format.h
index 432a56b..06f50dc 100644
--- a/native/jni/src/binary_format.h
+++ b/native/jni/src/binary_format.h
@@ -23,6 +23,7 @@
 
 #include "bloom_filter.h"
 #include "char_utils.h"
+#include "hash_map_compat.h"
 
 namespace latinime {
 
@@ -93,7 +94,13 @@
             const int unigramProbability, const int bigramProbability);
     static int getProbability(const int position, const std::map<int, int> *bigramMap,
             const uint8_t *bigramFilter, const int unigramProbability);
+    static int getBigramProbabilityFromHashMap(const int position,
+            const hash_map_compat<int, int> *bigramMap, const int unigramProbability);
     static float getMultiWordCostMultiplier(const uint8_t *const dict);
+    static void fillBigramProbabilityToHashMap(const uint8_t *const root, int position,
+            hash_map_compat<int, int> *bigramMap);
+    static int getBigramProbability(const uint8_t *const root, int position,
+            const int nextPosition, const int unigramProbability);
 
     // Flags for special processing
     // Those *must* match the flags in makedict (BinaryDictInputOutput#*_PROCESSING_FLAG) or
@@ -105,6 +112,8 @@
 
  private:
     DISALLOW_IMPLICIT_CONSTRUCTORS(BinaryFormat);
+    static int getBigramListPositionForWordPosition(const uint8_t *const root, int position);
+
     static const int FLAG_GROUP_ADDRESS_TYPE_NOADDRESS = 0x00;
     static const int FLAG_GROUP_ADDRESS_TYPE_ONEBYTE = 0x40;
     static const int FLAG_GROUP_ADDRESS_TYPE_TWOBYTES = 0x80;
@@ -687,5 +696,68 @@
     }
     return backoff(unigramProbability);
 }
+
+// This returns a probability in log space.
+inline int BinaryFormat::getBigramProbabilityFromHashMap(const int position,
+        const hash_map_compat<int, int> *bigramMap, const int unigramProbability) {
+    if (!bigramMap) return backoff(unigramProbability);
+    const hash_map_compat<int, int>::const_iterator bigramProbabilityIt = bigramMap->find(position);
+    if (bigramProbabilityIt != bigramMap->end()) {
+        const int bigramProbability = bigramProbabilityIt->second;
+        return computeProbabilityForBigram(unigramProbability, bigramProbability);
+    }
+    return backoff(unigramProbability);
+}
+
+AK_FORCE_INLINE void BinaryFormat::fillBigramProbabilityToHashMap(
+        const uint8_t *const root, int position, hash_map_compat<int, int> *bigramMap) {
+    position = getBigramListPositionForWordPosition(root, position);
+    if (0 == position) return;
+
+    uint8_t bigramFlags;
+    do {
+        bigramFlags = getFlagsAndForwardPointer(root, &position);
+        const int probability = MASK_ATTRIBUTE_PROBABILITY & bigramFlags;
+        const int bigramPos = getAttributeAddressAndForwardPointer(root, bigramFlags,
+                &position);
+        (*bigramMap)[bigramPos] = probability;
+    } while (FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags);
+}
+
+AK_FORCE_INLINE int BinaryFormat::getBigramProbability(const uint8_t *const root, int position,
+        const int nextPosition, const int unigramProbability) {
+    position = getBigramListPositionForWordPosition(root, position);
+    if (0 == position) return backoff(unigramProbability);
+
+    uint8_t bigramFlags;
+    do {
+        bigramFlags = getFlagsAndForwardPointer(root, &position);
+        const int bigramPos = getAttributeAddressAndForwardPointer(
+                root, bigramFlags, &position);
+        if (bigramPos == nextPosition) {
+            const int bigramProbability = MASK_ATTRIBUTE_PROBABILITY & bigramFlags;
+            return computeProbabilityForBigram(unigramProbability, bigramProbability);
+        }
+    } while (FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags);
+    return backoff(unigramProbability);
+}
+
+// Returns a pointer to the start of the bigram list.
+AK_FORCE_INLINE int BinaryFormat::getBigramListPositionForWordPosition(
+        const uint8_t *const root, int position) {
+    if (NOT_VALID_WORD == position) return 0;
+    const uint8_t flags = getFlagsAndForwardPointer(root, &position);
+    if (!(flags & FLAG_HAS_BIGRAMS)) return 0;
+    if (flags & FLAG_HAS_MULTIPLE_CHARS) {
+        position = skipOtherCharacters(root, position);
+    } else {
+        getCodePointAndForwardPointer(root, &position);
+    }
+    position = skipProbability(flags, position);
+    position = skipChildrenPosition(flags, position);
+    position = skipShortcuts(root, flags, position);
+    return position;
+}
+
 } // namespace latinime
 #endif // LATINIME_BINARY_FORMAT_H
diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h
index d3b351f..eb59744 100644
--- a/native/jni/src/defines.h
+++ b/native/jni/src/defines.h
@@ -379,6 +379,15 @@
 #error "BIGRAM_FILTER_MODULO is larger than BIGRAM_FILTER_BYTE_SIZE"
 #endif
 
+// Max number of bigram maps (previous word contexts) to be cached. Increasing this number could
+// improve bigram lookup speed for multi-word suggestions, but at the cost of more memory usage.
+// Also, there are diminishing returns since the most frequently used bigrams are typically near
+// the beginning of the input and are thus the first ones to be cached. Note that these bigrams
+// are reset for each new composing word.
+#define MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP 25
+// Most common previous word contexts currently have 100 bigrams
+#define DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP 100
+
 template<typename T> AK_FORCE_INLINE const T &min(const T &a, const T &b) { return a < b ? a : b; }
 template<typename T> AK_FORCE_INLINE const T &max(const T &a, const T &b) { return a > b ? a : b; }
 
diff --git a/native/jni/src/multi_bigram_map.h b/native/jni/src/multi_bigram_map.h
new file mode 100644
index 0000000..7e1b630
--- /dev/null
+++ b/native/jni/src/multi_bigram_map.h
@@ -0,0 +1,89 @@
+/*
+ * Copyright (C) 2013 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LATINIME_MULTI_BIGRAM_MAP_H
+#define LATINIME_MULTI_BIGRAM_MAP_H
+
+#include <cstring>
+#include <stdint.h>
+
+#include "defines.h"
+#include "binary_format.h"
+#include "hash_map_compat.h"
+
+namespace latinime {
+
+// Class for caching bigram maps for multiple previous word contexts. This is useful since the
+// algorithm needs to look up the set of bigrams for every word pair that occurs in every
+// multi-word suggestion.
+class MultiBigramMap {
+ public:
+    MultiBigramMap() : mBigramMaps() {}
+    ~MultiBigramMap() {}
+
+    // Look up the bigram probability for the given word pair from the cached bigram maps.
+    // Also caches the bigrams if there is space remaining and they have not been cached already.
+    int getBigramProbability(const uint8_t *const dicRoot, const int wordPosition,
+            const int nextWordPosition, const int unigramProbability) {
+        hash_map_compat<int, BigramMap>::const_iterator mapPosition =
+                mBigramMaps.find(wordPosition);
+        if (mapPosition != mBigramMaps.end()) {
+            return mapPosition->second.getBigramProbability(nextWordPosition, unigramProbability);
+        }
+        if (mBigramMaps.size() < MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP) {
+            addBigramsForWordPosition(dicRoot, wordPosition);
+            return mBigramMaps[wordPosition].getBigramProbability(
+                    nextWordPosition, unigramProbability);
+        }
+        return BinaryFormat::getBigramProbability(
+                dicRoot, wordPosition, nextWordPosition, unigramProbability);
+    }
+
+    void clear() {
+        mBigramMaps.clear();
+    }
+
+ private:
+    DISALLOW_COPY_AND_ASSIGN(MultiBigramMap);
+
+    class BigramMap {
+     public:
+        BigramMap() : mBigramMap(DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP) {}
+        ~BigramMap() {}
+
+        void init(const uint8_t *const dicRoot, int position) {
+            BinaryFormat::fillBigramProbabilityToHashMap(dicRoot, position, &mBigramMap);
+        }
+
+        inline int getBigramProbability(const int nextWordPosition, const int unigramProbability)
+                const {
+           return BinaryFormat::getBigramProbabilityFromHashMap(
+                   nextWordPosition, &mBigramMap, unigramProbability);
+        }
+
+     private:
+        // Note: Default copy constructor needed for use in hash_map.
+        hash_map_compat<int, int> mBigramMap;
+    };
+
+    void addBigramsForWordPosition(const uint8_t *const dicRoot, const int position) {
+        mBigramMaps[position].init(dicRoot, position);
+    }
+
+    hash_map_compat<int, BigramMap> mBigramMaps;
+};
+} // namespace latinime
+#endif // LATINIME_MULTI_BIGRAM_MAP_H
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp
index a048122..5357c37 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp
+++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp
@@ -21,6 +21,7 @@
 #include "dic_node.h"
 #include "dic_node_utils.h"
 #include "dic_node_vector.h"
+#include "multi_bigram_map.h"
 #include "proximity_info.h"
 #include "proximity_info_state.h"
 
@@ -191,11 +192,11 @@
  * Computes the combined bigram / unigram cost for the given dicNode.
  */
 /* static */ float DicNodeUtils::getBigramNodeImprobability(const uint8_t *const dicRoot,
-        const DicNode *const node, hash_map_compat<int, int16_t> *bigramCacheMap) {
+        const DicNode *const node, MultiBigramMap *multiBigramMap) {
     if (node->isImpossibleBigramWord()) {
         return static_cast<float>(MAX_VALUE_FOR_WEIGHTING);
     }
-    const int probability = getBigramNodeProbability(dicRoot, node, bigramCacheMap);
+    const int probability = getBigramNodeProbability(dicRoot, node, multiBigramMap);
     // TODO: This equation to calculate the improbability looks unreasonable.  Investigate this.
     const float cost = static_cast<float>(MAX_PROBABILITY - probability)
             / static_cast<float>(MAX_PROBABILITY);
@@ -203,83 +204,25 @@
 }
 
 /* static */ int DicNodeUtils::getBigramNodeProbability(const uint8_t *const dicRoot,
-        const DicNode *const node, hash_map_compat<int, int16_t> *bigramCacheMap) {
+        const DicNode *const node, MultiBigramMap *multiBigramMap) {
     const int unigramProbability = node->getProbability();
-    const int encodedDiffOfBigramProbability =
-            getBigramNodeEncodedDiffProbability(dicRoot, node, bigramCacheMap);
-    if (NOT_A_PROBABILITY == encodedDiffOfBigramProbability) {
+    const int wordPos = node->getPos();
+    const int prevWordPos = node->getPrevWordPos();
+    if (NOT_VALID_WORD == wordPos || NOT_VALID_WORD == prevWordPos) {
+        // Note: Normally wordPos comes from the dictionary and should never equal NOT_VALID_WORD.
         return backoff(unigramProbability);
     }
-    return BinaryFormat::computeProbabilityForBigram(
-            unigramProbability, encodedDiffOfBigramProbability);
+    if (multiBigramMap) {
+        return multiBigramMap->getBigramProbability(
+                dicRoot, prevWordPos, wordPos, unigramProbability);
+    }
+    return BinaryFormat::getBigramProbability(dicRoot, prevWordPos, wordPos, unigramProbability);
 }
 
 ///////////////////////////////////////
 // Bigram / Unigram dictionary utils //
 ///////////////////////////////////////
 
-/* static */ int16_t DicNodeUtils::getBigramNodeEncodedDiffProbability(const uint8_t *const dicRoot,
-        const DicNode *const node, hash_map_compat<int, int16_t> *bigramCacheMap) {
-    const int wordPos = node->getPos();
-    const int prevWordPos = node->getPrevWordPos();
-    return getBigramProbability(dicRoot, prevWordPos, wordPos, bigramCacheMap);
-}
-
-// TODO: Move this to BigramDictionary
-/* static */ int16_t DicNodeUtils::getBigramProbability(const uint8_t *const dicRoot, int pos,
-        const int nextPos, hash_map_compat<int, int16_t> *bigramCacheMap) {
-    // TODO: this is painfully slow compared to the method used in the previous version of the
-    // algorithm. Switch to that method.
-    if (NOT_VALID_WORD == pos) return NOT_A_PROBABILITY;
-    if (NOT_VALID_WORD == nextPos) return NOT_A_PROBABILITY;
-
-    // Create a hash code for the given node pair (based on Josh Bloch's effective Java).
-    // TODO: Use a real hash map data structure that deals with collisions.
-    int hash = 17;
-    hash = hash * 31 + pos;
-    hash = hash * 31 + nextPos;
-
-    hash_map_compat<int, int16_t>::const_iterator mapPos = bigramCacheMap->find(hash);
-    if (mapPos != bigramCacheMap->end()) {
-        return mapPos->second;
-    }
-    if (NOT_VALID_WORD == pos) {
-        return NOT_A_PROBABILITY;
-    }
-    const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(dicRoot, &pos);
-    if (0 == (flags & BinaryFormat::FLAG_HAS_BIGRAMS)) {
-        return NOT_A_PROBABILITY;
-    }
-    if (0 == (flags & BinaryFormat::FLAG_HAS_MULTIPLE_CHARS)) {
-        BinaryFormat::getCodePointAndForwardPointer(dicRoot, &pos);
-    } else {
-        pos = BinaryFormat::skipOtherCharacters(dicRoot, pos);
-    }
-    pos = BinaryFormat::skipChildrenPosition(flags, pos);
-    pos = BinaryFormat::skipProbability(flags, pos);
-    uint8_t bigramFlags;
-    int count = 0;
-    do {
-        bigramFlags = BinaryFormat::getFlagsAndForwardPointer(dicRoot, &pos);
-        const int bigramPos = BinaryFormat::getAttributeAddressAndForwardPointer(dicRoot,
-                bigramFlags, &pos);
-        if (bigramPos == nextPos) {
-            const int16_t probability = BinaryFormat::MASK_ATTRIBUTE_PROBABILITY & bigramFlags;
-            if (static_cast<int>(bigramCacheMap->size()) < MAX_BIGRAM_MAP_SIZE) {
-                (*bigramCacheMap)[hash] = probability;
-            }
-            return probability;
-        }
-        count++;
-    } while ((BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags)
-            && count < MAX_BIGRAMS_CONSIDERED_PER_CONTEXT);
-    if (static_cast<int>(bigramCacheMap->size()) < MAX_BIGRAM_MAP_SIZE) {
-        // TODO: does this -1 mean NOT_VALID_WORD?
-        (*bigramCacheMap)[hash] = -1;
-    }
-    return NOT_A_PROBABILITY;
-}
-
 /* static */ bool DicNodeUtils::isMatchedNodeCodePoint(const ProximityInfoState *pInfoState,
         const int pointIndex, const bool exactOnly, const int nodeCodePoint) {
     if (!pInfoState) {
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.h b/native/jni/src/suggest/core/dicnode/dic_node_utils.h
index 2e6361d..5bc542d 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_utils.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.h
@@ -21,7 +21,6 @@
 #include <vector>
 
 #include "defines.h"
-#include "hash_map_compat.h"
 
 namespace latinime {
 
@@ -29,6 +28,7 @@
 class DicNodeVector;
 class ProximityInfo;
 class ProximityInfoState;
+class MultiBigramMap;
 
 class DicNodeUtils {
  public:
@@ -42,7 +42,7 @@
     static void getAllChildDicNodes(DicNode *dicNode, const uint8_t *const dicRoot,
             DicNodeVector *childDicNodes);
     static float getBigramNodeImprobability(const uint8_t *const dicRoot,
-            const DicNode *const node, hash_map_compat<int, int16_t> *const bigramCacheMap);
+            const DicNode *const node, MultiBigramMap *const multiBigramMap);
     static bool isDicNodeFilteredOut(const int nodeCodePoint, const ProximityInfo *const pInfo,
             const std::vector<int> *const codePointsFilter);
     // TODO: Move to private
@@ -57,15 +57,11 @@
 
  private:
     DISALLOW_IMPLICIT_CONSTRUCTORS(DicNodeUtils);
-    // Max cache size for the space omission error correction bigram lookup
-    static const int MAX_BIGRAM_MAP_SIZE = 20000;
     // Max number of bigrams to look up
     static const int MAX_BIGRAMS_CONSIDERED_PER_CONTEXT = 500;
 
     static int getBigramNodeProbability(const uint8_t *const dicRoot, const DicNode *const node,
-            hash_map_compat<int, int16_t> *bigramCacheMap);
-    static int16_t getBigramNodeEncodedDiffProbability(const uint8_t *const dicRoot,
-            const DicNode *const node, hash_map_compat<int, int16_t> *bigramCacheMap);
+            MultiBigramMap *multiBigramMap);
     static void createAndGetPassingChildNode(DicNode *dicNode, const ProximityInfoState *pInfoState,
             const int pointIndex, const bool exactOnly, DicNodeVector *childDicNodes);
     static void createAndGetAllLeavingChildNodes(DicNode *dicNode, const uint8_t *const dicRoot,
@@ -76,8 +72,6 @@
             const int terminalDepth, const ProximityInfoState *pInfoState, const int pointIndex,
             const bool exactOnly, const std::vector<int> *const codePointsFilter,
             const ProximityInfo *const pInfo, DicNodeVector *childDicNodes);
-    static int16_t getBigramProbability(const uint8_t *const dicRoot, int pos, const int nextPos,
-            hash_map_compat<int, int16_t> *bigramCacheMap);
 
     // TODO: Move to proximity info
     static bool isMatchedNodeCodePoint(const ProximityInfoState *pInfoState, const int pointIndex,
diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp
index 4912b22..d01531f 100644
--- a/native/jni/src/suggest/core/policy/weighting.cpp
+++ b/native/jni/src/suggest/core/policy/weighting.cpp
@@ -18,7 +18,6 @@
 
 #include "char_utils.h"
 #include "defines.h"
-#include "hash_map_compat.h"
 #include "suggest/core/dicnode/dic_node.h"
 #include "suggest/core/dicnode/dic_node_profiler.h"
 #include "suggest/core/dicnode/dic_node_utils.h"
@@ -26,6 +25,8 @@
 
 namespace latinime {
 
+class MultiBigramMap;
+
 static inline void profile(const CorrectionType correctionType, DicNode *const node) {
 #if DEBUG_DICT
     switch (correctionType) {
@@ -71,14 +72,14 @@
 /* static */ void Weighting::addCostAndForwardInputIndex(const Weighting *const weighting,
         const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
         const DicNode *const parentDicNode, DicNode *const dicNode,
-        hash_map_compat<int, int16_t> *const bigramCacheMap) {
+        MultiBigramMap *const multiBigramMap) {
     const int inputSize = traverseSession->getInputSize();
     DicNode_InputStateG inputStateG;
     inputStateG.mNeedsToUpdateInputStateG = false; // Don't use input info by default
     const float spatialCost = Weighting::getSpatialCost(weighting, correctionType,
             traverseSession, parentDicNode, dicNode, &inputStateG);
     const float languageCost = Weighting::getLanguageCost(weighting, correctionType,
-            traverseSession, parentDicNode, dicNode, bigramCacheMap);
+            traverseSession, parentDicNode, dicNode, multiBigramMap);
     const ErrorType errorType = weighting->getErrorType(correctionType, traverseSession,
             parentDicNode, dicNode);
     profile(correctionType, dicNode);
@@ -127,14 +128,14 @@
 /* static */ float Weighting::getLanguageCost(const Weighting *const weighting,
         const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
         const DicNode *const parentDicNode, const DicNode *const dicNode,
-        hash_map_compat<int, int16_t> *const bigramCacheMap) {
+        MultiBigramMap *const multiBigramMap) {
     switch(correctionType) {
     case CT_OMISSION:
         return 0.0f;
     case CT_SUBSTITUTION:
         return 0.0f;
     case CT_NEW_WORD_SPACE_OMITTION:
-        return weighting->getNewWordBigramCost(traverseSession, parentDicNode, bigramCacheMap);
+        return weighting->getNewWordBigramCost(traverseSession, parentDicNode, multiBigramMap);
     case CT_MATCH:
         return 0.0f;
     case CT_COMPLETION:
@@ -142,11 +143,11 @@
     case CT_TERMINAL: {
         const float languageImprobability =
                 DicNodeUtils::getBigramNodeImprobability(
-                        traverseSession->getOffsetDict(), dicNode, bigramCacheMap);
+                        traverseSession->getOffsetDict(), dicNode, multiBigramMap);
         return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability);
     }
     case CT_NEW_WORD_SPACE_SUBSTITUTION:
-        return weighting->getNewWordBigramCost(traverseSession, parentDicNode, bigramCacheMap);
+        return weighting->getNewWordBigramCost(traverseSession, parentDicNode, multiBigramMap);
     case CT_INSERTION:
         return 0.0f;
     case CT_TRANSPOSITION:
diff --git a/native/jni/src/suggest/core/policy/weighting.h b/native/jni/src/suggest/core/policy/weighting.h
index 6e740d9..0d2745b 100644
--- a/native/jni/src/suggest/core/policy/weighting.h
+++ b/native/jni/src/suggest/core/policy/weighting.h
@@ -18,13 +18,13 @@
 #define LATINIME_WEIGHTING_H
 
 #include "defines.h"
-#include "hash_map_compat.h"
 
 namespace latinime {
 
 class DicNode;
 class DicTraverseSession;
 struct DicNode_InputStateG;
+class MultiBigramMap;
 
 class Weighting {
  public:
@@ -32,7 +32,7 @@
             const CorrectionType correctionType,
             const DicTraverseSession *const traverseSession,
             const DicNode *const parentDicNode, DicNode *const dicNode,
-            hash_map_compat<int, int16_t> *const bigramCacheMap);
+            MultiBigramMap *const multiBigramMap);
 
  protected:
     virtual float getTerminalSpatialCost(const DicTraverseSession *const traverseSession,
@@ -61,7 +61,7 @@
 
     virtual float getNewWordBigramCost(
             const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
-            hash_map_compat<int, int16_t> *const bigramCacheMap) const = 0;
+            MultiBigramMap *const multiBigramMap) const = 0;
 
     virtual float getCompletionCost(
             const DicTraverseSession *const traverseSession,
@@ -97,7 +97,7 @@
     static float getLanguageCost(const Weighting *const weighting,
             const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
             const DicNode *const parentDicNode, const DicNode *const dicNode,
-            hash_map_compat<int, int16_t> *const bigramCacheMap);
+            MultiBigramMap *const multiBigramMap);
     // TODO: Move to TypingWeighting and GestureWeighting?
     static int getForwardInputCount(const CorrectionType correctionType);
 };
diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp
index b3d4732..5116585 100644
--- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp
+++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp
@@ -100,7 +100,7 @@
 
 void DicTraverseSession::resetCache(const int nextActiveCacheSize, const int maxWords) {
     mDicNodesCache.reset(nextActiveCacheSize, maxWords);
-    mBigramCacheMap.clear();
+    mMultiBigramMap.clear();
     mPartiallyCommited = false;
 }
 
diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h
index d9c2a51..d88be5b 100644
--- a/native/jni/src/suggest/core/session/dic_traverse_session.h
+++ b/native/jni/src/suggest/core/session/dic_traverse_session.h
@@ -21,8 +21,8 @@
 #include <vector>
 
 #include "defines.h"
-#include "hash_map_compat.h"
 #include "jni.h"
+#include "multi_bigram_map.h"
 #include "proximity_info_state.h"
 #include "suggest/core/dicnode/dic_nodes_cache.h"
 
@@ -35,7 +35,7 @@
  public:
     AK_FORCE_INLINE DicTraverseSession(JNIEnv *env, jstring localeStr)
             : mPrevWordPos(NOT_VALID_WORD), mProximityInfo(0),
-              mDictionary(0), mDicNodesCache(), mBigramCacheMap(),
+              mDictionary(0), mDicNodesCache(), mMultiBigramMap(),
               mInputSize(0), mPartiallyCommited(false), mMaxPointerCount(1),
               mMultiWordCostMultiplier(1.0f) {
         // NOTE: mProximityInfoStates is an array of instances.
@@ -67,7 +67,7 @@
     // TODO: Use proper parameter when changed
     int getDicRootPos() const { return 0; }
     DicNodesCache *getDicTraverseCache() { return &mDicNodesCache; }
-    hash_map_compat<int, int16_t> *getBigramCacheMap() { return &mBigramCacheMap; }
+    MultiBigramMap *getMultiBigramMap() { return &mMultiBigramMap; }
     const ProximityInfoState *getProximityInfoState(int id) const {
         return &mProximityInfoStates[id];
     }
@@ -170,7 +170,7 @@
 
     DicNodesCache mDicNodesCache;
     // Temporary cache for bigram frequencies
-    hash_map_compat<int, int16_t> mBigramCacheMap;
+    MultiBigramMap mMultiBigramMap;
     ProximityInfoState mProximityInfoStates[MAX_POINTER_COUNT_G];
 
     int mInputSize;
diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp
index 4f94a9a..3221dee 100644
--- a/native/jni/src/suggest/core/suggest.cpp
+++ b/native/jni/src/suggest/core/suggest.cpp
@@ -359,7 +359,7 @@
     DicNode terminalDicNode;
     DicNodeUtils::initByCopy(dicNode, &terminalDicNode);
     Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL, traverseSession, 0,
-            &terminalDicNode, traverseSession->getBigramCacheMap());
+            &terminalDicNode, traverseSession->getMultiBigramMap());
     traverseSession->getDicTraverseCache()->copyPushTerminal(&terminalDicNode);
 }
 
@@ -391,8 +391,10 @@
 
 void Suggest::processDicNodeAsAdditionalProximityChar(DicTraverseSession *traverseSession,
         DicNode *dicNode, DicNode *childDicNode) const {
+    // Note: Most types of corrections don't need to look up the bigram information since they do
+    // not treat the node as a terminal. There is no need to pass the bigram map in these cases.
     Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_ADDITIONAL_PROXIMITY,
-            traverseSession, dicNode, childDicNode, 0 /* bigramCacheMap */);
+            traverseSession, dicNode, childDicNode, 0 /* multiBigramMap */);
     weightChildNode(traverseSession, childDicNode);
     processExpandedDicNode(traverseSession, childDicNode);
 }
@@ -400,7 +402,7 @@
 void Suggest::processDicNodeAsSubstitution(DicTraverseSession *traverseSession,
         DicNode *dicNode, DicNode *childDicNode) const {
     Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_SUBSTITUTION, traverseSession,
-            dicNode, childDicNode, 0 /* bigramCacheMap */);
+            dicNode, childDicNode, 0 /* multiBigramMap */);
     weightChildNode(traverseSession, childDicNode);
     processExpandedDicNode(traverseSession, childDicNode);
 }
@@ -432,7 +434,7 @@
         DicNode *const childDicNode = childDicNodes[i];
         // Treat this word as omission
         Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_OMISSION, traverseSession,
-                dicNode, childDicNode, 0 /* bigramCacheMap */);
+                dicNode, childDicNode, 0 /* multiBigramMap */);
         weightChildNode(traverseSession, childDicNode);
 
         if (!TRAVERSAL->isPossibleOmissionChildNode(traverseSession, dicNode, childDicNode)) {
@@ -456,7 +458,7 @@
     for (int i = 0; i < size; i++) {
         DicNode *const childDicNode = childDicNodes[i];
         Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_INSERTION, traverseSession,
-                dicNode, childDicNode, 0 /* bigramCacheMap */);
+                dicNode, childDicNode, 0 /* multiBigramMap */);
         processExpandedDicNode(traverseSession, childDicNode);
     }
 }
@@ -481,7 +483,7 @@
             for (int j = 0; j < childSize2; j++) {
                 DicNode *const childDicNode2 = childDicNodes2[j];
                 Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TRANSPOSITION,
-                        traverseSession, childDicNodes1[i], childDicNode2, 0 /* bigramCacheMap */);
+                        traverseSession, childDicNodes1[i], childDicNode2, 0 /* multiBigramMap */);
                 processExpandedDicNode(traverseSession, childDicNode2);
             }
         }
@@ -496,10 +498,10 @@
     const int inputSize = traverseSession->getInputSize();
     if (dicNode->isCompletion(inputSize)) {
         Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_COMPLETION, traverseSession,
-                0 /* parentDicNode */, dicNode, 0 /* bigramCacheMap */);
+                0 /* parentDicNode */, dicNode, 0 /* multiBigramMap */);
     } else { // completion
         Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_MATCH, traverseSession,
-                0 /* parentDicNode */, dicNode, 0 /* bigramCacheMap */);
+                0 /* parentDicNode */, dicNode, 0 /* multiBigramMap */);
     }
 }
 
@@ -520,7 +522,7 @@
     const CorrectionType correctionType = spaceSubstitution ?
             CT_NEW_WORD_SPACE_SUBSTITUTION : CT_NEW_WORD_SPACE_OMITTION;
     Weighting::addCostAndForwardInputIndex(WEIGHTING, correctionType, traverseSession, dicNode,
-            &newDicNode, traverseSession->getBigramCacheMap());
+            &newDicNode, traverseSession->getMultiBigramMap());
     traverseSession->getDicTraverseCache()->copyPushNextActive(&newDicNode);
 }
 } // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
index 9efcc17..e6fa1bd 100644
--- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
+++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
@@ -28,6 +28,7 @@
 
 class DicNode;
 struct DicNode_InputStateG;
+class MultiBigramMap;
 
 class TypingWeighting : public Weighting {
  public:
@@ -136,9 +137,9 @@
 
     float getNewWordBigramCost(const DicTraverseSession *const traverseSession,
             const DicNode *const dicNode,
-            hash_map_compat<int, int16_t> *const bigramCacheMap) const {
+            MultiBigramMap *const multiBigramMap) const {
         return DicNodeUtils::getBigramNodeImprobability(traverseSession->getOffsetDict(),
-                dicNode, bigramCacheMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
+                dicNode, multiBigramMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
     }
 
     float getCompletionCost(const DicTraverseSession *const traverseSession,