| #include <algorithm> |
| #include <functional> |
| #include <queue> |
| #include <stdexcept> |
| |
| #include "range.h" |
| #include "trie.h" |
| |
| namespace marisa { |
| |
| void Trie::build(const char * const *keys, std::size_t num_keys, |
| const std::size_t *key_lengths, const double *key_weights, |
| UInt32 *key_ids, int flags) { |
| MARISA_THROW_IF((keys == NULL) && (num_keys != 0), MARISA_PARAM_ERROR); |
| Vector<Key<String> > temp_keys; |
| temp_keys.resize(num_keys); |
| for (std::size_t i = 0; i < temp_keys.size(); ++i) { |
| MARISA_THROW_IF(keys[i] == NULL, MARISA_PARAM_ERROR); |
| std::size_t length = 0; |
| if (key_lengths == NULL) { |
| while (keys[i][length] != '\0') { |
| ++length; |
| } |
| } else { |
| length = key_lengths[i]; |
| } |
| MARISA_THROW_IF(length > MARISA_MAX_LENGTH, MARISA_SIZE_ERROR); |
| temp_keys[i].set_str(String(keys[i], length)); |
| temp_keys[i].set_weight((key_weights != NULL) ? key_weights[i] : 1.0); |
| } |
| build_trie(temp_keys, key_ids, flags); |
| } |
| |
| void Trie::build(const std::vector<std::string> &keys, |
| std::vector<UInt32> *key_ids, int flags) { |
| Vector<Key<String> > temp_keys; |
| temp_keys.resize(keys.size()); |
| for (std::size_t i = 0; i < temp_keys.size(); ++i) { |
| MARISA_THROW_IF(keys[i].length() > MARISA_MAX_LENGTH, MARISA_SIZE_ERROR); |
| temp_keys[i].set_str(String(keys[i].c_str(), keys[i].length())); |
| temp_keys[i].set_weight(1.0); |
| } |
| build_trie(temp_keys, key_ids, flags); |
| } |
| |
| void Trie::build(const std::vector<std::pair<std::string, double> > &keys, |
| std::vector<UInt32> *key_ids, int flags) { |
| Vector<Key<String> > temp_keys; |
| temp_keys.resize(keys.size()); |
| for (std::size_t i = 0; i < temp_keys.size(); ++i) { |
| MARISA_THROW_IF(keys[i].first.length() > MARISA_MAX_LENGTH, |
| MARISA_SIZE_ERROR); |
| temp_keys[i].set_str(String( |
| keys[i].first.c_str(), keys[i].first.length())); |
| temp_keys[i].set_weight(keys[i].second); |
| } |
| build_trie(temp_keys, key_ids, flags); |
| } |
| |
| void Trie::build_trie(Vector<Key<String> > &keys, |
| std::vector<UInt32> *key_ids, int flags) { |
| if (key_ids == NULL) { |
| build_trie(keys, static_cast<UInt32 *>(NULL), flags); |
| return; |
| } |
| std::vector<UInt32> temp_key_ids(keys.size()); |
| build_trie(keys, temp_key_ids.empty() ? NULL : &temp_key_ids[0], flags); |
| key_ids->swap(temp_key_ids); |
| } |
| |
| void Trie::build_trie(Vector<Key<String> > &keys, |
| UInt32 *key_ids, int flags) { |
| Trie temp; |
| Vector<UInt32> terminals; |
| Progress progress(flags); |
| MARISA_THROW_IF(!progress.is_valid(), MARISA_PARAM_ERROR); |
| temp.build_trie(keys, &terminals, progress); |
| |
| typedef std::pair<UInt32, UInt32> TerminalIdPair; |
| Vector<TerminalIdPair> pairs; |
| pairs.resize(terminals.size()); |
| for (UInt32 i = 0; i < pairs.size(); ++i) { |
| pairs[i].first = terminals[i]; |
| pairs[i].second = i; |
| } |
| terminals.clear(); |
| std::sort(pairs.begin(), pairs.end()); |
| |
| UInt32 node = 0; |
| for (UInt32 i = 0; i < pairs.size(); ++i) { |
| while (node < pairs[i].first) { |
| temp.terminal_flags_.push_back(false); |
| ++node; |
| } |
| if (node == pairs[i].first) { |
| temp.terminal_flags_.push_back(true); |
| ++node; |
| } |
| } |
| while (node < temp.labels_.size()) { |
| temp.terminal_flags_.push_back(false); |
| ++node; |
| } |
| terminal_flags_.push_back(false); |
| temp.terminal_flags_.build(); |
| temp.terminal_flags_.clear_select0s(); |
| progress.test_total_size(temp.terminal_flags_.total_size()); |
| |
| if (key_ids != NULL) { |
| for (UInt32 i = 0; i < pairs.size(); ++i) { |
| key_ids[pairs[i].second] = temp.node_to_key_id(pairs[i].first); |
| } |
| } |
| MARISA_THROW_IF(progress.total_size() != temp.total_size(), |
| MARISA_UNEXPECTED_ERROR); |
| temp.swap(this); |
| } |
| |
| template <typename T> |
| void Trie::build_trie(Vector<Key<T> > &keys, |
| Vector<UInt32> *terminals, Progress &progress) { |
| build_cur(keys, terminals, progress); |
| progress.test_total_size(louds_.total_size()); |
| progress.test_total_size(sizeof(num_first_branches_)); |
| progress.test_total_size(sizeof(num_keys_)); |
| if (link_flags_.empty()) { |
| labels_.shrink(); |
| progress.test_total_size(labels_.total_size()); |
| progress.test_total_size(link_flags_.total_size()); |
| progress.test_total_size(links_.total_size()); |
| progress.test_total_size(tail_.total_size()); |
| return; |
| } |
| |
| Vector<UInt32> next_terminals; |
| build_next(keys, &next_terminals, progress); |
| |
| if (has_trie()) { |
| progress.test_total_size(trie_->terminal_flags_.total_size()); |
| } else if (tail_.mode() == MARISA_BINARY_TAIL) { |
| labels_.push_back('\0'); |
| link_flags_.push_back(true); |
| } |
| link_flags_.build(); |
| |
| for (UInt32 i = 0; i < next_terminals.size(); ++i) { |
| labels_[link_flags_.select1(i)] = (UInt8)(next_terminals[i] % 256); |
| next_terminals[i] /= 256; |
| } |
| link_flags_.clear_select0s(); |
| if (has_trie() || (tail_.mode() == MARISA_TEXT_TAIL)) { |
| link_flags_.clear_select1s(); |
| } |
| |
| links_.build(next_terminals); |
| labels_.shrink(); |
| progress.test_total_size(labels_.total_size()); |
| progress.test_total_size(link_flags_.total_size()); |
| progress.test_total_size(links_.total_size()); |
| progress.test_total_size(tail_.total_size()); |
| } |
| |
| template <typename T> |
| void Trie::build_cur(Vector<Key<T> > &keys, |
| Vector<UInt32> *terminals, Progress &progress) { |
| num_keys_ = sort_keys(keys); |
| louds_.push_back(true); |
| louds_.push_back(false); |
| labels_.push_back('\0'); |
| link_flags_.push_back(false); |
| |
| Vector<Key<T> > rest_keys; |
| std::queue<Range> queue; |
| Vector<WRange> wranges; |
| queue.push(Range(0, (UInt32)keys.size(), 0)); |
| while (!queue.empty()) { |
| const UInt32 node = (UInt32)(link_flags_.size() - queue.size()); |
| Range range = queue.front(); |
| queue.pop(); |
| |
| while ((range.begin() < range.end()) && |
| (keys[range.begin()].str().length() == range.pos())) { |
| keys[range.begin()].set_terminal(node); |
| range.set_begin(range.begin() + 1); |
| } |
| if (range.begin() == range.end()) { |
| louds_.push_back(false); |
| continue; |
| } |
| |
| wranges.clear(); |
| double weight = keys[range.begin()].weight(); |
| for (UInt32 i = range.begin() + 1; i < range.end(); ++i) { |
| if (keys[i - 1].str()[range.pos()] != keys[i].str()[range.pos()]) { |
| wranges.push_back(WRange(range.begin(), i, range.pos(), weight)); |
| range.set_begin(i); |
| weight = 0.0; |
| } |
| weight += keys[i].weight(); |
| } |
| wranges.push_back(WRange(range, weight)); |
| if (progress.order() == MARISA_WEIGHT_ORDER) { |
| std::stable_sort(wranges.begin(), wranges.end(), std::greater<WRange>()); |
| } |
| if (node == 0) { |
| num_first_branches_ = wranges.size(); |
| } |
| for (UInt32 i = 0; i < wranges.size(); ++i) { |
| const WRange &wrange = wranges[i]; |
| UInt32 pos = wrange.pos() + 1; |
| if ((progress.tail() != MARISA_WITHOUT_TAIL) || !progress.is_last()) { |
| while (pos < keys[wrange.begin()].str().length()) { |
| UInt32 j; |
| for (j = wrange.begin() + 1; j < wrange.end(); ++j) { |
| if (keys[j - 1].str()[pos] != keys[j].str()[pos]) { |
| break; |
| } |
| } |
| if (j < wrange.end()) { |
| break; |
| } |
| ++pos; |
| } |
| } |
| if ((progress.trie() != MARISA_PATRICIA_TRIE) && |
| (pos != keys[wrange.end() - 1].str().length())) { |
| pos = wrange.pos() + 1; |
| } |
| louds_.push_back(true); |
| if (pos == wrange.pos() + 1) { |
| labels_.push_back(keys[wrange.begin()].str()[wrange.pos()]); |
| link_flags_.push_back(false); |
| } else { |
| labels_.push_back('\0'); |
| link_flags_.push_back(true); |
| Key<T> rest_key; |
| rest_key.set_str(keys[wrange.begin()].str().substr( |
| wrange.pos(), pos - wrange.pos())); |
| rest_key.set_weight(wrange.weight()); |
| rest_keys.push_back(rest_key); |
| } |
| wranges[i].set_pos(pos); |
| queue.push(wranges[i].range()); |
| } |
| louds_.push_back(false); |
| } |
| louds_.push_back(false); |
| louds_.build(); |
| if (progress.trie_id() != 0) { |
| louds_.clear_select0s(); |
| } |
| if (rest_keys.empty()) { |
| link_flags_.clear(); |
| } |
| |
| build_terminals(keys, terminals); |
| keys.swap(&rest_keys); |
| } |
| |
| void Trie::build_next(Vector<Key<String> > &keys, |
| Vector<UInt32> *terminals, Progress &progress) { |
| if (progress.is_last()) { |
| Vector<String> strs; |
| strs.resize(keys.size()); |
| for (UInt32 i = 0; i < strs.size(); ++i) { |
| strs[i] = keys[i].str(); |
| } |
| tail_.build(strs, terminals, progress.tail()); |
| return; |
| } |
| Vector<Key<RString> > rkeys; |
| rkeys.resize(keys.size()); |
| for (UInt32 i = 0; i < rkeys.size(); ++i) { |
| rkeys[i].set_str(RString(keys[i].str())); |
| rkeys[i].set_weight(keys[i].weight()); |
| } |
| keys.clear(); |
| trie_.reset(new (std::nothrow) Trie); |
| MARISA_THROW_IF(!has_trie(), MARISA_MEMORY_ERROR); |
| trie_->build_trie(rkeys, terminals, ++progress); |
| } |
| |
| void Trie::build_next(Vector<Key<RString> > &rkeys, |
| Vector<UInt32> *terminals, Progress &progress) { |
| if (progress.is_last()) { |
| Vector<String> strs; |
| strs.resize(rkeys.size()); |
| for (UInt32 i = 0; i < strs.size(); ++i) { |
| strs[i] = String(rkeys[i].str().ptr(), rkeys[i].str().length()); |
| } |
| tail_.build(strs, terminals, progress.tail()); |
| return; |
| } |
| trie_.reset(new (std::nothrow) Trie); |
| MARISA_THROW_IF(!has_trie(), MARISA_MEMORY_ERROR); |
| trie_->build_trie(rkeys, terminals, ++progress); |
| } |
| |
| template <typename T> |
| UInt32 Trie::sort_keys(Vector<Key<T> > &keys) const { |
| if (keys.empty()) { |
| return 0; |
| } |
| for (UInt32 i = 0; i < keys.size(); ++i) { |
| keys[i].set_id(i); |
| } |
| std::sort(keys.begin(), keys.end()); |
| UInt32 count = 1; |
| for (UInt32 i = 1; i < keys.size(); ++i) { |
| if (keys[i - 1].str() != keys[i].str()) { |
| ++count; |
| } |
| } |
| return count; |
| } |
| |
| template <typename T> |
| void Trie::build_terminals(const Vector<Key<T> > &keys, |
| Vector<UInt32> *terminals) const { |
| Vector<UInt32> temp_terminals; |
| temp_terminals.resize(keys.size()); |
| for (UInt32 i = 0; i < keys.size(); ++i) { |
| temp_terminals[keys[i].id()] = keys[i].terminal(); |
| } |
| temp_terminals.swap(terminals); |
| } |
| |
| } // namespace marisa |