mirror of https://github.com/CGAL/cgal
Add function to check validity of ground truth + preconditions
This commit is contained in:
parent
30e8ac4f7c
commit
12a027b027
|
|
@ -77,6 +77,10 @@ int main (int argc, char** argv)
|
||||||
Label_handle vegetation = labels.add ("vegetation");
|
Label_handle vegetation = labels.add ("vegetation");
|
||||||
Label_handle roof = labels.add ("roof");
|
Label_handle roof = labels.add ("roof");
|
||||||
|
|
||||||
|
// Check if ground truth is valid for this label set
|
||||||
|
if (!labels.is_valid_ground_truth (pts.range(label_map), true))
|
||||||
|
return EXIT_FAILURE;
|
||||||
|
|
||||||
std::vector<int> label_indices(pts.size(), -1);
|
std::vector<int> label_indices(pts.size(), -1);
|
||||||
|
|
||||||
std::cerr << "Using ETHZ Random Forest Classifier" << std::endl;
|
std::cerr << "Using ETHZ Random Forest Classifier" << std::endl;
|
||||||
|
|
|
||||||
|
|
@ -167,6 +167,8 @@ public:
|
||||||
std::size_t num_trees = 25,
|
std::size_t num_trees = 25,
|
||||||
std::size_t max_depth = 20)
|
std::size_t max_depth = 20)
|
||||||
{
|
{
|
||||||
|
CGAL_precondition (m_labels.is_valid_ground_truth (ground_truth));
|
||||||
|
|
||||||
CGAL::internal::liblearning::RandomForest::ForestParams params;
|
CGAL::internal::liblearning::RandomForest::ForestParams params;
|
||||||
params.n_trees = num_trees;
|
params.n_trees = num_trees;
|
||||||
params.max_depth = max_depth;
|
params.max_depth = max_depth;
|
||||||
|
|
|
||||||
|
|
@ -120,6 +120,8 @@ public:
|
||||||
void append (const GroundTruthIndexRange& ground_truth,
|
void append (const GroundTruthIndexRange& ground_truth,
|
||||||
const ResultIndexRange& result)
|
const ResultIndexRange& result)
|
||||||
{
|
{
|
||||||
|
CGAL_precondition (m_labels.is_valid_ground_truth (ground_truth));
|
||||||
|
CGAL_precondition (m_labels.is_valid_ground_truth (result));
|
||||||
|
|
||||||
for (const auto& p : CGAL::make_range
|
for (const auto& p : CGAL::make_range
|
||||||
(boost::make_zip_iterator(boost::make_tuple(ground_truth.begin(), result.begin())),
|
(boost::make_zip_iterator(boost::make_tuple(ground_truth.begin(), result.begin())),
|
||||||
|
|
|
||||||
|
|
@ -264,6 +264,67 @@ public:
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
/// \name Validity
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/*!
|
||||||
|
\brief Checks the validity of the ground truth with respect to the
|
||||||
|
label set.
|
||||||
|
|
||||||
|
\param ground_truth range of label indices. This function checks
|
||||||
|
that all these indices are either -1 (for unclassified) or a valid
|
||||||
|
index of one of the labels. If at least one of the indices is out
|
||||||
|
of range, this function returns `false`, otherwise it returns
|
||||||
|
`true`.
|
||||||
|
|
||||||
|
\param verbose if set to `true`, the number of inliers of each
|
||||||
|
label, the number of unclassified items and the potential number
|
||||||
|
of out-of-range items are displayed. Otherwise, this function does
|
||||||
|
not display anything.
|
||||||
|
*/
|
||||||
|
template <typename LabelIndexRange>
|
||||||
|
bool is_valid_ground_truth (const LabelIndexRange& ground_truth,
|
||||||
|
bool verbose = false) const
|
||||||
|
{
|
||||||
|
std::vector<std::size_t> nb_inliers (m_labels.size() + 2, 0);
|
||||||
|
std::size_t total = 0;
|
||||||
|
|
||||||
|
for (const auto& gt : ground_truth)
|
||||||
|
{
|
||||||
|
int g = int(gt);
|
||||||
|
if (g == -1)
|
||||||
|
++ nb_inliers[m_labels.size()];
|
||||||
|
else if (g >= int(m_labels.size()))
|
||||||
|
{
|
||||||
|
++ nb_inliers[m_labels.size() + 1];
|
||||||
|
if (!verbose)
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
++ nb_inliers[std::size_t(gt)];
|
||||||
|
++ total;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool valid = (nb_inliers[m_labels.size() + 1] == 0);
|
||||||
|
|
||||||
|
if (verbose)
|
||||||
|
{
|
||||||
|
std::cout << "Ground truth is " << (valid ? "valid" : "invalid") << ":" << std::endl;
|
||||||
|
std::cout << " * " << nb_inliers[m_labels.size()] << " unclassified item(s) ("
|
||||||
|
<< 100. * (nb_inliers[m_labels.size()] / double(total)) << "%)" << std::endl;
|
||||||
|
for (std::size_t i = 0; i < m_labels.size(); ++ i)
|
||||||
|
std::cout << " * " << nb_inliers[i] << " " << m_labels[i]->name() << " inlier(s) ("
|
||||||
|
<< 100. * (nb_inliers[i] / double(total)) << "%)" << std::endl;
|
||||||
|
if (!valid)
|
||||||
|
std::cout << " * " << nb_inliers[m_labels.size() + 1] << " item(s) with out-of-range index ("
|
||||||
|
<< 100. * (nb_inliers[m_labels.size() + 1] / double(total)) << "%)" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
return valid;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -147,6 +147,8 @@ public:
|
||||||
template <typename LabelIndexRange>
|
template <typename LabelIndexRange>
|
||||||
void train (const LabelIndexRange& ground_truth)
|
void train (const LabelIndexRange& ground_truth)
|
||||||
{
|
{
|
||||||
|
CGAL_precondition (m_labels.is_valid_ground_truth (ground_truth));
|
||||||
|
|
||||||
#if (CV_MAJOR_VERSION < 3)
|
#if (CV_MAJOR_VERSION < 3)
|
||||||
if (rtree != nullptr)
|
if (rtree != nullptr)
|
||||||
delete rtree;
|
delete rtree;
|
||||||
|
|
|
||||||
|
|
@ -298,6 +298,8 @@ public:
|
||||||
float train (const LabelIndexRange& ground_truth,
|
float train (const LabelIndexRange& ground_truth,
|
||||||
unsigned int nb_tests = 300)
|
unsigned int nb_tests = 300)
|
||||||
{
|
{
|
||||||
|
CGAL_precondition (m_labels.is_valid_ground_truth (ground_truth));
|
||||||
|
|
||||||
std::vector<std::vector<std::size_t> > training_sets (m_labels.size());
|
std::vector<std::vector<std::size_t> > training_sets (m_labels.size());
|
||||||
std::size_t nb_tot = 0;
|
std::size_t nb_tot = 0;
|
||||||
std::size_t i = 0;
|
std::size_t i = 0;
|
||||||
|
|
|
||||||
|
|
@ -251,6 +251,8 @@ public:
|
||||||
const std::vector<std::size_t>& hidden_layers
|
const std::vector<std::size_t>& hidden_layers
|
||||||
= std::vector<std::size_t>())
|
= std::vector<std::size_t>())
|
||||||
{
|
{
|
||||||
|
CGAL_precondition (m_labels.is_valid_ground_truth (ground_truth));
|
||||||
|
|
||||||
if (restart_from_scratch)
|
if (restart_from_scratch)
|
||||||
clear();
|
clear();
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue