diff --git a/Classification/examples/Classification/example_ethz_random_forest.cpp b/Classification/examples/Classification/example_ethz_random_forest.cpp index 0075b870b06..b0333f58ec4 100644 --- a/Classification/examples/Classification/example_ethz_random_forest.cpp +++ b/Classification/examples/Classification/example_ethz_random_forest.cpp @@ -77,6 +77,10 @@ int main (int argc, char** argv) Label_handle vegetation = labels.add ("vegetation"); Label_handle roof = labels.add ("roof"); + // Check if ground truth is valid for this label set + if (!labels.is_valid_ground_truth (pts.range(label_map), true)) + return EXIT_FAILURE; + std::vector label_indices(pts.size(), -1); std::cerr << "Using ETHZ Random Forest Classifier" << std::endl; diff --git a/Classification/include/CGAL/Classification/ETHZ/Random_forest_classifier.h b/Classification/include/CGAL/Classification/ETHZ/Random_forest_classifier.h index be3fc63638f..8a0894bdf1c 100644 --- a/Classification/include/CGAL/Classification/ETHZ/Random_forest_classifier.h +++ b/Classification/include/CGAL/Classification/ETHZ/Random_forest_classifier.h @@ -167,6 +167,8 @@ public: std::size_t num_trees = 25, std::size_t max_depth = 20) { + CGAL_precondition (m_labels.is_valid_ground_truth (ground_truth)); + CGAL::internal::liblearning::RandomForest::ForestParams params; params.n_trees = num_trees; params.max_depth = max_depth; diff --git a/Classification/include/CGAL/Classification/Evaluation.h b/Classification/include/CGAL/Classification/Evaluation.h index 740e61e14fd..fba0ee02fc8 100644 --- a/Classification/include/CGAL/Classification/Evaluation.h +++ b/Classification/include/CGAL/Classification/Evaluation.h @@ -120,6 +120,8 @@ public: void append (const GroundTruthIndexRange& ground_truth, const ResultIndexRange& result) { + CGAL_precondition (m_labels.is_valid_ground_truth (ground_truth)); + CGAL_precondition (m_labels.is_valid_ground_truth (result)); for (const auto& p : CGAL::make_range (boost::make_zip_iterator(boost::make_tuple(ground_truth.begin(), result.begin())), diff --git a/Classification/include/CGAL/Classification/Label_set.h b/Classification/include/CGAL/Classification/Label_set.h index 3102060ad34..adcd032f0fd 100644 --- a/Classification/include/CGAL/Classification/Label_set.h +++ b/Classification/include/CGAL/Classification/Label_set.h @@ -264,6 +264,67 @@ public: /// @} + /// \name Validity + /// @{ + + /*! + \brief Checks the validity of the ground truth with respect to the + label set. + + \param ground_truth range of label indices. This function checks + that all these indices are either -1 (for unclassified) or a valid + index of one of the labels. If at least one of the indices is out + of range, this function returns `false`, otherwise it returns + `true`. + + \param verbose if set to `true`, the number of inliers of each + label, the number of unclassified items and the potential number + of out-of-range items are displayed. Otherwise, this function does + not display anything. + */ + template + bool is_valid_ground_truth (const LabelIndexRange& ground_truth, + bool verbose = false) const + { + std::vector nb_inliers (m_labels.size() + 2, 0); + std::size_t total = 0; + + for (const auto& gt : ground_truth) + { + int g = int(gt); + if (g == -1) + ++ nb_inliers[m_labels.size()]; + else if (g >= int(m_labels.size())) + { + ++ nb_inliers[m_labels.size() + 1]; + if (!verbose) + break; + } + else + ++ nb_inliers[std::size_t(gt)]; + ++ total; + } + + bool valid = (nb_inliers[m_labels.size() + 1] == 0); + + if (verbose) + { + std::cout << "Ground truth is " << (valid ? "valid" : "invalid") << ":" << std::endl; + std::cout << " * " << nb_inliers[m_labels.size()] << " unclassified item(s) (" + << 100. * (nb_inliers[m_labels.size()] / double(total)) << "%)" << std::endl; + for (std::size_t i = 0; i < m_labels.size(); ++ i) + std::cout << " * " << nb_inliers[i] << " " << m_labels[i]->name() << " inlier(s) (" + << 100. * (nb_inliers[i] / double(total)) << "%)" << std::endl; + if (!valid) + std::cout << " * " << nb_inliers[m_labels.size() + 1] << " item(s) with out-of-range index (" + << 100. * (nb_inliers[m_labels.size() + 1] / double(total)) << "%)" << std::endl; + } + + return valid; + } + + /// @} + }; diff --git a/Classification/include/CGAL/Classification/OpenCV/Random_forest_classifier.h b/Classification/include/CGAL/Classification/OpenCV/Random_forest_classifier.h index 4c062c568bd..543fc34843d 100644 --- a/Classification/include/CGAL/Classification/OpenCV/Random_forest_classifier.h +++ b/Classification/include/CGAL/Classification/OpenCV/Random_forest_classifier.h @@ -147,6 +147,8 @@ public: template void train (const LabelIndexRange& ground_truth) { + CGAL_precondition (m_labels.is_valid_ground_truth (ground_truth)); + #if (CV_MAJOR_VERSION < 3) if (rtree != nullptr) delete rtree; diff --git a/Classification/include/CGAL/Classification/Sum_of_weighted_features_classifier.h b/Classification/include/CGAL/Classification/Sum_of_weighted_features_classifier.h index 02c67cc7eb4..979a0968eca 100644 --- a/Classification/include/CGAL/Classification/Sum_of_weighted_features_classifier.h +++ b/Classification/include/CGAL/Classification/Sum_of_weighted_features_classifier.h @@ -298,6 +298,8 @@ public: float train (const LabelIndexRange& ground_truth, unsigned int nb_tests = 300) { + CGAL_precondition (m_labels.is_valid_ground_truth (ground_truth)); + std::vector > training_sets (m_labels.size()); std::size_t nb_tot = 0; std::size_t i = 0; diff --git a/Classification/include/CGAL/Classification/TensorFlow/Neural_network_classifier.h b/Classification/include/CGAL/Classification/TensorFlow/Neural_network_classifier.h index 6ec0c3bfde7..ef7d6897560 100644 --- a/Classification/include/CGAL/Classification/TensorFlow/Neural_network_classifier.h +++ b/Classification/include/CGAL/Classification/TensorFlow/Neural_network_classifier.h @@ -251,6 +251,8 @@ public: const std::vector& hidden_layers = std::vector()) { + CGAL_precondition (m_labels.is_valid_ground_truth (ground_truth)); + if (restart_from_scratch) clear();