001/*
002 * Copyright (c) 2009 The openGion Project.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
013 * either express or implied. See the License for the specific language
014 * governing permissions and limitations under the License.
015 */
016package org.opengion.penguin.math.statistics;
017
018import java.util.Arrays;
019
020import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
021
022/**
023 * apache.commons.mathを利用したOLS重回帰計算のクラスです。
024 * y = c0 + x1c1 + x2c2 + x3c3 ...の係数を求めます。
025 * c0の切片を考慮するかどうかはnoInterceptで決めます。
026 *
027 */
028public class HybsMultiRegression implements HybsRegression {
029        private double cnst[];                  // 各係数(xの種類+1になる?)
030        private double rsquare;                 // 決定係数
031        private boolean noIntercept;    //切片を利用するかどうか
032
033        /**
034         * コンストラクタ。
035         * 与えた二次元データを元に重回帰を計算します。
036         * xデータとして二次元配列を与えます。
037         * noInterceptで切片有り無しを選択します。
038         *
039         * @param in_x 説明変数
040         * @param in_y 目的変数
041         * @param noIntercept 切片利用有無(trueで利用しない)
042         */
043        public HybsMultiRegression( final double[][] in_x, final double[] in_y, final boolean noIntercept ) {
044                train( in_x, in_y, noIntercept );
045        }
046
047        /**
048         * 与えた二次元データを元に重回帰を計算します。
049         * xデータとして二次元配列を与えます。
050         * noInterceptで切片有り無しを選択します。
051         *
052         * @param in_x 説明変数
053         * @param in_y 目的変数
054         * @param noIntercept 切片利用有無(trueで利用しない)
055         */
056        private void train( final double[][] in_x, final double[] in_y, final boolean noIntercept ) {
057                this.noIntercept = noIntercept;
058
059                // ここで重回帰計算
060                final OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
061                regression.setNoIntercept(noIntercept);
062                regression.newSampleData(in_y, in_x);
063
064                cnst    = regression.estimateRegressionParameters();
065                rsquare = regression.calculateRSquared();
066        }
067
068        /**
069         * 係数をセットした配列を返します。
070         *
071         * @return 係数の配列
072         */
073        @Override       // HybsRegression
074        public double[] getCoefficient() {
075                return Arrays.copyOf( cnst,cnst.length );
076        }
077
078        /**
079         * 決定係数の取得。
080         * @return 決定係数
081         */
082        @Override       // HybsRegression
083        public double getRSquare() {
084                return rsquare;
085        }
086
087        /**
088         * 計算( c0 + c1x1...)を行う。
089         * noInterceptによってc0の利用を決める。
090         * xの大きさが足りない場合は0を返す。
091         *
092         * @param in_x 必要な大きさの変数配列
093         * @return 計算結果
094         */
095        @Override       // HybsRegression
096        public double predict( final double... in_x ) {
097                double rtn = 0;
098                final int itr = noIntercept ? 0 : 1;
099                if( in_x.length < cnst.length-itr ) {
100                        return rtn;
101                }
102
103                for( int i=0; i < in_x.length; i++ ) {
104                        rtn = rtn + in_x[i] * cnst[i+itr];
105                }
106                if( !noIntercept ) { rtn = rtn + cnst[0]; }
107
108                return rtn;
109        }
110
111        // ================ ここまでが本体 ================
112
113        /**
114         * ここからテスト用mainメソッド 。
115         *
116         * @param args 引数
117         */
118        public static void main( final String[] args ) {
119                // データはhttp://mjin.doshisha.ac.jp/R/14.htmlより
120                final double[] y = new double[] { 50, 60, 65, 65, 70, 75, 80, 85, 90, 95 };
121                double[][] x = new double[10][];
122                x[0] = new double[] { 165, 65 };
123                x[1] = new double[] { 170, 68 };
124                x[2] = new double[] { 172, 70 };
125                x[3] = new double[] { 175, 65 };
126                x[4] = new double[] { 170, 80 };
127                x[5] = new double[] { 172, 85 };
128                x[6] = new double[] { 183, 78 };
129                x[7] = new double[] { 187, 79 };
130                x[8] = new double[] { 180, 95 };
131                x[9] = new double[] { 185, 97 };
132
133                final HybsMultiRegression mr = new HybsMultiRegression(x,y,true);
134
135                System.out.println( mr.getRSquare() );
136                System.out.println( Arrays.toString( mr.getCoefficient()) );
137
138                System.out.println( mr.predict( new double[] { 169,85 } ));
139        }
140}
141