am 5064ac88: Merge "Be careful about the dictionary size in detection methods"

* commit '5064ac885561d4b6af216d5e96ed94f17ac8e13f':
  Be careful about the dictionary size in detection methods
diff --git a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp
index 11fa3da..1dd68ea 100644
--- a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp
+++ b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp
@@ -109,7 +109,8 @@
     }
     Dictionary *dictionary = 0;
     if (BinaryFormat::UNKNOWN_FORMAT
-            == BinaryFormat::detectFormat(static_cast<uint8_t *>(dictBuf))) {
+            == BinaryFormat::detectFormat(static_cast<uint8_t *>(dictBuf),
+                    static_cast<int>(dictSize))) {
         AKLOGE("DICT: dictionary format is unknown, bad magic number");
 #ifdef USE_MMAP_FOR_DICTIONARY
         releaseDictBuf(static_cast<const char *>(dictBuf) - adjust, adjDictSize, fd);
diff --git a/native/jni/src/binary_format.h b/native/jni/src/binary_format.h
index 06f50dc..9824153 100644
--- a/native/jni/src/binary_format.h
+++ b/native/jni/src/binary_format.h
@@ -64,13 +64,14 @@
     static const int UNKNOWN_FORMAT = -1;
     static const int SHORTCUT_LIST_SIZE_SIZE = 2;
 
-    static int detectFormat(const uint8_t *const dict);
-    static int getHeaderSize(const uint8_t *const dict);
-    static int getFlags(const uint8_t *const dict);
+    static int detectFormat(const uint8_t *const dict, const int dictSize);
+    static int getHeaderSize(const uint8_t *const dict, const int dictSize);
+    static int getFlags(const uint8_t *const dict, const int dictSize);
     static bool hasBlacklistedOrNotAWordFlag(const int flags);
-    static void readHeaderValue(const uint8_t *const dict, const char *const key, int *outValue,
-            const int outValueSize);
-    static int readHeaderValueInt(const uint8_t *const dict, const char *const key);
+    static void readHeaderValue(const uint8_t *const dict, const int dictSize,
+            const char *const key, int *outValue, const int outValueSize);
+    static int readHeaderValueInt(const uint8_t *const dict, const int dictSize,
+            const char *const key);
     static int getGroupCountAndForwardPointer(const uint8_t *const dict, int *pos);
     static uint8_t getFlagsAndForwardPointer(const uint8_t *const dict, int *pos);
     static int getCodePointAndForwardPointer(const uint8_t *const dict, int *pos);
@@ -96,7 +97,7 @@
             const uint8_t *bigramFilter, const int unigramProbability);
     static int getBigramProbabilityFromHashMap(const int position,
             const hash_map_compat<int, int> *bigramMap, const int unigramProbability);
-    static float getMultiWordCostMultiplier(const uint8_t *const dict);
+    static float getMultiWordCostMultiplier(const uint8_t *const dict, const int dictSize);
     static void fillBigramProbabilityToHashMap(const uint8_t *const root, int position,
             hash_map_compat<int, int> *bigramMap);
     static int getBigramProbability(const uint8_t *const root, int position,
@@ -122,6 +123,8 @@
     static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES = 0x20;
     static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES = 0x30;
 
+    // Any file smaller than this is not a dictionary.
+    static const int DICTIONARY_MINIMUM_SIZE = 4;
     // Originally, format version 1 had a 16-bit magic number, then the version number `01'
     // then options that must be 0. Hence the first 32-bits of the format are always as follow
     // and it's okay to consider them a magic number as a whole.
@@ -131,6 +134,8 @@
     // number, so we had to change it so that version 2 files would be rejected by older
     // implementations. On this occasion, we made the magic number 32 bits long.
     static const int FORMAT_VERSION_2_MAGIC_NUMBER = -1681835266; // 0x9BC13AFE
+    // Magic number (4 bytes), version (2 bytes), options (2 bytes), header size (4 bytes) = 12
+    static const int FORMAT_VERSION_2_MINIMUM_SIZE = 12;
 
     static const int CHARACTER_ARRAY_TERMINATOR_SIZE = 1;
     static const int MINIMAL_ONE_BYTE_CHARACTER_VALUE = 0x20;
@@ -141,8 +146,11 @@
     static int skipBigrams(const uint8_t *const dict, const uint8_t flags, const int pos);
 };
 
-AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) {
+AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict, const int dictSize) {
     // The magic number is stored big-endian.
+    // If the dictionary is less than 4 bytes, we can't even read the magic number, so we don't
+    // understand this format.
+    if (dictSize < DICTIONARY_MINIMUM_SIZE) return UNKNOWN_FORMAT;
     const int magicNumber = (dict[0] << 24) + (dict[1] << 16) + (dict[2] << 8) + dict[3];
     switch (magicNumber) {
     case FORMAT_VERSION_1_MAGIC_NUMBER:
@@ -152,6 +160,10 @@
         // Options (2 bytes) must be 0x00 0x00
         return 1;
     case FORMAT_VERSION_2_MAGIC_NUMBER:
+        // Version 2 dictionaries are at least 12 bytes long (see below details for the header).
+        // If this dictionary has the version 2 magic number but is less than 12 bytes long, then
+        // it's an unknown format and we need to avoid confidently reading the next bytes.
+        if (dictSize < FORMAT_VERSION_2_MINIMUM_SIZE) return UNKNOWN_FORMAT;
         // Format 2 header is as follows:
         // Magic number (4 bytes) 0x9B 0xC1 0x3A 0xFE
         // Version number (2 bytes) 0x00 0x02
@@ -163,8 +175,8 @@
     }
 }
 
