37 #ifndef VIGRA_RANDOM_FOREST_HXX 38 #define VIGRA_RANDOM_FOREST_HXX 46 #include "mathutil.hxx" 47 #include "array_vector.hxx" 48 #include "sized_int.hxx" 50 #include "metaprogramming.hxx" 52 #include "functorexpression.hxx" 53 #include "random_forest/rf_common.hxx" 54 #include "random_forest/rf_nodeproxy.hxx" 55 #include "random_forest/rf_split.hxx" 56 #include "random_forest/rf_decisionTree.hxx" 57 #include "random_forest/rf_visitors.hxx" 58 #include "random_forest/rf_region.hxx" 59 #include "sampling.hxx" 60 #include "random_forest/rf_preprocessing.hxx" 61 #include "random_forest/rf_online_prediction_set.hxx" 62 #include "random_forest/rf_earlystopping.hxx" 63 #include "random_forest/rf_ridge_split.hxx" 83 inline SamplerOptions make_sampler_opt ( RandomForestOptions & RF_opt)
85 SamplerOptions return_opt;
87 return_opt.
stratified(RF_opt.stratification_method_ == RF_EQUAL);
146 template <
class LabelType =
double ,
class PreprocessorTag = ClassificationTag >
153 typedef detail::DecisionTree DecisionTree_t;
160 typedef LabelType LabelT;
168 ProblemSpec_t ext_param_;
198 ProblemSpec_t
const & ext_param = ProblemSpec_t())
201 ext_param_(ext_param)
227 template<
class TopologyIterator,
class ParameterIterator>
229 TopologyIterator topology_begin,
230 ParameterIterator parameter_begin,
231 ProblemSpec_t
const & problem_spec,
232 Options_t
const & options = Options_t())
234 trees_(treeCount, DecisionTree_t(problem_spec)),
235 ext_param_(problem_spec),
241 for(
int k=0; k<treeCount; ++k, ++topology_begin, ++parameter_begin)
243 trees_[k].topology_ = *topology_begin;
244 trees_[k].parameters_ = *parameter_begin;
262 vigra_precondition(ext_param_.used() ==
true,
263 "RandomForest::ext_param(): " 264 "Random forest has not been trained yet.");
281 vigra_precondition(ext_param_.used() ==
false,
282 "RandomForest::set_ext_param():" 283 "Random forest has been trained! Call reset()" 284 "before specifying new extrinsic parameters.");
308 DecisionTree_t
const &
tree(
int index)
const 310 return trees_[index];
315 DecisionTree_t &
tree(
int index)
317 return trees_[index];
325 return ext_param_.column_count_;
336 return ext_param_.column_count_;
344 return ext_param_.class_count_;
351 return options_.tree_count_;
392 template <
class U,
class C1,
403 Random_t
const & random);
405 template <
class U,
class C1,
426 template <
class U,
class C1,
class U2,
class C2,
class Visitor_t>
438 template <
class U,
class C1,
class U2,
class C2,
439 class Visitor_t,
class Split_t>
470 template <
class U,
class C1,
class U2,
class C2>
482 template<
class U,
class C1,
495 bool adjust_thresholds=
false);
497 template <
class U,
class C1,
class U2,
class C2>
502 onlineLearn(features,
512 template<
class U,
class C1,
526 template<
class U,
class C1,
class U2,
class C2>
532 reLearnTree(features,
561 template <
class U,
class C,
class Stop>
564 template <
class U,
class C>
575 template <
class U,
class C>
589 template <
class U,
class C1,
class T,
class C2>
593 vigra_precondition(features.
shape(0) == labels.
shape(0),
594 "RandomForest::predictLabels(): Label array has wrong size.");
595 for(
int k=0; k<features.
shape(0); ++k)
597 vigra_precondition(!detail::contains_nan(
rowVector(features, k)),
598 "RandomForest::predictLabels(): NaN in feature matrix.");
599 labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(
rowVector(features, k),
rf_default()));
613 template <
class U,
class C1,
class T,
class C2>
616 LabelType nanLabel)
const 618 vigra_precondition(features.
shape(0) == labels.
shape(0),
619 "RandomForest::predictLabels(): Label array has wrong size.");
620 for(
int k=0; k<features.
shape(0); ++k)
622 if(detail::contains_nan(
rowVector(features, k)))
623 labels(k,0) = nanLabel;
625 labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(
rowVector(features, k),
rf_default()));
638 template <
class U,
class C1,
class T,
class C2,
class Stop>
643 vigra_precondition(features.
shape(0) == labels.
shape(0),
644 "RandomForest::predictLabels(): Label array has wrong size.");
645 for(
int k=0; k<features.
shape(0); ++k)
646 labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(
rowVector(features, k), stop));
660 template <
class U,
class C1,
class T,
class C2,
class Stop>
664 template <
class T1,
class T2,
class C>
665 void predictProbabilities(OnlinePredictionSet<T1> & predictionSet,
674 template <
class U,
class C1,
class T,
class C2>
678 predictProbabilities(features, prob,
rf_default());
681 template <
class U,
class C1,
class T,
class C2>
691 template <
class LabelType,
class PreprocessorTag>
692 template<
class U,
class C1,
705 bool adjust_thresholds)
707 online_visitor_.activate();
708 online_visitor_.adjust_thresholds=adjust_thresholds;
720 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 721 Default_Stop_t default_stop(options_);
722 typename RF_CHOOSER(Stop_t)::type stop
723 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
724 Default_Split_t default_split;
725 typename RF_CHOOSER(Split_t)::type split
726 = RF_CHOOSER(Split_t)::choose(split_, default_split);
730 typename RF_CHOOSER(Visitor_t)::type>
733 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
735 vigra_precondition(options_.prepare_online_learning_,
"onlineLearn: online learning must be enabled on RandomForest construction");
741 ext_param_.class_count_=0;
742 Preprocessor_t preprocessor( features, response,
743 options_, ext_param_);
746 RandFunctor_t randint ( random);
749 split.set_external_parameters(ext_param_);
750 stop.set_external_parameters(ext_param_);
754 PoissonSampler<RandomTT800> poisson_sampler(1.0,
vigra::Int32(new_start_index),
vigra::Int32(ext_param().row_count_));
760 for(
int ii = 0; ii < static_cast<int>(trees_.
size()); ++ii)
762 online_visitor_.tree_id=ii;
763 poisson_sampler.sample();
764 std::map<int,int> leaf_parents;
765 leaf_parents.clear();
767 for(
int s=0;s<poisson_sampler.numOfSamples();++s)
769 int sample=poisson_sampler[s];
770 online_visitor_.current_label=preprocessor.response()(sample,0);
771 online_visitor_.last_node_id=StackEntry_t::DecisionTreeNoParent;
772 int leaf=trees_[ii].getToLeaf(
rowVector(features,sample),online_visitor_);
776 online_visitor_.add_to_index_list(ii,leaf,sample);
779 if(
Node<e_ConstProbNode>(trees_[ii].topology_,trees_[ii].parameters_,leaf).prob_begin()[preprocessor.response()(sample,0)]!=1.0)
781 leaf_parents[leaf]=online_visitor_.last_node_id;
786 std::map<int,int>::iterator leaf_iterator;
787 for(leaf_iterator=leaf_parents.begin();leaf_iterator!=leaf_parents.end();++leaf_iterator)
789 int leaf=leaf_iterator->first;
790 int parent=leaf_iterator->second;
791 int lin_index=online_visitor_.trees_online_information[ii].exterior_to_index[leaf];
794 indeces.swap(online_visitor_.trees_online_information[ii].index_lists[lin_index]);
795 StackEntry_t stack_entry(indeces.
begin(),
797 ext_param_.class_count_);
802 if(
NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).
child(0)==leaf)
808 vigra_assert(
NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(1)==leaf,
"last_node_id seems to be wrong");
809 stack_entry.rightParent=parent;
813 trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,-1);
815 online_visitor_.move_exterior_node(ii,trees_[ii].topology_.
size(),ii,leaf);
828 online_visitor_.deactivate();
831 template<
class LabelType,
class PreprocessorTag>
832 template<
class U,
class C1,
853 ext_param_.class_count_=0;
861 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 862 Default_Stop_t default_stop(options_);
863 typename RF_CHOOSER(Stop_t)::type stop
864 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
865 Default_Split_t default_split;
866 typename RF_CHOOSER(Split_t)::type split
867 = RF_CHOOSER(Split_t)::choose(split_, default_split);
871 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
873 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
875 vigra_precondition(options_.prepare_online_learning_,
"reLearnTree: Re learning trees only makes sense, if online learning is enabled");
876 online_visitor_.activate();
879 RandFunctor_t randint ( random);
885 Preprocessor_t preprocessor( features, response,
886 options_, ext_param_);
889 split.set_external_parameters(ext_param_);
890 stop.set_external_parameters(ext_param_);
897 preprocessor.strata().end(),
898 detail::make_sampler_opt(options_)
899 .sampleSize(ext_param().actual_msample_),
906 first_stack_entry( sampler.sampledIndices().begin(),
907 sampler.sampledIndices().end(),
908 ext_param_.class_count_);
910 .set_oob_range( sampler.oobIndices().begin(),
911 sampler.oobIndices().end());
913 online_visitor_.tree_id=treeId;
914 trees_[treeId].reset();
916 .learn( preprocessor.features(),
917 preprocessor.response(),
924 .visit_after_tree( *
this,
930 online_visitor_.deactivate();
933 template <
class LabelType,
class PreprocessorTag>
934 template <
class U,
class C1,
946 Random_t
const & random)
957 vigra_precondition(features.
shape(0) == response.
shape(0),
958 "RandomForest::learn(): shape mismatch between features and response.");
965 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 966 Default_Stop_t default_stop(options_);
967 typename RF_CHOOSER(Stop_t)::type stop
968 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
969 Default_Split_t default_split;
970 typename RF_CHOOSER(Split_t)::type split
971 = RF_CHOOSER(Split_t)::choose(split_, default_split);
975 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
977 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
979 if(options_.prepare_online_learning_)
980 online_visitor_.activate();
982 online_visitor_.deactivate();
986 RandFunctor_t randint ( random);
993 Preprocessor_t preprocessor( features, response,
994 options_, ext_param_);
997 split.set_external_parameters(ext_param_);
998 stop.set_external_parameters(ext_param_);
1002 trees_.resize(options_.tree_count_ , DecisionTree_t(ext_param_));
1005 preprocessor.strata().end(),
1006 detail::make_sampler_opt(options_)
1007 .sampleSize(ext_param().actual_msample_),
1010 visitor.visit_at_beginning(*
this, preprocessor);
1013 for(
int ii = 0; ii < static_cast<int>(trees_.
size()); ++ii)
1019 first_stack_entry( sampler.sampledIndices().begin(),
1020 sampler.sampledIndices().end(),
1021 ext_param_.class_count_);
1023 .set_oob_range( sampler.oobIndices().begin(),
1024 sampler.oobIndices().end());
1026 .learn( preprocessor.features(),
1027 preprocessor.response(),
1034 .visit_after_tree( *
this,
1041 visitor.visit_at_end(*
this, preprocessor);
1043 online_visitor_.deactivate();
1049 template <
class LabelType,
class Tag>
1050 template <
class U,
class C,
class Stop>
1054 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1055 "RandomForestn::predictLabel():" 1056 " Too few columns in feature matrix.");
1057 vigra_precondition(
rowCount(features) == 1,
1058 "RandomForestn::predictLabel():" 1059 " Feature matrix must have a singlerow.");
1062 predictProbabilities(features, probabilities, stop);
1063 ext_param_.to_classlabel(
argMax(probabilities), d);
1069 template <
class LabelType,
class PreprocessorTag>
1070 template <
class U,
class C>
1075 using namespace functor;
1076 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1077 "RandomForestn::predictLabel(): Too few columns in feature matrix.");
1078 vigra_precondition(
rowCount(features) == 1,
1079 "RandomForestn::predictLabel():" 1080 " Feature matrix must have a single row.");
1082 predictProbabilities(features, prob);
1083 std::transform( prob.begin(), prob.end(),
1084 priors.
begin(), prob.begin(),
1087 ext_param_.to_classlabel(
argMax(prob), d);
1091 template<
class LabelType,
class PreprocessorTag>
1092 template <
class T1,
class T2,
class C>
1101 "RandomFroest::predictProbabilities():" 1102 " Feature matrix and probability matrix size mismatch.");
1105 vigra_precondition(
columnCount(predictionSet.features) >= ext_param_.column_count_,
1106 "RandomForestn::predictProbabilities():" 1107 " Too few columns in feature matrix.");
1109 == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1110 "RandomForestn::predictProbabilities():" 1111 " Probability matrix must have as many columns as there are classes.");
1114 std::vector<T1> totalWeights(predictionSet.indices[0].size(),0.0);
1117 for(
int k=0; k<options_.tree_count_; ++k)
1119 set_id=(set_id+1) % predictionSet.indices[0].size();
1120 typedef std::set<SampleRange<T1> > my_set;
1121 typedef typename my_set::iterator set_it;
1124 std::vector<std::pair<int,set_it> > stack;
1126 for(set_it i=predictionSet.ranges[set_id].begin();
1127 i!=predictionSet.ranges[set_id].end();++i)
1128 stack.push_back(std::pair<int,set_it>(2,i));
1130 int num_decisions=0;
1131 while(!stack.empty())
1133 set_it range=stack.back().second;
1134 int index=stack.back().first;
1138 if(trees_[k].isLeafNode(trees_[k].topology_[index]))
1141 trees_[k].parameters_,
1142 index).prob_begin();
1143 for(
int i=range->start;i!=range->end;++i)
1146 for(
int l=0; l<ext_param_.class_count_; ++l)
1148 prob(predictionSet.indices[set_id][i], l) +=
static_cast<T2
>(weights[l]);
1150 totalWeights[predictionSet.indices[set_id][i]] +=
static_cast<T1
>(weights[l]);
1157 if(trees_[k].topology_[index]!=i_ThresholdNode)
1159 throw std::runtime_error(
"predicting with online prediction sets is only supported for RFs with threshold nodes");
1161 Node<i_ThresholdNode> node(trees_[k].topology_,trees_[k].parameters_,index);
1162 if(range->min_boundaries[node.column()]>=node.threshold())
1165 stack.push_back(std::pair<int,set_it>(node.child(1),range));
1168 if(range->max_boundaries[node.column()]<node.threshold())
1171 stack.push_back(std::pair<int,set_it>(node.child(0),range));
1175 SampleRange<T1> new_range=*range;
1176 new_range.min_boundaries[node.column()]=FLT_MAX;
1177 range->max_boundaries[node.column()]=-FLT_MAX;
1178 new_range.start=new_range.end=range->end;
1180 while(i!=range->end)
1183 if(predictionSet.features(predictionSet.indices[set_id][i],node.column())>=node.threshold())
1185 new_range.min_boundaries[node.column()]=std::min(new_range.min_boundaries[node.column()],
1186 predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1189 std::swap(predictionSet.indices[set_id][i],predictionSet.indices[set_id][range->end]);
1194 range->max_boundaries[node.column()]=std::max(range->max_boundaries[node.column()],
1195 predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1200 if(range->start==range->end)
1202 predictionSet.ranges[set_id].erase(range);
1206 stack.push_back(std::pair<int,set_it>(node.child(0),range));
1209 if(new_range.start!=new_range.end)
1211 std::pair<set_it,bool> new_it=predictionSet.ranges[set_id].insert(new_range);
1212 stack.push_back(std::pair<int,set_it>(node.child(1),new_it.first));
1216 predictionSet.cumulativePredTime[k]=num_decisions;
1218 for(
unsigned int i=0;i<totalWeights.size();++i)
1222 for(
int l=0; l<ext_param_.class_count_; ++l)
1225 prob(i, l) /= totalWeights[i];
1227 assert(test==totalWeights[i]);
1228 assert(totalWeights[i]>0.0);
1232 template <
class LabelType,
class PreprocessorTag>
1233 template <
class U,
class C1,
class T,
class C2,
class Stop_t>
1237 Stop_t & stop_)
const 1243 "RandomForestn::predictProbabilities():" 1244 " Feature matrix and probability matrix size mismatch.");
1248 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1249 "RandomForestn::predictProbabilities():" 1250 " Too few columns in feature matrix.");
1252 == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1253 "RandomForestn::predictProbabilities():" 1254 " Probability matrix must have as many columns as there are classes.");
1256 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 1257 Default_Stop_t default_stop(options_);
1258 typename RF_CHOOSER(Stop_t)::type & stop
1259 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
1261 stop.set_external_parameters(ext_param_, tree_count());
1262 prob.
init(NumericTraits<T>::zero());
1272 for(
int row=0; row <
rowCount(features); ++row)
1278 if(detail::contains_nan(currentRow))
1287 double totalWeight = 0.0;
1290 for(
int k=0; k<options_.tree_count_; ++k)
1293 weights = trees_[k ].predict(currentRow);
1296 int weighted = options_.predict_weighted_;
1297 for(
int l=0; l<ext_param_.class_count_; ++l)
1299 double cur_w = weights[l] * (weighted * (*(weights-1))
1301 prob(row, l) +=
static_cast<T
>(cur_w);
1303 totalWeight += cur_w;
1305 if(stop.after_prediction(weights,
1315 for(
int l=0; l< ext_param_.class_count_; ++l)
1317 prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
1323 template <
class LabelType,
class PreprocessorTag>
1324 template <
class U,
class C1,
class T,
class C2>
1333 "RandomForestn::predictProbabilities():" 1334 " Feature matrix and probability matrix size mismatch.");
1338 vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
1339 "RandomForestn::predictProbabilities():" 1340 " Too few columns in feature matrix.");
1342 == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1343 "RandomForestn::predictProbabilities():" 1344 " Probability matrix must have as many columns as there are classes.");
1346 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 1347 prob.
init(NumericTraits<T>::zero());
1357 for(
int row=0; row <
rowCount(features); ++row)
1362 double totalWeight = 0.0;
1365 for(
int k=0; k<options_.tree_count_; ++k)
1368 weights = trees_[k ].predict(
rowVector(features, row));
1371 int weighted = options_.predict_weighted_;
1372 for(
int l=0; l<ext_param_.class_count_; ++l)
1374 double cur_w = weights[l] * (weighted * (*(weights-1))
1376 prob(row, l) +=
static_cast<T
>(cur_w);
1378 totalWeight += cur_w;
1382 prob/= options_.tree_count_;
1388 #include "random_forest/rf_algorithm.hxx" 1389 #endif // VIGRA_RANDOM_FOREST_HXX int feature_count() const
return number of features used while training.
Definition: random_forest.hxx:323
Definition: rf_region.hxx:57
void set_ext_param(ProblemSpec_t const &in)
set external parameters
Definition: random_forest.hxx:278
Definition: rf_nodeproxy.hxx:626
detail::RF_DEFAULT & rf_default()
factory function to return a RF_DEFAULT tag
Definition: rf_common.hxx:131
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:671
Definition: rf_preprocessing.hxx:63
int class_count() const
return number of classes used while training.
Definition: random_forest.hxx:342
int column_count() const
return number of features used while training.
Definition: random_forest.hxx:334
Create random samples from a sequence of indices.
Definition: sampling.hxx:232
Int32 leftParent
Definition: rf_region.hxx:69
void sample()
Definition: sampling.hxx:467
Definition: rf_split.hxx:993
Definition: matrix.hxx:123
problem specification class for the random forest.
Definition: rf_common.hxx:538
RandomForest(Options_t const &options=Options_t(), ProblemSpec_t const &ext_param=ProblemSpec_t())
default constructor
Definition: random_forest.hxx:197
INT & child(Int32 l)
Definition: rf_nodeproxy.hxx:224
ProblemSpec_t const & ext_param() const
return external parameters for viewing
Definition: random_forest.hxx:260
void predictProbabilities(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob, Stop &stop) const
predict the class probabilities for multiple labels
Standard early stopping criterion.
Definition: rf_common.hxx:885
DecisionTree_t & tree(int index)
access trees
Definition: random_forest.hxx:315
Definition: rf_nodeproxy.hxx:87
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels, Stop &stop) const
predict multiple labels with given features
Definition: random_forest.hxx:639
Options_t & set_options()
access random forest options
Definition: random_forest.hxx:291
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &labels)
learn on data with default configuration
Definition: random_forest.hxx:471
void reset_tree(int tree_id)
Definition: rf_visitors.hxx:635
Random forest version 2 (see also vigra::rf3::RandomForest for version 3)
Definition: random_forest.hxx:147
const difference_type & shape() const
Definition: multi_array.hxx:1648
Definition: rf_visitors.hxx:254
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition: sized_int.hxx:175
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition: algorithm.hxx:96
Definition: rf_visitors.hxx:583
void predictProbabilities(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob) const
predict the class probabilities for multiple labels
Definition: random_forest.hxx:675
SamplerOptions & stratified(bool in=true)
Draw equally many samples from each "stratum". A stratum is a group of like entities, e.g. pixels belonging to the same object class. This is useful to create balanced samples when the class probabilities are very unbalanced (e.g. when there are many background and few foreground pixels). Stratified sampling thus avoids that a trained classifier is biased towards the majority class.
Definition: sampling.hxx:141
SamplerOptions & withReplacement(bool in=true)
Sample from training population with replacement.
Definition: sampling.hxx:83
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:697
const_iterator begin() const
Definition: array_vector.hxx:223
void reLearnTree(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, int treeId, Visitor_t visitor_, Split_t split_, Stop_t stop_, Random_t &random)
Definition: random_forest.hxx:838
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:684
RandomForest(int treeCount, TopologyIterator topology_begin, ParameterIterator parameter_begin, ProblemSpec_t const &problem_spec, Options_t const &options=Options_t())
Create RF from external source.
Definition: random_forest.hxx:228
int tree_count() const
return number of trees
Definition: random_forest.hxx:349
Definition: random.hxx:336
Base class for, and view to, vigra::MultiArray.
Definition: multi_array.hxx:704
Options object for the random forest.
Definition: rf_common.hxx:170
MultiArrayView & init(const U &init)
Definition: multi_array.hxx:1206
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, Visitor_t visitor, Split_t split, Stop_t stop, Random_t const &random)
learn on data with custom config and random number generator
Definition: random_forest.hxx:941
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels, LabelType nanLabel) const
predict multiple labels with given features
Definition: random_forest.hxx:614
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels) const
predict multiple labels with given features
Definition: random_forest.hxx:590
size_type size() const
Definition: array_vector.hxx:358
DecisionTree_t const & tree(int index) const
access const trees
Definition: random_forest.hxx:308
const_iterator end() const
Definition: array_vector.hxx:237
LabelType predictLabel(MultiArrayView< 2, U, C >const &features, Stop &stop) const
predict a label given a feature.
Definition: random_forest.hxx:1052
Definition: rf_visitors.hxx:234
Options_t const & options() const
access const random forest options
Definition: random_forest.hxx:301