Simplify API of random forest classifier

This commit is contained in:
Simon Giraudot 2017-05-31 08:21:52 +02:00
parent 01f6ee38c7
commit 41e12aed8b
1 changed files with 45 additions and 27 deletions

View File

@ -13,7 +13,7 @@ namespace Classification {
/*!
\ingroup PkgClassificationClassifiers
\brief Classifier based on a random forest algorithm.
\brief %Classifier based on a random forest algorithm.
\note This class requires the \ref thirdpartyOpenCV library.
@ -23,7 +23,12 @@ class Random_forest_classifier
{
const Label_set& m_labels;
const Feature_set& m_features;
int m_max_depth;
int m_min_sample_count;
int m_max_categories;
int m_max_number_of_trees_in_the_forest;
float m_forest_accuracy;
#if (CV_MAJOR_VERSION < 3)
CvRTrees* rtree;
#else
@ -34,10 +39,29 @@ public:
/*!
\brief Instantiate the classifier using the sets of `labels` and `features`.
Parameters documentation is copy-pasted from [the official documentation of OpenCV](http://docs.opencv.org/2.4/modules/ml/doc/random_trees.html). For more details on this method, please refer to it.
\param labels label set used.
\param features feature set used.
\param max_depth the depth of the tree. A low value will likely underfit and conversely a high value will likely overfit. The optimal value can be obtained using cross validation or other suitable methods.
\param min_sample_count minimum samples required at a leaf node for it to be split. A reasonable value is a small percentage of the total data e.g. 1%.
\param max_categories Cluster possible values of a categorical variable into K \leq max_categories clusters to find a suboptimal split. If a discrete variable, on which the training procedure tries to make a split, takes more than max_categories values, the precise best subset estimation may take a very long time because the algorithm is exponential. Instead, many decision trees engines (including ML) try to find sub-optimal split in this case by clustering all the samples into max_categories clusters that is some categories are merged together. The clustering is applied only in n>2-class classification problems for categorical variables with N > max_categories possible values. In case of regression and 2-class classification the optimal split can be found efficiently without employing clustering, thus the parameter is not used in these cases.
\param max_number_of_trees_in_the_forest The maximum number of trees in the forest (surprise, surprise). Typically the more trees you have the better the accuracy. However, the improvement in accuracy generally diminishes and asymptotes pass a certain number of trees. Also to keep in mind, the number of tree increases the prediction time linearly.
\param forest_accuracy Sufficient accuracy (OOB error).
*/
Random_forest_classifier (const Label_set& labels,
const Feature_set& features)
: m_labels (labels), m_features (features)
const Feature_set& features,
int max_depth = 20,
int min_sample_count = 5,
int max_categories = 15,
int max_number_of_trees_in_the_forest = 100,
float forest_accuracy = 0.01f)
: m_labels (labels), m_features (features),
m_max_depth (max_depth), m_min_sample_count (min_sample_count),
m_max_categories (max_categories),
m_max_number_of_trees_in_the_forest (max_number_of_trees_in_the_forest),
m_forest_accuracy (forest_accuracy)
#if (CV_MAJOR_VERSION < 3)
, rtree (NULL)
#endif
@ -60,8 +84,6 @@ public:
sets up the random trees that produce the most accurate result
with respect to this ground truth.
Parameters documentation is copy-pasted from [the official documentation of OpenCV](http://docs.opencv.org/2.4/modules/ml/doc/random_trees.html). For more details on this method, please refer to it.
\pre At least one ground truth item should be assigned to each
label.
@ -70,19 +92,8 @@ public:
the corresponding label in the `Label_set` provided in the
constructor. Input items that do not have a ground truth
information should be given the value `std::size_t(-1)`.
\param max_depth the depth of the tree. A low value will likely underfit and conversely a high value will likely overfit. The optimal value can be obtained using cross validation or other suitable methods.
\param min_sample_count minimum samples required at a leaf node for it to be split. A reasonable value is a small percentage of the total data e.g. 1%.
\param max_categories Cluster possible values of a categorical variable into K \leq max_categories clusters to find a suboptimal split. If a discrete variable, on which the training procedure tries to make a split, takes more than max_categories values, the precise best subset estimation may take a very long time because the algorithm is exponential. Instead, many decision trees engines (including ML) try to find sub-optimal split in this case by clustering all the samples into max_categories clusters that is some categories are merged together. The clustering is applied only in n>2-class classification problems for categorical variables with N > max_categories possible values. In case of regression and 2-class classification the optimal split can be found efficiently without employing clustering, thus the parameter is not used in these cases.
\param max_number_of_trees_in_the_forest The maximum number of trees in the forest (surprise, surprise). Typically the more trees you have the better the accuracy. However, the improvement in accuracy generally diminishes and asymptotes pass a certain number of trees. Also to keep in mind, the number of tree increases the prediction time linearly.
\param forest_accuracy Sufficient accuracy (OOB error).
*/
void train (const std::vector<std::size_t>& ground_truth,
int max_depth = 20,
int min_sample_count = 5,
int max_categories = 15,
int max_number_of_trees_in_the_forest = 100,
float forest_accuracy = 0.01f)
void train (const std::vector<std::size_t>& ground_truth)
{
#if (CV_MAJOR_VERSION < 3)
if (rtree != NULL)
@ -112,10 +123,10 @@ public:
for (std::size_t i = 0; i < m_labels.size(); ++ i)
priors[i] = 1.;
CvRTParams params (max_depth, min_sample_count,
0, false, max_categories, priors, false, 0,
max_number_of_trees_in_the_forest,
forest_accuracy,
CvRTParams params (m_max_depth, m_min_sample_count,
0, false, m_max_categories, priors, false, 0,
m_max_number_of_trees_in_the_forest,
m_forest_accuracy,
CV_TERMCRIT_ITER | CV_TERMCRIT_EPS
);
@ -129,16 +140,16 @@ public:
delete[] priors;
#else
rtree = cv::ml::RTrees::create();
rtree->setMaxDepth (max_depth);
rtree->setMinSampleCount (min_sample_count);
rtree->setMaxCategories (max_categories);
rtree->setMaxDepth (m_max_depth);
rtree->setMinSampleCount (m_min_sample_count);
rtree->setMaxCategories (m_max_categories);
rtree->setCalculateVarImportance (false);
rtree->setRegressionAccuracy (forest_accuracy);
rtree->setRegressionAccuracy (m_forest_accuracy);
rtree->setUseSurrogates(false);
rtree->setPriors(cv::Mat());
rtree->setCalculateVarImportance(false);
cv::TermCriteria criteria (cv::TermCriteria::EPS + cv::TermCriteria::COUNT, max_number_of_trees_in_the_forest, 0.01f);
cv::TermCriteria criteria (cv::TermCriteria::EPS + cv::TermCriteria::COUNT, m_max_number_of_trees_in_the_forest, 0.01f);
rtree->setTermCriteria (criteria);
cv::Ptr<cv::ml::TrainData> tdata = cv::ml::TrainData::create
@ -150,6 +161,13 @@ public:
}
void set_max_depth (int max_depth) { m_max_depth = max_depth; }
void set_min_sample_count (int min_sample_count) { m_min_sample_count = min_sample_count; }
void set_max_categories (int max_categories) { m_max_categories = max_categories; }
void set_max_number_of_trees_in_the_forest (int max_number_of_trees_in_the_forest)
{ m_max_number_of_trees_in_the_forest = max_number_of_trees_in_the_forest; }
void set_forest_accuracy (float forest_accuracy) { m_forest_accuracy = forest_accuracy; }
/// \cond SKIP_IN_MANUAL
void operator() (std::size_t item_index, std::vector<float>& out) const
{