diff --git a/Classification/include/CGAL/Classification/Trainer.h b/Classification/include/CGAL/Classification/Trainer.h index 616c7f27e4c..133a26ef8ee 100644 --- a/Classification/include/CGAL/Classification/Trainer.h +++ b/Classification/include/CGAL/Classification/Trainer.h @@ -35,6 +35,29 @@ namespace CGAL { namespace Classification { +/*! +\ingroup PkgClassification + +\brief Training algorithm to set up weights and effects of the +features used for classification. + +Each label must have ben given a small set of user-defined inliers to +provide the training algorithm with a ground truth (see `set_inlier()` +and `set_inliers()`). + +This methods estimates the set of feature weights and of +[effects](@ref Classification::Feature::Effect) that make the +classifier succeed in correctly classifying the sets of inliers given +by the user. These parameters are directly modified within the +`Classification::Feature_base` and `Classification::Label` objects. + +\tparam ItemRange model of `ConstRange`. Its iterator type is +`RandomAccessIterator`. + +\tparam ItemMap model of `ReadablePropertyMap` whose key +type is the value type of the iterator of `ItemRange` and value type is +the type of the items that are classified. +*/ template class Trainer { @@ -49,15 +72,25 @@ private: std::vector > m_training_sets; std::vector m_precision; std::vector m_recall; + std::vector m_iou; // intersection over union public: + /// \name Constructor + /// @{ + Trainer (Classifier& classifier) : m_classifier (&classifier) , m_training_sets (classifier.number_of_labels()) { } + /// @} + + + /// \name Inliers + /// @{ + /*! \brief Adds the item at position `index` as an inlier of `label` for the training algorithm. @@ -132,21 +165,19 @@ public: std::vector >().swap (m_training_sets); } + + /// @} + + + /// \name Training + /// @{ + /*! \brief Runs the training algorithm. All the `Classification::Label` and `Classification::Feature` necessary for classification should have been added before running - this function. Each label must have ben given a small set - of user-defined inliers to provide the training algorithm with a - ground truth (see `set_inlier()` and `set_inliers()`). - - This methods estimates the set of feature weights and of - [effects](@ref Classification::Feature::Effect) that make the - classifier succeed in correctly classifying the sets of inliers - given by the user. These parameters are directly modified within - the `Classification::Feature_base` and `Classification::Label` - objects. After training, the user can call `run()`, + this function. After training, the user can call `run()`, `run_with_local_smoothing()` or `run_with_graphcut()` to compute the classification using the estimated parameters. @@ -347,14 +378,115 @@ public: } CGAL_CLASSIFICATION_CERR << nb_removed << " feature(s) out of " << m_classifier->number_of_features() << " are useless" << std::endl; - + + compute_precision_recall(); return best_score; } - /// @} - + + /// \name Evaluation + /// @{ + + /*! + + \brief Returns the precision of the training for the given label. + + Precision is the number of true positives divided by the sum of + the true positives and the false positives. + + */ + double precision (Label_handle label) + { + std::size_t label_idx = (std::size_t)(-1); + for (std::size_t i = 0; i < m_classifier->number_of_labels(); ++ i) + if (m_classifier->label(i) == label) + { + label_idx = i; + break; + } + if (label_idx == (std::size_t)(-1)) + return 0.; + + return m_precision[label_idx]; + } + + /*! + + \brief Returns the recall of the training for the given label. + + Recall is the number of true positives divided by the sum of + the true positives and the false negatives. + + */ + double recall (Label_handle label) + { + std::size_t label_idx = (std::size_t)(-1); + for (std::size_t i = 0; i < m_classifier->number_of_labels(); ++ i) + if (m_classifier->label(i) == label) + { + label_idx = i; + break; + } + if (label_idx == (std::size_t)(-1)) + return 0.; + + return m_recall[label_idx]; + } + + /*! + + \brief Returns the \f$F_1\f$ score of the training for the given label. + + \f$F_1\f$ score is the harmonic mean of `precision()` and `recall()`: + + \f[ + F_1 = 2 \times \frac{precision \times recall}{precision + recall} + \f] + + */ + double f1_score (Label_handle label) + { + std::size_t label_idx = (std::size_t)(-1); + for (std::size_t i = 0; i < m_classifier->number_of_labels(); ++ i) + if (m_classifier->label(i) == label) + { + label_idx = i; + break; + } + if (label_idx == (std::size_t)(-1)) + return 0.; + + return 2. * (m_precision[label_idx] * m_recall[label_idx]) + / (m_precision[label_idx] + m_recall[label_idx]); + } + +/*! + + \brief Returns the intersection over union of the training for the + given label. + + Intersection over union is the number of true positives divided by + the sum of the true positives, of the false positives and of the + false negatives. + */ + double IoU (Label_handle label) + { + std::size_t label_idx = (std::size_t)(-1); + for (std::size_t i = 0; i < m_classifier->number_of_labels(); ++ i) + if (m_classifier->label(i) == label) + { + label_idx = i; + break; + } + if (label_idx == (std::size_t)(-1)) + return 0.; + + return m_iou[label_idx]; + } + /// @} + /// \cond SKIP_IN_MANUAL Label_handle training_label_of (std::size_t index) const { @@ -364,57 +496,6 @@ public: return Label_handle(); } - void compute_precision_recall () - { - std::vector true_positives (m_classifier->number_of_labels()); - std::vector false_positives (m_classifier->number_of_labels()); - std::vector false_negatives (m_classifier->number_of_labels()); - - for (std::size_t j = 0; j < m_classifier->number_of_labels(); ++ j) - { - for (std::size_t k = 0; k < m_training_sets[j].size(); ++ k) - { - std::size_t nb_class_best=0; - double val_class_best = (std::numeric_limits::max)(); - - for(std::size_t l = 0; l < m_classifier->number_of_labels(); ++ l) - { - double value = m_classifier->classification_value (m_classifier->label(l), - m_training_sets[j][k]); - - if(val_class_best > value) - { - val_class_best = value; - nb_class_best = l; - } - } - - if (nb_class_best == j) - ++ true_positives[j]; - else - { - ++ false_negatives[j]; - for(std::size_t l = 0; l < m_classifier->number_of_labels(); ++ l) - if (nb_class_best == l) - ++ false_positives[l]; - } - } - } - - - m_precision.clear(); - m_recall.clear(); - - for (std::size_t j = 0; j < m_classifier->number_of_labels(); ++ j) - { - m_precision.push_back (true_positives[j] / double(true_positives[j] + false_positives[j])); - m_recall.push_back (true_positives[j] / double(true_positives[j] + false_negatives[j])); - std::cerr << m_classifier->label(j)->name() << ": " << std::endl - << " * precision = " << m_precision.back() << std::endl - << " * recall = " << m_recall.back() << std::endl - << " * F_1 = " << 2. * (m_precision.back() * m_recall.back()) / (m_precision.back() + m_recall.back()) << std::endl; - } - } /// \endcond private: @@ -549,6 +630,56 @@ private: } + void compute_precision_recall () + { + std::vector true_positives (m_classifier->number_of_labels()); + std::vector false_positives (m_classifier->number_of_labels()); + std::vector false_negatives (m_classifier->number_of_labels()); + + for (std::size_t j = 0; j < m_classifier->number_of_labels(); ++ j) + { + for (std::size_t k = 0; k < m_training_sets[j].size(); ++ k) + { + std::size_t nb_class_best=0; + double val_class_best = (std::numeric_limits::max)(); + + for(std::size_t l = 0; l < m_classifier->number_of_labels(); ++ l) + { + double value = m_classifier->classification_value (m_classifier->label(l), + m_training_sets[j][k]); + + if(val_class_best > value) + { + val_class_best = value; + nb_class_best = l; + } + } + + if (nb_class_best == j) + ++ true_positives[j]; + else + { + ++ false_negatives[j]; + for(std::size_t l = 0; l < m_classifier->number_of_labels(); ++ l) + if (nb_class_best == l) + ++ false_positives[l]; + } + } + } + + + m_precision.clear(); + m_recall.clear(); + + for (std::size_t j = 0; j < m_classifier->number_of_labels(); ++ j) + { + m_precision.push_back (true_positives[j] / double(true_positives[j] + false_positives[j])); + m_recall.push_back (true_positives[j] / double(true_positives[j] + false_negatives[j])); + m_iou.push_back (true_positives[j] / double(true_positives[j] + false_positives[j] + false_negatives[j])); + } + } + + };