OpenCV 2.4.8 components for OpenCVgrabber.
[mmanager-3rdparty.git] / OpenCV2.4.8 / build / include / opencv2 / ml / ml.hpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 //   * Redistribution's of source code must retain the above copyright notice,
19 //     this list of conditions and the following disclaimer.
20 //
21 //   * Redistribution's in binary form must reproduce the above copyright notice,
22 //     this list of conditions and the following disclaimer in the documentation
23 //     and/or other materials provided with the distribution.
24 //
25 //   * The name of Intel Corporation may not be used to endorse or promote products
26 //     derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40
41 #ifndef __OPENCV_ML_HPP__
42 #define __OPENCV_ML_HPP__
43
44 #include "opencv2/core/core.hpp"
45 #include <limits.h>
46
47 #ifdef __cplusplus
48
49 #include <map>
50 #include <string>
51 #include <iostream>
52
53 // Apple defines a check() macro somewhere in the debug headers
54 // that interferes with a method definiton in this header
55 #undef check
56
57 /****************************************************************************************\
58 *                               Main struct definitions                                  *
59 \****************************************************************************************/
60
61 /* log(2*PI) */
62 #define CV_LOG2PI (1.8378770664093454835606594728112)
63
64 /* columns of <trainData> matrix are training samples */
65 #define CV_COL_SAMPLE 0
66
67 /* rows of <trainData> matrix are training samples */
68 #define CV_ROW_SAMPLE 1
69
70 #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE)
71
72 struct CvVectors
73 {
74     int type;
75     int dims, count;
76     CvVectors* next;
77     union
78     {
79         uchar** ptr;
80         float** fl;
81         double** db;
82     } data;
83 };
84
85 #if 0
86 /* A structure, representing the lattice range of statmodel parameters.
87    It is used for optimizing statmodel parameters by cross-validation method.
88    The lattice is logarithmic, so <step> must be greater then 1. */
89 typedef struct CvParamLattice
90 {
91     double min_val;
92     double max_val;
93     double step;
94 }
95 CvParamLattice;
96
97 CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val,
98                                          double log_step )
99 {
100     CvParamLattice pl;
101     pl.min_val = MIN( min_val, max_val );
102     pl.max_val = MAX( min_val, max_val );
103     pl.step = MAX( log_step, 1. );
104     return pl;
105 }
106
107 CV_INLINE CvParamLattice cvDefaultParamLattice( void )
108 {
109     CvParamLattice pl = {0,0,0};
110     return pl;
111 }
112 #endif
113
114 /* Variable type */
115 #define CV_VAR_NUMERICAL    0
116 #define CV_VAR_ORDERED      0
117 #define CV_VAR_CATEGORICAL  1
118
119 #define CV_TYPE_NAME_ML_SVM         "opencv-ml-svm"
120 #define CV_TYPE_NAME_ML_KNN         "opencv-ml-knn"
121 #define CV_TYPE_NAME_ML_NBAYES      "opencv-ml-bayesian"
122 #define CV_TYPE_NAME_ML_EM          "opencv-ml-em"
123 #define CV_TYPE_NAME_ML_BOOSTING    "opencv-ml-boost-tree"
124 #define CV_TYPE_NAME_ML_TREE        "opencv-ml-tree"
125 #define CV_TYPE_NAME_ML_ANN_MLP     "opencv-ml-ann-mlp"
126 #define CV_TYPE_NAME_ML_CNN         "opencv-ml-cnn"
127 #define CV_TYPE_NAME_ML_RTREES      "opencv-ml-random-trees"
128 #define CV_TYPE_NAME_ML_ERTREES     "opencv-ml-extremely-randomized-trees"
129 #define CV_TYPE_NAME_ML_GBT         "opencv-ml-gradient-boosting-trees"
130
131 #define CV_TRAIN_ERROR  0
132 #define CV_TEST_ERROR   1
133
134 class CV_EXPORTS_W CvStatModel
135 {
136 public:
137     CvStatModel();
138     virtual ~CvStatModel();
139
140     virtual void clear();
141
142     CV_WRAP virtual void save( const char* filename, const char* name=0 ) const;
143     CV_WRAP virtual void load( const char* filename, const char* name=0 );
144
145     virtual void write( CvFileStorage* storage, const char* name ) const;
146     virtual void read( CvFileStorage* storage, CvFileNode* node );
147
148 protected:
149     const char* default_model_name;
150 };
151
152 /****************************************************************************************\
153 *                                 Normal Bayes Classifier                                *
154 \****************************************************************************************/
155
156 /* The structure, representing the grid range of statmodel parameters.
157    It is used for optimizing statmodel accuracy by varying model parameters,
158    the accuracy estimate being computed by cross-validation.
159    The grid is logarithmic, so <step> must be greater then 1. */
160
161 class CvMLData;
162
163 struct CV_EXPORTS_W_MAP CvParamGrid
164 {
165     // SVM params type
166     enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 };
167
168     CvParamGrid()
169     {
170         min_val = max_val = step = 0;
171     }
172
173     CvParamGrid( double min_val, double max_val, double log_step );
174     //CvParamGrid( int param_id );
175     bool check() const;
176
177     CV_PROP_RW double min_val;
178     CV_PROP_RW double max_val;
179     CV_PROP_RW double step;
180 };
181
182 inline CvParamGrid::CvParamGrid( double _min_val, double _max_val, double _log_step )
183 {
184     min_val = _min_val;
185     max_val = _max_val;
186     step = _log_step;
187 }
188
189 class CV_EXPORTS_W CvNormalBayesClassifier : public CvStatModel
190 {
191 public:
192     CV_WRAP CvNormalBayesClassifier();
193     virtual ~CvNormalBayesClassifier();
194
195     CvNormalBayesClassifier( const CvMat* trainData, const CvMat* responses,
196         const CvMat* varIdx=0, const CvMat* sampleIdx=0 );
197
198     virtual bool train( const CvMat* trainData, const CvMat* responses,
199         const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false );
200
201     virtual float predict( const CvMat* samples, CV_OUT CvMat* results=0 ) const;
202     CV_WRAP virtual void clear();
203
204     CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses,
205                             const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat() );
206     CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
207                        const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
208                        bool update=false );
209     CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0 ) const;
210
211     virtual void write( CvFileStorage* storage, const char* name ) const;
212     virtual void read( CvFileStorage* storage, CvFileNode* node );
213
214 protected:
215     int     var_count, var_all;
216     CvMat*  var_idx;
217     CvMat*  cls_labels;
218     CvMat** count;
219     CvMat** sum;
220     CvMat** productsum;
221     CvMat** avg;
222     CvMat** inv_eigen_values;
223     CvMat** cov_rotate_mats;
224     CvMat*  c;
225 };
226
227
228 /****************************************************************************************\
229 *                          K-Nearest Neighbour Classifier                                *
230 \****************************************************************************************/
231
232 // k Nearest Neighbors
233 class CV_EXPORTS_W CvKNearest : public CvStatModel
234 {
235 public:
236
237     CV_WRAP CvKNearest();
238     virtual ~CvKNearest();
239
240     CvKNearest( const CvMat* trainData, const CvMat* responses,
241                 const CvMat* sampleIdx=0, bool isRegression=false, int max_k=32 );
242
243     virtual bool train( const CvMat* trainData, const CvMat* responses,
244                         const CvMat* sampleIdx=0, bool is_regression=false,
245                         int maxK=32, bool updateBase=false );
246
247     virtual float find_nearest( const CvMat* samples, int k, CV_OUT CvMat* results=0,
248         const float** neighbors=0, CV_OUT CvMat* neighborResponses=0, CV_OUT CvMat* dist=0 ) const;
249
250     CV_WRAP CvKNearest( const cv::Mat& trainData, const cv::Mat& responses,
251                const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false, int max_k=32 );
252
253     CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
254                        const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false,
255                        int maxK=32, bool updateBase=false );
256
257     virtual float find_nearest( const cv::Mat& samples, int k, cv::Mat* results=0,
258                                 const float** neighbors=0, cv::Mat* neighborResponses=0,
259                                 cv::Mat* dist=0 ) const;
260     CV_WRAP virtual float find_nearest( const cv::Mat& samples, int k, CV_OUT cv::Mat& results,
261                                         CV_OUT cv::Mat& neighborResponses, CV_OUT cv::Mat& dists) const;
262
263     virtual void clear();
264     int get_max_k() const;
265     int get_var_count() const;
266     int get_sample_count() const;
267     bool is_regression() const;
268
269     virtual float write_results( int k, int k1, int start, int end,
270         const float* neighbor_responses, const float* dist, CvMat* _results,
271         CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
272
273     virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
274         float* neighbor_responses, const float** neighbors, float* dist ) const;
275
276 protected:
277
278     int max_k, var_count;
279     int total;
280     bool regression;
281     CvVectors* samples;
282 };
283
284 /****************************************************************************************\
285 *                                   Support Vector Machines                              *
286 \****************************************************************************************/
287
288 // SVM training parameters
289 struct CV_EXPORTS_W_MAP CvSVMParams
290 {
291     CvSVMParams();
292     CvSVMParams( int svm_type, int kernel_type,
293                  double degree, double gamma, double coef0,
294                  double Cvalue, double nu, double p,
295                  CvMat* class_weights, CvTermCriteria term_crit );
296
297     CV_PROP_RW int         svm_type;
298     CV_PROP_RW int         kernel_type;
299     CV_PROP_RW double      degree; // for poly
300     CV_PROP_RW double      gamma;  // for poly/rbf/sigmoid
301     CV_PROP_RW double      coef0;  // for poly/sigmoid
302
303     CV_PROP_RW double      C;  // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
304     CV_PROP_RW double      nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
305     CV_PROP_RW double      p; // for CV_SVM_EPS_SVR
306     CvMat*      class_weights; // for CV_SVM_C_SVC
307     CV_PROP_RW CvTermCriteria term_crit; // termination criteria
308 };
309
310
311 struct CV_EXPORTS CvSVMKernel
312 {
313     typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
314                                        const float* another, float* results );
315     CvSVMKernel();
316     CvSVMKernel( const CvSVMParams* params, Calc _calc_func );
317     virtual bool create( const CvSVMParams* params, Calc _calc_func );
318     virtual ~CvSVMKernel();
319
320     virtual void clear();
321     virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
322
323     const CvSVMParams* params;
324     Calc calc_func;
325
326     virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
327                                     const float* another, float* results,
328                                     double alpha, double beta );
329
330     virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
331                               const float* another, float* results );
332     virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
333                            const float* another, float* results );
334     virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
335                             const float* another, float* results );
336     virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
337                                const float* another, float* results );
338 };
339
340
341 struct CvSVMKernelRow
342 {
343     CvSVMKernelRow* prev;
344     CvSVMKernelRow* next;
345     float* data;
346 };
347
348
349 struct CvSVMSolutionInfo
350 {
351     double obj;
352     double rho;
353     double upper_bound_p;
354     double upper_bound_n;
355     double r;   // for Solver_NU
356 };
357
358 class CV_EXPORTS CvSVMSolver
359 {
360 public:
361     typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
362     typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
363     typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );
364
365     CvSVMSolver();
366
367     CvSVMSolver( int count, int var_count, const float** samples, schar* y,
368                  int alpha_count, double* alpha, double Cp, double Cn,
369                  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
370                  SelectWorkingSet select_working_set, CalcRho calc_rho );
371     virtual bool create( int count, int var_count, const float** samples, schar* y,
372                  int alpha_count, double* alpha, double Cp, double Cn,
373                  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
374                  SelectWorkingSet select_working_set, CalcRho calc_rho );
375     virtual ~CvSVMSolver();
376
377     virtual void clear();
378     virtual bool solve_generic( CvSVMSolutionInfo& si );
379
380     virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y,
381                               double Cp, double Cn, CvMemStorage* storage,
382                               CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
383     virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y,
384                                CvMemStorage* storage, CvSVMKernel* kernel,
385                                double* alpha, CvSVMSolutionInfo& si );
386     virtual bool solve_one_class( int count, int var_count, const float** samples,
387                                   CvMemStorage* storage, CvSVMKernel* kernel,
388                                   double* alpha, CvSVMSolutionInfo& si );
389
390     virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
391                                 CvMemStorage* storage, CvSVMKernel* kernel,
392                                 double* alpha, CvSVMSolutionInfo& si );
393
394     virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
395                                CvMemStorage* storage, CvSVMKernel* kernel,
396                                double* alpha, CvSVMSolutionInfo& si );
397
398     virtual float* get_row_base( int i, bool* _existed );
399     virtual float* get_row( int i, float* dst );
400
401     int sample_count;
402     int var_count;
403     int cache_size;
404     int cache_line_size;
405     const float** samples;
406     const CvSVMParams* params;
407     CvMemStorage* storage;
408     CvSVMKernelRow lru_list;
409     CvSVMKernelRow* rows;
410
411     int alpha_count;
412
413     double* G;
414     double* alpha;
415
416     // -1 - lower bound, 0 - free, 1 - upper bound
417     schar* alpha_status;
418
419     schar* y;
420     double* b;
421     float* buf[2];
422     double eps;
423     int max_iter;
424     double C[2];  // C[0] == Cn, C[1] == Cp
425     CvSVMKernel* kernel;
426
427     SelectWorkingSet select_working_set_func;
428     CalcRho calc_rho_func;
429     GetRow get_row_func;
430
431     virtual bool select_working_set( int& i, int& j );
432     virtual bool select_working_set_nu_svm( int& i, int& j );
433     virtual void calc_rho( double& rho, double& r );
434     virtual void calc_rho_nu_svm( double& rho, double& r );
435
436     virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
437     virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
438     virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
439 };
440
441
442 struct CvSVMDecisionFunc
443 {
444     double rho;
445     int sv_count;
446     double* alpha;
447     int* sv_index;
448 };
449
450
451 // SVM model
452 class CV_EXPORTS_W CvSVM : public CvStatModel
453 {
454 public:
455     // SVM type
456     enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };
457
458     // SVM kernel type
459     enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3 };
460
461     // SVM params type
462     enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 };
463
464     CV_WRAP CvSVM();
465     virtual ~CvSVM();
466
467     CvSVM( const CvMat* trainData, const CvMat* responses,
468            const CvMat* varIdx=0, const CvMat* sampleIdx=0,
469            CvSVMParams params=CvSVMParams() );
470
471     virtual bool train( const CvMat* trainData, const CvMat* responses,
472                         const CvMat* varIdx=0, const CvMat* sampleIdx=0,
473                         CvSVMParams params=CvSVMParams() );
474
475     virtual bool train_auto( const CvMat* trainData, const CvMat* responses,
476         const CvMat* varIdx, const CvMat* sampleIdx, CvSVMParams params,
477         int kfold = 10,
478         CvParamGrid Cgrid      = get_default_grid(CvSVM::C),
479         CvParamGrid gammaGrid  = get_default_grid(CvSVM::GAMMA),
480         CvParamGrid pGrid      = get_default_grid(CvSVM::P),
481         CvParamGrid nuGrid     = get_default_grid(CvSVM::NU),
482         CvParamGrid coeffGrid  = get_default_grid(CvSVM::COEF),
483         CvParamGrid degreeGrid = get_default_grid(CvSVM::DEGREE),
484         bool balanced=false );
485
486     virtual float predict( const CvMat* sample, bool returnDFVal=false ) const;
487     virtual float predict( const CvMat* samples, CV_OUT CvMat* results ) const;
488
489     CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses,
490           const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
491           CvSVMParams params=CvSVMParams() );
492
493     CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
494                        const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
495                        CvSVMParams params=CvSVMParams() );
496
497     CV_WRAP virtual bool train_auto( const cv::Mat& trainData, const cv::Mat& responses,
498                             const cv::Mat& varIdx, const cv::Mat& sampleIdx, CvSVMParams params,
499                             int k_fold = 10,
500                             CvParamGrid Cgrid      = CvSVM::get_default_grid(CvSVM::C),
501                             CvParamGrid gammaGrid  = CvSVM::get_default_grid(CvSVM::GAMMA),
502                             CvParamGrid pGrid      = CvSVM::get_default_grid(CvSVM::P),
503                             CvParamGrid nuGrid     = CvSVM::get_default_grid(CvSVM::NU),
504                             CvParamGrid coeffGrid  = CvSVM::get_default_grid(CvSVM::COEF),
505                             CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE),
506                             bool balanced=false);
507     CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const;
508     CV_WRAP_AS(predict_all) void predict( cv::InputArray samples, cv::OutputArray results ) const;
509
510     CV_WRAP virtual int get_support_vector_count() const;
511     virtual const float* get_support_vector(int i) const;
512     virtual CvSVMParams get_params() const { return params; };
513     CV_WRAP virtual void clear();
514
515     static CvParamGrid get_default_grid( int param_id );
516
517     virtual void write( CvFileStorage* storage, const char* name ) const;
518     virtual void read( CvFileStorage* storage, CvFileNode* node );
519     CV_WRAP int get_var_count() const { return var_idx ? var_idx->cols : var_all; }
520
521 protected:
522
523     virtual bool set_params( const CvSVMParams& params );
524     virtual bool train1( int sample_count, int var_count, const float** samples,
525                     const void* responses, double Cp, double Cn,
526                     CvMemStorage* _storage, double* alpha, double& rho );
527     virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
528                     const CvMat* responses, CvMemStorage* _storage, double* alpha );
529     virtual void create_kernel();
530     virtual void create_solver();
531
532     virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const;
533
534     virtual void write_params( CvFileStorage* fs ) const;
535     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
536
537     void optimize_linear_svm();
538
539     CvSVMParams params;
540     CvMat* class_labels;
541     int var_all;
542     float** sv;
543     int sv_total;
544     CvMat* var_idx;
545     CvMat* class_weights;
546     CvSVMDecisionFunc* decision_func;
547     CvMemStorage* storage;
548
549     CvSVMSolver* solver;
550     CvSVMKernel* kernel;
551
552 private:
553     CvSVM(const CvSVM&);
554     CvSVM& operator = (const CvSVM&);
555 };
556
557 /****************************************************************************************\
558 *                              Expectation - Maximization                                *
559 \****************************************************************************************/
560 namespace cv
561 {
562 class CV_EXPORTS_W EM : public Algorithm
563 {
564 public:
565     // Type of covariation matrices
566     enum {COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2, COV_MAT_DEFAULT=COV_MAT_DIAGONAL};
567
568     // Default parameters
569     enum {DEFAULT_NCLUSTERS=5, DEFAULT_MAX_ITERS=100};
570
571     // The initial step
572     enum {START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0};
573
574     CV_WRAP EM(int nclusters=EM::DEFAULT_NCLUSTERS, int covMatType=EM::COV_MAT_DIAGONAL,
575        const TermCriteria& termCrit=TermCriteria(TermCriteria::COUNT+TermCriteria::EPS,
576                                                  EM::DEFAULT_MAX_ITERS, FLT_EPSILON));
577
578     virtual ~EM();
579     CV_WRAP virtual void clear();
580
581     CV_WRAP virtual bool train(InputArray samples,
582                        OutputArray logLikelihoods=noArray(),
583                        OutputArray labels=noArray(),
584                        OutputArray probs=noArray());
585
586     CV_WRAP virtual bool trainE(InputArray samples,
587                         InputArray means0,
588                         InputArray covs0=noArray(),
589                         InputArray weights0=noArray(),
590                         OutputArray logLikelihoods=noArray(),
591                         OutputArray labels=noArray(),
592                         OutputArray probs=noArray());
593
594     CV_WRAP virtual bool trainM(InputArray samples,
595                         InputArray probs0,
596                         OutputArray logLikelihoods=noArray(),
597                         OutputArray labels=noArray(),
598                         OutputArray probs=noArray());
599
600     CV_WRAP Vec2d predict(InputArray sample,
601                 OutputArray probs=noArray()) const;
602
603     CV_WRAP bool isTrained() const;
604
605     AlgorithmInfo* info() const;
606     virtual void read(const FileNode& fn);
607
608 protected:
609
610     virtual void setTrainData(int startStep, const Mat& samples,
611                               const Mat* probs0,
612                               const Mat* means0,
613                               const vector<Mat>* covs0,
614                               const Mat* weights0);
615
616     bool doTrain(int startStep,
617                  OutputArray logLikelihoods,
618                  OutputArray labels,
619                  OutputArray probs);
620     virtual void eStep();
621     virtual void mStep();
622
623     void clusterTrainSamples();
624     void decomposeCovs();
625     void computeLogWeightDivDet();
626
627     Vec2d computeProbabilities(const Mat& sample, Mat* probs) const;
628
629     // all inner matrices have type CV_64FC1
630     CV_PROP_RW int nclusters;
631     CV_PROP_RW int covMatType;
632     CV_PROP_RW int maxIters;
633     CV_PROP_RW double epsilon;
634
635     Mat trainSamples;
636     Mat trainProbs;
637     Mat trainLogLikelihoods;
638     Mat trainLabels;
639
640     CV_PROP Mat weights;
641     CV_PROP Mat means;
642     CV_PROP vector<Mat> covs;
643
644     vector<Mat> covsEigenValues;
645     vector<Mat> covsRotateMats;
646     vector<Mat> invCovsEigenValues;
647     Mat logWeightDivDet;
648 };
649 } // namespace cv
650
651 /****************************************************************************************\
652 *                                      Decision Tree                                     *
653 \****************************************************************************************/\
654 struct CvPair16u32s
655 {
656     unsigned short* u;
657     int* i;
658 };
659
660
661 #define CV_DTREE_CAT_DIR(idx,subset) \
662     (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
663
664 struct CvDTreeSplit
665 {
666     int var_idx;
667     int condensed_idx;
668     int inversed;
669     float quality;
670     CvDTreeSplit* next;
671     union
672     {
673         int subset[2];
674         struct
675         {
676             float c;
677             int split_point;
678         }
679         ord;
680     };
681 };
682
683 struct CvDTreeNode
684 {
685     int class_idx;
686     int Tn;
687     double value;
688
689     CvDTreeNode* parent;
690     CvDTreeNode* left;
691     CvDTreeNode* right;
692
693     CvDTreeSplit* split;
694
695     int sample_count;
696     int depth;
697     int* num_valid;
698     int offset;
699     int buf_idx;
700     double maxlr;
701
702     // global pruning data
703     int complexity;
704     double alpha;
705     double node_risk, tree_risk, tree_error;
706
707     // cross-validation pruning data
708     int* cv_Tn;
709     double* cv_node_risk;
710     double* cv_node_error;
711
712     int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; }
713     void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; }
714 };
715
716
717 struct CV_EXPORTS_W_MAP CvDTreeParams
718 {
719     CV_PROP_RW int   max_categories;
720     CV_PROP_RW int   max_depth;
721     CV_PROP_RW int   min_sample_count;
722     CV_PROP_RW int   cv_folds;
723     CV_PROP_RW bool  use_surrogates;
724     CV_PROP_RW bool  use_1se_rule;
725     CV_PROP_RW bool  truncate_pruned_tree;
726     CV_PROP_RW float regression_accuracy;
727     const float* priors;
728
729     CvDTreeParams();
730     CvDTreeParams( int max_depth, int min_sample_count,
731                    float regression_accuracy, bool use_surrogates,
732                    int max_categories, int cv_folds,
733                    bool use_1se_rule, bool truncate_pruned_tree,
734                    const float* priors );
735 };
736
737
738 struct CV_EXPORTS CvDTreeTrainData
739 {
740     CvDTreeTrainData();
741     CvDTreeTrainData( const CvMat* trainData, int tflag,
742                       const CvMat* responses, const CvMat* varIdx=0,
743                       const CvMat* sampleIdx=0, const CvMat* varType=0,
744                       const CvMat* missingDataMask=0,
745                       const CvDTreeParams& params=CvDTreeParams(),
746                       bool _shared=false, bool _add_labels=false );
747     virtual ~CvDTreeTrainData();
748
749     virtual void set_data( const CvMat* trainData, int tflag,
750                           const CvMat* responses, const CvMat* varIdx=0,
751                           const CvMat* sampleIdx=0, const CvMat* varType=0,
752                           const CvMat* missingDataMask=0,
753                           const CvDTreeParams& params=CvDTreeParams(),
754                           bool _shared=false, bool _add_labels=false,
755                           bool _update_data=false );
756     virtual void do_responses_copy();
757
758     virtual void get_vectors( const CvMat* _subsample_idx,
759          float* values, uchar* missing, float* responses, bool get_class_idx=false );
760
761     virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
762
763     virtual void write_params( CvFileStorage* fs ) const;
764     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
765
766     // release all the data
767     virtual void clear();
768
769     int get_num_classes() const;
770     int get_var_type(int vi) const;
771     int get_work_var_count() const {return work_var_count;}
772
773     virtual const float* get_ord_responses( CvDTreeNode* n, float* values_buf, int* sample_indices_buf );
774     virtual const int* get_class_labels( CvDTreeNode* n, int* labels_buf );
775     virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
776     virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
777     virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
778     virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf,
779                                    const float** ord_values, const int** sorted_indices, int* sample_indices_buf );
780     virtual int get_child_buf_idx( CvDTreeNode* n );
781
782     ////////////////////////////////////
783
784     virtual bool set_params( const CvDTreeParams& params );
785     virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count,
786                                    int storage_idx, int offset );
787
788     virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val,
789                 int split_point, int inversed, float quality );
790     virtual CvDTreeSplit* new_split_cat( int vi, float quality );
791     virtual void free_node_data( CvDTreeNode* node );
792     virtual void free_train_data();
793     virtual void free_node( CvDTreeNode* node );
794
795     int sample_count, var_all, var_count, max_c_count;
796     int ord_var_count, cat_var_count, work_var_count;
797     bool have_labels, have_priors;
798     bool is_classifier;
799     int tflag;
800
801     const CvMat* train_data;
802     const CvMat* responses;
803     CvMat* responses_copy; // used in Boosting
804
805     int buf_count, buf_size; // buf_size is obsolete, please do not use it, use expression ((int64)buf->rows * (int64)buf->cols / buf_count) instead
806     bool shared;
807     int is_buf_16u;
808
809     CvMat* cat_count;
810     CvMat* cat_ofs;
811     CvMat* cat_map;
812
813     CvMat* counts;
814     CvMat* buf;
815     inline size_t get_length_subbuf() const
816     {
817         size_t res = (size_t)(work_var_count + 1) * (size_t)sample_count;
818         return res;
819     }
820
821     CvMat* direction;
822     CvMat* split_buf;
823
824     CvMat* var_idx;
825     CvMat* var_type; // i-th element =
826                      //   k<0  - ordered
827                      //   k>=0 - categorical, see k-th element of cat_* arrays
828     CvMat* priors;
829     CvMat* priors_mult;
830
831     CvDTreeParams params;
832
833     CvMemStorage* tree_storage;
834     CvMemStorage* temp_storage;
835
836     CvDTreeNode* data_root;
837
838     CvSet* node_heap;
839     CvSet* split_heap;
840     CvSet* cv_heap;
841     CvSet* nv_heap;
842
843     cv::RNG* rng;
844 };
845
846 class CvDTree;
847 class CvForestTree;
848
849 namespace cv
850 {
851     struct DTreeBestSplitFinder;
852     struct ForestTreeBestSplitFinder;
853 }
854
855 class CV_EXPORTS_W CvDTree : public CvStatModel
856 {
857 public:
858     CV_WRAP CvDTree();
859     virtual ~CvDTree();
860
861     virtual bool train( const CvMat* trainData, int tflag,
862                         const CvMat* responses, const CvMat* varIdx=0,
863                         const CvMat* sampleIdx=0, const CvMat* varType=0,
864                         const CvMat* missingDataMask=0,
865                         CvDTreeParams params=CvDTreeParams() );
866
867     virtual bool train( CvMLData* trainData, CvDTreeParams params=CvDTreeParams() );
868
869     // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
870     virtual float calc_error( CvMLData* trainData, int type, std::vector<float> *resp = 0 );
871
872     virtual bool train( CvDTreeTrainData* trainData, const CvMat* subsampleIdx );
873
874     virtual CvDTreeNode* predict( const CvMat* sample, const CvMat* missingDataMask=0,
875                                   bool preprocessedInput=false ) const;
876
877     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
878                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
879                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
880                        const cv::Mat& missingDataMask=cv::Mat(),
881                        CvDTreeParams params=CvDTreeParams() );
882
883     CV_WRAP virtual CvDTreeNode* predict( const cv::Mat& sample, const cv::Mat& missingDataMask=cv::Mat(),
884                                   bool preprocessedInput=false ) const;
885     CV_WRAP virtual cv::Mat getVarImportance();
886
887     virtual const CvMat* get_var_importance();
888     CV_WRAP virtual void clear();
889
890     virtual void read( CvFileStorage* fs, CvFileNode* node );
891     virtual void write( CvFileStorage* fs, const char* name ) const;
892
893     // special read & write methods for trees in the tree ensembles
894     virtual void read( CvFileStorage* fs, CvFileNode* node,
895                        CvDTreeTrainData* data );
896     virtual void write( CvFileStorage* fs ) const;
897
898     const CvDTreeNode* get_root() const;
899     int get_pruned_tree_idx() const;
900     CvDTreeTrainData* get_data();
901
902 protected:
903     friend struct cv::DTreeBestSplitFinder;
904
905     virtual bool do_train( const CvMat* _subsample_idx );
906
907     virtual void try_split_node( CvDTreeNode* n );
908     virtual void split_node_data( CvDTreeNode* n );
909     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
910     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
911                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
912     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
913                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
914     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
915                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
916     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
917                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
918     virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
919     virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
920     virtual double calc_node_dir( CvDTreeNode* node );
921     virtual void complete_node_dir( CvDTreeNode* node );
922     virtual void cluster_categories( const int* vectors, int vector_count,
923         int var_count, int* sums, int k, int* cluster_labels );
924
925     virtual void calc_node_value( CvDTreeNode* node );
926
927     virtual void prune_cv();
928     virtual double update_tree_rnc( int T, int fold );
929     virtual int cut_tree( int T, int fold, double min_alpha );
930     virtual void free_prune_data(bool cut_tree);
931     virtual void free_tree();
932
933     virtual void write_node( CvFileStorage* fs, CvDTreeNode* node ) const;
934     virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split ) const;
935     virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent );
936     virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node );
937     virtual void write_tree_nodes( CvFileStorage* fs ) const;
938     virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node );
939
940     CvDTreeNode* root;
941     CvMat* var_importance;
942     CvDTreeTrainData* data;
943
944 public:
945     int pruned_tree_idx;
946 };
947
948
949 /****************************************************************************************\
950 *                                   Random Trees Classifier                              *
951 \****************************************************************************************/
952
953 class CvRTrees;
954
955 class CV_EXPORTS CvForestTree: public CvDTree
956 {
957 public:
958     CvForestTree();
959     virtual ~CvForestTree();
960
961     virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx, CvRTrees* forest );
962
963     virtual int get_var_count() const {return data ? data->var_count : 0;}
964     virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data );
965
966     /* dummy methods to avoid warnings: BEGIN */
967     virtual bool train( const CvMat* trainData, int tflag,
968                         const CvMat* responses, const CvMat* varIdx=0,
969                         const CvMat* sampleIdx=0, const CvMat* varType=0,
970                         const CvMat* missingDataMask=0,
971                         CvDTreeParams params=CvDTreeParams() );
972
973     virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
974     virtual void read( CvFileStorage* fs, CvFileNode* node );
975     virtual void read( CvFileStorage* fs, CvFileNode* node,
976                        CvDTreeTrainData* data );
977     /* dummy methods to avoid warnings: END */
978
979 protected:
980     friend struct cv::ForestTreeBestSplitFinder;
981
982     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
983     CvRTrees* forest;
984 };
985
986
987 struct CV_EXPORTS_W_MAP CvRTParams : public CvDTreeParams
988 {
989     //Parameters for the forest
990     CV_PROP_RW bool calc_var_importance; // true <=> RF processes variable importance
991     CV_PROP_RW int nactive_vars;
992     CV_PROP_RW CvTermCriteria term_crit;
993
994     CvRTParams();
995     CvRTParams( int max_depth, int min_sample_count,
996                 float regression_accuracy, bool use_surrogates,
997                 int max_categories, const float* priors, bool calc_var_importance,
998                 int nactive_vars, int max_num_of_trees_in_the_forest,
999                 float forest_accuracy, int termcrit_type );
1000 };
1001
1002
1003 class CV_EXPORTS_W CvRTrees : public CvStatModel
1004 {
1005 public:
1006     CV_WRAP CvRTrees();
1007     virtual ~CvRTrees();
1008     virtual bool train( const CvMat* trainData, int tflag,
1009                         const CvMat* responses, const CvMat* varIdx=0,
1010                         const CvMat* sampleIdx=0, const CvMat* varType=0,
1011                         const CvMat* missingDataMask=0,
1012                         CvRTParams params=CvRTParams() );
1013
1014     virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
1015     virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
1016     virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const;
1017
1018     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
1019                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1020                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1021                        const cv::Mat& missingDataMask=cv::Mat(),
1022                        CvRTParams params=CvRTParams() );
1023     CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
1024     CV_WRAP virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
1025     CV_WRAP virtual cv::Mat getVarImportance();
1026
1027     CV_WRAP virtual void clear();
1028
1029     virtual const CvMat* get_var_importance();
1030     virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
1031         const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;
1032
1033     virtual float calc_error( CvMLData* data, int type , std::vector<float>* resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
1034
1035     virtual float get_train_error();
1036
1037     virtual void read( CvFileStorage* fs, CvFileNode* node );
1038     virtual void write( CvFileStorage* fs, const char* name ) const;
1039
1040     CvMat* get_active_var_mask();
1041     CvRNG* get_rng();
1042
1043     int get_tree_count() const;
1044     CvForestTree* get_tree(int i) const;
1045
1046 protected:
1047     virtual std::string getName() const;
1048
1049     virtual bool grow_forest( const CvTermCriteria term_crit );
1050
1051     // array of the trees of the forest
1052     CvForestTree** trees;
1053     CvDTreeTrainData* data;
1054     int ntrees;
1055     int nclasses;
1056     double oob_error;
1057     CvMat* var_importance;
1058     int nsamples;
1059
1060     cv::RNG* rng;
1061     CvMat* active_var_mask;
1062 };
1063
1064 /****************************************************************************************\
1065 *                           Extremely randomized trees Classifier                        *
1066 \****************************************************************************************/
1067 struct CV_EXPORTS CvERTreeTrainData : public CvDTreeTrainData
1068 {
1069     virtual void set_data( const CvMat* trainData, int tflag,
1070                           const CvMat* responses, const CvMat* varIdx=0,
1071                           const CvMat* sampleIdx=0, const CvMat* varType=0,
1072                           const CvMat* missingDataMask=0,
1073                           const CvDTreeParams& params=CvDTreeParams(),
1074                           bool _shared=false, bool _add_labels=false,
1075                           bool _update_data=false );
1076     virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf,
1077                                    const float** ord_values, const int** missing, int* sample_buf = 0 );
1078     virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
1079     virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
1080     virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
1081     virtual void get_vectors( const CvMat* _subsample_idx, float* values, uchar* missing,
1082                               float* responses, bool get_class_idx=false );
1083     virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
1084     const CvMat* missing_mask;
1085 };
1086
1087 class CV_EXPORTS CvForestERTree : public CvForestTree
1088 {
1089 protected:
1090     virtual double calc_node_dir( CvDTreeNode* node );
1091     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
1092         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1093     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
1094         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1095     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
1096         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1097     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
1098         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1099     virtual void split_node_data( CvDTreeNode* n );
1100 };
1101
1102 class CV_EXPORTS_W CvERTrees : public CvRTrees
1103 {
1104 public:
1105     CV_WRAP CvERTrees();
1106     virtual ~CvERTrees();
1107     virtual bool train( const CvMat* trainData, int tflag,
1108                         const CvMat* responses, const CvMat* varIdx=0,
1109                         const CvMat* sampleIdx=0, const CvMat* varType=0,
1110                         const CvMat* missingDataMask=0,
1111                         CvRTParams params=CvRTParams());
1112     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
1113                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1114                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1115                        const cv::Mat& missingDataMask=cv::Mat(),
1116                        CvRTParams params=CvRTParams());
1117     virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
1118 protected:
1119     virtual std::string getName() const;
1120     virtual bool grow_forest( const CvTermCriteria term_crit );
1121 };
1122
1123
1124 /****************************************************************************************\
1125 *                                   Boosted tree classifier                              *
1126 \****************************************************************************************/
1127
1128 struct CV_EXPORTS_W_MAP CvBoostParams : public CvDTreeParams
1129 {
1130     CV_PROP_RW int boost_type;
1131     CV_PROP_RW int weak_count;
1132     CV_PROP_RW int split_criteria;
1133     CV_PROP_RW double weight_trim_rate;
1134
1135     CvBoostParams();
1136     CvBoostParams( int boost_type, int weak_count, double weight_trim_rate,
1137                    int max_depth, bool use_surrogates, const float* priors );
1138 };
1139
1140
1141 class CvBoost;
1142
1143 class CV_EXPORTS CvBoostTree: public CvDTree
1144 {
1145 public:
1146     CvBoostTree();
1147     virtual ~CvBoostTree();
1148
1149     virtual bool train( CvDTreeTrainData* trainData,
1150                         const CvMat* subsample_idx, CvBoost* ensemble );
1151
1152     virtual void scale( double s );
1153     virtual void read( CvFileStorage* fs, CvFileNode* node,
1154                        CvBoost* ensemble, CvDTreeTrainData* _data );
1155     virtual void clear();
1156
1157     /* dummy methods to avoid warnings: BEGIN */
1158     virtual bool train( const CvMat* trainData, int tflag,
1159                         const CvMat* responses, const CvMat* varIdx=0,
1160                         const CvMat* sampleIdx=0, const CvMat* varType=0,
1161                         const CvMat* missingDataMask=0,
1162                         CvDTreeParams params=CvDTreeParams() );
1163     virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
1164
1165     virtual void read( CvFileStorage* fs, CvFileNode* node );
1166     virtual void read( CvFileStorage* fs, CvFileNode* node,
1167                        CvDTreeTrainData* data );
1168     /* dummy methods to avoid warnings: END */
1169
1170 protected:
1171
1172     virtual void try_split_node( CvDTreeNode* n );
1173     virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
1174     virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
1175     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
1176         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1177     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
1178         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1179     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
1180         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1181     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
1182         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1183     virtual void calc_node_value( CvDTreeNode* n );
1184     virtual double calc_node_dir( CvDTreeNode* n );
1185
1186     CvBoost* ensemble;
1187 };
1188
1189
1190 class CV_EXPORTS_W CvBoost : public CvStatModel
1191 {
1192 public:
1193     // Boosting type
1194     enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
1195
1196     // Splitting criteria
1197     enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 };
1198
1199     CV_WRAP CvBoost();
1200     virtual ~CvBoost();
1201
1202     CvBoost( const CvMat* trainData, int tflag,
1203              const CvMat* responses, const CvMat* varIdx=0,
1204              const CvMat* sampleIdx=0, const CvMat* varType=0,
1205              const CvMat* missingDataMask=0,
1206              CvBoostParams params=CvBoostParams() );
1207
1208     virtual bool train( const CvMat* trainData, int tflag,
1209              const CvMat* responses, const CvMat* varIdx=0,
1210              const CvMat* sampleIdx=0, const CvMat* varType=0,
1211              const CvMat* missingDataMask=0,
1212              CvBoostParams params=CvBoostParams(),
1213              bool update=false );
1214
1215     virtual bool train( CvMLData* data,
1216              CvBoostParams params=CvBoostParams(),
1217              bool update=false );
1218
1219     virtual float predict( const CvMat* sample, const CvMat* missing=0,
1220                            CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
1221                            bool raw_mode=false, bool return_sum=false ) const;
1222
1223     CV_WRAP CvBoost( const cv::Mat& trainData, int tflag,
1224             const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1225             const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1226             const cv::Mat& missingDataMask=cv::Mat(),
1227             CvBoostParams params=CvBoostParams() );
1228
1229     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
1230                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1231                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1232                        const cv::Mat& missingDataMask=cv::Mat(),
1233                        CvBoostParams params=CvBoostParams(),
1234                        bool update=false );
1235
1236     CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
1237                                    const cv::Range& slice=cv::Range::all(), bool rawMode=false,
1238                                    bool returnSum=false ) const;
1239
1240     virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
1241
1242     CV_WRAP virtual void prune( CvSlice slice );
1243
1244     CV_WRAP virtual void clear();
1245
1246     virtual void write( CvFileStorage* storage, const char* name ) const;
1247     virtual void read( CvFileStorage* storage, CvFileNode* node );
1248     virtual const CvMat* get_active_vars(bool absolute_idx=true);
1249
1250     CvSeq* get_weak_predictors();
1251
1252     CvMat* get_weights();
1253     CvMat* get_subtree_weights();
1254     CvMat* get_weak_response();
1255     const CvBoostParams& get_params() const;
1256     const CvDTreeTrainData* get_data() const;
1257
1258 protected:
1259
1260     void update_weights_impl( CvBoostTree* tree, double initial_weights[2] );
1261
1262     virtual bool set_params( const CvBoostParams& params );
1263     virtual void update_weights( CvBoostTree* tree );
1264     virtual void trim_weights();
1265     virtual void write_params( CvFileStorage* fs ) const;
1266     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
1267
1268     CvDTreeTrainData* data;
1269     CvBoostParams params;
1270     CvSeq* weak;
1271
1272     CvMat* active_vars;
1273     CvMat* active_vars_abs;
1274     bool have_active_cat_vars;
1275
1276     CvMat* orig_response;
1277     CvMat* sum_response;
1278     CvMat* weak_eval;
1279     CvMat* subsample_mask;
1280     CvMat* weights;
1281     CvMat* subtree_weights;
1282     bool have_subsample;
1283 };
1284
1285
1286 /****************************************************************************************\
1287 *                                   Gradient Boosted Trees                               *
1288 \****************************************************************************************/
1289
1290 // DataType: STRUCT CvGBTreesParams
1291 // Parameters of GBT (Gradient Boosted trees model), including single
1292 // tree settings and ensemble parameters.
1293 //
1294 // weak_count          - count of trees in the ensemble
1295 // loss_function_type  - loss function used for ensemble training
1296 // subsample_portion   - portion of whole training set used for
1297 //                       every single tree training.
1298 //                       subsample_portion value is in (0.0, 1.0].
1299 //                       subsample_portion == 1.0 when whole dataset is
1300 //                       used on each step. Count of sample used on each
1301 //                       step is computed as
1302 //                       int(total_samples_count * subsample_portion).
1303 // shrinkage           - regularization parameter.
1304 //                       Each tree prediction is multiplied on shrinkage value.
1305
1306
1307 struct CV_EXPORTS_W_MAP CvGBTreesParams : public CvDTreeParams
1308 {
1309     CV_PROP_RW int weak_count;
1310     CV_PROP_RW int loss_function_type;
1311     CV_PROP_RW float subsample_portion;
1312     CV_PROP_RW float shrinkage;
1313
1314     CvGBTreesParams();
1315     CvGBTreesParams( int loss_function_type, int weak_count, float shrinkage,
1316         float subsample_portion, int max_depth, bool use_surrogates );
1317 };
1318
1319 // DataType: CLASS CvGBTrees
1320 // Gradient Boosting Trees (GBT) algorithm implementation.
1321 //
1322 // data             - training dataset
1323 // params           - parameters of the CvGBTrees
1324 // weak             - array[0..(class_count-1)] of CvSeq
1325 //                    for storing tree ensembles
1326 // orig_response    - original responses of the training set samples
1327 // sum_response     - predicitons of the current model on the training dataset.
1328 //                    this matrix is updated on every iteration.
1329 // sum_response_tmp - predicitons of the model on the training set on the next
1330 //                    step. On every iteration values of sum_responses_tmp are
1331 //                    computed via sum_responses values. When the current
1332 //                    step is complete sum_response values become equal to
1333 //                    sum_responses_tmp.
1334 // sampleIdx       - indices of samples used for training the ensemble.
1335 //                    CvGBTrees training procedure takes a set of samples
1336 //                    (train_data) and a set of responses (responses).
1337 //                    Only pairs (train_data[i], responses[i]), where i is
1338 //                    in sample_idx are used for training the ensemble.
1339 // subsample_train  - indices of samples used for training a single decision
1340 //                    tree on the current step. This indices are countered
1341 //                    relatively to the sample_idx, so that pairs
1342 //                    (train_data[sample_idx[i]], responses[sample_idx[i]])
1343 //                    are used for training a decision tree.
1344 //                    Training set is randomly splited
1345 //                    in two parts (subsample_train and subsample_test)
1346 //                    on every iteration accordingly to the portion parameter.
1347 // subsample_test   - relative indices of samples from the training set,
1348 //                    which are not used for training a tree on the current
1349 //                    step.
1350 // missing          - mask of the missing values in the training set. This
1351 //                    matrix has the same size as train_data. 1 - missing
1352 //                    value, 0 - not a missing value.
1353 // class_labels     - output class labels map.
1354 // rng              - random number generator. Used for spliting the
1355 //                    training set.
1356 // class_count      - count of output classes.
1357 //                    class_count == 1 in the case of regression,
1358 //                    and > 1 in the case of classification.
1359 // delta            - Huber loss function parameter.
1360 // base_value       - start point of the gradient descent procedure.
1361 //                    model prediction is
1362 //                    f(x) = f_0 + sum_{i=1..weak_count-1}(f_i(x)), where
1363 //                    f_0 is the base value.
1364
1365
1366
1367 class CV_EXPORTS_W CvGBTrees : public CvStatModel
1368 {
1369 public:
1370
1371     /*
1372     // DataType: ENUM
1373     // Loss functions implemented in CvGBTrees.
1374     //
1375     // SQUARED_LOSS
1376     // problem: regression
1377     // loss = (x - x')^2
1378     //
1379     // ABSOLUTE_LOSS
1380     // problem: regression
1381     // loss = abs(x - x')
1382     //
1383     // HUBER_LOSS
1384     // problem: regression
1385     // loss = delta*( abs(x - x') - delta/2), if abs(x - x') > delta
1386     //           1/2*(x - x')^2, if abs(x - x') <= delta,
1387     //           where delta is the alpha-quantile of pseudo responses from
1388     //           the training set.
1389     //
1390     // DEVIANCE_LOSS
1391     // problem: classification
1392     //
1393     */
1394     enum {SQUARED_LOSS=0, ABSOLUTE_LOSS, HUBER_LOSS=3, DEVIANCE_LOSS};
1395
1396
1397     /*
1398     // Default constructor. Creates a model only (without training).
1399     // Should be followed by one form of the train(...) function.
1400     //
1401     // API
1402     // CvGBTrees();
1403
1404     // INPUT
1405     // OUTPUT
1406     // RESULT
1407     */
1408     CV_WRAP CvGBTrees();
1409
1410
1411     /*
1412     // Full form constructor. Creates a gradient boosting model and does the
1413     // train.
1414     //
1415     // API
1416     // CvGBTrees( const CvMat* trainData, int tflag,
1417              const CvMat* responses, const CvMat* varIdx=0,
1418              const CvMat* sampleIdx=0, const CvMat* varType=0,
1419              const CvMat* missingDataMask=0,
1420              CvGBTreesParams params=CvGBTreesParams() );
1421
1422     // INPUT
1423     // trainData    - a set of input feature vectors.
1424     //                  size of matrix is
1425     //                  <count of samples> x <variables count>
1426     //                  or <variables count> x <count of samples>
1427     //                  depending on the tflag parameter.
1428     //                  matrix values are float.
1429     // tflag         - a flag showing how do samples stored in the
1430     //                  trainData matrix row by row (tflag=CV_ROW_SAMPLE)
1431     //                  or column by column (tflag=CV_COL_SAMPLE).
1432     // responses     - a vector of responses corresponding to the samples
1433     //                  in trainData.
1434     // varIdx       - indices of used variables. zero value means that all
1435     //                  variables are active.
1436     // sampleIdx    - indices of used samples. zero value means that all
1437     //                  samples from trainData are in the training set.
1438     // varType      - vector of <variables count> length. gives every
1439     //                  variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
1440     //                  varType = 0 means all variables are numerical.
1441     // missingDataMask  - a mask of misiing values in trainData.
1442     //                  missingDataMask = 0 means that there are no missing
1443     //                  values.
1444     // params         - parameters of GTB algorithm.
1445     // OUTPUT
1446     // RESULT
1447     */
1448     CvGBTrees( const CvMat* trainData, int tflag,
1449              const CvMat* responses, const CvMat* varIdx=0,
1450              const CvMat* sampleIdx=0, const CvMat* varType=0,
1451              const CvMat* missingDataMask=0,
1452              CvGBTreesParams params=CvGBTreesParams() );
1453
1454
1455     /*
1456     // Destructor.
1457     */
1458     virtual ~CvGBTrees();
1459
1460
1461     /*
1462     // Gradient tree boosting model training
1463     //
1464     // API
1465     // virtual bool train( const CvMat* trainData, int tflag,
1466              const CvMat* responses, const CvMat* varIdx=0,
1467              const CvMat* sampleIdx=0, const CvMat* varType=0,
1468              const CvMat* missingDataMask=0,
1469              CvGBTreesParams params=CvGBTreesParams(),
1470              bool update=false );
1471
1472     // INPUT
1473     // trainData    - a set of input feature vectors.
1474     //                  size of matrix is
1475     //                  <count of samples> x <variables count>
1476     //                  or <variables count> x <count of samples>
1477     //                  depending on the tflag parameter.
1478     //                  matrix values are float.
1479     // tflag         - a flag showing how do samples stored in the
1480     //                  trainData matrix row by row (tflag=CV_ROW_SAMPLE)
1481     //                  or column by column (tflag=CV_COL_SAMPLE).
1482     // responses     - a vector of responses corresponding to the samples
1483     //                  in trainData.
1484     // varIdx       - indices of used variables. zero value means that all
1485     //                  variables are active.
1486     // sampleIdx    - indices of used samples. zero value means that all
1487     //                  samples from trainData are in the training set.
1488     // varType      - vector of <variables count> length. gives every
1489     //                  variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
1490     //                  varType = 0 means all variables are numerical.
1491     // missingDataMask  - a mask of misiing values in trainData.
1492     //                  missingDataMask = 0 means that there are no missing
1493     //                  values.
1494     // params         - parameters of GTB algorithm.
1495     // update         - is not supported now. (!)
1496     // OUTPUT
1497     // RESULT
1498     // Error state.
1499     */
1500     virtual bool train( const CvMat* trainData, int tflag,
1501              const CvMat* responses, const CvMat* varIdx=0,
1502              const CvMat* sampleIdx=0, const CvMat* varType=0,
1503              const CvMat* missingDataMask=0,
1504              CvGBTreesParams params=CvGBTreesParams(),
1505              bool update=false );
1506
1507
1508     /*
1509     // Gradient tree boosting model training
1510     //
1511     // API
1512     // virtual bool train( CvMLData* data,
1513              CvGBTreesParams params=CvGBTreesParams(),
1514              bool update=false ) {return false;};
1515
1516     // INPUT
1517     // data          - training set.
1518     // params        - parameters of GTB algorithm.
1519     // update        - is not supported now. (!)
1520     // OUTPUT
1521     // RESULT
1522     // Error state.
1523     */
1524     virtual bool train( CvMLData* data,
1525              CvGBTreesParams params=CvGBTreesParams(),
1526              bool update=false );
1527
1528
1529     /*
1530     // Response value prediction
1531     //
1532     // API
1533     // virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
1534              CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
1535              int k=-1 ) const;
1536
1537     // INPUT
1538     // sample         - input sample of the same type as in the training set.
1539     // missing        - missing values mask. missing=0 if there are no
1540     //                   missing values in sample vector.
1541     // weak_responses  - predictions of all of the trees.
1542     //                   not implemented (!)
1543     // slice           - part of the ensemble used for prediction.
1544     //                   slice = CV_WHOLE_SEQ when all trees are used.
1545     // k               - number of ensemble used.
1546     //                   k is in {-1,0,1,..,<count of output classes-1>}.
1547     //                   in the case of classification problem
1548     //                   <count of output classes-1> ensembles are built.
1549     //                   If k = -1 ordinary prediction is the result,
1550     //                   otherwise function gives the prediction of the
1551     //                   k-th ensemble only.
1552     // OUTPUT
1553     // RESULT
1554     // Predicted value.
1555     */
1556     virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
1557             CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
1558             int k=-1 ) const;
1559
1560     /*
1561     // Response value prediction.
1562     // Parallel version (in the case of TBB existence)
1563     //
1564     // API
1565     // virtual float predict( const CvMat* sample, const CvMat* missing=0,
1566              CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
1567              int k=-1 ) const;
1568
1569     // INPUT
1570     // sample         - input sample of the same type as in the training set.
1571     // missing        - missing values mask. missing=0 if there are no
1572     //                   missing values in sample vector.
1573     // weak_responses  - predictions of all of the trees.
1574     //                   not implemented (!)
1575     // slice           - part of the ensemble used for prediction.
1576     //                   slice = CV_WHOLE_SEQ when all trees are used.
1577     // k               - number of ensemble used.
1578     //                   k is in {-1,0,1,..,<count of output classes-1>}.
1579     //                   in the case of classification problem
1580     //                   <count of output classes-1> ensembles are built.
1581     //                   If k = -1 ordinary prediction is the result,
1582     //                   otherwise function gives the prediction of the
1583     //                   k-th ensemble only.
1584     // OUTPUT
1585     // RESULT
1586     // Predicted value.
1587     */
1588     virtual float predict( const CvMat* sample, const CvMat* missing=0,
1589             CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
1590             int k=-1 ) const;
1591
1592     /*
1593     // Deletes all the data.
1594     //
1595     // API
1596     // virtual void clear();
1597
1598     // INPUT
1599     // OUTPUT
1600     // delete data, weak, orig_response, sum_response,
1601     //        weak_eval, subsample_train, subsample_test,
1602     //        sample_idx, missing, lass_labels
1603     // delta = 0.0
1604     // RESULT
1605     */
1606     CV_WRAP virtual void clear();
1607
1608     /*
1609     // Compute error on the train/test set.
1610     //
1611     // API
1612     // virtual float calc_error( CvMLData* _data, int type,
1613     //        std::vector<float> *resp = 0 );
1614     //
1615     // INPUT
1616     // data  - dataset
1617     // type  - defines which error is to compute: train (CV_TRAIN_ERROR) or
1618     //         test (CV_TEST_ERROR).
1619     // OUTPUT
1620     // resp  - vector of predicitons
1621     // RESULT
1622     // Error value.
1623     */
1624     virtual float calc_error( CvMLData* _data, int type,
1625             std::vector<float> *resp = 0 );
1626
1627     /*
1628     //
1629     // Write parameters of the gtb model and data. Write learned model.
1630     //
1631     // API
1632     // virtual void write( CvFileStorage* fs, const char* name ) const;
1633     //
1634     // INPUT
1635     // fs     - file storage to read parameters from.
1636     // name   - model name.
1637     // OUTPUT
1638     // RESULT
1639     */
1640     virtual void write( CvFileStorage* fs, const char* name ) const;
1641
1642
1643     /*
1644     //
1645     // Read parameters of the gtb model and data. Read learned model.
1646     //
1647     // API
1648     // virtual void read( CvFileStorage* fs, CvFileNode* node );
1649     //
1650     // INPUT
1651     // fs     - file storage to read parameters from.
1652     // node   - file node.
1653     // OUTPUT
1654     // RESULT
1655     */
1656     virtual void read( CvFileStorage* fs, CvFileNode* node );
1657
1658
1659     // new-style C++ interface
1660     CV_WRAP CvGBTrees( const cv::Mat& trainData, int tflag,
1661               const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1662               const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1663               const cv::Mat& missingDataMask=cv::Mat(),
1664               CvGBTreesParams params=CvGBTreesParams() );
1665
1666     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
1667                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1668                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1669                        const cv::Mat& missingDataMask=cv::Mat(),
1670                        CvGBTreesParams params=CvGBTreesParams(),
1671                        bool update=false );
1672
1673     CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
1674                            const cv::Range& slice = cv::Range::all(),
1675                            int k=-1 ) const;
1676
1677 protected:
1678
1679     /*
1680     // Compute the gradient vector components.
1681     //
1682     // API
1683     // virtual void find_gradient( const int k = 0);
1684
1685     // INPUT
1686     // k        - used for classification problem, determining current
1687     //            tree ensemble.
1688     // OUTPUT
1689     // changes components of data->responses
1690     // which correspond to samples used for training
1691     // on the current step.
1692     // RESULT
1693     */
1694     virtual void find_gradient( const int k = 0);
1695
1696
1697     /*
1698     //
1699     // Change values in tree leaves according to the used loss function.
1700     //
1701     // API
1702     // virtual void change_values(CvDTree* tree, const int k = 0);
1703     //
1704     // INPUT
1705     // tree      - decision tree to change.
1706     // k         - used for classification problem, determining current
1707     //             tree ensemble.
1708     // OUTPUT
1709     // changes 'value' fields of the trees' leaves.
1710     // changes sum_response_tmp.
1711     // RESULT
1712     */
1713     virtual void change_values(CvDTree* tree, const int k = 0);
1714
1715
1716     /*
1717     //
1718     // Find optimal constant prediction value according to the used loss
1719     // function.
1720     // The goal is to find a constant which gives the minimal summary loss
1721     // on the _Idx samples.
1722     //
1723     // API
1724     // virtual float find_optimal_value( const CvMat* _Idx );
1725     //
1726     // INPUT
1727     // _Idx        - indices of the samples from the training set.
1728     // OUTPUT
1729     // RESULT
1730     // optimal constant value.
1731     */
1732     virtual float find_optimal_value( const CvMat* _Idx );
1733
1734
1735     /*
1736     //
1737     // Randomly split the whole training set in two parts according
1738     // to params.portion.
1739     //
1740     // API
1741     // virtual void do_subsample();
1742     //
1743     // INPUT
1744     // OUTPUT
1745     // subsample_train - indices of samples used for training
1746     // subsample_test  - indices of samples used for test
1747     // RESULT
1748     */
1749     virtual void do_subsample();
1750
1751
1752     /*
1753     //
1754     // Internal recursive function giving an array of subtree tree leaves.
1755     //
1756     // API
1757     // void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
1758     //
1759     // INPUT
1760     // node         - current leaf.
1761     // OUTPUT
1762     // count        - count of leaves in the subtree.
1763     // leaves       - array of pointers to leaves.
1764     // RESULT
1765     */
1766     void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
1767
1768
1769     /*
1770     //
1771     // Get leaves of the tree.
1772     //
1773     // API
1774     // CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
1775     //
1776     // INPUT
1777     // dtree            - decision tree.
1778     // OUTPUT
1779     // len              - count of the leaves.
1780     // RESULT
1781     // CvDTreeNode**    - array of pointers to leaves.
1782     */
1783     CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
1784
1785
1786     /*
1787     //
1788     // Is it a regression or a classification.
1789     //
1790     // API
1791     // bool problem_type();
1792     //
1793     // INPUT
1794     // OUTPUT
1795     // RESULT
1796     // false if it is a classification problem,
1797     // true - if regression.
1798     */
1799     virtual bool problem_type() const;
1800
1801
1802     /*
1803     //
1804     // Write parameters of the gtb model.
1805     //
1806     // API
1807     // virtual void write_params( CvFileStorage* fs ) const;
1808     //
1809     // INPUT
1810     // fs           - file storage to write parameters to.
1811     // OUTPUT
1812     // RESULT
1813     */
1814     virtual void write_params( CvFileStorage* fs ) const;
1815
1816
1817     /*
1818     //
1819     // Read parameters of the gtb model and data.
1820     //
1821     // API
1822     // virtual void read_params( CvFileStorage* fs );
1823     //
1824     // INPUT
1825     // fs           - file storage to read parameters from.
1826     // OUTPUT
1827     // params       - parameters of the gtb model.
1828     // data         - contains information about the structure
1829     //                of the data set (count of variables,
1830     //                their types, etc.).
1831     // class_labels - output class labels map.
1832     // RESULT
1833     */
1834     virtual void read_params( CvFileStorage* fs, CvFileNode* fnode );
1835     int get_len(const CvMat* mat) const;
1836
1837
1838     CvDTreeTrainData* data;
1839     CvGBTreesParams params;
1840
1841     CvSeq** weak;
1842     CvMat* orig_response;
1843     CvMat* sum_response;
1844     CvMat* sum_response_tmp;
1845     CvMat* sample_idx;
1846     CvMat* subsample_train;
1847     CvMat* subsample_test;
1848     CvMat* missing;
1849     CvMat* class_labels;
1850
1851     cv::RNG* rng;
1852
1853     int class_count;
1854     float delta;
1855     float base_value;
1856
1857 };
1858
1859
1860
1861 /****************************************************************************************\
1862 *                              Artificial Neural Networks (ANN)                          *
1863 \****************************************************************************************/
1864
1865 /////////////////////////////////// Multi-Layer Perceptrons //////////////////////////////
1866
1867 struct CV_EXPORTS_W_MAP CvANN_MLP_TrainParams
1868 {
1869     CvANN_MLP_TrainParams();
1870     CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method,
1871                            double param1, double param2=0 );
1872     ~CvANN_MLP_TrainParams();
1873
1874     enum { BACKPROP=0, RPROP=1 };
1875
1876     CV_PROP_RW CvTermCriteria term_crit;
1877     CV_PROP_RW int train_method;
1878
1879     // backpropagation parameters
1880     CV_PROP_RW double bp_dw_scale, bp_moment_scale;
1881
1882     // rprop parameters
1883     CV_PROP_RW double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max;
1884 };
1885
1886
1887 class CV_EXPORTS_W CvANN_MLP : public CvStatModel
1888 {
1889 public:
1890     CV_WRAP CvANN_MLP();
1891     CvANN_MLP( const CvMat* layerSizes,
1892                int activateFunc=CvANN_MLP::SIGMOID_SYM,
1893                double fparam1=0, double fparam2=0 );
1894
1895     virtual ~CvANN_MLP();
1896
1897     virtual void create( const CvMat* layerSizes,
1898                          int activateFunc=CvANN_MLP::SIGMOID_SYM,
1899                          double fparam1=0, double fparam2=0 );
1900
1901     virtual int train( const CvMat* inputs, const CvMat* outputs,
1902                        const CvMat* sampleWeights, const CvMat* sampleIdx=0,
1903                        CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
1904                        int flags=0 );
1905     virtual float predict( const CvMat* inputs, CV_OUT CvMat* outputs ) const;
1906
1907     CV_WRAP CvANN_MLP( const cv::Mat& layerSizes,
1908               int activateFunc=CvANN_MLP::SIGMOID_SYM,
1909               double fparam1=0, double fparam2=0 );
1910
1911     CV_WRAP virtual void create( const cv::Mat& layerSizes,
1912                         int activateFunc=CvANN_MLP::SIGMOID_SYM,
1913                         double fparam1=0, double fparam2=0 );
1914
1915     CV_WRAP virtual int train( const cv::Mat& inputs, const cv::Mat& outputs,
1916                       const cv::Mat& sampleWeights, const cv::Mat& sampleIdx=cv::Mat(),
1917                       CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
1918                       int flags=0 );
1919
1920     CV_WRAP virtual float predict( const cv::Mat& inputs, CV_OUT cv::Mat& outputs ) const;
1921
1922     CV_WRAP virtual void clear();
1923
1924     // possible activation functions
1925     enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 };
1926
1927     // available training flags
1928     enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 };
1929
1930     virtual void read( CvFileStorage* fs, CvFileNode* node );
1931     virtual void write( CvFileStorage* storage, const char* name ) const;
1932
1933     int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; }
1934     const CvMat* get_layer_sizes() { return layer_sizes; }
1935     double* get_weights(int layer)
1936     {
1937         return layer_sizes && weights &&
1938             (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0;
1939     }
1940
1941     virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const;
1942
1943 protected:
1944
1945     virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs,
1946             const CvMat* _sample_weights, const CvMat* sampleIdx,
1947             CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags );
1948
1949     // sequential random backpropagation
1950     virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
1951
1952     // RPROP algorithm
1953     virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
1954
1955     virtual void calc_activ_func( CvMat* xf, const double* bias ) const;
1956     virtual void set_activ_func( int _activ_func=SIGMOID_SYM,
1957                                  double _f_param1=0, double _f_param2=0 );
1958     virtual void init_weights();
1959     virtual void scale_input( const CvMat* _src, CvMat* _dst ) const;
1960     virtual void scale_output( const CvMat* _src, CvMat* _dst ) const;
1961     virtual void calc_input_scale( const CvVectors* vecs, int flags );
1962     virtual void calc_output_scale( const CvVectors* vecs, int flags );
1963
1964     virtual void write_params( CvFileStorage* fs ) const;
1965     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
1966
1967     CvMat* layer_sizes;
1968     CvMat* wbuf;
1969     CvMat* sample_weights;
1970     double** weights;
1971     double f_param1, f_param2;
1972     double min_val, max_val, min_val1, max_val1;
1973     int activ_func;
1974     int max_count, max_buf_sz;
1975     CvANN_MLP_TrainParams params;
1976     cv::RNG* rng;
1977 };
1978
1979 /****************************************************************************************\
1980 *                           Auxilary functions declarations                              *
1981 \****************************************************************************************/
1982
1983 /* Generates <sample> from multivariate normal distribution, where <mean> - is an
1984    average row vector, <cov> - symmetric covariation matrix */
1985 CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample,
1986                            CvRNG* rng CV_DEFAULT(0) );
1987
1988 /* Generates sample from gaussian mixture distribution */
1989 CVAPI(void) cvRandGaussMixture( CvMat* means[],
1990                                CvMat* covs[],
1991                                float weights[],
1992                                int clsnum,
1993                                CvMat* sample,
1994                                CvMat* sampClasses CV_DEFAULT(0) );
1995
1996 #define CV_TS_CONCENTRIC_SPHERES 0
1997
1998 /* creates test set */
1999 CVAPI(void) cvCreateTestSet( int type, CvMat** samples,
2000                  int num_samples,
2001                  int num_features,
2002                  CvMat** responses,
2003                  int num_classes, ... );
2004
2005 /****************************************************************************************\
2006 *                                      Data                                             *
2007 \****************************************************************************************/
2008
2009 #define CV_COUNT     0
2010 #define CV_PORTION   1
2011
2012 struct CV_EXPORTS CvTrainTestSplit
2013 {
2014     CvTrainTestSplit();
2015     CvTrainTestSplit( int train_sample_count, bool mix = true);
2016     CvTrainTestSplit( float train_sample_portion, bool mix = true);
2017
2018     union
2019     {
2020         int count;
2021         float portion;
2022     } train_sample_part;
2023     int train_sample_part_mode;
2024
2025     bool mix;
2026 };
2027
2028 class CV_EXPORTS CvMLData
2029 {
2030 public:
2031     CvMLData();
2032     virtual ~CvMLData();
2033
2034     // returns:
2035     // 0 - OK
2036     // -1 - file can not be opened or is not correct
2037     int read_csv( const char* filename );
2038
2039     const CvMat* get_values() const;
2040     const CvMat* get_responses();
2041     const CvMat* get_missing() const;
2042
2043     void set_response_idx( int idx ); // old response become predictors, new response_idx = idx
2044                                       // if idx < 0 there will be no response
2045     int get_response_idx() const;
2046
2047     void set_train_test_split( const CvTrainTestSplit * spl );
2048     const CvMat* get_train_sample_idx() const;
2049     const CvMat* get_test_sample_idx() const;
2050     void mix_train_and_test_idx();
2051
2052     const CvMat* get_var_idx();
2053     void chahge_var_idx( int vi, bool state ); // misspelled (saved for back compitability),
2054                                                // use change_var_idx
2055     void change_var_idx( int vi, bool state ); // state == true to set vi-variable as predictor
2056
2057     const CvMat* get_var_types();
2058     int get_var_type( int var_idx ) const;
2059     // following 2 methods enable to change vars type
2060     // use these methods to assign CV_VAR_CATEGORICAL type for categorical variable
2061     // with numerical labels; in the other cases var types are correctly determined automatically
2062     void set_var_types( const char* str );  // str examples:
2063                                             // "ord[0-17],cat[18]", "ord[0,2,4,10-12], cat[1,3,5-9,13,14]",
2064                                             // "cat", "ord" (all vars are categorical/ordered)
2065     void change_var_type( int var_idx, int type); // type in { CV_VAR_ORDERED, CV_VAR_CATEGORICAL }
2066
2067     void set_delimiter( char ch );
2068     char get_delimiter() const;
2069
2070     void set_miss_ch( char ch );
2071     char get_miss_ch() const;
2072
2073     const std::map<std::string, int>& get_class_labels_map() const;
2074
2075 protected:
2076     virtual void clear();
2077
2078     void str_to_flt_elem( const char* token, float& flt_elem, int& type);
2079     void free_train_test_idx();
2080
2081     char delimiter;
2082     char miss_ch;
2083     //char flt_separator;
2084
2085     CvMat* values;
2086     CvMat* missing;
2087     CvMat* var_types;
2088     CvMat* var_idx_mask;
2089
2090     CvMat* response_out; // header
2091     CvMat* var_idx_out; // mat
2092     CvMat* var_types_out; // mat
2093
2094     int response_idx;
2095
2096     int train_sample_count;
2097     bool mix;
2098
2099     int total_class_count;
2100     std::map<std::string, int> class_map;
2101
2102     CvMat* train_sample_idx;
2103     CvMat* test_sample_idx;
2104     int* sample_idx; // data of train_sample_idx and test_sample_idx
2105
2106     cv::RNG* rng;
2107 };
2108
2109
2110 namespace cv
2111 {
2112
2113 typedef CvStatModel StatModel;
2114 typedef CvParamGrid ParamGrid;
2115 typedef CvNormalBayesClassifier NormalBayesClassifier;
2116 typedef CvKNearest KNearest;
2117 typedef CvSVMParams SVMParams;
2118 typedef CvSVMKernel SVMKernel;
2119 typedef CvSVMSolver SVMSolver;
2120 typedef CvSVM SVM;
2121 typedef CvDTreeParams DTreeParams;
2122 typedef CvMLData TrainData;
2123 typedef CvDTree DecisionTree;
2124 typedef CvForestTree ForestTree;
2125 typedef CvRTParams RandomTreeParams;
2126 typedef CvRTrees RandomTrees;
2127 typedef CvERTreeTrainData ERTreeTRainData;
2128 typedef CvForestERTree ERTree;
2129 typedef CvERTrees ERTrees;
2130 typedef CvBoostParams BoostParams;
2131 typedef CvBoostTree BoostTree;
2132 typedef CvBoost Boost;
2133 typedef CvANN_MLP_TrainParams ANN_MLP_TrainParams;
2134 typedef CvANN_MLP NeuralNet_MLP;
2135 typedef CvGBTreesParams GradientBoostingTreeParams;
2136 typedef CvGBTrees GradientBoostingTrees;
2137
2138 template<> CV_EXPORTS void Ptr<CvDTreeSplit>::delete_obj();
2139
2140 CV_EXPORTS bool initModule_ml(void);
2141
2142 }
2143
2144 #endif // __cplusplus
2145 #endif // __OPENCV_ML_HPP__
2146
2147 /* End of file. */