Parallelized RF in plugin

This commit is contained in:
Simon Giraudot 2018-12-06 14:58:26 +01:00
parent 2145a29297
commit 291982ee67
4 changed files with 11 additions and 10 deletions

View File

@ -818,9 +818,9 @@ void Cluster_classification::train(int classifier, const QMultipleInputDialog& d
if (m_ethz != NULL) if (m_ethz != NULL)
delete m_ethz; delete m_ethz;
m_ethz = new ETHZ_random_forest (m_labels, m_features); m_ethz = new ETHZ_random_forest (m_labels, m_features);
m_ethz->train(training, true, m_ethz->train<Concurrency_tag>(training, true,
dialog.get<QSpinBox>("num_trees")->value(), dialog.get<QSpinBox>("num_trees")->value(),
dialog.get<QSpinBox>("max_depth")->value()); dialog.get<QSpinBox>("max_depth")->value());
CGAL::Classification::classify<Concurrency_tag> (m_clusters, CGAL::Classification::classify<Concurrency_tag> (m_clusters,
m_labels, *m_ethz, m_labels, *m_ethz,
indices, m_label_probabilities); indices, m_label_probabilities);

View File

@ -727,9 +727,9 @@ void Point_set_item_classification::train(int classifier, const QMultipleInputDi
if (m_ethz != NULL) if (m_ethz != NULL)
delete m_ethz; delete m_ethz;
m_ethz = new ETHZ_random_forest (m_labels, m_features); m_ethz = new ETHZ_random_forest (m_labels, m_features);
m_ethz->train(training, true, m_ethz->train<Concurrency_tag>(training, true,
dialog.get<QSpinBox>("num_trees")->value(), dialog.get<QSpinBox>("num_trees")->value(),
dialog.get<QSpinBox>("max_depth")->value()); dialog.get<QSpinBox>("max_depth")->value());
CGAL::Classification::classify<Concurrency_tag> (*(m_points->point_set()), CGAL::Classification::classify<Concurrency_tag> (*(m_points->point_set()),
m_labels, *m_ethz, m_labels, *m_ethz,
indices, m_label_probabilities); indices, m_label_probabilities);

View File

@ -333,10 +333,11 @@ class Point_set_item_classification : public Item_classification_base
{ {
std::vector<int> indices (m_points->point_set()->size(), -1); std::vector<int> indices (m_points->point_set()->size(), -1);
m_label_probabilities.clear();
if (method == 0) if (method == 0)
CGAL::Classification::classify<Concurrency_tag> (*(m_points->point_set()), CGAL::Classification::classify<Concurrency_tag> (*(m_points->point_set()),
m_labels, classifier, m_labels, classifier,
indices); indices, m_label_probabilities);
else if (method == 1) else if (method == 1)
{ {
if (m_clusters.empty()) // Use real local smoothing if (m_clusters.empty()) // Use real local smoothing

View File

@ -293,9 +293,9 @@ void Surface_mesh_item_classification::train (int classifier, const QMultipleInp
if (m_ethz != NULL) if (m_ethz != NULL)
delete m_ethz; delete m_ethz;
m_ethz = new ETHZ_random_forest (m_labels, m_features); m_ethz = new ETHZ_random_forest (m_labels, m_features);
m_ethz->train(training, true, m_ethz->train<Concurrency_tag>(training, true,
dialog.get<QSpinBox>("num_trees")->value(), dialog.get<QSpinBox>("num_trees")->value(),
dialog.get<QSpinBox>("max_depth")->value()); dialog.get<QSpinBox>("max_depth")->value());
CGAL::Classification::classify<Concurrency_tag> (m_mesh->polyhedron()->faces(), CGAL::Classification::classify<Concurrency_tag> (m_mesh->polyhedron()->faces(),
m_labels, *m_ethz, m_labels, *m_ethz,
indices, m_label_probabilities); indices, m_label_probabilities);