From 7ac0eccbb9cbcdda0afe6ea051103bcedf33b7d6 Mon Sep 17 00:00:00 2001 From: Simon Giraudot Date: Tue, 11 Dec 2018 12:52:37 +0100 Subject: [PATCH] Add function to copy random forest classifier --- .../ETHZ_random_forest_classifier.h | 22 ++++++++++++++++++- .../Classification/test_classification_io.cpp | 7 +++++- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/Classification/include/CGAL/Classification/ETHZ_random_forest_classifier.h b/Classification/include/CGAL/Classification/ETHZ_random_forest_classifier.h index 6ea47a2a6e9..12db1189a0b 100644 --- a/Classification/include/CGAL/Classification/ETHZ_random_forest_classifier.h +++ b/Classification/include/CGAL/Classification/ETHZ_random_forest_classifier.h @@ -90,6 +90,26 @@ public: : m_labels (labels), m_features (features), m_rfc (NULL) { } + /*! + \brief Copies the `other` classifier's configuration using another + set of `features`. + + This constructor can be used to apply a trained random forest to + another data set. + + \warning The feature set should be composed of the same features + than the ones used by `other`, and in the same order. + + */ + ETHZ_random_forest_classifier (const ETHZ_random_forest_classifier& other, + const Feature_set& features) + : m_labels (other.m_labels), m_features (features), m_rfc (NULL) + { + std::stringstream stream; + other.save_configuration(stream); + this->load_configuration(stream); + } + /// \cond SKIP_IN_MANUAL ~ETHZ_random_forest_classifier () { @@ -267,7 +287,7 @@ public: The output file is written in an GZIP container that is readable by the `load_configuration()` method. */ - void save_configuration (std::ostream& output) + void save_configuration (std::ostream& output) const { boost::iostreams::filtering_ostream outs; outs.push(boost::iostreams::gzip_compressor()); diff --git a/Classification/test/Classification/test_classification_io.cpp b/Classification/test/Classification/test_classification_io.cpp index cc05ccfb110..bb1d770e47c 100644 --- a/Classification/test/Classification/test_classification_io.cpp +++ b/Classification/test/Classification/test_classification_io.cpp @@ -87,13 +87,18 @@ int main (int, char**) std::ifstream inf ("output_config.gz", std::ios::binary); classifier2.load_configuration(inf); + Classifier classifier3 (classifier, features); + std::vector label_indices; std::vector label_indices_2; + std::vector label_indices_3; Classification::classify (points, labels, classifier, label_indices); Classification::classify (points, labels, classifier2, label_indices_2); + Classification::classify (points, labels, classifier3, label_indices_3); assert (label_indices == label_indices_2); - + assert (label_indices == label_indices_3); + return EXIT_SUCCESS; }