36 #ifndef VIGRA_RANDOM_FOREST_NP_HXX 37 #define VIGRA_RANDOM_FOREST_NP_HXX 42 #include "vigra/mathutil.hxx" 43 #include "vigra/array_vector.hxx" 44 #include "vigra/sized_int.hxx" 45 #include "vigra/matrix.hxx" 46 #include "vigra/random.hxx" 47 #include "vigra/functorexpression.hxx" 58 AllColumns = 0x00000000,
59 ToBePrunedTag = 0x80000000,
60 LeafNodeTag = 0x40000000,
64 i_HypersphereNode = 2,
65 e_ConstProbNode = 0 | LeafNodeTag,
66 e_LogRegProbNode = 1 | LeafNodeTag
93 typedef T_Container_type::iterator Topology_type;
94 typedef P_Container_type::iterator Parameter_type;
97 mutable Topology_type topology_;
100 mutable Parameter_type parameters_;
101 int parameter_size_ ;
141 INT
const &
typeID()
const 161 return topology_ + 4 ;
177 return featureCount_;
197 Topology_type topology_end()
const 201 int topology_size()
const 203 return topology_size_;
211 Parameter_type parameters_end()
const 216 int parameters_size()
const 218 return parameter_size_;
243 vigra_precondition(topology_size_==o.topology_size_,
"Cannot copy nodes of different sizes");
244 vigra_precondition(featureCount_==o.featureCount_,
"Cannot copy nodes with different feature count");
245 vigra_precondition(classCount_==o.classCount_,
"Cannot copy nodes with different class counts");
246 vigra_precondition(parameters_size() ==o.parameters_size(),
"Cannot copy nodes with different parameter sizes");
255 P_Container_type
const & parameter,
258 topology_ (const_cast<Topology_type>(topology.begin()+ n)),
260 parameters_ (const_cast<Parameter_type>(parameter.begin() +
parameter_addr())),
262 featureCount_(topology[0]),
263 classCount_(topology[1]),
274 T_Container_type
const & topology,
275 P_Container_type
const & parameter,
278 topology_ (const_cast<Topology_type>(topology.begin()+ n)),
279 topology_size_(tLen),
280 parameters_ (const_cast<Parameter_type>(parameter.begin() +
parameter_addr())),
281 parameter_size_(pLen),
282 featureCount_(topology[0]),
283 classCount_(topology[1]),
296 topology_ (node.topology_),
297 topology_size_(tLen),
298 parameters_ (node.parameters_),
299 parameter_size_(pLen),
300 featureCount_(node.featureCount_),
301 classCount_(node.classCount_),
318 T_Container_type & topology,
319 P_Container_type & parameter)
321 topology_size_(tLen),
322 parameter_size_(pLen),
323 featureCount_(topology[0]),
324 classCount_(topology[1]),
330 size_t n = topology.
size();
331 for(
int ii = 0; ii < tLen; ++ii)
332 topology.push_back(0);
335 topology_ = topology.
begin()+ n;
341 for(
int ii = 0; ii < pLen; ++ii)
342 parameter.push_back(0);
357 T_Container_type & topology,
358 P_Container_type & parameter)
360 topology_size_(toCopy.topology_size()),
361 parameter_size_(toCopy.parameters_size()),
362 featureCount_(topology[0]),
363 classCount_(topology[1]),
369 size_t n = topology.
size();
370 for(
int ii = 0; ii < toCopy.topology_size(); ++ii)
373 topology_ = topology.
begin()+ n;
375 for(
int ii = 0; ii < toCopy.parameters_size(); ++ii)
383 template<NodeTags NodeType>
387 class Node<i_ThresholdNode>
397 Node( BT::T_Container_type & topology,
398 BT::P_Container_type & param)
399 : BT(5,2,topology, param)
401 BT::typeID() = i_ThresholdNode;
404 Node( BT::T_Container_type
const & topology,
405 BT::P_Container_type
const & param,
407 : BT(5,2,topology, param, n)
416 return BT::parameters_begin()[1];
419 double const & threshold()
const 421 return BT::parameters_begin()[1];
426 return BT::column_data()[0];
428 BT::INT
const & column()
const 430 return BT::column_data()[0];
433 template<
class U,
class C>
436 return (feature(0, column()) < threshold())?
child(0):
child(1);
442 class Node<i_HyperplaneNode>
452 BT::T_Container_type & topology,
453 BT::P_Container_type & split_param)
454 : BT(nCol + 5,nCol + 2,topology, split_param)
456 BT::typeID() = i_HyperplaneNode;
459 Node( BT::T_Container_type
const & topology,
460 BT::P_Container_type
const & split_param,
462 :
NodeBase(5 , 2,topology, split_param, n)
465 BT::topology_size_ += BT::column_data()[0]== AllColumns ?
467 : BT::column_data()[0];
468 BT::parameter_size_ += BT::columns_size();
475 BT::topology_size_ += BT::column_data()[0]== AllColumns ?
477 : BT::column_data()[0];
478 BT::parameter_size_ += BT::columns_size();
482 double const & intercept()
const 484 return BT::parameters_begin()[1];
488 return BT::parameters_begin()[1];
491 BT::Parameter_type
weights()
const 493 return BT::parameters_begin()+2;
498 return BT::parameters_begin()+2;
502 template<
class U,
class C>
505 double result = -1 * intercept();
506 if(*(BT::column_data()) == AllColumns)
508 for(
int ii = 0; ii < BT::columns_size(); ++ii)
510 result +=feature[ii] *
weights()[ii];
515 for(
int ii = 0; ii < BT::columns_size(); ++ii)
517 result +=feature[BT::columns_begin()[ii]] *
weights()[ii];
520 return result < 0 ? BT::child(0)
528 class Node<i_HypersphereNode>
538 BT::T_Container_type & topology,
539 BT::P_Container_type & param)
540 :
NodeBase(nCol + 5,nCol + 1,topology, param)
542 BT::typeID() = i_HypersphereNode;
545 Node( BT::T_Container_type
const & topology,
546 BT::P_Container_type
const & param,
550 BT::topology_size_ += BT::column_data()[0]== AllColumns ?
552 : BT::column_data()[0];
553 BT::parameter_size_ += BT::columns_size();
559 BT::topology_size_ += BT::column_data()[0]== AllColumns ?
561 : BT::column_data()[0];
562 BT::parameter_size_ += BT::columns_size();
566 double const & squaredRadius()
const 568 return BT::parameters_begin()[1];
571 double& squaredRadius()
573 return BT::parameters_begin()[1];
576 BT::Parameter_type center()
const 578 return BT::parameters_begin()+2;
581 BT::Parameter_type center()
583 return BT::parameters_begin()+2;
586 template<
class U,
class C>
589 double result = -1 * squaredRadius();
590 if(*(BT::column_data()) == AllColumns)
592 for(
int ii = 0; ii < BT::columns_size(); ++ii)
594 result += (feature[ii] - center()[ii])*
595 (feature[ii] - center()[ii]);
600 for(
int ii = 0; ii < BT::columns_size(); ++ii)
602 result += (feature[BT::columns_begin()[ii]] - center()[ii])*
603 (feature[BT::columns_begin()[ii]] - center()[ii]);
606 return result < 0 ? BT::child(0)
626 class Node<e_ConstProbNode>
636 BT(2,topology[1]+1, topology, param)
639 BT::typeID() = e_ConstProbNode;
646 : BT(2, topology[1]+1,topology, param, n)
651 : BT(2, node_.classCount_ +1, node_)
653 BT::Parameter_type prob_begin()
const 655 return BT::parameters_begin()+1;
657 BT::Parameter_type prob_end()
const 659 return prob_begin() + prob_size();
661 int prob_size()
const 663 return BT::classCount_;
668 class Node<e_LogRegProbNode>;
672 #endif //RF_nodeproxy Topology_type columns_begin() const
Definition: rf_nodeproxy.hxx:167
INT const & child(Int32 l) const
Definition: rf_nodeproxy.hxx:231
Topology_type column_data() const
Definition: rf_nodeproxy.hxx:159
NodeBase()
Definition: rf_nodeproxy.hxx:237
NodeBase(T_Container_type const &topology, P_Container_type const ¶meter, INT n)
Definition: rf_nodeproxy.hxx:254
bool data() const
Definition: rf_nodeproxy.hxx:128
Parameter_type parameters_begin() const
Definition: rf_nodeproxy.hxx:207
int columns_size() const
Definition: rf_nodeproxy.hxx:174
INT & child(Int32 l)
Definition: rf_nodeproxy.hxx:224
Definition: accessor.hxx:43
NodeBase(int tLen, int pLen, NodeBase &node)
Definition: rf_nodeproxy.hxx:292
NodeBase(int tLen, int pLen, T_Container_type const &topology, P_Container_type const ¶meter, INT n)
Definition: rf_nodeproxy.hxx:272
Definition: rf_nodeproxy.hxx:87
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition: sized_int.hxx:175
Topology_type topology_begin() const
Definition: rf_nodeproxy.hxx:193
INT & typeID()
Definition: rf_nodeproxy.hxx:136
const_iterator begin() const
Definition: array_vector.hxx:223
Base class for, and view to, vigra::MultiArray.
Definition: multi_array.hxx:652
NodeBase(int tLen, int pLen, T_Container_type &topology, P_Container_type ¶meter)
Definition: rf_nodeproxy.hxx:316
INT & parameter_addr()
Definition: rf_nodeproxy.hxx:148
size_type size() const
Definition: array_vector.hxx:358
double & weights()
Definition: rf_nodeproxy.hxx:115
NodeBase(NodeBase const &toCopy, T_Container_type &topology, P_Container_type ¶meter)
Definition: rf_nodeproxy.hxx:356
Topology_type columns_end() const
Definition: rf_nodeproxy.hxx:184