diff --git a/Classification/include/CGAL/Classification/Trainer.h b/Classification/include/CGAL/Classification/Trainer.h index 96c6d8f9c08..616c7f27e4c 100644 --- a/Classification/include/CGAL/Classification/Trainer.h +++ b/Classification/include/CGAL/Classification/Trainer.h @@ -47,6 +47,8 @@ public: private: Classifier* m_classifier; std::vector > m_training_sets; + std::vector m_precision; + std::vector m_recall; public: @@ -362,6 +364,57 @@ public: return Label_handle(); } + void compute_precision_recall () + { + std::vector true_positives (m_classifier->number_of_labels()); + std::vector false_positives (m_classifier->number_of_labels()); + std::vector false_negatives (m_classifier->number_of_labels()); + + for (std::size_t j = 0; j < m_classifier->number_of_labels(); ++ j) + { + for (std::size_t k = 0; k < m_training_sets[j].size(); ++ k) + { + std::size_t nb_class_best=0; + double val_class_best = (std::numeric_limits::max)(); + + for(std::size_t l = 0; l < m_classifier->number_of_labels(); ++ l) + { + double value = m_classifier->classification_value (m_classifier->label(l), + m_training_sets[j][k]); + + if(val_class_best > value) + { + val_class_best = value; + nb_class_best = l; + } + } + + if (nb_class_best == j) + ++ true_positives[j]; + else + { + ++ false_negatives[j]; + for(std::size_t l = 0; l < m_classifier->number_of_labels(); ++ l) + if (nb_class_best == l) + ++ false_positives[l]; + } + } + } + + + m_precision.clear(); + m_recall.clear(); + + for (std::size_t j = 0; j < m_classifier->number_of_labels(); ++ j) + { + m_precision.push_back (true_positives[j] / double(true_positives[j] + false_positives[j])); + m_recall.push_back (true_positives[j] / double(true_positives[j] + false_negatives[j])); + std::cerr << m_classifier->label(j)->name() << ": " << std::endl + << " * precision = " << m_precision.back() << std::endl + << " * recall = " << m_recall.back() << std::endl + << " * F_1 = " << 2. * (m_precision.back() * m_recall.back()) / (m_precision.back() + m_recall.back()) << std::endl; + } + } /// \endcond private: