blob: 7179cfb2569f2d812586365c91c227ac7476b241 [file] [log] [blame]
#ifndef __FST_IO_H__
#define __FST_IO_H__
// fst-io.h
// This is a copy of the OPENFST SDK application sample files ...
// except for the main functions ifdef'ed out
// 2007, 2008 Nuance Communications
//
// print-main.h compile-main.h
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//
// \file
// Classes and functions to compile a binary Fst from textual input.
// Includes helper function for fstcompile.cc that templates the main
// on the arc type to support multiple and extensible arc types.
#include <fstream>
#include <sstream>
#include "fst/lib/fst.h"
#include "fst/lib/fstlib.h"
#include "fst/lib/fst-decl.h"
#include "fst/lib/vector-fst.h"
#include "fst/lib/arcsort.h"
#include "fst/lib/invert.h"
namespace fst {
template <class A> class FstPrinter {
public:
typedef A Arc;
typedef typename A::StateId StateId;
typedef typename A::Label Label;
typedef typename A::Weight Weight;
FstPrinter(const Fst<A> &fst,
const SymbolTable *isyms,
const SymbolTable *osyms,
const SymbolTable *ssyms,
bool accep)
: fst_(fst), isyms_(isyms), osyms_(osyms), ssyms_(ssyms),
accep_(accep && fst.Properties(kAcceptor, true)), ostrm_(0) {}
// Print Fst to an output strm
void Print(ostream *ostrm, const string &dest) {
ostrm_ = ostrm;
dest_ = dest;
StateId start = fst_.Start();
if (start == kNoStateId)
return;
// initial state first
PrintState(start);
for (StateIterator< Fst<A> > siter(fst_);
!siter.Done();
siter.Next()) {
StateId s = siter.Value();
if (s != start)
PrintState(s);
}
}
private:
// Maximum line length in text file.
static const int kLineLen = 8096;
void PrintId(int64 id, const SymbolTable *syms,
const char *name) const {
if (syms) {
string symbol = syms->Find(id);
if (symbol == "") {
LOG(ERROR) << "FstPrinter: Integer " << id
<< " is not mapped to any textual symbol"
<< ", symbol table = " << syms->Name()
<< ", destination = " << dest_;
exit(1);
}
*ostrm_ << symbol;
} else {
*ostrm_ << id;
}
}
void PrintStateId(StateId s) const {
PrintId(s, ssyms_, "state ID");
}
void PrintILabel(Label l) const {
PrintId(l, isyms_, "arc input label");
}
void PrintOLabel(Label l) const {
PrintId(l, osyms_, "arc output label");
}
void PrintState(StateId s) const {
bool output = false;
for (ArcIterator< Fst<A> > aiter(fst_, s);
!aiter.Done();
aiter.Next()) {
Arc arc = aiter.Value();
PrintStateId(s);
*ostrm_ << "\t";
PrintStateId(arc.nextstate);
*ostrm_ << "\t";
PrintILabel(arc.ilabel);
if (!accep_) {
*ostrm_ << "\t";
PrintOLabel(arc.olabel);
}
if (arc.weight != Weight::One())
*ostrm_ << "\t" << arc.weight;
*ostrm_ << "\n";
output = true;
}
Weight final = fst_.Final(s);
if (final != Weight::Zero() || !output) {
PrintStateId(s);
if (final != Weight::One()) {
*ostrm_ << "\t" << final;
}
*ostrm_ << "\n";
}
}
const Fst<A> &fst_;
const SymbolTable *isyms_; // ilabel symbol table
const SymbolTable *osyms_; // olabel symbol table
const SymbolTable *ssyms_; // slabel symbol table
bool accep_; // print as acceptor when possible
ostream *ostrm_; // binary FST destination
string dest_; // binary FST destination name
DISALLOW_EVIL_CONSTRUCTORS(FstPrinter);
};
#if 0
// Main function for fstprint templated on the arc type.
template <class Arc>
int PrintMain(int argc, char **argv, istream &istrm,
const FstReadOptions &opts) {
Fst<Arc> *fst = Fst<Arc>::Read(istrm, opts);
if (!fst) return 1;
string dest = "standard output";
ostream *ostrm = &std::cout;
if (argc == 3) {
dest = argv[2];
ostrm = new ofstream(argv[2]);
if (!*ostrm) {
LOG(ERROR) << argv[0] << ": Open failed, file = " << argv[2];
return 0;
}
}
ostrm->precision(9);
const SymbolTable *isyms = 0, *osyms = 0, *ssyms = 0;
if (!FLAGS_isymbols.empty() && !FLAGS_numeric) {
isyms = SymbolTable::ReadText(FLAGS_isymbols);
if (!isyms) exit(1);
}
if (!FLAGS_osymbols.empty() && !FLAGS_numeric) {
osyms = SymbolTable::ReadText(FLAGS_osymbols);
if (!osyms) exit(1);
}
if (!FLAGS_ssymbols.empty() && !FLAGS_numeric) {
ssyms = SymbolTable::ReadText(FLAGS_ssymbols);
if (!ssyms) exit(1);
}
if (!isyms && !FLAGS_numeric)
isyms = fst->InputSymbols();
if (!osyms && !FLAGS_numeric)
osyms = fst->OutputSymbols();
FstPrinter<Arc> fstprinter(*fst, isyms, osyms, ssyms, FLAGS_acceptor);
fstprinter.Print(ostrm, dest);
if (isyms && !FLAGS_save_isymbols.empty())
isyms->WriteText(FLAGS_save_isymbols);
if (osyms && !FLAGS_save_osymbols.empty())
osyms->WriteText(FLAGS_save_osymbols);
if (ostrm != &std::cout)
delete ostrm;
return 0;
}
#endif
template <class A> class FstReader {
public:
typedef A Arc;
typedef typename A::StateId StateId;
typedef typename A::Label Label;
typedef typename A::Weight Weight;
FstReader(istream &istrm, const string &source,
const SymbolTable *isyms, const SymbolTable *osyms,
const SymbolTable *ssyms, bool accep, bool ikeep,
bool okeep, bool nkeep)
: nline_(0), source_(source),
isyms_(isyms), osyms_(osyms), ssyms_(ssyms),
nstates_(0), keep_state_numbering_(nkeep) {
char line[kLineLen];
while (istrm.getline(line, kLineLen)) {
++nline_;
vector<char *> col;
SplitToVector(line, "\n\t ", &col, true);
if (col.size() == 0 || col[0][0] == '\0') // empty line
continue;
if (col.size() > 5 ||
col.size() > 4 && accep ||
col.size() == 3 && !accep) {
LOG(ERROR) << "FstReader: Bad number of columns, source = " << source_
<< ", line = " << nline_;
exit(1);
}
StateId s = StrToStateId(col[0]);
while (s >= fst_.NumStates())
fst_.AddState();
if (nline_ == 1)
fst_.SetStart(s);
Arc arc;
StateId d = s;
switch (col.size()) {
case 1:
fst_.SetFinal(s, Weight::One());
break;
case 2:
fst_.SetFinal(s, StrToWeight(col[1], true));
break;
case 3:
arc.nextstate = d = StrToStateId(col[1]);
arc.ilabel = StrToILabel(col[2]);
arc.olabel = arc.ilabel;
arc.weight = Weight::One();
fst_.AddArc(s, arc);
break;
case 4:
arc.nextstate = d = StrToStateId(col[1]);
arc.ilabel = StrToILabel(col[2]);
if (accep) {
arc.olabel = arc.ilabel;
arc.weight = StrToWeight(col[3], false);
} else {
arc.olabel = StrToOLabel(col[3]);
arc.weight = Weight::One();
}
fst_.AddArc(s, arc);
break;
case 5:
arc.nextstate = d = StrToStateId(col[1]);
arc.ilabel = StrToILabel(col[2]);
arc.olabel = StrToOLabel(col[3]);
arc.weight = StrToWeight(col[4], false);
fst_.AddArc(s, arc);
}
while (d >= fst_.NumStates())
fst_.AddState();
}
if (ikeep)
fst_.SetInputSymbols(isyms);
if (okeep)
fst_.SetOutputSymbols(osyms);
}
const VectorFst<A> &Fst() const { return fst_; }
private:
// Maximum line length in text file.
static const int kLineLen = 8096;
int64 StrToId(const char *s, const SymbolTable *syms,
const char *name) const {
int64 n;
if (syms) {
n = syms->Find(s);
if (n < 0) {
LOG(ERROR) << "FstReader: Symbol \"" << s
<< "\" is not mapped to any integer " << name
<< ", symbol table = " << syms->Name()
<< ", source = " << source_ << ", line = " << nline_;
exit(1);
}
} else {
char *p;
n = strtoll(s, &p, 10);
if (p < s + strlen(s) || n < 0) {
LOG(ERROR) << "FstReader: Bad " << name << " integer = \"" << s
<< "\", source = " << source_ << ", line = " << nline_;
exit(1);
}
}
return n;
}
StateId StrToStateId(const char *s) {
StateId n = StrToId(s, ssyms_, "state ID");
if (keep_state_numbering_)
return n;
// remap state IDs to make dense set
typename hash_map<StateId, StateId>::const_iterator it = states_.find(n);
if (it == states_.end()) {
states_[n] = nstates_;
return nstates_++;
} else {
return it->second;
}
}
StateId StrToILabel(const char *s) const {
return StrToId(s, isyms_, "arc ilabel");
}
StateId StrToOLabel(const char *s) const {
return StrToId(s, osyms_, "arc olabel");
}
Weight StrToWeight(const char *s, bool allow_zero) const {
Weight w;
istringstream strm(s);
strm >> w;
if (strm.fail() || !allow_zero && w == Weight::Zero()) {
LOG(ERROR) << "FstReader: Bad weight = \"" << s
<< "\", source = " << source_ << ", line = " << nline_;
exit(1);
}
return w;
}
VectorFst<A> fst_;
size_t nline_;
string source_; // text FST source name
const SymbolTable *isyms_; // ilabel symbol table
const SymbolTable *osyms_; // olabel symbol table
const SymbolTable *ssyms_; // slabel symbol table
hash_map<StateId, StateId> states_; // state ID map
StateId nstates_; // number of seen states
bool keep_state_numbering_;
DISALLOW_EVIL_CONSTRUCTORS(FstReader);
};
#if 0
// Main function for fstcompile templated on the arc type. Last two
// arguments unneeded since fstcompile passes the arc type as a flag
// unlike the other mains, which infer the arc type from an input Fst.
template <class Arc>
int CompileMain(int argc, char **argv, istream& /* strm */,
const FstReadOptions & /* opts */) {
char *ifilename = "standard input";
istream *istrm = &std::cin;
if (argc > 1 && strcmp(argv[1], "-") != 0) {
ifilename = argv[1];
istrm = new ifstream(ifilename);
if (!*istrm) {
LOG(ERROR) << argv[0] << ": Open failed, file = " << ifilename;
return 1;
}
}
const SymbolTable *isyms = 0, *osyms = 0, *ssyms = 0;
if (!FLAGS_isymbols.empty()) {
isyms = SymbolTable::ReadText(FLAGS_isymbols);
if (!isyms) exit(1);
}
if (!FLAGS_osymbols.empty()) {
osyms = SymbolTable::ReadText(FLAGS_osymbols);
if (!osyms) exit(1);
}
if (!FLAGS_ssymbols.empty()) {
ssyms = SymbolTable::ReadText(FLAGS_ssymbols);
if (!ssyms) exit(1);
}
FstReader<Arc> fstreader(*istrm, ifilename, isyms, osyms, ssyms,
FLAGS_acceptor, FLAGS_keep_isymbols,
FLAGS_keep_osymbols, FLAGS_keep_state_numbering);
const Fst<Arc> *fst = &fstreader.Fst();
if (FLAGS_fst_type != "vector") {
fst = Convert<Arc>(*fst, FLAGS_fst_type);
if (!fst) return 1;
}
fst->Write(argc > 2 ? argv[2] : "");
if (istrm != &std::cin)
delete istrm;
return 0;
}
#endif
} // namespace fst
#endif /* __FST_IO_H__ */