diff --git a/Classification/include/CGAL/Classification/Random_forest_predicate.h b/Classification/include/CGAL/Classification/Random_forest_predicate.h new file mode 100644 index 00000000000..7e531559e5f --- /dev/null +++ b/Classification/include/CGAL/Classification/Random_forest_predicate.h @@ -0,0 +1,139 @@ +#ifndef CGAL_CLASSIFICATION_RANDOM_FOREST_PREDICATE_H +#define CGAL_CLASSIFICATION_RANDOM_FOREST_PREDICATE_H + +#include +#include + +#include // opencv general include file +#include // opencv machine learning include file + +namespace CGAL { + +namespace Classification { + +/*! + \ingroup PkgClassificationPredicates + + \brief %Classification predicate based on a random forest algorithm. + + \note This class requires the \ref thirdpartyOpenCV library. + + \cgalModels `CGAL::Classification::Predicate` +*/ +class Random_forest_predicate +{ + Label_set& m_labels; + Feature_set& m_features; + CvRTrees* rtree; + +public: + +/*! + \brief Instantiate the predicate using the sets of `labels` and `features`. +*/ + Random_forest_predicate (Label_set& labels, + Feature_set& features) + : m_labels (labels), m_features (features), rtree (NULL) + { } + + /// \cond SKIP_IN_MANUAL + ~Random_forest_predicate () + { + if (rtree != NULL) + delete rtree; + } + /// \endcond + + /*! + \brief Runs the training algorithm. + + From the set of provided ground truth, this algorithm estimates + sets up the random trees that produce the most accurate result + with respect to this ground truth. + + For more details on the parameters of this algorithm, please refer + to [the official documentation of OpenCV](http://docs.opencv.org/2.4/modules/ml/doc/random_trees.html). + + \note Each label should be assigned at least one ground truth + item. + + \param ground_truth vector of label indices. It should contain for + each input item, in the same order as the input set, the index of + the corresponding label in the `Label_set` provided in the + constructor. Input items that do not have a ground truth + information should be given the value `std::size_t(-1)`. + + */ + void train (const std::vector& ground_truth, + int max_depth = 20, + int min_sample_count = 5, + int max_categories = 15, + int max_number_of_trees_in_the_forest = 100, + float forest_accuracy = 0.01f) + + { + if (rtree != NULL) + delete rtree; + + std::size_t nb_samples = 0; + for (std::size_t i = 0; i < ground_truth.size(); ++ i) + if (ground_truth[i] != std::size_t(-1)) + ++ nb_samples; + + + cv::Mat training_features (nb_samples, m_features.size(), CV_32FC1); + cv::Mat training_labels (nb_samples, 1, CV_32FC1); + + for (std::size_t i = 0, index = 0; i < ground_truth.size(); ++ i) + if (ground_truth[i] != std::size_t(-1)) + { + for (std::size_t f = 0; f < m_features.size(); ++ f) + training_features.at(index, f) = m_features[f]->value(i); + training_labels.at(index, 0) = ground_truth[i]; + ++ index; + } + + float* priors = new float[m_labels.size()]; + for (std::size_t i = 0; i < m_labels.size(); ++ i) + priors[i] = 1.; + + CvRTParams params (max_depth, min_sample_count, + 0, false, max_categories, priors, false, 0, + max_number_of_trees_in_the_forest, + forest_accuracy, + CV_TERMCRIT_ITER | CV_TERMCRIT_EPS + ); + + cv::Mat var_type (m_features.size() + 1, 1, CV_8U); + var_type.setTo (cv::Scalar(CV_VAR_NUMERICAL)); + + rtree = new CvRTrees; + rtree->train (training_features, CV_ROW_SAMPLE, training_labels, + cv::Mat(), cv::Mat(), var_type, cv::Mat(), params); + + delete[] priors; + } + + /// \cond SKIP_IN_MANUAL + void probabilities (std::size_t item_index, std::vector& out) const + { + out.resize (m_labels.size(), 1.); + + cv::Mat feature (1, m_features.size(), CV_32FC1); + for (std::size_t f = 0; f < m_features.size(); ++ f) + feature.at(0, f) = m_features[f]->value(item_index); + + float result = rtree->predict (feature, cv::Mat()); + std::size_t label = std::size_t(result); + if (label < out.size()) + out[label] = 0.; + } + /// \endcond + +}; + +} + +} + +#endif // CGAL_CLASSIFICATION_RANDOM_FOREST_PREDICATE_H