-inline int BinaryFormat::getFlags(const uint8_t *const dict) {
-    switch (detectFormat(dict)) {
+inline int BinaryFormat::getFlags(const uint8_t *const dict, const int dictSize) {
+    switch (detectFormat(dict, dictSize)) {
     case 1:
         return NO_FLAGS; // TODO: NO_FLAGS is unused anywhere else?
     default:
@@ -176,8 +188,8 @@
     return (flags & (FLAG_IS_BLACKLISTED | FLAG_IS_NOT_A_WORD)) != 0;
 }
 
-inline int BinaryFormat::getHeaderSize(const uint8_t *const dict) {
-    switch (detectFormat(dict)) {
+inline int BinaryFormat::getHeaderSize(const uint8_t *const dict, const int dictSize) {
+    switch (detectFormat(dict, dictSize)) {
     case 1:
         return FORMAT_VERSION_1_HEADER_SIZE;
     case 2:
@@ -188,12 +200,12 @@
     }
 }
 
-inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const char *const key,
-        int *outValue, const int outValueSize) {
+inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const int dictSize,
+        const char *const key, int *outValue, const int outValueSize) {
     int outValueIndex = 0;
     // Only format 2 and above have header attributes as {key,value} string pairs. For prior
     // formats, we just return an empty string, as if the key wasn't found.
-    if (2 <= detectFormat(dict)) {
+    if (2 <= detectFormat(dict, dictSize)) {
         const int headerOptionsOffset = 4 /* magic number */
                 + 2 /* dictionary version */ + 2 /* flags */;
         const int headerSize =
@@ -236,11 +248,12 @@
     if (outValueIndex >= 0) outValue[outValueIndex] = 0;
 }
 
-inline int BinaryFormat::readHeaderValueInt(const uint8_t *const dict, const char *const key) {
+inline int BinaryFormat::readHeaderValueInt(const uint8_t *const dict, const int dictSize,
+        const char *const key) {
     const int bufferSize = LARGEST_INT_DIGIT_COUNT;
     int intBuffer[bufferSize];
     char charBuffer[bufferSize];
-    BinaryFormat::readHeaderValue(dict, key, intBuffer, bufferSize);
+    BinaryFormat::readHeaderValue(dict, dictSize, key, intBuffer, bufferSize);
     for (int i = 0; i < bufferSize; ++i) {
         charBuffer[i] = intBuffer[i];
     }
@@ -256,8 +269,10 @@
     return ((msb & 0x7F) << 8) | dict[(*pos)++];
 }
 
-inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict) {
-    const int headerValue = readHeaderValueInt(dict, "MULTIPLE_WORDS_DEMOTION_RATE");
+inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict,
+        const int dictSize) {
+    const int headerValue = readHeaderValueInt(dict, dictSize,
+            "MULTIPLE_WORDS_DEMOTION_RATE");
     if (headerValue == S_INT_MIN) {
         return 1.0f;
     }
diff --git a/native/jni/src/dictionary.cpp b/native/jni/src/dictionary.cpp
index c998c06..dadb2ba 100644
--- a/native/jni/src/dictionary.cpp
+++ b/native/jni/src/dictionary.cpp
@@ -34,9 +34,11 @@
 
 Dictionary::Dictionary(void *dict, int dictSize, int mmapFd, int dictBufAdjust)
         : mDict(static_cast<unsigned char *>(dict)),
-          mOffsetDict((static_cast<unsigned char *>(dict)) + BinaryFormat::getHeaderSize(mDict)),
+          mOffsetDict((static_cast<unsigned char *>(dict))
+                  + BinaryFormat::getHeaderSize(mDict, dictSize)),
           mDictSize(dictSize), mMmapFd(mmapFd), mDictBufAdjust(dictBufAdjust),
-          mUnigramDictionary(new UnigramDictionary(mOffsetDict, BinaryFormat::getFlags(mDict))),
+          mUnigramDictionary(new UnigramDictionary(mOffsetDict,
+                  BinaryFormat::getFlags(mDict, dictSize))),
           mBigramDictionary(new BigramDictionary(mOffsetDict)),
           mGestureSuggest(new Suggest(GestureSuggestPolicyFactory::getGestureSuggestPolicy())),
           mTypingSuggest(new Suggest(TypingSuggestPolicyFactory::getTypingSuggestPolicy())) {
diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp
index 5116585..6408f01 100644
--- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp
+++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp
@@ -64,7 +64,8 @@
 void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord,
         int prevWordLength) {
     mDictionary = dictionary;
-    mMultiWordCostMultiplier = BinaryFormat::getMultiWordCostMultiplier(mDictionary->getDict());
+    mMultiWordCostMultiplier = BinaryFormat::getMultiWordCostMultiplier(mDictionary->getDict(),
+            mDictionary->getDictSize());
     if (!prevWord) {
         mPrevWordPos = NOT_VALID_WORD;
         return;