@@ -428,7 +428,8 @@ void signal_callback_handler(int /* signum */) {
428428 requested_stop = true ;
429429}
430430
431- unsigned compute_correct (const map<int ,int >& ref, const map<int ,int >& hyp, unsigned len) {
431+ template <typename T>
432+ unsigned compute_correct (const map<int ,T>& ref, const map<int ,T>& hyp, unsigned len) {
432433 unsigned res = 0 ;
433434 for (unsigned i = 0 ; i < len; ++i) {
434435 auto ri = ref.find (i);
@@ -440,6 +441,24 @@ unsigned compute_correct(const map<int,int>& ref, const map<int,int>& hyp, unsig
440441 return res;
441442}
442443
444+ template <typename T1, typename T2>
445+ unsigned compute_correct (const map<int ,T1>& ref1, const map<int ,T1>& hyp1,
446+ const map<int ,T2>& ref2, const map<int ,T2>& hyp2, unsigned len) {
447+ unsigned res = 0 ;
448+ for (unsigned i = 0 ; i < len; ++i) {
449+ auto r1 = ref1.find (i);
450+ auto h1 = hyp1.find (i);
451+ auto r2 = ref2.find (i);
452+ auto h2 = hyp2.find (i);
453+ assert (r1 != ref1.end ());
454+ assert (h1 != hyp1.end ());
455+ assert (r2 != ref2.end ());
456+ assert (h2 != hyp2.end ());
457+ if (r1->second == h1->second && r2->second == h2->second ) ++res;
458+ }
459+ return res;
460+ }
461+
443462void output_conll (const vector<unsigned >& sentence, const vector<unsigned >& pos,
444463 const vector<string>& sentenceUnkStrings,
445464 const map<unsigned , string>& intToWords,
@@ -714,7 +733,8 @@ int main(int argc, char** argv) {
714733 double llh = 0 ;
715734 double trs = 0 ;
716735 double right = 0 ;
717- double correct_heads = 0 ;
736+ double correct_heads_unlabeled = 0 ;
737+ double correct_heads_labeled = 0 ;
718738 double total_heads = 0 ;
719739 auto t_start = std::chrono::high_resolution_clock::now ();
720740 unsigned corpus_size = corpus.nsentencesDev ;
@@ -736,11 +756,12 @@ int main(int argc, char** argv) {
736756 map<int ,int > ref = parser.compute_heads (sentence.size (), actions, corpus.actions , &rel_ref);
737757 map<int ,int > hyp = parser.compute_heads (sentence.size (), pred, corpus.actions , &rel_hyp);
738758 output_conll (sentence, sentencePos, sentenceUnkStr, corpus.intToWords , corpus.intToPos , hyp, rel_hyp);
739- correct_heads += compute_correct (ref, hyp, sentence.size () - 1 );
759+ correct_heads_unlabeled += compute_correct (ref, hyp, sentence.size () - 1 );
760+ correct_heads_labeled += compute_correct (ref, hyp, rel_ref, rel_hyp, sentence.size () - 1 );
740761 total_heads += sentence.size () - 1 ;
741762 }
742763 auto t_end = std::chrono::high_resolution_clock::now ();
743- cerr << " TEST llh=" << llh << " ppl: " << exp (llh / trs) << " err: " << (trs - right) / trs << " uas: " << (correct_heads / total_heads) << " \t [" << corpus_size << " sents in " << std::chrono::duration<double , std::milli>(t_end-t_start).count () << " ms]" << endl;
764+ cerr << " TEST llh=" << llh << " ppl: " << exp (llh / trs) << " err: " << (trs - right) / trs << " uas: " << (correct_heads_unlabeled / total_heads) << " las: " << (correct_heads_labeled / total_heads) << " \t [" << corpus_size << " sents in " << std::chrono::duration<double , std::milli>(t_end-t_start).count () << " ms]" << endl;
744765 }
745766 for (unsigned i = 0 ; i < corpus.actions .size (); ++i) {
746767 // cerr << corpus.actions[i] << '\t' << parser.p_r->values[i].transpose() << endl;
0 commit comments