| // lookahead-matcher.h |
| |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // |
| // Copyright 2005-2010 Google, Inc. |
| // Author: riley@google.com (Michael Riley) |
| // |
| // \file |
| // Classes to add lookahead to FST matchers, useful e.g. for improving |
| // composition efficiency with certain inputs. |
| |
| #ifndef FST_LIB_LOOKAHEAD_MATCHER_H__ |
| #define FST_LIB_LOOKAHEAD_MATCHER_H__ |
| |
| #include <fst/add-on.h> |
| #include <fst/const-fst.h> |
| #include <fst/fst.h> |
| #include <fst/label-reachable.h> |
| #include <fst/matcher.h> |
| |
| |
| DECLARE_string(save_relabel_ipairs); |
| DECLARE_string(save_relabel_opairs); |
| |
| namespace fst { |
| |
| // LOOKAHEAD MATCHERS - these have the interface of Matchers (see |
| // matcher.h) and these additional methods: |
| // |
| // template <class F> |
| // class LookAheadMatcher { |
| // public: |
| // typedef F FST; |
| // typedef F::Arc Arc; |
| // typedef typename Arc::StateId StateId; |
| // typedef typename Arc::Label Label; |
| // typedef typename Arc::Weight Weight; |
| // |
| // // Required constructors. |
| // LookAheadMatcher(const F &fst, MatchType match_type); |
| // // If safe=true, the copy is thread-safe (except the lookahead Fst is |
| // // preserved). See Fst<>::Cop() for further doc. |
| // LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false); |
| // |
| // Below are methods for looking ahead for a match to a label and |
| // more generally, to a rational set. Each returns false if there is |
| // definitely not a match and returns true if there possibly is a |
| // match. |
| |
| // // LABEL LOOKAHEAD: Can 'label' be read from the current matcher state |
| // // after possibly following epsilon transitions? |
| // bool LookAheadLabel(Label label) const; |
| // |
| // // RATIONAL LOOKAHEAD: The next methods allow looking ahead for an |
| // // arbitrary rational set of strings, specified by an FST and a state |
| // // from which to begin the matching. If the lookahead FST is a |
| // // transducer, this looks on the side different from the matcher |
| // // 'match_type' (cf. composition). |
| // |
| // // Are there paths P from 's' in the lookahead FST that can be read from |
| // // the cur. matcher state? |
| // bool LookAheadFst(const Fst<Arc>& fst, StateId s); |
| // |
| // // Gives an estimate of the combined weight of the paths P in the |
| // // lookahead and matcher FSTs for the last call to LookAheadFst. |
| // // A trivial implementation returns Weight::One(). Non-trivial |
| // // implementations are useful for weight-pushing in composition. |
| // Weight LookAheadWeight() const; |
| // |
| // // Is there is a single non-epsilon arc found in the lookahead FST |
| // // that begins P (after possibly following any epsilons) in the last |
| // // call LookAheadFst? If so, return true and copy it to '*arc', o.w. |
| // // return false. A trivial implementation returns false. Non-trivial |
| // // implementations are useful for label-pushing in composition. |
| // bool LookAheadPrefix(Arc *arc); |
| // |
| // // Optionally pre-specifies the lookahead FST that will be passed |
| // // to LookAheadFst() for possible precomputation. If copy is true, |
| // // then 'fst' is a copy of the FST used in the previous call to |
| // // this method (useful to avoid unnecessary updates). |
| // void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false); |
| // |
| // }; |
| |
| // |
| // LOOK-AHEAD FLAGS (see also kMatcherFlags in matcher.h): |
| // |
| // Matcher is a lookahead matcher when 'match_type' is MATCH_INPUT. |
| const uint32 kInputLookAheadMatcher = 0x00000010; |
| |
| // Matcher is a lookahead matcher when 'match_type' is MATCH_OUTPUT. |
| const uint32 kOutputLookAheadMatcher = 0x00000020; |
| |
| // A non-trivial implementation of LookAheadWeight() method defined and |
| // should be used? |
| const uint32 kLookAheadWeight = 0x00000040; |
| |
| // A non-trivial implementation of LookAheadPrefix() method defined and |
| // should be used? |
| const uint32 kLookAheadPrefix = 0x00000080; |
| |
| // Look-ahead of matcher FST non-epsilon arcs? |
| const uint32 kLookAheadNonEpsilons = 0x00000100; |
| |
| // Look-ahead of matcher FST epsilon arcs? |
| const uint32 kLookAheadEpsilons = 0x00000200; |
| |
| // Ignore epsilon paths for the lookahead prefix? Note this gives |
| // correct results in composition only with an appropriate composition |
| // filter since it depends on the filter blocking the ignored paths. |
| const uint32 kLookAheadNonEpsilonPrefix = 0x00000400; |
| |
| // For LabelLookAheadMatcher, save relabeling data to file |
| const uint32 kLookAheadKeepRelabelData = 0x00000800; |
| |
| // Flags used for lookahead matchers. |
| const uint32 kLookAheadFlags = 0x00000ff0; |
| |
| // LookAhead Matcher interface, templated on the Arc definition; used |
| // for lookahead matcher specializations that are returned by the |
| // InitMatcher() Fst method. |
| template <class A> |
| class LookAheadMatcherBase : public MatcherBase<A> { |
| public: |
| typedef A Arc; |
| typedef typename A::StateId StateId; |
| typedef typename A::Label Label; |
| typedef typename A::Weight Weight; |
| |
| LookAheadMatcherBase() |
| : weight_(Weight::One()), |
| prefix_arc_(kNoLabel, kNoLabel, Weight::One(), kNoStateId) {} |
| |
| virtual ~LookAheadMatcherBase() {} |
| |
| bool LookAheadLabel(Label label) const { return LookAheadLabel_(label); } |
| |
| bool LookAheadFst(const Fst<Arc> &fst, StateId s) { |
| return LookAheadFst_(fst, s); |
| } |
| |
| Weight LookAheadWeight() const { return weight_; } |
| |
| bool LookAheadPrefix(Arc *arc) const { |
| if (prefix_arc_.nextstate != kNoStateId) { |
| *arc = prefix_arc_; |
| return true; |
| } else { |
| return false; |
| } |
| } |
| |
| virtual void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) = 0; |
| |
| protected: |
| void SetLookAheadWeight(const Weight &w) { weight_ = w; } |
| |
| void SetLookAheadPrefix(const Arc &arc) { prefix_arc_ = arc; } |
| |
| void ClearLookAheadPrefix() { prefix_arc_.nextstate = kNoStateId; } |
| |
| private: |
| virtual bool LookAheadLabel_(Label label) const = 0; |
| virtual bool LookAheadFst_(const Fst<Arc> &fst, |
| StateId s) = 0; // This must set l.a. weight and |
| // prefix if non-trivial. |
| Weight weight_; // Look-ahead weight |
| Arc prefix_arc_; // Look-ahead prefix arc |
| }; |
| |
| |
| // Don't really lookahead, just declare future looks good regardless. |
| template <class M> |
| class TrivialLookAheadMatcher |
| : public LookAheadMatcherBase<typename M::FST::Arc> { |
| public: |
| typedef typename M::FST FST; |
| typedef typename M::Arc Arc; |
| typedef typename Arc::StateId StateId; |
| typedef typename Arc::Label Label; |
| typedef typename Arc::Weight Weight; |
| |
| TrivialLookAheadMatcher(const FST &fst, MatchType match_type) |
| : matcher_(fst, match_type) {} |
| |
| TrivialLookAheadMatcher(const TrivialLookAheadMatcher<M> &lmatcher, |
| bool safe = false) |
| : matcher_(lmatcher.matcher_, safe) {} |
| |
| // General matcher methods |
| TrivialLookAheadMatcher<M> *Copy(bool safe = false) const { |
| return new TrivialLookAheadMatcher<M>(*this, safe); |
| } |
| |
| MatchType Type(bool test) const { return matcher_.Type(test); } |
| void SetState(StateId s) { return matcher_.SetState(s); } |
| bool Find(Label label) { return matcher_.Find(label); } |
| bool Done() const { return matcher_.Done(); } |
| const Arc& Value() const { return matcher_.Value(); } |
| void Next() { matcher_.Next(); } |
| virtual const FST &GetFst() const { return matcher_.GetFst(); } |
| uint64 Properties(uint64 props) const { return matcher_.Properties(props); } |
| uint32 Flags() const { |
| return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher; |
| } |
| |
| // Look-ahead methods. |
| bool LookAheadLabel(Label label) const { return true; } |
| bool LookAheadFst(const Fst<Arc> &fst, StateId s) {return true; } |
| Weight LookAheadWeight() const { return Weight::One(); } |
| bool LookAheadPrefix(Arc *arc) const { return false; } |
| void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {} |
| |
| private: |
| // This allows base class virtual access to non-virtual derived- |
| // class members of the same name. It makes the derived class more |
| // efficient to use but unsafe to further derive. |
| virtual void SetState_(StateId s) { SetState(s); } |
| virtual bool Find_(Label label) { return Find(label); } |
| virtual bool Done_() const { return Done(); } |
| virtual const Arc& Value_() const { return Value(); } |
| virtual void Next_() { Next(); } |
| |
| bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); } |
| |
| bool LookAheadFst_(const Fst<Arc> &fst, StateId s) { |
| return LookAheadFst(fst, s); |
| } |
| |
| Weight LookAheadWeight_() const { return LookAheadWeight(); } |
| bool LookAheadPrefix_(Arc *arc) const { return LookAheadPrefix(arc); } |
| |
| M matcher_; |
| }; |
| |
| // Look-ahead of one transition. Template argument F accepts flags to |
| // control behavior. |
| template <class M, uint32 F = kLookAheadNonEpsilons | kLookAheadEpsilons | |
| kLookAheadWeight | kLookAheadPrefix> |
| class ArcLookAheadMatcher |
| : public LookAheadMatcherBase<typename M::FST::Arc> { |
| public: |
| typedef typename M::FST FST; |
| typedef typename M::Arc Arc; |
| typedef typename Arc::StateId StateId; |
| typedef typename Arc::Label Label; |
| typedef typename Arc::Weight Weight; |
| typedef NullAddOn MatcherData; |
| |
| using LookAheadMatcherBase<Arc>::LookAheadWeight; |
| using LookAheadMatcherBase<Arc>::SetLookAheadPrefix; |
| using LookAheadMatcherBase<Arc>::SetLookAheadWeight; |
| using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix; |
| |
| ArcLookAheadMatcher(const FST &fst, MatchType match_type, |
| MatcherData *data = 0) |
| : matcher_(fst, match_type), |
| fst_(matcher_.GetFst()), |
| lfst_(0), |
| s_(kNoStateId) {} |
| |
| ArcLookAheadMatcher(const ArcLookAheadMatcher<M, F> &lmatcher, |
| bool safe = false) |
| : matcher_(lmatcher.matcher_, safe), |
| fst_(matcher_.GetFst()), |
| lfst_(lmatcher.lfst_), |
| s_(kNoStateId) {} |
| |
| // General matcher methods |
| ArcLookAheadMatcher<M, F> *Copy(bool safe = false) const { |
| return new ArcLookAheadMatcher<M, F>(*this, safe); |
| } |
| |
| MatchType Type(bool test) const { return matcher_.Type(test); } |
| |
| void SetState(StateId s) { |
| s_ = s; |
| matcher_.SetState(s); |
| } |
| |
| bool Find(Label label) { return matcher_.Find(label); } |
| bool Done() const { return matcher_.Done(); } |
| const Arc& Value() const { return matcher_.Value(); } |
| void Next() { matcher_.Next(); } |
| const FST &GetFst() const { return fst_; } |
| uint64 Properties(uint64 props) const { return matcher_.Properties(props); } |
| uint32 Flags() const { |
| return matcher_.Flags() | kInputLookAheadMatcher | |
| kOutputLookAheadMatcher | F; |
| } |
| |
| // Writable matcher methods |
| MatcherData *GetData() const { return 0; } |
| |
| // Look-ahead methods. |
| bool LookAheadLabel(Label label) const { return matcher_.Find(label); } |
| |
| // Checks if there is a matching (possibly super-final) transition |
| // at (s_, s). |
| bool LookAheadFst(const Fst<Arc> &fst, StateId s); |
| |
| void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) { |
| lfst_ = &fst; |
| } |
| |
| private: |
| // This allows base class virtual access to non-virtual derived- |
| // class members of the same name. It makes the derived class more |
| // efficient to use but unsafe to further derive. |
| virtual void SetState_(StateId s) { SetState(s); } |
| virtual bool Find_(Label label) { return Find(label); } |
| virtual bool Done_() const { return Done(); } |
| virtual const Arc& Value_() const { return Value(); } |
| virtual void Next_() { Next(); } |
| |
| bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); } |
| bool LookAheadFst_(const Fst<Arc> &fst, StateId s) { |
| return LookAheadFst(fst, s); |
| } |
| |
| mutable M matcher_; |
| const FST &fst_; // Matcher FST |
| const Fst<Arc> *lfst_; // Look-ahead FST |
| StateId s_; // Matcher state |
| }; |
| |
| template <class M, uint32 F> |
| bool ArcLookAheadMatcher<M, F>::LookAheadFst(const Fst<Arc> &fst, StateId s) { |
| if (&fst != lfst_) |
| InitLookAheadFst(fst); |
| |
| bool ret = false; |
| ssize_t nprefix = 0; |
| if (F & kLookAheadWeight) |
| SetLookAheadWeight(Weight::Zero()); |
| if (F & kLookAheadPrefix) |
| ClearLookAheadPrefix(); |
| if (fst_.Final(s_) != Weight::Zero() && |
| lfst_->Final(s) != Weight::Zero()) { |
| if (!(F & (kLookAheadWeight | kLookAheadPrefix))) |
| return true; |
| ++nprefix; |
| if (F & kLookAheadWeight) |
| SetLookAheadWeight(Plus(LookAheadWeight(), |
| Times(fst_.Final(s_), lfst_->Final(s)))); |
| ret = true; |
| } |
| if (matcher_.Find(kNoLabel)) { |
| if (!(F & (kLookAheadWeight | kLookAheadPrefix))) |
| return true; |
| ++nprefix; |
| if (F & kLookAheadWeight) |
| for (; !matcher_.Done(); matcher_.Next()) |
| SetLookAheadWeight(Plus(LookAheadWeight(), matcher_.Value().weight)); |
| ret = true; |
| } |
| for (ArcIterator< Fst<Arc> > aiter(*lfst_, s); |
| !aiter.Done(); |
| aiter.Next()) { |
| const Arc &arc = aiter.Value(); |
| Label label = kNoLabel; |
| switch (matcher_.Type(false)) { |
| case MATCH_INPUT: |
| label = arc.olabel; |
| break; |
| case MATCH_OUTPUT: |
| label = arc.ilabel; |
| break; |
| default: |
| FSTERROR() << "ArcLookAheadMatcher::LookAheadFst: bad match type"; |
| return true; |
| } |
| if (label == 0) { |
| if (!(F & (kLookAheadWeight | kLookAheadPrefix))) |
| return true; |
| if (!(F & kLookAheadNonEpsilonPrefix)) |
| ++nprefix; |
| if (F & kLookAheadWeight) |
| SetLookAheadWeight(Plus(LookAheadWeight(), arc.weight)); |
| ret = true; |
| } else if (matcher_.Find(label)) { |
| if (!(F & (kLookAheadWeight | kLookAheadPrefix))) |
| return true; |
| for (; !matcher_.Done(); matcher_.Next()) { |
| ++nprefix; |
| if (F & kLookAheadWeight) |
| SetLookAheadWeight(Plus(LookAheadWeight(), |
| Times(arc.weight, |
| matcher_.Value().weight))); |
| if ((F & kLookAheadPrefix) && nprefix == 1) |
| SetLookAheadPrefix(arc); |
| } |
| ret = true; |
| } |
| } |
| if (F & kLookAheadPrefix) { |
| if (nprefix == 1) |
| SetLookAheadWeight(Weight::One()); // Avoids double counting. |
| else |
| ClearLookAheadPrefix(); |
| } |
| return ret; |
| } |
| |
| |
| // Template argument F accepts flags to control behavior. |
| // It must include precisely one of KInputLookAheadMatcher or |
| // KOutputLookAheadMatcher. |
| template <class M, uint32 F = kLookAheadEpsilons | kLookAheadWeight | |
| kLookAheadPrefix | kLookAheadNonEpsilonPrefix | |
| kLookAheadKeepRelabelData, |
| class S = DefaultAccumulator<typename M::Arc> > |
| class LabelLookAheadMatcher |
| : public LookAheadMatcherBase<typename M::FST::Arc> { |
| public: |
| typedef typename M::FST FST; |
| typedef typename M::Arc Arc; |
| typedef typename Arc::StateId StateId; |
| typedef typename Arc::Label Label; |
| typedef typename Arc::Weight Weight; |
| typedef LabelReachableData<Label> MatcherData; |
| |
| using LookAheadMatcherBase<Arc>::LookAheadWeight; |
| using LookAheadMatcherBase<Arc>::SetLookAheadPrefix; |
| using LookAheadMatcherBase<Arc>::SetLookAheadWeight; |
| using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix; |
| |
| LabelLookAheadMatcher(const FST &fst, MatchType match_type, |
| MatcherData *data = 0, S *s = 0) |
| : matcher_(fst, match_type), |
| lfst_(0), |
| label_reachable_(0), |
| s_(kNoStateId), |
| error_(false) { |
| if (!(F & (kInputLookAheadMatcher | kOutputLookAheadMatcher))) { |
| FSTERROR() << "LabelLookaheadMatcher: bad matcher flags: " << F; |
| error_ = true; |
| } |
| bool reach_input = match_type == MATCH_INPUT; |
| if (data) { |
| if (reach_input == data->ReachInput()) |
| label_reachable_ = new LabelReachable<Arc, S>(data, s); |
| } else if ((reach_input && (F & kInputLookAheadMatcher)) || |
| (!reach_input && (F & kOutputLookAheadMatcher))) { |
| label_reachable_ = new LabelReachable<Arc, S>( |
| fst, reach_input, s, F & kLookAheadKeepRelabelData); |
| } |
| } |
| |
| LabelLookAheadMatcher(const LabelLookAheadMatcher<M, F, S> &lmatcher, |
| bool safe = false) |
| : matcher_(lmatcher.matcher_, safe), |
| lfst_(lmatcher.lfst_), |
| label_reachable_( |
| lmatcher.label_reachable_ ? |
| new LabelReachable<Arc, S>(*lmatcher.label_reachable_) : 0), |
| s_(kNoStateId), |
| error_(lmatcher.error_) {} |
| |
| ~LabelLookAheadMatcher() { |
| delete label_reachable_; |
| } |
| |
| // General matcher methods |
| LabelLookAheadMatcher<M, F, S> *Copy(bool safe = false) const { |
| return new LabelLookAheadMatcher<M, F, S>(*this, safe); |
| } |
| |
| MatchType Type(bool test) const { return matcher_.Type(test); } |
| |
| void SetState(StateId s) { |
| if (s_ == s) |
| return; |
| s_ = s; |
| match_set_state_ = false; |
| reach_set_state_ = false; |
| } |
| |
| bool Find(Label label) { |
| if (!match_set_state_) { |
| matcher_.SetState(s_); |
| match_set_state_ = true; |
| } |
| return matcher_.Find(label); |
| } |
| |
| bool Done() const { return matcher_.Done(); } |
| const Arc& Value() const { return matcher_.Value(); } |
| void Next() { matcher_.Next(); } |
| const FST &GetFst() const { return matcher_.GetFst(); } |
| |
| uint64 Properties(uint64 inprops) const { |
| uint64 outprops = matcher_.Properties(inprops); |
| if (error_ || (label_reachable_ && label_reachable_->Error())) |
| outprops |= kError; |
| return outprops; |
| } |
| |
| uint32 Flags() const { |
| if (label_reachable_ && label_reachable_->GetData()->ReachInput()) |
| return matcher_.Flags() | F | kInputLookAheadMatcher; |
| else if (label_reachable_ && !label_reachable_->GetData()->ReachInput()) |
| return matcher_.Flags() | F | kOutputLookAheadMatcher; |
| else |
| return matcher_.Flags(); |
| } |
| |
| // Writable matcher methods |
| MatcherData *GetData() const { |
| return label_reachable_ ? label_reachable_->GetData() : 0; |
| }; |
| |
| // Look-ahead methods. |
| bool LookAheadLabel(Label label) const { |
| if (label == 0) |
| return true; |
| |
| if (label_reachable_) { |
| if (!reach_set_state_) { |
| label_reachable_->SetState(s_); |
| reach_set_state_ = true; |
| } |
| return label_reachable_->Reach(label); |
| } else { |
| return true; |
| } |
| } |
| |
| // Checks if there is a matching (possibly super-final) transition |
| // at (s_, s). |
| template <class L> |
| bool LookAheadFst(const L &fst, StateId s); |
| |
| void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) { |
| lfst_ = &fst; |
| if (label_reachable_) |
| label_reachable_->ReachInit(fst, copy); |
| } |
| |
| template <class L> |
| void InitLookAheadFst(const L& fst, bool copy = false) { |
| lfst_ = static_cast<const Fst<Arc> *>(&fst); |
| if (label_reachable_) |
| label_reachable_->ReachInit(fst, copy); |
| } |
| |
| private: |
| // This allows base class virtual access to non-virtual derived- |
| // class members of the same name. It makes the derived class more |
| // efficient to use but unsafe to further derive. |
| virtual void SetState_(StateId s) { SetState(s); } |
| virtual bool Find_(Label label) { return Find(label); } |
| virtual bool Done_() const { return Done(); } |
| virtual const Arc& Value_() const { return Value(); } |
| virtual void Next_() { Next(); } |
| |
| bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); } |
| bool LookAheadFst_(const Fst<Arc> &fst, StateId s) { |
| return LookAheadFst(fst, s); |
| } |
| |
| mutable M matcher_; |
| const Fst<Arc> *lfst_; // Look-ahead FST |
| LabelReachable<Arc, S> *label_reachable_; // Label reachability info |
| StateId s_; // Matcher state |
| bool match_set_state_; // matcher_.SetState called? |
| mutable bool reach_set_state_; // reachable_.SetState called? |
| bool error_; |
| }; |
| |
| template <class M, uint32 F, class S> |
| template <class L> inline |
| bool LabelLookAheadMatcher<M, F, S>::LookAheadFst(const L &fst, StateId s) { |
| if (static_cast<const Fst<Arc> *>(&fst) != lfst_) |
| InitLookAheadFst(fst); |
| |
| SetLookAheadWeight(Weight::One()); |
| ClearLookAheadPrefix(); |
| |
| if (!label_reachable_) |
| return true; |
| |
| label_reachable_->SetState(s_, s); |
| reach_set_state_ = true; |
| |
| bool compute_weight = F & kLookAheadWeight; |
| bool compute_prefix = F & kLookAheadPrefix; |
| |
| bool reach_input = Type(false) == MATCH_OUTPUT; |
| ArcIterator<L> aiter(fst, s); |
| bool reach_arc = label_reachable_->Reach(&aiter, 0, |
| internal::NumArcs(*lfst_, s), |
| reach_input, compute_weight); |
| Weight lfinal = internal::Final(*lfst_, s); |
| bool reach_final = lfinal != Weight::Zero() && label_reachable_->ReachFinal(); |
| if (reach_arc) { |
| ssize_t begin = label_reachable_->ReachBegin(); |
| ssize_t end = label_reachable_->ReachEnd(); |
| if (compute_prefix && end - begin == 1 && !reach_final) { |
| aiter.Seek(begin); |
| SetLookAheadPrefix(aiter.Value()); |
| compute_weight = false; |
| } else if (compute_weight) { |
| SetLookAheadWeight(label_reachable_->ReachWeight()); |
| } |
| } |
| if (reach_final && compute_weight) |
| SetLookAheadWeight(reach_arc ? |
| Plus(LookAheadWeight(), lfinal) : lfinal); |
| |
| return reach_arc || reach_final; |
| } |
| |
| |
| // Label-lookahead relabeling class. |
| template <class A> |
| class LabelLookAheadRelabeler { |
| public: |
| typedef typename A::Label Label; |
| typedef LabelReachableData<Label> MatcherData; |
| typedef AddOnPair<MatcherData, MatcherData> D; |
| |
| // Relabels matcher Fst - initialization function object. |
| template <typename I> |
| LabelLookAheadRelabeler(I **impl); |
| |
| // Relabels arbitrary Fst. Class L should be a label-lookahead Fst. |
| template <class L> |
| static void Relabel(MutableFst<A> *fst, const L &mfst, |
| bool relabel_input) { |
| typename L::Impl *impl = mfst.GetImpl(); |
| D *data = impl->GetAddOn(); |
| LabelReachable<A> reachable(data->First() ? |
| data->First() : data->Second()); |
| reachable.Relabel(fst, relabel_input); |
| } |
| |
| // Returns relabeling pairs (cf. relabel.h::Relabel()). |
| // Class L should be a label-lookahead Fst. |
| // If 'avoid_collisions' is true, extra pairs are added to |
| // ensure no collisions when relabeling automata that have |
| // labels unseen here. |
| template <class L> |
| static void RelabelPairs(const L &mfst, vector<pair<Label, Label> > *pairs, |
| bool avoid_collisions = false) { |
| typename L::Impl *impl = mfst.GetImpl(); |
| D *data = impl->GetAddOn(); |
| LabelReachable<A> reachable(data->First() ? |
| data->First() : data->Second()); |
| reachable.RelabelPairs(pairs, avoid_collisions); |
| } |
| }; |
| |
| template <class A> |
| template <typename I> inline |
| LabelLookAheadRelabeler<A>::LabelLookAheadRelabeler(I **impl) { |
| Fst<A> &fst = (*impl)->GetFst(); |
| D *data = (*impl)->GetAddOn(); |
| const string name = (*impl)->Type(); |
| bool is_mutable = fst.Properties(kMutable, false); |
| MutableFst<A> *mfst = 0; |
| if (is_mutable) { |
| mfst = static_cast<MutableFst<A> *>(&fst); |
| } else { |
| mfst = new VectorFst<A>(fst); |
| data->IncrRefCount(); |
| delete *impl; |
| } |
| if (data->First()) { // reach_input |
| LabelReachable<A> reachable(data->First()); |
| reachable.Relabel(mfst, true); |
| if (!FLAGS_save_relabel_ipairs.empty()) { |
| vector<pair<Label, Label> > pairs; |
| reachable.RelabelPairs(&pairs, true); |
| WriteLabelPairs(FLAGS_save_relabel_ipairs, pairs); |
| } |
| } else { |
| LabelReachable<A> reachable(data->Second()); |
| reachable.Relabel(mfst, false); |
| if (!FLAGS_save_relabel_opairs.empty()) { |
| vector<pair<Label, Label> > pairs; |
| reachable.RelabelPairs(&pairs, true); |
| WriteLabelPairs(FLAGS_save_relabel_opairs, pairs); |
| } |
| } |
| if (!is_mutable) { |
| *impl = new I(*mfst, name); |
| (*impl)->SetAddOn(data); |
| delete mfst; |
| data->DecrRefCount(); |
| } |
| } |
| |
| |
| // Generic lookahead matcher, templated on the FST definition |
| // - a wrapper around pointer to specific one. |
| template <class F> |
| class LookAheadMatcher { |
| public: |
| typedef F FST; |
| typedef typename F::Arc Arc; |
| typedef typename Arc::StateId StateId; |
| typedef typename Arc::Label Label; |
| typedef typename Arc::Weight Weight; |
| typedef LookAheadMatcherBase<Arc> LBase; |
| |
| LookAheadMatcher(const F &fst, MatchType match_type) { |
| base_ = fst.InitMatcher(match_type); |
| if (!base_) |
| base_ = new SortedMatcher<F>(fst, match_type); |
| lookahead_ = false; |
| } |
| |
| LookAheadMatcher(const LookAheadMatcher<F> &matcher, bool safe = false) { |
| base_ = matcher.base_->Copy(safe); |
| lookahead_ = matcher.lookahead_; |
| } |
| |
| ~LookAheadMatcher() { delete base_; } |
| |
| // General matcher methods |
| LookAheadMatcher<F> *Copy(bool safe = false) const { |
| return new LookAheadMatcher<F>(*this, safe); |
| } |
| |
| MatchType Type(bool test) const { return base_->Type(test); } |
| void SetState(StateId s) { base_->SetState(s); } |
| bool Find(Label label) { return base_->Find(label); } |
| bool Done() const { return base_->Done(); } |
| const Arc& Value() const { return base_->Value(); } |
| void Next() { base_->Next(); } |
| const F &GetFst() const { return static_cast<const F &>(base_->GetFst()); } |
| |
| uint64 Properties(uint64 props) const { return base_->Properties(props); } |
| |
| uint32 Flags() const { return base_->Flags(); } |
| |
| // Look-ahead methods |
| bool LookAheadLabel(Label label) const { |
| if (LookAheadCheck()) { |
| LBase *lbase = static_cast<LBase *>(base_); |
| return lbase->LookAheadLabel(label); |
| } else { |
| return true; |
| } |
| } |
| |
| bool LookAheadFst(const Fst<Arc> &fst, StateId s) { |
| if (LookAheadCheck()) { |
| LBase *lbase = static_cast<LBase *>(base_); |
| return lbase->LookAheadFst(fst, s); |
| } else { |
| return true; |
| } |
| } |
| |
| Weight LookAheadWeight() const { |
| if (LookAheadCheck()) { |
| LBase *lbase = static_cast<LBase *>(base_); |
| return lbase->LookAheadWeight(); |
| } else { |
| return Weight::One(); |
| } |
| } |
| |
| bool LookAheadPrefix(Arc *arc) const { |
| if (LookAheadCheck()) { |
| LBase *lbase = static_cast<LBase *>(base_); |
| return lbase->LookAheadPrefix(arc); |
| } else { |
| return false; |
| } |
| } |
| |
| void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) { |
| if (LookAheadCheck()) { |
| LBase *lbase = static_cast<LBase *>(base_); |
| lbase->InitLookAheadFst(fst, copy); |
| } |
| } |
| |
| private: |
| bool LookAheadCheck() const { |
| if (!lookahead_) { |
| lookahead_ = base_->Flags() & |
| (kInputLookAheadMatcher | kOutputLookAheadMatcher); |
| if (!lookahead_) { |
| FSTERROR() << "LookAheadMatcher: No look-ahead matcher defined"; |
| } |
| } |
| return lookahead_; |
| } |
| |
| MatcherBase<Arc> *base_; |
| mutable bool lookahead_; |
| |
| void operator=(const LookAheadMatcher<Arc> &); // disallow |
| }; |
| |
| } // namespace fst |
| |
| #endif // FST_LIB_LOOKAHEAD_MATCHER_H__ |