Add variant of classify() to get detailed output

This commit is contained in:
Simon Giraudot 2018-05-31 14:05:53 +02:00
parent 5b1e535cb9
commit 642aea115a
1 changed files with 85 additions and 0 deletions

View File

@ -87,6 +87,53 @@ namespace internal {
};
template <typename Classifier, typename LabelIndexRange, typename ProbabilitiesRanges>
class Classify_detailed_output_functor
{
const Label_set& m_labels;
const Classifier& m_classifier;
LabelIndexRange& m_out;
ProbabilitiesRanges& m_prob;
public:
Classify_detailed_output_functor (const Label_set& labels,
const Classifier& classifier,
LabelIndexRange& out,
ProbabilitiesRanges& prob)
: m_labels (labels), m_classifier (classifier), m_out (out), m_prob (prob)
{ }
#ifdef CGAL_LINKED_WITH_TBB
void operator()(const tbb::blocked_range<std::size_t>& r) const
{
for (std::size_t s = r.begin(); s != r.end(); ++ s)
apply(s);
}
#endif // CGAL_LINKED_WITH_TBB
inline void apply (std::size_t s) const
{
std::size_t nb_class_best=0;
std::vector<float> values;
m_classifier (s, values);
float val_class_best = 0.f;
for(std::size_t k = 0; k < m_labels.size(); ++ k)
{
m_prob[k][s] = values[k];
if(val_class_best < values[k])
{
val_class_best = values[k];
nb_class_best = k;
}
}
m_out[s] = static_cast<typename LabelIndexRange::iterator::value_type>(nb_class_best);
}
};
template <typename Classifier>
class Classify_functor_local_smoothing_preprocessing
{
@ -344,6 +391,44 @@ namespace internal {
}
}
/// \cond SKIP_IN_MANUAL
// variant to get a detailed output (not documented yet)
template <typename ConcurrencyTag,
typename ItemRange,
typename Classifier,
typename LabelIndexRange,
typename ProbabilitiesRanges>
void classify (const ItemRange& input,
const Label_set& labels,
const Classifier& classifier,
LabelIndexRange& output,
ProbabilitiesRanges& probabilities)
{
output.resize (input.size());
probabilities.resize (labels.size());
for (std::size_t i = 0; i < probabilities.size(); ++ i)
probabilities[i].resize (input.size());
internal::Classify_detailed_output_functor<Classifier, LabelIndexRange, ProbabilitiesRanges>
f (labels, classifier, output, probabilities);
#ifndef CGAL_LINKED_WITH_TBB
CGAL_static_assertion_msg (!(boost::is_convertible<ConcurrencyTag, Parallel_tag>::value),
"Parallel_tag is enabled but TBB is unavailable.");
#else
if (boost::is_convertible<ConcurrencyTag,Parallel_tag>::value)
{
tbb::parallel_for(tbb::blocked_range<size_t>(0, input.size ()), f);
}
else
#endif
{
for (std::size_t i = 0; i < input.size(); ++ i)
f.apply(i);
}
}
/// \endcond
/*!
\ingroup PkgClassificationMain