Porting continues.
authorilb@NIH.GOV <ilb@NIH.GOV@ba61647d-9d00-f842-95cd-605cb4296b96>
Wed, 18 Apr 2018 21:29:51 +0000 (21:29 +0000)
committerilb@NIH.GOV <ilb@NIH.GOV@ba61647d-9d00-f842-95cd-605cb4296b96>
Wed, 18 Apr 2018 21:29:51 +0000 (21:29 +0000)
git-svn-id: https://citdcbmipav.cit.nih.gov/repos-pub/mipav/trunk@15453 ba61647d-9d00-f842-95cd-605cb4296b96

mipav/src/gov/nih/mipav/model/algorithms/StochasticForests.java

index ce6c13c..729e644 100644 (file)
@@ -129,6 +129,8 @@ public class StochasticForests extends AlgorithmBase {
        // Threshold for q value split method switch\r
        final double Q_THRESHOLD = 0.02;\r
        \r
+       public static int RAND_MAX = 32767;\r
+       \r
        private class Data {\r
                protected Vector<String> variable_names;\r
                protected int num_rows = 0;\r
@@ -140,18 +142,18 @@ public class StochasticForests extends AlgorithmBase {
                protected boolean externalData = true;\r
                \r
                protected int index_data[] = null;\r
-               protected Vector<Vector<Double>> unique_data_values = null;\r
+               protected Vector<Vector<Double>> unique_data_values = new Vector<Vector<Double>>();\r
                protected int max_num_unique_values = 0;\r
                int i;\r
                \r
                // Variable to not split at (only dependent_varID for non-survival trees)\r
-               protected Vector<Integer> no_split_variables = null;\r
+               protected Vector<Integer> no_split_variables = new Vector<Integer>();\r
                \r
                // For each varID true if ordered\r
-           protected Vector<Boolean> is_ordered_variable = null;\r
+           protected Vector<Boolean> is_ordered_variable = new Vector<Boolean>();\r
            \r
            // Permuted samples for corrected impurity importance\r
-           protected Vector<Integer> permuted_sampleIDs = null;\r
+           protected Vector<Integer> permuted_sampleIDs = new Vector<Integer>();\r
            \r
            public void dispose() {\r
                index_data = null;\r
@@ -173,7 +175,7 @@ public class StochasticForests extends AlgorithmBase {
            }\r
            \r
            // #nocov start\r
-           /*public boolean loadFromFile(String filename) {\r
+           public boolean loadFromFile(String filename) {\r
                boolean result;\r
                \r
                // Open input file\r
@@ -184,7 +186,7 @@ public class StochasticForests extends AlgorithmBase {
                }\r
                catch (FileNotFoundException e) {\r
                        MipavUtil.displayError("Could not find file " + filename);\r
-                       return false;\r
+                       return true;\r
                }\r
                \r
                // Count number of rows\r
@@ -196,7 +198,7 @@ public class StochasticForests extends AlgorithmBase {
                        }\r
                        catch (IOException e) {\r
                                MipavUtil.displayError("IO exception on readLine of " + filename);\r
-                               return false;\r
+                               return true;\r
                        }\r
                        if (line != null) {\r
                                line_count++;\r
@@ -211,13 +213,14 @@ public class StochasticForests extends AlgorithmBase {
                }\r
                catch (IOException e) {\r
                        MipavUtil.displayError("IO exception on close of " + filename);\r
+                       return true;\r
                }\r
                try {\r
                    input_file = new BufferedReader(new FileReader(file));\r
                }\r
                catch (FileNotFoundException e) {\r
                        MipavUtil.displayError("Could not find file " + filename);\r
-                       return false;\r
+                       return true;\r
                }\r
                \r
                // Check if comma, semicolon, or whitespace separated\r
@@ -227,7 +230,7 @@ public class StochasticForests extends AlgorithmBase {
                }\r
                catch (IOException e) {\r
                        MipavUtil.displayError("IO exception reading header line of " + filename);\r
-                       return false;\r
+                       return true;\r
                }\r
                \r
                // Find out if comma, semicolon, or whitespace separated and call appropriate method\r
@@ -238,7 +241,7 @@ public class StochasticForests extends AlgorithmBase {
                        result = loadFromFileOther(input_file, header_line, ";");       \r
                }\r
                else {\r
-                       result = loadFromFileWhitespace(input_file, header_line);\r
+                       result = loadFromFileOther(input_file, header_line, " ");\r
                }\r
                \r
                externalData = false;\r
@@ -251,10 +254,11 @@ public class StochasticForests extends AlgorithmBase {
                return result;\r
            } // loadFromFile\r
            \r
-           public boolean loadFromFileWhitespace(BufferedReader input_file, String header_line) {\r
+           // Use instead of loadFromFileWhitespace by using separator = " ".\r
+           public boolean loadFromFileOther(BufferedReader input_file, String header_line, String separator) {\r
                // Read header\r
                String[] header_tokens;\r
-               header_tokens = header_line.split(" ");\r
+               header_tokens = header_line.split(separator);\r
                for (i = 0; i < header_tokens.length; i++) {\r
                        variable_names.add(header_tokens[i]);\r
                }\r
@@ -262,7 +266,8 @@ public class StochasticForests extends AlgorithmBase {
                num_cols_no_snp = num_cols;\r
                \r
                // Read body\r
-               boolean error = false;\r
+               reserveMemory();\r
+               boolean error[] = new boolean[]{false};\r
                String line;\r
                int row = 0;\r
                while (true) {\r
@@ -277,16 +282,299 @@ public class StochasticForests extends AlgorithmBase {
                                break;\r
                        }\r
                        String tokens[];\r
-                       tokens = line.split(" ");\r
+                       int column = 0;\r
+                       tokens = line.split(separator);\r
                        for (i = 0; i < tokens.length; i++) {\r
                                double dValue = Double.valueOf(tokens[i]).doubleValue();\r
+                               set(column, row, dValue, error);\r
+                    column++;\r
                        }\r
+                       if (separator.equals(" ")) {\r
+                               if (column > num_cols) {\r
+                                       MipavUtil.displayError("Too many columns in a row");\r
+                                       return false;\r
+                               }\r
+                               else if (column < num_cols) {\r
+                                       MipavUtil.displayError("Too few columns in a row.  Are all values numeric?");\r
+                                       return false;\r
+                               }\r
+                       } // if (separator.equals(" "))\r
+                       row++;\r
                } // while true\r
-           } */\r
+               num_rows = row;\r
+               return error[0];\r
+           }\r
+           \r
+           public void getAllValues(Vector<Double> all_values, Vector<Integer> sampleIDs,\r
+                       int varID) {\r
+               // All values for varID (no duplicates) for given sampleIDs\r
+               if (getUnpermutedVarID(varID) < num_cols_no_snp) {\r
+                   if (all_values.size() < sampleIDs.size()) {\r
+                       all_values.setSize(sampleIDs.size());   \r
+                   }\r
+                   for (i = 0; i < sampleIDs.size(); i++) {\r
+                       all_values.add(get(sampleIDs.get(i),varID));\r
+                   }\r
+                   all_values.sort(null);\r
+                   for (i = all_values.size()-1; i >= 1; i--) {\r
+                       if (all_values.get(i) == all_values.get(i-1)) {\r
+                               all_values.removeElementAt(i);\r
+                       }\r
+                   }\r
+               } // if (getUnpermutedVarID(varID) < num_cols_no_snp)\r
+               else {\r
+                       // If GWA data just use 0, 1, 2\r
+                       all_values.clear();\r
+                       all_values.add(0.0);\r
+                       all_values.add(1.0);\r
+                       all_values.add(2.0);\r
+               }\r
+           }\r
+           \r
+           public void getMinMaxValues(double min[], double max[], Vector<Integer> sampleIDs, int varID) {\r
+               if (sampleIDs.size() > 0) {\r
+                       min[0] = get(sampleIDs.get(0), varID);\r
+                       max[0] = min[0];\r
+               }\r
+               for ( i = 1; i < sampleIDs.size(); i++) {\r
+                       double value = get(sampleIDs.get(i), varID);\r
+                       if (value < min[0]) {\r
+                               min[0] = value;\r
+                       }\r
+                       if (value > max[0]) {\r
+                               max[0] = value;\r
+                       }\r
+               }\r
+           }\r
+           \r
+           public void sort() {\r
+               // Reserve memory\r
+               index_data = new int[num_cols_no_snp * num_rows];\r
+               \r
+               // For all columns, get unique values and save index for each observation\r
+               for (int col = 0; col < num_cols_no_snp; col++) {\r
+                       // Get all unique values\r
+                       Vector<Double>unique_values = new Vector<Double>();\r
+                       if (unique_values.size() < num_rows) {\r
+                               unique_values.setSize(num_rows);\r
+                       }\r
+                       for (int row = 0; row < num_rows; row++) {\r
+                               unique_values.add(row,get(row, col));\r
+                       }\r
+                       unique_values.sort(null);\r
+                       for (i = unique_values.size()-1; i >= 1; i--) {\r
+                       if (unique_values.get(i) == unique_values.get(i-1)) {\r
+                               unique_values.removeElementAt(i);\r
+                       }\r
+                   }\r
+                       \r
+                       // Get index of unique value\r
+                       for (int row = 0; row < num_rows; row++) {\r
+                               int idx;\r
+                               for (idx = 0; idx < unique_values.size(); idx++) {\r
+                                       if (unique_values.get(idx) >= get(row,col)) {\r
+                                               break;\r
+                                       }\r
+                               }\r
+                               index_data[col * num_rows + row] = idx;\r
+                       } // for (int row = 0; row < num_rows; row++)\r
+                       \r
+                       // Save unique values\r
+                   unique_data_values.add(unique_values);\r
+                   if (unique_values.size() > max_num_unique_values) {\r
+                       max_num_unique_values = unique_values.size();\r
+                   }\r
+               } // for (int col = 0; col < num_cols_no_snp; col++) {\r
+           } // public void sort()\r
+           \r
+           public void reserveMemory() {\r
+               \r
+           };\r
+           \r
+           public void set(int col, int row, double value, boolean error[]) {\r
+               \r
+           }\r
+           \r
+           \r
+           public double get(int row, int col) {\r
+               return 0.0;\r
+           }\r
+           \r
+           public int getUnpermutedVarID(int varID) {\r
+               if (varID >= num_cols) {\r
+                 varID -= num_cols;\r
+\r
+                 for (i = 0; i < no_split_variables.size(); i++) {\r
+                   if (varID >= no_split_variables.get(i)) {\r
+                     ++varID;\r
+                   }\r
+                 }\r
+               }\r
+               return varID;\r
+             }\r
+           \r
+           public int getPermutedSampleID(int sampleID) {\r
+               return permuted_sampleIDs.get(sampleID);\r
+           }\r
+\r
+           public int getIndex(int row, int col) {\r
+               // Use permuted data for corrected impurity importance\r
+               if (col >= num_cols) {\r
+                 col = getUnpermutedVarID(col);\r
+                 row = getPermutedSampleID(row);\r
+               }\r
+\r
+               if (col < num_cols_no_snp) {\r
+                 return index_data[col * num_rows + row];\r
+               } else {\r
+                 // Get data out of snp storage. -1 because of GenABEL coding.\r
+                 int idx = (col - num_cols_no_snp) * num_rows_rounded + row;\r
+                 int result = (((snp_data[idx / 4] & mask[idx % 4]) >> offset[idx % 4]) - 1);\r
+\r
+                 // TODO: Better way to treat missing values?\r
+                 if (result > 2) {\r
+                   return 0;\r
+                 } else {\r
+                   return result;\r
+                 }\r
+               }\r
+             }\r
+\r
+           public double getUniqueDataValue(int varID, int index) {\r
+               // Use permuted data for corrected impurity importance\r
+               if (varID >= num_cols) {\r
+                 varID = getUnpermutedVarID(varID);\r
+               }\r
+\r
+               if (varID < num_cols_no_snp) {\r
+                 return unique_data_values.get(varID).get(index);\r
+               } else {\r
+                 // For GWAS data the index is the value\r
+                 return (index);\r
+               }\r
+             }\r
+\r
+           public int getNumUniqueDataValues(int varID) {\r
+               // Use permuted data for corrected impurity importance\r
+               if (varID >= num_cols) {\r
+                 varID = getUnpermutedVarID(varID);\r
+               }\r
+\r
+               if (varID < num_cols_no_snp) {\r
+                 return unique_data_values.get(varID).size();\r
+               } else {\r
+                 // For GWAS data 0,1,2\r
+                 return (3);\r
+               }\r
+             }\r
+           \r
+           public Vector<String> getVariableNames() {\r
+               return variable_names;\r
+           }\r
+           \r
+           public int getNumCols() {\r
+               return num_cols;\r
+           }\r
+           \r
+           public int getNumRows() {\r
+               return num_rows;\r
+           }\r
+           \r
+           public int getMaxNumUniqueValues() {\r
+               if (snp_data == null || max_num_unique_values > 3) {\r
+                 // If no snp data or one variable with more than 3 unique values, return that value\r
+                 return max_num_unique_values;\r
+               } else {\r
+                 // If snp data and no variable with more than 3 unique values, return 3\r
+                 return 3;\r
+               }\r
+           }\r
+\r
+           public Vector<Integer> getNoSplitVariables() {\r
+               return no_split_variables;\r
+           }\r
+           \r
+           public void addNoSplitVariable(int varID) {\r
+               no_split_variables.add(varID);\r
+               no_split_variables.sort(null);\r
+           }\r
            \r
+           public Vector<Boolean> getIsOrderedVariable() {\r
+               return is_ordered_variable;\r
+           }\r
+\r
+           // Original name setIsOrderedVariable\r
+           public void setIsOrderedVariableString(Vector<String> unordered_variable_names) {\r
+               if (is_ordered_variable.size() > num_cols) {\r
+                       for (i = is_ordered_variable.size() - 1; i >= num_cols; i++) {\r
+                               is_ordered_variable.remove(i);\r
+                       }\r
+               }\r
+               else if (is_ordered_variable.size() < num_cols) {\r
+                       for (i = is_ordered_variable.size(); i < num_cols; i++) {\r
+                               is_ordered_variable.add(i, true);\r
+                       }\r
+               }\r
+               for (i = 0; i < unordered_variable_names.size(); i++) {\r
+                 int varID = getVariableID(unordered_variable_names.get(i));\r
+                 is_ordered_variable.add(varID, false);\r
+               }\r
+           }\r
            \r
+           public void setIsOrderedVariable(Vector<Boolean> is_ordered_variable) {\r
+               this.is_ordered_variable = is_ordered_variable;\r
+           }\r
+\r
+           public boolean isOrderedVariable(int varID) {\r
+               // Use permuted data for corrected impurity importance\r
+               if (varID >= num_cols) {\r
+                 varID = getUnpermutedVarID(varID);\r
+               }\r
+               return is_ordered_variable.get(varID);\r
+           }\r
+           \r
+           public void permuteSampleIDs() {\r
+               permuted_sampleIDs.clear();\r
+               for (i = 0; i < num_rows; i++) {\r
+                       permuted_sampleIDs.add(i);\r
+               }\r
+               shuffle(permuted_sampleIDs);\r
+           }\r
+\r
+\r
        } // private class Data\r
        \r
+       private class DoubleData extends Data {\r
+           private double data[] = null;\r
+           \r
+           public void reserveMemory() {\r
+               data = new double[num_cols * num_rows];\r
+           }\r
+           \r
+           public void set(int col, int row, double value, boolean error[]) {\r
+               data[col * num_rows + row] = value;\r
+           }\r
+           \r
+           public double get(int row, int col) {\r
+               // Use permuted data for corrected impurity importance\r
+               if (col >= num_cols) {\r
+                 col = getUnpermutedVarID(col);\r
+                 row = getPermutedSampleID(row);\r
+               }\r
+\r
+               if (col < num_cols_no_snp) {\r
+                 return data[col * num_rows + row];\r
+               } else {\r
+                 // Get data out of snp storage. -1 because of GenABEL coding.\r
+                 int idx = (col - num_cols_no_snp) * num_rows_rounded + row;\r
+                 double result = (((snp_data[idx / 4] & mask[idx % 4]) >> offset[idx % 4]) - 1);\r
+                 return result;\r
+               }\r
+             }\r
+\r
+\r
+       } // private class DoubleData extends Data\r
+       \r
        private int roundToNextMultiple(int value, int multiple) {\r
                if (multiple == 0) {\r
                        return value;\r
@@ -298,6 +586,19 @@ public class StochasticForests extends AlgorithmBase {
                }\r
                return value + multiple - remainder;\r
        }\r
+       \r
+       private void shuffle(Vector<Integer> v)\r
+       {\r
+           int index, temp;\r
+           Random random = new Random();\r
+           for (int i = v.size() - 1; i > 0; i--)\r
+           {\r
+               index = random.nextInt(i + 1);\r
+               temp = v.get(index);\r
+               v.set(index,v.get(i));\r
+               v.set(i,temp);\r
+           }\r
+       }\r
 \r
        \r
        public void runAlgorithm() {\r