Skip to content

Commit d3c0b78

Browse files
authored
getting lgits to client code
1 parent a694187 commit d3c0b78

File tree

6 files changed

+215
-7
lines changed

6 files changed

+215
-7
lines changed

native_client/args.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ bool init_from_array_of_bytes = false;
3939
int json_candidate_transcripts = 3;
4040

4141
int stream_size = 0;
42-
42+
bool keep_emissions = false;
4343
int extended_stream_size = 0;
4444

4545
char* hot_words = NULL;
@@ -59,6 +59,7 @@ void PrintHelp(const char* bin)
5959
"\t--lm_beta LM_BETA\t\tValue for language model beta param (float)\n"
6060
"\t-t\t\t\t\tRun in benchmark mode, output mfcc & inference time\n"
6161
"\t--extended\t\t\tOutput string from extended metadata\n"
62+
"\t--keep_emissions\t\t\tSave the output of the acoustic model\n"
6263
"\t--json\t\t\t\tExtended output, shows word timings as JSON\n"
6364
"\t--candidate_transcripts NUMBER\tNumber of candidate transcripts to include in JSON output\n"
6465
"\t--stream size\t\t\tRun in stream mode, output intermediate results\n"
@@ -85,6 +86,7 @@ bool ProcessArgs(int argc, char** argv)
8586
{"lm_beta", required_argument, nullptr, 'd'},
8687
{"t", no_argument, nullptr, 't'},
8788
{"extended", no_argument, nullptr, 'e'},
89+
{"keep_emissions", no_argument, nullptr, 'L'},
8890
{"json", no_argument, nullptr, 'j'},
8991
{"init_from_bytes", no_argument, nullptr, 'B'},
9092
{"candidate_transcripts", required_argument, nullptr, 150},
@@ -139,6 +141,10 @@ bool ProcessArgs(int argc, char** argv)
139141
case 'e':
140142
extended_metadata = true;
141143
break;
144+
145+
case 'L':
146+
keep_emissions = true;
147+
break;
142148

143149
case 'j':
144150
json_output = true;

native_client/client.cc

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,36 @@ MetadataToJSON(Metadata* result)
155155
}
156156
}
157157

158+
if (keep_emissions && result->emissions != NULL) {
159+
int num_timesteps = result->emissions->num_timesteps;
160+
int num_symbols = result->emissions->num_symbols;
161+
int class_dim = num_symbols + 1;
162+
const char **symbol_table = result->emissions->symbols;
163+
out_string << ",\n" << R"("alphabet")" << ":[";
164+
for(int i = 0; i < class_dim; i++) {
165+
out_string << "\"" << symbol_table[i] << "\"";
166+
if(i < class_dim - 1) {
167+
out_string << ", ";
168+
}
169+
}
170+
out_string << "],\n" << R"("emissions")" << ":[\n";
171+
for(int i = 0; i < num_timesteps; i++) {
172+
out_string << "[";
173+
for(int j = 0; j < num_symbols; j++) {
174+
out_string << result->emissions->emissions[i * num_symbols + j];
175+
if(j < num_symbols - 1) {
176+
out_string << ", ";
177+
}
178+
}
179+
out_string << "]";
180+
if(i < num_timesteps - 1) {
181+
out_string << ",";
182+
}
183+
out_string << "\n";
184+
}
185+
out_string << "\n]";
186+
}
187+
158188
out_string << "\n}\n";
159189

