mirror of https://github.com/CGAL/cgal
Add random forest predicate
This commit is contained in:
parent
34070bfd52
commit
735c3d5d54
|
|
@ -0,0 +1,139 @@
|
||||||
|
#ifndef CGAL_CLASSIFICATION_RANDOM_FOREST_PREDICATE_H
|
||||||
|
#define CGAL_CLASSIFICATION_RANDOM_FOREST_PREDICATE_H
|
||||||
|
|
||||||
|
#include <CGAL/Classification/Feature_set.h>
|
||||||
|
#include <CGAL/Classification/Label_set.h>
|
||||||
|
|
||||||
|
#include <cv.h> // opencv general include file
|
||||||
|
#include <ml.h> // opencv machine learning include file
|
||||||
|
|
||||||
|
namespace CGAL {
|
||||||
|
|
||||||
|
namespace Classification {
|
||||||
|
|
||||||
|
/*!
|
||||||
|
\ingroup PkgClassificationPredicates
|
||||||
|
|
||||||
|
\brief %Classification predicate based on a random forest algorithm.
|
||||||
|
|
||||||
|
\note This class requires the \ref thirdpartyOpenCV library.
|
||||||
|
|
||||||
|
\cgalModels `CGAL::Classification::Predicate`
|
||||||
|
*/
|
||||||
|
class Random_forest_predicate
|
||||||
|
{
|
||||||
|
Label_set& m_labels;
|
||||||
|
Feature_set& m_features;
|
||||||
|
CvRTrees* rtree;
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
/*!
|
||||||
|
\brief Instantiate the predicate using the sets of `labels` and `features`.
|
||||||
|
*/
|
||||||
|
Random_forest_predicate (Label_set& labels,
|
||||||
|
Feature_set& features)
|
||||||
|
: m_labels (labels), m_features (features), rtree (NULL)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
/// \cond SKIP_IN_MANUAL
|
||||||
|
~Random_forest_predicate ()
|
||||||
|
{
|
||||||
|
if (rtree != NULL)
|
||||||
|
delete rtree;
|
||||||
|
}
|
||||||
|
/// \endcond
|
||||||
|
|
||||||
|
/*!
|
||||||
|
\brief Runs the training algorithm.
|
||||||
|
|
||||||
|
From the set of provided ground truth, this algorithm estimates
|
||||||
|
sets up the random trees that produce the most accurate result
|
||||||
|
with respect to this ground truth.
|
||||||
|
|
||||||
|
For more details on the parameters of this algorithm, please refer
|
||||||
|
to [the official documentation of OpenCV](http://docs.opencv.org/2.4/modules/ml/doc/random_trees.html).
|
||||||
|
|
||||||
|
\note Each label should be assigned at least one ground truth
|
||||||
|
item.
|
||||||
|
|
||||||
|
\param ground_truth vector of label indices. It should contain for
|
||||||
|
each input item, in the same order as the input set, the index of
|
||||||
|
the corresponding label in the `Label_set` provided in the
|
||||||
|
constructor. Input items that do not have a ground truth
|
||||||
|
information should be given the value `std::size_t(-1)`.
|
||||||
|
|
||||||
|
*/
|
||||||
|
void train (const std::vector<std::size_t>& ground_truth,
|
||||||
|
int max_depth = 20,
|
||||||
|
int min_sample_count = 5,
|
||||||
|
int max_categories = 15,
|
||||||
|
int max_number_of_trees_in_the_forest = 100,
|
||||||
|
float forest_accuracy = 0.01f)
|
||||||
|
|
||||||
|
{
|
||||||
|
if (rtree != NULL)
|
||||||
|
delete rtree;
|
||||||
|
|
||||||
|
std::size_t nb_samples = 0;
|
||||||
|
for (std::size_t i = 0; i < ground_truth.size(); ++ i)
|
||||||
|
if (ground_truth[i] != std::size_t(-1))
|
||||||
|
++ nb_samples;
|
||||||
|
|
||||||
|
|
||||||
|
cv::Mat training_features (nb_samples, m_features.size(), CV_32FC1);
|
||||||
|
cv::Mat training_labels (nb_samples, 1, CV_32FC1);
|
||||||
|
|
||||||
|
for (std::size_t i = 0, index = 0; i < ground_truth.size(); ++ i)
|
||||||
|
if (ground_truth[i] != std::size_t(-1))
|
||||||
|
{
|
||||||
|
for (std::size_t f = 0; f < m_features.size(); ++ f)
|
||||||
|
training_features.at<float>(index, f) = m_features[f]->value(i);
|
||||||
|
training_labels.at<float>(index, 0) = ground_truth[i];
|
||||||
|
++ index;
|
||||||
|
}
|
||||||
|
|
||||||
|
float* priors = new float[m_labels.size()];
|
||||||
|
for (std::size_t i = 0; i < m_labels.size(); ++ i)
|
||||||
|
priors[i] = 1.;
|
||||||
|
|
||||||
|
CvRTParams params (max_depth, min_sample_count,
|
||||||
|
0, false, max_categories, priors, false, 0,
|
||||||
|
max_number_of_trees_in_the_forest,
|
||||||
|
forest_accuracy,
|
||||||
|
CV_TERMCRIT_ITER | CV_TERMCRIT_EPS
|
||||||
|
);
|
||||||
|
|
||||||
|
cv::Mat var_type (m_features.size() + 1, 1, CV_8U);
|
||||||
|
var_type.setTo (cv::Scalar(CV_VAR_NUMERICAL));
|
||||||
|
|
||||||
|
rtree = new CvRTrees;
|
||||||
|
rtree->train (training_features, CV_ROW_SAMPLE, training_labels,
|
||||||
|
cv::Mat(), cv::Mat(), var_type, cv::Mat(), params);
|
||||||
|
|
||||||
|
delete[] priors;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \cond SKIP_IN_MANUAL
|
||||||
|
void probabilities (std::size_t item_index, std::vector<float>& out) const
|
||||||
|
{
|
||||||
|
out.resize (m_labels.size(), 1.);
|
||||||
|
|
||||||
|
cv::Mat feature (1, m_features.size(), CV_32FC1);
|
||||||
|
for (std::size_t f = 0; f < m_features.size(); ++ f)
|
||||||
|
feature.at<float>(0, f) = m_features[f]->value(item_index);
|
||||||
|
|
||||||
|
float result = rtree->predict (feature, cv::Mat());
|
||||||
|
std::size_t label = std::size_t(result);
|
||||||
|
if (label < out.size())
|
||||||
|
out[label] = 0.;
|
||||||
|
}
|
||||||
|
/// \endcond
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // CGAL_CLASSIFICATION_RANDOM_FOREST_PREDICATE_H
|
||||||
Loading…
Reference in New Issue