blob: 0f82213e0dda9c4099b9576dd6976ea72094018f [file] [log] [blame]
/*---------------------------------------------------------------------------*
* vocab.cpp *
* *
* Copyright 2007, 2008 Nuance Communciations, Inc. *
* *
* Licensed under the Apache License, Version 2.0 (the 'License'); *
* you may not use this file except in compliance with the License. *
* *
* You may obtain a copy of the License at *
* http://www.apache.org/licenses/LICENSE-2.0 *
* *
* Unless required by applicable law or agreed to in writing, software *
* distributed under the License is distributed on an 'AS IS' BASIS, *
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
* See the License for the specific language governing permissions and *
* limitations under the License. *
* *
*---------------------------------------------------------------------------*/
#include <string>
#include <iostream>
#include <stdexcept>
#include "ESR_Locale.h"
#include "LCHAR.h"
#include "pstdio.h"
#include "ESR_Session.h"
#include "SR_Vocabulary.h"
#include "vocab.h"
#define MAX_LINE_LENGTH 256
#define MAX_PRONS_LENGTH 1024
#define DEBUG 0
#define GENERIC CONTEXT "#"
Vocabulary::Vocabulary( std::string const & vocFileName )
{
ESR_ReturnCode rc;
rc = SR_VocabularyLoad(vocFileName.c_str(), &m_hVocab);
if (rc != ESR_SUCCESS)
{
std::cout << "Error: " << ESR_rc2str(rc) <<std::endl;
exit (-1);
}
}
Vocabulary::~Vocabulary()
{
SR_VocabularyDestroy(m_hVocab);
}
Pronunciation::Pronunciation()
{
}
Pronunciation::~Pronunciation()
{
}
void Pronunciation::clear()
{
m_Prons.clear();
for (unsigned int ii=0;ii<m_ModelIDs.size();ii++ )
{
m_ModelIDs[ii].clear();
}
m_ModelIDs.clear();
}
int Pronunciation::lookup( Vocabulary & vocab, std::string & phrase )
{
ESR_ReturnCode rc;
LCHAR prons[MAX_PRONS_LENGTH];
LCHAR* c_phrase;
size_t len;
LCHAR s[MAX_LINE_LENGTH];
strcpy (s, phrase.c_str() ); // No conversion for std::string to wchar
//clear();
memset (prons, 0x00, sizeof(LCHAR));
c_phrase = s;
SR_Vocabulary *p_SRVocab = vocab.getSRVocabularyHandle();
#if DEBUG
std::cout << "DEBUG: " << phrase <<" to be looked up" << std::endl;
#endif
rc = SR_VocabularyGetPronunciation( p_SRVocab, c_phrase, prons, &len );
if (rc != ESR_SUCCESS)
// std::cout <<"ERORORORORROOR!" <<std::endl;
std::cout <<"ERROR: " << ESR_rc2str(rc) << std::endl;
else {
#if DEBUG
std::cout <<"OUTPUT: " << prons << " num " << len << std::endl;
#endif
size_t len_used;
LCHAR *pron = 0;
for(len_used=0; len_used <len; ) {
pron = &prons[0]+len_used;
len_used += LSTRLEN(pron)+1;
#if DEBUG
std::cout << "DEBUG: used " << len_used << " now " << LSTRLEN(pron) << std::endl;
#endif
std::string pronString( pron ); // wstring conversion if needed
addPron( pronString );
#if DEBUG
std::cout << "DEBUG: " << phrase << " " << pron << std::endl;
#endif
}
}
return getPronCount();
}
int Pronunciation::addPron( std::string & s )
{
m_Prons.push_back( s );
return m_Prons.size();
}
int Pronunciation::getPronCount()
{ // returns number of prons
return m_Prons.size();
}
bool Pronunciation::getPron( int index, std::string &s )
{
// returns string length used
try {
s = m_Prons.at(index);
}
catch(std::out_of_range& err) {
std::cerr << "out_of_range: " << err.what() << std::endl;
}
return true;
}
void Pronunciation::print()
{
std::string s;
for (int ii=0; ii< getPronCount(); ii++) {
getPron(ii, s);
#if DEBUG
std::cout << "Pron #" << ii << ": " << s << std::endl;
#endif
}
}
void Pronunciation::printModelIDs()
{
std::string s;
for (int ii=0; ii< getPronCount(); ii++) {
getPron(ii, s);
#if DEBUG
std::cout << " Pron #" << ii << ": " << s << std::endl;
std::cout << " Model IDs: ";
#endif
for (int jj=0;jj<getModelCount(ii);jj++) {
std::cout << " " << getModelID(ii,jj);
}
#if DEBUG
std::cout << std::endl;
#endif
}
}
int Pronunciation::getPhonemeCount( int pronIndex )
{
std::string s;
getPron(pronIndex, s);
return s.size();
}
bool Pronunciation::getPhoneme( int pronIndex, int picIndex , std::string &phoneme )
{
std::string s;
getPron(pronIndex, s);
phoneme= s.at(picIndex);
return true;
}
bool Pronunciation::getPIC( int pronIndex, int picIndex, std::string &pic )
{
std::string pron;
char lphon;
char cphon;
char rphon;
getPron( pronIndex, pron );
int numPhonemes = pron.size();
if ( 1==numPhonemes ) {
lphon=GENERIC_CONTEXT;
rphon=GENERIC_CONTEXT;
cphon = pron.at(0);
}
else
{
if ( 0==picIndex ) {
lphon=GENERIC_CONTEXT;
rphon=GENERIC_CONTEXT;
}
else if( numPhonemes-1==picIndex ) {
lphon = pron.at(picIndex-1);
rphon=GENERIC_CONTEXT;
}
else {
lphon = pron.at(picIndex-1);
rphon = pron.at(picIndex+1);
}
cphon = pron.at(picIndex);
pic = lphon + cphon + rphon;
}
return true;
}
int Pronunciation::lookupModelIDs( AcousticModel &acoustic )
{
// Looks up all hmms for all prons
std::string pron;
char lphon;
char cphon;
char rphon;
int numProns = getPronCount();
int totalCount=0;
for (int ii=0;ii < numProns; ii++ )
{
getPron( ii, pron );
std::vector<int> idList; // Create storage
int numPhonemes = getPhonemeCount(ii);
if (1==numPhonemes) {
lphon=GENERIC_CONTEXT;
rphon=GENERIC_CONTEXT;
cphon = pron.at(0);
}
else
for ( int jj=0;jj<numPhonemes;jj++ )
{
std::string pic;
getPIC(ii, jj, pic);
lphon = pron.at(0);
cphon = pron.at(1);
rphon = pron.at(2);
int id = CA_ArbdataGetModelIdsForPIC( acoustic.getCAModelHandle(), lphon, cphon, rphon );
#if DEBUG
std::cout <<"DEBUG model id: " << lphon <<cphon << rphon << " "<< id << std::endl;
#endif
idList.push_back(id);
}
m_ModelIDs.push_back(idList);
totalCount+=numPhonemes;
}
return totalCount;
}
int Pronunciation::getModelCount( int pronIndex )
{
return m_ModelIDs[pronIndex].size();
}
int Pronunciation::getModelID( int pronIndex, int modelPos )
{
return m_ModelIDs[pronIndex][modelPos];
}
AcousticModel::AcousticModel( std::string & arbFileName )
{
m_CA_Arbdata = CA_LoadArbdata( arbFileName.c_str() );
if (!m_CA_Arbdata)
{
std::cout << "Error: while trying to load " << arbFileName.c_str() << std::endl;
exit (-1);
}
}
AcousticModel::~AcousticModel()
{
CA_FreeArbdata( m_CA_Arbdata);
}
int AcousticModel::getStateIndices(int id, std::vector<int> & stateIDs)
{
srec_arbdata *allotree = (srec_arbdata*) m_CA_Arbdata;
int numStates = allotree->hmm_infos[id].num_states;
#if DEBUG
std::cout << "getStateIndices: count = " << numStates <<std::endl;
#endif
for (int ii=0; ii <numStates; ii++ ) {
stateIDs.push_back( allotree->hmm_infos[id].state_indices[ii] );
#if DEBUG
std::cout << allotree->hmm_infos[id].state_indices[ii] ;
#endif
}
#if DEBUG
std::cout << std::endl;
#endif
return stateIDs.size();
}