160190
return strdup(out_string.str().c_str());
@@ -169,14 +199,18 @@ LocalDsSTT(ModelState* aCtx, const short* aBuffer, size_t aBufferSize,
169199
clock_t stt_start_time = clock();
170200

171201
// sphinx-doc: c_ref_inference_start
172-
if (extended_output) {
202+
if (extended_output && !keep_emissions) {
173203
Metadata *result = STT_SpeechToTextWithMetadata(aCtx, aBuffer, aBufferSize, 1);
174204
res.string = CandidateTranscriptToString(&result->transcripts[0]);
175205
STT_FreeMetadata(result);
176-
} else if (json_output) {
206+
} else if (json_output && !keep_emissions) {
177207
Metadata *result = STT_SpeechToTextWithMetadata(aCtx, aBuffer, aBufferSize, json_candidate_transcripts);
178208
res.string = MetadataToJSON(result);
179209
STT_FreeMetadata(result);
210+
} else if (keep_emissions) {
211+
Metadata *result = STT_SpeechToTextWithEmissions(aCtx, aBuffer, aBufferSize, json_candidate_transcripts);
212+
res.string = MetadataToJSON(result);
213+
STT_FreeMetadata(result);
180214
} else if (stream_size > 0) {
181215
StreamingState* ctx;
182216
int status = STT_CreateStream(aCtx, &ctx);

native_client/coqui-stt.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,24 @@ typedef struct CandidateTranscript {
5252
*/
5353
const double confidence;
5454
} CandidateTranscript;
55+
/**
56+
* @brief An structure to contain emissions (the softmax output of individual
57+
* timesteps) from the acoustic model.
58+
*
59+
* @member The layout of the emissions member is time major, thus to access the
60+
* probability of symbol j at timestep i you would use
61+
* emissions[i * num_symbols + j]
62+
*/
63+
typedef struct AcousticModelEmissions {
64+
/** number of symbols in the alphabet, including CTC blank */
65+
int num_symbols;
66+
/** num_symbols long array of NUL-terminated strings */
67+
const char **symbols;
68+
/** total number of timesteps */
69+
int num_timesteps;
70+
/** num_timesteps long array, each pointer is a num_symbols long array */
71+
const double *emissions;
72+
} AcousticModelEmissions;
5573

5674
/**
5775
* @brief An array of CandidateTranscript objects computed by the model.
@@ -61,6 +79,8 @@ typedef struct Metadata {
6179
const CandidateTranscript* const transcripts;
6280
/** Size of the transcripts array */
6381
const unsigned int num_transcripts;
82+
/** Logits and information to decode them **/
83+
const AcousticModelEmissions* const emissions;
6484
} Metadata;
6585

6686
#endif /* SWIG_ERRORS_ONLY */
@@ -306,6 +326,13 @@ Metadata* STT_SpeechToTextWithMetadata(ModelState* aCtx,
306326
*
307327
* @return Zero for success, non-zero on failure.
308328
*/
329+
330+
STT_EXPORT
331+
Metadata* STT_SpeechToTextWithEmissions(ModelState* aCtx,
332+
const short* aBuffer,
333+
unsigned int aBufferSize,
334+
unsigned int aNumResults);
335+
309336
STT_EXPORT
310337
int STT_CreateStream(ModelState* aCtx,
311338
StreamingState** retval);

native_client/ctcdecode/output.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ struct Output {
1010
double confidence;
1111
std::vector<unsigned int> tokens;
1212
std::vector<unsigned int> timesteps;
13+
std::vector<std::vector<std::pair<int, double>>> probs;
1314
};
1415

1516
struct FlashlightOutput {

native_client/modelstate.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ ModelState::decode_metadata(const DecoderState& state,
6969
Metadata metadata {
7070
transcripts, // transcripts
7171
num_returned, // num_transcripts
72+
NULL,
7273
};
7374
memcpy(ret, &metadata, sizeof(Metadata));
7475
return ret;

native_client/stt.cc

Lines changed: 143 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ struct StreamingState {
6363
vector<float> batch_buffer_;
6464
vector<float> previous_state_c_;
6565
vector<float> previous_state_h_;
66+
bool keep_emissions_ = false;
6667

68+
vector<double> probs_;
6769
ModelState* model_;
6870
DecoderState decoder_state_;
6971

@@ -134,7 +136,42 @@ StreamingState::intermediateDecode() const
134136
Metadata*
135137
StreamingState::intermediateDecodeWithMetadata(unsigned int num_results) const
136138
{
137-
return model_->decode_metadata(decoder_state_, num_results);
139+
Metadata *m = model_->decode_metadata(decoder_state_, num_results);
140+
141+
if (keep_emissions_) {
142+
143+
const size_t alphabet_size = model_->alphabet_.GetSize();
144+
const int num_timesteps = probs_.size() / (ModelState::BATCH_SIZE * (alphabet_size + 1));
145+
146+
AcousticModelEmissions* emissions = (AcousticModelEmissions*)malloc(sizeof(AcousticModelEmissions));
147+
148+
emissions->num_symbols = alphabet_size;
149+
emissions->num_timesteps = num_timesteps;
150+
emissions->symbols = (const char**)malloc(sizeof(char*)*alphabet_size + 1);
151+
for (int i = 0; i < alphabet_size; i++) {
152+
emissions->symbols[i] = strdup(model_->alphabet_.DecodeSingle(i).c_str());
153+
}
154+
emissions->symbols[alphabet_size] = strdup("\t");
155+
156+
double* probs = (double*)malloc(sizeof(double)*(alphabet_size + 1)*num_timesteps);
157+
memcpy(probs, probs_.data(), sizeof(double)*(alphabet_size + 1)*num_timesteps);
158+
159+
emissions->emissions = probs;
160+
161+
Metadata* ret = (Metadata*)malloc(sizeof(Metadata));
162+
163+
Metadata metadata {
164+
m->transcripts, // transcripts
165+
m->num_transcripts, // num_transcripts
166+
emissions,
167+
};
168+
169+
memcpy(ret, &metadata, sizeof(Metadata));
170+
171+
return ret;
172+
}
173+
174+
return m;
138175
}
139176

140177
char*
@@ -148,7 +185,42 @@ Metadata*
148185
StreamingState::finishStreamWithMetadata(unsigned int num_results)
149186
{
150187
flushBuffers(true);
151-
return model_->decode_metadata(decoder_state_, num_results);
188+
Metadata *m = model_->decode_metadata(decoder_state_, num_results);
189+
190+
if (keep_emissions_) {
191+
192+
const size_t alphabet_size = model_->alphabet_.GetSize();
193+
const int num_timesteps = probs_.size() / (ModelState::BATCH_SIZE * (alphabet_size + 1));
194+
195+
AcousticModelEmissions* emissions = (AcousticModelEmissions*)malloc(sizeof(AcousticModelEmissions));
196+
197+
emissions->num_symbols = alphabet_size;
198+
emissions->num_timesteps = num_timesteps;
199+
emissions->symbols = (const char**)malloc(sizeof(char*)*alphabet_size + 1);
200+
for (int i = 0; i < alphabet_size; i++) {
201+
emissions->symbols[i] = strdup(model_->alphabet_.DecodeSingle(i).c_str());
202+
}
203+
emissions->symbols[alphabet_size] = strdup("\t");
204+
205+
double* probs = (double*)malloc(sizeof(double)*(alphabet_size + 1)*num_timesteps);
206+
memcpy(probs, probs_.data(), sizeof(double)*(alphabet_size + 1)*num_timesteps);
207+
208+
emissions->emissions = probs;
209+
210+
Metadata* ret = (Metadata*)malloc(sizeof(Metadata));
211+
212+
Metadata metadata {
213+
m->transcripts, // transcripts
214+
m->num_transcripts, // num_transcripts
215+
emissions,
216+
};
217+
218+
memcpy(ret, &metadata, sizeof(Metadata));
219+
220+
return ret;
221+
}
222+
223+
return m;
152224
}
153225

154226
void
@@ -253,7 +325,9 @@ StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps)
253325

254326
// Convert logits to double
255327
vector<double> inputs(logits.begin(), logits.end());
256-
328+
if (keep_emissions_) {
329+
probs_ = inputs;
330+
}
257331
decoder_state_.next(inputs.data(),
258332
n_frames,
259333
num_classes);
@@ -476,6 +550,41 @@ STT_CreateStream(ModelState* aCtx,
476550
return STT_ERR_OK;
477551
}
478552

553+
int
554+
CreateStreamWithEmissions(ModelState* aCtx,
555+
StreamingState** retval)
556+
{
557+
*retval = nullptr;
558+
559+
std::unique_ptr<StreamingState> ctx(new StreamingState());
560+
if (!ctx) {
561+
std::cerr << "Could not allocate streaming state." << std::endl;
562+
return STT_ERR_FAIL_CREATE_STREAM;
563+
}
564+
565+
ctx->audio_buffer_.reserve(aCtx->audio_win_len_);
566+
ctx->mfcc_buffer_.reserve(aCtx->mfcc_feats_per_timestep_);
567+
ctx->mfcc_buffer_.resize(aCtx->n_features_*aCtx->n_context_, 0.f);
568+
ctx->batch_buffer_.reserve(aCtx->n_steps_ * aCtx->mfcc_feats_per_timestep_);
569+
ctx->previous_state_c_.resize(aCtx->state_size_, 0.f);
570+
ctx->previous_state_h_.resize(aCtx->state_size_, 0.f);
571+
ctx->model_ = aCtx;
572+
ctx->keep_emissions_ = true;
573+
574+
const int cutoff_top_n = 40;
575+
const double cutoff_prob = 1.0;
576+
577+
ctx->decoder_state_.init(aCtx->alphabet_,
578+
aCtx->beam_width_,
579+
cutoff_prob,
580+
cutoff_top_n,
581+
aCtx->scorer_,
582+
aCtx->hot_words_);
583+
584+
*retval = ctx.release();
585+
return STT_ERR_OK;
586+
}
587+
479588
void
480589
STT_FeedAudioContent(StreamingState* aSctx,
481590
const short* aBuffer,
@@ -562,6 +671,22 @@ STT_SpeechToTextWithMetadata(ModelState* aCtx,
562671
return STT_FinishStreamWithMetadata(ctx, aNumResults);
563672
}
564673

674+
Metadata*
675+
STT_SpeechToTextWithEmissions(ModelState* aCtx,
676+
const short* aBuffer,
677+
unsigned int aBufferSize,
678+
unsigned int aNumResults)
679+
{
680+
StreamingState* ctx;
681+
int status = CreateStreamWithEmissions(aCtx, &ctx);
682+
if (status != STT_ERR_OK) {
683+
return nullptr;
684+
}
685+
STT_FeedAudioContent(ctx, aBuffer, aBufferSize);
686+
687+
return STT_FinishStreamWithMetadata(ctx, aNumResults);
688+
}
689+
565690
void
566691
STT_FreeStream(StreamingState* aSctx)
567692
{
@@ -581,10 +706,24 @@ STT_FreeMetadata(Metadata* m)
581706
}
582707

583708
free((void*)m->transcripts);
709+
710+
// Clean up logits if they are not NULL
711+
if (m->emissions) {
712+
713+
if (m->emissions->symbols) {
714+
for (int i = 0; i < m->emissions->num_symbols + 1; i++) {
715+
free((void*)m->emissions->symbols[i]);
716+
}
717+
free((void*)m->emissions->symbols);
718+
}
719+
if (m->emissions->emissions) {
720+
free((void*)m->emissions->emissions);
721+
}
722+
free((void*)m->emissions);
723+
}
584724
free(m);
585725
}
586726
}
587-
588727
void
589728
STT_FreeString(char* str)
590729
{

0 commit comments

Comments
 (0)