@@ -63,7 +63,9 @@ struct StreamingState {
6363 vector<float > batch_buffer_;
6464 vector<float > previous_state_c_;
6565 vector<float > previous_state_h_;
66+ bool keep_emissions_ = false ;
6667
68+ vector<double > probs_;
6769 ModelState* model_;
6870 DecoderState decoder_state_;
6971
@@ -134,7 +136,42 @@ StreamingState::intermediateDecode() const
134136Metadata*
135137StreamingState::intermediateDecodeWithMetadata (unsigned int num_results) const
136138{
137- return model_->decode_metadata (decoder_state_, num_results);
139+ Metadata *m = model_->decode_metadata (decoder_state_, num_results);
140+
141+ if (keep_emissions_) {
142+
143+ const size_t alphabet_size = model_->alphabet_ .GetSize ();
144+ const int num_timesteps = probs_.size () / (ModelState::BATCH_SIZE * (alphabet_size + 1 ));
145+
146+ AcousticModelEmissions* emissions = (AcousticModelEmissions*)malloc (sizeof (AcousticModelEmissions));
147+
148+ emissions->num_symbols = alphabet_size;
149+ emissions->num_timesteps = num_timesteps;
150+ emissions->symbols = (const char **)malloc (sizeof (char *)*alphabet_size + 1 );
151+ for (int i = 0 ; i < alphabet_size; i++) {
152+ emissions->symbols [i] = strdup (model_->alphabet_ .DecodeSingle (i).c_str ());
153+ }
154+ emissions->symbols [alphabet_size] = strdup (" \t " );
155+
156+ double * probs = (double *)malloc (sizeof (double )*(alphabet_size + 1 )*num_timesteps);
157+ memcpy (probs, probs_.data (), sizeof (double )*(alphabet_size + 1 )*num_timesteps);
158+
159+ emissions->emissions = probs;
160+
161+ Metadata* ret = (Metadata*)malloc (sizeof (Metadata));
162+
163+ Metadata metadata {
164+ m->transcripts , // transcripts
165+ m->num_transcripts , // num_transcripts
166+ emissions,
167+ };
168+
169+ memcpy (ret, &metadata, sizeof (Metadata));
170+
171+ return ret;
172+ }
173+
174+ return m;
138175}
139176
140177char *
@@ -148,7 +185,42 @@ Metadata*
148185StreamingState::finishStreamWithMetadata (unsigned int num_results)
149186{
150187 flushBuffers (true );
151- return model_->decode_metadata (decoder_state_, num_results);
188+ Metadata *m = model_->decode_metadata (decoder_state_, num_results);
189+
190+ if (keep_emissions_) {
191+
192+ const size_t alphabet_size = model_->alphabet_ .GetSize ();
193+ const int num_timesteps = probs_.size () / (ModelState::BATCH_SIZE * (alphabet_size + 1 ));
194+
195+ AcousticModelEmissions* emissions = (AcousticModelEmissions*)malloc (sizeof (AcousticModelEmissions));
196+
197+ emissions->num_symbols = alphabet_size;
198+ emissions->num_timesteps = num_timesteps;
199+ emissions->symbols = (const char **)malloc (sizeof (char *)*alphabet_size + 1 );
200+ for (int i = 0 ; i < alphabet_size; i++) {
201+ emissions->symbols [i] = strdup (model_->alphabet_ .DecodeSingle (i).c_str ());
202+ }
203+ emissions->symbols [alphabet_size] = strdup (" \t " );
204+
205+ double * probs = (double *)malloc (sizeof (double )*(alphabet_size + 1 )*num_timesteps);
206+ memcpy (probs, probs_.data (), sizeof (double )*(alphabet_size + 1 )*num_timesteps);
207+
208+ emissions->emissions = probs;
209+
210+ Metadata* ret = (Metadata*)malloc (sizeof (Metadata));
211+
212+ Metadata metadata {
213+ m->transcripts , // transcripts
214+ m->num_transcripts , // num_transcripts
215+ emissions,
216+ };
217+
218+ memcpy (ret, &metadata, sizeof (Metadata));
219+
220+ return ret;
221+ }
222+
223+ return m;
152224}
153225
154226void
@@ -253,7 +325,9 @@ StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps)
253325
254326 // Convert logits to double
255327 vector<double > inputs (logits.begin (), logits.end ());
256-
328+ if (keep_emissions_) {
329+ probs_ = inputs;
330+ }
257331 decoder_state_.next (inputs.data (),
258332 n_frames,
259333 num_classes);
@@ -476,6 +550,41 @@ STT_CreateStream(ModelState* aCtx,
476550 return STT_ERR_OK;
477551}
478552
553+ int
554+ CreateStreamWithEmissions (ModelState* aCtx,
555+ StreamingState** retval)
556+ {
557+ *retval = nullptr ;
558+
559+ std::unique_ptr<StreamingState> ctx (new StreamingState ());
560+ if (!ctx) {
561+ std::cerr << " Could not allocate streaming state." << std::endl;
562+ return STT_ERR_FAIL_CREATE_STREAM;
563+ }
564+
565+ ctx->audio_buffer_ .reserve (aCtx->audio_win_len_ );
566+ ctx->mfcc_buffer_ .reserve (aCtx->mfcc_feats_per_timestep_ );
567+ ctx->mfcc_buffer_ .resize (aCtx->n_features_ *aCtx->n_context_ , 0 .f );
568+ ctx->batch_buffer_ .reserve (aCtx->n_steps_ * aCtx->mfcc_feats_per_timestep_ );
569+ ctx->previous_state_c_ .resize (aCtx->state_size_ , 0 .f );
570+ ctx->previous_state_h_ .resize (aCtx->state_size_ , 0 .f );
571+ ctx->model_ = aCtx;
572+ ctx->keep_emissions_ = true ;
573+
574+ const int cutoff_top_n = 40 ;
575+ const double cutoff_prob = 1.0 ;
576+
577+ ctx->decoder_state_ .init (aCtx->alphabet_ ,
578+ aCtx->beam_width_ ,
579+ cutoff_prob,
580+ cutoff_top_n,
581+ aCtx->scorer_ ,
582+ aCtx->hot_words_ );
583+
584+ *retval = ctx.release ();
585+ return STT_ERR_OK;
586+ }
587+
479588void
480589STT_FeedAudioContent (StreamingState* aSctx,
481590 const short * aBuffer,
@@ -562,6 +671,22 @@ STT_SpeechToTextWithMetadata(ModelState* aCtx,
562671 return STT_FinishStreamWithMetadata (ctx, aNumResults);
563672}
564673
674+ Metadata*
675+ STT_SpeechToTextWithEmissions (ModelState* aCtx,
676+ const short * aBuffer,
677+ unsigned int aBufferSize,
678+ unsigned int aNumResults)
679+ {
680+ StreamingState* ctx;
681+ int status = CreateStreamWithEmissions (aCtx, &ctx);
682+ if (status != STT_ERR_OK) {
683+ return nullptr ;
684+ }
685+ STT_FeedAudioContent (ctx, aBuffer, aBufferSize);
686+
687+ return STT_FinishStreamWithMetadata (ctx, aNumResults);
688+ }
689+
565690void
566691STT_FreeStream (StreamingState* aSctx)
567692{
@@ -581,10 +706,24 @@ STT_FreeMetadata(Metadata* m)
581706 }
582707
583708 free ((void *)m->transcripts );
709+
710+ // Clean up logits if they are not NULL
711+ if (m->emissions ) {
712+
713+ if (m->emissions ->symbols ) {
714+ for (int i = 0 ; i < m->emissions ->num_symbols + 1 ; i++) {
715+ free ((void *)m->emissions ->symbols [i]);
716+ }
717+ free ((void *)m->emissions ->symbols );
718+ }
719+ if (m->emissions ->emissions ) {
720+ free ((void *)m->emissions ->emissions );
721+ }
722+ free ((void *)m->emissions );
723+ }
584724 free (m);
585725 }
586726}
587-
588727void
589728STT_FreeString (char * str)
590729{
0 commit comments