From ddf85031bf85bc0f6552dc61022de32b16b89013 Mon Sep 17 00:00:00 2001 From: Simon Giraudot Date: Fri, 6 Oct 2017 10:01:59 +0200 Subject: [PATCH] ETHZ random forest classifier --- Classification/include/CGAL/Classification.h | 1 + .../ETHZ_random_forest_classifier.h | 179 ++++++++++++++++++ 2 files changed, 180 insertions(+) create mode 100644 Classification/include/CGAL/Classification/ETHZ_random_forest_classifier.h diff --git a/Classification/include/CGAL/Classification.h b/Classification/include/CGAL/Classification.h index b1e0a1fb3d4..61ca972dce3 100644 --- a/Classification/include/CGAL/Classification.h +++ b/Classification/include/CGAL/Classification.h @@ -25,6 +25,7 @@ #include #include +#include #ifdef CGAL_LINKED_WITH_OPENCV #include diff --git a/Classification/include/CGAL/Classification/ETHZ_random_forest_classifier.h b/Classification/include/CGAL/Classification/ETHZ_random_forest_classifier.h new file mode 100644 index 00000000000..c6515f23ce0 --- /dev/null +++ b/Classification/include/CGAL/Classification/ETHZ_random_forest_classifier.h @@ -0,0 +1,179 @@ +// Copyright (c) 2017 GeometryFactory Sarl (France). +// All rights reserved. +// +// This file is part of CGAL (www.cgal.org). +// You can redistribute it and/or modify it under the terms of the GNU +// General Public License as published by the Free Software Foundation, +// either version 3 of the License, or (at your option) any later version. +// +// Licensees holding a valid commercial license may use this file in +// accordance with the commercial license agreement provided with the software. +// +// This file is provided AS IS with NO WARRANTY OF ANY KIND, INCLUDING THE +// WARRANTY OF DESIGN, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. +// +// $URL$ +// $Id$ +// +// Author(s) : Simon Giraudot + +#ifndef CGAL_CLASSIFICATION_ETHZ_RANDOM_FOREST_CLASSIFIER_H +#define CGAL_CLASSIFICATION_ETHZ_RANDOM_FOREST_CLASSIFIER_H + +#include + +#include +#include + +#include +#include + +#include +#include +#include +#include + + +namespace CGAL { + +namespace Classification { + +/*! + \ingroup PkgClassificationClassifiers + + \brief %Classifier based on the ETHZ version of random forest algorithm. + + \cgalModels `CGAL::Classification::Classifier` +*/ +class ETHZ_random_forest_classifier +{ + typedef liblearning::RandomForest::RandomForest + < liblearning::RandomForest::NodeGini + < liblearning::RandomForest::AxisAlignedSplitter> > Forest; + + const Label_set& m_labels; + const Feature_set& m_features; + Forest* m_rfc; + +public: + +/*! + \brief Instantiate the classifier using the sets of `labels` and `features`. + +*/ + ETHZ_random_forest_classifier (const Label_set& labels, + const Feature_set& features) + : m_labels (labels), m_features (features), m_rfc (NULL) + { } + + /// \cond SKIP_IN_MANUAL + ~ETHZ_random_forest_classifier () + { + if (m_rfc != NULL) + delete m_rfc; + } + /// \endcond + + /*! + \brief Runs the training algorithm. + + From the set of provided ground truth, this algorithm estimates + sets up the random trees that produce the most accurate result + with respect to this ground truth. + + \pre At least one ground truth item should be assigned to each + label. + + \param ground_truth vector of label indices. It should contain for + each input item, in the same order as the input set, the index of + the corresponding label in the `Label_set` provided in the + constructor. Input items that do not have a ground truth + information should be given the value `-1`. + */ + template + void train (const LabelIndexRange& ground_truth, + std::size_t num_trees = 25, + std::size_t max_depth = 20) + { + liblearning::RandomForest::ForestParams params; + params.n_trees = num_trees; + params.max_depth = max_depth; + + std::vector gt; + std::vector ft; + + for (std::size_t i = 0; i < ground_truth.size(); ++ i) + if (ground_truth[i] != std::size_t(-1)) + { + for (std::size_t f = 0; f < m_features.size(); ++ f) + ft.push_back(m_features[f]->value(i)); + gt.push_back(ground_truth[i]); + } + + liblearning::DataView2D label_vector (&(gt[0]), gt.size(), 1); + liblearning::DataView2D feature_vector(&(ft[0]), gt.size(), ft.size() / gt.size()); + + if (m_rfc != NULL) + delete m_rfc; + m_rfc = new Forest (params); + + liblearning::RandomForest::AxisAlignedRandomSplitGenerator generator; + + m_rfc->train(feature_vector, label_vector, liblearning::DataView2D(), generator, 0, false); + } + + /// \cond SKIP_IN_MANUAL + void operator() (std::size_t item_index, std::vector& out) const + { + out.resize (m_labels.size(), 0.); + + std::vector ft; + ft.reserve (m_features.size()); + for (std::size_t f = 0; f < m_features.size(); ++ f) + ft.push_back (m_features[f]->value(item_index)); + + std::vector prob (m_labels.size()); + + m_rfc->evaluate (ft.data(), prob.data()); + + for (std::size_t i = 0; i < out.size(); ++ i) + out[i] = - std::log (prob[i]); + } + + void save_configuration (const char* filename) + { + std::ofstream ofs(filename, std::ios_base::out | std::ios_base::binary); + boost::iostreams::filtering_ostream outs; + outs.push(boost::iostreams::gzip_compressor()); + outs.push(ofs); + boost::archive::text_oarchive oas(outs); + oas << BOOST_SERIALIZATION_NVP(*m_rfc); + } + + void load_configuration (const char* filename, + std::size_t num_trees = 25, + std::size_t max_depth = 20) + { + liblearning::RandomForest::ForestParams params; + params.n_trees = num_trees; + params.max_depth = max_depth; + if (m_rfc != NULL) + delete m_rfc; + m_rfc = new Forest (params); + + std::ifstream ifs(filename, std::ios_base::in | std::ios_base::binary); + boost::iostreams::filtering_istream ins; + ins.push(boost::iostreams::gzip_decompressor()); + ins.push(ifs); + boost::archive::text_iarchive ias(ins); + ias >> BOOST_SERIALIZATION_NVP(*m_rfc); + } + /// \endcond + +}; + +} + +} + +#endif // CGAL_CLASSIFICATION_ETHZ_RANDOM_FOREST_CLASSIFIER_H