From 98402a7cd95bcd6a40c9bd38a42e257fd526301e Mon Sep 17 00:00:00 2001 From: Simon Giraudot Date: Mon, 24 Sep 2018 14:43:09 +0200 Subject: [PATCH] Add test for IO functions of Classification --- .../test/Classification/CMakeLists.txt | 8 ++ .../test/Classification/test_io.cpp | 98 +++++++++++++++++++ 2 files changed, 106 insertions(+) create mode 100644 Classification/test/Classification/test_io.cpp diff --git a/Classification/test/Classification/CMakeLists.txt b/Classification/test/Classification/CMakeLists.txt index 845522e5213..7879c214c4c 100644 --- a/Classification/test/Classification/CMakeLists.txt +++ b/Classification/test/Classification/CMakeLists.txt @@ -104,3 +104,11 @@ if(TARGET deprecated_test_classification_point_set) CGAL_target_use_TBB( deprecated_test_classification_point_set ) endif() endif() + +create_single_source_cgal_program( "test_io.cpp" CXX_FEATURES ${needed_cxx_features} ) +if(TARGET test_io) + target_link_libraries(test_io PUBLIC ${classification_linked_libraries}) + if (TBB_FOUND) + CGAL_target_use_TBB( test_io ) + endif() +endif() diff --git a/Classification/test/Classification/test_io.cpp b/Classification/test/Classification/test_io.cpp new file mode 100644 index 00000000000..41c707ab590 --- /dev/null +++ b/Classification/test/Classification/test_io.cpp @@ -0,0 +1,98 @@ +#if defined (_MSC_VER) && !defined (_WIN64) +#pragma warning(disable:4244) // boost::number_distance::distance() + // converts 64 to 32 bits integers +#endif + +#include +#include +#include +#include + +#include +#include +#include +#include + +typedef CGAL::Simple_cartesian Kernel; +typedef Kernel::Point_3 Point; +typedef Kernel::Vector_3 Vector; +typedef CGAL::Point_set_3 Point_set; +typedef Point_set::Point_map Point_map; + +typedef Kernel::Iso_cuboid_3 Iso_cuboid_3; + +namespace Classification = CGAL::Classification; + +typedef Classification::Label_handle Label_handle; +typedef Classification::Feature_handle Feature_handle; +typedef Classification::Label_set Label_set; +typedef Classification::Feature_set Feature_set; + +typedef Classification::ETHZ_random_forest_classifier Classifier; + +typedef Classification::Planimetric_grid Planimetric_grid; +typedef Classification::Point_set_neighborhood Neighborhood; +typedef Classification::Local_eigen_analysis Local_eigen_analysis; + +typedef Classification::Feature::Distance_to_plane Distance_to_plane; +typedef Classification::Feature::Elevation Elevation; + +int main (int, char**) +{ + Point_set points; + + for (std::size_t i = 0; i < 1000; ++ i) + points.insert (Point (CGAL::get_default_random().get_double(), + CGAL::get_default_random().get_double(), + CGAL::get_default_random().get_double())); + + Iso_cuboid_3 bbox = CGAL::bounding_box (points.points().begin(), points.points().end()); + + float grid_resolution = 0.34f; + float radius_neighbors = 1.7f; + float radius_dtm = 15.0f; + + Planimetric_grid grid (points, points.point_map(), bbox, grid_resolution); + Neighborhood neighborhood (points, points.point_map()); + Local_eigen_analysis eigen (points, points.point_map(), neighborhood.k_neighbor_query(6)); + + Feature_set features; + Feature_handle distance_to_plane = features.add (points, points.point_map(), eigen); + Feature_handle elevation = features.add (points, points.point_map(), grid, + radius_dtm); + + Label_set labels; + + std::vector training_set (points.size(), -1); + for (std::size_t i = 0; i < 3; ++ i) + { + std::ostringstream oss; + oss << "label_" << i; + Label_handle lh = labels.add(oss.str().c_str()); + + for (std::size_t j = 0; j < 100; ++ j) + training_set[std::size_t(CGAL::get_default_random().get_int(0, int(training_set.size())))] = int(i); + } + + Classifier classifier (labels, features); + classifier.train (training_set); + + std::ofstream outf ("output_config.gz"); + outf.precision(18); + classifier.save_configuration(outf); + outf.close(); + + Classifier classifier2 (labels, features); + std::ifstream inf ("output_config.gz"); + classifier2.load_configuration(inf); + + std::vector label_indices; + std::vector label_indices_2; + + Classification::classify (points, labels, classifier, label_indices); + Classification::classify (points, labels, classifier2, label_indices_2); + + assert (label_indices == label_indices_2); + + return EXIT_SUCCESS; +}