mirror of https://github.com/CGAL/cgal
279 lines
9.8 KiB
C++
279 lines
9.8 KiB
C++
// Copyright (c) 2014 Stefan Walk
|
|
//
|
|
// Permission is hereby granted, free of charge, to any person obtaining a copy of
|
|
// this software and associated documentation files (the "Software"), to deal in
|
|
// the Software without restriction, including without limitation the rights to
|
|
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
|
|
// of the Software, and to permit persons to whom the Software is furnished to do
|
|
// so, subject to the following conditions:
|
|
//
|
|
// The above copyright notice and this permission notice shall be included in all
|
|
// copies or substantial portions of the Software.
|
|
//
|
|
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
// SOFTWARE.
|
|
//
|
|
// $URL$
|
|
// $Id$
|
|
// SPDX-License-Identifier: MIT
|
|
//
|
|
// Author(s) : Stefan Walk
|
|
|
|
// Modifications from original library:
|
|
// * changed inclusion protection tag
|
|
// * moved to namespace CGAL::internal::
|
|
// * fix computation of node_dist[label] so that results are always <= 1.0
|
|
// * change serialization functions to avoid a bug with boost and some
|
|
// compilers (that leads to dereferencing a null pointer)
|
|
// * add a method to get feature usage
|
|
|
|
#ifndef CGAL_INTERNAL_LIBLEARNING_RANDOMFORESTS_NODE_H
|
|
#define CGAL_INTERNAL_LIBLEARNING_RANDOMFORESTS_NODE_H
|
|
#include "../dataview.h"
|
|
#include "common-libraries.hpp"
|
|
#include <boost/serialization/scoped_ptr.hpp>
|
|
#include <boost/serialization/vector.hpp>
|
|
#if VERBOSE_NODE_LEARNING
|
|
#include <cstdio>
|
|
#endif
|
|
|
|
namespace CGAL { namespace internal {
|
|
|
|
namespace liblearning {
|
|
namespace RandomForest {
|
|
|
|
template <typename Derived, typename ParamT, typename Splitter>
|
|
class Node {
|
|
public:
|
|
typedef typename Splitter::FeatureType FeatureType;
|
|
bool is_leaf;
|
|
size_t n_samples;
|
|
size_t depth;
|
|
typedef ParamT ParamType;
|
|
ParamType const* params;
|
|
Splitter splitter;
|
|
|
|
boost::scoped_ptr<Derived> left;
|
|
boost::scoped_ptr<Derived> right;
|
|
std::vector<float> node_dist;
|
|
|
|
Node() : is_leaf(true), n_samples(0), depth(-1), params(0) {}
|
|
Node(size_t depth, ParamType const* params) :
|
|
is_leaf(true), n_samples(0), depth(depth), params(params)
|
|
{}
|
|
|
|
bool pure(DataView2D<int> labels, int* sample_idxes) const {
|
|
if (n_samples < 2)
|
|
return true; // an empty node is by definition pure
|
|
int first_sample_idx = sample_idxes[0];
|
|
int seen_class = labels(first_sample_idx, 0);
|
|
// check if all classes are equal to the first class
|
|
for (size_t i_sample = 1; i_sample < n_samples; ++i_sample) {
|
|
int sample_idx = sample_idxes[i_sample];
|
|
if (labels(sample_idx, 0) != seen_class)
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
float const* votes() const {
|
|
return (float const*)&node_dist[0];
|
|
}
|
|
|
|
int partition_samples(DataView2D<FeatureType> samples, int* sample_idxes) {
|
|
// sort samples in bag so that left-samples precede right-samples
|
|
// works like std::partition
|
|
int low = 0;
|
|
int high = n_samples;
|
|
|
|
while (true) {
|
|
while (true) {
|
|
if (low == high) {
|
|
return low;
|
|
} else if (!splitter.classify_sample(samples.row_pointer(sample_idxes[low]))) {
|
|
++low;
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
--high;
|
|
while (true) {
|
|
if (low == high) {
|
|
return low;
|
|
} else if (splitter.classify_sample(samples.row_pointer(sample_idxes[high]))) {
|
|
--high;
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
std::swap(sample_idxes[low], sample_idxes[high]);
|
|
++low;
|
|
}
|
|
}
|
|
|
|
Derived const* split (FeatureType const* sample) const {
|
|
if (splitter.classify_sample(sample)) {
|
|
return right.get();
|
|
} else {
|
|
return left.get();
|
|
}
|
|
}
|
|
|
|
typedef std::list<Derived const*> NodeList;
|
|
|
|
NodeList get_all_childs() {
|
|
NodeList ret;
|
|
ret.push_back(this);
|
|
if (!is_leaf) {
|
|
NodeList left_childs = left->get_all_childs();
|
|
ret.splice(ret.end(), left_childs);
|
|
NodeList right_childs = right->get_all_childs();
|
|
ret.splice(ret.end(), right_childs);
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
template<typename SplitGenerator>
|
|
void determine_best_split(DataView2D<FeatureType> samples,
|
|
DataView2D<int> labels,
|
|
int* sample_idxes,
|
|
SplitGenerator split_generator,
|
|
RandomGen& gen
|
|
)
|
|
{
|
|
typename Splitter::FeatureClassData data_points;
|
|
init_feature_class_data(data_points, params->n_classes, n_samples);
|
|
float best_loss = std::numeric_limits<float>::infinity();
|
|
|
|
std::vector<uint64_t> classes_l;
|
|
std::vector<uint64_t> classes_r;
|
|
|
|
// pass information about data to split generator
|
|
split_generator.init(samples,
|
|
labels,
|
|
sample_idxes,
|
|
n_samples,
|
|
params->n_classes,
|
|
gen);
|
|
|
|
size_t n_proposals = split_generator.num_proposals();
|
|
|
|
std::pair<FeatureType, float> results; // (threshold, loss)
|
|
|
|
for (size_t i_proposal = 0; i_proposal < n_proposals; ++i_proposal) {
|
|
// generate proposal
|
|
Splitter split = split_generator.gen_proposal(gen);
|
|
// map samples to numbers using proposal
|
|
split.map_points(samples, labels, sample_idxes, n_samples, data_points);
|
|
// check best loss using this proposal
|
|
results = static_cast<Derived*>(this)->determine_best_threshold(data_points, classes_l, classes_r, gen);
|
|
if (results.second < best_loss) {
|
|
// Proposal resulted into new optimum
|
|
best_loss = results.second;
|
|
split.set_threshold(results.first);
|
|
splitter = split;
|
|
}
|
|
}
|
|
}
|
|
|
|
template<typename SplitGenerator>
|
|
void train(DataView2D<FeatureType> samples,
|
|
DataView2D<int> labels,
|
|
int* sample_idxes,
|
|
size_t n_samples_,
|
|
SplitGenerator const& split_generator,
|
|
RandomGen& gen
|
|
)
|
|
{
|
|
n_samples = n_samples_;
|
|
node_dist.resize(params->n_classes, 0.0f);
|
|
for (size_t i_sample = 0; i_sample < n_samples; ++i_sample) {
|
|
int label = labels(sample_idxes[i_sample], 0);
|
|
node_dist[label] += 1.0f;
|
|
}
|
|
|
|
if (n_samples != 0)
|
|
for (std::size_t i = 0; i < node_dist.size(); ++ i)
|
|
node_dist[i] /= n_samples;
|
|
|
|
bool do_split = // Only split if ...
|
|
(n_samples >= params->min_samples_per_node) && // enough samples are available
|
|
!pure(labels, sample_idxes) && // this node is not already pure
|
|
(depth < params->max_depth); // we did not reach max depth
|
|
if (!do_split) {
|
|
splitter.threshold = 0.0;
|
|
return;
|
|
}
|
|
|
|
is_leaf = false;
|
|
|
|
#if VERBOSE_NODE_LEARNING
|
|
std::printf("Determining the best split at depth %zu/%zu\n", depth, params->max_depth);
|
|
#endif
|
|
determine_best_split(samples, labels, sample_idxes, split_generator, gen);
|
|
|
|
left.reset(new Derived(depth + 1, params));
|
|
right.reset(new Derived(depth + 1, params));
|
|
|
|
// sort samples in bag so that left-samples precede right-samples
|
|
int low = partition_samples(samples, sample_idxes);
|
|
int n_samples_left = low;
|
|
int n_samples_right = n_samples - low;
|
|
int offset_left = 0;
|
|
int offset_right = low;
|
|
#ifdef TREE_GRAPHVIZ_STREAM
|
|
if (depth <= TREE_GRAPHVIZ_MAX_DEPTH) {
|
|
TREE_GRAPHVIZ_STREAM << "p" << std::hex << (unsigned long)this
|
|
<< " -> "
|
|
<< "p" << std::hex << (unsigned long)left.get()
|
|
<< std::dec << " [label=\"" << n_samples_left << "\"];" << std::endl;
|
|
TREE_GRAPHVIZ_STREAM << "p" << std::hex << (unsigned long)this
|
|
<< " -> "
|
|
<< "p" << std::hex << (unsigned long)right.get()
|
|
<< std::dec << " [label=\"" << n_samples_right << "\"];" << std::endl;
|
|
}
|
|
#endif
|
|
// train left and right side of split
|
|
left->train (samples, labels, sample_idxes + offset_left, n_samples_left, split_generator, gen);
|
|
right->train(samples, labels, sample_idxes + offset_right, n_samples_right, split_generator, gen);
|
|
}
|
|
|
|
template <typename Archive>
|
|
void serialize(Archive& ar, unsigned /*version*/)
|
|
{
|
|
ar & BOOST_SERIALIZATION_NVP(is_leaf);
|
|
ar & BOOST_SERIALIZATION_NVP(n_samples);
|
|
ar & BOOST_SERIALIZATION_NVP(depth);
|
|
ar & BOOST_SERIALIZATION_NVP(params);
|
|
ar & BOOST_SERIALIZATION_NVP(splitter);
|
|
ar & BOOST_SERIALIZATION_NVP(node_dist);
|
|
if (!is_leaf)
|
|
{
|
|
ar & BOOST_SERIALIZATION_NVP(left);
|
|
ar & BOOST_SERIALIZATION_NVP(right);
|
|
}
|
|
}
|
|
|
|
void get_feature_usage (std::vector<std::size_t>& count) const
|
|
{
|
|
if (!is_leaf)
|
|
{
|
|
count[std::size_t(splitter.feature)] ++;
|
|
left->get_feature_usage(count);
|
|
right->get_feature_usage(count);
|
|
}
|
|
}
|
|
};
|
|
|
|
}
|
|
}
|
|
|
|
}} // namespace CGAL::internal::
|
|
|
|
#endif
|