001package org.opengion.penguin.math.statistics;
002
003import java.util.Arrays;
004
005import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
006
007/**
008 * apache.commons.mathを利用したOLS重回帰計算のクラスです。
009 * y = c0 + x1c1 + x2c2 + x3c3 ...の係数を求めます。
010 * c0の切片を考慮するかどうかはnoInterceptで決めます。
011 * 
012 */
013public class HybsMultiRegression {
014        private double coe[];                   // 各係数(xの種類+1になる?)
015        private double rsquare;         // 決定係数
016        private boolean noIntercept; //切片を利用するかどうか
017                
018        /**
019         * コンストラクタ。
020         * 与えた二次元データを元に重回帰を計算します。
021         * xデータとして二次元配列を与えます。
022         * noInterceptで切片有り無しを選択します。
023         * @param in_x 説明変数
024         * @param in_y 目的変数
025         * @param noIntercept 切片利用有無(trueで利用しない)
026         */
027        public HybsMultiRegression(final double[][] in_x, final double[] in_y, final boolean noIntercept){
028                this.noIntercept = noIntercept;
029                
030                // ここで重回帰計算
031                OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
032                regression.setNoIntercept(noIntercept);
033        regression.newSampleData(in_y, in_x);
034                
035                coe = regression.estimateRegressionParameters();
036                rsquare = regression.calculateRSquared();
037        }
038        
039        /**
040         * コンストラクタ。
041         * 係数配列を与えられるようにしておきます。
042         * (以前に計算したものを利用)
043         * @param in_c 係数配列
044         * @param noIntercept 切片利用有無(trueで利用しない)
045         * 
046         */
047        public HybsMultiRegression( final double[] in_c, final boolean noIntercept){
048                this.coe = in_c;
049                this.noIntercept = noIntercept;
050        }
051        
052        
053        /**
054         * 係数の取得。
055         * @return 係数配列
056         */
057        public double[] getParam(){
058                return coe;
059        }
060        
061        /**
062         * 決定係数の取得。
063         * @return 決定係数
064         */
065        public double getRSquare(){
066                return rsquare;
067        }
068        
069        /**
070         * 計算( c0 + c1x1...)を行う。
071         * noInterceptによってc0の利用を決める。
072         * xの大きさが足りない場合は0を返す。
073         * 
074         * @param in_x 必要な大きさの変数配列
075         * @return 計算結果
076         */
077        public double predict(final double[] in_x){
078                double rtn = 0;
079                int itr = noIntercept ? 0 : 1;
080                if( in_x.length < coe.length-itr ){
081                        return 0;
082                }
083                
084                for( int i=0; i < in_x.length; i++ ){
085                        rtn = rtn + in_x[i] * coe[i+itr];
086                }
087                if( !noIntercept ){ rtn = rtn + coe[0]; }
088                
089                return rtn;
090        }
091
092        /*** ここまでが本体 ***/
093        /*** ここからテスト用mainメソッド ***/
094        /**
095         * @param args *****************************************/
096        public static void main(final String [] args) {
097                // データはhttp://mjin.doshisha.ac.jp/R/14.htmlより
098                double[] y = new double[] { 50, 60, 65, 65, 70, 75, 80, 85, 90, 95 };
099                double[][] x = new double[10][];
100                x[0] = new double[] { 165, 65 };
101                x[1] = new double[] { 170, 68 };
102                x[2] = new double[] { 172, 70 };
103                x[3] = new double[] { 175, 65 };
104                x[4] = new double[] { 170, 80 };
105                x[5] = new double[] { 172, 85 };
106                x[6] = new double[] { 183, 78 };
107                x[7] = new double[] { 187, 79 };
108                x[8] = new double[] { 180, 95 };
109                x[9] = new double[] { 185, 97 };
110                
111                
112                HybsMultiRegression mr = new HybsMultiRegression(x,y,true);
113                
114                System.out.println( mr.getRSquare() );
115                System.out.println( Arrays.toString( mr.getParam()) );
116                
117                System.out.println( mr.predict( new double[] { 169,85 } ));
118        }
119}
120