Started work on.
[mipav.git] / mipav / src / gov / nih / mipav / model / algorithms / StochasticForests.java
1 package gov.nih.mipav.model.algorithms;\r
2 \r
3 \r
4 \r
5 import gov.nih.mipav.model.structures.*;\r
6 import gov.nih.mipav.view.*;\r
7 import java.io.*;\r
8 import java.util.*;\r
9 \r
10 public class StochasticForests extends AlgorithmBase {\r
11         // Note that while Random Forests is the name usually applied to this type\r
12         // of algorithm, Random Forests(tm) is a trademark of Leo Breiman and Adele Cutler and is \r
13         // licensed exclusively to Salford Systems for the commercial release of the software.\r
14         // Our trademarks also include RF(tm), RandomForests(tm), RandomForest(tm) and Random Forest(tm).\r
15         \r
16         // This is a port from C++ to Java of the ranger package of version 0.9.7 of A Fast Implementation\r
17         // of Random Forests.  The date of the original code is 3/29/2018.  The authors of the original code\r
18         // are Marvin N. Wright, Stefan Wager, and Philipp Probst.  The maintainer of the original code is\r
19         // Marvin N. Wright at cran@wrig.de.  The license of the C++ core of version 0.9.7 is a MIT license. \r
20         //(The R package which is not used in this port is still a GPL3 license.)\r
21         \r
22         // Copyright <2014-2018> <Marvin N. Wright>\r
23 \r
24         // Permission is hereby granted, free of charge, to any person obtaining a copy of this software and\r
25         // associated documentation files (the "Software"), to deal in the Software without restriction, \r
26         // including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,\r
27         // and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,\r
28         // subject to the following conditions:\r
29 \r
30         // The above copyright notice and this permission notice shall be included in all copies or substantial\r
31         // portions of the Software.\r
32 \r
33         // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT\r
34         // NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.\r
35         // IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,\r
36         // WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE\r
37         // SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\r
38         \r
39         // ### Introduction\r
40         // ranger is a fast implementation of random forest (Breiman 2001) or recursive partitioning, \r
41         // particularly suited for high dimensional data. Classification, regression, probability estimation\r
42         // and survival forests are supported. Classification and regression forests are implemented as in\r
43         // the original Random Forest (Breiman 2001), survival forests as in Random Survival Forests\r
44         // (Ishwaran et al. 2008). For probability estimation forests see Malley et al. (2012). \r
45         \r
46         // ### References\r
47         // 1.) Wright, M. N. & Ziegler, A. (2017). ranger: A Fast Implementation of Random Forests for High\r
48         // Dimensional Data in C++ and R. Journal of Statistical Software 77:1-17.\r
49         // http://dx.doi.org/10.18637/jss.v077.i01.\r
50         // 2.) Schmid, M., Wright, M. N. & Ziegler, A. (2016). On the use of Harrell's C for clinical risk\r
51         // prediction via random survival forests. Expert Systems with Applications 63:450-459. \r
52         // http://dx.doi.org/10.1016/j.eswa.2016.07.018.\r
53         // 3.) Wright, M. N., Dankowski, T. & Ziegler, A. (2017). Unbiased split variable selection for\r
54         // random survival forests using maximally selected rank statistics. Statistics in Medicine. \r
55         // http://dx.doi.org/10.1002/sim.7212.\r
56         // 4.) Breiman, L. (2001). Random forests. Machine learning 45:5-32.\r
57         // 5.) Ishwaran, H., Kogalur, U. B., Blackstone, E. H., & Lauer, M. S. (2008). \r
58         // Random survival forests. The Annals of Applied Statistics 2:841-860.\r
59         // 6.) Malley, J. D., Kruppa, J., Dasgupta, A., Malley, K. G., & Ziegler, A. (2012). \r
60         // Probability machines: consistent probability estimation using nonparametric learning machines. \r
61         // Methods of Information in Medicine 51:74-81.\r
62 \r
63     // Tree types, probability is not selected by ID\r
64         private enum TreeType {\r
65                 TREE_CLASSIFICATION,\r
66                 TREE_REGRESSION,\r
67             TREE_SURVIVAL,\r
68             TREE_PROBABILITY\r
69         };\r
70         \r
71         // Memory modes\r
72         private enum MemoryMode {\r
73                 MEM_DOUBLE,\r
74                 MEM_FLOAT,\r
75                 MEM_CHAR\r
76         };\r
77         \r
78         // Mask and offset to store 2 bit values in bytes\r
79         static final int mask[] = new int[]{192,48,12,3};\r
80         static final int offset[] = new int[]{6,4,2,0};\r
81         \r
82         // Variable importance\r
83         private enum ImportanceMode {\r
84                 IMP_NONE,\r
85                 IMP_GINI,\r
86                 IMP_PERM_BREIMAN,\r
87                 IMP_PERM_LIAW,\r
88                 IMP_PERM_RAW,\r
89                 IMP_GINI_CORRECTED\r
90         };\r
91         \r
92         // Split mode\r
93         private enum SplitRule {\r
94                 LOGRANK,\r
95                 AUC,\r
96                 AUC_IGNORE_TIES,\r
97                 MAXSTAT,\r
98                 EXTRATREES\r
99         };\r
100         \r
101         // Prediction type\r
102         private enum PredictionType {\r
103                 RESPONSE,\r
104                 TERMINALNODES\r
105         };\r
106         \r
107         // Default values\r
108         final int DEFAULT_NUM_TREE = 500;\r
109         final int DEFAULT_NUM_THREADS = 0;\r
110         final ImportanceMode DEFAULT_IMPORTANCE_MODE = ImportanceMode.IMP_NONE;\r
111         \r
112         final int DEFAULT_MIN_NODE_SIZE_CLASSIFICATION = 1;\r
113         final int DEFAULT_MIN_NODE_SIZE_REGRESSION = 5;\r
114         final int DEFAULT_MIN_NODE_SIZE_SURVIVAL = 3;\r
115         final int DEFAULT_MIN_NODE_SIZE_PROBABILITY = 10;\r
116 \r
117         final SplitRule DEFAULT_SPLITRULE = SplitRule.LOGRANK;\r
118         final double DEFAULT_ALPHA = 0.5;\r
119         final double DEFAULT_MINPROP = 0.1;\r
120 \r
121         final PredictionType DEFAULT_PREDICTIONTYPE = PredictionType.RESPONSE;\r
122         final int DEFAULT_NUM_RANDOM_SPLITS = 1;\r
123 \r
124         //const std::vector<double> DEFAULT_SAMPLE_FRACTION = std::vector<double>({1});\r
125 \r
126         // Interval to print progress in seconds\r
127         final double STATUS_INTERVAL = 30.0;\r
128 \r
129         // Threshold for q value split method switch\r
130         final double Q_THRESHOLD = 0.02;\r
131         \r
132         private class Data {\r
133                 protected Vector<String> variable_names;\r
134                 protected int num_rows = 0;\r
135                 protected int num_rows_rounded = 0;\r
136                 protected int num_cols = 0;\r
137                 \r
138                 protected char snp_data[] = null;\r
139                 protected int num_cols_no_snp = 0;\r
140                 protected boolean externalData = true;\r
141                 \r
142                 protected int index_data[] = null;\r
143                 protected Vector<Vector<Double>> unique_data_values = null;\r
144                 protected int max_num_unique_values = 0;\r
145                 int i;\r
146                 \r
147                 // Variable to not split at (only dependent_varID for non-survival trees)\r
148                 protected Vector<Integer> no_split_variables = null;\r
149                 \r
150                 // For each varID true if ordered\r
151             protected Vector<Boolean> is_ordered_variable = null;\r
152             \r
153             // Permuted samples for corrected impurity importance\r
154             protected Vector<Integer> permuted_sampleIDs = null;\r
155             \r
156             public void dispose() {\r
157                 index_data = null;\r
158             }\r
159             \r
160             public int getVariableID(String variable_name) {\r
161                  for (i = 0; i < variable_names.size(); i++) {\r
162                          if(variable_names.get(i).equals(variable_name)) {\r
163                                  return i;\r
164                          }\r
165                  } // for (i = 0; i < variable_names.size(); i++)\r
166                  return -1;\r
167             }\r
168             \r
169             public void addSnpData(char snp_data[], int num_cols_snp) {\r
170                 num_cols = num_cols_no_snp + num_cols_snp;\r
171                 num_rows_rounded = roundToNextMultiple(num_rows, 4);\r
172                 this.snp_data = snp_data;\r
173             }\r
174             \r
175             // #nocov start\r
176             /*public boolean loadFromFile(String filename) {\r
177                 boolean result;\r
178                 \r
179                 // Open input file\r
180                 File file = new File(filename);\r
181                 BufferedReader input_file;\r
182                 try {\r
183                     input_file = new BufferedReader(new FileReader(file));\r
184                 }\r
185                 catch (FileNotFoundException e) {\r
186                         MipavUtil.displayError("Could not find file " + filename);\r
187                         return false;\r
188                 }\r
189                 \r
190                 // Count number of rows\r
191                 int line_count = 0;\r
192                 String line;\r
193                 while (true) {\r
194                         try {\r
195                             line = input_file.readLine();\r
196                         }\r
197                         catch (IOException e) {\r
198                                 MipavUtil.displayError("IO exception on readLine of " + filename);\r
199                                 return false;\r
200                         }\r
201                         if (line != null) {\r
202                                 line_count++;\r
203                         }\r
204                         else {\r
205                                 break;\r
206                         }\r
207                 } // while (true)\r
208                 num_rows = line_count-1;\r
209                 try {\r
210                     input_file.close();\r
211                 }\r
212                 catch (IOException e) {\r
213                         MipavUtil.displayError("IO exception on close of " + filename);\r
214                 }\r
215                 try {\r
216                     input_file = new BufferedReader(new FileReader(file));\r
217                 }\r
218                 catch (FileNotFoundException e) {\r
219                         MipavUtil.displayError("Could not find file " + filename);\r
220                         return false;\r
221                 }\r
222                 \r
223                 // Check if comma, semicolon, or whitespace separated\r
224                 String header_line;\r
225                 try {\r
226                     header_line = input_file.readLine();\r
227                 }\r
228                 catch (IOException e) {\r
229                         MipavUtil.displayError("IO exception reading header line of " + filename);\r
230                         return false;\r
231                 }\r
232                 \r
233                 // Find out if comma, semicolon, or whitespace separated and call appropriate method\r
234                 if (header_line.indexOf(",") != -1) {\r
235                         result = loadFromFileOther(input_file, header_line, ",");\r
236                 }\r
237                 else if (header_line.indexOf(";") != -1) {\r
238                         result = loadFromFileOther(input_file, header_line, ";");       \r
239                 }\r
240                 else {\r
241                         result = loadFromFileWhitespace(input_file, header_line);\r
242                 }\r
243                 \r
244                 externalData = false;\r
245                 try {\r
246                     input_file.close();\r
247                 }\r
248                 catch (IOException e) {\r
249                         MipavUtil.displayError("IO exception on close of " + filename);\r
250                 }\r
251                 return result;\r
252             } // loadFromFile\r
253             \r
254             public boolean loadFromFileWhitespace(BufferedReader input_file, String header_line) {\r
255                 // Read header\r
256                 String[] header_tokens;\r
257                 header_tokens = header_line.split(" ");\r
258                 for (i = 0; i < header_tokens.length; i++) {\r
259                         variable_names.add(header_tokens[i]);\r
260                 }\r
261                 num_cols = variable_names.size();\r
262                 num_cols_no_snp = num_cols;\r
263                 \r
264                 // Read body\r
265                 boolean error = false;\r
266                 String line;\r
267                 int row = 0;\r
268                 while (true) {\r
269                         try {\r
270                         line = input_file.readLine();\r
271                         }\r
272                         catch (IOException e) {\r
273                                 MipavUtil.displayError("IO exception on readLine from input_file");\r
274                                 return false;\r
275                         }\r
276                         if (line == null) {\r
277                                 break;\r
278                         }\r
279                         String tokens[];\r
280                         tokens = line.split(" ");\r
281                         for (i = 0; i < tokens.length; i++) {\r
282                                 double dValue = Double.valueOf(tokens[i]).doubleValue();\r
283                         }\r
284                 } // while true\r
285             } */\r
286             \r
287             \r
288         } // private class Data\r
289         \r
290         private int roundToNextMultiple(int value, int multiple) {\r
291                 if (multiple == 0) {\r
292                         return value;\r
293                 }\r
294                 \r
295                 int remainder = value % multiple;\r
296                 if (remainder == 0) {\r
297                         return value;\r
298                 }\r
299                 return value + multiple - remainder;\r
300         }\r
301 \r
302         \r
303         public void runAlgorithm() {\r
304                 \r
305         }\r
306 }