Add function to check validity of ground truth + preconditions

This commit is contained in:
Simon Giraudot 2020-06-29 14:04:56 +02:00
parent 30e8ac4f7c
commit 12a027b027
7 changed files with 75 additions and 0 deletions

View File

@ -77,6 +77,10 @@ int main (int argc, char** argv)
Label_handle vegetation = labels.add ("vegetation");
Label_handle roof = labels.add ("roof");
// Check if ground truth is valid for this label set
if (!labels.is_valid_ground_truth (pts.range(label_map), true))
return EXIT_FAILURE;
std::vector<int> label_indices(pts.size(), -1);
std::cerr << "Using ETHZ Random Forest Classifier" << std::endl;

View File

@ -167,6 +167,8 @@ public:
std::size_t num_trees = 25,
std::size_t max_depth = 20)
{
CGAL_precondition (m_labels.is_valid_ground_truth (ground_truth));
CGAL::internal::liblearning::RandomForest::ForestParams params;
params.n_trees = num_trees;
params.max_depth = max_depth;

View File

@ -120,6 +120,8 @@ public:
void append (const GroundTruthIndexRange& ground_truth,
const ResultIndexRange& result)
{
CGAL_precondition (m_labels.is_valid_ground_truth (ground_truth));
CGAL_precondition (m_labels.is_valid_ground_truth (result));
for (const auto& p : CGAL::make_range
(boost::make_zip_iterator(boost::make_tuple(ground_truth.begin(), result.begin())),

View File

@ -264,6 +264,67 @@ public:
/// @}
/// \name Validity
/// @{
/*!
\brief Checks the validity of the ground truth with respect to the
label set.
\param ground_truth range of label indices. This function checks
that all these indices are either -1 (for unclassified) or a valid
index of one of the labels. If at least one of the indices is out
of range, this function returns `false`, otherwise it returns
`true`.
\param verbose if set to `true`, the number of inliers of each
label, the number of unclassified items and the potential number
of out-of-range items are displayed. Otherwise, this function does
not display anything.
*/
template <typename LabelIndexRange>
bool is_valid_ground_truth (const LabelIndexRange& ground_truth,
bool verbose = false) const
{
std::vector<std::size_t> nb_inliers (m_labels.size() + 2, 0);
std::size_t total = 0;
for (const auto& gt : ground_truth)
{
int g = int(gt);
if (g == -1)
++ nb_inliers[m_labels.size()];
else if (g >= int(m_labels.size()))
{
++ nb_inliers[m_labels.size() + 1];
if (!verbose)
break;
}
else
++ nb_inliers[std::size_t(gt)];
++ total;
}
bool valid = (nb_inliers[m_labels.size() + 1] == 0);
if (verbose)
{
std::cout << "Ground truth is " << (valid ? "valid" : "invalid") << ":" << std::endl;
std::cout << " * " << nb_inliers[m_labels.size()] << " unclassified item(s) ("
<< 100. * (nb_inliers[m_labels.size()] / double(total)) << "%)" << std::endl;
for (std::size_t i = 0; i < m_labels.size(); ++ i)
std::cout << " * " << nb_inliers[i] << " " << m_labels[i]->name() << " inlier(s) ("
<< 100. * (nb_inliers[i] / double(total)) << "%)" << std::endl;
if (!valid)
std::cout << " * " << nb_inliers[m_labels.size() + 1] << " item(s) with out-of-range index ("
<< 100. * (nb_inliers[m_labels.size() + 1] / double(total)) << "%)" << std::endl;
}
return valid;
}
/// @}
};

View File

@ -147,6 +147,8 @@ public:
template <typename LabelIndexRange>
void train (const LabelIndexRange& ground_truth)
{
CGAL_precondition (m_labels.is_valid_ground_truth (ground_truth));
#if (CV_MAJOR_VERSION < 3)
if (rtree != nullptr)
delete rtree;

View File

@ -298,6 +298,8 @@ public:
float train (const LabelIndexRange& ground_truth,
unsigned int nb_tests = 300)
{
CGAL_precondition (m_labels.is_valid_ground_truth (ground_truth));
std::vector<std::vector<std::size_t> > training_sets (m_labels.size());
std::size_t nb_tot = 0;
std::size_t i = 0;

View File

@ -251,6 +251,8 @@ public:
const std::vector<std::size_t>& hidden_layers
= std::vector<std::size_t>())
{
CGAL_precondition (m_labels.is_valid_ground_truth (ground_truth));
if (restart_from_scratch)
clear();