| // encode.h |
| |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // |
| // Copyright 2005-2010 Google, Inc. |
| // Author: johans@google.com (Johan Schalkwyk) |
| // |
| // \file |
| // Class to encode and decoder an fst. |
| |
| #ifndef FST_LIB_ENCODE_H__ |
| #define FST_LIB_ENCODE_H__ |
| |
| #include <climits> |
| #include <unordered_map> |
| using std::tr1::unordered_map; |
| using std::tr1::unordered_multimap; |
| #include <string> |
| #include <vector> |
| using std::vector; |
| |
| #include <fst/arc-map.h> |
| #include <fst/rmfinalepsilon.h> |
| |
| |
| namespace fst { |
| |
| static const uint32 kEncodeLabels = 0x0001; |
| static const uint32 kEncodeWeights = 0x0002; |
| static const uint32 kEncodeFlags = 0x0003; // All non-internal flags |
| |
| static const uint32 kEncodeHasISymbols = 0x0004; // For internal use |
| static const uint32 kEncodeHasOSymbols = 0x0008; // For internal use |
| |
| enum EncodeType { ENCODE = 1, DECODE = 2 }; |
| |
| // Identifies stream data as an encode table (and its endianity) |
| static const int32 kEncodeMagicNumber = 2129983209; |
| |
| |
| // The following class encapsulates implementation details for the |
| // encoding and decoding of label/weight tuples used for encoding |
| // and decoding of Fsts. The EncodeTable is bidirectional. I.E it |
| // stores both the Tuple of encode labels and weights to a unique |
| // label, and the reverse. |
| template <class A> class EncodeTable { |
| public: |
| typedef typename A::Label Label; |
| typedef typename A::Weight Weight; |
| |
| // Encoded data consists of arc input/output labels and arc weight |
| struct Tuple { |
| Tuple() {} |
| Tuple(Label ilabel_, Label olabel_, Weight weight_) |
| : ilabel(ilabel_), olabel(olabel_), weight(weight_) {} |
| Tuple(const Tuple& tuple) |
| : ilabel(tuple.ilabel), olabel(tuple.olabel), weight(tuple.weight) {} |
| |
| Label ilabel; |
| Label olabel; |
| Weight weight; |
| }; |
| |
| // Comparison object for hashing EncodeTable Tuple(s). |
| class TupleEqual { |
| public: |
| bool operator()(const Tuple* x, const Tuple* y) const { |
| return (x->ilabel == y->ilabel && |
| x->olabel == y->olabel && |
| x->weight == y->weight); |
| } |
| }; |
| |
| // Hash function for EncodeTabe Tuples. Based on the encode flags |
| // we either hash the labels, weights or combination of them. |
| class TupleKey { |
| public: |
| TupleKey() |
| : encode_flags_(kEncodeLabels | kEncodeWeights) {} |
| |
| TupleKey(const TupleKey& key) |
| : encode_flags_(key.encode_flags_) {} |
| |
| explicit TupleKey(uint32 encode_flags) |
| : encode_flags_(encode_flags) {} |
| |
| size_t operator()(const Tuple* x) const { |
| size_t hash = x->ilabel; |
| const int lshift = 5; |
| const int rshift = CHAR_BIT * sizeof(size_t) - 5; |
| if (encode_flags_ & kEncodeLabels) |
| hash = hash << lshift ^ hash >> rshift ^ x->olabel; |
| if (encode_flags_ & kEncodeWeights) |
| hash = hash << lshift ^ hash >> rshift ^ x->weight.Hash(); |
| return hash; |
| } |
| |
| private: |
| int32 encode_flags_; |
| }; |
| |
| typedef unordered_map<const Tuple*, |
| Label, |
| TupleKey, |
| TupleEqual> EncodeHash; |
| |
| explicit EncodeTable(uint32 encode_flags) |
| : flags_(encode_flags), |
| encode_hash_(1024, TupleKey(encode_flags)), |
| isymbols_(0), osymbols_(0) {} |
| |
| ~EncodeTable() { |
| for (size_t i = 0; i < encode_tuples_.size(); ++i) { |
| delete encode_tuples_[i]; |
| } |
| delete isymbols_; |
| delete osymbols_; |
| } |
| |
| // Given an arc encode either input/ouptut labels or input/costs or both |
| Label Encode(const A &arc) { |
| const Tuple tuple(arc.ilabel, |
| flags_ & kEncodeLabels ? arc.olabel : 0, |
| flags_ & kEncodeWeights ? arc.weight : Weight::One()); |
| typename EncodeHash::const_iterator it = encode_hash_.find(&tuple); |
| if (it == encode_hash_.end()) { |
| encode_tuples_.push_back(new Tuple(tuple)); |
| encode_hash_[encode_tuples_.back()] = encode_tuples_.size(); |
| return encode_tuples_.size(); |
| } else { |
| return it->second; |
| } |
| } |
| |
| // Given an arc, look up its encoded label. Returns kNoLabel if not found. |
| Label GetLabel(const A &arc) const { |
| const Tuple tuple(arc.ilabel, |
| flags_ & kEncodeLabels ? arc.olabel : 0, |
| flags_ & kEncodeWeights ? arc.weight : Weight::One()); |
| typename EncodeHash::const_iterator it = encode_hash_.find(&tuple); |
| if (it == encode_hash_.end()) { |
| return kNoLabel; |
| } else { |
| return it->second; |
| } |
| } |
| |
| // Given an encode arc Label decode back to input/output labels and costs |
| const Tuple* Decode(Label key) const { |
| if (key < 1 || key > encode_tuples_.size()) { |
| LOG(ERROR) << "EncodeTable::Decode: unknown decode key: " << key; |
| return 0; |
| } |
| return encode_tuples_[key - 1]; |
| } |
| |
| size_t Size() const { return encode_tuples_.size(); } |
| |
| bool Write(ostream &strm, const string &source) const; |
| |
| static EncodeTable<A> *Read(istream &strm, const string &source); |
| |
| const uint32 flags() const { return flags_ & kEncodeFlags; } |
| |
| int RefCount() const { return ref_count_.count(); } |
| int IncrRefCount() { return ref_count_.Incr(); } |
| int DecrRefCount() { return ref_count_.Decr(); } |
| |
| |
| SymbolTable *InputSymbols() const { return isymbols_; } |
| |
| SymbolTable *OutputSymbols() const { return osymbols_; } |
| |
| void SetInputSymbols(const SymbolTable* syms) { |
| if (isymbols_) delete isymbols_; |
| if (syms) { |
| isymbols_ = syms->Copy(); |
| flags_ |= kEncodeHasISymbols; |
| } else { |
| isymbols_ = 0; |
| flags_ &= ~kEncodeHasISymbols; |
| } |
| } |
| |
| void SetOutputSymbols(const SymbolTable* syms) { |
| if (osymbols_) delete osymbols_; |
| if (syms) { |
| osymbols_ = syms->Copy(); |
| flags_ |= kEncodeHasOSymbols; |
| } else { |
| osymbols_ = 0; |
| flags_ &= ~kEncodeHasOSymbols; |
| } |
| } |
| |
| private: |
| uint32 flags_; |
| vector<Tuple*> encode_tuples_; |
| EncodeHash encode_hash_; |
| RefCounter ref_count_; |
| SymbolTable *isymbols_; // Pre-encoded ilabel symbol table |
| SymbolTable *osymbols_; // Pre-encoded olabel symbol table |
| |
| DISALLOW_COPY_AND_ASSIGN(EncodeTable); |
| }; |
| |
| template <class A> inline |
| bool EncodeTable<A>::Write(ostream &strm, const string &source) const { |
| WriteType(strm, kEncodeMagicNumber); |
| WriteType(strm, flags_); |
| int64 size = encode_tuples_.size(); |
| WriteType(strm, size); |
| for (size_t i = 0; i < size; ++i) { |
| const Tuple* tuple = encode_tuples_[i]; |
| WriteType(strm, tuple->ilabel); |
| WriteType(strm, tuple->olabel); |
| tuple->weight.Write(strm); |
| } |
| |
| if (flags_ & kEncodeHasISymbols) |
| isymbols_->Write(strm); |
| |
| if (flags_ & kEncodeHasOSymbols) |
| osymbols_->Write(strm); |
| |
| strm.flush(); |
| if (!strm) { |
| LOG(ERROR) << "EncodeTable::Write: write failed: " << source; |
| return false; |
| } |
| return true; |
| } |
| |
| template <class A> inline |
| EncodeTable<A> *EncodeTable<A>::Read(istream &strm, const string &source) { |
| int32 magic_number = 0; |
| ReadType(strm, &magic_number); |
| if (magic_number != kEncodeMagicNumber) { |
| LOG(ERROR) << "EncodeTable::Read: Bad encode table header: " << source; |
| return 0; |
| } |
| uint32 flags; |
| ReadType(strm, &flags); |
| EncodeTable<A> *table = new EncodeTable<A>(flags); |
| |
| int64 size; |
| ReadType(strm, &size); |
| if (!strm) { |
| LOG(ERROR) << "EncodeTable::Read: read failed: " << source; |
| return 0; |
| } |
| |
| for (size_t i = 0; i < size; ++i) { |
| Tuple* tuple = new Tuple(); |
| ReadType(strm, &tuple->ilabel); |
| ReadType(strm, &tuple->olabel); |
| tuple->weight.Read(strm); |
| if (!strm) { |
| LOG(ERROR) << "EncodeTable::Read: read failed: " << source; |
| return 0; |
| } |
| table->encode_tuples_.push_back(tuple); |
| table->encode_hash_[table->encode_tuples_.back()] = |
| table->encode_tuples_.size(); |
| } |
| |
| if (flags & kEncodeHasISymbols) |
| table->isymbols_ = SymbolTable::Read(strm, source); |
| |
| if (flags & kEncodeHasOSymbols) |
| table->osymbols_ = SymbolTable::Read(strm, source); |
| |
| return table; |
| } |
| |
| |
| // A mapper to encode/decode weighted transducers. Encoding of an |
| // Fst is useful for performing classical determinization or minimization |
| // on a weighted transducer by treating it as an unweighted acceptor over |
| // encoded labels. |
| // |
| // The Encode mapper stores the encoding in a local hash table (EncodeTable) |
| // This table is shared (and reference counted) between the encoder and |
| // decoder. A decoder has read only access to the EncodeTable. |
| // |
| // The EncodeMapper allows on the fly encoding of the machine. As the |
| // EncodeTable is generated the same table may by used to decode the machine |
| // on the fly. For example in the following sequence of operations |
| // |
| // Encode -> Determinize -> Decode |
| // |
| // we will use the encoding table generated during the encode step in the |
| // decode, even though the encoding is not complete. |
| // |
| template <class A> class EncodeMapper { |
| typedef typename A::Weight Weight; |
| typedef typename A::Label Label; |
| public: |
| EncodeMapper(uint32 flags, EncodeType type) |
| : flags_(flags), |
| type_(type), |
| table_(new EncodeTable<A>(flags)), |
| error_(false) {} |
| |
| EncodeMapper(const EncodeMapper& mapper) |
| : flags_(mapper.flags_), |
| type_(mapper.type_), |
| table_(mapper.table_), |
| error_(false) { |
| table_->IncrRefCount(); |
| } |
| |
| // Copy constructor but setting the type, typically to DECODE |
| EncodeMapper(const EncodeMapper& mapper, EncodeType type) |
| : flags_(mapper.flags_), |
| type_(type), |
| table_(mapper.table_), |
| error_(mapper.error_) { |
| table_->IncrRefCount(); |
| } |
| |
| ~EncodeMapper() { |
| if (!table_->DecrRefCount()) delete table_; |
| } |
| |
| A operator()(const A &arc); |
| |
| MapFinalAction FinalAction() const { |
| return (type_ == ENCODE && (flags_ & kEncodeWeights)) ? |
| MAP_REQUIRE_SUPERFINAL : MAP_NO_SUPERFINAL; |
| } |
| |
| MapSymbolsAction InputSymbolsAction() const { return MAP_CLEAR_SYMBOLS; } |
| |
| MapSymbolsAction OutputSymbolsAction() const { return MAP_CLEAR_SYMBOLS;} |
| |
| uint64 Properties(uint64 inprops) { |
| uint64 outprops = inprops; |
| if (error_) outprops |= kError; |
| |
| uint64 mask = kFstProperties; |
| if (flags_ & kEncodeLabels) |
| mask &= kILabelInvariantProperties & kOLabelInvariantProperties; |
| if (flags_ & kEncodeWeights) |
| mask &= kILabelInvariantProperties & kWeightInvariantProperties & |
| (type_ == ENCODE ? kAddSuperFinalProperties : |
| kRmSuperFinalProperties); |
| |
| return outprops & mask; |
| } |
| |
| const uint32 flags() const { return flags_; } |
| const EncodeType type() const { return type_; } |
| const EncodeTable<A> &table() const { return *table_; } |
| |
| bool Write(ostream &strm, const string& source) { |
| return table_->Write(strm, source); |
| } |
| |
| bool Write(const string& filename) { |
| ofstream strm(filename.c_str(), ofstream::out | ofstream::binary); |
| if (!strm) { |
| LOG(ERROR) << "EncodeMap: Can't open file: " << filename; |
| return false; |
| } |
| return Write(strm, filename); |
| } |
| |
| static EncodeMapper<A> *Read(istream &strm, |
| const string& source, |
| EncodeType type = ENCODE) { |
| EncodeTable<A> *table = EncodeTable<A>::Read(strm, source); |
| return table ? new EncodeMapper(table->flags(), type, table) : 0; |
| } |
| |
| static EncodeMapper<A> *Read(const string& filename, |
| EncodeType type = ENCODE) { |
| ifstream strm(filename.c_str(), ifstream::in | ifstream::binary); |
| if (!strm) { |
| LOG(ERROR) << "EncodeMap: Can't open file: " << filename; |
| return NULL; |
| } |
| return Read(strm, filename, type); |
| } |
| |
| SymbolTable *InputSymbols() const { return table_->InputSymbols(); } |
| |
| SymbolTable *OutputSymbols() const { return table_->OutputSymbols(); } |
| |
| void SetInputSymbols(const SymbolTable* syms) { |
| table_->SetInputSymbols(syms); |
| } |
| |
| void SetOutputSymbols(const SymbolTable* syms) { |
| table_->SetOutputSymbols(syms); |
| } |
| |
| private: |
| uint32 flags_; |
| EncodeType type_; |
| EncodeTable<A>* table_; |
| bool error_; |
| |
| explicit EncodeMapper(uint32 flags, EncodeType type, EncodeTable<A> *table) |
| : flags_(flags), type_(type), table_(table) {} |
| void operator=(const EncodeMapper &); // Disallow. |
| }; |
| |
| template <class A> inline |
| A EncodeMapper<A>::operator()(const A &arc) { |
| if (type_ == ENCODE) { // labels and/or weights to single label |
| if ((arc.nextstate == kNoStateId && !(flags_ & kEncodeWeights)) || |
| (arc.nextstate == kNoStateId && (flags_ & kEncodeWeights) && |
| arc.weight == Weight::Zero())) { |
| return arc; |
| } else { |
| Label label = table_->Encode(arc); |
| return A(label, |
| flags_ & kEncodeLabels ? label : arc.olabel, |
| flags_ & kEncodeWeights ? Weight::One() : arc.weight, |
| arc.nextstate); |
| } |
| } else { // type_ == DECODE |
| if (arc.nextstate == kNoStateId) { |
| return arc; |
| } else { |
| if (arc.ilabel == 0) return arc; |
| if (flags_ & kEncodeLabels && arc.ilabel != arc.olabel) { |
| FSTERROR() << "EncodeMapper: Label-encoded arc has different " |
| "input and output labels"; |
| error_ = true; |
| } |
| if (flags_ & kEncodeWeights && arc.weight != Weight::One()) { |
| FSTERROR() << |
| "EncodeMapper: Weight-encoded arc has non-trivial weight"; |
| error_ = true; |
| } |
| const typename EncodeTable<A>::Tuple* tuple = table_->Decode(arc.ilabel); |
| if (!tuple) { |
| FSTERROR() << "EncodeMapper: decode failed"; |
| error_ = true; |
| return A(kNoLabel, kNoLabel, Weight::NoWeight(), arc.nextstate); |
| } else { |
| return A(tuple->ilabel, |
| flags_ & kEncodeLabels ? tuple->olabel : arc.olabel, |
| flags_ & kEncodeWeights ? tuple->weight : arc.weight, |
| arc.nextstate); |
| } |
| } |
| } |
| } |
| |
| |
| // Complexity: O(nstates + narcs) |
| template<class A> inline |
| void Encode(MutableFst<A> *fst, EncodeMapper<A>* mapper) { |
| mapper->SetInputSymbols(fst->InputSymbols()); |
| mapper->SetOutputSymbols(fst->OutputSymbols()); |
| ArcMap(fst, mapper); |
| } |
| |
| template<class A> inline |
| void Decode(MutableFst<A>* fst, const EncodeMapper<A>& mapper) { |
| ArcMap(fst, EncodeMapper<A>(mapper, DECODE)); |
| RmFinalEpsilon(fst); |
| fst->SetInputSymbols(mapper.InputSymbols()); |
| fst->SetOutputSymbols(mapper.OutputSymbols()); |
| } |
| |
| |
| // On the fly label and/or weight encoding of input Fst |
| // |
| // Complexity: |
| // - Constructor: O(1) |
| // - Traversal: O(nstates_visited + narcs_visited), assuming constant |
| // time to visit an input state or arc. |
| template <class A> |
| class EncodeFst : public ArcMapFst<A, A, EncodeMapper<A> > { |
| public: |
| typedef A Arc; |
| typedef EncodeMapper<A> C; |
| typedef ArcMapFstImpl< A, A, EncodeMapper<A> > Impl; |
| using ImplToFst<Impl>::GetImpl; |
| |
| EncodeFst(const Fst<A> &fst, EncodeMapper<A>* encoder) |
| : ArcMapFst<A, A, C>(fst, encoder, ArcMapFstOptions()) { |
| encoder->SetInputSymbols(fst.InputSymbols()); |
| encoder->SetOutputSymbols(fst.OutputSymbols()); |
| } |
| |
| EncodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder) |
| : ArcMapFst<A, A, C>(fst, encoder, ArcMapFstOptions()) {} |
| |
| // See Fst<>::Copy() for doc. |
| EncodeFst(const EncodeFst<A> &fst, bool copy = false) |
| : ArcMapFst<A, A, C>(fst, copy) {} |
| |
| // Get a copy of this EncodeFst. See Fst<>::Copy() for further doc. |
| virtual EncodeFst<A> *Copy(bool safe = false) const { |
| if (safe) { |
| FSTERROR() << "EncodeFst::Copy(true): not allowed."; |
| GetImpl()->SetProperties(kError, kError); |
| } |
| return new EncodeFst(*this); |
| } |
| }; |
| |
| |
| // On the fly label and/or weight encoding of input Fst |
| // |
| // Complexity: |
| // - Constructor: O(1) |
| // - Traversal: O(nstates_visited + narcs_visited), assuming constant |
| // time to visit an input state or arc. |
| template <class A> |
| class DecodeFst : public ArcMapFst<A, A, EncodeMapper<A> > { |
| public: |
| typedef A Arc; |
| typedef EncodeMapper<A> C; |
| typedef ArcMapFstImpl< A, A, EncodeMapper<A> > Impl; |
| using ImplToFst<Impl>::GetImpl; |
| |
| DecodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder) |
| : ArcMapFst<A, A, C>(fst, |
| EncodeMapper<A>(encoder, DECODE), |
| ArcMapFstOptions()) { |
| GetImpl()->SetInputSymbols(encoder.InputSymbols()); |
| GetImpl()->SetOutputSymbols(encoder.OutputSymbols()); |
| } |
| |
| // See Fst<>::Copy() for doc. |
| DecodeFst(const DecodeFst<A> &fst, bool safe = false) |
| : ArcMapFst<A, A, C>(fst, safe) {} |
| |
| // Get a copy of this DecodeFst. See Fst<>::Copy() for further doc. |
| virtual DecodeFst<A> *Copy(bool safe = false) const { |
| return new DecodeFst(*this, safe); |
| } |
| }; |
| |
| |
| // Specialization for EncodeFst. |
| template <class A> |
| class StateIterator< EncodeFst<A> > |
| : public StateIterator< ArcMapFst<A, A, EncodeMapper<A> > > { |
| public: |
| explicit StateIterator(const EncodeFst<A> &fst) |
| : StateIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst) {} |
| }; |
| |
| |
| // Specialization for EncodeFst. |
| template <class A> |
| class ArcIterator< EncodeFst<A> > |
| : public ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > > { |
| public: |
| ArcIterator(const EncodeFst<A> &fst, typename A::StateId s) |
| : ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst, s) {} |
| }; |
| |
| |
| // Specialization for DecodeFst. |
| template <class A> |
| class StateIterator< DecodeFst<A> > |
| : public StateIterator< ArcMapFst<A, A, EncodeMapper<A> > > { |
| public: |
| explicit StateIterator(const DecodeFst<A> &fst) |
| : StateIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst) {} |
| }; |
| |
| |
| // Specialization for DecodeFst. |
| template <class A> |
| class ArcIterator< DecodeFst<A> > |
| : public ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > > { |
| public: |
| ArcIterator(const DecodeFst<A> &fst, typename A::StateId s) |
| : ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst, s) {} |
| }; |
| |
| |
| // Useful aliases when using StdArc. |
| typedef EncodeFst<StdArc> StdEncodeFst; |
| |
| typedef DecodeFst<StdArc> StdDecodeFst; |
| |
| } // namespace fst |
| |
| #endif // FST_LIB_ENCODE_H__ |