Skip to content

Commit

Permalink
Run LSTM recognition in multiple threads
Browse files Browse the repository at this point in the history
Init time option lstm_num_threads should be used to set the number of LSTM threads
  • Loading branch information
jkarthic committed Jun 27, 2024
1 parent 2991d36 commit 6a2e239
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 59 deletions.
128 changes: 88 additions & 40 deletions src/ccmain/control.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <cstdint> // for int16_t, int32_t
#include <cstdio> // for fclose, fopen, FILE
#include <ctime> // for clock
#include <future>
#include "control.h"
#ifndef DISABLED_LEGACY_ENGINE
# include "docqual.h"
Expand Down Expand Up @@ -194,36 +195,42 @@ void Tesseract::SetupWordPassN(int pass_n, WordData *word) {
}
}

// Runs word recognition on all the words.
bool Tesseract::RecogAllWordsPassN(int pass_n, ETEXT_DESC *monitor, PAGE_RES_IT *pr_it,
std::vector<WordData> *words) {
// TODO(rays) Before this loop can be parallelized (it would yield a massive
// speed-up) all remaining member globals need to be converted to local/heap
// (eg set_pass1 and set_pass2) and an intermediate adaption pass needs to be
// added. The results will be significantly different with adaption on, and
// deterioration will need investigation.
pr_it->restart_page();
for (unsigned w = 0; w < words->size(); ++w) {
WordData *word = &(*words)[w];
if (w > 0) {
word->prev_word = &(*words)[w - 1];
bool Tesseract::RecogWordsSegment(std::vector<WordData>::iterator start,
std::vector<WordData>::iterator end,
int pass_n,
ETEXT_DESC *monitor,
PAGE_RES *page_res,
LSTMRecognizer *lstm_recognizer,
std::atomic<int>& words_done,
int total_words,
std::mutex& monitor_mutex) {
PAGE_RES_IT pr_it(page_res);
// Process a segment of the words vector
pr_it.restart_page();

for (auto it = start; it != end; ++it, ++words_done) {
WordData *word = &(*it);
if (it != start) {
word->prev_word = &(*(it - 1));
}
if (monitor != nullptr) {
std::lock_guard<std::mutex> lock(monitor_mutex);
monitor->ocr_alive = true;
if (pass_n == 1) {
monitor->progress = 70 * w / words->size();
monitor->progress = 70 * words_done / total_words;
} else {
monitor->progress = 70 + 30 * w / words->size();
monitor->progress = 70 + 30 * words_done / total_words;
}
// Only call the progress callback for the first thread.
if (monitor->progress_callback2 != nullptr) {
TBOX box = pr_it->word()->word->bounding_box();
TBOX box = pr_it.word()->word->bounding_box();
(*monitor->progress_callback2)(monitor, box.left(), box.right(), box.top(), box.bottom());
}
if (monitor->deadline_exceeded() ||
(monitor->cancel != nullptr && (*monitor->cancel)(monitor->cancel_this, words->size()))) {
(monitor->cancel != nullptr && (*monitor->cancel)(monitor->cancel_this, total_words))) {
// Timeout. Fake out the rest of the words.
for (; w < words->size(); ++w) {
(*words)[w].word->SetupFake(unicharset);
for (; it != end; ++it) {
it->word->SetupFake(unicharset);
}
return false;
}
Expand All @@ -238,31 +245,69 @@ bool Tesseract::RecogAllWordsPassN(int pass_n, ETEXT_DESC *monitor, PAGE_RES_IT
}
}
// Sync pr_it with the WordData.
while (pr_it->word() != nullptr && pr_it->word() != word->word) {
pr_it->forward();
while (pr_it.word() != nullptr && pr_it.word() != word->word) {
pr_it.forward();
}
ASSERT_HOST(pr_it->word() != nullptr);
ASSERT_HOST(pr_it.word() != nullptr);
bool make_next_word_fuzzy = false;
#ifndef DISABLED_LEGACY_ENGINE
if (!AnyLSTMLang() && ReassignDiacritics(pass_n, pr_it, &make_next_word_fuzzy)) {
if (!AnyLSTMLang() && ReassignDiacritics(pass_n, &pr_it, &make_next_word_fuzzy)) {
// Needs to be setup again to see the new outlines in the chopped_word.
SetupWordPassN(pass_n, word);
}
#endif // ndef DISABLED_LEGACY_ENGINE

classify_word_and_language(pass_n, pr_it, word);
classify_word_and_language(pass_n, &pr_it, word, lstm_recognizer);
if (tessedit_dump_choices || debug_noise_removal) {
tprintf("Pass%d: %s [%s]\n", pass_n, word->word->best_choice->unichar_string().c_str(),
word->word->best_choice->debug_string().c_str());
}
pr_it->forward();
if (make_next_word_fuzzy && pr_it->word() != nullptr) {
pr_it->MakeCurrentWordFuzzy();
pr_it.forward();
if (make_next_word_fuzzy && pr_it.word() != nullptr) {
pr_it.MakeCurrentWordFuzzy();
}
}
return true;
}

// Runs word recognition on all the words.
bool Tesseract::RecogAllWordsPassN(int pass_n, ETEXT_DESC *monitor, PAGE_RES *page_res,
std::vector<WordData> *words) {
int total_words = words->size();
int segment_size = total_words / lstm_num_threads;
std::atomic<int> words_done(0);
std::mutex monitor_mutex;
std::vector<std::future<bool>> futures;

// Launch multiple threads to recognize the words in parallel
auto segment_start = words->begin() + segment_size;
for (int i = 1; i < lstm_num_threads; ++i) {
auto segment_end = (i == lstm_num_threads - 1) ? words->end() : segment_start + segment_size;
futures.push_back(std::async(std::launch::async, &Tesseract::RecogWordsSegment,
this, segment_start, segment_end, pass_n, monitor, page_res,
lstm_recognizers_[i], std::ref(words_done), total_words, std::ref(monitor_mutex)));
segment_start = segment_end;
}

// Process the first segment in this thread
bool overall_result = RecogWordsSegment(words->begin(),
words->begin() + segment_size,
pass_n,
monitor,
page_res,
lstm_recognizers_[0],
std::ref(words_done),
total_words,
std::ref(monitor_mutex));

// Wait for all threads to complete and aggregate results
for (auto &f : futures) {
overall_result &= f.get();
}

return overall_result;
}

/**
* recog_all_words()
*
Expand Down Expand Up @@ -340,7 +385,7 @@ bool Tesseract::recog_all_words(PAGE_RES *page_res, ETEXT_DESC *monitor,

most_recently_used_ = this;
// Run pass 1 word recognition.
if (!RecogAllWordsPassN(1, monitor, &page_res_it, &words)) {
if (!RecogAllWordsPassN(1, monitor, page_res, &words)) {
return false;
}
// Pass 1 post-processing.
Expand Down Expand Up @@ -380,11 +425,10 @@ bool Tesseract::recog_all_words(PAGE_RES *page_res, ETEXT_DESC *monitor,
}
most_recently_used_ = this;
// Run pass 2 word recognition.
if (!RecogAllWordsPassN(2, monitor, &page_res_it, &words)) {
if (!RecogAllWordsPassN(2, monitor, page_res, &words)) {
return false;
}
}

// The next passes are only required for Tess-only.
if (AnyTessLang() && !AnyLSTMLang()) {
// ****************** Pass 3 *******************
Expand Down Expand Up @@ -871,14 +915,15 @@ static int SelectBestWords(double rating_ratio, double certainty_margin, bool de
// Returns positive if this recognizer found more new best words than the
// number kept from best_words.
int Tesseract::RetryWithLanguage(const WordData &word_data, WordRecognizer recognizer, bool debug,
WERD_RES **in_word, PointerVector<WERD_RES> *best_words) {
WERD_RES **in_word, PointerVector<WERD_RES> *best_words,
LSTMRecognizer *lstm_recognizer) {
if (debug) {
tprintf("Trying word using lang %s, oem %d\n", lang.c_str(),
static_cast<int>(tessedit_ocr_engine_mode));
}
// Run the recognizer on the word.
PointerVector<WERD_RES> new_words;
(this->*recognizer)(word_data, in_word, &new_words);
(this->*recognizer)(word_data, in_word, &new_words, lstm_recognizer);
if (new_words.empty()) {
// Transfer input word to new_words, as the classifier must have put
// the result back in the input.
Expand Down Expand Up @@ -1300,7 +1345,10 @@ float Tesseract::ClassifyBlobAsWord(int pass_n, PAGE_RES_IT *pr_it, C_BLOB *blob
// Recognizes in the current language, and if successful that is all.
// If recognition was not successful, tries all available languages until
// it gets a successful result or runs out of languages. Keeps the best result.
void Tesseract::classify_word_and_language(int pass_n, PAGE_RES_IT *pr_it, WordData *word_data) {
void Tesseract::classify_word_and_language(int pass_n, PAGE_RES_IT *pr_it, WordData *word_data,
LSTMRecognizer *lstm_recognizer_thread_local) {
LSTMRecognizer *lstm_recognizer = lstm_recognizer_thread_local ? lstm_recognizer_thread_local
: lstm_recognizer_;
#ifdef DISABLED_LEGACY_ENGINE
WordRecognizer recognizer = &Tesseract::classify_word_pass1;
#else
Expand Down Expand Up @@ -1333,19 +1381,19 @@ void Tesseract::classify_word_and_language(int pass_n, PAGE_RES_IT *pr_it, WordD
}
}
most_recently_used_->RetryWithLanguage(*word_data, recognizer, debug, &word_data->lang_words[sub],
&best_words);
&best_words, lstm_recognizer);
Tesseract *best_lang_tess = most_recently_used_;
if (!WordsAcceptable(best_words)) {
// Try all the other languages to see if they are any better.
if (most_recently_used_ != this &&
this->RetryWithLanguage(*word_data, recognizer, debug,
&word_data->lang_words[sub_langs_.size()], &best_words) > 0) {
&word_data->lang_words[sub_langs_.size()], &best_words, lstm_recognizer) > 0) {
best_lang_tess = this;
}
for (unsigned i = 0; !WordsAcceptable(best_words) && i < sub_langs_.size(); ++i) {
if (most_recently_used_ != sub_langs_[i] &&
sub_langs_[i]->RetryWithLanguage(*word_data, recognizer, debug, &word_data->lang_words[i],
&best_words) > 0) {
&best_words, lstm_recognizer) > 0) {
best_lang_tess = sub_langs_[i];
}
}
Expand Down Expand Up @@ -1378,7 +1426,7 @@ void Tesseract::classify_word_and_language(int pass_n, PAGE_RES_IT *pr_it, WordD
*/

void Tesseract::classify_word_pass1(const WordData &word_data, WERD_RES **in_word,
PointerVector<WERD_RES> *out_words) {
PointerVector<WERD_RES> *out_words, LSTMRecognizer *lstm_recognizer) {
ROW *row = word_data.row;
BLOCK *block = word_data.block;
prev_word_best_choice_ =
Expand All @@ -1390,14 +1438,14 @@ void Tesseract::classify_word_pass1(const WordData &word_data, WERD_RES **in_wor
tessedit_ocr_engine_mode == OEM_TESSERACT_LSTM_COMBINED) {
#endif // def DISABLED_LEGACY_ENGINE
if (!(*in_word)->odd_size || tessedit_ocr_engine_mode == OEM_LSTM_ONLY) {
LSTMRecognizeWord(*block, row, *in_word, out_words);
LSTMRecognizeWord(*block, row, *in_word, out_words, lstm_recognizer);
if (!out_words->empty()) {
return; // Successful lstm recognition.
}
}
if (tessedit_ocr_engine_mode == OEM_LSTM_ONLY) {
// No fallback allowed, so use a fake.
(*in_word)->SetupFake(lstm_recognizer_->GetUnicharset());
(*in_word)->SetupFake(lstm_recognizer->GetUnicharset());
return;
}

Expand Down Expand Up @@ -1534,7 +1582,7 @@ bool Tesseract::TestNewNormalization(int original_misfits, float baseline_shift,
*/

void Tesseract::classify_word_pass2(const WordData &word_data, WERD_RES **in_word,
PointerVector<WERD_RES> *out_words) {
PointerVector<WERD_RES> *out_words, LSTMRecognizer *lstm_recognizer) {
// Return if we do not want to run Tesseract.
if (tessedit_ocr_engine_mode == OEM_LSTM_ONLY) {
return;
Expand Down
16 changes: 8 additions & 8 deletions src/ccmain/linerec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ ImageData *Tesseract::GetRectImage(const TBOX &box, const BLOCK &block, int padd
// Recognizes a word or group of words, converting to WERD_RES in *words.
// Analogous to classify_word_pass1, but can handle a group of words as well.
void Tesseract::LSTMRecognizeWord(const BLOCK &block, ROW *row, WERD_RES *word,
PointerVector<WERD_RES> *words) {
PointerVector<WERD_RES> *words, LSTMRecognizer *lstm_recognizer) {
TBOX word_box = word->word->bounding_box();
// Get the word image - no frills.
if (tessedit_pageseg_mode == PSM_SINGLE_WORD || tessedit_pageseg_mode == PSM_RAW_LINE) {
Expand All @@ -251,30 +251,30 @@ void Tesseract::LSTMRecognizeWord(const BLOCK &block, ROW *row, WERD_RES *word,

bool do_invert = tessedit_do_invert;
float threshold = do_invert ? double(invert_threshold) : 0.0f;
lstm_recognizer_->RecognizeLine(*im_data, threshold, classify_debug_level > 0,
kWorstDictCertainty / kCertaintyScale, word_box, words,
lstm_choice_mode, lstm_choice_iterations);
lstm_recognizer->RecognizeLine(*im_data, threshold, classify_debug_level > 0,
kWorstDictCertainty / kCertaintyScale, word_box, words,
lstm_choice_mode, lstm_choice_iterations);
delete im_data;
SearchWords(words);
SearchWords(words, lstm_recognizer);
}

// Apply segmentation search to the given set of words, within the constraints
// of the existing ratings matrix. If there is already a best_choice on a word
// leaves it untouched and just sets the done/accepted etc flags.
void Tesseract::SearchWords(PointerVector<WERD_RES> *words) {
void Tesseract::SearchWords(PointerVector<WERD_RES> *words, LSTMRecognizer *lstm_recognizer) {
// Run the segmentation search on the network outputs and make a BoxWord
// for each of the output words.
// If we drop a word as junk, then there is always a space in front of the
// next.
const Dict *stopper_dict = lstm_recognizer_->GetDict();
const Dict *stopper_dict = lstm_recognizer->GetDict();
if (stopper_dict == nullptr) {
stopper_dict = &getDict();
}
for (unsigned w = 0; w < words->size(); ++w) {
WERD_RES *word = (*words)[w];
if (word->best_choice == nullptr) {
// It is a dud.
word->SetupFake(lstm_recognizer_->GetUnicharset());
word->SetupFake(lstm_recognizer->GetUnicharset());
} else {
// Set the best state.
for (unsigned i = 0; i < word->best_choice->length(); ++i) {
Expand Down
7 changes: 5 additions & 2 deletions src/ccmain/tessedit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,11 @@ bool Tesseract::init_tesseract_lang_data(const std::string &arg0,
tessedit_ocr_engine_mode == OEM_TESSERACT_LSTM_COMBINED) {
#endif // ndef DISABLED_LEGACY_ENGINE
if (mgr->IsComponentAvailable(TESSDATA_LSTM)) {
lstm_recognizer_ = new LSTMRecognizer(language_data_path_prefix.c_str());
ASSERT_HOST(lstm_recognizer_->Load(this->params(), lstm_use_matrix ? language : "", mgr));
for (int i = 0; i < lstm_num_threads; ++i) {
lstm_recognizers_.push_back(new LSTMRecognizer(language_data_path_prefix.c_str()));
lstm_recognizers_.back()->Load(this->params(), lstm_use_matrix ? language : "", mgr);
}
lstm_recognizer_ = lstm_recognizers_[0];
} else {
tprintf("Error: LSTM requested, but not present!! Loading tesseract.\n");
tessedit_ocr_engine_mode.set_value(OEM_TESSERACT_ONLY);
Expand Down
9 changes: 8 additions & 1 deletion src/ccmain/tesseractclass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,10 @@ Tesseract::Tesseract()
"lstm_choice_mode. Note that lstm_choice_mode must be set to a "
"value greater than 0 to produce results.",
this->params())
, INT_INIT_MEMBER(lstm_num_threads, 1,
"Sets the number of threads used by the LSTM recognizer. The "
"default value is 1.",
this->params())
, double_MEMBER(lstm_rating_coefficient, 5,
"Sets the rating coefficient for the lstm choices. The smaller the "
"coefficient, the better are the ratings for each choice and less "
Expand Down Expand Up @@ -477,7 +481,10 @@ Tesseract::~Tesseract() {
for (auto *lang : sub_langs_) {
delete lang;
}
delete lstm_recognizer_;
for (int i = 0; i < lstm_recognizers_.size(); ++i) {
delete lstm_recognizers_[i];
}
lstm_recognizers_.clear();
lstm_recognizer_ = nullptr;
}

Expand Down
Loading

0 comments on commit 6a2e239

Please sign in to comment.